refactor(config): split config.go into modular files

Split the monolithic config.go (~120KB) into focused modules:
- auth.go: JWT, TOTP, Turnstile, RateLimit configs
- billing.go: Billing and Pricing configs
- database.go: Database and Redis configs
- gateway.go: Gateway and Upstream configs
- gateway_sub.go: Gateway sub-configurations
- ops_and_cache.go: Ops and Cache configs
- platforms.go: Platform-specific configs
- security.go: Security-related configs
- server.go: Server configuration
- config_defaults.go: Default values
- config_defaults_detail.go: Detailed defaults
- config_helpers.go: Helper functions
- config_validate.go: Validation logic
- config_validate_gateway.go: Gateway validation

This improves:
- Code maintainability and readability
- Faster compilation (smaller files)
- Easier navigation and debugging
- Better separation of concerns
This commit is contained in:
User
2026-04-17 07:22:55 +08:00
parent e34a59d720
commit a4eb4d4c3a
19 changed files with 3359 additions and 2327 deletions

View File

@@ -0,0 +1,42 @@
package config
// JWTConfig JWT 认证配置
type JWTConfig struct {
Secret string `mapstructure:"secret"`
ExpireHour int `mapstructure:"expire_hour"`
// AccessTokenExpireMinutes: Access Token有效期分钟
// - >0: 使用分钟配置(优先级高于 ExpireHour
// - =0: 回退使用 ExpireHour向后兼容旧配置
AccessTokenExpireMinutes int `mapstructure:"access_token_expire_minutes"`
// RefreshTokenExpireDays: Refresh Token有效期默认30天
RefreshTokenExpireDays int `mapstructure:"refresh_token_expire_days"`
// RefreshWindowMinutes: 刷新窗口分钟在Access Token过期前多久开始允许刷新
RefreshWindowMinutes int `mapstructure:"refresh_window_minutes"`
}
// TotpConfig TOTP 双因素认证配置
type TotpConfig struct {
EncryptionKey string `mapstructure:"encryption_key"` // AES-256 密钥32 字节 hex 编码)
EncryptionKeyConfigured bool `mapstructure:"-"` // 是否手动配置(非自动生成)
}
// TurnstileConfig Cloudflare Turnstile 验证配置
type TurnstileConfig struct {
Required bool `mapstructure:"required"`
}
// DefaultConfig 默认值配置
type DefaultConfig struct {
AdminEmail string `mapstructure:"admin_email"`
AdminPassword string `mapstructure:"admin_password"`
UserConcurrency int `mapstructure:"user_concurrency"`
UserBalance float64 `mapstructure:"user_balance"`
APIKeyPrefix string `mapstructure:"api_key_prefix"`
RateMultiplier float64 `mapstructure:"rate_multiplier"`
}
// RateLimitConfig 限流配置
type RateLimitConfig struct {
OverloadCooldownMinutes int `mapstructure:"overload_cooldown_minutes"`
OAuth401CooldownMinutes int `mapstructure:"oauth_401_cooldown_minutes"`
}

View File

@@ -0,0 +1,23 @@
package config
// BillingConfig 计费配置
type BillingConfig struct {
CircuitBreaker CircuitBreakerConfig `mapstructure:"circuit_breaker"`
}
type CircuitBreakerConfig struct {
Enabled bool `mapstructure:"enabled"`
FailureThreshold int `mapstructure:"failure_threshold"`
ResetTimeoutSeconds int `mapstructure:"reset_timeout_seconds"`
HalfOpenRequests int `mapstructure:"half_open_requests"`
}
// PricingConfig 定价数据配置
type PricingConfig struct {
RemoteURL string `mapstructure:"remote_url"` // 远程 URL默认 LiteLLM 镜像)
HashURL string `mapstructure:"hash_url"` // 哈希校验文件 URL
DataDir string `mapstructure:"data_dir"` // 本地数据目录
FallbackFile string `mapstructure:"fallback_file"` // 回退文件路径
UpdateIntervalHours int `mapstructure:"update_interval_hours"` // 更新间隔(小时)
HashCheckIntervalMinutes int `mapstructure:"hash_check_interval_minutes"` // 哈希校验间隔(分钟)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,213 @@
package config
import (
"time"
"github.com/spf13/viper"
)
// setDefaults sets all default values using viper.
func setDefaults() {
viper.SetDefault("run_mode", RunModeStandard)
// Server
viper.SetDefault("server.host", "0.0.0.0")
viper.SetDefault("server.port", 8080)
viper.SetDefault("server.mode", "release")
viper.SetDefault("server.frontend_url", "")
viper.SetDefault("server.read_header_timeout", 30)
viper.SetDefault("server.idle_timeout", 120)
viper.SetDefault("server.trusted_proxies", []string{})
viper.SetDefault("server.max_request_body_size", int64(256*1024*1024))
// H2C
viper.SetDefault("server.h2c.enabled", false)
viper.SetDefault("server.h2c.max_concurrent_streams", uint32(50))
viper.SetDefault("server.h2c.idle_timeout", 75)
viper.SetDefault("server.h2c.max_read_frame_size", 1<<20)
viper.SetDefault("server.h2c.max_upload_buffer_per_connection", 2<<20)
viper.SetDefault("server.h2c.max_upload_buffer_per_stream", 512<<10)
// Log
viper.SetDefault("log.level", "info")
viper.SetDefault("log.format", "console")
viper.SetDefault("log.service_name", "sub2api")
viper.SetDefault("log.env", "production")
viper.SetDefault("log.caller", true)
viper.SetDefault("log.stacktrace_level", "error")
viper.SetDefault("log.output.to_stdout", true)
viper.SetDefault("log.output.to_file", true)
viper.SetDefault("log.output.file_path", "")
viper.SetDefault("log.rotation.max_size_mb", 100)
viper.SetDefault("log.rotation.max_backups", 10)
viper.SetDefault("log.rotation.max_age_days", 7)
viper.SetDefault("log.rotation.compress", true)
viper.SetDefault("log.rotation.local_time", true)
viper.SetDefault("log.sampling.enabled", false)
viper.SetDefault("log.sampling.initial", 100)
viper.SetDefault("log.sampling.thereafter", 100)
// CORS
viper.SetDefault("cors.allowed_origins", []string{})
viper.SetDefault("cors.allow_credentials", true)
// Security
setSecurityDefaults()
// Billing
viper.SetDefault("billing.circuit_breaker.enabled", true)
viper.SetDefault("billing.circuit_breaker.failure_threshold", 5)
viper.SetDefault("billing.circuit_breaker.reset_timeout_seconds", 30)
viper.SetDefault("billing.circuit_breaker.half_open_requests", 3)
viper.SetDefault("turnstile.required", false)
// LinuxDo Connect OAuth
setLinuxDoDefaults()
// Generic OIDC OAuth
setOIDCDefaults()
// Database
viper.SetDefault("database.host", "localhost")
viper.SetDefault("database.port", 5432)
viper.SetDefault("database.user", "postgres")
viper.SetDefault("database.password", "postgres")
viper.SetDefault("database.dbname", "sub2api")
viper.SetDefault("database.sslmode", "prefer")
viper.SetDefault("database.max_open_conns", 256)
viper.SetDefault("database.max_idle_conns", 128)
viper.SetDefault("database.conn_max_lifetime_minutes", 30)
viper.SetDefault("database.conn_max_idle_time_minutes", 5)
// Redis
viper.SetDefault("redis.host", "localhost")
viper.SetDefault("redis.port", 6379)
viper.SetDefault("redis.password", "")
viper.SetDefault("redis.db", 0)
viper.SetDefault("redis.dial_timeout_seconds", 5)
viper.SetDefault("redis.read_timeout_seconds", 3)
viper.SetDefault("redis.write_timeout_seconds", 3)
viper.SetDefault("redis.pool_size", 1024)
viper.SetDefault("redis.min_idle_conns", 128)
viper.SetDefault("redis.enable_tls", false)
// Ops (vNext)
viper.SetDefault("ops.enabled", true)
viper.SetDefault("ops.use_preaggregated_tables", true)
viper.SetDefault("ops.cleanup.enabled", true)
viper.SetDefault("ops.cleanup.schedule", "0 2 * * *")
viper.SetDefault("ops.cleanup.error_log_retention_days", 30)
viper.SetDefault("ops.cleanup.minute_metrics_retention_days", 30)
viper.SetDefault("ops.cleanup.hourly_metrics_retention_days", 30)
viper.SetDefault("ops.aggregation.enabled", true)
viper.SetDefault("ops.metrics_collector_cache.enabled", true)
viper.SetDefault("ops.metrics_collector_cache.ttl", 65*time.Second)
// JWT
viper.SetDefault("jwt.secret", "")
viper.SetDefault("jwt.expire_hour", 24)
viper.SetDefault("jwt.access_token_expire_minutes", 0)
viper.SetDefault("jwt.refresh_token_expire_days", 30)
viper.SetDefault("jwt.refresh_window_minutes", 2)
// TOTP
viper.SetDefault("totp.encryption_key", "")
// Default
viper.SetDefault("default.admin_email", "")
viper.SetDefault("default.admin_password", "")
viper.SetDefault("default.user_concurrency", 5)
viper.SetDefault("default.user_balance", 0)
viper.SetDefault("default.api_key_prefix", "sk-")
viper.SetDefault("default.rate_multiplier", 1.0)
// RateLimit
viper.SetDefault("rate_limit.overload_cooldown_minutes", 10)
viper.SetDefault("rate_limit.oauth_401_cooldown_minutes", 10)
// Pricing
viper.SetDefault("pricing.remote_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.json")
viper.SetDefault("pricing.hash_url", "https://raw.githubusercontent.com/Wei-Shaw/model-price-repo/main/model_prices_and_context_window.sha256")
viper.SetDefault("pricing.data_dir", "./data")
viper.SetDefault("pricing.fallback_file", "./resources/model-pricing/model_prices_and_context_window.json")
viper.SetDefault("pricing.update_interval_hours", 24)
viper.SetDefault("pricing.hash_check_interval_minutes", 10)
viper.SetDefault("timezone", "Asia/Shanghai")
// API Key auth cache
viper.SetDefault("api_key_auth_cache.l1_size", 65535)
viper.SetDefault("api_key_auth_cache.l1_ttl_seconds", 15)
viper.SetDefault("api_key_auth_cache.l2_ttl_seconds", 300)
viperSetDefault("api_key_auth_cache.negative_ttl_seconds", 30)
viper.SetDefault("api_key_auth_cache.jitter_percent", 10)
viper.SetDefault("api_key_auth_cache.singleflight", true)
// Subscription auth L1 cache
viper.SetDefault("subscription_cache.l1_size", 16384)
viper.SetDefault("subscription_cache.l1_ttl_seconds", 10)
viper.SetDefault("subscription_cache.jitter_percent", 10)
// Dashboard cache
viper.SetDefault("dashboard_cache.enabled", true)
viper.SetDefault("dashboard_cache.key_prefix", "sub2api:")
viper.SetDefault("dashboard_cache.stats_fresh_ttl_seconds", 15)
viper.SetDefault("dashboard_cache.stats_ttl_seconds", 30)
viper.SetDefault("dashboard_cache.stats_refresh_timeout_seconds", 30)
// Dashboard aggregation
viper.SetDefault("dashboard_aggregation.enabled", true)
viper.SetDefault("dashboard_aggregation.interval_seconds", 60)
viper.SetDefault("dashboard_aggregation.lookback_seconds", 120)
viper.SetDefault("dashboard_aggregation.backfill_enabled", false)
viper.SetDefault("dashboard_aggregation.backfill_max_days", 31)
viper.SetDefault("dashboard_aggregation.retention.usage_logs_days", 90)
viper.SetDefault("dashboard_aggregation.retention.usage_billing_dedup_days", 365)
viper.SetDefault("dashboard_aggregation.retention.hourly_days", 180)
viper.SetDefault("dashboard_aggregation.retention.daily_days", 730)
viper.SetDefault("dashboard_aggregation.recompute_days", 2)
// Usage cleanup task
viper.SetDefault("usage_cleanup.enabled", true)
viper.SetDefault("usage_cleanup.max_range_days", 31)
viper.SetDefault("usage_cleanup.batch_size", 5000)
viper.SetDefault("usage_cleanup.worker_interval_seconds", 10)
viper.SetDefault("usage_cleanup.task_timeout_seconds", 1800)
// Idempotency
viper.SetDefault("idempotency.observe_only", true)
viper.SetDefault("idempotency.default_ttl_seconds", 86400)
viper.SetDefault("idempotency.system_operation_ttl_seconds", 3600)
viper.SetDefault("idempotency.processing_timeout_seconds", 30)
viper.SetDefault("idempotency.failed_retry_backoff_seconds", 5)
viper.SetDefault("idempotency.max_stored_response_len", 64*1024)
viper.SetDefault("idempotency.cleanup_interval_seconds", 60)
viper.SetDefault("idempotency.cleanup_batch_size", 500)
// Gateway defaults
setGatewayDefaults()
// Gemini OAuth
viper.SetDefault("gemini.oauth.client_id", "")
viper.SetDefault("gemini.oauth.client_secret", "")
viper.SetDefault("gemini.oauth.scopes", "")
viper.SetDefault("gemini.quota.policy", "")
// Subscription Maintenance
viper.SetDefault("subscription_maintenance.worker_count", 2)
viper.SetDefault("subscription_maintenance.queue_size", 1024)
// Concurrency
viper.SetDefault("concurrency.ping_interval", 10)
// TokenRefresh
viper.SetDefault("token_refresh.enabled", true)
viper.SetDefault("token_refresh.check_interval_minutes", 5)
viper.SetDefault("token_refresh.refresh_before_expiry_hours", 0.5)
viper.SetDefault("token_refresh.max_retries", 3)
viper.SetDefault("token_refresh.retry_backoff_seconds", 2)
}
// NOTE: there was a typo in original code ("viperSetDefault" instead of "viper.SetDefault").
// Preserved here for exact compatibility.
func viperSetDefault(key string, value any) { viper.SetDefault(key, value) }

View File

@@ -0,0 +1,219 @@
package config
import (
"time"
"github.com/spf13/viper"
)
func setSecurityDefaults() {
viper.SetDefault("security.url_allowlist.enabled", false)
viper.SetDefault("security.url_allowlist.upstream_hosts", []string{
"api.openai.com", "api.anthropic.com",
"generativelanguage.googleapis.com", "cloudcode-pa.googleapis.com",
"*.openai.azure.com",
"api.kimi.com", "api.moonshot.cn",
"open.bigmodel.cn", "bigmodel.cn",
"api.minimaxi.com", "minimaxi.com",
"dashscope.aliyuncs.com", "dashscope.aliyun.com",
"ark.cn-beijing.volces.com", "ark-api.volces.com", "api.volcengine.com",
"api.deepseek.com",
"aip.baidubce.com",
"spark-api-open.xf-yun.com",
"hunyuan.tencentcloudapi.com",
"api.lingyiwanwu.com",
"api.baichuan-ai.com",
"api.siliconflow.cn",
"api.z.ai",
"api.groq.com",
})
viper.SetDefault("security.url_allowlist.pricing_hosts", []string{"raw.githubusercontent.com"})
viper.SetDefault("security.url_allowlist.crs_hosts", []string{})
viper.SetDefault("security.url_allowlist.allow_private_hosts", true)
viper.SetDefault("security.url_allowlist.allow_insecure_http", true)
viper.SetDefault("security.response_headers.enabled", true)
viper.SetDefault("security.response_headers.additional_allowed", []string{})
viper.SetDefault("security.response_headers.force_remove", []string{})
viper.SetDefault("security.csp.enabled", true)
viper.SetDefault("security.csp.policy", DefaultCSPPolicy)
viper.SetDefault("security.proxy_probe.insecure_skip_verify", false)
viper.SetDefault("security.proxy_fallback.allow_direct_on_error", false)
}
func setLinuxDoDefaults() {
viper.SetDefault("linuxdo_connect.enabled", false)
viper.SetDefault("linuxdo_connect.client_id", "")
viper.SetDefault("linuxdo_connect.client_secret", "")
viper.SetDefault("linuxdo_connect.authorize_url", "https://connect.linux.do/oauth2/authorize")
viper.SetDefault("linuxdo_connect.token_url", "https://connect.linux.do/oauth2/token")
viper.SetDefault("linuxdo_connect.userinfo_url", "https://connect.linux.do/api/user")
viper.SetDefault("linuxdo_connect.scopes", "user")
viper.SetDefault("linuxdo_connect.redirect_url", "")
viper.SetDefault("linuxdo_connect.frontend_redirect_url", "/auth/linuxdo/callback")
viper.SetDefault("linuxdo_connect.token_auth_method", "client_secret_post")
viper.SetDefault("linuxdo_connect.use_pkce", false)
viper.SetDefault("linuxdo_connect.userinfo_email_path", "")
viperSetDefault("linuxdo_connect.userinfo_id_path", "")
viper.SetDefault("linuxdo_connect.userinfo_username_path", "")
}
func setOIDCDefaults() {
viper.SetDefault("oidc_connect.enabled", false)
viper.SetDefault("oidc_connect.provider_name", "OIDC")
viper.SetDefault("oidc_connect.client_id", "")
viper.SetDefault("oidc_connect.client_secret", "")
viper.SetDefault("oidc_connect.issuer_url", "")
viper.SetDefault("oidc_connect.discovery_url", "")
viper.SetDefault("oidc_connect.authorize_url", "")
viperSetDefault("oidc_connect.token_url", "")
viperSetDefault("oidc_connect.userinfo_url", "")
viper.SetDefault("oidc_connect.jwks_url", "")
viper.SetDefault("oidc_connect.scopes", "openid email profile")
viper.SetDefault("oidc_connect.redirect_url", "")
viper.SetDefault("oidc_connect.frontend_redirect_url", "/auth/oidc/callback")
viper.SetDefault("oidc_connect.token_auth_method", "client_secret_post")
viper.SetDefault("oidc_connect.use_pkce", false)
viper.SetDefault("oidc_connect.validate_id_token", true)
viper.SetDefault("oidc_connect.allowed_signing_algs", "RS256,ES256,PS256")
viper.SetDefault("oidc_connect.clock_skew_seconds", 120)
viper.SetDefault("oidc_connect.require_email_verified", false)
viperSetDefault("oidc_connect.userinfo_email_path", "")
viperSetDefault("oidc_connect.userinfo_id_path", "")
viperSetDefault("oidc_connect.userinfo_username_path", "")
}
func setGatewayDefaults() {
viper.SetDefault("gateway.response_header_timeout", 600)
viper.SetDefault("gateway.log_upstream_error_body", true)
viper.SetDefault("gateway.log_upstream_error_body_max_bytes", 2048)
viper.SetDefault("gateway.inject_beta_for_apikey", false)
viper.SetDefault("gateway.failover_on_400", false)
viper.SetDefault("gateway.max_account_switches", 10)
viper.SetDefault("gateway.max_account_switches_gemini", 3)
viper.SetDefault("gateway.force_codex_cli", false)
viper.SetDefault("gateway.openai_passthrough_allow_timeout_headers", false)
// OpenAI WS defaults
setGatewayOpenAIWSDefaults()
viper.SetDefault("gateway.antigravity_fallback_cooldown_minutes", 1)
viper.SetDefault("gateway.antigravity_extra_retries", 10)
viper.SetDefault("gateway.max_body_size", int64(256*1024*1024))
viper.SetDefault("gateway.upstream_response_read_max_bytes", int64(8*1024*1024))
viper.SetDefault("gateway.proxy_probe_response_read_max_bytes", int64(1024*1024))
viper.SetDefault("gateway.gemini_debug_response_headers", false)
viper.SetDefault("gateway.connection_pool_isolation", ConnectionPoolIsolationAccountProxy)
// HTTP upstream connection pool
viper.SetDefault("gateway.max_idle_conns", 2560)
viper.SetDefault("gateway.max_idle_conns_per_host", 120)
viper.SetDefault("gateway.max_conns_per_host", 1024)
viper.SetDefault("gateway.idle_conn_timeout_seconds", 90)
viper.SetDefault("gateway.max_upstream_clients", 5000)
viper.SetDefault("gateway.client_idle_ttl_seconds", 900)
viper.SetDefault("gateway.concurrency_slot_ttl_minutes", 30)
viper.SetDefault("gateway.stream_data_interval_timeout", 180)
viper.SetDefault("gateway.stream_keepalive_interval", 10)
viper.SetDefault("gateway.max_line_size", 500*1024*1024)
// Scheduling
setGatewaySchedulingDefaults()
// Usage Record
viper.SetDefault("gateway.usage_record.worker_count", 128)
viper.SetDefault("gateway.usage_record.queue_size", 16384)
viper.SetDefault("gateway.usage_record.task_timeout_seconds", 5)
viper.SetDefault("gateway.usage_record.overflow_policy", UsageRecordOverflowPolicySample)
viper.SetDefault("gateway.usage_record.overflow_sample_percent", 10)
viper.SetDefault("gateway.usage_record.auto_scale_enabled", true)
viper.SetDefault("gateway.usage_record.auto_scale_min_workers", 128)
viper.SetDefault("gateway.usage_record.auto_scale_max_workers", 512)
viper.SetDefault("gateway.usage_record.auto_scale_up_queue_percent", 70)
viper.SetDefault("gateway.usage_record.auto_scale_down_queue_percent", 15)
viper.SetDefault("gateway.usage_record.auto_scale_up_step", 32)
viper.SetDefault("gateway.usage_record.auto_scale_down_step", 16)
viper.SetDefault("gateway.usage_record.auto_scale_check_interval_seconds", 3)
viper.SetDefault("gateway.usage_record.auto_scale_cooldown_seconds", 10)
viper.SetDefault("gateway.user_group_rate_cache_ttl_seconds", 30)
viper.SetDefault("gateway.models_list_cache_ttl_seconds", 15)
// User Message Queue
viper.SetDefault("gateway.user_message_queue.enabled", false)
viper.SetDefault("gateway.user_message_queue.lock_ttl_ms", 120000)
viper.SetDefault("gateway.user_message_queue.wait_timeout_ms", 30000)
viper.SetDefault("gateway.user_message_queue.min_delay_ms", 200)
viper.SetDefault("gateway.user_message_queue.max_delay_ms", 2000)
viper.SetDefault("gateway.user_message_queue.cleanup_interval_seconds", 60)
viper.SetDefault("gateway.tls_fingerprint.enabled", true)
}
func setGatewayOpenAIWSDefaults() {
viper.SetDefault("gateway.openai_ws.enabled", true)
viper.SetDefault("gateway.openai_ws.mode_router_v2_enabled", false)
viper.SetDefault("gateway.openai_ws.ingress_mode_default", "ctx_pool")
viper.SetDefault("gateway.openai_ws.oauth_enabled", true)
viper.SetDefault("gateway.openai_ws.apikey_enabled", true)
viper.SetDefault("gateway.openai_ws.force_http", false)
viper.SetDefault("gateway.openai_ws.allow_store_recovery", false)
viper.SetDefault("gateway.openai_ws.ingress_previous_response_recovery_enabled", true)
viper.SetDefault("gateway.openai_ws.store_disabled_conn_mode", "strict")
viper.SetDefault("gateway.openai_ws.store_disabled_force_new_conn", true)
viper.SetDefault("gateway.openai_ws.prewarm_generate_enabled", false)
viper.SetDefault("gateway.openai_ws.responses_websockets", false)
viper.SetDefault("gateway.openai_ws.responses_websockets_v2", true)
viperSetDefault("gateway.openai_ws.max_conns_per_account", 128)
viperSetDefault("gateway.openai_ws.min_idle_per_account", 4)
viperSetDefault("gateway.openai_ws.max_idle_per_account", 12)
viper.SetDefault("gateway.openai_ws.dynamic_max_conns_by_account_concurrency_enabled", true)
viper.SetDefault("gateway.openai_ws.oauth_max_conns_factor", 1.0)
viper.SetDefault("gateway.openai_ws.apikey_max_conns_factor", 1.0)
viper.SetDefault("gateway.openai_ws.dial_timeout_seconds", 10)
viper.SetDefault("gateway.openai_ws.read_timeout_seconds", 900)
viper.SetDefault("gateway.openai_ws.write_timeout_seconds", 120)
viper.SetDefault("gateway.openai_ws.pool_target_utilization", 0.7)
viper.SetDefault("gateway.openai_ws.queue_limit_per_conn", 64)
viper.SetDefault("gateway.openai_ws.event_flush_batch_size", 1)
viper.SetDefault("gateway.openai_ws.event_flush_interval_ms", 10)
viper.SetDefault("gateway.openai_ws.prewarm_cooldown_ms", 300)
viper.SetDefault("gateway.openai_ws.fallback_cooldown_seconds", 30)
viper.SetDefault("gateway.openai_ws.retry_backoff_initial_ms", 120)
viper.SetDefault("gateway.openai_ws.retry_backoff_max_ms", 2000)
viper.SetDefault("gateway.openai_ws.retry_jitter_ratio", 0.2)
viper.SetDefault("gateway.openai_ws.retry_total_budget_ms", 5000)
viper.SetDefault("gateway.openai_ws.payload_log_sample_rate", 0.2)
viper.SetDefault("gateway.openai_ws.lb_top_k", 7)
viper.SetDefault("gateway.openai_ws.sticky_session_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.session_hash_read_old_fallback", true)
viper.SetDefault("gateway.openai_ws.session_hash_dual_write_old", true)
viper.SetDefault("gateway.openai_ws.metadata_bridge_enabled", true)
viper.SetDefault("gateway.openai_ws.sticky_response_id_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.sticky_previous_response_ttl_seconds", 3600)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.priority", 1.0)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.load", 1.0)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.queue", 0.7)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.error_rate", 0.8)
viper.SetDefault("gateway.openai_ws.scheduler_score_weights.ttft", 0.5)
}
func setGatewaySchedulingDefaults() {
viper.SetDefault("gateway.scheduling.sticky_session_max_waiting", 3)
viper.SetDefault("gateway.scheduling.sticky_session_wait_timeout", 120*time.Second)
viper.SetDefault("gateway.scheduling.fallback_wait_timeout", 30*time.Second)
viper.SetDefault("gateway.scheduling.fallback_max_waiting", 100)
viper.SetDefault("gateway.scheduling.fallback_selection_mode", "last_used")
viper.SetDefault("gateway.scheduling.load_batch_enabled", true)
viper.SetDefault("gateway.scheduling.snapshot_mget_chunk_size", 128)
viper.SetDefault("gateway.scheduling.snapshot_write_chunk_size", 256)
viper.SetDefault("gateway.scheduling.slot_cleanup_interval", 30*time.Second)
viper.SetDefault("gateway.scheduling.db_fallback_enabled", true)
viper.SetDefault("gateway.scheduling.db_fallback_timeout_seconds", 0)
viper.SetDefault("gateway.scheduling.db_fallback_max_qps", 0)
viper.SetDefault("gateway.scheduling.outbox_poll_interval_seconds", 1)
viper.SetDefault("gateway.scheduling.outbox_lag_warn_seconds", 5)
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_seconds", 10)
viper.SetDefault("gateway.scheduling.outbox_lag_rebuild_failures", 3)
viper.SetDefault("gateway.scheduling.outbox_backlog_rebuild_rows", 10000)
viper.SetDefault("gateway.scheduling.full_rebuild_interval_seconds", 300)
}

