fix(gateway): fail closed on secret and cors defaults

This commit is contained in:
Your Name
2026-04-17 20:00:43 +08:00
parent cccb76b72b
commit 0b8de726a8
3 changed files with 129 additions and 24 deletions

View File

@@ -24,6 +24,9 @@ func BuildServer(cfg *config.Config) (*http.Server, error) {
if err := config.ValidateAuthConfig(normalized.Auth); err != nil {
return nil, err
}
if err := validateStartupSecurity(normalized); err != nil {
return nil, err
}
r, err := buildRouter(&normalized)
if err != nil {
@@ -236,28 +239,14 @@ func normalizeConfig(cfg config.Config) config.Config {
}
}
}
// P0-1: Fail startup in production if encryption key is not explicitly set
if strings.EqualFold(cfg.Auth.Env, "production") || strings.EqualFold(cfg.Auth.Env, "prod") || strings.EqualFold(cfg.Auth.Env, "online") {
if _, isDefault := checkEncryptionKeyIsDefault(); isDefault {
panic("FATAL: PASSWORD_ENCRYPTION_KEY environment variable must be explicitly set in production environment. Using the default key is not allowed.")
}
}
return cfg
}
// buildCORSConfig builds CORS config from normalized config
// In production (Env=production/prod/online), rejects wildcard if CORSAllowOrigins not explicitly set
func buildCORSConfig(cfg config.Config) middleware.CORSConfig {
corsOrigins := cfg.Auth.CORSAllowOrigins
if len(corsOrigins) == 0 {
corsOrigins = []string{"*"}
}
// P0-2: Warn in production if using wildcard
if strings.EqualFold(cfg.Auth.Env, "production") || strings.EqualFold(cfg.Auth.Env, "prod") || strings.EqualFold(cfg.Auth.Env, "online") {
if len(corsOrigins) == 1 && corsOrigins[0] == "*" {
panic("FATAL: CORS_ALLOW_ORIGINS must be explicitly set in production environment. Using wildcard '*' is not allowed.")
}
}
return middleware.CORSConfig{
AllowOrigins: corsOrigins,
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
@@ -268,10 +257,42 @@ func buildCORSConfig(cfg config.Config) middleware.CORSConfig {
}
}
func checkEncryptionKeyIsDefault() (string, bool) {
envKey := os.Getenv("PASSWORD_ENCRYPTION_KEY")
defaultKey := "default-key-32-bytes-long!!!!!!!"
return envKey, envKey == "" || envKey == defaultKey
func validateStartupSecurity(cfg config.Config) error {
if !isProductionEnv(cfg.Auth.Env) {
return nil
}
if isDefaultEncryptionKey() {
return fmt.Errorf("PASSWORD_ENCRYPTION_KEY must be explicitly set in production environment")
}
if usesWildcardCORS(cfg.Auth.CORSAllowOrigins) {
return fmt.Errorf("CORS_ALLOW_ORIGINS must be explicitly set in production environment")
}
return nil
}
func isProductionEnv(env string) bool {
switch strings.ToLower(strings.TrimSpace(env)) {
case "production", "prod", "online":
return true
default:
return false
}
}
func isDefaultEncryptionKey() bool {
envKey := strings.TrimSpace(os.Getenv("PASSWORD_ENCRYPTION_KEY"))
return envKey == "" || envKey == configDefaultEncryptionKey()
}
func configDefaultEncryptionKey() string {
return "default-key-32-bytes-long!!!!!!!"
}
func usesWildcardCORS(origins []string) bool {
if len(origins) == 0 {
return true
}
return len(origins) == 1 && strings.TrimSpace(origins[0]) == "*"
}
func limitHandler(limiter *ratelimit.Middleware, next http.Handler) http.Handler {

View File

@@ -2,6 +2,7 @@ package app
import (
"net/http"
"strings"
"testing"
"lijiaoqiao/gateway/internal/config"
@@ -64,6 +65,88 @@ func TestBuildMux_HealthRouteRemainsOpen(t *testing.T) {
}
}
func TestBuildServer_ProductionRejectsDefaultEncryptionKey(t *testing.T) {
t.Setenv("PASSWORD_ENCRYPTION_KEY", "")
_, err := buildServerWithoutPanic(t, newProductionServerConfig())
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "PASSWORD_ENCRYPTION_KEY") {
t.Fatalf("expected PASSWORD_ENCRYPTION_KEY error, got %v", err)
}
}
func TestBuildServer_ProductionRejectsWildcardCORS(t *testing.T) {
t.Setenv("PASSWORD_ENCRYPTION_KEY", "0123456789abcdef0123456789abcdef")
cfg := newProductionServerConfig()
cfg.Auth.CORSAllowOrigins = []string{"*"}
_, err := buildServerWithoutPanic(t, cfg)
if err == nil {
t.Fatal("expected error")
}
if !strings.Contains(err.Error(), "CORS_ALLOW_ORIGINS") {
t.Fatalf("expected CORS_ALLOW_ORIGINS error, got %v", err)
}
}
func TestBuildServer_DevelopmentAllowsDefaultSecurityFallbacks(t *testing.T) {
t.Setenv("PASSWORD_ENCRYPTION_KEY", "")
cfg := &config.Config{
Auth: config.AuthConfig{
Env: "dev",
TokenRuntimeMode: "inmemory",
},
Providers: []config.ProviderConfig{{
Name: "openai",
Type: "openai",
BaseURL: "https://api.openai.com",
APIKey: "secret",
Models: []string{"gpt-4o"},
}},
}
server, err := buildServerWithoutPanic(t, cfg)
if err != nil {
t.Fatalf("expected dev config to succeed, got %v", err)
}
if server == nil {
t.Fatal("expected server")
}
}
func buildServerWithoutPanic(t *testing.T, cfg *config.Config) (_ *http.Server, err error) {
t.Helper()
defer func() {
if recovered := recover(); recovered != nil {
t.Fatalf("BuildServer panicked: %v", recovered)
}
}()
return BuildServer(cfg)
}
func newProductionServerConfig() *config.Config {
return &config.Config{
Auth: config.AuthConfig{
Env: "production",
TokenRuntimeMode: "remote_introspection",
TokenRuntimeURL: "http://127.0.0.1:18081",
},
Providers: []config.ProviderConfig{{
Name: "openai",
Type: "openai",
BaseURL: "https://api.openai.com",
APIKey: "secret",
Models: []string{"gpt-4o"},
}},
}
}
type testResponseRecorder struct {
header http.Header
code int

View File

@@ -12,10 +12,7 @@ import (
"time"
)
// Encryption key should be provided via environment variable or secure key management
// In production, use a proper key management system (KMS)
// Must be 16, 24, or 32 bytes for AES-128, AES-192, or AES-256
var encryptionKey = []byte(getEnv("PASSWORD_ENCRYPTION_KEY", "default-key-32-bytes-long!!!!!!!"))
const defaultEncryptionKey = "default-key-32-bytes-long!!!!!!!"
// Config 网关配置
type Config struct {
@@ -263,13 +260,17 @@ func getEnv(key, defaultValue string) string {
return defaultValue
}
func currentEncryptionKey() []byte {
return []byte(getEnv("PASSWORD_ENCRYPTION_KEY", defaultEncryptionKey))
}
// encryptPassword 使用AES-GCM加密密码
func encryptPassword(plaintext string) (string, error) {
if plaintext == "" {
return "", nil
}
block, err := aes.NewCipher(encryptionKey)
block, err := aes.NewCipher(currentEncryptionKey())
if err != nil {
return "", err
}
@@ -303,7 +304,7 @@ func decryptPassword(encrypted string) (string, error) {
return encrypted, nil
}
block, err := aes.NewCipher(encryptionKey)
block, err := aes.NewCipher(currentEncryptionKey())
if err != nil {
return "", err
}