import { Injectable, Logger, UnauthorizedException } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { JwtService } from '@nestjs/jwt';
import { randomBytes, randomUUID } from 'crypto';
import { RedisService } from '@core/cache/redis/redis.service';

/**
 * Token 服务：access JWT + refresh 不透明 token + 黑名单 + active_jtis 列表。
 *
 * 规则对齐（09-iam-security.md §2.2 / §5.3.5）：
 * - Access Token：JWT，30d，payload 仅 `{ sub, jti, exp }`
 * - Refresh Token：不透明字符串，Redis 存 30d，一次性使用 + 旋转
 * - 黑名单：JTI 做 key，TTL = token 剩余有效期
 * - `user:${userId}:active_jtis`：支持解雇场景一键失效所有 token
 * - 撤权 SLA 不受 access TTL 影响（黑名单 + Redis 5min 权限缓存即时生效）
 */
@Injectable()
export class TokenService {
  private readonly logger = new Logger(TokenService.name);
  private readonly accessTtlSec: number;
  private readonly refreshTtlSec: number;

  constructor(
    private readonly jwtService: JwtService,
    private readonly config: ConfigService,
    private readonly redis: RedisService,
  ) {
    this.accessTtlSec = parseTtl(
      this.config.get<string>('jwt.accessTtl') || '30d',
    );
    this.refreshTtlSec = parseTtl(
      this.config.get<string>('jwt.refreshTtl') || '30d',
    );
  }

  getAccessTtl(): number {
    return this.accessTtlSec;
  }

  getRefreshTtl(): number {
    return this.refreshTtlSec;
  }

  /**
   * 签发一对 access + refresh token。
   * JWT payload 只含 `sub` / `jti`，不携带权限或用户身份字段。
   */
  async issuePair(userId: string): Promise<{ accessToken: string; refreshToken: string }> {
    const jti = randomUUID();
    const accessToken = this.jwtService.sign(
      { sub: userId, jti },
      { expiresIn: `${this.accessTtlSec}s` },
    );

    const refreshToken = randomBytes(48).toString('base64url');
    await this.redis.setJson(
      `refresh_token:${refreshToken}`,
      this.refreshTtlSec,
      { userId, issuedAt: Date.now() },
    );

    await this.addActiveJti(userId, jti);

    return { accessToken, refreshToken };
  }

  /**
   * 验证 refresh token 并旋转（删除旧的，签发新的一对）。
   *
   * 并发竞态防护：先 GET → 再 DEL，DEL 返回 0 说明并发请求已抢先删除并签发，
   * 当前请求拒绝（避免同一 refreshToken 被两个请求各自换出新 access）。
   */
  async rotateRefresh(
    oldRefreshToken: string,
  ): Promise<{ accessToken: string; refreshToken: string; userId: string }> {
    const key = `refresh_token:${oldRefreshToken}`;
    const payload = await this.redis.getJson<{ userId: string; issuedAt: number }>(key);
    if (!payload) throw new UnauthorizedException('Refresh token 无效或已过期');

    // del 返回 1 才说明本请求成功"占有"删除权（race winner），返回 0 = 已被并发请求删除
    const deleted = await this.redis.del(key);
    if (deleted === 0) {
      throw new UnauthorizedException('Refresh token 已被并发使用');
    }

    const { userId } = payload;
    const pair = await this.issuePair(userId);
    return { ...pair, userId };
  }

  /**
   * JTI 是否在黑名单。
   */
  async isBlacklisted(jti: string): Promise<boolean> {
    return this.redis.exists(`token_blacklist:${jti}`);
  }

  /**
   * 将 JTI 加入黑名单，TTL 与 access token 一致（足够覆盖剩余有效期）。
   */
  async blacklist(jti: string, userId: string): Promise<void> {
    await this.redis.setEx(`token_blacklist:${jti}`, this.accessTtlSec, '1');
    await this.redis.sRem(`user:${userId}:active_jtis`, jti);
  }

  /**
   * 主动失效用户所有 token（解雇 / 紧急情况）。
   */
  async revokeAllForUser(userId: string): Promise<number> {
    const jtis = await this.redis.sMembers(`user:${userId}:active_jtis`);
    // 并行黑名单写入 + 一次性删除 active_jtis 集合
    await Promise.all(
      jtis.map((jti) =>
        this.redis.setEx(`token_blacklist:${jti}`, this.accessTtlSec, '1'),
      ),
    );
    await this.redis.del(`user:${userId}:active_jtis`);
    return jtis.length;
  }

  private async addActiveJti(userId: string, jti: string): Promise<void> {
    await this.redis.sAdd(`user:${userId}:active_jtis`, jti);
    // active_jtis 集合本身 TTL 取 refresh 周期（覆盖所有可能在用的 access）
    await this.redis.expire(`user:${userId}:active_jtis`, this.refreshTtlSec);
  }
}

function parseTtl(v: string): number {
  const m = /^(\d+)([smhd])$/.exec(v.trim());
  if (!m) return 30 * 24 * 60 * 60; // 默认 30d
  const n = parseInt(m[1], 10);
  switch (m[2]) {
    case 's': return n;
    case 'm': return n * 60;
    case 'h': return n * 60 * 60;
    case 'd': return n * 24 * 60 * 60;
    default: return 30 * 24 * 60 * 60;
  }
}
