fix: fail closed on invalid cors config
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -11,22 +12,24 @@ import (
|
|||||||
|
|
||||||
var corsConfig = config.CORSConfig{
|
var corsConfig = config.CORSConfig{
|
||||||
AllowedOrigins: []string{}, // 默认为空,必须显式配置
|
AllowedOrigins: []string{}, // 默认为空,必须显式配置
|
||||||
AllowCredentials: false, // 默认关闭凭证,必须显式启用
|
AllowCredentials: false, // 默认关闭凭证,必须显式启用
|
||||||
}
|
}
|
||||||
|
|
||||||
// init 在包初始化时检测危险的 CORS 配置组合
|
func validateCORSConfig(cfg config.CORSConfig) error {
|
||||||
func init() {
|
for _, origin := range cfg.AllowedOrigins {
|
||||||
// 检测危险的通配符 + Credentials 组合
|
if origin == "*" && cfg.AllowCredentials {
|
||||||
for _, origin := range corsConfig.AllowedOrigins {
|
return errors.New("CORS 配置错误: AllowedOrigins 包含 '*' 时不能启用 AllowCredentials")
|
||||||
if origin == "*" && corsConfig.AllowCredentials {
|
|
||||||
panic("CORS 配置错误: AllowedOrigins 包含 '*' 且 AllowCredentials 为 true 是危险组合")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetCORSConfig(cfg config.CORSConfig) {
|
func SetCORSConfig(cfg config.CORSConfig) error {
|
||||||
// 注意:显式配置危险组合时不会panic,但生产环境应避免使用
|
if err := validateCORSConfig(cfg); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
corsConfig = cfg
|
corsConfig = cfg
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CORS() gin.HandlerFunc {
|
func CORS() gin.HandlerFunc {
|
||||||
|
|||||||
@@ -14,15 +14,16 @@ import (
|
|||||||
|
|
||||||
func TestCORS_UsesConfiguredOrigins(t *testing.T) {
|
func TestCORS_UsesConfiguredOrigins(t *testing.T) {
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
SetCORSConfig(config.CORSConfig{
|
if err := SetCORSConfig(config.CORSConfig{
|
||||||
AllowedOrigins: []string{"https://app.example.com"},
|
AllowedOrigins: []string{"https://app.example.com"},
|
||||||
AllowCredentials: true,
|
AllowCredentials: true,
|
||||||
})
|
}); err != nil {
|
||||||
|
t.Fatalf("SetCORSConfig should accept explicit origin with credentials: %v", err)
|
||||||
|
}
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
SetCORSConfig(config.CORSConfig{
|
if err := SetCORSConfig(config.CORSConfig{}); err != nil {
|
||||||
AllowedOrigins: []string{"*"},
|
t.Fatalf("reset cors config failed: %v", err)
|
||||||
AllowCredentials: true,
|
}
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
@@ -44,6 +45,33 @@ func TestCORS_UsesConfiguredOrigins(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetCORSConfig_RejectsWildcardWithCredentials(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
if err := SetCORSConfig(config.CORSConfig{}); err != nil {
|
||||||
|
t.Fatalf("failed to initialize baseline cors config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := SetCORSConfig(config.CORSConfig{
|
||||||
|
AllowedOrigins: []string{"*"},
|
||||||
|
AllowCredentials: true,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected wildcard+credentials cors config to be rejected")
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodOptions, "/api/v1/users", nil)
|
||||||
|
c.Request.Header.Set("Origin", "https://evil.example.com")
|
||||||
|
c.Request.Header.Set("Access-Control-Request-Headers", "Authorization")
|
||||||
|
|
||||||
|
CORS()(c)
|
||||||
|
|
||||||
|
if recorder.Code != http.StatusForbidden {
|
||||||
|
t.Fatalf("expected previous safe config to remain active and reject origin, got %d", recorder.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSanitizeQuery_MasksSensitiveValues(t *testing.T) {
|
func TestSanitizeQuery_MasksSensitiveValues(t *testing.T) {
|
||||||
raw := "token=abc123&foo=bar&access_token=xyz&secret=s1"
|
raw := "token=abc123&foo=bar&access_token=xyz&secret=s1"
|
||||||
sanitized := sanitizeQuery(raw)
|
sanitized := sanitizeQuery(raw)
|
||||||
|
|||||||
@@ -137,7 +137,9 @@ func Serve(cfg *config.Config) error {
|
|||||||
themeService := service.NewThemeService(themeRepo)
|
themeService := service.NewThemeService(themeRepo)
|
||||||
|
|
||||||
// 设置 CORS 配置
|
// 设置 CORS 配置
|
||||||
middleware.SetCORSConfig(cfg.CORS)
|
if err := middleware.SetCORSConfig(cfg.CORS); err != nil {
|
||||||
|
return fmt.Errorf("invalid cors config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// 初始化中间件
|
// 初始化中间件
|
||||||
rateLimitMiddleware := middleware.NewRateLimitMiddleware(cfg.RateLimit)
|
rateLimitMiddleware := middleware.NewRateLimitMiddleware(cfg.RateLimit)
|
||||||
|
|||||||
Reference in New Issue
Block a user