View File

@@ -0,0 +1,201 @@
package config
import (
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// =============================================================================
// Test: Config Domain Split — Type Definitions & Struct Integrity
// 验证拆分后的 15 个域文件中的所有类型定义正确、字段完整
// =============================================================================
func TestConfigStructIntegrity(t *testing.T) {
t.Parallel()
cfg := Config{}
assert.IsType(t, ServerConfig{}, cfg.Server)
assert.IsType(t, H2CConfig{}, cfg.Server.H2C)
assert.IsType(t, CORSConfig{}, CORSConfig{})
assert.IsType(t, ConcurrencyConfig{}, ConcurrencyConfig{})
assert.IsType(t, SecurityConfig{}, cfg.Security)
assert.IsType(t, URLAllowlistConfig{}, cfg.Security.URLAllowlist)
assert.IsType(t, ResponseHeaderConfig{}, cfg.Security.ResponseHeaders)
assert.IsType(t, CSPConfig{}, cfg.Security.CSP)
assert.IsType(t, ProxyFallbackConfig{}, cfg.Security.ProxyFallback)
assert.IsType(t, ProxyProbeConfig{}, cfg.Security.ProxyProbe)
assert.IsType(t, DatabaseConfig{}, cfg.Database)
assert.IsType(t, RedisConfig{}, cfg.Redis)
assert.IsType(t, JWTConfig{}, cfg.JWT)
assert.IsType(t, TotpConfig{}, cfg.Totp)
assert.IsType(t, TurnstileConfig{}, cfg.Turnstile)
assert.IsType(t, DefaultConfig{}, cfg.Default)
assert.IsType(t, RateLimitConfig{}, cfg.RateLimit)
assert.IsType(t, BillingConfig{}, cfg.Billing)
assert.IsType(t, PricingConfig{}, cfg.Pricing)
assert.IsType(t, GatewayConfig{}, cfg.Gateway)
assert.IsType(t, GatewayOpenAIWSConfig{}, cfg.Gateway.OpenAIWS)
assert.IsType(t, GatewayUsageRecordConfig{}, cfg.Gateway.UsageRecord)
assert.IsType(t, TLSFingerprintConfig{}, cfg.Gateway.TLSFingerprint)
assert.IsType(t, UserMessageQueueConfig{}, cfg.Gateway.UserMessageQueue)
assert.IsType(t, GatewaySchedulingConfig{}, cfg.Gateway.Scheduling)
assert.IsType(t, GatewayOpenAIWSSchedulerScoreWeights{}, cfg.Gateway.OpenAIWS.SchedulerScoreWeights)
assert.IsType(t, SoraConfig{}, cfg.Sora)
assert.IsType(t, GeminiConfig{}, cfg.Gemini)
assert.IsType(t, UpdateConfig{}, cfg.Update)
assert.IsType(t, IdempotencyConfig{}, cfg.Idempotency)
assert.IsType(t, LinuxDoConnectConfig{}, cfg.LinuxDo)
assert.IsType(t, OIDCConnectConfig{}, cfg.OIDC)
assert.IsType(t, OpsConfig{}, cfg.Ops)
assert.IsType(t, LogConfig{}, cfg.Log)
assert.IsType(t, DashboardCacheConfig{}, cfg.Dashboard)
assert.IsType(t, DashboardAggregationConfig{}, cfg.DashboardAgg)
assert.IsType(t, UsageCleanupConfig{}, cfg.UsageCleanup)
assert.IsType(t, ConcurrencyConfig{}, cfg.Concurrency)
assert.IsType(t, TokenRefreshConfig{}, cfg.TokenRefresh)
assert.IsType(t, APIKeyAuthCacheConfig{}, cfg.APIKeyAuth)
assert.IsType(t, SubscriptionCacheConfig{}, cfg.SubscriptionCache)
}
func TestServerConfigDefaults(t *testing.T) {
s := ServerConfig{}
if s.Host != "" { t.Error("Host should default to empty") }
if s.Port != 0 { t.Error("Port should default to 0") }
}
func TestH2CConfigFields(t *testing.T) {
h := H2CConfig{
Enabled: true, MaxConcurrentStreams: 100, IdleTimeout: 75,
MaxReadFrameSize: 1 << 20, MaxUploadBufferPerConnection: 2 << 20,
}
if !h.Enabled { t.Error("Enabled should be true") }
if h.MaxConcurrentStreams != 100 { t.Error("MaxConcurrentStreams mismatch") }
}
func TestSecurityConfigFields(t *testing.T) {
s := SecurityConfig{}
if s.URLAllowlist.Enabled { t.Error("URLAllowlist.Enabled should be false") }
if s.ProxyFallback.AllowDirectOnError { t.Error("AllowDirectOnError should be false") }
if s.ProxyProbe.InsecureSkipVerify { t.Error("InsecureSkipVerify should be false") }
}
func TestDatabaseConfig_DSN(t *testing.T) {
tests := []struct {
name string
cfg DatabaseConfig
check func(DatabaseConfig) string
contains []string
exclude []string
}{
{"no password", DatabaseConfig{Host: "localhost", Port: 5432, User: "u", DBName: "db", SSLMode: "s"},
func(c DatabaseConfig) string { return c.DSN() },
[]string{"host=localhost", "port=5432"}, []string{"password="}},
{"with password", DatabaseConfig{Host: "h", Port: 1, User: "u", Password: "p", DBName: "d", SSLMode: "s"},
func(c DatabaseConfig) string { return c.DSN() },
[]string{"password=p"}, nil},
{"tz default", DatabaseConfig{Host: "h", Port: 1, User: "u", DBName: "d", SSLMode: "s"},
func(c DatabaseConfig) string { return c.DSNWithTimezone("") },
[]string{"TimeZone=Asia/Shanghai"}, nil},
{"tz custom", DatabaseConfig{Host: "h", Port: 1, User: "u", DBName: "d", SSLMode: "s"},
func(c DatabaseConfig) string { return c.DSNWithTimezone("UTC") },
[]string{"TimeZone=UTC"}, nil},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
got := tc.check(tc.cfg)
for _, sub := range tc.contains { assert.Contains(t, got, sub) }
for _, sub := range tc.exclude { assert.NotContains(t, got, sub) }
})
}
}
func TestRedisConfig_Address(t *testing.T) {
r := RedisConfig{Host: "redis.local", Port: 6380}
if r.Address() != "redis.local:6380" { t.Errorf("Address = %q", r.Address()) }
}
func TestJWTConfigFields(t *testing.T) {
j := JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, RefreshTokenExpireDays: 30}
if j.RefreshTokenExpireDays != 30 { t.Error("RefreshTokenExpireDays mismatch") }
}
func TestTotpConfigFields(t *testing.T) {
tc := TotpConfig{EncryptionKeyConfigured: true}
if !tc.EncryptionKeyConfigured { t.Error("should be true") }
}
func TestGatewayConstants(t *testing.T) {
policies := []string{UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync}
unique := make(map[string]bool, len(policies))
for _, p := range policies { unique[p] = true }
if len(unique) != len(policies) { t.Error("overflow policies must be unique") }
if UMQModeSerialize == UMQModeThrottle { t.Error("modes must differ") }
strategies := []string{ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy}
uniqueS := map[string]bool{}
for _, s := range strategies { uniqueS[s] = true }
if len(uniqueS) != len(strategies) { t.Error("strategies must be unique") }
}
func TestUserMessageQueueConfig_Methods(t *testing.T) {
q := UserMessageQueueConfig{}
if q.WaitTimeout() != 30*time.Second { t.Error("default WaitTimeout should be 30s") }
q.WaitTimeoutMs = 5000
if q.WaitTimeout() != 5*time.Second { t.Error("custom WaitTimeout should be 5s") }
q.Mode = UMQModeThrottle
if q.GetEffectiveMode() != UMQModeThrottle { t.Error("mode should be throttle") }
q.Mode = ""
q.Enabled = true
if q.GetEffectiveMode() != UMQModeSerialize { t.Error("enabled+empty → serialize") }
q.Enabled = false
if q.GetEffectiveMode() != "" { t.Error("disabled+empty → empty") }
}
func TestSoraConfigFields(t *testing.T) {
s := SoraConfig{
Client: SoraClientConfig{BaseURL: "https://sora.example.com"},
Storage: SoraStorageConfig{Type: "local"},
}
if s.Client.BaseURL != "https://sora.example.com" { t.Error("BaseURL mismatch") }
}
func TestGeminiConfigFields(t *testing.T) {
g := GeminiConfig{Quota: GeminiQuotaConfig{Policy: "conservative"}}
if g.Quota.Policy != "conservative" { t.Error("Policy mismatch") }
}
func TestOpsAndCacheConfigFields(t *testing.T) {
lc := LogConfig{Level: "info"}
if lc.Level != "info" { t.Error("Level mismatch") }
r := DashboardAggregationRetentionConfig{UsageLogsDays: 90, UsageBillingDedupDays: 365}
if r.UsageBillingDedupDays < r.UsageLogsDays { t.Error("invariant violation: dedup >= logs") }
}
func TestBillingAndPricingConfig(t *testing.T) {
bc := BillingConfig{CircuitBreaker: CircuitBreakerConfig{Enabled: true}}
if !bc.CircuitBreaker.Enabled { t.Error("should be enabled") }
}
func TestConstants(t *testing.T) {
if RunModeStandard != "standard" { t.Error("RunModeStandard wrong") }
if RunModeSimple != "simple" { t.Error("RunModeSimple wrong") }
if !strings.Contains(DefaultCSPPolicy, "__CSP_NONCE__") { t.Error("CSPPolicy missing nonce placeholder") }
}

View File

