#!/usr/bin/env python3
"""
Batch Evaluation Script for AAA Security Assessment Framework

This script runs all security assessment tasks and calculates success rate.
Results are saved to batch_results/ directory with a timestamped summary.
"""

import os
import sys
import json
import subprocess
import time
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Any


def get_all_tasks(task_dir: Path) -> List[str]:
    """Get all task IDs from the task directory."""
    tasks = []
    if not task_dir.exists():
        print(f"Error: Task directory not found: {task_dir}")
        return tasks

    for item in sorted(task_dir.iterdir()):
        if item.is_dir() and item.name.startswith("task-"):
            tasks.append(item.name)

    return tasks


def run_single_task(task_id: str, timeout: int = 1800) -> Dict[str, Any]:
    """Run a single task and return the result."""
    print(f"\n{'='*80}")
    print(f"Running: {task_id}")
    print(f"{'='*80}")

    start_time = time.time()

    try:
        # Run the task with timeout (default 30 minutes per task)
        result = subprocess.run(
            ["uv", "run", "python", "main.py", "launch", task_id],
            capture_output=True,
            text=True,
            timeout=timeout,
            env=os.environ.copy()
        )

        duration = time.time() - start_time

        # Check if task completed
        if result.returncode == 0:
            print(f"✓ {task_id} completed in {duration:.1f}s")
            return {
                "task_id": task_id,
                "status": "completed",
                "duration": duration,
                "stdout": result.stdout,
                "stderr": result.stderr
            }
        else:
            print(f"✗ {task_id} failed with return code {result.returncode}")
            return {
                "task_id": task_id,
                "status": "failed",
                "duration": duration,
                "return_code": result.returncode,
                "stdout": result.stdout,
                "stderr": result.stderr
            }

    except subprocess.TimeoutExpired:
        duration = time.time() - start_time
        print(f"⏱ {task_id} timed out after {duration:.1f}s")
        return {
            "task_id": task_id,
            "status": "timeout",
            "duration": duration
        }

    except Exception as e:
        duration = time.time() - start_time
        print(f"✗ {task_id} error: {e}")
        return {
            "task_id": task_id,
            "status": "error",
            "duration": duration,
            "error": str(e)
        }


def parse_result_file(task_id: str, results_dir: Path) -> Dict[str, Any] | None:
    """Parse the most recent result file for a task."""
    # Find the most recent result file for this task
    result_files = list(results_dir.glob(f"{task_id}_*.json"))
    if not result_files:
        return None

    # Get the most recent file
    latest_file = max(result_files, key=lambda p: p.stat().st_mtime)

    try:
        with open(latest_file, 'r') as f:
            return json.load(f)
    except Exception as e:
        print(f"Warning: Could not parse {latest_file}: {e}")
        return None


def generate_summary(
    task_results: List[Dict[str, Any]],
    parsed_results: Dict[str, Dict[str, Any]],
    total_duration: float
) -> Dict[str, Any]:
    """Generate summary statistics."""
    total_tasks = len(task_results)
    completed = sum(1 for r in task_results if r["status"] == "completed")
    failed = sum(1 for r in task_results if r["status"] == "failed")
    timeout = sum(1 for r in task_results if r["status"] == "timeout")
    error = sum(1 for r in task_results if r["status"] == "error")

    # Count successful exploits based on parsed result files
    successful_exploits = 0
    exploit_scores = []

    for task_id, parsed in parsed_results.items():
        if parsed and parsed.get("success"):
            successful_exploits += 1
            if "score" in parsed:
                exploit_scores.append(parsed["score"])

    success_rate = (successful_exploits / total_tasks * 100) if total_tasks > 0 else 0
    avg_score = (sum(exploit_scores) / len(exploit_scores)) if exploit_scores else 0

    return {
        "total_tasks": total_tasks,
        "completed": completed,
        "failed": failed,
        "timeout": timeout,
        "error": error,
        "successful_exploits": successful_exploits,
        "success_rate": f"{success_rate:.2f}%",
        "average_score": f"{avg_score:.2f}",
        "total_duration_seconds": total_duration,
        "total_duration_formatted": f"{total_duration / 3600:.2f} hours"
    }


def main():
    """Main batch evaluation function."""
    # Check for API key
    if not os.environ.get("OPENAI_API_KEY"):
        print("Error: OPENAI_API_KEY environment variable not set")
        print("Please run: export OPENAI_API_KEY='your-api-key'")
        sys.exit(1)

    # Setup paths
    project_root = Path(__file__).parent
    task_dir = project_root / "data" / "task"
    results_dir = project_root / "results"
    batch_results_dir = project_root / "batch_results"

    # Create batch results directory
    batch_results_dir.mkdir(exist_ok=True)

    # Get all tasks
    tasks = get_all_tasks(task_dir)
    if not tasks:
        print("No tasks found!")
        sys.exit(1)

    print(f"Found {len(tasks)} tasks to evaluate")
    print(f"Results will be saved to: {batch_results_dir}")
    print()

    # Run all tasks
    task_results = []
    start_time = time.time()

    for i, task_id in enumerate(tasks, 1):
        print(f"\n[{i}/{len(tasks)}] Processing {task_id}...")
        result = run_single_task(task_id)
        task_results.append(result)

    total_duration = time.time() - start_time

    # Parse individual result files
    print("\nParsing individual result files...")
    parsed_results = {}
    for task_id in tasks:
        parsed = parse_result_file(task_id, results_dir)
        if parsed:
            parsed_results[task_id] = parsed

    # Generate summary
    summary = generate_summary(task_results, parsed_results, total_duration)

    # Save batch results
    timestamp = datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
    batch_result_file = batch_results_dir / f"batch_results_{timestamp}.json"

    full_report = {
        "timestamp": timestamp,
        "summary": summary,
        "task_execution": task_results,
        "parsed_results": parsed_results
    }

    with open(batch_result_file, 'w') as f:
        json.dump(full_report, f, indent=2)

    # Print summary
    print("\n" + "="*80)
    print("BATCH EVALUATION SUMMARY")
    print("="*80)
    print(f"Total Tasks:         {summary['total_tasks']}")
    print(f"Completed:           {summary['completed']}")
    print(f"Failed:              {summary['failed']}")
    print(f"Timeout:             {summary['timeout']}")
    print(f"Error:               {summary['error']}")
    print(f"Successful Exploits: {summary['successful_exploits']}")
    print(f"Success Rate:        {summary['success_rate']}")
    print(f"Average Score:       {summary['average_score']}")
    print(f"Total Duration:      {summary['total_duration_formatted']}")
    print(f"\nDetailed results saved to: {batch_result_file}")
    print("="*80)


if __name__ == "__main__":
    main()
