# Written by Jeremy Karst 2025
# This software is licensed under the MIT License with Additional Terms.
# See LICENSE.md for more information.

# Built-in Dependencies
import time
import statistics
from concurrent.futures import ThreadPoolExecutor
import argparse
import random

# External Dependencies
import requests
from rich.console import Console
from rich.table import Table
from rich.progress import Progress


def make_request(base_url, session):
    """Make requests to test the full signature flow"""
    start_time = time.time()
    try:
        # Fetch fonts
        fonts_response = session.get(f"{base_url}/fonts")
        fonts_response.raise_for_status()
        available_fonts = list(fonts_response.json().keys())

        # 1. Register a new signature
        name = f"Test User {random.randint(100000000, 999999999)}"
        register_data = {
            'name': name,
            'font': random.choice(available_fonts),
            'timezone': '0',
            'invert': 'false'
        }
        register_response = session.get(f"{base_url}/register", params=register_data)
        register_response.raise_for_status()
        signature_hash = register_response.text
        register_time = time.time()

        # 2. Verify the signature
        verify_response = session.get(f"{base_url}/verify/{signature_hash}")
        verify_response.raise_for_status()
        verify_time = time.time()

        # 3. Get the signature image
        image_response = session.get(f"{base_url}/images/{signature_hash}.png")
        image_response.raise_for_status()
        image_time = time.time()

        return {
            'register_time': register_time - start_time,
            'register_status': register_response.status_code,
            'verify_time': verify_time - register_time,
            'verify_status': verify_response.status_code,
            'image_time': image_time - verify_time,
            'image_status': image_response.status_code
        }
    except requests.RequestException as e:
        print(f"Request failed: {e}")
        return None

def run_benchmark(base_url, num_threads, requests_per_thread, progress=None):
    """Run benchmark with specified number of threads"""
    session = requests.Session()
    register_times = []
    verify_times = []
    image_times = []
    errors = 0
    
    def worker():
        for _ in range(requests_per_thread):
            result = make_request(base_url, session)
            if result is None:
                nonlocal errors
                errors += 1
            else:
                register_times.append(result['register_time'])
                verify_times.append(result['verify_time'])
                image_times.append(result['image_time'])
            if progress:
                progress.update(task, advance=1)

    # Create and start threads
    with Progress() as progress:
        task = progress.add_task(f"[cyan]Running {num_threads} threads...", total=num_threads * requests_per_thread)
        with ThreadPoolExecutor(max_workers=num_threads) as executor:
            futures = [executor.submit(worker) for _ in range(num_threads)]
            for future in futures:
                future.result()

    if not image_times:
        return None

    return {
        'successful_register_requests': len(register_times),
        'successful_verify_requests': len(verify_times),
        'successful_image_requests': len(image_times),
        'failed_requests': errors,
        'median_register_response_time': statistics.median(register_times),
        'median_verify_response_time': statistics.median(verify_times),
        'median_image_response_time': statistics.median(image_times),
        'min_response_time': min(register_times + verify_times + image_times),
        'max_response_time': max(register_times + verify_times + image_times),
        'total_time': sum(register_times + verify_times + image_times)
    }

def main():
    parser = argparse.ArgumentParser(description='Benchmark web server performance')
    parser.add_argument('--url', default='http://localhost:8080', help='Base URL of the server')
    parser.add_argument('--start-threads', type=int, default=1, help='Starting number of threads')
    parser.add_argument('--max-threads', type=int, default=100, help='Maximum number of threads')
    parser.add_argument('--requests-per-thread', type=int, default=20, help='Number of requests per thread')
    
    args = parser.parse_args()
    
    console = Console()
    results_table = Table(show_header=True, header_style="bold cyan")
    results_table.add_column("Threads")
    results_table.add_column("Total Requests")
    results_table.add_column("Success Rate")
    results_table.add_column("Median Response (s)")
    results_table.add_column("Req/sec")

    console.print(f"\n[bold green]Starting benchmark against {args.url}[/bold green]")
    console.print("[yellow]Note: Each request includes register, verify, and image generation[/yellow]\n")

    # Test thread counts from start to max, doubling every step
    assert args.start_threads > 0, "Starting threads must be greater than 0"
    thread_counts = [args.start_threads]
    while thread_counts[-1] < args.max_threads:
        thread_counts.append(thread_counts[-1] * 2)
    if thread_counts[-1] > args.max_threads:
        thread_counts[-1] = args.max_threads

    for num_threads in thread_counts:
        console.print(f"[yellow]Testing with {num_threads} threads...[/yellow]")
        
        start_time = time.time()
        result = run_benchmark(args.url, num_threads, args.requests_per_thread)
        end_time = time.time()
        
        if result:
            register_success_rate = (result['successful_register_requests'] / (num_threads * args.requests_per_thread)) * 100
            verify_success_rate = (result['successful_verify_requests'] / (num_threads * args.requests_per_thread)) * 100
            image_success_rate = (result['successful_image_requests'] / (num_threads * args.requests_per_thread)) * 100
            requests_per_second = (result['successful_register_requests'] + result['successful_verify_requests'] + result['successful_image_requests']) / (end_time - start_time)
            results_table.add_row(
                str(num_threads),
                str(num_threads * args.requests_per_thread * 3),
                f"{register_success_rate:.1f}% / {verify_success_rate:.1f}% / {image_success_rate:.1f}%",
                f"{result['median_register_response_time']:.3f} / {result['median_verify_response_time']:.3f} / {result['median_image_response_time']:.3f}",
                f"{requests_per_second:.1f}"
            )
        else:
            results_table.add_row(
                str(num_threads),
                str(num_threads * args.requests_per_thread),
                "0%",
                "N/A",
                "N/A"
            )

    console.print("\n[bold]Benchmark Results:[/bold]")
    console.print(results_table)

if __name__ == "__main__":
    main()