import { Injectable } from '@nestjs/common';
import { createHash } from 'crypto';
import { PrismaService } from '@core/database/prisma/prisma.service';
import type { AgentTrajectoryEvent, AgentTrajectoryEventType } from '@prisma/client';

export interface AppendEventInput {
  organizationId: string;
  sessionId: string;
  turnId?: string;
  eventType: AgentTrajectoryEventType;
  payload: unknown;
}

/**
 * AgentTrajectoryEvent 哈希链 append-only 审计（PR4c）。
 *
 * 不变量：
 *   - 同 session 内 sequenceInSession 严格递增（DB 唯一约束兜底）
 *   - eventHash = sha256(prevEventHash ?? '' || canonical(payload))
 *   - 任一篡改 → 重算 hash 不一致 → verifyChain() fail-loud
 *
 * canonical(payload) 用 JSON.stringify + 键排序保证确定性。
 *
 * 不写关键路径上的"次要 event"——只记真正用于审计/回放的关键状态点：
 *   TURN_STARTED / ROUTING_DECIDED / PROVIDER_INVOKED / MESSAGE_APPENDED /
 *   TURN_DONE / TOOL_CALL / TOOL_RESULT
 */
@Injectable()
export class TrajectoryService {
  constructor(private readonly prisma: PrismaService) {}

  async append(input: AppendEventInput): Promise<AgentTrajectoryEvent> {
    // 防御性：findFirst 加 organizationId 过滤，即便调用方误把外 org sessionId 传进来，
    // 也不会拼到他人哈希链上；与 messages.service.assertSessionOwnership 形成双层守卫
    const last = await this.prisma.agentTrajectoryEvent.findFirst({
      where: { sessionId: input.sessionId, organizationId: input.organizationId },
      orderBy: { sequenceInSession: 'desc' },
      select: { sequenceInSession: true, eventHash: true },
    });
    const nextSeq = (last?.sequenceInSession ?? 0) + 1;
    const prevHash = last?.eventHash ?? null;
    const eventHash = computeHash(prevHash, input.payload);

    return this.prisma.agentTrajectoryEvent.create({
      data: {
        organizationId: input.organizationId,
        sessionId: input.sessionId,
        turnId: input.turnId ?? null,
        eventType: input.eventType,
        payload: input.payload as never,
        sequenceInSession: nextSeq,
        prevEventHash: prevHash,
        eventHash,
      },
    });
  }

  async listForSession(
    sessionId: string,
    organizationId: string,
  ): Promise<AgentTrajectoryEvent[]> {
    return this.prisma.agentTrajectoryEvent.findMany({
      where: { sessionId, organizationId },
      orderBy: { sequenceInSession: 'asc' },
    });
  }

  /**
   * 校验单个 session 的哈希链完整性。返回第一条断链的 sequenceInSession，或 null 表示通过。
   */
  async verifyChain(
    sessionId: string,
    organizationId: string,
  ): Promise<{ ok: true } | { ok: false; firstBrokenSeq: number; reason: string }> {
    const events = await this.prisma.agentTrajectoryEvent.findMany({
      where: { sessionId, organizationId },
      orderBy: { sequenceInSession: 'asc' },
    });
    let prevHash: string | null = null;
    let expectedSeq = 1;
    for (const ev of events) {
      if (ev.sequenceInSession !== expectedSeq) {
        return {
          ok: false,
          firstBrokenSeq: ev.sequenceInSession,
          reason: `sequence gap: expected ${expectedSeq} got ${ev.sequenceInSession}`,
        };
      }
      if (ev.prevEventHash !== prevHash) {
        return {
          ok: false,
          firstBrokenSeq: ev.sequenceInSession,
          reason: 'prevEventHash mismatch',
        };
      }
      const expectedHash = computeHash(prevHash, ev.payload);
      if (ev.eventHash !== expectedHash) {
        return {
          ok: false,
          firstBrokenSeq: ev.sequenceInSession,
          reason: 'eventHash mismatch (payload tampered)',
        };
      }
      prevHash = ev.eventHash;
      expectedSeq += 1;
    }
    return { ok: true };
  }
}

function computeHash(prevHash: string | null, payload: unknown): string {
  return createHash('sha256')
    .update(prevHash ?? '')
    .update(canonicalJson(payload))
    .digest('hex');
}

/**
 * 确定性 JSON 序列化：对象键按字典序排，保证同样数据产生同样字符串。
 * 不处理 Date / BigInt / undefined / 循环引用——payload 全是普通 JSON 值。
 */
function canonicalJson(value: unknown): string {
  if (value === null || typeof value !== 'object') return JSON.stringify(value);
  if (Array.isArray(value)) {
    return '[' + value.map(canonicalJson).join(',') + ']';
  }
  const keys = Object.keys(value as Record<string, unknown>).sort();
  return (
    '{' +
    keys
      .map((k) => JSON.stringify(k) + ':' + canonicalJson((value as Record<string, unknown>)[k]))
      .join(',') +
    '}'
  );
}
