Files
user-system/internal/api/middleware/auth.go
long-agent 7b047e2f11 perf: Sprint 19 P0/P1 性能优化落地
P0(高优先级):
- P0-1: 确认数据库复合索引已存在(GORM tag),composite_index_test 验证通过
- P0-2: 连接池调优 MaxIdleConns 5→10, ConnMaxLifetime 30min→5min
- P0-3: Redis 智能探测(ProbeRedis),无 Redis 自动降级到纯内存模式

P1(中优先级):
- P1-1: GZIP 压缩中间件(compress/gzip 标准库,零新依赖)
- P1-2: 权限缓存 TTL 30min→5min
- P1-3: Argon2id 启动自适应校准(CalibrateArgon2id)

历史优化(含本次提交):
- L1Cache O(n)→O(1) LRU 重构
- Auth 中间件 DB 查询合并 + 5s L1 缓存
- Logger 异步化(4096 缓冲通道)

验证: go build/vet/test 41/41 PASS, govulncheck 无漏洞
2026-04-18 22:57:44 +08:00

307 lines
8.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/sync/singleflight"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain"
apierrors "github.com/user-management-system/internal/pkg/errors"
)
// Interfaces for dependency inversion (DIP) — middleware depends on these abstractions, not concrete types.
type authUserRepository interface {
GetByID(ctx context.Context, id int64) (*domain.User, error)
}
type authUserRoleRepository interface {
GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error)
}
type AuthMiddleware struct {
jwt *auth.JWT
userRepo authUserRepository
userRoleRepo authUserRoleRepository
l1Cache *cache.L1Cache
cacheManager *cache.CacheManager
sfGroup singleflight.Group
}
func NewAuthMiddleware(
jwt *auth.JWT,
userRepo authUserRepository,
userRoleRepo authUserRoleRepository,
l1Cache *cache.L1Cache,
) *AuthMiddleware {
return &AuthMiddleware{
jwt: jwt,
userRepo: userRepo,
userRoleRepo: userRoleRepo,
l1Cache: l1Cache,
}
}
func (m *AuthMiddleware) SetCacheManager(cm *cache.CacheManager) {
m.cacheManager = cm
}
func (m *AuthMiddleware) Required() gin.HandlerFunc {
return func(c *gin.Context) {
token := m.extractToken(c)
if token == "" {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "未提供认证令牌"))
c.Abort()
return
}
claims, err := m.jwt.ValidateAccessToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "无效的认证令牌"))
c.Abort()
return
}
if m.isJTIBlacklisted(c.Request.Context(), claims.JTI) {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "令牌已失效,请重新登录"))
c.Abort()
return
}
// Perf: merge two separate DB round-trips (password-change check + active check)
// into a single cached user-state validation.
if denyReason := m.validateUserState(c.Request.Context(), claims.UserID, claims.PCE); denyReason != "" {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", denyReason))
c.Abort()
return
}
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("token_jti", claims.JTI)
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
c.Set("role_codes", roleCodes)
c.Set("permission_codes", permCodes)
c.Next()
}
}
func (m *AuthMiddleware) Optional() gin.HandlerFunc {
return func(c *gin.Context) {
token := m.extractToken(c)
if token != "" {
claims, err := m.jwt.ValidateAccessToken(token)
if err == nil && !m.isJTIBlacklisted(c.Request.Context(), claims.JTI) && m.validateUserState(c.Request.Context(), claims.UserID, claims.PCE) == "" {
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("token_jti", claims.JTI)
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
c.Set("role_codes", roleCodes)
c.Set("permission_codes", permCodes)
}
}
c.Next()
}
}
func (m *AuthMiddleware) isJTIBlacklisted(ctx context.Context, jti string) bool {
if jti == "" {
return false
}
key := "jwt_blacklist:" + jti
// 先检查 L1 缓存
if _, ok := m.l1Cache.Get(key); ok {
return true
}
// L1 miss 时使用 singleflight 防止缓存击穿
// 多个并发请求只会触发一次 L2 查询
if m.cacheManager != nil {
val, err, _ := m.sfGroup.Do(key, func() (interface{}, error) {
found, _ := m.cacheManager.Get(ctx, key)
return found, nil
})
if err == nil && val != nil {
// 回写 L1 缓存
m.l1Cache.Set(key, true, 5*time.Minute)
return true
}
}
return false
}
// validateUserState performs a single cached DB lookup that replaces the two
// previously separate checks: isPasswordChangedSinceTokenIssued + isUserActive.
//
// Returns "" on success, or an i18n-ready denial message on failure.
// Results are cached for 5 seconds per user to reduce DB pressure under high
// concurrency (e.g. 100 VU × 10 req/s = 1 000 auth middleware calls/s against
// the same hot user IDs).
func (m *AuthMiddleware) validateUserState(ctx context.Context, userID int64, tokenPCE int64) string {
if m.userRepo == nil {
return ""
}
// Check short-lived user-state cache (5 s TTL).
stateCacheKey := fmt.Sprintf("user_state:%d", userID)
if cached, ok := m.l1Cache.Get(stateCacheKey); ok {
if state, ok := cached.(userStateEntry); ok {
// tokenPCE > 0 means the JWT was issued for a user who had already
// changed their password at least once. Zero/negative values come from
// users whose PasswordChangedAt is still the Go zero-time, meaning they
// have never changed it — skip the check in that case.
if tokenPCE > 0 && state.passwordChangedAt > 0 && tokenPCE < state.passwordChangedAt {
return "密码已更新,请重新登录"
}
if !state.active {
return "账号不可用,请重新登录"
}
return ""
}
}
// Cache miss — single DB round-trip.
user, err := m.userRepo.GetByID(ctx, userID)
if err != nil {
return "账号不可用,请重新登录"
}
state := userStateEntry{
active: user.Status == domain.UserStatusActive,
passwordChangedAt: 0,
}
if !user.PasswordChangedAt.IsZero() {
state.passwordChangedAt = user.PasswordChangedAt.Unix()
}
// Cache for 5 seconds — short enough to reflect account lock/disable promptly.
m.l1Cache.Set(stateCacheKey, state, 5*time.Second)
// Same guard: tokenPCE <= 0 means no password-change time in the JWT → skip.
if tokenPCE > 0 && state.passwordChangedAt > 0 && tokenPCE < state.passwordChangedAt {
return "密码已更新,请重新登录"
}
if !state.active {
return "账号不可用,请重新登录"
}
return ""
}
// InvalidateUserStateCache removes the user-state cache entry so the next
// request picks up fresh data. Call this after status change or password reset.
func (m *AuthMiddleware) InvalidateUserStateCache(userID int64) {
m.l1Cache.Delete(fmt.Sprintf("user_state:%d", userID))
}
// isPasswordChangedSinceTokenIssued 检查用户密码是否在令牌发放后已更改
// Deprecated: use validateUserState for combined check with caching.
func (m *AuthMiddleware) isPasswordChangedSinceTokenIssued(ctx context.Context, userID int64, tokenPCE int64) bool {
if tokenPCE == 0 {
return false
}
if m.userRepo == nil {
return false
}
user, err := m.userRepo.GetByID(ctx, userID)
if err != nil || user.PasswordChangedAt.IsZero() {
return false
}
return tokenPCE < user.PasswordChangedAt.Unix()
}
func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) {
if m.userRoleRepo == nil {
return nil, nil
}
cacheKey := fmt.Sprintf("user_perms:%d", userID)
if cached, ok := m.l1Cache.Get(cacheKey); ok {
if entry, ok := cached.(userPermEntry); ok {
return entry.roles, entry.perms
}
}
// 使用已优化的单次 JOIN 查询获取用户角色和权限
roles, permissions, err := m.userRoleRepo.GetUserRolesAndPermissions(ctx, userID)
if err != nil || len(roles) == 0 {
return nil, nil
}
roleCodes := make([]string, 0, len(roles))
for _, role := range roles {
roleCodes = append(roleCodes, role.Code)
}
permCodes := make([]string, 0, len(permissions))
for _, perm := range permissions {
permCodes = append(permCodes, perm.Code)
}
// P1-2 权限缓存 TTL 调优5min原 30min
// 理由:角色/权限变更后最长 5min 生效,与 userStateEntry TTL 保持一致。
// 若需立即生效,调用 InvalidateUserPermCache(userID) 主动驱逐。
m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 5*time.Minute)
return roleCodes, permCodes
}
func (m *AuthMiddleware) InvalidateUserPermCache(userID int64) {
m.l1Cache.Delete(fmt.Sprintf("user_perms:%d", userID))
}
func (m *AuthMiddleware) AddToBlacklist(jti string, ttl time.Duration) {
if jti != "" && ttl > 0 {
m.l1Cache.Set("jwt_blacklist:"+jti, true, ttl)
}
}
func (m *AuthMiddleware) isUserActive(ctx context.Context, userID int64) bool {
if m.userRepo == nil {
return true
}
user, err := m.userRepo.GetByID(ctx, userID)
if err != nil {
return false
}
return user.Status == domain.UserStatusActive
}
func (m *AuthMiddleware) extractToken(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
return ""
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
return ""
}
return parts[1]
}
type userPermEntry struct {
roles []string
perms []string
}
// userStateEntry caches the minimal user state needed for auth checks.
// TTL is 5 s so that account lock/disable takes effect within seconds.
type userStateEntry struct {
active bool
passwordChangedAt int64 // Unix timestamp; 0 means never changed
}