4837 Total CVEs
26 Years
GitHub
README.md
Rendering markdown...
POC / CVE-2025-53770.py PY
#!/usr/bin/env python3
"""
CVE-2025-53770 SharePoint Vulnerability Scanner.

A comprehensive security scanner designed to identify SharePoint instances vulnerable
to CVE-2025-53770, which involves a deserialization vulnerability in SharePoint's
ExcelDataSet component that allows remote code execution and machine key extraction.

This scanner incorporates real-world attack patterns observed in active exploitation
campaigns and provides detailed confidence scoring based on machine key extraction,
secondary payload deployment, and SharePoint component processing indicators.

"""

import json
import csv
import ssl
import socket
import re
import time
import random
import logging
import hashlib
import asyncio
from datetime import datetime, timedelta
from concurrent.futures import ThreadPoolExecutor, as_completed
from urllib.parse import urlparse
from pathlib import Path
from typing import Dict, List, Optional, Union, Any, Tuple
from dataclasses import dataclass, asdict
from contextlib import contextmanager, suppress
import threading
import warnings

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
import urllib3


@dataclass
class ScanResult:
    """Data class for scan results"""

    host: str
    url: str
    scan_time: str
    vulnerable: bool = False
    status_code: Optional[int] = None
    response_size: int = 0
    error: Optional[str] = None
    response_time: Optional[float] = None
    request_size: int = 0
    detection_confidence: str = "none"
    confidence_score: int = 0
    vulnerability_indicators: List[str] = None
    ssl_info: Dict[str, Any] = None
    sharepoint_info: Dict[str, Any] = None
    security_headers: Dict[str, Any] = None
    sharepoint_version_hint: str = "Unknown"
    endpoint_tested: str = ""
    cached_result: bool = False
    scan_metrics: Dict[str, Any] = None

    def __post_init__(self):
        if self.vulnerability_indicators is None:
            self.vulnerability_indicators = []
        if self.ssl_info is None:
            self.ssl_info = {}
        if self.sharepoint_info is None:
            self.sharepoint_info = {}
        if self.security_headers is None:
            self.security_headers = {}
        if self.scan_metrics is None:
            self.scan_metrics = {}


class ConfigManager:
    """Configuration management for the scanner"""

    def __init__(self, config_path: str = "config.json"):
        self.config_path = config_path
        self.config = self._load_config()
        self._compile_patterns()

    def _load_config(self) -> Dict[str, Any]:
        """Load configuration from JSON file"""
        try:
            with open(self.config_path, "r", encoding="utf-8") as f:
                return json.load(f)
        except FileNotFoundError:
            logging.warning(f"Config file {self.config_path} not found, using defaults")
            return self._get_default_config()
        except json.JSONDecodeError as e:
            logging.error(f"Invalid JSON in config file: {e}")
            return self._get_default_config()

    def _get_default_config(self) -> Dict[str, Any]:
        """Default configuration if file is missing"""
        return {
            "detection_rules": {
                "critical_patterns": [
                    {
                        "name": "machine_key_extraction",
                        "pattern": r"[A-F0-9]{128,256}\|[A-Z0-9]+\|[A-F0-9]{48,96}\|[A-Z0-9]+\|Framework[0-9A-Z]+",
                        "score": 95,
                        "description": "Full machine key extraction response detected",
                        "case_insensitive": True
                    }
                ],
                "high_patterns": [],
                "medium_patterns": [
                    {
                        "name": "sharepoint_components",
                        "patterns": ["Scorecard", "ExcelDataSet"],
                        "score": 25,
                        "description": "SharePoint vulnerable components"
                    }
                ],
                "low_patterns": []
            },
            "confidence_thresholds": {
                "critical": 85,
                "high": 75,
                "medium": 60,
                "low": 50,
            },
            "scan_settings": {
                "default_timeout": 10,
                "default_threads": 10,
                "ssl_verification": True,
            },
            "endpoints": [
                "/_layouts/15/ToolPane.aspx?DisplayMode=Edit&a=/ToolPane.aspx"
            ],
        }

    def _compile_patterns(self) -> None:
        """Pre-compile regex patterns for performance"""
        self.compiled_patterns = {}
        detection_rules = self.config.get("detection_rules", {})

        for category, patterns in detection_rules.items():
            self.compiled_patterns[category] = []
            for pattern_config in patterns:
                if "pattern" in pattern_config:
                    flags = (
                        re.IGNORECASE
                        if pattern_config.get("case_insensitive", False)
                        else 0
                    )
                    compiled = re.compile(pattern_config["pattern"], flags)
                    self.compiled_patterns[category].append(
                        {**pattern_config, "compiled_pattern": compiled}
                    )
                else:
                    self.compiled_patterns[category].append(pattern_config)

    def get_detection_rules(self) -> Dict[str, List[Dict]]:
        """Get compiled detection rules"""
        return self.compiled_patterns

    def get_confidence_thresholds(self) -> Dict[str, int]:
        """Get confidence thresholds"""
        return self.config.get("confidence_thresholds", {})

    def get_scan_settings(self) -> Dict[str, Any]:
        """Get scan settings"""
        return self.config.get("scan_settings", {})

    def get_endpoints(self) -> List[str]:
        """Get endpoints to test"""
        return self.config.get("endpoints", [])
    
    def get_payload_config(self) -> Dict[str, str]:
        """Get payload configuration"""
        return self.config.get("payload_config", {})
    
    def get_request_headers(self) -> Dict[str, str]:
        """Get request headers configuration"""
        return self.config.get("request_headers", {})
    
    def get_rate_limiting(self) -> Dict[str, Any]:
        """Get rate limiting configuration"""
        return self.config.get("rate_limiting", {})
    
    def get_caching(self) -> Dict[str, Any]:
        """Get caching configuration"""
        return self.config.get("caching", {})
    
    def get_metrics(self) -> Dict[str, Any]:
        """Get metrics configuration"""
        return self.config.get("metrics", {})


