package middleware import ( "context" "fmt" "net/http" "strings" "time" "github.com/gin-gonic/gin" "golang.org/x/sync/singleflight" "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/cache" "github.com/user-management-system/internal/domain" apierrors "github.com/user-management-system/internal/pkg/errors" ) // Interfaces for dependency inversion (DIP) — middleware depends on these abstractions, not concrete types. type authUserRepository interface { GetByID(ctx context.Context, id int64) (*domain.User, error) } type authUserRoleRepository interface { GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error) } type AuthMiddleware struct { jwt *auth.JWT userRepo authUserRepository userRoleRepo authUserRoleRepository l1Cache *cache.L1Cache cacheManager *cache.CacheManager sfGroup singleflight.Group } func NewAuthMiddleware( jwt *auth.JWT, userRepo authUserRepository, userRoleRepo authUserRoleRepository, l1Cache *cache.L1Cache, ) *AuthMiddleware { return &AuthMiddleware{ jwt: jwt, userRepo: userRepo, userRoleRepo: userRoleRepo, l1Cache: l1Cache, } } func (m *AuthMiddleware) SetCacheManager(cm *cache.CacheManager) { m.cacheManager = cm } func (m *AuthMiddleware) Required() gin.HandlerFunc { return func(c *gin.Context) { token := m.extractToken(c) if token == "" { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "未提供认证令牌")) c.Abort() return } claims, err := m.jwt.ValidateAccessToken(token) if err != nil { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "无效的认证令牌")) c.Abort() return } if m.isJTIBlacklisted(c.Request.Context(), claims.JTI) { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "令牌已失效,请重新登录")) c.Abort() return } // Perf: merge two separate DB round-trips (password-change check + active check) // into a single cached user-state validation. if denyReason := m.validateUserState(c.Request.Context(), claims.UserID, claims.PCE); denyReason != "" { c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", denyReason)) c.Abort() return } c.Set("user_id", claims.UserID) c.Set("username", claims.Username) c.Set("token_jti", claims.JTI) roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID) c.Set("role_codes", roleCodes) c.Set("permission_codes", permCodes) c.Next() } } func (m *AuthMiddleware) Optional() gin.HandlerFunc { return func(c *gin.Context) { token := m.extractToken(c) if token != "" { claims, err := m.jwt.ValidateAccessToken(token) if err == nil && !m.isJTIBlacklisted(c.Request.Context(), claims.JTI) && m.validateUserState(c.Request.Context(), claims.UserID, claims.PCE) == "" { c.Set("user_id", claims.UserID) c.Set("username", claims.Username) c.Set("token_jti", claims.JTI) roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID) c.Set("role_codes", roleCodes) c.Set("permission_codes", permCodes) } } c.Next() } } func (m *AuthMiddleware) isJTIBlacklisted(ctx context.Context, jti string) bool { if jti == "" { return false } key := "jwt_blacklist:" + jti // 先检查 L1 缓存 if _, ok := m.l1Cache.Get(key); ok { return true } // L1 miss 时使用 singleflight 防止缓存击穿 // 多个并发请求只会触发一次 L2 查询 if m.cacheManager != nil { val, err, _ := m.sfGroup.Do(key, func() (interface{}, error) { found, _ := m.cacheManager.Get(ctx, key) return found, nil }) if err == nil && val != nil { // 回写 L1 缓存 m.l1Cache.Set(key, true, 5*time.Minute) return true } } return false } // validateUserState performs a single cached DB lookup that replaces the two // previously separate checks: isPasswordChangedSinceTokenIssued + isUserActive. // // Returns "" on success, or an i18n-ready denial message on failure. // Results are cached for 5 seconds per user to reduce DB pressure under high // concurrency (e.g. 100 VU × 10 req/s = 1 000 auth middleware calls/s against // the same hot user IDs). func (m *AuthMiddleware) validateUserState(ctx context.Context, userID int64, tokenPCE int64) string { if m.userRepo == nil { return "" } // Check short-lived user-state cache (5 s TTL). stateCacheKey := fmt.Sprintf("user_state:%d", userID) if cached, ok := m.l1Cache.Get(stateCacheKey); ok { if state, ok := cached.(userStateEntry); ok { // tokenPCE > 0 means the JWT was issued for a user who had already // changed their password at least once. Zero/negative values come from // users whose PasswordChangedAt is still the Go zero-time, meaning they // have never changed it — skip the check in that case. if tokenPCE > 0 && state.passwordChangedAt > 0 && tokenPCE < state.passwordChangedAt { return "密码已更新,请重新登录" } if !state.active { return "账号不可用,请重新登录" } return "" } } // Cache miss — single DB round-trip. user, err := m.userRepo.GetByID(ctx, userID) if err != nil { return "账号不可用,请重新登录" } state := userStateEntry{ active: user.Status == domain.UserStatusActive, passwordChangedAt: 0, } if !user.PasswordChangedAt.IsZero() { state.passwordChangedAt = user.PasswordChangedAt.Unix() } // Cache for 5 seconds — short enough to reflect account lock/disable promptly. m.l1Cache.Set(stateCacheKey, state, 5*time.Second) // Same guard: tokenPCE <= 0 means no password-change time in the JWT → skip. if tokenPCE > 0 && state.passwordChangedAt > 0 && tokenPCE < state.passwordChangedAt { return "密码已更新,请重新登录" } if !state.active { return "账号不可用,请重新登录" } return "" } // InvalidateUserStateCache removes the user-state cache entry so the next // request picks up fresh data. Call this after status change or password reset. func (m *AuthMiddleware) InvalidateUserStateCache(userID int64) { m.l1Cache.Delete(fmt.Sprintf("user_state:%d", userID)) } // isPasswordChangedSinceTokenIssued 检查用户密码是否在令牌发放后已更改 // Deprecated: use validateUserState for combined check with caching. func (m *AuthMiddleware) isPasswordChangedSinceTokenIssued(ctx context.Context, userID int64, tokenPCE int64) bool { if tokenPCE == 0 { return false } if m.userRepo == nil { return false } user, err := m.userRepo.GetByID(ctx, userID) if err != nil || user.PasswordChangedAt.IsZero() { return false } return tokenPCE < user.PasswordChangedAt.Unix() } func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) { if m.userRoleRepo == nil { return nil, nil } cacheKey := fmt.Sprintf("user_perms:%d", userID) if cached, ok := m.l1Cache.Get(cacheKey); ok { if entry, ok := cached.(userPermEntry); ok { return entry.roles, entry.perms } } // 使用已优化的单次 JOIN 查询获取用户角色和权限 roles, permissions, err := m.userRoleRepo.GetUserRolesAndPermissions(ctx, userID) if err != nil || len(roles) == 0 { return nil, nil } roleCodes := make([]string, 0, len(roles)) for _, role := range roles { roleCodes = append(roleCodes, role.Code) } permCodes := make([]string, 0, len(permissions)) for _, perm := range permissions { permCodes = append(permCodes, perm.Code) } // P1-2 权限缓存 TTL 调优:5min(原 30min) // 理由:角色/权限变更后最长 5min 生效,与 userStateEntry TTL 保持一致。 // 若需立即生效,调用 InvalidateUserPermCache(userID) 主动驱逐。 m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 5*time.Minute) return roleCodes, permCodes } func (m *AuthMiddleware) InvalidateUserPermCache(userID int64) { m.l1Cache.Delete(fmt.Sprintf("user_perms:%d", userID)) } func (m *AuthMiddleware) AddToBlacklist(jti string, ttl time.Duration) { if jti != "" && ttl > 0 { m.l1Cache.Set("jwt_blacklist:"+jti, true, ttl) } } func (m *AuthMiddleware) isUserActive(ctx context.Context, userID int64) bool { if m.userRepo == nil { return true } user, err := m.userRepo.GetByID(ctx, userID) if err != nil { return false } return user.Status == domain.UserStatusActive } func (m *AuthMiddleware) extractToken(c *gin.Context) string { authHeader := c.GetHeader("Authorization") if authHeader == "" { return "" } parts := strings.SplitN(authHeader, " ", 2) if len(parts) != 2 || parts[0] != "Bearer" { return "" } return parts[1] } type userPermEntry struct { roles []string perms []string } // userStateEntry caches the minimal user state needed for auth checks. // TTL is 5 s so that account lock/disable takes effect within seconds. type userStateEntry struct { active bool passwordChangedAt int64 // Unix timestamp; 0 means never changed }