/**
 * assertAccess 静态契约检查
 *
 * 规则对齐（09-iam-security.md §7.1 + §8 禁止事项 #10）：
 * 所有 Service 方法里的 prisma.<model>.update / delete / updateMany / deleteMany 调用，
 * 必须在同一函数体内看到 assertAccess 调用，或整方法带 @SkipAssertAccess('理由') 装饰器。
 *
 * 仅检查本次新增/修改的文件里的新违规；存量文件一律放过。
 *
 * 用法：
 *   npx ts-node --transpile-only --project testing/scripts/tsconfig.json testing/scripts/assert-access-check.ts --staged
 *   npx ts-node --transpile-only --project testing/scripts/tsconfig.json testing/scripts/assert-access-check.ts --base origin/develop
 */

import { execSync } from 'child_process';
import * as fs from 'fs';
import * as path from 'path';

const ROOT_DIR = path.resolve(__dirname, '..', '..');

// 仅扫 Service 层（业务逻辑 + 持久化在这里）
const SERVICE_PATH_PATTERN = /backend\/src\/.*\.service\.ts$/;

interface Violation {
  file: string;
  line: number;
  method: string;
  prismaCall: string;
}

function parseArgs() {
  const args = process.argv.slice(2);
  const mode = args.includes('--staged') ? 'staged' : 'base';
  const baseIdx = args.indexOf('--base');
  const base = baseIdx >= 0 ? args[baseIdx + 1] : 'origin/develop';
  return { mode, base };
}

function getChangedServiceFiles(mode: string, base: string): string[] {
  const cmd =
    mode === 'staged'
      ? `git diff --cached --name-only --diff-filter=AM`
      : `git diff ${base}...HEAD --name-only --diff-filter=AM`;
  try {
    return execSync(cmd, { cwd: ROOT_DIR, encoding: 'utf-8' })
      .split('\n')
      .filter((f) => f && SERVICE_PATH_PATTERN.test(f));
  } catch {
    return [];
  }
}

function gitShow(ref: string, file: string): string {
  try {
    return execSync(`git show ${ref}:${file}`, {
      cwd: ROOT_DIR,
      encoding: 'utf-8',
      stdio: ['pipe', 'pipe', 'ignore'],
    });
  } catch {
    return '';
  }
}

function readContent(file: string, mode: string, base: string): string {
  if (mode === 'staged') return gitShow('', file);
  const abs = path.join(ROOT_DIR, file);
  return fs.existsSync(abs) ? fs.readFileSync(abs, 'utf-8') : '';
}

function readBaseContent(file: string, mode: string, base: string): string {
  return gitShow(mode === 'staged' ? 'HEAD' : base, file);
}

/**
 * 简单的方法范围解析（括号匹配）— 从 TS 文件中找所有可能的"方法体"
 * 返回 { name, bodyStart, bodyEnd, bodyText, decoratorText(紧邻上方的装饰器块), startLine }
 */
interface MethodScope {
  name: string;
  body: string;
  decorators: string;
  startLine: number;
}

function extractMethods(source: string): MethodScope[] {
  const lines = source.split('\n');
  const methods: MethodScope[] = [];
  const methodDecl = /^\s*(?:public|private|protected|async|static|\s)*\s*([a-zA-Z_][\w]*)\s*\([^)]*\)\s*(?::\s*[^{]+)?\s*\{/;

  for (let i = 0; i < lines.length; i++) {
    const m = lines[i].match(methodDecl);
    if (!m) continue;
    // 排除 constructor 的简单情况；关键词过滤（if/for/while/switch/catch/return 等）
    const name = m[1];
    if (['constructor', 'if', 'for', 'while', 'switch', 'catch', 'return'].includes(name)) continue;

    // 收集紧邻上方的装饰器（连续的 @xxx 行）
    const decoratorLines: string[] = [];
    for (let j = i - 1; j >= 0; j--) {
      const trimmed = lines[j].trim();
      if (trimmed.startsWith('@') || trimmed === '' || trimmed.startsWith('//') || trimmed.startsWith('*')) {
        if (trimmed.startsWith('@')) decoratorLines.unshift(lines[j]);
      } else {
        break;
      }
    }

    // 找到方法体末尾（括号匹配）
    let depth = 0;
    let end = i;
    for (let k = i; k < lines.length; k++) {
      for (const ch of lines[k]) {
        if (ch === '{') depth++;
        else if (ch === '}') depth--;
      }
      if (depth === 0) {
        end = k;
        break;
      }
    }

    methods.push({
      name,
      body: lines.slice(i, end + 1).join('\n'),
      decorators: decoratorLines.join('\n'),
      startLine: i + 1,
    });
  }
  return methods;
}

function hasSkipAssertAccess(decorators: string): boolean {
  return /@SkipAssertAccess\s*\(\s*['"`][^'"`]+['"`]\s*\)/.test(decorators);
}

function findPrismaWriteCalls(body: string): string[] {
  const pattern = /prisma\.\w+\.(update|delete|updateMany|deleteMany)\s*\(/g;
  const matches: string[] = [];
  let m;
  while ((m = pattern.exec(body)) !== null) {
    matches.push(m[0]);
  }
  return matches;
}

function methodHasAssertAccess(body: string): boolean {
  return /\bassertAccess\s*\(/.test(body);
}

function checkFile(file: string, newContent: string, baseContent: string): Violation[] {
  const violations: Violation[] = [];

  const newMethods = extractMethods(newContent);
  const baseMethods = extractMethods(baseContent);

  // 只报告"本次新增 or 内容变化"的方法；存量原状的放过
  const baseMethodMap = new Map(baseMethods.map((m) => [m.name, m.body]));

  for (const method of newMethods) {
    const existed = baseMethodMap.get(method.name);
    if (existed && existed === method.body) continue; // 未变化

    const writeCalls = findPrismaWriteCalls(method.body);
    if (writeCalls.length === 0) continue;

    if (hasSkipAssertAccess(method.decorators)) continue;
    if (methodHasAssertAccess(method.body)) continue;

    for (const call of writeCalls) {
      violations.push({
        file,
        line: method.startLine,
        method: method.name,
        prismaCall: call,
      });
    }
  }

  return violations;
}

function main() {
  const { mode, base } = parseArgs();
  const files = getChangedServiceFiles(mode, base);

  if (files.length === 0) {
    process.exit(0);
  }

  const violations: Violation[] = [];
  for (const file of files) {
    const newContent = readContent(file, mode, base);
    if (!newContent) continue;
    const baseContent = readBaseContent(file, mode, base);
    violations.push(...checkFile(file, newContent, baseContent));
  }

  if (violations.length === 0) {
    process.exit(0);
  }

  console.error('\n❌ assertAccess 静态契约检查失败\n');
  for (const v of violations) {
    console.error(`  ${v.file}:${v.line}  method ${v.method}()`);
    console.error(`    → 含 ${v.prismaCall} 但未调用 assertAccess，也没有 @SkipAssertAccess('理由')`);
  }
  console.error(`\n规则详见 docs/standards/09-iam-security.md §7.1 / §8 禁止事项 #10`);
  console.error(`若确实无 IDOR 风险，在方法上加 @SkipAssertAccess('具体理由') 显式豁免\n`);
  process.exit(1);
}

main();
