P0(高优先级): - P0-1: 确认数据库复合索引已存在(GORM tag),composite_index_test 验证通过 - P0-2: 连接池调优 MaxIdleConns 5→10, ConnMaxLifetime 30min→5min - P0-3: Redis 智能探测(ProbeRedis),无 Redis 自动降级到纯内存模式 P1(中优先级): - P1-1: GZIP 压缩中间件(compress/gzip 标准库,零新依赖) - P1-2: 权限缓存 TTL 30min→5min - P1-3: Argon2id 启动自适应校准(CalibrateArgon2id) 历史优化(含本次提交): - L1Cache O(n)→O(1) LRU 重构 - Auth 中间件 DB 查询合并 + 5s L1 缓存 - Logger 异步化(4096 缓冲通道) 验证: go build/vet/test 41/41 PASS, govulncheck 无漏洞
307 lines
8.9 KiB
Go
307 lines
8.9 KiB
Go
package middleware
|
||
|
||
import (
|
||
"context"
|
||
"fmt"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/gin-gonic/gin"
|
||
"golang.org/x/sync/singleflight"
|
||
|
||
"github.com/user-management-system/internal/auth"
|
||
"github.com/user-management-system/internal/cache"
|
||
"github.com/user-management-system/internal/domain"
|
||
apierrors "github.com/user-management-system/internal/pkg/errors"
|
||
)
|
||
|
||
// Interfaces for dependency inversion (DIP) — middleware depends on these abstractions, not concrete types.
|
||
type authUserRepository interface {
|
||
GetByID(ctx context.Context, id int64) (*domain.User, error)
|
||
}
|
||
|
||
type authUserRoleRepository interface {
|
||
GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error)
|
||
}
|
||
|
||
type AuthMiddleware struct {
|
||
jwt *auth.JWT
|
||
userRepo authUserRepository
|
||
userRoleRepo authUserRoleRepository
|
||
l1Cache *cache.L1Cache
|
||
cacheManager *cache.CacheManager
|
||
sfGroup singleflight.Group
|
||
}
|
||
|
||
func NewAuthMiddleware(
|
||
jwt *auth.JWT,
|
||
userRepo authUserRepository,
|
||
userRoleRepo authUserRoleRepository,
|
||
l1Cache *cache.L1Cache,
|
||
) *AuthMiddleware {
|
||
return &AuthMiddleware{
|
||
jwt: jwt,
|
||
userRepo: userRepo,
|
||
userRoleRepo: userRoleRepo,
|
||
l1Cache: l1Cache,
|
||
}
|
||
}
|
||
|
||
func (m *AuthMiddleware) SetCacheManager(cm *cache.CacheManager) {
|
||
m.cacheManager = cm
|
||
}
|
||
|
||
func (m *AuthMiddleware) Required() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
token := m.extractToken(c)
|
||
if token == "" {
|
||
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "未提供认证令牌"))
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
claims, err := m.jwt.ValidateAccessToken(token)
|
||
if err != nil {
|
||
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "无效的认证令牌"))
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
if m.isJTIBlacklisted(c.Request.Context(), claims.JTI) {
|
||
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "令牌已失效,请重新登录"))
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
// Perf: merge two separate DB round-trips (password-change check + active check)
|
||
// into a single cached user-state validation.
|
||
if denyReason := m.validateUserState(c.Request.Context(), claims.UserID, claims.PCE); denyReason != "" {
|
||
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", denyReason))
|
||
c.Abort()
|
||
return
|
||
}
|
||
|
||
c.Set("user_id", claims.UserID)
|
||
c.Set("username", claims.Username)
|
||
c.Set("token_jti", claims.JTI)
|
||
|
||
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
|
||
c.Set("role_codes", roleCodes)
|
||
c.Set("permission_codes", permCodes)
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
func (m *AuthMiddleware) Optional() gin.HandlerFunc {
|
||
return func(c *gin.Context) {
|
||
token := m.extractToken(c)
|
||
if token != "" {
|
||
claims, err := m.jwt.ValidateAccessToken(token)
|
||
if err == nil && !m.isJTIBlacklisted(c.Request.Context(), claims.JTI) && m.validateUserState(c.Request.Context(), claims.UserID, claims.PCE) == "" {
|
||
c.Set("user_id", claims.UserID)
|
||
c.Set("username", claims.Username)
|
||
c.Set("token_jti", claims.JTI)
|
||
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
|
||
c.Set("role_codes", roleCodes)
|
||
c.Set("permission_codes", permCodes)
|
||
}
|
||
}
|
||
|
||
c.Next()
|
||
}
|
||
}
|
||
|
||
func (m *AuthMiddleware) isJTIBlacklisted(ctx context.Context, jti string) bool {
|
||
if jti == "" {
|
||
return false
|
||
}
|
||
|
||
key := "jwt_blacklist:" + jti
|
||
|
||
// 先检查 L1 缓存
|
||
if _, ok := m.l1Cache.Get(key); ok {
|
||
return true
|
||
}
|
||
|
||
// L1 miss 时使用 singleflight 防止缓存击穿
|
||
// 多个并发请求只会触发一次 L2 查询
|
||
if m.cacheManager != nil {
|
||
val, err, _ := m.sfGroup.Do(key, func() (interface{}, error) {
|
||
found, _ := m.cacheManager.Get(ctx, key)
|
||
return found, nil
|
||
})
|
||
if err == nil && val != nil {
|
||
// 回写 L1 缓存
|
||
m.l1Cache.Set(key, true, 5*time.Minute)
|
||
return true
|
||
}
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// validateUserState performs a single cached DB lookup that replaces the two
|
||
// previously separate checks: isPasswordChangedSinceTokenIssued + isUserActive.
|
||
//
|
||
// Returns "" on success, or an i18n-ready denial message on failure.
|
||
// Results are cached for 5 seconds per user to reduce DB pressure under high
|
||
// concurrency (e.g. 100 VU × 10 req/s = 1 000 auth middleware calls/s against
|
||
// the same hot user IDs).
|
||
func (m *AuthMiddleware) validateUserState(ctx context.Context, userID int64, tokenPCE int64) string {
|
||
if m.userRepo == nil {
|
||
return ""
|
||
}
|
||
|
||
// Check short-lived user-state cache (5 s TTL).
|
||
stateCacheKey := fmt.Sprintf("user_state:%d", userID)
|
||
if cached, ok := m.l1Cache.Get(stateCacheKey); ok {
|
||
if state, ok := cached.(userStateEntry); ok {
|
||
// tokenPCE > 0 means the JWT was issued for a user who had already
|
||
// changed their password at least once. Zero/negative values come from
|
||
// users whose PasswordChangedAt is still the Go zero-time, meaning they
|
||
// have never changed it — skip the check in that case.
|
||
if tokenPCE > 0 && state.passwordChangedAt > 0 && tokenPCE < state.passwordChangedAt {
|
||
return "密码已更新,请重新登录"
|
||
}
|
||
if !state.active {
|
||
return "账号不可用,请重新登录"
|
||
}
|
||
return ""
|
||
}
|
||
}
|
||
|
||
// Cache miss — single DB round-trip.
|
||
user, err := m.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return "账号不可用,请重新登录"
|
||
}
|
||
|
||
state := userStateEntry{
|
||
active: user.Status == domain.UserStatusActive,
|
||
passwordChangedAt: 0,
|
||
}
|
||
if !user.PasswordChangedAt.IsZero() {
|
||
state.passwordChangedAt = user.PasswordChangedAt.Unix()
|
||
}
|
||
|
||
// Cache for 5 seconds — short enough to reflect account lock/disable promptly.
|
||
m.l1Cache.Set(stateCacheKey, state, 5*time.Second)
|
||
|
||
// Same guard: tokenPCE <= 0 means no password-change time in the JWT → skip.
|
||
if tokenPCE > 0 && state.passwordChangedAt > 0 && tokenPCE < state.passwordChangedAt {
|
||
return "密码已更新,请重新登录"
|
||
}
|
||
if !state.active {
|
||
return "账号不可用,请重新登录"
|
||
}
|
||
return ""
|
||
}
|
||
|
||
// InvalidateUserStateCache removes the user-state cache entry so the next
|
||
// request picks up fresh data. Call this after status change or password reset.
|
||
func (m *AuthMiddleware) InvalidateUserStateCache(userID int64) {
|
||
m.l1Cache.Delete(fmt.Sprintf("user_state:%d", userID))
|
||
}
|
||
|
||
// isPasswordChangedSinceTokenIssued 检查用户密码是否在令牌发放后已更改
|
||
// Deprecated: use validateUserState for combined check with caching.
|
||
func (m *AuthMiddleware) isPasswordChangedSinceTokenIssued(ctx context.Context, userID int64, tokenPCE int64) bool {
|
||
if tokenPCE == 0 {
|
||
return false
|
||
}
|
||
if m.userRepo == nil {
|
||
return false
|
||
}
|
||
user, err := m.userRepo.GetByID(ctx, userID)
|
||
if err != nil || user.PasswordChangedAt.IsZero() {
|
||
return false
|
||
}
|
||
return tokenPCE < user.PasswordChangedAt.Unix()
|
||
}
|
||
|
||
func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) {
|
||
if m.userRoleRepo == nil {
|
||
return nil, nil
|
||
}
|
||
|
||
cacheKey := fmt.Sprintf("user_perms:%d", userID)
|
||
if cached, ok := m.l1Cache.Get(cacheKey); ok {
|
||
if entry, ok := cached.(userPermEntry); ok {
|
||
return entry.roles, entry.perms
|
||
}
|
||
}
|
||
|
||
// 使用已优化的单次 JOIN 查询获取用户角色和权限
|
||
roles, permissions, err := m.userRoleRepo.GetUserRolesAndPermissions(ctx, userID)
|
||
if err != nil || len(roles) == 0 {
|
||
return nil, nil
|
||
}
|
||
|
||
roleCodes := make([]string, 0, len(roles))
|
||
for _, role := range roles {
|
||
roleCodes = append(roleCodes, role.Code)
|
||
}
|
||
|
||
permCodes := make([]string, 0, len(permissions))
|
||
for _, perm := range permissions {
|
||
permCodes = append(permCodes, perm.Code)
|
||
}
|
||
|
||
// P1-2 权限缓存 TTL 调优:5min(原 30min)
|
||
// 理由:角色/权限变更后最长 5min 生效,与 userStateEntry TTL 保持一致。
|
||
// 若需立即生效,调用 InvalidateUserPermCache(userID) 主动驱逐。
|
||
m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 5*time.Minute)
|
||
return roleCodes, permCodes
|
||
}
|
||
|
||
func (m *AuthMiddleware) InvalidateUserPermCache(userID int64) {
|
||
m.l1Cache.Delete(fmt.Sprintf("user_perms:%d", userID))
|
||
}
|
||
|
||
func (m *AuthMiddleware) AddToBlacklist(jti string, ttl time.Duration) {
|
||
if jti != "" && ttl > 0 {
|
||
m.l1Cache.Set("jwt_blacklist:"+jti, true, ttl)
|
||
}
|
||
}
|
||
|
||
func (m *AuthMiddleware) isUserActive(ctx context.Context, userID int64) bool {
|
||
if m.userRepo == nil {
|
||
return true
|
||
}
|
||
|
||
user, err := m.userRepo.GetByID(ctx, userID)
|
||
if err != nil {
|
||
return false
|
||
}
|
||
|
||
return user.Status == domain.UserStatusActive
|
||
}
|
||
|
||
func (m *AuthMiddleware) extractToken(c *gin.Context) string {
|
||
authHeader := c.GetHeader("Authorization")
|
||
if authHeader == "" {
|
||
return ""
|
||
}
|
||
|
||
parts := strings.SplitN(authHeader, " ", 2)
|
||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||
return ""
|
||
}
|
||
|
||
return parts[1]
|
||
}
|
||
|
||
type userPermEntry struct {
|
||
roles []string
|
||
perms []string
|
||
}
|
||
|
||
// userStateEntry caches the minimal user state needed for auth checks.
|
||
// TTL is 5 s so that account lock/disable takes effect within seconds.
|
||
type userStateEntry struct {
|
||
active bool
|
||
passwordChangedAt int64 // Unix timestamp; 0 means never changed
|
||
}
|