package middleware import ( "net/http" "strings" "github.com/gin-gonic/gin" "github.com/user-management-system/internal/config" ) var corsConfig = config.CORSConfig{ AllowedOrigins: []string{}, // 默认为空,必须显式配置 AllowCredentials: false, // 默认关闭凭证,必须显式启用 } // init 在包初始化时检测危险的 CORS 配置组合 func init() { // 检测危险的通配符 + Credentials 组合 for _, origin := range corsConfig.AllowedOrigins { if origin == "*" && corsConfig.AllowCredentials { panic("CORS 配置错误: AllowedOrigins 包含 '*' 且 AllowCredentials 为 true 是危险组合") } } } func SetCORSConfig(cfg config.CORSConfig) { // 注意:显式配置危险组合时不会panic,但生产环境应避免使用 corsConfig = cfg } func CORS() gin.HandlerFunc { return func(c *gin.Context) { cfg := corsConfig origin := c.GetHeader("Origin") if origin != "" { allowOrigin, allowed := resolveAllowedOrigin(origin, cfg.AllowedOrigins, cfg.AllowCredentials) if !allowed { if c.Request.Method == http.MethodOptions { c.AbortWithStatus(http.StatusForbidden) return } c.AbortWithStatus(http.StatusForbidden) return } c.Writer.Header().Set("Access-Control-Allow-Origin", allowOrigin) if cfg.AllowCredentials { c.Writer.Header().Set("Access-Control-Allow-Credentials", "true") } } if c.Request.Method == http.MethodOptions { c.Writer.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Writer.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type, X-Requested-With, X-CSRF-Token") c.Writer.Header().Set("Access-Control-Max-Age", "3600") c.AbortWithStatus(http.StatusNoContent) return } c.Next() } } func resolveAllowedOrigin(origin string, allowedOrigins []string, allowCredentials bool) (string, bool) { for _, allowed := range allowedOrigins { if allowed == "*" { if allowCredentials { return origin, true } return "*", true } if strings.EqualFold(origin, allowed) { return origin, true } } return "", false }