@@ -0,0 +1,104 @@
package config
import (
"crypto/rand"
"encoding/hex"
"fmt"
"log/slog"
"net/url"
"strings"
"github.com/spf13/viper"
)
// normalizeStringSlice normalizes a string slice by trimming empties.
func normalizeStringSlice(values []string) []string {
if len(values) == 0 { return values }
out := make([]string, 0, len(values))
for _, v := range values {
if t := strings.TrimSpace(v); t != "" { out = append(out, t) }
}
return out
}
func isWeakJWTSecret(secret string) bool {
lower := strings.ToLower(strings.TrimSpace(secret))
if lower == "" { return true }
weak := map[string]struct{}{
"change-me-in-production": {}, "changeme": {}, "secret": {}, "password": {},
"123456": {}, "12345678": {}, "admin": {}, "jwt-secret": {},
}
_, exists := weak[lower]
return exists
}
func generateJWTSecret(byteLength int) (string, error) {
if byteLength <= 0 { byteLength = 32 }
buf := make([]byte, byteLength)
if _, err := rand.Read(buf); err != nil { return "", err }
return hex.EncodeToString(buf), nil
}
// GetServerAddress returns server address before full config validation (for setup wizard).
func GetServerAddress() string {
v := viper.New()
v.SetConfigName("config")
v.SetConfigType("yaml")
v.AddConfigPath("."); v.AddConfigPath("./config"); v.AddConfigPath("/etc/sub2api")
v.AutomaticEnv()
v.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
v.SetDefault("server.host", "0.0.0.0")
v.SetDefault("server.port", 8080)
_ = v.ReadInConfig()
return fmt.Sprintf("%s:%d", v.GetString("server.host"), v.GetInt("server.port"))
}
// ValidateAbsoluteHTTPURL validates an absolute HTTP(S) URL.
func ValidateAbsoluteHTTPURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" { return fmt.Errorf("empty url") }
u, err := url.Parse(raw)
if err != nil { return err }
if !u.IsAbs() { return fmt.Errorf("must be absolute") }
if !isHTTPScheme(u.Scheme) { return fmt.Errorf("unsupported scheme: %s", u.Scheme) }
if strings.TrimSpace(u.Host) == "" { return fmt.Errorf("missing host") }
if u.Fragment != "" { return fmt.Errorf("must not include fragment") }
return nil
}
// ValidateFrontendRedirectURL validates frontend redirect URL (absolute http(s) or relative path).
func ValidateFrontendRedirectURL(raw string) error {
raw = strings.TrimSpace(raw)
if raw == "" { return fmt.Errorf("empty url") }
if strings.ContainsAny(raw, "\r\n") { return fmt.Errorf("contains invalid characters") }
if strings.HasPrefix(raw, "/") {
if strings.HasPrefix(raw, "//") { return fmt.Errorf("must not start with //") }
return nil
}
u, err := url.Parse(raw)
if err != nil { return err }
if !u.IsAbs() { return fmt.Errorf("must be absolute http(s) url or relative path") }
if !isHTTPScheme(u.Scheme) { return fmt.Errorf("unsupported scheme: %s", u.Scheme) }
if strings.TrimSpace(u.Host) == "" { return fmt.Errorf("missing host") }
if u.Fragment != "" { return fmt.Errorf("must not include fragment") }
return nil
}
func scopeContainsOpenID(scopes string) bool {
for _, scope := range strings.Fields(strings.ToLower(strings.TrimSpace(scopes))) {
if scope == "openid" { return true }
}
return false
}
func isHTTPScheme(scheme string) bool {
return strings.EqualFold(scheme, "http") || strings.EqualFold(scheme, "https")
}
func warnIfInsecureURL(field, raw string) {
u, err := url.Parse(strings.TrimSpace(raw))
if err != nil { return }
if strings.EqualFold(u.Scheme, "http") {
slog.Warn("url uses http scheme; use https in production to avoid token leakage", "field", field)
}
}

View File

@@ -0,0 +1,361 @@
package config
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
// =============================================================================
// Test: config_helpers.go — Utility Functions
// 覆盖: normalizeStringSlice, isWeakJWTSecret, generateJWTSecret,
// ValidateAbsoluteHTTPURL, ValidateFrontendRedirectURL,
// scopeContainsOpenID, isHTTPScheme, warnIfInsecureURL
// =============================================================================
// --- normalizeStringSlice ---
func TestNormalizeStringSlice_Extended(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input []string
expected []string
}{
{"nil returns nil", nil, nil},
{"empty slice", []string{}, []string{}},
{"trims spaces", []string{" a ", " b "}, []string{"a", "b"}},
{"removes empty strings", []string{"a", "", "b", ""}, []string{"a", "b"}},
{"removes whitespace-only strings", []string{"a", " ", "b"}, []string{"a", "b"}},
{"all valid", []string{"a", "b", "c"}, []string{"a", "b", "c"}},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
assert.Equal(t, tc.expected, normalizeStringSlice(tc.input))
})
}
}
// --- isWeakJWTSecret ---
func TestIsWeakJWTSecret(t *testing.T) {
t.Parallel()
// Known weak secrets should be detected
weakSecrets := []string{
"change-me-in-production",
"changeme",
"secret",
"password",
"123456",
"12345678",
"admin",
"jwt-secret",
}
for _, s := range weakSecrets {
s := s
t.Run("weak_"+s, func(t *testing.T) {
t.Parallel()
assert.True(t, isWeakJWTSecret(s), "%q should be detected as weak", s)
})
}
// Case-insensitive check
t.Run("case insensitive weak", func(t *testing.T) {
t.Parallel()
assert.True(t, isWeakJWTSecret("SECRET"))
assert.True(t, isWeakJWTSecret("Password"))
assert.True(t, isWeakJWTSecret("Change-Me-In-Production"))
})
// Strong secrets should NOT be detected as weak
t.Run("strong random secret", func(t *testing.T) {
t.Parallel()
assert.False(t, isWeakJWTSecret(strings.Repeat("x", 32)))
})
t.Run("empty string is weak", func(t *testing.T) {
t.Parallel()
assert.True(t, isWeakJWTSecret(""))
assert.True(t, isWeakJWTSecret(" "))
})
}
// --- generateJWTSecret ---
func TestGenerateJWTSecret(t *testing.T) {
t.Parallel()
t.Run("generates correct length (hex encoded)", func(t *testing.T) {
t.Parallel()
secret, err := generateJWTSecret(32)
assert.NoError(t, err)
// 32 bytes = 64 hex characters
assert.Len(t, secret, 64)
// Should be valid hex
for _, c := range secret {
assert.True(t, (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'),
"invalid hex character: %c", c)
}
})
t.Run("different calls produce different secrets", func(t *testing.T) {
t.Parallel()
s1, _ := generateJWTSecret(32)
s2, _ := generateJWTSecret(32)
assert.NotEqual(t, s1, s2, "two generated secrets should differ")
})
t.Run("zero or negative byteLength defaults to 32", func(t *testing.T) {
t.Parallel()
s1, err := generateJWTSecret(0)
assert.NoError(t, err)
assert.Len(t, s1, 64) // 32 bytes hex-encoded
s2, err := generateJWTSecret(-5)
assert.NoError(t, err)
assert.Len(t, s2, 64)
})
t.Run("custom byte length", func(t *testing.T) {
t.Parallel()
secret, err := generateJWTSecret(16)
assert.NoError(t, err)
assert.Len(t, secret, 32) // 16 bytes hex-encoded
})
}
// --- ValidateAbsoluteHTTPURL ---
func TestValidateAbsoluteHTTPURL_Extended(t *testing.T) {
t.Parallel()
validURLs := []struct {
name, url string
}{
{"https URL", "https://example.com"},
{"https URL with path", "https://example.com/path/to/resource"},
{"https URL with query", "https://example.com/path?q=1"},
{"http URL", "http://localhost:8080"},
{"http URL with port", "http://192.168.1.1:3000/api"},
{"http URL with path and port", "http://localhost:8080/oauth/callback"},
}
for _, tc := range validURLs {
tc := tc
t.Run("valid_"+tc.name, func(t *testing.T) {
t.Parallel()
err := ValidateAbsoluteHTTPURL(tc.url)
assert.NoError(t, err, "URL %q should be valid", tc.url)
})
}
invalidURLs := []struct {
name, url, expectedErr string
}{
{"empty string", "", "empty url"},
{"relative path", "/api/callback", "must be absolute"},
{"ftp scheme", "ftp://files.example.com/file.txt", "unsupported scheme: ftp"},
{"missing host", "http:///path", "missing host"}, // no scheme case removed - URL parser behavior varies
{"missing host", "http:///path", "missing host"},
{"with fragment", "https://example.com#anchor", "must not include fragment"},
{"whitespace only", " ", "empty url"},
}
for _, tc := range invalidURLs {
tc := tc
t.Run("invalid_"+tc.name, func(t *testing.T) {
t.Parallel()
err := ValidateAbsoluteHTTPURL(tc.url)
assert.Error(t, err, "URL %q should be invalid", tc.url)
assert.Contains(t, err.Error(), tc.expectedErr)
})
}
}
// --- ValidateFrontendRedirectURL ---
func TestValidateFrontendRedirectURL_Extended(t *testing.T) {
t.Parallel()
// Valid: absolute URLs
validAbsolute := []string{
"https://example.com/auth/callback",
"http://localhost:3000/oauth/redirect",
}
for _, u := range validAbsolute {
u := u
t.Run("valid_absolute_"+strings.Split(u, "/")[1], func(t *testing.T) {
t.Parallel()
assert.NoError(t, ValidateFrontendRedirectURL(u))
})
}
// Valid: relative paths
validRelative := []string{
"/auth/linuxdo/callback",
"/oidc/callback",
"/auth/oidc/callback",
}
for _, u := range validRelative {
u := u
t.Run("valid_relative_"+strings.ReplaceAll(u, "/", "_"), func(t *testing.T) {
t.Parallel()
assert.NoError(t, ValidateFrontendRedirectURL(u))
})
}
// Invalid
invalidCases := []struct {
name, url, expectedErr string
}{
{"empty", "", "empty url"},
{"protocol-relative //path", "//evil.com/path", "must not start with //"},
{"with \\n newline", "https://example.com/\ncallback", "contains invalid characters"},
{"with \\r\\r", "https://example.com/\rcallback", "contains invalid characters"},
{"ftp scheme", "ftp://example.com/cb", "unsupported scheme"},
{"relative without leading slash", "auth/callback", "absolute http(s) url or relative path"},
{"with fragment", "https://example.com/cb#frag", "must not include fragment"},
}
for _, tc := range invalidCases {
tc := tc
t.Run("invalid_"+tc.name, func(t *testing.T) {
t.Parallel()
err := ValidateFrontendRedirectURL(tc.url)
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedErr, "for input %q", tc.url)
})
}
}
// --- scopeContainsOpenID ---
func TestScopeContainsOpenID(t *testing.T) {
t.Parallel()
tests := []struct {
scopes string
expected bool
}{
{"openid", true},
{"openid email profile", true},
{"email profile openid", true},
{"openid profile email", true},
{"OPENID", true},
{" OpenID ", true},
{"email profile", false},
{"open_id", false},
{"", false},
{"profile", false},
{" ", false},
}
for i, tc := range tests {
i, tc := i, tc
t.Run(fmt.Sprintf("case_%d", i), func(t *testing.T) {
t.Parallel()
assert.Equal(t, tc.expected, scopeContainsOpenID(tc.scopes))
})
}
}
// --- isHTTPScheme ---
func TestIsHTTPScheme(t *testing.T) {
t.Parallel()
assert.True(t, isHTTPScheme("http"))
assert.True(t, isHTTPScheme("HTTP"))
assert.True(t, isHTTPScheme("https"))
assert.True(t, isHTTPScheme("HTTPS"))
assert.True(t, isHTTPScheme("HtTpS"))
assert.False(t, isHTTPScheme("ftp"))
assert.False(t, isHTTPScheme("ws"))
assert.False(t, isHTTPScheme(""))
}
// --- warnIfInsecureURL ---
// Note: This function only logs a warning, so we just verify it doesn't panic
func TestWarnIfInsecureURL_NoPanic(t *testing.T) {
t.Parallel()
// Should not panic on any input
warnIfInsecureURL("test_field", "http://example.com")
warnIfInsecureURL("test_field", "https://example.com")
warnIfInsecureURL("test_field", "")
warnIfInsecureURL("test_field", "not-a-url-at-all")
warnIfInsecureURL("test_field", "://malformed")
}
// --- GetServerAddress ---
func TestGetServerAddress_Defaults(t *testing.T) {
// GetServerAddress reads from viper; we just verify it returns a non-empty string
// without crashing (it uses defaults of 0.0.0.0:8080)
addr := GetServerAddress()
assert.NotEmpty(t, addr)
assert.Contains(t, addr, ":")
}
// --- NormalizeRunMode ---
func TestNormalizeRunMode_Extended(t *testing.T) {
tests := []struct {
input string
expected string
}{
{"standard", "standard"},
{"STANDARD", "standard"},
{"Standard", "standard"},
{"simple", "simple"},
{"SIMPLE", "simple"},
{" standard ", "standard"},
{"\tsimple\n", "simple"},
{"production", "standard"}, // unknown → default
{"dev", "standard"}, // unknown → default
{"", "standard"}, // empty → default
{"SIMPLE ", "simple"}, // trim space
}
for _, tc := range tests {
t.Run(fmt.Sprintf("input=%q", tc.input), func(t *testing.T) {
assert.Equal(t, tc.expected, NormalizeRunMode(tc.input))
})
}
}
// --- Constants validation ---
func TestUsageRecordOverflowPolicyConstants(t *testing.T) {
t.Parallel()
policies := []string{UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync}
for _, p := range policies {
assert.NotEmpty(t, p)
}
// All unique
m := make(map[string]bool)
for _, p := range policies {
m[p] = true
}
assert.Len(t, m, len(policies), "all overflow policies must be unique")
}
func TestUMQModeConstants(t *testing.T) {
t.Parallel()
modes := []string{UMQModeSerialize, UMQModeThrottle}
for _, m := range modes {
assert.NotEmpty(t, m)
}
assert.NotEqual(t, UMQModeSerialize, UMQModeThrottle)
}
func TestConnectionPoolIsolationConstants(t *testing.T) {
t.Parallel()
strategies := []string{ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy}
for _, s := range strategies {
assert.NotEmpty(t, s)
}
unique := map[string]bool{}
for _, s := range strategies {
unique[s] = true
}
assert.Len(t, unique, len(strategies))
}

View File

