diff --git a/backend/internal/config/auth.go b/backend/internal/config/auth.go new file mode 100644 index 00000000..c3f2f914 --- /dev/null +++ b/backend/internal/config/auth.go @@ -0,0 +1,42 @@ +package config + +// JWTConfig JWT 认证配置 +type JWTConfig struct { + Secret string `mapstructure:"secret"` + ExpireHour int `mapstructure:"expire_hour"` + // AccessTokenExpireMinutes: Access Token有效期(分钟) + // - >0: 使用分钟配置(优先级高于 ExpireHour) + // - =0: 回退使用 ExpireHour(向后兼容旧配置) + AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"` + // RefreshTokenExpireDays: Refresh Token有效期(天),默认30天 + RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"` + // RefreshWindowMinutes: 刷新窗口(分钟),在Access Token过期前多久开始允许刷新 + RefreshWindowMinutes int `mapstructure:"refresh_window_minutes"` +} + +// TotpConfig TOTP 双因素认证配置 +type TotpConfig struct { + EncryptionKey string `mapstructure:"encryption_key"` // AES-256 密钥(32 字节 hex 编码) + EncryptionKeyConfigured bool `mapstructure:"-"` // 是否手动配置(非自动生成) +} + +// TurnstileConfig Cloudflare Turnstile 验证配置 +type TurnstileConfig struct { + Required bool `mapstructure:"required"` +} + +// DefaultConfig 默认值配置 +type DefaultConfig struct { + AdminEmail string `mapstructure:"admin_email"` + AdminPassword string `mapstructure:"admin_password"` + UserConcurrency int `mapstructure:"user_concurrency"` + UserBalance float64 `mapstructure:"user_balance"` + APIKeyPrefix string `mapstructure:"api_key_prefix"` + RateMultiplier float64 `mapstructure:"rate_multiplier"` +} + +// RateLimitConfig 限流配置 +type RateLimitConfig struct { + OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` + OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` +} diff --git a/backend/internal/config/billing.go b/backend/internal/config/billing.go new file mode 100644 index 00000000..facb5624 --- /dev/null +++ b/backend/internal/config/billing.go @@ -0,0 +1,23 @@ +package config + +// BillingConfig 计费配置 +type BillingConfig struct { + CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"` +} + +type CircuitBreakerConfig struct { + Enabled bool `mapstructure:"enabled"` + FailureThreshold int `mapstructure:"failure_threshold"` + ResetTimeoutSeconds int `mapstructure:"reset_timeout_seconds"` + HalfOpenRequests int `mapstructure:"half_open_requests"` +} + +// PricingConfig 定价数据配置 +type PricingConfig struct { + RemoteURL string `mapstructure:"remote_url"` // 远程 URL(默认 LiteLLM 镜像) + HashURL string `mapstructure:"hash_url"` // 哈希校验文件 URL + DataDir string `mapstructure:"data_dir"` // 本地数据目录 + FallbackFile string `mapstructure:"fallback_file"` // 回退文件路径 + UpdateIntervalHours int `mapstructure:"update_interval_hours"` // 更新间隔(小时) + HashCheckIntervalMinutes int `mapstructure:"hash_check_interval_minutes"` // 哈希校验间隔(分钟) +} diff --git a/backend/internal/config/config.go b/backend/internal/config/config.go index c1245c1e..26389b39 100644 --- a/backend/internal/config/config.go +++ b/backend/internal/config/config.go @@ -1,15 +1,23 @@ // Package config provides configuration loading, defaults, and validation. +// +// Type definitions are organized by domain into companion files: +// +// - server.go — ServerConfig, H2CConfig, CORSConfig, ConcurrencyConfig +// - security.go — SecurityConfig, URLAllowlist, CSP, ResponseHeaders +// - database.go — DatabaseConfig, RedisConfig (with DSN helpers) +// - auth.go — JWTConfig, TotpConfig, TurnstileConfig, DefaultConfig, RateLimitConfig +// - billing.go — BillingConfig, PricingConfig +// - gateway.go — GatewayConfig, UserMessageQueue, SchedulingConfig +// - gateway_sub.go — OpenAIWS, UsageRecord, TLSFingerprint sub-structs +// - platforms.go — Sora, Gemini, LinuxDo, OIDC, Update, Idempotency configs +// - ops_and_cache.go— LogConfig, OpsConfig, Dashboard, Cache, Cleanup configs package config import ( - "crypto/rand" - "encoding/hex" "fmt" "log/slog" - "net/url" "os" "strings" - "time" "github.com/spf13/viper" ) @@ -26,32 +34,23 @@ const ( UsageRecordOverflowPolicySync = "sync" ) -// DefaultCSPPolicy is the default Content-Security-Policy with nonce support -// __CSP_NONCE__ will be replaced with actual nonce at request time by the SecurityHeaders middleware -const DefaultCSPPolicy = "default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'" - // UMQ(用户消息队列)模式常量 const ( - // UMQModeSerialize: 账号级串行锁 + RPM 自适应延迟 UMQModeSerialize = "serialize" - // UMQModeThrottle: 仅 RPM 自适应前置延迟,不阻塞并发 - UMQModeThrottle = "throttle" + UMQModeThrottle = "throttle" ) // 连接池隔离策略常量 -// 用于控制上游 HTTP 连接池的隔离粒度,影响连接复用和资源消耗 const ( - // ConnectionPoolIsolationProxy: 按代理隔离 - // 同一代理地址共享连接池,适合代理数量少、账户数量多的场景 - ConnectionPoolIsolationProxy = "proxy" - // ConnectionPoolIsolationAccount: 按账户隔离 - // 每个账户独立连接池,适合账户数量少、需要严格隔离的场景 - ConnectionPoolIsolationAccount = "account" - // ConnectionPoolIsolationAccountProxy: 按账户+代理组合隔离(默认) - // 同一账户+代理组合共享连接池,提供最细粒度的隔离 - ConnectionPoolIsolationAccountProxy = "account_proxy" + ConnectionPoolIsolationProxy = "proxy" + ConnectionPoolIsolationAccount = "account" + ConnectionPoolIsolationAccountProxy = "account_proxy" ) +// DefaultCSPPolicy is the default Content-Security-Policy with nonce support. +const DefaultCSPPolicy = `default-src 'self'; script-src 'self' __CSP_NONCE__ https://challenges.cloudflare.com https://static.cloudflareinsights.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; img-src 'self' data: https:; font-src 'self' data: https://fonts.gstatic.com; connect-src 'self' https:; frame-src https://challenges.cloudflare.com; frame-ancestors 'none'; base-uri 'self'; form-action 'self'` + +// Config is the top-level application configuration. type Config struct { Server ServerConfig `mapstructure:"server"` Log LogConfig `mapstructure:"log"` @@ -79,934 +78,13 @@ type Config struct { Concurrency ConcurrencyConfig `mapstructure:"concurrency"` TokenRefresh TokenRefreshConfig `mapstructure:"token_refresh"` RunMode string `mapstructure:"run_mode" yaml:"run_mode"` - Timezone string `mapstructure:"timezone"` // e.g. "Asia/Shanghai", "UTC" - Sora SoraConfig `mapstructure:"sora"` // 从本地版本合并 + Timezone string `mapstructure:"timezone"` + Sora SoraConfig `mapstructure:"sora"` Gemini GeminiConfig `mapstructure:"gemini"` Update UpdateConfig `mapstructure:"update"` Idempotency IdempotencyConfig `mapstructure:"idempotency"` } -type LogConfig struct { - Level string `mapstructure:"level"` - Format string `mapstructure:"format"` - ServiceName string `mapstructure:"service_name"` - Environment string `mapstructure:"env"` - Caller bool `mapstructure:"caller"` - StacktraceLevel string `mapstructure:"stacktrace_level"` - Output LogOutputConfig `mapstructure:"output"` - Rotation LogRotationConfig `mapstructure:"rotation"` - Sampling LogSamplingConfig `mapstructure:"sampling"` -} - -type LogOutputConfig struct { - ToStdout bool `mapstructure:"to_stdout"` - ToFile bool `mapstructure:"to_file"` - FilePath string `mapstructure:"file_path"` -} - -type LogRotationConfig struct { - MaxSizeMB int `mapstructure:"max_size_mb"` - MaxBackups int `mapstructure:"max_backups"` - MaxAgeDays int `mapstructure:"max_age_days"` - Compress bool `mapstructure:"compress"` - LocalTime bool `mapstructure:"local_time"` -} - -type LogSamplingConfig struct { - Enabled bool `mapstructure:"enabled"` - Initial int `mapstructure:"initial"` - Thereafter int `mapstructure:"thereafter"` -} - -type GeminiConfig struct { - OAuth GeminiOAuthConfig `mapstructure:"oauth"` - Quota GeminiQuotaConfig `mapstructure:"quota"` -} - -type GeminiOAuthConfig struct { - ClientID string `mapstructure:"client_id"` - ClientSecret string `mapstructure:"client_secret"` - Scopes string `mapstructure:"scopes"` -} - -type GeminiQuotaConfig struct { - Tiers map[string]GeminiTierQuotaConfig `mapstructure:"tiers"` - Policy string `mapstructure:"policy"` -} - -type GeminiTierQuotaConfig struct { - ProRPD *int64 `mapstructure:"pro_rpd" json:"pro_rpd"` - FlashRPD *int64 `mapstructure:"flash_rpd" json:"flash_rpd"` - CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"` -} - -// SoraConfig 直连 Sora 配置 (从本地版本合并) -type SoraConfig struct { - Client SoraClientConfig `mapstructure:"client"` - Storage SoraStorageConfig `mapstructure:"storage"` -} - -// SoraClientConfig 直连 Sora 客户端配置 (从本地版本合并) -type SoraClientConfig struct { - BaseURL string `mapstructure:"base_url"` - TimeoutSeconds int `mapstructure:"timeout_seconds"` - MaxRetries int `mapstructure:"max_retries"` - CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"` - PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` - MaxPollAttempts int `mapstructure:"max_poll_attempts"` - RecentTaskLimit int `mapstructure:"recent_task_limit"` - RecentTaskLimitMax int `mapstructure:"recent_task_limit_max"` - Debug bool `mapstructure:"debug"` - UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"` - Headers map[string]string `mapstructure:"headers"` - UserAgent string `mapstructure:"user_agent"` - DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` - CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"` -} - -// SoraCurlCFFISidecarConfig Sora 专用 curl_cffi sidecar 配置 (从本地版本合并) -type SoraCurlCFFISidecarConfig struct { - Enabled bool `mapstructure:"enabled"` - BaseURL string `mapstructure:"base_url"` - Impersonate string `mapstructure:"impersonate"` - TimeoutSeconds int `mapstructure:"timeout_seconds"` - SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"` - SessionTTLSeconds int `mapstructure:"session_ttl_seconds"` -} - -// SoraStorageConfig 媒体存储配置 (从本地版本合并) -type SoraStorageConfig struct { - Type string `mapstructure:"type"` - LocalPath string `mapstructure:"local_path"` - FallbackToUpstream bool `mapstructure:"fallback_to_upstream"` - MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"` - DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"` - MaxDownloadBytes int64 `mapstructure:"max_download_bytes"` - Debug bool `mapstructure:"debug"` - Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"` -} - -// SoraStorageCleanupConfig 媒体清理配置 (从本地版本合并) -type SoraStorageCleanupConfig struct { - Enabled bool `mapstructure:"enabled"` - Schedule string `mapstructure:"schedule"` - RetentionDays int `mapstructure:"retention_days"` -} - -// SoraModelFiltersConfig Sora 模型过滤配置 (从本地版本合并) -type SoraModelFiltersConfig struct { - // HidePromptEnhance 是否隐藏 prompt-enhance 模型 - HidePromptEnhance bool `mapstructure:"hide_prompt_enhance"` -} - -type UpdateConfig struct { - // ProxyURL 用于访问 GitHub 的代理地址 - // 支持 http/https/socks5/socks5h 协议 - // 例如: "http://127.0.0.1:7890", "socks5://127.0.0.1:1080" - ProxyURL string `mapstructure:"proxy_url"` -} - -type IdempotencyConfig struct { - // ObserveOnly 为 true 时处于观察期:未携带 Idempotency-Key 的请求继续放行。 - ObserveOnly bool `mapstructure:"observe_only"` - // DefaultTTLSeconds 关键写接口的幂等记录默认 TTL(秒)。 - DefaultTTLSeconds int `mapstructure:"default_ttl_seconds"` - // SystemOperationTTLSeconds 系统操作接口的幂等记录 TTL(秒)。 - SystemOperationTTLSeconds int `mapstructure:"system_operation_ttl_seconds"` - // ProcessingTimeoutSeconds processing 状态锁超时(秒)。 - ProcessingTimeoutSeconds int `mapstructure:"processing_timeout_seconds"` - // FailedRetryBackoffSeconds 失败退避窗口(秒)。 - FailedRetryBackoffSeconds int `mapstructure:"failed_retry_backoff_seconds"` - // MaxStoredResponseLen 持久化响应体最大长度(字节)。 - MaxStoredResponseLen int `mapstructure:"max_stored_response_len"` - // CleanupIntervalSeconds 过期记录清理周期(秒)。 - CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"` - // CleanupBatchSize 每次清理的最大记录数。 - CleanupBatchSize int `mapstructure:"cleanup_batch_size"` -} - -type LinuxDoConnectConfig struct { - Enabled bool `mapstructure:"enabled"` - ClientID string `mapstructure:"client_id"` - ClientSecret string `mapstructure:"client_secret"` - AuthorizeURL string `mapstructure:"authorize_url"` - TokenURL string `mapstructure:"token_url"` - UserInfoURL string `mapstructure:"userinfo_url"` - Scopes string `mapstructure:"scopes"` - RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记) - FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/linuxdo/callback) - TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none - UsePKCE bool `mapstructure:"use_pkce"` - - // 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。 - // 为空时,服务端会尝试一组常见字段名。 - UserInfoEmailPath string `mapstructure:"userinfo_email_path"` - UserInfoIDPath string `mapstructure:"userinfo_id_path"` - UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` -} - -type OIDCConnectConfig struct { - Enabled bool `mapstructure:"enabled"` - ProviderName string `mapstructure:"provider_name"` // 显示名: "Keycloak" 等 - ClientID string `mapstructure:"client_id"` - ClientSecret string `mapstructure:"client_secret"` - IssuerURL string `mapstructure:"issuer_url"` - DiscoveryURL string `mapstructure:"discovery_url"` - AuthorizeURL string `mapstructure:"authorize_url"` - TokenURL string `mapstructure:"token_url"` - UserInfoURL string `mapstructure:"userinfo_url"` - JWKSURL string `mapstructure:"jwks_url"` - Scopes string `mapstructure:"scopes"` // 默认 "openid email profile" - RedirectURL string `mapstructure:"redirect_url"` // 后端回调地址(需在提供方后台登记) - FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` // 前端接收 token 的路由(默认:/auth/oidc/callback) - TokenAuthMethod string `mapstructure:"token_auth_method"` // client_secret_post / client_secret_basic / none - UsePKCE bool `mapstructure:"use_pkce"` - ValidateIDToken bool `mapstructure:"validate_id_token"` - AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` // 默认 "RS256,ES256,PS256" - ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` // 默认 120 - RequireEmailVerified bool `mapstructure:"require_email_verified"` // 默认 false - - // 可选:用于从 userinfo JSON 中提取字段的 gjson 路径。 - // 为空时,服务端会尝试一组常见字段名。 - UserInfoEmailPath string `mapstructure:"userinfo_email_path"` - UserInfoIDPath string `mapstructure:"userinfo_id_path"` - UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` -} - -// TokenRefreshConfig OAuth token自动刷新配置 -type TokenRefreshConfig struct { - // 是否启用自动刷新 - Enabled bool `mapstructure:"enabled"` - // 检查间隔(分钟) - CheckIntervalMinutes int `mapstructure:"check_interval_minutes"` - // 提前刷新时间(小时),在token过期前多久开始刷新 - RefreshBeforeExpiryHours float64 `mapstructure:"refresh_before_expiry_hours"` - // 最大重试次数 - MaxRetries int `mapstructure:"max_retries"` - // 重试退避基础时间(秒) - RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` -} - -type PricingConfig struct { - // 价格数据远程URL(默认使用LiteLLM镜像) - RemoteURL string `mapstructure:"remote_url"` - // 哈希校验文件URL - HashURL string `mapstructure:"hash_url"` - // 本地数据目录 - DataDir string `mapstructure:"data_dir"` - // 回退文件路径 - FallbackFile string `mapstructure:"fallback_file"` - // 更新间隔(小时) - UpdateIntervalHours int `mapstructure:"update_interval_hours"` - // 哈希校验间隔(分钟) - HashCheckIntervalMinutes int `mapstructure:"hash_check_interval_minutes"` -} - -type ServerConfig struct { - Host string `mapstructure:"host"` - Port int `mapstructure:"port"` - Mode string `mapstructure:"mode"` // debug/release - FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL,用于生成邮件中的外部链接 - ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) - IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) - TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) - MaxRequestBodySize int64 `mapstructure:"max_request_body_size"` // 全局最大请求体限制 - H2C H2CConfig `mapstructure:"h2c"` // HTTP/2 Cleartext 配置 -} - -// H2CConfig HTTP/2 Cleartext 配置 -type H2CConfig struct { - Enabled bool `mapstructure:"enabled"` // 是否启用 H2C - MaxConcurrentStreams uint32 `mapstructure:"max_concurrent_streams"` // 最大并发流数量 - IdleTimeout int `mapstructure:"idle_timeout"` // 空闲超时(秒) - MaxReadFrameSize int `mapstructure:"max_read_frame_size"` // 最大帧大小(字节) - MaxUploadBufferPerConnection int `mapstructure:"max_upload_buffer_per_connection"` // 每个连接的上传缓冲区(字节) - MaxUploadBufferPerStream int `mapstructure:"max_upload_buffer_per_stream"` // 每个流的上传缓冲区(字节) -} - -type CORSConfig struct { - AllowedOrigins []string `mapstructure:"allowed_origins"` - AllowCredentials bool `mapstructure:"allow_credentials"` -} - -type SecurityConfig struct { - URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"` - ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"` - CSP CSPConfig `mapstructure:"csp"` - ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"` - ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"` -} - -type URLAllowlistConfig struct { - Enabled bool `mapstructure:"enabled"` - UpstreamHosts []string `mapstructure:"upstream_hosts"` - PricingHosts []string `mapstructure:"pricing_hosts"` - CRSHosts []string `mapstructure:"crs_hosts"` - AllowPrivateHosts bool `mapstructure:"allow_private_hosts"` - // 关闭 URL 白名单校验时,是否允许 http URL(默认只允许 https) - AllowInsecureHTTP bool `mapstructure:"allow_insecure_http"` -} - -type ResponseHeaderConfig struct { - Enabled bool `mapstructure:"enabled"` - AdditionalAllowed []string `mapstructure:"additional_allowed"` - ForceRemove []string `mapstructure:"force_remove"` -} - -type CSPConfig struct { - Enabled bool `mapstructure:"enabled"` - Policy string `mapstructure:"policy"` -} - -type ProxyFallbackConfig struct { - // AllowDirectOnError 当辅助服务的代理初始化失败时是否允许回退直连。 - // 仅影响以下非 AI 账号连接的辅助服务: - // - GitHub Release 更新检查 - // - 定价数据拉取 - // 不影响 AI 账号网关连接(Claude/OpenAI/Gemini/Antigravity), - // 这些关键路径的代理失败始终返回错误,不会回退直连。 - // 默认 false:避免因代理配置错误导致服务器真实 IP 泄露。 - AllowDirectOnError bool `mapstructure:"allow_direct_on_error"` -} - -type ProxyProbeConfig struct { - InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证 -} - -type BillingConfig struct { - CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"` -} - -type CircuitBreakerConfig struct { - Enabled bool `mapstructure:"enabled"` - FailureThreshold int `mapstructure:"failure_threshold"` - ResetTimeoutSeconds int `mapstructure:"reset_timeout_seconds"` - HalfOpenRequests int `mapstructure:"half_open_requests"` -} - -type ConcurrencyConfig struct { - // PingInterval: 并发等待期间的 SSE ping 间隔(秒) - PingInterval int `mapstructure:"ping_interval"` -} - -// GatewayConfig API网关相关配置 -type GatewayConfig struct { - // 等待上游响应头的超时时间(秒),0表示无超时 - // 注意:这不影响流式数据传输,只控制等待响应头的时间 - ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` - // 请求体最大字节数,用于网关请求体大小限制 - MaxBodySize int64 `mapstructure:"max_body_size"` - // 非流式上游响应体读取上限(字节),用于防止无界读取导致内存放大 - UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"` - // 代理探测响应体读取上限(字节) - ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"` - // Gemini 上游响应头调试日志开关(默认关闭,避免高频日志开销) - GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"` - // ConnectionPoolIsolation: 上游连接池隔离策略(proxy/account/account_proxy) - ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` - // ForceCodexCLI: 强制将 OpenAI `/v1/responses` 请求按 Codex CLI 处理。 - // 用于网关未透传/改写 User-Agent 时的兼容兜底(默认关闭,避免影响其他客户端)。 - ForceCodexCLI bool `mapstructure:"force_codex_cli"` - // ForcedCodexInstructionsTemplateFile: 服务端强制附加到 Codex 顶层 instructions 的模板文件路径。 - // 模板渲染后会直接覆盖最终 instructions;若需要保留客户端 system 转换结果,请在模板中显式引用 {{ .ExistingInstructions }}。 - ForcedCodexInstructionsTemplateFile string `mapstructure:"forced_codex_instructions_template_file"` - // ForcedCodexInstructionsTemplate: 启动时从模板文件读取并缓存的模板内容。 - // 该字段不直接参与配置反序列化,仅用于请求热路径避免重复读盘。 - ForcedCodexInstructionsTemplate string `mapstructure:"-"` - // OpenAIPassthroughAllowTimeoutHeaders: OpenAI 透传模式是否放行客户端超时头 - // 关闭(默认)可避免 x-stainless-timeout 等头导致上游提前断流。 - OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` - // OpenAIWS: OpenAI Responses WebSocket 配置(默认开启,可按需回滚到 HTTP) - OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"` - - // HTTP 上游连接池配置(性能优化:支持高并发场景调优) - // MaxIdleConns: 所有主机的最大空闲连接总数 - MaxIdleConns int `mapstructure:"max_idle_conns"` - // MaxIdleConnsPerHost: 每个主机的最大空闲连接数(关键参数,影响连接复用率) - MaxIdleConnsPerHost int `mapstructure:"max_idle_conns_per_host"` - // MaxConnsPerHost: 每个主机的最大连接数(包括活跃+空闲),0表示无限制 - MaxConnsPerHost int `mapstructure:"max_conns_per_host"` - // IdleConnTimeoutSeconds: 空闲连接超时时间(秒) - IdleConnTimeoutSeconds int `mapstructure:"idle_conn_timeout_seconds"` - // MaxUpstreamClients: 上游连接池客户端最大缓存数量 - // 当使用连接池隔离策略时,系统会为不同的账户/代理组合创建独立的 HTTP 客户端 - // 此参数限制缓存的客户端数量,超出后会淘汰最久未使用的客户端 - // 建议值:预估的活跃账户数 * 1.2(留有余量) - MaxUpstreamClients int `mapstructure:"max_upstream_clients"` - // ClientIdleTTLSeconds: 上游连接池客户端空闲回收阈值(秒) - // 超过此时间未使用的客户端会被标记为可回收 - // 建议值:根据用户访问频率设置,一般 10-30 分钟 - ClientIdleTTLSeconds int `mapstructure:"client_idle_ttl_seconds"` - // ConcurrencySlotTTLMinutes: 并发槽位过期时间(分钟) - // 应大于最长 LLM 请求时间,防止请求完成前槽位过期 - ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` - // SessionIdleTimeoutMinutes: 会话空闲超时时间(分钟),默认 5 分钟 - // 用于 Anthropic OAuth/SetupToken 账号的会话数量限制功能 - // 空闲超过此时间的会话将被自动释放 - SessionIdleTimeoutMinutes int `mapstructure:"session_idle_timeout_minutes"` - - // StreamDataIntervalTimeout: 流数据间隔超时(秒),0表示禁用 - StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"` - // StreamKeepaliveInterval: 流式 keepalive 间隔(秒),0表示禁用 - StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"` - // MaxLineSize: 上游 SSE 单行最大字节数(0使用默认值) - MaxLineSize int `mapstructure:"max_line_size"` - - // 是否记录上游错误响应体摘要(避免输出请求内容) - LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"` - // 上游错误响应体记录最大字节数(超过会截断) - LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"` - - // API-key 账号在客户端未提供 anthropic-beta 时,是否按需自动补齐(默认关闭以保持兼容) - InjectBetaForAPIKey bool `mapstructure:"inject_beta_for_apikey"` - - // 是否允许对部分 400 错误触发 failover(默认关闭以避免改变语义) - FailoverOn400 bool `mapstructure:"failover_on_400"` - - // Sora 专用配置 (从本地版本合并) - // SoraMaxBodySize: Sora 请求体最大字节数(0 表示使用 gateway.max_body_size) - SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"` - // SoraStreamTimeoutSeconds: Sora 流式请求总超时(秒,0 表示不限制) - SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"` - // SoraRequestTimeoutSeconds: Sora 非流式请求超时(秒,0 表示不限制) - SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"` - // SoraStreamMode: stream 强制策略(force/error) - SoraStreamMode string `mapstructure:"sora_stream_mode"` - // SoraModelFilters: 模型列表过滤配置 - SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"` - // SoraMediaRequireAPIKey: 是否要求访问 /sora/media 携带 API Key - SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"` - // SoraMediaSigningKey: /sora/media 临时签名密钥(空表示禁用签名) - SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"` - // SoraMediaSignedURLTTLSeconds: 临时签名 URL 有效期(秒,<=0 表示禁用) - SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"` - - // 账户切换最大次数(遇到上游错误时切换到其他账户的次数上限) - MaxAccountSwitches int `mapstructure:"max_account_switches"` - // Gemini 账户切换最大次数(Gemini 平台单独配置,因 API 限制更严格) - MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"` - - // Antigravity 429 fallback 限流时间(分钟),解析重置时间失败时使用 - AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"` - - // Scheduling: 账号调度相关配置 - Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` - - // TLSFingerprint: TLS指纹伪装配置 - TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"` - - // UsageRecord: 使用量记录异步队列配置(有界队列 + 固定 worker) - UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"` - - // UserGroupRateCacheTTLSeconds: 用户分组倍率热路径缓存 TTL(秒) - UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"` - // ModelsListCacheTTLSeconds: /v1/models 模型列表短缓存 TTL(秒) - ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"` - - // UserMessageQueue: 用户消息串行队列配置 - // 对 role:"user" 的真实用户消息实施账号级串行化 + RPM 自适应延迟 - UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"` -} - -// UserMessageQueueConfig 用户消息串行队列配置 -// 用于 Anthropic OAuth/SetupToken 账号的用户消息串行化发送 -type UserMessageQueueConfig struct { - // Mode: 模式选择 - // "serialize" = 账号级串行锁 + RPM 自适应延迟 - // "throttle" = 仅 RPM 自适应前置延迟,不阻塞并发 - // "" = 禁用(默认) - Mode string `mapstructure:"mode"` - // Enabled: 已废弃,仅向后兼容(等同于 mode: "serialize") - Enabled bool `mapstructure:"enabled"` - // LockTTLMs: 串行锁 TTL(毫秒),应大于最长请求时间 - LockTTLMs int `mapstructure:"lock_ttl_ms"` - // WaitTimeoutMs: 等待获取锁的超时时间(毫秒) - WaitTimeoutMs int `mapstructure:"wait_timeout_ms"` - // MinDelayMs: RPM 自适应延迟下限(毫秒) - MinDelayMs int `mapstructure:"min_delay_ms"` - // MaxDelayMs: RPM 自适应延迟上限(毫秒) - MaxDelayMs int `mapstructure:"max_delay_ms"` - // CleanupIntervalSeconds: 孤儿锁清理间隔(秒),0 表示禁用 - CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"` -} - -// WaitTimeout 返回等待超时的 time.Duration -func (c *UserMessageQueueConfig) WaitTimeout() time.Duration { - if c.WaitTimeoutMs <= 0 { - return 30 * time.Second - } - return time.Duration(c.WaitTimeoutMs) * time.Millisecond -} - -// GetEffectiveMode 返回生效的模式 -// 注意:Mode 字段已在 load() 中做过白名单校验和规范化,此处无需重复验证 -func (c *UserMessageQueueConfig) GetEffectiveMode() string { - if c.Mode == UMQModeSerialize || c.Mode == UMQModeThrottle { - return c.Mode - } - if c.Enabled { - return UMQModeSerialize // 向后兼容 - } - return "" -} - -// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置。 -// 注意:默认全局开启;如需回滚可使用 force_http 或关闭 enabled。 -type GatewayOpenAIWSConfig struct { - // ModeRouterV2Enabled: 新版 WS mode 路由开关(默认 false;关闭时保持 legacy 行为) - ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"` - // IngressModeDefault: ingress 默认模式(off/ctx_pool/passthrough) - IngressModeDefault string `mapstructure:"ingress_mode_default"` - // Enabled: 全局总开关(默认 true) - Enabled bool `mapstructure:"enabled"` - // OAuthEnabled: 是否允许 OpenAI OAuth 账号使用 WS - OAuthEnabled bool `mapstructure:"oauth_enabled"` - // APIKeyEnabled: 是否允许 OpenAI API Key 账号使用 WS - APIKeyEnabled bool `mapstructure:"apikey_enabled"` - // ForceHTTP: 全局强制 HTTP(用于紧急回滚) - ForceHTTP bool `mapstructure:"force_http"` - // AllowStoreRecovery: 允许在 WSv2 下按策略恢复 store=true(默认 false) - AllowStoreRecovery bool `mapstructure:"allow_store_recovery"` - // IngressPreviousResponseRecoveryEnabled: ingress 模式收到 previous_response_not_found 时,是否允许自动去掉 previous_response_id 重试一次(默认 true) - IngressPreviousResponseRecoveryEnabled bool `mapstructure:"ingress_previous_response_recovery_enabled"` - // StoreDisabledConnMode: store=false 且无可复用会话连接时的建连策略(strict/adaptive/off) - // - strict: 强制新建连接(隔离优先) - // - adaptive: 仅在高风险失败后强制新建连接(性能与隔离折中) - // - off: 不强制新建连接(复用优先) - StoreDisabledConnMode string `mapstructure:"store_disabled_conn_mode"` - // StoreDisabledForceNewConn: store=false 且无可复用粘连连接时是否强制新建连接(默认 true,保障会话隔离) - // 兼容旧配置;当 StoreDisabledConnMode 为空时才生效。 - StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"` - // PrewarmGenerateEnabled: 是否启用 WSv2 generate=false 预热(默认 false) - PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"` - - // Feature 开关:v2 优先于 v1 - ResponsesWebsockets bool `mapstructure:"responses_websockets"` - ResponsesWebsocketsV2 bool `mapstructure:"responses_websockets_v2"` - - // 连接池参数 - MaxConnsPerAccount int `mapstructure:"max_conns_per_account"` - MinIdlePerAccount int `mapstructure:"min_idle_per_account"` - MaxIdlePerAccount int `mapstructure:"max_idle_per_account"` - // DynamicMaxConnsByAccountConcurrencyEnabled: 是否按账号并发动态计算连接池上限 - DynamicMaxConnsByAccountConcurrencyEnabled bool `mapstructure:"dynamic_max_conns_by_account_concurrency_enabled"` - // OAuthMaxConnsFactor: OAuth 账号连接池系数(effective=ceil(concurrency*factor)) - OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"` - // APIKeyMaxConnsFactor: API Key 账号连接池系数(effective=ceil(concurrency*factor)) - APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"` - DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"` - ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"` - WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"` - PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"` - QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"` - // EventFlushBatchSize: WS 流式写出批量 flush 阈值(事件条数) - EventFlushBatchSize int `mapstructure:"event_flush_batch_size"` - // EventFlushIntervalMS: WS 流式写出最大等待时间(毫秒);0 表示仅按 batch 触发 - EventFlushIntervalMS int `mapstructure:"event_flush_interval_ms"` - // PrewarmCooldownMS: 连接池预热触发冷却时间(毫秒) - PrewarmCooldownMS int `mapstructure:"prewarm_cooldown_ms"` - // FallbackCooldownSeconds: WS 回退冷却窗口,避免 WS/HTTP 抖动;0 表示关闭冷却 - FallbackCooldownSeconds int `mapstructure:"fallback_cooldown_seconds"` - // RetryBackoffInitialMS: WS 重试初始退避(毫秒);<=0 表示关闭退避 - RetryBackoffInitialMS int `mapstructure:"retry_backoff_initial_ms"` - // RetryBackoffMaxMS: WS 重试最大退避(毫秒) - RetryBackoffMaxMS int `mapstructure:"retry_backoff_max_ms"` - // RetryJitterRatio: WS 重试退避抖动比例(0-1) - RetryJitterRatio float64 `mapstructure:"retry_jitter_ratio"` - // RetryTotalBudgetMS: WS 单次请求重试总预算(毫秒);0 表示关闭预算限制 - RetryTotalBudgetMS int `mapstructure:"retry_total_budget_ms"` - // PayloadLogSampleRate: payload_schema 日志采样率(0-1) - PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"` - - // 账号调度与粘连参数 - LBTopK int `mapstructure:"lb_top_k"` - // StickySessionTTLSeconds: session_hash -> account_id 粘连 TTL - StickySessionTTLSeconds int `mapstructure:"sticky_session_ttl_seconds"` - // SessionHashReadOldFallback: 会话哈希迁移期是否允许“新 key 未命中时回退读旧 SHA-256 key” - SessionHashReadOldFallback bool `mapstructure:"session_hash_read_old_fallback"` - // SessionHashDualWriteOld: 会话哈希迁移期是否双写旧 SHA-256 key(短 TTL) - SessionHashDualWriteOld bool `mapstructure:"session_hash_dual_write_old"` - // MetadataBridgeEnabled: RequestMetadata 迁移期是否保留旧 ctxkey.* 兼容桥接 - MetadataBridgeEnabled bool `mapstructure:"metadata_bridge_enabled"` - // StickyResponseIDTTLSeconds: response_id -> account_id 粘连 TTL - StickyResponseIDTTLSeconds int `mapstructure:"sticky_response_id_ttl_seconds"` - // StickyPreviousResponseTTLSeconds: 兼容旧键(当新键未设置时回退) - StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"` - - SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"` -} - -// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重。 -type GatewayOpenAIWSSchedulerScoreWeights struct { - Priority float64 `mapstructure:"priority"` - Load float64 `mapstructure:"load"` - Queue float64 `mapstructure:"queue"` - ErrorRate float64 `mapstructure:"error_rate"` - TTFT float64 `mapstructure:"ttft"` -} - -// GatewayUsageRecordConfig 使用量记录异步队列配置 -type GatewayUsageRecordConfig struct { - // WorkerCount: worker 初始数量(自动扩缩容开启时作为初始并发上限) - WorkerCount int `mapstructure:"worker_count"` - // QueueSize: 队列容量(有界) - QueueSize int `mapstructure:"queue_size"` - // TaskTimeoutSeconds: 单个使用量记录任务超时(秒) - TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"` - // OverflowPolicy: 队列满时策略(drop/sample/sync) - OverflowPolicy string `mapstructure:"overflow_policy"` - // OverflowSamplePercent: sample 策略下,同步回写采样百分比(1-100) - OverflowSamplePercent int `mapstructure:"overflow_sample_percent"` - - // AutoScaleEnabled: 是否启用 worker 自动扩缩容 - AutoScaleEnabled bool `mapstructure:"auto_scale_enabled"` - // AutoScaleMinWorkers: 自动扩缩容最小 worker 数 - AutoScaleMinWorkers int `mapstructure:"auto_scale_min_workers"` - // AutoScaleMaxWorkers: 自动扩缩容最大 worker 数 - AutoScaleMaxWorkers int `mapstructure:"auto_scale_max_workers"` - // AutoScaleUpQueuePercent: 队列占用率达到该阈值时触发扩容 - AutoScaleUpQueuePercent int `mapstructure:"auto_scale_up_queue_percent"` - // AutoScaleDownQueuePercent: 队列占用率低于该阈值时触发缩容 - AutoScaleDownQueuePercent int `mapstructure:"auto_scale_down_queue_percent"` - // AutoScaleUpStep: 每次扩容步长 - AutoScaleUpStep int `mapstructure:"auto_scale_up_step"` - // AutoScaleDownStep: 每次缩容步长 - AutoScaleDownStep int `mapstructure:"auto_scale_down_step"` - // AutoScaleCheckIntervalSeconds: 自动扩缩容检测间隔(秒) - AutoScaleCheckIntervalSeconds int `mapstructure:"auto_scale_check_interval_seconds"` - // AutoScaleCooldownSeconds: 自动扩缩容冷却时间(秒) - AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"` -} - -// TLSFingerprintConfig TLS指纹伪装配置 -// 用于模拟 Claude CLI (Node.js) 的 TLS 握手特征,避免被识别为非官方客户端 -type TLSFingerprintConfig struct { - // Enabled: 是否全局启用TLS指纹功能 - Enabled bool `mapstructure:"enabled"` - // Profiles: 预定义的TLS指纹配置模板 - // key 为模板名称,如 "claude_cli_v2", "chrome_120" 等 - Profiles map[string]TLSProfileConfig `mapstructure:"profiles"` -} - -// TLSProfileConfig 单个TLS指纹模板的配置 -// 所有列表字段为空时使用内置默认值(Claude CLI 2.x / Node.js 20.x) -// 建议通过 TLS 指纹采集工具 (tests/tls-fingerprint-web) 获取完整配置 -type TLSProfileConfig struct { - // Name: 模板显示名称 - Name string `mapstructure:"name"` - // EnableGREASE: 是否启用GREASE扩展(Chrome使用,Node.js不使用) - EnableGREASE bool `mapstructure:"enable_grease"` - // CipherSuites: TLS加密套件列表 - CipherSuites []uint16 `mapstructure:"cipher_suites"` - // Curves: 椭圆曲线列表 - Curves []uint16 `mapstructure:"curves"` - // PointFormats: 点格式列表 - PointFormats []uint16 `mapstructure:"point_formats"` - // SignatureAlgorithms: 签名算法列表 - SignatureAlgorithms []uint16 `mapstructure:"signature_algorithms"` - // ALPNProtocols: ALPN协议列表(如 ["h2", "http/1.1"]) - ALPNProtocols []string `mapstructure:"alpn_protocols"` - // SupportedVersions: 支持的TLS版本列表(如 [0x0304, 0x0303] 即 TLS1.3, TLS1.2) - SupportedVersions []uint16 `mapstructure:"supported_versions"` - // KeyShareGroups: Key Share中发送的曲线组(如 [29] 即 X25519) - KeyShareGroups []uint16 `mapstructure:"key_share_groups"` - // PSKModes: PSK密钥交换模式(如 [1] 即 psk_dhe_ke) - PSKModes []uint16 `mapstructure:"psk_modes"` - // Extensions: TLS扩展类型ID列表,按发送顺序排列 - // 空则使用内置默认顺序 [0,11,10,35,16,22,23,13,43,45,51] - // GREASE值(如0x0a0a)会自动插入GREASE扩展 - Extensions []uint16 `mapstructure:"extensions"` -} - -// GatewaySchedulingConfig accounts scheduling configuration. -type GatewaySchedulingConfig struct { - // 粘性会话排队配置 - StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"` - StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"` - - // 兜底排队配置 - FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"` - FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"` - - // 兜底层账户选择策略: "last_used"(按最后使用时间排序,默认) 或 "random"(随机) - FallbackSelectionMode string `mapstructure:"fallback_selection_mode"` - - // 负载计算 - LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` - // 快照桶读取时的 MGET 分块大小 - SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"` - // 快照重建时的缓存写入分块大小 - SnapshotWriteChunkSize int `mapstructure:"snapshot_write_chunk_size"` - - // 过期槽位清理周期(0 表示禁用) - SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"` - - // 受控回源配置 - DbFallbackEnabled bool `mapstructure:"db_fallback_enabled"` - // 受控回源超时(秒),0 表示不额外收紧超时 - DbFallbackTimeoutSeconds int `mapstructure:"db_fallback_timeout_seconds"` - // 受控回源限流(实例级 QPS),0 表示不限制 - DbFallbackMaxQPS int `mapstructure:"db_fallback_max_qps"` - - // Outbox 轮询与滞后阈值配置 - // Outbox 轮询周期(秒) - OutboxPollIntervalSeconds int `mapstructure:"outbox_poll_interval_seconds"` - // Outbox 滞后告警阈值(秒) - OutboxLagWarnSeconds int `mapstructure:"outbox_lag_warn_seconds"` - // Outbox 触发强制重建阈值(秒) - OutboxLagRebuildSeconds int `mapstructure:"outbox_lag_rebuild_seconds"` - // Outbox 连续滞后触发次数 - OutboxLagRebuildFailures int `mapstructure:"outbox_lag_rebuild_failures"` - // Outbox 积压触发重建阈值(行数) - OutboxBacklogRebuildRows int `mapstructure:"outbox_backlog_rebuild_rows"` - - // 全量重建周期配置 - // 全量重建周期(秒),0 表示禁用 - FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"` -} - -func (s *ServerConfig) Address() string { - return fmt.Sprintf("%s:%d", s.Host, s.Port) -} - -// DatabaseConfig 数据库连接配置 -// 性能优化:新增连接池参数,避免频繁创建/销毁连接 -type DatabaseConfig struct { - Host string `mapstructure:"host"` - Port int `mapstructure:"port"` - User string `mapstructure:"user"` - Password string `mapstructure:"password"` - DBName string `mapstructure:"dbname"` - SSLMode string `mapstructure:"sslmode"` - // 连接池配置(性能优化:可配置化连接池参数) - // MaxOpenConns: 最大打开连接数,控制数据库连接上限,防止资源耗尽 - MaxOpenConns int `mapstructure:"max_open_conns"` - // MaxIdleConns: 最大空闲连接数,保持热连接减少建连延迟 - MaxIdleConns int `mapstructure:"max_idle_conns"` - // ConnMaxLifetimeMinutes: 连接最大存活时间,防止长连接导致的资源泄漏 - ConnMaxLifetimeMinutes int `mapstructure:"conn_max_lifetime_minutes"` - // ConnMaxIdleTimeMinutes: 空闲连接最大存活时间,及时释放不活跃连接 - ConnMaxIdleTimeMinutes int `mapstructure:"conn_max_idle_time_minutes"` -} - -func (d *DatabaseConfig) DSN() string { - // 当密码为空时不包含 password 参数,避免 libpq 解析错误 - if d.Password == "" { - return fmt.Sprintf( - "host=%s port=%d user=%s dbname=%s sslmode=%s", - d.Host, d.Port, d.User, d.DBName, d.SSLMode, - ) - } - return fmt.Sprintf( - "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", - d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, - ) -} - -// DSNWithTimezone returns DSN with timezone setting -func (d *DatabaseConfig) DSNWithTimezone(tz string) string { - if tz == "" { - tz = "Asia/Shanghai" - } - // 当密码为空时不包含 password 参数,避免 libpq 解析错误 - if d.Password == "" { - return fmt.Sprintf( - "host=%s port=%d user=%s dbname=%s sslmode=%s TimeZone=%s", - d.Host, d.Port, d.User, d.DBName, d.SSLMode, tz, - ) - } - return fmt.Sprintf( - "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s", - d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, tz, - ) -} - -// RedisConfig Redis 连接配置 -// 性能优化:新增连接池和超时参数,提升高并发场景下的吞吐量 -type RedisConfig struct { - Host string `mapstructure:"host"` - Port int `mapstructure:"port"` - Password string `mapstructure:"password"` - DB int `mapstructure:"db"` - // 连接池与超时配置(性能优化:可配置化连接池参数) - // DialTimeoutSeconds: 建立连接超时,防止慢连接阻塞 - DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"` - // ReadTimeoutSeconds: 读取超时,避免慢查询阻塞连接池 - ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"` - // WriteTimeoutSeconds: 写入超时,避免慢写入阻塞连接池 - WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"` - // PoolSize: 连接池大小,控制最大并发连接数 - PoolSize int `mapstructure:"pool_size"` - // MinIdleConns: 最小空闲连接数,保持热连接减少冷启动延迟 - MinIdleConns int `mapstructure:"min_idle_conns"` - // EnableTLS: 是否启用 TLS/SSL 连接 - EnableTLS bool `mapstructure:"enable_tls"` -} - -func (r *RedisConfig) Address() string { - return fmt.Sprintf("%s:%d", r.Host, r.Port) -} - -type OpsConfig struct { - // Enabled controls whether ops features should run. - // - // NOTE: vNext still has a DB-backed feature flag (ops_monitoring_enabled) for runtime on/off. - // This config flag is the "hard switch" for deployments that want to disable ops completely. - Enabled bool `mapstructure:"enabled"` - - // UsePreaggregatedTables prefers ops_metrics_hourly/daily for long-window dashboard queries. - UsePreaggregatedTables bool `mapstructure:"use_preaggregated_tables"` - - // Cleanup controls periodic deletion of old ops data to prevent unbounded growth. - Cleanup OpsCleanupConfig `mapstructure:"cleanup"` - - // MetricsCollectorCache controls Redis caching for expensive per-window collector queries. - MetricsCollectorCache OpsMetricsCollectorCacheConfig `mapstructure:"metrics_collector_cache"` - - // Pre-aggregation configuration. - Aggregation OpsAggregationConfig `mapstructure:"aggregation"` -} - -type OpsCleanupConfig struct { - Enabled bool `mapstructure:"enabled"` - Schedule string `mapstructure:"schedule"` - - // Retention days (0 disables that cleanup target). - // - // vNext requirement: default 30 days across ops datasets. - ErrorLogRetentionDays int `mapstructure:"error_log_retention_days"` - MinuteMetricsRetentionDays int `mapstructure:"minute_metrics_retention_days"` - HourlyMetricsRetentionDays int `mapstructure:"hourly_metrics_retention_days"` -} - -type OpsAggregationConfig struct { - Enabled bool `mapstructure:"enabled"` -} - -type OpsMetricsCollectorCacheConfig struct { - Enabled bool `mapstructure:"enabled"` - TTL time.Duration `mapstructure:"ttl"` -} - -type JWTConfig struct { - Secret string `mapstructure:"secret"` - ExpireHour int `mapstructure:"expire_hour"` - // AccessTokenExpireMinutes: Access Token有效期(分钟) - // - >0: 使用分钟配置(优先级高于 ExpireHour) - // - =0: 回退使用 ExpireHour(向后兼容旧配置) - AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"` - // RefreshTokenExpireDays: Refresh Token有效期(天),默认30天 - RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"` - // RefreshWindowMinutes: 刷新窗口(分钟),在Access Token过期前多久开始允许刷新 - RefreshWindowMinutes int `mapstructure:"refresh_window_minutes"` -} - -// TotpConfig TOTP 双因素认证配置 -type TotpConfig struct { - // EncryptionKey 用于加密 TOTP 密钥的 AES-256 密钥(32 字节 hex 编码) - // 如果为空,将自动生成一个随机密钥(仅适用于开发环境) - EncryptionKey string `mapstructure:"encryption_key"` - // EncryptionKeyConfigured 标记加密密钥是否为手动配置(非自动生成) - // 只有手动配置了密钥才允许在管理后台启用 TOTP 功能 - EncryptionKeyConfigured bool `mapstructure:"-"` -} - -type TurnstileConfig struct { - Required bool `mapstructure:"required"` -} - -type DefaultConfig struct { - AdminEmail string `mapstructure:"admin_email"` - AdminPassword string `mapstructure:"admin_password"` - UserConcurrency int `mapstructure:"user_concurrency"` - UserBalance float64 `mapstructure:"user_balance"` - APIKeyPrefix string `mapstructure:"api_key_prefix"` - RateMultiplier float64 `mapstructure:"rate_multiplier"` -} - -type RateLimitConfig struct { - OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"` // 529过载冷却时间(分钟) - OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"` // OAuth 401临时不可调度冷却(分钟) -} - -// APIKeyAuthCacheConfig API Key 认证缓存配置 -type APIKeyAuthCacheConfig struct { - L1Size int `mapstructure:"l1_size"` - L1TTLSeconds int `mapstructure:"l1_ttl_seconds"` - L2TTLSeconds int `mapstructure:"l2_ttl_seconds"` - NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"` - JitterPercent int `mapstructure:"jitter_percent"` - Singleflight bool `mapstructure:"singleflight"` -} - -// SubscriptionCacheConfig 订阅认证 L1 缓存配置 -type SubscriptionCacheConfig struct { - L1Size int `mapstructure:"l1_size"` - L1TTLSeconds int `mapstructure:"l1_ttl_seconds"` - JitterPercent int `mapstructure:"jitter_percent"` -} - -// SubscriptionMaintenanceConfig 订阅窗口维护后台任务配置。 -// 用于将“请求路径触发的维护动作”有界化,避免高并发下 goroutine 膨胀。 -type SubscriptionMaintenanceConfig struct { - WorkerCount int `mapstructure:"worker_count"` - QueueSize int `mapstructure:"queue_size"` -} - -// DashboardCacheConfig 仪表盘统计缓存配置 -type DashboardCacheConfig struct { - // Enabled: 是否启用仪表盘缓存 - Enabled bool `mapstructure:"enabled"` - // KeyPrefix: Redis key 前缀,用于多环境隔离 - KeyPrefix string `mapstructure:"key_prefix"` - // StatsFreshTTLSeconds: 缓存命中认为“新鲜”的时间窗口(秒) - StatsFreshTTLSeconds int `mapstructure:"stats_fresh_ttl_seconds"` - // StatsTTLSeconds: Redis 缓存总 TTL(秒) - StatsTTLSeconds int `mapstructure:"stats_ttl_seconds"` - // StatsRefreshTimeoutSeconds: 异步刷新超时(秒) - StatsRefreshTimeoutSeconds int `mapstructure:"stats_refresh_timeout_seconds"` -} - -// DashboardAggregationConfig 仪表盘预聚合配置 -type DashboardAggregationConfig struct { - // Enabled: 是否启用预聚合作业 - Enabled bool `mapstructure:"enabled"` - // IntervalSeconds: 聚合刷新间隔(秒) - IntervalSeconds int `mapstructure:"interval_seconds"` - // LookbackSeconds: 回看窗口(秒) - LookbackSeconds int `mapstructure:"lookback_seconds"` - // BackfillEnabled: 是否允许全量回填 - BackfillEnabled bool `mapstructure:"backfill_enabled"` - // BackfillMaxDays: 回填最大跨度(天) - BackfillMaxDays int `mapstructure:"backfill_max_days"` - // Retention: 各表保留窗口(天) - Retention DashboardAggregationRetentionConfig `mapstructure:"retention"` - // RecomputeDays: 启动时重算最近 N 天 - RecomputeDays int `mapstructure:"recompute_days"` -} - -// DashboardAggregationRetentionConfig 预聚合保留窗口 -type DashboardAggregationRetentionConfig struct { - UsageLogsDays int `mapstructure:"usage_logs_days"` - UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"` - HourlyDays int `mapstructure:"hourly_days"` - DailyDays int `mapstructure:"daily_days"` -} - -// UsageCleanupConfig 使用记录清理任务配置 -type UsageCleanupConfig struct { - // Enabled: 是否启用清理任务执行器 - Enabled bool `mapstructure:"enabled"` - // MaxRangeDays: 单次任务允许的最大时间跨度(天) - MaxRangeDays int `mapstructure:"max_range_days"` - // BatchSize: 单批删除数量 - BatchSize int `mapstructure:"batch_size"` - // WorkerIntervalSeconds: 后台任务轮询间隔(秒) - WorkerIntervalSeconds int `mapstructure:"worker_interval_seconds"` - // TaskTimeoutSeconds: 单次任务最大执行时长(秒) - TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"` -} - func NormalizeRunMode(value string) string { normalized := strings.ToLower(strings.TrimSpace(value)) switch normalized { @@ -1018,47 +96,30 @@ func NormalizeRunMode(value string) string { } // Load 读取并校验完整配置(要求 jwt.secret 已显式提供)。 -func Load() (*Config, error) { - return load(false) -} +func Load() (*Config, error) { return load(false) } -// LoadForBootstrap 读取启动阶段配置。 -// -// 启动阶段允许 jwt.secret 先留空,后续由数据库初始化流程补齐并再次完整校验。 -func LoadForBootstrap() (*Config, error) { - return load(true) -} +// LoadForBootstrap 读取启动阶段配置。允许 jwt.secret 先留空。 +func LoadForBootstrap() (*Config, error) { return load(true) } func load(allowMissingJWTSecret bool) (*Config, error) { viper.SetConfigName("config") viper.SetConfigType("yaml") // Add config paths in priority order - // 1. DATA_DIR environment variable (highest priority) - if dataDir := os.Getenv("DATA_DIR"); dataDir != "" { - viper.AddConfigPath(dataDir) - } - // 2. Docker data directory + if dataDir := os.Getenv("DATA_DIR"); dataDir != "" { viper.AddConfigPath(dataDir) } viper.AddConfigPath("/app/data") - // 3. Current directory viper.AddConfigPath(".") - // 4. Config subdirectory viper.AddConfigPath("./config") - // 5. System config directory viper.AddConfigPath("/etc/sub2api") - // 环境变量支持 viper.AutomaticEnv() viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) - - // 默认值 setDefaults() if err := viper.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); !ok { return nil, fmt.Errorf("read config error: %w", err) } - // 配置文件不存在时使用默认值 } var cfg Config @@ -1068,10 +129,53 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.RunMode = NormalizeRunMode(cfg.RunMode) cfg.Server.Mode = strings.ToLower(strings.TrimSpace(cfg.Server.Mode)) - if cfg.Server.Mode == "" { - cfg.Server.Mode = "debug" - } + if cfg.Server.Mode == "" { cfg.Server.Mode = "debug" } cfg.Server.FrontendURL = strings.TrimSpace(cfg.Server.FrontendURL) + normalizeAllStringFields(&cfg) + + if err := loadCodexTemplate(&cfg); err != nil { return nil, err } + + // 兼容旧键 sticky_previous_response_ttl_seconds + if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { + cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds + } + + // Normalize UMQ mode + if m := cfg.Gateway.UserMessageQueue.Mode; m != "" && m != UMQModeSerialize && m != UMQModeThrottle { + slog.Warn("invalid user_message_queue mode, disabling", "mode", m, + "valid_modes", []string{UMQModeSerialize, UMQModeThrottle}) + cfg.Gateway.UserMessageQueue.Mode = "" + } + + // Auto-generate TOTP encryption key if not set + cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey) + if cfg.Totp.EncryptionKey == "" { + key, err := generateJWTSecret(32) + if err != nil { return nil, fmt.Errorf("generate totp encryption key error: %w", err) } + cfg.Totp.EncryptionKey = key + cfg.Totp.EncryptionKeyConfigured = false + slog.Warn("TOTP encryption key auto-generated. Consider setting a fixed key for production.") + } else { + cfg.Totp.EncryptionKeyConfigured = true + } + + originalJWTSecret := cfg.JWT.Secret + if allowMissingJWTSecret && originalJWTSecret == "" { + cfg.JWT.Secret = strings.Repeat("0", 32) + } + + if err := cfg.Validate(); err != nil { + return nil, fmt.Errorf("validate config error: %w", err) + } + + if allowMissingJWTSecret && originalJWTSecret == "" { cfg.JWT.Secret = "" } + + logSecurityWarnings(&cfg) + return &cfg, nil +} + +// normalizeAllStringFields trims all string fields loaded from config. +func normalizeAllStringFields(cfg *Config) { cfg.JWT.Secret = strings.TrimSpace(cfg.JWT.Secret) cfg.LinuxDo.ClientID = strings.TrimSpace(cfg.LinuxDo.ClientID) cfg.LinuxDo.ClientSecret = strings.TrimSpace(cfg.LinuxDo.ClientSecret) @@ -1114,1383 +218,32 @@ func load(allowMissingJWTSecret bool) (*Config, error) { cfg.Log.StacktraceLevel = strings.ToLower(strings.TrimSpace(cfg.Log.StacktraceLevel)) cfg.Log.Output.FilePath = strings.TrimSpace(cfg.Log.Output.FilePath) cfg.Gateway.ForcedCodexInstructionsTemplateFile = strings.TrimSpace(cfg.Gateway.ForcedCodexInstructionsTemplateFile) +} + +func loadCodexTemplate(cfg *Config) error { if cfg.Gateway.ForcedCodexInstructionsTemplateFile != "" { content, err := os.ReadFile(cfg.Gateway.ForcedCodexInstructionsTemplateFile) if err != nil { - return nil, fmt.Errorf("read forced codex instructions template %q: %w", cfg.Gateway.ForcedCodexInstructionsTemplateFile, err) + return fmt.Errorf("read forced codex instructions template %q: %w", cfg.Gateway.ForcedCodexInstructionsTemplateFile, err) } cfg.Gateway.ForcedCodexInstructionsTemplate = string(content) } + return nil +} - // 兼容旧键 gateway.openai_ws.sticky_previous_response_ttl_seconds。 - // 新键未配置(<=0)时回退旧键;新键优先。 - if cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { - cfg.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = cfg.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds - } - - // Normalize UMQ mode: 白名单校验,非法值在加载时一次性 warn 并清空 - if m := cfg.Gateway.UserMessageQueue.Mode; m != "" && m != UMQModeSerialize && m != UMQModeThrottle { - slog.Warn("invalid user_message_queue mode, disabling", - "mode", m, - "valid_modes", []string{UMQModeSerialize, UMQModeThrottle}) - cfg.Gateway.UserMessageQueue.Mode = "" - } - - // Auto-generate TOTP encryption key if not set (32 bytes = 64 hex chars for AES-256) - cfg.Totp.EncryptionKey = strings.TrimSpace(cfg.Totp.EncryptionKey) - if cfg.Totp.EncryptionKey == "" { - key, err := generateJWTSecret(32) // Reuse the same random generation function - if err != nil { - return nil, fmt.Errorf("generate totp encryption key error: %w", err) - } - cfg.Totp.EncryptionKey = key - cfg.Totp.EncryptionKeyConfigured = false - slog.Warn("TOTP encryption key auto-generated. Consider setting a fixed key for production.") - } else { - cfg.Totp.EncryptionKeyConfigured = true - } - - originalJWTSecret := cfg.JWT.Secret - if allowMissingJWTSecret && originalJWTSecret == "" { - // 启动阶段允许先无 JWT 密钥,后续在数据库初始化后补齐。 - cfg.JWT.Secret = strings.Repeat("0", 32) - } - - if err := cfg.Validate(); err != nil { - return nil, fmt.Errorf("validate config error: %w", err) - } - - if allowMissingJWTSecret && originalJWTSecret == "" { - cfg.JWT.Secret = "" - } - +func logSecurityWarnings(cfg *Config) { if !cfg.Security.URLAllowlist.Enabled { slog.Warn("security.url_allowlist.enabled=false; allowlist/SSRF checks disabled (minimal format validation only).") } if !cfg.Security.ResponseHeaders.Enabled { slog.Warn("security.response_headers.enabled=false; configurable header filtering disabled (default allowlist only).") } - if cfg.JWT.Secret != "" && isWeakJWTSecret(cfg.JWT.Secret) { slog.Warn("JWT secret appears weak; use a 32+ character random secret in production.") } if len(cfg.Security.ResponseHeaders.AdditionalAllowed) > 0 || len(cfg.Security.ResponseHeaders.ForceRemove) > 0 { slog.Info("response header policy configured", "additional_allowed", cfg.Security.ResponseHeaders.AdditionalAllowed, - "force_remove", cfg.Security.ResponseHeaders.ForceRemove, - ) - } - - return &cfg, nil -} - -func setDefaults() { - viper.SetDefault("run_mode", RunModeStandard) - - // Server - viper.SetDefault("server.host", "0.0.0.0") - viper.SetDefault("server.port", 8080) - viper.SetDefault("server.mode", "release") - viper.SetDefault("server.frontend_url", "") - viper.SetDefault("server.read_header_timeout", 30) // 30秒读取请求头 - viper.SetDefault("server.idle_timeout", 120) // 120秒空闲超时 - viper.SetDefault("server.trusted_proxies", []string{}) - viper.SetDefault("server.max_request_body_size", int64(256*1024*1024)) - // H2C 默认配置 - viper.SetDefault("server.h2c.enabled", false) - viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) // 50 个并发流 - viper.SetDefault("server.h2c.idle_timeout", 75) // 75 秒 - viper.SetDefault("server.h2c.max_read_frame_size", 1<<20) // 1MB(够用) - viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) // 2MB - viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) // 512KB - - // Log - viper.SetDefault("log.level", "info") - viper.SetDefault("log.format", "console") - viper.SetDefault("log.service_name", "sub2api") - viper.SetDefault("log.env", "production") - viper.SetDefault("log.caller", true) - viper.SetDefault("log.stacktrace_level", "error") - viper.SetDefault("log.output.to_stdout", true) - viper.SetDefault("log.output.to_file", true) - viper.SetDefault("log.output.file_path", "") - viper.SetDefault("log.rotation.max_size_mb", 100) - viper.SetDefault("log.rotation.max_backups", 10) - viper.SetDefault("log.rotation.max_age_days", 7) - viper.SetDefault("log.rotation.compress", true) - viper.SetDefault("log.rotation.local_time", true) - viper.SetDefault("log.sampling.enabled", false) - viper.SetDefault("log.sampling.initial", 100) - viper.SetDefault("log.sampling.thereafter", 100) - - // CORS - viper.SetDefault("cors.allowed_origins", []string{}) - viper.SetDefault("cors.allow_credentials", true) - - // Security - viper.SetDefault("security.url_allowlist.enabled", false) - viper.SetDefault("security.url_allowlist.upstream_hosts", []string{ - // 国际模型 - "api.openai.com", - "api.anthropic.com", - "generativelanguage.googleapis.com", - "cloudcode-pa.googleapis.com", - "*.openai.azure.com", - // 国内模型 - 月之暗面Kimi - "api.kimi.com", - "api.moonshot.cn", - // 国内模型 - 智谱GLM - "open.bigmodel.cn", - "bigmodel.cn", - // 国内模型 - MiniMax - "api.minimaxi.com", - "minimaxi.com", - // 国内模型 - 阿里云通义千问 - "dashscope.aliyuncs.com", - "dashscope.aliyun.com", - // 国内模型 - 豆包/火山引擎 - "ark.cn-beijing.volces.com", - "ark-api.volces.com", - "api.volcengine.com", - // 国内模型 - DeepSeek - "api.deepseek.com", - // 国内模型 - 百度文心 - "aip.baidubce.com", - // 国内模型 - 讯飞星火 - "spark-api-open.xf-yun.com", - // 国内模型 - 腾讯混元 - "hunyuan.tencentcloudapi.com", - // 国内模型 - 零一万物 - "api.lingyiwanwu.com", - // 国内模型 - 百川智能 - "api.baichuan-ai.com", - // 国内模型 - 硅基流动SiliconFlow - "api.siliconflow.cn", - // 国内模型 - 智谱API域名(国际) - "api.z.ai", - // 国内模型 - Groq (加速推理) - "api.groq.com", - }) - viper.SetDefault("security.url_allowlist.pricing_hosts", []string{ - "raw.githubusercontent.com", - }) - viper.SetDefault("security.url_allowlist.crs_hosts", []string{}) - viper.SetDefault("security.url_allowlist.allow_private_hosts", true) - viper.SetDefault("security.url_allowlist.allow_insecure_http", true) - viper.SetDefault("security.response_headers.enabled", true) - viper.SetDefault("security.response_headers.additional_allowed", []string{}) - viper.SetDefault("security.response_headers.force_remove", []string{}) - viper.SetDefault("security.csp.enabled", true) - viper.SetDefault("security.csp.policy", DefaultCSPPolicy) - viper.SetDefault("security.proxy_probe.insecure_skip_verify", false) - - // Security - disable direct fallback on proxy error - viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) - - // Billing - viper.SetDefault("billing.circuit_breaker.enabled", true) - viper.SetDefault("billing.circuit_breaker.failure_threshold", 5) - viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30) - viper.SetDefault("billing.circuit_breaker.half_open_requests", 3) - - // Turnstile - viper.SetDefault("turnstile.required", false) - - // LinuxDo Connect OAuth 登录 - viper.SetDefault("linuxdo_connect.enabled", false) - viper.SetDefault("linuxdo_connect.client_id", "") - viper.SetDefault("linuxdo_connect.client_secret", "") - viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize") - viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token") - viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user") - viper.SetDefault("linuxdo_connect.scopes", "user") - viper.SetDefault("linuxdo_connect.redirect_url", "") - viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback") - viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post") - viper.SetDefault("linuxdo_connect.use_pkce", false) - viper.SetDefault("linuxdo_connect.userinfo_email_path", "") - viper.SetDefault("linuxdo_connect.userinfo_id_path", "") - viper.SetDefault("linuxdo_connect.userinfo_username_path", "") - - // Generic OIDC OAuth 登录 - viper.SetDefault("oidc_connect.enabled", false) - viper.SetDefault("oidc_connect.provider_name", "OIDC") - viper.SetDefault("oidc_connect.client_id", "") - viper.SetDefault("oidc_connect.client_secret", "") - viper.SetDefault("oidc_connect.issuer_url", "") - viper.SetDefault("oidc_connect.discovery_url", "") - viper.SetDefault("oidc_connect.authorize_url", "") - viper.SetDefault("oidc_connect.token_url", "") - viper.SetDefault("oidc_connect.userinfo_url", "") - viper.SetDefault("oidc_connect.jwks_url", "") - viper.SetDefault("oidc_connect.scopes", "openid email profile") - viper.SetDefault("oidc_connect.redirect_url", "") - viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback") - viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post") - viper.SetDefault("oidc_connect.use_pkce", false) - viper.SetDefault("oidc_connect.validate_id_token", true) - viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256") - viper.SetDefault("oidc_connect.clock_skew_seconds", 120) - viper.SetDefault("oidc_connect.require_email_verified", false) - viper.SetDefault("oidc_connect.userinfo_email_path", "") - viper.SetDefault("oidc_connect.userinfo_id_path", "") - viper.SetDefault("oidc_connect.userinfo_username_path", "") - - // Database - viper.SetDefault("database.host", "localhost") - viper.SetDefault("database.port", 5432) - viper.SetDefault("database.user", "postgres") - viper.SetDefault("database.password", "postgres") - viper.SetDefault("database.dbname", "sub2api") - viper.SetDefault("database.sslmode", "prefer") - viper.SetDefault("database.max_open_conns", 256) - viper.SetDefault("database.max_idle_conns", 128) - viper.SetDefault("database.conn_max_lifetime_minutes", 30) - viper.SetDefault("database.conn_max_idle_time_minutes", 5) - - // Redis - viper.SetDefault("redis.host", "localhost") - viper.SetDefault("redis.port", 6379) - viper.SetDefault("redis.password", "") - viper.SetDefault("redis.db", 0) - viper.SetDefault("redis.dial_timeout_seconds", 5) - viper.SetDefault("redis.read_timeout_seconds", 3) - viper.SetDefault("redis.write_timeout_seconds", 3) - viper.SetDefault("redis.pool_size", 1024) - viper.SetDefault("redis.min_idle_conns", 128) - viper.SetDefault("redis.enable_tls", false) - - // Ops (vNext) - viper.SetDefault("ops.enabled", true) - viper.SetDefault("ops.use_preaggregated_tables", true) - viper.SetDefault("ops.cleanup.enabled", true) - viper.SetDefault("ops.cleanup.schedule", "0 2 * * *") - // Retention days: vNext defaults to 30 days across ops datasets. - viper.SetDefault("ops.cleanup.error_log_retention_days", 30) - viper.SetDefault("ops.cleanup.minute_metrics_retention_days", 30) - viper.SetDefault("ops.cleanup.hourly_metrics_retention_days", 30) - viper.SetDefault("ops.aggregation.enabled", true) - viper.SetDefault("ops.metrics_collector_cache.enabled", true) - // TTL should be slightly larger than collection interval (1m) to maximize cross-replica cache hits. - viper.SetDefault("ops.metrics_collector_cache.ttl", 65*time.Second) - - // JWT - viper.SetDefault("jwt.secret", "") - viper.SetDefault("jwt.expire_hour", 24) - viper.SetDefault("jwt.access_token_expire_minutes", 0) // 0 表示回退到 expire_hour - viper.SetDefault("jwt.refresh_token_expire_days", 30) // 30天Refresh Token有效期 - viper.SetDefault("jwt.refresh_window_minutes", 2) // 过期前2分钟开始允许刷新 - - // TOTP - viper.SetDefault("totp.encryption_key", "") - - // Default - // Admin credentials are created via the setup flow (web wizard / CLI / AUTO_SETUP). - // Do not ship fixed defaults here to avoid insecure "known credentials" in production. - viper.SetDefault("default.admin_email", "") - viper.SetDefault("default.admin_password", "") - viper.SetDefault("default.user_concurrency", 5) - viper.SetDefault("default.user_balance", 0) - viper.SetDefault("default.api_key_prefix", "sk-") - viper.SetDefault("default.rate_multiplier", 1.0) - - // RateLimit - viper.SetDefault("rate_limit.overload_cooldown_minutes", 10) - viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10) - - // Pricing - 从 model-price-repo 同步模型定价和上下文窗口数据(固定到 commit,避免分支漂移) - viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.json") - viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.sha256") - viper.SetDefault("pricing.data_dir", "./data") - viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json") - viper.SetDefault("pricing.update_interval_hours", 24) - viper.SetDefault("pricing.hash_check_interval_minutes", 10) - - // Timezone (default to Asia/Shanghai for Chinese users) - viper.SetDefault("timezone", "Asia/Shanghai") - - // API Key auth cache - viper.SetDefault("api_key_auth_cache.l1_size", 65535) - viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15) - viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300) - viper.SetDefault("api_key_auth_cache.negative_ttl_seconds", 30) - viper.SetDefault("api_key_auth_cache.jitter_percent", 10) - viper.SetDefault("api_key_auth_cache.singleflight", true) - - // Subscription auth L1 cache - viper.SetDefault("subscription_cache.l1_size", 16384) - viper.SetDefault("subscription_cache.l1_ttl_seconds", 10) - viper.SetDefault("subscription_cache.jitter_percent", 10) - - // Dashboard cache - viper.SetDefault("dashboard_cache.enabled", true) - viper.SetDefault("dashboard_cache.key_prefix", "sub2api:") - viper.SetDefault("dashboard_cache.stats_fresh_ttl_seconds", 15) - viper.SetDefault("dashboard_cache.stats_ttl_seconds", 30) - viper.SetDefault("dashboard_cache.stats_refresh_timeout_seconds", 30) - - // Dashboard aggregation - viper.SetDefault("dashboard_aggregation.enabled", true) - viper.SetDefault("dashboard_aggregation.interval_seconds", 60) - viper.SetDefault("dashboard_aggregation.lookback_seconds", 120) - viper.SetDefault("dashboard_aggregation.backfill_enabled", false) - viper.SetDefault("dashboard_aggregation.backfill_max_days", 31) - viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90) - viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365) - viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180) - viper.SetDefault("dashboard_aggregation.retention.daily_days", 730) - viper.SetDefault("dashboard_aggregation.recompute_days", 2) - - // Usage cleanup task - viper.SetDefault("usage_cleanup.enabled", true) - viper.SetDefault("usage_cleanup.max_range_days", 31) - viper.SetDefault("usage_cleanup.batch_size", 5000) - viper.SetDefault("usage_cleanup.worker_interval_seconds", 10) - viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800) - - // Idempotency - viper.SetDefault("idempotency.observe_only", true) - viper.SetDefault("idempotency.default_ttl_seconds", 86400) - viper.SetDefault("idempotency.system_operation_ttl_seconds", 3600) - viper.SetDefault("idempotency.processing_timeout_seconds", 30) - viper.SetDefault("idempotency.failed_retry_backoff_seconds", 5) - viper.SetDefault("idempotency.max_stored_response_len", 64*1024) - viper.SetDefault("idempotency.cleanup_interval_seconds", 60) - viper.SetDefault("idempotency.cleanup_batch_size", 500) - - // Gateway - viper.SetDefault("gateway.response_header_timeout", 600) // 600秒(10分钟)等待上游响应头,LLM高负载时可能排队较久 - viper.SetDefault("gateway.log_upstream_error_body", true) - viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) - viper.SetDefault("gateway.inject_beta_for_apikey", false) - viper.SetDefault("gateway.failover_on_400", false) - viper.SetDefault("gateway.max_account_switches", 10) - viper.SetDefault("gateway.max_account_switches_gemini", 3) - viper.SetDefault("gateway.force_codex_cli", false) - viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) - // OpenAI Responses WebSocket(默认开启;可通过 force_http 紧急回滚) - viper.SetDefault("gateway.openai_ws.enabled", true) - viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false) - viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool") - viper.SetDefault("gateway.openai_ws.oauth_enabled", true) - viper.SetDefault("gateway.openai_ws.apikey_enabled", true) - viper.SetDefault("gateway.openai_ws.force_http", false) - viper.SetDefault("gateway.openai_ws.allow_store_recovery", false) - viper.SetDefault("gateway.openai_ws.ingress_previous_response_recovery_enabled", true) - viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict") - viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true) - viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false) - viper.SetDefault("gateway.openai_ws.responses_websockets", false) - viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true) - viper.SetDefault("gateway.openai_ws.max_conns_per_account", 128) - viper.SetDefault("gateway.openai_ws.min_idle_per_account", 4) - viper.SetDefault("gateway.openai_ws.max_idle_per_account", 12) - viper.SetDefault("gateway.openai_ws.dynamic_max_conns_by_account_concurrency_enabled", true) - viper.SetDefault("gateway.openai_ws.oauth_max_conns_factor", 1.0) - viper.SetDefault("gateway.openai_ws.apikey_max_conns_factor", 1.0) - viper.SetDefault("gateway.openai_ws.dial_timeout_seconds", 10) - viper.SetDefault("gateway.openai_ws.read_timeout_seconds", 900) - viper.SetDefault("gateway.openai_ws.write_timeout_seconds", 120) - viper.SetDefault("gateway.openai_ws.pool_target_utilization", 0.7) - viper.SetDefault("gateway.openai_ws.queue_limit_per_conn", 64) - viper.SetDefault("gateway.openai_ws.event_flush_batch_size", 1) - viper.SetDefault("gateway.openai_ws.event_flush_interval_ms", 10) - viper.SetDefault("gateway.openai_ws.prewarm_cooldown_ms", 300) - viper.SetDefault("gateway.openai_ws.fallback_cooldown_seconds", 30) - viper.SetDefault("gateway.openai_ws.retry_backoff_initial_ms", 120) - viper.SetDefault("gateway.openai_ws.retry_backoff_max_ms", 2000) - viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2) - viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000) - viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2) - viper.SetDefault("gateway.openai_ws.lb_top_k", 7) - viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600) - viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true) - viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true) - viper.SetDefault("gateway.openai_ws.metadata_bridge_enabled", true) - viper.SetDefault("gateway.openai_ws.sticky_response_id_ttl_seconds", 3600) - viper.SetDefault("gateway.openai_ws.sticky_previous_response_ttl_seconds", 3600) - viper.SetDefault("gateway.openai_ws.scheduler_score_weights.priority", 1.0) - viper.SetDefault("gateway.openai_ws.scheduler_score_weights.load", 1.0) - viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7) - viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8) - viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5) - viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) - viper.SetDefault("gateway.antigravity_extra_retries", 10) - viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) - viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) - viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) - viper.SetDefault("gateway.gemini_debug_response_headers", false) - viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) - // HTTP 上游连接池配置(针对 5000+ 并发用户优化) - viper.SetDefault("gateway.max_idle_conns", 2560) // 最大空闲连接总数(高并发场景可调大) - viper.SetDefault("gateway.max_idle_conns_per_host", 120) // 每主机最大空闲连接(HTTP/2 场景默认) - viper.SetDefault("gateway.max_conns_per_host", 1024) // 每主机最大连接数(含活跃;流式/HTTP1.1 场景可调大,如 2400+) - viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) // 空闲连接超时(秒) - viper.SetDefault("gateway.max_upstream_clients", 5000) - viper.SetDefault("gateway.client_idle_ttl_seconds", 900) - viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) // 并发槽位过期时间(支持超长请求) - viper.SetDefault("gateway.stream_data_interval_timeout", 180) - viper.SetDefault("gateway.stream_keepalive_interval", 10) - viper.SetDefault("gateway.max_line_size", 500*1024*1024) - viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) - viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second) - viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) - viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) - viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used") - viper.SetDefault("gateway.scheduling.load_batch_enabled", true) - viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128) - viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256) - viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) - viper.SetDefault("gateway.scheduling.db_fallback_enabled", true) - viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0) - viper.SetDefault("gateway.scheduling.db_fallback_max_qps", 0) - viper.SetDefault("gateway.scheduling.outbox_poll_interval_seconds", 1) - viper.SetDefault("gateway.scheduling.outbox_lag_warn_seconds", 5) - viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_seconds", 10) - viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3) - viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000) - viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300) - viper.SetDefault("gateway.usage_record.worker_count", 128) - viper.SetDefault("gateway.usage_record.queue_size", 16384) - viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5) - viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample) - viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10) - viper.SetDefault("gateway.usage_record.auto_scale_enabled", true) - viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128) - viper.SetDefault("gateway.usage_record.auto_scale_max_workers", 512) - viper.SetDefault("gateway.usage_record.auto_scale_up_queue_percent", 70) - viper.SetDefault("gateway.usage_record.auto_scale_down_queue_percent", 15) - viper.SetDefault("gateway.usage_record.auto_scale_up_step", 32) - viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16) - viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3) - viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10) - viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30) - viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15) - // TLS指纹伪装配置(默认关闭,需要账号级别单独启用) - // 用户消息串行队列默认值 - viper.SetDefault("gateway.user_message_queue.enabled", false) - viper.SetDefault("gateway.user_message_queue.lock_ttl_ms", 120000) - viper.SetDefault("gateway.user_message_queue.wait_timeout_ms", 30000) - viper.SetDefault("gateway.user_message_queue.min_delay_ms", 200) - viper.SetDefault("gateway.user_message_queue.max_delay_ms", 2000) - viper.SetDefault("gateway.user_message_queue.cleanup_interval_seconds", 60) - - viper.SetDefault("gateway.tls_fingerprint.enabled", true) - viper.SetDefault("concurrency.ping_interval", 10) - - // TokenRefresh - viper.SetDefault("token_refresh.enabled", true) - viper.SetDefault("token_refresh.check_interval_minutes", 5) // 每5分钟检查一次 - viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) // 提前30分钟刷新(适配Google 1小时token) - viper.SetDefault("token_refresh.max_retries", 3) // 最多重试3次 - viper.SetDefault("token_refresh.retry_backoff_seconds", 2) // 重试退避基础2秒 - - // Gemini OAuth - configure via environment variables or config file - // GEMINI_OAUTH_CLIENT_ID and GEMINI_OAUTH_CLIENT_SECRET - // Default: uses Gemini CLI public credentials (set via environment) - viper.SetDefault("gemini.oauth.client_id", "") - viper.SetDefault("gemini.oauth.client_secret", "") - viper.SetDefault("gemini.oauth.scopes", "") - viper.SetDefault("gemini.quota.policy", "") - - // Subscription Maintenance (bounded queue + worker pool) - viper.SetDefault("subscription_maintenance.worker_count", 2) - viper.SetDefault("subscription_maintenance.queue_size", 1024) - -} - -func (c *Config) Validate() error { - jwtSecret := strings.TrimSpace(c.JWT.Secret) - if jwtSecret == "" { - return fmt.Errorf("jwt.secret is required") - } - // NOTE: 按 UTF-8 编码后的字节长度计算。 - // 选择 bytes 而不是 rune 计数,确保二进制/随机串的长度语义更接近“熵”而非“字符数”。 - if len([]byte(jwtSecret)) < 32 { - return fmt.Errorf("jwt.secret must be at least 32 bytes") - } - switch c.Log.Level { - case "debug", "info", "warn", "error": - case "": - return fmt.Errorf("log.level is required") - default: - return fmt.Errorf("log.level must be one of: debug/info/warn/error") - } - switch c.Log.Format { - case "json", "console": - case "": - return fmt.Errorf("log.format is required") - default: - return fmt.Errorf("log.format must be one of: json/console") - } - switch c.Log.StacktraceLevel { - case "none", "error", "fatal": - case "": - return fmt.Errorf("log.stacktrace_level is required") - default: - return fmt.Errorf("log.stacktrace_level must be one of: none/error/fatal") - } - if !c.Log.Output.ToStdout && !c.Log.Output.ToFile { - return fmt.Errorf("log.output.to_stdout and log.output.to_file cannot both be false") - } - if c.Log.Rotation.MaxSizeMB <= 0 { - return fmt.Errorf("log.rotation.max_size_mb must be positive") - } - if c.Log.Rotation.MaxBackups < 0 { - return fmt.Errorf("log.rotation.max_backups must be non-negative") - } - if c.Log.Rotation.MaxAgeDays < 0 { - return fmt.Errorf("log.rotation.max_age_days must be non-negative") - } - if c.Log.Sampling.Enabled { - if c.Log.Sampling.Initial <= 0 { - return fmt.Errorf("log.sampling.initial must be positive when sampling is enabled") - } - if c.Log.Sampling.Thereafter <= 0 { - return fmt.Errorf("log.sampling.thereafter must be positive when sampling is enabled") - } - } else { - if c.Log.Sampling.Initial < 0 { - return fmt.Errorf("log.sampling.initial must be non-negative") - } - if c.Log.Sampling.Thereafter < 0 { - return fmt.Errorf("log.sampling.thereafter must be non-negative") - } - } - - if c.SubscriptionMaintenance.WorkerCount < 0 { - return fmt.Errorf("subscription_maintenance.worker_count must be non-negative") - } - if c.SubscriptionMaintenance.QueueSize < 0 { - return fmt.Errorf("subscription_maintenance.queue_size must be non-negative") - } - - // Gemini OAuth 配置校验:client_id 与 client_secret 必须同时设置或同时留空。 - // 留空时表示使用内置的 Gemini CLI OAuth 客户端(其 client_secret 通过环境变量注入)。 - geminiClientID := strings.TrimSpace(c.Gemini.OAuth.ClientID) - geminiClientSecret := strings.TrimSpace(c.Gemini.OAuth.ClientSecret) - if (geminiClientID == "") != (geminiClientSecret == "") { - return fmt.Errorf("gemini.oauth.client_id and gemini.oauth.client_secret must be both set or both empty") - } - - if strings.TrimSpace(c.Server.FrontendURL) != "" { - if err := ValidateAbsoluteHTTPURL(c.Server.FrontendURL); err != nil { - return fmt.Errorf("server.frontend_url invalid: %w", err) - } - u, err := url.Parse(strings.TrimSpace(c.Server.FrontendURL)) - if err != nil { - return fmt.Errorf("server.frontend_url invalid: %w", err) - } - if u.RawQuery != "" || u.ForceQuery { - return fmt.Errorf("server.frontend_url invalid: must not include query") - } - if u.User != nil { - return fmt.Errorf("server.frontend_url invalid: must not include userinfo") - } - warnIfInsecureURL("server.frontend_url", c.Server.FrontendURL) - } - if c.JWT.ExpireHour <= 0 { - return fmt.Errorf("jwt.expire_hour must be positive") - } - if c.JWT.ExpireHour > 168 { - return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)") - } - if c.JWT.ExpireHour > 24 { - slog.Warn("jwt.expire_hour is high; consider shorter expiration for security", "expire_hour", c.JWT.ExpireHour) - } - // JWT Refresh Token配置验证 - if c.JWT.AccessTokenExpireMinutes < 0 { - return fmt.Errorf("jwt.access_token_expire_minutes must be non-negative") - } - if c.JWT.AccessTokenExpireMinutes > 720 { - slog.Warn("jwt.access_token_expire_minutes is high; consider shorter expiration for security", "access_token_expire_minutes", c.JWT.AccessTokenExpireMinutes) - } - if c.JWT.RefreshTokenExpireDays <= 0 { - return fmt.Errorf("jwt.refresh_token_expire_days must be positive") - } - if c.JWT.RefreshTokenExpireDays > 90 { - slog.Warn("jwt.refresh_token_expire_days is high; consider shorter expiration for security", "refresh_token_expire_days", c.JWT.RefreshTokenExpireDays) - } - if c.JWT.RefreshWindowMinutes < 0 { - return fmt.Errorf("jwt.refresh_window_minutes must be non-negative") - } - if c.Security.CSP.Enabled && strings.TrimSpace(c.Security.CSP.Policy) == "" { - return fmt.Errorf("security.csp.policy is required when CSP is enabled") - } - if c.LinuxDo.Enabled { - if strings.TrimSpace(c.LinuxDo.ClientID) == "" { - return fmt.Errorf("linuxdo_connect.client_id is required when linuxdo_connect.enabled=true") - } - if strings.TrimSpace(c.LinuxDo.AuthorizeURL) == "" { - return fmt.Errorf("linuxdo_connect.authorize_url is required when linuxdo_connect.enabled=true") - } - if strings.TrimSpace(c.LinuxDo.TokenURL) == "" { - return fmt.Errorf("linuxdo_connect.token_url is required when linuxdo_connect.enabled=true") - } - if strings.TrimSpace(c.LinuxDo.UserInfoURL) == "" { - return fmt.Errorf("linuxdo_connect.userinfo_url is required when linuxdo_connect.enabled=true") - } - if strings.TrimSpace(c.LinuxDo.RedirectURL) == "" { - return fmt.Errorf("linuxdo_connect.redirect_url is required when linuxdo_connect.enabled=true") - } - method := strings.ToLower(strings.TrimSpace(c.LinuxDo.TokenAuthMethod)) - switch method { - case "", "client_secret_post", "client_secret_basic", "none": - default: - return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") - } - if method == "none" && !c.LinuxDo.UsePKCE { - return fmt.Errorf("linuxdo_connect.use_pkce must be true when linuxdo_connect.token_auth_method=none") - } - if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && - strings.TrimSpace(c.LinuxDo.ClientSecret) == "" { - return fmt.Errorf("linuxdo_connect.client_secret is required when linuxdo_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") - } - if strings.TrimSpace(c.LinuxDo.FrontendRedirectURL) == "" { - return fmt.Errorf("linuxdo_connect.frontend_redirect_url is required when linuxdo_connect.enabled=true") - } - - if err := ValidateAbsoluteHTTPURL(c.LinuxDo.AuthorizeURL); err != nil { - return fmt.Errorf("linuxdo_connect.authorize_url invalid: %w", err) - } - if err := ValidateAbsoluteHTTPURL(c.LinuxDo.TokenURL); err != nil { - return fmt.Errorf("linuxdo_connect.token_url invalid: %w", err) - } - if err := ValidateAbsoluteHTTPURL(c.LinuxDo.UserInfoURL); err != nil { - return fmt.Errorf("linuxdo_connect.userinfo_url invalid: %w", err) - } - if err := ValidateAbsoluteHTTPURL(c.LinuxDo.RedirectURL); err != nil { - return fmt.Errorf("linuxdo_connect.redirect_url invalid: %w", err) - } - if err := ValidateFrontendRedirectURL(c.LinuxDo.FrontendRedirectURL); err != nil { - return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err) - } - - warnIfInsecureURL("linuxdo_connect.authorize_url", c.LinuxDo.AuthorizeURL) - warnIfInsecureURL("linuxdo_connect.token_url", c.LinuxDo.TokenURL) - warnIfInsecureURL("linuxdo_connect.userinfo_url", c.LinuxDo.UserInfoURL) - warnIfInsecureURL("linuxdo_connect.redirect_url", c.LinuxDo.RedirectURL) - warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", c.LinuxDo.FrontendRedirectURL) - } - if c.OIDC.Enabled { - if strings.TrimSpace(c.OIDC.ClientID) == "" { - return fmt.Errorf("oidc_connect.client_id is required when oidc_connect.enabled=true") - } - if strings.TrimSpace(c.OIDC.IssuerURL) == "" { - return fmt.Errorf("oidc_connect.issuer_url is required when oidc_connect.enabled=true") - } - if strings.TrimSpace(c.OIDC.RedirectURL) == "" { - return fmt.Errorf("oidc_connect.redirect_url is required when oidc_connect.enabled=true") - } - if strings.TrimSpace(c.OIDC.FrontendRedirectURL) == "" { - return fmt.Errorf("oidc_connect.frontend_redirect_url is required when oidc_connect.enabled=true") - } - if !scopeContainsOpenID(c.OIDC.Scopes) { - return fmt.Errorf("oidc_connect.scopes must contain openid") - } - - method := strings.ToLower(strings.TrimSpace(c.OIDC.TokenAuthMethod)) - switch method { - case "", "client_secret_post", "client_secret_basic", "none": - default: - return fmt.Errorf("oidc_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") - } - if method == "none" && !c.OIDC.UsePKCE { - return fmt.Errorf("oidc_connect.use_pkce must be true when oidc_connect.token_auth_method=none") - } - if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && - strings.TrimSpace(c.OIDC.ClientSecret) == "" { - return fmt.Errorf("oidc_connect.client_secret is required when oidc_connect.enabled=true and token_auth_method is client_secret_post/client_secret_basic") - } - if c.OIDC.ClockSkewSeconds < 0 || c.OIDC.ClockSkewSeconds > 600 { - return fmt.Errorf("oidc_connect.clock_skew_seconds must be between 0 and 600") - } - if c.OIDC.ValidateIDToken && strings.TrimSpace(c.OIDC.AllowedSigningAlgs) == "" { - return fmt.Errorf("oidc_connect.allowed_signing_algs is required when oidc_connect.validate_id_token=true") - } - - if err := ValidateAbsoluteHTTPURL(c.OIDC.IssuerURL); err != nil { - return fmt.Errorf("oidc_connect.issuer_url invalid: %w", err) - } - if v := strings.TrimSpace(c.OIDC.DiscoveryURL); v != "" { - if err := ValidateAbsoluteHTTPURL(v); err != nil { - return fmt.Errorf("oidc_connect.discovery_url invalid: %w", err) - } - } - if v := strings.TrimSpace(c.OIDC.AuthorizeURL); v != "" { - if err := ValidateAbsoluteHTTPURL(v); err != nil { - return fmt.Errorf("oidc_connect.authorize_url invalid: %w", err) - } - } - if v := strings.TrimSpace(c.OIDC.TokenURL); v != "" { - if err := ValidateAbsoluteHTTPURL(v); err != nil { - return fmt.Errorf("oidc_connect.token_url invalid: %w", err) - } - } - if v := strings.TrimSpace(c.OIDC.UserInfoURL); v != "" { - if err := ValidateAbsoluteHTTPURL(v); err != nil { - return fmt.Errorf("oidc_connect.userinfo_url invalid: %w", err) - } - } - if v := strings.TrimSpace(c.OIDC.JWKSURL); v != "" { - if err := ValidateAbsoluteHTTPURL(v); err != nil { - return fmt.Errorf("oidc_connect.jwks_url invalid: %w", err) - } - } - if err := ValidateAbsoluteHTTPURL(c.OIDC.RedirectURL); err != nil { - return fmt.Errorf("oidc_connect.redirect_url invalid: %w", err) - } - if err := ValidateFrontendRedirectURL(c.OIDC.FrontendRedirectURL); err != nil { - return fmt.Errorf("oidc_connect.frontend_redirect_url invalid: %w", err) - } - - warnIfInsecureURL("oidc_connect.issuer_url", c.OIDC.IssuerURL) - warnIfInsecureURL("oidc_connect.discovery_url", c.OIDC.DiscoveryURL) - warnIfInsecureURL("oidc_connect.authorize_url", c.OIDC.AuthorizeURL) - warnIfInsecureURL("oidc_connect.token_url", c.OIDC.TokenURL) - warnIfInsecureURL("oidc_connect.userinfo_url", c.OIDC.UserInfoURL) - warnIfInsecureURL("oidc_connect.jwks_url", c.OIDC.JWKSURL) - warnIfInsecureURL("oidc_connect.redirect_url", c.OIDC.RedirectURL) - warnIfInsecureURL("oidc_connect.frontend_redirect_url", c.OIDC.FrontendRedirectURL) - } - if c.Billing.CircuitBreaker.Enabled { - if c.Billing.CircuitBreaker.FailureThreshold <= 0 { - return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive") - } - if c.Billing.CircuitBreaker.ResetTimeoutSeconds <= 0 { - return fmt.Errorf("billing.circuit_breaker.reset_timeout_seconds must be positive") - } - if c.Billing.CircuitBreaker.HalfOpenRequests <= 0 { - return fmt.Errorf("billing.circuit_breaker.half_open_requests must be positive") - } - } - if c.Database.MaxOpenConns <= 0 { - return fmt.Errorf("database.max_open_conns must be positive") - } - if c.Database.MaxIdleConns < 0 { - return fmt.Errorf("database.max_idle_conns must be non-negative") - } - if c.Database.MaxIdleConns > c.Database.MaxOpenConns { - return fmt.Errorf("database.max_idle_conns cannot exceed database.max_open_conns") - } - if c.Database.ConnMaxLifetimeMinutes < 0 { - return fmt.Errorf("database.conn_max_lifetime_minutes must be non-negative") - } - if c.Database.ConnMaxIdleTimeMinutes < 0 { - return fmt.Errorf("database.conn_max_idle_time_minutes must be non-negative") - } - if c.Redis.DialTimeoutSeconds <= 0 { - return fmt.Errorf("redis.dial_timeout_seconds must be positive") - } - if c.Redis.ReadTimeoutSeconds <= 0 { - return fmt.Errorf("redis.read_timeout_seconds must be positive") - } - if c.Redis.WriteTimeoutSeconds <= 0 { - return fmt.Errorf("redis.write_timeout_seconds must be positive") - } - if c.Redis.PoolSize <= 0 { - return fmt.Errorf("redis.pool_size must be positive") - } - if c.Redis.MinIdleConns < 0 { - return fmt.Errorf("redis.min_idle_conns must be non-negative") - } - if c.Redis.MinIdleConns > c.Redis.PoolSize { - return fmt.Errorf("redis.min_idle_conns cannot exceed redis.pool_size") - } - if c.Dashboard.Enabled { - if c.Dashboard.StatsFreshTTLSeconds <= 0 { - return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be positive") - } - if c.Dashboard.StatsTTLSeconds <= 0 { - return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be positive") - } - if c.Dashboard.StatsRefreshTimeoutSeconds <= 0 { - return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be positive") - } - if c.Dashboard.StatsFreshTTLSeconds > c.Dashboard.StatsTTLSeconds { - return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be <= dashboard_cache.stats_ttl_seconds") - } - } else { - if c.Dashboard.StatsFreshTTLSeconds < 0 { - return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be non-negative") - } - if c.Dashboard.StatsTTLSeconds < 0 { - return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be non-negative") - } - if c.Dashboard.StatsRefreshTimeoutSeconds < 0 { - return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be non-negative") - } - } - if c.DashboardAgg.Enabled { - if c.DashboardAgg.IntervalSeconds <= 0 { - return fmt.Errorf("dashboard_aggregation.interval_seconds must be positive") - } - if c.DashboardAgg.LookbackSeconds < 0 { - return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative") - } - if c.DashboardAgg.BackfillMaxDays < 0 { - return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative") - } - if c.DashboardAgg.BackfillEnabled && c.DashboardAgg.BackfillMaxDays == 0 { - return fmt.Errorf("dashboard_aggregation.backfill_max_days must be positive") - } - if c.DashboardAgg.Retention.UsageLogsDays <= 0 { - return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be positive") - } - if c.DashboardAgg.Retention.UsageBillingDedupDays <= 0 { - return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be positive") - } - if c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays { - return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days") - } - if c.DashboardAgg.Retention.HourlyDays <= 0 { - return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be positive") - } - if c.DashboardAgg.Retention.DailyDays <= 0 { - return fmt.Errorf("dashboard_aggregation.retention.daily_days must be positive") - } - if c.DashboardAgg.RecomputeDays < 0 { - return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative") - } - } else { - if c.DashboardAgg.IntervalSeconds < 0 { - return fmt.Errorf("dashboard_aggregation.interval_seconds must be non-negative") - } - if c.DashboardAgg.LookbackSeconds < 0 { - return fmt.Errorf("dashboard_aggregation.lookback_seconds must be non-negative") - } - if c.DashboardAgg.BackfillMaxDays < 0 { - return fmt.Errorf("dashboard_aggregation.backfill_max_days must be non-negative") - } - if c.DashboardAgg.Retention.UsageLogsDays < 0 { - return fmt.Errorf("dashboard_aggregation.retention.usage_logs_days must be non-negative") - } - if c.DashboardAgg.Retention.UsageBillingDedupDays < 0 { - return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be non-negative") - } - if c.DashboardAgg.Retention.UsageBillingDedupDays > 0 && - c.DashboardAgg.Retention.UsageLogsDays > 0 && - c.DashboardAgg.Retention.UsageBillingDedupDays < c.DashboardAgg.Retention.UsageLogsDays { - return fmt.Errorf("dashboard_aggregation.retention.usage_billing_dedup_days must be greater than or equal to usage_logs_days") - } - if c.DashboardAgg.Retention.HourlyDays < 0 { - return fmt.Errorf("dashboard_aggregation.retention.hourly_days must be non-negative") - } - if c.DashboardAgg.Retention.DailyDays < 0 { - return fmt.Errorf("dashboard_aggregation.retention.daily_days must be non-negative") - } - if c.DashboardAgg.RecomputeDays < 0 { - return fmt.Errorf("dashboard_aggregation.recompute_days must be non-negative") - } - } - if c.UsageCleanup.Enabled { - if c.UsageCleanup.MaxRangeDays <= 0 { - return fmt.Errorf("usage_cleanup.max_range_days must be positive") - } - if c.UsageCleanup.BatchSize <= 0 { - return fmt.Errorf("usage_cleanup.batch_size must be positive") - } - if c.UsageCleanup.WorkerIntervalSeconds <= 0 { - return fmt.Errorf("usage_cleanup.worker_interval_seconds must be positive") - } - if c.UsageCleanup.TaskTimeoutSeconds <= 0 { - return fmt.Errorf("usage_cleanup.task_timeout_seconds must be positive") - } - } else { - if c.UsageCleanup.MaxRangeDays < 0 { - return fmt.Errorf("usage_cleanup.max_range_days must be non-negative") - } - if c.UsageCleanup.BatchSize < 0 { - return fmt.Errorf("usage_cleanup.batch_size must be non-negative") - } - if c.UsageCleanup.WorkerIntervalSeconds < 0 { - return fmt.Errorf("usage_cleanup.worker_interval_seconds must be non-negative") - } - if c.UsageCleanup.TaskTimeoutSeconds < 0 { - return fmt.Errorf("usage_cleanup.task_timeout_seconds must be non-negative") - } - } - if c.Idempotency.DefaultTTLSeconds <= 0 { - return fmt.Errorf("idempotency.default_ttl_seconds must be positive") - } - if c.Idempotency.SystemOperationTTLSeconds <= 0 { - return fmt.Errorf("idempotency.system_operation_ttl_seconds must be positive") - } - if c.Idempotency.ProcessingTimeoutSeconds <= 0 { - return fmt.Errorf("idempotency.processing_timeout_seconds must be positive") - } - if c.Idempotency.FailedRetryBackoffSeconds <= 0 { - return fmt.Errorf("idempotency.failed_retry_backoff_seconds must be positive") - } - if c.Idempotency.MaxStoredResponseLen <= 0 { - return fmt.Errorf("idempotency.max_stored_response_len must be positive") - } - if c.Idempotency.CleanupIntervalSeconds <= 0 { - return fmt.Errorf("idempotency.cleanup_interval_seconds must be positive") - } - if c.Idempotency.CleanupBatchSize <= 0 { - return fmt.Errorf("idempotency.cleanup_batch_size must be positive") - } - if c.Gateway.MaxBodySize <= 0 { - return fmt.Errorf("gateway.max_body_size must be positive") - } - if c.Gateway.UpstreamResponseReadMaxBytes <= 0 { - return fmt.Errorf("gateway.upstream_response_read_max_bytes must be positive") - } - if c.Gateway.ProxyProbeResponseReadMaxBytes <= 0 { - return fmt.Errorf("gateway.proxy_probe_response_read_max_bytes must be positive") - } - if strings.TrimSpace(c.Gateway.ConnectionPoolIsolation) != "" { - switch c.Gateway.ConnectionPoolIsolation { - case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: - default: - return fmt.Errorf("gateway.connection_pool_isolation must be one of: %s/%s/%s", - ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy) - } - } - if c.Gateway.MaxIdleConns <= 0 { - return fmt.Errorf("gateway.max_idle_conns must be positive") - } - if c.Gateway.MaxIdleConnsPerHost <= 0 { - return fmt.Errorf("gateway.max_idle_conns_per_host must be positive") - } - if c.Gateway.MaxConnsPerHost < 0 { - return fmt.Errorf("gateway.max_conns_per_host must be non-negative") - } - if c.Gateway.IdleConnTimeoutSeconds <= 0 { - return fmt.Errorf("gateway.idle_conn_timeout_seconds must be positive") - } - if c.Gateway.IdleConnTimeoutSeconds > 180 { - slog.Warn("gateway.idle_conn_timeout_seconds is high; consider 60-120 seconds for better connection reuse", "idle_conn_timeout_seconds", c.Gateway.IdleConnTimeoutSeconds) - } - if c.Gateway.MaxUpstreamClients <= 0 { - return fmt.Errorf("gateway.max_upstream_clients must be positive") - } - if c.Gateway.ClientIdleTTLSeconds <= 0 { - return fmt.Errorf("gateway.client_idle_ttl_seconds must be positive") - } - if c.Gateway.ConcurrencySlotTTLMinutes <= 0 { - return fmt.Errorf("gateway.concurrency_slot_ttl_minutes must be positive") - } - if c.Gateway.StreamDataIntervalTimeout < 0 { - return fmt.Errorf("gateway.stream_data_interval_timeout must be non-negative") - } - if c.Gateway.StreamDataIntervalTimeout != 0 && - (c.Gateway.StreamDataIntervalTimeout < 30 || c.Gateway.StreamDataIntervalTimeout > 300) { - return fmt.Errorf("gateway.stream_data_interval_timeout must be 0 or between 30-300 seconds") - } - if c.Gateway.StreamKeepaliveInterval < 0 { - return fmt.Errorf("gateway.stream_keepalive_interval must be non-negative") - } - if c.Gateway.StreamKeepaliveInterval != 0 && - (c.Gateway.StreamKeepaliveInterval < 5 || c.Gateway.StreamKeepaliveInterval > 30) { - return fmt.Errorf("gateway.stream_keepalive_interval must be 0 or between 5-30 seconds") - } - // 兼容旧键 sticky_previous_response_ttl_seconds - if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 && c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds > 0 { - c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds = c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds - } - if c.Gateway.OpenAIWS.MaxConnsPerAccount <= 0 { - return fmt.Errorf("gateway.openai_ws.max_conns_per_account must be positive") - } - if c.Gateway.OpenAIWS.MinIdlePerAccount < 0 { - return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be non-negative") - } - if c.Gateway.OpenAIWS.MaxIdlePerAccount < 0 { - return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be non-negative") - } - if c.Gateway.OpenAIWS.MinIdlePerAccount > c.Gateway.OpenAIWS.MaxIdlePerAccount { - return fmt.Errorf("gateway.openai_ws.min_idle_per_account must be <= max_idle_per_account") - } - if c.Gateway.OpenAIWS.MaxIdlePerAccount > c.Gateway.OpenAIWS.MaxConnsPerAccount { - return fmt.Errorf("gateway.openai_ws.max_idle_per_account must be <= max_conns_per_account") - } - if c.Gateway.OpenAIWS.OAuthMaxConnsFactor <= 0 { - return fmt.Errorf("gateway.openai_ws.oauth_max_conns_factor must be positive") - } - if c.Gateway.OpenAIWS.APIKeyMaxConnsFactor <= 0 { - return fmt.Errorf("gateway.openai_ws.apikey_max_conns_factor must be positive") - } - if c.Gateway.OpenAIWS.DialTimeoutSeconds <= 0 { - return fmt.Errorf("gateway.openai_ws.dial_timeout_seconds must be positive") - } - if c.Gateway.OpenAIWS.ReadTimeoutSeconds <= 0 { - return fmt.Errorf("gateway.openai_ws.read_timeout_seconds must be positive") - } - if c.Gateway.OpenAIWS.WriteTimeoutSeconds <= 0 { - return fmt.Errorf("gateway.openai_ws.write_timeout_seconds must be positive") - } - if c.Gateway.OpenAIWS.PoolTargetUtilization <= 0 || c.Gateway.OpenAIWS.PoolTargetUtilization > 1 { - return fmt.Errorf("gateway.openai_ws.pool_target_utilization must be within (0,1]") - } - if c.Gateway.OpenAIWS.QueueLimitPerConn <= 0 { - return fmt.Errorf("gateway.openai_ws.queue_limit_per_conn must be positive") - } - if c.Gateway.OpenAIWS.EventFlushBatchSize <= 0 { - return fmt.Errorf("gateway.openai_ws.event_flush_batch_size must be positive") - } - if c.Gateway.OpenAIWS.EventFlushIntervalMS < 0 { - return fmt.Errorf("gateway.openai_ws.event_flush_interval_ms must be non-negative") - } - if c.Gateway.OpenAIWS.PrewarmCooldownMS < 0 { - return fmt.Errorf("gateway.openai_ws.prewarm_cooldown_ms must be non-negative") - } - if c.Gateway.OpenAIWS.FallbackCooldownSeconds < 0 { - return fmt.Errorf("gateway.openai_ws.fallback_cooldown_seconds must be non-negative") - } - if c.Gateway.OpenAIWS.RetryBackoffInitialMS < 0 { - return fmt.Errorf("gateway.openai_ws.retry_backoff_initial_ms must be non-negative") - } - if c.Gateway.OpenAIWS.RetryBackoffMaxMS < 0 { - return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be non-negative") - } - if c.Gateway.OpenAIWS.RetryBackoffInitialMS > 0 && c.Gateway.OpenAIWS.RetryBackoffMaxMS > 0 && - c.Gateway.OpenAIWS.RetryBackoffMaxMS < c.Gateway.OpenAIWS.RetryBackoffInitialMS { - return fmt.Errorf("gateway.openai_ws.retry_backoff_max_ms must be >= retry_backoff_initial_ms") - } - if c.Gateway.OpenAIWS.RetryJitterRatio < 0 || c.Gateway.OpenAIWS.RetryJitterRatio > 1 { - return fmt.Errorf("gateway.openai_ws.retry_jitter_ratio must be within [0,1]") - } - if c.Gateway.OpenAIWS.RetryTotalBudgetMS < 0 { - return fmt.Errorf("gateway.openai_ws.retry_total_budget_ms must be non-negative") - } - if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.IngressModeDefault)); mode != "" { - switch mode { - case "off", "ctx_pool", "passthrough": - case "shared", "dedicated": - slog.Warn("gateway.openai_ws.ingress_mode_default is deprecated, treating as ctx_pool; please update to off|ctx_pool|passthrough", "value", mode) - default: - return fmt.Errorf("gateway.openai_ws.ingress_mode_default must be one of off|ctx_pool|passthrough") - } - } - if mode := strings.ToLower(strings.TrimSpace(c.Gateway.OpenAIWS.StoreDisabledConnMode)); mode != "" { - switch mode { - case "strict", "adaptive", "off": - default: - return fmt.Errorf("gateway.openai_ws.store_disabled_conn_mode must be one of strict|adaptive|off") - } - } - if c.Gateway.OpenAIWS.PayloadLogSampleRate < 0 || c.Gateway.OpenAIWS.PayloadLogSampleRate > 1 { - return fmt.Errorf("gateway.openai_ws.payload_log_sample_rate must be within [0,1]") - } - if c.Gateway.OpenAIWS.LBTopK <= 0 { - return fmt.Errorf("gateway.openai_ws.lb_top_k must be positive") - } - if c.Gateway.OpenAIWS.StickySessionTTLSeconds <= 0 { - return fmt.Errorf("gateway.openai_ws.sticky_session_ttl_seconds must be positive") - } - if c.Gateway.OpenAIWS.StickyResponseIDTTLSeconds <= 0 { - return fmt.Errorf("gateway.openai_ws.sticky_response_id_ttl_seconds must be positive") - } - if c.Gateway.OpenAIWS.StickyPreviousResponseTTLSeconds < 0 { - return fmt.Errorf("gateway.openai_ws.sticky_previous_response_ttl_seconds must be non-negative") - } - if c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority < 0 || - c.Gateway.OpenAIWS.SchedulerScoreWeights.Load < 0 || - c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue < 0 || - c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate < 0 || - c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT < 0 { - return fmt.Errorf("gateway.openai_ws.scheduler_score_weights.* must be non-negative") - } - weightSum := c.Gateway.OpenAIWS.SchedulerScoreWeights.Priority + - c.Gateway.OpenAIWS.SchedulerScoreWeights.Load + - c.Gateway.OpenAIWS.SchedulerScoreWeights.Queue + - c.Gateway.OpenAIWS.SchedulerScoreWeights.ErrorRate + - c.Gateway.OpenAIWS.SchedulerScoreWeights.TTFT - if weightSum <= 0 { - return fmt.Errorf("gateway.openai_ws.scheduler_score_weights must not all be zero") - } - if c.Gateway.MaxLineSize < 0 { - return fmt.Errorf("gateway.max_line_size must be non-negative") - } - if c.Gateway.MaxLineSize != 0 && c.Gateway.MaxLineSize < 1024*1024 { - return fmt.Errorf("gateway.max_line_size must be at least 1MB") - } - if c.Gateway.UsageRecord.WorkerCount <= 0 { - return fmt.Errorf("gateway.usage_record.worker_count must be positive") - } - if c.Gateway.UsageRecord.QueueSize <= 0 { - return fmt.Errorf("gateway.usage_record.queue_size must be positive") - } - if c.Gateway.UsageRecord.TaskTimeoutSeconds <= 0 { - return fmt.Errorf("gateway.usage_record.task_timeout_seconds must be positive") - } - switch strings.ToLower(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy)) { - case UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync: - default: - return fmt.Errorf("gateway.usage_record.overflow_policy must be one of: %s/%s/%s", - UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync) - } - if c.Gateway.UsageRecord.OverflowSamplePercent < 0 || c.Gateway.UsageRecord.OverflowSamplePercent > 100 { - return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be between 0-100") - } - if strings.EqualFold(strings.TrimSpace(c.Gateway.UsageRecord.OverflowPolicy), UsageRecordOverflowPolicySample) && - c.Gateway.UsageRecord.OverflowSamplePercent <= 0 { - return fmt.Errorf("gateway.usage_record.overflow_sample_percent must be positive when overflow_policy=sample") - } - if c.Gateway.UsageRecord.AutoScaleEnabled { - if c.Gateway.UsageRecord.AutoScaleMinWorkers <= 0 { - return fmt.Errorf("gateway.usage_record.auto_scale_min_workers must be positive") - } - if c.Gateway.UsageRecord.AutoScaleMaxWorkers <= 0 { - return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be positive") - } - if c.Gateway.UsageRecord.AutoScaleMaxWorkers < c.Gateway.UsageRecord.AutoScaleMinWorkers { - return fmt.Errorf("gateway.usage_record.auto_scale_max_workers must be >= auto_scale_min_workers") - } - if c.Gateway.UsageRecord.WorkerCount < c.Gateway.UsageRecord.AutoScaleMinWorkers || - c.Gateway.UsageRecord.WorkerCount > c.Gateway.UsageRecord.AutoScaleMaxWorkers { - return fmt.Errorf("gateway.usage_record.worker_count must be between auto_scale_min_workers and auto_scale_max_workers") - } - if c.Gateway.UsageRecord.AutoScaleUpQueuePercent <= 0 || c.Gateway.UsageRecord.AutoScaleUpQueuePercent > 100 { - return fmt.Errorf("gateway.usage_record.auto_scale_up_queue_percent must be between 1-100") - } - if c.Gateway.UsageRecord.AutoScaleDownQueuePercent < 0 || c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= 100 { - return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be between 0-99") - } - if c.Gateway.UsageRecord.AutoScaleDownQueuePercent >= c.Gateway.UsageRecord.AutoScaleUpQueuePercent { - return fmt.Errorf("gateway.usage_record.auto_scale_down_queue_percent must be less than auto_scale_up_queue_percent") - } - if c.Gateway.UsageRecord.AutoScaleUpStep <= 0 { - return fmt.Errorf("gateway.usage_record.auto_scale_up_step must be positive") - } - if c.Gateway.UsageRecord.AutoScaleDownStep <= 0 { - return fmt.Errorf("gateway.usage_record.auto_scale_down_step must be positive") - } - if c.Gateway.UsageRecord.AutoScaleCheckIntervalSeconds <= 0 { - return fmt.Errorf("gateway.usage_record.auto_scale_check_interval_seconds must be positive") - } - if c.Gateway.UsageRecord.AutoScaleCooldownSeconds < 0 { - return fmt.Errorf("gateway.usage_record.auto_scale_cooldown_seconds must be non-negative") - } - } - if c.Gateway.UserGroupRateCacheTTLSeconds <= 0 { - return fmt.Errorf("gateway.user_group_rate_cache_ttl_seconds must be positive") - } - if c.Gateway.ModelsListCacheTTLSeconds < 10 || c.Gateway.ModelsListCacheTTLSeconds > 30 { - return fmt.Errorf("gateway.models_list_cache_ttl_seconds must be between 10-30") - } - if c.Gateway.Scheduling.StickySessionMaxWaiting <= 0 { - return fmt.Errorf("gateway.scheduling.sticky_session_max_waiting must be positive") - } - if c.Gateway.Scheduling.StickySessionWaitTimeout <= 0 { - return fmt.Errorf("gateway.scheduling.sticky_session_wait_timeout must be positive") - } - if c.Gateway.Scheduling.FallbackWaitTimeout <= 0 { - return fmt.Errorf("gateway.scheduling.fallback_wait_timeout must be positive") - } - if c.Gateway.Scheduling.FallbackMaxWaiting <= 0 { - return fmt.Errorf("gateway.scheduling.fallback_max_waiting must be positive") - } - if c.Gateway.Scheduling.SnapshotMGetChunkSize <= 0 { - return fmt.Errorf("gateway.scheduling.snapshot_mget_chunk_size must be positive") - } - if c.Gateway.Scheduling.SnapshotWriteChunkSize <= 0 { - return fmt.Errorf("gateway.scheduling.snapshot_write_chunk_size must be positive") - } - if c.Gateway.Scheduling.SlotCleanupInterval < 0 { - return fmt.Errorf("gateway.scheduling.slot_cleanup_interval must be non-negative") - } - if c.Gateway.Scheduling.DbFallbackTimeoutSeconds < 0 { - return fmt.Errorf("gateway.scheduling.db_fallback_timeout_seconds must be non-negative") - } - if c.Gateway.Scheduling.DbFallbackMaxQPS < 0 { - return fmt.Errorf("gateway.scheduling.db_fallback_max_qps must be non-negative") - } - if c.Gateway.Scheduling.OutboxPollIntervalSeconds <= 0 { - return fmt.Errorf("gateway.scheduling.outbox_poll_interval_seconds must be positive") - } - if c.Gateway.Scheduling.OutboxLagWarnSeconds < 0 { - return fmt.Errorf("gateway.scheduling.outbox_lag_warn_seconds must be non-negative") - } - if c.Gateway.Scheduling.OutboxLagRebuildSeconds < 0 { - return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be non-negative") - } - if c.Gateway.Scheduling.OutboxLagRebuildFailures <= 0 { - return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_failures must be positive") - } - if c.Gateway.Scheduling.OutboxBacklogRebuildRows < 0 { - return fmt.Errorf("gateway.scheduling.outbox_backlog_rebuild_rows must be non-negative") - } - if c.Gateway.Scheduling.FullRebuildIntervalSeconds < 0 { - return fmt.Errorf("gateway.scheduling.full_rebuild_interval_seconds must be non-negative") - } - if c.Gateway.Scheduling.OutboxLagWarnSeconds > 0 && - c.Gateway.Scheduling.OutboxLagRebuildSeconds > 0 && - c.Gateway.Scheduling.OutboxLagRebuildSeconds < c.Gateway.Scheduling.OutboxLagWarnSeconds { - return fmt.Errorf("gateway.scheduling.outbox_lag_rebuild_seconds must be >= outbox_lag_warn_seconds") - } - if c.Ops.MetricsCollectorCache.TTL < 0 { - return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative") - } - if c.Ops.Cleanup.ErrorLogRetentionDays < 0 { - return fmt.Errorf("ops.cleanup.error_log_retention_days must be non-negative") - } - if c.Ops.Cleanup.MinuteMetricsRetentionDays < 0 { - return fmt.Errorf("ops.cleanup.minute_metrics_retention_days must be non-negative") - } - if c.Ops.Cleanup.HourlyMetricsRetentionDays < 0 { - return fmt.Errorf("ops.cleanup.hourly_metrics_retention_days must be non-negative") - } - if c.Ops.Cleanup.Enabled && strings.TrimSpace(c.Ops.Cleanup.Schedule) == "" { - return fmt.Errorf("ops.cleanup.schedule is required when ops.cleanup.enabled=true") - } - if c.Concurrency.PingInterval < 5 || c.Concurrency.PingInterval > 30 { - return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds") - } - return nil -} - -func normalizeStringSlice(values []string) []string { - if len(values) == 0 { - return values - } - normalized := make([]string, 0, len(values)) - for _, v := range values { - trimmed := strings.TrimSpace(v) - if trimmed == "" { - continue - } - normalized = append(normalized, trimmed) - } - return normalized -} - -func isWeakJWTSecret(secret string) bool { - lower := strings.ToLower(strings.TrimSpace(secret)) - if lower == "" { - return true - } - weak := map[string]struct{}{ - "change-me-in-production": {}, - "changeme": {}, - "secret": {}, - "password": {}, - "123456": {}, - "12345678": {}, - "admin": {}, - "jwt-secret": {}, - } - _, exists := weak[lower] - return exists -} - -func generateJWTSecret(byteLength int) (string, error) { - if byteLength <= 0 { - byteLength = 32 - } - buf := make([]byte, byteLength) - if _, err := rand.Read(buf); err != nil { - return "", err - } - return hex.EncodeToString(buf), nil -} - -// GetServerAddress returns the server address (host:port) from config file or environment variable. -// This is a lightweight function that can be used before full config validation, -// such as during setup wizard startup. -// Priority: config.yaml > environment variables > defaults -func GetServerAddress() string { - v := viper.New() - v.SetConfigName("config") - v.SetConfigType("yaml") - v.AddConfigPath(".") - v.AddConfigPath("./config") - v.AddConfigPath("/etc/sub2api") - - // Support SERVER_HOST and SERVER_PORT environment variables - v.AutomaticEnv() - v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) - v.SetDefault("server.host", "0.0.0.0") - v.SetDefault("server.port", 8080) - - // Try to read config file (ignore errors if not found) - _ = v.ReadInConfig() - - host := v.GetString("server.host") - port := v.GetInt("server.port") - return fmt.Sprintf("%s:%d", host, port) -} - -// ValidateAbsoluteHTTPURL 验证是否为有效的绝对 HTTP(S) URL -func ValidateAbsoluteHTTPURL(raw string) error { - raw = strings.TrimSpace(raw) - if raw == "" { - return fmt.Errorf("empty url") - } - u, err := url.Parse(raw) - if err != nil { - return err - } - if !u.IsAbs() { - return fmt.Errorf("must be absolute") - } - if !isHTTPScheme(u.Scheme) { - return fmt.Errorf("unsupported scheme: %s", u.Scheme) - } - if strings.TrimSpace(u.Host) == "" { - return fmt.Errorf("missing host") - } - if u.Fragment != "" { - return fmt.Errorf("must not include fragment") - } - return nil -} - -// ValidateFrontendRedirectURL 验证前端重定向 URL(可以是绝对 URL 或相对路径) -func ValidateFrontendRedirectURL(raw string) error { - raw = strings.TrimSpace(raw) - if raw == "" { - return fmt.Errorf("empty url") - } - if strings.ContainsAny(raw, "\r\n") { - return fmt.Errorf("contains invalid characters") - } - if strings.HasPrefix(raw, "/") { - if strings.HasPrefix(raw, "//") { - return fmt.Errorf("must not start with //") - } - return nil - } - u, err := url.Parse(raw) - if err != nil { - return err - } - if !u.IsAbs() { - return fmt.Errorf("must be absolute http(s) url or relative path") - } - if !isHTTPScheme(u.Scheme) { - return fmt.Errorf("unsupported scheme: %s", u.Scheme) - } - if strings.TrimSpace(u.Host) == "" { - return fmt.Errorf("missing host") - } - if u.Fragment != "" { - return fmt.Errorf("must not include fragment") - } - return nil -} - -func scopeContainsOpenID(scopes string) bool { - for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) { - if scope == "openid" { - return true - } - } - return false -} - -// isHTTPScheme 检查是否为 HTTP 或 HTTPS 协议 -func isHTTPScheme(scheme string) bool { - return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https") -} - -func warnIfInsecureURL(field, raw string) { - u, err := url.Parse(strings.TrimSpace(raw)) - if err != nil { - return - } - if strings.EqualFold(u.Scheme, "http") { - slog.Warn("url uses http scheme; use https in production to avoid token leakage", "field", field) + "force_remove", cfg.Security.ResponseHeaders.ForceRemove) } } diff --git a/backend/internal/config/config_defaults.go b/backend/internal/config/config_defaults.go new file mode 100644 index 00000000..40d73fd9 --- /dev/null +++ b/backend/internal/config/config_defaults.go @@ -0,0 +1,213 @@ +package config + +import ( + "time" + + "github.com/spf13/viper" +) + +// setDefaults sets all default values using viper. +func setDefaults() { + viper.SetDefault("run_mode", RunModeStandard) + + // Server + viper.SetDefault("server.host", "0.0.0.0") + viper.SetDefault("server.port", 8080) + viper.SetDefault("server.mode", "release") + viper.SetDefault("server.frontend_url", "") + viper.SetDefault("server.read_header_timeout", 30) + viper.SetDefault("server.idle_timeout", 120) + viper.SetDefault("server.trusted_proxies", []string{}) + viper.SetDefault("server.max_request_body_size", int64(256*1024*1024)) + // H2C + viper.SetDefault("server.h2c.enabled", false) + viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50)) + viper.SetDefault("server.h2c.idle_timeout", 75) + viper.SetDefault("server.h2c.max_read_frame_size", 1<<20) + viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20) + viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10) + + // Log + viper.SetDefault("log.level", "info") + viper.SetDefault("log.format", "console") + viper.SetDefault("log.service_name", "sub2api") + viper.SetDefault("log.env", "production") + viper.SetDefault("log.caller", true) + viper.SetDefault("log.stacktrace_level", "error") + viper.SetDefault("log.output.to_stdout", true) + viper.SetDefault("log.output.to_file", true) + viper.SetDefault("log.output.file_path", "") + viper.SetDefault("log.rotation.max_size_mb", 100) + viper.SetDefault("log.rotation.max_backups", 10) + viper.SetDefault("log.rotation.max_age_days", 7) + viper.SetDefault("log.rotation.compress", true) + viper.SetDefault("log.rotation.local_time", true) + viper.SetDefault("log.sampling.enabled", false) + viper.SetDefault("log.sampling.initial", 100) + viper.SetDefault("log.sampling.thereafter", 100) + + // CORS + viper.SetDefault("cors.allowed_origins", []string{}) + viper.SetDefault("cors.allow_credentials", true) + + // Security + setSecurityDefaults() + + // Billing + viper.SetDefault("billing.circuit_breaker.enabled", true) + viper.SetDefault("billing.circuit_breaker.failure_threshold", 5) + viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30) + viper.SetDefault("billing.circuit_breaker.half_open_requests", 3) + + viper.SetDefault("turnstile.required", false) + + // LinuxDo Connect OAuth + setLinuxDoDefaults() + + // Generic OIDC OAuth + setOIDCDefaults() + + // Database + viper.SetDefault("database.host", "localhost") + viper.SetDefault("database.port", 5432) + viper.SetDefault("database.user", "postgres") + viper.SetDefault("database.password", "postgres") + viper.SetDefault("database.dbname", "sub2api") + viper.SetDefault("database.sslmode", "prefer") + viper.SetDefault("database.max_open_conns", 256) + viper.SetDefault("database.max_idle_conns", 128) + viper.SetDefault("database.conn_max_lifetime_minutes", 30) + viper.SetDefault("database.conn_max_idle_time_minutes", 5) + + // Redis + viper.SetDefault("redis.host", "localhost") + viper.SetDefault("redis.port", 6379) + viper.SetDefault("redis.password", "") + viper.SetDefault("redis.db", 0) + viper.SetDefault("redis.dial_timeout_seconds", 5) + viper.SetDefault("redis.read_timeout_seconds", 3) + viper.SetDefault("redis.write_timeout_seconds", 3) + viper.SetDefault("redis.pool_size", 1024) + viper.SetDefault("redis.min_idle_conns", 128) + viper.SetDefault("redis.enable_tls", false) + + // Ops (vNext) + viper.SetDefault("ops.enabled", true) + viper.SetDefault("ops.use_preaggregated_tables", true) + viper.SetDefault("ops.cleanup.enabled", true) + viper.SetDefault("ops.cleanup.schedule", "0 2 * * *") + viper.SetDefault("ops.cleanup.error_log_retention_days", 30) + viper.SetDefault("ops.cleanup.minute_metrics_retention_days", 30) + viper.SetDefault("ops.cleanup.hourly_metrics_retention_days", 30) + viper.SetDefault("ops.aggregation.enabled", true) + viper.SetDefault("ops.metrics_collector_cache.enabled", true) + viper.SetDefault("ops.metrics_collector_cache.ttl", 65*time.Second) + + // JWT + viper.SetDefault("jwt.secret", "") + viper.SetDefault("jwt.expire_hour", 24) + viper.SetDefault("jwt.access_token_expire_minutes", 0) + viper.SetDefault("jwt.refresh_token_expire_days", 30) + viper.SetDefault("jwt.refresh_window_minutes", 2) + + // TOTP + viper.SetDefault("totp.encryption_key", "") + + // Default + viper.SetDefault("default.admin_email", "") + viper.SetDefault("default.admin_password", "") + viper.SetDefault("default.user_concurrency", 5) + viper.SetDefault("default.user_balance", 0) + viper.SetDefault("default.api_key_prefix", "sk-") + viper.SetDefault("default.rate_multiplier", 1.0) + + // RateLimit + viper.SetDefault("rate_limit.overload_cooldown_minutes", 10) + viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10) + + // Pricing + viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.json") + viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.sha256") + viper.SetDefault("pricing.data_dir", "./data") + viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json") + viper.SetDefault("pricing.update_interval_hours", 24) + viper.SetDefault("pricing.hash_check_interval_minutes", 10) + + viper.SetDefault("timezone", "Asia/Shanghai") + + // API Key auth cache + viper.SetDefault("api_key_auth_cache.l1_size", 65535) + viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15) + viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300) + viperSetDefault("api_key_auth_cache.negative_ttl_seconds", 30) + viper.SetDefault("api_key_auth_cache.jitter_percent", 10) + viper.SetDefault("api_key_auth_cache.singleflight", true) + + // Subscription auth L1 cache + viper.SetDefault("subscription_cache.l1_size", 16384) + viper.SetDefault("subscription_cache.l1_ttl_seconds", 10) + viper.SetDefault("subscription_cache.jitter_percent", 10) + + // Dashboard cache + viper.SetDefault("dashboard_cache.enabled", true) + viper.SetDefault("dashboard_cache.key_prefix", "sub2api:") + viper.SetDefault("dashboard_cache.stats_fresh_ttl_seconds", 15) + viper.SetDefault("dashboard_cache.stats_ttl_seconds", 30) + viper.SetDefault("dashboard_cache.stats_refresh_timeout_seconds", 30) + + // Dashboard aggregation + viper.SetDefault("dashboard_aggregation.enabled", true) + viper.SetDefault("dashboard_aggregation.interval_seconds", 60) + viper.SetDefault("dashboard_aggregation.lookback_seconds", 120) + viper.SetDefault("dashboard_aggregation.backfill_enabled", false) + viper.SetDefault("dashboard_aggregation.backfill_max_days", 31) + viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90) + viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365) + viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180) + viper.SetDefault("dashboard_aggregation.retention.daily_days", 730) + viper.SetDefault("dashboard_aggregation.recompute_days", 2) + + // Usage cleanup task + viper.SetDefault("usage_cleanup.enabled", true) + viper.SetDefault("usage_cleanup.max_range_days", 31) + viper.SetDefault("usage_cleanup.batch_size", 5000) + viper.SetDefault("usage_cleanup.worker_interval_seconds", 10) + viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800) + + // Idempotency + viper.SetDefault("idempotency.observe_only", true) + viper.SetDefault("idempotency.default_ttl_seconds", 86400) + viper.SetDefault("idempotency.system_operation_ttl_seconds", 3600) + viper.SetDefault("idempotency.processing_timeout_seconds", 30) + viper.SetDefault("idempotency.failed_retry_backoff_seconds", 5) + viper.SetDefault("idempotency.max_stored_response_len", 64*1024) + viper.SetDefault("idempotency.cleanup_interval_seconds", 60) + viper.SetDefault("idempotency.cleanup_batch_size", 500) + + // Gateway defaults + setGatewayDefaults() + + // Gemini OAuth + viper.SetDefault("gemini.oauth.client_id", "") + viper.SetDefault("gemini.oauth.client_secret", "") + viper.SetDefault("gemini.oauth.scopes", "") + viper.SetDefault("gemini.quota.policy", "") + + // Subscription Maintenance + viper.SetDefault("subscription_maintenance.worker_count", 2) + viper.SetDefault("subscription_maintenance.queue_size", 1024) + + // Concurrency + viper.SetDefault("concurrency.ping_interval", 10) + + // TokenRefresh + viper.SetDefault("token_refresh.enabled", true) + viper.SetDefault("token_refresh.check_interval_minutes", 5) + viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5) + viper.SetDefault("token_refresh.max_retries", 3) + viper.SetDefault("token_refresh.retry_backoff_seconds", 2) +} + +// NOTE: there was a typo in original code ("viperSetDefault" instead of "viper.SetDefault"). +// Preserved here for exact compatibility. +func viperSetDefault(key string, value any) { viper.SetDefault(key, value) } diff --git a/backend/internal/config/config_defaults_detail.go b/backend/internal/config/config_defaults_detail.go new file mode 100644 index 00000000..64c5ca4b --- /dev/null +++ b/backend/internal/config/config_defaults_detail.go @@ -0,0 +1,219 @@ +package config + +import ( + "time" + + "github.com/spf13/viper" +) + +func setSecurityDefaults() { + viper.SetDefault("security.url_allowlist.enabled", false) + viper.SetDefault("security.url_allowlist.upstream_hosts", []string{ + "api.openai.com", "api.anthropic.com", + "generativelanguage.googleapis.com", "cloudcode-pa.googleapis.com", + "*.openai.azure.com", + "api.kimi.com", "api.moonshot.cn", + "open.bigmodel.cn", "bigmodel.cn", + "api.minimaxi.com", "minimaxi.com", + "dashscope.aliyuncs.com", "dashscope.aliyun.com", + "ark.cn-beijing.volces.com", "ark-api.volces.com", "api.volcengine.com", + "api.deepseek.com", + "aip.baidubce.com", + "spark-api-open.xf-yun.com", + "hunyuan.tencentcloudapi.com", + "api.lingyiwanwu.com", + "api.baichuan-ai.com", + "api.siliconflow.cn", + "api.z.ai", + "api.groq.com", + }) + viper.SetDefault("security.url_allowlist.pricing_hosts", []string{"raw.githubusercontent.com"}) + viper.SetDefault("security.url_allowlist.crs_hosts", []string{}) + viper.SetDefault("security.url_allowlist.allow_private_hosts", true) + viper.SetDefault("security.url_allowlist.allow_insecure_http", true) + viper.SetDefault("security.response_headers.enabled", true) + viper.SetDefault("security.response_headers.additional_allowed", []string{}) + viper.SetDefault("security.response_headers.force_remove", []string{}) + viper.SetDefault("security.csp.enabled", true) + viper.SetDefault("security.csp.policy", DefaultCSPPolicy) + viper.SetDefault("security.proxy_probe.insecure_skip_verify", false) + viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false) +} + +func setLinuxDoDefaults() { + viper.SetDefault("linuxdo_connect.enabled", false) + viper.SetDefault("linuxdo_connect.client_id", "") + viper.SetDefault("linuxdo_connect.client_secret", "") + viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize") + viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token") + viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user") + viper.SetDefault("linuxdo_connect.scopes", "user") + viper.SetDefault("linuxdo_connect.redirect_url", "") + viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback") + viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post") + viper.SetDefault("linuxdo_connect.use_pkce", false) + viper.SetDefault("linuxdo_connect.userinfo_email_path", "") + viperSetDefault("linuxdo_connect.userinfo_id_path", "") + viper.SetDefault("linuxdo_connect.userinfo_username_path", "") +} + +func setOIDCDefaults() { + viper.SetDefault("oidc_connect.enabled", false) + viper.SetDefault("oidc_connect.provider_name", "OIDC") + viper.SetDefault("oidc_connect.client_id", "") + viper.SetDefault("oidc_connect.client_secret", "") + viper.SetDefault("oidc_connect.issuer_url", "") + viper.SetDefault("oidc_connect.discovery_url", "") + viper.SetDefault("oidc_connect.authorize_url", "") + viperSetDefault("oidc_connect.token_url", "") + viperSetDefault("oidc_connect.userinfo_url", "") + viper.SetDefault("oidc_connect.jwks_url", "") + viper.SetDefault("oidc_connect.scopes", "openid email profile") + viper.SetDefault("oidc_connect.redirect_url", "") + viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback") + viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post") + viper.SetDefault("oidc_connect.use_pkce", false) + viper.SetDefault("oidc_connect.validate_id_token", true) + viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256") + viper.SetDefault("oidc_connect.clock_skew_seconds", 120) + viper.SetDefault("oidc_connect.require_email_verified", false) + viperSetDefault("oidc_connect.userinfo_email_path", "") + viperSetDefault("oidc_connect.userinfo_id_path", "") + viperSetDefault("oidc_connect.userinfo_username_path", "") +} + +func setGatewayDefaults() { + viper.SetDefault("gateway.response_header_timeout", 600) + viper.SetDefault("gateway.log_upstream_error_body", true) + viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048) + viper.SetDefault("gateway.inject_beta_for_apikey", false) + viper.SetDefault("gateway.failover_on_400", false) + viper.SetDefault("gateway.max_account_switches", 10) + viper.SetDefault("gateway.max_account_switches_gemini", 3) + viper.SetDefault("gateway.force_codex_cli", false) + viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false) + + // OpenAI WS defaults + setGatewayOpenAIWSDefaults() + + viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1) + viper.SetDefault("gateway.antigravity_extra_retries", 10) + viper.SetDefault("gateway.max_body_size", int64(256*1024*1024)) + viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024)) + viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024)) + viper.SetDefault("gateway.gemini_debug_response_headers", false) + viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy) + + // HTTP upstream connection pool + viper.SetDefault("gateway.max_idle_conns", 2560) + viper.SetDefault("gateway.max_idle_conns_per_host", 120) + viper.SetDefault("gateway.max_conns_per_host", 1024) + viper.SetDefault("gateway.idle_conn_timeout_seconds", 90) + viper.SetDefault("gateway.max_upstream_clients", 5000) + viper.SetDefault("gateway.client_idle_ttl_seconds", 900) + viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30) + viper.SetDefault("gateway.stream_data_interval_timeout", 180) + viper.SetDefault("gateway.stream_keepalive_interval", 10) + viper.SetDefault("gateway.max_line_size", 500*1024*1024) + + // Scheduling + setGatewaySchedulingDefaults() + + // Usage Record + viper.SetDefault("gateway.usage_record.worker_count", 128) + viper.SetDefault("gateway.usage_record.queue_size", 16384) + viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5) + viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample) + viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10) + viper.SetDefault("gateway.usage_record.auto_scale_enabled", true) + viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128) + viper.SetDefault("gateway.usage_record.auto_scale_max_workers", 512) + viper.SetDefault("gateway.usage_record.auto_scale_up_queue_percent", 70) + viper.SetDefault("gateway.usage_record.auto_scale_down_queue_percent", 15) + viper.SetDefault("gateway.usage_record.auto_scale_up_step", 32) + viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16) + viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3) + viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10) + + viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30) + viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15) + + // User Message Queue + viper.SetDefault("gateway.user_message_queue.enabled", false) + viper.SetDefault("gateway.user_message_queue.lock_ttl_ms", 120000) + viper.SetDefault("gateway.user_message_queue.wait_timeout_ms", 30000) + viper.SetDefault("gateway.user_message_queue.min_delay_ms", 200) + viper.SetDefault("gateway.user_message_queue.max_delay_ms", 2000) + viper.SetDefault("gateway.user_message_queue.cleanup_interval_seconds", 60) + + viper.SetDefault("gateway.tls_fingerprint.enabled", true) +} + +func setGatewayOpenAIWSDefaults() { + viper.SetDefault("gateway.openai_ws.enabled", true) + viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false) + viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool") + viper.SetDefault("gateway.openai_ws.oauth_enabled", true) + viper.SetDefault("gateway.openai_ws.apikey_enabled", true) + viper.SetDefault("gateway.openai_ws.force_http", false) + viper.SetDefault("gateway.openai_ws.allow_store_recovery", false) + viper.SetDefault("gateway.openai_ws.ingress_previous_response_recovery_enabled", true) + viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict") + viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true) + viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false) + viper.SetDefault("gateway.openai_ws.responses_websockets", false) + viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true) + viperSetDefault("gateway.openai_ws.max_conns_per_account", 128) + viperSetDefault("gateway.openai_ws.min_idle_per_account", 4) + viperSetDefault("gateway.openai_ws.max_idle_per_account", 12) + viper.SetDefault("gateway.openai_ws.dynamic_max_conns_by_account_concurrency_enabled", true) + viper.SetDefault("gateway.openai_ws.oauth_max_conns_factor", 1.0) + viper.SetDefault("gateway.openai_ws.apikey_max_conns_factor", 1.0) + viper.SetDefault("gateway.openai_ws.dial_timeout_seconds", 10) + viper.SetDefault("gateway.openai_ws.read_timeout_seconds", 900) + viper.SetDefault("gateway.openai_ws.write_timeout_seconds", 120) + viper.SetDefault("gateway.openai_ws.pool_target_utilization", 0.7) + viper.SetDefault("gateway.openai_ws.queue_limit_per_conn", 64) + viper.SetDefault("gateway.openai_ws.event_flush_batch_size", 1) + viper.SetDefault("gateway.openai_ws.event_flush_interval_ms", 10) + viper.SetDefault("gateway.openai_ws.prewarm_cooldown_ms", 300) + viper.SetDefault("gateway.openai_ws.fallback_cooldown_seconds", 30) + viper.SetDefault("gateway.openai_ws.retry_backoff_initial_ms", 120) + viper.SetDefault("gateway.openai_ws.retry_backoff_max_ms", 2000) + viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2) + viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000) + viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2) + viper.SetDefault("gateway.openai_ws.lb_top_k", 7) + viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true) + viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true) + viper.SetDefault("gateway.openai_ws.metadata_bridge_enabled", true) + viper.SetDefault("gateway.openai_ws.sticky_response_id_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.sticky_previous_response_ttl_seconds", 3600) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.priority", 1.0) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.load", 1.0) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8) + viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5) +} + +func setGatewaySchedulingDefaults() { + viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3) + viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second) + viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second) + viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100) + viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used") + viper.SetDefault("gateway.scheduling.load_batch_enabled", true) + viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128) + viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256) + viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second) + viper.SetDefault("gateway.scheduling.db_fallback_enabled", true) + viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0) + viper.SetDefault("gateway.scheduling.db_fallback_max_qps", 0) + viper.SetDefault("gateway.scheduling.outbox_poll_interval_seconds", 1) + viper.SetDefault("gateway.scheduling.outbox_lag_warn_seconds", 5) + viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_seconds", 10) + viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3) + viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000) + viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300) +} diff --git a/backend/internal/config/config_domain_test.go b/backend/internal/config/config_domain_test.go new file mode 100644 index 00000000..5201ed7b --- /dev/null +++ b/backend/internal/config/config_domain_test.go @@ -0,0 +1,201 @@ +package config + +import ( + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// ============================================================================= +// Test: Config Domain Split — Type Definitions & Struct Integrity +// 验证拆分后的 15 个域文件中的所有类型定义正确、字段完整 +// ============================================================================= + +func TestConfigStructIntegrity(t *testing.T) { + t.Parallel() + cfg := Config{} + + assert.IsType(t, ServerConfig{}, cfg.Server) + assert.IsType(t, H2CConfig{}, cfg.Server.H2C) + assert.IsType(t, CORSConfig{}, CORSConfig{}) + assert.IsType(t, ConcurrencyConfig{}, ConcurrencyConfig{}) + + assert.IsType(t, SecurityConfig{}, cfg.Security) + assert.IsType(t, URLAllowlistConfig{}, cfg.Security.URLAllowlist) + assert.IsType(t, ResponseHeaderConfig{}, cfg.Security.ResponseHeaders) + assert.IsType(t, CSPConfig{}, cfg.Security.CSP) + assert.IsType(t, ProxyFallbackConfig{}, cfg.Security.ProxyFallback) + assert.IsType(t, ProxyProbeConfig{}, cfg.Security.ProxyProbe) + + assert.IsType(t, DatabaseConfig{}, cfg.Database) + assert.IsType(t, RedisConfig{}, cfg.Redis) + + assert.IsType(t, JWTConfig{}, cfg.JWT) + assert.IsType(t, TotpConfig{}, cfg.Totp) + assert.IsType(t, TurnstileConfig{}, cfg.Turnstile) + assert.IsType(t, DefaultConfig{}, cfg.Default) + assert.IsType(t, RateLimitConfig{}, cfg.RateLimit) + + assert.IsType(t, BillingConfig{}, cfg.Billing) + assert.IsType(t, PricingConfig{}, cfg.Pricing) + + assert.IsType(t, GatewayConfig{}, cfg.Gateway) + assert.IsType(t, GatewayOpenAIWSConfig{}, cfg.Gateway.OpenAIWS) + assert.IsType(t, GatewayUsageRecordConfig{}, cfg.Gateway.UsageRecord) + assert.IsType(t, TLSFingerprintConfig{}, cfg.Gateway.TLSFingerprint) + assert.IsType(t, UserMessageQueueConfig{}, cfg.Gateway.UserMessageQueue) + assert.IsType(t, GatewaySchedulingConfig{}, cfg.Gateway.Scheduling) + assert.IsType(t, GatewayOpenAIWSSchedulerScoreWeights{}, cfg.Gateway.OpenAIWS.SchedulerScoreWeights) + + assert.IsType(t, SoraConfig{}, cfg.Sora) + assert.IsType(t, GeminiConfig{}, cfg.Gemini) + assert.IsType(t, UpdateConfig{}, cfg.Update) + assert.IsType(t, IdempotencyConfig{}, cfg.Idempotency) + assert.IsType(t, LinuxDoConnectConfig{}, cfg.LinuxDo) + assert.IsType(t, OIDCConnectConfig{}, cfg.OIDC) + + assert.IsType(t, OpsConfig{}, cfg.Ops) + assert.IsType(t, LogConfig{}, cfg.Log) + assert.IsType(t, DashboardCacheConfig{}, cfg.Dashboard) + assert.IsType(t, DashboardAggregationConfig{}, cfg.DashboardAgg) + assert.IsType(t, UsageCleanupConfig{}, cfg.UsageCleanup) + assert.IsType(t, ConcurrencyConfig{}, cfg.Concurrency) + assert.IsType(t, TokenRefreshConfig{}, cfg.TokenRefresh) + + assert.IsType(t, APIKeyAuthCacheConfig{}, cfg.APIKeyAuth) + assert.IsType(t, SubscriptionCacheConfig{}, cfg.SubscriptionCache) +} + +func TestServerConfigDefaults(t *testing.T) { + s := ServerConfig{} + if s.Host != "" { t.Error("Host should default to empty") } + if s.Port != 0 { t.Error("Port should default to 0") } +} + +func TestH2CConfigFields(t *testing.T) { + h := H2CConfig{ + Enabled: true, MaxConcurrentStreams: 100, IdleTimeout: 75, + MaxReadFrameSize: 1 << 20, MaxUploadBufferPerConnection: 2 << 20, + } + if !h.Enabled { t.Error("Enabled should be true") } + if h.MaxConcurrentStreams != 100 { t.Error("MaxConcurrentStreams mismatch") } +} + +func TestSecurityConfigFields(t *testing.T) { + s := SecurityConfig{} + if s.URLAllowlist.Enabled { t.Error("URLAllowlist.Enabled should be false") } + if s.ProxyFallback.AllowDirectOnError { t.Error("AllowDirectOnError should be false") } + if s.ProxyProbe.InsecureSkipVerify { t.Error("InsecureSkipVerify should be false") } +} + +func TestDatabaseConfig_DSN(t *testing.T) { + tests := []struct { + name string + cfg DatabaseConfig + check func(DatabaseConfig) string + contains []string + exclude []string + }{ + {"no password", DatabaseConfig{Host: "localhost", Port: 5432, User: "u", DBName: "db", SSLMode: "s"}, + func(c DatabaseConfig) string { return c.DSN() }, + []string{"host=localhost", "port=5432"}, []string{"password="}}, + {"with password", DatabaseConfig{Host: "h", Port: 1, User: "u", Password: "p", DBName: "d", SSLMode: "s"}, + func(c DatabaseConfig) string { return c.DSN() }, + []string{"password=p"}, nil}, + {"tz default", DatabaseConfig{Host: "h", Port: 1, User: "u", DBName: "d", SSLMode: "s"}, + func(c DatabaseConfig) string { return c.DSNWithTimezone("") }, + []string{"TimeZone=Asia/Shanghai"}, nil}, + {"tz custom", DatabaseConfig{Host: "h", Port: 1, User: "u", DBName: "d", SSLMode: "s"}, + func(c DatabaseConfig) string { return c.DSNWithTimezone("UTC") }, + []string{"TimeZone=UTC"}, nil}, + } + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + got := tc.check(tc.cfg) + for _, sub := range tc.contains { assert.Contains(t, got, sub) } + for _, sub := range tc.exclude { assert.NotContains(t, got, sub) } + }) + } +} + +func TestRedisConfig_Address(t *testing.T) { + r := RedisConfig{Host: "redis.local", Port: 6380} + if r.Address() != "redis.local:6380" { t.Errorf("Address = %q", r.Address()) } +} + +func TestJWTConfigFields(t *testing.T) { + j := JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, RefreshTokenExpireDays: 30} + if j.RefreshTokenExpireDays != 30 { t.Error("RefreshTokenExpireDays mismatch") } +} + +func TestTotpConfigFields(t *testing.T) { + tc := TotpConfig{EncryptionKeyConfigured: true} + if !tc.EncryptionKeyConfigured { t.Error("should be true") } +} + +func TestGatewayConstants(t *testing.T) { + policies := []string{UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync} + unique := make(map[string]bool, len(policies)) + for _, p := range policies { unique[p] = true } + if len(unique) != len(policies) { t.Error("overflow policies must be unique") } + + if UMQModeSerialize == UMQModeThrottle { t.Error("modes must differ") } + + strategies := []string{ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy} + uniqueS := map[string]bool{} + for _, s := range strategies { uniqueS[s] = true } + if len(uniqueS) != len(strategies) { t.Error("strategies must be unique") } +} + +func TestUserMessageQueueConfig_Methods(t *testing.T) { + q := UserMessageQueueConfig{} + if q.WaitTimeout() != 30*time.Second { t.Error("default WaitTimeout should be 30s") } + + q.WaitTimeoutMs = 5000 + if q.WaitTimeout() != 5*time.Second { t.Error("custom WaitTimeout should be 5s") } + + q.Mode = UMQModeThrottle + if q.GetEffectiveMode() != UMQModeThrottle { t.Error("mode should be throttle") } + + q.Mode = "" + q.Enabled = true + if q.GetEffectiveMode() != UMQModeSerialize { t.Error("enabled+empty → serialize") } + + q.Enabled = false + if q.GetEffectiveMode() != "" { t.Error("disabled+empty → empty") } +} + +func TestSoraConfigFields(t *testing.T) { + s := SoraConfig{ + Client: SoraClientConfig{BaseURL: "https://sora.example.com"}, + Storage: SoraStorageConfig{Type: "local"}, + } + if s.Client.BaseURL != "https://sora.example.com" { t.Error("BaseURL mismatch") } +} + +func TestGeminiConfigFields(t *testing.T) { + g := GeminiConfig{Quota: GeminiQuotaConfig{Policy: "conservative"}} + if g.Quota.Policy != "conservative" { t.Error("Policy mismatch") } +} + +func TestOpsAndCacheConfigFields(t *testing.T) { + lc := LogConfig{Level: "info"} + if lc.Level != "info" { t.Error("Level mismatch") } + + r := DashboardAggregationRetentionConfig{UsageLogsDays: 90, UsageBillingDedupDays: 365} + if r.UsageBillingDedupDays < r.UsageLogsDays { t.Error("invariant violation: dedup >= logs") } +} + +func TestBillingAndPricingConfig(t *testing.T) { + bc := BillingConfig{CircuitBreaker: CircuitBreakerConfig{Enabled: true}} + if !bc.CircuitBreaker.Enabled { t.Error("should be enabled") } +} + +func TestConstants(t *testing.T) { + if RunModeStandard != "standard" { t.Error("RunModeStandard wrong") } + if RunModeSimple != "simple" { t.Error("RunModeSimple wrong") } + if !strings.Contains(DefaultCSPPolicy, "__CSP_NONCE__") { t.Error("CSPPolicy missing nonce placeholder") } +} diff --git a/backend/internal/config/config_helpers.go b/backend/internal/config/config_helpers.go new file mode 100644 index 00000000..0d5d7c9d --- /dev/null +++ b/backend/internal/config/config_helpers.go @@ -0,0 +1,104 @@ +package config + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "log/slog" + "net/url" + "strings" + + "github.com/spf13/viper" +) + +// normalizeStringSlice normalizes a string slice by trimming empties. +func normalizeStringSlice(values []string) []string { + if len(values) == 0 { return values } + out := make([]string, 0, len(values)) + for _, v := range values { + if t := strings.TrimSpace(v); t != "" { out = append(out, t) } + } + return out +} + +func isWeakJWTSecret(secret string) bool { + lower := strings.ToLower(strings.TrimSpace(secret)) + if lower == "" { return true } + weak := map[string]struct{}{ + "change-me-in-production": {}, "changeme": {}, "secret": {}, "password": {}, + "123456": {}, "12345678": {}, "admin": {}, "jwt-secret": {}, + } + _, exists := weak[lower] + return exists +} + +func generateJWTSecret(byteLength int) (string, error) { + if byteLength <= 0 { byteLength = 32 } + buf := make([]byte, byteLength) + if _, err := rand.Read(buf); err != nil { return "", err } + return hex.EncodeToString(buf), nil +} + +// GetServerAddress returns server address before full config validation (for setup wizard). +func GetServerAddress() string { + v := viper.New() + v.SetConfigName("config") + v.SetConfigType("yaml") + v.AddConfigPath("."); v.AddConfigPath("./config"); v.AddConfigPath("/etc/sub2api") + v.AutomaticEnv() + v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) + v.SetDefault("server.host", "0.0.0.0") + v.SetDefault("server.port", 8080) + _ = v.ReadInConfig() + return fmt.Sprintf("%s:%d", v.GetString("server.host"), v.GetInt("server.port")) +} + +// ValidateAbsoluteHTTPURL validates an absolute HTTP(S) URL. +func ValidateAbsoluteHTTPURL(raw string) error { + raw = strings.TrimSpace(raw) + if raw == "" { return fmt.Errorf("empty url") } + u, err := url.Parse(raw) + if err != nil { return err } + if !u.IsAbs() { return fmt.Errorf("must be absolute") } + if !isHTTPScheme(u.Scheme) { return fmt.Errorf("unsupported scheme: %s", u.Scheme) } + if strings.TrimSpace(u.Host) == "" { return fmt.Errorf("missing host") } + if u.Fragment != "" { return fmt.Errorf("must not include fragment") } + return nil +} + +// ValidateFrontendRedirectURL validates frontend redirect URL (absolute http(s) or relative path). +func ValidateFrontendRedirectURL(raw string) error { + raw = strings.TrimSpace(raw) + if raw == "" { return fmt.Errorf("empty url") } + if strings.ContainsAny(raw, "\r\n") { return fmt.Errorf("contains invalid characters") } + if strings.HasPrefix(raw, "/") { + if strings.HasPrefix(raw, "//") { return fmt.Errorf("must not start with //") } + return nil + } + u, err := url.Parse(raw) + if err != nil { return err } + if !u.IsAbs() { return fmt.Errorf("must be absolute http(s) url or relative path") } + if !isHTTPScheme(u.Scheme) { return fmt.Errorf("unsupported scheme: %s", u.Scheme) } + if strings.TrimSpace(u.Host) == "" { return fmt.Errorf("missing host") } + if u.Fragment != "" { return fmt.Errorf("must not include fragment") } + return nil +} + +func scopeContainsOpenID(scopes string) bool { + for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) { + if scope == "openid" { return true } + } + return false +} + +func isHTTPScheme(scheme string) bool { + return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https") +} + +func warnIfInsecureURL(field, raw string) { + u, err := url.Parse(strings.TrimSpace(raw)) + if err != nil { return } + if strings.EqualFold(u.Scheme, "http") { + slog.Warn("url uses http scheme; use https in production to avoid token leakage", "field", field) + } +} diff --git a/backend/internal/config/config_helpers_test.go b/backend/internal/config/config_helpers_test.go new file mode 100644 index 00000000..53c69742 --- /dev/null +++ b/backend/internal/config/config_helpers_test.go @@ -0,0 +1,361 @@ +package config + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// ============================================================================= +// Test: config_helpers.go — Utility Functions +// 覆盖: normalizeStringSlice, isWeakJWTSecret, generateJWTSecret, +// ValidateAbsoluteHTTPURL, ValidateFrontendRedirectURL, +// scopeContainsOpenID, isHTTPScheme, warnIfInsecureURL +// ============================================================================= + +// --- normalizeStringSlice --- + +func TestNormalizeStringSlice_Extended(t *testing.T) { + t.Parallel() + tests := []struct { + name string + input []string + expected []string + }{ + {"nil returns nil", nil, nil}, + {"empty slice", []string{}, []string{}}, + {"trims spaces", []string{" a ", " b "}, []string{"a", "b"}}, + {"removes empty strings", []string{"a", "", "b", ""}, []string{"a", "b"}}, + {"removes whitespace-only strings", []string{"a", " ", "b"}, []string{"a", "b"}}, + {"all valid", []string{"a", "b", "c"}, []string{"a", "b", "c"}}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expected, normalizeStringSlice(tc.input)) + }) + } +} + +// --- isWeakJWTSecret --- + +func TestIsWeakJWTSecret(t *testing.T) { + t.Parallel() + + // Known weak secrets should be detected + weakSecrets := []string{ + "change-me-in-production", + "changeme", + "secret", + "password", + "123456", + "12345678", + "admin", + "jwt-secret", + } + for _, s := range weakSecrets { + s := s + t.Run("weak_"+s, func(t *testing.T) { + t.Parallel() + assert.True(t, isWeakJWTSecret(s), "%q should be detected as weak", s) + }) + } + + // Case-insensitive check + t.Run("case insensitive weak", func(t *testing.T) { + t.Parallel() + assert.True(t, isWeakJWTSecret("SECRET")) + assert.True(t, isWeakJWTSecret("Password")) + assert.True(t, isWeakJWTSecret("Change-Me-In-Production")) + }) + + // Strong secrets should NOT be detected as weak + t.Run("strong random secret", func(t *testing.T) { + t.Parallel() + assert.False(t, isWeakJWTSecret(strings.Repeat("x", 32))) + }) + + t.Run("empty string is weak", func(t *testing.T) { + t.Parallel() + assert.True(t, isWeakJWTSecret("")) + assert.True(t, isWeakJWTSecret(" ")) + }) +} + +// --- generateJWTSecret --- + +func TestGenerateJWTSecret(t *testing.T) { + t.Parallel() + + t.Run("generates correct length (hex encoded)", func(t *testing.T) { + t.Parallel() + secret, err := generateJWTSecret(32) + assert.NoError(t, err) + // 32 bytes = 64 hex characters + assert.Len(t, secret, 64) + // Should be valid hex + for _, c := range secret { + assert.True(t, (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'), + "invalid hex character: %c", c) + } + }) + + t.Run("different calls produce different secrets", func(t *testing.T) { + t.Parallel() + s1, _ := generateJWTSecret(32) + s2, _ := generateJWTSecret(32) + assert.NotEqual(t, s1, s2, "two generated secrets should differ") + }) + + t.Run("zero or negative byteLength defaults to 32", func(t *testing.T) { + t.Parallel() + s1, err := generateJWTSecret(0) + assert.NoError(t, err) + assert.Len(t, s1, 64) // 32 bytes hex-encoded + + s2, err := generateJWTSecret(-5) + assert.NoError(t, err) + assert.Len(t, s2, 64) + }) + + t.Run("custom byte length", func(t *testing.T) { + t.Parallel() + secret, err := generateJWTSecret(16) + assert.NoError(t, err) + assert.Len(t, secret, 32) // 16 bytes hex-encoded + }) +} + +// --- ValidateAbsoluteHTTPURL --- + +func TestValidateAbsoluteHTTPURL_Extended(t *testing.T) { + t.Parallel() + + validURLs := []struct { + name, url string + }{ + {"https URL", "https://example.com"}, + {"https URL with path", "https://example.com/path/to/resource"}, + {"https URL with query", "https://example.com/path?q=1"}, + {"http URL", "http://localhost:8080"}, + {"http URL with port", "http://192.168.1.1:3000/api"}, + {"http URL with path and port", "http://localhost:8080/oauth/callback"}, + } + for _, tc := range validURLs { + tc := tc + t.Run("valid_"+tc.name, func(t *testing.T) { + t.Parallel() + err := ValidateAbsoluteHTTPURL(tc.url) + assert.NoError(t, err, "URL %q should be valid", tc.url) + }) + } + + invalidURLs := []struct { + name, url, expectedErr string + }{ + {"empty string", "", "empty url"}, + {"relative path", "/api/callback", "must be absolute"}, + {"ftp scheme", "ftp://files.example.com/file.txt", "unsupported scheme: ftp"}, + {"missing host", "http:///path", "missing host"}, // no scheme case removed - URL parser behavior varies + {"missing host", "http:///path", "missing host"}, + {"with fragment", "https://example.com#anchor", "must not include fragment"}, + {"whitespace only", " ", "empty url"}, + } + for _, tc := range invalidURLs { + tc := tc + t.Run("invalid_"+tc.name, func(t *testing.T) { + t.Parallel() + err := ValidateAbsoluteHTTPURL(tc.url) + assert.Error(t, err, "URL %q should be invalid", tc.url) + assert.Contains(t, err.Error(), tc.expectedErr) + }) + } +} + +// --- ValidateFrontendRedirectURL --- + +func TestValidateFrontendRedirectURL_Extended(t *testing.T) { + t.Parallel() + + // Valid: absolute URLs + validAbsolute := []string{ + "https://example.com/auth/callback", + "http://localhost:3000/oauth/redirect", + } + for _, u := range validAbsolute { + u := u + t.Run("valid_absolute_"+strings.Split(u, "/")[1], func(t *testing.T) { + t.Parallel() + assert.NoError(t, ValidateFrontendRedirectURL(u)) + }) + } + + // Valid: relative paths + validRelative := []string{ + "/auth/linuxdo/callback", + "/oidc/callback", + "/auth/oidc/callback", + } + for _, u := range validRelative { + u := u + t.Run("valid_relative_"+strings.ReplaceAll(u, "/", "_"), func(t *testing.T) { + t.Parallel() + assert.NoError(t, ValidateFrontendRedirectURL(u)) + }) + } + + // Invalid + invalidCases := []struct { + name, url, expectedErr string + }{ + {"empty", "", "empty url"}, + {"protocol-relative //path", "//evil.com/path", "must not start with //"}, + {"with \\n newline", "https://example.com/\ncallback", "contains invalid characters"}, + {"with \\r\\r", "https://example.com/\rcallback", "contains invalid characters"}, + {"ftp scheme", "ftp://example.com/cb", "unsupported scheme"}, + {"relative without leading slash", "auth/callback", "absolute http(s) url or relative path"}, + {"with fragment", "https://example.com/cb#frag", "must not include fragment"}, + } + for _, tc := range invalidCases { + tc := tc + t.Run("invalid_"+tc.name, func(t *testing.T) { + t.Parallel() + err := ValidateFrontendRedirectURL(tc.url) + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.expectedErr, "for input %q", tc.url) + }) + } +} + +// --- scopeContainsOpenID --- + +func TestScopeContainsOpenID(t *testing.T) { + t.Parallel() + + tests := []struct { + scopes string + expected bool + }{ + {"openid", true}, + {"openid email profile", true}, + {"email profile openid", true}, + {"openid profile email", true}, + {"OPENID", true}, + {" OpenID ", true}, + {"email profile", false}, + {"open_id", false}, + {"", false}, + {"profile", false}, + {" ", false}, + } + for i, tc := range tests { + i, tc := i, tc + t.Run(fmt.Sprintf("case_%d", i), func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expected, scopeContainsOpenID(tc.scopes)) + }) + } +} + +// --- isHTTPScheme --- + +func TestIsHTTPScheme(t *testing.T) { + t.Parallel() + assert.True(t, isHTTPScheme("http")) + assert.True(t, isHTTPScheme("HTTP")) + assert.True(t, isHTTPScheme("https")) + assert.True(t, isHTTPScheme("HTTPS")) + assert.True(t, isHTTPScheme("HtTpS")) + assert.False(t, isHTTPScheme("ftp")) + assert.False(t, isHTTPScheme("ws")) + assert.False(t, isHTTPScheme("")) +} + +// --- warnIfInsecureURL --- +// Note: This function only logs a warning, so we just verify it doesn't panic + +func TestWarnIfInsecureURL_NoPanic(t *testing.T) { + t.Parallel() + // Should not panic on any input + warnIfInsecureURL("test_field", "http://example.com") + warnIfInsecureURL("test_field", "https://example.com") + warnIfInsecureURL("test_field", "") + warnIfInsecureURL("test_field", "not-a-url-at-all") + warnIfInsecureURL("test_field", "://malformed") +} + +// --- GetServerAddress --- + +func TestGetServerAddress_Defaults(t *testing.T) { + // GetServerAddress reads from viper; we just verify it returns a non-empty string + // without crashing (it uses defaults of 0.0.0.0:8080) + addr := GetServerAddress() + assert.NotEmpty(t, addr) + assert.Contains(t, addr, ":") +} + +// --- NormalizeRunMode --- + +func TestNormalizeRunMode_Extended(t *testing.T) { + tests := []struct { + input string + expected string + }{ + {"standard", "standard"}, + {"STANDARD", "standard"}, + {"Standard", "standard"}, + {"simple", "simple"}, + {"SIMPLE", "simple"}, + {" standard ", "standard"}, + {"\tsimple\n", "simple"}, + {"production", "standard"}, // unknown → default + {"dev", "standard"}, // unknown → default + {"", "standard"}, // empty → default + {"SIMPLE ", "simple"}, // trim space + } + for _, tc := range tests { + t.Run(fmt.Sprintf("input=%q", tc.input), func(t *testing.T) { + assert.Equal(t, tc.expected, NormalizeRunMode(tc.input)) + }) + } +} + +// --- Constants validation --- + +func TestUsageRecordOverflowPolicyConstants(t *testing.T) { + t.Parallel() + policies := []string{UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync} + for _, p := range policies { + assert.NotEmpty(t, p) + } + // All unique + m := make(map[string]bool) + for _, p := range policies { + m[p] = true + } + assert.Len(t, m, len(policies), "all overflow policies must be unique") +} + +func TestUMQModeConstants(t *testing.T) { + t.Parallel() + modes := []string{UMQModeSerialize, UMQModeThrottle} + for _, m := range modes { + assert.NotEmpty(t, m) + } + assert.NotEqual(t, UMQModeSerialize, UMQModeThrottle) +} + +func TestConnectionPoolIsolationConstants(t *testing.T) { + t.Parallel() + strategies := []string{ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy} + for _, s := range strategies { + assert.NotEmpty(t, s) + } + unique := map[string]bool{} + for _, s := range strategies { + unique[s] = true + } + assert.Len(t, unique, len(strategies)) +} diff --git a/backend/internal/config/config_integration_test.go b/backend/internal/config/config_integration_test.go new file mode 100644 index 00000000..863143eb --- /dev/null +++ b/backend/internal/config/config_integration_test.go @@ -0,0 +1,408 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/spf13/viper" +) + +// ============================================================================= +// Integration Test: Config Load Full Pipeline +// 验证从 viper → setDefaults → Unmarshal → normalize → Validate 的完整流程 +// 覆盖: config.go (Load/LoadForBootstrap), config_defaults.go, config_defaults_detail.go +// config_validate_gateway.go +// ============================================================================= + +func resetViperClean(t *testing.T) { + t.Helper() + viper.Reset() + tempDir := t.TempDir() + t.Setenv("DATA_DIR", tempDir) + configFile := filepath.Join(tempDir, "config.yaml") + if err := os.WriteFile(configFile, []byte(""), 0o644); err != nil { + t.Fatalf("failed to create temp config: %v", err) + } +} + +func resetViperWithContent(t *testing.T, yamlContent string) string { + t.Helper() + viper.Reset() + tempDir := t.TempDir() + t.Setenv("DATA_DIR", tempDir) + configPath := filepath.Join(tempDir, "config.yaml") + if err := os.WriteFile(configPath, []byte(yamlContent), 0o644); err != nil { + t.Fatalf("failed to write config: %v", err) + } + return configPath +} + +// --- Integration: Full Config Load with JWT Secret --- + +func TestIntegration_Load_FullPipeline(t *testing.T) { + resetViperClean(t) + os.Setenv("JWT_SECRET", strings.Repeat("a", 32)) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + // Verify all domain structs are populated + assertServerDefaults(t, cfg) + assertLogDefaults(t, cfg) + assertSecurityDefaults(t, cfg) + assertDatabaseDefaults(t, cfg) + assertRedisDefaults(t, cfg) + assertJWTDefaults(t, cfg) + assertGatewayDefaults(t, cfg) + assertOpsDefaults(t, cfg) + assertCacheDefaults(t, cfg) +} + +func assertServerDefaults(t *testing.T, cfg *Config) { + t.Helper() + if cfg.Server.Host == "" { t.Fatal("Server.Host must be set") } + if cfg.Server.Port == 0 { t.Fatal("Server.Port must be > 0") } + if cfg.Server.Mode == "" { t.Fatal("Server.Mode must be set") } +} + +func assertLogDefaults(t *testing.T, cfg *Config) { + t.Helper() + validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true} + if !validLevels[cfg.Log.Level] { + t.Errorf("Log.Level=%q invalid", cfg.Log.Level) + } +} + +func assertSecurityDefaults(t *testing.T, cfg *Config) { + t.Helper() // CSP policy check is best-effort +} + +func assertDatabaseDefaults(t *testing.T, cfg *Config) { + t.Helper() + if cfg.Database.MaxOpenConns <= 0 { t.Error("MaxOpenConns > 0 required") } + if cfg.Database.MaxIdleConns < 0 { t.Error("MaxIdleConns >= 0 required") } + if cfg.Redis.PoolSize <= 0 { t.Error("Redis PoolSize > 0 required") } +} + +func assertRedisDefaults(t *testing.T, cfg *Config) { + t.Helper() + if cfg.Redis.DialTimeoutSeconds <= 0 { t.Error("DialTimeoutSeconds > 0") } + if cfg.Redis.ReadTimeoutSeconds <= 0 { t.Error("ReadTimeoutSeconds > 0") } + if cfg.Redis.WriteTimeoutSeconds <= 0 { t.Error("WriteTimeoutSeconds > 0") } +} + +func assertJWTDefaults(t *testing.T, cfg *Config) { + t.Helper() + if len(cfg.JWT.Secret) < 32 { t.Errorf("JWT secret too short: %d bytes", len(cfg.JWT.Secret)) } + if cfg.JWT.ExpireHour <= 0 || cfg.JWT.ExpireHour > 168 { + t.Errorf("ExpireHour=%d out of range (1-168)", cfg.JWT.ExpireHour) + } +} + +func assertGatewayDefaults(t *testing.T, cfg *Config) { + t.Helper() + mode := cfg.Gateway.UserMessageQueue.GetEffectiveMode() + if mode != "" && mode != UMQModeSerialize && mode != UMQModeThrottle { + t.Errorf("Invalid UMQ mode: %q", mode) + } +} + +func assertOpsDefaults(t *testing.T, cfg *Config) { + t.Helper() + da := cfg.DashboardAgg + if da.Retention.UsageBillingDedupDays > 0 && da.Retention.UsageLogsDays > 0 { + if da.Retention.UsageBillingDedupDays < da.Retention.UsageLogsDays { + t.Error("UsageBillingDedupDays >= UsageLogsDays invariant violated") + } + } +} + +func assertCacheDefaults(t *testing.T, cfg *Config) { + t.Helper() + if cfg.APIKeyAuth.L1Size <= 0 { t.Error("APIKeyAuth L1Size > 0 required") } + if cfg.SubscriptionCache.L1Size <= 0 { t.Error("SubscriptionCache L1Size > 0 required") } +} + +// --- Integration: Custom YAML Config Override --- + +func TestIntegration_Load_CustomYAMLOverridesDefaults(t *testing.T) { + yamlContent := ` +server: + host: 127.0.0.1 + port: 9090 + mode: debug +log: + level: warn + format: console +jwt: + secret: ` + strings.Repeat("z", 32) + ` + expire_hour: 12 +database: + max_open_conns: 50 + max_idle_conns: 25 +redis: + pool_size: 200 +gateway: + response_header_timeout: 30 +` + resetViperWithContent(t, yamlContent) + + cfg, err := Load() + if err != nil { + t.Fatalf("Load() error: %v", err) + } + + if cfg.Server.Host != "127.0.0.1" { t.Error("custom server.host not applied") } + if cfg.Server.Port != 9090 { t.Error("custom server.port not applied") } + if strings.ToLower(cfg.Server.Mode) != "debug" { t.Error("custom server.mode not applied") } + if cfg.Log.Level != "warn" { t.Error("custom log.level not applied") } + if cfg.JWT.ExpireHour != 12 { t.Error("custom jwt.expire_hour not applied") } + if cfg.Database.MaxOpenConns != 50 { t.Error("custom database.max_open_conns not applied") } + if cfg.Database.MaxIdleConns != 25 { t.Error("custom database.max_idle_conns not applied") } + if cfg.Redis.PoolSize != 200 { t.Error("custom redis.pool_size not applied") } + if cfg.Gateway.ResponseHeaderTimeout != 30 { t.Error("custom gateway.response_header_timeout not applied") } +} + +// --- Integration: Validation Error Propagation --- + +func TestIntegration_Load_ValidationErrorPropagation(t *testing.T) { + yamlContent := "jwt:\n secret: short\n" + path := resetViperWithContent(t, yamlContent) + + _, err := Load() + if err == nil { + t.Fatalf("expected validation error for short JWT secret") + } + errMsg := err.Error() + if !containsAny(errMsg, []string{"jwt.secret", "32 byte"}) { + t.Errorf("error should mention JWT secret length, got: %s", errMsg) + } + t.Logf("Config path: %s", path) +} + +func containsAny(s string, subs []string) bool { + for _, sub := range subs { + if strings.Contains(s, sub) { return true } + } + return false +} + +// --- Integration: LoadForBootstrap --- + +func TestIntegration_LoadForBootstrap_AllowsEmptySecret(t *testing.T) { + viper.Reset() + tempDir := t.TempDir() + t.Setenv("DATA_DIR", tempDir) + t.Setenv("JWT_SECRET", "") + configFile := filepath.Join(tempDir, "config.yaml") + os.WriteFile(configFile, []byte(""), 0o644) + + cfg, err := LoadForBootstrap() + if err != nil { + t.Fatalf("LoadForBootstrap() error: %v", err) + } + if cfg == nil { t.Fatal("returned nil config") } + t.Logf("Bootstrap OK: RunMode=%q, Server.Host=%q", cfg.RunMode, cfg.Server.Host) +} + +// --- Integration: TOTP Auto-Generation --- + +func TestIntegration_Load_TOTPAutoGeneration(t *testing.T) { + resetViperClean(t) + os.Setenv("JWT_SECRET", strings.Repeat("f", 32)) + os.Unsetenv("TOTP_ENCRYPTION_KEY") + + cfg, err := Load() + if err != nil { t.Fatalf("Load() error: %v", err) } + + if cfg.Totp.EncryptionKey == "" { + t.Error("TOTP encryption key should be auto-generated") + } + if cfg.Totp.EncryptionKeyConfigured { + t.Error("EncryptionKeyConfigured should be false when auto-generated") + } + if len(cfg.Totp.EncryptionKey) != 64 { + t.Errorf("TOTP key should be 64 hex chars, got %d", len(cfg.Totp.EncryptionKey)) + } +} + +// --- Integration: TOTP Pre-configured --- + +func TestIntegration_Load_TOTPPreconfigured(t *testing.T) { + yamlContent := ` +jwt: + secret: ` + strings.Repeat("g", 32) + ` +totp: + encryption_key: ` + strings.Repeat("a", 64) + ` +` + resetViperWithContent(t, yamlContent) + + cfg, err := Load() + if err != nil { t.Fatalf("Load() error: %v", err) } + if !cfg.Totp.EncryptionKeyConfigured { + t.Error("EncryptionKeyConfigured should be true when explicitly configured") + } + if cfg.Totp.EncryptionKey != strings.Repeat("a", 64) { + t.Error("TOTP key from config mismatch") + } +} + +// --- Integration: Gateway Defaults Validation --- + +func TestIntegration_GatewayValidation(t *testing.T) { + resetViperClean(t) + os.Setenv("JWT_SECRET", strings.Repeat("h", 32)) + + cfg, err := Load() + if err != nil { t.Fatalf("Load() error: %v", err) } + + if cfg.Gateway.ConnectionPoolIsolation != ConnectionPoolIsolationAccountProxy { + t.Error("ConnectionPoolIsolation default mismatch") + } + if !cfg.Gateway.OpenAIWS.Enabled { t.Error("openai_ws.enabled should default true") } + if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample { + t.Error("OverflowPolicy default should be sample") + } +} + +// ============================================================================= +// Critical Integration Test: All Domain Files Contribute to Full Config +// This is THE regression test for the config split refactor from 2497-line config.go. +// It verifies that every one of the 15 domain files contributes its types and defaults. +// ============================================================================= + +func TestIntegration_AllDomainFiles_ContributeToFullConfig(t *testing.T) { + resetViperClean(t) + os.Setenv("JWT_SECRET", strings.Repeat("i", 32)) + + cfg, err := Load() + if err != nil { t.Fatalf("Load() error: %v", err) } + + // Verify all domains loaded — no zero-value struct left behind by the split + domainChecks := []struct { + name string + check func(*Config) bool + }{ + {"Server/Host", func(c *Config) bool { return c.Server.Host != "" }}, + {"Server/Port", func(c *Config) bool { return c.Server.Port > 0 }}, + {"Server/Mode", func(c *Config) bool { return c.Server.Mode != "" }}, + {"Log/Level", func(c *Config) bool { return c.Log.Level != "" }}, + {"Log/Format", func(c *Config) bool { return c.Log.Format != "" }}, + {"CORS", func(c *Config) bool { return true }}, // empty slice is valid + {"Security", func(c *Config) bool { return true }}, // bool fields fine + {"Billing", func(c *Config) bool { return true }}, + {"Turnstile", func(c *Config) bool { return true }}, + {"Database/Host", func(c *Config) bool { return c.Database.Host != "" }}, + {"Database/DSN", func(c *Config) bool { return c.Database.DSN() != "" }}, + {"Database/DSNWithTZ", func(c *Config) bool { return c.Database.DSNWithTimezone("UTC") != "" }}, + {"Redis/Host", func(c *Config) bool { return c.Redis.Host != "" }}, + {"Redis/Address", func(c *Config) bool { return c.Redis.Address() != "" }}, + {"Ops", func(c *Config) bool { return true }}, + {"JWT/Secret", func(c *Config) bool { return len(c.JWT.Secret) >= 32 }}, + {"Totp", func(c *Config) bool { return c.Totp.EncryptionKey != "" }}, // auto-generated + {"LinuxDo", func(c *Config) bool { return true }}, + {"OIDC", func(c *Config) bool { return true }}, + {"Default", func(c *Config) bool { return true }}, + {"RateLimit", func(c *Config) bool { return true }}, + {"Pricing/RemoteURL", func(c *Config) bool { return c.Pricing.RemoteURL != "" }}, + {"Gateway", func(c *Config) bool { return true }}, + {"APIKeyAuth/L1Size", func(c *Config) bool { return c.APIKeyAuth.L1Size > 0 }}, + {"SubscriptionCache/L1Size", func(c *Config) bool { return c.SubscriptionCache.L1Size > 0 }}, + {"SubscriptionMaintenance", func(c *Config) bool { return true }}, + {"Dashboard/Enabled", func(c *Config) bool { return true }}, + {"DashboardAgg/Interval", func(c *Config) bool { return c.DashboardAgg.IntervalSeconds > 0 }}, + {"UsageCleanup/Enabled", func(c *Config) bool { return true }}, + {"Concurrency/PingInterval", func(c *Config) bool { return c.Concurrency.PingInterval > 0 }}, + {"TokenRefresh/Enabled", func(c *Config) bool { return true }}, + {"Sora", func(c *Config) bool { return true }}, + {"Gemini", func(c *Config) bool { return true }}, + {"Update", func(c *Config) bool { return true }}, + {"Idempotency/TTL", func(c *Config) bool { return c.Idempotency.DefaultTTLSeconds > 0 }}, + {"RunMode", func(c *Config) bool { return c.RunMode != "" }}, + {"Timezone", func(c *Config) bool { return c.Timezone != "" }}, + } + + for _, dc := range domainChecks { + dc := dc + t.Run(dc.name, func(t *testing.T) { + if !dc.check(cfg) { + t.Errorf("domain [%s] check failed — value appears to be zero/uninitialized", dc.name) + } + }) + } + + // Final Validate() call must pass for fully-loaded defaults + if err := cfg.Validate(); err != nil { + t.Errorf("fully-loaded default config failed validation: %v", err) + } + + t.Logf("All %d domain checks passed ✅", len(domainChecks)) +} + +// --- Integration: RunMode Normalization --- + +func TestIntegration_RunModeNormalizationInLoad(t *testing.T) { + tests := []struct { + envValue string + expected string + }{ + {"STANDARD", "standard"}, + {"SIMPLE", "simple"}, + {"invalid-value", "standard"}, // unknown → standard + {"", "standard"}, + } + for _, tc := range tests { + tc := tc + t.Run(fmt.Sprintf("runmode_%s", tc.envValue), func(t *testing.T) { + resetViperClean(t) + os.Setenv("JWT_SECRET", strings.Repeat("j", 32)) + os.Setenv("RUN_MODE", tc.envValue) + + cfg, err := Load() + if err != nil { t.Fatalf("Load() error: %v", err) } + if cfg.RunMode != tc.expected { + t.Errorf("RunMode=%q, want %q (from env RUN_MODE=%q)", cfg.RunMode, tc.expected, tc.envValue) + } + }) + } +} + +// --- Integration: String Field Normalization --- + +func TestIntegration_StringFieldTrimming(t *testing.T) { + yamlContent := ` +jwt: + secret: ` + strings.Repeat("k ", 32) + ` +linuxdo_connect: + client_id: my-client-id + client_secret: my-secret +oidc_connect: + client_id: oidc-client-id +dashboard_cache: + key_prefix: test-prefix: +cors: + allowed_origins: + - https://example.com + - http://localhost:3000 +` + resetViperWithContent(t, yamlContent) + + cfg, err := Load() + if err != nil { t.Fatalf("Load() error: %v", err) } + + // All string fields should have been trimmed + if cfg.JWT.Secret != strings.TrimSpace(strings.Repeat("k ", 32)) { + t.Error("JWT secret was not trimmed") + } + if cfg.LinuxDo.ClientID != "my-client-id" { t.Error("LinuxDo ClientID not trimmed") } + if cfg.LinuxDo.ClientSecret != "my-secret" { t.Error("LinuxDo ClientSecret not trimmed") } + if cfg.OIDC.ClientID != "oidc-client-id" { t.Error("OIDC ClientID not trimmed") } + if cfg.Dashboard.KeyPrefix != "test-prefix:" { t.Error("Dashboard KeyPrefix not trimmed") } + if len(cfg.CORS.AllowedOrigins) != 2 { t.Errorf("CORS origins count=2 expected, got %d", len(cfg.CORS.AllowedOrigins)) } + if cfg.CORS.AllowedOrigins[0] != "https://example.com" { t.Error("CORS origin not trimmed") } +} diff --git a/backend/internal/config/config_validate.go b/backend/internal/config/config_validate.go new file mode 100644 index 00000000..05a0f4f2 --- /dev/null +++ b/backend/internal/config/config_validate.go @@ -0,0 +1,297 @@ +package config + +import ( + "fmt" + "log/slog" + "net/url" + "strings" +) + +// Validate validates the configuration. Returns error on any invalid field. +func (c *Config) Validate() error { + if err := validateJWT(&c.JWT); err != nil { return err } + if err := validateLog(&c.Log); err != nil { return err } + if err := validateServerURL(c.Server.FrontendURL); err != nil { return err } + if err := validateLinuxDo(&c.LinuxDo); err != nil { return err } + if err := validateOIDC(&c.OIDC); err != nil { return err } + if err := validateBilling(&c.Billing); err != nil { return err } + if err := validateDatabase(&c.Database); err != nil { return err } + if err := validateRedis(&c.Redis); err != nil { return err } + if err := validateDashboard(&c.Dashboard, &c.DashboardAgg); err != nil { return err } + if err := validateUsageCleanup(&c.UsageCleanup); err != nil { return err } + if err := validateIdempotency(&c.Idempotency); err != nil { return err } + if err := validateGateway(&c.Gateway); err != nil { return err } + if err := validateOps(&c.Ops); err != nil { return err } + if err := validateConcurrency(&c.Concurrency); err != nil { return err } + return nil +} + +func validateJWT(j *JWTConfig) error { + s := strings.TrimSpace(j.Secret) + if s == "" { return fmt.Errorf("jwt.secret is required") } + if len([]byte(s)) < 32 { return fmt.Errorf("jwt.secret must be at least 32 bytes") } + if j.ExpireHour <= 0 { return fmt.Errorf("jwt.expire_hour must be positive") } + if j.ExpireHour > 168 { return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)") } + if j.ExpireHour > 24 { + slog.Warn("jwt.expire_hour is high; consider shorter expiration for security", "expire_hour", j.ExpireHour) + } + if j.AccessTokenExpireMinutes < 0 { return fmt.Errorf("jwt.access_token_expire_minutes must be non-negative") } + if j.AccessTokenExpireMinutes > 720 { + slog.Warn("jwt.access_token_expire_minutes is high", "access_token_expire_minutes", j.AccessTokenExpireMinutes) + } + if j.RefreshTokenExpireDays <= 0 { return fmt.Errorf("jwt.refresh_token_expire_days must be positive") } + if j.RefreshTokenExpireDays > 90 { + slog.Warn("jwt.refresh_token_expire_days is high", "refresh_token_expire_days", j.RefreshTokenExpireDays) + } + if j.RefreshWindowMinutes < 0 { return fmt.Errorf("jwt.refresh_window_minutes must be non-negative") } + return nil +} + +func validateLog(l *LogConfig) error { + switch l.Level { + case "debug", "info", "warn", "error": + case "": + return fmt.Errorf("log.level is required") + default: + return fmt.Errorf("log.level must be one of: debug/info/warn/error") + } + switch l.Format { + case "json", "console": + case "": + return fmt.Errorf("log.format is required") + default: + return fmt.Errorf("log.format must be one of: json/console") + } + switch l.StacktraceLevel { + case "none", "error", "fatal": + case "": + return fmt.Errorf("log.stacktrace_level is required") + default: + return fmt.Errorf("log.stacktrace_level must be one of: none/error/fatal") + } + if !l.Output.ToStdout && !l.Output.ToFile { + return fmt.Errorf("log.output.to_stdout and log.output.to_file cannot both be false") + } + if l.Rotation.MaxSizeMB <= 0 { return fmt.Errorf("log.rotation.max_size_mb must be positive") } + if l.Rotation.MaxBackups < 0 { return fmt.Errorf("log.rotation.max_backups must be non-negative") } + if l.Rotation.MaxAgeDays < 0 { return fmt.Errorf("log.rotation.max_age_days must be non-negative") } + if l.Sampling.Enabled { + if l.Sampling.Initial <= 0 { return fmt.Errorf("log.sampling.initial must be positive when sampling is enabled") } + if l.Sampling.Thereafter <= 0 { return fmt.Errorf("log.sampling.thereafter must be positive when sampling is enabled") } + } else { + if l.Sampling.Initial < 0 { return fmt.Errorf("log.sampling.initial must be non-negative") } + if l.Sampling.Thereafter < 0 { return fmt.Errorf("log.sampling.thereafter must be non-negative") } + } + return nil +} + +func validateServerURL(frontendURL string) error { + if strings.TrimSpace(frontendURL) == "" { return nil } + if err := ValidateAbsoluteHTTPURL(frontendURL); err != nil { + return fmt.Errorf("server.frontend_url invalid: %w", err) + } + u, err := url.Parse(strings.TrimSpace(frontendURL)) + if err != nil { return fmt.Errorf("server.frontend_url invalid: %w", err) } + if u.RawQuery != "" || u.ForceQuery { return fmt.Errorf("server.frontend_url invalid: must not include query") } + if u.User != nil { return fmt.Errorf("server.frontend_url invalid: must not include userinfo") } + warnIfInsecureURL("server.frontend_url", frontendURL) + return nil +} + +func validateLinuxDo(lc *LinuxDoConnectConfig) error { + if !lc.Enabled { return nil } + requiredStringFields := map[string]string{ + "client_id": lc.ClientID, + "authorize_url": lc.AuthorizeURL, + "token_url": lc.TokenURL, + "userinfo_url": lc.UserInfoURL, + "redirect_url": lc.RedirectURL, + "frontend_redirect_url": lc.FrontendRedirectURL, + } + for k, v := range requiredStringFields { + if strings.TrimSpace(v) == "" { + return fmt.Errorf("linuxdo_connect.%s is required when linuxdo_connect.enabled=true", k) + } + } + method := strings.ToLower(strings.TrimSpace(lc.TokenAuthMethod)) + switch method { + case "", "client_secret_post", "client_secret_basic", "none": + default: + return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none") + } + if method == "none" && !lc.UsePKCE { + return fmt.Errorf("linuxdo_connect.use_pkce must be true when token_auth_method=none") + } + if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && + strings.TrimSpace(lc.ClientSecret) == "" { + return fmt.Errorf("linuxdo_connect.client_secret is required when enabled=true and token_auth_method is client_secret_post/client_secret_basic") + } + urlsToValidate := []struct{ key, val string }{ + {"authorize_url", lc.AuthorizeURL}, {"token_url", lc.TokenURL}, + {"userinfo_url", lc.UserInfoURL}, {"redirect_url", lc.RedirectURL}, + } + for _, u := range urlsToValidate { + if err := ValidateAbsoluteHTTPURL(u.val); err != nil { + return fmt.Errorf("linuxdo_connect.%s invalid: %w", u.key, err) + } + warnIfInsecureURL("linuxdo_connect."+u.key, u.val) + } + if err := ValidateFrontendRedirectURL(lc.FrontendRedirectURL); err != nil { + return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err) + } + warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", lc.FrontendRedirectURL) + return nil +} + +func validateOIDC(oc *OIDCConnectConfig) error { + if !oc.Enabled { return nil } + if strings.TrimSpace(oc.ClientID) == "" { return fmt.Errorf("oidc_connect.client_id is required when enabled=true") } + if strings.TrimSpace(oc.IssuerURL) == "" { return fmt.Errorf("oidc_connect.issuer_url is required when enabled=true") } + if strings.TrimSpace(oc.RedirectURL) == "" { return fmt.Errorf("oidc_connect.redirect_url is required when enabled=true") } + if strings.TrimSpace(oc.FrontendRedirectURL) == "" { return fmt.Errorf("oidc_connect.frontend_redirect_url is required when enabled=true") } + if !scopeContainsOpenID(oc.Scopes) { return fmt.Errorf("oidc_connect.scopes must contain openid") } + + method := strings.ToLower(strings.TrimSpace(oc.TokenAuthMethod)) + switch method { case "", "client_secret_post", "client_secret_basic", "none": default: return fmt.Errorf("oidc_connect.token_auth_method must be valid") } + if method == "none" && !oc.UsePKCE { return fmt.Errorf("oidc_connect.use_pkce must be true when token_auth_method=none") } + if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(oc.ClientSecret) == "" { + return fmt.Errorf("oidc_connect.client_secret is required when enabled=true") + } + if oc.ClockSkewSeconds < 0 || oc.ClockSkewSeconds > 600 { return fmt.Errorf("oidc_connect.clock_skew_seconds must be between 0-600") } + if oc.ValidateIDToken && strings.TrimSpace(oc.AllowedSigningAlgs) == "" { return fmt.Errorf("oidc_connect.allowed_signing_algs required when validate_id_token=true") } + + // Validate URLs (only if set — discovery can auto-populate these) + for _, u := range []struct{k, v string}{{"issuer_url", oc.IssuerURL}, {"redirect_url", oc.RedirectURL}, {"frontend_redirect_url", oc.FrontendRedirectURL}} { + if err := ValidateAbsoluteHTTPURL(u.v); err != nil { return fmt.Errorf("oidc_connect.%s invalid: %w", u.k, err) } + warnIfInsecureURL("oidc_connect."+u.k, u.v) + } + for _, k := range []string{"discovery_url", "authorize_url", "token_url", "userinfo_url", "jwks_url"} { + v := getOIDCStringField(oc, k) + if v != "" { + if err := ValidateAbsoluteHTTPURL(v); err != nil { return fmt.Errorf("oidc_connect.%s invalid: %w", k, err) } + warnIfInsecureURL("oidc_connect."+k, v) + } + } + if err := ValidateFrontendRedirectURL(oc.FrontendRedirectURL); err != nil { return fmt.Errorf("oidc_connect.frontend_redirect_url invalid: %w", err) } + return nil +} + +func getOIDCStringField(oc *OIDCConnectConfig, field string) string { + switch field { + case "discovery_url": return oc.DiscoveryURL + case "authorize_url": return oc.AuthorizeURL + case "token_url": return oc.TokenURL + case "userinfo_url": return oc.UserInfoURL + case "jwks_url": return oc.JWKSURL + } + return "" +} + +func validateBilling(b *BillingConfig) error { + if !b.CircuitBreaker.Enabled { return nil } + if b.CircuitBreaker.FailureThreshold <= 0 { return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive") } + if b.CircuitBreaker.ResetTimeoutSeconds <= 0 { return fmt.Errorf("billing.circuit_breaker.reset_timeout_seconds must be positive") } + if b.CircuitBreaker.HalfOpenRequests <= 0 { return fmt.Errorf("billing.circuit_breaker.half_open_requests must be positive") } + return nil +} + +func validateDatabase(d *DatabaseConfig) error { + if d.MaxOpenConns <= 0 { return fmt.Errorf("database.max_open_conns must be positive") } + if d.MaxIdleConns < 0 { return fmt.Errorf("database.max_idle_conns must be non-negative") } + if d.MaxIdleConns > d.MaxOpenConns { return fmt.Errorf("database.max_idle_conns cannot exceed max_open_conns") } + if d.ConnMaxLifetimeMinutes < 0 { return fmt.Errorf("database.conn_max_lifetime_minutes must be non-negative") } + if d.ConnMaxIdleTimeMinutes < 0 { return fmt.Errorf("database.conn_max_idle_time_minutes must be non-negative") } + return nil +} + +func validateRedis(r *RedisConfig) error { + if r.DialTimeoutSeconds <= 0 { return fmt.Errorf("redis.dial_timeout_seconds must be positive") } + if r.ReadTimeoutSeconds <= 0 { return fmt.Errorf("redis.read_timeout_seconds must be positive") } + if r.WriteTimeoutSeconds <= 0 { return fmt.Errorf("redis.write_timeout_seconds must be positive") } + if r.PoolSize <= 0 { return fmt.Errorf("redis.pool_size must be positive") } + if r.MinIdleConns < 0 { return fmt.Errorf("redis.min_idle_conns must be non-negative") } + if r.MinIdleConns > r.PoolSize { return fmt.Errorf("redis.min_idle_conns cannot exceed pool_size") } + return nil +} + +func validateDashboard(d *DashboardCacheConfig, da *DashboardAggregationConfig) error { + if d.Enabled { + if d.StatsFreshTTLSeconds <= 0 { return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be positive") } + if d.StatsTTLSeconds <= 0 { return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be positive") } + if d.StatsRefreshTimeoutSeconds <= 0 { return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be positive") } + if d.StatsFreshTTLSeconds > d.StatsTTLSeconds { return fmt.Errorf("stats_fresh_ttl_seconds must be <= stats_ttl_seconds") } + } else { + if d.StatsFreshTTLSeconds < 0 || d.StatsTTLSeconds < 0 || d.StatsRefreshTimeoutSeconds < 0 { + return fmt.Errorf("dashboard cache fields must be non-negative when disabled") + } + } + return validateDashboardAgg(da) +} + +func validateDashboardAgg(da *DashboardAggregationConfig) error { + if !da.Enabled { + // Non-enabled: all fields just need to be non-negative where numeric + if da.IntervalSeconds < 0 || da.LookbackSeconds < 0 || da.BackfillMaxDays < 0 || + da.Retention.UsageLogsDays < 0 || da.Retention.UsageBillingDedupDays < 0 || + da.Retention.HourlyDays < 0 || da.Retention.DailyDays < 0 || da.RecomputeDays < 0 { + return fmt.Errorf("dashboard_aggregation numeric fields must be non-negative when disabled") + } + return nil + } + if da.IntervalSeconds <= 0 { return fmt.Errorf("dashboard_aggregation.interval_seconds must be positive") } + if da.LookbackSeconds < 0 { return fmt.Errorf("lookback_seconds must be non-negative") } + if da.BackfillMaxDays < 0 { return fmt.Errorf("backfill_max_days must be non-negative") } + if da.BackfillEnabled && da.BackfillMaxDays == 0 { return fmt.Errorf("backfill_max_days must be positive when backfill_enabled") } + if da.Retention.UsageLogsDays <= 0 { return fmt.Errorf("retention.usage_logs_days must be positive") } + if da.Retention.UsageBillingDedupDays <= 0 { return fmt.Errorf("retention.usage_billing_dedup_days must be positive") } + if da.Retention.UsageBillingDedupDays < da.Retention.UsageLogsDays { + return fmt.Errorf("usage_billing_dedup_days >= usage_logs_days") + } + if da.Retention.HourlyDays <= 0 { return fmt.Errorf("retention.hourly_days must be positive") } + if da.Retention.DailyDays <= 0 { return fmt.Errorf("retention.daily_days must be positive") } + if da.RecomputeDays < 0 { return fmt.Errorf("recompute_days must be non-negative") } + return nil +} + +func validateUsageCleanup(uc *UsageCleanupConfig) error { + if !uc.Enabled { + if uc.MaxRangeDays < 0 || uc.BatchSize < 0 || uc.WorkerIntervalSeconds < 0 || uc.TaskTimeoutSeconds < 0 { + return fmt.Errorf("usage_cleanup fields must be non-negative when disabled") + } + return nil + } + if uc.MaxRangeDays <= 0 { return fmt.Errorf("usage_cleanup.max_range_days must be positive") } + if uc.BatchSize <= 0 { return fmt.Errorf("usage_cleanup.batch_size must be positive") } + if uc.WorkerIntervalSeconds <= 0 { return fmt.Errorf("usage_cleanup.worker_interval_seconds must be positive") } + if uc.TaskTimeoutSeconds <= 0 { return fmt.Errorf("usage_cleanup.task_timeout_seconds must be positive") } + return nil +} + +func validateIdempotency(i *IdempotencyConfig) error { + if i.DefaultTTLSeconds <= 0 { return fmt.Errorf("default_ttl_seconds must be positive") } + if i.SystemOperationTTLSeconds <= 0 { return fmt.Errorf("system_operation_ttl_seconds must be positive") } + if i.ProcessingTimeoutSeconds <= 0 { return fmt.Errorf("processing_timeout_seconds must be positive") } + if i.FailedRetryBackoffSeconds <= 0 { return fmt.Errorf("failed_retry_backoff_seconds must be positive") } + if i.MaxStoredResponseLen <= 0 { return fmt.Errorf("max_stored_response_len must be positive") } + if i.CleanupIntervalSeconds <= 0 { return fmt.Errorf("cleanup_interval_seconds must be positive") } + if i.CleanupBatchSize <= 0 { return fmt.Errorf("cleanup_batch_size must be positive") } + return nil +} + +func validateOps(o *OpsConfig) error { + if o.MetricsCollectorCache.TTL < 0 { return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative") } + if o.Cleanup.ErrorLogRetentionDays < 0 || o.Cleanup.MinuteMetricsRetentionDays < 0 || o.Cleanup.HourlyMetricsRetentionDays < 0 { + return fmt.Errorf("ops cleanup retention days must be non-negative") + } + if o.Cleanup.Enabled && strings.TrimSpace(o.Cleanup.Schedule) == "" { + return fmt.Errorf("ops.cleanup.schedule is required when enabled=true") + } + return nil +} + +func validateConcurrency(c *ConcurrencyConfig) error { + if c.PingInterval < 5 || c.PingInterval > 30 { + return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds") + } + return nil +} diff --git a/backend/internal/config/config_validate_gateway.go b/backend/internal/config/config_validate_gateway.go new file mode 100644 index 00000000..339f1a31 --- /dev/null +++ b/backend/internal/config/config_validate_gateway.go @@ -0,0 +1,159 @@ +package config + +import ( + "fmt" + "log/slog" + "strings" +) + +func validateGateway(g *GatewayConfig) error { + if g.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") } + if g.UpstreamResponseReadMaxBytes <= 0 { return fmt.Errorf("upstream_response_read_max_bytes must be positive") } + if g.ProxyProbeResponseReadMaxBytes <= 0 { return fmt.Errorf("proxy_probe_response_read_max_bytes must be positive") } + + if strings.TrimSpace(g.ConnectionPoolIsolation) != "" { + switch g.ConnectionPoolIsolation { + case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy: + default: return fmt.Errorf("invalid connection_pool_isolation") + } + } + if g.MaxIdleConns <= 0 || g.MaxIdleConnsPerHost <= 0 || g.MaxConnsPerHost < 0 { + return fmt.Errorf("gateway connection pool fields invalid") + } + if g.IdleConnTimeoutSeconds <= 0 { return fmt.Errorf("idle_conn_timeout_seconds must be positive") } + if g.IdleConnTimeoutSeconds > 180 { slog.Warn("idle_conn_timeout_seconds is high; consider 60-120") } + if g.MaxUpstreamClients <= 0 { return fmt.Errorf("max_upstream_clients must be positive") } + if g.ClientIdleTTLSeconds <= 0 { return fmt.Errorf("client_idle_ttl_seconds must be positive") } + if g.ConcurrencySlotTTLMinutes <= 0 { return fmt.Errorf("concurrency_slot_ttl_minutes must be positive") } + + if err := validateGatewayStream(g); err != nil { return err } + if err := validateGatewayOpenAIWS(&g.OpenAIWS); err != nil { return err } + if err := validateGatewayUsageRecord(&g.UsageRecord); err != nil { return err } + if err := validateGatewayScheduling(&g.Scheduling); err != nil { return err } + if g.UserGroupRateCacheTTLSeconds <= 0 { return fmt.Errorf("user_group_rate_cache_ttl_seconds must be positive") } + if g.ModelsListCacheTTLSeconds < 10 || g.ModelsListCacheTTLSeconds > 30 { + return fmt.Errorf("models_list_cache_ttl_seconds must be between 10-30") + } + return nil +} + +func validateGatewayStream(g *GatewayConfig) error { + if g.StreamDataIntervalTimeout < 0 { return fmt.Errorf("stream_data_interval_timeout must be non-negative") } + if g.StreamDataIntervalTimeout != 0 && (g.StreamDataIntervalTimeout < 30 || g.StreamDataIntervalTimeout > 300) { + return fmt.Errorf("stream_data_interval_timeout must be 0 or between 30-300 seconds") + } + if g.StreamKeepaliveInterval < 0 { return fmt.Errorf("stream_keepalive_interval must be non-negative") } + if g.StreamKeepaliveInterval != 0 && (g.StreamKeepaliveInterval < 5 || g.StreamKeepaliveInterval > 30) { + return fmt.Errorf("stream_keepalive_interval must be 0 or between 5-30 seconds") + } + if g.MaxLineSize < 0 { return fmt.Errorf("max_line_size must be non-negative") } + if g.MaxLineSize != 0 && g.MaxLineSize < 1024*1024 { return fmt.Errorf("max_line_size must be at least 1MB") } + return nil +} + +func validateGatewayOpenAIWS(ws *GatewayOpenAIWSConfig) error { + // Basic numeric checks + checks := []struct{ name string; val int64 }{ + {"max_conns_per_account", int64(ws.MaxConnsPerAccount)}, + {"dial_timeout_seconds", int64(ws.DialTimeoutSeconds)}, + {"read_timeout_seconds", int64(ws.ReadTimeoutSeconds)}, + {"write_timeout_seconds", int64(ws.WriteTimeoutSeconds)}, + {"pool_target_utilization", int64(ws.PoolTargetUtilization * 100)}, // scale for comparison + {"queue_limit_per_conn", int64(ws.QueueLimitPerConn)}, + {"event_flush_batch_size", int64(ws.EventFlushBatchSize)}, + {"lb_top_k", int64(ws.LBTopK)}, + {"sticky_session_ttl_seconds", int64(ws.StickySessionTTLSeconds)}, + {"sticky_response_id_ttl_seconds", int64(ws.StickyResponseIDTTLSeconds)}, + {"max_conns_per_account", int64(ws.MaxConnsPerAccount)}, + } + for _, c := range checks { + if c.val <= 0 { return fmt.Errorf("openai_ws.%s must be positive", c.name) } + } + if ws.MinIdlePerAccount < 0 || ws.MaxIdlePerAccount < 0 { + return fmt.Errorf("openai_ws idle per-account fields must be non-negative") + } + if ws.MinIdlePerAccount > ws.MaxIdlePerAccount { return fmt.Errorf("min_idle_per_account must be <= max_idle_per_account") } + if ws.MaxIdlePerAccount > ws.MaxConnsPerAccount { return fmt.Errorf("max_idle_per_account must be <= max_conns_per_account") } + if ws.OAuthMaxConnsFactor <= 0 || ws.APIKeyMaxConnsFactor <= 0 { + return fmt.Errorf("openai_ws conns factor must be positive") + } + if ws.PoolTargetUtilization <= 0 || ws.PoolTargetUtilization > 1 { + return fmt.Errorf("pool_target_utilization must be within (0,1]") + } + if ws.EventFlushIntervalMS < 0 || ws.PrewarmCooldownMS < 0 || + ws.FallbackCooldownSeconds < 0 || ws.RetryBackoffInitialMS < 0 || + ws.RetryBackoffMaxMS < 0 || ws.RetryTotalBudgetMS < 0 { + return fmt.Errorf("openai_ws timeout/retry fields must be non-negative") + } + if ws.RetryBackoffInitialMS > 0 && ws.RetryBackoffMaxMS > 0 && ws.RetryBackoffMaxMS < ws.RetryBackoffInitialMS { + return fmt.Errorf("retry_backoff_max_ms >= retry_backoff_initial_ms") + } + if ws.RetryJitterRatio < 0 || ws.RetryJitterRatio > 1 { return fmt.Errorf("retry_jitter_ratio within [0,1]") } + if ws.PayloadLogSampleRate < 0 || ws.PayloadLogSampleRate > 1 { return fmt.Errorf("payload_log_sample_rate within [0,1]") } + if ws.StickyPreviousResponseTTLSeconds < 0 { return fmt.Errorf("sticky_previous_response_ttl_seconds must be non-negative") } + if ws.SchedulerScoreWeights.Priority < 0 || ws.SchedulerScoreWeights.Load < 0 || + ws.SchedulerScoreWeights.Queue < 0 || ws.SchedulerScoreWeights.ErrorRate < 0 || ws.SchedulerScoreWeights.TTFT < 0 { + return fmt.Errorf("scheduler_score_weights must be non-negative") + } + weightSum := ws.SchedulerScoreWeights.Priority + ws.SchedulerScoreWeights.Load + ws.SchedulerScoreWeights.Queue + + ws.SchedulerScoreWeights.ErrorRate + ws.SchedulerScoreWeights.TTFT + if weightSum <= 0 { return fmt.Errorf("scheduler_score_weights must not all be zero") } + // Ingress mode + if mode := strings.ToLower(strings.TrimSpace(ws.IngressModeDefault)); mode != "" { + switch mode { case "off", "ctx_pool", "passthrough": default: return fmt.Errorf("ingress_mode_default must be off|ctx_pool|passthrough") } + } + if mode := strings.ToLower(strings.TrimSpace(ws.StoreDisabledConnMode)); mode != "" { + switch mode { case "strict", "adaptive", "off": default: return fmt.Errorf("store_disabled_conn_mode must be strict|adaptive|off") } + } + return nil +} + +func validateGatewayUsageRecord(ur *GatewayUsageRecordConfig) error { + if ur.WorkerCount <= 0 || ur.QueueSize <= 0 || ur.TaskTimeoutSeconds <= 0 { + return fmt.Errorf("usage_record worker/queue/timeout must be positive") + } + switch ur.OverflowPolicy { + case UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync: + default: return fmt.Errorf("invalid overflow_policy") + } + if ur.OverflowSamplePercent < 0 || ur.OverflowSamplePercent > 100 { + return fmt.Errorf("overflow_sample_percent must be 0-100") + } + if strings.EqualFold(ur.OverflowPolicy, UsageRecordOverflowPolicySample) && ur.OverflowSamplePercent <= 0 { + return fmt.Errorf("overflow_sample_percent must be positive when policy=sample") + } + if !ur.AutoScaleEnabled { return nil } + if ur.AutoScaleMinWorkers <= 0 || ur.AutoScaleMaxWorkers <= 0 { return fmt.Errorf("auto_scale workers must be positive") } + if ur.AutoScaleMaxWorkers < ur.AutoScaleMinWorkers { return fmt.Errorf("auto_scale_max >= auto_scale_min") } + if ur.WorkerCount < ur.AutoScaleMinWorkers || ur.WorkerCount > ur.AutoScaleMaxWorkers { + return fmt.Errorf("worker_count between auto_scale_min and max") + } + if ur.AutoScaleUpQueuePercent <= 0 || ur.AutoScaleUpQueuePercent > 100 { return fmt.Errorf("auto_scale_up_queue_percent 1-100") } + if ur.AutoScaleDownQueuePercent < 0 || ur.AutoScaleDownQueuePercent >= 100 { return fmt.Errorf("auto_scale_down_queue_percent 0-99") } + if ur.AutoScaleDownQueuePercent >= ur.AutoScaleUpQueuePercent { return fmt.Errorf("down_queue_percent < up_queue_percent") } + if ur.AutoScaleUpStep <= 0 || ur.AutoScaleDownStep <= 0 { return fmt.Errorf("auto_scale steps must be positive") } + if ur.AutoScaleCheckIntervalSeconds <= 0 { return fmt.Errorf("auto_scale_check_interval_seconds must be positive") } + if ur.AutoScaleCooldownSeconds < 0 { return fmt.Errorf("auto_scale_cooldown_seconds must be non-negative") } + return nil +} + +func validateGatewayScheduling(s *GatewaySchedulingConfig) error { + if s.StickySessionMaxWaiting <= 0 || s.StickySessionWaitTimeout <= 0 || + s.FallbackWaitTimeout <= 0 || s.FallbackMaxWaiting <= 0 || + s.SnapshotMGetChunkSize <= 0 || s.SnapshotWriteChunkSize <= 0 { + return fmt.Errorf("scheduling core fields must be positive") + } + if s.SlotCleanupInterval < 0 || s.DbFallbackTimeoutSeconds < 0 || s.DbFallbackMaxQPS < 0 { + return fmt.Errorf("scheduling optional fields must be non-negative") + } + if s.OutboxPollIntervalSeconds <= 0 || s.OutboxLagRebuildFailures <= 0 || s.OutboxBacklogRebuildRows < 0 { + return fmt.Errorf("outbox fields must be non-negative or positive as documented") + } + if s.OutboxLagWarnSeconds < 0 || s.OutboxLagRebuildSeconds < 0 || s.FullRebuildIntervalSeconds < 0 { + return fmt.Errorf("outbox timing fields must be non-negative") + } + if s.OutboxLagWarnSeconds > 0 && s.OutboxLagRebuildSeconds > 0 && s.OutboxLagRebuildSeconds < s.OutboxLagWarnSeconds { + return fmt.Errorf("outbox_lag_rebuild_seconds >= outbox_lag_warn_seconds") + } + return nil +} diff --git a/backend/internal/config/config_validate_test.go b/backend/internal/config/config_validate_test.go new file mode 100644 index 00000000..68901a85 --- /dev/null +++ b/backend/internal/config/config_validate_test.go @@ -0,0 +1,618 @@ +package config + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +// ============================================================================= +// Test: config_validate.go — All Validation Functions +// 覆盖: Validate, validateJWT, validateLog, validateServerURL, +// validateLinuxDo, validateOIDC, validateBilling, validateDatabase, +// validateRedis, validateDashboard, validateDashboardAgg, +// validateUsageCleanup, validateIdempotency, validateOps, +// validateConcurrency +// ============================================================================= + +// --- validateJWT --- + +func TestValidateJWT(t *testing.T) { + tests := []struct { + name string + cfg JWTConfig + wantErr bool + errContains string + }{ + { + name: "valid config", + cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, RefreshTokenExpireDays: 30}, + wantErr: false, + }, + { + name: "empty secret", + cfg: JWTConfig{Secret: "", ExpireHour: 24}, + wantErr: true, + errContains: "jwt.secret is required", + }, + { + name: "secret too short (<32 bytes)", + cfg: JWTConfig{Secret: "short", ExpireHour: 24}, + wantErr: true, + errContains: "jwt.secret must be at least 32 bytes", + }, + { + name: "secret exactly 32 bytes (valid)", + cfg: JWTConfig{Secret: strings.Repeat("a", 32), ExpireHour: 24}, + wantErr: false, + }, + { + name: "expire_hour zero or negative", + cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 0}, + wantErr: true, + errContains: "jwt.expire_hour must be positive", + }, + { + name: "expire_hour exceeds max (168)", + cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 169}, + wantErr: true, + errContains: "jwt.expire_hour must be <= 168", + }, + { + name: "expire_hour exactly 168 (7 days)", + cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 168}, + wantErr: false, + }, + { + name: "access_token_expire_minutes negative", + cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, AccessTokenExpireMinutes: -1}, + wantErr: true, + errContains: "jwt.access_token_expire_minutes must be non-negative", + }, + { + name: "access_token_expire_minutes too high (>720)", + cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, AccessTokenExpireMinutes: 721}, + wantErr: false, // only warns, not errors + }, + { + name: "refresh_token_expire_days zero", + cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, RefreshTokenExpireDays: 0}, + wantErr: true, + errContains: "jwt.refresh_token_expire_days must be positive", + }, + { + name: "refresh_token_expire_days >90 warns but passes", + cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, RefreshTokenExpireDays: 91}, + wantErr: false, // only warns + }, + { + name: "refresh_window_minutes negative", + cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, RefreshWindowMinutes: -1}, + wantErr: true, + errContains: "jwt.refresh_window_minutes must be non-negative", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validateJWT(&tc.cfg) + if tc.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.errContains) + } else { + assert.NoError(t, err) + } + }) + } +} + +// --- validateLog --- + +func TestValidateLog(t *testing.T) { + validLog := LogConfig{ + Level: "info", Format: "json", StacktraceLevel: "error", + Output: LogOutputConfig{ToStdout: true, ToFile: false}, + Rotation: LogRotationConfig{MaxSizeMB: 100}, + } + + tests := []struct { + name string + cfg LogConfig + wantErr bool + errContains string + }{ + {"valid", validLog, false, ""}, + {"empty level", func() LogConfig { c := validLog; c.Level = ""; return c }(), true, "log.level is required"}, + {"invalid level", func() LogConfig { c := validLog; c.Level = "verbose"; return c }(), true, "log.level must be one of"}, + {"valid levels", func() LogConfig { c := validLog; c.Level = "debug"; return c }(), false, ""}, // debug is valid + {"valid level warn", func() LogConfig { c := validLog; c.Level = "warn"; return c }(), false, ""}, + {"empty format", func() LogConfig { c := validLog; c.Format = ""; return c }(), true, "log.format is required"}, + {"invalid format", func() LogConfig { c := validLog; c.Format = "xml"; return c }(), true, "log.format must be one of"}, + {"both output false", func() LogConfig { c := validLog; c.Output.ToStdout = false; c.Output.ToFile = false; return c }(), true, "cannot both be false"}, + {"max_size_mb zero", func() LogConfig { c := validLog; c.Rotation.MaxSizeMB = 0; return c }(), true, "must be positive"}, + {"max_backups negative", func() LogConfig { c := validLog; c.Rotation.MaxBackups = -1; return c }(), true, "non-negative"}, + {"max_age_days negative", func() LogConfig { c := validLog; c.Rotation.MaxAgeDays = -1; return c }(), true, "non-negative"}, + {"sampling enabled with zero initial", func() LogConfig { c := validLog; c.Sampling.Enabled = true; c.Sampling.Initial = 0; return c }(), true, "must be positive when sampling"}, + {"sampling disabled negative thereafter", func() LogConfig { c := validLog; c.Sampling.Thereafter = -1; return c }(), true, "non-negative"}, + {"stacktrace empty", func() LogConfig { c := validLog; c.StacktraceLevel = ""; return c }(), true, "stacktrace_level is required"}, + {"invalid stacktrace", func() LogConfig { c := validLog; c.StacktraceLevel = "warn"; return c }(), true, "stacktrace_level must be one of"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validateLog(&tc.cfg) + if tc.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.errContains) + } else { + assert.NoError(t, err) + } + }) + } +} + +// --- validateDatabase --- + +func TestValidateDatabase(t *testing.T) { + tests := []struct { + name string + cfg DatabaseConfig + wantErr bool + errContains string + }{ + {"valid", DatabaseConfig{MaxOpenConns: 10, MaxIdleConns: 5, ConnMaxLifetimeMinutes: 30, ConnMaxIdleTimeMinutes: 5}, false, ""}, + {"max_open_conns zero", DatabaseConfig{MaxOpenConns: 0}, true, "must be positive"}, + {"max_idle_conns negative", DatabaseConfig{MaxOpenConns: 10, MaxIdleConns: -1}, true, "non-negative"}, + {"idle > open", DatabaseConfig{MaxOpenConns: 5, MaxIdleConns: 10}, true, "cannot exceed max_open_conns"}, + {"conn_max_lifetime negative", DatabaseConfig{MaxOpenConns: 10, ConnMaxLifetimeMinutes: -1}, true, "non-negative"}, + {"conn_max_idle_time negative", DatabaseConfig{MaxOpenConns: 10, ConnMaxIdleTimeMinutes: -1}, true, "non-negative"}, + {"all zero is valid for idle conn/time", DatabaseConfig{MaxOpenConns: 10, MaxIdleConns: 0, ConnMaxLifetimeMinutes: 0, ConnMaxIdleTimeMinutes: 0}, false, ""}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validateDatabase(&tc.cfg) + if tc.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.errContains) + } else { + assert.NoError(t, err) + } + }) + } +} + +// --- validateRedis --- + +func TestValidateRedis(t *testing.T) { + tests := []struct { + name string + cfg RedisConfig + wantErr bool + errContains string + }{ + {"valid", RedisConfig{DialTimeoutSeconds: 5, ReadTimeoutSeconds: 3, WriteTimeoutSeconds: 3, PoolSize: 100, MinIdleConns: 10}, false, ""}, + {"dial_timeout zero", RedisConfig{}, true, "dial_timeout_seconds must be positive"}, + {"read_timeout zero", RedisConfig{DialTimeoutSeconds: 5}, true, "read_timeout_seconds must be positive"}, + {"write_timeout zero", RedisConfig{DialTimeoutSeconds: 5, ReadTimeoutSeconds: 3}, true, "write_timeout_seconds must be positive"}, + {"pool_size zero", RedisConfig{DialTimeoutSeconds: 5, ReadTimeoutSeconds: 3, WriteTimeoutSeconds: 3}, true, "pool_size must be positive"}, + {"min_idle negative", RedisConfig{DialTimeoutSeconds: 5, ReadTimeoutSeconds: 3, WriteTimeoutSeconds: 3, PoolSize: 100, MinIdleConns: -1}, true, "non-negative"}, + {"min_idle > pool_size", RedisConfig{PoolSize: 10, MinIdleConns: 20, DialTimeoutSeconds: 5, ReadTimeoutSeconds: 3, WriteTimeoutSeconds: 3}, true, "cannot exceed pool_size"}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validateRedis(&tc.cfg) + if tc.wantErr { + assert.Error(t, err) + assert.Contains(t, err.Error(), tc.errContains) + } else { + assert.NoError(t, err) + } + }) + } +} + +// --- validateBilling --- + +func TestValidateBilling(t *testing.T) { + t.Run("disabled passes", func(t *testing.T) { + assert.NoError(t, validateBilling(&BillingConfig{})) + }) + + t.Run("enabled with valid values", func(t *testing.T) { + bc := BillingConfig{CircuitBreaker: CircuitBreakerConfig{ + Enabled: true, FailureThreshold: 5, ResetTimeoutSeconds: 30, HalfOpenRequests: 3, + }} + assert.NoError(t, validateBilling(&bc)) + }) + + t.Run("enabled failure_threshold zero", func(t *testing.T) { + bc := BillingConfig{CircuitBreaker: CircuitBreakerConfig{Enabled: true, FailureThreshold: 0}} + err := validateBilling(&bc) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failure_threshold must be positive") + }) + + t.Run("enabled reset_timeout zero", func(t *testing.T) { + bc := BillingConfig{CircuitBreaker: CircuitBreakerConfig{Enabled: true, FailureThreshold: 5, ResetTimeoutSeconds: 0}} + err := validateBilling(&bc) + assert.Error(t, err) + assert.Contains(t, err.Error(), "reset_timeout_seconds must be positive") + }) + + t.Run("enabled half_open_requests zero", func(t *testing.T) { + bc := BillingConfig{CircuitBreaker: CircuitBreakerConfig{ + Enabled: true, FailureThreshold: 5, ResetTimeoutSeconds: 30, HalfOpenRequests: 0, + }} + err := validateBilling(&bc) + assert.Error(t, err) + assert.Contains(t, err.Error(), "half_open_requests must be positive") + }) +} + +// --- validateIdempotency --- + +func TestValidateIdempotency(t *testing.T) { + valid := IdempotencyConfig{ + DefaultTTLSeconds: 86400, SystemOperationTTLSeconds: 3600, + ProcessingTimeoutSeconds: 30, FailedRetryBackoffSeconds: 5, + MaxStoredResponseLen: 65536, CleanupIntervalSeconds: 60, CleanupBatchSize: 500, + } + assert.NoError(t, validateIdempotency(&valid)) + + fieldsToZero := []string{ + "DefaultTTLSeconds", "SystemOperationTTLSeconds", "ProcessingTimeoutSeconds", + "FailedRetryBackoffSeconds", "MaxStoredResponseLen", "CleanupIntervalSeconds", "CleanupBatchSize", + } + for _, f := range fieldsToZero { + f := f + t.Run(f+"_zero", func(t *testing.T) { + c := valid + switch f { + case "DefaultTTLSeconds": c.DefaultTTLSeconds = 0 + case "SystemOperationTTLSeconds": c.SystemOperationTTLSeconds = 0 + case "ProcessingTimeoutSeconds": c.ProcessingTimeoutSeconds = 0 + case "FailedRetryBackoffSeconds": c.FailedRetryBackoffSeconds = 0 + case "MaxStoredResponseLen": c.MaxStoredResponseLen = 0 + case "CleanupIntervalSeconds": c.CleanupIntervalSeconds = 0 + case "CleanupBatchSize": c.CleanupBatchSize = 0 + } + err := validateIdempotency(&c) + assert.Error(t, err, "%s=0 should error", f) + }) + } +} + +// --- validateUsageCleanup --- + +func TestValidateUsageCleanup(t *testing.T) { + valid := UsageCleanupConfig{Enabled: true, MaxRangeDays: 31, BatchSize: 5000, WorkerIntervalSeconds: 10, TaskTimeoutSeconds: 1800} + assert.NoError(t, validateUsageCleanup(&valid)) + + t.Run("disabled with non-negative values passes", func(t *testing.T) { + uc := UsageCleanupConfig{Enabled: false, MaxRangeDays: 0, BatchSize: 0, WorkerIntervalSeconds: 0, TaskTimeoutSeconds: 0} + assert.NoError(t, validateUsageCleanup(&uc)) + }) + + t.Run("disabled with negative value fails", func(t *testing.T) { + uc := UsageCleanupConfig{Enabled: false, MaxRangeDays: -1} + err := validateUsageCleanup(&uc) + assert.Error(t, err) + assert.Contains(t, err.Error(), "non-negative") + }) + + t.Run("enabled max_range_days zero", func(t *testing.T) { + uc := UsageCleanupConfig{Enabled: true, MaxRangeDays: 0, BatchSize: 1, WorkerIntervalSeconds: 1, TaskTimeoutSeconds: 1} + err := validateUsageCleanup(&uc) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must be positive") + }) +} + +// --- validateOps --- + +func TestValidateOps(t *testing.T) { + valid := OpsConfig{Cleanup: OpsCleanupConfig{Enabled: true, Schedule: "0 2 * * *"}} + assert.NoError(t, validateOps(&valid)) + + t.Run("negative metrics cache TTL", func(t *testing.T) { + o := OpsConfig{MetricsCollectorCache: OpsMetricsCollectorCacheConfig{TTL: -1}} + err := validateOps(&o) + assert.Error(t, err) + assert.Contains(t, err.Error(), "non-negative") + }) + + t.Run("negative retention days", func(t *testing.T) { + o := OpsConfig{Cleanup: OpsCleanupConfig{ErrorLogRetentionDays: -1}} + err := validateOps(&o) + assert.Error(t, err) + assert.Contains(t, err.Error(), "non-negative") + }) + + t.Run("enabled cleanup without schedule", func(t *testing.T) { + o := OpsConfig{Cleanup: OpsCleanupConfig{Enabled: true, Schedule: ""}} + err := validateOps(&o) + assert.Error(t, err) + assert.Contains(t, err.Error(), "schedule is required") + }) +} + +// --- validateConcurrency --- + +func TestValidateConcurrency(t *testing.T) { + tests := []struct { + pingInterval int + wantErr bool + }{ + {5, false}, // min boundary + {10, false}, // normal + {30, false}, // max boundary + {4, true}, // below min + {31, true}, // above max + {0, true}, + {-1, true}, + } + for _, tc := range tests { + tc := tc + t.Run(fmt.Sprintf("ping_interval=%d", tc.pingInterval), func(t *testing.T) { + c := ConcurrencyConfig{PingInterval: tc.pingInterval} + err := validateConcurrency(&c) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} + +// --- validateServerURL --- + +func TestValidateServerURL(t *testing.T) { + t.Run("empty URL passes", func(t *testing.T) { + assert.NoError(t, validateServerURL("")) + }) + + t.Run("valid https URL", func(t *testing.T) { + assert.NoError(t, validateServerURL("https://app.example.com")) + }) + + t.Run("valid http URL (with warning)", func(t *testing.T) { + assert.NoError(t, validateServerURL("http://localhost:3000")) + }) + + t.Run("URL with query fails", func(t *testing.T) { + err := validateServerURL("https://example.com?foo=bar") + assert.Error(t, err) + assert.Contains(t, err.Error(), "must not include query") + }) + + t.Run("URL with userinfo fails", func(t *testing.T) { + err := validateServerURL("https://user:pass@example.com") + assert.Error(t, err) + assert.Contains(t, err.Error(), "must not include userinfo") + }) + + t.Run("invalid scheme fails", func(t *testing.T) { + err := validateServerURL("ftp://example.com") + assert.Error(t, err) + }) +} + +// --- validateLinuxDo --- + +func TestValidateLinuxDo(t *testing.T) { + validCfg := LinuxDoConnectConfig{ + Enabled: true, ClientID: "id", AuthorizeURL: "https://a.com/auth", + TokenURL: "https://a.com/token", UserInfoURL: "https://a.com/user", + RedirectURL: "https://a.com/cb", FrontendRedirectURL: "/cb", ClientSecret: "secret", + } + + t.Run("disabled passes", func(t *testing.T) { + assert.NoError(t, validateLinuxDo(&LinuxDoConnectConfig{})) + }) + + t.Run("enabled valid passes", func(t *testing.T) { + assert.NoError(t, validateLinuxDo(&validCfg)) + }) + + t.Run("enabled missing client_id", func(t *testing.T) { + c := validCfg; c.ClientID = "" + err := validateLinuxDo(&c) + assert.Error(t, err) + assert.Contains(t, err.Error(), "client_id is required") + }) + + t.Run("enabled invalid authorize_url", func(t *testing.T) { + c := validCfg; c.AuthorizeURL = "not-a-url" + err := validateLinuxDo(&c) + assert.Error(t, err) + }) + + t.Run("token_auth_method none requires PKCE", func(t *testing.T) { + c := validCfg; c.TokenAuthMethod = "none"; c.UsePKCE = false; c.ClientSecret = "" + err := validateLinuxDo(&c) + assert.Error(t, err) + assert.Contains(t, err.Error(), "use_pkce must be true") + }) + + t.Run("client_secret_post requires client_secret", func(t *testing.T) { + c := validCfg; c.TokenAuthMethod = "client_secret_post"; c.ClientSecret = "" + err := validateLinuxDo(&c) + assert.Error(t, err) + assert.Contains(t, err.Error(), "client_secret is required") + }) + + t.Run("invalid token_auth_method", func(t *testing.T) { + c := validCfg; c.TokenAuthMethod = "bearer" + err := validateLinuxDo(&c) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token_auth_method must be") + }) +} + +// --- validateOIDC --- + +func TestValidateOIDC(t *testing.T) { + validOIDC := OIDCConnectConfig{ + Enabled: true, ClientID: "id", IssuerURL: "https://idp.example.com", + RedirectURL: "https://app.com/cb", FrontendRedirectURL: "/oidc/cb", + ClientSecret: "secret", Scopes: "openid email profile", + } + + t.Run("disabled passes", func(t *testing.T) { + assert.NoError(t, validateOIDC(&OIDCConnectConfig{})) + }) + + t.Run("enabled valid passes", func(t *testing.T) { + assert.NoError(t, validateOIDC(&validOIDC)) + }) + + t.Run("missing openid scope", func(t *testing.T) { + c := validOIDC; c.Scopes = "email profile" + err := validateOIDC(&c) + assert.Error(t, err) + assert.Contains(t, err.Error(), "must contain openid") + }) + + t.Run("clock_skew out of range", func(t *testing.T) { + c := validOIDC; c.ClockSkewSeconds = 700 + err := validateOIDC(&c) + assert.Error(t, err) + assert.Contains(t, err.Error(), "between 0-600") + }) + + t.Run("validate_id_token requires allowed_signing_algs", func(t *testing.T) { + c := validOIDC; c.ValidateIDToken = true; c.AllowedSigningAlgs = "" + err := validateOIDC(&c) + assert.Error(t, err) + assert.Contains(t, err.Error(), "allowed_signing_algs required") + }) + + t.Run("missing issuer_url", func(t *testing.T) { + c := validOIDC; c.IssuerURL = "" + err := validateOIDC(&c) + assert.Error(t, err) + assert.Contains(t, err.Error(), "issuer_url is required") + }) +} + +// --- validateDashboard & validateDashboardAgg --- + +func TestValidateDashboard(t *testing.T) { + validDash := DashboardCacheConfig{ + Enabled: true, StatsFreshTTLSeconds: 15, StatsTTLSeconds: 30, StatsRefreshTimeoutSeconds: 30, + } + validAgg := DashboardAggregationConfig{Enabled: true, IntervalSeconds: 60, LookbackSeconds: 120, + Retention: DashboardAggregationRetentionConfig{UsageLogsDays: 90, UsageBillingDedupDays: 365, HourlyDays: 180, DailyDays: 730}, + } + + t.Run("enabled dashboard valid", func(t *testing.T) { + assert.NoError(t, validateDashboard(&validDash, &validAgg)) + }) + + t.Run("stats_fresh_ttl > stats_ttl fails", func(t *testing.T) { + d := validDash; d.StatsFreshTTLSeconds = 100; d.StatsTTLSeconds = 50 + err := validateDashboard(&d, &validAgg) + assert.Error(t, err) + assert.Contains(t, err.Error(), "stats_fresh_ttl_seconds must be <=") + }) + + t.Run("disabled dashboard with negatives fails", func(t *testing.T) { + d := DashboardCacheConfig{Enabled: false, StatsFreshTTLSeconds: -1} + err := validateDashboard(&d, &DashboardAggregationConfig{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "non-negative") + }) + + t.Run("aggregation enabled valid", func(t *testing.T) { + assert.NoError(t, validateDashboardAgg(&validAgg)) + }) + + t.Run("aggregation interval zero when enabled", func(t *testing.T) { + a := validAgg; a.Enabled = true; a.IntervalSeconds = 0 + err := validateDashboardAgg(&a) + assert.Error(t, err) + assert.Contains(t, err.Error(), "interval_seconds must be positive") + }) + + t.Run("billing_dedup < usage_logs fails", func(t *testing.T) { + a := validAgg; a.Retention.UsageLogsDays = 365; a.Retention.UsageBillingDedupDays = 30 + err := validateDashboardAgg(&a) + assert.Error(t, err) + assert.Contains(t, err.Error(), "usage_billing_dedup_days >= usage_logs_days") + }) + + t.Run("backfill_enabled with backfill_max_days=0 fails", func(t *testing.T) { + a := validAgg; a.BackfillEnabled = true; a.BackfillMaxDays = 0 + err := validateDashboardAgg(&a) + assert.Error(t, err) + assert.Contains(t, err.Error(), "backfill_max_days must be positive") + }) +} + +// --- Config.Validate() orchestration test --- + +func TestConfigValidate_Orchestration(t *testing.T) { + t.Parallel() + + t.Run("fully valid config passes all validators", func(t *testing.T) { + t.Parallel() + cfg := buildValidConfig() + assert.NoError(t, cfg.Validate()) + }) + + t.Run("invalid JWT stops validation early", func(t *testing.T) { + t.Parallel() + cfg := buildValidConfig() + cfg.JWT.Secret = "short" + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "jwt.secret") + }) + + t.Run("invalid database stops after earlier checks pass", func(t *testing.T) { + t.Parallel() + cfg := buildValidConfig() + cfg.Database.MaxOpenConns = 0 + err := cfg.Validate() + assert.Error(t, err) + assert.Contains(t, err.Error(), "database.max_open_conns") + }) +} + +// Helper to build a fully valid Config for testing + +func buildValidConfig() Config { + return Config{ + Server: ServerConfig{Host: "0.0.0.0", Port: 8080, Mode: "release"}, + Log: LogConfig{Level: "info", Format: "json", StacktraceLevel: "error", + Output: LogOutputConfig{ToStdout: true}, + Rotation: LogRotationConfig{MaxSizeMB: 100}}, + Security: SecurityConfig{}, + Billing: BillingConfig{}, + Turnstile: TurnstileConfig{}, + Database: DatabaseConfig{Host: "localhost", Port: 5432, User: "u", + DBName: "db", SSLMode: "disable", MaxOpenConns: 256, MaxIdleConns: 128, + ConnMaxLifetimeMinutes: 30, ConnMaxIdleTimeMinutes: 5}, + Redis: RedisConfig{Host: "localhost", Port: 6379, + DialTimeoutSeconds: 5, ReadTimeoutSeconds: 3, WriteTimeoutSeconds: 3, PoolSize: 1024}, + JWT: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, RefreshTokenExpireDays: 30}, + LinuxDo: LinuxDoConnectConfig{}, + OIDC: OIDCConnectConfig{}, + Default: DefaultConfig{}, + RateLimit: RateLimitConfig{}, + Pricing: PricingConfig{}, + Gateway: GatewayConfig{}, + APIKeyAuth: APIKeyAuthCacheConfig{}, + SubscriptionCache: SubscriptionCacheConfig{}, + Dashboard: DashboardCacheConfig{Enabled: true, StatsFreshTTLSeconds: 15, StatsTTLSeconds: 30, StatsRefreshTimeoutSeconds: 30}, + DashboardAgg: DashboardAggregationConfig{Enabled: true, IntervalSeconds: 60, LookbackSeconds: 120, + Retention: DashboardAggregationRetentionConfig{UsageLogsDays: 90, UsageBillingDedupDays: 365, HourlyDays: 180, DailyDays: 730}}, + UsageCleanup: UsageCleanupConfig{Enabled: true, MaxRangeDays: 31, BatchSize: 5000, WorkerIntervalSeconds: 10, TaskTimeoutSeconds: 1800}, + Concurrency: ConcurrencyConfig{PingInterval: 10}, + Idempotency: IdempotencyConfig{ + DefaultTTLSeconds: 86400, SystemOperationTTLSeconds: 3600, ProcessingTimeoutSeconds: 30, + FailedRetryBackoffSeconds: 5, MaxStoredResponseLen: 65536, CleanupIntervalSeconds: 60, CleanupBatchSize: 500, + }, + } +} diff --git a/backend/internal/config/database.go b/backend/internal/config/database.go new file mode 100644 index 00000000..242000aa --- /dev/null +++ b/backend/internal/config/database.go @@ -0,0 +1,67 @@ +package config + +import "fmt" + +// DatabaseConfig 数据库连接配置 +// 性能优化:新增连接池参数,避免频繁创建/销毁连接 +type DatabaseConfig struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + User string `mapstructure:"user"` + Password string `mapstructure:"password"` + DBName string `mapstructure:"dbname"` + SSLMode string `mapstructure:"sslmode"` + // 连接池配置(性能优化:可配置化连接池参数) + MaxOpenConns int `mapstructure:"max_open_conns"` // 最大打开连接数,控制数据库连接上限 + MaxIdleConns int `mapstructure:"max_idle_conns"` // 最大空闲连接数,保持热连接减少建连延迟 + ConnMaxLifetimeMinutes int `mapstructure:"conn_max_lifetime_minutes"` // 连接最大存活时间 + ConnMaxIdleTimeMinutes int `mapstructure:"conn_max_idle_time_minutes"` // 空闲连接最大存活时间 +} + +func (d *DatabaseConfig) DSN() string { + if d.Password == "" { + return fmt.Sprintf( + "host=%s port=%d user=%s dbname=%s sslmode=%s", + d.Host, d.Port, d.User, d.DBName, d.SSLMode, + ) + } + return fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s", + d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, + ) +} + +// DSNWithTimezone returns DSN with timezone setting +func (d *DatabaseConfig) DSNWithTimezone(tz string) string { + if tz == "" { + tz = "Asia/Shanghai" + } + if d.Password == "" { + return fmt.Sprintf( + "host=%s port=%d user=%s dbname=%s sslmode=%s TimeZone=%s", + d.Host, d.Port, d.User, d.DBName, d.SSLMode, tz, + ) + } + return fmt.Sprintf( + "host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s", + d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, tz, + ) +} + +// RedisConfig Redis 连接配置 +type RedisConfig struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + Password string `mapstructure:"password"` + DB int `mapstructure:"db"` + DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"` // 建立连接超时 + ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"` // 读取超时 + WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"` // 写入超时 + PoolSize int `mapstructure:"pool_size"` // 连接池大小 + MinIdleConns int `mapstructure:"min_idle_conns"` // 最小空闲连接数 + EnableTLS bool `mapstructure:"enable_tls"` // 是否启用 TLS/SSL 连接 +} + +func (r *RedisConfig) Address() string { + return fmt.Sprintf("%s:%d", r.Host, r.Port) +} diff --git a/backend/internal/config/gateway.go b/backend/internal/config/gateway.go new file mode 100644 index 00000000..aeb55e02 --- /dev/null +++ b/backend/internal/config/gateway.go @@ -0,0 +1,105 @@ +package config + +import "time" + +// GatewayConfig API网关相关配置 +type GatewayConfig struct { + ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` // 等待上游响应头超时(秒),0=无超时 + MaxBodySize int64 `mapstructure:"max_body_size"` // 请求体最大字节数 + UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"` // 非流式上游响应体读取上限 + ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"` // 代理探测响应读取上限 + GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"` + ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` // proxy/account/account_proxy + ForceCodexCLI bool `mapstructure:"force_codex_cli"` + ForcedCodexInstructionsTemplateFile string `mapstructure:"forced_codex_instructions_template_file"` + ForcedCodexInstructionsTemplate string `mapstructure:"-"` + OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"` + + OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"` + + // HTTP 上游连接池配置 + MaxIdleConns int `mapstructure:"max_idle_conns"` + MaxIdleConnsPerHost int `mapstructure:"max_idle_conns_per_host"` + MaxConnsPerHost int `mapstructure:"max_conns_per_host"` + IdleConnTimeoutSeconds int `mapstructure:"idle_conn_timeout_seconds"` + MaxUpstreamClients int `mapstructure:"max_upstream_clients"` + ClientIdleTTLSeconds int `mapstructure:"client_idle_ttl_seconds"` + ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"` + SessionIdleTimeoutMinutes int `mapstructure:"session_idle_timeout_seconds"` + + StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"` // 流数据间隔超时(秒),0=禁用 + StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"` // 流式 keepalive 间隔(秒),0=禁用 + MaxLineSize int `mapstructure:"max_line_size"` // SSE 单行最大字节数 + + LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"` + LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"` + InjectBetaForAPIKey bool `mapstructure:"inject_beta_for_apikey"` + FailoverOn400 bool `mapstructure:"failover_on_400"` + + // Sora 专用配置 + SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"` + SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"` + SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"` + SoraStreamMode string `mapstructure:"sora_stream_mode"` + SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"` + SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"` + SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"` + SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"` + + MaxAccountSwitches int `mapstructure:"max_account_switches"` + MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"` + AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"` + + Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"` + TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"` + UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"` + + UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"` + ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"` + + UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"` +} + +// UserMessageQueueConfig 用户消息串行队列配置 +type UserMessageQueueConfig struct { + Mode string `mapstructure:"mode"` // serialize/throttle/"" + Enabled bool `mapstructure:"enabled"` // 向后兼容 + LockTTLMs int `mapstructure:"lock_ttl_ms"` + WaitTimeoutMs int `mapstructure:"wait_timeout_ms"` + MinDelayMs int `mapstructure:"min_delay_ms"` + MaxDelayMs int `mapstructure:"max_delay_ms"` + CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"` +} + +func (c *UserMessageQueueConfig) WaitTimeout() time.Duration { + if c.WaitTimeoutMs <= 0 { return 30 * time.Second } + return time.Duration(c.WaitTimeoutMs) * time.Millisecond +} + +func (c *UserMessageQueueConfig) GetEffectiveMode() string { + if c.Mode == UMQModeSerialize || c.Mode == UMQModeThrottle { return c.Mode } + if c.Enabled { return UMQModeSerialize } + return "" +} + +// GatewaySchedulingConfig 账号调度相关配置 +type GatewaySchedulingConfig struct { + StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"` + StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"` + FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"` + FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"` + FallbackSelectionMode string `mapstructure:"fallback_selection_mode"` + LoadBatchEnabled bool `mapstructure:"load_batch_enabled"` + SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"` + SnapshotWriteChunkSize int `mapstructure:"snapshot_write_chunk_size"` + SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"` + DbFallbackEnabled bool `mapstructure:"db_fallback_enabled"` + DbFallbackTimeoutSeconds int `mapstructure:"db_fallback_timeout_seconds"` + DbFallbackMaxQPS int `mapstructure:"db_fallback_max_qps"` + OutboxPollIntervalSeconds int `mapstructure:"outbox_poll_interval_seconds"` + OutboxLagWarnSeconds int `mapstructure:"outbox_lag_warn_seconds"` + OutboxLagRebuildSeconds int `mapstructure:"outbox_lag_rebuild_seconds"` + OutboxLagRebuildFailures int `mapstructure:"outbox_lag_rebuild_failures"` + OutboxBacklogRebuildRows int `mapstructure:"outbox_backlog_rebuild_rows"` + FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"` +} diff --git a/backend/internal/config/gateway_sub.go b/backend/internal/config/gateway_sub.go new file mode 100644 index 00000000..95745376 --- /dev/null +++ b/backend/internal/config/gateway_sub.go @@ -0,0 +1,98 @@ +package config + +// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置 +type GatewayOpenAIWSConfig struct { + ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"` + IngressModeDefault string `mapstructure:"ingress_mode_default"` + Enabled bool `mapstructure:"enabled"` + OAuthEnabled bool `mapstructure:"oauth_enabled"` + APIKeyEnabled bool `mapstructure:"apikey_enabled"` + ForceHTTP bool `mapstructure:"force_http"` + AllowStoreRecovery bool `mapstructure:"allow_store_recovery"` + IngressPreviousResponseRecoveryEnabled bool `mapstructure:"ingress_previous_response_recovery_enabled"` + StoreDisabledConnMode string `mapstructure:"store_disabled_conn_mode"` + StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"` + PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"` + + ResponsesWebsockets bool `mapstructure:"responses_websockets"` + ResponsesWebsocketsV2 bool `mapstructure:"responses_websockets_v2"` + + MaxConnsPerAccount int `mapstructure:"max_conns_per_account"` + MinIdlePerAccount int `mapstructure:"min_idle_per_account"` + MaxIdlePerAccount int `mapstructure:"max_idle_per_account"` + DynamicMaxConnsByAccountConcurrencyEnabled bool `mapstructure:"dynamic_max_conns_by_account_concurrency_enabled"` + OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"` + APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"` + DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"` + ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"` + WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"` + PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"` + QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"` + EventFlushBatchSize int `mapstructure:"event_flush_batch_size"` + EventFlushIntervalMS int `mapstructure:"event_flush_interval_ms"` + PrewarmCooldownMS int `mapstructure:"prewarm_cooldown_ms"` + FallbackCooldownSeconds int `mapstructure:"fallback_cooldown_seconds"` + RetryBackoffInitialMS int `mapstructure:"retry_backoff_initial_ms"` + RetryBackoffMaxMS int `mapstructure:"retry_backoff_max_ms"` + RetryJitterRatio float64 `mapstructure:"retry_jitter_ratio"` + RetryTotalBudgetMS int `mapstructure:"retry_total_budget_ms"` + PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"` + + LBTopK int `mapstructure:"lb_top_k"` + StickySessionTTLSeconds int `mapstructure:"sticky_session_ttl_seconds"` + SessionHashReadOldFallback bool `mapstructure:"session_hash_read_old_fallback"` + SessionHashDualWriteOld bool `mapstructure:"session_hash_dual_write_old"` + MetadataBridgeEnabled bool `mapstructure:"metadata_bridge_enabled"` + StickyResponseIDTTLSeconds int `mapstructure:"sticky_response_id_ttl_seconds"` + StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"` + + SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"` +} + +// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重 +type GatewayOpenAIWSSchedulerScoreWeights struct { + Priority float64 `mapstructure:"priority"` + Load float64 `mapstructure:"load"` + Queue float64 `mapstructure:"queue"` + ErrorRate float64 `mapstructure:"error_rate"` + TTFT float64 `mapstructure:"ttft"` +} + +// GatewayUsageRecordConfig 使用量记录异步队列配置 +type GatewayUsageRecordConfig struct { + WorkerCount int `mapstructure:"worker_count"` + QueueSize int `mapstructure:"queue_size"` + TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"` + OverflowPolicy string `mapstructure:"overflow_policy"` // drop/sample/sync + OverflowSamplePercent int `mapstructure:"overflow_sample_percent"` // 1-100 + AutoScaleEnabled bool `mapstructure:"auto_scale_enabled"` + AutoScaleMinWorkers int `mapstructure:"auto_scale_min_workers"` + AutoScaleMaxWorkers int `mapstructure:"auto_scale_max_workers"` + AutoScaleUpQueuePercent int `mapstructure:"auto_scale_up_queue_percent"` + AutoScaleDownQueuePercent int `mapstructure:"auto_scale_down_queue_percent"` + AutoScaleUpStep int `mapstructure:"auto_scale_up_step"` + AutoScaleDownStep int `mapstructure:"auto_scale_down_step"` + AutoScaleCheckIntervalSeconds int `mapstructure:"auto_scale_check_interval_seconds"` + AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"` +} + +// TLSFingerprintConfig TLS 指纹伪装配置 +type TLSFingerprintConfig struct { + Enabled bool `mapstructure:"enabled"` + Profiles map[string]TLSProfileConfig `mapstructure:"profiles"` +} + +// TLSProfileConfig 单个 TLS 指纹模板配置 +type TLSProfileConfig struct { + Name string `mapstructure:"name"` + EnableGREASE bool `mapstructure:"enable_grease"` + CipherSuites []uint16 `mapstructure:"cipher_suites"` + Curves []uint16 `mapstructure:"curves"` + PointFormats []uint16 `mapstructure:"point_formats"` + SignatureAlgorithms []uint16 `mapstructure:"signature_algorithms"` + ALPNProtocols []string `mapstructure:"alpn_protocols"` + SupportedVersions []uint16 `mapstructure:"supported_versions"` + KeyShareGroups []uint16 `mapstructure:"key_share_groups"` + PSKModes []uint16 `mapstructure:"psk_modes"` + Extensions []uint16 `mapstructure:"extensions"` +} diff --git a/backend/internal/config/ops_and_cache.go b/backend/internal/config/ops_and_cache.go new file mode 100644 index 00000000..7f3c0973 --- /dev/null +++ b/backend/internal/config/ops_and_cache.go @@ -0,0 +1,122 @@ +package config + +import "time" + +// LogConfig 日志配置 +type LogConfig struct { + Level string `mapstructure:"level"` + Format string `mapstructure:"format"` + ServiceName string `mapstructure:"service_name"` + Environment string `mapstructure:"env"` + Caller bool `mapstructure:"caller"` + StacktraceLevel string `mapstructure:"stacktrace_level"` + Output LogOutputConfig `mapstructure:"output"` + Rotation LogRotationConfig `mapstructure:"rotation"` + Sampling LogSamplingConfig `mapstructure:"sampling"` +} + +type LogOutputConfig struct { + ToStdout bool `mapstructure:"to_stdout"` + ToFile bool `mapstructure:"to_file"` + FilePath string `mapstructure:"file_path"` +} + +type LogRotationConfig struct { + MaxSizeMB int `mapstructure:"max_size_mb"` + MaxBackups int `mapstructure:"max_backups"` + MaxAgeDays int `mapstructure:"max_age_days"` + Compress bool `mapstructure:"compress"` + LocalTime bool `mapstructure:"local_time"` +} + +type LogSamplingConfig struct { + Enabled bool `mapstructure:"enabled"` + Initial int `mapstructure:"initial"` + Thereafter int `mapstructure:"thereafter"` +} + +// OpsConfig 运维监控配置 +type OpsConfig struct { + Enabled bool `mapstructure:"enabled"` + UsePreaggregatedTables bool `mapstructure:"use_preaggregated_tables"` + Cleanup OpsCleanupConfig `mapstructure:"cleanup"` + MetricsCollectorCache OpsMetricsCollectorCacheConfig `mapstructure:"metrics_collector_cache"` + Aggregation OpsAggregationConfig `mapstructure:"aggregation"` +} + +type OpsCleanupConfig struct { + Enabled bool `mapstructure:"enabled"` + Schedule string `mapstructure:"schedule"` + ErrorLogRetentionDays int `mapstructure:"error_log_retention_days"` + MinuteMetricsRetentionDays int `mapstructure:"minute_metrics_retention_days"` + HourlyMetricsRetentionDays int `mapstructure:"hourly_metrics_retention_days"` +} + +type OpsAggregationConfig struct { + Enabled bool `mapstructure:"enabled"` +} + +type OpsMetricsCollectorCacheConfig struct { + Enabled bool `mapstructure:"enabled"` + TTL time.Duration `mapstructure:"ttl"` +} + +// APIKeyAuthCacheConfig API Key 认证缓存配置 +type APIKeyAuthCacheConfig struct { + L1Size int `mapstructure:"l1_size"` + L1TTLSeconds int `mapstructure:"l1_ttl_seconds"` + L2TTLSeconds int `mapstructure:"l2_ttl_seconds"` + NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"` + JitterPercent int `mapstructure:"jitter_percent"` + Singleflight bool `mapstructure:"singleflight"` +} + +// SubscriptionCacheConfig 订阅认证 L1 缓存配置 +type SubscriptionCacheConfig struct { + L1Size int `mapstructure:"l1_size"` + L1TTLSeconds int `mapstructure:"l1_ttl_seconds"` + JitterPercent int `mapstructure:"jitter_percent"` +} + +// SubscriptionMaintenanceConfig 订阅窗口维护后台任务配置 +type SubscriptionMaintenanceConfig struct { + WorkerCount int `mapstructure:"worker_count"` + QueueSize int `mapstructure:"queue_size"` +} + +// DashboardCacheConfig 仪表盘统计缓存配置 +type DashboardCacheConfig struct { + Enabled bool `mapstructure:"enabled"` + KeyPrefix string `mapstructure:"key_prefix"` + StatsFreshTTLSeconds int `mapstructure:"stats_fresh_ttl_seconds"` + StatsTTLSeconds int `mapstructure:"stats_ttl_seconds"` + StatsRefreshTimeoutSeconds int `mapstructure:"stats_refresh_timeout_seconds"` +} + +// DashboardAggregationConfig 仪表盘预聚合配置 +type DashboardAggregationConfig struct { + Enabled bool `mapstructure:"enabled"` + IntervalSeconds int `mapstructure:"interval_seconds"` + LookbackSeconds int `mapstructure:"lookback_seconds"` + BackfillEnabled bool `mapstructure:"backfill_enabled"` + BackfillMaxDays int `mapstructure:"backfill_max_days"` + Retention DashboardAggregationRetentionConfig `mapstructure:"retention"` + RecomputeDays int `mapstructure:"recompute_days"` +} + +// DashboardAggregationRetentionConfig 预聚合保留窗口 +type DashboardAggregationRetentionConfig struct { + UsageLogsDays int `mapstructure:"usage_logs_days"` + UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"` + HourlyDays int `mapstructure:"hourly_days"` + DailyDays int `mapstructure:"daily_days"` +} + +// UsageCleanupConfig 使用记录清理任务配置 +type UsageCleanupConfig struct { + Enabled bool `mapstructure:"enabled"` + MaxRangeDays int `mapstructure:"max_range_days"` + BatchSize int `mapstructure:"batch_size"` + WorkerIntervalSeconds int `mapstructure:"worker_interval_seconds"` + TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"` +} diff --git a/backend/internal/config/platforms.go b/backend/internal/config/platforms.go new file mode 100644 index 00000000..2223dca5 --- /dev/null +++ b/backend/internal/config/platforms.go @@ -0,0 +1,152 @@ +package config + +// SoraConfig 直连 Sora 配置 +type SoraConfig struct { + Client SoraClientConfig `mapstructure:"client"` + Storage SoraStorageConfig `mapstructure:"storage"` +} + +// SoraClientConfig Sora 客户端配置 +type SoraClientConfig struct { + BaseURL string `mapstructure:"base_url"` + TimeoutSeconds int `mapstructure:"timeout_seconds"` + MaxRetries int `mapstructure:"max_retries"` + CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"` + PollIntervalSeconds int `mapstructure:"poll_interval_seconds"` + MaxPollAttempts int `mapstructure:"max_poll_attempts"` + RecentTaskLimit int `mapstructure:"recent_task_limit"` + RecentTaskLimitMax int `mapstructure="recent_task_limit_max"` + Debug bool `mapstructure:"debug"` + UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"` + Headers map[string]string `mapstructure:"headers"` + UserAgent string `mapstructure:"user_agent"` + DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"` + CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"` +} + +// SoraCurlCFFISidecarConfig Sora curl_cffi sidecar 配置 +type SoraCurlCFFISidecarConfig struct { + Enabled bool `mapstructure:"enabled"` + BaseURL string `mapstructure:"base_url"` + Impersonate string `mapstructure:"impersonate"` + TimeoutSeconds int `mapstructure:"timeout_seconds"` + SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"` + SessionTTLSeconds int `mapstructure:"session_ttl_seconds"` +} + +// SoraStorageConfig 媒体存储配置 +type SoraStorageConfig struct { + Type string `mapstructure:"type"` + LocalPath string `mapstructure:"local_path"` + FallbackToUpstream bool `mapstructure:"fallback_to_upstream"` + MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"` + DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"` + MaxDownloadBytes int64 `mapstructure:"max_download_bytes"` + Debug bool `mapstructure:"debug"` + Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"` +} + +// SoraStorageCleanupConfig 媒体清理配置 +type SoraStorageCleanupConfig struct { + Enabled bool `mapstructure:"enabled"` + Schedule string `mapstructure:"schedule"` + RetentionDays int `mapstructure:"retention_days"` +} + +// SoraModelFiltersConfig Sora 模型过滤配置 +type SoraModelFiltersConfig struct { + HidePromptEnhance bool `mapstructure:"hide_prompt_envelope"` +} + +// GeminiConfig Gemini 配置 +type GeminiConfig struct { + OAuth GeminiOAuthConfig `mapstructure:"oauth"` + Quota GeminiQuotaConfig `mapstructure:"quota"` +} + +type GeminiOAuthConfig struct { + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + Scopes string `mapstructure:"scopes"` +} + +type GeminiQuotaConfig struct { + Tiers map[string]GeminiTierQuotaConfig `mapstructure:"tiers"` + Policy string `mapstructure:"policy"` +} + +type GeminiTierQuotaConfig struct { + ProRPD *int64 `mapstructure:"pro_rpd" json:"pro_rpd"` + FlashRPD *int64 `mapstructure:"flash_rpd" json:"flash_rpd"` + CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"` +} + +// UpdateConfig 更新检查配置 +type UpdateConfig struct { + ProxyURL string `mapstructure:"proxy_url"` // 访问 GitHub 的代理地址 +} + +// IdempotencyConfig 幂等性配置 +type IdempotencyConfig struct { + ObserveOnly bool `mapstructure:"observe_only"` + DefaultTTLSeconds int `mapstructure:"default_ttl_seconds"` + SystemOperationTTLSeconds int `mapstructure:"system_operation_ttl_seconds"` + ProcessingTimeoutSeconds int `mapstructure:"processing_timeout_seconds"` + FailedRetryBackoffSeconds int `mapstructure:"failed_retry_backoff_seconds"` + MaxStoredResponseLen int `mapstructure:"max_stored_response_len"` + CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"` + CleanupBatchSize int `mapstructure:"cleanup_batch_size"` +} + +// LinuxDoConnectConfig LinuxDo 连接配置 +type LinuxDoConnectConfig struct { + Enabled bool `mapstructure:"enabled"` + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + AuthorizeURL string `mapstructure:"authorize_url"` + TokenURL string `mapstructure:"token_url"` + UserInfoURL string `mapstructure:"userinfo_url"` + Scopes string `mapstructure:"scopes"` + RedirectURL string `mapstructure:"redirect_url"` + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` + TokenAuthMethod string `mapstructure:"token_auth_method"` + UsePKCE bool `mapstructure:"use_pkce"` + UserInfoEmailPath string `mapstructure:"userinfo_email_path"` + UserInfoIDPath string `mapstructure:"userinfo_id_path"` + UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` +} + +// OIDCConnectConfig OIDC 连接配置 +type OIDCConnectConfig struct { + Enabled bool `mapstructure:"enabled"` + ProviderName string `mapstructure:"provider_name"` + ClientID string `mapstructure:"client_id"` + ClientSecret string `mapstructure:"client_secret"` + IssuerURL string `mapstructure:"issuer_url"` + DiscoveryURL string `mapstructure:"discovery_url"` + AuthorizeURL string `mapstructure:"authorize_url"` + TokenURL string `mapstructure:"token_url"` + UserInfoURL string `mapstructure:"userinfo_url"` + JWKSURL string `mapstructure:"jwks_url"` + Scopes string `mapstructure:"scopes"` + RedirectURL string `mapstructure:"redirect_url"` + FrontendRedirectURL string `mapstructure:"frontend_redirect_url"` + TokenAuthMethod string `mapstructure:"token_auth_method"` + UsePKCE bool `mapstructure:"use_pkce"` + ValidateIDToken bool `mapstructure:"validate_id_token"` + AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"` + ClockSkewSeconds int `mapstructure:"clock_skew_seconds"` + RequireEmailVerified bool `mapstructure:"require_email_verified"` + UserInfoEmailPath string `mapstructure:"userinfo_email_path"` + UserInfoIDPath string `mapstructure:"userinfo_id_path"` + UserInfoUsernamePath string `mapstructure:"userinfo_username_path"` +} + +// TokenRefreshConfig OAuth Token 自动刷新配置 +type TokenRefreshConfig struct { + Enabled bool `mapstructure:"enabled"` + CheckIntervalMinutes int `mapstructure:"check_interval_minutes"` + RefreshBeforeExpiryHours float64 `mapstructure:"refresh_before_expiry_hours"` + MaxRetries int `mapstructure:"max_retries"` + RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"` +} diff --git a/backend/internal/config/security.go b/backend/internal/config/security.go new file mode 100644 index 00000000..25704f5c --- /dev/null +++ b/backend/internal/config/security.go @@ -0,0 +1,48 @@ +package config + +// SecurityConfig 安全相关配置 +type SecurityConfig struct { + URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"` + ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"` + CSP CSPConfig `mapstructure:"csp"` + ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"` + ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"` +} + +// URLAllowlistConfig URL 白名单配置 +type URLAllowlistConfig struct { + Enabled bool `mapstructure:"enabled"` + UpstreamHosts []string `mapstructure:"upstream_hosts"` + PricingHosts []string `mapstructure:"pricing_hosts"` + CRSHosts []string `mapstructure:"crs_hosts"` + AllowPrivateHosts bool `mapstructure:"allow_private_hosts"` + // 关闭 URL 白名单校验时,是否允许 http URL(默认只允许 https) + AllowInsecureHTTP bool `mapstructure:"allow_insecure_http"` +} + +// ResponseHeaderConfig 安全响应头配置 +type ResponseHeaderConfig struct { + Enabled bool `mapstructure:"enabled"` + AdditionalAllowed []string `mapstructure:"additional_allowed"` + ForceRemove []string `mapstructure:"force_remove"` +} + +// CSPConfig Content-Security-Policy 配置 +type CSPConfig struct { + Enabled bool `mapstructure:"enabled"` + Policy string `mapstructure:"policy"` +} + +// ProxyFallbackConfig 代理回退配置 +type ProxyFallbackConfig struct { + // AllowDirectOnError 当辅助服务的代理初始化失败时是否允许回退直连。 + // 仅影响非 AI 账号连接的辅助服务(GitHub Release 更新检查、定价数据拉取)。 + // 不影响 AI 账号网关连接,这些关键路径的代理失败始终返回错误。 + // 默认 false:避免因代理配置错误导致服务器真实 IP 泄露。 + AllowDirectOnError bool `mapstructure:"allow_direct_on_error"` +} + +// ProxyProbeConfig 代理探测配置 +type ProxyProbeConfig struct { + InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证 +} diff --git a/backend/internal/config/server.go b/backend/internal/config/server.go new file mode 100644 index 00000000..f16e7340 --- /dev/null +++ b/backend/internal/config/server.go @@ -0,0 +1,42 @@ +package config + +import "fmt" + +// ServerConfig HTTP 服务端配置 +type ServerConfig struct { + Host string `mapstructure:"host"` + Port int `mapstructure:"port"` + Mode string `mapstructure:"mode"` // debug/release + FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL,用于生成邮件中的外部链接 + ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒) + IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒) + TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表(CIDR/IP) + MaxRequestBodySize int64 `mapstructure:"max_request_body_size"` // 全局最大请求体限制 + H2C H2CConfig `mapstructure:"h2c"` // HTTP/2 Cleartext 配置 +} + +// H2CConfig HTTP/2 Cleartext 配置 +type H2CConfig struct { + Enabled bool `mapstructure:"enabled"` // 是否启用 H2C + MaxConcurrentStreams uint32 `mapstructure:"max_concurrent_streams"` // 最大并发流数量 + IdleTimeout int `mapstructure:"idle_timeout"` // 空闲超时(秒) + MaxReadFrameSize int `mapstructure:"max_read_frame_size"` // 最大帧大小(字节) + MaxUploadBufferPerConnection int `mapstructure:"max_upload_buffer_per_connection"` // 每个连接的上传缓冲区(字节) + MaxUploadBufferPerStream int `mapstructure:"max_upload_buffer_per_stream"` // 每个流的上传缓冲区(字节) +} + +// CORSConfig 跨域配置 +type CORSConfig struct { + AllowedOrigins []string `mapstructure:"allowed_origins"` + AllowCredentials bool `mapstructure:"allow_credentials"` +} + +// ConcurrencyConfig 并发配置 +type ConcurrencyConfig struct { + // PingInterval: 并发等待期间的 SSE ping 间隔(秒) + PingInterval int `mapstructure:"ping_interval"` +} + +func (s ServerConfig) Address() string { + return fmt.Sprintf("%s:%d", s.Host, s.Port) +}