fix: fail closed on invalid cors config

This commit is contained in:
Your Name
2026-05-28 16:53:33 +08:00
parent 547fdab0b2
commit 7eb5f9c7d4
3 changed files with 49 additions and 16 deletions

View File

@@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"errors"
"net/http" "net/http"
"strings" "strings"
@@ -11,22 +12,24 @@ import (
var corsConfig = config.CORSConfig{ var corsConfig = config.CORSConfig{
AllowedOrigins: []string{}, // 默认为空,必须显式配置 AllowedOrigins: []string{}, // 默认为空,必须显式配置
AllowCredentials: false, // 默认关闭凭证,必须显式启用 AllowCredentials: false, // 默认关闭凭证,必须显式启用
} }
// init 在包初始化时检测危险的 CORS 配置组合 func validateCORSConfig(cfg config.CORSConfig) error {
func init() { for _, origin := range cfg.AllowedOrigins {
// 检测危险的通配符 + Credentials 组合 if origin == "*" && cfg.AllowCredentials {
for _, origin := range corsConfig.AllowedOrigins { return errors.New("CORS 配置错误: AllowedOrigins 包含 '*' 时不能启用 AllowCredentials")
if origin == "*" && corsConfig.AllowCredentials {
panic("CORS 配置错误: AllowedOrigins 包含 '*' 且 AllowCredentials 为 true 是危险组合")
} }
} }
return nil
} }
func SetCORSConfig(cfg config.CORSConfig) { func SetCORSConfig(cfg config.CORSConfig) error {
// 注意显式配置危险组合时不会panic但生产环境应避免使用 if err := validateCORSConfig(cfg); err != nil {
return err
}
corsConfig = cfg corsConfig = cfg
return nil
} }
func CORS() gin.HandlerFunc { func CORS() gin.HandlerFunc {

View File

@@ -14,15 +14,16 @@ import (
func TestCORS_UsesConfiguredOrigins(t *testing.T) { func TestCORS_UsesConfiguredOrigins(t *testing.T) {
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
SetCORSConfig(config.CORSConfig{ if err := SetCORSConfig(config.CORSConfig{
AllowedOrigins: []string{"https://app.example.com"}, AllowedOrigins: []string{"https://app.example.com"},
AllowCredentials: true, AllowCredentials: true,
}) }); err != nil {
t.Fatalf("SetCORSConfig should accept explicit origin with credentials: %v", err)
}
t.Cleanup(func() { t.Cleanup(func() {
SetCORSConfig(config.CORSConfig{ if err := SetCORSConfig(config.CORSConfig{}); err != nil {
AllowedOrigins: []string{"*"}, t.Fatalf("reset cors config failed: %v", err)
AllowCredentials: true, }
})
}) })
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -44,6 +45,33 @@ func TestCORS_UsesConfiguredOrigins(t *testing.T) {
} }
} }
func TestSetCORSConfig_RejectsWildcardWithCredentials(t *testing.T) {
gin.SetMode(gin.TestMode)
if err := SetCORSConfig(config.CORSConfig{}); err != nil {
t.Fatalf("failed to initialize baseline cors config: %v", err)
}
err := SetCORSConfig(config.CORSConfig{
AllowedOrigins: []string{"*"},
AllowCredentials: true,
})
if err == nil {
t.Fatal("expected wildcard+credentials cors config to be rejected")
}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodOptions, "/api/v1/users", nil)
c.Request.Header.Set("Origin", "https://evil.example.com")
c.Request.Header.Set("Access-Control-Request-Headers", "Authorization")
CORS()(c)
if recorder.Code != http.StatusForbidden {
t.Fatalf("expected previous safe config to remain active and reject origin, got %d", recorder.Code)
}
}
func TestSanitizeQuery_MasksSensitiveValues(t *testing.T) { func TestSanitizeQuery_MasksSensitiveValues(t *testing.T) {
raw := "token=abc123&foo=bar&access_token=xyz&secret=s1" raw := "token=abc123&foo=bar&access_token=xyz&secret=s1"
sanitized := sanitizeQuery(raw) sanitized := sanitizeQuery(raw)

View File

@@ -137,7 +137,9 @@ func Serve(cfg *config.Config) error {
themeService := service.NewThemeService(themeRepo) themeService := service.NewThemeService(themeRepo)
// 设置 CORS 配置 // 设置 CORS 配置
middleware.SetCORSConfig(cfg.CORS) if err := middleware.SetCORSConfig(cfg.CORS); err != nil {
return fmt.Errorf("invalid cors config: %w", err)
}
// 初始化中间件 // 初始化中间件
rateLimitMiddleware := middleware.NewRateLimitMiddleware(cfg.RateLimit) rateLimitMiddleware := middleware.NewRateLimitMiddleware(cfg.RateLimit)