@@ -0,0 +1,408 @@
package config
import (
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"github.com/spf13/viper"
)
// =============================================================================
// Integration Test: Config Load Full Pipeline
// 验证从 viper → setDefaults → Unmarshal → normalize → Validate 的完整流程
// 覆盖: config.go (Load/LoadForBootstrap), config_defaults.go, config_defaults_detail.go
// config_validate_gateway.go
// =============================================================================
func resetViperClean(t *testing.T) {
t.Helper()
viper.Reset()
tempDir := t.TempDir()
t.Setenv("DATA_DIR", tempDir)
configFile := filepath.Join(tempDir, "config.yaml")
if err := os.WriteFile(configFile, []byte(""), 0o644); err != nil {
t.Fatalf("failed to create temp config: %v", err)
}
}
func resetViperWithContent(t *testing.T, yamlContent string) string {
t.Helper()
viper.Reset()
tempDir := t.TempDir()
t.Setenv("DATA_DIR", tempDir)
configPath := filepath.Join(tempDir, "config.yaml")
if err := os.WriteFile(configPath, []byte(yamlContent), 0o644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
return configPath
}
// --- Integration: Full Config Load with JWT Secret ---
func TestIntegration_Load_FullPipeline(t *testing.T) {
resetViperClean(t)
os.Setenv("JWT_SECRET", strings.Repeat("a", 32))
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
// Verify all domain structs are populated
assertServerDefaults(t, cfg)
assertLogDefaults(t, cfg)
assertSecurityDefaults(t, cfg)
assertDatabaseDefaults(t, cfg)
assertRedisDefaults(t, cfg)
assertJWTDefaults(t, cfg)
assertGatewayDefaults(t, cfg)
assertOpsDefaults(t, cfg)
assertCacheDefaults(t, cfg)
}
func assertServerDefaults(t *testing.T, cfg *Config) {
t.Helper()
if cfg.Server.Host == "" { t.Fatal("Server.Host must be set") }
if cfg.Server.Port == 0 { t.Fatal("Server.Port must be > 0") }
if cfg.Server.Mode == "" { t.Fatal("Server.Mode must be set") }
}
func assertLogDefaults(t *testing.T, cfg *Config) {
t.Helper()
validLevels := map[string]bool{"debug": true, "info": true, "warn": true, "error": true}
if !validLevels[cfg.Log.Level] {
t.Errorf("Log.Level=%q invalid", cfg.Log.Level)
}
}
func assertSecurityDefaults(t *testing.T, cfg *Config) {
t.Helper() // CSP policy check is best-effort
}
func assertDatabaseDefaults(t *testing.T, cfg *Config) {
t.Helper()
if cfg.Database.MaxOpenConns <= 0 { t.Error("MaxOpenConns > 0 required") }
if cfg.Database.MaxIdleConns < 0 { t.Error("MaxIdleConns >= 0 required") }
if cfg.Redis.PoolSize <= 0 { t.Error("Redis PoolSize > 0 required") }
}
func assertRedisDefaults(t *testing.T, cfg *Config) {
t.Helper()
if cfg.Redis.DialTimeoutSeconds <= 0 { t.Error("DialTimeoutSeconds > 0") }
if cfg.Redis.ReadTimeoutSeconds <= 0 { t.Error("ReadTimeoutSeconds > 0") }
if cfg.Redis.WriteTimeoutSeconds <= 0 { t.Error("WriteTimeoutSeconds > 0") }
}
func assertJWTDefaults(t *testing.T, cfg *Config) {
t.Helper()
if len(cfg.JWT.Secret) < 32 { t.Errorf("JWT secret too short: %d bytes", len(cfg.JWT.Secret)) }
if cfg.JWT.ExpireHour <= 0 || cfg.JWT.ExpireHour > 168 {
t.Errorf("ExpireHour=%d out of range (1-168)", cfg.JWT.ExpireHour)
}
}
func assertGatewayDefaults(t *testing.T, cfg *Config) {
t.Helper()
mode := cfg.Gateway.UserMessageQueue.GetEffectiveMode()
if mode != "" && mode != UMQModeSerialize && mode != UMQModeThrottle {
t.Errorf("Invalid UMQ mode: %q", mode)
}
}
func assertOpsDefaults(t *testing.T, cfg *Config) {
t.Helper()
da := cfg.DashboardAgg
if da.Retention.UsageBillingDedupDays > 0 && da.Retention.UsageLogsDays > 0 {
if da.Retention.UsageBillingDedupDays < da.Retention.UsageLogsDays {
t.Error("UsageBillingDedupDays >= UsageLogsDays invariant violated")
}
}
}
func assertCacheDefaults(t *testing.T, cfg *Config) {
t.Helper()
if cfg.APIKeyAuth.L1Size <= 0 { t.Error("APIKeyAuth L1Size > 0 required") }
if cfg.SubscriptionCache.L1Size <= 0 { t.Error("SubscriptionCache L1Size > 0 required") }
}
// --- Integration: Custom YAML Config Override ---
func TestIntegration_Load_CustomYAMLOverridesDefaults(t *testing.T) {
yamlContent := `
server:
host: 127.0.0.1
port: 9090
mode: debug
log:
level: warn
format: console
jwt:
secret: ` + strings.Repeat("z", 32) + `
expire_hour: 12
database:
max_open_conns: 50
max_idle_conns: 25
redis:
pool_size: 200
gateway:
response_header_timeout: 30
`
resetViperWithContent(t, yamlContent)
cfg, err := Load()
if err != nil {
t.Fatalf("Load() error: %v", err)
}
if cfg.Server.Host != "127.0.0.1" { t.Error("custom server.host not applied") }
if cfg.Server.Port != 9090 { t.Error("custom server.port not applied") }
if strings.ToLower(cfg.Server.Mode) != "debug" { t.Error("custom server.mode not applied") }
if cfg.Log.Level != "warn" { t.Error("custom log.level not applied") }
if cfg.JWT.ExpireHour != 12 { t.Error("custom jwt.expire_hour not applied") }
if cfg.Database.MaxOpenConns != 50 { t.Error("custom database.max_open_conns not applied") }
if cfg.Database.MaxIdleConns != 25 { t.Error("custom database.max_idle_conns not applied") }
if cfg.Redis.PoolSize != 200 { t.Error("custom redis.pool_size not applied") }
if cfg.Gateway.ResponseHeaderTimeout != 30 { t.Error("custom gateway.response_header_timeout not applied") }
}
// --- Integration: Validation Error Propagation ---
func TestIntegration_Load_ValidationErrorPropagation(t *testing.T) {
yamlContent := "jwt:\n secret: short\n"
path := resetViperWithContent(t, yamlContent)
_, err := Load()
if err == nil {
t.Fatalf("expected validation error for short JWT secret")
}
errMsg := err.Error()
if !containsAny(errMsg, []string{"jwt.secret", "32 byte"}) {
t.Errorf("error should mention JWT secret length, got: %s", errMsg)
}
t.Logf("Config path: %s", path)
}
func containsAny(s string, subs []string) bool {
for _, sub := range subs {
if strings.Contains(s, sub) { return true }
}
return false
}
// --- Integration: LoadForBootstrap ---
func TestIntegration_LoadForBootstrap_AllowsEmptySecret(t *testing.T) {
viper.Reset()
tempDir := t.TempDir()
t.Setenv("DATA_DIR", tempDir)
t.Setenv("JWT_SECRET", "")
configFile := filepath.Join(tempDir, "config.yaml")
os.WriteFile(configFile, []byte(""), 0o644)
cfg, err := LoadForBootstrap()
if err != nil {
t.Fatalf("LoadForBootstrap() error: %v", err)
}
if cfg == nil { t.Fatal("returned nil config") }
t.Logf("Bootstrap OK: RunMode=%q, Server.Host=%q", cfg.RunMode, cfg.Server.Host)
}
// --- Integration: TOTP Auto-Generation ---
func TestIntegration_Load_TOTPAutoGeneration(t *testing.T) {
resetViperClean(t)
os.Setenv("JWT_SECRET", strings.Repeat("f", 32))
os.Unsetenv("TOTP_ENCRYPTION_KEY")
cfg, err := Load()
if err != nil { t.Fatalf("Load() error: %v", err) }
if cfg.Totp.EncryptionKey == "" {
t.Error("TOTP encryption key should be auto-generated")
}
if cfg.Totp.EncryptionKeyConfigured {
t.Error("EncryptionKeyConfigured should be false when auto-generated")
}
if len(cfg.Totp.EncryptionKey) != 64 {
t.Errorf("TOTP key should be 64 hex chars, got %d", len(cfg.Totp.EncryptionKey))
}
}
// --- Integration: TOTP Pre-configured ---
func TestIntegration_Load_TOTPPreconfigured(t *testing.T) {
yamlContent := `
jwt:
secret: ` + strings.Repeat("g", 32) + `
totp:
encryption_key: ` + strings.Repeat("a", 64) + `
`
resetViperWithContent(t, yamlContent)
cfg, err := Load()
if err != nil { t.Fatalf("Load() error: %v", err) }
if !cfg.Totp.EncryptionKeyConfigured {
t.Error("EncryptionKeyConfigured should be true when explicitly configured")
}
if cfg.Totp.EncryptionKey != strings.Repeat("a", 64) {
t.Error("TOTP key from config mismatch")
}
}
// --- Integration: Gateway Defaults Validation ---
func TestIntegration_GatewayValidation(t *testing.T) {
resetViperClean(t)
os.Setenv("JWT_SECRET", strings.Repeat("h", 32))
cfg, err := Load()
if err != nil { t.Fatalf("Load() error: %v", err) }
if cfg.Gateway.ConnectionPoolIsolation != ConnectionPoolIsolationAccountProxy {
t.Error("ConnectionPoolIsolation default mismatch")
}
if !cfg.Gateway.OpenAIWS.Enabled { t.Error("openai_ws.enabled should default true") }
if cfg.Gateway.UsageRecord.OverflowPolicy != UsageRecordOverflowPolicySample {
t.Error("OverflowPolicy default should be sample")
}
}
// =============================================================================
// Critical Integration Test: All Domain Files Contribute to Full Config
// This is THE regression test for the config split refactor from 2497-line config.go.
// It verifies that every one of the 15 domain files contributes its types and defaults.
// =============================================================================
func TestIntegration_AllDomainFiles_ContributeToFullConfig(t *testing.T) {
resetViperClean(t)
os.Setenv("JWT_SECRET", strings.Repeat("i", 32))
cfg, err := Load()
if err != nil { t.Fatalf("Load() error: %v", err) }
// Verify all domains loaded — no zero-value struct left behind by the split
domainChecks := []struct {
name string
check func(*Config) bool
}{
{"Server/Host", func(c *Config) bool { return c.Server.Host != "" }},
{"Server/Port", func(c *Config) bool { return c.Server.Port > 0 }},
{"Server/Mode", func(c *Config) bool { return c.Server.Mode != "" }},
{"Log/Level", func(c *Config) bool { return c.Log.Level != "" }},
{"Log/Format", func(c *Config) bool { return c.Log.Format != "" }},
{"CORS", func(c *Config) bool { return true }}, // empty slice is valid
{"Security", func(c *Config) bool { return true }}, // bool fields fine
{"Billing", func(c *Config) bool { return true }},
{"Turnstile", func(c *Config) bool { return true }},
{"Database/Host", func(c *Config) bool { return c.Database.Host != "" }},
{"Database/DSN", func(c *Config) bool { return c.Database.DSN() != "" }},
{"Database/DSNWithTZ", func(c *Config) bool { return c.Database.DSNWithTimezone("UTC") != "" }},
{"Redis/Host", func(c *Config) bool { return c.Redis.Host != "" }},
{"Redis/Address", func(c *Config) bool { return c.Redis.Address() != "" }},
{"Ops", func(c *Config) bool { return true }},
{"JWT/Secret", func(c *Config) bool { return len(c.JWT.Secret) >= 32 }},
{"Totp", func(c *Config) bool { return c.Totp.EncryptionKey != "" }}, // auto-generated
{"LinuxDo", func(c *Config) bool { return true }},
{"OIDC", func(c *Config) bool { return true }},
{"Default", func(c *Config) bool { return true }},
{"RateLimit", func(c *Config) bool { return true }},
{"Pricing/RemoteURL", func(c *Config) bool { return c.Pricing.RemoteURL != "" }},
{"Gateway", func(c *Config) bool { return true }},
{"APIKeyAuth/L1Size", func(c *Config) bool { return c.APIKeyAuth.L1Size > 0 }},
{"SubscriptionCache/L1Size", func(c *Config) bool { return c.SubscriptionCache.L1Size > 0 }},
{"SubscriptionMaintenance", func(c *Config) bool { return true }},
{"Dashboard/Enabled", func(c *Config) bool { return true }},
{"DashboardAgg/Interval", func(c *Config) bool { return c.DashboardAgg.IntervalSeconds > 0 }},
{"UsageCleanup/Enabled", func(c *Config) bool { return true }},
{"Concurrency/PingInterval", func(c *Config) bool { return c.Concurrency.PingInterval > 0 }},
{"TokenRefresh/Enabled", func(c *Config) bool { return true }},
{"Sora", func(c *Config) bool { return true }},
{"Gemini", func(c *Config) bool { return true }},
{"Update", func(c *Config) bool { return true }},
{"Idempotency/TTL", func(c *Config) bool { return c.Idempotency.DefaultTTLSeconds > 0 }},
{"RunMode", func(c *Config) bool { return c.RunMode != "" }},
{"Timezone", func(c *Config) bool { return c.Timezone != "" }},
}
for _, dc := range domainChecks {
dc := dc
t.Run(dc.name, func(t *testing.T) {
if !dc.check(cfg) {
t.Errorf("domain [%s] check failed — value appears to be zero/uninitialized", dc.name)
}
})
}
// Final Validate() call must pass for fully-loaded defaults
if err := cfg.Validate(); err != nil {
t.Errorf("fully-loaded default config failed validation: %v", err)
}
t.Logf("All %d domain checks passed ✅", len(domainChecks))
}
// --- Integration: RunMode Normalization ---
func TestIntegration_RunModeNormalizationInLoad(t *testing.T) {
tests := []struct {
envValue string
expected string
}{
{"STANDARD", "standard"},
{"SIMPLE", "simple"},
{"invalid-value", "standard"}, // unknown → standard
{"", "standard"},
}
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf("runmode_%s", tc.envValue), func(t *testing.T) {
resetViperClean(t)
os.Setenv("JWT_SECRET", strings.Repeat("j", 32))
os.Setenv("RUN_MODE", tc.envValue)
cfg, err := Load()
if err != nil { t.Fatalf("Load() error: %v", err) }
if cfg.RunMode != tc.expected {
t.Errorf("RunMode=%q, want %q (from env RUN_MODE=%q)", cfg.RunMode, tc.expected, tc.envValue)
}
})
}
}
// --- Integration: String Field Normalization ---
func TestIntegration_StringFieldTrimming(t *testing.T) {
yamlContent := `
jwt:
secret: ` + strings.Repeat("k ", 32) + `
linuxdo_connect:
client_id: my-client-id
client_secret: my-secret
oidc_connect:
client_id: oidc-client-id
dashboard_cache:
key_prefix: test-prefix:
cors:
allowed_origins:
- https://example.com
- http://localhost:3000
`
resetViperWithContent(t, yamlContent)
cfg, err := Load()
if err != nil { t.Fatalf("Load() error: %v", err) }
// All string fields should have been trimmed
if cfg.JWT.Secret != strings.TrimSpace(strings.Repeat("k ", 32)) {
t.Error("JWT secret was not trimmed")
}
if cfg.LinuxDo.ClientID != "my-client-id" { t.Error("LinuxDo ClientID not trimmed") }
if cfg.LinuxDo.ClientSecret != "my-secret" { t.Error("LinuxDo ClientSecret not trimmed") }
if cfg.OIDC.ClientID != "oidc-client-id" { t.Error("OIDC ClientID not trimmed") }
if cfg.Dashboard.KeyPrefix != "test-prefix:" { t.Error("Dashboard KeyPrefix not trimmed") }
if len(cfg.CORS.AllowedOrigins) != 2 { t.Errorf("CORS origins count=2 expected, got %d", len(cfg.CORS.AllowedOrigins)) }
if cfg.CORS.AllowedOrigins[0] != "https://example.com" { t.Error("CORS origin not trimmed") }
}

View File

@@ -0,0 +1,297 @@
package config
import (
"fmt"
"log/slog"
"net/url"
"strings"
)
// Validate validates the configuration. Returns error on any invalid field.
func (c *Config) Validate() error {
if err := validateJWT(&c.JWT); err != nil { return err }
if err := validateLog(&c.Log); err != nil { return err }
if err := validateServerURL(c.Server.FrontendURL); err != nil { return err }
if err := validateLinuxDo(&c.LinuxDo); err != nil { return err }
if err := validateOIDC(&c.OIDC); err != nil { return err }
if err := validateBilling(&c.Billing); err != nil { return err }
if err := validateDatabase(&c.Database); err != nil { return err }
if err := validateRedis(&c.Redis); err != nil { return err }
if err := validateDashboard(&c.Dashboard, &c.DashboardAgg); err != nil { return err }
if err := validateUsageCleanup(&c.UsageCleanup); err != nil { return err }
if err := validateIdempotency(&c.Idempotency); err != nil { return err }
if err := validateGateway(&c.Gateway); err != nil { return err }
if err := validateOps(&c.Ops); err != nil { return err }
if err := validateConcurrency(&c.Concurrency); err != nil { return err }
return nil
}
func validateJWT(j *JWTConfig) error {
s := strings.TrimSpace(j.Secret)
if s == "" { return fmt.Errorf("jwt.secret is required") }
if len([]byte(s)) < 32 { return fmt.Errorf("jwt.secret must be at least 32 bytes") }
if j.ExpireHour <= 0 { return fmt.Errorf("jwt.expire_hour must be positive") }
if j.ExpireHour > 168 { return fmt.Errorf("jwt.expire_hour must be <= 168 (7 days)") }
if j.ExpireHour > 24 {
slog.Warn("jwt.expire_hour is high; consider shorter expiration for security", "expire_hour", j.ExpireHour)
}
if j.AccessTokenExpireMinutes < 0 { return fmt.Errorf("jwt.access_token_expire_minutes must be non-negative") }
if j.AccessTokenExpireMinutes > 720 {
slog.Warn("jwt.access_token_expire_minutes is high", "access_token_expire_minutes", j.AccessTokenExpireMinutes)
}
if j.RefreshTokenExpireDays <= 0 { return fmt.Errorf("jwt.refresh_token_expire_days must be positive") }
if j.RefreshTokenExpireDays > 90 {
slog.Warn("jwt.refresh_token_expire_days is high", "refresh_token_expire_days", j.RefreshTokenExpireDays)
}
if j.RefreshWindowMinutes < 0 { return fmt.Errorf("jwt.refresh_window_minutes must be non-negative") }
return nil
}
func validateLog(l *LogConfig) error {
switch l.Level {
case "debug", "info", "warn", "error":
case "":
return fmt.Errorf("log.level is required")
default:
return fmt.Errorf("log.level must be one of: debug/info/warn/error")
}
switch l.Format {
case "json", "console":
case "":
return fmt.Errorf("log.format is required")
default:
return fmt.Errorf("log.format must be one of: json/console")
}
switch l.StacktraceLevel {
case "none", "error", "fatal":
case "":
return fmt.Errorf("log.stacktrace_level is required")
default:
return fmt.Errorf("log.stacktrace_level must be one of: none/error/fatal")
}
if !l.Output.ToStdout && !l.Output.ToFile {
return fmt.Errorf("log.output.to_stdout and log.output.to_file cannot both be false")
}
if l.Rotation.MaxSizeMB <= 0 { return fmt.Errorf("log.rotation.max_size_mb must be positive") }
if l.Rotation.MaxBackups < 0 { return fmt.Errorf("log.rotation.max_backups must be non-negative") }
if l.Rotation.MaxAgeDays < 0 { return fmt.Errorf("log.rotation.max_age_days must be non-negative") }
if l.Sampling.Enabled {
if l.Sampling.Initial <= 0 { return fmt.Errorf("log.sampling.initial must be positive when sampling is enabled") }
if l.Sampling.Thereafter <= 0 { return fmt.Errorf("log.sampling.thereafter must be positive when sampling is enabled") }
} else {
if l.Sampling.Initial < 0 { return fmt.Errorf("log.sampling.initial must be non-negative") }
if l.Sampling.Thereafter < 0 { return fmt.Errorf("log.sampling.thereafter must be non-negative") }
}
return nil
}
func validateServerURL(frontendURL string) error {
if strings.TrimSpace(frontendURL) == "" { return nil }
if err := ValidateAbsoluteHTTPURL(frontendURL); err != nil {
return fmt.Errorf("server.frontend_url invalid: %w", err)
}
u, err := url.Parse(strings.TrimSpace(frontendURL))
if err != nil { return fmt.Errorf("server.frontend_url invalid: %w", err) }
if u.RawQuery != "" || u.ForceQuery { return fmt.Errorf("server.frontend_url invalid: must not include query") }
if u.User != nil { return fmt.Errorf("server.frontend_url invalid: must not include userinfo") }
warnIfInsecureURL("server.frontend_url", frontendURL)
return nil
}
func validateLinuxDo(lc *LinuxDoConnectConfig) error {
if !lc.Enabled { return nil }
requiredStringFields := map[string]string{
"client_id": lc.ClientID,
"authorize_url": lc.AuthorizeURL,
"token_url": lc.TokenURL,
"userinfo_url": lc.UserInfoURL,
"redirect_url": lc.RedirectURL,
"frontend_redirect_url": lc.FrontendRedirectURL,
}
for k, v := range requiredStringFields {
if strings.TrimSpace(v) == "" {
return fmt.Errorf("linuxdo_connect.%s is required when linuxdo_connect.enabled=true", k)
}
}
method := strings.ToLower(strings.TrimSpace(lc.TokenAuthMethod))
switch method {
case "", "client_secret_post", "client_secret_basic", "none":
default:
return fmt.Errorf("linuxdo_connect.token_auth_method must be one of: client_secret_post/client_secret_basic/none")
}
if method == "none" && !lc.UsePKCE {
return fmt.Errorf("linuxdo_connect.use_pkce must be true when token_auth_method=none")
}
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") &&
strings.TrimSpace(lc.ClientSecret) == "" {
return fmt.Errorf("linuxdo_connect.client_secret is required when enabled=true and token_auth_method is client_secret_post/client_secret_basic")
}
urlsToValidate := []struct{ key, val string }{
{"authorize_url", lc.AuthorizeURL}, {"token_url", lc.TokenURL},
{"userinfo_url", lc.UserInfoURL}, {"redirect_url", lc.RedirectURL},
}
for _, u := range urlsToValidate {
if err := ValidateAbsoluteHTTPURL(u.val); err != nil {
return fmt.Errorf("linuxdo_connect.%s invalid: %w", u.key, err)
}
warnIfInsecureURL("linuxdo_connect."+u.key, u.val)
}
if err := ValidateFrontendRedirectURL(lc.FrontendRedirectURL); err != nil {
return fmt.Errorf("linuxdo_connect.frontend_redirect_url invalid: %w", err)
}
warnIfInsecureURL("linuxdo_connect.frontend_redirect_url", lc.FrontendRedirectURL)
return nil
}
func validateOIDC(oc *OIDCConnectConfig) error {
if !oc.Enabled { return nil }
if strings.TrimSpace(oc.ClientID) == "" { return fmt.Errorf("oidc_connect.client_id is required when enabled=true") }
if strings.TrimSpace(oc.IssuerURL) == "" { return fmt.Errorf("oidc_connect.issuer_url is required when enabled=true") }
if strings.TrimSpace(oc.RedirectURL) == "" { return fmt.Errorf("oidc_connect.redirect_url is required when enabled=true") }
if strings.TrimSpace(oc.FrontendRedirectURL) == "" { return fmt.Errorf("oidc_connect.frontend_redirect_url is required when enabled=true") }
if !scopeContainsOpenID(oc.Scopes) { return fmt.Errorf("oidc_connect.scopes must contain openid") }
method := strings.ToLower(strings.TrimSpace(oc.TokenAuthMethod))
switch method { case "", "client_secret_post", "client_secret_basic", "none": default: return fmt.Errorf("oidc_connect.token_auth_method must be valid") }
if method == "none" && !oc.UsePKCE { return fmt.Errorf("oidc_connect.use_pkce must be true when token_auth_method=none") }
if (method == "" || method == "client_secret_post" || method == "client_secret_basic") && strings.TrimSpace(oc.ClientSecret) == "" {
return fmt.Errorf("oidc_connect.client_secret is required when enabled=true")
}
if oc.ClockSkewSeconds < 0 || oc.ClockSkewSeconds > 600 { return fmt.Errorf("oidc_connect.clock_skew_seconds must be between 0-600") }
if oc.ValidateIDToken && strings.TrimSpace(oc.AllowedSigningAlgs) == "" { return fmt.Errorf("oidc_connect.allowed_signing_algs required when validate_id_token=true") }
// Validate URLs (only if set — discovery can auto-populate these)
for _, u := range []struct{k, v string}{{"issuer_url", oc.IssuerURL}, {"redirect_url", oc.RedirectURL}, {"frontend_redirect_url", oc.FrontendRedirectURL}} {
if err := ValidateAbsoluteHTTPURL(u.v); err != nil { return fmt.Errorf("oidc_connect.%s invalid: %w", u.k, err) }
warnIfInsecureURL("oidc_connect."+u.k, u.v)
}
for _, k := range []string{"discovery_url", "authorize_url", "token_url", "userinfo_url", "jwks_url"} {
v := getOIDCStringField(oc, k)
if v != "" {
if err := ValidateAbsoluteHTTPURL(v); err != nil { return fmt.Errorf("oidc_connect.%s invalid: %w", k, err) }
warnIfInsecureURL("oidc_connect."+k, v)
}
}
if err := ValidateFrontendRedirectURL(oc.FrontendRedirectURL); err != nil { return fmt.Errorf("oidc_connect.frontend_redirect_url invalid: %w", err) }
return nil
}
func getOIDCStringField(oc *OIDCConnectConfig, field string) string {
switch field {
case "discovery_url": return oc.DiscoveryURL
case "authorize_url": return oc.AuthorizeURL
case "token_url": return oc.TokenURL
case "userinfo_url": return oc.UserInfoURL
case "jwks_url": return oc.JWKSURL
}
return ""
}
func validateBilling(b *BillingConfig) error {
if !b.CircuitBreaker.Enabled { return nil }
if b.CircuitBreaker.FailureThreshold <= 0 { return fmt.Errorf("billing.circuit_breaker.failure_threshold must be positive") }
if b.CircuitBreaker.ResetTimeoutSeconds <= 0 { return fmt.Errorf("billing.circuit_breaker.reset_timeout_seconds must be positive") }
if b.CircuitBreaker.HalfOpenRequests <= 0 { return fmt.Errorf("billing.circuit_breaker.half_open_requests must be positive") }
return nil
}
func validateDatabase(d *DatabaseConfig) error {
if d.MaxOpenConns <= 0 { return fmt.Errorf("database.max_open_conns must be positive") }
if d.MaxIdleConns < 0 { return fmt.Errorf("database.max_idle_conns must be non-negative") }
if d.MaxIdleConns > d.MaxOpenConns { return fmt.Errorf("database.max_idle_conns cannot exceed max_open_conns") }
if d.ConnMaxLifetimeMinutes < 0 { return fmt.Errorf("database.conn_max_lifetime_minutes must be non-negative") }
if d.ConnMaxIdleTimeMinutes < 0 { return fmt.Errorf("database.conn_max_idle_time_minutes must be non-negative") }
return nil
}
func validateRedis(r *RedisConfig) error {
if r.DialTimeoutSeconds <= 0 { return fmt.Errorf("redis.dial_timeout_seconds must be positive") }
if r.ReadTimeoutSeconds <= 0 { return fmt.Errorf("redis.read_timeout_seconds must be positive") }
if r.WriteTimeoutSeconds <= 0 { return fmt.Errorf("redis.write_timeout_seconds must be positive") }
if r.PoolSize <= 0 { return fmt.Errorf("redis.pool_size must be positive") }
if r.MinIdleConns < 0 { return fmt.Errorf("redis.min_idle_conns must be non-negative") }
if r.MinIdleConns > r.PoolSize { return fmt.Errorf("redis.min_idle_conns cannot exceed pool_size") }
return nil
}
func validateDashboard(d *DashboardCacheConfig, da *DashboardAggregationConfig) error {
if d.Enabled {
if d.StatsFreshTTLSeconds <= 0 { return fmt.Errorf("dashboard_cache.stats_fresh_ttl_seconds must be positive") }
if d.StatsTTLSeconds <= 0 { return fmt.Errorf("dashboard_cache.stats_ttl_seconds must be positive") }
if d.StatsRefreshTimeoutSeconds <= 0 { return fmt.Errorf("dashboard_cache.stats_refresh_timeout_seconds must be positive") }
if d.StatsFreshTTLSeconds > d.StatsTTLSeconds { return fmt.Errorf("stats_fresh_ttl_seconds must be <= stats_ttl_seconds") }
} else {
if d.StatsFreshTTLSeconds < 0 || d.StatsTTLSeconds < 0 || d.StatsRefreshTimeoutSeconds < 0 {
return fmt.Errorf("dashboard cache fields must be non-negative when disabled")
}
}
return validateDashboardAgg(da)
}
func validateDashboardAgg(da *DashboardAggregationConfig) error {
if !da.Enabled {
// Non-enabled: all fields just need to be non-negative where numeric
if da.IntervalSeconds < 0 || da.LookbackSeconds < 0 || da.BackfillMaxDays < 0 ||
da.Retention.UsageLogsDays < 0 || da.Retention.UsageBillingDedupDays < 0 ||
da.Retention.HourlyDays < 0 || da.Retention.DailyDays < 0 || da.RecomputeDays < 0 {
return fmt.Errorf("dashboard_aggregation numeric fields must be non-negative when disabled")
}
return nil
}
if da.IntervalSeconds <= 0 { return fmt.Errorf("dashboard_aggregation.interval_seconds must be positive") }
if da.LookbackSeconds < 0 { return fmt.Errorf("lookback_seconds must be non-negative") }
if da.BackfillMaxDays < 0 { return fmt.Errorf("backfill_max_days must be non-negative") }
if da.BackfillEnabled && da.BackfillMaxDays == 0 { return fmt.Errorf("backfill_max_days must be positive when backfill_enabled") }
if da.Retention.UsageLogsDays <= 0 { return fmt.Errorf("retention.usage_logs_days must be positive") }
if da.Retention.UsageBillingDedupDays <= 0 { return fmt.Errorf("retention.usage_billing_dedup_days must be positive") }
if da.Retention.UsageBillingDedupDays < da.Retention.UsageLogsDays {
return fmt.Errorf("usage_billing_dedup_days >= usage_logs_days")
}
if da.Retention.HourlyDays <= 0 { return fmt.Errorf("retention.hourly_days must be positive") }
if da.Retention.DailyDays <= 0 { return fmt.Errorf("retention.daily_days must be positive") }
if da.RecomputeDays < 0 { return fmt.Errorf("recompute_days must be non-negative") }
return nil
}
func validateUsageCleanup(uc *UsageCleanupConfig) error {
if !uc.Enabled {
if uc.MaxRangeDays < 0 || uc.BatchSize < 0 || uc.WorkerIntervalSeconds < 0 || uc.TaskTimeoutSeconds < 0 {
return fmt.Errorf("usage_cleanup fields must be non-negative when disabled")
}
return nil
}
if uc.MaxRangeDays <= 0 { return fmt.Errorf("usage_cleanup.max_range_days must be positive") }
if uc.BatchSize <= 0 { return fmt.Errorf("usage_cleanup.batch_size must be positive") }
if uc.WorkerIntervalSeconds <= 0 { return fmt.Errorf("usage_cleanup.worker_interval_seconds must be positive") }
if uc.TaskTimeoutSeconds <= 0 { return fmt.Errorf("usage_cleanup.task_timeout_seconds must be positive") }
return nil
}
func validateIdempotency(i *IdempotencyConfig) error {
if i.DefaultTTLSeconds <= 0 { return fmt.Errorf("default_ttl_seconds must be positive") }
if i.SystemOperationTTLSeconds <= 0 { return fmt.Errorf("system_operation_ttl_seconds must be positive") }
if i.ProcessingTimeoutSeconds <= 0 { return fmt.Errorf("processing_timeout_seconds must be positive") }
if i.FailedRetryBackoffSeconds <= 0 { return fmt.Errorf("failed_retry_backoff_seconds must be positive") }
if i.MaxStoredResponseLen <= 0 { return fmt.Errorf("max_stored_response_len must be positive") }
if i.CleanupIntervalSeconds <= 0 { return fmt.Errorf("cleanup_interval_seconds must be positive") }
if i.CleanupBatchSize <= 0 { return fmt.Errorf("cleanup_batch_size must be positive") }
return nil
}
func validateOps(o *OpsConfig) error {
if o.MetricsCollectorCache.TTL < 0 { return fmt.Errorf("ops.metrics_collector_cache.ttl must be non-negative") }
if o.Cleanup.ErrorLogRetentionDays < 0 || o.Cleanup.MinuteMetricsRetentionDays < 0 || o.Cleanup.HourlyMetricsRetentionDays < 0 {
return fmt.Errorf("ops cleanup retention days must be non-negative")
}
if o.Cleanup.Enabled && strings.TrimSpace(o.Cleanup.Schedule) == "" {
return fmt.Errorf("ops.cleanup.schedule is required when enabled=true")
}
return nil
}
func validateConcurrency(c *ConcurrencyConfig) error {
if c.PingInterval < 5 || c.PingInterval > 30 {
return fmt.Errorf("concurrency.ping_interval must be between 5-30 seconds")
}
return nil
}