class RateLimiter:
    """Rate limiting implementation"""
    
    def __init__(self, requests_per_second: float = 10, burst_size: int = 20):
        self.requests_per_second = requests_per_second
        self.burst_size = burst_size
        self.tokens = burst_size
        self.last_update = time.time()
        self.lock = threading.Lock()
    
    def acquire(self) -> bool:
        """Acquire a token for rate limiting"""
        with self.lock:
            now = time.time()
            elapsed = now - self.last_update
            self.tokens = min(self.burst_size, self.tokens + elapsed * self.requests_per_second)
            self.last_update = now
            
            if self.tokens >= 1:
                self.tokens -= 1
                return True
            return False
    
    def wait_for_token(self):
        """Wait until a token is available"""
        while not self.acquire():
            time.sleep(1.0 / self.requests_per_second)


class ScanCache:
    """Caching implementation for scan results"""
    
    def __init__(self, cache_file: str = "scan_cache.json", duration_seconds: int = 3600):
        self.cache_file = cache_file
        self.duration = timedelta(seconds=duration_seconds)
        self.cache = {}
        self.lock = threading.Lock()
        self._load_cache()
    
    def _load_cache(self):
        """Load cache from file"""
        try:
            if Path(self.cache_file).exists():
                with open(self.cache_file, 'r') as f:
                    data = json.load(f)
                    for key, value in data.items():
                        # Convert ISO timestamp back to datetime
                        value['timestamp'] = datetime.fromisoformat(value['timestamp'])
                        # Keep result data as dict for lazy loading
                        self.cache[key] = value
        except Exception as e:
            logging.warning(f"Failed to load cache: {e}")
    
    def _save_cache(self):
        """Save cache to file"""
        try:
            # Convert datetime to ISO string for JSON serialization
            cache_data = {}
            for key, value in self.cache.items():
                cache_data[key] = {
                    **value,
                    'timestamp': value['timestamp'].isoformat()
                }
            
            with open(self.cache_file, 'w') as f:
                json.dump(cache_data, f, indent=2)
        except Exception as e:
            logging.warning(f"Failed to save cache: {e}")
    
    def _get_cache_key(self, host: str, endpoint: str) -> str:
        """Generate cache key for host/endpoint combination"""
        return hashlib.md5(f"{host}:{endpoint}".encode()).hexdigest()
    
    def get(self, host: str, endpoint: str) -> Optional[ScanResult]:
        """Get cached result if valid"""
        with self.lock:
            key = self._get_cache_key(host, endpoint)
            if key in self.cache:
                entry = self.cache[key]
                if datetime.now() - entry['timestamp'] < self.duration:
                    # Reconstruct ScanResult from cached dict
                    result_data = entry['result']
                    if isinstance(result_data, dict):
                        result = ScanResult(**result_data)
                    else:
                        result = result_data  # Handle old format
                    result.cached_result = True
                    return result
                else:
                    # Expired entry
                    del self.cache[key]
        return None
    
    def put(self, host: str, endpoint: str, result: ScanResult):
        """Cache scan result"""
        with self.lock:
            key = self._get_cache_key(host, endpoint)
            self.cache[key] = {
                'timestamp': datetime.now(),
                'result': asdict(result)  # Convert dataclass to dict for JSON serialization
            }
            self._save_cache()


