"""
压力测试脚本（增强版）
功能：记录每个请求的详细时间戳，生成分析报告
"""
import aiohttp
import asyncio
import time
import sys
import csv
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

prompt = \
"""
你是一位有多年丰富经验的国资国企领域研究专家，请撰写一份国有控股混合所有制企业员工持股法律风险研究报告，字数要求在5000字以上，涵盖引言、相关研究综述、研究方法、研究结果与讨论、结论与展望等主要模块。
"""

# 全局配置
API_URL = 'http://localhost:8000/v1/chat/completions'
REQUEST_PAYLOAD = {
    "model": "qwen",
    "messages": [
        {"role": "system", "content": "你是国资国企数智化分析平台人工智能助手国资小研，是由项目组基于国产开源大语言模型结合国资国企领域知识训练开发的垂直领域大模型。你的任务是针对国资国企研究领域的问题和要求提供适当的答复和支持。"},
        {"role": "user", "content": prompt}
    ],
    "do_sample": True,
    "temperature": 0.95,
    "top_p": 0.7,
    "n": 1,
    "max_tokens": 8192,
    "stream": False

}


async def fetch(session, url, request_id):
    """带详细时间记录的请求函数"""
    timestamps = {
        'request_id': request_id,
        'start_time': time.perf_counter(),  # 使用高精度计时器
        'end_time': None,
        'tokens': 0,
        'status': 'success'
    }

    try:
        async with session.post(url, json=REQUEST_PAYLOAD) as response:
            # 检查HTTP状态码
            if response.status != 200:
                raise aiohttp.ClientResponseError(
                    response.request_info,
                    response.history,
                    status=response.status
                )

            response_json = await response.json()
            timestamps['tokens'] = response_json['usage']['completion_tokens']
            timestamps['end_time'] = time.perf_counter()

    except Exception as e:
        timestamps['status'] = f'failed: {str(e)}'

    return timestamps


async def bound_fetch(sem, session, url, pbar, request_id):
    """带并发限制的请求封装"""
    async with sem:
        result = await fetch(session, url, request_id)
        pbar.update(1)
        return result


async def run(load_url, max_concurrent, total_requests):
    """主测试逻辑"""
    sem = asyncio.Semaphore(max_concurrent)
    request_id_gen = (i for i in range(1, total_requests + 1))
    CLIENT_TIMEOUT = aiohttp.ClientTimeout(total=3600)
    async with aiohttp.ClientSession(timeout=CLIENT_TIMEOUT) as session:

        tasks = []
        with tqdm(total=total_requests, desc="Processing") as pbar:
            for _ in range(total_requests):
                req_id = next(request_id_gen)
                task = asyncio.ensure_future(
                    bound_fetch(sem, session, load_url, pbar, req_id)
                )
                tasks.append(task)

            results = await asyncio.gather(*tasks)

    return sorted(results, key=lambda x: x['request_id'])


def analyze_results(results):
    """结果分析与可视化"""
    # 基础统计
    success_requests = [r for r in results if r['status'] == 'success']
    response_times = [r['end_time'] - r['start_time'] for r in success_requests]
    total_tokens = sum(r['tokens'] for r in success_requests)

    # 耗时统计
    total_time = results[-1]['end_time'] - results[0]['start_time'] if success_requests else 0

    # 输出报告
    print("\n=== 性能分析报告 ===")
    print(f"总请求数: {len(results)}")
    print(f"成功请求: {len(success_requests)}")
    print(f"总耗时: {total_time:.2f}s")
    print(f"Token生成速率: {total_tokens / total_time:.2f} tokens/s" if total_time else "N/A")
    print("\n响应时间分布 (秒):")
    print(f"  Min   : {np.min(response_times):.2f}")
    print(f"  Avg   : {np.mean(response_times):.2f}")
    print(f"  P90   : {np.percentile(response_times, 90):.2f}")
    print(f"  P95   : {np.percentile(response_times, 95):.2f}")
    print(f"  Max   : {np.max(response_times):.2f}")

    # 可视化
    plt.figure(figsize=(12, 6))

    # 响应时间分布直方图
    plt.subplot(1, 2, 1)
    plt.hist(response_times, bins=20, color='skyblue', edgecolor='black')
    plt.title('响应时间分布')
    plt.xlabel('时间 (秒)')
    plt.ylabel('请求数量')

    # 时间线分布图
    plt.subplot(1, 2, 2)
    start_times = [r['start_time'] - results[0]['start_time'] for r in results]
    durations = [r['end_time'] - r['start_time'] for r in results]
    plt.barh(y=range(len(results)), width=durations, left=start_times, height=1)
    plt.title('请求时间线')
    plt.xlabel('时间轴 (秒)')
    plt.ylabel('请求序号')

    plt.tight_layout()
    plt.savefig('performance_analysis.png')
    plt.show()


def save_to_csv(results):
    """保存详细结果到CSV"""
    with open('request_details.csv', 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=results[0].keys())
        writer.writeheader()
        writer.writerows(results)
    print("\n详细结果已保存到 request_details.csv")


if __name__ == '__main__':
    if len(sys.argv) != 3:
        print("Usage: python bench.py <并发数> <总请求数>")
        sys.exit(1)

    CONCURRENCY = int(sys.argv[1])
    TOTAL_REQUESTS = int(sys.argv[2])

    # 运行测试
    start = time.perf_counter()
    test_results = asyncio.run(run(API_URL, CONCURRENCY, TOTAL_REQUESTS))
    elapsed = time.perf_counter() - start

    # 结果处理
    save_to_csv(test_results)
    analyze_results(test_results)

    # 附加元数据
    with open('test_metadata.txt', 'w') as f:
        f.write(f"测试时间: {time.ctime()}\n")
        f.write(f"并发数: {CONCURRENCY}\n")
        f.write(f"总请求数: {TOTAL_REQUESTS}\n")
        f.write(f"客户端总耗时: {elapsed:.2f}秒")