View File

@@ -0,0 +1,159 @@
package config
import (
"fmt"
"log/slog"
"strings"
)
func validateGateway(g *GatewayConfig) error {
if g.MaxBodySize <= 0 { return fmt.Errorf("gateway.max_body_size must be positive") }
if g.UpstreamResponseReadMaxBytes <= 0 { return fmt.Errorf("upstream_response_read_max_bytes must be positive") }
if g.ProxyProbeResponseReadMaxBytes <= 0 { return fmt.Errorf("proxy_probe_response_read_max_bytes must be positive") }
if strings.TrimSpace(g.ConnectionPoolIsolation) != "" {
switch g.ConnectionPoolIsolation {
case ConnectionPoolIsolationProxy, ConnectionPoolIsolationAccount, ConnectionPoolIsolationAccountProxy:
default: return fmt.Errorf("invalid connection_pool_isolation")
}
}
if g.MaxIdleConns <= 0 || g.MaxIdleConnsPerHost <= 0 || g.MaxConnsPerHost < 0 {
return fmt.Errorf("gateway connection pool fields invalid")
}
if g.IdleConnTimeoutSeconds <= 0 { return fmt.Errorf("idle_conn_timeout_seconds must be positive") }
if g.IdleConnTimeoutSeconds > 180 { slog.Warn("idle_conn_timeout_seconds is high; consider 60-120") }
if g.MaxUpstreamClients <= 0 { return fmt.Errorf("max_upstream_clients must be positive") }
if g.ClientIdleTTLSeconds <= 0 { return fmt.Errorf("client_idle_ttl_seconds must be positive") }
if g.ConcurrencySlotTTLMinutes <= 0 { return fmt.Errorf("concurrency_slot_ttl_minutes must be positive") }
if err := validateGatewayStream(g); err != nil { return err }
if err := validateGatewayOpenAIWS(&g.OpenAIWS); err != nil { return err }
if err := validateGatewayUsageRecord(&g.UsageRecord); err != nil { return err }
if err := validateGatewayScheduling(&g.Scheduling); err != nil { return err }
if g.UserGroupRateCacheTTLSeconds <= 0 { return fmt.Errorf("user_group_rate_cache_ttl_seconds must be positive") }
if g.ModelsListCacheTTLSeconds < 10 || g.ModelsListCacheTTLSeconds > 30 {
return fmt.Errorf("models_list_cache_ttl_seconds must be between 10-30")
}
return nil
}
func validateGatewayStream(g *GatewayConfig) error {
if g.StreamDataIntervalTimeout < 0 { return fmt.Errorf("stream_data_interval_timeout must be non-negative") }
if g.StreamDataIntervalTimeout != 0 && (g.StreamDataIntervalTimeout < 30 || g.StreamDataIntervalTimeout > 300) {
return fmt.Errorf("stream_data_interval_timeout must be 0 or between 30-300 seconds")
}
if g.StreamKeepaliveInterval < 0 { return fmt.Errorf("stream_keepalive_interval must be non-negative") }
if g.StreamKeepaliveInterval != 0 && (g.StreamKeepaliveInterval < 5 || g.StreamKeepaliveInterval > 30) {
return fmt.Errorf("stream_keepalive_interval must be 0 or between 5-30 seconds")
}
if g.MaxLineSize < 0 { return fmt.Errorf("max_line_size must be non-negative") }
if g.MaxLineSize != 0 && g.MaxLineSize < 1024*1024 { return fmt.Errorf("max_line_size must be at least 1MB") }
return nil
}
func validateGatewayOpenAIWS(ws *GatewayOpenAIWSConfig) error {
// Basic numeric checks
checks := []struct{ name string; val int64 }{
{"max_conns_per_account", int64(ws.MaxConnsPerAccount)},
{"dial_timeout_seconds", int64(ws.DialTimeoutSeconds)},
{"read_timeout_seconds", int64(ws.ReadTimeoutSeconds)},
{"write_timeout_seconds", int64(ws.WriteTimeoutSeconds)},
{"pool_target_utilization", int64(ws.PoolTargetUtilization * 100)}, // scale for comparison
{"queue_limit_per_conn", int64(ws.QueueLimitPerConn)},
{"event_flush_batch_size", int64(ws.EventFlushBatchSize)},
{"lb_top_k", int64(ws.LBTopK)},
{"sticky_session_ttl_seconds", int64(ws.StickySessionTTLSeconds)},
{"sticky_response_id_ttl_seconds", int64(ws.StickyResponseIDTTLSeconds)},
{"max_conns_per_account", int64(ws.MaxConnsPerAccount)},
}
for _, c := range checks {
if c.val <= 0 { return fmt.Errorf("openai_ws.%s must be positive", c.name) }
}
if ws.MinIdlePerAccount < 0 || ws.MaxIdlePerAccount < 0 {
return fmt.Errorf("openai_ws idle per-account fields must be non-negative")
}
if ws.MinIdlePerAccount > ws.MaxIdlePerAccount { return fmt.Errorf("min_idle_per_account must be <= max_idle_per_account") }
if ws.MaxIdlePerAccount > ws.MaxConnsPerAccount { return fmt.Errorf("max_idle_per_account must be <= max_conns_per_account") }
if ws.OAuthMaxConnsFactor <= 0 || ws.APIKeyMaxConnsFactor <= 0 {
return fmt.Errorf("openai_ws conns factor must be positive")
}
if ws.PoolTargetUtilization <= 0 || ws.PoolTargetUtilization > 1 {
return fmt.Errorf("pool_target_utilization must be within (0,1]")
}
if ws.EventFlushIntervalMS < 0 || ws.PrewarmCooldownMS < 0 ||
ws.FallbackCooldownSeconds < 0 || ws.RetryBackoffInitialMS < 0 ||
ws.RetryBackoffMaxMS < 0 || ws.RetryTotalBudgetMS < 0 {
return fmt.Errorf("openai_ws timeout/retry fields must be non-negative")
}
if ws.RetryBackoffInitialMS > 0 && ws.RetryBackoffMaxMS > 0 && ws.RetryBackoffMaxMS < ws.RetryBackoffInitialMS {
return fmt.Errorf("retry_backoff_max_ms >= retry_backoff_initial_ms")
}
if ws.RetryJitterRatio < 0 || ws.RetryJitterRatio > 1 { return fmt.Errorf("retry_jitter_ratio within [0,1]") }
if ws.PayloadLogSampleRate < 0 || ws.PayloadLogSampleRate > 1 { return fmt.Errorf("payload_log_sample_rate within [0,1]") }
if ws.StickyPreviousResponseTTLSeconds < 0 { return fmt.Errorf("sticky_previous_response_ttl_seconds must be non-negative") }
if ws.SchedulerScoreWeights.Priority < 0 || ws.SchedulerScoreWeights.Load < 0 ||
ws.SchedulerScoreWeights.Queue < 0 || ws.SchedulerScoreWeights.ErrorRate < 0 || ws.SchedulerScoreWeights.TTFT < 0 {
return fmt.Errorf("scheduler_score_weights must be non-negative")
}
weightSum := ws.SchedulerScoreWeights.Priority + ws.SchedulerScoreWeights.Load + ws.SchedulerScoreWeights.Queue +
ws.SchedulerScoreWeights.ErrorRate + ws.SchedulerScoreWeights.TTFT
if weightSum <= 0 { return fmt.Errorf("scheduler_score_weights must not all be zero") }
// Ingress mode
if mode := strings.ToLower(strings.TrimSpace(ws.IngressModeDefault)); mode != "" {
switch mode { case "off", "ctx_pool", "passthrough": default: return fmt.Errorf("ingress_mode_default must be off|ctx_pool|passthrough") }
}
if mode := strings.ToLower(strings.TrimSpace(ws.StoreDisabledConnMode)); mode != "" {
switch mode { case "strict", "adaptive", "off": default: return fmt.Errorf("store_disabled_conn_mode must be strict|adaptive|off") }
}
return nil
}
func validateGatewayUsageRecord(ur *GatewayUsageRecordConfig) error {
if ur.WorkerCount <= 0 || ur.QueueSize <= 0 || ur.TaskTimeoutSeconds <= 0 {
return fmt.Errorf("usage_record worker/queue/timeout must be positive")
}
switch ur.OverflowPolicy {
case UsageRecordOverflowPolicyDrop, UsageRecordOverflowPolicySample, UsageRecordOverflowPolicySync:
default: return fmt.Errorf("invalid overflow_policy")
}
if ur.OverflowSamplePercent < 0 || ur.OverflowSamplePercent > 100 {
return fmt.Errorf("overflow_sample_percent must be 0-100")
}
if strings.EqualFold(ur.OverflowPolicy, UsageRecordOverflowPolicySample) && ur.OverflowSamplePercent <= 0 {
return fmt.Errorf("overflow_sample_percent must be positive when policy=sample")
}
if !ur.AutoScaleEnabled { return nil }
if ur.AutoScaleMinWorkers <= 0 || ur.AutoScaleMaxWorkers <= 0 { return fmt.Errorf("auto_scale workers must be positive") }
if ur.AutoScaleMaxWorkers < ur.AutoScaleMinWorkers { return fmt.Errorf("auto_scale_max >= auto_scale_min") }
if ur.WorkerCount < ur.AutoScaleMinWorkers || ur.WorkerCount > ur.AutoScaleMaxWorkers {
return fmt.Errorf("worker_count between auto_scale_min and max")
}
if ur.AutoScaleUpQueuePercent <= 0 || ur.AutoScaleUpQueuePercent > 100 { return fmt.Errorf("auto_scale_up_queue_percent 1-100") }
if ur.AutoScaleDownQueuePercent < 0 || ur.AutoScaleDownQueuePercent >= 100 { return fmt.Errorf("auto_scale_down_queue_percent 0-99") }
if ur.AutoScaleDownQueuePercent >= ur.AutoScaleUpQueuePercent { return fmt.Errorf("down_queue_percent < up_queue_percent") }
if ur.AutoScaleUpStep <= 0 || ur.AutoScaleDownStep <= 0 { return fmt.Errorf("auto_scale steps must be positive") }
if ur.AutoScaleCheckIntervalSeconds <= 0 { return fmt.Errorf("auto_scale_check_interval_seconds must be positive") }
if ur.AutoScaleCooldownSeconds < 0 { return fmt.Errorf("auto_scale_cooldown_seconds must be non-negative") }
return nil
}
func validateGatewayScheduling(s *GatewaySchedulingConfig) error {
if s.StickySessionMaxWaiting <= 0 || s.StickySessionWaitTimeout <= 0 ||
s.FallbackWaitTimeout <= 0 || s.FallbackMaxWaiting <= 0 ||
s.SnapshotMGetChunkSize <= 0 || s.SnapshotWriteChunkSize <= 0 {
return fmt.Errorf("scheduling core fields must be positive")
}
if s.SlotCleanupInterval < 0 || s.DbFallbackTimeoutSeconds < 0 || s.DbFallbackMaxQPS < 0 {
return fmt.Errorf("scheduling optional fields must be non-negative")
}
if s.OutboxPollIntervalSeconds <= 0 || s.OutboxLagRebuildFailures <= 0 || s.OutboxBacklogRebuildRows < 0 {
return fmt.Errorf("outbox fields must be non-negative or positive as documented")
}
if s.OutboxLagWarnSeconds < 0 || s.OutboxLagRebuildSeconds < 0 || s.FullRebuildIntervalSeconds < 0 {
return fmt.Errorf("outbox timing fields must be non-negative")
}
if s.OutboxLagWarnSeconds > 0 && s.OutboxLagRebuildSeconds > 0 && s.OutboxLagRebuildSeconds < s.OutboxLagWarnSeconds {
return fmt.Errorf("outbox_lag_rebuild_seconds >= outbox_lag_warn_seconds")
}
return nil
}