class MetricsCollector:
    """Performance and accuracy metrics collection"""
    
    def __init__(self):
        self.metrics = {
            'total_scans': 0,
            'successful_scans': 0,
            'vulnerable_found': 0,
            'average_response_time': 0.0,
            'cache_hits': 0,
            'rate_limited_requests': 0,
            'ssl_errors': 0,
            'scan_start_time': None,
            'scan_end_time': None
        }
        self.response_times = []
        self.lock = threading.Lock()
    
    def start_scan(self):
        """Mark scan start time"""
        self.metrics['scan_start_time'] = datetime.now()
    
    def end_scan(self):
        """Mark scan end time"""
        self.metrics['scan_end_time'] = datetime.now()
    
    def record_scan(self, result: ScanResult):
        """Record scan result metrics"""
        with self.lock:
            self.metrics['total_scans'] += 1
            
            if not result.error:
                self.metrics['successful_scans'] += 1
                
                if result.response_time:
                    self.response_times.append(result.response_time)
                    self.metrics['average_response_time'] = sum(self.response_times) / len(self.response_times)
            
            if result.vulnerable:
                self.metrics['vulnerable_found'] += 1
            
            if result.cached_result:
                self.metrics['cache_hits'] += 1
            
            if 'ssl' in str(result.error).lower():
                self.metrics['ssl_errors'] += 1
    
    def record_rate_limit(self):
        """Record rate limiting event"""
        with self.lock:
            self.metrics['rate_limited_requests'] += 1
    
    def get_metrics(self) -> Dict[str, Any]:
        """Get current metrics"""
        with self.lock:
            metrics = self.metrics.copy()
            if metrics['scan_start_time'] and metrics['scan_end_time']:
                duration = metrics['scan_end_time'] - metrics['scan_start_time']
                metrics['total_scan_duration'] = duration.total_seconds()
                metrics['scans_per_second'] = metrics['total_scans'] / duration.total_seconds() if duration.total_seconds() > 0 else 0
            return metrics


