Files
lijiaoqiao/gateway/internal/app/bootstrap.go
Your Name ad776e4079 fix: P0/P1 security fixes across gateway, token-runtime, and supply-api
P0 fixes:
- platform-token-runtime: Add store.Save() after Refresh token update (P0-3)
- platform-token-runtime: Add sync.RWMutex to InMemoryRuntimeStore (P0-4)
- platform-token-runtime: Add bearer token auth to /audit-events endpoint (P0-5)
- gateway: Fail startup in production if PASSWORD_ENCRYPTION_KEY uses default (P0-1)
- gateway: Require explicit CORS_ALLOW_ORIGINS in production (P0-2)

P1 fixes:
- gateway: Add TrustedProxies config field + env var GATEWAY_TRUSTED_PROXIES (P1-5)
- gateway: Sanitize X-Request-ID header to prevent log injection (P1-6)
- gateway: Strip internal error details from error responses to clients (P1-7)
- supply-api: Upgrade deriveDEK from trivial byte-rotation to HKDF-SHA256 (P1-1)
- supply-api: Reject HS256/HS384/HS512 in production, require RSA (P1-2)

Code quality fixes:
- supply-api: Add BruteForceMaxAttempts + BruteForceLockoutDuration to AuthConfig (MED-12)
- supply-api: Add TrustedProxies to token_auth_middleware (IP spoofing protection)
- supply-api: Use shared pathutil.SplitPath instead of duplicate splitPath
- supply-api: Fix query_key_reject_middleware call sites with trustedProxies param
- gateway: Wire TrustedProxies into AuthMiddlewareConfig and extractClientIP
- gateway: Add CORSAllowOrigins to AuthConfig, wire into CORSMiddleware
- gateway: Fix CompletionsHandle to have context and RecordResult like ChatCompletions
- gateway: Add sanitizeRequestID helper for X-Request-ID log injection prevention
- gateway: Add os import for PASSWORD_ENCRYPTION_KEY check
- gateway: Add strings import to handler.go for sanitizeRequestID

Environment issues documented in TEST_ENVIRONMENT_ISSUES.md
2026-04-17 14:36:02 +08:00

285 lines
8.8 KiB
Go