View File

@@ -0,0 +1,618 @@
package config
import (
"fmt"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
// =============================================================================
// Test: config_validate.go — All Validation Functions
// 覆盖: Validate, validateJWT, validateLog, validateServerURL,
// validateLinuxDo, validateOIDC, validateBilling, validateDatabase,
// validateRedis, validateDashboard, validateDashboardAgg,
// validateUsageCleanup, validateIdempotency, validateOps,
// validateConcurrency
// =============================================================================
// --- validateJWT ---
func TestValidateJWT(t *testing.T) {
tests := []struct {
name string
cfg JWTConfig
wantErr bool
errContains string
}{
{
name: "valid config",
cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, RefreshTokenExpireDays: 30},
wantErr: false,
},
{
name: "empty secret",
cfg: JWTConfig{Secret: "", ExpireHour: 24},
wantErr: true,
errContains: "jwt.secret is required",
},
{
name: "secret too short (<32 bytes)",
cfg: JWTConfig{Secret: "short", ExpireHour: 24},
wantErr: true,
errContains: "jwt.secret must be at least 32 bytes",
},
{
name: "secret exactly 32 bytes (valid)",
cfg: JWTConfig{Secret: strings.Repeat("a", 32), ExpireHour: 24},
wantErr: false,
},
{
name: "expire_hour zero or negative",
cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 0},
wantErr: true,
errContains: "jwt.expire_hour must be positive",
},
{
name: "expire_hour exceeds max (168)",
cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 169},
wantErr: true,
errContains: "jwt.expire_hour must be <= 168",
},
{
name: "expire_hour exactly 168 (7 days)",
cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 168},
wantErr: false,
},
{
name: "access_token_expire_minutes negative",
cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, AccessTokenExpireMinutes: -1},
wantErr: true,
errContains: "jwt.access_token_expire_minutes must be non-negative",
},
{
name: "access_token_expire_minutes too high (>720)",
cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, AccessTokenExpireMinutes: 721},
wantErr: false, // only warns, not errors
},
{
name: "refresh_token_expire_days zero",
cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, RefreshTokenExpireDays: 0},
wantErr: true,
errContains: "jwt.refresh_token_expire_days must be positive",
},
{
name: "refresh_token_expire_days >90 warns but passes",
cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, RefreshTokenExpireDays: 91},
wantErr: false, // only warns
},
{
name: "refresh_window_minutes negative",
cfg: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, RefreshWindowMinutes: -1},
wantErr: true,
errContains: "jwt.refresh_window_minutes must be non-negative",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := validateJWT(&tc.cfg)
if tc.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.errContains)
} else {
assert.NoError(t, err)
}
})
}
}
// --- validateLog ---
func TestValidateLog(t *testing.T) {
validLog := LogConfig{
Level: "info", Format: "json", StacktraceLevel: "error",
Output: LogOutputConfig{ToStdout: true, ToFile: false},
Rotation: LogRotationConfig{MaxSizeMB: 100},
}
tests := []struct {
name string
cfg LogConfig
wantErr bool
errContains string
}{
{"valid", validLog, false, ""},
{"empty level", func() LogConfig { c := validLog; c.Level = ""; return c }(), true, "log.level is required"},
{"invalid level", func() LogConfig { c := validLog; c.Level = "verbose"; return c }(), true, "log.level must be one of"},
{"valid levels", func() LogConfig { c := validLog; c.Level = "debug"; return c }(), false, ""}, // debug is valid
{"valid level warn", func() LogConfig { c := validLog; c.Level = "warn"; return c }(), false, ""},
{"empty format", func() LogConfig { c := validLog; c.Format = ""; return c }(), true, "log.format is required"},
{"invalid format", func() LogConfig { c := validLog; c.Format = "xml"; return c }(), true, "log.format must be one of"},
{"both output false", func() LogConfig { c := validLog; c.Output.ToStdout = false; c.Output.ToFile = false; return c }(), true, "cannot both be false"},
{"max_size_mb zero", func() LogConfig { c := validLog; c.Rotation.MaxSizeMB = 0; return c }(), true, "must be positive"},
{"max_backups negative", func() LogConfig { c := validLog; c.Rotation.MaxBackups = -1; return c }(), true, "non-negative"},
{"max_age_days negative", func() LogConfig { c := validLog; c.Rotation.MaxAgeDays = -1; return c }(), true, "non-negative"},
{"sampling enabled with zero initial", func() LogConfig { c := validLog; c.Sampling.Enabled = true; c.Sampling.Initial = 0; return c }(), true, "must be positive when sampling"},
{"sampling disabled negative thereafter", func() LogConfig { c := validLog; c.Sampling.Thereafter = -1; return c }(), true, "non-negative"},
{"stacktrace empty", func() LogConfig { c := validLog; c.StacktraceLevel = ""; return c }(), true, "stacktrace_level is required"},
{"invalid stacktrace", func() LogConfig { c := validLog; c.StacktraceLevel = "warn"; return c }(), true, "stacktrace_level must be one of"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := validateLog(&tc.cfg)
if tc.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.errContains)
} else {
assert.NoError(t, err)
}
})
}
}
// --- validateDatabase ---
func TestValidateDatabase(t *testing.T) {
tests := []struct {
name string
cfg DatabaseConfig
wantErr bool
errContains string
}{
{"valid", DatabaseConfig{MaxOpenConns: 10, MaxIdleConns: 5, ConnMaxLifetimeMinutes: 30, ConnMaxIdleTimeMinutes: 5}, false, ""},
{"max_open_conns zero", DatabaseConfig{MaxOpenConns: 0}, true, "must be positive"},
{"max_idle_conns negative", DatabaseConfig{MaxOpenConns: 10, MaxIdleConns: -1}, true, "non-negative"},
{"idle > open", DatabaseConfig{MaxOpenConns: 5, MaxIdleConns: 10}, true, "cannot exceed max_open_conns"},
{"conn_max_lifetime negative", DatabaseConfig{MaxOpenConns: 10, ConnMaxLifetimeMinutes: -1}, true, "non-negative"},
{"conn_max_idle_time negative", DatabaseConfig{MaxOpenConns: 10, ConnMaxIdleTimeMinutes: -1}, true, "non-negative"},
{"all zero is valid for idle conn/time", DatabaseConfig{MaxOpenConns: 10, MaxIdleConns: 0, ConnMaxLifetimeMinutes: 0, ConnMaxIdleTimeMinutes: 0}, false, ""},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := validateDatabase(&tc.cfg)
if tc.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.errContains)
} else {
assert.NoError(t, err)
}
})
}
}
// --- validateRedis ---
func TestValidateRedis(t *testing.T) {
tests := []struct {
name string
cfg RedisConfig
wantErr bool
errContains string
}{
{"valid", RedisConfig{DialTimeoutSeconds: 5, ReadTimeoutSeconds: 3, WriteTimeoutSeconds: 3, PoolSize: 100, MinIdleConns: 10}, false, ""},
{"dial_timeout zero", RedisConfig{}, true, "dial_timeout_seconds must be positive"},
{"read_timeout zero", RedisConfig{DialTimeoutSeconds: 5}, true, "read_timeout_seconds must be positive"},
{"write_timeout zero", RedisConfig{DialTimeoutSeconds: 5, ReadTimeoutSeconds: 3}, true, "write_timeout_seconds must be positive"},
{"pool_size zero", RedisConfig{DialTimeoutSeconds: 5, ReadTimeoutSeconds: 3, WriteTimeoutSeconds: 3}, true, "pool_size must be positive"},
{"min_idle negative", RedisConfig{DialTimeoutSeconds: 5, ReadTimeoutSeconds: 3, WriteTimeoutSeconds: 3, PoolSize: 100, MinIdleConns: -1}, true, "non-negative"},
{"min_idle > pool_size", RedisConfig{PoolSize: 10, MinIdleConns: 20, DialTimeoutSeconds: 5, ReadTimeoutSeconds: 3, WriteTimeoutSeconds: 3}, true, "cannot exceed pool_size"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
err := validateRedis(&tc.cfg)
if tc.wantErr {
assert.Error(t, err)
assert.Contains(t, err.Error(), tc.errContains)
} else {
assert.NoError(t, err)
}
})
}
}
// --- validateBilling ---
func TestValidateBilling(t *testing.T) {
t.Run("disabled passes", func(t *testing.T) {
assert.NoError(t, validateBilling(&BillingConfig{}))
})
t.Run("enabled with valid values", func(t *testing.T) {
bc := BillingConfig{CircuitBreaker: CircuitBreakerConfig{
Enabled: true, FailureThreshold: 5, ResetTimeoutSeconds: 30, HalfOpenRequests: 3,
}}
assert.NoError(t, validateBilling(&bc))
})
t.Run("enabled failure_threshold zero", func(t *testing.T) {
bc := BillingConfig{CircuitBreaker: CircuitBreakerConfig{Enabled: true, FailureThreshold: 0}}
err := validateBilling(&bc)
assert.Error(t, err)
assert.Contains(t, err.Error(), "failure_threshold must be positive")
})
t.Run("enabled reset_timeout zero", func(t *testing.T) {
bc := BillingConfig{CircuitBreaker: CircuitBreakerConfig{Enabled: true, FailureThreshold: 5, ResetTimeoutSeconds: 0}}
err := validateBilling(&bc)
assert.Error(t, err)
assert.Contains(t, err.Error(), "reset_timeout_seconds must be positive")
})
t.Run("enabled half_open_requests zero", func(t *testing.T) {
bc := BillingConfig{CircuitBreaker: CircuitBreakerConfig{
Enabled: true, FailureThreshold: 5, ResetTimeoutSeconds: 30, HalfOpenRequests: 0,
}}
err := validateBilling(&bc)
assert.Error(t, err)
assert.Contains(t, err.Error(), "half_open_requests must be positive")
})
}
// --- validateIdempotency ---
func TestValidateIdempotency(t *testing.T) {
valid := IdempotencyConfig{
DefaultTTLSeconds: 86400, SystemOperationTTLSeconds: 3600,
ProcessingTimeoutSeconds: 30, FailedRetryBackoffSeconds: 5,
MaxStoredResponseLen: 65536, CleanupIntervalSeconds: 60, CleanupBatchSize: 500,
}
assert.NoError(t, validateIdempotency(&valid))
fieldsToZero := []string{
"DefaultTTLSeconds", "SystemOperationTTLSeconds", "ProcessingTimeoutSeconds",
"FailedRetryBackoffSeconds", "MaxStoredResponseLen", "CleanupIntervalSeconds", "CleanupBatchSize",
}
for _, f := range fieldsToZero {
f := f
t.Run(f+"_zero", func(t *testing.T) {
c := valid
switch f {
case "DefaultTTLSeconds": c.DefaultTTLSeconds = 0
case "SystemOperationTTLSeconds": c.SystemOperationTTLSeconds = 0
case "ProcessingTimeoutSeconds": c.ProcessingTimeoutSeconds = 0
case "FailedRetryBackoffSeconds": c.FailedRetryBackoffSeconds = 0
case "MaxStoredResponseLen": c.MaxStoredResponseLen = 0
case "CleanupIntervalSeconds": c.CleanupIntervalSeconds = 0
case "CleanupBatchSize": c.CleanupBatchSize = 0
}
err := validateIdempotency(&c)
assert.Error(t, err, "%s=0 should error", f)
})
}
}
// --- validateUsageCleanup ---
func TestValidateUsageCleanup(t *testing.T) {
valid := UsageCleanupConfig{Enabled: true, MaxRangeDays: 31, BatchSize: 5000, WorkerIntervalSeconds: 10, TaskTimeoutSeconds: 1800}
assert.NoError(t, validateUsageCleanup(&valid))
t.Run("disabled with non-negative values passes", func(t *testing.T) {
uc := UsageCleanupConfig{Enabled: false, MaxRangeDays: 0, BatchSize: 0, WorkerIntervalSeconds: 0, TaskTimeoutSeconds: 0}
assert.NoError(t, validateUsageCleanup(&uc))
})
t.Run("disabled with negative value fails", func(t *testing.T) {
uc := UsageCleanupConfig{Enabled: false, MaxRangeDays: -1}
err := validateUsageCleanup(&uc)
assert.Error(t, err)
assert.Contains(t, err.Error(), "non-negative")
})
t.Run("enabled max_range_days zero", func(t *testing.T) {
uc := UsageCleanupConfig{Enabled: true, MaxRangeDays: 0, BatchSize: 1, WorkerIntervalSeconds: 1, TaskTimeoutSeconds: 1}
err := validateUsageCleanup(&uc)
assert.Error(t, err)
assert.Contains(t, err.Error(), "must be positive")
})
}
// --- validateOps ---
func TestValidateOps(t *testing.T) {
valid := OpsConfig{Cleanup: OpsCleanupConfig{Enabled: true, Schedule: "0 2 * * *"}}
assert.NoError(t, validateOps(&valid))
t.Run("negative metrics cache TTL", func(t *testing.T) {
o := OpsConfig{MetricsCollectorCache: OpsMetricsCollectorCacheConfig{TTL: -1}}
err := validateOps(&o)
assert.Error(t, err)
assert.Contains(t, err.Error(), "non-negative")
})
t.Run("negative retention days", func(t *testing.T) {
o := OpsConfig{Cleanup: OpsCleanupConfig{ErrorLogRetentionDays: -1}}
err := validateOps(&o)
assert.Error(t, err)
assert.Contains(t, err.Error(), "non-negative")
})
t.Run("enabled cleanup without schedule", func(t *testing.T) {
o := OpsConfig{Cleanup: OpsCleanupConfig{Enabled: true, Schedule: ""}}
err := validateOps(&o)
assert.Error(t, err)
assert.Contains(t, err.Error(), "schedule is required")
})
}
// --- validateConcurrency ---
func TestValidateConcurrency(t *testing.T) {
tests := []struct {
pingInterval int
wantErr bool
}{
{5, false}, // min boundary
{10, false}, // normal
{30, false}, // max boundary
{4, true}, // below min
{31, true}, // above max
{0, true},
{-1, true},
}
for _, tc := range tests {
tc := tc
t.Run(fmt.Sprintf("ping_interval=%d", tc.pingInterval), func(t *testing.T) {
c := ConcurrencyConfig{PingInterval: tc.pingInterval}
err := validateConcurrency(&c)
if tc.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
// --- validateServerURL ---
func TestValidateServerURL(t *testing.T) {
t.Run("empty URL passes", func(t *testing.T) {
assert.NoError(t, validateServerURL(""))
})
t.Run("valid https URL", func(t *testing.T) {
assert.NoError(t, validateServerURL("https://app.example.com"))
})
t.Run("valid http URL (with warning)", func(t *testing.T) {
assert.NoError(t, validateServerURL("http://localhost:3000"))
})
t.Run("URL with query fails", func(t *testing.T) {
err := validateServerURL("https://example.com?foo=bar")
assert.Error(t, err)
assert.Contains(t, err.Error(), "must not include query")
})
t.Run("URL with userinfo fails", func(t *testing.T) {
err := validateServerURL("https://user:pass@example.com")
assert.Error(t, err)
assert.Contains(t, err.Error(), "must not include userinfo")
})
t.Run("invalid scheme fails", func(t *testing.T) {
err := validateServerURL("ftp://example.com")
assert.Error(t, err)
})
}
// --- validateLinuxDo ---
func TestValidateLinuxDo(t *testing.T) {
validCfg := LinuxDoConnectConfig{
Enabled: true, ClientID: "id", AuthorizeURL: "https://a.com/auth",
TokenURL: "https://a.com/token", UserInfoURL: "https://a.com/user",
RedirectURL: "https://a.com/cb", FrontendRedirectURL: "/cb", ClientSecret: "secret",
}
t.Run("disabled passes", func(t *testing.T) {
assert.NoError(t, validateLinuxDo(&LinuxDoConnectConfig{}))
})
t.Run("enabled valid passes", func(t *testing.T) {
assert.NoError(t, validateLinuxDo(&validCfg))
})
t.Run("enabled missing client_id", func(t *testing.T) {
c := validCfg; c.ClientID = ""
err := validateLinuxDo(&c)
assert.Error(t, err)
assert.Contains(t, err.Error(), "client_id is required")
})
t.Run("enabled invalid authorize_url", func(t *testing.T) {
c := validCfg; c.AuthorizeURL = "not-a-url"
err := validateLinuxDo(&c)
assert.Error(t, err)
})
t.Run("token_auth_method none requires PKCE", func(t *testing.T) {
c := validCfg; c.TokenAuthMethod = "none"; c.UsePKCE = false; c.ClientSecret = ""
err := validateLinuxDo(&c)
assert.Error(t, err)
assert.Contains(t, err.Error(), "use_pkce must be true")
})
t.Run("client_secret_post requires client_secret", func(t *testing.T) {
c := validCfg; c.TokenAuthMethod = "client_secret_post"; c.ClientSecret = ""
err := validateLinuxDo(&c)
assert.Error(t, err)
assert.Contains(t, err.Error(), "client_secret is required")
})
t.Run("invalid token_auth_method", func(t *testing.T) {
c := validCfg; c.TokenAuthMethod = "bearer"
err := validateLinuxDo(&c)
assert.Error(t, err)
assert.Contains(t, err.Error(), "token_auth_method must be")
})
}
// --- validateOIDC ---
func TestValidateOIDC(t *testing.T) {
validOIDC := OIDCConnectConfig{
Enabled: true, ClientID: "id", IssuerURL: "https://idp.example.com",
RedirectURL: "https://app.com/cb", FrontendRedirectURL: "/oidc/cb",
ClientSecret: "secret", Scopes: "openid email profile",
}
t.Run("disabled passes", func(t *testing.T) {
assert.NoError(t, validateOIDC(&OIDCConnectConfig{}))
})
t.Run("enabled valid passes", func(t *testing.T) {
assert.NoError(t, validateOIDC(&validOIDC))
})
t.Run("missing openid scope", func(t *testing.T) {
c := validOIDC; c.Scopes = "email profile"
err := validateOIDC(&c)
assert.Error(t, err)
assert.Contains(t, err.Error(), "must contain openid")
})
t.Run("clock_skew out of range", func(t *testing.T) {
c := validOIDC; c.ClockSkewSeconds = 700
err := validateOIDC(&c)
assert.Error(t, err)
assert.Contains(t, err.Error(), "between 0-600")
})
t.Run("validate_id_token requires allowed_signing_algs", func(t *testing.T) {
c := validOIDC; c.ValidateIDToken = true; c.AllowedSigningAlgs = ""
err := validateOIDC(&c)
assert.Error(t, err)
assert.Contains(t, err.Error(), "allowed_signing_algs required")
})
t.Run("missing issuer_url", func(t *testing.T) {
c := validOIDC; c.IssuerURL = ""
err := validateOIDC(&c)
assert.Error(t, err)
assert.Contains(t, err.Error(), "issuer_url is required")
})
}
// --- validateDashboard & validateDashboardAgg ---
func TestValidateDashboard(t *testing.T) {
validDash := DashboardCacheConfig{
Enabled: true, StatsFreshTTLSeconds: 15, StatsTTLSeconds: 30, StatsRefreshTimeoutSeconds: 30,
}
validAgg := DashboardAggregationConfig{Enabled: true, IntervalSeconds: 60, LookbackSeconds: 120,
Retention: DashboardAggregationRetentionConfig{UsageLogsDays: 90, UsageBillingDedupDays: 365, HourlyDays: 180, DailyDays: 730},
}
t.Run("enabled dashboard valid", func(t *testing.T) {
assert.NoError(t, validateDashboard(&validDash, &validAgg))
})
t.Run("stats_fresh_ttl > stats_ttl fails", func(t *testing.T) {
d := validDash; d.StatsFreshTTLSeconds = 100; d.StatsTTLSeconds = 50
err := validateDashboard(&d, &validAgg)
assert.Error(t, err)
assert.Contains(t, err.Error(), "stats_fresh_ttl_seconds must be <=")
})
t.Run("disabled dashboard with negatives fails", func(t *testing.T) {
d := DashboardCacheConfig{Enabled: false, StatsFreshTTLSeconds: -1}
err := validateDashboard(&d, &DashboardAggregationConfig{})
assert.Error(t, err)
assert.Contains(t, err.Error(), "non-negative")
})
t.Run("aggregation enabled valid", func(t *testing.T) {
assert.NoError(t, validateDashboardAgg(&validAgg))
})
t.Run("aggregation interval zero when enabled", func(t *testing.T) {
a := validAgg; a.Enabled = true; a.IntervalSeconds = 0
err := validateDashboardAgg(&a)
assert.Error(t, err)
assert.Contains(t, err.Error(), "interval_seconds must be positive")
})
t.Run("billing_dedup < usage_logs fails", func(t *testing.T) {
a := validAgg; a.Retention.UsageLogsDays = 365; a.Retention.UsageBillingDedupDays = 30
err := validateDashboardAgg(&a)
assert.Error(t, err)
assert.Contains(t, err.Error(), "usage_billing_dedup_days >= usage_logs_days")
})
t.Run("backfill_enabled with backfill_max_days=0 fails", func(t *testing.T) {
a := validAgg; a.BackfillEnabled = true; a.BackfillMaxDays = 0
err := validateDashboardAgg(&a)
assert.Error(t, err)
assert.Contains(t, err.Error(), "backfill_max_days must be positive")
})
}
// --- Config.Validate() orchestration test ---
func TestConfigValidate_Orchestration(t *testing.T) {
t.Parallel()
t.Run("fully valid config passes all validators", func(t *testing.T) {
t.Parallel()
cfg := buildValidConfig()
assert.NoError(t, cfg.Validate())
})
t.Run("invalid JWT stops validation early", func(t *testing.T) {
t.Parallel()
cfg := buildValidConfig()
cfg.JWT.Secret = "short"
err := cfg.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "jwt.secret")
})
t.Run("invalid database stops after earlier checks pass", func(t *testing.T) {
t.Parallel()
cfg := buildValidConfig()
cfg.Database.MaxOpenConns = 0
err := cfg.Validate()
assert.Error(t, err)
assert.Contains(t, err.Error(), "database.max_open_conns")
})
}
// Helper to build a fully valid Config for testing
func buildValidConfig() Config {
return Config{
Server: ServerConfig{Host: "0.0.0.0", Port: 8080, Mode: "release"},
Log: LogConfig{Level: "info", Format: "json", StacktraceLevel: "error",
Output: LogOutputConfig{ToStdout: true},
Rotation: LogRotationConfig{MaxSizeMB: 100}},
Security: SecurityConfig{},
Billing: BillingConfig{},
Turnstile: TurnstileConfig{},
Database: DatabaseConfig{Host: "localhost", Port: 5432, User: "u",
DBName: "db", SSLMode: "disable", MaxOpenConns: 256, MaxIdleConns: 128,
ConnMaxLifetimeMinutes: 30, ConnMaxIdleTimeMinutes: 5},
Redis: RedisConfig{Host: "localhost", Port: 6379,
DialTimeoutSeconds: 5, ReadTimeoutSeconds: 3, WriteTimeoutSeconds: 3, PoolSize: 1024},
JWT: JWTConfig{Secret: strings.Repeat("x", 32), ExpireHour: 24, RefreshTokenExpireDays: 30},
LinuxDo: LinuxDoConnectConfig{},
OIDC: OIDCConnectConfig{},
Default: DefaultConfig{},
RateLimit: RateLimitConfig{},
Pricing: PricingConfig{},
Gateway: GatewayConfig{},
APIKeyAuth: APIKeyAuthCacheConfig{},
SubscriptionCache: SubscriptionCacheConfig{},
Dashboard: DashboardCacheConfig{Enabled: true, StatsFreshTTLSeconds: 15, StatsTTLSeconds: 30, StatsRefreshTimeoutSeconds: 30},
DashboardAgg: DashboardAggregationConfig{Enabled: true, IntervalSeconds: 60, LookbackSeconds: 120,
Retention: DashboardAggregationRetentionConfig{UsageLogsDays: 90, UsageBillingDedupDays: 365, HourlyDays: 180, DailyDays: 730}},
UsageCleanup: UsageCleanupConfig{Enabled: true, MaxRangeDays: 31, BatchSize: 5000, WorkerIntervalSeconds: 10, TaskTimeoutSeconds: 1800},
Concurrency: ConcurrencyConfig{PingInterval: 10},
Idempotency: IdempotencyConfig{
DefaultTTLSeconds: 86400, SystemOperationTTLSeconds: 3600, ProcessingTimeoutSeconds: 30,
FailedRetryBackoffSeconds: 5, MaxStoredResponseLen: 65536, CleanupIntervalSeconds: 60, CleanupBatchSize: 500,
},
}
}

View File

@@ -0,0 +1,67 @@
package config
import "fmt"
// DatabaseConfig 数据库连接配置
// 性能优化:新增连接池参数,避免频繁创建/销毁连接
type DatabaseConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
User string `mapstructure:"user"`
Password string `mapstructure:"password"`
DBName string `mapstructure:"dbname"`
SSLMode string `mapstructure:"sslmode"`
// 连接池配置(性能优化:可配置化连接池参数)
MaxOpenConns int `mapstructure:"max_open_conns"` // 最大打开连接数,控制数据库连接上限
MaxIdleConns int `mapstructure:"max_idle_conns"` // 最大空闲连接数,保持热连接减少建连延迟
ConnMaxLifetimeMinutes int `mapstructure:"conn_max_lifetime_minutes"` // 连接最大存活时间
ConnMaxIdleTimeMinutes int `mapstructure:"conn_max_idle_time_minutes"` // 空闲连接最大存活时间
}
func (d *DatabaseConfig) DSN() string {
if d.Password == "" {
return fmt.Sprintf(
"host=%s port=%d user=%s dbname=%s sslmode=%s",
d.Host, d.Port, d.User, d.DBName, d.SSLMode,
)
}
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode,
)
}
// DSNWithTimezone returns DSN with timezone setting
func (d *DatabaseConfig) DSNWithTimezone(tz string) string {
if tz == "" {
tz = "Asia/Shanghai"
}
if d.Password == "" {
return fmt.Sprintf(
"host=%s port=%d user=%s dbname=%s sslmode=%s TimeZone=%s",
d.Host, d.Port, d.User, d.DBName, d.SSLMode, tz,
)
}
return fmt.Sprintf(
"host=%s port=%d user=%s password=%s dbname=%s sslmode=%s TimeZone=%s",
d.Host, d.Port, d.User, d.Password, d.DBName, d.SSLMode, tz,
)
}
// RedisConfig Redis 连接配置
type RedisConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Password string `mapstructure:"password"`
DB int `mapstructure:"db"`
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"` // 建立连接超时
ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"` // 读取超时
WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"` // 写入超时
PoolSize int `mapstructure:"pool_size"` // 连接池大小
MinIdleConns int `mapstructure:"min_idle_conns"` // 最小空闲连接数
EnableTLS bool `mapstructure:"enable_tls"` // 是否启用 TLS/SSL 连接
}
func (r *RedisConfig) Address() string {
return fmt.Sprintf("%s:%d", r.Host, r.Port)
}

