# Original idea from Hitcon 2022 web2pdf challenge: https://blog.splitline.tw/hitcon-ctf-2022/#%F0%9F%93%83-web2pdf-web
# Example usage: python extract_pdf_images.py Ticket-831767.pdf -v
import fitz # PyMuPDF
from PIL import Image
import base64
import zlib
import re
from argparse import ArgumentParser

parser = ArgumentParser(description="Extract file contents embedded in bitmap images in PDF file. File content may be plaintext or base64/zlib compressed. Extracted bitmap images and file contents are written to disk")
parser.add_argument('file', help='PDF file path')
parser.add_argument('-v', '--verbose', action='store_true', help='Verbose output')
args = parser.parse_args()

PDF_FILE_PATH = args.file
VERBOSE = args.verbose

def decompress(data: bytes, chunk_size: int = 1024) -> bytes:
    """best effort zlib decompress. Return empty if not zlib compressed."""
    decompressor = zlib.decompressobj(wbits=-15)
    decompressed_output = b''
    for i in range(0, len(data), chunk_size):
        chunk = data[i:i + chunk_size]
        try:
            decompressed_chunk = decompressor.decompress(chunk)
            decompressed_output += decompressed_chunk
        except zlib.error as e:
            if VERBOSE:
                print(f"Zlib error encountered at chunk {i}: {e}. Stopping decompression.")
            # Return the successfully decompressed data up to this point
            return decompressed_output + decompressor.flush()

    # Return the full output if no error was encountered
    return decompressed_output + decompressor.flush()

def decodeb64(encoded_data: bytes, min_b64_output_bytes: int = 12) -> bytes:
    """
    Best-effort Base64 decode.

    If the total successful Base64 decoded output is less than 
    'min_b64_output_bytes', the function falls back to cleaning the entire 
    original input as printable plain text.
    """
    
    # 1. Clean and initialize
    encoded_data = encoded_data.strip()
    decoded_output = b""
    block_size = 4  
    
    # 2. Iterative Base64 Decoding Attempt
    for i in range(0, len(encoded_data), block_size):
        block = encoded_data[i:i + block_size]
        
        try:
            # base64.b64decode directly accepts the bytes block
            decoded_block = base64.b64decode(block, validate=True)
            decoded_output += decoded_block
        
        except base64.binascii.Error as e:
            # --- Base64 Decode Failed: Trigger Fallback/Partial Logic ---
            
            # Check the total size of successfully decoded output bytes so far
            if len(decoded_output) < min_b64_output_bytes:
                if VERBOSE:
                    print(f"B64 decode failed after only {len(decoded_output)} bytes. Falling back to plain text.")
                # Fallback triggered: We use the entire original input data
                return _clean_unprintable_bytes(encoded_data)

            # If the partial decode meets or exceeds the minimum threshold (N), 
            # we treat it as best-effort B64 and return the partial binary data.
            if VERBOSE:
                print(f"B64 decode failed after {len(decoded_output)} bytes (>= {min_b64_output_bytes}). Returning partial output.")
            return decoded_output
            
    # --- Success Case ---
    # Loop completed without error. Now check the total length of the full output.
    if len(decoded_output) < min_b64_output_bytes:
        if VERBOSE:
            print(f"Full B64 decode resulted in only {len(decoded_output)} bytes (< {min_b64_output_bytes}). Falling back to plain text.")
        # Fallback triggered: Even though it decoded fully, the output was too short.
        return _clean_unprintable_bytes(encoded_data)
        
    # Full successful decode that meets the length requirement.
    return decoded_output

# --- Helper Function for Plain Text Cleaning ---
def _clean_unprintable_bytes(data: bytes) -> bytes:
    """Decodes bytes to string and removes all non-printable ASCII characters."""
    
    # Convert the raw bytes to a string.
    # Using 'ascii' and 'ignore' errors to handle non-ASCII bytes gracefully.
    text_input = data.decode('ascii', errors='ignore')

    # Regex to find and remove non-standard printable characters.
    # [^\x20-\x7E\n\r\t] matches anything outside the ASCII printable range.
    RE_UNPRINTABLE = re.compile(r'[^\x20-\x7E\n\r\t]')
    
    cleaned_text = RE_UNPRINTABLE.sub('', text_input)
    
    # Return the cleaned string (type is 'str')
    return cleaned_text.encode()


def extract_data(filename):
    try:
        with open(filename, 'rb') as f:
            data = f.read()

            marker = b'\x1b$)C'
            data = data.partition(marker)[2].replace(b'\x00', b'')            
            b64_decoded_data = decodeb64(data)
            decompressed_data = decompress(b64_decoded_data)
            if decompressed_data:
                data = decompressed_data
            elif b64_decoded_data:
                data = b64_decoded_data

            if VERBOSE:
                print(data)
            
            if data:
                extracted_filename = filename + '.extracted'
                with open(extracted_filename, 'wb') as f2:
                    f2.write(data)

                print(f'Wrote extracted data to: {extracted_filename}')
            
    except Exception as e:
        print(f'Unexpected error extracting data: {e}')

try:
    pdf_file = fitz.open(PDF_FILE_PATH)
    for page_index in range(len(pdf_file)):
        page = pdf_file[page_index]
        
        image_list = page.get_images(full=True)

        if not image_list:
            continue

        if VERBOSE:
            print(f"Found {len(image_list)} images on page {page_index + 1}")

        for image_index, img in enumerate(image_list, start=1):
            xref = img[0]

            try:
                
                # Use PyMuPDF's pixmap to get the raw image data
                pix = fitz.Pixmap(pdf_file, xref)
                
                # Check for alpha channel and convert to RGB if present, as BMP doesn't support it
                if pix.alpha:
                    pix = fitz.Pixmap(fitz.csRGB, pix)

                # Get the raw pixel data from the pixmap
                image_data = pix.samples

                # Create a Pillow image object from the raw pixel data
                pil_image = Image.frombytes("RGB", [pix.width, pix.height], image_data)

                # Define the filename and save using Pillow in the BMP format
                image_filename = f"page{page_index+1}_img{image_index}.bmp"
                pil_image.save(image_filename, "BMP")

                if VERBOSE:
                    print(f"Saved original image as BMP: {image_filename}")

                extracted_data = extract_data(image_filename)
            
            except Exception as e:
                print(f"Error processing image {image_index} on page {page_index + 1}: {e}")
            print()


    pdf_file.close()

except Exception as e:
    print(f"An error occurred: {e}")
