Files
lijiaoqiao/supply-api/cmd/supply-api/main.go
Your Name 50225f6822 fix: 修复4个安全漏洞 (HIGH-01, HIGH-02, MED-01, MED-02)
- HIGH-01: CheckScope空scope绕过权限检查
  * 修复: 空scope现在返回false拒绝访问

- HIGH-02: JWT算法验证不严格
  * 修复: 使用token.Method.Alg()严格验证只接受HS256

- MED-01: RequireAnyScope空scope列表逻辑错误
  * 修复: 空列表现在返回403拒绝访问

- MED-02: Token状态缓存未命中时默认返回active
  * 修复: 添加TokenStatusBackend接口,缓存未命中时必须查询后端

影响文件:
- supply-api/internal/iam/middleware/scope_auth.go
- supply-api/internal/middleware/auth.go
- supply-api/cmd/supply-api/main.go (适配新API)

测试覆盖:
- 添加4个新的安全测试用例
- 更新1个原有测试以反映正确的安全行为
2026-04-03 07:52:41 +08:00

480 lines
15 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package main
import (
"context"
"encoding/json"
"flag"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"lijiaoqiao/supply-api/internal/audit"
"lijiaoqiao/supply-api/internal/cache"
"lijiaoqiao/supply-api/internal/config"
"lijiaoqiao/supply-api/internal/domain"
"lijiaoqiao/supply-api/internal/httpapi"
"lijiaoqiao/supply-api/internal/middleware"
"lijiaoqiao/supply-api/internal/repository"
"lijiaoqiao/supply-api/internal/storage"
)
func main() {
// 解析命令行参数
env := flag.String("env", "dev", "environment: dev/staging/prod")
configPath := flag.String("config", "", "config file path")
flag.Parse()
// 确定配置文件路径
if *configPath == "" {
*configPath = "./config/config." + *env + ".yaml"
}
// 加载配置
cfg, err := config.Load(*env)
if err != nil {
log.Fatalf("failed to load config: %v", err)
}
log.Printf("starting supply-api in %s mode", *env)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 初始化数据库连接
db, err := repository.NewDB(ctx, cfg.Database)
if err != nil {
log.Printf("warning: failed to connect to database: %v (using in-memory store)", err)
db = nil
} else {
log.Printf("connected to database at %s:%d", cfg.Database.Host, cfg.Database.Port)
defer db.Close()
}
// 初始化Redis缓存
redisCache, err := cache.NewRedisCache(cfg.Redis)
if err != nil {
log.Printf("warning: failed to connect to redis: %v (caching disabled)", err)
redisCache = nil
} else {
log.Printf("connected to redis at %s:%d", cfg.Redis.Host, cfg.Redis.Port)
defer redisCache.Close()
}
// 初始化审计存储
auditStore := audit.NewMemoryAuditStore() // TODO: 替换为DB-backed实现
// 初始化存储层
var accountStore domain.AccountStore
var packageStore domain.PackageStore
var settlementStore domain.SettlementStore
var earningStore domain.EarningStore
if db != nil {
// 使用PostgreSQL存储
accountRepo := repository.NewAccountRepository(db.Pool)
packageRepo := repository.NewPackageRepository(db.Pool)
settlementRepo := repository.NewSettlementRepository(db.Pool)
idempotencyRepo := repository.NewIdempotencyRepository(db.Pool)
// 创建DB-backed存储使用repository作为store接口
accountStore = &DBAccountStore{repo: accountRepo}
packageStore = &DBPackageStore{repo: packageRepo}
settlementStore = &DBSettlementStore{repo: settlementRepo}
earningStore = &DBEarningStore{repo: settlementRepo} // 复用
_ = idempotencyRepo // 用于幂等中间件
} else {
// 回退到内存存储(开发模式)
accountStore = NewInMemoryAccountStoreAdapter()
packageStore = NewInMemoryPackageStoreAdapter()
settlementStore = NewInMemorySettlementStoreAdapter()
earningStore = NewInMemoryEarningStoreAdapter()
}
// 初始化不变量检查器
invariantChecker := domain.NewInvariantChecker(accountStore, packageStore, settlementStore)
_ = invariantChecker // 用于业务逻辑校验
// 初始化领域服务
accountService := domain.NewAccountService(accountStore, auditStore)
packageService := domain.NewPackageService(packageStore, accountStore, auditStore)
settlementService := domain.NewSettlementService(settlementStore, earningStore, auditStore)
earningService := domain.NewEarningService(earningStore)
// 初始化幂等仓储
var idempotencyRepo *repository.IdempotencyRepository
if db != nil {
idempotencyRepo = repository.NewIdempotencyRepository(db.Pool)
}
_ = idempotencyRepo // TODO: 在生产环境中用于DB-backed幂等
// 初始化Token缓存
tokenCache := middleware.NewTokenCache()
if redisCache != nil {
// 可以使用Redis缓存
}
// 初始化鉴权中间件
authConfig := middleware.AuthConfig{
SecretKey: cfg.Token.SecretKey,
Issuer: cfg.Token.Issuer,
CacheTTL: cfg.Token.RevocationCacheTTL,
Enabled: *env != "dev", // 开发模式禁用鉴权
}
authMiddleware := middleware.NewAuthMiddleware(authConfig, tokenCache, nil, nil)
// 初始化幂等中间件
idempotencyMiddleware := middleware.NewIdempotencyMiddleware(nil, middleware.IdempotencyConfig{
TTL: 24 * time.Hour,
Enabled: *env != "dev",
})
_ = idempotencyMiddleware // TODO: 在生产环境中用于幂等处理
// 初始化幂等存储
idempotencyStore := storage.NewInMemoryIdempotencyStore()
// 初始化HTTP API处理器
api := httpapi.NewSupplyAPI(
accountService,
packageService,
settlementService,
earningService,
idempotencyStore,
auditStore,
1, // 默认供应商ID
time.Now,
)
// 创建路由器
mux := http.NewServeMux()
// 健康检查端点
mux.HandleFunc("/actuator/health", handleHealthCheck(db, redisCache))
mux.HandleFunc("/actuator/health/live", handleLiveness)
mux.HandleFunc("/actuator/health/ready", handleReadiness(db, redisCache))
// 注册API路由应用鉴权和幂等中间件
api.Register(mux)
// 应用中间件链路
// 1. RequestID - 请求追踪
// 2. Recovery - Panic恢复
// 3. Logging - 请求日志
// 4. QueryKeyReject - 拒绝外部query key (M-016)
// 5. BearerExtract - Bearer Token提取
// 6. TokenVerify - JWT校验
// 7. ScopeRoleAuthz - 权限校验
// 8. Idempotent - 幂等处理
handler := http.Handler(mux)
handler = middleware.RequestID(handler)
handler = middleware.Recovery(handler)
handler = middleware.Logging(handler)
// 生产环境启用安全中间件
if *env != "dev" {
// 4. QueryKeyReject - 拒绝外部query key
handler = authMiddleware.QueryKeyRejectMiddleware(handler)
// 5. BearerExtract
handler = authMiddleware.BearerExtractMiddleware(handler)
// 6. TokenVerify
handler = authMiddleware.TokenVerifyMiddleware(handler)
}
// 注册API路由
api.Register(mux)
// 创建HTTP服务器
srv := &http.Server{
Addr: cfg.Server.Addr,
Handler: handler,
ReadHeaderTimeout: cfg.Server.ReadTimeout,
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
}
// 启动服务器
go func() {
log.Printf("supply-api listening on %s", cfg.Server.Addr)
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("listen failed: %v", err)
}
}()
// 优雅关闭
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
<-sigCh
log.Println("shutting down...")
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), cfg.Server.ShutdownTimeout)
defer shutdownCancel()
if err := srv.Shutdown(shutdownCtx); err != nil {
log.Printf("graceful shutdown failed: %v", err)
}
log.Println("shutdown complete")
}
// handleHealthCheck 健康检查
func handleHealthCheck(db *repository.DB, redisCache *cache.RedisCache) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
checks := map[string]string{
"database": "UP",
"redis": "UP",
}
if db != nil {
if err := db.HealthCheck(ctx); err != nil {
checks["database"] = "DOWN"
}
} else {
checks["database"] = "MISSING"
}
if redisCache != nil {
if err := redisCache.HealthCheck(ctx); err != nil {
checks["redis"] = "DOWN"
}
} else {
checks["redis"] = "MISSING"
}
status := http.StatusOK
for _, v := range checks {
if v == "DOWN" {
status = http.StatusServiceUnavailable
break
}
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]interface{}{
"status": map[bool]string{true: "UP", false: "DOWN"}[status == http.StatusOK],
"checks": checks,
"time": time.Now().Format(time.RFC3339),
})
}
}
// handleLiveness 存活探针
func handleLiveness(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status":"LIVE"}`))
}
// handleReadiness 就绪探针
func handleReadiness(db *repository.DB, redisCache *cache.RedisCache) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ready := true
if db == nil {
ready = false
} else if err := db.HealthCheck(ctx); err != nil {
ready = false
}
w.Header().Set("Content-Type", "application/json")
if ready {
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"status":"READY"}`))
} else {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(`{"status":"NOT_READY"}`))
}
}
}
// ==================== 内存存储适配器(开发模式)====================
// InMemoryAccountStoreAdapter 内存账号存储适配器
type InMemoryAccountStoreAdapter struct {
store *storage.InMemoryAccountStore
}
func NewInMemoryAccountStoreAdapter() *InMemoryAccountStoreAdapter {
return &InMemoryAccountStoreAdapter{store: storage.NewInMemoryAccountStore()}
}
func (a *InMemoryAccountStoreAdapter) Create(ctx context.Context, account *domain.Account) error {
return a.store.Create(ctx, account)
}
func (a *InMemoryAccountStoreAdapter) GetByID(ctx context.Context, supplierID, id int64) (*domain.Account, error) {
return a.store.GetByID(ctx, supplierID, id)
}
func (a *InMemoryAccountStoreAdapter) Update(ctx context.Context, account *domain.Account) error {
return a.store.Update(ctx, account)
}
func (a *InMemoryAccountStoreAdapter) List(ctx context.Context, supplierID int64) ([]*domain.Account, error) {
return a.store.List(ctx, supplierID)
}
// InMemoryPackageStoreAdapter 内存套餐存储适配器
type InMemoryPackageStoreAdapter struct {
store *storage.InMemoryPackageStore
}
func NewInMemoryPackageStoreAdapter() *InMemoryPackageStoreAdapter {
return &InMemoryPackageStoreAdapter{store: storage.NewInMemoryPackageStore()}
}
func (a *InMemoryPackageStoreAdapter) Create(ctx context.Context, pkg *domain.Package) error {
return a.store.Create(ctx, pkg)
}
func (a *InMemoryPackageStoreAdapter) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) {
return a.store.GetByID(ctx, supplierID, id)
}
func (a *InMemoryPackageStoreAdapter) Update(ctx context.Context, pkg *domain.Package) error {
return a.store.Update(ctx, pkg)
}
func (a *InMemoryPackageStoreAdapter) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
return a.store.List(ctx, supplierID)
}
// InMemorySettlementStoreAdapter 内存结算存储适配器
type InMemorySettlementStoreAdapter struct {
store *storage.InMemorySettlementStore
}
func NewInMemorySettlementStoreAdapter() *InMemorySettlementStoreAdapter {
return &InMemorySettlementStoreAdapter{store: storage.NewInMemorySettlementStore()}
}
func (a *InMemorySettlementStoreAdapter) Create(ctx context.Context, s *domain.Settlement) error {
return a.store.Create(ctx, s)
}
func (a *InMemorySettlementStoreAdapter) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
return a.store.GetByID(ctx, supplierID, id)
}
func (a *InMemorySettlementStoreAdapter) Update(ctx context.Context, s *domain.Settlement) error {
return a.store.Update(ctx, s)
}
func (a *InMemorySettlementStoreAdapter) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
return a.store.List(ctx, supplierID)
}
func (a *InMemorySettlementStoreAdapter) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) {
return a.store.GetWithdrawableBalance(ctx, supplierID)
}
// InMemoryEarningStoreAdapter 内存收益存储适配器
type InMemoryEarningStoreAdapter struct {
store *storage.InMemoryEarningStore
}
func NewInMemoryEarningStoreAdapter() *InMemoryEarningStoreAdapter {
return &InMemoryEarningStoreAdapter{store: storage.NewInMemoryEarningStore()}
}
func (a *InMemoryEarningStoreAdapter) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) {
return a.store.ListRecords(ctx, supplierID, startDate, endDate, page, pageSize)
}
func (a *InMemoryEarningStoreAdapter) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
return a.store.GetBillingSummary(ctx, supplierID, startDate, endDate)
}
// ==================== DB-backed存储适配器 ====================
// DBAccountStore DB-backed账号存储
type DBAccountStore struct {
repo *repository.AccountRepository
}
func (s *DBAccountStore) Create(ctx context.Context, account *domain.Account) error {
return s.repo.Create(ctx, account, "", "", "")
}
func (s *DBAccountStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Account, error) {
return s.repo.GetByID(ctx, supplierID, id)
}
func (s *DBAccountStore) Update(ctx context.Context, account *domain.Account) error {
return s.repo.Update(ctx, account, account.Version)
}
func (s *DBAccountStore) List(ctx context.Context, supplierID int64) ([]*domain.Account, error) {
return s.repo.List(ctx, supplierID)
}
// DBPackageStore DB-backed套餐存储
type DBPackageStore struct {
repo *repository.PackageRepository
}
func (s *DBPackageStore) Create(ctx context.Context, pkg *domain.Package) error {
return s.repo.Create(ctx, pkg, "", "")
}
func (s *DBPackageStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Package, error) {
return s.repo.GetByID(ctx, supplierID, id)
}
func (s *DBPackageStore) Update(ctx context.Context, pkg *domain.Package) error {
return s.repo.Update(ctx, pkg, pkg.Version)
}
func (s *DBPackageStore) List(ctx context.Context, supplierID int64) ([]*domain.Package, error) {
return s.repo.List(ctx, supplierID)
}
// DBSettlementStore DB-backed结算存储
type DBSettlementStore struct {
repo *repository.SettlementRepository
}
func (s *DBSettlementStore) Create(ctx context.Context, settlement *domain.Settlement) error {
return s.repo.Create(ctx, settlement, "", "", "")
}
func (s *DBSettlementStore) GetByID(ctx context.Context, supplierID, id int64) (*domain.Settlement, error) {
return s.repo.GetByID(ctx, supplierID, id)
}
func (s *DBSettlementStore) Update(ctx context.Context, settlement *domain.Settlement) error {
return s.repo.Update(ctx, settlement, settlement.Version)
}
func (s *DBSettlementStore) List(ctx context.Context, supplierID int64) ([]*domain.Settlement, error) {
return s.repo.List(ctx, supplierID)
}
func (s *DBSettlementStore) GetWithdrawableBalance(ctx context.Context, supplierID int64) (float64, error) {
// TODO: 实现真实查询 - 通过 account service 获取
return 0.0, nil
}
// DBEarningStore DB-backed收益存储
type DBEarningStore struct {
repo *repository.SettlementRepository
}
func (s *DBEarningStore) ListRecords(ctx context.Context, supplierID int64, startDate, endDate string, page, pageSize int) ([]*domain.EarningRecord, int, error) {
// TODO: 实现真实查询
return nil, 0, nil
}
func (s *DBEarningStore) GetBillingSummary(ctx context.Context, supplierID int64, startDate, endDate string) (*domain.BillingSummary, error) {
// TODO: 实现真实查询
return nil, nil
}