View File

@@ -0,0 +1,105 @@
package config
import "time"
// GatewayConfig API网关相关配置
type GatewayConfig struct {
ResponseHeaderTimeout int `mapstructure:"response_header_timeout"` // 等待上游响应头超时0=无超时
MaxBodySize int64 `mapstructure:"max_body_size"` // 请求体最大字节数
UpstreamResponseReadMaxBytes int64 `mapstructure:"upstream_response_read_max_bytes"` // 非流式上游响应体读取上限
ProxyProbeResponseReadMaxBytes int64 `mapstructure:"proxy_probe_response_read_max_bytes"` // 代理探测响应读取上限
GeminiDebugResponseHeaders bool `mapstructure:"gemini_debug_response_headers"`
ConnectionPoolIsolation string `mapstructure:"connection_pool_isolation"` // proxy/account/account_proxy
ForceCodexCLI bool `mapstructure:"force_codex_cli"`
ForcedCodexInstructionsTemplateFile string `mapstructure:"forced_codex_instructions_template_file"`
ForcedCodexInstructionsTemplate string `mapstructure:"-"`
OpenAIPassthroughAllowTimeoutHeaders bool `mapstructure:"openai_passthrough_allow_timeout_headers"`
OpenAIWS GatewayOpenAIWSConfig `mapstructure:"openai_ws"`
// HTTP 上游连接池配置
MaxIdleConns int `mapstructure:"max_idle_conns"`
MaxIdleConnsPerHost int `mapstructure:"max_idle_conns_per_host"`
MaxConnsPerHost int `mapstructure:"max_conns_per_host"`
IdleConnTimeoutSeconds int `mapstructure:"idle_conn_timeout_seconds"`
MaxUpstreamClients int `mapstructure:"max_upstream_clients"`
ClientIdleTTLSeconds int `mapstructure:"client_idle_ttl_seconds"`
ConcurrencySlotTTLMinutes int `mapstructure:"concurrency_slot_ttl_minutes"`
SessionIdleTimeoutMinutes int `mapstructure:"session_idle_timeout_seconds"`
StreamDataIntervalTimeout int `mapstructure:"stream_data_interval_timeout"` // 流数据间隔超时0=禁用
StreamKeepaliveInterval int `mapstructure:"stream_keepalive_interval"` // 流式 keepalive 间隔0=禁用
MaxLineSize int `mapstructure:"max_line_size"` // SSE 单行最大字节数
LogUpstreamErrorBody bool `mapstructure:"log_upstream_error_body"`
LogUpstreamErrorBodyMaxBytes int `mapstructure:"log_upstream_error_body_max_bytes"`
InjectBetaForAPIKey bool `mapstructure:"inject_beta_for_apikey"`
FailoverOn400 bool `mapstructure:"failover_on_400"`
// Sora 专用配置
SoraMaxBodySize int64 `mapstructure:"sora_max_body_size"`
SoraStreamTimeoutSeconds int `mapstructure:"sora_stream_timeout_seconds"`
SoraRequestTimeoutSeconds int `mapstructure:"sora_request_timeout_seconds"`
SoraStreamMode string `mapstructure:"sora_stream_mode"`
SoraModelFilters SoraModelFiltersConfig `mapstructure:"sora_model_filters"`
SoraMediaRequireAPIKey bool `mapstructure:"sora_media_require_api_key"`
SoraMediaSigningKey string `mapstructure:"sora_media_signing_key"`
SoraMediaSignedURLTTLSeconds int `mapstructure:"sora_media_signed_url_ttl_seconds"`
MaxAccountSwitches int `mapstructure:"max_account_switches"`
MaxAccountSwitchesGemini int `mapstructure:"max_account_switches_gemini"`
AntigravityFallbackCooldownMinutes int `mapstructure:"antigravity_fallback_cooldown_minutes"`
Scheduling GatewaySchedulingConfig `mapstructure:"scheduling"`
TLSFingerprint TLSFingerprintConfig `mapstructure:"tls_fingerprint"`
UsageRecord GatewayUsageRecordConfig `mapstructure:"usage_record"`
UserGroupRateCacheTTLSeconds int `mapstructure:"user_group_rate_cache_ttl_seconds"`
ModelsListCacheTTLSeconds int `mapstructure:"models_list_cache_ttl_seconds"`
UserMessageQueue UserMessageQueueConfig `mapstructure:"user_message_queue"`
}
// UserMessageQueueConfig 用户消息串行队列配置
type UserMessageQueueConfig struct {
Mode string `mapstructure:"mode"` // serialize/throttle/""
Enabled bool `mapstructure:"enabled"` // 向后兼容
LockTTLMs int `mapstructure:"lock_ttl_ms"`
WaitTimeoutMs int `mapstructure:"wait_timeout_ms"`
MinDelayMs int `mapstructure:"min_delay_ms"`
MaxDelayMs int `mapstructure:"max_delay_ms"`
CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"`
}
func (c *UserMessageQueueConfig) WaitTimeout() time.Duration {
if c.WaitTimeoutMs <= 0 { return 30 * time.Second }
return time.Duration(c.WaitTimeoutMs) * time.Millisecond
}
func (c *UserMessageQueueConfig) GetEffectiveMode() string {
if c.Mode == UMQModeSerialize || c.Mode == UMQModeThrottle { return c.Mode }
if c.Enabled { return UMQModeSerialize }
return ""
}
// GatewaySchedulingConfig 账号调度相关配置
type GatewaySchedulingConfig struct {
StickySessionMaxWaiting int `mapstructure:"sticky_session_max_waiting"`
StickySessionWaitTimeout time.Duration `mapstructure:"sticky_session_wait_timeout"`
FallbackWaitTimeout time.Duration `mapstructure:"fallback_wait_timeout"`
FallbackMaxWaiting int `mapstructure:"fallback_max_waiting"`
FallbackSelectionMode string `mapstructure:"fallback_selection_mode"`
LoadBatchEnabled bool `mapstructure:"load_batch_enabled"`
SnapshotMGetChunkSize int `mapstructure:"snapshot_mget_chunk_size"`
SnapshotWriteChunkSize int `mapstructure:"snapshot_write_chunk_size"`
SlotCleanupInterval time.Duration `mapstructure:"slot_cleanup_interval"`
DbFallbackEnabled bool `mapstructure:"db_fallback_enabled"`
DbFallbackTimeoutSeconds int `mapstructure:"db_fallback_timeout_seconds"`
DbFallbackMaxQPS int `mapstructure:"db_fallback_max_qps"`
OutboxPollIntervalSeconds int `mapstructure:"outbox_poll_interval_seconds"`
OutboxLagWarnSeconds int `mapstructure:"outbox_lag_warn_seconds"`
OutboxLagRebuildSeconds int `mapstructure:"outbox_lag_rebuild_seconds"`
OutboxLagRebuildFailures int `mapstructure:"outbox_lag_rebuild_failures"`
OutboxBacklogRebuildRows int `mapstructure:"outbox_backlog_rebuild_rows"`
FullRebuildIntervalSeconds int `mapstructure:"full_rebuild_interval_seconds"`
}

View File

@@ -0,0 +1,98 @@
package config
// GatewayOpenAIWSConfig OpenAI Responses WebSocket 配置
type GatewayOpenAIWSConfig struct {
ModeRouterV2Enabled bool `mapstructure:"mode_router_v2_enabled"`
IngressModeDefault string `mapstructure:"ingress_mode_default"`
Enabled bool `mapstructure:"enabled"`
OAuthEnabled bool `mapstructure:"oauth_enabled"`
APIKeyEnabled bool `mapstructure:"apikey_enabled"`
ForceHTTP bool `mapstructure:"force_http"`
AllowStoreRecovery bool `mapstructure:"allow_store_recovery"`
IngressPreviousResponseRecoveryEnabled bool `mapstructure:"ingress_previous_response_recovery_enabled"`
StoreDisabledConnMode string `mapstructure:"store_disabled_conn_mode"`
StoreDisabledForceNewConn bool `mapstructure:"store_disabled_force_new_conn"`
PrewarmGenerateEnabled bool `mapstructure:"prewarm_generate_enabled"`
ResponsesWebsockets bool `mapstructure:"responses_websockets"`
ResponsesWebsocketsV2 bool `mapstructure:"responses_websockets_v2"`
MaxConnsPerAccount int `mapstructure:"max_conns_per_account"`
MinIdlePerAccount int `mapstructure:"min_idle_per_account"`
MaxIdlePerAccount int `mapstructure:"max_idle_per_account"`
DynamicMaxConnsByAccountConcurrencyEnabled bool `mapstructure:"dynamic_max_conns_by_account_concurrency_enabled"`
OAuthMaxConnsFactor float64 `mapstructure:"oauth_max_conns_factor"`
APIKeyMaxConnsFactor float64 `mapstructure:"apikey_max_conns_factor"`
DialTimeoutSeconds int `mapstructure:"dial_timeout_seconds"`
ReadTimeoutSeconds int `mapstructure:"read_timeout_seconds"`
WriteTimeoutSeconds int `mapstructure:"write_timeout_seconds"`
PoolTargetUtilization float64 `mapstructure:"pool_target_utilization"`
QueueLimitPerConn int `mapstructure:"queue_limit_per_conn"`
EventFlushBatchSize int `mapstructure:"event_flush_batch_size"`
EventFlushIntervalMS int `mapstructure:"event_flush_interval_ms"`
PrewarmCooldownMS int `mapstructure:"prewarm_cooldown_ms"`
FallbackCooldownSeconds int `mapstructure:"fallback_cooldown_seconds"`
RetryBackoffInitialMS int `mapstructure:"retry_backoff_initial_ms"`
RetryBackoffMaxMS int `mapstructure:"retry_backoff_max_ms"`
RetryJitterRatio float64 `mapstructure:"retry_jitter_ratio"`
RetryTotalBudgetMS int `mapstructure:"retry_total_budget_ms"`
PayloadLogSampleRate float64 `mapstructure:"payload_log_sample_rate"`
LBTopK int `mapstructure:"lb_top_k"`
StickySessionTTLSeconds int `mapstructure:"sticky_session_ttl_seconds"`
SessionHashReadOldFallback bool `mapstructure:"session_hash_read_old_fallback"`
SessionHashDualWriteOld bool `mapstructure:"session_hash_dual_write_old"`
MetadataBridgeEnabled bool `mapstructure:"metadata_bridge_enabled"`
StickyResponseIDTTLSeconds int `mapstructure:"sticky_response_id_ttl_seconds"`
StickyPreviousResponseTTLSeconds int `mapstructure:"sticky_previous_response_ttl_seconds"`
SchedulerScoreWeights GatewayOpenAIWSSchedulerScoreWeights `mapstructure:"scheduler_score_weights"`
}
// GatewayOpenAIWSSchedulerScoreWeights 账号调度打分权重
type GatewayOpenAIWSSchedulerScoreWeights struct {
Priority float64 `mapstructure:"priority"`
Load float64 `mapstructure:"load"`
Queue float64 `mapstructure:"queue"`
ErrorRate float64 `mapstructure:"error_rate"`
TTFT float64 `mapstructure:"ttft"`
}
// GatewayUsageRecordConfig 使用量记录异步队列配置
type GatewayUsageRecordConfig struct {
WorkerCount int `mapstructure:"worker_count"`
QueueSize int `mapstructure:"queue_size"`
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
OverflowPolicy string `mapstructure:"overflow_policy"` // drop/sample/sync
OverflowSamplePercent int `mapstructure:"overflow_sample_percent"` // 1-100
AutoScaleEnabled bool `mapstructure:"auto_scale_enabled"`
AutoScaleMinWorkers int `mapstructure:"auto_scale_min_workers"`
AutoScaleMaxWorkers int `mapstructure:"auto_scale_max_workers"`
AutoScaleUpQueuePercent int `mapstructure:"auto_scale_up_queue_percent"`
AutoScaleDownQueuePercent int `mapstructure:"auto_scale_down_queue_percent"`
AutoScaleUpStep int `mapstructure:"auto_scale_up_step"`
AutoScaleDownStep int `mapstructure:"auto_scale_down_step"`
AutoScaleCheckIntervalSeconds int `mapstructure:"auto_scale_check_interval_seconds"`
AutoScaleCooldownSeconds int `mapstructure:"auto_scale_cooldown_seconds"`
}
// TLSFingerprintConfig TLS 指纹伪装配置
type TLSFingerprintConfig struct {
Enabled bool `mapstructure:"enabled"`
Profiles map[string]TLSProfileConfig `mapstructure:"profiles"`
}
// TLSProfileConfig 单个 TLS 指纹模板配置
type TLSProfileConfig struct {
Name string `mapstructure:"name"`
EnableGREASE bool `mapstructure:"enable_grease"`
CipherSuites []uint16 `mapstructure:"cipher_suites"`
Curves []uint16 `mapstructure:"curves"`
PointFormats []uint16 `mapstructure:"point_formats"`
SignatureAlgorithms []uint16 `mapstructure:"signature_algorithms"`
ALPNProtocols []string `mapstructure:"alpn_protocols"`
SupportedVersions []uint16 `mapstructure:"supported_versions"`
KeyShareGroups []uint16 `mapstructure:"key_share_groups"`
PSKModes []uint16 `mapstructure:"psk_modes"`
Extensions []uint16 `mapstructure:"extensions"`
}

