#!/usr/bin/env python3
"""Generate backend test report from Jest JSON + test file + API doc.

This script is intentionally lightweight and heuristic-based. It extracts:
- Test cases from `it()` blocks
- API usage from request().get/post/put/delete
- Inputs from send()/query()
- Outputs from expect()/expect(status)
- Coverage from docs/modules/<module>/07-api.md
"""

import argparse
import datetime as dt
import json
import os
import re
import textwrap
from typing import Dict, List, Optional, Tuple

API_METHODS = ("get", "post", "put", "patch", "delete")


def read_text(path: str) -> str:
    with open(path, "r", encoding="utf-8") as f:
        return f.read()


def now_timestamp() -> str:
    return dt.datetime.now().strftime("%Y-%m-%d %H:%M:%S")


def today_date() -> str:
    return dt.datetime.now().strftime("%Y-%m-%d")


def find_matching_brace(text: str, start: int) -> int:
    depth = 0
    in_string = None
    escape = False
    for i in range(start, len(text)):
        ch = text[i]
        if in_string:
            if escape:
                escape = False
                continue
            if ch == "\\":
                escape = True
                continue
            if ch == in_string:
                in_string = None
            continue
        if ch in ("'", '"', "`"):
            in_string = ch
            continue
        if ch == "{":
            depth += 1
        elif ch == "}":
            depth -= 1
            if depth == 0:
                return i
    return -1


def parse_string_literal(text: str, start: int) -> Tuple[str, int]:
    quote = text[start]
    i = start + 1
    result = []
    while i < len(text):
        ch = text[i]
        if ch == "\\":
            if i + 1 < len(text):
                result.append(text[i:i + 2])
                i += 2
                continue
        if ch == quote:
            return "".join(result), i + 1
        result.append(ch)
        i += 1
    return "".join(result), i


def parse_paren_content(text: str, start: int) -> Tuple[str, int]:
    if text[start] != "(":
        return "", start
    depth = 0
    i = start
    in_string = None
    escape = False
    content_chars = []
    while i < len(text):
        ch = text[i]
        if in_string:
            content_chars.append(ch)
            if escape:
                escape = False
            elif ch == "\\":
                escape = True
            elif ch == in_string:
                in_string = None
            i += 1
            continue
        if ch in ("'", '"', "`"):
            in_string = ch
            content_chars.append(ch)
            i += 1
            continue
        if ch == "(":
            depth += 1
        elif ch == ")":
            depth -= 1
            if depth == 0:
                content_chars.append(ch)
                return "".join(content_chars[1:-1]).strip(), i + 1
        content_chars.append(ch)
        i += 1
    return "".join(content_chars).strip(), i


def extract_api_calls(text: str) -> List[Tuple[int, str, str]]:
    results = []
    pattern = re.compile(r"\.\s*(%s)\s*\(" % "|".join(API_METHODS))
    for match in pattern.finditer(text):
        method = match.group(1).upper()
        idx = match.end()
        while idx < len(text) and text[idx].isspace():
            idx += 1
        if idx >= len(text):
            continue
        if text[idx] not in ("'", '"', "`"):
            continue
        literal, _ = parse_string_literal(text, idx)
        results.append((match.start(), method, literal))
    return results


def extract_call_args(text: str, call_name: str) -> List[str]:
    results = []
    pattern = re.compile(r"\.%s\s*\(" % re.escape(call_name))
    for match in pattern.finditer(text):
        arg_start = match.end() - 1
        arg_text, _ = parse_paren_content(text, arg_start)
        cleaned = arg_text.strip()
        if cleaned:
            results.append(cleaned)
    return results


def extract_expect_lines(text: str) -> List[str]:
    lines = []
    for raw in text.splitlines():
        if "expect(" in raw:
            lines.append(raw.strip())
    return lines


def extract_expect_status(text: str) -> List[str]:
    statuses = []
    for match in re.finditer(r"\.expect\((\d{3})\)", text):
        statuses.append(match.group(1))
    return statuses


def extract_error_codes(text: str) -> List[str]:
    codes = []
    for match in re.finditer(r"error\.code\)\.toBe\(['\"]([^'\"]+)['\"]\)", text):
        codes.append(match.group(1))
    return codes