class SharePointScanner:
    """Main scanner class with enhanced capabilities"""

    def __init__(self, config_path: str = "config.json"):
        self.config_manager = ConfigManager(config_path)
        self.logger = logging.getLogger(__name__)
        
        # Initialize enhanced components
        rate_config = self.config_manager.get_rate_limiting()
        if rate_config.get('enabled', True):
            self.rate_limiter = RateLimiter(
                requests_per_second=rate_config.get('requests_per_second', 10),
                burst_size=rate_config.get('burst_size', 20)
            )
        else:
            self.rate_limiter = None
        
        cache_config = self.config_manager.get_caching()
        if cache_config.get('enabled', True):
            self.cache = ScanCache(
                cache_file=cache_config.get('cache_file', 'scan_cache.json'),
                duration_seconds=cache_config.get('cache_duration_seconds', 3600)
            )
        else:
            self.cache = None
        
        metrics_config = self.config_manager.get_metrics()
        if metrics_config.get('enabled', True):
            self.metrics = MetricsCollector()
        else:
            self.metrics = None
        
        # Context-specific SSL warnings handling
        self._ssl_context = None

    @staticmethod
    def get_ssl_info(host: str, port: int = 443, timeout: int = 10) -> Dict[str, Any]:
        """Get SSL certificate information"""
        try:
            context = ssl.create_default_context()
            with socket.create_connection((host, port), timeout=timeout) as sock:
                with context.wrap_socket(sock, server_hostname=host) as ssock:
                    cert = ssock.getpeercert()
                    return {
                        "subject": dict(x[0] for x in cert.get("subject", [])),
                        "issuer": dict(x[0] for x in cert.get("issuer", [])),
                        "version": cert.get("version"),
                        "serial_number": cert.get("serialNumber"),
                        "not_before": cert.get("notBefore"),
                        "not_after": cert.get("notAfter"),
                        "cipher": ssock.cipher(),
                        "compression": ssock.compression(),
                    }
        except Exception as e:
            return {"error": str(e)}

    @staticmethod
    def detect_sharepoint_version(
        response_headers: Dict[str, str], response_text: str
    ) -> Dict[str, Any]:
        """Detect SharePoint version from headers and response"""
        version_info = {
            "sharepoint_version": "Unknown",
            "server_info": response_headers.get("Server", "Unknown"),
            "asp_net_version": response_headers.get("X-AspNet-Version", "Unknown"),
            "sharepoint_headers": {},
        }

        # Check for SharePoint-specific headers
        sp_headers = [
            "MicrosoftSharePointTeamServices",
            "X-SharePointError",
            "X-MS-SPError",
            "SPIisLatency",
            "SPRequestGuid",
            "X-MS-InvokeApp",
        ]

        for header in sp_headers:
            if header in response_headers:
                version_info["sharepoint_headers"][header] = response_headers[header]

        # Try to detect version from response text patterns
        version_patterns = [
            (r"SharePoint.*?(\d{4})", "SharePoint {version}"),
            (r"Microsoft.*?SharePoint.*?(\d+\.\d+)", "SharePoint {version}"),
            (r"_layouts/(\d+)/", "SharePoint {version} (Layout Version)"),
        ]

        for pattern, format_str in version_patterns:
            match = re.search(pattern, response_text, re.IGNORECASE)
            if match:
                version_info["sharepoint_version"] = format_str.format(
                    version=match.group(1)
                )
                break

        return version_info

    def create_session_with_retries(self) -> requests.Session:
        """Create a requests session with retry strategy"""
        session = requests.Session()
        settings = self.config_manager.get_scan_settings()
        
        retry_strategy = Retry(
            total=settings.get("max_retries", 3),
            backoff_factor=settings.get("backoff_factor", 1),
            status_forcelist=[429, 500, 502, 503, 504],
        )
        adapter = HTTPAdapter(max_retries=retry_strategy)
        session.mount("http://", adapter)
        session.mount("https://", adapter)
        return session

    def analyze_response(
        self, response: requests.Response, response_time: float, url: str
    ) -> Tuple[List[str], int, str]:
        """Analyze response for vulnerability indicators"""
        vulnerability_indicators = []
        confidence_score = 0
        detection_rules = self.config_manager.get_detection_rules()

        # Critical patterns (machine key extraction)
        for pattern_config in detection_rules.get("critical_patterns", []):
            if "compiled_pattern" in pattern_config:
                if pattern_config["compiled_pattern"].search(response.text):
                    vulnerability_indicators.append(pattern_config["description"])
                    confidence_score += pattern_config["score"]
            elif "patterns" in pattern_config:
                for pattern in pattern_config["patterns"]:
                    if pattern in response.text:
                        vulnerability_indicators.append(
                            f"{pattern_config['description']}: {pattern}"
                        )
                        confidence_score += pattern_config["score"]
                        break

        # Response characteristics analysis
        if response.status_code == 200:
            if len(response.text) > 200:
                vulnerability_indicators.append(
                    f"HTTP 200 with substantial content ({len(response.text)} bytes)"
                )
                confidence_score += 15
            else:
                vulnerability_indicators.append("HTTP 200 response to crafted payload")
                confidence_score += 5

        # Version-specific indicators
        version_hint = "Unknown"
        if "WEBSER~1\\15\\TEMPLATE" in response.text or "_layouts/15/" in url:
            vulnerability_indicators.append(
                "SharePoint 2013/2016 layout structure detected"
            )
            version_hint = "2013-2016"
            confidence_score += 2

        return vulnerability_indicators, confidence_score, version_hint

    def determine_confidence_level(self, confidence_score: int) -> Tuple[bool, str]:
        """Determine vulnerability status and confidence level"""
        thresholds = self.config_manager.get_confidence_thresholds()
        
        if confidence_score >= thresholds.get("critical", 85):
            return True, "critical"
        elif confidence_score >= thresholds.get("high", 75):
            return True, "high"
        elif confidence_score >= thresholds.get("medium", 60):
            return True, "medium"
        elif confidence_score >= thresholds.get("low", 50):
            return True, "low"
        else:
            return False, "none"

    def scan_host(
        self,
        host: str,
        timeout: int = None,
        additional_mode: bool = False,
        delay_range: Optional[Tuple[float, float]] = None,
        ssl_check: bool = False,
    ) -> ScanResult:
        """Scan a single host for CVE-2025-53770 vulnerability"""

        if timeout is None:
            timeout = self.config_manager.get_scan_settings().get("default_timeout", 10)

        endpoints = self.config_manager.get_endpoints()
        best_result = None
        highest_confidence = -1

        # Test multiple endpoints
        for endpoint in endpoints:
            try:
                result = self._scan_endpoint(
                    host, endpoint, timeout, additional_mode, delay_range, ssl_check
                )
                if result.confidence_score > highest_confidence:
                    highest_confidence = result.confidence_score
                    best_result = result
            except Exception as e:
                self.logger.error(
                    f"Error scanning {host} with endpoint {endpoint}: {e}"
                )
                if best_result is None:
                    best_result = ScanResult(
                        host=host,
                        url=f"https://{host}{endpoint}",
                        scan_time=datetime.now().isoformat(),
                        error=str(e),
                        endpoint_tested=endpoint,
                    )

        return best_result or ScanResult(
            host=host,
            url=f"https://{host}",
            scan_time=datetime.now().isoformat(),
            error="No valid endpoints to test",
        )

    def _scan_endpoint(
        self,
        host: str,
        endpoint: str,
        timeout: int,
        additional_mode: bool,
        delay_range: Optional[Tuple[float, float]],
        ssl_check: bool,
    ) -> ScanResult:
        """Scan a specific endpoint"""

        url = f"https://{host}{endpoint}"

        # Apply delay if specified
        if delay_range:
            delay = random.uniform(delay_range[0], delay_range[1])
            time.sleep(delay)

        # Apply rate limiting if enabled
        if self.rate_limiter:
            if not self.rate_limiter.acquire():
                if self.metrics:
                    self.metrics.record_rate_limit()
                self.rate_limiter.wait_for_token()
        
        # Check cache first
        if self.cache:
            cached_result = self.cache.get(host, endpoint)
            if cached_result:
                self.logger.info(f"Using cached result for {host}")
                return cached_result
        
        # Prepare request with config-based payload and headers
        payload_config = self.config_manager.get_payload_config()
        if not payload_config:
            # Fallback to minimal payload if config is missing
            payload_config = {
                "MSOTlPn_Uri": "https://{host}/_controltemplates/15/AclEditor.ascx",
                "MSOTlPn_DWP": "<test payload>"
            }
        
        data = payload_config.copy()
        data["MSOTlPn_Uri"] = data["MSOTlPn_Uri"].format(host=host)
        
        headers_config = self.config_manager.get_request_headers()
        headers = headers_config.copy() if headers_config else {
            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:120.0) Gecko/20100101 Firefox/120.0",
            "Content-Type": "application/x-www-form-urlencoded"
        }
        if additional_mode:
            user_agents = self.config_manager.get_scan_settings().get(
                "user_agents", [headers.get("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36")]
            )
            headers["User-Agent"] = random.choice(user_agents)
            headers["Connection"] = "close"

        # Calculate request size
        request_size = len(f"{url}{json.dumps(data)}{json.dumps(headers)}")

        # Initialize result
        result = ScanResult(
            host=host,
            url=url,
            scan_time=datetime.now().isoformat(),
            request_size=request_size,
            endpoint_tested=endpoint,
        )

        try:
            # Get SSL information if requested
            if ssl_check:
                parsed_url = urlparse(url)
                result.ssl_info = self.get_ssl_info(parsed_url.hostname)

            # Create session and make request
            session = self.create_session_with_retries()
            verify_ssl = self.config_manager.get_scan_settings().get("ssl_verification", True)
            
            # Context-specific SSL warning suppression
            with self._suppress_ssl_warnings(verify_ssl):
                start_time = time.time()
                response = session.post(
                    url,
                    headers=headers,
                    data=data,
                    verify=verify_ssl,
                    timeout=timeout,
                    allow_redirects=True,
                )
                response_time = time.time() - start_time

            result.status_code = response.status_code
            result.response_size = len(response.text)
            result.response_time = response_time

            # Analyze security headers
            security_headers = {
                "strict_transport_security": response.headers.get(
                    "Strict-Transport-Security"
                ),
                "content_security_policy": response.headers.get(
                    "Content-Security-Policy"
                ),
                "x_frame_options": response.headers.get("X-Frame-Options"),
                "x_content_type_options": response.headers.get(
                    "X-Content-Type-Options"
                ),
                "x_xss_protection": response.headers.get("X-XSS-Protection"),
            }
            result.security_headers = {
                k: v for k, v in security_headers.items() if v
            }

            # Detect SharePoint version and configuration
            result.sharepoint_info = self.detect_sharepoint_version(
                response.headers, response.text
            )

            # Analyze for vulnerabilities
            indicators, confidence_score, version_hint = self.analyze_response(
                response, response_time, url
            )

            result.vulnerability_indicators = indicators
            result.confidence_score = confidence_score
            result.sharepoint_version_hint = version_hint

            # Determine vulnerability status
            is_vulnerable, confidence_level = self.determine_confidence_level(
                confidence_score
            )
            result.vulnerable = is_vulnerable
            result.detection_confidence = confidence_level

            # Enhanced result metrics
            if self.metrics and self.config_manager.get_metrics().get('track_performance', True):
                result.scan_metrics = {
                    'request_start_time': start_time,
                    'response_time': response_time,
                    'status_code': response.status_code,
                    'response_size': len(response.text),
                    'endpoint_used': endpoint
                }
            
            # Log results
            if is_vulnerable:
                self.logger.info(
                    f"VULNERABLE: {host} - Confidence: {confidence_level} ({confidence_score}%) - "
                    f"Status: {response.status_code}, Size: {len(response.text)}, Time: {response_time:.2f}s"
                )
            else:
                self.logger.info(
                    f"NOT VULNERABLE: {host} - Status: {response.status_code}, "
                    f"Size: {len(response.text)}, Time: {response_time:.2f}s"
                )

        except Exception as e:
            result.error = str(e)[:500]  # Limit error message length
            self.logger.error(f"Error scanning {host}: {result.error}")

        # Cache the result if caching is enabled
        if self.cache and not result.error:
            self.cache.put(host, endpoint, result)
        
        # Record metrics
        if self.metrics:
            self.metrics.record_scan(result)

        return result

    def save_results(self, results: List[ScanResult], output_file: str) -> None:
        """Save results to file in specified format"""

        if output_file.endswith(".json"):
            with open(output_file, "w", encoding="utf-8") as f:
                json.dump(
                    [asdict(result) for result in results],
                    f,
                    indent=2,
                    ensure_ascii=False,
                )

        elif output_file.endswith(".csv"):
            with open(output_file, "w", newline="", encoding="utf-8") as f:
                if results:
                    fieldnames = list(asdict(results[0]).keys())
                    writer = csv.DictWriter(f, fieldnames=fieldnames)
                    writer.writeheader()
                    for result in results:
                        # Convert complex fields to JSON strings for CSV
                        row = asdict(result)
                        for key, value in row.items():
                            if isinstance(value, (list, dict)):
                                row[key] = json.dumps(value)
                        writer.writerow(row)

        else:  # Text format
            with open(output_file, "w", encoding="utf-8") as f:
                for result in results:
                    if result.vulnerable:
                        status = f"VULNERABLE ({result.detection_confidence.upper()} confidence)"
                    else:
                        status = "NOT VULNERABLE"

                    f.write(f"{status}: {result.host}\n")
                    if result.error:
                        f.write(f"  Error: {result.error}\n")
                    elif result.status_code:
                        f.write(
                            f"  Status: {result.status_code}, Size: {result.response_size}, "
                            f"Time: {result.response_time:.2f}s\n"
                        )
                        f.write(f"  Request Size: {result.request_size} bytes\n")
                        f.write(f"  Endpoint: {result.endpoint_tested}\n")

                    if result.vulnerability_indicators:
                        f.write("  Vulnerability Indicators:\n")
                        for indicator in result.vulnerability_indicators:
                            f.write(f"    - {indicator}\n")

                    f.write(f"  Scanned: {result.scan_time}\n")
                    
                    if result.cached_result:
                        f.write(f"  Source: Cached Result\n")
                    
                    f.write("\n")

    @contextmanager
    def _suppress_ssl_warnings(self, verify_ssl: bool):
        """Context manager for SSL warning suppression"""
        if not verify_ssl:
            with warnings.catch_warnings():
                warnings.simplefilter("ignore", urllib3.exceptions.InsecureRequestWarning)
                yield
        else:
            yield
    
    def get_scan_metrics(self) -> Optional[Dict[str, Any]]:
        """Get scan metrics if available"""
        return self.metrics.get_metrics() if self.metrics else None