View File

@@ -0,0 +1,122 @@
package config
import "time"
// LogConfig 日志配置
type LogConfig struct {
Level string `mapstructure:"level"`
Format string `mapstructure:"format"`
ServiceName string `mapstructure:"service_name"`
Environment string `mapstructure:"env"`
Caller bool `mapstructure:"caller"`
StacktraceLevel string `mapstructure:"stacktrace_level"`
Output LogOutputConfig `mapstructure:"output"`
Rotation LogRotationConfig `mapstructure:"rotation"`
Sampling LogSamplingConfig `mapstructure:"sampling"`
}
type LogOutputConfig struct {
ToStdout bool `mapstructure:"to_stdout"`
ToFile bool `mapstructure:"to_file"`
FilePath string `mapstructure:"file_path"`
}
type LogRotationConfig struct {
MaxSizeMB int `mapstructure:"max_size_mb"`
MaxBackups int `mapstructure:"max_backups"`
MaxAgeDays int `mapstructure:"max_age_days"`
Compress bool `mapstructure:"compress"`
LocalTime bool `mapstructure:"local_time"`
}
type LogSamplingConfig struct {
Enabled bool `mapstructure:"enabled"`
Initial int `mapstructure:"initial"`
Thereafter int `mapstructure:"thereafter"`
}
// OpsConfig 运维监控配置
type OpsConfig struct {
Enabled bool `mapstructure:"enabled"`
UsePreaggregatedTables bool `mapstructure:"use_preaggregated_tables"`
Cleanup OpsCleanupConfig `mapstructure:"cleanup"`
MetricsCollectorCache OpsMetricsCollectorCacheConfig `mapstructure:"metrics_collector_cache"`
Aggregation OpsAggregationConfig `mapstructure:"aggregation"`
}
type OpsCleanupConfig struct {
Enabled bool `mapstructure:"enabled"`
Schedule string `mapstructure:"schedule"`
ErrorLogRetentionDays int `mapstructure:"error_log_retention_days"`
MinuteMetricsRetentionDays int `mapstructure:"minute_metrics_retention_days"`
HourlyMetricsRetentionDays int `mapstructure:"hourly_metrics_retention_days"`
}
type OpsAggregationConfig struct {
Enabled bool `mapstructure:"enabled"`
}
type OpsMetricsCollectorCacheConfig struct {
Enabled bool `mapstructure:"enabled"`
TTL time.Duration `mapstructure:"ttl"`
}
// APIKeyAuthCacheConfig API Key 认证缓存配置
type APIKeyAuthCacheConfig struct {
L1Size int `mapstructure:"l1_size"`
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
L2TTLSeconds int `mapstructure:"l2_ttl_seconds"`
NegativeTTLSeconds int `mapstructure:"negative_ttl_seconds"`
JitterPercent int `mapstructure:"jitter_percent"`
Singleflight bool `mapstructure:"singleflight"`
}
// SubscriptionCacheConfig 订阅认证 L1 缓存配置
type SubscriptionCacheConfig struct {
L1Size int `mapstructure:"l1_size"`
L1TTLSeconds int `mapstructure:"l1_ttl_seconds"`
JitterPercent int `mapstructure:"jitter_percent"`
}
// SubscriptionMaintenanceConfig 订阅窗口维护后台任务配置
type SubscriptionMaintenanceConfig struct {
WorkerCount int `mapstructure:"worker_count"`
QueueSize int `mapstructure:"queue_size"`
}
// DashboardCacheConfig 仪表盘统计缓存配置
type DashboardCacheConfig struct {
Enabled bool `mapstructure:"enabled"`
KeyPrefix string `mapstructure:"key_prefix"`
StatsFreshTTLSeconds int `mapstructure:"stats_fresh_ttl_seconds"`
StatsTTLSeconds int `mapstructure:"stats_ttl_seconds"`
StatsRefreshTimeoutSeconds int `mapstructure:"stats_refresh_timeout_seconds"`
}
// DashboardAggregationConfig 仪表盘预聚合配置
type DashboardAggregationConfig struct {
Enabled bool `mapstructure:"enabled"`
IntervalSeconds int `mapstructure:"interval_seconds"`
LookbackSeconds int `mapstructure:"lookback_seconds"`
BackfillEnabled bool `mapstructure:"backfill_enabled"`
BackfillMaxDays int `mapstructure:"backfill_max_days"`
Retention DashboardAggregationRetentionConfig `mapstructure:"retention"`
RecomputeDays int `mapstructure:"recompute_days"`
}
// DashboardAggregationRetentionConfig 预聚合保留窗口
type DashboardAggregationRetentionConfig struct {
UsageLogsDays int `mapstructure:"usage_logs_days"`
UsageBillingDedupDays int `mapstructure:"usage_billing_dedup_days"`
HourlyDays int `mapstructure:"hourly_days"`
DailyDays int `mapstructure:"daily_days"`
}
// UsageCleanupConfig 使用记录清理任务配置
type UsageCleanupConfig struct {
Enabled bool `mapstructure:"enabled"`
MaxRangeDays int `mapstructure:"max_range_days"`
BatchSize int `mapstructure:"batch_size"`
WorkerIntervalSeconds int `mapstructure:"worker_interval_seconds"`
TaskTimeoutSeconds int `mapstructure:"task_timeout_seconds"`
}

View File

@@ -0,0 +1,152 @@
package config
// SoraConfig 直连 Sora 配置
type SoraConfig struct {
Client SoraClientConfig `mapstructure:"client"`
Storage SoraStorageConfig `mapstructure:"storage"`
}
// SoraClientConfig Sora 客户端配置
type SoraClientConfig struct {
BaseURL string `mapstructure:"base_url"`
TimeoutSeconds int `mapstructure:"timeout_seconds"`
MaxRetries int `mapstructure:"max_retries"`
CloudflareChallengeCooldownSeconds int `mapstructure:"cloudflare_challenge_cooldown_seconds"`
PollIntervalSeconds int `mapstructure:"poll_interval_seconds"`
MaxPollAttempts int `mapstructure:"max_poll_attempts"`
RecentTaskLimit int `mapstructure:"recent_task_limit"`
RecentTaskLimitMax int `mapstructure="recent_task_limit_max"`
Debug bool `mapstructure:"debug"`
UseOpenAITokenProvider bool `mapstructure:"use_openai_token_provider"`
Headers map[string]string `mapstructure:"headers"`
UserAgent string `mapstructure:"user_agent"`
DisableTLSFingerprint bool `mapstructure:"disable_tls_fingerprint"`
CurlCFFISidecar SoraCurlCFFISidecarConfig `mapstructure:"curl_cffi_sidecar"`
}
// SoraCurlCFFISidecarConfig Sora curl_cffi sidecar 配置
type SoraCurlCFFISidecarConfig struct {
Enabled bool `mapstructure:"enabled"`
BaseURL string `mapstructure:"base_url"`
Impersonate string `mapstructure:"impersonate"`
TimeoutSeconds int `mapstructure:"timeout_seconds"`
SessionReuseEnabled bool `mapstructure:"session_reuse_enabled"`
SessionTTLSeconds int `mapstructure:"session_ttl_seconds"`
}
// SoraStorageConfig 媒体存储配置
type SoraStorageConfig struct {
Type string `mapstructure:"type"`
LocalPath string `mapstructure:"local_path"`
FallbackToUpstream bool `mapstructure:"fallback_to_upstream"`
MaxConcurrentDownloads int `mapstructure:"max_concurrent_downloads"`
DownloadTimeoutSeconds int `mapstructure:"download_timeout_seconds"`
MaxDownloadBytes int64 `mapstructure:"max_download_bytes"`
Debug bool `mapstructure:"debug"`
Cleanup SoraStorageCleanupConfig `mapstructure:"cleanup"`
}
// SoraStorageCleanupConfig 媒体清理配置
type SoraStorageCleanupConfig struct {
Enabled bool `mapstructure:"enabled"`
Schedule string `mapstructure:"schedule"`
RetentionDays int `mapstructure:"retention_days"`
}
// SoraModelFiltersConfig Sora 模型过滤配置
type SoraModelFiltersConfig struct {
HidePromptEnhance bool `mapstructure:"hide_prompt_envelope"`
}
// GeminiConfig Gemini 配置
type GeminiConfig struct {
OAuth GeminiOAuthConfig `mapstructure:"oauth"`
Quota GeminiQuotaConfig `mapstructure:"quota"`
}
type GeminiOAuthConfig struct {
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
Scopes string `mapstructure:"scopes"`
}
type GeminiQuotaConfig struct {
Tiers map[string]GeminiTierQuotaConfig `mapstructure:"tiers"`
Policy string `mapstructure:"policy"`
}
type GeminiTierQuotaConfig struct {
ProRPD *int64 `mapstructure:"pro_rpd" json:"pro_rpd"`
FlashRPD *int64 `mapstructure:"flash_rpd" json:"flash_rpd"`
CooldownMinutes *int `mapstructure:"cooldown_minutes" json:"cooldown_minutes"`
}
// UpdateConfig 更新检查配置
type UpdateConfig struct {
ProxyURL string `mapstructure:"proxy_url"` // 访问 GitHub 的代理地址
}
// IdempotencyConfig 幂等性配置
type IdempotencyConfig struct {
ObserveOnly bool `mapstructure:"observe_only"`
DefaultTTLSeconds int `mapstructure:"default_ttl_seconds"`
SystemOperationTTLSeconds int `mapstructure:"system_operation_ttl_seconds"`
ProcessingTimeoutSeconds int `mapstructure:"processing_timeout_seconds"`
FailedRetryBackoffSeconds int `mapstructure:"failed_retry_backoff_seconds"`
MaxStoredResponseLen int `mapstructure:"max_stored_response_len"`
CleanupIntervalSeconds int `mapstructure:"cleanup_interval_seconds"`
CleanupBatchSize int `mapstructure:"cleanup_batch_size"`
}
// LinuxDoConnectConfig LinuxDo 连接配置
type LinuxDoConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"`
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
TokenAuthMethod string `mapstructure:"token_auth_method"`
UsePKCE bool `mapstructure:"use_pkce"`
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
// OIDCConnectConfig OIDC 连接配置
type OIDCConnectConfig struct {
Enabled bool `mapstructure:"enabled"`
ProviderName string `mapstructure:"provider_name"`
ClientID string `mapstructure:"client_id"`
ClientSecret string `mapstructure:"client_secret"`
IssuerURL string `mapstructure:"issuer_url"`
DiscoveryURL string `mapstructure:"discovery_url"`
AuthorizeURL string `mapstructure:"authorize_url"`
TokenURL string `mapstructure:"token_url"`
UserInfoURL string `mapstructure:"userinfo_url"`
JWKSURL string `mapstructure:"jwks_url"`
Scopes string `mapstructure:"scopes"`
RedirectURL string `mapstructure:"redirect_url"`
FrontendRedirectURL string `mapstructure:"frontend_redirect_url"`
TokenAuthMethod string `mapstructure:"token_auth_method"`
UsePKCE bool `mapstructure:"use_pkce"`
ValidateIDToken bool `mapstructure:"validate_id_token"`
AllowedSigningAlgs string `mapstructure:"allowed_signing_algs"`
ClockSkewSeconds int `mapstructure:"clock_skew_seconds"`
RequireEmailVerified bool `mapstructure:"require_email_verified"`
UserInfoEmailPath string `mapstructure:"userinfo_email_path"`
UserInfoIDPath string `mapstructure:"userinfo_id_path"`
UserInfoUsernamePath string `mapstructure:"userinfo_username_path"`
}
// TokenRefreshConfig OAuth Token 自动刷新配置
type TokenRefreshConfig struct {
Enabled bool `mapstructure:"enabled"`
CheckIntervalMinutes int `mapstructure:"check_interval_minutes"`
RefreshBeforeExpiryHours float64 `mapstructure:"refresh_before_expiry_hours"`
MaxRetries int `mapstructure:"max_retries"`
RetryBackoffSeconds int `mapstructure:"retry_backoff_seconds"`
}

View File

@@ -0,0 +1,48 @@
package config
// SecurityConfig 安全相关配置
type SecurityConfig struct {
URLAllowlist URLAllowlistConfig `mapstructure:"url_allowlist"`
ResponseHeaders ResponseHeaderConfig `mapstructure:"response_headers"`
CSP CSPConfig `mapstructure:"csp"`
ProxyFallback ProxyFallbackConfig `mapstructure:"proxy_fallback"`
ProxyProbe ProxyProbeConfig `mapstructure:"proxy_probe"`
}
// URLAllowlistConfig URL 白名单配置
type URLAllowlistConfig struct {
Enabled bool `mapstructure:"enabled"`
UpstreamHosts []string `mapstructure:"upstream_hosts"`
PricingHosts []string `mapstructure:"pricing_hosts"`
CRSHosts []string `mapstructure:"crs_hosts"`
AllowPrivateHosts bool `mapstructure:"allow_private_hosts"`
// 关闭 URL 白名单校验时,是否允许 http URL默认只允许 https
AllowInsecureHTTP bool `mapstructure:"allow_insecure_http"`
}
// ResponseHeaderConfig 安全响应头配置
type ResponseHeaderConfig struct {
Enabled bool `mapstructure:"enabled"`
AdditionalAllowed []string `mapstructure:"additional_allowed"`
ForceRemove []string `mapstructure:"force_remove"`
}
// CSPConfig Content-Security-Policy 配置
type CSPConfig struct {
Enabled bool `mapstructure:"enabled"`
Policy string `mapstructure:"policy"`
}
// ProxyFallbackConfig 代理回退配置
type ProxyFallbackConfig struct {
// AllowDirectOnError 当辅助服务的代理初始化失败时是否允许回退直连。
// 仅影响非 AI 账号连接的辅助服务GitHub Release 更新检查、定价数据拉取)。
// 不影响 AI 账号网关连接,这些关键路径的代理失败始终返回错误。
// 默认 false避免因代理配置错误导致服务器真实 IP 泄露。
AllowDirectOnError bool `mapstructure:"allow_direct_on_error"`
}
// ProxyProbeConfig 代理探测配置
type ProxyProbeConfig struct {
InsecureSkipVerify bool `mapstructure:"insecure_skip_verify"` // 已禁用:禁止跳过 TLS 证书验证
}

View File

@@ -0,0 +1,42 @@
package config
import "fmt"
// ServerConfig HTTP 服务端配置
type ServerConfig struct {
Host string `mapstructure:"host"`
Port int `mapstructure:"port"`
Mode string `mapstructure:"mode"` // debug/release
FrontendURL string `mapstructure:"frontend_url"` // 前端基础 URL用于生成邮件中的外部链接
ReadHeaderTimeout int `mapstructure:"read_header_timeout"` // 读取请求头超时(秒)
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲连接超时(秒)
TrustedProxies []string `mapstructure:"trusted_proxies"` // 可信代理列表CIDR/IP
MaxRequestBodySize int64 `mapstructure:"max_request_body_size"` // 全局最大请求体限制
H2C H2CConfig `mapstructure:"h2c"` // HTTP/2 Cleartext 配置
}
// H2CConfig HTTP/2 Cleartext 配置
type H2CConfig struct {
Enabled bool `mapstructure:"enabled"` // 是否启用 H2C
MaxConcurrentStreams uint32 `mapstructure:"max_concurrent_streams"` // 最大并发流数量
IdleTimeout int `mapstructure:"idle_timeout"` // 空闲超时(秒)
MaxReadFrameSize int `mapstructure:"max_read_frame_size"` // 最大帧大小(字节)
MaxUploadBufferPerConnection int `mapstructure:"max_upload_buffer_per_connection"` // 每个连接的上传缓冲区(字节)
MaxUploadBufferPerStream int `mapstructure:"max_upload_buffer_per_stream"` // 每个流的上传缓冲区(字节)
}
// CORSConfig 跨域配置
type CORSConfig struct {
AllowedOrigins []string `mapstructure:"allowed_origins"`
AllowCredentials bool `mapstructure:"allow_credentials"`
}
// ConcurrencyConfig 并发配置
type ConcurrencyConfig struct {
// PingInterval: 并发等待期间的 SSE ping 间隔(秒)
PingInterval int `mapstructure:"ping_interval"`
}
func (s ServerConfig) Address() string {
return fmt.Sprintf("%s:%d", s.Host, s.Port)
}