def parse_functions(text: str) -> Dict[str, List[Tuple[str, str]]]:
    functions = {}
    pattern = re.compile(r"(?:async\s+)?function\s+(\w+)\s*\(")
    for match in pattern.finditer(text):
        name = match.group(1)
        paren_start = match.end() - 1
        _, paren_end = parse_paren_content(text, paren_start)
        brace_start = text.find("{", paren_end)
        if brace_start == -1:
            continue
        brace_end = find_matching_brace(text, brace_start)
        if brace_end == -1:
            continue
        body = text[brace_start:brace_end + 1]
        calls = extract_api_calls(body)
        if calls:
            endpoints = [(method, path) for _, method, path in calls]
            functions[name] = endpoints
    return functions


def parse_test_cases(text: str, helper_api: Dict[str, List[Tuple[str, str]]]) -> List[Dict[str, object]]:
    results = []
    pattern = re.compile(r"\b(it|test)\s*\(")
    for match in pattern.finditer(text):
        idx = match.end()
        while idx < len(text) and text[idx].isspace():
            idx += 1
        if idx >= len(text) or text[idx] not in ("'", '"', "`"):
            continue
        title, title_end = parse_string_literal(text, idx)
        arrow_idx = text.find("=>", title_end)
        if arrow_idx == -1:
            continue
        brace_start = text.find("{", arrow_idx)
        if brace_start == -1:
            continue
        brace_end = find_matching_brace(text, brace_start)
        if brace_end == -1:
            continue
        block = text[brace_start:brace_end + 1]

        api_calls = extract_api_calls(block)
        api_items = []
        for pos, method, path in api_calls:
            api_items.append((pos, method, path, "direct"))

        for helper_name, endpoints in helper_api.items():
            if re.search(r"\b%s\s*\(" % re.escape(helper_name), block):
                for method, path in endpoints:
                    api_items.append((block.find(helper_name), method, path, "helper:%s" % helper_name))

        api_items_sorted = sorted(api_items, key=lambda x: x[0])
        api_list = [(method, path, source) for _, method, path, source in api_items_sorted]

        inputs = []
        inputs.extend(extract_call_args(block, "send"))
        inputs.extend(extract_call_args(block, "query"))

        outputs = extract_expect_lines(block)
        status_codes = extract_expect_status(block)
        error_codes = extract_error_codes(block)

        results.append({
            "title": title,
            "block": block,
            "apis": api_list,
            "inputs": inputs,
            "outputs": outputs,
            "status_codes": status_codes,
            "error_codes": error_codes,
        })
    return results


def parse_api_doc(doc_text: str, base_url: str) -> List[Tuple[str, str]]:
    endpoints = []
    for line in doc_text.splitlines():
        line = line.strip()
        if not line.startswith("|"):
            continue
        match = re.match(r"\|\s*(GET|POST|PUT|PATCH|DELETE)\s*\|\s*([^|]+)\|", line)
        if not match:
            continue
        method = match.group(1)
        path = match.group(2).strip()
        if not path.startswith("/"):
            continue
        full_path = base_url.rstrip("/") + path
        endpoints.append((method, full_path))
    return endpoints


def build_doc_regex(path: str) -> re.Pattern:
    escaped = re.escape(path)
    escaped = re.sub(r":([^/]+)", r"[^/]+", escaped)
    return re.compile("^%s$" % escaped)


def normalize_used_path(path: str) -> str:
    # Strip query params and replace template placeholders with a dummy segment
    normalized = path.split("?", 1)[0]
    normalized = re.sub(r"\$\{[^}]+\}", "DUMMY", normalized)
    return normalized


def match_doc_endpoint(method: str, used_path: str, doc_endpoints: List[Tuple[str, str]]) -> Optional[Tuple[str, str]]:
    normalized = normalize_used_path(used_path)
    # Prefer exact match to avoid :id capturing fixed paths like /my or /upcoming
    for doc_method, doc_path in doc_endpoints:
        if doc_method != method:
            continue
        if ":" not in doc_path and normalized == doc_path:
            return doc_method, doc_path
    for doc_method, doc_path in doc_endpoints:
        if doc_method != method:
            continue
        regex = build_doc_regex(doc_path)
        if regex.match(normalized):
            return doc_method, doc_path
    return None