def setup_logging(verbose: bool = False, log_file: Optional[str] = None) -> None:
    """Setup logging configuration"""

    log_level = logging.INFO if verbose else logging.WARNING
    log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"

    handlers = []

    # Console handler
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(logging.Formatter(log_format))
    handlers.append(console_handler)

    # File handler
    if log_file:
        file_handler = logging.FileHandler(log_file, encoding="utf-8")
        file_handler.setFormatter(logging.Formatter(log_format))
        handlers.append(file_handler)

    # Configure root logger
    logging.basicConfig(
        level=log_level, format=log_format, handlers=handlers, force=True
    )


def print_results_summary(results: List[ScanResult]) -> None:
    """Print comprehensive results summary"""

    total_hosts = len(results)
    vulnerable_results = [r for r in results if r.vulnerable]
    successful_scans = [r for r in results if not r.error]

    success_rate = (len(successful_scans) / total_hosts * 100) if total_hosts > 0 else 0

    print(f"\n{'=' * 60}")
    print(f"SCAN COMPLETE")
    print(f"{'=' * 60}")
    print(f"Total hosts scanned: {total_hosts}")
    print(f"Vulnerable hosts: {len(vulnerable_results)}")
    print(f"Success rate: {success_rate:.1f}%")

    if vulnerable_results:
        print(f"\nVULNERABLE HOSTS (CVE-2025-53770):")

        # Group by confidence level
        confidence_groups = {}
        for result in vulnerable_results:
            confidence = result.detection_confidence
            if confidence not in confidence_groups:
                confidence_groups[confidence] = []
            confidence_groups[confidence].append(result)

        # Display in priority order
        priority_order = ["critical", "high", "medium", "low"]

        for confidence in priority_order:
            if confidence in confidence_groups:
                results_group = confidence_groups[confidence]

                if confidence == "critical":
                    print(
                        f"\n  CRITICAL - MACHINE KEY EXTRACTED ({len(results_group)} hosts):"
                    )
                    for result in results_group:
                        print(
                            f"    • {result.host} (Response time: {result.response_time:.2f}s, "
                            f"Version: {result.sharepoint_version_hint})"
                        )
                        print(
                            f"      WARNING: IMMEDIATE ACTION REQUIRED: Machine keys compromised"
                        )
                else:
                    print(
                        f"\n  {confidence.upper()} CONFIDENCE ({len(results_group)} hosts):"
                    )
                    for result in results_group:
                        print(
                            f"    • {result.host} (Response time: {result.response_time:.2f}s, "
                            f"Version: {result.sharepoint_version_hint})"
                        )


