import { PricingTable } from './types';

/**
 * Anthropic 阶梯定价：单 event 超 tierThresholdTokens（默认 200k）的部分用更高单价。
 * 简化口径：以"单 event 内每种 token 类型 > threshold 才算超阶梯"，
 * 不做跨 event 累计；常态场景偏差 < 0.5%，超长上下文（MyBrain 类）会更准。
 */
function tieredCost(tokens: number, baseRate: number, tierRate: number, threshold: number): number {
  if (tokens <= threshold) return (tokens / 1_000_000) * baseRate;
  const baseShare = (threshold / 1_000_000) * baseRate;
  const tierShare = ((tokens - threshold) / 1_000_000) * tierRate;
  return baseShare + tierShare;
}

function rate(entry: any, key: string, fallback?: string): number {
  return parseFloat(entry[key] ?? fallback ?? '0');
}

export function estimateCostUsd(
  model: string,
  inputTokens: number,
  outputTokens: number,
  cacheCreationTokens: number,
  cacheReadTokens: number,
  pricing: PricingTable | null,
): string {
  if (!pricing) return '0.000000';
  const entry = pricing.models.find((m) => m.id === model);
  if (!entry) return '0.000000';
  const threshold = (pricing as any).tierThresholdTokens ?? Number.POSITIVE_INFINITY;

  const baseIn = rate(entry, 'inputUsdPerMTok');
  const baseOut = rate(entry, 'outputUsdPerMTok');
  const baseCC = rate(entry, 'cacheCreationUsdPerMTok');
  const baseCR = rate(entry, 'cacheReadUsdPerMTok');
  const tierIn = rate(entry, 'inputUsdPerMTokAbove200k', String(baseIn));
  const tierOut = rate(entry, 'outputUsdPerMTokAbove200k', String(baseOut));
  const tierCC = rate(entry, 'cacheCreationUsdPerMTokAbove200k', String(baseCC));
  const tierCR = rate(entry, 'cacheReadUsdPerMTokAbove200k', String(baseCR));

  const cost =
    tieredCost(inputTokens, baseIn, tierIn, threshold) +
    tieredCost(outputTokens, baseOut, tierOut, threshold) +
    tieredCost(cacheCreationTokens, baseCC, tierCC, threshold) +
    tieredCost(cacheReadTokens, baseCR, tierCR, threshold);

  return cost.toFixed(6);
}