def unique_list(items: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
    seen = set()
    result = []
    for item in items:
        if item in seen:
            continue
        seen.add(item)
        result.append(item)
    return result


def format_block(lines: List[str]) -> str:
    if not lines:
        return "无"
    joined = "\n".join(lines)
    return "```text\n%s\n```" % joined


def format_inputs(inputs: List[str]) -> str:
    if not inputs:
        return "无"
    payloads = []
    for payload in inputs:
        payloads.append(payload.strip())
    return "```text\n%s\n```" % "\n\n".join(payloads)


def format_api_list(apis: List[Tuple[str, str, str]]) -> str:
    if not apis:
        return "无"
    lines = []
    for method, path, source in apis:
        lines.append(f"- {method} {path} ({source})")
    return "\n".join(lines)


def format_outputs(outputs: List[str], status_codes: List[str], error_codes: List[str]) -> str:
    parts = []
    if status_codes:
        parts.append("状态码: %s" % ", ".join(sorted(set(status_codes))))
    if error_codes:
        parts.append("错误码: %s" % ", ".join(sorted(set(error_codes))))
    if outputs:
        parts.append("断言:\n" + "\n".join(outputs))
    if not parts:
        return "无"
    return format_block(parts)


def load_jest_results(path: Optional[str]) -> Dict[str, Dict[str, str]]:
    if not path:
        return {}
    if not os.path.exists(path):
        return {}
    data = json.loads(read_text(path))
    result_map = {}
    for suite in data.get("testResults", []):
        for case in suite.get("assertionResults", []):
            title = case.get("title")
            if not title:
                continue
            result_map[title] = {
                "status": case.get("status", "unknown"),
                "failure": "\n".join(case.get("failureMessages", [])) if case.get("failureMessages") else "",
            }
    return result_map


def render_report(
    module: str,
    report_type: str,
    command: str,
    env_name: str,
    branch: str,
    commit: str,
    test_cases: List[Dict[str, object]],
    doc_endpoints: List[Tuple[str, str]],
    jest_results: Dict[str, Dict[str, str]],
) -> str:
    timestamp = now_timestamp()
    report_id = f"{module}-{today_date()}-{dt.datetime.now().strftime('%H%M%S')}"

    used_endpoints = []
    for case in test_cases:
        for method, path, _source in case["apis"]:
            used_endpoints.append((method, path))
    used_endpoints = unique_list(used_endpoints)

    covered_doc = []
    uncovered_doc = []
    for doc_endpoint in doc_endpoints:
        doc_method, doc_path = doc_endpoint
        matched = False
        for method, path in used_endpoints:
            if match_doc_endpoint(method, path, doc_endpoints) == doc_endpoint:
                matched = True
                break
        if matched:
            covered_doc.append(doc_endpoint)
        else:
            uncovered_doc.append(doc_endpoint)

    total_count = len(doc_endpoints)
    covered_count = len(covered_doc)
    coverage_pct = (covered_count / total_count * 100) if total_count else 0

    passed = 0
    failed = 0
    unknown = 0
    for case in test_cases:
        title = case["title"]
        status = jest_results.get(title, {}).get("status", "unknown")
        if status == "passed":
            passed += 1
        elif status == "failed":
            failed += 1
        else:
            unknown += 1

    pass_rate = (passed / len(test_cases) * 100) if test_cases else 0

    lines = []
    lines.append(f"# 测试报告 - {module}")
    lines.append("")
    lines.append(f"> **报告ID**: `{report_id}`  ")
    lines.append(f"> **生成时间**: {timestamp}  ")
    lines.append(f"> **报告类型**: {report_type}")
    lines.append("")
    lines.append("---")
    lines.append("")
    lines.append("## 执行摘要")
    lines.append("")
    lines.append("### 执行摘要")
    lines.append("")
    lines.append("| 字段 | 内容 |")
    lines.append("|------|------|")
    lines.append(f"| 报告ID | `{report_id}` |")
    lines.append(f"| 环境 | {env_name} |")
    lines.append(f"| 分支 | `{branch}` |")
    lines.append(f"| Git Commit | `{commit}` |")
    lines.append(f"| 通过率 | {pass_rate:.2f}% |")
    lines.append(f"| 失败用例 | {failed} |")
    lines.append(f"| 未知状态用例 | {unknown} |")
    lines.append(f"| 执行命令 | `{command}` |")
    lines.append(f"| 测试类型 | {report_type} |")
    lines.append("")
    lines.append("### 覆盖度摘要")
    lines.append("")
    lines.append("| 字段 | 内容 |")
    lines.append("|------|------|")
    lines.append(f"| 文档 API 总数 | {total_count} |")
    lines.append(f"| 覆盖 API 数量 | {covered_count} |")
    lines.append(f"| 覆盖率 | {coverage_pct:.2f}% |")
    lines.append("| 覆盖口径 | 以 07-api.md 的接口清单为基准，匹配测试中出现的 HTTP 方法 + 路径 |")
    lines.append("")
    lines.append("### 覆盖清单")
    lines.append("")
    lines.append("**已覆盖 API**")
    if covered_doc:
        for method, path in covered_doc:
            lines.append(f"- {method} {path}")
    else:
        lines.append("- 无")
    lines.append("")
    lines.append("**未覆盖 API**")
    if uncovered_doc:
        for method, path in uncovered_doc:
            lines.append(f"- {method} {path}")
    else:
        lines.append("- 无")
    lines.append("")
    lines.append("---")
    lines.append("")
    lines.append("## 执行详情")
    lines.append("")
    lines.append("### 执行上下文")
    lines.append("")
    lines.append("| 项目 | 信息 |")
    lines.append("|------|------|")
    lines.append(f"| 执行时间 | {timestamp} |")
    lines.append("| 执行器 | Jest + Supertest |")
    lines.append("| 环境信息 | N/A（建议补充 OS / Node / DB） |")
    lines.append("")
    lines.append("### 契约一致性检查")
    lines.append("")
    lines.append("- 检查口径：基于测试断言中的状态码与错误码，未做字段级响应结构自动比对。")
    lines.append("- 如需字段级比对，请在测试中增加字段断言，或补充专用对比脚本。")
    lines.append("")
    lines.append("### 用例执行记录")
    lines.append("")
    for case in test_cases:
        title = case["title"]
        result = jest_results.get(title, {})
        status = result.get("status", "unknown")
        failure = result.get("failure", "")
        status_label = "✅ 通过" if status == "passed" else ("❌ 失败" if status == "failed" else "⚪ 未知")

        lines.append(f"#### {title}")
        lines.append("")
        lines.append(f"- 结果: {status_label}")
        if failure:
            lines.append("- 失败信息:")
            lines.append(format_block([failure]))
        lines.append("- 输入:")
        lines.append(format_inputs(case["inputs"]))
        lines.append("- 使用的 API 清单:")
        lines.append(format_api_list(case["apis"]))
        lines.append("- 输出/断言:")
        lines.append(format_outputs(case["outputs"], case["status_codes"], case["error_codes"]))
        lines.append("- 执行流程:")
        lines.append("  - 基于测试代码中的 API 调用顺序与断言执行")
        lines.append("")

    return "\n".join(lines)


def main() -> None:
    parser = argparse.ArgumentParser(description="Generate backend test report")
    parser.add_argument("--module", required=True)
    parser.add_argument("--type", required=True, choices=["unit", "integration"])
    parser.add_argument("--test-file", required=True)
    parser.add_argument("--api-doc", required=True)
    parser.add_argument("--jest-json")
    parser.add_argument("--output")
    parser.add_argument("--command", default="")
    parser.add_argument("--env", default="test")
    parser.add_argument("--branch", default="unknown")
    parser.add_argument("--commit", default="unknown")
    parser.add_argument("--base-url", default="/api/v1/performance")

    args = parser.parse_args()

    test_text = read_text(args.test_file)
    api_doc_text = read_text(args.api_doc)

    helper_api = parse_functions(test_text)
    test_cases = parse_test_cases(test_text, helper_api)
    doc_endpoints = parse_api_doc(api_doc_text, args.base_url)
    jest_results = load_jest_results(args.jest_json)

    report = render_report(
        module=args.module,
        report_type=args.type,
        command=args.command,
        env_name=args.env,
        branch=args.branch,
        commit=args.commit,
        test_cases=test_cases,
        doc_endpoints=doc_endpoints,
        jest_results=jest_results,
    )

    if args.output:
        output_path = args.output
    else:
        output_name = f"{args.module}-{today_date()}-{args.type}-report.md"
        output_path = os.path.join("testing", "reports", output_name)

    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(report)

    print(f"Report generated: {output_path}")


if __name__ == "__main__":
    main()