package app
import (
"fmt"
"net/http"
"os"
"strings"
"time"
"lijiaoqiao/gateway/internal/config"
"lijiaoqiao/gateway/internal/handler"
"lijiaoqiao/gateway/internal/middleware"
"lijiaoqiao/gateway/internal/ratelimit"
"lijiaoqiao/gateway/internal/router"
)
func BuildServer(cfg *config.Config) (*http.Server, error) {
if cfg == nil {
return nil, fmt.Errorf("config is required")
}
normalized := normalizeConfig(*cfg)
if err := config.ValidateAuthConfig(normalized.Auth); err != nil {
return nil, err
}
r, err := buildRouter(&normalized)
if err != nil {
return nil, err
}
limiter := buildLimiter(normalized.RateLimit)
auditor, err := buildAuditor(normalized)
if err != nil {
return nil, err
}
tokenRuntime, err := buildTokenRuntime(normalized.Auth)
if err != nil {
return nil, err
}
authConfig := middleware.AuthMiddlewareConfig{
Verifier: tokenRuntime,
StatusResolver: tokenRuntime,
Authorizer: middleware.NewScopeRoleAuthorizer(),
Auditor: auditor,
ProtectedPrefixes: []string{
"/v1/chat/completions",
"/v1/completions",
"/api/v1/chat/completions",
"/api/v1/completions",
"/api/v1/supply",
"/api/v1/platform",
},
ExcludedPrefixes: []string{"/health", "/healthz", "/metrics", "/readyz"},
Now: time.Now,
TrustedProxies: normalized.Auth.TrustedProxies,
}
handler := handler.NewHandler(r)
corsConfig := buildCORSConfig(normalized)
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", normalized.Server.Host, normalized.Server.Port),
Handler: BuildMux(handler, limiter, authConfig, corsConfig),
ReadTimeout: normalized.Server.ReadTimeout,
WriteTimeout: normalized.Server.WriteTimeout,
IdleTimeout: normalized.Server.IdleTimeout,
}
return server, nil
}
func BuildMux(h *handler.Handler, limiter *ratelimit.Middleware, authConfig middleware.AuthMiddlewareConfig, corsConfig middleware.CORSConfig) http.Handler {
mux := http.NewServeMux()
chatHandler := middleware.BuildTokenAuthChain(authConfig, http.HandlerFunc(h.ChatCompletionsHandle))
completionsHandler := middleware.BuildTokenAuthChain(authConfig, http.HandlerFunc(h.CompletionsHandle))
mux.HandleFunc("/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
limitHandler(limiter, chatHandler).ServeHTTP(w, r)
})
mux.HandleFunc("/v1/completions", func(w http.ResponseWriter, r *http.Request) {
limitHandler(limiter, completionsHandler).ServeHTTP(w, r)
})
mux.HandleFunc("/v1/models", h.ModelsHandle)
mux.HandleFunc("/api/v1/chat/completions", func(w http.ResponseWriter, r *http.Request) {
limitHandler(limiter, chatHandler).ServeHTTP(w, r)
})
mux.HandleFunc("/api/v1/completions", func(w http.ResponseWriter, r *http.Request) {
limitHandler(limiter, completionsHandler).ServeHTTP(w, r)
})
mux.HandleFunc("/health", h.HealthHandle)
mux.HandleFunc("/healthz", h.HealthHandle)
mux.HandleFunc("/readyz", h.HealthHandle)
return middleware.CORSMiddleware(corsConfig)(mux)
}
func buildRouter(cfg *config.Config) (*router.Router, error) {
if len(cfg.Providers) == 0 {
return nil, fmt.Errorf("at least one provider must be configured")
}
r := router.NewRouter(resolveStrategy(cfg.Router.Strategy))
for _, providerCfg := range cfg.Providers {
provider, err := buildProvider(providerCfg)
if err != nil {
return nil, err
}
r.RegisterProvider(providerCfg.Name, provider)
}
return r, nil
}
func buildLimiter(cfg config.RateLimitConfig) *ratelimit.Middleware {
if strings.EqualFold(cfg.Algorithm, "sliding_window") {
limiter := ratelimit.NewSlidingWindowLimiter(time.Minute, cfg.DefaultRPM)
return ratelimit.NewMiddleware(limiter)
}
limiter := ratelimit.NewTokenBucketLimiter(cfg.DefaultRPM, cfg.DefaultTPM, cfg.BurstMultiplier)
return ratelimit.NewMiddleware(limiter)
}
func buildAuditor(cfg config.Config) (middleware.AuditEmitter, error) {
if strings.TrimSpace(cfg.Database.Host) == "" {
return middleware.NewMemoryAuditEmitter(), nil
}
dsn := fmt.Sprintf(
"postgres://%s:%s@%s:%d/%s?sslmode=disable",
cfg.Database.User,
cfg.Database.GetPassword(),
cfg.Database.Host,
cfg.Database.Port,
cfg.Database.Database,
)
auditor, err := middleware.NewDatabaseAuditEmitter(dsn, time.Now)
if err != nil {
return nil, fmt.Errorf("create database audit emitter: %w", err)
}
return auditor, nil
}
func buildTokenRuntime(cfg config.AuthConfig) (interface {
middleware.TokenVerifier
middleware.TokenStatusResolver
}, error) {
switch strings.ToLower(strings.TrimSpace(cfg.TokenRuntimeMode)) {
case "", "inmemory":
return middleware.NewInMemoryTokenRuntime(time.Now), nil
case "remote_introspection":
return middleware.NewRemoteTokenRuntime(cfg.TokenRuntimeURL, http.DefaultClient, time.Now), nil
default:
return nil, fmt.Errorf("unsupported token runtime mode: %s", cfg.TokenRuntimeMode)
}
}
func resolveStrategy(strategy string) router.LoadBalancerStrategy {
switch strings.ToLower(strings.TrimSpace(strategy)) {
case string(router.StrategyRoundRobin):
return router.StrategyRoundRobin
case string(router.StrategyWeighted):
return router.StrategyWeighted
case string(router.StrategyAvailability):
return router.StrategyAvailability
default:
return router.StrategyLatency
}
}
func normalizeConfig(cfg config.Config) config.Config {
if strings.TrimSpace(cfg.Server.Host) == "" {
cfg.Server.Host = "0.0.0.0"
}
if cfg.Server.Port == 0 {
cfg.Server.Port = 8080
}
if cfg.Server.ReadTimeout == 0 {
cfg.Server.ReadTimeout = 30 * time.Second
}
if cfg.Server.WriteTimeout == 0 {
cfg.Server.WriteTimeout = 30 * time.Second
}
if cfg.Server.IdleTimeout == 0 {
cfg.Server.IdleTimeout = 120 * time.Second
}
if cfg.Router.Strategy == "" {
cfg.Router.Strategy = string(router.StrategyLatency)
}
if cfg.RateLimit.DefaultRPM == 0 {
cfg.RateLimit.DefaultRPM = 60
}
if cfg.RateLimit.DefaultTPM == 0 {
cfg.RateLimit.DefaultTPM = 60000
}
if cfg.RateLimit.BurstMultiplier == 0 {
cfg.RateLimit.BurstMultiplier = 1.5
}
if cfg.RateLimit.Algorithm == "" {
cfg.RateLimit.Algorithm = "token_bucket"
}
if cfg.Auth.TokenRuntimeMode == "" {
cfg.Auth.TokenRuntimeMode = "inmemory"
}
// TrustedProxies from env: comma-separated list of trusted proxy IPs
if len(cfg.Auth.TrustedProxies) == 0 {
trustedProxiesEnv := strings.TrimSpace(os.Getenv("GATEWAY_TRUSTED_PROXIES"))
if trustedProxiesEnv != "" {
for _, ip := range strings.Split(trustedProxiesEnv, ",") {
ip = strings.TrimSpace(ip)
if ip != "" {
cfg.Auth.TrustedProxies = append(cfg.Auth.TrustedProxies, ip)
}
}
}
}
// CORSAllowOrigins from env: comma-separated list of allowed origins
if len(cfg.Auth.CORSAllowOrigins) == 0 {
corsEnv := strings.TrimSpace(os.Getenv("GATEWAY_CORS_ALLOW_ORIGINS"))
if corsEnv != "" {
for _, origin := range strings.Split(corsEnv, ",") {
origin = strings.TrimSpace(origin)
if origin != "" {
cfg.Auth.CORSAllowOrigins = append(cfg.Auth.CORSAllowOrigins, origin)
}
}
}
}
// P0-1: Fail startup in production if encryption key is not explicitly set
if strings.EqualFold(cfg.Auth.Env, "production") || strings.EqualFold(cfg.Auth.Env, "prod") || strings.EqualFold(cfg.Auth.Env, "online") {
if _, isDefault := checkEncryptionKeyIsDefault(); isDefault {
panic("FATAL: PASSWORD_ENCRYPTION_KEY environment variable must be explicitly set in production environment. Using the default key is not allowed.")
}
}
return cfg
}
// buildCORSConfig builds CORS config from normalized config
// In production (Env=production/prod/online), rejects wildcard if CORSAllowOrigins not explicitly set
func buildCORSConfig(cfg config.Config) middleware.CORSConfig {
corsOrigins := cfg.Auth.CORSAllowOrigins
if len(corsOrigins) == 0 {
corsOrigins = []string{"*"}
}
// P0-2: Warn in production if using wildcard
if strings.EqualFold(cfg.Auth.Env, "production") || strings.EqualFold(cfg.Auth.Env, "prod") || strings.EqualFold(cfg.Auth.Env, "online") {
if len(corsOrigins) == 1 && corsOrigins[0] == "*" {
panic("FATAL: CORS_ALLOW_ORIGINS must be explicitly set in production environment. Using wildcard '*' is not allowed.")
}
}
return middleware.CORSConfig{
AllowOrigins: corsOrigins,
AllowMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowHeaders: []string{"Authorization", "Content-Type", "X-Request-ID", "X-Request-Key"},
ExposeHeaders: []string{"X-Request-ID"},
AllowCredentials: false,
MaxAge: 86400,
}
}
func checkEncryptionKeyIsDefault() (string, bool) {
envKey := os.Getenv("PASSWORD_ENCRYPTION_KEY")
defaultKey := "default-key-32-bytes-long!!!!!!!"
return envKey, envKey == "" || envKey == defaultKey
}
func limitHandler(limiter *ratelimit.Middleware, next http.Handler) http.Handler {
if limiter == nil {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limiter.Limit(next.ServeHTTP)(w, r)
})
}