package config import ( "fmt" "os" "strconv" "strings" "time" "github.com/spf13/viper" ) // Config 应用配置 type Config struct { Server ServerConfig Database DatabaseConfig Redis RedisConfig Token TokenConfig Audit AuditConfig } // ServerConfig HTTP服务配置 type ServerConfig struct { Addr string ReadTimeout time.Duration WriteTimeout time.Duration IdleTimeout time.Duration ShutdownTimeout time.Duration } // DatabaseConfig PostgreSQL配置 type DatabaseConfig struct { Host string Port int User string Password string Database string MaxOpenConns int MaxIdleConns int ConnMaxLifetime time.Duration ConnMaxIdleTime time.Duration } // RedisConfig Redis配置 type RedisConfig struct { Host string Port int Password string DB int PoolSize int } // TokenConfig Token运行时配置 type TokenConfig struct { SecretKey string Issuer string AccessTokenTTL time.Duration RefreshTokenTTL time.Duration RevocationCacheTTL time.Duration } // AuditConfig 审计配置 type AuditConfig struct { BufferSize int FlushInterval time.Duration ExportTimeout time.Duration } // DSN 返回数据库连接字符串(包含明文密码,仅限内部使用) func (d *DatabaseConfig) DSN() string { return fmt.Sprintf("postgres://%s:%s@%s:%d/%s?sslmode=disable", d.User, d.Password, d.Host, d.Port, d.Database) } // SafeDSN 返回脱敏的数据库连接字符串(密码被替换为***),用于日志记录 // P2-05: 避免在日志中泄露数据库密码 func (d *DatabaseConfig) SafeDSN() string { return fmt.Sprintf("postgres://%s:***@%s:%d/%s?sslmode=disable", d.User, d.Host, d.Port, d.Database) } // Addr 返回Redis地址 func (r *RedisConfig) Addr() string { return fmt.Sprintf("%s:%d", r.Host, r.Port) } // Load 加载配置 func Load(env string) (*Config, error) { v := viper.New() // 设置环境变量前缀 v.SetEnvPrefix("SUPPLY_API") v.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) // 默认配置 setDefaults(v) // 加载配置文件 configFile := fmt.Sprintf("config.%s.yaml", env) v.SetConfigName(configFile) v.SetConfigType("yaml") v.AddConfigPath(".") v.AddConfigPath("./config") // 允许环境变量覆盖 v.AutomaticEnv() // 读取配置文件 if err := v.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); !ok { return nil, fmt.Errorf("failed to read config: %w", err) } // 配置文件不存在时,使用环境变量 } // 绑定环境变量 bindEnvVars(v) var cfg Config // Server配置 cfg.Server.Addr = v.GetString("server.addr") cfg.Server.ReadTimeout = v.GetDuration("server.read_timeout") cfg.Server.WriteTimeout = v.GetDuration("server.write_timeout") cfg.Server.IdleTimeout = v.GetDuration("server.idle_timeout") cfg.Server.ShutdownTimeout = v.GetDuration("server.shutdown_timeout") // Database配置 cfg.Database.Host = v.GetString("database.host") cfg.Database.Port = v.GetInt("database.port") cfg.Database.User = v.GetString("database.user") cfg.Database.Password = v.GetString("database.password") cfg.Database.Database = v.GetString("database.database") cfg.Database.MaxOpenConns = v.GetInt("database.max_open_conns") cfg.Database.MaxIdleConns = v.GetInt("database.max_idle_conns") cfg.Database.ConnMaxLifetime = v.GetDuration("database.conn_max_lifetime") cfg.Database.ConnMaxIdleTime = v.GetDuration("database.conn_max_idle_time") // Redis配置 cfg.Redis.Host = v.GetString("redis.host") cfg.Redis.Port = v.GetInt("redis.port") cfg.Redis.Password = v.GetString("redis.password") cfg.Redis.DB = v.GetInt("redis.db") cfg.Redis.PoolSize = v.GetInt("redis.pool_size") // Token配置 cfg.Token.SecretKey = v.GetString("token.secret_key") cfg.Token.Issuer = v.GetString("token.issuer") cfg.Token.AccessTokenTTL = v.GetDuration("token.access_token_ttl") cfg.Token.RefreshTokenTTL = v.GetDuration("token.refresh_token_ttl") cfg.Token.RevocationCacheTTL = v.GetDuration("token.revocation_cache_ttl") // Audit配置 cfg.Audit.BufferSize = v.GetInt("audit.buffer_size") cfg.Audit.FlushInterval = v.GetDuration("audit.flush_interval") cfg.Audit.ExportTimeout = v.GetDuration("audit.export_timeout") return &cfg, nil } // setDefaults 设置默认值 func setDefaults(v *viper.Viper) { // Server defaults v.SetDefault("server.addr", ":18082") v.SetDefault("server.read_timeout", 10*time.Second) v.SetDefault("server.write_timeout", 15*time.Second) v.SetDefault("server.idle_timeout", 30*time.Second) v.SetDefault("server.shutdown_timeout", 5*time.Second) // Database defaults v.SetDefault("database.host", "localhost") v.SetDefault("database.port", 5432) v.SetDefault("database.user", "postgres") v.SetDefault("database.password", "") v.SetDefault("database.database", "supply_db") v.SetDefault("database.max_open_conns", 25) v.SetDefault("database.max_idle_conns", 5) v.SetDefault("database.conn_max_lifetime", 1*time.Hour) v.SetDefault("database.conn_max_idle_time", 10*time.Minute) // Redis defaults v.SetDefault("redis.host", "localhost") v.SetDefault("redis.port", 6379) v.SetDefault("redis.password", "") v.SetDefault("redis.db", 0) v.SetDefault("redis.pool_size", 10) // Token defaults v.SetDefault("token.issuer", "lijiaoqiao/supply-api") v.SetDefault("token.access_token_ttl", 1*time.Hour) v.SetDefault("token.refresh_token_ttl", 7*24*time.Hour) v.SetDefault("token.revocation_cache_ttl", 30*time.Second) // Audit defaults v.SetDefault("audit.buffer_size", 1000) v.SetDefault("audit.flush_interval", 5*time.Second) v.SetDefault("audit.export_timeout", 30*time.Second) } // bindEnvVars 绑定环境变量 func bindEnvVars(v *viper.Viper) { _ = v.BindEnv("server.addr", "SUPPLY_API_ADDR") _ = v.BindEnv("server.read_timeout", "SUPPLY_API_READ_TIMEOUT") _ = v.BindEnv("server.write_timeout", "SUPPLY_API_WRITE_TIMEOUT") _ = v.BindEnv("database.host", "SUPPLY_DB_HOST") _ = v.BindEnv("database.port", "SUPPLY_DB_PORT") _ = v.BindEnv("database.user", "SUPPLY_DB_USER") _ = v.BindEnv("database.password", "SUPPLY_DB_PASSWORD") _ = v.BindEnv("database.database", "SUPPLY_DB_NAME") _ = v.BindEnv("database.max_open_conns", "SUPPLY_DB_MAX_OPEN_CONNS") _ = v.BindEnv("database.max_idle_conns", "SUPPLY_DB_MAX_IDLE_CONNS") _ = v.BindEnv("redis.host", "SUPPLY_REDIS_HOST") _ = v.BindEnv("redis.port", "SUPPLY_REDIS_PORT") _ = v.BindEnv("redis.password", "SUPPLY_REDIS_PASSWORD") _ = v.BindEnv("redis.db", "SUPPLY_REDIS_DB") _ = v.BindEnv("token.secret_key", "SUPPLY_TOKEN_SECRET_KEY") } // MustLoad 加载配置,失败时panic func MustLoad(env string) *Config { cfg, err := Load(env) if err != nil { panic("failed to load config: " + err.Error()) } return cfg } // GetEnvInt 获取环境变量int值 func GetEnvInt(key string, defaultVal int) int { if v := os.Getenv(key); v != "" { if i, err := strconv.Atoi(v); err == nil { return i } } return defaultVal } // GetEnvDuration 获取环境变量duration值 func GetEnvDuration(key string, defaultVal time.Duration) time.Duration { if v := os.Getenv(key); v != "" { if d, err := time.ParseDuration(v); err == nil { return d } } return defaultVal }