From 7eb5f9c7d46df5e99709db285b32a26e129a3a67 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 28 May 2026 16:53:33 +0800 Subject: [PATCH] fix: fail closed on invalid cors config --- internal/api/middleware/cors.go | 21 +++++++------ internal/api/middleware/runtime_test.go | 40 +++++++++++++++++++++---- internal/server/server.go | 4 ++- 3 files changed, 49 insertions(+), 16 deletions(-) diff --git a/internal/api/middleware/cors.go b/internal/api/middleware/cors.go index 328c192..4ca289e 100644 --- a/internal/api/middleware/cors.go +++ b/internal/api/middleware/cors.go @@ -1,6 +1,7 @@ package middleware import ( + "errors" "net/http" "strings" @@ -11,22 +12,24 @@ import ( var corsConfig = config.CORSConfig{ AllowedOrigins: []string{}, // 默认为空,必须显式配置 - AllowCredentials: false, // 默认关闭凭证,必须显式启用 + AllowCredentials: false, // 默认关闭凭证,必须显式启用 } -// init 在包初始化时检测危险的 CORS 配置组合 -func init() { - // 检测危险的通配符 + Credentials 组合 - for _, origin := range corsConfig.AllowedOrigins { - if origin == "*" && corsConfig.AllowCredentials { - panic("CORS 配置错误: AllowedOrigins 包含 '*' 且 AllowCredentials 为 true 是危险组合") +func validateCORSConfig(cfg config.CORSConfig) error { + for _, origin := range cfg.AllowedOrigins { + if origin == "*" && cfg.AllowCredentials { + return errors.New("CORS 配置错误: AllowedOrigins 包含 '*' 时不能启用 AllowCredentials") } } + return nil } -func SetCORSConfig(cfg config.CORSConfig) { - // 注意:显式配置危险组合时不会panic,但生产环境应避免使用 +func SetCORSConfig(cfg config.CORSConfig) error { + if err := validateCORSConfig(cfg); err != nil { + return err + } corsConfig = cfg + return nil } func CORS() gin.HandlerFunc { diff --git a/internal/api/middleware/runtime_test.go b/internal/api/middleware/runtime_test.go index 79e65bf..d8ac537 100644 --- a/internal/api/middleware/runtime_test.go +++ b/internal/api/middleware/runtime_test.go @@ -14,15 +14,16 @@ import ( func TestCORS_UsesConfiguredOrigins(t *testing.T) { gin.SetMode(gin.TestMode) - SetCORSConfig(config.CORSConfig{ + if err := SetCORSConfig(config.CORSConfig{ AllowedOrigins: []string{"https://app.example.com"}, AllowCredentials: true, - }) + }); err != nil { + t.Fatalf("SetCORSConfig should accept explicit origin with credentials: %v", err) + } t.Cleanup(func() { - SetCORSConfig(config.CORSConfig{ - AllowedOrigins: []string{"*"}, - AllowCredentials: true, - }) + if err := SetCORSConfig(config.CORSConfig{}); err != nil { + t.Fatalf("reset cors config failed: %v", err) + } }) recorder := httptest.NewRecorder() @@ -44,6 +45,33 @@ func TestCORS_UsesConfiguredOrigins(t *testing.T) { } } +func TestSetCORSConfig_RejectsWildcardWithCredentials(t *testing.T) { + gin.SetMode(gin.TestMode) + if err := SetCORSConfig(config.CORSConfig{}); err != nil { + t.Fatalf("failed to initialize baseline cors config: %v", err) + } + + err := SetCORSConfig(config.CORSConfig{ + AllowedOrigins: []string{"*"}, + AllowCredentials: true, + }) + if err == nil { + t.Fatal("expected wildcard+credentials cors config to be rejected") + } + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodOptions, "/api/v1/users", nil) + c.Request.Header.Set("Origin", "https://evil.example.com") + c.Request.Header.Set("Access-Control-Request-Headers", "Authorization") + + CORS()(c) + + if recorder.Code != http.StatusForbidden { + t.Fatalf("expected previous safe config to remain active and reject origin, got %d", recorder.Code) + } +} + func TestSanitizeQuery_MasksSensitiveValues(t *testing.T) { raw := "token=abc123&foo=bar&access_token=xyz&secret=s1" sanitized := sanitizeQuery(raw) diff --git a/internal/server/server.go b/internal/server/server.go index c586673..5d676ef 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -137,7 +137,9 @@ func Serve(cfg *config.Config) error { themeService := service.NewThemeService(themeRepo) // 设置 CORS 配置 - middleware.SetCORSConfig(cfg.CORS) + if err := middleware.SetCORSConfig(cfg.CORS); err != nil { + return fmt.Errorf("invalid cors config: %w", err) + } // 初始化中间件 rateLimitMiddleware := middleware.NewRateLimitMiddleware(cfg.RateLimit)