def main():
    """Main function"""
    import argparse

    parser = argparse.ArgumentParser(
        description="CVE-2025-53770 SharePoint Vulnerability Scanner",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python3 scanner.py -i hosts.txt
  python3 scanner.py -i hosts.txt -o results.json -l scan.log -v
  python3 scanner.py -i hosts.txt --ssl-check --additional -t 20
        """,
    )

    parser.add_argument("-i", "--input", required=True, help="Path to host list file")
    parser.add_argument(
        "-o", "--output", help="Output file for results (supports .json, .csv, or .txt)"
    )
    parser.add_argument("-l", "--logfile", help="Log file path")
    parser.add_argument(
        "-t", "--threads", type=int, default=10, help="Number of concurrent threads"
    )
    parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output")
    parser.add_argument(
        "--additional",
        action="store_true",
        help="Enable additional mode (random User-Agents, connection close)",
    )
    parser.add_argument(
        "--delay",
        type=float,
        nargs=2,
        metavar=("MIN", "MAX"),
        help="Random delay between requests (min max in seconds)",
    )
    parser.add_argument(
        "--ssl-check", action="store_true", help="Perform SSL certificate analysis"
    )
    parser.add_argument(
        "--timeout", type=int, default=None, help="Request timeout in seconds"
    )
    parser.add_argument(
        "--config", default="config.json", help="Configuration file path"
    )

    args = parser.parse_args()

    # Setup logging
    setup_logging(args.verbose, args.logfile)
    logger = logging.getLogger(__name__)

    # Initialize scanner
    try:
        scanner = SharePointScanner(args.config)
    except Exception as e:
        logger.error(f"Failed to initialize scanner: {e}")
        return 1

    # Load hosts
    try:
        with open(args.input, "r", encoding="utf-8") as f:
            hosts = [
                line.strip()
                for line in f
                if line.strip() and not line.strip().startswith("#")
            ]
    except FileNotFoundError:
        logger.error(f"Host file not found: {args.input}")
        return 1
    except Exception as e:
        logger.error(f"Error reading host file: {e}")
        return 1

    if not hosts:
        logger.error("No valid hosts found in input file")
        return 1

    print(f"Starting scan of {len(hosts)} hosts with {args.threads} threads...")
    print(f"Target CVE: CVE-2025-53770 (SharePoint ExcelDataSet deserialization)")

    if args.logfile:
        print(f"Logging to: {args.logfile}")
    if args.output:
        print(f"Results will be saved to: {args.output}")

    # Start metrics collection
    if scanner.metrics:
        scanner.metrics.start_scan()

    results = []
    completed = 0

    delay_range = tuple(args.delay) if args.delay else None

    # Scan hosts
    with ThreadPoolExecutor(max_workers=args.threads) as executor:
        futures = {
            executor.submit(
                scanner.scan_host,
                host,
                timeout=args.timeout,
                additional_mode=args.additional,
                delay_range=delay_range,
                ssl_check=args.ssl_check,
            ): host
            for host in hosts
        }

        for future in as_completed(futures):
            try:
                result = future.result()
                results.append(result)
                completed += 1

                # Print progress
                if result.vulnerable:
                    confidence_badge = f"[{result.detection_confidence.upper()}]"
                    print(f"[+] VULNERABLE {confidence_badge}: {result.host}")
                    if args.verbose and result.vulnerability_indicators:
                        for indicator in result.vulnerability_indicators[
                            :3
                        ]:  # Limit output
                            print(f"    └─ {indicator}")
                else:
                    if result.error:
                        print(f"[!] ERROR: {result.host} - {result.error}")
                    else:
                        status_msg = f"[-] Not vulnerable: {result.host}"
                        if (
                            args.verbose
                            and result.sharepoint_info.get("sharepoint_version")
                            != "Unknown"
                        ):
                            status_msg += f" (SharePoint {result.sharepoint_info['sharepoint_version']})"
                        print(status_msg)

                # Progress indicator
                if completed % 25 == 0 or completed == len(hosts):
                    print(
                        f"Progress: {completed}/{len(hosts)} ({completed/len(hosts)*100:.1f}%)"
                    )

            except Exception as e:
                logger.error(f"Error processing scan result: {e}")
                completed += 1

    # End metrics collection
    if scanner.metrics:
        scanner.metrics.end_scan()
    
    # Print summary
    print_results_summary(results)
    
    # Print performance metrics if available
    if scanner.metrics:
        metrics = scanner.get_scan_metrics()
        print(f"\n{'='*60}")
        print(f"PERFORMANCE METRICS")
        print(f"{'='*60}")
        print(f"Total scans: {metrics['total_scans']}")
        print(f"Successful scans: {metrics['successful_scans']}")
        print(f"Cache hits: {metrics['cache_hits']}")
        print(f"Average response time: {metrics['average_response_time']:.2f}s")
        if 'total_scan_duration' in metrics:
            print(f"Total scan duration: {metrics['total_scan_duration']:.2f}s")
            print(f"Scans per second: {metrics['scans_per_second']:.2f}")
        if metrics['rate_limited_requests'] > 0:
            print(f"Rate limited requests: {metrics['rate_limited_requests']}")
        if metrics['ssl_errors'] > 0:
            print(f"SSL errors: {metrics['ssl_errors']}")

    # Save results
    if args.output:
        try:
            scanner.save_results(results, args.output)
            print(f"\nDetailed results saved to: {args.output}")
        except Exception as e:
            logger.error(f"Error saving results: {e}")

    return 0


if __name__ == "__main__":
    exit(main())