feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers
This commit is contained in:
240
internal/api/middleware/auth.go
Normal file
240
internal/api/middleware/auth.go
Normal file
@@ -0,0 +1,240 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
apierrors "github.com/user-management-system/internal/pkg/errors"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
type AuthMiddleware struct {
|
||||
jwt *auth.JWT
|
||||
userRepo *repository.UserRepository
|
||||
userRoleRepo *repository.UserRoleRepository
|
||||
roleRepo *repository.RoleRepository
|
||||
rolePermissionRepo *repository.RolePermissionRepository
|
||||
permissionRepo *repository.PermissionRepository
|
||||
l1Cache *cache.L1Cache
|
||||
cacheManager *cache.CacheManager
|
||||
}
|
||||
|
||||
func NewAuthMiddleware(
|
||||
jwt *auth.JWT,
|
||||
userRepo *repository.UserRepository,
|
||||
userRoleRepo *repository.UserRoleRepository,
|
||||
roleRepo *repository.RoleRepository,
|
||||
rolePermissionRepo *repository.RolePermissionRepository,
|
||||
permissionRepo *repository.PermissionRepository,
|
||||
) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
jwt: jwt,
|
||||
userRepo: userRepo,
|
||||
userRoleRepo: userRoleRepo,
|
||||
roleRepo: roleRepo,
|
||||
rolePermissionRepo: rolePermissionRepo,
|
||||
permissionRepo: permissionRepo,
|
||||
l1Cache: cache.NewL1Cache(),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) SetCacheManager(cm *cache.CacheManager) {
|
||||
m.cacheManager = cm
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) Required() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := m.extractToken(c)
|
||||
if token == "" {
|
||||
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "未提供认证令牌"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := m.jwt.ValidateAccessToken(token)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "无效的认证令牌"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if m.isJTIBlacklisted(claims.JTI) {
|
||||
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "令牌已失效,请重新登录"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
if !m.isUserActive(c.Request.Context(), claims.UserID) {
|
||||
c.JSON(http.StatusUnauthorized, apierrors.New(http.StatusUnauthorized, "UNAUTHORIZED", "账号不可用,请重新登录"))
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("token_jti", claims.JTI)
|
||||
|
||||
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
|
||||
c.Set("role_codes", roleCodes)
|
||||
c.Set("permission_codes", permCodes)
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) Optional() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := m.extractToken(c)
|
||||
if token != "" {
|
||||
claims, err := m.jwt.ValidateAccessToken(token)
|
||||
if err == nil && !m.isJTIBlacklisted(claims.JTI) && m.isUserActive(c.Request.Context(), claims.UserID) {
|
||||
c.Set("user_id", claims.UserID)
|
||||
c.Set("username", claims.Username)
|
||||
c.Set("token_jti", claims.JTI)
|
||||
roleCodes, permCodes := m.loadUserRolesAndPerms(c.Request.Context(), claims.UserID)
|
||||
c.Set("role_codes", roleCodes)
|
||||
c.Set("permission_codes", permCodes)
|
||||
}
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) isJTIBlacklisted(jti string) bool {
|
||||
if jti == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
key := "jwt_blacklist:" + jti
|
||||
if _, ok := m.l1Cache.Get(key); ok {
|
||||
return true
|
||||
}
|
||||
|
||||
if m.cacheManager != nil {
|
||||
if _, ok := m.cacheManager.Get(context.Background(), key); ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) {
|
||||
if m.userRoleRepo == nil || m.roleRepo == nil || m.rolePermissionRepo == nil || m.permissionRepo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
cacheKey := fmt.Sprintf("user_perms:%d", userID)
|
||||
if cached, ok := m.l1Cache.Get(cacheKey); ok {
|
||||
if entry, ok := cached.(userPermEntry); ok {
|
||||
return entry.roles, entry.perms
|
||||
}
|
||||
}
|
||||
|
||||
roleIDs, err := m.userRoleRepo.GetRoleIDsByUserID(ctx, userID)
|
||||
if err != nil || len(roleIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 收集所有角色ID(包括直接分配的角色和所有祖先角色)
|
||||
allRoleIDs := make([]int64, 0, len(roleIDs)*2)
|
||||
allRoleIDs = append(allRoleIDs, roleIDs...)
|
||||
|
||||
for _, roleID := range roleIDs {
|
||||
ancestorIDs, err := m.roleRepo.GetAncestorIDs(ctx, roleID)
|
||||
if err == nil && len(ancestorIDs) > 0 {
|
||||
allRoleIDs = append(allRoleIDs, ancestorIDs...)
|
||||
}
|
||||
}
|
||||
|
||||
// 去重
|
||||
seen := make(map[int64]bool)
|
||||
uniqueRoleIDs := make([]int64, 0, len(allRoleIDs))
|
||||
for _, id := range allRoleIDs {
|
||||
if !seen[id] {
|
||||
seen[id] = true
|
||||
uniqueRoleIDs = append(uniqueRoleIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
roles, err := m.roleRepo.GetByIDs(ctx, roleIDs)
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
roleCodes := make([]string, 0, len(roles))
|
||||
for _, role := range roles {
|
||||
roleCodes = append(roleCodes, role.Code)
|
||||
}
|
||||
|
||||
permissionIDs, err := m.rolePermissionRepo.GetPermissionIDsByRoleIDs(ctx, uniqueRoleIDs)
|
||||
if err != nil || len(permissionIDs) == 0 {
|
||||
entry := userPermEntry{roles: roleCodes, perms: []string{}}
|
||||
m.l1Cache.Set(cacheKey, entry, 30*time.Minute) // PERF-01 优化:增加缓存 TTL 减少 DB 查询
|
||||
return entry.roles, entry.perms
|
||||
}
|
||||
|
||||
permissions, err := m.permissionRepo.GetByIDs(ctx, permissionIDs)
|
||||
if err != nil {
|
||||
return roleCodes, nil
|
||||
}
|
||||
|
||||
permCodes := make([]string, 0, len(permissions))
|
||||
for _, permission := range permissions {
|
||||
permCodes = append(permCodes, permission.Code)
|
||||
}
|
||||
|
||||
m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 30*time.Minute) // PERF-01 优化:增加缓存 TTL 减少 DB 查询
|
||||
return roleCodes, permCodes
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) InvalidateUserPermCache(userID int64) {
|
||||
m.l1Cache.Delete(fmt.Sprintf("user_perms:%d", userID))
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) AddToBlacklist(jti string, ttl time.Duration) {
|
||||
if jti != "" && ttl > 0 {
|
||||
m.l1Cache.Set("jwt_blacklist:"+jti, true, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) isUserActive(ctx context.Context, userID int64) bool {
|
||||
if m.userRepo == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
user, err := m.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return user.Status == domain.UserStatusActive
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) extractToken(c *gin.Context) string {
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
parts := strings.SplitN(authHeader, " ", 2)
|
||||
if len(parts) != 2 || parts[0] != "Bearer" {
|
||||
return ""
|
||||
}
|
||||
|
||||
return parts[1]
|
||||
}
|
||||
|
||||
type userPermEntry struct {
|
||||
roles []string
|
||||
perms []string
|
||||
}
|
||||
32
internal/api/middleware/cache_control.go
Normal file
32
internal/api/middleware/cache_control.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const sensitiveNoStoreCacheControl = "no-store, no-cache, must-revalidate, max-age=0"
|
||||
|
||||
// NoStoreSensitiveResponses prevents browser or intermediary caching for auth routes.
|
||||
func NoStoreSensitiveResponses() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if shouldDisableCaching(c.FullPath(), c.Request.URL.Path) {
|
||||
headers := c.Writer.Header()
|
||||
headers.Set("Cache-Control", sensitiveNoStoreCacheControl)
|
||||
headers.Set("Pragma", "no-cache")
|
||||
headers.Set("Expires", "0")
|
||||
headers.Set("Surrogate-Control", "no-store")
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func shouldDisableCaching(routePath, requestPath string) bool {
|
||||
path := strings.TrimSpace(routePath)
|
||||
if path == "" {
|
||||
path = strings.TrimSpace(requestPath)
|
||||
}
|
||||
return strings.HasPrefix(path, "/api/v1/auth")
|
||||
}
|
||||
67
internal/api/middleware/cors.go
Normal file
67
internal/api/middleware/cors.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/user-management-system/internal/config"
|
||||
)
|
||||
|
||||
var corsConfig = config.CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowCredentials: true,
|
||||
}
|
||||
|
||||
func SetCORSConfig(cfg config.CORSConfig) {
|
||||
corsConfig = cfg
|
||||
}
|
||||
|
||||
func CORS() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
cfg := corsConfig
|
||||
|
||||
origin := c.GetHeader("Origin")
|
||||
if origin != "" {
|
||||
allowOrigin, allowed := resolveAllowedOrigin(origin, cfg.AllowedOrigins, cfg.AllowCredentials)
|
||||
if !allowed {
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
c.AbortWithStatus(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
c.AbortWithStatus(http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
c.Writer.Header().Set("Access-Control-Allow-Origin", allowOrigin)
|
||||
if cfg.AllowCredentials {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Credentials", "true")
|
||||
}
|
||||
}
|
||||
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With, X-CSRF-Token")
|
||||
c.Writer.Header().Set("Access-Control-Max-Age", "3600")
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func resolveAllowedOrigin(origin string, allowedOrigins []string, allowCredentials bool) (string, bool) {
|
||||
for _, allowed := range allowedOrigins {
|
||||
if allowed == "*" {
|
||||
if allowCredentials {
|
||||
return origin, true
|
||||
}
|
||||
return "*", true
|
||||
}
|
||||
if strings.EqualFold(origin, allowed) {
|
||||
return origin, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
43
internal/api/middleware/error.go
Normal file
43
internal/api/middleware/error.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
apierrors "github.com/user-management-system/internal/pkg/errors"
|
||||
)
|
||||
|
||||
// ErrorHandler 错误处理中间件
|
||||
func ErrorHandler() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
|
||||
// 检查是否有错误
|
||||
if len(c.Errors) > 0 {
|
||||
// 获取最后一个错误
|
||||
err := c.Errors.Last()
|
||||
|
||||
// 判断错误类型
|
||||
if appErr, ok := err.Err.(*apierrors.ApplicationError); ok {
|
||||
c.JSON(int(appErr.Code), appErr)
|
||||
} else {
|
||||
c.JSON(http.StatusInternalServerError, apierrors.New(http.StatusInternalServerError, "", err.Err.Error()))
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Recover 恢复中间件
|
||||
func Recover() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, apierrors.New(http.StatusInternalServerError, "", "服务器内部错误"))
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
134
internal/api/middleware/ip_filter.go
Normal file
134
internal/api/middleware/ip_filter.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/user-management-system/internal/security"
|
||||
)
|
||||
|
||||
// IPFilterConfig IP过滤中间件配置
|
||||
type IPFilterConfig struct {
|
||||
TrustProxy bool // 是否信任 X-Forwarded-For
|
||||
TrustedProxies []string // 可信代理 IP 列表
|
||||
}
|
||||
|
||||
// IPFilterMiddleware IP 黑白名单过滤中间件
|
||||
type IPFilterMiddleware struct {
|
||||
filter *security.IPFilter
|
||||
config IPFilterConfig
|
||||
}
|
||||
|
||||
// NewIPFilterMiddleware 创建 IP 过滤中间件
|
||||
func NewIPFilterMiddleware(filter *security.IPFilter, config IPFilterConfig) *IPFilterMiddleware {
|
||||
return &IPFilterMiddleware{filter: filter, config: config}
|
||||
}
|
||||
|
||||
// Filter 返回 Gin 中间件 HandlerFunc
|
||||
// 逻辑:先取客户端真实 IP → 检查黑名单 → 被封则返回 403 并终止
|
||||
func (m *IPFilterMiddleware) Filter() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
ip := m.realIP(c)
|
||||
|
||||
blocked, reason := m.filter.IsBlocked(ip)
|
||||
if blocked {
|
||||
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
|
||||
"code": 403,
|
||||
"message": "访问被拒绝:" + reason,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// 将真实 IP 写入 context,供后续中间件和 handler 直接取用
|
||||
c.Set("client_ip", ip)
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// GetFilter 返回底层 IPFilter,供 handler 层做黑白名单管理
|
||||
func (m *IPFilterMiddleware) GetFilter() *security.IPFilter {
|
||||
return m.filter
|
||||
}
|
||||
|
||||
// realIP 从请求中提取真实客户端 IP
|
||||
// 优先级:X-Forwarded-For > X-Real-IP > RemoteAddr
|
||||
// SEC-05 修复:如果启用 TrustProxy,只接受来自可信代理的 X-Forwarded-For
|
||||
func (m *IPFilterMiddleware) realIP(c *gin.Context) string {
|
||||
// 如果不信任代理,直接使用 TCP 连接 IP
|
||||
if !m.config.TrustProxy {
|
||||
return c.ClientIP()
|
||||
}
|
||||
|
||||
// X-Forwarded-For 可能包含代理链
|
||||
xff := c.GetHeader("X-Forwarded-For")
|
||||
if xff != "" {
|
||||
// 从右到左遍历(最右边是最后一次代理添加的)
|
||||
for _, part := range strings.Split(xff, ",") {
|
||||
ip := strings.TrimSpace(part)
|
||||
if ip == "" {
|
||||
continue
|
||||
}
|
||||
// 检查是否是可信代理
|
||||
if !m.isTrustedProxy(ip) {
|
||||
continue // 不是可信代理,跳过
|
||||
}
|
||||
// 是可信代理,检查是否为公网 IP
|
||||
if !isPrivateIP(ip) {
|
||||
return ip
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// X-Real-IP(Nginx 反代常用)
|
||||
if xri := c.GetHeader("X-Real-IP"); xri != "" {
|
||||
return strings.TrimSpace(xri)
|
||||
}
|
||||
|
||||
// 直接 TCP 连接的 RemoteAddr(去掉端口号)
|
||||
ip, _, err := net.SplitHostPort(c.Request.RemoteAddr)
|
||||
if err != nil {
|
||||
return c.Request.RemoteAddr
|
||||
}
|
||||
return ip
|
||||
}
|
||||
|
||||
// isTrustedProxy 检查 IP 是否在可信代理列表中
|
||||
func (m *IPFilterMiddleware) isTrustedProxy(ip string) bool {
|
||||
if len(m.config.TrustedProxies) == 0 {
|
||||
return true // 如果没有配置可信代理列表,默认信任所有(兼容旧行为)
|
||||
}
|
||||
for _, trusted := range m.config.TrustedProxies {
|
||||
if ip == trusted {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// isPrivateIP 判断是否为内网 IP
|
||||
func isPrivateIP(ipStr string) bool {
|
||||
ip := net.ParseIP(ipStr)
|
||||
if ip == nil {
|
||||
return false
|
||||
}
|
||||
privateRanges := []string{
|
||||
"10.0.0.0/8",
|
||||
"172.16.0.0/12",
|
||||
"192.168.0.0/16",
|
||||
"127.0.0.0/8",
|
||||
"::1/128",
|
||||
"fc00::/7",
|
||||
}
|
||||
for _, cidr := range privateRanges {
|
||||
_, network, err := net.ParseCIDR(cidr)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if network.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
258
internal/api/middleware/ip_filter_test.go
Normal file
258
internal/api/middleware/ip_filter_test.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/user-management-system/internal/security"
|
||||
)
|
||||
|
||||
func init() {
|
||||
gin.SetMode(gin.TestMode)
|
||||
}
|
||||
|
||||
// newTestEngine 用给定的 IPFilterMiddleware 构建一个最简 Gin 引擎,
|
||||
// 注册一个 GET /ping 路由,返回 client_ip 值。
|
||||
func newTestEngine(f *security.IPFilter) *gin.Engine {
|
||||
engine := gin.New()
|
||||
engine.Use(NewIPFilterMiddleware(f, IPFilterConfig{}).Filter())
|
||||
engine.GET("/ping", func(c *gin.Context) {
|
||||
ip, _ := c.Get("client_ip")
|
||||
c.JSON(http.StatusOK, gin.H{"ip": ip})
|
||||
})
|
||||
return engine
|
||||
}
|
||||
|
||||
// doRequest 发送 GET /ping,返回响应码和响应 body map。
|
||||
func doRequest(engine *gin.Engine, remoteAddr, xff, xri string) (int, map[string]interface{}) {
|
||||
req := httptest.NewRequest(http.MethodGet, "/ping", nil)
|
||||
req.RemoteAddr = remoteAddr
|
||||
if xff != "" {
|
||||
req.Header.Set("X-Forwarded-For", xff)
|
||||
}
|
||||
if xri != "" {
|
||||
req.Header.Set("X-Real-IP", xri)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
engine.ServeHTTP(w, req)
|
||||
|
||||
var body map[string]interface{}
|
||||
_ = json.Unmarshal(w.Body.Bytes(), &body)
|
||||
return w.Code, body
|
||||
}
|
||||
|
||||
// ---------- 黑名单拦截 ----------
|
||||
|
||||
func TestIPFilter_BlockedIP_Returns403(t *testing.T) {
|
||||
f := security.NewIPFilter()
|
||||
_ = f.AddToBlacklist("1.2.3.4", "测试封禁", 0)
|
||||
|
||||
engine := newTestEngine(f)
|
||||
code, body := doRequest(engine, "1.2.3.4:9999", "", "")
|
||||
|
||||
if code != http.StatusForbidden {
|
||||
t.Fatalf("期望 403,实际 %d", code)
|
||||
}
|
||||
msg, _ := body["message"].(string)
|
||||
if msg == "" {
|
||||
t.Error("期望 body 中包含 message 字段")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPFilter_NonBlockedIP_Returns200(t *testing.T) {
|
||||
f := security.NewIPFilter()
|
||||
_ = f.AddToBlacklist("9.9.9.9", "其他 IP", 0)
|
||||
|
||||
engine := newTestEngine(f)
|
||||
code, _ := doRequest(engine, "1.2.3.4:9999", "", "")
|
||||
|
||||
if code != http.StatusOK {
|
||||
t.Fatalf("期望 200,实际 %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPFilter_EmptyBlacklist_AllPass(t *testing.T) {
|
||||
f := security.NewIPFilter()
|
||||
engine := newTestEngine(f)
|
||||
|
||||
for _, ip := range []string{"1.1.1.1:80", "8.8.8.8:443", "203.0.113.5:1234"} {
|
||||
code, _ := doRequest(engine, ip, "", "")
|
||||
if code != http.StatusOK {
|
||||
t.Errorf("IP %s 应通过,实际 %d", ip, code)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- 白名单豁免 ----------
|
||||
|
||||
func TestIPFilter_WhitelistOverridesBlacklist(t *testing.T) {
|
||||
f := security.NewIPFilter()
|
||||
_ = f.AddToBlacklist("5.5.5.5", "封禁测试", 0)
|
||||
_ = f.AddToWhitelist("5.5.5.5", "白名单豁免")
|
||||
|
||||
engine := newTestEngine(f)
|
||||
// 白名单优先,应通过
|
||||
code, _ := doRequest(engine, "5.5.5.5:8080", "", "")
|
||||
if code != http.StatusOK {
|
||||
t.Fatalf("白名单 IP 应返回 200,实际 %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- CIDR 黑名单 ----------
|
||||
|
||||
func TestIPFilter_CIDRBlacklist(t *testing.T) {
|
||||
f := security.NewIPFilter()
|
||||
_ = f.AddToBlacklist("10.10.10.0/24", "封禁整段 CIDR", 0)
|
||||
|
||||
engine := newTestEngine(f)
|
||||
|
||||
// 在 CIDR 范围内,应被封
|
||||
code, _ := doRequest(engine, "10.10.10.55:1234", "", "")
|
||||
if code != http.StatusForbidden {
|
||||
t.Fatalf("CIDR 内 IP 应返回 403,实际 %d", code)
|
||||
}
|
||||
|
||||
// 不在 CIDR 范围内,应通过
|
||||
code2, _ := doRequest(engine, "10.10.11.1:1234", "", "")
|
||||
if code2 != http.StatusOK {
|
||||
t.Fatalf("CIDR 外 IP 应返回 200,实际 %d", code2)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- 过期规则 ----------
|
||||
|
||||
func TestIPFilter_ExpiredRule_Passes(t *testing.T) {
|
||||
f := security.NewIPFilter()
|
||||
// 封禁 1 纳秒,几乎立即过期
|
||||
_ = f.AddToBlacklist("7.7.7.7", "即将过期", time.Nanosecond)
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
|
||||
engine := newTestEngine(f)
|
||||
code, _ := doRequest(engine, "7.7.7.7:80", "", "")
|
||||
if code != http.StatusOK {
|
||||
t.Fatalf("过期规则不应拦截,期望 200,实际 %d", code)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- client_ip 注入 ----------
|
||||
|
||||
func TestIPFilter_ClientIPSetInContext(t *testing.T) {
|
||||
f := security.NewIPFilter()
|
||||
engine := newTestEngine(f)
|
||||
|
||||
code, body := doRequest(engine, "203.0.113.1:9000", "", "")
|
||||
if code != http.StatusOK {
|
||||
t.Fatalf("期望 200,实际 %d", code)
|
||||
}
|
||||
ip, _ := body["ip"].(string)
|
||||
if ip != "203.0.113.1" {
|
||||
t.Errorf("期望 client_ip=203.0.113.1,实际 %q", ip)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- realIP 提取逻辑 ----------
|
||||
|
||||
// TestRealIP_XForwardedFor_PublicIP 公网 X-Forwarded-For 取第一个非内网 IP
|
||||
func TestRealIP_XForwardedFor_PublicIP(t *testing.T) {
|
||||
f := security.NewIPFilter()
|
||||
engine := newTestEngine(f)
|
||||
|
||||
// X-Forwarded-For: 203.0.113.10, 192.168.1.1(代理内网)
|
||||
code, body := doRequest(engine, "192.168.1.1:80", "203.0.113.10, 192.168.1.1", "")
|
||||
if code != http.StatusOK {
|
||||
t.Fatalf("期望 200,实际 %d", code)
|
||||
}
|
||||
ip, _ := body["ip"].(string)
|
||||
if ip != "203.0.113.10" {
|
||||
t.Errorf("期望从 X-Forwarded-For 取公网 IP,实际 %q", ip)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRealIP_XForwardedFor_AllPrivate 全内网则取第一个
|
||||
func TestRealIP_XForwardedFor_AllPrivate(t *testing.T) {
|
||||
f := security.NewIPFilter()
|
||||
engine := newTestEngine(f)
|
||||
|
||||
code, body := doRequest(engine, "10.0.0.2:80", "192.168.0.5, 10.0.0.1", "")
|
||||
if code != http.StatusOK {
|
||||
t.Fatalf("期望 200,实际 %d", code)
|
||||
}
|
||||
ip, _ := body["ip"].(string)
|
||||
if ip != "192.168.0.5" {
|
||||
t.Errorf("全内网时应取第一个,实际 %q", ip)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRealIP_XRealIP_Fallback X-Forwarded-For 缺失时使用 X-Real-IP
|
||||
func TestRealIP_XRealIP_Fallback(t *testing.T) {
|
||||
f := security.NewIPFilter()
|
||||
engine := newTestEngine(f)
|
||||
|
||||
code, body := doRequest(engine, "192.168.1.1:80", "", "203.0.113.20")
|
||||
if code != http.StatusOK {
|
||||
t.Fatalf("期望 200,实际 %d", code)
|
||||
}
|
||||
ip, _ := body["ip"].(string)
|
||||
if ip != "203.0.113.20" {
|
||||
t.Errorf("期望 X-Real-IP 回退,实际 %q", ip)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRealIP_RemoteAddr_Fallback 都无 header 时用 RemoteAddr
|
||||
func TestRealIP_RemoteAddr_Fallback(t *testing.T) {
|
||||
f := security.NewIPFilter()
|
||||
engine := newTestEngine(f)
|
||||
|
||||
code, body := doRequest(engine, "203.0.113.99:12345", "", "")
|
||||
if code != http.StatusOK {
|
||||
t.Fatalf("期望 200,实际 %d", code)
|
||||
}
|
||||
ip, _ := body["ip"].(string)
|
||||
if ip != "203.0.113.99" {
|
||||
t.Errorf("期望 RemoteAddr 回退,实际 %q", ip)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- GetFilter ----------
|
||||
|
||||
func TestIPFilterMiddleware_GetFilter(t *testing.T) {
|
||||
f := security.NewIPFilter()
|
||||
mw := NewIPFilterMiddleware(f, IPFilterConfig{})
|
||||
if mw.GetFilter() != f {
|
||||
t.Error("GetFilter 应返回同一个 IPFilter 实例")
|
||||
}
|
||||
}
|
||||
|
||||
// ---------- 并发安全 ----------
|
||||
|
||||
func TestIPFilter_ConcurrentRequests(t *testing.T) {
|
||||
f := security.NewIPFilter()
|
||||
_ = f.AddToBlacklist("66.66.66.66", "并发测试封禁", 0)
|
||||
engine := newTestEngine(f)
|
||||
|
||||
done := make(chan struct{}, 20)
|
||||
for i := 0; i < 20; i++ {
|
||||
go func(i int) {
|
||||
defer func() { done <- struct{}{} }()
|
||||
var remoteAddr string
|
||||
if i%2 == 0 {
|
||||
remoteAddr = "66.66.66.66:9000"
|
||||
} else {
|
||||
remoteAddr = "1.2.3.4:9000"
|
||||
}
|
||||
code, _ := doRequest(engine, remoteAddr, "", "")
|
||||
if i%2 == 0 && code != http.StatusForbidden {
|
||||
t.Errorf("并发:封禁 IP 应返回 403,实际 %d", code)
|
||||
} else if i%2 != 0 && code != http.StatusOK {
|
||||
t.Errorf("并发:正常 IP 应返回 200,实际 %d", code)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
for i := 0; i < 20; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
83
internal/api/middleware/logger.go
Normal file
83
internal/api/middleware/logger.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var sensitiveQueryKeys = map[string]struct{}{
|
||||
"token": {},
|
||||
"access_token": {},
|
||||
"refresh_token": {},
|
||||
"code": {},
|
||||
"secret": {},
|
||||
}
|
||||
|
||||
func Logger() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
start := time.Now()
|
||||
path := c.Request.URL.Path
|
||||
raw := sanitizeQuery(c.Request.URL.RawQuery)
|
||||
|
||||
c.Next()
|
||||
|
||||
latency := time.Since(start)
|
||||
status := c.Writer.Status()
|
||||
method := c.Request.Method
|
||||
ip := c.ClientIP()
|
||||
userAgent := c.Request.UserAgent()
|
||||
userID, _ := c.Get("user_id")
|
||||
|
||||
log.Printf("[API] %s %s %s | status: %d | latency: %v | ip: %s | user_id: %v | ua: %s",
|
||||
time.Now().Format("2006-01-02 15:04:05"),
|
||||
method,
|
||||
path,
|
||||
status,
|
||||
latency,
|
||||
ip,
|
||||
userID,
|
||||
userAgent,
|
||||
)
|
||||
|
||||
if len(c.Errors) > 0 {
|
||||
for _, err := range c.Errors {
|
||||
log.Printf("[Error] %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if raw != "" {
|
||||
log.Printf("[Query] %s?%s", path, raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeQuery(raw string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
values, err := url.ParseQuery(raw)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
for key := range values {
|
||||
if isSensitiveQueryKey(key) {
|
||||
values.Set(key, "***")
|
||||
}
|
||||
}
|
||||
|
||||
return values.Encode()
|
||||
}
|
||||
|
||||
func isSensitiveQueryKey(key string) bool {
|
||||
normalized := strings.ToLower(strings.TrimSpace(key))
|
||||
if _, ok := sensitiveQueryKeys[normalized]; ok {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(normalized, "token") || strings.Contains(normalized, "secret")
|
||||
}
|
||||
125
internal/api/middleware/operation_log.go
Normal file
125
internal/api/middleware/operation_log.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
)
|
||||
|
||||
type OperationLogMiddleware struct {
|
||||
repo *repository.OperationLogRepository
|
||||
}
|
||||
|
||||
func NewOperationLogMiddleware(repo *repository.OperationLogRepository) *OperationLogMiddleware {
|
||||
return &OperationLogMiddleware{repo: repo}
|
||||
}
|
||||
|
||||
type bodyWriter struct {
|
||||
gin.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func newBodyWriter(w gin.ResponseWriter) *bodyWriter {
|
||||
return &bodyWriter{ResponseWriter: w, statusCode: 200}
|
||||
}
|
||||
|
||||
func (bw *bodyWriter) WriteHeader(code int) {
|
||||
bw.statusCode = code
|
||||
bw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (bw *bodyWriter) WriteHeaderNow() {
|
||||
bw.ResponseWriter.WriteHeaderNow()
|
||||
}
|
||||
|
||||
func (m *OperationLogMiddleware) Record() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
method := c.Request.Method
|
||||
if method == "GET" || method == "HEAD" || method == "OPTIONS" {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
var reqParams string
|
||||
if c.Request.Body != nil {
|
||||
bodyBytes, err := io.ReadAll(io.LimitReader(c.Request.Body, 4096))
|
||||
if err == nil {
|
||||
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||
reqParams = sanitizeParams(bodyBytes)
|
||||
}
|
||||
}
|
||||
|
||||
bw := newBodyWriter(c.Writer)
|
||||
c.Writer = bw
|
||||
|
||||
c.Next()
|
||||
|
||||
var userIDPtr *int64
|
||||
if uid, exists := c.Get("user_id"); exists {
|
||||
if id, ok := uid.(int64); ok {
|
||||
userID := id
|
||||
userIDPtr = &userID
|
||||
}
|
||||
}
|
||||
|
||||
logEntry := &domain.OperationLog{
|
||||
UserID: userIDPtr,
|
||||
OperationType: methodToType(method),
|
||||
OperationName: c.FullPath(),
|
||||
RequestMethod: method,
|
||||
RequestPath: c.Request.URL.Path,
|
||||
RequestParams: reqParams,
|
||||
ResponseStatus: bw.statusCode,
|
||||
IP: c.ClientIP(),
|
||||
UserAgent: c.Request.UserAgent(),
|
||||
}
|
||||
|
||||
go func(entry *domain.OperationLog) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
_ = m.repo.Create(ctx, entry)
|
||||
}(logEntry)
|
||||
}
|
||||
}
|
||||
|
||||
func methodToType(method string) string {
|
||||
switch method {
|
||||
case "POST":
|
||||
return "CREATE"
|
||||
case "PUT", "PATCH":
|
||||
return "UPDATE"
|
||||
case "DELETE":
|
||||
return "DELETE"
|
||||
default:
|
||||
return "OTHER"
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeParams(data []byte) string {
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(data, &payload); err != nil {
|
||||
if len(data) > 500 {
|
||||
return string(data[:500]) + "..."
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
for _, field := range []string{"password", "old_password", "new_password", "confirm_password", "secret", "token"} {
|
||||
if _, ok := payload[field]; ok {
|
||||
payload[field] = "***"
|
||||
}
|
||||
}
|
||||
|
||||
result, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(result)
|
||||
}
|
||||
127
internal/api/middleware/ratelimit.go
Normal file
127
internal/api/middleware/ratelimit.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/user-management-system/internal/config"
|
||||
)
|
||||
|
||||
// RateLimitMiddleware 限流中间件
|
||||
type RateLimitMiddleware struct {
|
||||
cfg config.RateLimitConfig
|
||||
limiters map[string]*SlidingWindowLimiter
|
||||
mu sync.RWMutex
|
||||
cleanupInt time.Duration
|
||||
}
|
||||
|
||||
// SlidingWindowLimiter 滑动窗口限流器
|
||||
type SlidingWindowLimiter struct {
|
||||
mu sync.Mutex
|
||||
window time.Duration
|
||||
capacity int64
|
||||
requests []int64
|
||||
}
|
||||
|
||||
// NewSlidingWindowLimiter 创建滑动窗口限流器
|
||||
func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter {
|
||||
return &SlidingWindowLimiter{
|
||||
window: window,
|
||||
capacity: capacity,
|
||||
requests: make([]int64, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Allow 检查是否允许请求
|
||||
func (l *SlidingWindowLimiter) Allow() bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
|
||||
now := time.Now().UnixMilli()
|
||||
cutoff := now - l.window.Milliseconds()
|
||||
|
||||
// 清理过期请求
|
||||
var validRequests []int64
|
||||
for _, t := range l.requests {
|
||||
if t > cutoff {
|
||||
validRequests = append(validRequests, t)
|
||||
}
|
||||
}
|
||||
l.requests = validRequests
|
||||
|
||||
// 检查容量
|
||||
if int64(len(l.requests)) >= l.capacity {
|
||||
return false
|
||||
}
|
||||
|
||||
l.requests = append(l.requests, now)
|
||||
return true
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
cfg: cfg,
|
||||
limiters: make(map[string]*SlidingWindowLimiter),
|
||||
cleanupInt: 5 * time.Minute,
|
||||
}
|
||||
}
|
||||
|
||||
// Register 返回注册接口的限流中间件
|
||||
func (m *RateLimitMiddleware) Register() gin.HandlerFunc {
|
||||
return m.limitForKey("register", 60, 10)
|
||||
}
|
||||
|
||||
// Login 返回登录接口的限流中间件
|
||||
func (m *RateLimitMiddleware) Login() gin.HandlerFunc {
|
||||
return m.limitForKey("login", 60, 5)
|
||||
}
|
||||
|
||||
// API 返回 API 接口的限流中间件
|
||||
func (m *RateLimitMiddleware) API() gin.HandlerFunc {
|
||||
return m.limitForKey("api", 60, 100)
|
||||
}
|
||||
|
||||
// Refresh 返回刷新令牌的限流中间件
|
||||
func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
|
||||
return m.limitForKey("refresh", 60, 10)
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc {
|
||||
limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
if !limiter.Allow() {
|
||||
c.JSON(429, gin.H{
|
||||
"code": 429,
|
||||
"message": "请求过于频繁,请稍后再试",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
|
||||
m.mu.RLock()
|
||||
limiter, exists := m.limiters[key]
|
||||
m.mu.RUnlock()
|
||||
|
||||
if exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// 双重检查
|
||||
if limiter, exists = m.limiters[key]; exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
limiter = NewSlidingWindowLimiter(window, capacity)
|
||||
m.limiters[key] = limiter
|
||||
return limiter
|
||||
}
|
||||
156
internal/api/middleware/rbac.go
Normal file
156
internal/api/middleware/rbac.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// contextKey 上下文键常量
|
||||
const (
|
||||
ContextKeyRoleCodes = "role_codes"
|
||||
ContextKeyPermissionCodes = "permission_codes"
|
||||
)
|
||||
|
||||
// RequirePermission 要求用户拥有指定权限之一(OR 逻辑)
|
||||
// 适用于需要单个或多选权限校验的路由
|
||||
func RequirePermission(codes ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !hasAnyPermission(c, codes) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"code": 403,
|
||||
"message": "权限不足",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAllPermissions 要求用户拥有所有指定权限(AND 逻辑)
|
||||
func RequireAllPermissions(codes ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !hasAllPermissions(c, codes) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"code": 403,
|
||||
"message": "权限不足,需要所有指定权限",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireRole 要求用户拥有指定角色之一(OR 逻辑)
|
||||
func RequireRole(codes ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if !hasAnyRole(c, codes) {
|
||||
c.JSON(http.StatusForbidden, gin.H{
|
||||
"code": 403,
|
||||
"message": "权限不足,角色受限",
|
||||
})
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
// RequireAnyPermission RequirePermission 的别名,语义更清晰
|
||||
func RequireAnyPermission(codes ...string) gin.HandlerFunc {
|
||||
return RequirePermission(codes...)
|
||||
}
|
||||
|
||||
// AdminOnly 仅限 admin 角色
|
||||
func AdminOnly() gin.HandlerFunc {
|
||||
return RequireRole("admin")
|
||||
}
|
||||
|
||||
// GetRoleCodes 从 Context 获取当前用户角色代码列表
|
||||
func GetRoleCodes(c *gin.Context) []string {
|
||||
val, exists := c.Get(ContextKeyRoleCodes)
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
if codes, ok := val.([]string); ok {
|
||||
return codes
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPermissionCodes 从 Context 获取当前用户权限代码列表
|
||||
func GetPermissionCodes(c *gin.Context) []string {
|
||||
val, exists := c.Get(ContextKeyPermissionCodes)
|
||||
if !exists {
|
||||
return nil
|
||||
}
|
||||
if codes, ok := val.([]string); ok {
|
||||
return codes
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsAdmin 判断当前用户是否为 admin
|
||||
func IsAdmin(c *gin.Context) bool {
|
||||
return hasAnyRole(c, []string{"admin"})
|
||||
}
|
||||
|
||||
// hasAnyPermission 判断用户是否拥有任意一个权限
|
||||
func hasAnyPermission(c *gin.Context, codes []string) bool {
|
||||
// admin 角色拥有所有权限
|
||||
if IsAdmin(c) {
|
||||
return true
|
||||
}
|
||||
permCodes := GetPermissionCodes(c)
|
||||
if len(permCodes) == 0 {
|
||||
return false
|
||||
}
|
||||
permSet := toSet(permCodes)
|
||||
for _, code := range codes {
|
||||
if _, ok := permSet[code]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasAllPermissions 判断用户是否拥有所有权限
|
||||
func hasAllPermissions(c *gin.Context, codes []string) bool {
|
||||
if IsAdmin(c) {
|
||||
return true
|
||||
}
|
||||
permCodes := GetPermissionCodes(c)
|
||||
permSet := toSet(permCodes)
|
||||
for _, code := range codes {
|
||||
if _, ok := permSet[code]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// hasAnyRole 判断用户是否拥有任意一个角色
|
||||
func hasAnyRole(c *gin.Context, codes []string) bool {
|
||||
roleCodes := GetRoleCodes(c)
|
||||
if len(roleCodes) == 0 {
|
||||
return false
|
||||
}
|
||||
roleSet := toSet(roleCodes)
|
||||
for _, code := range codes {
|
||||
if _, ok := roleSet[code]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// toSet 将字符串切片转换为 map 集合
|
||||
func toSet(items []string) map[string]struct{} {
|
||||
s := make(map[string]struct{}, len(items))
|
||||
for _, item := range items {
|
||||
s[item] = struct{}{}
|
||||
}
|
||||
return s
|
||||
}
|
||||
139
internal/api/middleware/runtime_test.go
Normal file
139
internal/api/middleware/runtime_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/user-management-system/internal/config"
|
||||
)
|
||||
|
||||
func TestCORS_UsesConfiguredOrigins(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
SetCORSConfig(config.CORSConfig{
|
||||
AllowedOrigins: []string{"https://app.example.com"},
|
||||
AllowCredentials: true,
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
SetCORSConfig(config.CORSConfig{
|
||||
AllowedOrigins: []string{"*"},
|
||||
AllowCredentials: true,
|
||||
})
|
||||
})
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodOptions, "/api/v1/users", nil)
|
||||
c.Request.Header.Set("Origin", "https://app.example.com")
|
||||
c.Request.Header.Set("Access-Control-Request-Headers", "Authorization")
|
||||
|
||||
CORS()(c)
|
||||
|
||||
if recorder.Code != http.StatusNoContent {
|
||||
t.Fatalf("expected 204, got %d", recorder.Code)
|
||||
}
|
||||
if got := recorder.Header().Get("Access-Control-Allow-Origin"); got != "https://app.example.com" {
|
||||
t.Fatalf("unexpected allow origin: %s", got)
|
||||
}
|
||||
if got := recorder.Header().Get("Access-Control-Allow-Credentials"); got != "true" {
|
||||
t.Fatalf("expected credentials header to be 'true', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSanitizeQuery_MasksSensitiveValues(t *testing.T) {
|
||||
raw := "token=abc123&foo=bar&access_token=xyz&secret=s1"
|
||||
sanitized := sanitizeQuery(raw)
|
||||
|
||||
if sanitized == "" {
|
||||
t.Fatal("expected sanitized query")
|
||||
}
|
||||
if sanitized == raw {
|
||||
t.Fatal("expected query to be sanitized")
|
||||
}
|
||||
for _, value := range []string{"abc123", "xyz", "s1"} {
|
||||
if strings.Contains(sanitized, value) {
|
||||
t.Fatalf("expected sensitive value %q to be masked in %q", value, sanitized)
|
||||
}
|
||||
}
|
||||
if sanitizeQuery("") != "" {
|
||||
t.Fatal("expected empty query to stay empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_AttachesExpectedHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
||||
|
||||
SecurityHeaders()(c)
|
||||
|
||||
if got := recorder.Header().Get("X-Content-Type-Options"); got != "nosniff" {
|
||||
t.Fatalf("unexpected nosniff header: %q", got)
|
||||
}
|
||||
if got := recorder.Header().Get("X-Frame-Options"); got != "DENY" {
|
||||
t.Fatalf("unexpected frame options: %q", got)
|
||||
}
|
||||
if got := recorder.Header().Get("Content-Security-Policy"); got == "" {
|
||||
t.Fatal("expected content security policy header")
|
||||
}
|
||||
if got := recorder.Header().Get("Strict-Transport-Security"); got != "" {
|
||||
t.Fatalf("did not expect hsts header for http request, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSecurityHeaders_AttachesHSTSForForwardedHTTPS(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
||||
c.Request.Header.Set("X-Forwarded-Proto", "https")
|
||||
|
||||
SecurityHeaders()(c)
|
||||
|
||||
if got := recorder.Header().Get("Strict-Transport-Security"); !strings.Contains(got, "max-age=31536000") {
|
||||
t.Fatalf("expected hsts header, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoStoreSensitiveResponses_AttachesExpectedHeadersToAuthRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/auth/capabilities", nil)
|
||||
|
||||
NoStoreSensitiveResponses()(c)
|
||||
|
||||
if got := recorder.Header().Get("Cache-Control"); got != sensitiveNoStoreCacheControl {
|
||||
t.Fatalf("unexpected cache-control header: %q", got)
|
||||
}
|
||||
if got := recorder.Header().Get("Pragma"); got != "no-cache" {
|
||||
t.Fatalf("unexpected pragma header: %q", got)
|
||||
}
|
||||
if got := recorder.Header().Get("Expires"); got != "0" {
|
||||
t.Fatalf("unexpected expires header: %q", got)
|
||||
}
|
||||
if got := recorder.Header().Get("Surrogate-Control"); got != "no-store" {
|
||||
t.Fatalf("unexpected surrogate-control header: %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNoStoreSensitiveResponses_DoesNotAttachHeadersToNonAuthRoutes(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/api/v1/users", nil)
|
||||
|
||||
NoStoreSensitiveResponses()(c)
|
||||
|
||||
if got := recorder.Header().Get("Cache-Control"); got != "" {
|
||||
t.Fatalf("did not expect cache-control header, got %q", got)
|
||||
}
|
||||
}
|
||||
45
internal/api/middleware/security_headers.go
Normal file
45
internal/api/middleware/security_headers.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
const contentSecurityPolicy = "default-src 'none'; frame-ancestors 'none'; base-uri 'none'; form-action 'self'"
|
||||
|
||||
func SecurityHeaders() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
headers := c.Writer.Header()
|
||||
headers.Set("X-Content-Type-Options", "nosniff")
|
||||
headers.Set("X-Frame-Options", "DENY")
|
||||
headers.Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
||||
headers.Set("Permissions-Policy", "camera=(), microphone=(), geolocation=()")
|
||||
headers.Set("Cross-Origin-Opener-Policy", "same-origin")
|
||||
headers.Set("X-Permitted-Cross-Domain-Policies", "none")
|
||||
|
||||
if shouldAttachCSP(c.FullPath(), c.Request.URL.Path) {
|
||||
headers.Set("Content-Security-Policy", contentSecurityPolicy)
|
||||
}
|
||||
if isHTTPSRequest(c) {
|
||||
headers.Set("Strict-Transport-Security", "max-age=31536000; includeSubDomains")
|
||||
}
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func shouldAttachCSP(routePath, requestPath string) bool {
|
||||
path := strings.TrimSpace(routePath)
|
||||
if path == "" {
|
||||
path = strings.TrimSpace(requestPath)
|
||||
}
|
||||
return !strings.HasPrefix(path, "/swagger/")
|
||||
}
|
||||
|
||||
func isHTTPSRequest(c *gin.Context) bool {
|
||||
if c.Request.TLS != nil {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")), "https")
|
||||
}
|
||||
Reference in New Issue
Block a user