Files
user-system/internal/api/middleware/auth.go
long-agent 8095307d82 fix: P0/P1 security and quality fixes
P0-01: Add ESCAPE clause to LIKE queries in operation_log.go and device.go
P0-02: Add atomic Increment to L1Cache and L2Cache interfaces
P0-07: Add TOTP verification step after password login
P1-01: Sanitize error messages in error.go middleware
P1-03: Remove err.Error() from export error messages
P1-04: Add error return to CountByResultSince in login_log.go
P1-05: Add transactional DeleteCascade to RoleRepository
P1-06: Add PasswordChangedAt tracking for JWT token invalidation
P1-07: Wrap theme SetDefault in database transaction
P1-08: Use config values for database pool parameters
P1-09: Add rows.Err() checks in social_account_repo.go
P1-10: Validate sortOrder with map in user.go ORDER BY
P1-11: Add GORM tags to Announcement struct
P1-15: Add pageSize upper limit (100) to device and log handlers
2026-04-18 15:33:12 +08:00

243 lines
6.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package middleware
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"golang.org/x/sync/singleflight"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/domain"
apierrors "github.com/user-management-system/internal/pkg/errors"
)
// Interfaces for dependency inversion (DIP) — middleware depends on these abstractions, not concrete types.
type authUserRepository interface {
GetByID(ctx context.Context, id int64) (*domain.User, error)
}
type authUserRoleRepository interface {
GetUserRolesAndPermissions(ctx context.Context, userID int64) ([]*domain.Role, []*domain.Permission, error)
}
type AuthMiddleware struct {
jwt *auth.JWT
userRepo authUserRepository
userRoleRepo authUserRoleRepository
l1Cache *cache.L1Cache
cacheManager *cache.CacheManager
sfGroup singleflight.Group
}
func NewAuthMiddleware(
jwt *auth.JWT,
userRepo authUserRepository,
userRoleRepo authUserRoleRepository,
l1Cache *cache.L1Cache,
) *AuthMiddleware {
return &AuthMiddleware{
jwt: jwt,
userRepo: userRepo,
userRoleRepo: userRoleRepo,
l1Cache: l1Cache,
}
}
func (m *AuthMiddleware) SetCacheManager(cm *cache.CacheManager) {
m.cacheManager = cm
}
func (m *AuthMiddleware) Required() gin.HandlerFunc {
return func(c *gin.Context) {
token := m.extractToken(c)
if token == "" {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "未提供认证令牌"))
c.Abort()
return
}
claims, err := m.jwt.ValidateAccessToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "无效的认证令牌"))
c.Abort()
return
}
if m.isJTIBlacklisted(c.Request.Context(), claims.JTI) {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "令牌已失效,请重新登录"))
c.Abort()
return
}
if m.isPasswordChangedSinceTokenIssued(c.Request.Context(), claims.UserID, claims.PCE) {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "密码已更新,请重新登录"))
c.Abort()
return
}
if !m.isUserActive(c.Request.Context(), claims.UserID) {
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "账号不可用,请重新登录"))
c.Abort()
return
}
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("token_jti", claims.JTI)
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
c.Set("role_codes", roleCodes)
c.Set("permission_codes", permCodes)
c.Next()
}
}
func (m *AuthMiddleware) Optional() gin.HandlerFunc {
return func(c *gin.Context) {
token := m.extractToken(c)
if token != "" {
claims, err := m.jwt.ValidateAccessToken(token)
if err == nil && !m.isJTIBlacklisted(c.Request.Context(), claims.JTI) && !m.isPasswordChangedSinceTokenIssued(c.Request.Context(), claims.UserID, claims.PCE) && m.isUserActive(c.Request.Context(), claims.UserID) {
c.Set("user_id", claims.UserID)
c.Set("username", claims.Username)
c.Set("token_jti", claims.JTI)
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
c.Set("role_codes", roleCodes)
c.Set("permission_codes", permCodes)
}
}
c.Next()
}
}
func (m *AuthMiddleware) isJTIBlacklisted(ctx context.Context, jti string) bool {
if jti == "" {
return false
}
key := "jwt_blacklist:" + jti
// 先检查 L1 缓存
if _, ok := m.l1Cache.Get(key); ok {
return true
}
// L1 miss 时使用 singleflight 防止缓存击穿
// 多个并发请求只会触发一次 L2 查询
if m.cacheManager != nil {
val, err, _ := m.sfGroup.Do(key, func() (interface{}, error) {
found, _ := m.cacheManager.Get(ctx, key)
return found, nil
})
if err == nil && val != nil {
// 回写 L1 缓存
m.l1Cache.Set(key, true, 5*time.Minute)
return true
}
}
return false
}
// isPasswordChangedSinceTokenIssued 检查用户密码是否在令牌发放后已更改
// 如果 tokenPCE 为 0旧令牌则不检查向后兼容
func (m *AuthMiddleware) isPasswordChangedSinceTokenIssued(ctx context.Context, userID int64, tokenPCE int64) bool {
if tokenPCE == 0 {
// 旧令牌没有密码变更时间戳,不拦截
return false
}
if m.userRepo == nil {
return false
}
user, err := m.userRepo.GetByID(ctx, userID)
if err != nil || user.PasswordChangedAt.IsZero() {
return false
}
// 如果令牌的 PCE < 用户密码变更时间,说明密码在令牌发放后已更改
return tokenPCE < user.PasswordChangedAt.Unix()
}
func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) {
if m.userRoleRepo == nil {
return nil, nil
}
cacheKey := fmt.Sprintf("user_perms:%d", userID)
if cached, ok := m.l1Cache.Get(cacheKey); ok {
if entry, ok := cached.(userPermEntry); ok {
return entry.roles, entry.perms
}
}
// 使用已优化的单次 JOIN 查询获取用户角色和权限
roles, permissions, err := m.userRoleRepo.GetUserRolesAndPermissions(ctx, userID)
if err != nil || len(roles) == 0 {
return nil, nil
}
roleCodes := make([]string, 0, len(roles))
for _, role := range roles {
roleCodes = append(roleCodes, role.Code)
}
permCodes := make([]string, 0, len(permissions))
for _, perm := range permissions {
permCodes = append(permCodes, perm.Code)
}
m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 30*time.Minute)
return roleCodes, permCodes
}
func (m *AuthMiddleware) InvalidateUserPermCache(userID int64) {
m.l1Cache.Delete(fmt.Sprintf("user_perms:%d", userID))
}
func (m *AuthMiddleware) AddToBlacklist(jti string, ttl time.Duration) {
if jti != "" && ttl > 0 {
m.l1Cache.Set("jwt_blacklist:"+jti, true, ttl)
}
}
func (m *AuthMiddleware) isUserActive(ctx context.Context, userID int64) bool {
if m.userRepo == nil {
return true
}
user, err := m.userRepo.GetByID(ctx, userID)
if err != nil {
return false
}
return user.Status == domain.UserStatusActive
}
func (m *AuthMiddleware) extractToken(c *gin.Context) string {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
return ""
}
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
return ""
}
return parts[1]
}
type userPermEntry struct {
roles []string
perms []string
}