chore: initial snapshot for gitea/github upload

This commit is contained in:
Your Name
2026-03-26 16:04:46 +08:00
commit a699a1ac98
3497 changed files with 1586237 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,62 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_IsAnthropicAPIKeyPassthroughEnabled(t *testing.T) {
t.Run("Anthropic API Key 开启", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"anthropic_passthrough": true,
},
}
require.True(t, account.IsAnthropicAPIKeyPassthroughEnabled())
})
t.Run("Anthropic API Key 关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"anthropic_passthrough": false,
},
}
require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled())
})
t.Run("字段类型非法默认关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"anthropic_passthrough": "true",
},
}
require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled())
})
t.Run("非 Anthropic API Key 账号始终关闭", func(t *testing.T) {
oauth := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"anthropic_passthrough": true,
},
}
require.False(t, oauth.IsAnthropicAPIKeyPassthroughEnabled())
openai := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"anthropic_passthrough": true,
},
}
require.False(t, openai.IsAnthropicAPIKeyPassthroughEnabled())
})
}

View File

@@ -0,0 +1,160 @@
//go:build unit
package service
import (
"testing"
)
func TestGetBaseURL(t *testing.T) {
tests := []struct {
name string
account Account
expected string
}{
{
name: "non-apikey type returns empty",
account: Account{
Type: AccountTypeOAuth,
Platform: PlatformAnthropic,
},
expected: "",
},
{
name: "apikey without base_url returns default anthropic",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAnthropic,
Credentials: map[string]any{},
},
expected: "https://api.anthropic.com",
},
{
name: "apikey with custom base_url",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAnthropic,
Credentials: map[string]any{"base_url": "https://custom.example.com"},
},
expected: "https://custom.example.com",
},
{
name: "antigravity apikey auto-appends /antigravity",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
},
expected: "https://upstream.example.com/antigravity",
},
{
name: "antigravity apikey trims trailing slash before appending",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com/"},
},
expected: "https://upstream.example.com/antigravity",
},
{
name: "antigravity non-apikey returns empty",
account: Account{
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
},
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetBaseURL()
if result != tt.expected {
t.Errorf("GetBaseURL() = %q, want %q", result, tt.expected)
}
})
}
}
func TestGetGeminiBaseURL(t *testing.T) {
const defaultGeminiURL = "https://generativelanguage.googleapis.com"
tests := []struct {
name string
account Account
expected string
}{
{
name: "apikey without base_url returns default",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{},
},
expected: defaultGeminiURL,
},
{
name: "apikey with custom base_url",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
Credentials: map[string]any{"base_url": "https://custom-gemini.example.com"},
},
expected: "https://custom-gemini.example.com",
},
{
name: "antigravity apikey auto-appends /antigravity",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
},
expected: "https://upstream.example.com/antigravity",
},
{
name: "antigravity apikey trims trailing slash",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com/"},
},
expected: "https://upstream.example.com/antigravity",
},
{
name: "antigravity oauth does NOT append /antigravity",
account: Account{
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
},
expected: "https://upstream.example.com",
},
{
name: "oauth without base_url returns default",
account: Account{
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Credentials: map[string]any{},
},
expected: defaultGeminiURL,
},
{
name: "nil credentials returns default",
account: Account{
Type: AccountTypeAPIKey,
Platform: PlatformGemini,
},
expected: defaultGeminiURL,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := tt.account.GetGeminiBaseURL(defaultGeminiURL)
if result != tt.expected {
t.Errorf("GetGeminiBaseURL() = %q, want %q", result, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,27 @@
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil(t *testing.T) {
var a Account
require.NoError(t, json.Unmarshal([]byte(`{"id":1,"name":"acc","status":"active"}`), &a))
require.Nil(t, a.RateMultiplier)
require.Equal(t, 1.0, a.BillingRateMultiplier())
}
func TestAccount_BillingRateMultiplier_AllowsZero(t *testing.T) {
v := 0.0
a := Account{RateMultiplier: &v}
require.Equal(t, 0.0, a.BillingRateMultiplier())
}
func TestAccount_BillingRateMultiplier_NegativeFallsBackToOne(t *testing.T) {
v := -1.0
a := Account{RateMultiplier: &v}
require.Equal(t, 1.0, a.BillingRateMultiplier())
}

View File

@@ -0,0 +1,71 @@
package service
import (
"context"
"log"
"sync"
"time"
)
// AccountExpiryService periodically pauses expired accounts when auto-pause is enabled.
type AccountExpiryService struct {
accountRepo AccountRepository
interval time.Duration
stopCh chan struct{}
stopOnce sync.Once
wg sync.WaitGroup
}
func NewAccountExpiryService(accountRepo AccountRepository, interval time.Duration) *AccountExpiryService {
return &AccountExpiryService{
accountRepo: accountRepo,
interval: interval,
stopCh: make(chan struct{}),
}
}
func (s *AccountExpiryService) Start() {
if s == nil || s.accountRepo == nil || s.interval <= 0 {
return
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
ticker := time.NewTicker(s.interval)
defer ticker.Stop()
s.runOnce()
for {
select {
case <-ticker.C:
s.runOnce()
case <-s.stopCh:
return
}
}
}()
}
func (s *AccountExpiryService) Stop() {
if s == nil {
return
}
s.stopOnce.Do(func() {
close(s.stopCh)
})
s.wg.Wait()
}
func (s *AccountExpiryService) runOnce() {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
updated, err := s.accountRepo.AutoPauseExpiredAccounts(ctx, time.Now())
if err != nil {
log.Printf("[AccountExpiry] Auto pause expired accounts failed: %v", err)
return
}
if updated > 0 {
log.Printf("[AccountExpiry] Auto paused %d expired accounts", updated)
}
}

View File

@@ -0,0 +1,13 @@
package service
import "time"
type AccountGroup struct {
AccountID int64
GroupID int64
Priority int
CreatedAt time.Time
Account *Account
Group *Group
}

View File

@@ -0,0 +1,66 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_IsInterceptWarmupEnabled(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
expected bool
}{
{
name: "nil credentials",
credentials: nil,
expected: false,
},
{
name: "empty map",
credentials: map[string]any{},
expected: false,
},
{
name: "field not present",
credentials: map[string]any{"access_token": "tok"},
expected: false,
},
{
name: "field is true",
credentials: map[string]any{"intercept_warmup_requests": true},
expected: true,
},
{
name: "field is false",
credentials: map[string]any{"intercept_warmup_requests": false},
expected: false,
},
{
name: "field is string true",
credentials: map[string]any{"intercept_warmup_requests": "true"},
expected: false,
},
{
name: "field is int 1",
credentials: map[string]any{"intercept_warmup_requests": 1},
expected: false,
},
{
name: "field is nil",
credentials: map[string]any{"intercept_warmup_requests": nil},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Credentials: tt.credentials}
result := a.IsInterceptWarmupEnabled()
require.Equal(t, tt.expected, result)
})
}
}

View File

@@ -0,0 +1,46 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func intPtrHelper(v int) *int { return &v }
func TestEffectiveLoadFactor_NilAccount(t *testing.T) {
var a *Account
require.Equal(t, 1, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_NilLoadFactor_PositiveConcurrency(t *testing.T) {
a := &Account{Concurrency: 5}
require.Equal(t, 5, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_NilLoadFactor_ZeroConcurrency(t *testing.T) {
a := &Account{Concurrency: 0}
require.Equal(t, 1, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_PositiveLoadFactor(t *testing.T) {
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(20)}
require.Equal(t, 20, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_ZeroLoadFactor_FallbackToConcurrency(t *testing.T) {
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(0)}
require.Equal(t, 5, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_NegativeLoadFactor_FallbackToConcurrency(t *testing.T) {
a := &Account{Concurrency: 3, LoadFactor: intPtrHelper(-1)}
require.Equal(t, 3, a.EffectiveLoadFactor())
}
func TestEffectiveLoadFactor_ZeroLoadFactor_ZeroConcurrency(t *testing.T) {
a := &Account{Concurrency: 0, LoadFactor: intPtrHelper(0)}
require.Equal(t, 1, a.EffectiveLoadFactor())
}

View File

@@ -0,0 +1,315 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAccount_IsOpenAIPassthroughEnabled(t *testing.T) {
t.Run("新字段开启", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_passthrough": true,
},
}
require.True(t, account.IsOpenAIPassthroughEnabled())
})
t.Run("兼容旧字段", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_passthrough": true,
},
}
require.True(t, account.IsOpenAIPassthroughEnabled())
})
t.Run("非OpenAI账号始终关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_passthrough": true,
},
}
require.False(t, account.IsOpenAIPassthroughEnabled())
})
t.Run("空额外配置默认关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
require.False(t, account.IsOpenAIPassthroughEnabled())
})
}
func TestAccount_IsOpenAIOAuthPassthroughEnabled(t *testing.T) {
t.Run("仅OAuth类型允许返回开启", func(t *testing.T) {
oauthAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_passthrough": true,
},
}
require.True(t, oauthAccount.IsOpenAIOAuthPassthroughEnabled())
apiKeyAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_passthrough": true,
},
}
require.False(t, apiKeyAccount.IsOpenAIOAuthPassthroughEnabled())
})
}
func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) {
t.Run("OpenAI OAuth 开启", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"codex_cli_only": true,
},
}
require.True(t, account.IsCodexCLIOnlyEnabled())
})
t.Run("OpenAI OAuth 关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"codex_cli_only": false,
},
}
require.False(t, account.IsCodexCLIOnlyEnabled())
})
t.Run("字段缺失默认关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{},
}
require.False(t, account.IsCodexCLIOnlyEnabled())
})
t.Run("类型非法默认关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"codex_cli_only": "true",
},
}
require.False(t, account.IsCodexCLIOnlyEnabled())
})
t.Run("非 OAuth 账号始终关闭", func(t *testing.T) {
apiKeyAccount := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"codex_cli_only": true,
},
}
require.False(t, apiKeyAccount.IsCodexCLIOnlyEnabled())
otherPlatform := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"codex_cli_only": true,
},
}
require.False(t, otherPlatform.IsCodexCLIOnlyEnabled())
})
}
func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
t.Run("OAuth使用OAuth专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("API Key使用API Key专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("分类型新键优先于兼容键", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": true,
"openai_ws_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("分类型键缺失时回退兼容键", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
t.Run("非OpenAI账号默认关闭", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
})
}
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
t.Run("default fallback to ctx_pool", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{},
}
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
})
t.Run("oauth mode field has highest priority", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
"openai_oauth_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": false,
},
}
require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
})
t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
})
t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) {
shared := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
},
}
dedicated := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared))
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated))
})
t.Run("legacy disabled maps to off", func(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeAPIKey,
Extra: map[string]any{
"openai_apikey_responses_websockets_v2_enabled": false,
"responses_websockets_v2_enabled": true,
},
}
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
})
t.Run("non openai always off", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
},
}
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated))
})
}
func TestAccount_OpenAIWSExtraFlags(t *testing.T) {
account := &Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_ws_force_http": true,
"openai_ws_allow_store_recovery": true,
},
}
require.True(t, account.IsOpenAIWSForceHTTPEnabled())
require.True(t, account.IsOpenAIWSAllowStoreRecoveryEnabled())
off := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
require.False(t, off.IsOpenAIWSForceHTTPEnabled())
require.False(t, off.IsOpenAIWSAllowStoreRecoveryEnabled())
var nilAccount *Account
require.False(t, nilAccount.IsOpenAIWSAllowStoreRecoveryEnabled())
nonOpenAI := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_ws_allow_store_recovery": true,
},
}
require.False(t, nonOpenAI.IsOpenAIWSAllowStoreRecoveryEnabled())
}

View File

@@ -0,0 +1,117 @@
//go:build unit
package service
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestGetPoolModeRetryCount(t *testing.T) {
tests := []struct {
name string
account *Account
expected int
}{
{
name: "default_when_not_pool_mode",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{},
},
expected: defaultPoolModeRetryCount,
},
{
name: "default_when_missing_retry_count",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
},
},
expected: defaultPoolModeRetryCount,
},
{
name: "supports_float64_from_json_credentials",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": float64(5),
},
},
expected: 5,
},
{
name: "supports_json_number",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": json.Number("4"),
},
},
expected: 4,
},
{
name: "supports_string_value",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": "2",
},
},
expected: 2,
},
{
name: "negative_value_is_clamped_to_zero",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": -1,
},
},
expected: 0,
},
{
name: "oversized_value_is_clamped_to_max",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": 99,
},
},
expected: maxPoolModeRetryCount,
},
{
name: "invalid_value_falls_back_to_default",
account: &Account{
Type: AccountTypeAPIKey,
Platform: PlatformOpenAI,
Credentials: map[string]any{
"pool_mode": true,
"pool_mode_retry_count": "oops",
},
},
expected: defaultPoolModeRetryCount,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.expected, tt.account.GetPoolModeRetryCount())
})
}
}

View File

@@ -0,0 +1,516 @@
//go:build unit
package service
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// nextFixedDailyReset
// ---------------------------------------------------------------------------
func TestNextFixedDailyReset_BeforeResetHour(t *testing.T) {
tz := time.UTC
// 2026-03-14 06:00 UTC, reset hour = 9
after := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
got := nextFixedDailyReset(9, tz, after)
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_AtResetHour(t *testing.T) {
tz := time.UTC
// Exactly at reset hour → should return tomorrow
after := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
got := nextFixedDailyReset(9, tz, after)
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_AfterResetHour(t *testing.T) {
tz := time.UTC
// After reset hour → should return tomorrow
after := time.Date(2026, 3, 14, 15, 30, 0, 0, tz)
got := nextFixedDailyReset(9, tz, after)
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_MidnightReset(t *testing.T) {
tz := time.UTC
// Reset at hour 0 (midnight), currently 23:59
after := time.Date(2026, 3, 14, 23, 59, 0, 0, tz)
got := nextFixedDailyReset(0, tz, after)
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedDailyReset_NonUTCTimezone(t *testing.T) {
tz, err := time.LoadLocation("Asia/Shanghai")
require.NoError(t, err)
// 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST)
after := time.Date(2026, 3, 14, 7, 0, 0, 0, time.UTC)
got := nextFixedDailyReset(9, tz, after)
// Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// lastFixedDailyReset
// ---------------------------------------------------------------------------
func TestLastFixedDailyReset_BeforeResetHour(t *testing.T) {
tz := time.UTC
now := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
got := lastFixedDailyReset(9, tz, now)
// Before today's 9:00 → yesterday 9:00
want := time.Date(2026, 3, 13, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedDailyReset_AtResetHour(t *testing.T) {
tz := time.UTC
now := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
got := lastFixedDailyReset(9, tz, now)
// At exactly 9:00 → today 9:00
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedDailyReset_AfterResetHour(t *testing.T) {
tz := time.UTC
now := time.Date(2026, 3, 14, 15, 0, 0, 0, tz)
got := lastFixedDailyReset(9, tz, now)
// After 9:00 → today 9:00
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// nextFixedWeeklyReset
// ---------------------------------------------------------------------------
func TestNextFixedWeeklyReset_TargetDayAhead(t *testing.T) {
tz := time.UTC
// 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday = 2026-03-16
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayToday_BeforeHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00
after := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Today at 9:00
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayToday_AtHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00
after := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday at 9:00
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayToday_AfterHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00
after := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday at 9:00
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_TargetDayPast(t *testing.T) {
tz := time.UTC
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
after := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
got := nextFixedWeeklyReset(1, 9, tz, after)
// Next Monday = 2026-03-23
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestNextFixedWeeklyReset_Sunday(t *testing.T) {
tz := time.UTC
// 2026-03-14 is Saturday (day=6), target = Sunday (day=0)
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
got := nextFixedWeeklyReset(0, 0, tz, after)
// Next Sunday = 2026-03-15
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// lastFixedWeeklyReset
// ---------------------------------------------------------------------------
func TestLastFixedWeeklyReset_SameDay_AfterHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00
now := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
got := lastFixedWeeklyReset(1, 9, tz, now)
// Today at 9:00
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedWeeklyReset_SameDay_BeforeHour(t *testing.T) {
tz := time.UTC
// 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00
now := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
got := lastFixedWeeklyReset(1, 9, tz, now)
// Last Monday at 9:00 = 2026-03-09
want := time.Date(2026, 3, 9, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
func TestLastFixedWeeklyReset_DifferentDay(t *testing.T) {
tz := time.UTC
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
now := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
got := lastFixedWeeklyReset(1, 9, tz, now)
// Last Monday = 2026-03-16
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
assert.Equal(t, want, got)
}
// ---------------------------------------------------------------------------
// isFixedDailyPeriodExpired
// ---------------------------------------------------------------------------
func TestIsFixedDailyPeriodExpired_ZeroPeriodStart(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
assert.True(t, a.isFixedDailyPeriodExpired(time.Time{}))
}
func TestIsFixedDailyPeriodExpired_NotExpired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started after the most recent reset → not expired
// (This test uses a time very close to "now", which is after the last reset)
periodStart := time.Now().Add(-1 * time.Minute)
assert.False(t, a.isFixedDailyPeriodExpired(periodStart))
}
func TestIsFixedDailyPeriodExpired_Expired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started 3 days ago → definitely expired
periodStart := time.Now().Add(-72 * time.Hour)
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
}
func TestIsFixedDailyPeriodExpired_InvalidTimezone(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "Invalid/Timezone",
}}
// Invalid timezone falls back to UTC
periodStart := time.Now().Add(-72 * time.Hour)
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
}
// ---------------------------------------------------------------------------
// isFixedWeeklyPeriodExpired
// ---------------------------------------------------------------------------
func TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
assert.True(t, a.isFixedWeeklyPeriodExpired(time.Time{}))
}
func TestIsFixedWeeklyPeriodExpired_NotExpired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started 1 minute ago → not expired
periodStart := time.Now().Add(-1 * time.Minute)
assert.False(t, a.isFixedWeeklyPeriodExpired(periodStart))
}
func TestIsFixedWeeklyPeriodExpired_Expired(t *testing.T) {
a := &Account{Extra: map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}}
// Period started 10 days ago → definitely expired
periodStart := time.Now().Add(-240 * time.Hour)
assert.True(t, a.isFixedWeeklyPeriodExpired(periodStart))
}
// ---------------------------------------------------------------------------
// ValidateQuotaResetConfig
// ---------------------------------------------------------------------------
func TestValidateQuotaResetConfig_NilExtra(t *testing.T) {
assert.NoError(t, ValidateQuotaResetConfig(nil))
}
func TestValidateQuotaResetConfig_EmptyExtra(t *testing.T) {
assert.NoError(t, ValidateQuotaResetConfig(map[string]any{}))
}
func TestValidateQuotaResetConfig_ValidFixed(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1),
"quota_weekly_reset_hour": float64(0),
"quota_reset_timezone": "Asia/Shanghai",
}
assert.NoError(t, ValidateQuotaResetConfig(extra))
}
func TestValidateQuotaResetConfig_ValidRolling(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "rolling",
"quota_weekly_reset_mode": "rolling",
}
assert.NoError(t, ValidateQuotaResetConfig(extra))
}
func TestValidateQuotaResetConfig_InvalidTimezone(t *testing.T) {
extra := map[string]any{
"quota_reset_timezone": "Not/A/Timezone",
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_reset_timezone")
}
func TestValidateQuotaResetConfig_InvalidDailyMode(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "invalid",
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_daily_reset_mode")
}
func TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_hour": float64(24),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
}
func TestValidateQuotaResetConfig_InvalidDailyHour_Negative(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_hour": float64(-1),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
}
func TestValidateQuotaResetConfig_InvalidWeeklyMode(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_mode": "unknown",
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_mode")
}
func TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_day": float64(7),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
}
func TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_day": float64(-1),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
}
func TestValidateQuotaResetConfig_InvalidWeeklyHour(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_hour": float64(25),
}
err := ValidateQuotaResetConfig(extra)
require.Error(t, err)
assert.Contains(t, err.Error(), "quota_weekly_reset_hour")
}
func TestValidateQuotaResetConfig_BoundaryValues(t *testing.T) {
// All boundary values should be valid
extra := map[string]any{
"quota_daily_reset_hour": float64(23),
"quota_weekly_reset_day": float64(0), // Sunday
"quota_weekly_reset_hour": float64(0),
"quota_reset_timezone": "UTC",
}
assert.NoError(t, ValidateQuotaResetConfig(extra))
extra2 := map[string]any{
"quota_daily_reset_hour": float64(0),
"quota_weekly_reset_day": float64(6), // Saturday
"quota_weekly_reset_hour": float64(23),
}
assert.NoError(t, ValidateQuotaResetConfig(extra2))
}
// ---------------------------------------------------------------------------
// ComputeQuotaResetAt
// ---------------------------------------------------------------------------
func TestComputeQuotaResetAt_RollingMode_NoResetAt(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "rolling",
"quota_weekly_reset_mode": "rolling",
}
ComputeQuotaResetAt(extra)
_, hasDailyResetAt := extra["quota_daily_reset_at"]
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
assert.False(t, hasDailyResetAt, "rolling mode should not set quota_daily_reset_at")
assert.False(t, hasWeeklyResetAt, "rolling mode should not set quota_weekly_reset_at")
}
func TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "rolling",
"quota_weekly_reset_mode": "rolling",
"quota_daily_reset_at": "2026-03-14T09:00:00Z",
"quota_weekly_reset_at": "2026-03-16T09:00:00Z",
}
ComputeQuotaResetAt(extra)
_, hasDailyResetAt := extra["quota_daily_reset_at"]
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
assert.False(t, hasDailyResetAt, "rolling mode should remove quota_daily_reset_at")
assert.False(t, hasWeeklyResetAt, "rolling mode should remove quota_weekly_reset_at")
}
func TestComputeQuotaResetAt_FixedDaily_SetsResetAt(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "UTC",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok, "quota_daily_reset_at should be set")
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Reset time should be in the future
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
// Reset hour should be 9 UTC
assert.Equal(t, 9, resetAt.UTC().Hour())
}
func TestComputeQuotaResetAt_FixedWeekly_SetsResetAt(t *testing.T) {
extra := map[string]any{
"quota_weekly_reset_mode": "fixed",
"quota_weekly_reset_day": float64(1), // Monday
"quota_weekly_reset_hour": float64(0),
"quota_reset_timezone": "UTC",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_weekly_reset_at"].(string)
require.True(t, ok, "quota_weekly_reset_at should be set")
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Reset time should be in the future
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
// Reset day should be Monday
assert.Equal(t, time.Monday, resetAt.UTC().Weekday())
}
func TestComputeQuotaResetAt_FixedDaily_WithTimezone(t *testing.T) {
tz, err := time.LoadLocation("Asia/Shanghai")
require.NoError(t, err)
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(9),
"quota_reset_timezone": "Asia/Shanghai",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok)
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// In Shanghai timezone, the hour should be 9
assert.Equal(t, 9, resetAt.In(tz).Hour())
}
func TestComputeQuotaResetAt_DefaultTimezone(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(12),
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok)
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Default timezone is UTC
assert.Equal(t, 12, resetAt.UTC().Hour())
}
func TestComputeQuotaResetAt_InvalidHour_ClampedToZero(t *testing.T) {
extra := map[string]any{
"quota_daily_reset_mode": "fixed",
"quota_daily_reset_hour": float64(99),
"quota_reset_timezone": "UTC",
}
ComputeQuotaResetAt(extra)
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
require.True(t, ok)
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
require.NoError(t, err)
// Invalid hour → clamped to 0
assert.Equal(t, 0, resetAt.UTC().Hour())
}

View File

@@ -0,0 +1,120 @@
package service
import (
"encoding/json"
"testing"
)
func TestGetBaseRPM(t *testing.T) {
tests := []struct {
name string
extra map[string]any
expected int
}{
{"nil extra", nil, 0},
{"no key", map[string]any{}, 0},
{"zero", map[string]any{"base_rpm": 0}, 0},
{"int value", map[string]any{"base_rpm": 15}, 15},
{"float value", map[string]any{"base_rpm": 15.0}, 15},
{"string value", map[string]any{"base_rpm": "15"}, 15},
{"negative value", map[string]any{"base_rpm": -5}, 0},
{"int64 value", map[string]any{"base_rpm": int64(20)}, 20},
{"json.Number value", map[string]any{"base_rpm": json.Number("25")}, 25},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Extra: tt.extra}
if got := a.GetBaseRPM(); got != tt.expected {
t.Errorf("GetBaseRPM() = %d, want %d", got, tt.expected)
}
})
}
}
func TestGetRPMStrategy(t *testing.T) {
tests := []struct {
name string
extra map[string]any
expected string
}{
{"nil extra", nil, "tiered"},
{"no key", map[string]any{}, "tiered"},
{"tiered", map[string]any{"rpm_strategy": "tiered"}, "tiered"},
{"sticky_exempt", map[string]any{"rpm_strategy": "sticky_exempt"}, "sticky_exempt"},
{"invalid", map[string]any{"rpm_strategy": "foobar"}, "tiered"},
{"empty string fallback", map[string]any{"rpm_strategy": ""}, "tiered"},
{"numeric value fallback", map[string]any{"rpm_strategy": 123}, "tiered"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Extra: tt.extra}
if got := a.GetRPMStrategy(); got != tt.expected {
t.Errorf("GetRPMStrategy() = %q, want %q", got, tt.expected)
}
})
}
}
func TestCheckRPMSchedulability(t *testing.T) {
tests := []struct {
name string
extra map[string]any
currentRPM int
expected WindowCostSchedulability
}{
{"disabled", map[string]any{}, 100, WindowCostSchedulable},
{"green zone", map[string]any{"base_rpm": 15}, 10, WindowCostSchedulable},
{"yellow zone tiered", map[string]any{"base_rpm": 15}, 15, WindowCostStickyOnly},
{"red zone tiered", map[string]any{"base_rpm": 15}, 18, WindowCostNotSchedulable},
{"sticky_exempt at limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 15, WindowCostStickyOnly},
{"sticky_exempt over limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 100, WindowCostStickyOnly},
{"custom buffer", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 14, WindowCostStickyOnly},
{"custom buffer red", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 15, WindowCostNotSchedulable},
{"base_rpm=1 green", map[string]any{"base_rpm": 1}, 0, WindowCostSchedulable},
{"base_rpm=1 yellow (at limit)", map[string]any{"base_rpm": 1}, 1, WindowCostStickyOnly},
{"base_rpm=1 red (at limit+buffer)", map[string]any{"base_rpm": 1}, 2, WindowCostNotSchedulable},
{"negative currentRPM", map[string]any{"base_rpm": 15}, -1, WindowCostSchedulable},
{"base_rpm negative disabled", map[string]any{"base_rpm": -5}, 10, WindowCostSchedulable},
{"very high currentRPM", map[string]any{"base_rpm": 10}, 9999, WindowCostNotSchedulable},
{"sticky_exempt very high currentRPM", map[string]any{"base_rpm": 10, "rpm_strategy": "sticky_exempt"}, 9999, WindowCostStickyOnly},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Extra: tt.extra}
if got := a.CheckRPMSchedulability(tt.currentRPM); got != tt.expected {
t.Errorf("CheckRPMSchedulability(%d) = %d, want %d", tt.currentRPM, got, tt.expected)
}
})
}
}
func TestGetRPMStickyBuffer(t *testing.T) {
tests := []struct {
name string
extra map[string]any
expected int
}{
{"nil extra", nil, 0},
{"no keys", map[string]any{}, 0},
{"base_rpm=0", map[string]any{"base_rpm": 0}, 0},
{"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1},
{"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1},
{"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1},
{"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2},
{"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3},
{"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20},
{"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5},
{"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2},
{"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2},
{"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7},
{"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
a := &Account{Extra: tt.extra}
if got := a.GetRPMStickyBuffer(); got != tt.expected {
t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,397 @@
package service
import (
"context"
"fmt"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
var (
ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
)
type AccountRepository interface {
Create(ctx context.Context, account *Account) error
GetByID(ctx context.Context, id int64) (*Account, error)
// GetByIDs fetches accounts by IDs in a single query.
// It should return all accounts found (missing IDs are ignored).
GetByIDs(ctx context.Context, ids []int64) ([]*Account, error)
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
ExistsByID(ctx context.Context, id int64) (bool, error)
// GetByCRSAccountID finds an account previously synced from CRS.
// Returns (nil, nil) if not found.
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora'
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
// for all accounts that have been synced from CRS.
ListCRSAccountIDs(ctx context.Context) (map[string]int64, error)
Update(ctx context.Context, account *Account) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
ListActive(ctx context.Context) ([]Account, error)
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
UpdateLastUsed(ctx context.Context, id int64) error
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
SetError(ctx context.Context, id int64, errorMsg string) error
ClearError(ctx context.Context, id int64) error
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
ListSchedulable(ctx context.Context) ([]Account, error)
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error)
ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
SetOverloaded(ctx context.Context, id int64, until time.Time) error
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
ClearTempUnschedulable(ctx context.Context, id int64) error
ClearRateLimit(ctx context.Context, id int64) error
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
ClearModelRateLimits(ctx context.Context, id int64) error
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量(总/日/周)
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
// ResetQuotaUsed 重置 API Key 账号所有维度的配额用量为 0
ResetQuotaUsed(ctx context.Context, id int64) error
}
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
// Nil pointers mean "do not change".
type AccountBulkUpdate struct {
Name *string
ProxyID *int64
Concurrency *int
Priority *int
RateMultiplier *float64
LoadFactor *int
Status *string
Schedulable *bool
Credentials map[string]any
Extra map[string]any
}
// CreateAccountRequest 创建账号请求
type CreateAccountRequest struct {
Name string `json:"name"`
Notes *string `json:"notes"`
Platform string `json:"platform"`
Type string `json:"type"`
Credentials map[string]any `json:"credentials"`
Extra map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency int `json:"concurrency"`
Priority int `json:"priority"`
GroupIDs []int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
}
// UpdateAccountRequest 更新账号请求
type UpdateAccountRequest struct {
Name *string `json:"name"`
Notes *string `json:"notes"`
Credentials *map[string]any `json:"credentials"`
Extra *map[string]any `json:"extra"`
ProxyID *int64 `json:"proxy_id"`
Concurrency *int `json:"concurrency"`
Priority *int `json:"priority"`
Status *string `json:"status"`
GroupIDs *[]int64 `json:"group_ids"`
ExpiresAt *time.Time `json:"expires_at"`
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
}
// AccountService 账号管理服务
type AccountService struct {
accountRepo AccountRepository
groupRepo GroupRepository
}
type groupExistenceBatchChecker interface {
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
}
// NewAccountService 创建账号服务实例
func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService {
return &AccountService{
accountRepo: accountRepo,
groupRepo: groupRepo,
}
}
// Create 创建账号
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
// 验证分组是否存在(如果指定了分组)
if len(req.GroupIDs) > 0 {
if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil {
return nil, err
}
}
// 创建账号
account := &Account{
Name: req.Name,
Notes: normalizeAccountNotes(req.Notes),
Platform: req.Platform,
Type: req.Type,
Credentials: req.Credentials,
Extra: req.Extra,
ProxyID: req.ProxyID,
Concurrency: req.Concurrency,
Priority: req.Priority,
Status: StatusActive,
ExpiresAt: req.ExpiresAt,
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
} else {
account.AutoPauseOnExpired = true
}
if err := s.accountRepo.Create(ctx, account); err != nil {
return nil, fmt.Errorf("create account: %w", err)
}
// 绑定分组
if len(req.GroupIDs) > 0 {
if err := s.accountRepo.BindGroups(ctx, account.ID, req.GroupIDs); err != nil {
return nil, fmt.Errorf("bind groups: %w", err)
}
}
return account, nil
}
// GetByID 根据ID获取账号
func (s *AccountService) GetByID(ctx context.Context, id int64) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get account: %w", err)
}
return account, nil
}
// List 获取账号列表
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
accounts, pagination, err := s.accountRepo.List(ctx, params)
if err != nil {
return nil, nil, fmt.Errorf("list accounts: %w", err)
}
return accounts, pagination, nil
}
// ListByPlatform 根据平台获取账号列表
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
accounts, err := s.accountRepo.ListByPlatform(ctx, platform)
if err != nil {
return nil, fmt.Errorf("list accounts by platform: %w", err)
}
return accounts, nil
}
// ListByGroup 根据分组获取账号列表
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
if err != nil {
return nil, fmt.Errorf("list accounts by group: %w", err)
}
return accounts, nil
}
// Update 更新账号
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*Account, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get account: %w", err)
}
// 更新字段
if req.Name != nil {
account.Name = *req.Name
}
if req.Notes != nil {
account.Notes = normalizeAccountNotes(req.Notes)
}
if req.Credentials != nil {
account.Credentials = *req.Credentials
}
if req.Extra != nil {
account.Extra = *req.Extra
}
if req.ProxyID != nil {
account.ProxyID = req.ProxyID
}
if req.Concurrency != nil {
account.Concurrency = *req.Concurrency
}
if req.Priority != nil {
account.Priority = *req.Priority
}
if req.Status != nil {
account.Status = *req.Status
}
if req.ExpiresAt != nil {
account.ExpiresAt = req.ExpiresAt
}
if req.AutoPauseOnExpired != nil {
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
}
// 先验证分组是否存在(在任何写操作之前)
if req.GroupIDs != nil {
if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil {
return nil, err
}
}
// 执行更新
if err := s.accountRepo.Update(ctx, account); err != nil {
return nil, fmt.Errorf("update account: %w", err)
}
// 绑定分组
if req.GroupIDs != nil {
if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil {
return nil, fmt.Errorf("bind groups: %w", err)
}
}
return account, nil
}
// Delete 删除账号
// 优化:使用 ExistsByID 替代 GetByID 进行存在性检查,
// 避免加载完整账号对象及其关联数据,提升删除操作的性能
func (s *AccountService) Delete(ctx context.Context, id int64) error {
// 使用轻量级的存在性检查,而非加载完整账号对象
exists, err := s.accountRepo.ExistsByID(ctx, id)
if err != nil {
return fmt.Errorf("check account: %w", err)
}
// 明确返回账号不存在错误,便于调用方区分错误类型
if !exists {
return ErrAccountNotFound
}
if err := s.accountRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete account: %w", err)
}
return nil
}
func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
if len(groupIDs) == 0 {
return nil
}
if s.groupRepo == nil {
return fmt.Errorf("group repository not configured")
}
if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok {
existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs)
if err != nil {
return fmt.Errorf("check groups exists: %w", err)
}
for _, groupID := range groupIDs {
if groupID <= 0 {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
if !existsByID[groupID] {
return fmt.Errorf("get group: %w", ErrGroupNotFound)
}
}
return nil
}
for _, groupID := range groupIDs {
_, err := s.groupRepo.GetByID(ctx, groupID)
if err != nil {
return fmt.Errorf("get group: %w", err)
}
}
return nil
}
// UpdateStatus 更新账号状态
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("get account: %w", err)
}
account.Status = status
account.ErrorMessage = errorMessage
if err := s.accountRepo.Update(ctx, account); err != nil {
return fmt.Errorf("update account: %w", err)
}
return nil
}
// UpdateLastUsed 更新最后使用时间
func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error {
if err := s.accountRepo.UpdateLastUsed(ctx, id); err != nil {
return fmt.Errorf("update last used: %w", err)
}
return nil
}
// GetCredential 获取账号凭证(安全访问)
func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return "", fmt.Errorf("get account: %w", err)
}
return account.GetCredential(key), nil
}
// TestCredentials 测试账号凭证是否有效(需要实现具体平台的测试逻辑)
func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
account, err := s.accountRepo.GetByID(ctx, id)
if err != nil {
return fmt.Errorf("get account: %w", err)
}
// 根据平台执行不同的测试逻辑
switch account.Platform {
case PlatformAnthropic:
// TODO: 测试Anthropic API凭证
return nil
case PlatformOpenAI:
// TODO: 测试OpenAI API凭证
return nil
case PlatformGemini:
// TODO: 测试Gemini API凭证
return nil
default:
return fmt.Errorf("unsupported platform: %s", account.Platform)
}
}

View File

@@ -0,0 +1,271 @@
//go:build unit
// 账号服务删除方法的单元测试
// 测试 AccountService.Delete 方法在各种场景下的行为
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// accountRepoStub 是 AccountRepository 接口的测试桩实现。
// 用于隔离测试 AccountService.Delete 方法,避免依赖真实数据库。
//
// 设计说明:
// - exists: 模拟 ExistsByID 返回的存在性结果
// - existsErr: 模拟 ExistsByID 返回的错误
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的账号 ID用于断言验证
type accountRepoStub struct {
exists bool // ExistsByID 的返回值
existsErr error // ExistsByID 的错误返回值
deleteErr error // Delete 的错误返回值
deletedIDs []int64 // 记录已删除的账号 ID 列表
}
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
func (s *accountRepoStub) Create(ctx context.Context, account *Account) error {
panic("unexpected Create call")
}
func (s *accountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
panic("unexpected GetByID call")
}
func (s *accountRepoStub) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
panic("unexpected GetByIDs call")
}
// ExistsByID 返回预设的存在性检查结果。
// 这是 Delete 方法调用的第一个仓储方法,用于验证账号是否存在。
func (s *accountRepoStub) ExistsByID(ctx context.Context, id int64) (bool, error) {
return s.exists, s.existsErr
}
func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
panic("unexpected GetByCRSAccountID call")
}
func (s *accountRepoStub) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) {
panic("unexpected FindByExtraField call")
}
func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
panic("unexpected ListCRSAccountIDs call")
}
func (s *accountRepoStub) Update(ctx context.Context, account *Account) error {
panic("unexpected Update call")
}
// Delete 记录被删除的账号 ID 并返回预设的错误。
// 通过 deletedIDs 可以验证删除操作是否被正确调用。
func (s *accountRepoStub) Delete(ctx context.Context, id int64) error {
s.deletedIDs = append(s.deletedIDs, id)
return s.deleteErr
}
// 以下是接口要求实现但本测试不关心的方法
func (s *accountRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *accountRepoStub) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
panic("unexpected ListByGroup call")
}
func (s *accountRepoStub) ListActive(ctx context.Context) ([]Account, error) {
panic("unexpected ListActive call")
}
func (s *accountRepoStub) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
panic("unexpected ListByPlatform call")
}
func (s *accountRepoStub) UpdateLastUsed(ctx context.Context, id int64) error {
panic("unexpected UpdateLastUsed call")
}
func (s *accountRepoStub) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
panic("unexpected BatchUpdateLastUsed call")
}
func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
panic("unexpected SetError call")
}
func (s *accountRepoStub) ClearError(ctx context.Context, id int64) error {
panic("unexpected ClearError call")
}
func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
panic("unexpected SetSchedulable call")
}
func (s *accountRepoStub) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
panic("unexpected AutoPauseExpiredAccounts call")
}
func (s *accountRepoStub) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
panic("unexpected BindGroups call")
}
func (s *accountRepoStub) ListSchedulable(ctx context.Context) ([]Account, error) {
panic("unexpected ListSchedulable call")
}
func (s *accountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
panic("unexpected ListSchedulableByGroupID call")
}
func (s *accountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
panic("unexpected ListSchedulableByPlatform call")
}
func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
panic("unexpected ListSchedulableByGroupIDAndPlatform call")
}
func (s *accountRepoStub) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
panic("unexpected ListSchedulableByPlatforms call")
}
func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
panic("unexpected ListSchedulableByGroupIDAndPlatforms call")
}
func (s *accountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
panic("unexpected ListSchedulableUngroupedByPlatform call")
}
func (s *accountRepoStub) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
panic("unexpected ListSchedulableUngroupedByPlatforms call")
}
func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
panic("unexpected SetRateLimited call")
}
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
panic("unexpected SetModelRateLimit call")
}
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
panic("unexpected SetOverloaded call")
}
func (s *accountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
panic("unexpected SetTempUnschedulable call")
}
func (s *accountRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error {
panic("unexpected ClearTempUnschedulable call")
}
func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
panic("unexpected ClearRateLimit call")
}
func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
panic("unexpected ClearAntigravityQuotaScopes call")
}
func (s *accountRepoStub) ClearModelRateLimits(ctx context.Context, id int64) error {
panic("unexpected ClearModelRateLimits call")
}
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
panic("unexpected UpdateSessionWindow call")
}
func (s *accountRepoStub) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
panic("unexpected UpdateExtra call")
}
func (s *accountRepoStub) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
panic("unexpected BulkUpdate call")
}
func (s *accountRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
return nil
}
func (s *accountRepoStub) ResetQuotaUsed(ctx context.Context, id int64) error {
return nil
}
// TestAccountService_Delete_NotFound 测试删除不存在的账号时返回正确的错误。
// 预期行为:
// - ExistsByID 返回 false账号不存在
// - 返回 ErrAccountNotFound 错误
// - Delete 方法不被调用deletedIDs 为空)
func TestAccountService_Delete_NotFound(t *testing.T) {
repo := &accountRepoStub{exists: false}
svc := &AccountService{accountRepo: repo}
err := svc.Delete(context.Background(), 55)
require.ErrorIs(t, err, ErrAccountNotFound)
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
}
// TestAccountService_Delete_CheckError 测试存在性检查失败时的错误处理。
// 预期行为:
// - ExistsByID 返回数据库错误
// - 返回包含 "check account" 的错误信息
// - Delete 方法不被调用
func TestAccountService_Delete_CheckError(t *testing.T) {
repo := &accountRepoStub{existsErr: errors.New("db down")}
svc := &AccountService{accountRepo: repo}
err := svc.Delete(context.Background(), 55)
require.Error(t, err)
require.ErrorContains(t, err, "check account") // 验证错误信息包含上下文
require.Empty(t, repo.deletedIDs)
}
// TestAccountService_Delete_DeleteError 测试删除操作失败时的错误处理。
// 预期行为:
// - ExistsByID 返回 true账号存在
// - Delete 被调用但返回错误
// - 返回包含 "delete account" 的错误信息
// - deletedIDs 记录了尝试删除的 ID
func TestAccountService_Delete_DeleteError(t *testing.T) {
repo := &accountRepoStub{
exists: true,
deleteErr: errors.New("delete failed"),
}
svc := &AccountService{accountRepo: repo}
err := svc.Delete(context.Background(), 55)
require.Error(t, err)
require.ErrorContains(t, err, "delete account")
require.Equal(t, []int64{55}, repo.deletedIDs) // 验证删除操作被调用
}
// TestAccountService_Delete_Success 测试删除操作成功的场景。
// 预期行为:
// - ExistsByID 返回 true账号存在
// - Delete 成功执行
// - 返回 nil 错误
// - deletedIDs 记录了被删除的 ID
func TestAccountService_Delete_Success(t *testing.T) {
repo := &accountRepoStub{exists: true}
svc := &AccountService{accountRepo: repo}
err := svc.Delete(context.Background(), 55)
require.NoError(t, err)
require.Equal(t, []int64{55}, repo.deletedIDs) // 验证正确的 ID 被删除
}

View File

@@ -0,0 +1,59 @@
//go:build unit
package service
import (
"encoding/json"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
func TestCreateGeminiTestPayload_ImageModel(t *testing.T) {
t.Parallel()
payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot")
var parsed struct {
Contents []struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
} `json:"contents"`
GenerationConfig struct {
ResponseModalities []string `json:"responseModalities"`
ImageConfig struct {
AspectRatio string `json:"aspectRatio"`
} `json:"imageConfig"`
} `json:"generationConfig"`
}
require.NoError(t, json.Unmarshal(payload, &parsed))
require.Len(t, parsed.Contents, 1)
require.Len(t, parsed.Contents[0].Parts, 1)
require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text)
require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities)
require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio)
}
func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
t.Parallel()
gin.SetMode(gin.TestMode)
ctx, recorder := newSoraTestContext()
svc := &AccountTestService{}
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
err := svc.processGeminiStream(ctx, stream)
require.NoError(t, err)
body := recorder.Body.String()
require.Contains(t, body, "\"type\":\"content\"")
require.Contains(t, body, "\"text\":\"ok\"")
require.Contains(t, body, "\"type\":\"image\"")
require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"")
require.Contains(t, body, "\"mime_type\":\"image/png\"")
}

View File

@@ -0,0 +1,102 @@
//go:build unit
package service
import (
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type openAIAccountTestRepo struct {
mockAccountRepoForGemini
updatedExtra map[string]any
rateLimitedID int64
rateLimitedAt *time.Time
}
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
r.updatedExtra = updates
return nil
}
func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
r.rateLimitedID = id
r.rateLimitedAt = &resetAt
return nil
}
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, recorder := newSoraTestContext()
resp := newJSONResponse(http.StatusOK, "")
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
`))
resp.Header.Set("x-codex-primary-used-percent", "88")
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
resp.Header.Set("x-codex-primary-window-minutes", "10080")
resp.Header.Set("x-codex-secondary-used-percent", "42")
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
resp.Header.Set("x-codex-secondary-window-minutes", "300")
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 89,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
require.NoError(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
require.Equal(t, 88.0, repo.updatedExtra["codex_7d_used_percent"])
require.Contains(t, recorder.Body.String(), "test_complete")
}
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
gin.SetMode(gin.TestMode)
ctx, _ := newSoraTestContext()
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
resp.Header.Set("x-codex-primary-used-percent", "100")
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
resp.Header.Set("x-codex-primary-window-minutes", "10080")
resp.Header.Set("x-codex-secondary-used-percent", "100")
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
resp.Header.Set("x-codex-secondary-window-minutes", "300")
repo := &openAIAccountTestRepo{}
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
account := &Account{
ID: 88,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{"access_token": "test-token"},
}
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
require.Error(t, err)
require.NotEmpty(t, repo.updatedExtra)
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
require.Equal(t, int64(88), repo.rateLimitedID)
require.NotNil(t, repo.rateLimitedAt)
require.NotNil(t, account.RateLimitResetAt)
if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil {
require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second)
}
}

View File

@@ -0,0 +1,319 @@
package service
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/require"
)
type queuedHTTPUpstream struct {
responses []*http.Response
requests []*http.Request
tlsFlags []bool
}
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
return nil, fmt.Errorf("unexpected Do call")
}
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) {
u.requests = append(u.requests, req)
u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint)
if len(u.responses) == 0 {
return nil, fmt.Errorf("no mocked response")
}
resp := u.responses[0]
u.responses = u.responses[1:]
return resp, nil
}
func newJSONResponse(status int, body string) *http.Response {
return &http.Response{
StatusCode: status,
Header: make(http.Header),
Body: io.NopCloser(strings.NewReader(body)),
}
}
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
resp := newJSONResponse(status, body)
resp.Header.Set(key, value)
return resp
}
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
gin.SetMode(gin.TestMode)
rec := httptest.NewRecorder()
c, _ := gin.CreateTestContext(rec)
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
return c, rec
}
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
cfg: &config.Config{
Gateway: config.GatewayConfig{
TLSFingerprint: config.TLSFingerprintConfig{
Enabled: true,
},
},
Sora: config.SoraConfig{
Client: config.SoraClientConfig{
DisableTLSFingerprint: false,
},
},
},
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 4)
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
body := rec.Body.String()
require.Contains(t, body, `"type":"test_start"`)
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
require.Len(t, upstream.requests, 4)
body := rec.Body.String()
require.Contains(t, body, "Sora connection OK - User: demo-user")
require.Contains(t, body, "Subscription check returned 403")
require.Contains(t, body, "Sora2 invite check returned 401")
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"partial_success"`)
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
body := rec.Body.String()
require.Contains(t, body, `"type":"error"`)
require.Contains(t, body, "Cloudflare challenge")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
}
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "Cloudflare challenge")
require.Contains(t, err.Error(), "HTTP 429")
body := rec.Body.String()
require.Contains(t, body, "Cloudflare challenge")
}
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.Error(t, err)
require.Contains(t, err.Error(), "token_invalidated")
body := rec.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"status":"failed"`)
require.Contains(t, body, "token_invalidated")
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
},
}
svc := &AccountTestService{
httpUpstream: upstream,
soraTestCooldown: time.Hour,
}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c1, _ := newSoraTestContext()
err := svc.testSoraAccountConnection(c1, account)
require.NoError(t, err)
c2, rec2 := newSoraTestContext()
err = svc.testSoraAccountConnection(c2, account)
require.Error(t, err)
require.Contains(t, err.Error(), "测试过于频繁")
body := rec2.Body.String()
require.Contains(t, body, `"type":"sora_test_result"`)
require.Contains(t, body, `"code":"test_rate_limited"`)
require.Contains(t, body, `"status":"failed"`)
require.NotContains(t, body, `"type":"test_complete","success":true`)
}
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
upstream := &queuedHTTPUpstream{
responses: []*http.Response{
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
},
}
svc := &AccountTestService{httpUpstream: upstream}
account := &Account{
ID: 1,
Platform: PlatformSora,
Type: AccountTypeOAuth,
Concurrency: 1,
Credentials: map[string]any{
"access_token": "test_token",
},
}
c, rec := newSoraTestContext()
err := svc.testSoraAccountConnection(c, account)
require.NoError(t, err)
body := rec.Body.String()
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
require.Contains(t, body, `"type":"test_complete","success":true`)
}
func TestSanitizeProxyURLForLog(t *testing.T) {
require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
require.Equal(t, "", sanitizeProxyURLForLog(""))
require.Equal(t, "<invalid_proxy_url>", sanitizeProxyURLForLog("://invalid"))
}
func TestExtractSoraEgressIPHint(t *testing.T) {
h := make(http.Header)
h.Set("x-openai-public-ip", "203.0.113.10")
require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
h2 := make(http.Header)
h2.Set("x-envoy-external-address", "198.51.100.9")
require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
}

View File

@@ -0,0 +1,150 @@
package service
import (
"context"
"net/http"
"testing"
"time"
)
type accountUsageCodexProbeRepo struct {
stubOpenAIAccountRepo
updateExtraCh chan map[string]any
rateLimitCh chan time.Time
}
func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
if r.updateExtraCh != nil {
copied := make(map[string]any, len(updates))
for k, v := range updates {
copied[k] = v
}
r.updateExtraCh <- copied
}
return nil
}
func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
if r.rateLimitCh != nil {
r.rateLimitCh <- resetAt
}
return nil
}
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
t.Parallel()
rateLimitedUntil := time.Now().Add(5 * time.Minute)
now := time.Now()
usage := &UsageInfo{
FiveHour: &UsageProgress{Utilization: 0},
SevenDay: &UsageProgress{Utilization: 0},
}
if !shouldRefreshOpenAICodexSnapshot(&Account{RateLimitResetAt: &rateLimitedUntil}, usage, now) {
t.Fatal("expected rate-limited account to force codex snapshot refresh")
}
if shouldRefreshOpenAICodexSnapshot(&Account{}, usage, now) {
t.Fatal("expected complete non-rate-limited usage to skip codex snapshot refresh")
}
if !shouldRefreshOpenAICodexSnapshot(&Account{}, &UsageInfo{FiveHour: nil, SevenDay: &UsageProgress{}}, now) {
t.Fatal("expected missing 5h snapshot to require refresh")
}
staleAt := now.Add(-(openAIProbeCacheTTL + time.Minute)).Format(time.RFC3339)
if !shouldRefreshOpenAICodexSnapshot(&Account{
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
Extra: map[string]any{
"openai_oauth_responses_websockets_v2_enabled": true,
"codex_usage_updated_at": staleAt,
},
}, usage, now) {
t.Fatal("expected stale ws snapshot to trigger refresh")
}
}
func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) {
t.Parallel()
headers := make(http.Header)
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "604800")
headers.Set("x-codex-primary-window-minutes", "10080")
headers.Set("x-codex-secondary-used-percent", "100")
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
headers.Set("x-codex-secondary-window-minutes", "300")
updates, err := extractOpenAICodexProbeUpdates(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
if err != nil {
t.Fatalf("extractOpenAICodexProbeUpdates() error = %v", err)
}
if len(updates) == 0 {
t.Fatal("expected codex probe updates from 429 headers")
}
if got := updates["codex_5h_used_percent"]; got != 100.0 {
t.Fatalf("codex_5h_used_percent = %v, want 100", got)
}
if got := updates["codex_7d_used_percent"]; got != 100.0 {
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
}
}
func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) {
t.Parallel()
headers := make(http.Header)
headers.Set("x-codex-primary-used-percent", "100")
headers.Set("x-codex-primary-reset-after-seconds", "604800")
headers.Set("x-codex-primary-window-minutes", "10080")
headers.Set("x-codex-secondary-used-percent", "100")
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
headers.Set("x-codex-secondary-window-minutes", "300")
updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
if err != nil {
t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err)
}
if len(updates) == 0 {
t.Fatal("expected codex probe updates from 429 headers")
}
if resetAt == nil {
t.Fatal("expected resetAt from exhausted codex headers")
}
}
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) {
t.Parallel()
repo := &accountUsageCodexProbeRepo{
updateExtraCh: make(chan map[string]any, 1),
rateLimitCh: make(chan time.Time, 1),
}
svc := &AccountUsageService{accountRepo: repo}
resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
"codex_7d_used_percent": 100.0,
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
}, &resetAt)
select {
case updates := <-repo.updateExtraCh:
if got := updates["codex_7d_used_percent"]; got != 100.0 {
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
}
case <-time.After(2 * time.Second):
t.Fatal("waiting for codex probe extra persistence timed out")
}
select {
case got := <-repo.rateLimitCh:
if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) {
t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt)
}
case <-time.After(2 * time.Second):
t.Fatal("waiting for codex probe rate limit persistence timed out")
}
}

View File

@@ -0,0 +1,455 @@
//go:build unit
package service
import (
"testing"
)
func TestMatchWildcard(t *testing.T) {
tests := []struct {
name string
pattern string
str string
expected bool
}{
// 精确匹配
{"exact match", "claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"exact mismatch", "claude-sonnet-4-5", "claude-opus-4-5", false},
// 通配符匹配
{"wildcard prefix match", "claude-*", "claude-sonnet-4-5", true},
{"wildcard prefix match 2", "claude-*", "claude-opus-4-5-thinking", true},
{"wildcard prefix mismatch", "claude-*", "gemini-3-flash", false},
{"wildcard partial match", "gemini-3*", "gemini-3-flash", true},
{"wildcard partial match 2", "gemini-3*", "gemini-3-pro-image", true},
{"wildcard partial mismatch", "gemini-3*", "gemini-2.5-flash", false},
// 边界情况
{"empty pattern exact", "", "", true},
{"empty pattern mismatch", "", "claude", false},
{"single star", "*", "anything", true},
{"star at end only", "abc*", "abcdef", true},
{"star at end empty suffix", "abc*", "abc", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := matchWildcard(tt.pattern, tt.str)
if result != tt.expected {
t.Errorf("matchWildcard(%q, %q) = %v, want %v", tt.pattern, tt.str, result, tt.expected)
}
})
}
}
func TestMatchWildcardMappingResult(t *testing.T) {
tests := []struct {
name string
mapping map[string]string
requestedModel string
expected string
matched bool
}{
// 精确匹配优先于通配符
{
name: "exact match takes precedence",
mapping: map[string]string{
"claude-sonnet-4-5": "claude-sonnet-4-5-exact",
"claude-*": "claude-default",
},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5-exact",
matched: true,
},
// 最长通配符优先
{
name: "longer wildcard takes precedence",
mapping: map[string]string{
"claude-*": "claude-default",
"claude-sonnet-*": "claude-sonnet-default",
"claude-sonnet-4*": "claude-sonnet-4-series",
},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-series",
matched: true,
},
// 单个通配符
{
name: "single wildcard",
mapping: map[string]string{
"claude-*": "claude-mapped",
},
requestedModel: "claude-opus-4-5",
expected: "claude-mapped",
matched: true,
},
// 无匹配返回原始模型
{
name: "no match returns original",
mapping: map[string]string{
"claude-*": "claude-mapped",
},
requestedModel: "gemini-3-flash",
expected: "gemini-3-flash",
matched: false,
},
// 空映射返回原始模型
{
name: "empty mapping returns original",
mapping: map[string]string{},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
matched: false,
},
// Gemini 模型映射
{
name: "gemini wildcard mapping",
mapping: map[string]string{
"gemini-3*": "gemini-3-pro-high",
"gemini-2.5*": "gemini-2.5-flash",
},
requestedModel: "gemini-3-flash-preview",
expected: "gemini-3-pro-high",
matched: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel)
if result != tt.expected || matched != tt.matched {
t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched)
}
})
}
}
func TestAccountIsModelSupported(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
requestedModel string
expected bool
}{
// 无映射 = 允许所有
{
name: "no mapping allows all",
credentials: nil,
requestedModel: "any-model",
expected: true,
},
{
name: "empty mapping allows all",
credentials: map[string]any{},
requestedModel: "any-model",
expected: true,
},
// 精确匹配
{
name: "exact match supported",
credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-5": "target-model",
},
},
requestedModel: "claude-sonnet-4-5",
expected: true,
},
{
name: "exact match not supported",
credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-5": "target-model",
},
},
requestedModel: "claude-opus-4-5",
expected: false,
},
// 通配符匹配
{
name: "wildcard match supported",
credentials: map[string]any{
"model_mapping": map[string]any{
"claude-*": "claude-sonnet-4-5",
},
},
requestedModel: "claude-opus-4-5-thinking",
expected: true,
},
{
name: "wildcard match not supported",
credentials: map[string]any{
"model_mapping": map[string]any{
"claude-*": "claude-sonnet-4-5",
},
},
requestedModel: "gemini-3-flash",
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Credentials: tt.credentials,
}
result := account.IsModelSupported(tt.requestedModel)
if result != tt.expected {
t.Errorf("IsModelSupported(%q) = %v, want %v", tt.requestedModel, result, tt.expected)
}
})
}
}
func TestAccountGetMappedModel(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
requestedModel string
expected string
}{
// 无映射 = 返回原始模型
{
name: "no mapping returns original",
credentials: nil,
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
},
// 精确匹配
{
name: "exact match",
credentials: map[string]any{
"model_mapping": map[string]any{
"claude-sonnet-4-5": "target-model",
},
},
requestedModel: "claude-sonnet-4-5",
expected: "target-model",
},
// 通配符匹配(最长优先)
{
name: "wildcard longest match",
credentials: map[string]any{
"model_mapping": map[string]any{
"claude-*": "claude-default",
"claude-sonnet-*": "claude-sonnet-mapped",
},
},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-mapped",
},
// 无匹配返回原始模型
{
name: "no match returns original",
credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-*": "gemini-mapped",
},
},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Credentials: tt.credentials,
}
result := account.GetMappedModel(tt.requestedModel)
if result != tt.expected {
t.Errorf("GetMappedModel(%q) = %q, want %q", tt.requestedModel, result, tt.expected)
}
})
}
}
func TestAccountResolveMappedModel(t *testing.T) {
tests := []struct {
name string
credentials map[string]any
requestedModel string
expectedModel string
expectedMatch bool
}{
{
name: "no mapping reports unmatched",
credentials: nil,
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
{
name: "exact passthrough mapping still counts as matched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.4": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "wildcard passthrough mapping still counts as matched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-*": "gpt-5.4",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: true,
},
{
name: "missing mapping reports unmatched",
credentials: map[string]any{
"model_mapping": map[string]any{
"gpt-5.2": "gpt-5.2",
},
},
requestedModel: "gpt-5.4",
expectedModel: "gpt-5.4",
expectedMatch: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Credentials: tt.credentials,
}
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
if mappedModel != tt.expectedModel || matched != tt.expectedMatch {
t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch)
}
})
}
}
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3-pro-high": "gemini-3.1-pro-high",
},
},
}
mapping := account.GetModelMapping()
if mapping["gemini-3-flash"] != "gemini-3-flash" {
t.Fatalf("expected gemini-3-flash passthrough to be auto-filled, got: %q", mapping["gemini-3-flash"])
}
if mapping["gemini-3.1-pro-high"] != "gemini-3.1-pro-high" {
t.Fatalf("expected gemini-3.1-pro-high passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-high"])
}
if mapping["gemini-3.1-pro-low"] != "gemini-3.1-pro-low" {
t.Fatalf("expected gemini-3.1-pro-low passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-low"])
}
}
func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": map[string]any{
"gemini-3*": "gemini-3.1-pro-high",
},
},
}
mapping := account.GetModelMapping()
if _, exists := mapping["gemini-3-flash"]; exists {
t.Fatalf("did not expect explicit gemini-3-flash passthrough when wildcard already exists")
}
if _, exists := mapping["gemini-3.1-pro-high"]; exists {
t.Fatalf("did not expect explicit gemini-3.1-pro-high passthrough when wildcard already exists")
}
if _, exists := mapping["gemini-3.1-pro-low"]; exists {
t.Fatalf("did not expect explicit gemini-3.1-pro-low passthrough when wildcard already exists")
}
if mapped := account.GetMappedModel("gemini-3-flash"); mapped != "gemini-3.1-pro-high" {
t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace(t *testing.T) {
account := &Account{
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "upstream-a",
},
},
}
first := account.GetModelMapping()
if first["claude-3-5-sonnet"] != "upstream-a" {
t.Fatalf("unexpected first mapping: %v", first)
}
account.Credentials = map[string]any{
"model_mapping": map[string]any{
"claude-3-5-sonnet": "upstream-b",
},
}
second := account.GetModelMapping()
if second["claude-3-5-sonnet"] != "upstream-b" {
t.Fatalf("expected cache invalidated after credentials replace, got: %v", second)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange(t *testing.T) {
rawMapping := map[string]any{
"claude-sonnet": "sonnet-a",
}
account := &Account{
Credentials: map[string]any{
"model_mapping": rawMapping,
},
}
first := account.GetModelMapping()
if len(first) != 1 {
t.Fatalf("unexpected first mapping length: %d", len(first))
}
rawMapping["claude-opus"] = "opus-b"
second := account.GetModelMapping()
if second["claude-opus"] != "opus-b" {
t.Fatalf("expected cache invalidated after mapping len change, got: %v", second)
}
}
func TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange(t *testing.T) {
rawMapping := map[string]any{
"claude-sonnet": "sonnet-a",
}
account := &Account{
Credentials: map[string]any{
"model_mapping": rawMapping,
},
}
first := account.GetModelMapping()
if first["claude-sonnet"] != "sonnet-a" {
t.Fatalf("unexpected first mapping: %v", first)
}
rawMapping["claude-sonnet"] = "sonnet-b"
second := account.GetModelMapping()
if second["claude-sonnet"] != "sonnet-b" {
t.Fatalf("expected cache invalidated after in-place value change, got: %v", second)
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,503 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// Stubs
// ---------------------------------------------------------------------------
// userRepoStubForGroupUpdate implements UserRepository for AdminUpdateAPIKeyGroupID tests.
type userRepoStubForGroupUpdate struct {
addGroupErr error
addGroupCalled bool
addedUserID int64
addedGroupID int64
}
func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context, userID int64, groupID int64) error {
s.addGroupCalled = true
s.addedUserID = userID
s.addedGroupID = groupID
return s.addGroupErr
}
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
panic("unexpected")
}
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
type apiKeyRepoStubForGroupUpdate struct {
key *APIKey
getErr error
updateErr error
updated *APIKey // captures what was passed to Update
}
func (s *apiKeyRepoStubForGroupUpdate) GetByID(_ context.Context, _ int64) (*APIKey, error) {
if s.getErr != nil {
return nil, s.getErr
}
clone := *s.key
return &clone, nil
}
func (s *apiKeyRepoStubForGroupUpdate) Update(_ context.Context, key *APIKey) error {
if s.updateErr != nil {
return s.updateErr
}
clone := *key
s.updated = &clone
return nil
}
// Unused methods panic on unexpected call.
func (s *apiKeyRepoStubForGroupUpdate) Create(context.Context, *APIKey) error { panic("unexpected") }
func (s *apiKeyRepoStubForGroupUpdate) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) GetByKey(context.Context, string) (*APIKey, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) CountByUserID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ExistsByKey(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ListKeysByUserID(context.Context, int64) ([]string, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ListKeysByGroupID(context.Context, int64) ([]string, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) IncrementRateLimitUsage(context.Context, int64, float64) error {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, int64) error {
panic("unexpected")
}
func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
panic("unexpected")
}
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
type groupRepoStubForGroupUpdate struct {
group *Group
getErr error
lastGetByIDArg int64
}
func (s *groupRepoStubForGroupUpdate) GetByID(_ context.Context, id int64) (*Group, error) {
s.lastGetByIDArg = id
if s.getErr != nil {
return nil, s.getErr
}
clone := *s.group
return &clone, nil
}
// Unused methods panic on unexpected call.
func (s *groupRepoStubForGroupUpdate) Create(context.Context, *Group) error { panic("unexpected") }
func (s *groupRepoStubForGroupUpdate) GetByIDLite(context.Context, int64) (*Group, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) Update(context.Context, *Group) error { panic("unexpected") }
func (s *groupRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
func (s *groupRepoStubForGroupUpdate) DeleteCascade(context.Context, int64) ([]int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ListActive(context.Context) ([]Group, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, string) ([]Group, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) BindAccountsToGroup(context.Context, int64, []int64) error {
panic("unexpected")
}
func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
panic("unexpected")
}
type userSubRepoStubForGroupUpdate struct {
userSubRepoNoop
getActiveSub *UserSubscription
getActiveErr error
called bool
calledUserID int64
calledGroupID int64
}
func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) {
s.called = true
s.calledUserID = userID
s.calledGroupID = groupID
if s.getActiveErr != nil {
return nil, s.getActiveErr
}
if s.getActiveSub == nil {
return nil, ErrSubscriptionNotFound
}
clone := *s.getActiveSub
return &clone, nil
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
func TestAdminService_AdminUpdateAPIKeyGroupID_KeyNotFound(t *testing.T) {
repo := &apiKeyRepoStubForGroupUpdate{getErr: ErrAPIKeyNotFound}
svc := &adminServiceImpl{apiKeyRepo: repo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 999, int64Ptr(1))
require.ErrorIs(t, err, ErrAPIKeyNotFound)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NilGroupID_NoOp(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5)}
repo := &apiKeyRepoStubForGroupUpdate{key: existing}
svc := &adminServiceImpl{apiKeyRepo: repo}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, nil)
require.NoError(t, err)
require.Equal(t, int64(1), got.APIKey.ID)
// Update should NOT have been called (updated stays nil)
require.Nil(t, repo.updated)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5), Group: &Group{ID: 5, Name: "Old"}}
repo := &apiKeyRepoStubForGroupUpdate{key: existing}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: repo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
require.NoError(t, err)
require.Nil(t, got.APIKey.GroupID, "group_id should be nil after unbind")
require.Nil(t, got.APIKey.Group, "group object should be nil after unbind")
require.NotNil(t, repo.updated, "Update should have been called")
require.Nil(t, repo.updated.GroupID)
require.Equal(t, []string{"sk-test"}, cache.keys, "cache should be invalidated")
}
func TestAdminService_AdminUpdateAPIKeyGroupID_BindActiveGroup(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
require.Equal(t, []string{"sk-test"}, cache.keys)
// M3: verify correct group ID was passed to repo
require.Equal(t, int64(10), groupRepo.lastGetByIDArg)
// C1 fix: verify Group object is populated
require.NotNil(t, got.APIKey.Group)
require.Equal(t, "Pro", got.APIKey.Group.Name)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SameGroup_Idempotent(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Pro"}}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
// Update is still called (current impl doesn't short-circuit on same group)
require.NotNil(t, apiKeyRepo.updated)
require.Equal(t, []string{"sk-test"}, cache.keys)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotFound(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{getErr: ErrGroupNotFound}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(99))
require.ErrorIs(t, err, ErrGroupNotFound)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotActive(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 5, Status: StatusDisabled}}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(5))
require.Error(t, err)
require.Equal(t, "GROUP_NOT_ACTIVE", infraerrors.Reason(err))
}
func TestAdminService_AdminUpdateAPIKeyGroupID_UpdateFails(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(3)}
repo := &apiKeyRepoStubForGroupUpdate{key: existing, updateErr: errors.New("db write error")}
svc := &adminServiceImpl{apiKeyRepo: repo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
require.Error(t, err)
require.Contains(t, err.Error(), "update api key")
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NegativeGroupID(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(-5))
require.Error(t, err)
require.Equal(t, "INVALID_GROUP_ID", infraerrors.Reason(err))
}
func TestAdminService_AdminUpdateAPIKeyGroupID_PointerIsolation(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
inputGID := int64(10)
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, &inputGID)
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
// Mutating the input pointer must NOT affect the stored value
inputGID = 999
require.Equal(t, int64(10), *got.APIKey.GroupID)
require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NilCacheInvalidator(t *testing.T) {
existing := &APIKey{ID: 1, Key: "sk-test"}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 7, Status: StatusActive}}
// authCacheInvalidator is nil should not panic
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(7))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(7), *got.APIKey.GroupID)
}
// ---------------------------------------------------------------------------
// Tests: AllowedGroup auto-sync
// ---------------------------------------------------------------------------
func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AddsAllowedGroup(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
userRepo := &userRepoStubForGroupUpdate{}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
// 验证 AddGroupToAllowedGroups 被调用,且参数正确
require.True(t, userRepo.addGroupCalled)
require.Equal(t, int64(42), userRepo.addedUserID)
require.Equal(t, int64(10), userRepo.addedGroupID)
// 验证 result 标记了自动授权
require.True(t, got.AutoGrantedGroupAccess)
require.NotNil(t, got.GrantedGroupID)
require.Equal(t, int64(10), *got.GrantedGroupID)
require.Equal(t, "Exclusive", got.GrantedGroupName)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupUpdate(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Public", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeStandard}}
userRepo := &userRepoStubForGroupUpdate{}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.NotNil(t, got.APIKey.GroupID)
// 非专属分组不触发 AddGroupToAllowedGroups
require.False(t, userRepo.addGroupCalled)
require.False(t, got.AutoGrantedGroupAccess)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
// 无有效订阅时应拒绝绑定
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err))
require.True(t, userSubRepo.called)
require.Equal(t, int64(42), userSubRepo.calledUserID)
require.Equal(t, int64(10), userSubRepo.calledGroupID)
require.False(t, userRepo.addGroupCalled)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err))
require.False(t, userRepo.addGroupCalled)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
userRepo := &userRepoStubForGroupUpdate{}
userSubRepo := &userSubRepoStubForGroupUpdate{
getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10},
}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.NoError(t, err)
require.True(t, userSubRepo.called)
require.NotNil(t, got.APIKey.GroupID)
require.Equal(t, int64(10), *got.APIKey.GroupID)
require.False(t, userRepo.addGroupCalled)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AllowedGroupAddFails_ReturnsError(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
userRepo := &userRepoStubForGroupUpdate{addGroupErr: errors.New("db error")}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
// 严格模式AddGroupToAllowedGroups 失败时,整体操作报错
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
require.Error(t, err)
require.Contains(t, err.Error(), "add group to user allowed groups")
require.True(t, userRepo.addGroupCalled)
// apiKey 不应被更新
require.Nil(t, apiKeyRepo.updated)
}
func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind_NoAllowedGroupUpdate(t *testing.T) {
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Exclusive"}}
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
userRepo := &userRepoStubForGroupUpdate{}
cache := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, userRepo: userRepo, authCacheInvalidator: cache}
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
require.NoError(t, err)
require.Nil(t, got.APIKey.GroupID)
// 解绑时不修改 allowed_groups
require.False(t, userRepo.addGroupCalled)
require.False(t, got.AutoGrantedGroupAccess)
}

View File

@@ -0,0 +1,172 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
type accountRepoStubForBulkUpdate struct {
accountRepoStub
bulkUpdateErr error
bulkUpdateIDs []int64
bindGroupErrByID map[int64]error
bindGroupsCalls []int64
getByIDsAccounts []*Account
getByIDsErr error
getByIDsCalled bool
getByIDsIDs []int64
getByIDAccounts map[int64]*Account
getByIDErrByID map[int64]error
getByIDCalled []int64
listByGroupData map[int64][]Account
listByGroupErr map[int64]error
}
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
s.bulkUpdateIDs = append([]int64{}, ids...)
if s.bulkUpdateErr != nil {
return 0, s.bulkUpdateErr
}
return int64(len(ids)), nil
}
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
s.bindGroupsCalls = append(s.bindGroupsCalls, accountID)
if err, ok := s.bindGroupErrByID[accountID]; ok {
return err
}
return nil
}
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
s.getByIDsCalled = true
s.getByIDsIDs = append([]int64{}, ids...)
if s.getByIDsErr != nil {
return nil, s.getByIDsErr
}
return s.getByIDsAccounts, nil
}
func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Account, error) {
s.getByIDCalled = append(s.getByIDCalled, id)
if err, ok := s.getByIDErrByID[id]; ok {
return nil, err
}
if account, ok := s.getByIDAccounts[id]; ok {
return account, nil
}
return nil, errors.New("account not found")
}
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
if err, ok := s.listByGroupErr[groupID]; ok {
return nil, err
}
if rows, ok := s.listByGroupData[groupID]; ok {
return rows, nil
}
return nil, nil
}
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
svc := &adminServiceImpl{accountRepo: repo}
schedulable := true
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1, 2, 3},
Schedulable: &schedulable,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 3, result.Success)
require.Equal(t, 0, result.Failed)
require.ElementsMatch(t, []int64{1, 2, 3}, result.SuccessIDs)
require.Empty(t, result.FailedIDs)
require.Len(t, result.Results, 3)
}
// TestAdminService_BulkUpdateAccounts_PartialFailureIDs 验证部分失败时 success_ids/failed_ids 正确。
func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
bindGroupErrByID: map[int64]error{
2: errors.New("bind failed"),
},
}
svc := &adminServiceImpl{
accountRepo: repo,
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "g10"}},
}
groupIDs := []int64{10}
schedulable := false
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1, 2, 3},
GroupIDs: &groupIDs,
Schedulable: &schedulable,
SkipMixedChannelCheck: true,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.NoError(t, err)
require.Equal(t, 2, result.Success)
require.Equal(t, 1, result.Failed)
require.ElementsMatch(t, []int64{1, 3}, result.SuccessIDs)
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
require.Len(t, result.Results, 3)
}
func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{}
svc := &adminServiceImpl{accountRepo: repo}
groupIDs := []int64{10}
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1},
GroupIDs: &groupIDs,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "group repository not configured")
}
// TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict verifies
// that the global pre-check detects a conflict with existing group members and returns an
// error before any DB write is performed.
func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict(t *testing.T) {
repo := &accountRepoStubForBulkUpdate{
getByIDsAccounts: []*Account{
{ID: 1, Platform: PlatformAntigravity},
},
// Group 10 already contains an Anthropic account.
listByGroupData: map[int64][]Account{
10: {{ID: 99, Platform: PlatformAnthropic}},
},
}
svc := &adminServiceImpl{
accountRepo: repo,
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "target-group"}},
}
groupIDs := []int64{10}
input := &BulkUpdateAccountsInput{
AccountIDs: []int64{1},
GroupIDs: &groupIDs,
}
result, err := svc.BulkUpdateAccounts(context.Background(), input)
require.Nil(t, result)
require.Error(t, err)
require.Contains(t, err.Error(), "mixed channel")
// No BindGroups should have been called since the check runs before any write.
require.Empty(t, repo.bindGroupsCalls)
}

View File

@@ -0,0 +1,97 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func TestAdminService_CreateUser_Success(t *testing.T) {
repo := &userRepoStub{nextID: 10}
svc := &adminServiceImpl{userRepo: repo}
input := &CreateUserInput{
Email: "user@test.com",
Password: "strong-pass",
Username: "tester",
Notes: "note",
Balance: 12.5,
Concurrency: 7,
AllowedGroups: []int64{3, 5},
}
user, err := svc.CreateUser(context.Background(), input)
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, int64(10), user.ID)
require.Equal(t, input.Email, user.Email)
require.Equal(t, input.Username, user.Username)
require.Equal(t, input.Notes, user.Notes)
require.Equal(t, input.Balance, user.Balance)
require.Equal(t, input.Concurrency, user.Concurrency)
require.Equal(t, input.AllowedGroups, user.AllowedGroups)
require.Equal(t, RoleUser, user.Role)
require.Equal(t, StatusActive, user.Status)
require.True(t, user.CheckPassword(input.Password))
require.Len(t, repo.created, 1)
require.Equal(t, user, repo.created[0])
}
func TestAdminService_CreateUser_EmailExists(t *testing.T) {
repo := &userRepoStub{createErr: ErrEmailExists}
svc := &adminServiceImpl{userRepo: repo}
_, err := svc.CreateUser(context.Background(), &CreateUserInput{
Email: "dup@test.com",
Password: "password",
})
require.ErrorIs(t, err, ErrEmailExists)
require.Empty(t, repo.created)
}
func TestAdminService_CreateUser_CreateError(t *testing.T) {
createErr := errors.New("db down")
repo := &userRepoStub{createErr: createErr}
svc := &adminServiceImpl{userRepo: repo}
_, err := svc.CreateUser(context.Background(), &CreateUserInput{
Email: "user@test.com",
Password: "password",
})
require.ErrorIs(t, err, createErr)
require.Empty(t, repo.created)
}
func TestAdminService_CreateUser_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 21}
assigner := &defaultSubscriptionAssignerStub{}
cfg := &config.Config{
Default: config.DefaultConfig{
UserBalance: 0,
UserConcurrency: 1,
},
}
settingService := NewSettingService(&settingRepoStub{values: map[string]string{
SettingKeyDefaultSubscriptions: `[{"group_id":5,"validity_days":30}]`,
}}, cfg)
svc := &adminServiceImpl{
userRepo: repo,
settingService: settingService,
defaultSubAssigner: assigner,
}
_, err := svc.CreateUser(context.Background(), &CreateUserInput{
Email: "new-user@test.com",
Password: "password",
})
require.NoError(t, err)
require.Len(t, assigner.calls, 1)
require.Equal(t, int64(21), assigner.calls[0].UserID)
require.Equal(t, int64(5), assigner.calls[0].GroupID)
require.Equal(t, 30, assigner.calls[0].ValidityDays)
}

View File

@@ -0,0 +1,542 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type userRepoStub struct {
user *User
getErr error
createErr error
deleteErr error
exists bool
existsErr error
nextID int64
created []*User
deletedIDs []int64
}
func (s *userRepoStub) Create(ctx context.Context, user *User) error {
if s.createErr != nil {
return s.createErr
}
if s.nextID != 0 && user.ID == 0 {
user.ID = s.nextID
}
s.created = append(s.created, user)
return nil
}
func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
if s.getErr != nil {
return nil, s.getErr
}
if s.user == nil {
return nil, ErrUserNotFound
}
return s.user, nil
}
func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) {
panic("unexpected GetByEmail call")
}
func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
panic("unexpected GetFirstAdmin call")
}
func (s *userRepoStub) Update(ctx context.Context, user *User) error {
panic("unexpected Update call")
}
func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
s.deletedIDs = append(s.deletedIDs, id)
return s.deleteErr
}
func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *userRepoStub) UpdateBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected UpdateBalance call")
}
func (s *userRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
panic("unexpected DeductBalance call")
}
func (s *userRepoStub) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
panic("unexpected UpdateConcurrency call")
}
func (s *userRepoStub) ExistsByEmail(ctx context.Context, email string) (bool, error) {
if s.existsErr != nil {
return false, s.existsErr
}
return s.exists, nil
}
func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected RemoveGroupFromAllowedGroups call")
}
func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
panic("unexpected AddGroupToAllowedGroups call")
}
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
panic("unexpected UpdateTotpSecret call")
}
func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
panic("unexpected EnableTotp call")
}
func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
panic("unexpected DisableTotp call")
}
type groupRepoStub struct {
affectedUserIDs []int64
deleteErr error
deleteCalls []int64
}
func (s *groupRepoStub) Create(ctx context.Context, group *Group) error {
panic("unexpected Create call")
}
func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByID call")
}
func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
panic("unexpected GetByIDLite call")
}
func (s *groupRepoStub) Update(ctx context.Context, group *Group) error {
panic("unexpected Update call")
}
func (s *groupRepoStub) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStub) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
s.deleteCalls = append(s.deleteCalls, id)
return s.affectedUserIDs, s.deleteErr
}
func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *groupRepoStub) ListActive(ctx context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStub) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (s *groupRepoStub) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
func (s *groupRepoStub) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
return nil
}
type proxyRepoStub struct {
deleteErr error
countErr error
accountCount int64
deletedIDs []int64
}
func (s *proxyRepoStub) Create(ctx context.Context, proxy *Proxy) error {
panic("unexpected Create call")
}
func (s *proxyRepoStub) GetByID(ctx context.Context, id int64) (*Proxy, error) {
panic("unexpected GetByID call")
}
func (s *proxyRepoStub) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) {
panic("unexpected ListByIDs call")
}
func (s *proxyRepoStub) Update(ctx context.Context, proxy *Proxy) error {
panic("unexpected Update call")
}
func (s *proxyRepoStub) Delete(ctx context.Context, id int64) error {
s.deletedIDs = append(s.deletedIDs, id)
return s.deleteErr
}
func (s *proxyRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *proxyRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *proxyRepoStub) ListActive(ctx context.Context) ([]Proxy, error) {
panic("unexpected ListActive call")
}
func (s *proxyRepoStub) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
panic("unexpected ListActiveWithAccountCount call")
}
func (s *proxyRepoStub) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
panic("unexpected ListWithFiltersAndAccountCount call")
}
func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
panic("unexpected ExistsByHostPortAuth call")
}
func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
if s.countErr != nil {
return 0, s.countErr
}
return s.accountCount, nil
}
func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
panic("unexpected ListAccountSummariesByProxyID call")
}
type redeemRepoStub struct {
deleteErrByID map[int64]error
deletedIDs []int64
}
func (s *redeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
panic("unexpected Create call")
}
func (s *redeemRepoStub) CreateBatch(ctx context.Context, codes []RedeemCode) error {
panic("unexpected CreateBatch call")
}
func (s *redeemRepoStub) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
panic("unexpected GetByID call")
}
func (s *redeemRepoStub) GetByCode(ctx context.Context, code string) (*RedeemCode, error) {
panic("unexpected GetByCode call")
}
func (s *redeemRepoStub) Update(ctx context.Context, code *RedeemCode) error {
panic("unexpected Update call")
}
func (s *redeemRepoStub) Delete(ctx context.Context, id int64) error {
s.deletedIDs = append(s.deletedIDs, id)
if s.deleteErrByID != nil {
if err, ok := s.deleteErrByID[id]; ok {
return err
}
}
return nil
}
func (s *redeemRepoStub) Use(ctx context.Context, id, userID int64) error {
panic("unexpected Use call")
}
func (s *redeemRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *redeemRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *redeemRepoStub) ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) {
panic("unexpected ListByUser call")
}
func (s *redeemRepoStub) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected ListByUserPaginated call")
}
func (s *redeemRepoStub) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
panic("unexpected SumPositiveBalanceByUser call")
}
type subscriptionInvalidateCall struct {
userID int64
groupID int64
}
type billingCacheStub struct {
invalidations chan subscriptionInvalidateCall
}
func newBillingCacheStub(buffer int) *billingCacheStub {
return &billingCacheStub{invalidations: make(chan subscriptionInvalidateCall, buffer)}
}
func (s *billingCacheStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
panic("unexpected GetUserBalance call")
}
func (s *billingCacheStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
panic("unexpected SetUserBalance call")
}
func (s *billingCacheStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
panic("unexpected DeductUserBalance call")
}
func (s *billingCacheStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
panic("unexpected InvalidateUserBalance call")
}
func (s *billingCacheStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
panic("unexpected GetSubscriptionCache call")
}
func (s *billingCacheStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
panic("unexpected SetSubscriptionCache call")
}
func (s *billingCacheStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
panic("unexpected UpdateSubscriptionUsage call")
}
func (s *billingCacheStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
s.invalidations <- subscriptionInvalidateCall{userID: userID, groupID: groupID}
return nil
}
func (s *billingCacheStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
panic("unexpected GetAPIKeyRateLimit call")
}
func (s *billingCacheStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
panic("unexpected SetAPIKeyRateLimit call")
}
func (s *billingCacheStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
panic("unexpected UpdateAPIKeyRateLimitUsage call")
}
func (s *billingCacheStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
panic("unexpected InvalidateAPIKeyRateLimit call")
}
func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall {
t.Helper()
calls := make([]subscriptionInvalidateCall, 0, expected)
timeout := time.After(2 * time.Second)
for len(calls) < expected {
select {
case call := <-ch:
calls = append(calls, call)
case <-timeout:
t.Fatalf("timeout waiting for %d invalidations, got %d", expected, len(calls))
}
}
return calls
}
func TestAdminService_DeleteUser_Success(t *testing.T) {
repo := &userRepoStub{user: &User{ID: 7, Role: RoleUser}}
svc := &adminServiceImpl{userRepo: repo}
err := svc.DeleteUser(context.Background(), 7)
require.NoError(t, err)
require.Equal(t, []int64{7}, repo.deletedIDs)
}
func TestAdminService_DeleteUser_NotFound(t *testing.T) {
repo := &userRepoStub{getErr: ErrUserNotFound}
svc := &adminServiceImpl{userRepo: repo}
err := svc.DeleteUser(context.Background(), 404)
require.ErrorIs(t, err, ErrUserNotFound)
require.Empty(t, repo.deletedIDs)
}
func TestAdminService_DeleteUser_AdminGuard(t *testing.T) {
repo := &userRepoStub{user: &User{ID: 1, Role: RoleAdmin}}
svc := &adminServiceImpl{userRepo: repo}
err := svc.DeleteUser(context.Background(), 1)
require.Error(t, err)
require.ErrorContains(t, err, "cannot delete admin user")
require.Empty(t, repo.deletedIDs)
}
func TestAdminService_DeleteUser_DeleteError(t *testing.T) {
deleteErr := errors.New("delete failed")
repo := &userRepoStub{
user: &User{ID: 9, Role: RoleUser},
deleteErr: deleteErr,
}
svc := &adminServiceImpl{userRepo: repo}
err := svc.DeleteUser(context.Background(), 9)
require.ErrorIs(t, err, deleteErr)
require.Equal(t, []int64{9}, repo.deletedIDs)
}
func TestAdminService_DeleteGroup_Success_WithCacheInvalidation(t *testing.T) {
cache := newBillingCacheStub(2)
repo := &groupRepoStub{affectedUserIDs: []int64{11, 12}}
svc := &adminServiceImpl{
groupRepo: repo,
billingCacheService: &BillingCacheService{cache: cache},
}
err := svc.DeleteGroup(context.Background(), 5)
require.NoError(t, err)
require.Equal(t, []int64{5}, repo.deleteCalls)
calls := waitForInvalidations(t, cache.invalidations, 2)
require.ElementsMatch(t, []subscriptionInvalidateCall{
{userID: 11, groupID: 5},
{userID: 12, groupID: 5},
}, calls)
}
func TestAdminService_DeleteGroup_NotFound(t *testing.T) {
repo := &groupRepoStub{deleteErr: ErrGroupNotFound}
svc := &adminServiceImpl{groupRepo: repo}
err := svc.DeleteGroup(context.Background(), 99)
require.ErrorIs(t, err, ErrGroupNotFound)
}
func TestAdminService_DeleteGroup_Error(t *testing.T) {
deleteErr := errors.New("delete failed")
repo := &groupRepoStub{deleteErr: deleteErr}
svc := &adminServiceImpl{groupRepo: repo}
err := svc.DeleteGroup(context.Background(), 42)
require.ErrorIs(t, err, deleteErr)
}
func TestAdminService_DeleteProxy_Success(t *testing.T) {
repo := &proxyRepoStub{}
svc := &adminServiceImpl{proxyRepo: repo}
err := svc.DeleteProxy(context.Background(), 7)
require.NoError(t, err)
require.Equal(t, []int64{7}, repo.deletedIDs)
}
func TestAdminService_DeleteProxy_Idempotent(t *testing.T) {
repo := &proxyRepoStub{}
svc := &adminServiceImpl{proxyRepo: repo}
err := svc.DeleteProxy(context.Background(), 404)
require.NoError(t, err)
require.Equal(t, []int64{404}, repo.deletedIDs)
}
func TestAdminService_DeleteProxy_InUse(t *testing.T) {
repo := &proxyRepoStub{accountCount: 2}
svc := &adminServiceImpl{proxyRepo: repo}
err := svc.DeleteProxy(context.Background(), 77)
require.ErrorIs(t, err, ErrProxyInUse)
require.Empty(t, repo.deletedIDs)
}
func TestAdminService_DeleteProxy_Error(t *testing.T) {
deleteErr := errors.New("delete failed")
repo := &proxyRepoStub{deleteErr: deleteErr}
svc := &adminServiceImpl{proxyRepo: repo}
err := svc.DeleteProxy(context.Background(), 33)
require.ErrorIs(t, err, deleteErr)
}
func TestAdminService_DeleteRedeemCode_Success(t *testing.T) {
repo := &redeemRepoStub{}
svc := &adminServiceImpl{redeemCodeRepo: repo}
err := svc.DeleteRedeemCode(context.Background(), 10)
require.NoError(t, err)
require.Equal(t, []int64{10}, repo.deletedIDs)
}
func TestAdminService_DeleteRedeemCode_Idempotent(t *testing.T) {
repo := &redeemRepoStub{}
svc := &adminServiceImpl{redeemCodeRepo: repo}
err := svc.DeleteRedeemCode(context.Background(), 999)
require.NoError(t, err)
require.Equal(t, []int64{999}, repo.deletedIDs)
}
func TestAdminService_DeleteRedeemCode_Error(t *testing.T) {
deleteErr := errors.New("delete failed")
repo := &redeemRepoStub{deleteErrByID: map[int64]error{1: deleteErr}}
svc := &adminServiceImpl{redeemCodeRepo: repo}
err := svc.DeleteRedeemCode(context.Background(), 1)
require.ErrorIs(t, err, deleteErr)
require.Equal(t, []int64{1}, repo.deletedIDs)
}
func TestAdminService_BatchDeleteRedeemCodes_Success(t *testing.T) {
repo := &redeemRepoStub{}
svc := &adminServiceImpl{redeemCodeRepo: repo}
deleted, err := svc.BatchDeleteRedeemCodes(context.Background(), []int64{1, 2, 3})
require.NoError(t, err)
require.Equal(t, int64(3), deleted)
require.Equal(t, []int64{1, 2, 3}, repo.deletedIDs)
}
func TestAdminService_BatchDeleteRedeemCodes_PartialFailures(t *testing.T) {
repo := &redeemRepoStub{
deleteErrByID: map[int64]error{
2: errors.New("db error"),
},
}
svc := &adminServiceImpl{redeemCodeRepo: repo}
deleted, err := svc.BatchDeleteRedeemCodes(context.Background(), []int64{1, 2, 3})
require.NoError(t, err)
require.Equal(t, int64(2), deleted)
require.Equal(t, []int64{1, 2, 3}, repo.deletedIDs)
}

View File

@@ -0,0 +1,176 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/require"
)
// userGroupRateRepoStubForGroupRate implements UserGroupRateRepository for group rate tests.
type userGroupRateRepoStubForGroupRate struct {
getByGroupIDData map[int64][]UserGroupRateEntry
getByGroupIDErr error
deletedGroupIDs []int64
deleteByGroupErr error
syncedGroupID int64
syncedEntries []GroupRateMultiplierInput
syncGroupErr error
}
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
panic("unexpected GetByUserID call")
}
func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, _, _ int64) (*float64, error) {
panic("unexpected GetByUserAndGroup call")
}
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
if s.getByGroupIDErr != nil {
return nil, s.getByGroupIDErr
}
return s.getByGroupIDData[groupID], nil
}
func (s *userGroupRateRepoStubForGroupRate) SyncUserGroupRates(_ context.Context, _ int64, _ map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call")
}
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
s.syncedGroupID = groupID
s.syncedEntries = entries
return s.syncGroupErr
}
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
return s.deleteByGroupErr
}
func (s *userGroupRateRepoStubForGroupRate) DeleteByUserID(_ context.Context, _ int64) error {
panic("unexpected DeleteByUserID call")
}
func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
t.Run("returns entries for group", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{
10: {
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
},
},
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
require.NoError(t, err)
require.Len(t, entries, 2)
require.Equal(t, int64(1), entries[0].UserID)
require.Equal(t, "alice", entries[0].UserName)
require.Equal(t, 1.5, entries[0].RateMultiplier)
require.Equal(t, int64(2), entries[1].UserID)
require.Equal(t, 0.8, entries[1].RateMultiplier)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
svc := &adminServiceImpl{userGroupRateRepo: nil}
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
require.NoError(t, err)
require.Nil(t, entries)
})
t.Run("returns empty slice for group with no entries", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDData: map[int64][]UserGroupRateEntry{},
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
entries, err := svc.GetGroupRateMultipliers(context.Background(), 99)
require.NoError(t, err)
require.Nil(t, entries)
})
t.Run("propagates repo error", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
getByGroupIDErr: errors.New("db error"),
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
_, err := svc.GetGroupRateMultipliers(context.Background(), 10)
require.Error(t, err)
require.Contains(t, err.Error(), "db error")
})
}
func TestAdminService_ClearGroupRateMultipliers(t *testing.T) {
t.Run("deletes by group ID", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
require.NoError(t, err)
require.Equal(t, []int64{42}, repo.deletedGroupIDs)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
svc := &adminServiceImpl{userGroupRateRepo: nil}
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
require.NoError(t, err)
})
t.Run("propagates repo error", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
deleteByGroupErr: errors.New("delete failed"),
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
require.Error(t, err)
require.Contains(t, err.Error(), "delete failed")
})
}
func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
t.Run("syncs entries to repo", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{}
svc := &adminServiceImpl{userGroupRateRepo: repo}
entries := []GroupRateMultiplierInput{
{UserID: 1, RateMultiplier: 1.5},
{UserID: 2, RateMultiplier: 0.8},
}
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, entries)
require.NoError(t, err)
require.Equal(t, int64(10), repo.syncedGroupID)
require.Equal(t, entries, repo.syncedEntries)
})
t.Run("returns nil when repo is nil", func(t *testing.T) {
svc := &adminServiceImpl{userGroupRateRepo: nil}
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, nil)
require.NoError(t, err)
})
t.Run("propagates repo error", func(t *testing.T) {
repo := &userGroupRateRepoStubForGroupRate{
syncGroupErr: errors.New("sync failed"),
}
svc := &adminServiceImpl{userGroupRateRepo: repo}
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, []GroupRateMultiplierInput{
{UserID: 1, RateMultiplier: 1.0},
})
require.Error(t, err)
require.Contains(t, err.Error(), "sync failed")
})
}

View File

@@ -0,0 +1,787 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
type groupRepoStubForAdmin struct {
created *Group // 记录 Create 调用的参数
updated *Group // 记录 Update 调用的参数
getByID *Group // GetByID 返回值
getErr error // GetByID 返回的错误
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersPlatform string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersIsExclusive *bool
listWithFiltersGroups []Group
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
}
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
s.created = g
return nil
}
func (s *groupRepoStubForAdmin) Update(_ context.Context, g *Group) error {
s.updated = g
return nil
}
func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, error) {
if s.getErr != nil {
return nil, s.getErr
}
return s.getByID, nil
}
func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) {
if s.getErr != nil {
return nil, s.getErr
}
return s.getByID, nil
}
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStubForAdmin) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
s.listWithFiltersIsExclusive = isExclusive
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersGroups)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersGroups, result, nil
}
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStubForAdmin) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (s *groupRepoStubForAdmin) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error {
return nil
}
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
price1K := 0.10
price2K := 0.15
price4K := 0.30
input := &CreateGroupInput{
Name: "test-group",
Description: "Test group",
Platform: PlatformAntigravity,
RateMultiplier: 1.0,
ImagePrice1K: &price1K,
ImagePrice2K: &price2K,
ImagePrice4K: &price4K,
}
group, err := svc.CreateGroup(context.Background(), input)
require.NoError(t, err)
require.NotNil(t, group)
// 验证 repo 收到了正确的字段
require.NotNil(t, repo.created)
require.NotNil(t, repo.created.ImagePrice1K)
require.NotNil(t, repo.created.ImagePrice2K)
require.NotNil(t, repo.created.ImagePrice4K)
require.InDelta(t, 0.10, *repo.created.ImagePrice1K, 0.0001)
require.InDelta(t, 0.15, *repo.created.ImagePrice2K, 0.0001)
require.InDelta(t, 0.30, *repo.created.ImagePrice4K, 0.0001)
}
// TestAdminService_CreateGroup_NilImagePricing 测试 ImagePrice 为 nil 时正常创建
func TestAdminService_CreateGroup_NilImagePricing(t *testing.T) {
repo := &groupRepoStubForAdmin{}
svc := &adminServiceImpl{groupRepo: repo}
input := &CreateGroupInput{
Name: "test-group",
Description: "Test group",
Platform: PlatformAntigravity,
RateMultiplier: 1.0,
// ImagePrice 字段全部为 nil
}
group, err := svc.CreateGroup(context.Background(), input)
require.NoError(t, err)
require.NotNil(t, group)
// 验证 ImagePrice 字段为 nil
require.NotNil(t, repo.created)
require.Nil(t, repo.created.ImagePrice1K)
require.Nil(t, repo.created.ImagePrice2K)
require.Nil(t, repo.created.ImagePrice4K)
}
// TestAdminService_UpdateGroup_WithImagePricing 测试更新分组时 ImagePrice 字段正确更新
func TestAdminService_UpdateGroup_WithImagePricing(t *testing.T) {
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformAntigravity,
Status: StatusActive,
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
price1K := 0.12
price2K := 0.18
price4K := 0.36
input := &UpdateGroupInput{
ImagePrice1K: &price1K,
ImagePrice2K: &price2K,
ImagePrice4K: &price4K,
}
group, err := svc.UpdateGroup(context.Background(), 1, input)
require.NoError(t, err)
require.NotNil(t, group)
// 验证 repo 收到了更新后的字段
require.NotNil(t, repo.updated)
require.NotNil(t, repo.updated.ImagePrice1K)
require.NotNil(t, repo.updated.ImagePrice2K)
require.NotNil(t, repo.updated.ImagePrice4K)
require.InDelta(t, 0.12, *repo.updated.ImagePrice1K, 0.0001)
require.InDelta(t, 0.18, *repo.updated.ImagePrice2K, 0.0001)
require.InDelta(t, 0.36, *repo.updated.ImagePrice4K, 0.0001)
}
// TestAdminService_UpdateGroup_PartialImagePricing 测试仅更新部分 ImagePrice 字段
func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
oldPrice2K := 0.15
existingGroup := &Group{
ID: 1,
Name: "existing-group",
Platform: PlatformAntigravity,
Status: StatusActive,
ImagePrice2K: &oldPrice2K, // 已有 2K 价格
}
repo := &groupRepoStubForAdmin{getByID: existingGroup}
svc := &adminServiceImpl{groupRepo: repo}
// 只更新 1K 价格
price1K := 0.10
input := &UpdateGroupInput{
ImagePrice1K: &price1K,
// ImagePrice2K 和 ImagePrice4K 为 nil不更新
}
group, err := svc.UpdateGroup(context.Background(), 1, input)
require.NoError(t, err)
require.NotNil(t, group)
// 验证1K 被更新2K 保持原值4K 仍为 nil
require.NotNil(t, repo.updated)
require.NotNil(t, repo.updated.ImagePrice1K)
require.InDelta(t, 0.10, *repo.updated.ImagePrice1K, 0.0001)
require.NotNil(t, repo.updated.ImagePrice2K)
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
require.Nil(t, repo.updated.ImagePrice4K)
}
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
// 测试:
// 1. search 参数正常传递到 repository 层
// 2. search 为空字符串时的行为
// 3. search 与其他过滤条件组合使用
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
require.NoError(t, err)
require.Equal(t, int64(1), total)
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, "alpha", repo.listWithFiltersSearch)
require.Nil(t, repo.listWithFiltersIsExclusive)
})
t.Run("search 为空字符串时传递空字符串", func(t *testing.T) {
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{},
listWithFiltersResult: &pagination.PaginationResult{Total: 0},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
require.NoError(t, err)
require.Empty(t, groups)
require.Equal(t, int64(0), total)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams)
require.Equal(t, "", repo.listWithFiltersSearch)
require.Nil(t, repo.listWithFiltersIsExclusive)
})
t.Run("search 与其他过滤条件组合使用", func(t *testing.T) {
isExclusive := true
repo := &groupRepoStubForAdmin{
listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 42},
}
svc := &adminServiceImpl{groupRepo: repo}
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
require.NoError(t, err)
require.Equal(t, int64(42), total)
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "beta", repo.listWithFiltersSearch)
require.NotNil(t, repo.listWithFiltersIsExclusive)
require.True(t, *repo.listWithFiltersIsExclusive)
})
}
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
groupID := int64(1)
fallbackID := int64(2)
repo := &groupRepoStubForFallbackCycle{
groups: map[int64]*Group{
groupID: {
ID: groupID,
FallbackGroupID: &fallbackID,
},
fallbackID: {
ID: fallbackID,
FallbackGroupID: &groupID,
},
},
}
svc := &adminServiceImpl{groupRepo: repo}
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group cycle")
}
type groupRepoStubForFallbackCycle struct {
groups map[int64]*Group
}
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
panic("unexpected Create call")
}
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
panic("unexpected Update call")
}
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
return s.GetByIDLite(ctx, id)
}
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
if g, ok := s.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
func (s *groupRepoStubForFallbackCycle) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error {
return nil
}
type groupRepoStubForInvalidRequestFallback struct {
groups map[int64]*Group
created *Group
updated *Group
}
func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error {
s.created = g
return nil
}
func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error {
s.updated = g
return nil
}
func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) {
return s.GetByIDLite(ctx, id)
}
func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) {
if g, ok := s.groups[id]; ok {
return g, nil
}
return nil, ErrGroupNotFound
}
func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error {
panic("unexpected Delete call")
}
func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
panic("unexpected DeleteCascade call")
}
func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected List call")
}
func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
panic("unexpected ListWithFilters call")
}
func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) {
panic("unexpected ListActive call")
}
func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
panic("unexpected ListActiveByPlatform call")
}
func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) {
panic("unexpected ExistsByName call")
}
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
panic("unexpected GetAccountCount call")
}
func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
panic("unexpected DeleteAccountGroupsByGroupID call")
}
func (s *groupRepoStubForInvalidRequestFallback) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
panic("unexpected GetAccountIDsByGroupIDs call")
}
func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
panic("unexpected BindAccountsToGroup call")
}
func (s *groupRepoStubForInvalidRequestFallback) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error {
return nil
}
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformOpenAI,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
require.Nil(t, repo.created)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeSubscription,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
require.Nil(t, repo.created)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
tests := []struct {
name string
fallback *Group
wantMessage string
}{
{
name: "openai_target",
fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard},
wantMessage: "fallback group must be anthropic platform",
},
{
name: "antigravity_target",
fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard},
wantMessage: "fallback group must be anthropic platform",
},
{
name: "subscription_group",
fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
wantMessage: "fallback group cannot be subscription type",
},
{
name: "nested_fallback",
fallback: &Group{
ID: 10,
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(),
},
wantMessage: "fallback group cannot have invalid request fallback configured",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
fallbackID := tc.fallback.ID
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: tc.fallback,
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), tc.wantMessage)
require.Nil(t, repo.created)
})
}
}
func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group not found")
require.Nil(t, repo.created)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
fallbackID := int64(10)
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAntigravity,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
zero := int64(0)
repo := &groupRepoStubForInvalidRequestFallback{}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
FallbackGroupIDOnInvalidRequest: &zero,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.created)
require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
FallbackGroupIDOnInvalidRequest: &fallbackID,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
Platform: PlatformOpenAI,
})
require.Error(t, err)
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
FallbackGroupIDOnInvalidRequest: &fallbackID,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
SubscriptionType: SubscriptionTypeSubscription,
})
require.Error(t, err)
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
FallbackGroupIDOnInvalidRequest: &fallbackID,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
clear := int64(0)
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
Platform: PlatformOpenAI,
FallbackGroupIDOnInvalidRequest: &clear,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
},
}
svc := &adminServiceImpl{groupRepo: repo}
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.Error(t, err)
require.Contains(t, err.Error(), "fallback group cannot be subscription type")
require.Nil(t, repo.updated)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAnthropic,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
}
func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
fallbackID := int64(10)
existing := &Group{
ID: 1,
Name: "g1",
Platform: PlatformAntigravity,
SubscriptionType: SubscriptionTypeStandard,
Status: StatusActive,
}
repo := &groupRepoStubForInvalidRequestFallback{
groups: map[int64]*Group{
existing.ID: existing,
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
},
}
svc := &adminServiceImpl{groupRepo: repo}
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
FallbackGroupIDOnInvalidRequest: &fallbackID,
})
require.NoError(t, err)
require.NotNil(t, group)
require.NotNil(t, repo.updated)
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
}

View File

@@ -0,0 +1,114 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type userRepoStubForListUsers struct {
userRepoStub
users []User
err error
}
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
if s.err != nil {
return nil, nil, s.err
}
out := make([]User, len(s.users))
copy(out, s.users)
return out, &pagination.PaginationResult{
Total: int64(len(out)),
Page: params.Page,
PageSize: params.PageSize,
}, nil
}
type userGroupRateRepoStubForListUsers struct {
batchCalls int
singleCall []int64
batchErr error
batchData map[int64]map[int64]float64
singleErr map[int64]error
singleData map[int64]map[int64]float64
}
func (s *userGroupRateRepoStubForListUsers) GetByUserIDs(_ context.Context, _ []int64) (map[int64]map[int64]float64, error) {
s.batchCalls++
if s.batchErr != nil {
return nil, s.batchErr
}
return s.batchData, nil
}
func (s *userGroupRateRepoStubForListUsers) GetByUserID(_ context.Context, userID int64) (map[int64]float64, error) {
s.singleCall = append(s.singleCall, userID)
if err, ok := s.singleErr[userID]; ok {
return nil, err
}
if rates, ok := s.singleData[userID]; ok {
return rates, nil
}
return map[int64]float64{}, nil
}
func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, userID, groupID int64) (*float64, error) {
panic("unexpected GetByUserAndGroup call")
}
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
panic("unexpected SyncUserGroupRates call")
}
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, _ int64) ([]UserGroupRateEntry, error) {
panic("unexpected GetByGroupID call")
}
func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.Context, _ int64, _ []GroupRateMultiplierInput) error {
panic("unexpected SyncGroupRateMultipliers call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
panic("unexpected DeleteByGroupID call")
}
func (s *userGroupRateRepoStubForListUsers) DeleteByUserID(_ context.Context, userID int64) error {
panic("unexpected DeleteByUserID call")
}
func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
userRepo := &userRepoStubForListUsers{
users: []User{
{ID: 101, Username: "u1"},
{ID: 202, Username: "u2"},
},
}
rateRepo := &userGroupRateRepoStubForListUsers{
batchErr: errors.New("batch unavailable"),
singleData: map[int64]map[int64]float64{
101: {11: 1.1},
202: {22: 2.2},
},
}
svc := &adminServiceImpl{
userRepo: userRepo,
userGroupRateRepo: rateRepo,
}
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{})
require.NoError(t, err)
require.Equal(t, int64(2), total)
require.Len(t, users, 2)
require.Equal(t, 1, rateRepo.batchCalls)
require.ElementsMatch(t, []int64{101, 202}, rateRepo.singleCall)
require.Equal(t, 1.1, users[0].GroupRates[11])
require.Equal(t, 2.2, users[1].GroupRates[22])
}

View File

@@ -0,0 +1,123 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type updateAccountOveragesRepoStub struct {
mockAccountRepoForGemini
account *Account
updateCalls int
}
func (r *updateAccountOveragesRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
return r.account, nil
}
func (r *updateAccountOveragesRepoStub) Update(ctx context.Context, account *Account) error {
r.updateCalls++
r.account = account
return nil
}
func TestUpdateAccount_DisableOveragesClearsAICreditsKey(t *testing.T) {
accountID := int64(101)
repo := &updateAccountOveragesRepoStub{
account: &Account{
ID: accountID,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Extra: map[string]any{
"allow_overages": true,
"mixed_scheduling": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
creditsExhaustedKey: map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
},
}
svc := &adminServiceImpl{accountRepo: repo}
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
Extra: map[string]any{
"mixed_scheduling": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
creditsExhaustedKey: map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 1, repo.updateCalls)
require.False(t, updated.IsOveragesEnabled())
// 关闭 overages 后AICredits key 应被清除
rawLimits, ok := repo.account.Extra[modelRateLimitsKey].(map[string]any)
if ok {
_, exists := rawLimits[creditsExhaustedKey]
require.False(t, exists, "关闭 overages 时应清除 AICredits 限流 key")
}
// 普通模型限流应保留
require.True(t, ok)
_, exists := rawLimits["claude-sonnet-4-5"]
require.True(t, exists, "普通模型限流应保留")
}
func TestUpdateAccount_EnableOveragesClearsModelRateLimitsBeforePersist(t *testing.T) {
accountID := int64(102)
repo := &updateAccountOveragesRepoStub{
account: &Account{
ID: accountID,
Platform: PlatformAntigravity,
Type: AccountTypeOAuth,
Status: StatusActive,
Extra: map[string]any{
"mixed_scheduling": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
},
},
},
}
svc := &adminServiceImpl{accountRepo: repo}
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
Extra: map[string]any{
"mixed_scheduling": true,
"allow_overages": true,
},
})
require.NoError(t, err)
require.NotNil(t, updated)
require.Equal(t, 1, repo.updateCalls)
require.True(t, updated.IsOveragesEnabled())
_, exists := repo.account.Extra[modelRateLimitsKey]
require.False(t, exists, "开启 overages 时应在持久化前清掉旧模型限流")
}

View File

@@ -0,0 +1,95 @@
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/require"
)
func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
result := &ProxyQualityCheckResult{
PassedCount: 2,
WarnCount: 1,
FailedCount: 1,
ChallengeCount: 1,
}
finalizeProxyQualityResult(result)
require.Equal(t, 38, result.Score)
require.Equal(t, "F", result.Grade)
require.Contains(t, result.Summary, "通过 2 项")
require.Contains(t, result.Summary, "告警 1 项")
require.Contains(t, result.Summary, "失败 1 项")
require.Contains(t, result.Summary, "挑战 1 项")
}
func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/html")
w.Header().Set("cf-ray", "test-ray-123")
w.WriteHeader(http.StatusForbidden)
_, _ = w.Write([]byte("<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>"))
}))
defer server.Close()
target := proxyQualityTarget{
Target: "sora",
URL: server.URL,
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
},
}
item := runProxyQualityTarget(context.Background(), server.Client(), target)
require.Equal(t, "challenge", item.Status)
require.Equal(t, http.StatusForbidden, item.HTTPStatus)
require.Equal(t, "test-ray-123", item.CFRay)
}
func TestRunProxyQualityTarget_AllowedStatusPass(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(`{"models":[]}`))
}))
defer server.Close()
target := proxyQualityTarget{
Target: "gemini",
URL: server.URL,
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusOK: {},
},
}
item := runProxyQualityTarget(context.Background(), server.Client(), target)
require.Equal(t, "pass", item.Status)
require.Equal(t, http.StatusOK, item.HTTPStatus)
}
func TestRunProxyQualityTarget_AllowedStatusWarnForUnauthorized(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
}))
defer server.Close()
target := proxyQualityTarget{
Target: "openai",
URL: server.URL,
Method: http.MethodGet,
AllowedStatuses: map[int]struct{}{
http.StatusUnauthorized: {},
},
}
item := runProxyQualityTarget(context.Background(), server.Client(), target)
require.Equal(t, "warn", item.Status)
require.Equal(t, http.StatusUnauthorized, item.HTTPStatus)
require.Contains(t, item.Message, "目标可达")
}

View File

@@ -0,0 +1,246 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type accountRepoStubForAdminList struct {
accountRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersPlatform string
listWithFiltersType string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersAccounts []Account
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
}
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersPlatform = platform
s.listWithFiltersType = accountType
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersAccounts)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersAccounts, result, nil
}
type proxyRepoStubForAdminList struct {
proxyRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersProtocol string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersProxies []Proxy
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
listWithFiltersAndAccountCountCalls int
listWithFiltersAndAccountCountParams pagination.PaginationParams
listWithFiltersAndAccountCountProtocol string
listWithFiltersAndAccountCountStatus string
listWithFiltersAndAccountCountSearch string
listWithFiltersAndAccountCountProxies []ProxyWithAccountCount
listWithFiltersAndAccountCountResult *pagination.PaginationResult
listWithFiltersAndAccountCountErr error
}
func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersProtocol = protocol
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersProxies)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersProxies, result, nil
}
func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
s.listWithFiltersAndAccountCountCalls++
s.listWithFiltersAndAccountCountParams = params
s.listWithFiltersAndAccountCountProtocol = protocol
s.listWithFiltersAndAccountCountStatus = status
s.listWithFiltersAndAccountCountSearch = search
if s.listWithFiltersAndAccountCountErr != nil {
return nil, nil, s.listWithFiltersAndAccountCountErr
}
result := s.listWithFiltersAndAccountCountResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersAndAccountCountProxies)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersAndAccountCountProxies, result, nil
}
type redeemRepoStubForAdminList struct {
redeemRepoStub
listWithFiltersCalls int
listWithFiltersParams pagination.PaginationParams
listWithFiltersType string
listWithFiltersStatus string
listWithFiltersSearch string
listWithFiltersCodes []RedeemCode
listWithFiltersResult *pagination.PaginationResult
listWithFiltersErr error
}
func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
s.listWithFiltersCalls++
s.listWithFiltersParams = params
s.listWithFiltersType = codeType
s.listWithFiltersStatus = status
s.listWithFiltersSearch = search
if s.listWithFiltersErr != nil {
return nil, nil, s.listWithFiltersErr
}
result := s.listWithFiltersResult
if result == nil {
result = &pagination.PaginationResult{
Total: int64(len(s.listWithFiltersCodes)),
Page: params.Page,
PageSize: params.PageSize,
}
}
return s.listWithFiltersCodes, result, nil
}
func (s *redeemRepoStubForAdminList) ListByUserPaginated(_ context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) {
panic("unexpected ListByUserPaginated call")
}
func (s *redeemRepoStubForAdminList) SumPositiveBalanceByUser(_ context.Context, userID int64) (float64, error) {
panic("unexpected SumPositiveBalanceByUser call")
}
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &accountRepoStubForAdminList{
listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 10},
}
svc := &adminServiceImpl{accountRepo: repo}
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0)
require.NoError(t, err)
require.Equal(t, int64(10), total)
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "acc", repo.listWithFiltersSearch)
})
}
func TestAdminService_ListProxies_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &proxyRepoStubForAdminList{
listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 7},
}
svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1")
require.NoError(t, err)
require.Equal(t, int64(7), total)
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
require.Equal(t, "http", repo.listWithFiltersProtocol)
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
require.Equal(t, "p1", repo.listWithFiltersSearch)
})
}
func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &proxyRepoStubForAdminList{
listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}},
listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9},
}
svc := &adminServiceImpl{proxyRepo: repo}
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2")
require.NoError(t, err)
require.Equal(t, int64(9), total)
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams)
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
})
}
func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
repo := &redeemRepoStubForAdminList{
listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}},
listWithFiltersResult: &pagination.PaginationResult{Total: 3},
}
svc := &adminServiceImpl{redeemCodeRepo: repo}
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC")
require.NoError(t, err)
require.Equal(t, int64(3), total)
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
require.Equal(t, 1, repo.listWithFiltersCalls)
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
require.Equal(t, "ABC", repo.listWithFiltersSearch)
})
}

View File

@@ -0,0 +1,97 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
type balanceUserRepoStub struct {
*userRepoStub
updateErr error
updated []*User
}
func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error {
if s.updateErr != nil {
return s.updateErr
}
if user == nil {
return nil
}
clone := *user
s.updated = append(s.updated, &clone)
if s.userRepoStub != nil {
s.userRepoStub.user = &clone
}
return nil
}
type balanceRedeemRepoStub struct {
*redeemRepoStub
created []*RedeemCode
}
func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
if code == nil {
return nil
}
clone := *code
s.created = append(s.created, &clone)
return nil
}
type authCacheInvalidatorStub struct {
userIDs []int64
groupIDs []int64
keys []string
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) {
s.keys = append(s.keys, key)
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
s.userIDs = append(s.userIDs, userID)
}
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
s.groupIDs = append(s.groupIDs, groupID)
}
func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) {
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: redeemRepo,
authCacheInvalidator: invalidator,
}
_, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "")
require.NoError(t, err)
require.Equal(t, []int64{7}, invalidator.userIDs)
require.Len(t, redeemRepo.created, 1)
}
func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) {
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
invalidator := &authCacheInvalidatorStub{}
svc := &adminServiceImpl{
userRepo: repo,
redeemCodeRepo: redeemRepo,
authCacheInvalidator: invalidator,
}
_, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "")
require.NoError(t, err)
require.Empty(t, invalidator.userIDs)
require.Empty(t, redeemRepo.created)
}

View File

@@ -0,0 +1,69 @@
package service
import (
"context"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
const (
AnnouncementStatusDraft = domain.AnnouncementStatusDraft
AnnouncementStatusActive = domain.AnnouncementStatusActive
AnnouncementStatusArchived = domain.AnnouncementStatusArchived
)
const (
AnnouncementNotifyModeSilent = domain.AnnouncementNotifyModeSilent
AnnouncementNotifyModePopup = domain.AnnouncementNotifyModePopup
)
const (
AnnouncementConditionTypeSubscription = domain.AnnouncementConditionTypeSubscription
AnnouncementConditionTypeBalance = domain.AnnouncementConditionTypeBalance
)
const (
AnnouncementOperatorIn = domain.AnnouncementOperatorIn
AnnouncementOperatorGT = domain.AnnouncementOperatorGT
AnnouncementOperatorGTE = domain.AnnouncementOperatorGTE
AnnouncementOperatorLT = domain.AnnouncementOperatorLT
AnnouncementOperatorLTE = domain.AnnouncementOperatorLTE
AnnouncementOperatorEQ = domain.AnnouncementOperatorEQ
)
var (
ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
)
type AnnouncementTargeting = domain.AnnouncementTargeting
type AnnouncementConditionGroup = domain.AnnouncementConditionGroup
type AnnouncementCondition = domain.AnnouncementCondition
type Announcement = domain.Announcement
type AnnouncementListFilters struct {
Status string
Search string
}
type AnnouncementRepository interface {
Create(ctx context.Context, a *Announcement) error
GetByID(ctx context.Context, id int64) (*Announcement, error)
Update(ctx context.Context, a *Announcement) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error)
ListActive(ctx context.Context, now time.Time) ([]Announcement, error)
}
type AnnouncementReadRepository interface {
MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error
GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error)
GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error)
CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error)
}

View File

@@ -0,0 +1,406 @@
package service
import (
"context"
"fmt"
"sort"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
)
type AnnouncementService struct {
announcementRepo AnnouncementRepository
readRepo AnnouncementReadRepository
userRepo UserRepository
userSubRepo UserSubscriptionRepository
}
func NewAnnouncementService(
announcementRepo AnnouncementRepository,
readRepo AnnouncementReadRepository,
userRepo UserRepository,
userSubRepo UserSubscriptionRepository,
) *AnnouncementService {
return &AnnouncementService{
announcementRepo: announcementRepo,
readRepo: readRepo,
userRepo: userRepo,
userSubRepo: userSubRepo,
}
}
type CreateAnnouncementInput struct {
Title string
Content string
Status string
NotifyMode string
Targeting AnnouncementTargeting
StartsAt *time.Time
EndsAt *time.Time
ActorID *int64 // 管理员用户ID
}
type UpdateAnnouncementInput struct {
Title *string
Content *string
Status *string
NotifyMode *string
Targeting *AnnouncementTargeting
StartsAt **time.Time
EndsAt **time.Time
ActorID *int64 // 管理员用户ID
}
type UserAnnouncement struct {
Announcement Announcement
ReadAt *time.Time
}
type AnnouncementUserReadStatus struct {
UserID int64 `json:"user_id"`
Email string `json:"email"`
Username string `json:"username"`
Balance float64 `json:"balance"`
Eligible bool `json:"eligible"`
ReadAt *time.Time `json:"read_at,omitempty"`
}
func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) {
if input == nil {
return nil, fmt.Errorf("create announcement: nil input")
}
title := strings.TrimSpace(input.Title)
content := strings.TrimSpace(input.Content)
if title == "" || len(title) > 200 {
return nil, fmt.Errorf("create announcement: invalid title")
}
if content == "" {
return nil, fmt.Errorf("create announcement: content is required")
}
status := strings.TrimSpace(input.Status)
if status == "" {
status = AnnouncementStatusDraft
}
if !isValidAnnouncementStatus(status) {
return nil, fmt.Errorf("create announcement: invalid status")
}
targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate()
if err != nil {
return nil, err
}
notifyMode := strings.TrimSpace(input.NotifyMode)
if notifyMode == "" {
notifyMode = AnnouncementNotifyModeSilent
}
if !isValidAnnouncementNotifyMode(notifyMode) {
return nil, fmt.Errorf("create announcement: invalid notify_mode")
}
if input.StartsAt != nil && input.EndsAt != nil {
if !input.StartsAt.Before(*input.EndsAt) {
return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
}
}
a := &Announcement{
Title: title,
Content: content,
Status: status,
NotifyMode: notifyMode,
Targeting: targeting,
StartsAt: input.StartsAt,
EndsAt: input.EndsAt,
}
if input.ActorID != nil && *input.ActorID > 0 {
a.CreatedBy = input.ActorID
a.UpdatedBy = input.ActorID
}
if err := s.announcementRepo.Create(ctx, a); err != nil {
return nil, fmt.Errorf("create announcement: %w", err)
}
return a, nil
}
func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) {
if input == nil {
return nil, fmt.Errorf("update announcement: nil input")
}
a, err := s.announcementRepo.GetByID(ctx, id)
if err != nil {
return nil, err
}
if input.Title != nil {
title := strings.TrimSpace(*input.Title)
if title == "" || len(title) > 200 {
return nil, fmt.Errorf("update announcement: invalid title")
}
a.Title = title
}
if input.Content != nil {
content := strings.TrimSpace(*input.Content)
if content == "" {
return nil, fmt.Errorf("update announcement: content is required")
}
a.Content = content
}
if input.Status != nil {
status := strings.TrimSpace(*input.Status)
if !isValidAnnouncementStatus(status) {
return nil, fmt.Errorf("update announcement: invalid status")
}
a.Status = status
}
if input.NotifyMode != nil {
notifyMode := strings.TrimSpace(*input.NotifyMode)
if !isValidAnnouncementNotifyMode(notifyMode) {
return nil, fmt.Errorf("update announcement: invalid notify_mode")
}
a.NotifyMode = notifyMode
}
if input.Targeting != nil {
targeting, err := domain.AnnouncementTargeting(*input.Targeting).NormalizeAndValidate()
if err != nil {
return nil, err
}
a.Targeting = targeting
}
if input.StartsAt != nil {
a.StartsAt = *input.StartsAt
}
if input.EndsAt != nil {
a.EndsAt = *input.EndsAt
}
if a.StartsAt != nil && a.EndsAt != nil {
if !a.StartsAt.Before(*a.EndsAt) {
return nil, fmt.Errorf("update announcement: starts_at must be before ends_at")
}
}
if input.ActorID != nil && *input.ActorID > 0 {
a.UpdatedBy = input.ActorID
}
if err := s.announcementRepo.Update(ctx, a); err != nil {
return nil, fmt.Errorf("update announcement: %w", err)
}
return a, nil
}
func (s *AnnouncementService) Delete(ctx context.Context, id int64) error {
if err := s.announcementRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete announcement: %w", err)
}
return nil
}
func (s *AnnouncementService) GetByID(ctx context.Context, id int64) (*Announcement, error) {
return s.announcementRepo.GetByID(ctx, id)
}
func (s *AnnouncementService) List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) {
return s.announcementRepo.List(ctx, params, filters)
}
func (s *AnnouncementService) ListForUser(ctx context.Context, userID int64, unreadOnly bool) ([]UserAnnouncement, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("list active subscriptions: %w", err)
}
activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
for i := range activeSubs {
activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
}
now := time.Now()
anns, err := s.announcementRepo.ListActive(ctx, now)
if err != nil {
return nil, fmt.Errorf("list active announcements: %w", err)
}
visible := make([]Announcement, 0, len(anns))
ids := make([]int64, 0, len(anns))
for i := range anns {
a := anns[i]
if !a.IsActiveAt(now) {
continue
}
if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
continue
}
visible = append(visible, a)
ids = append(ids, a.ID)
}
if len(visible) == 0 {
return []UserAnnouncement{}, nil
}
readMap, err := s.readRepo.GetReadMapByUser(ctx, userID, ids)
if err != nil {
return nil, fmt.Errorf("get read map: %w", err)
}
out := make([]UserAnnouncement, 0, len(visible))
for i := range visible {
a := visible[i]
readAt, ok := readMap[a.ID]
if unreadOnly && ok {
continue
}
var ptr *time.Time
if ok {
t := readAt
ptr = &t
}
out = append(out, UserAnnouncement{
Announcement: a,
ReadAt: ptr,
})
}
// 未读优先、同状态按创建时间倒序
sort.Slice(out, func(i, j int) bool {
ai, aj := out[i], out[j]
if (ai.ReadAt == nil) != (aj.ReadAt == nil) {
return ai.ReadAt == nil
}
return ai.Announcement.ID > aj.Announcement.ID
})
return out, nil
}
func (s *AnnouncementService) MarkRead(ctx context.Context, userID, announcementID int64) error {
// 安全:仅允许标记当前用户“可见”的公告
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return fmt.Errorf("get user: %w", err)
}
a, err := s.announcementRepo.GetByID(ctx, announcementID)
if err != nil {
return err
}
now := time.Now()
if !a.IsActiveAt(now) {
return ErrAnnouncementNotFound
}
activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return fmt.Errorf("list active subscriptions: %w", err)
}
activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
for i := range activeSubs {
activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
}
if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
return ErrAnnouncementNotFound
}
if err := s.readRepo.MarkRead(ctx, announcementID, userID, now); err != nil {
return fmt.Errorf("mark read: %w", err)
}
return nil
}
func (s *AnnouncementService) ListUserReadStatus(
ctx context.Context,
announcementID int64,
params pagination.PaginationParams,
search string,
) ([]AnnouncementUserReadStatus, *pagination.PaginationResult, error) {
ann, err := s.announcementRepo.GetByID(ctx, announcementID)
if err != nil {
return nil, nil, err
}
filters := UserListFilters{
Search: strings.TrimSpace(search),
}
users, page, err := s.userRepo.ListWithFilters(ctx, params, filters)
if err != nil {
return nil, nil, fmt.Errorf("list users: %w", err)
}
userIDs := make([]int64, 0, len(users))
for i := range users {
userIDs = append(userIDs, users[i].ID)
}
readMap, err := s.readRepo.GetReadMapByUsers(ctx, announcementID, userIDs)
if err != nil {
return nil, nil, fmt.Errorf("get read map: %w", err)
}
out := make([]AnnouncementUserReadStatus, 0, len(users))
for i := range users {
u := users[i]
subs, err := s.userSubRepo.ListActiveByUserID(ctx, u.ID)
if err != nil {
return nil, nil, fmt.Errorf("list active subscriptions: %w", err)
}
activeGroupIDs := make(map[int64]struct{}, len(subs))
for j := range subs {
activeGroupIDs[subs[j].GroupID] = struct{}{}
}
readAt, ok := readMap[u.ID]
var ptr *time.Time
if ok {
t := readAt
ptr = &t
}
out = append(out, AnnouncementUserReadStatus{
UserID: u.ID,
Email: u.Email,
Username: u.Username,
Balance: u.Balance,
Eligible: domain.AnnouncementTargeting(ann.Targeting).Matches(u.Balance, activeGroupIDs),
ReadAt: ptr,
})
}
return out, page, nil
}
func isValidAnnouncementStatus(status string) bool {
switch status {
case AnnouncementStatusDraft, AnnouncementStatusActive, AnnouncementStatusArchived:
return true
default:
return false
}
}
func isValidAnnouncementNotifyMode(mode string) bool {
switch mode {
case AnnouncementNotifyModeSilent, AnnouncementNotifyModePopup:
return true
default:
return false
}
}

View File

@@ -0,0 +1,66 @@
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAnnouncementTargeting_Matches_EmptyMatchesAll(t *testing.T) {
var targeting AnnouncementTargeting
require.True(t, targeting.Matches(0, nil))
require.True(t, targeting.Matches(123.45, map[int64]struct{}{1: {}}))
}
func TestAnnouncementTargeting_NormalizeAndValidate_RejectsEmptyGroup(t *testing.T) {
targeting := AnnouncementTargeting{
AnyOf: []AnnouncementConditionGroup{
{AllOf: nil},
},
}
_, err := targeting.NormalizeAndValidate()
require.Error(t, err)
require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
}
func TestAnnouncementTargeting_NormalizeAndValidate_RejectsInvalidCondition(t *testing.T) {
targeting := AnnouncementTargeting{
AnyOf: []AnnouncementConditionGroup{
{
AllOf: []AnnouncementCondition{
{Type: "balance", Operator: "between", Value: 10},
},
},
},
}
_, err := targeting.NormalizeAndValidate()
require.Error(t, err)
require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
}
func TestAnnouncementTargeting_Matches_AndOrSemantics(t *testing.T) {
targeting := AnnouncementTargeting{
AnyOf: []AnnouncementConditionGroup{
{
AllOf: []AnnouncementCondition{
{Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorGTE, Value: 100},
{Type: AnnouncementConditionTypeSubscription, Operator: AnnouncementOperatorIn, GroupIDs: []int64{10}},
},
},
{
AllOf: []AnnouncementCondition{
{Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorLT, Value: 5},
},
},
},
}
// 命中第 2 组balance < 5
require.True(t, targeting.Matches(4.99, nil))
require.False(t, targeting.Matches(5, nil))
// 命中第 1 组balance >= 100 AND 订阅 in [10]
require.False(t, targeting.Matches(100, map[int64]struct{}{}))
require.False(t, targeting.Matches(99.9, map[int64]struct{}{10: {}}))
require.True(t, targeting.Matches(100, map[int64]struct{}{10: {}}))
}

View File

@@ -0,0 +1,79 @@
package service
import (
"encoding/json"
"strings"
"time"
)
// Anthropic 会话 Fallback 相关常量
const (
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL5 分钟)
anthropicSessionTTLSeconds = 300
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
)
// AnthropicSessionTTL 返回 Anthropic 会话缓存 TTL
func AnthropicSessionTTL() time.Duration {
return anthropicSessionTTLSeconds * time.Second
}
// BuildAnthropicDigestChain 根据 Anthropic 请求生成摘要链
// 格式: s:<hash>-u:<hash>-a:<hash>-u:<hash>-...
// s = system, u = user, a = assistant
func BuildAnthropicDigestChain(parsed *ParsedRequest) string {
if parsed == nil {
return ""
}
var parts []string
// 1. system prompt
if parsed.System != nil {
systemData, _ := json.Marshal(parsed.System)
if len(systemData) > 0 && string(systemData) != "null" {
parts = append(parts, "s:"+shortHash(systemData))
}
}
// 2. messages
for _, msg := range parsed.Messages {
msgMap, ok := msg.(map[string]any)
if !ok {
continue
}
role, _ := msgMap["role"].(string)
prefix := rolePrefix(role)
content := msgMap["content"]
contentData, _ := json.Marshal(content)
parts = append(parts, prefix+":"+shortHash(contentData))
}
return strings.Join(parts, "-")
}
// rolePrefix 将 Anthropic 的 role 映射为单字符前缀
func rolePrefix(role string) string {
switch role {
case "assistant":
return "a"
default:
return "u"
}
}
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {
prefix := prefixHash
if len(prefixHash) >= 8 {
prefix = prefixHash[:8]
}
uuidPart := uuid
if len(uuid) >= 8 {
uuidPart = uuid[:8]
}
return anthropicDigestSessionKeyPrefix + prefix + ":" + uuidPart
}

View File

@@ -0,0 +1,320 @@
package service
import (
"strings"
"testing"
)
func TestBuildAnthropicDigestChain_NilRequest(t *testing.T) {
result := BuildAnthropicDigestChain(nil)
if result != "" {
t.Errorf("expected empty string for nil request, got: %s", result)
}
}
func TestBuildAnthropicDigestChain_EmptyMessages(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{},
}
result := BuildAnthropicDigestChain(parsed)
if result != "" {
t.Errorf("expected empty string for empty messages, got: %s", result)
}
}
func TestBuildAnthropicDigestChain_SingleUserMessage(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("expected prefix 'u:', got: %s", parts[0])
}
}
func TestBuildAnthropicDigestChain_UserAndAssistant(t *testing.T) {
parsed := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("part[0] expected prefix 'u:', got: %s", parts[0])
}
if !strings.HasPrefix(parts[1], "a:") {
t.Errorf("part[1] expected prefix 'a:', got: %s", parts[1])
}
}
func TestBuildAnthropicDigestChain_WithSystemString(t *testing.T) {
parsed := &ParsedRequest{
System: "You are a helpful assistant",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "s:") {
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
}
if !strings.HasPrefix(parts[1], "u:") {
t.Errorf("part[1] expected prefix 'u:', got: %s", parts[1])
}
}
func TestBuildAnthropicDigestChain_WithSystemContentBlocks(t *testing.T) {
parsed := &ParsedRequest{
System: []any{
map[string]any{"type": "text", "text": "You are a helpful assistant"},
},
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 2 {
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "s:") {
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
}
}
func TestBuildAnthropicDigestChain_ConversationPrefixRelationship(t *testing.T) {
// 核心测试:验证对话增长时链的前缀关系
// 上一轮的完整链一定是下一轮链的前缀
system := "You are a helpful assistant"
// 第 1 轮: system + user
round1 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
chain1 := BuildAnthropicDigestChain(round1)
// 第 2 轮: system + user + assistant + user
round2 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
map[string]any{"role": "user", "content": "how are you?"},
},
}
chain2 := BuildAnthropicDigestChain(round2)
// 第 3 轮: system + user + assistant + user + assistant + user
round3 := &ParsedRequest{
System: system,
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi there"},
map[string]any{"role": "user", "content": "how are you?"},
map[string]any{"role": "assistant", "content": "I'm doing well"},
map[string]any{"role": "user", "content": "great"},
},
}
chain3 := BuildAnthropicDigestChain(round3)
t.Logf("Chain1: %s", chain1)
t.Logf("Chain2: %s", chain2)
t.Logf("Chain3: %s", chain3)
// chain1 是 chain2 的前缀
if !strings.HasPrefix(chain2, chain1) {
t.Errorf("chain1 should be prefix of chain2:\n chain1: %s\n chain2: %s", chain1, chain2)
}
// chain2 是 chain3 的前缀
if !strings.HasPrefix(chain3, chain2) {
t.Errorf("chain2 should be prefix of chain3:\n chain2: %s\n chain3: %s", chain2, chain3)
}
// chain1 也是 chain3 的前缀(传递性)
if !strings.HasPrefix(chain3, chain1) {
t.Errorf("chain1 should be prefix of chain3:\n chain1: %s\n chain3: %s", chain1, chain3)
}
}
func TestBuildAnthropicDigestChain_DifferentSystemProducesDifferentChain(t *testing.T) {
parsed1 := &ParsedRequest{
System: "System A",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
parsed2 := &ParsedRequest{
System: "System B",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
},
}
chain1 := BuildAnthropicDigestChain(parsed1)
chain2 := BuildAnthropicDigestChain(parsed2)
if chain1 == chain2 {
t.Error("Different system prompts should produce different chains")
}
// 但 user 部分的 hash 应该相同
parts1 := splitChain(chain1)
parts2 := splitChain(chain2)
if parts1[1] != parts2[1] {
t.Error("Same user message should produce same hash regardless of system")
}
}
func TestBuildAnthropicDigestChain_DifferentContentProducesDifferentChain(t *testing.T) {
parsed1 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "ORIGINAL reply"},
map[string]any{"role": "user", "content": "next"},
},
}
parsed2 := &ParsedRequest{
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "TAMPERED reply"},
map[string]any{"role": "user", "content": "next"},
},
}
chain1 := BuildAnthropicDigestChain(parsed1)
chain2 := BuildAnthropicDigestChain(parsed2)
if chain1 == chain2 {
t.Error("Different content should produce different chains")
}
parts1 := splitChain(chain1)
parts2 := splitChain(chain2)
// 第一个 user message hash 应该相同
if parts1[0] != parts2[0] {
t.Error("First user message hash should be the same")
}
// assistant reply hash 应该不同
if parts1[1] == parts2[1] {
t.Error("Assistant reply hash should differ")
}
}
func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) {
parsed := &ParsedRequest{
System: "test system",
Messages: []any{
map[string]any{"role": "user", "content": "hello"},
map[string]any{"role": "assistant", "content": "hi"},
},
}
chain1 := BuildAnthropicDigestChain(parsed)
chain2 := BuildAnthropicDigestChain(parsed)
if chain1 != chain2 {
t.Errorf("BuildAnthropicDigestChain not deterministic: %s vs %s", chain1, chain2)
}
}
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
tests := []struct {
name string
prefixHash string
uuid string
want string
}{
{
name: "normal 16 char hash with uuid",
prefixHash: "abcdefgh12345678",
uuid: "550e8400-e29b-41d4-a716-446655440000",
want: "anthropic:digest:abcdefgh:550e8400",
},
{
name: "exactly 8 chars",
prefixHash: "12345678",
uuid: "abcdefgh",
want: "anthropic:digest:12345678:abcdefgh",
},
{
name: "short values",
prefixHash: "abc",
uuid: "xyz",
want: "anthropic:digest:abc:xyz",
},
{
name: "empty values",
prefixHash: "",
uuid: "",
want: "anthropic:digest::",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GenerateAnthropicDigestSessionKey(tt.prefixHash, tt.uuid)
if got != tt.want {
t.Errorf("GenerateAnthropicDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want)
}
})
}
// 验证不同 uuid 产生不同 sessionKey
t.Run("different uuid different key", func(t *testing.T) {
hash := "sameprefix123456"
result1 := GenerateAnthropicDigestSessionKey(hash, "uuid0001-session-a")
result2 := GenerateAnthropicDigestSessionKey(hash, "uuid0002-session-b")
if result1 == result2 {
t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2)
}
})
}
func TestAnthropicSessionTTL(t *testing.T) {
ttl := AnthropicSessionTTL()
if ttl.Seconds() != 300 {
t.Errorf("expected 300 seconds, got: %v", ttl.Seconds())
}
}
func TestBuildAnthropicDigestChain_ContentBlocks(t *testing.T) {
// 测试 content 为 content blocks 数组的情况
parsed := &ParsedRequest{
Messages: []any{
map[string]any{
"role": "user",
"content": []any{
map[string]any{"type": "text", "text": "describe this image"},
map[string]any{"type": "image", "source": map[string]any{"type": "base64"}},
},
},
},
}
result := BuildAnthropicDigestChain(parsed)
parts := splitChain(result)
if len(parts) != 1 {
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
}
if !strings.HasPrefix(parts[0], "u:") {
t.Errorf("expected prefix 'u:', got: %s", parts[0])
}
}

View File

@@ -0,0 +1,234 @@
package service
import (
"context"
"encoding/json"
"io"
"net/http"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
const (
// creditsExhaustedKey 是 model_rate_limits 中标记积分耗尽的特殊 key。
// 与普通模型限流完全同构:通过 SetModelRateLimit / isRateLimitActiveForKey 读写。
creditsExhaustedKey = "AICredits"
creditsExhaustedDuration = 5 * time.Hour
)
type antigravity429Category string
const (
antigravity429Unknown antigravity429Category = "unknown"
antigravity429RateLimited antigravity429Category = "rate_limited"
antigravity429QuotaExhausted antigravity429Category = "quota_exhausted"
)
var (
antigravityQuotaExhaustedKeywords = []string{
"quota_exhausted",
"quota exhausted",
}
creditsExhaustedKeywords = []string{
"google_one_ai",
"insufficient credit",
"insufficient credits",
"not enough credit",
"not enough credits",
"credit exhausted",
"credits exhausted",
"credit balance",
"minimumcreditamountforusage",
"minimum credit amount for usage",
"minimum credit",
}
)
// isCreditsExhausted 检查账号的 AICredits 限流 key 是否生效(积分是否耗尽)。
func (a *Account) isCreditsExhausted() bool {
if a == nil {
return false
}
return a.isRateLimitActiveForKey(creditsExhaustedKey)
}
// setCreditsExhausted 标记账号积分耗尽:写入 model_rate_limits["AICredits"] + 更新缓存。
func (s *AntigravityGatewayService) setCreditsExhausted(ctx context.Context, account *Account) {
if account == nil || account.ID == 0 {
return
}
resetAt := time.Now().Add(creditsExhaustedDuration)
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, creditsExhaustedKey, resetAt); err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "set credits exhausted failed: account=%d err=%v", account.ID, err)
return
}
s.updateAccountModelRateLimitInCache(ctx, account, creditsExhaustedKey, resetAt)
logger.LegacyPrintf("service.antigravity_gateway", "credits_exhausted_marked account=%d reset_at=%s",
account.ID, resetAt.UTC().Format(time.RFC3339))
}
// clearCreditsExhausted 清除账号的 AICredits 限流 key。
func (s *AntigravityGatewayService) clearCreditsExhausted(ctx context.Context, account *Account) {
if account == nil || account.ID == 0 || account.Extra == nil {
return
}
rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any)
if !ok {
return
}
if _, exists := rawLimits[creditsExhaustedKey]; !exists {
return
}
delete(rawLimits, creditsExhaustedKey)
account.Extra[modelRateLimitsKey] = rawLimits
if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{
modelRateLimitsKey: rawLimits,
}); err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "clear credits exhausted failed: account=%d err=%v", account.ID, err)
}
}
// classifyAntigravity429 将 Antigravity 的 429 响应归类为配额耗尽、限流或未知。
func classifyAntigravity429(body []byte) antigravity429Category {
if len(body) == 0 {
return antigravity429Unknown
}
lowerBody := strings.ToLower(string(body))
for _, keyword := range antigravityQuotaExhaustedKeywords {
if strings.Contains(lowerBody, keyword) {
return antigravity429QuotaExhausted
}
}
if info := parseAntigravitySmartRetryInfo(body); info != nil && !info.IsModelCapacityExhausted {
return antigravity429RateLimited
}
return antigravity429Unknown
}
// injectEnabledCreditTypes 在已序列化的 v1internal JSON body 中注入 AI Credits 类型。
func injectEnabledCreditTypes(body []byte) []byte {
var payload map[string]any
if err := json.Unmarshal(body, &payload); err != nil {
return nil
}
payload["enabledCreditTypes"] = []string{"GOOGLE_ONE_AI"}
result, err := json.Marshal(payload)
if err != nil {
return nil
}
return result
}
// resolveCreditsOveragesModelKey 解析当前请求对应的 overages 状态模型 key。
func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstreamModelName, requestedModel string) string {
modelKey := strings.TrimSpace(upstreamModelName)
if modelKey != "" {
return modelKey
}
if account == nil {
return ""
}
modelKey = resolveFinalAntigravityModelKey(ctx, account, requestedModel)
if strings.TrimSpace(modelKey) != "" {
return modelKey
}
return resolveAntigravityModelKey(requestedModel)
}
// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。
func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool {
if reqErr != nil || resp == nil {
return false
}
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout {
return false
}
if isURLLevelRateLimit(respBody) {
return false
}
if info := parseAntigravitySmartRetryInfo(respBody); info != nil {
return false
}
bodyLower := strings.ToLower(string(respBody))
for _, keyword := range creditsExhaustedKeywords {
if strings.Contains(bodyLower, keyword) {
return true
}
}
return false
}
type creditsOveragesRetryResult struct {
handled bool
resp *http.Response
}
// attemptCreditsOveragesRetry 在确认免费配额耗尽后,尝试注入 AI Credits 继续请求。
func (s *AntigravityGatewayService) attemptCreditsOveragesRetry(
p antigravityRetryLoopParams,
baseURL string,
modelName string,
waitDuration time.Duration,
originalStatusCode int,
respBody []byte,
) *creditsOveragesRetryResult {
creditsBody := injectEnabledCreditTypes(p.body)
if creditsBody == nil {
return &creditsOveragesRetryResult{handled: false}
}
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)",
p.prefix, modelKey, p.account.ID)
creditsReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, creditsBody)
if err != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d build_request_err=%v",
p.prefix, modelKey, p.account.ID, err)
return &creditsOveragesRetryResult{handled: true}
}
creditsResp, err := p.httpUpstream.Do(creditsReq, p.proxyURL, p.account.ID, p.account.Concurrency)
if err == nil && creditsResp != nil && creditsResp.StatusCode < 400 {
s.clearCreditsExhausted(p.ctx, p.account)
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d credit_overages_success model=%s account=%d",
p.prefix, creditsResp.StatusCode, modelKey, p.account.ID)
return &creditsOveragesRetryResult{handled: true, resp: creditsResp}
}
s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, creditsResp, err)
return &creditsOveragesRetryResult{handled: true}
}
func (s *AntigravityGatewayService) handleCreditsRetryFailure(
ctx context.Context,
prefix string,
modelKey string,
account *Account,
creditsResp *http.Response,
reqErr error,
) {
var creditsRespBody []byte
creditsStatusCode := 0
if creditsResp != nil {
creditsStatusCode = creditsResp.StatusCode
if creditsResp.Body != nil {
creditsRespBody, _ = io.ReadAll(io.LimitReader(creditsResp.Body, 64<<10))
_ = creditsResp.Body.Close()
}
}
if shouldMarkCreditsExhausted(creditsResp, creditsRespBody, reqErr) && account != nil {
s.setCreditsExhausted(ctx, account)
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=true status=%d body=%s",
prefix, modelKey, account.ID, creditsStatusCode, truncateForLog(creditsRespBody, 200))
return
}
if account != nil {
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=false status=%d err=%v body=%s",
prefix, modelKey, account.ID, creditsStatusCode, reqErr, truncateForLog(creditsRespBody, 200))
}
}

View File

@@ -0,0 +1,538 @@
//go:build unit
package service
import (
"bytes"
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
"github.com/stretchr/testify/require"
)
func TestClassifyAntigravity429(t *testing.T) {
t.Run("明确配额耗尽", func(t *testing.T) {
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body))
})
t.Run("结构化限流", func(t *testing.T) {
body := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
]
}
}`)
require.Equal(t, antigravity429RateLimited, classifyAntigravity429(body))
})
t.Run("未知429", func(t *testing.T) {
body := []byte(`{"error":{"message":"too many requests"}}`)
require.Equal(t, antigravity429Unknown, classifyAntigravity429(body))
})
}
func TestIsCreditsExhausted_UsesAICreditsKey(t *testing.T) {
t.Run("无 AICredits key 则积分可用", func(t *testing.T) {
account := &Account{
ID: 1,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
},
}
require.False(t, account.isCreditsExhausted())
})
t.Run("AICredits key 生效则积分耗尽", func(t *testing.T) {
account := &Account{
ID: 2,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
creditsExhaustedKey: map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
}
require.True(t, account.isCreditsExhausted())
})
t.Run("AICredits key 过期则积分可用", func(t *testing.T) {
account := &Account{
ID: 3,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
creditsExhaustedKey: map[string]any{
"rate_limited_at": time.Now().Add(-6 * time.Hour).UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
}
require.False(t, account.isCreditsExhausted())
})
}
func TestHandleSmartRetry_QuotaExhausted_UsesCreditsAndStoresIndependentState(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 101,
Name: "acc-101",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
},
Credentials: map[string]any{
"model_mapping": map[string]any{
"claude-opus-4-6": "claude-sonnet-4-5",
},
},
}
respBody := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-opus-4-6","request":{}}`),
httpUpstream: upstream,
accountRepo: repo,
requestedModel: "claude-opus-4-6",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp)
require.Nil(t, result.switchError)
require.Len(t, upstream.requestBodies, 1)
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
require.Empty(t, repo.modelRateLimitCalls, "overages 成功后不应写入普通 model_rate_limits")
}
func TestHandleSmartRetry_RateLimited_DoesNotUseCredits(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 102,
Name: "acc-102",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Extra: map[string]any{
"allow_overages": true,
},
}
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp)
require.Len(t, upstream.requestBodies, 1)
require.NotContains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
require.Empty(t, repo.extraUpdateCalls)
require.Empty(t, repo.modelRateLimitCalls)
}
func TestAntigravityRetryLoop_ModelRateLimited_InjectsCredits(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvailability
}()
antigravity.BaseURLs = []string{"https://ag-1.test"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
},
},
errors: []error{nil},
}
// 模型已限流 + overages 启用 + 无 AICredits key → 应直接注入积分
account := &Account{
ID: 103,
Name: "acc-103",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.NoError(t, err)
require.NotNil(t, result)
require.Len(t, upstream.requestBodies, 1)
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
}
func TestAntigravityRetryLoop_CreditsExhausted_DoesNotInject(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvailability
}()
antigravity.BaseURLs = []string{"https://ag-1.test"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
// 模型限流 + overages 启用 + AICredits key 生效 → 不应注入积分,应切号
account := &Account{
ID: 104,
Name: "acc-104",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
},
creditsExhaustedKey: map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
_, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
// 模型限流 + 积分耗尽 → 应触发切号错误
require.Error(t, err)
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr)
}
func TestAntigravityRetryLoop_CreditErrorMarksExhausted(t *testing.T) {
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
oldAvailability := antigravity.DefaultURLAvailability
defer func() {
antigravity.BaseURLs = oldBaseURLs
antigravity.DefaultURLAvailability = oldAvailability
}()
antigravity.BaseURLs = []string{"https://ag-1.test"}
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
repo := &stubAntigravityAccountRepo{}
upstream := &queuedHTTPUpstreamStub{
responses: []*http.Response{
{
StatusCode: http.StatusForbidden,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`)),
},
},
errors: []error{nil},
}
// 模型限流 + overages 启用 + 积分可用 → 注入积分但上游返回积分不足
account := &Account{
ID: 105,
Name: "acc-105",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Status: StatusActive,
Schedulable: true,
Extra: map[string]any{
"allow_overages": true,
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{accountRepo: repo}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
httpUpstream: upstream,
accountRepo: repo,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.NoError(t, err)
require.NotNil(t, result)
// 验证 AICredits key 已通过 SetModelRateLimit 写入数据库
require.Len(t, repo.modelRateLimitCalls, 1, "应通过 SetModelRateLimit 写入 AICredits key")
require.Equal(t, creditsExhaustedKey, repo.modelRateLimitCalls[0].modelKey)
}
func TestShouldMarkCreditsExhausted(t *testing.T) {
t.Run("reqErr 不为 nil 时不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusForbidden}
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), io.ErrUnexpectedEOF))
})
t.Run("resp 为 nil 时不标记", func(t *testing.T) {
require.False(t, shouldMarkCreditsExhausted(nil, []byte(`{"error":"Insufficient credits"}`), nil))
})
t.Run("5xx 响应不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusInternalServerError}
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
})
t.Run("408 RequestTimeout 不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusRequestTimeout}
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
})
t.Run("URL 级限流不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
body := []byte(`{"error":{"message":"Resource has been exhausted"}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
})
t.Run("结构化限流不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
})
t.Run("含 credits 关键词时标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusForbidden}
for _, keyword := range []string{
"Insufficient GOOGLE_ONE_AI credits",
"insufficient credit balance",
"not enough credits for this request",
"Credits exhausted",
"minimumCreditAmountForUsage requirement not met",
} {
body := []byte(`{"error":{"message":"` + keyword + `"}}`)
require.True(t, shouldMarkCreditsExhausted(resp, body, nil), "should mark for keyword: %s", keyword)
}
})
t.Run("无 credits 关键词时不标记", func(t *testing.T) {
resp := &http.Response{StatusCode: http.StatusForbidden}
body := []byte(`{"error":{"message":"permission denied"}}`)
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
})
}
func TestInjectEnabledCreditTypes(t *testing.T) {
t.Run("正常 JSON 注入成功", func(t *testing.T) {
body := []byte(`{"model":"claude-sonnet-4-5","request":{}}`)
result := injectEnabledCreditTypes(body)
require.NotNil(t, result)
require.Contains(t, string(result), `"enabledCreditTypes"`)
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
})
t.Run("非法 JSON 返回 nil", func(t *testing.T) {
require.Nil(t, injectEnabledCreditTypes([]byte(`not json`)))
})
t.Run("空 body 返回 nil", func(t *testing.T) {
require.Nil(t, injectEnabledCreditTypes([]byte{}))
})
t.Run("已有 enabledCreditTypes 会被覆盖", func(t *testing.T) {
body := []byte(`{"enabledCreditTypes":["OLD"],"model":"test"}`)
result := injectEnabledCreditTypes(body)
require.NotNil(t, result)
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
require.NotContains(t, string(result), `OLD`)
})
}
func TestClearCreditsExhausted(t *testing.T) {
t.Run("account 为 nil 不操作", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
svc.clearCreditsExhausted(context.Background(), nil)
require.Empty(t, repo.extraUpdateCalls)
})
t.Run("Extra 为 nil 不操作", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
svc.clearCreditsExhausted(context.Background(), &Account{ID: 1})
require.Empty(t, repo.extraUpdateCalls)
})
t.Run("无 modelRateLimitsKey 不操作", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
svc.clearCreditsExhausted(context.Background(), &Account{
ID: 1,
Extra: map[string]any{"some_key": "value"},
})
require.Empty(t, repo.extraUpdateCalls)
})
t.Run("无 AICredits key 不操作", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
svc.clearCreditsExhausted(context.Background(), &Account{
ID: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
},
},
})
require.Empty(t, repo.extraUpdateCalls)
})
t.Run("有 AICredits key 时删除并调用 UpdateExtra", func(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
svc := &AntigravityGatewayService{accountRepo: repo}
account := &Account{
ID: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
},
creditsExhaustedKey: map[string]any{
"rate_limited_at": "2026-03-15T00:00:00Z",
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
},
},
},
}
svc.clearCreditsExhausted(context.Background(), account)
require.Len(t, repo.extraUpdateCalls, 1)
// AICredits key 应被删除
rawLimits := account.Extra[modelRateLimitsKey].(map[string]any)
_, exists := rawLimits[creditsExhaustedKey]
require.False(t, exists, "AICredits key 应被删除")
// 普通模型限流应保留
_, exists = rawLimits["claude-sonnet-4-5"]
require.True(t, exists, "普通模型限流应保留")
})
}

View File

@@ -0,0 +1,123 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// TestIsImageGenerationModel_GeminiProImage 测试 gemini-3-pro-image 识别
func TestIsImageGenerationModel_GeminiProImage(t *testing.T) {
require.True(t, isImageGenerationModel("gemini-3-pro-image"))
require.True(t, isImageGenerationModel("gemini-3-pro-image-preview"))
require.True(t, isImageGenerationModel("models/gemini-3-pro-image"))
}
// TestIsImageGenerationModel_GeminiFlashImage 测试 gemini-2.5-flash-image 识别
func TestIsImageGenerationModel_GeminiFlashImage(t *testing.T) {
require.True(t, isImageGenerationModel("gemini-2.5-flash-image"))
require.True(t, isImageGenerationModel("gemini-2.5-flash-image-preview"))
}
// TestIsImageGenerationModel_RegularModel 测试普通模型不被识别为图片模型
func TestIsImageGenerationModel_RegularModel(t *testing.T) {
require.False(t, isImageGenerationModel("claude-3-opus"))
require.False(t, isImageGenerationModel("claude-sonnet-4-20250514"))
require.False(t, isImageGenerationModel("gpt-4o"))
require.False(t, isImageGenerationModel("gemini-2.5-pro")) // 非图片模型
require.False(t, isImageGenerationModel("gemini-2.5-flash"))
// 验证不会误匹配包含关键词的自定义模型名
require.False(t, isImageGenerationModel("my-gemini-3-pro-image-test"))
require.False(t, isImageGenerationModel("custom-gemini-2.5-flash-image-wrapper"))
}
// TestIsImageGenerationModel_CaseInsensitive 测试大小写不敏感
func TestIsImageGenerationModel_CaseInsensitive(t *testing.T) {
require.True(t, isImageGenerationModel("GEMINI-3-PRO-IMAGE"))
require.True(t, isImageGenerationModel("Gemini-3-Pro-Image"))
require.True(t, isImageGenerationModel("GEMINI-2.5-FLASH-IMAGE"))
}
// TestExtractImageSize_ValidSizes 测试有效尺寸解析
func TestExtractImageSize_ValidSizes(t *testing.T) {
svc := &AntigravityGatewayService{}
// 1K
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1K"}}}`)
require.Equal(t, "1K", svc.extractImageSize(body))
// 2K
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"2K"}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
// 4K
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4K"}}}`)
require.Equal(t, "4K", svc.extractImageSize(body))
}
// TestExtractImageSize_CaseInsensitive 测试大小写不敏感
func TestExtractImageSize_CaseInsensitive(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1k"}}}`)
require.Equal(t, "1K", svc.extractImageSize(body))
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4k"}}}`)
require.Equal(t, "4K", svc.extractImageSize(body))
}
// TestExtractImageSize_Default 测试无 imageConfig 返回默认 2K
func TestExtractImageSize_Default(t *testing.T) {
svc := &AntigravityGatewayService{}
// 无 generationConfig
body := []byte(`{"contents":[]}`)
require.Equal(t, "2K", svc.extractImageSize(body))
// 有 generationConfig 但无 imageConfig
body = []byte(`{"generationConfig":{"temperature":0.7}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
// 有 imageConfig 但无 imageSize
body = []byte(`{"generationConfig":{"imageConfig":{}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
}
// TestExtractImageSize_InvalidJSON 测试非法 JSON 返回默认 2K
func TestExtractImageSize_InvalidJSON(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`not valid json`)
require.Equal(t, "2K", svc.extractImageSize(body))
body = []byte(`{"broken":`)
require.Equal(t, "2K", svc.extractImageSize(body))
}
// TestExtractImageSize_EmptySize 测试空 imageSize 返回默认 2K
func TestExtractImageSize_EmptySize(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":""}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
// 空格
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":" "}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
}
// TestExtractImageSize_InvalidSize 测试无效尺寸返回默认 2K
func TestExtractImageSize_InvalidSize(t *testing.T) {
svc := &AntigravityGatewayService{}
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"3K"}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"8K"}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"invalid"}}}`)
require.Equal(t, "2K", svc.extractImageSize(body))
}

View File

@@ -0,0 +1,285 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
requestedModel string
accountMapping map[string]string
expected string
}{
// 1. 账户级映射优先
{
name: "账户映射优先",
requestedModel: "claude-3-5-sonnet-20241022",
accountMapping: map[string]string{"claude-3-5-sonnet-20241022": "custom-model"},
expected: "custom-model",
},
{
name: "账户映射 - 可覆盖默认映射的模型",
requestedModel: "claude-sonnet-4-5",
accountMapping: map[string]string{"claude-sonnet-4-5": "my-custom-sonnet"},
expected: "my-custom-sonnet",
},
{
name: "账户映射 - 可覆盖未知模型",
requestedModel: "claude-opus-4",
accountMapping: map[string]string{"claude-opus-4": "my-opus"},
expected: "my-opus",
},
// 2. 默认映射DefaultAntigravityModelMapping
{
name: "默认映射 - claude-opus-4-6 → claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-6",
accountMapping: nil,
expected: "claude-opus-4-6-thinking",
},
{
name: "默认映射 - claude-opus-4-5-20251101 → claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-5-20251101",
accountMapping: nil,
expected: "claude-opus-4-6-thinking",
},
{
name: "默认映射 - claude-opus-4-5-thinking → claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-5-thinking",
accountMapping: nil,
expected: "claude-opus-4-6-thinking",
},
{
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
requestedModel: "claude-haiku-4-5-20251001",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5-20250929",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
// 3. 默认映射中的透传(映射到自己)
{
name: "默认映射透传 - claude-sonnet-4-6",
requestedModel: "claude-sonnet-4-6",
accountMapping: nil,
expected: "claude-sonnet-4-6",
},
{
name: "默认映射透传 - claude-sonnet-4-5",
requestedModel: "claude-sonnet-4-5",
accountMapping: nil,
expected: "claude-sonnet-4-5",
},
{
name: "默认映射透传 - claude-opus-4-6-thinking",
requestedModel: "claude-opus-4-6-thinking",
accountMapping: nil,
expected: "claude-opus-4-6-thinking",
},
{
name: "默认映射透传 - claude-sonnet-4-5-thinking",
requestedModel: "claude-sonnet-4-5-thinking",
accountMapping: nil,
expected: "claude-sonnet-4-5-thinking",
},
{
name: "默认映射透传 - gemini-2.5-flash",
requestedModel: "gemini-2.5-flash",
accountMapping: nil,
expected: "gemini-2.5-flash",
},
{
name: "默认映射透传 - gemini-2.5-pro",
requestedModel: "gemini-2.5-pro",
accountMapping: nil,
expected: "gemini-2.5-pro",
},
{
name: "默认映射透传 - gemini-3-flash",
requestedModel: "gemini-3-flash",
accountMapping: nil,
expected: "gemini-3-flash",
},
// 4. 未在默认映射中的模型返回空字符串(不支持)
{
name: "未知模型 - claude-unknown 返回空",
requestedModel: "claude-unknown",
accountMapping: nil,
expected: "",
},
{
name: "未知模型 - claude-3-5-sonnet-20241022 返回空(未在默认映射)",
requestedModel: "claude-3-5-sonnet-20241022",
accountMapping: nil,
expected: "",
},
{
name: "未知模型 - claude-3-opus-20240229 返回空",
requestedModel: "claude-3-opus-20240229",
accountMapping: nil,
expected: "",
},
{
name: "未知模型 - claude-opus-4 返回空",
requestedModel: "claude-opus-4",
accountMapping: nil,
expected: "",
},
{
name: "未知模型 - gemini-future-model 返回空",
requestedModel: "gemini-future-model",
accountMapping: nil,
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
}
if tt.accountMapping != nil {
// GetModelMapping 期望 model_mapping 是 map[string]any 格式
mappingAny := make(map[string]any)
for k, v := range tt.accountMapping {
mappingAny[k] = v
}
account.Credentials = map[string]any{
"model_mapping": mappingAny,
}
}
got := svc.getMappedModel(account, tt.requestedModel)
require.Equal(t, tt.expected, got, "model: %s", tt.requestedModel)
})
}
}
func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
requestedModel string
expected string
}{
// 空字符串和非 claude/gemini 前缀返回空字符串
{"空字符串", "", ""},
{"非claude/gemini前缀 - gpt", "gpt-4", ""},
{"非claude/gemini前缀 - llama", "llama-3", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{Platform: PlatformAntigravity}
got := svc.getMappedModel(account, tt.requestedModel)
require.Equal(t, tt.expected, got)
})
}
}
func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
svc := &AntigravityGatewayService{}
tests := []struct {
name string
model string
expected bool
}{
// 直接支持
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
// 可映射(有明确前缀映射)
{"可映射 - claude-opus-4-6", "claude-opus-4-6", true},
// 前缀透传claude 和 gemini 前缀)
{"Gemini前缀", "gemini-unknown", true},
{"Claude前缀", "claude-unknown", true},
// 不支持
{"不支持 - gpt-4", "gpt-4", false},
{"不支持 - 空字符串", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := svc.IsModelSupported(tt.model)
require.Equal(t, tt.expected, got)
})
}
}
// TestMapAntigravityModel_WildcardTargetEqualsRequest 测试通配符映射目标恰好等于请求模型名的 edge case
// 例如 {"claude-*": "claude-sonnet-4-5"},请求 "claude-sonnet-4-5" 时应该通过
func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) {
tests := []struct {
name string
modelMapping map[string]any
requestedModel string
expected string
}{
{
name: "wildcard target equals request model",
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
},
{
name: "wildcard target differs from request model",
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"},
requestedModel: "claude-opus-4-6",
expected: "claude-sonnet-4-5",
},
{
name: "wildcard no match",
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"},
requestedModel: "gpt-4o",
expected: "",
},
{
name: "explicit passthrough same name",
modelMapping: map[string]any{"claude-sonnet-4-5": "claude-sonnet-4-5"},
requestedModel: "claude-sonnet-4-5",
expected: "claude-sonnet-4-5",
},
{
name: "multiple wildcards target equals one request",
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5", "gemini-*": "gemini-2.5-flash"},
requestedModel: "gemini-2.5-flash",
expected: "gemini-2.5-flash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Credentials: map[string]any{
"model_mapping": tt.modelMapping,
},
}
got := mapAntigravityModel(account, tt.requestedModel)
require.Equal(t, tt.expected, got, "mapAntigravityModel(%q) = %q, want %q", tt.requestedModel, got, tt.expected)
})
}
}

View File

@@ -0,0 +1,436 @@
package service
import (
"context"
"fmt"
"strconv"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
type AntigravityOAuthService struct {
sessionStore *antigravity.SessionStore
proxyRepo ProxyRepository
}
func NewAntigravityOAuthService(proxyRepo ProxyRepository) *AntigravityOAuthService {
return &AntigravityOAuthService{
sessionStore: antigravity.NewSessionStore(),
proxyRepo: proxyRepo,
}
}
// AntigravityAuthURLResult is the result of generating an authorization URL
type AntigravityAuthURLResult struct {
AuthURL string `json:"auth_url"`
SessionID string `json:"session_id"`
State string `json:"state"`
}
// GenerateAuthURL 生成 Google OAuth 授权链接
func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*AntigravityAuthURLResult, error) {
state, err := antigravity.GenerateState()
if err != nil {
return nil, fmt.Errorf("生成 state 失败: %w", err)
}
codeVerifier, err := antigravity.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("生成 code_verifier 失败: %w", err)
}
sessionID, err := antigravity.GenerateSessionID()
if err != nil {
return nil, fmt.Errorf("生成 session_id 失败: %w", err)
}
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
session := &antigravity.OAuthSession{
State: state,
CodeVerifier: codeVerifier,
ProxyURL: proxyURL,
CreatedAt: time.Now(),
}
s.sessionStore.Set(sessionID, session)
codeChallenge := antigravity.GenerateCodeChallenge(codeVerifier)
authURL := antigravity.BuildAuthorizationURL(state, codeChallenge)
return &AntigravityAuthURLResult{
AuthURL: authURL,
SessionID: sessionID,
State: state,
}, nil
}
// AntigravityExchangeCodeInput 交换 code 的输入
type AntigravityExchangeCodeInput struct {
SessionID string
State string
Code string
ProxyID *int64
}
// AntigravityTokenInfo token 信息
type AntigravityTokenInfo struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
ExpiresAt int64 `json:"expires_at"`
TokenType string `json:"token_type"`
Email string `json:"email,omitempty"`
ProjectID string `json:"project_id,omitempty"`
ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id
}
// ExchangeCode 用 authorization code 交换 token
func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *AntigravityExchangeCodeInput) (*AntigravityTokenInfo, error) {
session, ok := s.sessionStore.Get(input.SessionID)
if !ok {
return nil, fmt.Errorf("session 不存在或已过期")
}
if strings.TrimSpace(input.State) == "" || input.State != session.State {
return nil, fmt.Errorf("state 无效")
}
// 确定代理 URL
proxyURL := session.ProxyURL
if input.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
// 交换 token
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
if err != nil {
return nil, fmt.Errorf("token 交换失败: %w", err)
}
// 删除 session
s.sessionStore.Delete(input.SessionID)
// 计算过期时间(减去 5 分钟安全窗口)
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
result := &AntigravityTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
TokenType: tokenResp.TokenType,
}
// 获取用户信息
userInfo, err := client.GetUserInfo(ctx, tokenResp.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
} else {
result.Email = userInfo.Email
}
// 获取 project_id部分账户类型可能没有失败时重试
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
result.ProjectIDMissing = true
} else {
result.ProjectID = projectID
}
return result, nil
}
// RefreshToken 刷新 token
func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*AntigravityTokenInfo, error) {
var lastErr error
for attempt := 0; attempt <= 3; attempt++ {
if attempt > 0 {
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 30*time.Second {
backoff = 30 * time.Second
}
time.Sleep(backoff)
}
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
tokenResp, err := client.RefreshToken(ctx, refreshToken)
if err == nil {
now := time.Now()
expiresAt := now.Unix() + tokenResp.ExpiresIn - 300
fmt.Printf("[AntigravityOAuth] Token refreshed: expires_in=%d, expires_at=%d (%s)\n",
tokenResp.ExpiresIn, expiresAt, time.Unix(expiresAt, 0).Format("2006-01-02 15:04:05"))
return &AntigravityTokenInfo{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
ExpiresIn: tokenResp.ExpiresIn,
ExpiresAt: expiresAt,
TokenType: tokenResp.TokenType,
}, nil
}
if isNonRetryableAntigravityOAuthError(err) {
return nil, err
}
lastErr = err
}
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
}
// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id
func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64) (*AntigravityTokenInfo, error) {
var proxyURL string
if proxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
// 刷新 token
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
if err != nil {
return nil, err
}
// 获取用户信息email
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken)
if err != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
} else {
tokenInfo.Email = userInfo.Email
}
// 获取 project_id容错失败不阻塞
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if loadErr != nil {
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
tokenInfo.ProjectIDMissing = true
} else {
tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
}
func isNonRetryableAntigravityOAuthError(err error) bool {
msg := err.Error()
nonRetryable := []string{
"invalid_grant",
"invalid_client",
"unauthorized_client",
"access_denied",
}
for _, needle := range nonRetryable {
if strings.Contains(msg, needle) {
return true
}
}
return false
}
// RefreshAccountToken 刷新账户的 token
func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*AntigravityTokenInfo, error) {
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
return nil, fmt.Errorf("非 Antigravity OAuth 账户")
}
refreshToken := account.GetCredential("refresh_token")
if strings.TrimSpace(refreshToken) == "" {
return nil, fmt.Errorf("无可用的 refresh_token")
}
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
if err != nil {
return nil, err
}
// 保留原有的 email
existingEmail := strings.TrimSpace(account.GetCredential("email"))
if existingEmail != "" {
tokenInfo.Email = existingEmail
}
// 每次刷新都调用 LoadCodeAssist 获取 project_id失败时重试
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
if loadErr != nil {
// LoadCodeAssist 失败,保留原有 project_id
tokenInfo.ProjectID = existingProjectID
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
// 如果之前有 project_id本次只是临时故障不应标记为错误
if existingProjectID == "" {
tokenInfo.ProjectIDMissing = true
}
} else {
tokenInfo.ProjectID = projectID
}
return tokenInfo, nil
}
// loadProjectIDWithRetry 带重试机制获取 project_id
// 返回 project_id 和错误,失败时会重试指定次数
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
var lastErr error
for attempt := 0; attempt <= maxRetries; attempt++ {
if attempt > 0 {
// 指数退避1s, 2s, 4s
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
if backoff > 8*time.Second {
backoff = 8 * time.Second
}
time.Sleep(backoff)
}
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return "", fmt.Errorf("create antigravity client failed: %w", err)
}
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
return loadResp.CloudAICompanionProject, nil
}
if err == nil {
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
return projectID, nil
} else if onboardErr != nil {
lastErr = onboardErr
continue
}
}
// 记录错误
if err != nil {
lastErr = err
} else if loadResp == nil {
lastErr = fmt.Errorf("LoadCodeAssist 返回空响应")
} else {
lastErr = fmt.Errorf("LoadCodeAssist 返回空 project_id")
}
}
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
}
func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
tierID := resolveDefaultTierID(loadRaw)
if tierID == "" {
return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier")
}
projectID, err := client.OnboardUser(ctx, accessToken, tierID)
if err != nil {
return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err)
}
return projectID, nil
}
func resolveDefaultTierID(loadRaw map[string]any) string {
if len(loadRaw) == 0 {
return ""
}
rawTiers, ok := loadRaw["allowedTiers"]
if !ok {
return ""
}
tiers, ok := rawTiers.([]any)
if !ok {
return ""
}
for _, rawTier := range tiers {
tier, ok := rawTier.(map[string]any)
if !ok {
continue
}
if isDefault, _ := tier["isDefault"].(bool); !isDefault {
continue
}
if id, ok := tier["id"].(string); ok {
id = strings.TrimSpace(id)
if id != "" {
return id
}
}
}
return ""
}
// FillProjectID 仅获取 project_id不刷新 OAuth token
func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Account, accessToken string) (string, error) {
var proxyURL string
if account.ProxyID != nil {
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
if err == nil && proxy != nil {
proxyURL = proxy.URL()
}
}
return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3)
}
// BuildAccountCredentials 构建账户凭证
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
creds := map[string]any{
"access_token": tokenInfo.AccessToken,
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
}
if tokenInfo.RefreshToken != "" {
creds["refresh_token"] = tokenInfo.RefreshToken
}
if tokenInfo.TokenType != "" {
creds["token_type"] = tokenInfo.TokenType
}
if tokenInfo.Email != "" {
creds["email"] = tokenInfo.Email
}
if tokenInfo.ProjectID != "" {
creds["project_id"] = tokenInfo.ProjectID
}
return creds
}
// Stop 停止服务
func (s *AntigravityOAuthService) Stop() {
s.sessionStore.Stop()
}

View File

@@ -0,0 +1,82 @@
package service
import (
"testing"
)
func TestResolveDefaultTierID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
loadRaw map[string]any
want string
}{
{
name: "nil loadRaw",
loadRaw: nil,
want: "",
},
{
name: "missing allowedTiers",
loadRaw: map[string]any{
"paidTier": map[string]any{"id": "g1-pro-tier"},
},
want: "",
},
{
name: "empty allowedTiers",
loadRaw: map[string]any{"allowedTiers": []any{}},
want: "",
},
{
name: "tier missing id field",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"isDefault": true},
},
},
want: "",
},
{
name: "allowedTiers but no default",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": "free-tier", "isDefault": false},
map[string]any{"id": "standard-tier", "isDefault": false},
},
},
want: "",
},
{
name: "default tier found",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": "free-tier", "isDefault": true},
map[string]any{"id": "standard-tier", "isDefault": false},
},
},
want: "free-tier",
},
{
name: "default tier id with spaces",
loadRaw: map[string]any{
"allowedTiers": []any{
map[string]any{"id": " standard-tier ", "isDefault": true},
},
},
want: "standard-tier",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := resolveDefaultTierID(tc.loadRaw)
if got != tc.want {
t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want)
}
})
}
}

View File

@@ -0,0 +1,272 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"log/slog"
"regexp"
"strings"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
const (
forbiddenTypeValidation = "validation"
forbiddenTypeViolation = "violation"
forbiddenTypeForbidden = "forbidden"
// 机器可读的错误码
errorCodeForbidden = "forbidden"
errorCodeUnauthenticated = "unauthenticated"
errorCodeRateLimited = "rate_limited"
errorCodeNetworkError = "network_error"
)
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
type AntigravityQuotaFetcher struct {
proxyRepo ProxyRepository
}
// NewAntigravityQuotaFetcher 创建 AntigravityQuotaFetcher
func NewAntigravityQuotaFetcher(proxyRepo ProxyRepository) *AntigravityQuotaFetcher {
return &AntigravityQuotaFetcher{proxyRepo: proxyRepo}
}
// CanFetch 检查是否可以获取此账户的额度
func (f *AntigravityQuotaFetcher) CanFetch(account *Account) bool {
if account.Platform != PlatformAntigravity {
return false
}
accessToken := account.GetCredential("access_token")
return accessToken != ""
}
// FetchQuota 获取 Antigravity 账户额度信息
func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error) {
accessToken := account.GetCredential("access_token")
projectID := account.GetCredential("project_id")
client, err := antigravity.NewClient(proxyURL)
if err != nil {
return nil, fmt.Errorf("create antigravity client failed: %w", err)
}
// 调用 API 获取配额
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
if err != nil {
// 403 Forbidden: 不报错,返回 is_forbidden 标记
var forbiddenErr *antigravity.ForbiddenError
if errors.As(err, &forbiddenErr) {
now := time.Now()
fbType := classifyForbiddenType(forbiddenErr.Body)
return &QuotaResult{
UsageInfo: &UsageInfo{
UpdatedAt: &now,
IsForbidden: true,
ForbiddenReason: forbiddenErr.Body,
ForbiddenType: fbType,
ValidationURL: extractValidationURL(forbiddenErr.Body),
NeedsVerify: fbType == forbiddenTypeValidation,
IsBanned: fbType == forbiddenTypeViolation,
ErrorCode: errorCodeForbidden,
},
}, nil
}
return nil, err
}
// 调用 LoadCodeAssist 获取订阅等级和 AI Credits 余额(非关键路径,失败不影响主流程)
tierRaw, tierNormalized, loadResp := f.fetchSubscriptionTier(ctx, client, accessToken)
// 转换为 UsageInfo
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized, loadResp)
return &QuotaResult{
UsageInfo: usageInfo,
Raw: modelsRaw,
}, nil
}
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串。
// 同时返回 LoadCodeAssistResponse以便提取 AI Credits 余额。
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string, loadResp *antigravity.LoadCodeAssistResponse) {
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
if err != nil {
slog.Warn("failed to fetch subscription tier", "error", err)
return "", "", nil
}
if loadResp == nil {
return "", "", nil
}
raw = loadResp.GetTier() // 已有方法paidTier > currentTier
normalized = normalizeTier(raw)
return raw, normalized, loadResp
}
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
func normalizeTier(raw string) string {
if raw == "" {
return ""
}
lower := strings.ToLower(raw)
switch {
case strings.Contains(lower, "ultra"):
return "ULTRA"
case strings.Contains(lower, "pro"):
return "PRO"
case strings.Contains(lower, "free"):
return "FREE"
default:
return "UNKNOWN"
}
}
// buildUsageInfo 将 API 响应转换为 UsageInfo。
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string, loadResp *antigravity.LoadCodeAssistResponse) *UsageInfo {
now := time.Now()
info := &UsageInfo{
UpdatedAt: &now,
AntigravityQuota: make(map[string]*AntigravityModelQuota),
AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail),
SubscriptionTier: tierNormalized,
SubscriptionTierRaw: tierRaw,
}
// 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails
for modelName, modelInfo := range modelsResp.Models {
if modelInfo.QuotaInfo == nil {
continue
}
// remainingFraction 是剩余比例 (0.0-1.0),转换为使用率百分比
utilization := int((1.0 - modelInfo.QuotaInfo.RemainingFraction) * 100)
info.AntigravityQuota[modelName] = &AntigravityModelQuota{
Utilization: utilization,
ResetTime: modelInfo.QuotaInfo.ResetTime,
}
// 填充模型详细能力信息
detail := &AntigravityModelDetail{
DisplayName: modelInfo.DisplayName,
SupportsImages: modelInfo.SupportsImages,
SupportsThinking: modelInfo.SupportsThinking,
ThinkingBudget: modelInfo.ThinkingBudget,
Recommended: modelInfo.Recommended,
MaxTokens: modelInfo.MaxTokens,
MaxOutputTokens: modelInfo.MaxOutputTokens,
SupportedMimeTypes: modelInfo.SupportedMimeTypes,
}
info.AntigravityQuotaDetails[modelName] = detail
}
// 废弃模型转发规则
if len(modelsResp.DeprecatedModelIDs) > 0 {
info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs))
for oldID, deprecated := range modelsResp.DeprecatedModelIDs {
info.ModelForwardingRules[oldID] = deprecated.NewModelID
}
}
// 同时设置 FiveHour 用于兼容展示(取主要模型)
priorityModels := []string{"claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"}
for _, modelName := range priorityModels {
if modelInfo, ok := modelsResp.Models[modelName]; ok && modelInfo.QuotaInfo != nil {
utilization := (1.0 - modelInfo.QuotaInfo.RemainingFraction) * 100
progress := &UsageProgress{
Utilization: utilization,
}
if modelInfo.QuotaInfo.ResetTime != "" {
if resetTime, err := time.Parse(time.RFC3339, modelInfo.QuotaInfo.ResetTime); err == nil {
progress.ResetsAt = &resetTime
progress.RemainingSeconds = int(time.Until(resetTime).Seconds())
}
}
info.FiveHour = progress
break
}
}
if loadResp != nil {
for _, credit := range loadResp.GetAvailableCredits() {
info.AICredits = append(info.AICredits, AICredit{
CreditType: credit.CreditType,
Amount: credit.GetAmount(),
MinimumBalance: credit.GetMinimumAmount(),
})
}
}
return info
}
// GetProxyURL 获取账户的代理 URL
func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Account) string {
if account.ProxyID == nil || f.proxyRepo == nil {
return ""
}
proxy, err := f.proxyRepo.GetByID(ctx, *account.ProxyID)
if err != nil || proxy == nil {
return ""
}
return proxy.URL()
}
// classifyForbiddenType 根据 403 响应体判断禁止类型
func classifyForbiddenType(body string) string {
lower := strings.ToLower(body)
switch {
case strings.Contains(lower, "validation_required") ||
strings.Contains(lower, "verify your account") ||
strings.Contains(lower, "validation_url"):
return forbiddenTypeValidation
case strings.Contains(lower, "terms of service") ||
strings.Contains(lower, "violation"):
return forbiddenTypeViolation
default:
return forbiddenTypeForbidden
}
}
// urlPattern 用于从 403 响应体中提取 URL降级方案
var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`)
// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接
func extractValidationURL(body string) string {
// 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url
var parsed struct {
Error struct {
Details []struct {
Metadata map[string]string `json:"metadata"`
} `json:"details"`
} `json:"error"`
}
if json.Unmarshal([]byte(body), &parsed) == nil {
for _, detail := range parsed.Error.Details {
if u := detail.Metadata["validation_url"]; u != "" {
return u
}
if u := detail.Metadata["appeal_url"]; u != "" {
return u
}
}
}
// 2. 降级:正则匹配 URL
lower := strings.ToLower(body)
if !strings.Contains(lower, "validation") &&
!strings.Contains(lower, "verify") &&
!strings.Contains(lower, "appeal") {
return ""
}
// 先解码常见转义再匹配
normalized := strings.ReplaceAll(body, `\u0026`, "&")
if m := urlPattern.FindString(normalized); m != "" {
return m
}
return ""
}

View File

@@ -0,0 +1,522 @@
//go:build unit
package service
import (
"errors"
"testing"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
)
// ---------------------------------------------------------------------------
// normalizeTier
// ---------------------------------------------------------------------------
func TestNormalizeTier(t *testing.T) {
tests := []struct {
name string
raw string
expected string
}{
{name: "empty string", raw: "", expected: ""},
{name: "free-tier", raw: "free-tier", expected: "FREE"},
{name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"},
{name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"},
{name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"},
{name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"},
{name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"},
{name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"},
{name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := normalizeTier(tt.raw)
require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw)
})
}
}
// ---------------------------------------------------------------------------
// buildUsageInfo
// ---------------------------------------------------------------------------
func aqfBoolPtr(v bool) *bool { return &v }
func aqfIntPtr(v int) *int { return &v }
func TestBuildUsageInfo_BasicModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.75,
ResetTime: "2026-03-08T12:00:00Z",
},
DisplayName: "Claude Sonnet 4",
SupportsImages: aqfBoolPtr(true),
SupportsThinking: aqfBoolPtr(false),
ThinkingBudget: aqfIntPtr(0),
Recommended: aqfBoolPtr(true),
MaxTokens: aqfIntPtr(200000),
MaxOutputTokens: aqfIntPtr(16384),
SupportedMimeTypes: map[string]bool{
"image/png": true,
"image/jpeg": true,
},
},
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.50,
ResetTime: "2026-03-08T15:00:00Z",
},
DisplayName: "Gemini 2.5 Pro",
MaxTokens: aqfIntPtr(1000000),
MaxOutputTokens: aqfIntPtr(65536),
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", nil)
// 基本字段
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
require.Equal(t, "PRO", info.SubscriptionTier)
require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw)
// AntigravityQuota
require.Len(t, info.AntigravityQuota, 2)
sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, sonnetQuota)
require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25
require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime)
geminiQuota := info.AntigravityQuota["gemini-2.5-pro"]
require.NotNil(t, geminiQuota)
require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50
require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime)
// AntigravityQuotaDetails
require.Len(t, info.AntigravityQuotaDetails, 2)
sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"]
require.NotNil(t, sonnetDetail)
require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName)
require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages)
require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking)
require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget)
require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended)
require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens)
require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens)
require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes)
geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"]
require.NotNil(t, geminiDetail)
require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName)
require.Nil(t, geminiDetail.SupportsImages)
require.Nil(t, geminiDetail.SupportsThinking)
require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens)
require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens)
}
func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 1.0,
},
},
},
DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{
"claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"},
"claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.Len(t, info.ModelForwardingRules, 2)
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"])
}
func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"some-model": {
QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
}
func TestBuildUsageInfo_EmptyModels(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info)
require.NotNil(t, info.AntigravityQuota)
require.Empty(t, info.AntigravityQuota)
require.NotNil(t, info.AntigravityQuotaDetails)
require.Empty(t, info.AntigravityQuotaDetails)
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
}
func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"model-without-quota": {
DisplayName: "No Quota Model",
// QuotaInfo is nil
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info)
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too")
}
func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"]
// When the first priority model exists, it should be used for FiveHour
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.40,
ResetTime: "2026-03-08T18:00:00Z",
},
},
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.80,
ResetTime: "2026-03-08T12:00:00Z",
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
// claude-sonnet-4-20250514 is first in priority list, so it should be used
expectedUtilization := (1.0 - 0.80) * 100 // 20
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime")
}
func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.60,
ResetTime: "2026-03-08T14:00:00Z",
},
},
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.30,
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info.FiveHour)
expectedUtilization := (1.0 - 0.60) * 100 // 40
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
}
func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// Only gemini-2.5-pro exists (third in priority list)
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"gemini-2.5-pro": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.30,
},
},
"other-model": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.90,
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info.FiveHour)
expectedUtilization := (1.0 - 0.30) * 100 // 70
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
}
func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
// None of the priority models exist
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"some-other-model": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.50,
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
}
func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.50,
ResetTime: "", // empty reset time
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
require.NotNil(t, info.FiveHour)
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
require.Equal(t, 0, info.FiveHour.RemainingSeconds)
}
func TestBuildUsageInfo_FullUtilization(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 0.0, // fully used
ResetTime: "2026-03-08T12:00:00Z",
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, quota)
require.Equal(t, 100, quota.Utilization)
}
func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{
"claude-sonnet-4-20250514": {
QuotaInfo: &antigravity.ModelQuotaInfo{
RemainingFraction: 1.0, // fully available
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
require.NotNil(t, quota)
require.Equal(t, 0, quota.Utilization)
}
func TestBuildUsageInfo_AICredits(t *testing.T) {
fetcher := &AntigravityQuotaFetcher{}
modelsResp := &antigravity.FetchAvailableModelsResponse{
Models: map[string]antigravity.ModelInfo{},
}
loadResp := &antigravity.LoadCodeAssistResponse{
PaidTier: &antigravity.PaidTierInfo{
ID: "g1-pro-tier",
AvailableCredits: []antigravity.AvailableCredit{
{
CreditType: "GOOGLE_ONE_AI",
CreditAmount: "25",
MinimumCreditAmountForUsage: "5",
},
},
},
}
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", loadResp)
require.Len(t, info.AICredits, 1)
require.Equal(t, "GOOGLE_ONE_AI", info.AICredits[0].CreditType)
require.Equal(t, 25.0, info.AICredits[0].Amount)
require.Equal(t, 5.0, info.AICredits[0].MinimumBalance)
}
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
// 模拟 FetchQuota 遇到 403 时的行为:
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
forbiddenErr := &antigravity.ForbiddenError{
StatusCode: 403,
Body: "Access denied",
}
// 验证 ForbiddenError 满足 errors.As
var target *antigravity.ForbiddenError
require.True(t, errors.As(forbiddenErr, &target))
require.Equal(t, 403, target.StatusCode)
require.Equal(t, "Access denied", target.Body)
require.Contains(t, forbiddenErr.Error(), "403")
}
// ---------------------------------------------------------------------------
// classifyForbiddenType
// ---------------------------------------------------------------------------
func TestClassifyForbiddenType(t *testing.T) {
tests := []struct {
name string
body string
expected string
}{
{
name: "VALIDATION_REQUIRED keyword",
body: `{"error":{"message":"VALIDATION_REQUIRED"}}`,
expected: "validation",
},
{
name: "verify your account",
body: `Please verify your account to continue`,
expected: "validation",
},
{
name: "contains validation_url field",
body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`,
expected: "validation",
},
{
name: "terms of service violation",
body: `Your account has been suspended for Terms of Service violation`,
expected: "violation",
},
{
name: "violation keyword",
body: `Account suspended due to policy violation`,
expected: "violation",
},
{
name: "generic 403",
body: `Access denied`,
expected: "forbidden",
},
{
name: "empty body",
body: "",
expected: "forbidden",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := classifyForbiddenType(tt.body)
require.Equal(t, tt.expected, got)
})
}
}
// ---------------------------------------------------------------------------
// extractValidationURL
// ---------------------------------------------------------------------------
func TestExtractValidationURL(t *testing.T) {
tests := []struct {
name string
body string
expected string
}{
{
name: "structured validation_url",
body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`,
expected: "https://accounts.google.com/verify?token=abc",
},
{
name: "structured appeal_url",
body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`,
expected: "https://support.google.com/appeal/123",
},
{
name: "validation_url takes priority over appeal_url",
body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`,
expected: "https://v.com",
},
{
name: "fallback regex with verify keyword",
body: `Please verify your account at https://accounts.google.com/verify`,
expected: "https://accounts.google.com/verify",
},
{
name: "no URL in generic forbidden",
body: `Access denied`,
expected: "",
},
{
name: "empty body",
body: "",
expected: "",
},
{
name: "URL present but no validation keywords",
body: `Error at https://example.com/something`,
expected: "",
},
{
name: "unicode escaped ampersand",
body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`,
expected: "https://accounts.google.com/verify?a=1&b=2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := extractValidationURL(tt.body)
require.Equal(t, tt.expected, got)
})
}
}

View File

@@ -0,0 +1,57 @@
package service
import (
"context"
"strings"
"time"
)
func normalizeAntigravityModelName(model string) string {
normalized := strings.ToLower(strings.TrimSpace(model))
normalized = strings.TrimPrefix(normalized, "models/")
return normalized
}
// resolveAntigravityModelKey 根据请求的模型名解析限流 key
// 返回空字符串表示无法解析
func resolveAntigravityModelKey(requestedModel string) string {
return normalizeAntigravityModelName(requestedModel)
}
// IsSchedulableForModel 结合模型级限流判断是否可调度。
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
}
func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requestedModel string) bool {
if a == nil {
return false
}
if !a.IsSchedulable() {
return false
}
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
// Antigravity + overages 启用 + 积分未耗尽 → 放行(有积分可用)
if a.Platform == PlatformAntigravity && a.IsOveragesEnabled() && !a.isCreditsExhausted() {
return true
}
return false
}
return true
}
// GetRateLimitRemainingTime 获取限流剩余时间(模型级限流)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
}
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型级限流)
// 返回 0 表示未限流或已过期
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
if a == nil {
return 0
}
return a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
}

View File

@@ -0,0 +1,904 @@
//go:build unit
package service
import (
"bytes"
"context"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
// ---------------------------------------------------------------------------
// 辅助函数:构造带 SingleAccountRetry 标记的 context
// ---------------------------------------------------------------------------
func ctxWithSingleAccountRetry() context.Context {
return context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true)
}
// ---------------------------------------------------------------------------
// 1. isSingleAccountRetry 测试
// ---------------------------------------------------------------------------
func TestIsSingleAccountRetry_True(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true)
require.True(t, isSingleAccountRetry(ctx))
}
func TestIsSingleAccountRetry_False_NoValue(t *testing.T) {
require.False(t, isSingleAccountRetry(context.Background()))
}
func TestIsSingleAccountRetry_False_ExplicitFalse(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, false)
require.False(t, isSingleAccountRetry(ctx))
}
func TestIsSingleAccountRetry_False_WrongType(t *testing.T) {
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, "true")
require.False(t, isSingleAccountRetry(ctx))
}
// ---------------------------------------------------------------------------
// 2. 常量验证
// ---------------------------------------------------------------------------
func TestSingleAccountRetryConstants(t *testing.T) {
require.Equal(t, 3, antigravitySingleAccountSmartRetryMaxAttempts,
"单账号原地重试最多 3 次")
require.Equal(t, 15*time.Second, antigravitySingleAccountSmartRetryMaxWait,
"单次最大等待 15s")
require.Equal(t, 30*time.Second, antigravitySingleAccountSmartRetryTotalMaxWait,
"总累计等待不超过 30s")
}
// ---------------------------------------------------------------------------
// 3. handleSmartRetry + 503 + SingleAccountRetry → 走 handleSingleAccountRetryInPlace
// (而非设模型限流 + 切换账号)
// ---------------------------------------------------------------------------
// TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace
// 核心场景503 + retryDelay >= 7s + SingleAccountRetry 标记
// → 不设模型限流、不切换账号,改为原地重试
func TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace(t *testing.T) {
// 原地重试成功
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 1,
Name: "acc-single",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
// 503 + 39s >= 7s 阈值 + MODEL_CAPACITY_EXHAUSTED
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
],
"message": "No capacity available for model gemini-3-pro-high on the server"
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(), // 关键:设置单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 关键断言:返回 resp原地重试成功而非 switchError切换账号
require.NotNil(t, result.resp, "should return successful response from in-place retry")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should NOT return switchError in single account mode")
require.Nil(t, result.err)
// 验证未设模型限流(单账号模式不应设限流)
require.Len(t, repo.modelRateLimitCalls, 0,
"should NOT set model rate limit in single account retry mode")
// 验证确实调用了 upstream原地重试
require.GreaterOrEqual(t, len(upstream.calls), 1, "should have made at least one retry call")
}
// TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches
// 对照组503 + retryDelay >= 7s + 无 SingleAccountRetry 标记
// → 照常设模型限流 + 切换账号
func TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 2,
Name: "acc-multi",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 503 + 39s >= 7s 阈值(使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED
// 因为 MODEL_CAPACITY_EXHAUSTED 走独立的重试路径,不触发 shouldRateLimitModel
respBody := []byte(`{
"error": {
"code": 503,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(), // 关键:无单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 对照:多账号模式返回 switchError
require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503")
require.Nil(t, result.resp, "should not return resp when switchError is set")
// 对照:多账号模式应设模型限流
require.Len(t, repo.modelRateLimitCalls, 1,
"multi-account mode SHOULD set model rate limit")
}
// TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches
// 边界情况429非 503+ SingleAccountRetry 标记
// → 单账号原地重试仅针对 503429 依然走切换账号逻辑
func TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches(t *testing.T) {
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 3,
Name: "acc-429",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 429 + 15s >= 7s 阈值
respBody := []byte(`{
"error": {
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusTooManyRequests, // 429不是 503
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(), // 有单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 429 即使有单账号标记,也应走切换账号
require.NotNil(t, result.switchError, "429 should still return switchError even with SingleAccountRetry")
require.Len(t, repo.modelRateLimitCalls, 1,
"429 should still set model rate limit even with SingleAccountRetry")
}
// ---------------------------------------------------------------------------
// 4. handleSmartRetry + 503 + 短延迟 + SingleAccountRetry → 智能重试耗尽后不设限流
// ---------------------------------------------------------------------------
// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit
// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503不设限流
func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) {
// 智能重试也返回 503
failRespBody := `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 4,
Name: "acc-short-503",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
// 0.1s < 7s 阈值
respBody := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 关键断言:单账号 503 模式下,智能重试耗尽后直接返回 503 响应,不切换
require.NotNil(t, result.resp, "should return 503 response directly for single account mode")
require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode)
require.Nil(t, result.switchError, "should NOT switch account in single account mode")
// 关键断言:不设模型限流
require.Len(t, repo.modelRateLimitCalls, 0,
"should NOT set model rate limit for 503 in single account mode")
}
// TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit
// 对照组503 + retryDelay < 7s + 无 SingleAccountRetry → 智能重试耗尽后照常设限流
// 使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED因为后者走独立的 60 次重试路径
func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *testing.T) {
failRespBody := `{
"error": {
"code": 503,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
failResp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(failRespBody)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{failResp},
errors: []error{nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 5,
Name: "acc-multi-503",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
}
respBody := []byte(`{
"error": {
"code": 503,
"status": "RESOURCE_EXHAUSTED",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(respBody)),
}
params := antigravityRetryLoopParams{
ctx: context.Background(), // 无单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
}
availableURLs := []string{"https://ag-1.test"}
svc := &AntigravityGatewayService{}
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 对照:多账号模式应返回 switchError
require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503")
// 对照:多账号模式应设模型限流
require.Len(t, repo.modelRateLimitCalls, 1,
"multi-account mode should set model rate limit")
}
// ---------------------------------------------------------------------------
// 5. handleSingleAccountRetryInPlace 直接测试
// ---------------------------------------------------------------------------
// TestHandleSingleAccountRetryInPlace_Success 原地重试成功
func TestHandleSingleAccountRetryInPlace_Success(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
account := &Account{
ID: 10,
Name: "acc-inplace-ok",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError, "should not switch account on success")
require.Nil(t, result.err)
}
// TestHandleSingleAccountRetryInPlace_AllRetriesFail 所有重试都失败,返回 503不设限流
func TestHandleSingleAccountRetryInPlace_AllRetriesFail(t *testing.T) {
// 构造 3 个 503 响应(对应 3 次原地重试)
var responses []*http.Response
var errors []error
for i := 0; i < antigravitySingleAccountSmartRetryMaxAttempts; i++ {
responses = append(responses, &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`)),
})
errors = append(errors, nil)
}
upstream := &mockSmartRetryUpstream{
responses: responses,
errors: errors,
}
account := &Account{
ID: 11,
Name: "acc-inplace-fail",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
origBody := []byte(`{"error":{"code":503,"status":"UNAVAILABLE"}}`)
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{"X-Test": {"original"}},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, origBody, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
// 关键:返回 503 resp不返回 switchError
require.NotNil(t, result.resp, "should return 503 response directly")
require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode)
require.Nil(t, result.switchError, "should NOT return switchError - let Handler handle it")
require.Nil(t, result.err)
// 验证确实重试了指定次数
require.Len(t, upstream.calls, antigravitySingleAccountSmartRetryMaxAttempts,
"should have made exactly maxAttempts retry calls")
}
// TestHandleSingleAccountRetryInPlace_WaitDurationClamped 等待时间被限制在 [min, max] 范围
func TestHandleSingleAccountRetryInPlace_WaitDurationClamped(t *testing.T) {
// 用短延迟的成功响应,只验证不 panic
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{successResp},
errors: []error{nil},
}
account := &Account{
ID: 12,
Name: "acc-clamp",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
// 等待时间过大应被 clamp 到 antigravitySingleAccountSmartRetryMaxWait
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 999*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp)
require.Equal(t, http.StatusOK, result.resp.StatusCode)
}
// TestHandleSingleAccountRetryInPlace_ContextCanceled context 取消时立即返回
func TestHandleSingleAccountRetryInPlace_ContextCanceled(t *testing.T) {
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{nil},
errors: []error{nil},
}
account := &Account{
ID: 13,
Name: "acc-cancel",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
ctx, cancel := context.WithCancel(context.Background())
ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true)
cancel() // 立即取消
params := antigravityRetryLoopParams{
ctx: ctx,
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.Error(t, result.err, "should return context error")
// 不应调用 upstream因为在等待阶段就被取消了
require.Len(t, upstream.calls, 0, "should not call upstream when context is canceled")
}
// TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry 网络错误时继续重试
func TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry(t *testing.T) {
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
// 第1次网络错误nil resp第2次成功
responses: []*http.Response{nil, successResp},
errors: []error{nil, nil},
}
account := &Account{
ID: 14,
Name: "acc-net-retry",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response after network error recovery")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Len(t, upstream.calls, 2, "first call fails (network error), second succeeds")
}
// ---------------------------------------------------------------------------
// 6. antigravityRetryLoop 预检查:单账号模式跳过限流
// ---------------------------------------------------------------------------
// TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit
// 预检查中,如果有 SingleAccountRetry 标记,即使账号已限流也跳过直接发请求
func TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit(t *testing.T) {
// 创建一个已设模型限流的账号
upstream := &recordingOKUpstream{}
account := &Account{
ID: 20,
Name: "acc-rate-limited",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.NoError(t, err, "should not return error")
require.NotNil(t, result, "should return result")
require.NotNil(t, result.resp, "should have response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
// 关键:尽管限流了,有 SingleAccountRetry 标记时仍然到达了 upstream
require.Equal(t, 1, upstream.calls, "should have reached upstream despite rate limit")
}
// TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit
// 对照组:无 SingleAccountRetry + 已限流 → 预检查返回 switchError
func TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit(t *testing.T) {
upstream := &recordingOKUpstream{}
account := &Account{
ID: 21,
Name: "acc-rate-limited-multi",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
Extra: map[string]any{
modelRateLimitsKey: map[string]any{
"claude-sonnet-4-5": map[string]any{
"rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339),
},
},
},
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: context.Background(), // 无单账号标记
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
requestedModel: "claude-sonnet-4-5",
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.Nil(t, result, "should not return result on rate limit switch")
require.NotNil(t, err, "should return error")
var switchErr *AntigravityAccountSwitchError
require.ErrorAs(t, err, &switchErr, "should return AntigravityAccountSwitchError")
require.Equal(t, account.ID, switchErr.OriginalAccountID)
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
// upstream 不应被调用(预检查就短路了)
require.Equal(t, 0, upstream.calls, "upstream should NOT be called when pre-check blocks")
}
// ---------------------------------------------------------------------------
// 7. 端到端集成场景测试
// ---------------------------------------------------------------------------
// TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E
// 端到端场景503 + 单账号 + 原地重试第2次成功
func TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E(t *testing.T) {
// 第1次原地重试仍返回 503第2次成功
fail503Body := `{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
]
}
}`
resp503 := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(fail503Body)),
}
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
responses: []*http.Response{resp503, successResp},
errors: []error{nil, nil},
}
account := &Account{
ID: 30,
Name: "acc-e2e",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Concurrency: 1,
}
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
}
params := antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
}
svc := &AntigravityGatewayService{}
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
require.NotNil(t, result)
require.Equal(t, smartRetryActionBreakWithResp, result.action)
require.NotNil(t, result.resp, "should return successful response after 2nd attempt")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
require.Nil(t, result.switchError)
require.Len(t, upstream.calls, 2, "first 503, second OK")
}
// TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E
// 通过 antigravityRetryLoop → handleSmartRetry → handleSingleAccountRetryInPlace 完整链路
func TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E(t *testing.T) {
// 初始请求返回 503 + 长延迟
initial503Body := []byte(`{
"error": {
"code": 503,
"status": "UNAVAILABLE",
"details": [
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "10s"}
],
"message": "No capacity available"
}
}`)
initial503Resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(initial503Body)),
}
// 原地重试成功
successResp := &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{},
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
}
upstream := &mockSmartRetryUpstream{
// 第1次调用retryLoop 主循环)返回 503
// 第2次调用handleSingleAccountRetryInPlace 原地重试)返回 200
responses: []*http.Response{initial503Resp, successResp},
errors: []error{nil, nil},
}
repo := &stubAntigravityAccountRepo{}
account := &Account{
ID: 31,
Name: "acc-e2e-loop",
Type: AccountTypeOAuth,
Platform: PlatformAntigravity,
Schedulable: true,
Status: StatusActive,
Concurrency: 1,
}
svc := &AntigravityGatewayService{}
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
ctx: ctxWithSingleAccountRetry(),
prefix: "[test]",
account: account,
accessToken: "token",
action: "generateContent",
body: []byte(`{"input":"test"}`),
httpUpstream: upstream,
accountRepo: repo,
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
return nil
},
})
require.NoError(t, err, "should not return error on successful retry")
require.NotNil(t, result, "should return result")
require.NotNil(t, result.resp, "should return response")
require.Equal(t, http.StatusOK, result.resp.StatusCode)
// 验证未设模型限流
require.Len(t, repo.modelRateLimitCalls, 0,
"should NOT set model rate limit in single account retry mode")
}

View File

@@ -0,0 +1,68 @@
//go:build unit
package service
import (
"testing"
)
func TestApplyThinkingModelSuffix(t *testing.T) {
tests := []struct {
name string
mappedModel string
thinkingEnabled bool
expected string
}{
// Thinking 未开启:保持原样
{
name: "thinking disabled - claude-sonnet-4-5 unchanged",
mappedModel: "claude-sonnet-4-5",
thinkingEnabled: false,
expected: "claude-sonnet-4-5",
},
{
name: "thinking disabled - other model unchanged",
mappedModel: "claude-opus-4-6-thinking",
thinkingEnabled: false,
expected: "claude-opus-4-6-thinking",
},
// Thinking 开启 + claude-sonnet-4-5自动添加后缀
{
name: "thinking enabled - claude-sonnet-4-5 becomes thinking version",
mappedModel: "claude-sonnet-4-5",
thinkingEnabled: true,
expected: "claude-sonnet-4-5-thinking",
},
// Thinking 开启 + 其他模型:保持原样
{
name: "thinking enabled - claude-sonnet-4-5-thinking unchanged",
mappedModel: "claude-sonnet-4-5-thinking",
thinkingEnabled: true,
expected: "claude-sonnet-4-5-thinking",
},
{
name: "thinking enabled - claude-opus-4-6-thinking unchanged",
mappedModel: "claude-opus-4-6-thinking",
thinkingEnabled: true,
expected: "claude-opus-4-6-thinking",
},
{
name: "thinking enabled - gemini model unchanged",
mappedModel: "gemini-3-flash",
thinkingEnabled: true,
expected: "gemini-3-flash",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := applyThinkingModelSuffix(tt.mappedModel, tt.thinkingEnabled)
if result != tt.expected {
t.Errorf("applyThinkingModelSuffix(%q, %v) = %q, want %q",
tt.mappedModel, tt.thinkingEnabled, result, tt.expected)
}
})
}
}

View File

@@ -0,0 +1,185 @@
package service
import (
"context"
"errors"
"log/slog"
"strconv"
"strings"
"sync"
"time"
)
const (
antigravityTokenRefreshSkew = 3 * time.Minute
antigravityTokenCacheSkew = 5 * time.Minute
antigravityBackfillCooldown = 5 * time.Minute
)
// AntigravityTokenCache token cache interface.
type AntigravityTokenCache = GeminiTokenCache
// AntigravityTokenProvider manages access_token for antigravity accounts.
type AntigravityTokenProvider struct {
accountRepo AccountRepository
tokenCache AntigravityTokenCache
antigravityOAuthService *AntigravityOAuthService
backfillCooldown sync.Map // key: accountID -> last attempt time
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
refreshPolicy ProviderRefreshPolicy
}
func NewAntigravityTokenProvider(
accountRepo AccountRepository,
tokenCache AntigravityTokenCache,
antigravityOAuthService *AntigravityOAuthService,
) *AntigravityTokenProvider {
return &AntigravityTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
antigravityOAuthService: antigravityOAuthService,
refreshPolicy: AntigravityProviderRefreshPolicy(),
}
}
// SetRefreshAPI injects unified OAuth refresh API and executor.
func (p *AntigravityTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
p.refreshAPI = api
p.executor = executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func (p *AntigravityTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
// GetAccessToken returns a valid access_token.
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAntigravity {
return "", errors.New("not an antigravity account")
}
// upstream accounts use static api_key and never refresh oauth token.
if account.Type == AccountTypeUpstream {
apiKey := account.GetCredential("api_key")
if apiKey == "" {
return "", errors.New("upstream account missing api_key in credentials")
}
return apiKey, nil
}
if account.Type != AccountTypeOAuth {
return "", errors.New("not an antigravity oauth account")
}
cacheKey := AntigravityTokenCacheKey(account)
// 1) Try cache first.
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
// 2) Refresh if needed (pre-expiry skew).
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, antigravityTokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
return token, nil
}
}
// default policy: continue with existing token.
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if needsRefresh && p.tokenCache != nil {
// Backward-compatible test path when refreshAPI is not injected.
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// Backfill project_id online when missing, with cooldown to avoid hammering.
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
if p.shouldAttemptBackfill(account.ID) {
p.markBackfillAttempted(account.ID)
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
account.Credentials["project_id"] = projectID
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
slog.Warn("antigravity_project_id_backfill_persist_failed",
"account_id", account.ID,
"error", updateErr,
)
}
}
}
}
// 3) Populate cache with TTL.
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
} else {
ttl := 30 * time.Minute
if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > antigravityTokenCacheSkew:
ttl = until - antigravityTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
}
return accessToken, nil
}
// shouldAttemptBackfill checks backfill cooldown.
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
if v, ok := p.backfillCooldown.Load(accountID); ok {
if lastAttempt, ok := v.(time.Time); ok {
return time.Since(lastAttempt) > antigravityBackfillCooldown
}
}
return true
}
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
p.backfillCooldown.Store(accountID, time.Now())
}
func AntigravityTokenCacheKey(account *Account) string {
projectID := strings.TrimSpace(account.GetCredential("project_id"))
if projectID != "" {
return "ag:" + projectID
}
return "ag:account:" + strconv.FormatInt(account.ID, 10)
}

View File

@@ -0,0 +1,97 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestAntigravityTokenProvider_GetAccessToken_Upstream(t *testing.T) {
provider := &AntigravityTokenProvider{}
t.Run("upstream account with valid api_key", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
Credentials: map[string]any{
"api_key": "sk-test-key-12345",
},
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "sk-test-key-12345", token)
})
t.Run("upstream account missing api_key", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
Credentials: map[string]any{},
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "upstream account missing api_key")
require.Empty(t, token)
})
t.Run("upstream account with empty api_key", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
Credentials: map[string]any{
"api_key": "",
},
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "upstream account missing api_key")
require.Empty(t, token)
})
t.Run("upstream account with nil credentials", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeUpstream,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "upstream account missing api_key")
require.Empty(t, token)
})
}
func TestAntigravityTokenProvider_GetAccessToken_Guards(t *testing.T) {
provider := &AntigravityTokenProvider{}
t.Run("nil account", func(t *testing.T) {
token, err := provider.GetAccessToken(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "account is nil")
require.Empty(t, token)
})
t.Run("non-antigravity platform", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an antigravity account")
require.Empty(t, token)
})
t.Run("unsupported account type", func(t *testing.T) {
account := &Account{
Platform: PlatformAntigravity,
Type: AccountTypeAPIKey,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an antigravity oauth account")
require.Empty(t, token)
})
}

View File

@@ -0,0 +1,90 @@
package service
import (
"context"
"fmt"
"log"
"strings"
"time"
)
const (
// antigravityRefreshWindow Antigravity token 提前刷新窗口15分钟
// Google OAuth token 有效期55分钟提前15分钟刷新
antigravityRefreshWindow = 15 * time.Minute
)
// AntigravityTokenRefresher 实现 TokenRefresher 接口
type AntigravityTokenRefresher struct {
antigravityOAuthService *AntigravityOAuthService
}
func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthService) *AntigravityTokenRefresher {
return &AntigravityTokenRefresher{
antigravityOAuthService: antigravityOAuthService,
}
}
// CacheKey 返回用于分布式锁的缓存键
func (r *AntigravityTokenRefresher) CacheKey(account *Account) string {
return AntigravityTokenCacheKey(account)
}
// CanRefresh 检查是否可以刷新此账户
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
}
// NeedsRefresh 检查账户是否需要刷新
// Antigravity 使用固定的15分钟刷新窗口忽略全局配置
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
if !r.CanRefresh(account) {
return false
}
expiresAt := account.GetCredentialAsTime("expires_at")
if expiresAt == nil {
return false
}
timeUntilExpiry := time.Until(*expiresAt)
needsRefresh := timeUntilExpiry < antigravityRefreshWindow
if needsRefresh {
fmt.Printf("[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v\n",
account.ID, expiresAt.Format("2006-01-02 15:04:05"), timeUntilExpiry, antigravityRefreshWindow)
}
return needsRefresh
}
// Refresh 执行 token 刷新
func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
tokenInfo, err := r.antigravityOAuthService.RefreshAccountToken(ctx, account)
if err != nil {
return nil, err
}
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
// 合并旧的 credentials保留新 credentials 中不存在的字段
newCredentials = MergeCredentials(account.Credentials, newCredentials)
// 特殊处理 project_id如果新值为空但旧值非空保留旧值
// 这确保了即使 LoadCodeAssist 失败project_id 也不会丢失
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
newCredentials["project_id"] = oldProjectID
}
}
// 如果 project_id 获取失败,只记录警告,不返回错误
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
// Token 刷新本身是成功的access_token 和 refresh_token 已更新)
if tokenInfo.ProjectIDMissing {
if tokenInfo.ProjectID != "" {
// 有旧的 project_id本次获取失败保留旧值
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id", account.ID)
} else {
// 从未获取过 project_id本次也失败但不返回错误以允许下次重试
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败project_id 缺失,但 token 已更新,将在下次刷新时重试", account.ID)
}
}
return newCredentials, nil
}

View File

@@ -0,0 +1,143 @@
package service
import (
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
)
// API Key status constants
const (
StatusAPIKeyActive = "active"
StatusAPIKeyDisabled = "disabled"
StatusAPIKeyQuotaExhausted = "quota_exhausted"
StatusAPIKeyExpired = "expired"
)
// Rate limit window durations
const (
RateLimitWindow5h = 5 * time.Hour
RateLimitWindow1d = 24 * time.Hour
RateLimitWindow7d = 7 * 24 * time.Hour
)
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
// A nil windowStart is treated as expired — no initialized window means any accumulated usage is stale.
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
return windowStart == nil || time.Since(*windowStart) >= duration
}
type APIKey struct {
ID int64
UserID int64
Key string
Name string
GroupID *int64
Status string
IPWhitelist []string
IPBlacklist []string
// 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。
CompiledIPWhitelist *ip.CompiledIPRules `json:"-"`
CompiledIPBlacklist *ip.CompiledIPRules `json:"-"`
LastUsedAt *time.Time
CreatedAt time.Time
UpdatedAt time.Time
User *User
Group *Group
// Quota fields
Quota float64 // Quota limit in USD (0 = unlimited)
QuotaUsed float64 // Used quota amount
ExpiresAt *time.Time // Expiration time (nil = never expires)
// Rate limit fields
RateLimit5h float64 // Rate limit in USD per 5h (0 = unlimited)
RateLimit1d float64 // Rate limit in USD per 1d (0 = unlimited)
RateLimit7d float64 // Rate limit in USD per 7d (0 = unlimited)
Usage5h float64 // Used amount in current 5h window
Usage1d float64 // Used amount in current 1d window
Usage7d float64 // Used amount in current 7d window
Window5hStart *time.Time // Start of current 5h window
Window1dStart *time.Time // Start of current 1d window
Window7dStart *time.Time // Start of current 7d window
}
func (k *APIKey) IsActive() bool {
return k.Status == StatusActive
}
// HasRateLimits returns true if any rate limit window is configured
func (k *APIKey) HasRateLimits() bool {
return k.RateLimit5h > 0 || k.RateLimit1d > 0 || k.RateLimit7d > 0
}
// IsExpired checks if the API key has expired
func (k *APIKey) IsExpired() bool {
if k.ExpiresAt == nil {
return false
}
return time.Now().After(*k.ExpiresAt)
}
// IsQuotaExhausted checks if the API key quota is exhausted
func (k *APIKey) IsQuotaExhausted() bool {
if k.Quota <= 0 {
return false // unlimited
}
return k.QuotaUsed >= k.Quota
}
// GetQuotaRemaining returns remaining quota (-1 for unlimited)
func (k *APIKey) GetQuotaRemaining() float64 {
if k.Quota <= 0 {
return -1 // unlimited
}
remaining := k.Quota - k.QuotaUsed
if remaining < 0 {
return 0
}
return remaining
}
// GetDaysUntilExpiry returns days until expiry (-1 for never expires)
func (k *APIKey) GetDaysUntilExpiry() int {
if k.ExpiresAt == nil {
return -1 // never expires
}
duration := time.Until(*k.ExpiresAt)
if duration < 0 {
return 0
}
return int(duration.Hours() / 24)
}
// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired.
func (k *APIKey) EffectiveUsage5h() float64 {
if IsWindowExpired(k.Window5hStart, RateLimitWindow5h) {
return 0
}
return k.Usage5h
}
// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired.
func (k *APIKey) EffectiveUsage1d() float64 {
if IsWindowExpired(k.Window1dStart, RateLimitWindow1d) {
return 0
}
return k.Usage1d
}
// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired.
func (k *APIKey) EffectiveUsage7d() float64 {
if IsWindowExpired(k.Window7dStart, RateLimitWindow7d) {
return 0
}
return k.Usage7d
}
// APIKeyListFilters holds optional filtering parameters for listing API keys.
type APIKeyListFilters struct {
Search string
Status string
GroupID *int64 // nil=不筛选, 0=无分组, >0=指定分组
}

View File

@@ -0,0 +1,78 @@
package service
import "time"
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
type APIKeyAuthSnapshot struct {
APIKeyID int64 `json:"api_key_id"`
UserID int64 `json:"user_id"`
GroupID *int64 `json:"group_id,omitempty"`
Status string `json:"status"`
IPWhitelist []string `json:"ip_whitelist,omitempty"`
IPBlacklist []string `json:"ip_blacklist,omitempty"`
User APIKeyAuthUserSnapshot `json:"user"`
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
// Quota fields for API Key independent quota feature
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
QuotaUsed float64 `json:"quota_used"` // Used quota amount
// Expiration field for API Key expiration feature
ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires)
// Rate limit configuration (only limits, not usage - usage read from Redis at check time)
RateLimit5h float64 `json:"rate_limit_5h"`
RateLimit1d float64 `json:"rate_limit_1d"`
RateLimit7d float64 `json:"rate_limit_7d"`
}
// APIKeyAuthUserSnapshot 用户快照
type APIKeyAuthUserSnapshot struct {
ID int64 `json:"id"`
Status string `json:"status"`
Role string `json:"role"`
Balance float64 `json:"balance"`
Concurrency int `json:"concurrency"`
}
// APIKeyAuthGroupSnapshot 分组快照
type APIKeyAuthGroupSnapshot struct {
ID int64 `json:"id"`
Name string `json:"name"`
Platform string `json:"platform"`
Status string `json:"status"`
SubscriptionType string `json:"subscription_type"`
RateMultiplier float64 `json:"rate_multiplier"`
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
ClaudeCodeOnly bool `json:"claude_code_only"`
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
// Only anthropic groups use these fields; others may leave them empty.
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
ModelRoutingEnabled bool `json:"model_routing_enabled"`
MCPXMLInject bool `json:"mcp_xml_inject"`
// 支持的模型系列(仅 antigravity 平台使用)
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
// OpenAI Messages 调度配置(仅 openai 平台使用)
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
}
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
type APIKeyAuthCacheEntry struct {
NotFound bool `json:"not_found"`
Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"`
}

View File

@@ -0,0 +1,313 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"math/rand/v2"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/dgraph-io/ristretto"
)
type apiKeyAuthCacheConfig struct {
l1Size int
l1TTL time.Duration
l2TTL time.Duration
negativeTTL time.Duration
jitterPercent int
singleflight bool
}
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
if cfg == nil {
return apiKeyAuthCacheConfig{}
}
auth := cfg.APIKeyAuth
return apiKeyAuthCacheConfig{
l1Size: auth.L1Size,
l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second,
l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second,
negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second,
jitterPercent: auth.JitterPercent,
singleflight: auth.Singleflight,
}
}
func (c apiKeyAuthCacheConfig) l1Enabled() bool {
return c.l1Size > 0 && c.l1TTL > 0
}
func (c apiKeyAuthCacheConfig) l2Enabled() bool {
return c.l2TTL > 0
}
func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
return c.negativeTTL > 0
}
// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。
// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
if ttl <= 0 {
return ttl
}
if c.jitterPercent <= 0 {
return ttl
}
percent := c.jitterPercent
if percent > 100 {
percent = 100
}
delta := float64(percent) / 100
randVal := rand.Float64()
factor := 1 - delta + randVal*(2*delta)
if factor <= 0 {
return ttl
}
return time.Duration(float64(ttl) * factor)
}
func (s *APIKeyService) initAuthCache(cfg *config.Config) {
s.authCfg = newAPIKeyAuthCacheConfig(cfg)
if !s.authCfg.l1Enabled() {
return
}
cache, err := ristretto.NewCache(&ristretto.Config{
NumCounters: int64(s.authCfg.l1Size) * 10,
MaxCost: int64(s.authCfg.l1Size),
BufferItems: 64,
})
if err != nil {
return
}
s.authCacheL1 = cache
}
// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation.
// This should be called after the service is fully initialized.
func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) {
if s.cache == nil || s.authCacheL1 == nil {
return
}
if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) {
s.authCacheL1.Del(cacheKey)
}); err != nil {
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
}
}
func (s *APIKeyService) authCacheKey(key string) string {
sum := sha256.Sum256([]byte(key))
return hex.EncodeToString(sum[:])
}
func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) {
if s.authCacheL1 != nil {
if val, ok := s.authCacheL1.Get(cacheKey); ok {
if entry, ok := val.(*APIKeyAuthCacheEntry); ok {
return entry, true
}
}
}
if s.cache == nil || !s.authCfg.l2Enabled() {
return nil, false
}
entry, err := s.cache.GetAuthCache(ctx, cacheKey)
if err != nil {
return nil, false
}
s.setAuthCacheL1(cacheKey, entry)
return entry, true
}
func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) {
if s.authCacheL1 == nil || entry == nil {
return
}
ttl := s.authCfg.l1TTL
if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl {
ttl = s.authCfg.negativeTTL
}
ttl = s.authCfg.jitterTTL(ttl)
_ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl)
}
func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) {
if entry == nil {
return
}
s.setAuthCacheL1(cacheKey, entry)
if s.cache == nil || !s.authCfg.l2Enabled() {
return
}
_ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl))
}
func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
if s.authCacheL1 != nil {
s.authCacheL1.Del(cacheKey)
}
if s.cache == nil {
return
}
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
// Publish invalidation message to other instances
_ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey)
}
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
if err != nil {
if errors.Is(err, ErrAPIKeyNotFound) {
entry := &APIKeyAuthCacheEntry{NotFound: true}
if s.authCfg.negativeEnabled() {
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL)
}
return entry, nil
}
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
snapshot := s.snapshotFromAPIKey(apiKey)
if snapshot == nil {
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
}
entry := &APIKeyAuthCacheEntry{Snapshot: snapshot}
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL)
return entry, nil
}
func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) {
if entry == nil {
return nil, false, nil
}
if entry.NotFound {
return nil, true, ErrAPIKeyNotFound
}
if entry.Snapshot == nil {
return nil, false, nil
}
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
}
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
if apiKey == nil || apiKey.User == nil {
return nil
}
snapshot := &APIKeyAuthSnapshot{
APIKeyID: apiKey.ID,
UserID: apiKey.UserID,
GroupID: apiKey.GroupID,
Status: apiKey.Status,
IPWhitelist: apiKey.IPWhitelist,
IPBlacklist: apiKey.IPBlacklist,
Quota: apiKey.Quota,
QuotaUsed: apiKey.QuotaUsed,
ExpiresAt: apiKey.ExpiresAt,
RateLimit5h: apiKey.RateLimit5h,
RateLimit1d: apiKey.RateLimit1d,
RateLimit7d: apiKey.RateLimit7d,
User: APIKeyAuthUserSnapshot{
ID: apiKey.User.ID,
Status: apiKey.User.Status,
Role: apiKey.User.Role,
Balance: apiKey.User.Balance,
Concurrency: apiKey.User.Concurrency,
},
}
if apiKey.Group != nil {
snapshot.Group = &APIKeyAuthGroupSnapshot{
ID: apiKey.Group.ID,
Name: apiKey.Group.Name,
Platform: apiKey.Group.Platform,
Status: apiKey.Group.Status,
SubscriptionType: apiKey.Group.SubscriptionType,
RateMultiplier: apiKey.Group.RateMultiplier,
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
ImagePrice1K: apiKey.Group.ImagePrice1K,
ImagePrice2K: apiKey.Group.ImagePrice2K,
ImagePrice4K: apiKey.Group.ImagePrice4K,
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
FallbackGroupID: apiKey.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
ModelRouting: apiKey.Group.ModelRouting,
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
MCPXMLInject: apiKey.Group.MCPXMLInject,
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
}
}
return snapshot
}
func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey {
if snapshot == nil {
return nil
}
apiKey := &APIKey{
ID: snapshot.APIKeyID,
UserID: snapshot.UserID,
GroupID: snapshot.GroupID,
Key: key,
Status: snapshot.Status,
IPWhitelist: snapshot.IPWhitelist,
IPBlacklist: snapshot.IPBlacklist,
Quota: snapshot.Quota,
QuotaUsed: snapshot.QuotaUsed,
ExpiresAt: snapshot.ExpiresAt,
RateLimit5h: snapshot.RateLimit5h,
RateLimit1d: snapshot.RateLimit1d,
RateLimit7d: snapshot.RateLimit7d,
User: &User{
ID: snapshot.User.ID,
Status: snapshot.User.Status,
Role: snapshot.User.Role,
Balance: snapshot.User.Balance,
Concurrency: snapshot.User.Concurrency,
},
}
if snapshot.Group != nil {
apiKey.Group = &Group{
ID: snapshot.Group.ID,
Name: snapshot.Group.Name,
Platform: snapshot.Group.Platform,
Status: snapshot.Group.Status,
Hydrated: true,
SubscriptionType: snapshot.Group.SubscriptionType,
RateMultiplier: snapshot.Group.RateMultiplier,
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
ImagePrice1K: snapshot.Group.ImagePrice1K,
ImagePrice2K: snapshot.Group.ImagePrice2K,
ImagePrice4K: snapshot.Group.ImagePrice4K,
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
FallbackGroupID: snapshot.Group.FallbackGroupID,
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
ModelRouting: snapshot.Group.ModelRouting,
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
MCPXMLInject: snapshot.Group.MCPXMLInject,
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
}
}
s.compileAPIKeyIPRules(apiKey)
return apiKey
}

View File

@@ -0,0 +1,48 @@
package service
import "context"
// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存
func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) {
if key == "" {
return
}
cacheKey := s.authCacheKey(key)
s.deleteAuthCache(ctx, cacheKey)
}
// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存
func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
if userID <= 0 {
return
}
keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID)
if err != nil {
return
}
s.deleteAuthCacheByKeys(ctx, keys)
}
// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存
func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
if groupID <= 0 {
return
}
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID)
if err != nil {
return
}
s.deleteAuthCacheByKeys(ctx, keys)
}
func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) {
if len(keys) == 0 {
return
}
for _, key := range keys {
if key == "" {
continue
}
s.deleteAuthCache(ctx, s.authCacheKey(key))
}
}

View File

@@ -0,0 +1,245 @@
package service
import (
"testing"
"time"
)
func TestIsWindowExpired(t *testing.T) {
now := time.Now()
tests := []struct {
name string
start *time.Time
duration time.Duration
want bool
}{
{
name: "nil window start (treated as expired)",
start: nil,
duration: RateLimitWindow5h,
want: true,
},
{
name: "active window (started 1h ago, 5h window)",
start: rateLimitTimePtr(now.Add(-1 * time.Hour)),
duration: RateLimitWindow5h,
want: false,
},
{
name: "expired window (started 6h ago, 5h window)",
start: rateLimitTimePtr(now.Add(-6 * time.Hour)),
duration: RateLimitWindow5h,
want: true,
},
{
name: "exactly at boundary (started 5h ago, 5h window)",
start: rateLimitTimePtr(now.Add(-5 * time.Hour)),
duration: RateLimitWindow5h,
want: true,
},
{
name: "active 1d window (started 12h ago)",
start: rateLimitTimePtr(now.Add(-12 * time.Hour)),
duration: RateLimitWindow1d,
want: false,
},
{
name: "expired 1d window (started 25h ago)",
start: rateLimitTimePtr(now.Add(-25 * time.Hour)),
duration: RateLimitWindow1d,
want: true,
},
{
name: "active 7d window (started 3d ago)",
start: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)),
duration: RateLimitWindow7d,
want: false,
},
{
name: "expired 7d window (started 8d ago)",
start: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)),
duration: RateLimitWindow7d,
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := IsWindowExpired(tt.start, tt.duration)
if got != tt.want {
t.Errorf("IsWindowExpired() = %v, want %v", got, tt.want)
}
})
}
}
func TestAPIKey_EffectiveUsage(t *testing.T) {
now := time.Now()
tests := []struct {
name string
key APIKey
want5h float64
want1d float64
want7d float64
}{
{
name: "all windows active",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
Usage7d: 50.0,
Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)),
},
want5h: 5.0,
want1d: 10.0,
want7d: 50.0,
},
{
name: "all windows expired",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
Usage7d: 50.0,
Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-25 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)),
},
want5h: 0,
want1d: 0,
want7d: 0,
},
{
name: "nil window starts return 0 (stale usage reset)",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
Usage7d: 50.0,
Window5hStart: nil,
Window1dStart: nil,
Window7dStart: nil,
},
want5h: 0,
want1d: 0,
want7d: 0,
},
{
name: "mixed: 5h expired, 1d active, 7d nil",
key: APIKey{
Usage5h: 5.0,
Usage1d: 10.0,
Usage7d: 50.0,
Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)),
Window7dStart: nil,
},
want5h: 0,
want1d: 10.0,
want7d: 0,
},
{
name: "zero usage with active windows",
key: APIKey{
Usage5h: 0,
Usage1d: 0,
Usage7d: 0,
Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
},
want5h: 0,
want1d: 0,
want7d: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.key.EffectiveUsage5h(); got != tt.want5h {
t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h)
}
if got := tt.key.EffectiveUsage1d(); got != tt.want1d {
t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d)
}
if got := tt.key.EffectiveUsage7d(); got != tt.want7d {
t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d)
}
})
}
}
func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
now := time.Now()
tests := []struct {
name string
data APIKeyRateLimitData
want5h float64
want1d float64
want7d float64
}{
{
name: "all windows active",
data: APIKeyRateLimitData{
Usage5h: 3.0,
Usage1d: 8.0,
Usage7d: 40.0,
Window5hStart: rateLimitTimePtr(now.Add(-2 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-10 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-2 * 24 * time.Hour)),
},
want5h: 3.0,
want1d: 8.0,
want7d: 40.0,
},
{
name: "all windows expired",
data: APIKeyRateLimitData{
Usage5h: 3.0,
Usage1d: 8.0,
Usage7d: 40.0,
Window5hStart: rateLimitTimePtr(now.Add(-10 * time.Hour)),
Window1dStart: rateLimitTimePtr(now.Add(-48 * time.Hour)),
Window7dStart: rateLimitTimePtr(now.Add(-10 * 24 * time.Hour)),
},
want5h: 0,
want1d: 0,
want7d: 0,
},
{
name: "nil window starts return 0 (stale usage reset)",
data: APIKeyRateLimitData{
Usage5h: 3.0,
Usage1d: 8.0,
Usage7d: 40.0,
Window5hStart: nil,
Window1dStart: nil,
Window7dStart: nil,
},
want5h: 0,
want1d: 0,
want7d: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.data.EffectiveUsage5h(); got != tt.want5h {
t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h)
}
if got := tt.data.EffectiveUsage1d(); got != tt.want1d {
t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d)
}
if got := tt.data.EffectiveUsage7d(); got != tt.want7d {
t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d)
}
})
}
}
func rateLimitTimePtr(t time.Time) *time.Time {
return &t
}

View File

@@ -0,0 +1,881 @@
package service
import (
"context"
"crypto/rand"
"encoding/hex"
"fmt"
"strconv"
"strings"
"sync"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
"github.com/dgraph-io/ristretto"
"golang.org/x/sync/singleflight"
)
var (
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
// ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key has expired")
ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期")
// ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted")
ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完")
// Rate limit errors
ErrAPIKeyRateLimit5hExceeded = infraerrors.TooManyRequests("API_KEY_RATE_5H_EXCEEDED", "api key 5小时限额已用完")
ErrAPIKeyRateLimit1dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_1D_EXCEEDED", "api key 日限额已用完")
ErrAPIKeyRateLimit7dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_7D_EXCEEDED", "api key 7天限额已用完")
)
const (
apiKeyMaxErrorsPerHour = 20
apiKeyLastUsedMinTouch = 30 * time.Second
// DB 写失败后的短退避,避免请求路径持续同步重试造成写风暴与高延迟。
apiKeyLastUsedFailBackoff = 5 * time.Second
)
type APIKeyRepository interface {
Create(ctx context.Context, key *APIKey) error
GetByID(ctx context.Context, id int64) (*APIKey, error)
// GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID用于删除等轻量场景
GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error)
GetByKey(ctx context.Context, key string) (*APIKey, error)
// GetByKeyForAuth 认证专用查询,返回最小字段集
GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error)
Update(ctx context.Context, key *APIKey) error
Delete(ctx context.Context, id int64) error
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error)
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
CountByUserID(ctx context.Context, userID int64) (int64, error)
ExistsByKey(ctx context.Context, key string) (bool, error)
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
// Quota methods
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error)
UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error
// Rate limit methods
IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error
ResetRateLimitWindows(ctx context.Context, id int64) error
GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error)
}
// APIKeyRateLimitData holds rate limit usage and window state for an API key.
type APIKeyRateLimitData struct {
Usage5h float64
Usage1d float64
Usage7d float64
Window5hStart *time.Time
Window1dStart *time.Time
Window7dStart *time.Time
}
// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired.
func (d *APIKeyRateLimitData) EffectiveUsage5h() float64 {
if IsWindowExpired(d.Window5hStart, RateLimitWindow5h) {
return 0
}
return d.Usage5h
}
// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired.
func (d *APIKeyRateLimitData) EffectiveUsage1d() float64 {
if IsWindowExpired(d.Window1dStart, RateLimitWindow1d) {
return 0
}
return d.Usage1d
}
// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired.
func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
if IsWindowExpired(d.Window7dStart, RateLimitWindow7d) {
return 0
}
return d.Usage7d
}
// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update.
// It is intentionally small so repositories can return it from a single SQL statement.
type APIKeyQuotaUsageState struct {
QuotaUsed float64
Quota float64
Key string
Status string
}
// APIKeyCache defines cache operations for API key service
type APIKeyCache interface {
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
IncrementDailyUsage(ctx context.Context, apiKey string) error
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
DeleteAuthCache(ctx context.Context, key string) error
// Pub/Sub for L1 cache invalidation across instances
PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error
SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error
}
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
type APIKeyAuthCacheInvalidator interface {
InvalidateAuthCacheByKey(ctx context.Context, key string)
InvalidateAuthCacheByUserID(ctx context.Context, userID int64)
InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64)
}
// CreateAPIKeyRequest 创建API Key请求
type CreateAPIKeyRequest struct {
Name string `json:"name"`
GroupID *int64 `json:"group_id"`
CustomKey *string `json:"custom_key"` // 可选的自定义key
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
// Quota fields
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires)
// Rate limit fields (0 = unlimited)
RateLimit5h float64 `json:"rate_limit_5h"`
RateLimit1d float64 `json:"rate_limit_1d"`
RateLimit7d float64 `json:"rate_limit_7d"`
}
// UpdateAPIKeyRequest 更新API Key请求
type UpdateAPIKeyRequest struct {
Name *string `json:"name"`
GroupID *int64 `json:"group_id"`
Status *string `json:"status"`
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
// Quota fields
Quota *float64 `json:"quota"` // Quota limit in USD (nil = no change, 0 = unlimited)
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change)
ClearExpiration bool `json:"-"` // Clear expiration (internal use)
ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0
// Rate limit fields (nil = no change, 0 = unlimited)
RateLimit5h *float64 `json:"rate_limit_5h"`
RateLimit1d *float64 `json:"rate_limit_1d"`
RateLimit7d *float64 `json:"rate_limit_7d"`
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // Reset all usage counters to 0
}
// APIKeyService API Key服务
// RateLimitCacheInvalidator invalidates rate limit cache entries on manual reset.
type RateLimitCacheInvalidator interface {
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
}
type APIKeyService struct {
apiKeyRepo APIKeyRepository
userRepo UserRepository
groupRepo GroupRepository
userSubRepo UserSubscriptionRepository
userGroupRateRepo UserGroupRateRepository
cache APIKeyCache
rateLimitCacheInvalid RateLimitCacheInvalidator // optional: invalidate Redis rate limit cache
cfg *config.Config
authCacheL1 *ristretto.Cache
authCfg apiKeyAuthCacheConfig
authGroup singleflight.Group
lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
lastUsedTouchSF singleflight.Group
}
// NewAPIKeyService 创建API Key服务实例
func NewAPIKeyService(
apiKeyRepo APIKeyRepository,
userRepo UserRepository,
groupRepo GroupRepository,
userSubRepo UserSubscriptionRepository,
userGroupRateRepo UserGroupRateRepository,
cache APIKeyCache,
cfg *config.Config,
) *APIKeyService {
svc := &APIKeyService{
apiKeyRepo: apiKeyRepo,
userRepo: userRepo,
groupRepo: groupRepo,
userSubRepo: userSubRepo,
userGroupRateRepo: userGroupRateRepo,
cache: cache,
cfg: cfg,
}
svc.initAuthCache(cfg)
return svc
}
// SetRateLimitCacheInvalidator sets the optional rate limit cache invalidator.
// Called after construction (e.g. in wire) to avoid circular dependencies.
func (s *APIKeyService) SetRateLimitCacheInvalidator(inv RateLimitCacheInvalidator) {
s.rateLimitCacheInvalid = inv
}
func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) {
if apiKey == nil {
return
}
apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist)
apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist)
}
// GenerateKey 生成随机API Key
func (s *APIKeyService) GenerateKey() (string, error) {
// 生成32字节随机数据
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("generate random bytes: %w", err)
}
// 转换为十六进制字符串并添加前缀
prefix := s.cfg.Default.APIKeyPrefix
if prefix == "" {
prefix = "sk-"
}
key := prefix + hex.EncodeToString(bytes)
return key, nil
}
// ValidateCustomKey 验证自定义API Key格式
func (s *APIKeyService) ValidateCustomKey(key string) error {
// 检查长度
if len(key) < 16 {
return ErrAPIKeyTooShort
}
// 检查字符:只允许字母、数字、下划线、连字符
for _, c := range key {
if (c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') ||
(c >= '0' && c <= '9') ||
c == '_' || c == '-' {
continue
}
return ErrAPIKeyInvalidChars
}
return nil
}
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
if s.cache == nil {
return nil
}
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
if err != nil {
// Redis 出错时不阻止用户操作
return nil
}
if count >= apiKeyMaxErrorsPerHour {
return ErrAPIKeyRateLimited
}
return nil
}
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
if s.cache == nil {
return
}
_ = s.cache.IncrementCreateAttemptCount(ctx, userID)
}
// canUserBindGroup 检查用户是否可以绑定指定分组
// 对于订阅类型分组:检查用户是否有有效订阅
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
return err == nil // 有有效订阅则允许
}
// 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive)
}
// Create 创建API Key
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
// 验证用户存在
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
// 验证 IP 白名单格式
if len(req.IPWhitelist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证 IP 黑名单格式
if len(req.IPBlacklist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证分组权限(如果指定了分组)
if req.GroupID != nil {
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
// 检查用户是否可以绑定该分组
if !s.canUserBindGroup(ctx, user, group) {
return nil, ErrGroupNotAllowed
}
}
var key string
// 判断是否使用自定义Key
if req.CustomKey != nil && *req.CustomKey != "" {
// 检查限流仅对自定义key进行限流
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
return nil, err
}
// 验证自定义Key格式
if err := s.ValidateCustomKey(*req.CustomKey); err != nil {
return nil, err
}
// 检查Key是否已存在
exists, err := s.apiKeyRepo.ExistsByKey(ctx, *req.CustomKey)
if err != nil {
return nil, fmt.Errorf("check key exists: %w", err)
}
if exists {
// Key已存在增加错误计数
s.incrementAPIKeyErrorCount(ctx, userID)
return nil, ErrAPIKeyExists
}
key = *req.CustomKey
} else {
// 生成随机API Key
var err error
key, err = s.GenerateKey()
if err != nil {
return nil, fmt.Errorf("generate key: %w", err)
}
}
// 创建API Key记录
apiKey := &APIKey{
UserID: userID,
Key: key,
Name: req.Name,
GroupID: req.GroupID,
Status: StatusActive,
IPWhitelist: req.IPWhitelist,
IPBlacklist: req.IPBlacklist,
Quota: req.Quota,
QuotaUsed: 0,
RateLimit5h: req.RateLimit5h,
RateLimit1d: req.RateLimit1d,
RateLimit7d: req.RateLimit7d,
}
// Set expiration time if specified
if req.ExpiresInDays != nil && *req.ExpiresInDays > 0 {
expiresAt := time.Now().AddDate(0, 0, *req.ExpiresInDays)
apiKey.ExpiresAt = &expiresAt
}
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
return nil, fmt.Errorf("create api key: %w", err)
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
// List 获取用户的API Key列表
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, filters)
if err != nil {
return nil, nil, fmt.Errorf("list api keys: %w", err)
}
return keys, pagination, nil
}
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
if len(apiKeyIDs) == 0 {
return []int64{}, nil
}
validIDs, err := s.apiKeyRepo.VerifyOwnership(ctx, userID, apiKeyIDs)
if err != nil {
return nil, fmt.Errorf("verify api key ownership: %w", err)
}
return validIDs, nil
}
// GetByID 根据ID获取API Key
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
// GetByKey 根据Key字符串获取API Key用于认证
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
cacheKey := s.authCacheKey(key)
if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok {
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
}
if s.authCfg.singleflight {
value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) {
return s.loadAuthCacheEntry(ctx, key, cacheKey)
})
if err != nil {
return nil, err
}
entry, _ := value.(*APIKeyAuthCacheEntry)
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
} else {
entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey)
if err != nil {
return nil, err
}
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
}
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
apiKey.Key = key
s.compileAPIKeyIPRules(apiKey)
return apiKey, nil
}
// Update 更新API Key
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("get api key: %w", err)
}
// 验证所有权
if apiKey.UserID != userID {
return nil, ErrInsufficientPerms
}
// 验证 IP 白名单格式
if len(req.IPWhitelist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 验证 IP 黑名单格式
if len(req.IPBlacklist) > 0 {
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
}
}
// 更新字段
if req.Name != nil {
apiKey.Name = *req.Name
}
if req.GroupID != nil {
// 验证分组权限
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
if err != nil {
return nil, fmt.Errorf("get group: %w", err)
}
if !s.canUserBindGroup(ctx, user, group) {
return nil, ErrGroupNotAllowed
}
apiKey.GroupID = req.GroupID
}
if req.Status != nil {
apiKey.Status = *req.Status
// 如果状态改变清除Redis缓存
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
}
}
// Update quota fields
if req.Quota != nil {
apiKey.Quota = *req.Quota
// If quota is increased and status was quota_exhausted, reactivate
if apiKey.Status == StatusAPIKeyQuotaExhausted && *req.Quota > apiKey.QuotaUsed {
apiKey.Status = StatusActive
}
}
if req.ResetQuota != nil && *req.ResetQuota {
apiKey.QuotaUsed = 0
// If resetting quota and status was quota_exhausted, reactivate
if apiKey.Status == StatusAPIKeyQuotaExhausted {
apiKey.Status = StatusActive
}
}
if req.ClearExpiration {
apiKey.ExpiresAt = nil
// If clearing expiry and status was expired, reactivate
if apiKey.Status == StatusAPIKeyExpired {
apiKey.Status = StatusActive
}
} else if req.ExpiresAt != nil {
apiKey.ExpiresAt = req.ExpiresAt
// If extending expiry and status was expired, reactivate
if apiKey.Status == StatusAPIKeyExpired && time.Now().Before(*req.ExpiresAt) {
apiKey.Status = StatusActive
}
}
// 更新 IP 限制(空数组会清空设置)
apiKey.IPWhitelist = req.IPWhitelist
apiKey.IPBlacklist = req.IPBlacklist
// Update rate limit configuration
if req.RateLimit5h != nil {
apiKey.RateLimit5h = *req.RateLimit5h
}
if req.RateLimit1d != nil {
apiKey.RateLimit1d = *req.RateLimit1d
}
if req.RateLimit7d != nil {
apiKey.RateLimit7d = *req.RateLimit7d
}
resetRateLimit := req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage
if resetRateLimit {
apiKey.Usage5h = 0
apiKey.Usage1d = 0
apiKey.Usage7d = 0
apiKey.Window5hStart = nil
apiKey.Window1dStart = nil
apiKey.Window7dStart = nil
}
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil, fmt.Errorf("update api key: %w", err)
}
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
s.compileAPIKeyIPRules(apiKey)
// Invalidate Redis rate limit cache so reset takes effect immediately
if resetRateLimit && s.rateLimitCacheInvalid != nil {
_ = s.rateLimitCacheInvalid.InvalidateAPIKeyRateLimit(ctx, apiKey.ID)
}
return apiKey, nil
}
// Delete 删除API Key
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id)
if err != nil {
return fmt.Errorf("get api key: %w", err)
}
// 验证当前用户是否为该 API Key 的所有者
if ownerID != userID {
return ErrInsufficientPerms
}
// 清除Redis缓存使用 userID 而非 apiKey.UserID
if s.cache != nil {
_ = s.cache.DeleteCreateAttemptCount(ctx, userID)
}
s.InvalidateAuthCacheByKey(ctx, key)
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
return fmt.Errorf("delete api key: %w", err)
}
s.lastUsedTouchL1.Delete(id)
return nil
}
// ValidateKey 验证API Key是否有效用于认证中间件
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
// 获取API Key
apiKey, err := s.GetByKey(ctx, key)
if err != nil {
return nil, nil, err
}
// 检查API Key状态
if !apiKey.IsActive() {
return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active")
}
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
if err != nil {
return nil, nil, fmt.Errorf("get user: %w", err)
}
// 检查用户状态
if !user.IsActive() {
return nil, nil, ErrUserNotActive
}
return apiKey, user, nil
}
// TouchLastUsed 通过防抖更新 api_keys.last_used_at减少高频写放大。
// 该操作为尽力而为,不应阻塞主请求链路。
func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error {
if keyID <= 0 {
return nil
}
now := time.Now()
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) {
return nil
}
}
_, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) {
latest := time.Now()
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) {
return nil, nil
}
}
if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil {
s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedFailBackoff))
return nil, fmt.Errorf("touch api key last used: %w", err)
}
s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedMinTouch))
return nil, nil
})
return err
}
// IncrementUsage 增加API Key使用次数可选用于统计
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
// 使用Redis计数器
if s.cache != nil {
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
return fmt.Errorf("increment usage: %w", err)
}
// 设置24小时过期
_ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
}
return nil
}
// GetAvailableGroups 获取用户有权限绑定的分组列表
// 返回用户可以选择的分组:
// - 标准类型分组:公开的(非专属)或用户被明确允许的
// - 订阅类型分组:用户有有效订阅的
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
// 获取用户信息
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user: %w", err)
}
// 获取所有活跃分组
allGroups, err := s.groupRepo.ListActive(ctx)
if err != nil {
return nil, fmt.Errorf("list active groups: %w", err)
}
// 获取用户的所有有效订阅
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("list active subscriptions: %w", err)
}
// 构建订阅分组 ID 集合
subscribedGroupIDs := make(map[int64]bool)
for _, sub := range activeSubscriptions {
subscribedGroupIDs[sub.GroupID] = true
}
// 过滤出用户有权限的分组
availableGroups := make([]Group, 0)
for _, group := range allGroups {
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
availableGroups = append(availableGroups, group)
}
}
return availableGroups, nil
}
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
// 订阅类型分组:需要有效订阅
if group.IsSubscriptionType() {
return subscribedGroupIDs[group.ID]
}
// 标准类型分组:使用原有逻辑
return user.CanBindGroup(group.ID, group.IsExclusive)
}
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
if err != nil {
return nil, fmt.Errorf("search api keys: %w", err)
}
return keys, nil
}
// GetUserGroupRates 获取用户的专属分组倍率配置
// 返回 map[groupID]rateMultiplier
func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) {
if s.userGroupRateRepo == nil {
return nil, nil
}
rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID)
if err != nil {
return nil, fmt.Errorf("get user group rates: %w", err)
}
return rates, nil
}
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
// Returns nil if valid, error if invalid
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {
// Check expiration
if apiKey.IsExpired() {
return ErrAPIKeyExpired
}
// Check quota
if apiKey.IsQuotaExhausted() {
return ErrAPIKeyQuotaExhausted
}
return nil
}
// UpdateQuotaUsed updates the quota_used field after a request
// Also checks if quota is exhausted and updates status accordingly
func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
if cost <= 0 {
return nil
}
type quotaStateReader interface {
IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error)
}
if repo, ok := s.apiKeyRepo.(quotaStateReader); ok {
state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost)
if err != nil {
return fmt.Errorf("increment quota used: %w", err)
}
if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" {
s.InvalidateAuthCacheByKey(ctx, state.Key)
}
return nil
}
// Use repository to atomically increment quota_used
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
if err != nil {
return fmt.Errorf("increment quota used: %w", err)
}
// Check if quota is now exhausted and update status if needed
apiKey, err := s.apiKeyRepo.GetByID(ctx, apiKeyID)
if err != nil {
return nil // Don't fail the request, just log
}
// If quota is set and now exhausted, update status
if apiKey.Quota > 0 && newQuotaUsed >= apiKey.Quota {
apiKey.Status = StatusAPIKeyQuotaExhausted
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
return nil // Don't fail the request
}
// Invalidate cache so next request sees the new status
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
}
return nil
}
// GetRateLimitData returns rate limit usage and window state for an API key.
func (s *APIKeyService) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
return s.apiKeyRepo.GetRateLimitData(ctx, id)
}
// UpdateRateLimitUsage atomically increments rate limit usage counters in the DB.
func (s *APIKeyService) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
if cost <= 0 {
return nil
}
return s.apiKeyRepo.IncrementRateLimitUsage(ctx, apiKeyID, cost)
}

View File

@@ -0,0 +1,448 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)
type authRepoStub struct {
getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error)
listKeysByUserID func(ctx context.Context, userID int64) ([]string, error)
listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error)
}
func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error {
panic("unexpected Create call")
}
func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
panic("unexpected GetByID call")
}
func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
panic("unexpected GetKeyAndOwnerID call")
}
func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
if s.getByKeyForAuth == nil {
panic("unexpected GetByKeyForAuth call")
}
return s.getByKeyForAuth(ctx, key)
}
func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
panic("unexpected Delete call")
}
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
panic("unexpected VerifyOwnership call")
}
func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
panic("unexpected CountByUserID call")
}
func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
panic("unexpected ExistsByKey call")
}
func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected ClearGroupIDByGroupID call")
}
func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected CountByGroupID call")
}
func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
if s.listKeysByUserID == nil {
panic("unexpected ListKeysByUserID call")
}
return s.listKeysByUserID(ctx, userID)
}
func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
if s.listKeysByGroupID == nil {
panic("unexpected ListKeysByGroupID call")
}
return s.listKeysByGroupID(ctx, groupID)
}
func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
panic("unexpected IncrementQuotaUsed call")
}
func (s *authRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
panic("unexpected UpdateLastUsed call")
}
func (s *authRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
panic("unexpected IncrementRateLimitUsage call")
}
func (s *authRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error {
panic("unexpected ResetRateLimitWindows call")
}
func (s *authRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
panic("unexpected GetRateLimitData call")
}
type authCacheStub struct {
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
setAuthKeys []string
deleteAuthKeys []string
}
func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return nil
}
func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return nil
}
func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
if s.getAuthCache == nil {
return nil, redis.Nil
}
return s.getAuthCache(ctx, key)
}
func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
s.setAuthKeys = append(s.setAuthKeys, key)
return nil
}
func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
func (s *authCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return nil
}
func (s *authCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
return nil
}
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, errors.New("unexpected repo call")
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
groupID := int64(9)
cacheEntry := &APIKeyAuthCacheEntry{
Snapshot: &APIKeyAuthSnapshot{
APIKeyID: 1,
UserID: 2,
GroupID: &groupID,
Status: StatusActive,
User: APIKeyAuthUserSnapshot{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 10,
Concurrency: 3,
},
Group: &APIKeyAuthGroupSnapshot{
ID: groupID,
Name: "g",
Platform: PlatformAnthropic,
Status: StatusActive,
SubscriptionType: SubscriptionTypeStandard,
RateMultiplier: 1,
ModelRoutingEnabled: true,
ModelRouting: map[string][]int64{
"claude-opus-*": {1, 2},
},
},
},
}
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return cacheEntry, nil
}
apiKey, err := svc.GetByKey(context.Background(), "k1")
require.NoError(t, err)
require.Equal(t, int64(1), apiKey.ID)
require.Equal(t, int64(2), apiKey.User.ID)
require.Equal(t, groupID, apiKey.Group.ID)
require.True(t, apiKey.Group.ModelRoutingEnabled)
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
}
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, errors.New("unexpected repo call")
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return &APIKeyAuthCacheEntry{NotFound: true}, nil
}
_, err := svc.GetByKey(context.Background(), "missing")
require.ErrorIs(t, err, ErrAPIKeyNotFound)
}
func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return &APIKey{
ID: 5,
UserID: 7,
Status: StatusActive,
User: &User{
ID: 7,
Status: StatusActive,
Role: RoleUser,
Balance: 12,
Concurrency: 2,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
apiKey, err := svc.GetByKey(context.Background(), "k2")
require.NoError(t, err)
require.Equal(t, int64(5), apiKey.ID)
require.Len(t, cache.setAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
var calls int32
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&calls, 1)
return &APIKey{
ID: 21,
UserID: 3,
Status: StatusActive,
User: &User{
ID: 3,
Status: StatusActive,
Role: RoleUser,
Balance: 5,
Concurrency: 2,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L1Size: 1000,
L1TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
require.NotNil(t, svc.authCacheL1)
_, err := svc.GetByKey(context.Background(), "k-l1")
require.NoError(t, err)
svc.authCacheL1.Wait()
cacheKey := svc.authCacheKey("k-l1")
_, ok := svc.authCacheL1.Get(cacheKey)
require.True(t, ok)
_, err = svc.GetByKey(context.Background(), "k-l1")
require.NoError(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
}
func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
return []string{"k1", "k2"}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
require.Len(t, cache.deleteAuthKeys, 2)
}
func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) {
return []string{"k1", "k2"}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
require.Len(t, cache.deleteAuthKeys, 2)
}
func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
return nil, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
require.Len(t, cache.deleteAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
return nil, ErrAPIKeyNotFound
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
L2TTLSeconds: 60,
NegativeTTLSeconds: 30,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, redis.Nil
}
_, err := svc.GetByKey(context.Background(), "missing")
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Len(t, cache.setAuthKeys, 1)
}
func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
var calls int32
cache := &authCacheStub{}
repo := &authRepoStub{
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
atomic.AddInt32(&calls, 1)
time.Sleep(50 * time.Millisecond)
return &APIKey{
ID: 11,
UserID: 2,
Status: StatusActive,
User: &User{
ID: 2,
Status: StatusActive,
Role: RoleUser,
Balance: 1,
Concurrency: 1,
},
}, nil
},
}
cfg := &config.Config{
APIKeyAuth: config.APIKeyAuthCacheConfig{
Singleflight: true,
},
}
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
start := make(chan struct{})
wg := sync.WaitGroup{}
errs := make([]error, 5)
for i := 0; i < 5; i++ {
wg.Add(1)
go func(idx int) {
defer wg.Done()
<-start
_, err := svc.GetByKey(context.Background(), "k1")
errs[idx] = err
}(i)
}
close(start)
wg.Wait()
for _, err := range errs {
require.NoError(t, err)
}
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
}

View File

@@ -0,0 +1,291 @@
//go:build unit
// API Key 服务删除方法的单元测试
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
// 包括权限验证、缓存清理和错误处理
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
//
// 设计说明:
// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误
// - deleteErr: 模拟 Delete 返回的错误
// - deletedIDs: 记录被调用删除的 API Key ID用于断言验证
type apiKeyRepoStub struct {
apiKey *APIKey // GetKeyAndOwnerID 的返回值
getByIDErr error // GetKeyAndOwnerID 的错误返回值
deleteErr error // Delete 的错误返回值
deletedIDs []int64 // 记录已删除的 API Key ID 列表
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
touchedIDs []int64
touchedUsedAts []time.Time
}
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
panic("unexpected Create call")
}
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
if s.getByIDErr != nil {
return nil, s.getByIDErr
}
if s.apiKey != nil {
clone := *s.apiKey
return &clone, nil
}
panic("unexpected GetByID call")
}
func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
if s.getByIDErr != nil {
return "", 0, s.getByIDErr
}
if s.apiKey != nil {
return s.apiKey.Key, s.apiKey.UserID, nil
}
return "", 0, ErrAPIKeyNotFound
}
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
panic("unexpected GetByKeyForAuth call")
}
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
panic("unexpected Update call")
}
// Delete 记录被删除的 API Key ID 并返回预设的错误。
// 通过 deletedIDs 可以验证删除操作是否被正确调用。
func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
s.deletedIDs = append(s.deletedIDs, id)
return s.deleteErr
}
// 以下是接口要求实现但本测试不关心的方法
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
func (s *apiKeyRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
panic("unexpected VerifyOwnership call")
}
func (s *apiKeyRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
panic("unexpected CountByUserID call")
}
func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
panic("unexpected ExistsByKey call")
}
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected ClearGroupIDByGroupID call")
}
func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
panic("unexpected CountByGroupID call")
}
func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
panic("unexpected ListKeysByUserID call")
}
func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
panic("unexpected ListKeysByGroupID call")
}
func (s *apiKeyRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
panic("unexpected IncrementQuotaUsed call")
}
func (s *apiKeyRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
s.touchedIDs = append(s.touchedIDs, id)
s.touchedUsedAts = append(s.touchedUsedAts, usedAt)
if s.updateLastUsed != nil {
return s.updateLastUsed(ctx, id, usedAt)
}
return nil
}
func (s *apiKeyRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
panic("unexpected IncrementRateLimitUsage call")
}
func (s *apiKeyRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error {
panic("unexpected ResetRateLimitWindows call")
}
func (s *apiKeyRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
panic("unexpected GetRateLimitData call")
}
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
//
// 设计说明:
// - invalidated: 记录被清除缓存的用户 ID 列表
type apiKeyCacheStub struct {
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key
}
// GetCreateAttemptCount 返回 0表示用户未超过创建次数限制
func (s *apiKeyCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
return 0, nil
}
// IncrementCreateAttemptCount 空实现,本测试不验证此行为
func (s *apiKeyCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
return nil
}
// DeleteCreateAttemptCount 记录被清除缓存的用户 ID。
// 删除 API Key 时会调用此方法清除用户的创建尝试计数缓存。
func (s *apiKeyCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
s.invalidated = append(s.invalidated, userID)
return nil
}
// IncrementDailyUsage 空实现,本测试不验证此行为
func (s *apiKeyCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
return nil
}
// SetDailyUsageExpiry 空实现,本测试不验证此行为
func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
return nil
}
func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
return nil, nil
}
func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
return nil
}
func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
func (s *apiKeyCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
return nil
}
func (s *apiKeyCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
return nil
}
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
// 预期行为:
// - GetKeyAndOwnerID 返回所有者 ID 为 1
// - 调用者 userID 为 2不匹配
// - 返回 ErrInsufficientPerms 错误
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
repo := &apiKeyRepoStub{
apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"},
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
require.ErrorIs(t, err, ErrInsufficientPerms)
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
require.Empty(t, cache.invalidated) // 验证缓存未被清除
require.Empty(t, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
// 预期行为:
// - GetKeyAndOwnerID 返回所有者 ID 为 7
// - 调用者 userID 为 7匹配
// - Delete 成功执行
// - 缓存被正确清除(使用 ownerID
// - 返回 nil 错误
func TestApiKeyService_Delete_Success(t *testing.T) {
repo := &apiKeyRepoStub{
apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"},
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
svc.lastUsedTouchL1.Store(int64(42), time.Now())
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
require.NoError(t, err)
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
_, exists := svc.lastUsedTouchL1.Load(int64(42))
require.False(t, exists, "delete should clear touch debounce cache")
}
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
// 预期行为:
// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
// - Delete 方法不被调用
// - 缓存不被清除
func TestApiKeyService_Delete_NotFound(t *testing.T) {
repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 99, 1)
require.ErrorIs(t, err, ErrAPIKeyNotFound)
require.Empty(t, repo.deletedIDs)
require.Empty(t, cache.invalidated)
require.Empty(t, cache.deleteAuthKeys)
}
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
// 预期行为:
// - GetKeyAndOwnerID 返回正确的所有者 ID
// - 所有权验证通过
// - 缓存被清除(在删除之前)
// - Delete 被调用但返回错误
// - 返回包含 "delete api key" 的错误信息
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
repo := &apiKeyRepoStub{
apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"},
deleteErr: errors.New("delete failed"),
}
cache := &apiKeyCacheStub{}
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
require.Error(t, err)
require.ErrorContains(t, err, "delete api key")
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
}

View File

@@ -0,0 +1,170 @@
//go:build unit
package service
import (
"context"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
"github.com/stretchr/testify/require"
)
type quotaStateRepoStub struct {
quotaBaseAPIKeyRepoStub
stateCalls int
state *APIKeyQuotaUsageState
stateErr error
}
func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) {
s.stateCalls++
if s.stateErr != nil {
return nil, s.stateErr
}
if s.state == nil {
return nil, nil
}
out := *s.state
return &out, nil
}
type quotaStateCacheStub struct {
deleteAuthKeys []string
}
func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) {
return 0, nil
}
func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error {
return nil
}
func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error {
return nil
}
func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error {
return nil
}
func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error {
return nil
}
func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) {
return nil, nil
}
func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error {
return nil
}
func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error {
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
return nil
}
func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error {
return nil
}
func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error {
return nil
}
type quotaBaseAPIKeyRepoStub struct {
getByIDCalls int
}
func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error {
panic("unexpected Create call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) {
s.getByIDCalls++
return nil, nil
}
func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
panic("unexpected GetKeyAndOwnerID call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) {
panic("unexpected GetByKey call")
}
func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
panic("unexpected GetByKeyForAuth call")
}
func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error {
panic("unexpected Update call")
}
func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error {
panic("unexpected Delete call")
}
func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
panic("unexpected VerifyOwnership call")
}
func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) {
panic("unexpected CountByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) {
panic("unexpected ExistsByKey call")
}
func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
panic("unexpected ListByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
panic("unexpected SearchAPIKeys call")
}
func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
panic("unexpected ClearGroupIDByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) {
panic("unexpected CountByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) {
panic("unexpected ListKeysByUserID call")
}
func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) {
panic("unexpected ListKeysByGroupID call")
}
func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
panic("unexpected IncrementQuotaUsed call")
}
func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error {
panic("unexpected UpdateLastUsed call")
}
func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error {
panic("unexpected IncrementRateLimitUsage call")
}
func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error {
panic("unexpected ResetRateLimitWindows call")
}
func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
panic("unexpected GetRateLimitData call")
}
func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) {
repo := &quotaStateRepoStub{
state: &APIKeyQuotaUsageState{
QuotaUsed: 12,
Quota: 10,
Key: "sk-test-quota",
Status: StatusAPIKeyQuotaExhausted,
},
}
cache := &quotaStateCacheStub{}
svc := &APIKeyService{
apiKeyRepo: repo,
cache: cache,
}
err := svc.UpdateQuotaUsed(context.Background(), 101, 2)
require.NoError(t, err)
require.Equal(t, 1, repo.stateCalls)
require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id")
require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys)
}

View File

@@ -0,0 +1,160 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestAPIKeyService_TouchLastUsed_InvalidKeyID(t *testing.T) {
repo := &apiKeyRepoStub{
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
return errors.New("should not be called")
},
}
svc := &APIKeyService{apiKeyRepo: repo}
require.NoError(t, svc.TouchLastUsed(context.Background(), 0))
require.NoError(t, svc.TouchLastUsed(context.Background(), -1))
require.Empty(t, repo.touchedIDs)
}
func TestAPIKeyService_TouchLastUsed_FirstTouchSucceeds(t *testing.T) {
repo := &apiKeyRepoStub{}
svc := &APIKeyService{apiKeyRepo: repo}
err := svc.TouchLastUsed(context.Background(), 123)
require.NoError(t, err)
require.Equal(t, []int64{123}, repo.touchedIDs)
require.Len(t, repo.touchedUsedAts, 1)
require.False(t, repo.touchedUsedAts[0].IsZero())
cached, ok := svc.lastUsedTouchL1.Load(int64(123))
require.True(t, ok, "successful touch should update debounce cache")
_, isTime := cached.(time.Time)
require.True(t, isTime)
}
func TestAPIKeyService_TouchLastUsed_DebouncedWithinWindow(t *testing.T) {
repo := &apiKeyRepoStub{}
svc := &APIKeyService{apiKeyRepo: repo}
require.NoError(t, svc.TouchLastUsed(context.Background(), 123))
require.NoError(t, svc.TouchLastUsed(context.Background(), 123))
require.Equal(t, []int64{123}, repo.touchedIDs, "second touch within debounce window should not hit repository")
}
func TestAPIKeyService_TouchLastUsed_ExpiredDebounceTouchesAgain(t *testing.T) {
repo := &apiKeyRepoStub{}
svc := &APIKeyService{apiKeyRepo: repo}
require.NoError(t, svc.TouchLastUsed(context.Background(), 123))
// 强制将 debounce 时间回拨到窗口之外,触发第二次写库。
svc.lastUsedTouchL1.Store(int64(123), time.Now().Add(-apiKeyLastUsedMinTouch-time.Second))
require.NoError(t, svc.TouchLastUsed(context.Background(), 123))
require.Len(t, repo.touchedIDs, 2)
require.Equal(t, int64(123), repo.touchedIDs[0])
require.Equal(t, int64(123), repo.touchedIDs[1])
}
func TestAPIKeyService_TouchLastUsed_RepoError(t *testing.T) {
repo := &apiKeyRepoStub{
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
return errors.New("db write failed")
},
}
svc := &APIKeyService{apiKeyRepo: repo}
err := svc.TouchLastUsed(context.Background(), 123)
require.Error(t, err)
require.ErrorContains(t, err, "touch api key last used")
require.Equal(t, []int64{123}, repo.touchedIDs)
cached, ok := svc.lastUsedTouchL1.Load(int64(123))
require.True(t, ok, "failed touch should still update retry debounce cache")
_, isTime := cached.(time.Time)
require.True(t, isTime)
}
func TestAPIKeyService_TouchLastUsed_RepoErrorDebounced(t *testing.T) {
repo := &apiKeyRepoStub{
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
return errors.New("db write failed")
},
}
svc := &APIKeyService{apiKeyRepo: repo}
firstErr := svc.TouchLastUsed(context.Background(), 456)
require.Error(t, firstErr)
require.ErrorContains(t, firstErr, "touch api key last used")
secondErr := svc.TouchLastUsed(context.Background(), 456)
require.NoError(t, secondErr, "failed touch should be debounced and skip immediate retry")
require.Equal(t, []int64{456}, repo.touchedIDs, "debounced retry should not hit repository again")
}
type touchSingleflightRepo struct {
*apiKeyRepoStub
mu sync.Mutex
calls int
blockCh chan struct{}
}
func (r *touchSingleflightRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
r.mu.Lock()
r.calls++
r.mu.Unlock()
<-r.blockCh
return nil
}
func TestAPIKeyService_TouchLastUsed_ConcurrentFirstTouchDeduplicated(t *testing.T) {
repo := &touchSingleflightRepo{
apiKeyRepoStub: &apiKeyRepoStub{},
blockCh: make(chan struct{}),
}
svc := &APIKeyService{apiKeyRepo: repo}
const workers = 20
startCh := make(chan struct{})
errCh := make(chan error, workers)
var wg sync.WaitGroup
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-startCh
errCh <- svc.TouchLastUsed(context.Background(), 321)
}()
}
close(startCh)
require.Eventually(t, func() bool {
repo.mu.Lock()
defer repo.mu.Unlock()
return repo.calls >= 1
}, time.Second, 10*time.Millisecond)
close(repo.blockCh)
wg.Wait()
close(errCh)
for err := range errCh {
require.NoError(t, err)
}
repo.mu.Lock()
defer repo.mu.Unlock()
require.Equal(t, 1, repo.calls, "并发首次 touch 只应写库一次")
}

View File

@@ -0,0 +1,33 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/stretchr/testify/require"
)
func TestUsageService_InvalidateUsageCaches(t *testing.T) {
invalidator := &authCacheInvalidatorStub{}
svc := &UsageService{authCacheInvalidator: invalidator}
svc.invalidateUsageCaches(context.Background(), 7, false)
require.Empty(t, invalidator.userIDs)
svc.invalidateUsageCaches(context.Background(), 7, true)
require.Equal(t, []int64{7}, invalidator.userIDs)
}
func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) {
invalidator := &authCacheInvalidatorStub{}
svc := &RedeemService{authCacheInvalidator: invalidator}
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance})
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency})
groupID := int64(3)
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeSubscription, GroupID: &groupID})
require.Equal(t, []int64{11, 11, 11}, invalidator.userIDs)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,146 @@
//go:build unit
package service
import (
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/require"
)
func newAuthServiceForPendingOAuthTest() *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret-pending-oauth",
ExpireHour: 1,
},
}
return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
}
// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
require.NotEmpty(t, token)
email, username, err := svc.VerifyPendingOAuthToken(token)
require.NoError(t, err)
require.Equal(t, "user@example.com", email)
require.Equal(t, "alice", username)
}
// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
// 签发一个普通 access tokenJWTClaims无 Purpose 字段)
accessToken, err := svc.GenerateToken(&User{
ID: 1,
Email: "user@example.com",
Role: RoleUser,
})
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(accessToken)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "some_other_purpose",
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT模拟旧 token应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
now := time.Now()
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: "", // 旧 token 无此字段,反序列化后为零值
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
IssuedAt: jwt.NewNumericDate(now),
NotBefore: jwt.NewNumericDate(now),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
past := time.Now().Add(-1 * time.Hour)
claims := &pendingOAuthClaims{
Email: "user@example.com",
Username: "alice",
Purpose: pendingOAuthPurpose,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(past),
IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
require.NoError(t, err)
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
other := NewAuthService(nil, nil, nil, nil, &config.Config{
JWT: config.JWTConfig{Secret: "other-secret"},
}, nil, nil, nil, nil, nil, nil)
token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
require.NoError(t, err)
svc := newAuthServiceForPendingOAuthTest()
_, _, err = svc.VerifyPendingOAuthToken(token)
require.ErrorIs(t, err, ErrInvalidToken)
}
// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
svc := newAuthServiceForPendingOAuthTest()
giant := make([]byte, maxTokenLength+1)
for i := range giant {
giant[i] = 'a'
}
_, _, err := svc.VerifyPendingOAuthToken(string(giant))
require.ErrorIs(t, err, ErrInvalidToken)
}

View File

@@ -0,0 +1,466 @@
//go:build unit
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/stretchr/testify/require"
)
type settingRepoStub struct {
values map[string]string
err error
}
func (s *settingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
panic("unexpected Get call")
}
func (s *settingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
if s.err != nil {
return "", s.err
}
if v, ok := s.values[key]; ok {
return v, nil
}
return "", ErrSettingNotFound
}
func (s *settingRepoStub) Set(ctx context.Context, key, value string) error {
panic("unexpected Set call")
}
func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
panic("unexpected GetMultiple call")
}
func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
panic("unexpected SetMultiple call")
}
func (s *settingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
panic("unexpected GetAll call")
}
func (s *settingRepoStub) Delete(ctx context.Context, key string) error {
panic("unexpected Delete call")
}
type emailCacheStub struct {
data *VerificationCodeData
err error
}
type defaultSubscriptionAssignerStub struct {
calls []AssignSubscriptionInput
err error
}
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
if input != nil {
s.calls = append(s.calls, *input)
}
if s.err != nil {
return nil, false, s.err
}
return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
}
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
if s.err != nil {
return nil, s.err
}
return s.data, nil
}
func (s *emailCacheStub) SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error {
return nil
}
func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email string) error {
return nil
}
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
return nil, nil
}
func (s *emailCacheStub) SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error {
return nil
}
func (s *emailCacheStub) DeletePasswordResetToken(ctx context.Context, email string) error {
return nil
}
func (s *emailCacheStub) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
return false
}
func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
return nil
}
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
cfg := &config.Config{
JWT: config.JWTConfig{
Secret: "test-secret",
ExpireHour: 1,
},
Default: config.DefaultConfig{
UserBalance: 3.5,
UserConcurrency: 2,
},
}
var settingService *SettingService
if settings != nil {
settingService = NewSettingService(&settingRepoStub{values: settings}, cfg)
}
var emailService *EmailService
if emailCache != nil {
emailService = NewEmailService(&settingRepoStub{values: settings}, emailCache)
}
return NewAuthService(
nil, // entClient
repo,
nil, // redeemRepo
nil, // refreshTokenCache
cfg,
settingService,
emailService,
nil,
nil,
nil, // promoService
nil, // defaultSubAssigner
)
}
func TestAuthService_Register_Disabled(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "false",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrRegDisabled)
}
func TestAuthService_Register_DisabledByDefault(t *testing.T) {
// 当 settings 为 nil设置项不存在注册应该默认关闭
repo := &userRepoStub{}
service := newAuthService(repo, nil, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrRegDisabled)
}
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
repo := &userRepoStub{}
// 邮件验证开启但 emailCache 为 nilemailService 未配置)
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
}, nil)
// 应返回服务不可用错误,而不是允许绕过验证
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
repo := &userRepoStub{}
cache := &emailCacheStub{} // 配置 emailService
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
require.ErrorIs(t, err, ErrEmailVerifyRequired)
}
func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
repo := &userRepoStub{}
cache := &emailCacheStub{
data: &VerificationCodeData{Code: "expected", Attempts: 0},
}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyEmailVerifyEnabled: "true",
}, cache)
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
require.ErrorIs(t, err, ErrInvalidVerifyCode)
require.ErrorContains(t, err, "verify code")
}
func TestAuthService_Register_EmailExists(t *testing.T) {
repo := &userRepoStub{exists: true}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists)
}
func TestAuthService_Register_CheckEmailError(t *testing.T) {
repo := &userRepoStub{existsErr: errors.New("db down")}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
func TestAuthService_Register_ReservedEmail(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password")
require.ErrorIs(t, err, ErrEmailReserved)
}
func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
}, nil)
_, _, err := service.Register(context.Background(), "user@other.com", "password")
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
appErr := infraerrors.FromError(err)
require.Contains(t, appErr.Message, "@example.com")
require.Contains(t, appErr.Message, "@company.com")
require.Equal(t, "EMAIL_SUFFIX_NOT_ALLOWED", appErr.Reason)
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
require.Equal(t, "@example.com,@company.com", appErr.Metadata["allowed_suffixes"])
}
func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) {
repo := &userRepoStub{nextID: 8}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`,
}, nil)
_, user, err := service.Register(context.Background(), "user@example.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Equal(t, int64(8), user.ID)
}
func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
}, nil)
err := service.SendVerifyCode(context.Background(), "user@other.com")
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
appErr := infraerrors.FromError(err)
require.Contains(t, appErr.Message, "@example.com")
require.Contains(t, appErr.Message, "@company.com")
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
}
func TestAuthService_Register_CreateError(t *testing.T) {
repo := &userRepoStub{createErr: errors.New("create failed")}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrServiceUnavailable)
}
func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
// 模拟竞态条件ExistsByEmail 返回 false但 Create 时因唯一约束失败
repo := &userRepoStub{createErr: ErrEmailExists}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
_, _, err := service.Register(context.Background(), "user@test.com", "password")
require.ErrorIs(t, err, ErrEmailExists)
}
func TestAuthService_Register_Success(t *testing.T) {
repo := &userRepoStub{nextID: 5}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
}, nil)
token, user, err := service.Register(context.Background(), "user@test.com", "password")
require.NoError(t, err)
require.NotEmpty(t, token)
require.NotNil(t, user)
require.Equal(t, int64(5), user.ID)
require.Equal(t, "user@test.com", user.Email)
require.Equal(t, RoleUser, user.Role)
require.Equal(t, StatusActive, user.Status)
require.Equal(t, 3.5, user.Balance)
require.Equal(t, 2, user.Concurrency)
require.Len(t, repo.created, 1)
require.True(t, user.CheckPassword("password"))
}
func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) {
repo := &userRepoStub{}
service := newAuthService(repo, nil, nil)
// 创建用户并生成 token
user := &User{
ID: 1,
Email: "test@test.com",
Role: RoleUser,
Status: StatusActive,
TokenVersion: 1,
}
token, err := service.GenerateToken(user)
require.NoError(t, err)
// 验证有效 token
claims, err := service.ValidateToken(token)
require.NoError(t, err)
require.NotNil(t, claims)
require.Equal(t, int64(1), claims.UserID)
// 模拟过期 token通过创建一个过期很久的 token
service.cfg.JWT.ExpireHour = -1 // 设置为负数使 token 立即过期
expiredToken, err := service.GenerateToken(user)
require.NoError(t, err)
service.cfg.JWT.ExpireHour = 1 // 恢复
// 验证过期 token 应返回 claims 和 ErrTokenExpired
claims, err = service.ValidateToken(expiredToken)
require.ErrorIs(t, err, ErrTokenExpired)
require.NotNil(t, claims, "claims should not be nil when token is expired")
require.Equal(t, int64(1), claims.UserID)
require.Equal(t, "test@test.com", claims.Email)
}
func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
user := &User{
ID: 1,
Email: "test@test.com",
Role: RoleUser,
Status: StatusActive,
TokenVersion: 1,
}
repo := &userRepoStub{user: user}
service := newAuthService(repo, nil, nil)
// 创建过期 token
service.cfg.JWT.ExpireHour = -1
expiredToken, err := service.GenerateToken(user)
require.NoError(t, err)
service.cfg.JWT.ExpireHour = 1
// RefreshToken 使用过期 token 不应 panic
require.NotPanics(t, func() {
newToken, err := service.RefreshToken(context.Background(), expiredToken)
require.NoError(t, err)
require.NotEmpty(t, newToken)
})
}
func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T) {
service := newAuthService(&userRepoStub{}, nil, nil)
service.cfg.JWT.ExpireHour = 24
service.cfg.JWT.AccessTokenExpireMinutes = 0
require.Equal(t, 24*3600, service.GetAccessTokenExpiresIn())
}
func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) {
service := newAuthService(&userRepoStub{}, nil, nil)
service.cfg.JWT.ExpireHour = 24
service.cfg.JWT.AccessTokenExpireMinutes = 90
require.Equal(t, 90*60, service.GetAccessTokenExpiresIn())
}
func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) {
service := newAuthService(&userRepoStub{}, nil, nil)
service.cfg.JWT.ExpireHour = 24
service.cfg.JWT.AccessTokenExpireMinutes = 0
user := &User{
ID: 1,
Email: "test@test.com",
Role: RoleUser,
Status: StatusActive,
TokenVersion: 1,
}
token, err := service.GenerateToken(user)
require.NoError(t, err)
claims, err := service.ValidateToken(token)
require.NoError(t, err)
require.NotNil(t, claims)
require.NotNil(t, claims.IssuedAt)
require.NotNil(t, claims.ExpiresAt)
require.WithinDuration(t, claims.IssuedAt.Time.Add(24*time.Hour), claims.ExpiresAt.Time, 2*time.Second)
}
func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) {
service := newAuthService(&userRepoStub{}, nil, nil)
service.cfg.JWT.ExpireHour = 24
service.cfg.JWT.AccessTokenExpireMinutes = 90
user := &User{
ID: 2,
Email: "test2@test.com",
Role: RoleUser,
Status: StatusActive,
TokenVersion: 1,
}
token, err := service.GenerateToken(user)
require.NoError(t, err)
claims, err := service.ValidateToken(token)
require.NoError(t, err)
require.NotNil(t, claims)
require.NotNil(t, claims.IssuedAt)
require.NotNil(t, claims.ExpiresAt)
require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second)
}
func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
repo := &userRepoStub{nextID: 42}
assigner := &defaultSubscriptionAssignerStub{}
service := newAuthService(repo, map[string]string{
SettingKeyRegistrationEnabled: "true",
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
}, nil)
service.defaultSubAssigner = assigner
_, user, err := service.Register(context.Background(), "default-sub@test.com", "password")
require.NoError(t, err)
require.NotNil(t, user)
require.Len(t, assigner.calls, 2)
require.Equal(t, int64(42), assigner.calls[0].UserID)
require.Equal(t, int64(11), assigner.calls[0].GroupID)
require.Equal(t, 30, assigner.calls[0].ValidityDays)
require.Equal(t, int64(12), assigner.calls[1].GroupID)
require.Equal(t, 7, assigner.calls[1].ValidityDays)
}

View File

@@ -0,0 +1,98 @@
//go:build unit
package service
import (
"context"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type turnstileVerifierSpy struct {
called int
lastToken string
result *TurnstileVerifyResponse
err error
}
func (s *turnstileVerifierSpy) VerifyToken(_ context.Context, _ string, token, _ string) (*TurnstileVerifyResponse, error) {
s.called++
s.lastToken = token
if s.err != nil {
return nil, s.err
}
if s.result != nil {
return s.result, nil
}
return &TurnstileVerifyResponse{Success: true}, nil
}
func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier TurnstileVerifier) *AuthService {
cfg := &config.Config{
Server: config.ServerConfig{
Mode: "release",
},
Turnstile: config.TurnstileConfig{
Required: true,
},
}
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
turnstileService := NewTurnstileService(settingService, verifier)
return NewAuthService(
nil, // entClient
&userRepoStub{},
nil, // redeemRepo
nil, // refreshTokenCache
cfg,
settingService,
nil, // emailService
turnstileService,
nil, // emailQueueService
nil, // promoService
nil, // defaultSubAssigner
)
}
func TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "true",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
SettingKeyRegistrationEnabled: "true",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "123456")
require.NoError(t, err)
require.Equal(t, 0, verifier.called)
}
func TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "true",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "")
require.ErrorIs(t, err, ErrTurnstileVerificationFailed)
}
func TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled(t *testing.T) {
verifier := &turnstileVerifierSpy{}
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
SettingKeyEmailVerifyEnabled: "false",
SettingKeyTurnstileEnabled: "true",
SettingKeyTurnstileSecretKey: "secret",
}, verifier)
err := service.VerifyTurnstileForRegister(context.Background(), "turnstile-token", "127.0.0.1", "123456")
require.NoError(t, err)
require.Equal(t, 1, verifier.called)
require.Equal(t, "turnstile-token", verifier.lastToken)
}

View File

@@ -0,0 +1,770 @@
package service
import (
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
"sort"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/robfig/cron/v3"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
const (
settingKeyBackupS3Config = "backup_s3_config"
settingKeyBackupSchedule = "backup_schedule"
settingKeyBackupRecords = "backup_records"
maxBackupRecords = 100
)
var (
ErrBackupS3NotConfigured = infraerrors.BadRequest("BACKUP_S3_NOT_CONFIGURED", "backup S3 storage is not configured")
ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found")
ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress")
ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress")
ErrBackupRecordsCorrupt = infraerrors.InternalServer("BACKUP_RECORDS_CORRUPT", "backup records data is corrupted")
ErrBackupS3ConfigCorrupt = infraerrors.InternalServer("BACKUP_S3_CONFIG_CORRUPT", "backup S3 config data is corrupted")
)
// ─── 接口定义 ───
// DBDumper abstracts database dump/restore operations
type DBDumper interface {
Dump(ctx context.Context) (io.ReadCloser, error)
Restore(ctx context.Context, data io.Reader) error
}
// BackupObjectStore abstracts object storage for backup files
type BackupObjectStore interface {
Upload(ctx context.Context, key string, body io.Reader, contentType string) (sizeBytes int64, err error)
Download(ctx context.Context, key string) (io.ReadCloser, error)
Delete(ctx context.Context, key string) error
PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error)
HeadBucket(ctx context.Context) error
}
// BackupObjectStoreFactory creates an object store from S3 config
type BackupObjectStoreFactory func(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error)
// ─── 数据模型 ───
// BackupS3Config S3 兼容存储配置(支持 Cloudflare R2
type BackupS3Config struct {
Endpoint string `json:"endpoint"` // e.g. https://<account_id>.r2.cloudflarestorage.com
Region string `json:"region"` // R2 用 "auto"
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key,omitempty"` //nolint:revive // field name follows AWS convention
Prefix string `json:"prefix"` // S3 key 前缀,如 "backups/"
ForcePathStyle bool `json:"force_path_style"`
}
// IsConfigured 检查必要字段是否已配置
func (c *BackupS3Config) IsConfigured() bool {
return c.Bucket != "" && c.AccessKeyID != "" && c.SecretAccessKey != ""
}
// BackupScheduleConfig 定时备份配置
type BackupScheduleConfig struct {
Enabled bool `json:"enabled"`
CronExpr string `json:"cron_expr"` // cron 表达式,如 "0 2 * * *" 每天凌晨2点
RetainDays int `json:"retain_days"` // 备份文件过期天数默认140=不自动清理
RetainCount int `json:"retain_count"` // 最多保留份数0=不限制
}
// BackupRecord 备份记录
type BackupRecord struct {
ID string `json:"id"`
Status string `json:"status"` // pending, running, completed, failed
BackupType string `json:"backup_type"` // postgres
FileName string `json:"file_name"`
S3Key string `json:"s3_key"`
SizeBytes int64 `json:"size_bytes"`
TriggeredBy string `json:"triggered_by"` // manual, scheduled
ErrorMsg string `json:"error_message,omitempty"`
StartedAt string `json:"started_at"`
FinishedAt string `json:"finished_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
}
// BackupService 数据库备份恢复服务
type BackupService struct {
settingRepo SettingRepository
dbCfg *config.DatabaseConfig
encryptor SecretEncryptor
storeFactory BackupObjectStoreFactory
dumper DBDumper
mu sync.Mutex
store BackupObjectStore
s3Cfg *BackupS3Config
backingUp bool
restoring bool
recordsMu sync.Mutex // 保护 records 的 load/save 操作
cronMu sync.Mutex
cronSched *cron.Cron
cronEntryID cron.EntryID
}
func NewBackupService(
settingRepo SettingRepository,
cfg *config.Config,
encryptor SecretEncryptor,
storeFactory BackupObjectStoreFactory,
dumper DBDumper,
) *BackupService {
return &BackupService{
settingRepo: settingRepo,
dbCfg: &cfg.Database,
encryptor: encryptor,
storeFactory: storeFactory,
dumper: dumper,
}
}
// Start 启动定时备份调度器
func (s *BackupService) Start() {
s.cronSched = cron.New()
s.cronSched.Start()
// 加载已有的定时配置
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
schedule, err := s.GetSchedule(ctx)
if err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 加载定时备份配置失败: %v", err)
return
}
if schedule.Enabled && schedule.CronExpr != "" {
if err := s.applyCronSchedule(schedule); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 应用定时备份配置失败: %v", err)
}
}
}
// Stop 停止定时备份
func (s *BackupService) Stop() {
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched != nil {
s.cronSched.Stop()
}
}
// ─── S3 配置管理 ───
func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) {
cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
if cfg == nil {
return &BackupS3Config{}, nil
}
// 脱敏返回
cfg.SecretAccessKey = ""
return cfg, nil
}
func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) {
// 如果没提供 secret保留原有值
if cfg.SecretAccessKey == "" {
old, _ := s.loadS3Config(ctx)
if old != nil {
cfg.SecretAccessKey = old.SecretAccessKey
}
} else {
// 加密 SecretAccessKey
encrypted, err := s.encryptor.Encrypt(cfg.SecretAccessKey)
if err != nil {
return nil, fmt.Errorf("encrypt secret: %w", err)
}
cfg.SecretAccessKey = encrypted
}
data, err := json.Marshal(cfg)
if err != nil {
return nil, fmt.Errorf("marshal s3 config: %w", err)
}
if err := s.settingRepo.Set(ctx, settingKeyBackupS3Config, string(data)); err != nil {
return nil, fmt.Errorf("save s3 config: %w", err)
}
// 清除缓存的 S3 客户端
s.mu.Lock()
s.store = nil
s.s3Cfg = nil
s.mu.Unlock()
cfg.SecretAccessKey = ""
return &cfg, nil
}
func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config) error {
// 如果没提供 secret用已保存的
if cfg.SecretAccessKey == "" {
old, _ := s.loadS3Config(ctx)
if old != nil {
cfg.SecretAccessKey = old.SecretAccessKey
}
}
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required")
}
store, err := s.storeFactory(ctx, &cfg)
if err != nil {
return err
}
return store.HeadBucket(ctx)
}
// ─── 定时备份管理 ───
func (s *BackupService) GetSchedule(ctx context.Context) (*BackupScheduleConfig, error) {
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupSchedule)
if err != nil || raw == "" {
return &BackupScheduleConfig{}, nil
}
var cfg BackupScheduleConfig
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
return &BackupScheduleConfig{}, nil
}
return &cfg, nil
}
func (s *BackupService) UpdateSchedule(ctx context.Context, cfg BackupScheduleConfig) (*BackupScheduleConfig, error) {
if cfg.Enabled && cfg.CronExpr == "" {
return nil, infraerrors.BadRequest("INVALID_CRON", "cron expression is required when schedule is enabled")
}
// 验证 cron 表达式
if cfg.CronExpr != "" {
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
if _, err := parser.Parse(cfg.CronExpr); err != nil {
return nil, infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("invalid cron expression: %v", err))
}
}
data, err := json.Marshal(cfg)
if err != nil {
return nil, fmt.Errorf("marshal schedule config: %w", err)
}
if err := s.settingRepo.Set(ctx, settingKeyBackupSchedule, string(data)); err != nil {
return nil, fmt.Errorf("save schedule config: %w", err)
}
// 应用或停止定时任务
if cfg.Enabled {
if err := s.applyCronSchedule(&cfg); err != nil {
return nil, err
}
} else {
s.removeCronSchedule()
}
return &cfg, nil
}
func (s *BackupService) applyCronSchedule(cfg *BackupScheduleConfig) error {
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched == nil {
return fmt.Errorf("cron scheduler not initialized")
}
// 移除旧任务
if s.cronEntryID != 0 {
s.cronSched.Remove(s.cronEntryID)
s.cronEntryID = 0
}
entryID, err := s.cronSched.AddFunc(cfg.CronExpr, func() {
s.runScheduledBackup()
})
if err != nil {
return infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("failed to schedule: %v", err))
}
s.cronEntryID = entryID
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已启用: %s", cfg.CronExpr)
return nil
}
func (s *BackupService) removeCronSchedule() {
s.cronMu.Lock()
defer s.cronMu.Unlock()
if s.cronSched != nil && s.cronEntryID != 0 {
s.cronSched.Remove(s.cronEntryID)
s.cronEntryID = 0
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已停用")
}
}
func (s *BackupService) runScheduledBackup() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
defer cancel()
// 读取定时备份配置中的过期天数
schedule, _ := s.GetSchedule(ctx)
expireDays := 14 // 默认14天过期
if schedule != nil && schedule.RetainDays > 0 {
expireDays = schedule.RetainDays
}
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
if err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
return
}
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
// 清理过期备份(复用已加载的 schedule
if schedule == nil {
return
}
if err := s.cleanupOldBackups(ctx, schedule); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 清理过期备份失败: %v", err)
}
}
// ─── 备份/恢复核心 ───
// CreateBackup 创建全量数据库备份并上传到 S3流式处理
// expireDays: 备份过期天数0=永不过期默认14天
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
s.mu.Lock()
if s.backingUp {
s.mu.Unlock()
return nil, ErrBackupInProgress
}
s.backingUp = true
s.mu.Unlock()
defer func() {
s.mu.Lock()
s.backingUp = false
s.mu.Unlock()
}()
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return nil, err
}
if s3Cfg == nil || !s3Cfg.IsConfigured() {
return nil, ErrBackupS3NotConfigured
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return nil, fmt.Errorf("init object store: %w", err)
}
now := time.Now()
backupID := uuid.New().String()[:8]
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
s3Key := s.buildS3Key(s3Cfg, fileName)
var expiresAt string
if expireDays > 0 {
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
}
record := &BackupRecord{
ID: backupID,
Status: "running",
BackupType: "postgres",
FileName: fileName,
S3Key: s3Key,
TriggeredBy: triggeredBy,
StartedAt: now.Format(time.RFC3339),
ExpiresAt: expiresAt,
}
// 流式执行: pg_dump -> gzip -> S3 upload
dumpReader, err := s.dumper.Dump(ctx)
if err != nil {
record.Status = "failed"
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, record)
return record, fmt.Errorf("pg_dump: %w", err)
}
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
pr, pw := io.Pipe()
var gzipErr error
go func() {
gzWriter := gzip.NewWriter(pw)
_, gzipErr = io.Copy(gzWriter, dumpReader)
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
}
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
gzipErr = closeErr
}
if gzipErr != nil {
_ = pw.CloseWithError(gzipErr)
} else {
_ = pw.Close()
}
}()
contentType := "application/gzip"
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
if err != nil {
record.Status = "failed"
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
if gzipErr != nil {
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
}
record.ErrorMsg = errMsg
record.FinishedAt = time.Now().Format(time.RFC3339)
_ = s.saveRecord(ctx, record)
return record, fmt.Errorf("backup upload: %w", err)
}
record.SizeBytes = sizeBytes
record.Status = "completed"
record.FinishedAt = time.Now().Format(time.RFC3339)
if err := s.saveRecord(ctx, record); err != nil {
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
}
return record, nil
}
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
s.mu.Lock()
if s.restoring {
s.mu.Unlock()
return ErrRestoreInProgress
}
s.restoring = true
s.mu.Unlock()
defer func() {
s.mu.Lock()
s.restoring = false
s.mu.Unlock()
}()
record, err := s.GetBackupRecord(ctx, backupID)
if err != nil {
return err
}
if record.Status != "completed" {
return infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
}
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return err
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return fmt.Errorf("init object store: %w", err)
}
// 从 S3 流式下载
body, err := objectStore.Download(ctx, record.S3Key)
if err != nil {
return fmt.Errorf("S3 download failed: %w", err)
}
defer func() { _ = body.Close() }()
// 流式解压 gzip -> psql不将全部数据加载到内存
gzReader, err := gzip.NewReader(body)
if err != nil {
return fmt.Errorf("gzip reader: %w", err)
}
defer func() { _ = gzReader.Close() }()
// 流式恢复
if err := s.dumper.Restore(ctx, gzReader); err != nil {
return fmt.Errorf("pg restore: %w", err)
}
return nil
}
// ─── 备份记录管理 ───
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
records, err := s.loadRecords(ctx)
if err != nil {
return nil, err
}
// 倒序返回(最新在前)
sort.Slice(records, func(i, j int) bool {
return records[i].StartedAt > records[j].StartedAt
})
return records, nil
}
func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (*BackupRecord, error) {
records, err := s.loadRecords(ctx)
if err != nil {
return nil, err
}
for i := range records {
if records[i].ID == backupID {
return &records[i], nil
}
}
return nil, ErrBackupNotFound
}
func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error {
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
records, err := s.loadRecordsLocked(ctx)
if err != nil {
return err
}
var found *BackupRecord
var remaining []BackupRecord
for i := range records {
if records[i].ID == backupID {
found = &records[i]
} else {
remaining = append(remaining, records[i])
}
}
if found == nil {
return ErrBackupNotFound
}
// 从 S3 删除
if found.S3Key != "" && found.Status == "completed" {
s3Cfg, err := s.loadS3Config(ctx)
if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() {
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err == nil {
_ = objectStore.Delete(ctx, found.S3Key)
}
}
}
return s.saveRecordsLocked(ctx, remaining)
}
// GetBackupDownloadURL 获取备份文件预签名下载 URL
func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID string) (string, error) {
record, err := s.GetBackupRecord(ctx, backupID)
if err != nil {
return "", err
}
if record.Status != "completed" {
return "", infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "backup is not completed")
}
s3Cfg, err := s.loadS3Config(ctx)
if err != nil {
return "", err
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return "", err
}
url, err := objectStore.PresignURL(ctx, record.S3Key, 1*time.Hour)
if err != nil {
return "", fmt.Errorf("presign url: %w", err)
}
return url, nil
}
// ─── 内部方法 ───
func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) {
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config)
if err != nil || raw == "" {
return nil, nil //nolint:nilnil // no config is a valid state
}
var cfg BackupS3Config
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
return nil, ErrBackupS3ConfigCorrupt
}
// 解密 SecretAccessKey
if cfg.SecretAccessKey != "" {
decrypted, err := s.encryptor.Decrypt(cfg.SecretAccessKey)
if err != nil {
// 兼容未加密的旧数据:如果解密失败,保持原值
logger.LegacyPrintf("service.backup", "[Backup] S3 SecretAccessKey 解密失败(可能是旧的未加密数据): %v", err)
} else {
cfg.SecretAccessKey = decrypted
}
}
return &cfg, nil
}
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.store != nil && s.s3Cfg != nil {
return s.store, nil
}
if cfg == nil {
return nil, ErrBackupS3NotConfigured
}
store, err := s.storeFactory(ctx, cfg)
if err != nil {
return nil, err
}
s.store = store
s.s3Cfg = cfg
return store, nil
}
func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string {
prefix := strings.TrimRight(cfg.Prefix, "/")
if prefix == "" {
prefix = "backups"
}
return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName)
}
// loadRecords 加载备份记录,区分"无数据"和"数据损坏"
func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) {
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
return s.loadRecordsLocked(ctx)
}
// loadRecordsLocked 在已持有 recordsMu 锁的情况下加载记录
func (s *BackupService) loadRecordsLocked(ctx context.Context) ([]BackupRecord, error) {
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords)
if err != nil || raw == "" {
return nil, nil //nolint:nilnil // no records is a valid state
}
var records []BackupRecord
if err := json.Unmarshal([]byte(raw), &records); err != nil {
return nil, ErrBackupRecordsCorrupt
}
return records, nil
}
// saveRecordsLocked 在已持有 recordsMu 锁的情况下保存记录
func (s *BackupService) saveRecordsLocked(ctx context.Context, records []BackupRecord) error {
data, err := json.Marshal(records)
if err != nil {
return err
}
return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data))
}
// saveRecord 保存单条记录(带互斥锁保护)
func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error {
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
records, _ := s.loadRecordsLocked(ctx)
// 更新已有记录或追加
found := false
for i := range records {
if records[i].ID == record.ID {
records[i] = *record
found = true
break
}
}
if !found {
records = append(records, *record)
}
// 限制记录数量
if len(records) > maxBackupRecords {
records = records[len(records)-maxBackupRecords:]
}
return s.saveRecordsLocked(ctx, records)
}
func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error {
if schedule == nil {
return nil
}
s.recordsMu.Lock()
defer s.recordsMu.Unlock()
records, err := s.loadRecordsLocked(ctx)
if err != nil {
return err
}
// 按时间倒序
sort.Slice(records, func(i, j int) bool {
return records[i].StartedAt > records[j].StartedAt
})
var toDelete []BackupRecord
var toKeep []BackupRecord
for i, r := range records {
shouldDelete := false
// 按保留份数清理
if schedule.RetainCount > 0 && i >= schedule.RetainCount {
shouldDelete = true
}
// 按保留天数清理
if schedule.RetainDays > 0 && r.StartedAt != "" {
startedAt, err := time.Parse(time.RFC3339, r.StartedAt)
if err == nil && time.Since(startedAt) > time.Duration(schedule.RetainDays)*24*time.Hour {
shouldDelete = true
}
}
if shouldDelete && r.Status == "completed" {
toDelete = append(toDelete, r)
} else {
toKeep = append(toKeep, r)
}
}
// 删除 S3 上的文件
for _, r := range toDelete {
if r.S3Key != "" {
_ = s.deleteS3Object(ctx, r.S3Key)
}
}
if len(toDelete) > 0 {
logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete))
return s.saveRecordsLocked(ctx, toKeep)
}
return nil
}
func (s *BackupService) deleteS3Object(ctx context.Context, key string) error {
s3Cfg, err := s.loadS3Config(ctx)
if err != nil || s3Cfg == nil {
return nil
}
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
if err != nil {
return err
}
return objectStore.Delete(ctx, key)
}

View File

@@ -0,0 +1,528 @@
//go:build unit
package service
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// ─── Mocks ───
type mockSettingRepo struct {
mu sync.Mutex
data map[string]string
}
func newMockSettingRepo() *mockSettingRepo {
return &mockSettingRepo{data: make(map[string]string)}
}
func (m *mockSettingRepo) Get(_ context.Context, key string) (*Setting, error) {
m.mu.Lock()
defer m.mu.Unlock()
v, ok := m.data[key]
if !ok {
return nil, ErrSettingNotFound
}
return &Setting{Key: key, Value: v}, nil
}
func (m *mockSettingRepo) GetValue(_ context.Context, key string) (string, error) {
m.mu.Lock()
defer m.mu.Unlock()
v, ok := m.data[key]
if !ok {
return "", nil
}
return v, nil
}
func (m *mockSettingRepo) Set(_ context.Context, key, value string) error {
m.mu.Lock()
defer m.mu.Unlock()
m.data[key] = value
return nil
}
func (m *mockSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
result := make(map[string]string)
for _, k := range keys {
if v, ok := m.data[k]; ok {
result[k] = v
}
}
return result, nil
}
func (m *mockSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
m.mu.Lock()
defer m.mu.Unlock()
for k, v := range settings {
m.data[k] = v
}
return nil
}
func (m *mockSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
m.mu.Lock()
defer m.mu.Unlock()
result := make(map[string]string, len(m.data))
for k, v := range m.data {
result[k] = v
}
return result, nil
}
func (m *mockSettingRepo) Delete(_ context.Context, key string) error {
m.mu.Lock()
defer m.mu.Unlock()
delete(m.data, key)
return nil
}
// plainEncryptor 仅做 base64-like 包装,用于测试
type plainEncryptor struct{}
func (e *plainEncryptor) Encrypt(plaintext string) (string, error) {
return "ENC:" + plaintext, nil
}
func (e *plainEncryptor) Decrypt(ciphertext string) (string, error) {
if strings.HasPrefix(ciphertext, "ENC:") {
return strings.TrimPrefix(ciphertext, "ENC:"), nil
}
return ciphertext, fmt.Errorf("not encrypted")
}
type mockDumper struct {
dumpData []byte
dumpErr error
restored []byte
restErr error
}
func (m *mockDumper) Dump(_ context.Context) (io.ReadCloser, error) {
if m.dumpErr != nil {
return nil, m.dumpErr
}
return io.NopCloser(bytes.NewReader(m.dumpData)), nil
}
func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
if m.restErr != nil {
return m.restErr
}
d, err := io.ReadAll(data)
if err != nil {
return err
}
m.restored = d
return nil
}
type mockObjectStore struct {
objects map[string][]byte
mu sync.Mutex
}
func newMockObjectStore() *mockObjectStore {
return &mockObjectStore{objects: make(map[string][]byte)}
}
func (m *mockObjectStore) Upload(_ context.Context, key string, body io.Reader, _ string) (int64, error) {
data, err := io.ReadAll(body)
if err != nil {
return 0, err
}
m.mu.Lock()
m.objects[key] = data
m.mu.Unlock()
return int64(len(data)), nil
}
func (m *mockObjectStore) Download(_ context.Context, key string) (io.ReadCloser, error) {
m.mu.Lock()
data, ok := m.objects[key]
m.mu.Unlock()
if !ok {
return nil, fmt.Errorf("not found: %s", key)
}
return io.NopCloser(bytes.NewReader(data)), nil
}
func (m *mockObjectStore) Delete(_ context.Context, key string) error {
m.mu.Lock()
delete(m.objects, key)
m.mu.Unlock()
return nil
}
func (m *mockObjectStore) PresignURL(_ context.Context, key string, _ time.Duration) (string, error) {
return "https://presigned.example.com/" + key, nil
}
func (m *mockObjectStore) HeadBucket(_ context.Context) error {
return nil
}
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
cfg := &config.Config{
Database: config.DatabaseConfig{
Host: "localhost",
Port: 5432,
User: "test",
DBName: "testdb",
},
}
factory := func(_ context.Context, _ *BackupS3Config) (BackupObjectStore, error) {
return store, nil
}
return NewBackupService(repo, cfg, &plainEncryptor{}, factory, dumper)
}
func seedS3Config(t *testing.T, repo *mockSettingRepo) {
t.Helper()
cfg := BackupS3Config{
Bucket: "test-bucket",
AccessKeyID: "AKID",
SecretAccessKey: "ENC:secret123",
Prefix: "backups",
}
data, _ := json.Marshal(cfg)
require.NoError(t, repo.Set(context.Background(), settingKeyBackupS3Config, string(data)))
}
// ─── Tests ───
func TestBackupService_S3ConfigEncryption(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 保存配置 -> SecretAccessKey 应被加密
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
Bucket: "my-bucket",
AccessKeyID: "AKID",
SecretAccessKey: "my-secret",
Prefix: "backups",
})
require.NoError(t, err)
// 直接读取数据库中存储的值,应该是加密后的
raw, _ := repo.GetValue(context.Background(), settingKeyBackupS3Config)
var stored BackupS3Config
require.NoError(t, json.Unmarshal([]byte(raw), &stored))
require.Equal(t, "ENC:my-secret", stored.SecretAccessKey)
// 通过 GetS3Config 获取应该脱敏
cfg, err := svc.GetS3Config(context.Background())
require.NoError(t, err)
require.Empty(t, cfg.SecretAccessKey)
require.Equal(t, "my-bucket", cfg.Bucket)
// loadS3Config 内部应解密
internal, err := svc.loadS3Config(context.Background())
require.NoError(t, err)
require.Equal(t, "my-secret", internal.SecretAccessKey)
}
func TestBackupService_S3ConfigKeepExistingSecret(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 先保存一个有 secret 的配置
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
Bucket: "my-bucket",
AccessKeyID: "AKID",
SecretAccessKey: "original-secret",
})
require.NoError(t, err)
// 再更新时不提供 secret应保留原值
_, err = svc.UpdateS3Config(context.Background(), BackupS3Config{
Bucket: "my-bucket",
AccessKeyID: "AKID-NEW",
})
require.NoError(t, err)
internal, err := svc.loadS3Config(context.Background())
require.NoError(t, err)
require.Equal(t, "original-secret", internal.SecretAccessKey)
require.Equal(t, "AKID-NEW", internal.AccessKeyID)
}
func TestBackupService_SaveRecordConcurrency(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
var wg sync.WaitGroup
n := 20
wg.Add(n)
for i := 0; i < n; i++ {
go func(idx int) {
defer wg.Done()
record := &BackupRecord{
ID: fmt.Sprintf("rec-%d", idx),
Status: "completed",
StartedAt: time.Now().Format(time.RFC3339),
}
_ = svc.saveRecord(context.Background(), record)
}(i)
}
wg.Wait()
records, err := svc.loadRecords(context.Background())
require.NoError(t, err)
require.Len(t, records, n)
}
func TestBackupService_LoadRecords_Empty(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
records, err := svc.loadRecords(context.Background())
require.NoError(t, err)
require.Nil(t, records) // 无数据时返回 nil
}
func TestBackupService_LoadRecords_Corrupted(t *testing.T) {
repo := newMockSettingRepo()
_ = repo.Set(context.Background(), settingKeyBackupRecords, "not valid json{{{")
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
records, err := svc.loadRecords(context.Background())
require.Error(t, err) // 损坏数据应返回错误
require.Nil(t, records)
}
func TestBackupService_CreateBackup_Streaming(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
require.Equal(t, "completed", record.Status)
require.Greater(t, record.SizeBytes, int64(0))
require.NotEmpty(t, record.S3Key)
// 验证 S3 上确实有文件
store.mu.Lock()
require.Len(t, store.objects, 1)
store.mu.Unlock()
}
func TestBackupService_CreateBackup_DumpFailure(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &mockDumper{dumpErr: fmt.Errorf("pg_dump failed")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.Error(t, err)
require.Equal(t, "failed", record.Status)
require.Contains(t, record.ErrorMsg, "pg_dump")
}
func TestBackupService_CreateBackup_NoS3Config(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
_, err := svc.CreateBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupS3NotConfigured)
}
func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
// 使用一个慢速 dumper 来模拟正在进行的备份
dumper := &mockDumper{dumpData: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 手动设置 backingUp 标志
svc.mu.Lock()
svc.backingUp = true
svc.mu.Unlock()
_, err := svc.CreateBackup(context.Background(), "manual", 14)
require.ErrorIs(t, err, ErrBackupInProgress)
}
func TestBackupService_RestoreBackup_Streaming(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
// 先创建一个备份
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// 恢复
err = svc.RestoreBackup(context.Background(), record.ID)
require.NoError(t, err)
// 验证 psql 收到的数据是否与原始 dump 内容一致
require.Equal(t, dumpContent, string(dumper.restored))
}
func TestBackupService_RestoreBackup_NotCompleted(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
// 手动插入一条 failed 记录
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: "fail-1",
Status: "failed",
})
err := svc.RestoreBackup(context.Background(), "fail-1")
require.Error(t, err)
}
func TestBackupService_DeleteBackup(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumpContent := "data"
dumper := &mockDumper{dumpData: []byte(dumpContent)}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
// S3 中应有文件
store.mu.Lock()
require.Len(t, store.objects, 1)
store.mu.Unlock()
// 删除
err = svc.DeleteBackup(context.Background(), record.ID)
require.NoError(t, err)
// S3 中文件应被删除
store.mu.Lock()
require.Len(t, store.objects, 0)
store.mu.Unlock()
// 记录应不存在
_, err = svc.GetBackupRecord(context.Background(), record.ID)
require.ErrorIs(t, err, ErrBackupNotFound)
}
func TestBackupService_GetDownloadURL(t *testing.T) {
repo := newMockSettingRepo()
seedS3Config(t, repo)
dumper := &mockDumper{dumpData: []byte("data")}
store := newMockObjectStore()
svc := newTestBackupService(repo, dumper, store)
record, err := svc.CreateBackup(context.Background(), "manual", 14)
require.NoError(t, err)
url, err := svc.GetBackupDownloadURL(context.Background(), record.ID)
require.NoError(t, err)
require.Contains(t, url, "https://presigned.example.com/")
}
func TestBackupService_ListBackups_Sorted(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
now := time.Now()
for i := 0; i < 3; i++ {
_ = svc.saveRecord(context.Background(), &BackupRecord{
ID: fmt.Sprintf("rec-%d", i),
Status: "completed",
StartedAt: now.Add(time.Duration(i) * time.Hour).Format(time.RFC3339),
})
}
records, err := svc.ListBackups(context.Background())
require.NoError(t, err)
require.Len(t, records, 3)
// 最新在前
require.Equal(t, "rec-2", records[0].ID)
require.Equal(t, "rec-0", records[2].ID)
}
func TestBackupService_TestS3Connection(t *testing.T) {
repo := newMockSettingRepo()
store := newMockObjectStore()
svc := newTestBackupService(repo, &mockDumper{}, store)
err := svc.TestS3Connection(context.Background(), BackupS3Config{
Bucket: "test",
AccessKeyID: "ak",
SecretAccessKey: "sk",
})
require.NoError(t, err)
}
func TestBackupService_TestS3Connection_Incomplete(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
err := svc.TestS3Connection(context.Background(), BackupS3Config{
Bucket: "test",
})
require.Error(t, err)
require.Contains(t, err.Error(), "incomplete")
}
func TestBackupService_Schedule_CronValidation(t *testing.T) {
repo := newMockSettingRepo()
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
svc.cronSched = nil // 未初始化 cron
// 启用但 cron 为空
_, err := svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
Enabled: true,
CronExpr: "",
})
require.Error(t, err)
// 无效的 cron 表达式
_, err = svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
Enabled: true,
CronExpr: "invalid",
})
require.Error(t, err)
}
func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
repo := newMockSettingRepo()
_ = repo.Set(context.Background(), settingKeyBackupS3Config, "not json!!!!")
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
cfg, err := svc.loadS3Config(context.Background())
require.Error(t, err)
require.Nil(t, cfg)
}

View File

@@ -0,0 +1,607 @@
package service
import (
"encoding/json"
"fmt"
"net/url"
"regexp"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/domain"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const defaultBedrockRegion = "us-east-1"
var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."}
// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀
// 参考: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
func BedrockCrossRegionPrefix(region string) string {
switch {
case strings.HasPrefix(region, "us-gov"):
return "us-gov" // GovCloud 使用独立的 us-gov 前缀
case strings.HasPrefix(region, "us-"):
return "us"
case strings.HasPrefix(region, "eu-"):
return "eu"
case region == "ap-northeast-1":
return "jp" // 日本区域使用独立的 jp 前缀AWS 官方定义)
case region == "ap-southeast-2":
return "au" // 澳大利亚区域使用独立的 au 前缀AWS 官方定义)
case strings.HasPrefix(region, "ap-"):
return "apac" // 其余亚太区域使用通用 apac 前缀
case strings.HasPrefix(region, "ca-"):
return "us" // 加拿大区域使用 us 前缀的跨区域推理
case strings.HasPrefix(region, "sa-"):
return "us" // 南美区域使用 us 前缀的跨区域推理
default:
return "us"
}
}
// AdjustBedrockModelRegionPrefix 将模型 ID 的区域前缀替换为与当前 AWS Region 匹配的前缀
// 例如 region=eu-west-1 时,"us.anthropic.claude-opus-4-6-v1" → "eu.anthropic.claude-opus-4-6-v1"
// 特殊值 region="global" 强制使用 global. 前缀
func AdjustBedrockModelRegionPrefix(modelID, region string) string {
var targetPrefix string
if region == "global" {
targetPrefix = "global"
} else {
targetPrefix = BedrockCrossRegionPrefix(region)
}
for _, p := range bedrockCrossRegionPrefixes {
if strings.HasPrefix(modelID, p) {
if p == targetPrefix+"." {
return modelID // 前缀已匹配,无需替换
}
return targetPrefix + "." + modelID[len(p):]
}
}
// 模型 ID 没有已知区域前缀(如 "anthropic.claude-..."),不做修改
return modelID
}
func bedrockRuntimeRegion(account *Account) string {
if account == nil {
return defaultBedrockRegion
}
if region := account.GetCredential("aws_region"); region != "" {
return region
}
return defaultBedrockRegion
}
func shouldForceBedrockGlobal(account *Account) bool {
return account != nil && account.GetCredential("aws_force_global") == "true"
}
func isRegionalBedrockModelID(modelID string) bool {
for _, prefix := range bedrockCrossRegionPrefixes {
if strings.HasPrefix(modelID, prefix) {
return true
}
}
return false
}
func isLikelyBedrockModelID(modelID string) bool {
lower := strings.ToLower(strings.TrimSpace(modelID))
if lower == "" {
return false
}
if strings.HasPrefix(lower, "arn:") {
return true
}
for _, prefix := range []string{
"anthropic.",
"amazon.",
"meta.",
"mistral.",
"cohere.",
"ai21.",
"deepseek.",
"stability.",
"writer.",
"nova.",
} {
if strings.HasPrefix(lower, prefix) {
return true
}
}
return isRegionalBedrockModelID(lower)
}
func normalizeBedrockModelID(modelID string) (normalized string, shouldAdjustRegion bool, ok bool) {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return "", false, false
}
if mapped, exists := domain.DefaultBedrockModelMapping[modelID]; exists {
return mapped, true, true
}
if isRegionalBedrockModelID(modelID) {
return modelID, true, true
}
if isLikelyBedrockModelID(modelID) {
return modelID, false, true
}
return "", false, false
}
// ResolveBedrockModelID resolves a requested Claude model into a Bedrock model ID.
// It applies account model_mapping first, then default Bedrock aliases, and finally
// adjusts Anthropic cross-region prefixes to match the account region.
func ResolveBedrockModelID(account *Account, requestedModel string) (string, bool) {
if account == nil {
return "", false
}
mappedModel := account.GetMappedModel(requestedModel)
modelID, shouldAdjustRegion, ok := normalizeBedrockModelID(mappedModel)
if !ok {
return "", false
}
if shouldAdjustRegion {
targetRegion := bedrockRuntimeRegion(account)
if shouldForceBedrockGlobal(account) {
targetRegion = "global"
}
modelID = AdjustBedrockModelRegionPrefix(modelID, targetRegion)
}
return modelID, true
}
// BuildBedrockURL 构建 Bedrock InvokeModel 的 URL
// stream=true 时使用 invoke-with-response-stream 端点
// modelID 中的特殊字符会被 URL 编码(与 litellm 的 urllib.parse.quote(safe="") 对齐)
func BuildBedrockURL(region, modelID string, stream bool) string {
if region == "" {
region = defaultBedrockRegion
}
encodedModelID := url.PathEscape(modelID)
// url.PathEscape 不编码冒号RFC 允许 path 中出现 ":"
// 但 AWS Bedrock 期望模型 ID 中的冒号被编码为 %3A
encodedModelID = strings.ReplaceAll(encodedModelID, ":", "%3A")
if stream {
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, encodedModelID)
}
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", region, encodedModelID)
}
// PrepareBedrockRequestBody 处理请求体以适配 Bedrock API
// 1. 注入 anthropic_version
// 2. 注入 anthropic_beta从客户端 anthropic-beta 头解析)
// 3. 移除 Bedrock 不支持的字段model, stream, output_format, output_config
// 4. 移除工具定义中的 custom 字段Claude Code 会发送 custom: {defer_loading: true}
// 5. 清理 cache_control 中 Bedrock 不支持的字段scope, ttl
func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) {
betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID)
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens)
}
// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens.
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) {
var err error
// 注入 anthropic_versionBedrock 要求)
body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31")
if err != nil {
return nil, fmt.Errorf("inject anthropic_version: %w", err)
}
// 注入 anthropic_betaBedrock Invoke 通过请求体传递 beta 头,而非 HTTP 头)
// 1. 从客户端 anthropic-beta header 解析
// 2. 根据请求体内容自动补齐必要的 beta token
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() + _get_tool_search_beta_header_for_bedrock()
if len(betaTokens) > 0 {
body, err = sjson.SetBytes(body, "anthropic_beta", betaTokens)
if err != nil {
return nil, fmt.Errorf("inject anthropic_beta: %w", err)
}
}
// 移除 model 字段Bedrock 通过 URL 指定模型)
body, err = sjson.DeleteBytes(body, "model")
if err != nil {
return nil, fmt.Errorf("remove model field: %w", err)
}
// 移除 stream 字段Bedrock 通过不同端点控制流式,不接受请求体中的 stream 字段)
body, err = sjson.DeleteBytes(body, "stream")
if err != nil {
return nil, fmt.Errorf("remove stream field: %w", err)
}
// 转换 output_formatBedrock Invoke 不支持此字段,但可将 schema 内联到最后一条 user message
// 参考 litellm: _convert_output_format_to_inline_schema()
body = convertOutputFormatToInlineSchema(body)
// 移除 output_config 字段Bedrock Invoke 不支持)
body, err = sjson.DeleteBytes(body, "output_config")
if err != nil {
return nil, fmt.Errorf("remove output_config field: %w", err)
}
// 移除工具定义中的 custom 字段
// Claude Code (v2.1.69+) 在 tool 定义中发送 custom: {defer_loading: true}
// Anthropic API 接受但 Bedrock 会拒绝并报 "Extra inputs are not permitted"
body = removeCustomFieldFromTools(body)
// 清理 cache_control 中 Bedrock 不支持的字段
body = sanitizeBedrockCacheControl(body, modelID)
return body, nil
}
// ResolveBedrockBetaTokens computes the final Bedrock beta token list before policy filtering.
func ResolveBedrockBetaTokens(betaHeader string, body []byte, modelID string) []string {
betaTokens := parseAnthropicBetaHeader(betaHeader)
betaTokens = autoInjectBedrockBetaTokens(betaTokens, body, modelID)
return filterBedrockBetaTokens(betaTokens)
}
// convertOutputFormatToInlineSchema 将 output_format 中的 JSON schema 内联到最后一条 user message
// Bedrock Invoke 不支持 output_format 参数litellm 的做法是将 schema 追加到用户消息中
// 参考: litellm AmazonAnthropicClaudeMessagesConfig._convert_output_format_to_inline_schema()
func convertOutputFormatToInlineSchema(body []byte) []byte {
outputFormat := gjson.GetBytes(body, "output_format")
if !outputFormat.Exists() || !outputFormat.IsObject() {
return body
}
// 先从请求体中移除 output_format
body, _ = sjson.DeleteBytes(body, "output_format")
schema := outputFormat.Get("schema")
if !schema.Exists() {
return body
}
// 找到最后一条 user message
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
msgArr := messages.Array()
lastUserIdx := -1
for i := len(msgArr) - 1; i >= 0; i-- {
if msgArr[i].Get("role").String() == "user" {
lastUserIdx = i
break
}
}
if lastUserIdx < 0 {
return body
}
// 将 schema 序列化为 JSON 文本追加到该 message 的 content 数组
schemaJSON, err := json.Marshal(json.RawMessage(schema.Raw))
if err != nil {
return body
}
content := msgArr[lastUserIdx].Get("content")
basePath := fmt.Sprintf("messages.%d.content", lastUserIdx)
if content.IsArray() {
// 追加一个 text block 到 content 数组末尾
idx := len(content.Array())
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.type", basePath, idx), "text")
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.text", basePath, idx), string(schemaJSON))
} else if content.Type == gjson.String {
// content 是纯字符串,转换为数组格式
originalText := content.String()
body, _ = sjson.SetBytes(body, basePath, []map[string]string{
{"type": "text", "text": originalText},
{"type": "text", "text": string(schemaJSON)},
})
}
return body
}
// removeCustomFieldFromTools 移除 tools 数组中每个工具定义的 custom 字段
func removeCustomFieldFromTools(body []byte) []byte {
tools := gjson.GetBytes(body, "tools")
if !tools.Exists() || !tools.IsArray() {
return body
}
var err error
for i := range tools.Array() {
body, err = sjson.DeleteBytes(body, fmt.Sprintf("tools.%d.custom", i))
if err != nil {
// 删除失败不影响整体流程,跳过
continue
}
}
return body
}
// claudeVersionRe 匹配 Claude 模型 ID 中的版本号部分
// 支持 claude-{tier}-{major}-{minor} 和 claude-{tier}-{major}.{minor} 格式
var claudeVersionRe = regexp.MustCompile(`claude-(?:haiku|sonnet|opus)-(\d+)[-.](\d+)`)
// isBedrockClaude45OrNewer 判断 Bedrock 模型 ID 是否为 Claude 4.5 或更新版本
// Claude 4.5+ 支持 cache_control 中的 ttl 字段("5m" 和 "1h"
func isBedrockClaude45OrNewer(modelID string) bool {
lower := strings.ToLower(modelID)
matches := claudeVersionRe.FindStringSubmatch(lower)
if matches == nil {
return false
}
major, _ := strconv.Atoi(matches[1])
minor, _ := strconv.Atoi(matches[2])
return major > 4 || (major == 4 && minor >= 5)
}
// sanitizeBedrockCacheControl 清理 system 和 messages 中 cache_control 里
// Bedrock 不支持的字段:
// - scopeBedrock 不支持(如 "global" 跨请求缓存)
// - ttl仅 Claude 4.5+ 支持 "5m" 和 "1h",旧模型需要移除
func sanitizeBedrockCacheControl(body []byte, modelID string) []byte {
isClaude45 := isBedrockClaude45OrNewer(modelID)
// 清理 system 数组中的 cache_control
systemArr := gjson.GetBytes(body, "system")
if systemArr.Exists() && systemArr.IsArray() {
for i, item := range systemArr.Array() {
if !item.IsObject() {
continue
}
cc := item.Get("cache_control")
if !cc.Exists() || !cc.IsObject() {
continue
}
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("system.%d.cache_control", i), cc, isClaude45)
}
}
// 清理 messages 中的 cache_control
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
for mi, msg := range messages.Array() {
if !msg.IsObject() {
continue
}
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
continue
}
for ci, block := range content.Array() {
if !block.IsObject() {
continue
}
cc := block.Get("cache_control")
if !cc.Exists() || !cc.IsObject() {
continue
}
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci), cc, isClaude45)
}
}
return body
}
// deleteCacheControlUnsupportedFields 删除给定 cache_control 路径下 Bedrock 不支持的字段
func deleteCacheControlUnsupportedFields(body []byte, basePath string, cc gjson.Result, isClaude45 bool) []byte {
// Bedrock 不支持 scope如 "global"
if cc.Get("scope").Exists() {
body, _ = sjson.DeleteBytes(body, basePath+".scope")
}
// ttl仅 Claude 4.5+ 支持 "5m" 和 "1h",其余情况移除
ttl := cc.Get("ttl")
if ttl.Exists() {
shouldRemove := true
if isClaude45 {
v := ttl.String()
if v == "5m" || v == "1h" {
shouldRemove = false
}
}
if shouldRemove {
body, _ = sjson.DeleteBytes(body, basePath+".ttl")
}
}
return body
}
// parseAnthropicBetaHeader 解析 anthropic-beta 头的逗号分隔字符串为 token 列表
func parseAnthropicBetaHeader(header string) []string {
header = strings.TrimSpace(header)
if header == "" {
return nil
}
if strings.HasPrefix(header, "[") && strings.HasSuffix(header, "]") {
var parsed []any
if err := json.Unmarshal([]byte(header), &parsed); err == nil {
tokens := make([]string, 0, len(parsed))
for _, item := range parsed {
token := strings.TrimSpace(fmt.Sprint(item))
if token != "" {
tokens = append(tokens, token)
}
}
return tokens
}
}
var tokens []string
for _, part := range strings.Split(header, ",") {
t := strings.TrimSpace(part)
if t != "" {
tokens = append(tokens, t)
}
}
return tokens
}
// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单
// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json)
// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单
var bedrockSupportedBetaTokens = map[string]bool{
"computer-use-2025-01-24": true,
"computer-use-2025-11-24": true,
"context-1m-2025-08-07": true,
"context-management-2025-06-27": true,
"compact-2026-01-12": true,
"interleaved-thinking-2025-05-14": true,
"tool-search-tool-2025-10-19": true,
"tool-examples-2025-10-29": true,
}
// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则
// Anthropic 直接 API 使用通用头Bedrock Invoke 需要特定的替代头
var bedrockBetaTokenTransforms = map[string]string{
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
}
// autoInjectBedrockBetaTokens 根据请求体内容自动补齐必要的 beta token
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() 和
// AmazonAnthropicClaudeMessagesConfig._get_tool_search_beta_header_for_bedrock()
//
// 客户端(特别是非 Claude Code 客户端)可能只在 body 中启用了功能而不在 header 中带对应 beta token
// 这里通过检测请求体特征自动补齐,确保 Bedrock Invoke 不会因缺少必要 beta 头而 400。
func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) []string {
seen := make(map[string]bool, len(tokens))
for _, t := range tokens {
seen[t] = true
}
inject := func(token string) {
if !seen[token] {
tokens = append(tokens, token)
seen[token] = true
}
}
// 检测 thinking / interleaved thinking
// 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta
if gjson.GetBytes(body, "thinking").Exists() {
inject("interleaved-thinking-2025-05-14")
}
// 检测 computer_use 工具
// tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta
tools := gjson.GetBytes(body, "tools")
if tools.Exists() && tools.IsArray() {
toolSearchUsed := false
programmaticToolCallingUsed := false
inputExamplesUsed := false
for _, tool := range tools.Array() {
toolType := tool.Get("type").String()
if strings.HasPrefix(toolType, "computer_20") {
inject("computer-use-2025-11-24")
}
if isBedrockToolSearchType(toolType) {
toolSearchUsed = true
}
if hasCodeExecutionAllowedCallers(tool) {
programmaticToolCallingUsed = true
}
if hasInputExamples(tool) {
inputExamplesUsed = true
}
}
if programmaticToolCallingUsed || inputExamplesUsed {
// programmatic tool calling 和 input examples 需要 advanced-tool-use
// 后续 filterBedrockBetaTokens 会将其转换为 Bedrock 特定的 tool-search-tool
inject("advanced-tool-use-2025-11-20")
}
if toolSearchUsed && bedrockModelSupportsToolSearch(modelID) {
// 纯 tool search无 programmatic/inputExamples时直接注入 Bedrock 特定头,
// 跳过 advanced-tool-use → tool-search-tool 的转换步骤(与 litellm 对齐)
if !programmaticToolCallingUsed && !inputExamplesUsed {
inject("tool-search-tool-2025-10-19")
} else {
inject("advanced-tool-use-2025-11-20")
}
}
}
return tokens
}
func isBedrockToolSearchType(toolType string) bool {
return toolType == "tool_search_tool_regex_20251119" || toolType == "tool_search_tool_bm25_20251119"
}
func hasCodeExecutionAllowedCallers(tool gjson.Result) bool {
allowedCallers := tool.Get("allowed_callers")
if containsStringInJSONArray(allowedCallers, "code_execution_20250825") {
return true
}
return containsStringInJSONArray(tool.Get("function.allowed_callers"), "code_execution_20250825")
}
func hasInputExamples(tool gjson.Result) bool {
if arr := tool.Get("input_examples"); arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 {
return true
}
arr := tool.Get("function.input_examples")
return arr.Exists() && arr.IsArray() && len(arr.Array()) > 0
}
func containsStringInJSONArray(result gjson.Result, target string) bool {
if !result.Exists() || !result.IsArray() {
return false
}
for _, item := range result.Array() {
if item.String() == target {
return true
}
}
return false
}
// bedrockModelSupportsToolSearch 判断 Bedrock 模型是否支持 tool search
// 目前仅 Claude Opus/Sonnet 4.5+ 支持Haiku 不支持
func bedrockModelSupportsToolSearch(modelID string) bool {
lower := strings.ToLower(modelID)
matches := claudeVersionRe.FindStringSubmatch(lower)
if matches == nil {
return false
}
// Haiku 不支持 tool search
if strings.Contains(lower, "haiku") {
return false
}
major, _ := strconv.Atoi(matches[1])
minor, _ := strconv.Atoi(matches[2])
return major > 4 || (major == 4 && minor >= 5)
}
// filterBedrockBetaTokens 过滤并转换 beta token 列表,仅保留 Bedrock Invoke 支持的 token
// 1. 应用转换规则(如 advanced-tool-use → tool-search-tool
// 2. 过滤掉 Bedrock 不支持的 token如 output-128k, files-api, structured-outputs 等)
// 3. 自动关联 tool-examples当 tool-search-tool 存在时)
func filterBedrockBetaTokens(tokens []string) []string {
seen := make(map[string]bool, len(tokens))
var result []string
for _, t := range tokens {
// 应用转换规则
if replacement, ok := bedrockBetaTokenTransforms[t]; ok {
t = replacement
}
// 只保留白名单中的 token且去重
if bedrockSupportedBetaTokens[t] && !seen[t] {
result = append(result, t)
seen[t] = true
}
}
// 自动关联: tool-search-tool 存在时,确保 tool-examples 也存在
if seen["tool-search-tool-2025-10-19"] && !seen["tool-examples-2025-10-29"] {
result = append(result, "tool-examples-2025-10-29")
}
return result
}

View File

@@ -0,0 +1,659 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestPrepareBedrockRequestBody_BasicFields(t *testing.T) {
input := `{"model":"claude-opus-4-6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
// anthropic_version 应被注入
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
// model 和 stream 应被移除
assert.False(t, gjson.GetBytes(result, "model").Exists())
assert.False(t, gjson.GetBytes(result, "stream").Exists())
// max_tokens 应保留
assert.Equal(t, int64(1024), gjson.GetBytes(result, "max_tokens").Int())
}
func TestPrepareBedrockRequestBody_OutputFormatInlineSchema(t *testing.T) {
t.Run("schema inlined into last user message array content", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}},"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
// schema 应内联到最后一条 user message 的 content 数组末尾
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "text", contentArr[1].Get("type").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"name":"string"`)
})
t.Run("schema inlined into string content", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"result":"number"}},"messages":[{"role":"user","content":"compute this"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "compute this", contentArr[0].Get("text").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"number"`)
})
t.Run("no schema field just removes output_format", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json"},"messages":[{"role":"user","content":"hi"}]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
})
t.Run("no messages just removes output_format", func(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
})
}
func TestPrepareBedrockRequestBody_RemoveOutputConfig(t *testing.T) {
input := `{"model":"claude-sonnet-4-5","output_config":{"max_tokens":100},"messages":[]}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
}
func TestRemoveCustomFieldFromTools(t *testing.T) {
input := `{
"tools": [
{"name":"tool1","custom":{"defer_loading":true},"description":"desc1"},
{"name":"tool2","description":"desc2"},
{"name":"tool3","custom":{"defer_loading":true,"other":123},"description":"desc3"}
]
}`
result := removeCustomFieldFromTools([]byte(input))
tools := gjson.GetBytes(result, "tools").Array()
require.Len(t, tools, 3)
// custom 应被移除
assert.False(t, tools[0].Get("custom").Exists())
// name/description 应保留
assert.Equal(t, "tool1", tools[0].Get("name").String())
assert.Equal(t, "desc1", tools[0].Get("description").String())
// 没有 custom 的工具不受影响
assert.Equal(t, "tool2", tools[1].Get("name").String())
// 第三个工具的 custom 也应被移除
assert.False(t, tools[2].Get("custom").Exists())
assert.Equal(t, "tool3", tools[2].Get("name").String())
}
func TestRemoveCustomFieldFromTools_NoTools(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}]}`
result := removeCustomFieldFromTools([]byte(input))
// 无 tools 时不改变原始数据
assert.JSONEq(t, input, string(result))
}
func TestSanitizeBedrockCacheControl_RemoveScope(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","scope":"global"}}],
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"global"}}]}]
}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
// scope 应被移除
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
assert.False(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.scope").Exists())
// type 应保留
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "messages.0.content.0.cache_control.type").String())
}
func TestSanitizeBedrockCacheControl_TTL_OldModel(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
}`
// 旧模型Claude 3.5)不支持 ttl
result := sanitizeBedrockCacheControl([]byte(input), "anthropic.claude-3-5-sonnet-20241022-v2:0")
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
}
func TestSanitizeBedrockCacheControl_TTL_Claude45_Supported(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
}`
// Claude 4.5+ 支持 "5m" 和 "1h"
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
assert.True(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
}
func TestSanitizeBedrockCacheControl_TTL_Claude45_UnsupportedValue(t *testing.T) {
input := `{
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"10m"}}]
}`
// Claude 4.5 不支持 "10m"
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
}
func TestSanitizeBedrockCacheControl_TTL_Claude46(t *testing.T) {
input := `{
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
assert.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").Exists())
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
}
func TestSanitizeBedrockCacheControl_NoCacheControl(t *testing.T) {
input := `{"system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
// 无 cache_control 时不改变原始数据
assert.JSONEq(t, input, string(result))
}
func TestIsBedrockClaude45OrNewer(t *testing.T) {
tests := []struct {
modelID string
expect bool
}{
{"us.anthropic.claude-opus-4-6-v1", true},
{"us.anthropic.claude-sonnet-4-6", true},
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
{"us.anthropic.claude-opus-4-5-20251101-v1:0", true},
{"us.anthropic.claude-haiku-4-5-20251001-v1:0", true},
{"anthropic.claude-3-5-sonnet-20241022-v2:0", false},
{"anthropic.claude-3-opus-20240229-v1:0", false},
{"anthropic.claude-3-haiku-20240307-v1:0", false},
// 未来版本应自动支持
{"us.anthropic.claude-sonnet-5-0-v1", true},
{"us.anthropic.claude-opus-4-7-v1", true},
// 旧版本
{"anthropic.claude-opus-4-1-v1", false},
{"anthropic.claude-sonnet-4-0-v1", false},
// 非 Claude 模型
{"amazon.nova-pro-v1", false},
{"meta.llama3-70b", false},
}
for _, tt := range tests {
t.Run(tt.modelID, func(t *testing.T) {
assert.Equal(t, tt.expect, isBedrockClaude45OrNewer(tt.modelID))
})
}
}
func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
// 模拟一个完整的 Claude Code 请求
input := `{
"model": "claude-opus-4-6",
"stream": true,
"max_tokens": 16384,
"output_format": {"type": "json", "schema": {"result": "string"}},
"output_config": {"max_tokens": 100},
"system": [{"type": "text", "text": "You are helpful", "cache_control": {"type": "ephemeral", "scope": "global", "ttl": "5m"}}],
"messages": [
{"role": "user", "content": [{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral", "ttl": "1h"}}]}
],
"tools": [
{"name": "bash", "description": "Run bash", "custom": {"defer_loading": true}, "input_schema": {"type": "object"}},
{"name": "read", "description": "Read file", "input_schema": {"type": "object"}}
]
}`
betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12"
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader)
require.NoError(t, err)
// 基本字段
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
assert.False(t, gjson.GetBytes(result, "model").Exists())
assert.False(t, gjson.GetBytes(result, "stream").Exists())
assert.Equal(t, int64(16384), gjson.GetBytes(result, "max_tokens").Int())
// anthropic_beta 应包含所有 beta tokens
betaArr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, betaArr, 3)
assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String())
assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String())
assert.Equal(t, "compact-2026-01-12", betaArr[2].String())
// output_format 应被移除schema 内联到最后一条 user message
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
// content 数组:原始 text block + 内联 schema block
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
require.Len(t, contentArr, 2)
assert.Equal(t, "hello", contentArr[0].Get("text").String())
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"string"`)
// tools 中的 custom 应被移除
assert.False(t, gjson.GetBytes(result, "tools.0.custom").Exists())
assert.Equal(t, "bash", gjson.GetBytes(result, "tools.0.name").String())
assert.Equal(t, "read", gjson.GetBytes(result, "tools.1.name").String())
// cache_control: scope 应被移除ttl 在 Claude 4.6 上保留合法值
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
}
func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
t.Run("empty beta header", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
})
t.Run("single beta token", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 1)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
})
t.Run("multiple beta tokens with spaces", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
})
t.Run("json array beta header", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`)
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
})
}
func TestParseAnthropicBetaHeader(t *testing.T) {
assert.Nil(t, parseAnthropicBetaHeader(""))
assert.Equal(t, []string{"a"}, parseAnthropicBetaHeader("a"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a,b"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a , b "))
assert.Equal(t, []string{"a", "b", "c"}, parseAnthropicBetaHeader("a,b,c"))
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader(`["a","b"]`))
}
func TestFilterBedrockBetaTokens(t *testing.T) {
t.Run("supported tokens pass through", func(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, tokens, result)
})
t.Run("unsupported tokens are filtered out", func(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
})
t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) {
tokens := []string{"advanced-tool-use-2025-11-20"}
result := filterBedrockBetaTokens(tokens)
assert.Contains(t, result, "tool-search-tool-2025-10-19")
// tool-examples 自动关联
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("tool-search-tool auto-associates tool-examples", func(t *testing.T) {
tokens := []string{"tool-search-tool-2025-10-19"}
result := filterBedrockBetaTokens(tokens)
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("no duplication when tool-examples already present", func(t *testing.T) {
tokens := []string{"tool-search-tool-2025-10-19", "tool-examples-2025-10-29"}
result := filterBedrockBetaTokens(tokens)
count := 0
for _, t := range result {
if t == "tool-examples-2025-10-29" {
count++
}
}
assert.Equal(t, 1, count)
})
t.Run("empty input returns nil", func(t *testing.T) {
result := filterBedrockBetaTokens(nil)
assert.Nil(t, result)
})
t.Run("all unsupported returns nil", func(t *testing.T) {
result := filterBedrockBetaTokens([]string{"output-128k-2025-02-19", "effort-2025-11-24"})
assert.Nil(t, result)
})
t.Run("duplicate tokens are deduplicated", func(t *testing.T) {
tokens := []string{"context-1m-2025-08-07", "context-1m-2025-08-07"}
result := filterBedrockBetaTokens(tokens)
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
})
}
func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
t.Run("unsupported beta tokens are filtered", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
"interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 1)
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
})
t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) {
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
"advanced-tool-use-2025-11-20")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
require.Len(t, arr, 2)
assert.Equal(t, "tool-search-tool-2025-10-19", arr[0].String())
assert.Equal(t, "tool-examples-2025-10-29", arr[1].String())
})
}
func TestBedrockCrossRegionPrefix(t *testing.T) {
tests := []struct {
region string
expect string
}{
// US regions
{"us-east-1", "us"},
{"us-east-2", "us"},
{"us-west-1", "us"},
{"us-west-2", "us"},
// GovCloud
{"us-gov-east-1", "us-gov"},
{"us-gov-west-1", "us-gov"},
// EU regions
{"eu-west-1", "eu"},
{"eu-west-2", "eu"},
{"eu-west-3", "eu"},
{"eu-central-1", "eu"},
{"eu-central-2", "eu"},
{"eu-north-1", "eu"},
{"eu-south-1", "eu"},
// APAC regions
{"ap-northeast-1", "jp"},
{"ap-northeast-2", "apac"},
{"ap-southeast-1", "apac"},
{"ap-southeast-2", "au"},
{"ap-south-1", "apac"},
// Canada / South America fallback to us
{"ca-central-1", "us"},
{"sa-east-1", "us"},
// Unknown defaults to us
{"me-south-1", "us"},
}
for _, tt := range tests {
t.Run(tt.region, func(t *testing.T) {
assert.Equal(t, tt.expect, BedrockCrossRegionPrefix(tt.region))
})
}
}
func TestResolveBedrockModelID(t *testing.T) {
t.Run("default alias resolves and adjusts region", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "eu-west-1",
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-5")
require.True(t, ok)
assert.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", modelID)
})
t.Run("custom alias mapping reuses default bedrock mapping", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "ap-southeast-2",
"model_mapping": map[string]any{
"claude-*": "claude-opus-4-6",
},
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-6-thinking")
require.True(t, ok)
assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID)
})
t.Run("force global rewrites anthropic regional model id", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
"aws_force_global": "true",
"model_mapping": map[string]any{
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
},
},
}
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-6")
require.True(t, ok)
assert.Equal(t, "global.anthropic.claude-sonnet-4-6", modelID)
})
t.Run("direct bedrock model id passes through", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
},
}
modelID, ok := ResolveBedrockModelID(account, "anthropic.claude-haiku-4-5-20251001-v1:0")
require.True(t, ok)
assert.Equal(t, "anthropic.claude-haiku-4-5-20251001-v1:0", modelID)
})
t.Run("unsupported alias returns false", func(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_region": "us-east-1",
},
}
_, ok := ResolveBedrockModelID(account, "claude-3-5-sonnet-20241022")
assert.False(t, ok)
})
}
func TestAutoInjectBedrockBetaTokens(t *testing.T) {
t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
})
t.Run("no duplicate when already present", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1")
count := 0
for _, t := range result {
if t == "interleaved-thinking-2025-05-14" {
count++
}
}
assert.Equal(t, 1, count)
})
t.Run("inject computer-use when computer tool present", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"computer_20250124","name":"computer","display_width_px":1024}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "computer-use-2025-11-24")
})
t.Run("inject advanced-tool-use for programmatic tool calling", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject advanced-tool-use for input examples", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","input_examples":[{"cmd":"ls"}]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject tool-search-tool directly for pure tool search (no programmatic/inputExamples)", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
// 纯 tool search 场景直接注入 Bedrock 特定头,不走 advanced-tool-use 转换
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("inject advanced-tool-use when tool search combined with programmatic calling", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"},{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
// 混合场景使用 advanced-tool-use后续由 filter 转换为 tool-search-tool
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
})
t.Run("do not inject tool-search beta for unsupported models", func(t *testing.T) {
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "anthropic.claude-3-5-sonnet-20241022-v2:0")
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
assert.NotContains(t, result, "tool-search-tool-2025-10-19")
})
t.Run("no injection for regular tools", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","description":"run bash","input_schema":{"type":"object"}}],"messages":[{"role":"user","content":"hi"}]}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Empty(t, result)
})
t.Run("no injection when no features detected", func(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`)
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
assert.Empty(t, result)
})
t.Run("preserves existing tokens", func(t *testing.T) {
body := []byte(`{"thinking":{"type":"enabled"},"messages":[{"role":"user","content":"hi"}]}`)
existing := []string{"context-1m-2025-08-07", "compact-2026-01-12"}
result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "context-1m-2025-08-07")
assert.Contains(t, result, "compact-2026-01-12")
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
})
}
func TestResolveBedrockBetaTokens(t *testing.T) {
t.Run("body-only tool features resolve to final bedrock tokens", func(t *testing.T) {
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
result := ResolveBedrockBetaTokens("", body, "us.anthropic.claude-opus-4-6-v1")
assert.Contains(t, result, "tool-search-tool-2025-10-19")
assert.Contains(t, result, "tool-examples-2025-10-29")
})
t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) {
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
})
}
func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
t.Run("thinking in body auto-injects beta without header", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
found := false
for _, v := range arr {
if v.String() == "interleaved-thinking-2025-05-14" {
found = true
}
}
assert.True(t, found, "interleaved-thinking should be auto-injected")
})
t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) {
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
require.NoError(t, err)
arr := gjson.GetBytes(result, "anthropic_beta").Array()
names := make([]string, len(arr))
for i, v := range arr {
names[i] = v.String()
}
assert.Contains(t, names, "context-1m-2025-08-07")
assert.Contains(t, names, "interleaved-thinking-2025-05-14")
})
}
func TestAdjustBedrockModelRegionPrefix(t *testing.T) {
tests := []struct {
name string
modelID string
region string
expect string
}{
// US region — no change needed
{"us region keeps us prefix", "us.anthropic.claude-opus-4-6-v1", "us-east-1", "us.anthropic.claude-opus-4-6-v1"},
// EU region — replace us → eu
{"eu region replaces prefix", "us.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
{"eu region sonnet", "us.anthropic.claude-sonnet-4-6", "eu-central-1", "eu.anthropic.claude-sonnet-4-6"},
// APAC region — jp and au have dedicated prefixes per AWS docs
{"jp region (ap-northeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-northeast-1", "jp.anthropic.claude-sonnet-4-5-20250929-v1:0"},
{"au region (ap-southeast-2)", "us.anthropic.claude-haiku-4-5-20251001-v1:0", "ap-southeast-2", "au.anthropic.claude-haiku-4-5-20251001-v1:0"},
{"apac region (ap-southeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-southeast-1", "apac.anthropic.claude-sonnet-4-5-20250929-v1:0"},
// eu → us (user manually set eu prefix, moved to us region)
{"eu to us", "eu.anthropic.claude-opus-4-6-v1", "us-west-2", "us.anthropic.claude-opus-4-6-v1"},
// global prefix — replace to match region
{"global to eu", "global.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
// No known prefix — leave unchanged
{"no prefix unchanged", "anthropic.claude-3-5-sonnet-20241022-v2:0", "eu-west-1", "anthropic.claude-3-5-sonnet-20241022-v2:0"},
// GovCloud — uses independent us-gov prefix
{"govcloud from us", "us.anthropic.claude-opus-4-6-v1", "us-gov-east-1", "us-gov.anthropic.claude-opus-4-6-v1"},
{"govcloud already correct", "us-gov.anthropic.claude-opus-4-6-v1", "us-gov-west-1", "us-gov.anthropic.claude-opus-4-6-v1"},
// Force global (special region value)
{"force global from us", "us.anthropic.claude-opus-4-6-v1", "global", "global.anthropic.claude-opus-4-6-v1"},
{"force global from eu", "eu.anthropic.claude-sonnet-4-6", "global", "global.anthropic.claude-sonnet-4-6"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.expect, AdjustBedrockModelRegionPrefix(tt.modelID, tt.region))
})
}
}

View File

@@ -0,0 +1,67 @@
package service
import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"net/http"
"time"
"github.com/aws/aws-sdk-go-v2/aws"
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
)
// BedrockSigner 使用 AWS SigV4 对 Bedrock 请求签名
type BedrockSigner struct {
credentials aws.Credentials
region string
signer *v4.Signer
}
// NewBedrockSigner 创建 BedrockSigner
func NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region string) *BedrockSigner {
return &BedrockSigner{
credentials: aws.Credentials{
AccessKeyID: accessKeyID,
SecretAccessKey: secretAccessKey,
SessionToken: sessionToken,
},
region: region,
signer: v4.NewSigner(),
}
}
// NewBedrockSignerFromAccount 从 Account 凭证创建 BedrockSigner
func NewBedrockSignerFromAccount(account *Account) (*BedrockSigner, error) {
accessKeyID := account.GetCredential("aws_access_key_id")
if accessKeyID == "" {
return nil, fmt.Errorf("aws_access_key_id not found in credentials")
}
secretAccessKey := account.GetCredential("aws_secret_access_key")
if secretAccessKey == "" {
return nil, fmt.Errorf("aws_secret_access_key not found in credentials")
}
region := account.GetCredential("aws_region")
if region == "" {
region = defaultBedrockRegion
}
sessionToken := account.GetCredential("aws_session_token") // 可选
return NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region), nil
}
// SignRequest 对 HTTP 请求进行 SigV4 签名
// 重要约束调用此方法前req 应只包含 AWS 相关的 header如 Content-Type、Accept
// 非 AWS header如 anthropic-beta会参与签名计算如果 Bedrock 服务端不识别这些 header
// 签名验证可能失败。litellm 通过 _filter_headers_for_aws_signature 实现头过滤,
// 当前实现中 buildUpstreamRequestBedrock 仅设置了 Content-Type 和 Accept因此是安全的。
func (s *BedrockSigner) SignRequest(ctx context.Context, req *http.Request, body []byte) error {
payloadHash := sha256Hash(body)
return s.signer.SignHTTP(ctx, s.credentials, req, payloadHash, "bedrock", s.region, time.Now())
}
func sha256Hash(data []byte) string {
h := sha256.Sum256(data)
return hex.EncodeToString(h[:])
}

View File

@@ -0,0 +1,35 @@
package service
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestNewBedrockSignerFromAccount_DefaultRegion(t *testing.T) {
account := &Account{
Platform: PlatformAnthropic,
Type: AccountTypeBedrock,
Credentials: map[string]any{
"aws_access_key_id": "test-akid",
"aws_secret_access_key": "test-secret",
},
}
signer, err := NewBedrockSignerFromAccount(account)
require.NoError(t, err)
require.NotNil(t, signer)
assert.Equal(t, defaultBedrockRegion, signer.region)
}
func TestFilterBetaTokens(t *testing.T) {
tokens := []string{"interleaved-thinking-2025-05-14", "tool-search-tool-2025-10-19"}
filterSet := map[string]struct{}{
"tool-search-tool-2025-10-19": {},
}
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, filterBetaTokens(tokens, filterSet))
assert.Equal(t, tokens, filterBetaTokens(tokens, nil))
assert.Nil(t, filterBetaTokens(nil, filterSet))
}

View File

@@ -0,0 +1,414 @@
package service
import (
"bufio"
"context"
"encoding/base64"
"errors"
"fmt"
"hash/crc32"
"io"
"net/http"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应
// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的
// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。
func (s *GatewayService) handleBedrockStreamingResponse(
ctx context.Context,
resp *http.Response,
c *gin.Context,
account *Account,
startTime time.Time,
model string,
) (*streamingResult, error) {
w := c.Writer
flusher, ok := w.(http.Flusher)
if !ok {
return nil, errors.New("streaming not supported")
}
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("X-Accel-Buffering", "no")
if v := resp.Header.Get("x-amzn-requestid"); v != "" {
c.Header("x-request-id", v)
}
usage := &ClaudeUsage{}
var firstTokenMs *int
clientDisconnected := false
// Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。
// 每个帧结构total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4)
// 但更实用的方式是使用行扫描找 JSON chunks因为 Bedrock 的响应在二进制帧中。
// 我们使用 EventStream decoder 来正确解析。
decoder := newBedrockEventStreamDecoder(resp.Body)
type decodeEvent struct {
payload []byte
err error
}
events := make(chan decodeEvent, 16)
done := make(chan struct{})
sendEvent := func(ev decodeEvent) bool {
select {
case events <- ev:
return true
case <-done:
return false
}
}
var lastReadAt atomic.Int64
lastReadAt.Store(time.Now().UnixNano())
go func() {
defer close(events)
for {
payload, err := decoder.Decode()
if err != nil {
if err == io.EOF {
return
}
_ = sendEvent(decodeEvent{err: err})
return
}
lastReadAt.Store(time.Now().UnixNano())
if !sendEvent(decodeEvent{payload: payload}) {
return
}
}
}()
defer close(done)
streamInterval := time.Duration(0)
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
}
var intervalTicker *time.Ticker
if streamInterval > 0 {
intervalTicker = time.NewTicker(streamInterval)
defer intervalTicker.Stop()
}
var intervalCh <-chan time.Time
if intervalTicker != nil {
intervalCh = intervalTicker.C
}
for {
select {
case ev, ok := <-events:
if !ok {
if !clientDisconnected {
flusher.Flush()
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
}
if ev.err != nil {
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err)
}
// payload 是 JSON提取 chunk.bytesbase64 编码的 Claude SSE 事件数据)
sseData := extractBedrockChunkData(ev.payload)
if sseData == nil {
continue
}
if firstTokenMs == nil {
ms := int(time.Since(startTime).Milliseconds())
firstTokenMs = &ms
}
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
// 同时移除该字段避免透传给客户端
sseData = transformBedrockInvocationMetrics(sseData)
// 解析 SSE 事件数据提取 usage
s.parseSSEUsagePassthrough(string(sseData), usage)
// 确定 SSE event type
eventType := gjson.GetBytes(sseData, "type").String()
// 写入标准 SSE 格式
if !clientDisconnected {
var writeErr error
if eventType != "" {
_, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData)
} else {
_, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData)
}
if writeErr != nil {
clientDisconnected = true
logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID)
} else {
flusher.Flush()
}
}
case <-intervalCh:
lastRead := time.Unix(0, lastReadAt.Load())
if time.Since(lastRead) < streamInterval {
continue
}
if clientDisconnected {
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
}
logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
if s.rateLimitService != nil {
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
}
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
}
}
}
// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据
// Bedrock payload 格式:{"bytes":"<base64-encoded-json>"}
func extractBedrockChunkData(payload []byte) []byte {
b64 := gjson.GetBytes(payload, "bytes").String()
if b64 == "" {
return nil
}
decoded, err := base64.StdEncoding.DecodeString(b64)
if err != nil {
return nil
}
return decoded
}
// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics
// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。
//
// Bedrock Invoke 返回的 message_delta 事件可能包含:
//
// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}
//
// 转换为:
//
// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}}
func transformBedrockInvocationMetrics(data []byte) []byte {
metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics")
if !metrics.Exists() || !metrics.IsObject() {
return data
}
// 移除 Bedrock 特有字段
data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics")
// 如果已有标准 usage 字段,不覆盖
if gjson.GetBytes(data, "usage").Exists() {
return data
}
// 转换 camelCase → snake_case 写入 usage
inputTokens := metrics.Get("inputTokenCount")
outputTokens := metrics.Get("outputTokenCount")
if inputTokens.Exists() {
data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int())
}
if outputTokens.Exists() {
data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int())
}
return data
}
// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧
// EventStream 帧格式:
//
// [total_byte_length: 4 bytes]
// [headers_byte_length: 4 bytes]
// [prelude_crc: 4 bytes]
// [headers: variable]
// [payload: variable]
// [message_crc: 4 bytes]
type bedrockEventStreamDecoder struct {
reader *bufio.Reader
}
func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder {
return &bedrockEventStreamDecoder{
reader: bufio.NewReaderSize(r, 64*1024),
}
}
// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload
func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) {
for {
// 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes
prelude := make([]byte, 12)
if _, err := io.ReadFull(d.reader, prelude); err != nil {
return nil, err
}
// 验证 prelude CRCAWS EventStream 使用标准 CRC32 / IEEE
preludeCRC := bedrockReadUint32(prelude[8:12])
if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC {
return nil, fmt.Errorf("eventstream prelude CRC mismatch")
}
totalLength := bedrockReadUint32(prelude[0:4])
headersLength := bedrockReadUint32(prelude[4:8])
if totalLength < 16 { // minimum: 12 prelude + 4 message_crc
return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength)
}
// 读取 headers + payload + message_crc
remaining := int(totalLength) - 12
if remaining <= 0 {
continue
}
data := make([]byte, remaining)
if _, err := io.ReadFull(d.reader, data); err != nil {
return nil, err
}
// 验证 message CRC覆盖 prelude + headers + payload
messageCRC := bedrockReadUint32(data[len(data)-4:])
h := crc32.New(crc32IEEETable)
_, _ = h.Write(prelude)
_, _ = h.Write(data[:len(data)-4])
if h.Sum32() != messageCRC {
return nil, fmt.Errorf("eventstream message CRC mismatch")
}
// 解析 headers
headers := data[:headersLength]
payload := data[headersLength : len(data)-4] // 去掉 message_crc
// 从 headers 中提取 :event-type
eventType := extractEventStreamHeaderValue(headers, ":event-type")
// 只处理 chunk 事件
if eventType == "chunk" {
// payload 是完整的 JSON包含 bytes 字段
return payload, nil
}
// 检查异常事件
exceptionType := extractEventStreamHeaderValue(headers, ":exception-type")
if exceptionType != "" {
return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload))
}
messageType := extractEventStreamHeaderValue(headers, ":message-type")
if messageType == "exception" || messageType == "error" {
return nil, fmt.Errorf("bedrock error: %s", string(payload))
}
// 跳过其他事件类型(如 initial-response
}
}
// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值
// EventStream header 格式:
//
// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable]
//
// value_type = 7 表示 string 类型,前 2 bytes 为长度
func extractEventStreamHeaderValue(headers []byte, targetName string) string {
pos := 0
for pos < len(headers) {
if pos >= len(headers) {
break
}
nameLen := int(headers[pos])
pos++
if pos+nameLen > len(headers) {
break
}
name := string(headers[pos : pos+nameLen])
pos += nameLen
if pos >= len(headers) {
break
}
valueType := headers[pos]
pos++
switch valueType {
case 7: // string
if pos+2 > len(headers) {
return ""
}
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
pos += 2
if pos+valueLen > len(headers) {
return ""
}
value := string(headers[pos : pos+valueLen])
pos += valueLen
if name == targetName {
return value
}
case 0: // bool true
if name == targetName {
return "true"
}
case 1: // bool false
if name == targetName {
return "false"
}
case 2: // byte
pos++
if name == targetName {
return ""
}
case 3: // short
pos += 2
if name == targetName {
return ""
}
case 4: // int
pos += 4
if name == targetName {
return ""
}
case 5: // long
pos += 8
if name == targetName {
return ""
}
case 6: // bytes
if pos+2 > len(headers) {
return ""
}
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
pos += 2 + valueLen
case 8: // timestamp
pos += 8
case 9: // uuid
pos += 16
default:
return "" // 未知类型,无法继续解析
}
}
return ""
}
// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream.
var crc32IEEETable = crc32.MakeTable(crc32.IEEE)
func bedrockReadUint32(b []byte) uint32 {
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
}
func bedrockReadUint16(b []byte) uint16 {
return uint16(b[0])<<8 | uint16(b[1])
}

View File

@@ -0,0 +1,261 @@
package service
import (
"bytes"
"encoding/base64"
"encoding/binary"
"hash/crc32"
"io"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
)
func TestExtractBedrockChunkData(t *testing.T) {
t.Run("valid base64 payload", func(t *testing.T) {
original := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`
b64 := base64.StdEncoding.EncodeToString([]byte(original))
payload := []byte(`{"bytes":"` + b64 + `"}`)
result := extractBedrockChunkData(payload)
require.NotNil(t, result)
assert.JSONEq(t, original, string(result))
})
t.Run("empty bytes field", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"bytes":""}`))
assert.Nil(t, result)
})
t.Run("no bytes field", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"other":"value"}`))
assert.Nil(t, result)
})
t.Run("invalid base64", func(t *testing.T) {
result := extractBedrockChunkData([]byte(`{"bytes":"not-valid-base64!!!"}`))
assert.Nil(t, result)
})
}
func TestTransformBedrockInvocationMetrics(t *testing.T) {
t.Run("converts metrics to usage", func(t *testing.T) {
input := `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
result := transformBedrockInvocationMetrics([]byte(input))
// amazon-bedrock-invocationMetrics should be removed
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
// usage should be set
assert.Equal(t, int64(150), gjson.GetBytes(result, "usage.input_tokens").Int())
assert.Equal(t, int64(42), gjson.GetBytes(result, "usage.output_tokens").Int())
// original fields preserved
assert.Equal(t, "message_delta", gjson.GetBytes(result, "type").String())
assert.Equal(t, "end_turn", gjson.GetBytes(result, "delta.stop_reason").String())
})
t.Run("no metrics present", func(t *testing.T) {
input := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}`
result := transformBedrockInvocationMetrics([]byte(input))
assert.JSONEq(t, input, string(result))
})
t.Run("does not overwrite existing usage", func(t *testing.T) {
input := `{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
result := transformBedrockInvocationMetrics([]byte(input))
// metrics removed but existing usage preserved
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
assert.Equal(t, int64(100), gjson.GetBytes(result, "usage.output_tokens").Int())
})
}
func TestExtractEventStreamHeaderValue(t *testing.T) {
// Build a header with :event-type = "chunk" (string type = 7)
buildStringHeader := func(name, value string) []byte {
var buf bytes.Buffer
// name length (1 byte)
_ = buf.WriteByte(byte(len(name)))
// name
_, _ = buf.WriteString(name)
// value type (7 = string)
_ = buf.WriteByte(7)
// value length (2 bytes, big-endian)
_ = binary.Write(&buf, binary.BigEndian, uint16(len(value)))
// value
_, _ = buf.WriteString(value)
return buf.Bytes()
}
t.Run("find string header", func(t *testing.T) {
headers := buildStringHeader(":event-type", "chunk")
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
})
t.Run("header not found", func(t *testing.T) {
headers := buildStringHeader(":event-type", "chunk")
assert.Equal(t, "", extractEventStreamHeaderValue(headers, ":message-type"))
})
t.Run("multiple headers", func(t *testing.T) {
var buf bytes.Buffer
_, _ = buf.Write(buildStringHeader(":content-type", "application/json"))
_, _ = buf.Write(buildStringHeader(":event-type", "chunk"))
_, _ = buf.Write(buildStringHeader(":message-type", "event"))
headers := buf.Bytes()
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
assert.Equal(t, "application/json", extractEventStreamHeaderValue(headers, ":content-type"))
assert.Equal(t, "event", extractEventStreamHeaderValue(headers, ":message-type"))
})
t.Run("empty headers", func(t *testing.T) {
assert.Equal(t, "", extractEventStreamHeaderValue([]byte{}, ":event-type"))
})
}
func TestBedrockEventStreamDecoder(t *testing.T) {
crc32IeeeTab := crc32.MakeTable(crc32.IEEE)
// Build a valid EventStream frame with correct CRC32/IEEE checksums.
buildFrame := func(eventType string, payload []byte) []byte {
// Build headers
var headersBuf bytes.Buffer
// :event-type header
_ = headersBuf.WriteByte(byte(len(":event-type")))
_, _ = headersBuf.WriteString(":event-type")
_ = headersBuf.WriteByte(7) // string type
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType)))
_, _ = headersBuf.WriteString(eventType)
// :message-type header
_ = headersBuf.WriteByte(byte(len(":message-type")))
_, _ = headersBuf.WriteString(":message-type")
_ = headersBuf.WriteByte(7)
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event")))
_, _ = headersBuf.WriteString("event")
headers := headersBuf.Bytes()
headersLen := uint32(len(headers))
// total = 12 (prelude) + headers + payload + 4 (message_crc)
totalLen := uint32(12 + len(headers) + len(payload) + 4)
// Prelude: total_length(4) + headers_length(4)
var preludeBuf bytes.Buffer
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
preludeBytes := preludeBuf.Bytes()
preludeCRC := crc32.Checksum(preludeBytes, crc32IeeeTab)
// Build frame: prelude + prelude_crc + headers + payload
var frame bytes.Buffer
_, _ = frame.Write(preludeBytes)
_ = binary.Write(&frame, binary.BigEndian, preludeCRC)
_, _ = frame.Write(headers)
_, _ = frame.Write(payload)
// Message CRC covers everything before itself
messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab)
_ = binary.Write(&frame, binary.BigEndian, messageCRC)
return frame.Bytes()
}
t.Run("decode chunk event", func(t *testing.T) {
payload := []byte(`{"bytes":"dGVzdA=="}`) // base64("test")
frame := buildFrame("chunk", payload)
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
result, err := decoder.Decode()
require.NoError(t, err)
assert.Equal(t, payload, result)
})
t.Run("skip non-chunk events", func(t *testing.T) {
// Write initial-response followed by chunk
var buf bytes.Buffer
_, _ = buf.Write(buildFrame("initial-response", []byte(`{}`)))
chunkPayload := []byte(`{"bytes":"aGVsbG8="}`)
_, _ = buf.Write(buildFrame("chunk", chunkPayload))
decoder := newBedrockEventStreamDecoder(&buf)
result, err := decoder.Decode()
require.NoError(t, err)
assert.Equal(t, chunkPayload, result)
})
t.Run("EOF on empty input", func(t *testing.T) {
decoder := newBedrockEventStreamDecoder(bytes.NewReader(nil))
_, err := decoder.Decode()
assert.Equal(t, io.EOF, err)
})
t.Run("corrupted prelude CRC", func(t *testing.T) {
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
// Corrupt the prelude CRC (bytes 8-11)
frame[8] ^= 0xFF
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "prelude CRC mismatch")
})
t.Run("corrupted message CRC", func(t *testing.T) {
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
// Corrupt the message CRC (last 4 bytes)
frame[len(frame)-1] ^= 0xFF
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "message CRC mismatch")
})
t.Run("castagnoli encoded frame is rejected", func(t *testing.T) {
castagnoliTab := crc32.MakeTable(crc32.Castagnoli)
payload := []byte(`{"bytes":"dGVzdA=="}`)
var headersBuf bytes.Buffer
_ = headersBuf.WriteByte(byte(len(":event-type")))
_, _ = headersBuf.WriteString(":event-type")
_ = headersBuf.WriteByte(7)
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk")))
_, _ = headersBuf.WriteString("chunk")
headers := headersBuf.Bytes()
headersLen := uint32(len(headers))
totalLen := uint32(12 + len(headers) + len(payload) + 4)
var preludeBuf bytes.Buffer
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
preludeBytes := preludeBuf.Bytes()
var frame bytes.Buffer
_, _ = frame.Write(preludeBytes)
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab))
_, _ = frame.Write(headers)
_, _ = frame.Write(payload)
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab))
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes()))
_, err := decoder.Decode()
require.Error(t, err)
assert.Contains(t, err.Error(), "prelude CRC mismatch")
})
}
func TestBuildBedrockURL(t *testing.T) {
t.Run("stream URL with colon in model ID", func(t *testing.T) {
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-opus-4-5-20251101-v1:0", true)
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream", url)
})
t.Run("non-stream URL with colon in model ID", func(t *testing.T) {
url := BuildBedrockURL("eu-west-1", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", false)
assert.Equal(t, "https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", url)
})
t.Run("model ID without colon", func(t *testing.T) {
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-sonnet-4-6", true)
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream", url)
})
}

View File

@@ -0,0 +1,15 @@
package service
import (
"time"
)
// SubscriptionCacheData represents cached subscription data
type SubscriptionCacheData struct {
Status string
ExpiresAt time.Time
DailyUsage float64
WeeklyUsage float64
MonthlyUsage float64
Version int64
}

View File

@@ -0,0 +1,861 @@
package service
import (
"context"
"fmt"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"golang.org/x/sync/singleflight"
)
// 错误定义
// 注ErrInsufficientBalance在redeem_service.go中定义
// 注ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
var (
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
)
// subscriptionCacheData 订阅缓存数据结构(内部使用)
type subscriptionCacheData struct {
Status string
ExpiresAt time.Time
DailyUsage float64
WeeklyUsage float64
MonthlyUsage float64
Version int64
}
// 缓存写入任务类型
type cacheWriteKind int
const (
cacheWriteSetBalance cacheWriteKind = iota
cacheWriteSetSubscription
cacheWriteUpdateSubscriptionUsage
cacheWriteDeductBalance
cacheWriteUpdateRateLimitUsage
)
// 异步缓存写入工作池配置
//
// 性能优化说明:
// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题:
// 1. 每次请求创建新 goroutine高并发下产生大量短生命周期 goroutine
// 2. 无法控制并发数量,可能导致 Redis 连接耗尽
// 3. goroutine 创建/销毁带来额外开销
//
// 新实现使用固定大小的工作池:
// 1. 预创建 10 个 worker goroutine避免频繁创建销毁
// 2. 使用带缓冲的 channel1000作为任务队列平滑写入峰值
// 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警
// 4. 统一超时控制,避免慢操作阻塞工作池
const (
cacheWriteWorkerCount = 10 // 工作协程数量
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
balanceLoadTimeout = 3 * time.Second
)
// cacheWriteTask 缓存写入任务
type cacheWriteTask struct {
kind cacheWriteKind
userID int64
groupID int64
apiKeyID int64
balance float64
amount float64
subscriptionData *subscriptionCacheData
}
// apiKeyRateLimitLoader defines the interface for loading rate limit data from DB.
type apiKeyRateLimitLoader interface {
GetRateLimitData(ctx context.Context, keyID int64) (*APIKeyRateLimitData, error)
}
// BillingCacheService 计费缓存服务
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
type BillingCacheService struct {
cache BillingCache
userRepo UserRepository
subRepo UserSubscriptionRepository
apiKeyRateLimitLoader apiKeyRateLimitLoader
cfg *config.Config
circuitBreaker *billingCircuitBreaker
cacheWriteChan chan cacheWriteTask
cacheWriteWg sync.WaitGroup
cacheWriteStopOnce sync.Once
cacheWriteMu sync.RWMutex
stopped atomic.Bool
balanceLoadSF singleflight.Group
// 丢弃日志节流计数器(减少高负载下日志噪音)
cacheWriteDropFullCount uint64
cacheWriteDropFullLastLog int64
cacheWriteDropClosedCount uint64
cacheWriteDropClosedLastLog int64
}
// NewBillingCacheService 创建计费缓存服务
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService {
svc := &BillingCacheService{
cache: cache,
userRepo: userRepo,
subRepo: subRepo,
apiKeyRateLimitLoader: apiKeyRepo,
cfg: cfg,
}
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
svc.startCacheWriteWorkers()
return svc
}
// Stop 关闭缓存写入工作池
func (s *BillingCacheService) Stop() {
s.cacheWriteStopOnce.Do(func() {
s.stopped.Store(true)
s.cacheWriteMu.Lock()
ch := s.cacheWriteChan
if ch != nil {
close(ch)
}
s.cacheWriteMu.Unlock()
if ch == nil {
return
}
s.cacheWriteWg.Wait()
s.cacheWriteMu.Lock()
if s.cacheWriteChan == ch {
s.cacheWriteChan = nil
}
s.cacheWriteMu.Unlock()
})
}
func (s *BillingCacheService) startCacheWriteWorkers() {
ch := make(chan cacheWriteTask, cacheWriteBufferSize)
s.cacheWriteChan = ch
for i := 0; i < cacheWriteWorkerCount; i++ {
s.cacheWriteWg.Add(1)
go s.cacheWriteWorker(ch)
}
}
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false并记录告警
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
if s.stopped.Load() {
s.logCacheWriteDrop(task, "closed")
return false
}
s.cacheWriteMu.RLock()
defer s.cacheWriteMu.RUnlock()
if s.cacheWriteChan == nil {
s.logCacheWriteDrop(task, "closed")
return false
}
select {
case s.cacheWriteChan <- task:
return true
default:
// 队列满时不阻塞主流程,交由调用方决定是否同步回退。
s.logCacheWriteDrop(task, "full")
return false
}
}
func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
defer s.cacheWriteWg.Done()
for task := range ch {
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
switch task.kind {
case cacheWriteSetBalance:
s.setBalanceCache(ctx, task.userID, task.balance)
case cacheWriteSetSubscription:
s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData)
case cacheWriteUpdateSubscriptionUsage:
if s.cache != nil {
if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
}
}
case cacheWriteDeductBalance:
if s.cache != nil {
if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err)
}
}
case cacheWriteUpdateRateLimitUsage:
if s.cache != nil {
if err := s.cache.UpdateAPIKeyRateLimitUsage(ctx, task.apiKeyID, task.amount); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: update rate limit usage cache failed for api key %d: %v", task.apiKeyID, err)
}
}
}
cancel()
}
}
// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。
func cacheWriteKindName(kind cacheWriteKind) string {
switch kind {
case cacheWriteSetBalance:
return "set_balance"
case cacheWriteSetSubscription:
return "set_subscription"
case cacheWriteUpdateSubscriptionUsage:
return "update_subscription_usage"
case cacheWriteDeductBalance:
return "deduct_balance"
case cacheWriteUpdateRateLimitUsage:
return "update_rate_limit_usage"
default:
return "unknown"
}
}
// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。
func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) {
var (
countPtr *uint64
lastPtr *int64
)
switch reason {
case "full":
countPtr = &s.cacheWriteDropFullCount
lastPtr = &s.cacheWriteDropFullLastLog
case "closed":
countPtr = &s.cacheWriteDropClosedCount
lastPtr = &s.cacheWriteDropClosedLastLog
default:
return
}
atomic.AddUint64(countPtr, 1)
now := time.Now().UnixNano()
last := atomic.LoadInt64(lastPtr)
if now-last < int64(cacheWriteDropLogInterval) {
return
}
if !atomic.CompareAndSwapInt64(lastPtr, last, now) {
return
}
dropped := atomic.SwapUint64(countPtr, 0)
if dropped == 0 {
return
}
logger.LegacyPrintf("service.billing_cache", "Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
reason,
dropped,
cacheWriteDropLogInterval,
cacheWriteKindName(task.kind),
task.userID,
task.groupID,
)
}
// ============================================
// 余额缓存方法
// ============================================
// GetUserBalance 获取用户余额(优先从缓存读取)
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
if s.cache == nil {
// Redis不可用直接查询数据库
return s.getUserBalanceFromDB(ctx, userID)
}
// 尝试从缓存读取
balance, err := s.cache.GetUserBalance(ctx, userID)
if err == nil {
return balance, nil
}
// 缓存未命中singleflight 合并同一 userID 的并发回源请求。
value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) {
loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout)
defer cancel()
balance, err := s.getUserBalanceFromDB(loadCtx, userID)
if err != nil {
return nil, err
}
// 异步建立缓存
_ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetBalance,
userID: userID,
balance: balance,
})
return balance, nil
})
if err != nil {
return 0, err
}
balance, ok := value.(float64)
if !ok {
return 0, fmt.Errorf("unexpected balance type: %T", value)
}
return balance, nil
}
// getUserBalanceFromDB 从数据库获取用户余额
func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID int64) (float64, error) {
user, err := s.userRepo.GetByID(ctx, userID)
if err != nil {
return 0, fmt.Errorf("get user balance: %w", err)
}
return user.Balance, nil
}
// setBalanceCache 设置余额缓存
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
if s.cache == nil {
return
}
if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: set balance cache failed for user %d: %v", userID, err)
}
}
// DeductBalanceCache 扣减余额缓存(同步调用)
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
if s.cache == nil {
return nil
}
return s.cache.DeductUserBalance(ctx, userID, amount)
}
// QueueDeductBalance 异步扣减余额缓存
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
if s.cache == nil {
return
}
// 队列满时同步回退,避免关键扣减被静默丢弃。
if s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteDeductBalance,
userID: userID,
amount: amount,
}) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.DeductBalanceCache(ctx, userID, amount); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
}
}
// InvalidateUserBalance 失效用户余额缓存
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
if s.cache == nil {
return nil
}
if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate balance cache failed for user %d: %v", userID, err)
return err
}
return nil
}
// ============================================
// 订阅缓存方法
// ============================================
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
if s.cache == nil {
return s.getSubscriptionFromDB(ctx, userID, groupID)
}
// 尝试从缓存读取
cacheData, err := s.cache.GetSubscriptionCache(ctx, userID, groupID)
if err == nil && cacheData != nil {
return s.convertFromPortsData(cacheData), nil
}
// 缓存未命中,从数据库读取
data, err := s.getSubscriptionFromDB(ctx, userID, groupID)
if err != nil {
return nil, err
}
// 异步建立缓存
_ = s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteSetSubscription,
userID: userID,
groupID: groupID,
subscriptionData: data,
})
return data, nil
}
func (s *BillingCacheService) convertFromPortsData(data *SubscriptionCacheData) *subscriptionCacheData {
return &subscriptionCacheData{
Status: data.Status,
ExpiresAt: data.ExpiresAt,
DailyUsage: data.DailyUsage,
WeeklyUsage: data.WeeklyUsage,
MonthlyUsage: data.MonthlyUsage,
Version: data.Version,
}
}
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *SubscriptionCacheData {
return &SubscriptionCacheData{
Status: data.Status,
ExpiresAt: data.ExpiresAt,
DailyUsage: data.DailyUsage,
WeeklyUsage: data.WeeklyUsage,
MonthlyUsage: data.MonthlyUsage,
Version: data.Version,
}
}
// getSubscriptionFromDB 从数据库获取订阅数据
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
if err != nil {
return nil, fmt.Errorf("get subscription: %w", err)
}
return &subscriptionCacheData{
Status: sub.Status,
ExpiresAt: sub.ExpiresAt,
DailyUsage: sub.DailyUsageUSD,
WeeklyUsage: sub.WeeklyUsageUSD,
MonthlyUsage: sub.MonthlyUsageUSD,
Version: sub.UpdatedAt.Unix(),
}, nil
}
// setSubscriptionCache 设置订阅缓存
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
if s.cache == nil || data == nil {
return
}
if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
}
}
// UpdateSubscriptionUsage 更新订阅用量缓存(同步调用)
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
if s.cache == nil {
return nil
}
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
}
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
if s.cache == nil {
return
}
// 队列满时同步回退,确保订阅用量及时更新。
if s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteUpdateSubscriptionUsage,
userID: userID,
groupID: groupID,
amount: costUSD,
}) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
}
}
// InvalidateSubscription 失效指定订阅缓存
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
if s.cache == nil {
return nil
}
if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
return err
}
return nil
}
// ============================================
// API Key 限速缓存方法
// ============================================
// checkAPIKeyRateLimits checks rate limit windows for an API key.
// It loads usage from Redis cache (falling back to DB on cache miss),
// resets expired windows in-memory and triggers async DB reset,
// and returns an error if any window limit is exceeded.
func (s *BillingCacheService) checkAPIKeyRateLimits(ctx context.Context, apiKey *APIKey) error {
if s.cache == nil {
// No cache: fall back to reading from DB directly
if s.apiKeyRateLimitLoader == nil {
return nil
}
data, err := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
if err != nil {
return nil // Don't block requests on DB errors
}
return s.evaluateRateLimits(ctx, apiKey, data.Usage5h, data.Usage1d, data.Usage7d,
data.Window5hStart, data.Window1dStart, data.Window7dStart)
}
cacheData, err := s.cache.GetAPIKeyRateLimit(ctx, apiKey.ID)
if err != nil {
// Cache miss: load from DB and populate cache
if s.apiKeyRateLimitLoader == nil {
return nil
}
dbData, dbErr := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
if dbErr != nil {
return nil // Don't block requests on DB errors
}
// Build cache entry from DB data
cacheEntry := &APIKeyRateLimitCacheData{
Usage5h: dbData.Usage5h,
Usage1d: dbData.Usage1d,
Usage7d: dbData.Usage7d,
}
if dbData.Window5hStart != nil {
cacheEntry.Window5h = dbData.Window5hStart.Unix()
}
if dbData.Window1dStart != nil {
cacheEntry.Window1d = dbData.Window1dStart.Unix()
}
if dbData.Window7dStart != nil {
cacheEntry.Window7d = dbData.Window7dStart.Unix()
}
_ = s.cache.SetAPIKeyRateLimit(ctx, apiKey.ID, cacheEntry)
cacheData = cacheEntry
}
var w5h, w1d, w7d *time.Time
if cacheData.Window5h > 0 {
t := time.Unix(cacheData.Window5h, 0)
w5h = &t
}
if cacheData.Window1d > 0 {
t := time.Unix(cacheData.Window1d, 0)
w1d = &t
}
if cacheData.Window7d > 0 {
t := time.Unix(cacheData.Window7d, 0)
w7d = &t
}
return s.evaluateRateLimits(ctx, apiKey, cacheData.Usage5h, cacheData.Usage1d, cacheData.Usage7d, w5h, w1d, w7d)
}
// evaluateRateLimits checks usage against limits, triggering async resets for expired windows.
func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *APIKey, usage5h, usage1d, usage7d float64, w5h, w1d, w7d *time.Time) error {
needsReset := false
// Reset expired windows in-memory for check purposes
if IsWindowExpired(w5h, RateLimitWindow5h) {
usage5h = 0
needsReset = true
}
if IsWindowExpired(w1d, RateLimitWindow1d) {
usage1d = 0
needsReset = true
}
if IsWindowExpired(w7d, RateLimitWindow7d) {
usage7d = 0
needsReset = true
}
// Trigger async DB reset if any window expired
if needsReset {
keyID := apiKey.ID
go func() {
resetCtx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
defer cancel()
if s.apiKeyRateLimitLoader != nil {
// Use the repo directly - reset then reload cache
if loader, ok := s.apiKeyRateLimitLoader.(interface {
ResetRateLimitWindows(ctx context.Context, id int64) error
}); ok {
if err := loader.ResetRateLimitWindows(resetCtx, keyID); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: reset rate limit windows failed for api key %d: %v", keyID, err)
}
}
}
// Invalidate cache so next request loads fresh data
if s.cache != nil {
if err := s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID); err != nil {
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate rate limit cache failed for api key %d: %v", keyID, err)
}
}
}()
}
// Check limits
if apiKey.RateLimit5h > 0 && usage5h >= apiKey.RateLimit5h {
return ErrAPIKeyRateLimit5hExceeded
}
if apiKey.RateLimit1d > 0 && usage1d >= apiKey.RateLimit1d {
return ErrAPIKeyRateLimit1dExceeded
}
if apiKey.RateLimit7d > 0 && usage7d >= apiKey.RateLimit7d {
return ErrAPIKeyRateLimit7dExceeded
}
return nil
}
// QueueUpdateAPIKeyRateLimitUsage asynchronously updates rate limit usage in the cache.
func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, cost float64) {
if s.cache == nil {
return
}
s.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteUpdateRateLimitUsage,
apiKeyID: apiKeyID,
amount: cost,
})
}
// ============================================
// 统一检查方法
// ============================================
// CheckBillingEligibility 检查用户是否有资格发起请求
// 余额模式:检查缓存余额 > 0
// 订阅模式检查缓存用量未超过限额Group限额从参数传入
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
// 简易模式:跳过所有计费检查
if s.cfg.RunMode == config.RunModeSimple {
return nil
}
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
return ErrBillingServiceUnavailable
}
// 判断计费模式
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
if isSubscriptionMode {
if err := s.checkSubscriptionEligibility(ctx, user.ID, group, subscription); err != nil {
return err
}
} else {
if err := s.checkBalanceEligibility(ctx, user.ID); err != nil {
return err
}
}
// Check API Key rate limits (applies to both billing modes)
if apiKey != nil && apiKey.HasRateLimits() {
if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil {
return err
}
}
return nil
}
// checkBalanceEligibility 检查余额模式资格
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
balance, err := s.GetUserBalance(ctx, userID)
if err != nil {
if s.circuitBreaker != nil {
s.circuitBreaker.OnFailure(err)
}
logger.LegacyPrintf("service.billing_cache", "ALERT: billing balance check failed for user %d: %v", userID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
}
if balance <= 0 {
return ErrInsufficientBalance
}
return nil
}
// checkSubscriptionEligibility 检查订阅模式资格
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *Group, subscription *UserSubscription) error {
// 获取订阅缓存数据
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
if err != nil {
if s.circuitBreaker != nil {
s.circuitBreaker.OnFailure(err)
}
logger.LegacyPrintf("service.billing_cache", "ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
return ErrBillingServiceUnavailable.WithCause(err)
}
if s.circuitBreaker != nil {
s.circuitBreaker.OnSuccess()
}
// 检查订阅状态
if subData.Status != SubscriptionStatusActive {
return ErrSubscriptionInvalid
}
// 检查是否过期
if time.Now().After(subData.ExpiresAt) {
return ErrSubscriptionInvalid
}
// 检查限额使用传入的Group限额配置
if group.HasDailyLimit() && subData.DailyUsage >= *group.DailyLimitUSD {
return ErrDailyLimitExceeded
}
if group.HasWeeklyLimit() && subData.WeeklyUsage >= *group.WeeklyLimitUSD {
return ErrWeeklyLimitExceeded
}
if group.HasMonthlyLimit() && subData.MonthlyUsage >= *group.MonthlyLimitUSD {
return ErrMonthlyLimitExceeded
}
return nil
}
type billingCircuitBreakerState int
const (
billingCircuitClosed billingCircuitBreakerState = iota
billingCircuitOpen
billingCircuitHalfOpen
)
type billingCircuitBreaker struct {
mu sync.Mutex
state billingCircuitBreakerState
failures int
openedAt time.Time
failureThreshold int
resetTimeout time.Duration
halfOpenRequests int
halfOpenRemaining int
}
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
if !cfg.Enabled {
return nil
}
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
if resetTimeout <= 0 {
resetTimeout = 30 * time.Second
}
halfOpen := cfg.HalfOpenRequests
if halfOpen <= 0 {
halfOpen = 1
}
threshold := cfg.FailureThreshold
if threshold <= 0 {
threshold = 5
}
return &billingCircuitBreaker{
state: billingCircuitClosed,
failureThreshold: threshold,
resetTimeout: resetTimeout,
halfOpenRequests: halfOpen,
}
}
func (b *billingCircuitBreaker) Allow() bool {
b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitClosed:
return true
case billingCircuitOpen:
if time.Since(b.openedAt) < b.resetTimeout {
return false
}
b.state = billingCircuitHalfOpen
b.halfOpenRemaining = b.halfOpenRequests
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker entering half-open state")
fallthrough
case billingCircuitHalfOpen:
if b.halfOpenRemaining <= 0 {
return false
}
b.halfOpenRemaining--
return true
default:
return false
}
}
func (b *billingCircuitBreaker) OnFailure(err error) {
if b == nil {
return
}
b.mu.Lock()
defer b.mu.Unlock()
switch b.state {
case billingCircuitOpen:
return
case billingCircuitHalfOpen:
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after half-open failure: %v", err)
return
default:
b.failures++
if b.failures >= b.failureThreshold {
b.state = billingCircuitOpen
b.openedAt = time.Now()
b.halfOpenRemaining = 0
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
}
}
}
func (b *billingCircuitBreaker) OnSuccess() {
if b == nil {
return
}
b.mu.Lock()
defer b.mu.Unlock()
previousState := b.state
previousFailures := b.failures
b.state = billingCircuitClosed
b.failures = 0
b.halfOpenRemaining = 0
// 只有状态真正发生变化时才记录日志
if previousState != billingCircuitClosed {
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
} else if previousFailures > 0 {
logger.LegacyPrintf("service.billing_cache", "INFO: billing circuit breaker failures reset from %d", previousFailures)
}
}
func circuitStateString(state billingCircuitBreakerState) string {
switch state {
case billingCircuitClosed:
return "closed"
case billingCircuitOpen:
return "open"
case billingCircuitHalfOpen:
return "half-open"
default:
return "unknown"
}
}

View File

@@ -0,0 +1,131 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type billingCacheMissStub struct {
setBalanceCalls atomic.Int64
}
func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
return 0, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
s.setBalanceCalls.Add(1)
return nil
}
func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
return nil
}
func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
return nil, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
return nil
}
func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
return nil
}
func (s *billingCacheMissStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
return nil, errors.New("cache miss")
}
func (s *billingCacheMissStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
return nil
}
func (s *billingCacheMissStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
return nil
}
func (s *billingCacheMissStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
return nil
}
type balanceLoadUserRepoStub struct {
mockUserRepo
calls atomic.Int64
delay time.Duration
balance float64
}
func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
s.calls.Add(1)
if s.delay > 0 {
select {
case <-time.After(s.delay):
case <-ctx.Done():
return nil, ctx.Err()
}
}
return &User{ID: id, Balance: s.balance}, nil
}
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
cache := &billingCacheMissStub{}
userRepo := &balanceLoadUserRepoStub{
delay: 80 * time.Millisecond,
balance: 12.34,
}
svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
const goroutines = 16
start := make(chan struct{})
var wg sync.WaitGroup
errCh := make(chan error, goroutines)
balCh := make(chan float64, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
bal, err := svc.GetUserBalance(context.Background(), 99)
errCh <- err
balCh <- bal
}()
}
close(start)
wg.Wait()
close(errCh)
close(balCh)
for err := range errCh {
require.NoError(t, err)
}
for bal := range balCh {
require.Equal(t, 12.34, bal)
}
require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并")
require.Eventually(t, func() bool {
return cache.setBalanceCalls.Load() >= 1
}, time.Second, 10*time.Millisecond)
}

View File

@@ -0,0 +1,104 @@
package service
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type billingCacheWorkerStub struct {
balanceUpdates int64
subscriptionUpdates int64
}
func (b *billingCacheWorkerStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
return 0, errors.New("not implemented")
}
func (b *billingCacheWorkerStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
atomic.AddInt64(&b.balanceUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
atomic.AddInt64(&b.balanceUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
return nil
}
func (b *billingCacheWorkerStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
return nil, errors.New("not implemented")
}
func (b *billingCacheWorkerStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
atomic.AddInt64(&b.subscriptionUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
atomic.AddInt64(&b.subscriptionUpdates, 1)
return nil
}
func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
return nil
}
func (b *billingCacheWorkerStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
return nil, errors.New("not implemented")
}
func (b *billingCacheWorkerStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
return nil
}
func (b *billingCacheWorkerStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
return nil
}
func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
return nil
}
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
t.Cleanup(svc.Stop)
start := time.Now()
for i := 0; i < cacheWriteBufferSize*2; i++ {
svc.QueueDeductBalance(1, 1)
}
require.Less(t, time.Since(start), 2*time.Second)
svc.QueueUpdateSubscriptionUsage(1, 2, 1.5)
require.Eventually(t, func() bool {
return atomic.LoadInt64(&cache.balanceUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
require.Eventually(t, func() bool {
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
}, 2*time.Second, 10*time.Millisecond)
}
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
cache := &billingCacheWorkerStub{}
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
svc.Stop()
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
kind: cacheWriteDeductBalance,
userID: 1,
amount: 1,
})
require.False(t, enqueued)
}

View File

@@ -0,0 +1,763 @@
package service
import (
"context"
"fmt"
"log"
"strings"
"github.com/Wei-Shaw/sub2api/internal/config"
)
// APIKeyRateLimitCacheData holds rate limit usage data cached in Redis.
type APIKeyRateLimitCacheData struct {
Usage5h float64 `json:"usage_5h"`
Usage1d float64 `json:"usage_1d"`
Usage7d float64 `json:"usage_7d"`
Window5h int64 `json:"window_5h"` // unix timestamp, 0 = not started
Window1d int64 `json:"window_1d"`
Window7d int64 `json:"window_7d"`
}
// BillingCache defines cache operations for billing service
type BillingCache interface {
// Balance operations
GetUserBalance(ctx context.Context, userID int64) (float64, error)
SetUserBalance(ctx context.Context, userID int64, balance float64) error
DeductUserBalance(ctx context.Context, userID int64, amount float64) error
InvalidateUserBalance(ctx context.Context, userID int64) error
// Subscription operations
GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error)
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
// API Key rate limit operations
GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error)
SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error
UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
}
// ModelPricing 模型价格配置per-token价格与LiteLLM格式一致
type ModelPricing struct {
InputPricePerToken float64 // 每token输入价格 (USD)
InputPricePerTokenPriority float64 // priority service tier 下每token输入价格 (USD)
OutputPricePerToken float64 // 每token输出价格 (USD)
OutputPricePerTokenPriority float64 // priority service tier 下每token输出价格 (USD)
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
CacheReadPricePerTokenPriority float64 // priority service tier 下缓存读取每token价格 (USD)
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
}
const (
openAIGPT54LongContextInputThreshold = 272000
openAIGPT54LongContextInputMultiplier = 2.0
openAIGPT54LongContextOutputMultiplier = 1.5
)
func normalizeBillingServiceTier(serviceTier string) string {
return strings.ToLower(strings.TrimSpace(serviceTier))
}
func usePriorityServiceTierPricing(serviceTier string, pricing *ModelPricing) bool {
if pricing == nil || normalizeBillingServiceTier(serviceTier) != "priority" {
return false
}
return pricing.InputPricePerTokenPriority > 0 || pricing.OutputPricePerTokenPriority > 0 || pricing.CacheReadPricePerTokenPriority > 0
}
func serviceTierCostMultiplier(serviceTier string) float64 {
switch normalizeBillingServiceTier(serviceTier) {
case "priority":
return 2.0
case "flex":
return 0.5
default:
return 1.0
}
}
// UsageTokens 使用的token数量
type UsageTokens struct {
InputTokens int
OutputTokens int
CacheCreationTokens int
CacheReadTokens int
CacheCreation5mTokens int
CacheCreation1hTokens int
}
// CostBreakdown 费用明细
type CostBreakdown struct {
InputCost float64
OutputCost float64
CacheCreationCost float64
CacheReadCost float64
TotalCost float64
ActualCost float64 // 应用倍率后的实际费用
}
// BillingService 计费服务
type BillingService struct {
cfg *config.Config
pricingService *PricingService
fallbackPrices map[string]*ModelPricing // 硬编码回退价格
}
// NewBillingService 创建计费服务实例
func NewBillingService(cfg *config.Config, pricingService *PricingService) *BillingService {
s := &BillingService{
cfg: cfg,
pricingService: pricingService,
fallbackPrices: make(map[string]*ModelPricing),
}
// 初始化硬编码回退价格(当动态价格不可用时使用)
s.initFallbackPricing()
return s
}
// initFallbackPricing 初始化硬编码回退价格(当动态价格不可用时使用)
// 价格单位USD per token与LiteLLM格式一致
func (s *BillingService) initFallbackPricing() {
// Claude 4.5 Opus
s.fallbackPrices["claude-opus-4.5"] = &ModelPricing{
InputPricePerToken: 5e-6, // $5 per MTok
OutputPricePerToken: 25e-6, // $25 per MTok
CacheCreationPricePerToken: 6.25e-6, // $6.25 per MTok
CacheReadPricePerToken: 0.5e-6, // $0.50 per MTok
SupportsCacheBreakdown: false,
}
// Claude 4 Sonnet
s.fallbackPrices["claude-sonnet-4"] = &ModelPricing{
InputPricePerToken: 3e-6, // $3 per MTok
OutputPricePerToken: 15e-6, // $15 per MTok
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
SupportsCacheBreakdown: false,
}
// Claude 3.5 Sonnet
s.fallbackPrices["claude-3-5-sonnet"] = &ModelPricing{
InputPricePerToken: 3e-6, // $3 per MTok
OutputPricePerToken: 15e-6, // $15 per MTok
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
SupportsCacheBreakdown: false,
}
// Claude 3.5 Haiku
s.fallbackPrices["claude-3-5-haiku"] = &ModelPricing{
InputPricePerToken: 1e-6, // $1 per MTok
OutputPricePerToken: 5e-6, // $5 per MTok
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
CacheReadPricePerToken: 0.1e-6, // $0.10 per MTok
SupportsCacheBreakdown: false,
}
// Claude 3 Opus
s.fallbackPrices["claude-3-opus"] = &ModelPricing{
InputPricePerToken: 15e-6, // $15 per MTok
OutputPricePerToken: 75e-6, // $75 per MTok
CacheCreationPricePerToken: 18.75e-6, // $18.75 per MTok
CacheReadPricePerToken: 1.5e-6, // $1.50 per MTok
SupportsCacheBreakdown: false,
}
// Claude 3 Haiku
s.fallbackPrices["claude-3-haiku"] = &ModelPricing{
InputPricePerToken: 0.25e-6, // $0.25 per MTok
OutputPricePerToken: 1.25e-6, // $1.25 per MTok
CacheCreationPricePerToken: 0.3e-6, // $0.30 per MTok
CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok
SupportsCacheBreakdown: false,
}
// Claude 4.6 Opus (与4.5同价)
s.fallbackPrices["claude-opus-4.6"] = s.fallbackPrices["claude-opus-4.5"]
// Gemini 3.1 Pro
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
InputPricePerToken: 2e-6, // $2 per MTok
OutputPricePerToken: 12e-6, // $12 per MTok
CacheCreationPricePerToken: 2e-6, // $2 per MTok
CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok
SupportsCacheBreakdown: false,
}
// OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
s.fallbackPrices["gpt-5.1"] = &ModelPricing{
InputPricePerToken: 1.25e-6, // $1.25 per MTok
InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok
OutputPricePerToken: 10e-6, // $10 per MTok
OutputPricePerTokenPriority: 20e-6, // $20 per MTok
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
CacheReadPricePerToken: 0.125e-6,
CacheReadPricePerTokenPriority: 0.25e-6,
SupportsCacheBreakdown: false,
}
// OpenAI GPT-5.4(业务指定价格)
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
InputPricePerToken: 2.5e-6, // $2.5 per MTok
InputPricePerTokenPriority: 5e-6, // $5 per MTok
OutputPricePerToken: 15e-6, // $15 per MTok
OutputPricePerTokenPriority: 30e-6, // $30 per MTok
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
CacheReadPricePerTokenPriority: 0.5e-6, // $0.5 per MTok
SupportsCacheBreakdown: false,
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
}
// OpenAI GPT-5.2(本地兜底)
s.fallbackPrices["gpt-5.2"] = &ModelPricing{
InputPricePerToken: 1.75e-6,
InputPricePerTokenPriority: 3.5e-6,
OutputPricePerToken: 14e-6,
OutputPricePerTokenPriority: 28e-6,
CacheCreationPricePerToken: 1.75e-6,
CacheReadPricePerToken: 0.175e-6,
CacheReadPricePerTokenPriority: 0.35e-6,
SupportsCacheBreakdown: false,
}
// Codex 族兜底统一按 GPT-5.1 Codex 价格计费
s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
InputPricePerToken: 1.5e-6, // $1.5 per MTok
InputPricePerTokenPriority: 3e-6, // $3 per MTok
OutputPricePerToken: 12e-6, // $12 per MTok
OutputPricePerTokenPriority: 24e-6, // $24 per MTok
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
CacheReadPricePerToken: 0.15e-6,
CacheReadPricePerTokenPriority: 0.3e-6,
SupportsCacheBreakdown: false,
}
s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{
InputPricePerToken: 1.75e-6,
InputPricePerTokenPriority: 3.5e-6,
OutputPricePerToken: 14e-6,
OutputPricePerTokenPriority: 28e-6,
CacheCreationPricePerToken: 1.75e-6,
CacheReadPricePerToken: 0.175e-6,
CacheReadPricePerTokenPriority: 0.35e-6,
SupportsCacheBreakdown: false,
}
s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
}
// getFallbackPricing 根据模型系列获取回退价格
func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
modelLower := strings.ToLower(model)
// 按模型系列匹配
if strings.Contains(modelLower, "opus") {
if strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6") {
return s.fallbackPrices["claude-opus-4.6"]
}
if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") {
return s.fallbackPrices["claude-opus-4.5"]
}
return s.fallbackPrices["claude-3-opus"]
}
if strings.Contains(modelLower, "sonnet") {
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
return s.fallbackPrices["claude-sonnet-4"]
}
return s.fallbackPrices["claude-3-5-sonnet"]
}
if strings.Contains(modelLower, "haiku") {
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
return s.fallbackPrices["claude-3-5-haiku"]
}
return s.fallbackPrices["claude-3-haiku"]
}
// Claude 未知型号统一回退到 Sonnet避免计费中断。
if strings.Contains(modelLower, "claude") {
return s.fallbackPrices["claude-sonnet-4"]
}
if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") {
return s.fallbackPrices["gemini-3.1-pro"]
}
// OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
normalized := normalizeCodexModel(modelLower)
switch normalized {
case "gpt-5.4":
return s.fallbackPrices["gpt-5.4"]
case "gpt-5.2":
return s.fallbackPrices["gpt-5.2"]
case "gpt-5.2-codex":
return s.fallbackPrices["gpt-5.2-codex"]
case "gpt-5.3-codex":
return s.fallbackPrices["gpt-5.3-codex"]
case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
return s.fallbackPrices["gpt-5.1-codex"]
case "gpt-5.1":
return s.fallbackPrices["gpt-5.1"]
}
}
return nil
}
// GetModelPricing 获取模型价格配置
func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
// 标准化模型名称(转小写)
model = strings.ToLower(model)
// 1. 优先从动态价格服务获取
if s.pricingService != nil {
litellmPricing := s.pricingService.GetModelPricing(model)
if litellmPricing != nil {
// 启用 5m/1h 分类计费的条件:
// 1. 存在 1h 价格
// 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费)
price5m := litellmPricing.CacheCreationInputTokenCost
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
enableBreakdown := price1h > 0 && price1h > price5m
return s.applyModelSpecificPricingPolicy(model, &ModelPricing{
InputPricePerToken: litellmPricing.InputCostPerToken,
InputPricePerTokenPriority: litellmPricing.InputCostPerTokenPriority,
OutputPricePerToken: litellmPricing.OutputCostPerToken,
OutputPricePerTokenPriority: litellmPricing.OutputCostPerTokenPriority,
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
CacheReadPricePerTokenPriority: litellmPricing.CacheReadInputTokenCostPriority,
CacheCreation5mPrice: price5m,
CacheCreation1hPrice: price1h,
SupportsCacheBreakdown: enableBreakdown,
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
}), nil
}
}
// 2. 使用硬编码回退价格
fallback := s.getFallbackPricing(model)
if fallback != nil {
log.Printf("[Billing] Using fallback pricing for model: %s", model)
return s.applyModelSpecificPricingPolicy(model, fallback), nil
}
return nil, fmt.Errorf("pricing not found for model: %s", model)
}
// CalculateCost 计算使用费用
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "")
}
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
pricing, err := s.GetModelPricing(model)
if err != nil {
return nil, err
}
breakdown := &CostBreakdown{}
inputPricePerToken := pricing.InputPricePerToken
outputPricePerToken := pricing.OutputPricePerToken
cacheReadPricePerToken := pricing.CacheReadPricePerToken
tierMultiplier := 1.0
if usePriorityServiceTierPricing(serviceTier, pricing) {
if pricing.InputPricePerTokenPriority > 0 {
inputPricePerToken = pricing.InputPricePerTokenPriority
}
if pricing.OutputPricePerTokenPriority > 0 {
outputPricePerToken = pricing.OutputPricePerTokenPriority
}
if pricing.CacheReadPricePerTokenPriority > 0 {
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority
}
} else {
tierMultiplier = serviceTierCostMultiplier(serviceTier)
}
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
inputPricePerToken *= pricing.LongContextInputMultiplier
outputPricePerToken *= pricing.LongContextOutputMultiplier
}
// 计算输入token费用使用per-token价格
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken
// 计算输出token费用
breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken
// 计算缓存费用
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
// 支持详细缓存分类的模型5分钟/1小时缓存价格为 per-token
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
} else {
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
}
} else {
// 标准缓存创建价格per-token
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
}
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken
if tierMultiplier != 1.0 {
breakdown.InputCost *= tierMultiplier
breakdown.OutputCost *= tierMultiplier
breakdown.CacheCreationCost *= tierMultiplier
breakdown.CacheReadCost *= tierMultiplier
}
// 计算总费用
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
breakdown.CacheCreationCost + breakdown.CacheReadCost
// 应用倍率计算实际费用
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
breakdown.ActualCost = breakdown.TotalCost * rateMultiplier
return breakdown, nil
}
func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing {
if pricing == nil {
return nil
}
if !isOpenAIGPT54Model(model) {
return pricing
}
if pricing.LongContextInputThreshold > 0 && pricing.LongContextInputMultiplier > 0 && pricing.LongContextOutputMultiplier > 0 {
return pricing
}
cloned := *pricing
if cloned.LongContextInputThreshold <= 0 {
cloned.LongContextInputThreshold = openAIGPT54LongContextInputThreshold
}
if cloned.LongContextInputMultiplier <= 0 {
cloned.LongContextInputMultiplier = openAIGPT54LongContextInputMultiplier
}
if cloned.LongContextOutputMultiplier <= 0 {
cloned.LongContextOutputMultiplier = openAIGPT54LongContextOutputMultiplier
}
return &cloned
}
func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens, pricing *ModelPricing) bool {
if pricing == nil || pricing.LongContextInputThreshold <= 0 {
return false
}
if pricing.LongContextInputMultiplier <= 1 && pricing.LongContextOutputMultiplier <= 1 {
return false
}
totalInputTokens := tokens.InputTokens + tokens.CacheReadTokens
return totalInputTokens > pricing.LongContextInputThreshold
}
func isOpenAIGPT54Model(model string) bool {
normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model)))
return normalized == "gpt-5.4"
}
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) {
multiplier := s.cfg.Default.RateMultiplier
if multiplier <= 0 {
multiplier = 1.0
}
return s.CalculateCost(model, tokens, multiplier)
}
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
// threshold: 阈值(如 200000超过此值的部分按 extraMultiplier 倍计费
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
//
// 示例:缓存 210k + 输入 10k = 220k阈值 200k倍率 2.0
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
// 范围内正常计费,范围外 × 2 计费
func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) {
// 未启用长上下文计费,直接走正常计费
if threshold <= 0 || extraMultiplier <= 1 {
return s.CalculateCost(model, tokens, rateMultiplier)
}
// 计算总输入 token缓存读取 + 新输入)
total := tokens.CacheReadTokens + tokens.InputTokens
if total <= threshold {
return s.CalculateCost(model, tokens, rateMultiplier)
}
// 拆分成范围内和范围外
var inRangeCacheTokens, inRangeInputTokens int
var outRangeCacheTokens, outRangeInputTokens int
if tokens.CacheReadTokens >= threshold {
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
inRangeCacheTokens = threshold
inRangeInputTokens = 0
outRangeCacheTokens = tokens.CacheReadTokens - threshold
outRangeInputTokens = tokens.InputTokens
} else {
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
inRangeCacheTokens = tokens.CacheReadTokens
inRangeInputTokens = threshold - tokens.CacheReadTokens
outRangeCacheTokens = 0
outRangeInputTokens = tokens.InputTokens - inRangeInputTokens
}
// 范围内部分:正常计费
inRangeTokens := UsageTokens{
InputTokens: inRangeInputTokens,
OutputTokens: tokens.OutputTokens, // 输出只算一次
CacheCreationTokens: tokens.CacheCreationTokens,
CacheReadTokens: inRangeCacheTokens,
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
}
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
if err != nil {
return nil, err
}
// 范围外部分:× extraMultiplier 计费
outRangeTokens := UsageTokens{
InputTokens: outRangeInputTokens,
CacheReadTokens: outRangeCacheTokens,
}
outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier)
if err != nil {
return inRangeCost, fmt.Errorf("out-range cost: %w", err)
}
// 合并成本
return &CostBreakdown{
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
OutputCost: inRangeCost.OutputCost,
CacheCreationCost: inRangeCost.CacheCreationCost,
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost,
}, nil
}
// ListSupportedModels 列出所有支持的模型现在总是返回true因为有模糊匹配
func (s *BillingService) ListSupportedModels() []string {
models := make([]string, 0)
// 返回回退价格支持的模型系列
for model := range s.fallbackPrices {
models = append(models, model)
}
return models
}
// IsModelSupported 检查模型是否支持现在总是返回true因为有模糊匹配回退
func (s *BillingService) IsModelSupported(model string) bool {
// 所有Claude模型都有回退价格支持
modelLower := strings.ToLower(model)
return strings.Contains(modelLower, "claude") ||
strings.Contains(modelLower, "opus") ||
strings.Contains(modelLower, "sonnet") ||
strings.Contains(modelLower, "haiku")
}
// GetEstimatedCost 估算费用(用于前端展示)
func (s *BillingService) GetEstimatedCost(model string, estimatedInputTokens, estimatedOutputTokens int) (float64, error) {
tokens := UsageTokens{
InputTokens: estimatedInputTokens,
OutputTokens: estimatedOutputTokens,
}
breakdown, err := s.CalculateCostWithConfig(model, tokens)
if err != nil {
return 0, err
}
return breakdown.ActualCost, nil
}
// GetPricingServiceStatus 获取价格服务状态
func (s *BillingService) GetPricingServiceStatus() map[string]any {
if s.pricingService != nil {
return s.pricingService.GetStatus()
}
return map[string]any{
"model_count": len(s.fallbackPrices),
"last_updated": "using fallback",
"local_hash": "N/A",
}
}
// ForceUpdatePricing 强制更新价格数据
func (s *BillingService) ForceUpdatePricing() error {
if s.pricingService != nil {
return s.pricingService.ForceUpdate()
}
return fmt.Errorf("pricing service not initialized")
}
// ImagePriceConfig 图片计费配置
type ImagePriceConfig struct {
Price1K *float64 // 1K 尺寸价格nil 表示使用默认值)
Price2K *float64 // 2K 尺寸价格nil 表示使用默认值)
Price4K *float64 // 4K 尺寸价格nil 表示使用默认值)
}
// SoraPriceConfig Sora 按次计费配置
type SoraPriceConfig struct {
ImagePrice360 *float64
ImagePrice540 *float64
VideoPricePerRequest *float64
VideoPricePerRequestHD *float64
}
// CalculateImageCost 计算图片生成费用
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
// imageSize: 图片尺寸 "1K", "2K", "4K"
// imageCount: 生成的图片数量
// groupConfig: 分组配置的价格(可能为 nil表示使用默认值
// rateMultiplier: 费率倍数
func (s *BillingService) CalculateImageCost(model string, imageSize string, imageCount int, groupConfig *ImagePriceConfig, rateMultiplier float64) *CostBreakdown {
if imageCount <= 0 {
return &CostBreakdown{}
}
// 获取单价
unitPrice := s.getImageUnitPrice(model, imageSize, groupConfig)
// 计算总费用
totalCost := unitPrice * float64(imageCount)
// 应用倍率
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// CalculateSoraImageCost 计算 Sora 图片按次费用
func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
if imageCount <= 0 {
return &CostBreakdown{}
}
unitPrice := 0.0
if groupConfig != nil {
switch imageSize {
case "540":
if groupConfig.ImagePrice540 != nil {
unitPrice = *groupConfig.ImagePrice540
}
default:
if groupConfig.ImagePrice360 != nil {
unitPrice = *groupConfig.ImagePrice360
}
}
}
totalCost := unitPrice * float64(imageCount)
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// CalculateSoraVideoCost 计算 Sora 视频按次费用
func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
unitPrice := 0.0
if groupConfig != nil {
modelLower := strings.ToLower(model)
if strings.Contains(modelLower, "sora2pro-hd") {
if groupConfig.VideoPricePerRequestHD != nil {
unitPrice = *groupConfig.VideoPricePerRequestHD
}
}
if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil {
unitPrice = *groupConfig.VideoPricePerRequest
}
}
totalCost := unitPrice
if rateMultiplier <= 0 {
rateMultiplier = 1.0
}
actualCost := totalCost * rateMultiplier
return &CostBreakdown{
TotalCost: totalCost,
ActualCost: actualCost,
}
}
// getImageUnitPrice 获取图片单价
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
// 优先使用分组配置的价格
if groupConfig != nil {
switch imageSize {
case "1K":
if groupConfig.Price1K != nil {
return *groupConfig.Price1K
}
case "2K":
if groupConfig.Price2K != nil {
return *groupConfig.Price2K
}
case "4K":
if groupConfig.Price4K != nil {
return *groupConfig.Price4K
}
}
}
// 回退到 LiteLLM 默认价格
return s.getDefaultImagePrice(model, imageSize)
}
// getDefaultImagePrice 获取 LiteLLM 默认图片价格
func (s *BillingService) getDefaultImagePrice(model string, imageSize string) float64 {
basePrice := 0.0
// 从 PricingService 获取 output_cost_per_image
if s.pricingService != nil {
pricing := s.pricingService.GetModelPricing(model)
if pricing != nil && pricing.OutputCostPerImage > 0 {
basePrice = pricing.OutputCostPerImage
}
}
// 如果没有找到价格,使用硬编码默认值($0.134,来自 gemini-3-pro-image-preview
if basePrice <= 0 {
basePrice = 0.134
}
// 2K 尺寸 1.5 倍4K 尺寸翻倍
if imageSize == "2K" {
return basePrice * 1.5
}
if imageSize == "4K" {
return basePrice * 2
}
return basePrice
}

View File

@@ -0,0 +1,149 @@
//go:build unit
package service
import (
"testing"
"github.com/stretchr/testify/require"
)
// TestCalculateImageCost_DefaultPricing 测试无分组配置时使用默认价格
func TestCalculateImageCost_DefaultPricing(t *testing.T) {
svc := &BillingService{} // pricingService 为 nil使用硬编码默认值
// 2K 尺寸,默认价格 $0.134 * 1.5 = $0.201
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
require.InDelta(t, 0.201, cost.ActualCost, 0.0001)
// 多张图片
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 3, nil, 1.0)
require.InDelta(t, 0.603, cost.TotalCost, 0.0001)
}
// TestCalculateImageCost_GroupCustomPricing 测试分组自定义价格
func TestCalculateImageCost_GroupCustomPricing(t *testing.T) {
svc := &BillingService{}
price1K := 0.10
price2K := 0.15
price4K := 0.30
groupConfig := &ImagePriceConfig{
Price1K: &price1K,
Price2K: &price2K,
Price4K: &price4K,
}
// 1K 使用分组价格
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 2, groupConfig, 1.0)
require.InDelta(t, 0.20, cost.TotalCost, 0.0001)
// 2K 使用分组价格
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
require.InDelta(t, 0.15, cost.TotalCost, 0.0001)
// 4K 使用分组价格
cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0)
require.InDelta(t, 0.30, cost.TotalCost, 0.0001)
}
// TestCalculateImageCost_4KDoublePrice 测试 4K 默认价格翻倍
func TestCalculateImageCost_4KDoublePrice(t *testing.T) {
svc := &BillingService{}
// 4K 尺寸,默认价格翻倍 $0.134 * 2 = $0.268
cost := svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, nil, 1.0)
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
}
// TestCalculateImageCost_RateMultiplier 测试费率倍数
func TestCalculateImageCost_RateMultiplier(t *testing.T) {
svc := &BillingService{}
// 费率倍数 1.5x
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5
require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5
// 费率倍数 2.0x
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 2, nil, 2.0)
require.InDelta(t, 0.402, cost.TotalCost, 0.0001)
require.InDelta(t, 0.804, cost.ActualCost, 0.0001)
}
// TestCalculateImageCost_ZeroCount 测试 imageCount=0
func TestCalculateImageCost_ZeroCount(t *testing.T) {
svc := &BillingService{}
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 0, nil, 1.0)
require.Equal(t, 0.0, cost.TotalCost)
require.Equal(t, 0.0, cost.ActualCost)
}
// TestCalculateImageCost_NegativeCount 测试 imageCount=-1
func TestCalculateImageCost_NegativeCount(t *testing.T) {
svc := &BillingService{}
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", -1, nil, 1.0)
require.Equal(t, 0.0, cost.TotalCost)
require.Equal(t, 0.0, cost.ActualCost)
}
// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0
func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
svc := &BillingService{}
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
}
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
func TestGetImageUnitPrice_GroupPriorityOverDefault(t *testing.T) {
svc := &BillingService{}
price2K := 0.20
groupConfig := &ImagePriceConfig{
Price2K: &price2K,
}
// 分组配置了 2K 价格,应该使用分组价格而不是默认的 $0.134
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
require.InDelta(t, 0.20, cost.TotalCost, 0.0001)
}
// TestGetImageUnitPrice_PartialGroupConfig 测试分组部分配置时回退默认
func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) {
svc := &BillingService{}
// 只配置 1K 价格
price1K := 0.10
groupConfig := &ImagePriceConfig{
Price1K: &price1K,
}
// 1K 使用分组价格
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, groupConfig, 1.0)
require.InDelta(t, 0.10, cost.TotalCost, 0.0001)
// 2K 回退默认价格 $0.201 (1.5倍)
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
// 4K 回退默认价格 $0.268 (翻倍)
cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0)
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
}
// TestGetDefaultImagePrice_FallbackHardcoded 测试 PricingService 无数据时使用硬编码默认值
func TestGetDefaultImagePrice_FallbackHardcoded(t *testing.T) {
svc := &BillingService{} // pricingService 为 nil
// 1K 默认价格 $0.1342K 默认价格 $0.201 (1.5倍)
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, nil, 1.0)
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
}

View File

@@ -0,0 +1,710 @@
//go:build unit
package service
import (
"math"
"testing"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
func newTestBillingService() *BillingService {
return NewBillingService(&config.Config{}, nil)
}
func TestCalculateCost_BasicComputation(t *testing.T) {
svc := newTestBillingService()
// 使用 claude-sonnet-4 的回退价格Input $3/MTok, Output $15/MTok
tokens := UsageTokens{
InputTokens: 1000,
OutputTokens: 500,
}
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
// 1000 * 3e-6 = 0.003, 500 * 15e-6 = 0.0075
expectedInput := 1000 * 3e-6
expectedOutput := 500 * 15e-6
require.InDelta(t, expectedInput, cost.InputCost, 1e-10)
require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10)
require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10)
require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10)
}
func TestCalculateCost_WithCacheTokens(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{
InputTokens: 1000,
OutputTokens: 500,
CacheCreationTokens: 2000,
CacheReadTokens: 3000,
}
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
expectedCacheCreation := 2000 * 3.75e-6
expectedCacheRead := 3000 * 0.3e-6
require.InDelta(t, expectedCacheCreation, cost.CacheCreationCost, 1e-10)
require.InDelta(t, expectedCacheRead, cost.CacheReadCost, 1e-10)
expectedTotal := cost.InputCost + cost.OutputCost + expectedCacheCreation + expectedCacheRead
require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10)
}
func TestCalculateCost_RateMultiplier(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
cost1x, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
cost2x, err := svc.CalculateCost("claude-sonnet-4", tokens, 2.0)
require.NoError(t, err)
// TotalCost 不受倍率影响ActualCost 翻倍
require.InDelta(t, cost1x.TotalCost, cost2x.TotalCost, 1e-10)
require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10)
}
func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000}
costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0)
require.NoError(t, err)
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
}
func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 1000}
costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0)
require.NoError(t, err)
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
}
func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) {
svc := newTestBillingService()
tests := []struct {
model string
expectedInput float64
}{
{"claude-opus-4.5-20250101", 5e-6},
{"claude-3-opus-20240229", 15e-6},
{"claude-sonnet-4-20250514", 3e-6},
{"claude-3-5-sonnet-20241022", 3e-6},
{"claude-3-5-haiku-20241022", 1e-6},
{"claude-3-haiku-20240307", 0.25e-6},
}
for _, tt := range tests {
pricing, err := svc.GetModelPricing(tt.model)
require.NoError(t, err, "模型 %s", tt.model)
require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12, "模型 %s 输入价格", tt.model)
}
}
func TestGetModelPricing_CaseInsensitive(t *testing.T) {
svc := newTestBillingService()
p1, err := svc.GetModelPricing("Claude-Sonnet-4")
require.NoError(t, err)
p2, err := svc.GetModelPricing("claude-sonnet-4")
require.NoError(t, err)
require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken)
}
func TestGetModelPricing_UnknownClaudeModelFallsBackToSonnet(t *testing.T) {
svc := newTestBillingService()
// 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格
pricing, err := svc.GetModelPricing("claude-unknown-model")
require.NoError(t, err)
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12)
}
func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricing("gpt-unknown-model")
require.Error(t, err)
require.Nil(t, pricing)
require.Contains(t, err.Error(), "pricing not found")
}
func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricing("gpt-5.1")
require.NoError(t, err)
require.NotNil(t, pricing)
require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12)
}
func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
svc := newTestBillingService()
pricing, err := svc.GetModelPricing("gpt-5.4")
require.NoError(t, err)
require.NotNil(t, pricing)
require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12)
require.Equal(t, 272000, pricing.LongContextInputThreshold)
require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12)
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
}
func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{
InputTokens: 300000,
OutputTokens: 4000,
}
cost, err := svc.CalculateCost("gpt-5.4-2026-03-05", tokens, 1.0)
require.NoError(t, err)
expectedInput := float64(tokens.InputTokens) * 2.5e-6 * 2.0
expectedOutput := float64(tokens.OutputTokens) * 15e-6 * 1.5
require.InDelta(t, expectedInput, cost.InputCost, 1e-10)
require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10)
require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10)
require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10)
}
func TestGetFallbackPricing_FamilyMatching(t *testing.T) {
svc := newTestBillingService()
tests := []struct {
name string
model string
expectedInput float64
expectNilPricing bool
}{
{name: "empty model", model: " ", expectNilPricing: true},
{name: "claude opus 4.6", model: "claude-opus-4.6-20260201", expectedInput: 5e-6},
{name: "claude opus 4.5 alt separator", model: "claude-opus-4-5-20260101", expectedInput: 5e-6},
{name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6},
{name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6},
{name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true},
{name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6},
{name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6},
{name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6},
{name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6},
{name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6},
{name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true},
{name: "non supported family", model: "qwen-max", expectNilPricing: true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pricing := svc.getFallbackPricing(tt.model)
if tt.expectNilPricing {
require.Nil(t, pricing)
return
}
require.NotNil(t, pricing)
require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12)
})
}
}
func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{
InputTokens: 50000,
OutputTokens: 1000,
CacheReadTokens: 100000,
}
// 总输入 150k < 200k 阈值,应走正常计费
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0)
require.NoError(t, err)
normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10)
}
func TestCalculateCostWithLongContext_AboveThreshold_CacheExceedsThreshold(t *testing.T) {
svc := newTestBillingService()
// 缓存 210k + 输入 10k = 220k > 200k 阈值
// 缓存已超阈值:范围内 200k 缓存,范围外 10k 缓存 + 10k 输入
tokens := UsageTokens{
InputTokens: 10000,
OutputTokens: 1000,
CacheReadTokens: 210000,
}
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0)
require.NoError(t, err)
// 范围内200k cache + 0 input + 1k output
inRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{
InputTokens: 0,
OutputTokens: 1000,
CacheReadTokens: 200000,
}, 1.0)
// 范围外10k cache + 10k input倍率 2.0
outRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{
InputTokens: 10000,
CacheReadTokens: 10000,
}, 2.0)
require.InDelta(t, inRange.ActualCost+outRange.ActualCost, cost.ActualCost, 1e-10)
}
func TestCalculateCostWithLongContext_AboveThreshold_CacheBelowThreshold(t *testing.T) {
svc := newTestBillingService()
// 缓存 100k + 输入 150k = 250k > 200k 阈值
// 缓存未超阈值:范围内 100k 缓存 + 100k 输入,范围外 50k 输入
tokens := UsageTokens{
InputTokens: 150000,
OutputTokens: 1000,
CacheReadTokens: 100000,
}
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0)
require.NoError(t, err)
require.True(t, cost.ActualCost > 0, "费用应大于 0")
// 正常费用不含长上下文
normalCost, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.True(t, cost.ActualCost > normalCost.ActualCost, "长上下文费用应高于正常费用")
}
func TestCalculateCostWithLongContext_DisabledThreshold(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0}
// threshold <= 0 应禁用长上下文计费
cost1, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 0, 2.0)
require.NoError(t, err)
cost2, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, cost2.ActualCost, cost1.ActualCost, 1e-10)
}
func TestCalculateCostWithLongContext_ExtraMultiplierLessEqualOne(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 300000}
// extraMultiplier <= 1 应禁用长上下文计费
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 1.0)
require.NoError(t, err)
normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10)
}
func TestCalculateImageCost(t *testing.T) {
svc := newTestBillingService()
price := 0.134
cfg := &ImagePriceConfig{Price1K: &price}
cost := svc.CalculateImageCost("gpt-image-1", "1K", 3, cfg, 1.0)
require.InDelta(t, 0.134*3, cost.TotalCost, 1e-10)
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
}
func TestCalculateSoraVideoCost(t *testing.T) {
svc := newTestBillingService()
price := 0.5
cfg := &SoraPriceConfig{VideoPricePerRequest: &price}
cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0)
require.InDelta(t, 0.5, cost.TotalCost, 1e-10)
}
func TestCalculateSoraVideoCost_HDModel(t *testing.T) {
svc := newTestBillingService()
hdPrice := 1.0
normalPrice := 0.5
cfg := &SoraPriceConfig{
VideoPricePerRequest: &normalPrice,
VideoPricePerRequestHD: &hdPrice,
}
cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0)
require.InDelta(t, 1.0, cost.TotalCost, 1e-10)
}
func TestIsModelSupported(t *testing.T) {
svc := newTestBillingService()
require.True(t, svc.IsModelSupported("claude-sonnet-4"))
require.True(t, svc.IsModelSupported("Claude-Opus-4.5"))
require.True(t, svc.IsModelSupported("claude-3-haiku"))
require.False(t, svc.IsModelSupported("gpt-4o"))
require.False(t, svc.IsModelSupported("gemini-pro"))
}
func TestCalculateCost_ZeroTokens(t *testing.T) {
svc := newTestBillingService()
cost, err := svc.CalculateCost("claude-sonnet-4", UsageTokens{}, 1.0)
require.NoError(t, err)
require.Equal(t, 0.0, cost.TotalCost)
require.Equal(t, 0.0, cost.ActualCost)
}
func TestCalculateCostWithConfig(t *testing.T) {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 1.5
svc := NewBillingService(cfg, nil)
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
cost, err := svc.CalculateCostWithConfig("claude-sonnet-4", tokens)
require.NoError(t, err)
expected, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.5)
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10)
}
func TestCalculateCostWithConfig_ZeroMultiplier(t *testing.T) {
cfg := &config.Config{}
cfg.Default.RateMultiplier = 0
svc := NewBillingService(cfg, nil)
tokens := UsageTokens{InputTokens: 1000}
cost, err := svc.CalculateCostWithConfig("claude-sonnet-4", tokens)
require.NoError(t, err)
// 倍率 <=0 时默认 1.0
expected, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10)
}
func TestGetEstimatedCost(t *testing.T) {
svc := newTestBillingService()
est, err := svc.GetEstimatedCost("claude-sonnet-4", 1000, 500)
require.NoError(t, err)
require.True(t, est > 0)
}
func TestListSupportedModels(t *testing.T) {
svc := newTestBillingService()
models := svc.ListSupportedModels()
require.NotEmpty(t, models)
require.GreaterOrEqual(t, len(models), 6)
}
func TestGetPricingServiceStatus_NilService(t *testing.T) {
svc := newTestBillingService()
status := svc.GetPricingServiceStatus()
require.NotNil(t, status)
require.Equal(t, "using fallback", status["last_updated"])
}
func TestForceUpdatePricing_NilService(t *testing.T) {
svc := newTestBillingService()
err := svc.ForceUpdatePricing()
require.Error(t, err)
require.Contains(t, err.Error(), "not initialized")
}
func TestCalculateSoraImageCost(t *testing.T) {
svc := newTestBillingService()
price360 := 0.05
price540 := 0.08
cfg := &SoraPriceConfig{ImagePrice360: &price360, ImagePrice540: &price540}
cost := svc.CalculateSoraImageCost("360", 2, cfg, 1.0)
require.InDelta(t, 0.10, cost.TotalCost, 1e-10)
cost540 := svc.CalculateSoraImageCost("540", 1, cfg, 2.0)
require.InDelta(t, 0.08, cost540.TotalCost, 1e-10)
require.InDelta(t, 0.16, cost540.ActualCost, 1e-10)
}
func TestCalculateSoraImageCost_ZeroCount(t *testing.T) {
svc := newTestBillingService()
cost := svc.CalculateSoraImageCost("360", 0, nil, 1.0)
require.Equal(t, 0.0, cost.TotalCost)
}
func TestCalculateSoraVideoCost_NilConfig(t *testing.T) {
svc := newTestBillingService()
cost := svc.CalculateSoraVideoCost("sora-video", nil, 1.0)
require.Equal(t, 0.0, cost.TotalCost)
}
func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) {
// 使用空的 fallback prices 让 GetModelPricing 失败
svc := &BillingService{
cfg: &config.Config{},
fallbackPrices: make(map[string]*ModelPricing),
}
tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0}
_, err := svc.CalculateCostWithLongContext("unknown-model", tokens, 1.0, 200000, 2.0)
require.Error(t, err)
require.Contains(t, err.Error(), "pricing not found")
}
func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) {
svc := &BillingService{
cfg: &config.Config{},
fallbackPrices: map[string]*ModelPricing{
"claude-sonnet-4": {
InputPricePerToken: 3e-6,
OutputPricePerToken: 15e-6,
SupportsCacheBreakdown: true,
CacheCreation5mPrice: 4e-6, // per token
CacheCreation1hPrice: 5e-6, // per token
},
},
}
tokens := UsageTokens{
InputTokens: 1000,
OutputTokens: 500,
CacheCreation5mTokens: 100000,
CacheCreation1hTokens: 50000,
}
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
expected5m := float64(tokens.CacheCreation5mTokens) * 4e-6
expected1h := float64(tokens.CacheCreation1hTokens) * 5e-6
require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10)
}
func TestCalculateCost_LargeTokenCount(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{
InputTokens: 1_000_000,
OutputTokens: 1_000_000,
}
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
// Input: 1M * 3e-6 = $3, Output: 1M * 15e-6 = $15
require.InDelta(t, 3.0, cost.InputCost, 1e-6)
require.InDelta(t, 15.0, cost.OutputCost, 1e-6)
require.False(t, math.IsNaN(cost.TotalCost))
require.False(t, math.IsInf(cost.TotalCost, 0))
}
func TestServiceTierCostMultiplier(t *testing.T) {
require.InDelta(t, 2.0, serviceTierCostMultiplier("priority"), 1e-12)
require.InDelta(t, 2.0, serviceTierCostMultiplier(" Priority "), 1e-12)
require.InDelta(t, 0.5, serviceTierCostMultiplier("flex"), 1e-12)
require.InDelta(t, 1.0, serviceTierCostMultiplier(""), 1e-12)
require.InDelta(t, 1.0, serviceTierCostMultiplier("default"), 1e-12)
}
func TestCalculateCostWithServiceTier_OpenAIPriorityUsesPriorityPricing(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheReadTokens: 20}
baseCost, err := svc.CalculateCost("gpt-5.1-codex", tokens, 1.0)
require.NoError(t, err)
priorityCost, err := svc.CalculateCostWithServiceTier("gpt-5.1-codex", tokens, 1.0, "priority")
require.NoError(t, err)
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
}
func TestCalculateCostWithServiceTier_FlexAppliesHalfMultiplier(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20}
baseCost, err := svc.CalculateCost("gpt-5.4", tokens, 1.0)
require.NoError(t, err)
flexCost, err := svc.CalculateCostWithServiceTier("gpt-5.4", tokens, 1.0, "flex")
require.NoError(t, err)
require.InDelta(t, baseCost.InputCost*0.5, flexCost.InputCost, 1e-10)
require.InDelta(t, baseCost.OutputCost*0.5, flexCost.OutputCost, 1e-10)
require.InDelta(t, baseCost.CacheCreationCost*0.5, flexCost.CacheCreationCost, 1e-10)
require.InDelta(t, baseCost.CacheReadCost*0.5, flexCost.CacheReadCost, 1e-10)
require.InDelta(t, baseCost.TotalCost*0.5, flexCost.TotalCost, 1e-10)
}
func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWithoutExplicitPriorityPrice(t *testing.T) {
svc := newTestBillingService()
tokens := UsageTokens{InputTokens: 120, OutputTokens: 30, CacheCreationTokens: 12, CacheReadTokens: 8}
baseCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
require.NoError(t, err)
priorityCost, err := svc.CalculateCostWithServiceTier("claude-sonnet-4", tokens, 1.0, "priority")
require.NoError(t, err)
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10)
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
}
func TestBillingServiceGetModelPricing_UsesDynamicPriorityFields(t *testing.T) {
pricingSvc := &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"gpt-5.4": {
InputCostPerToken: 2.5e-6,
InputCostPerTokenPriority: 5e-6,
OutputCostPerToken: 15e-6,
OutputCostPerTokenPriority: 30e-6,
CacheCreationInputTokenCost: 2.5e-6,
CacheReadInputTokenCost: 0.25e-6,
CacheReadInputTokenCostPriority: 0.5e-6,
LongContextInputTokenThreshold: 272000,
LongContextInputCostMultiplier: 2.0,
LongContextOutputCostMultiplier: 1.5,
},
},
}
svc := NewBillingService(&config.Config{}, pricingSvc)
pricing, err := svc.GetModelPricing("gpt-5.4")
require.NoError(t, err)
require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 5e-6, pricing.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 30e-6, pricing.OutputPricePerTokenPriority, 1e-12)
require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12)
require.InDelta(t, 0.5e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
require.Equal(t, 272000, pricing.LongContextInputThreshold)
require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12)
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
}
func TestBillingServiceGetModelPricing_OpenAIFallbackGpt52Variants(t *testing.T) {
svc := newTestBillingService()
gpt52, err := svc.GetModelPricing("gpt-5.2")
require.NoError(t, err)
require.NotNil(t, gpt52)
require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12)
require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12)
gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex")
require.NoError(t, err)
require.NotNil(t, gpt52Codex)
require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12)
require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12)
}
func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWhenExplicitPriceMissing(t *testing.T) {
svc := NewBillingService(&config.Config{}, &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"custom-no-priority": {
InputCostPerToken: 1e-6,
OutputCostPerToken: 2e-6,
CacheCreationInputTokenCost: 0.5e-6,
CacheReadInputTokenCost: 0.25e-6,
},
},
})
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20}
baseCost, err := svc.CalculateCost("custom-no-priority", tokens, 1.0)
require.NoError(t, err)
priorityCost, err := svc.CalculateCostWithServiceTier("custom-no-priority", tokens, 1.0, "priority")
require.NoError(t, err)
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10)
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
}
func TestGetModelPricing_OpenAIGpt52FallbacksExposePriorityPrices(t *testing.T) {
svc := newTestBillingService()
gpt52, err := svc.GetModelPricing("gpt-5.2")
require.NoError(t, err)
require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12)
require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 14e-6, gpt52.OutputPricePerToken, 1e-12)
require.InDelta(t, 28e-6, gpt52.OutputPricePerTokenPriority, 1e-12)
gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex")
require.NoError(t, err)
require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12)
require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 14e-6, gpt52Codex.OutputPricePerToken, 1e-12)
require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12)
}
func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.T) {
svc := NewBillingService(&config.Config{}, &PricingService{
pricingData: map[string]*LiteLLMModelPricing{
"dynamic-tier-model": {
InputCostPerToken: 1e-6,
InputCostPerTokenPriority: 2e-6,
OutputCostPerToken: 3e-6,
OutputCostPerTokenPriority: 6e-6,
CacheCreationInputTokenCost: 4e-6,
CacheCreationInputTokenCostAbove1hr: 5e-6,
CacheReadInputTokenCost: 7e-7,
CacheReadInputTokenCostPriority: 8e-7,
LongContextInputTokenThreshold: 999,
LongContextInputCostMultiplier: 1.5,
LongContextOutputCostMultiplier: 1.25,
},
},
})
pricing, err := svc.GetModelPricing("dynamic-tier-model")
require.NoError(t, err)
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
require.InDelta(t, 2e-6, pricing.InputPricePerTokenPriority, 1e-12)
require.InDelta(t, 3e-6, pricing.OutputPricePerToken, 1e-12)
require.InDelta(t, 6e-6, pricing.OutputPricePerTokenPriority, 1e-12)
require.InDelta(t, 4e-6, pricing.CacheCreation5mPrice, 1e-12)
require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12)
require.True(t, pricing.SupportsCacheBreakdown)
require.InDelta(t, 7e-7, pricing.CacheReadPricePerToken, 1e-12)
require.InDelta(t, 8e-7, pricing.CacheReadPricePerTokenPriority, 1e-12)
require.Equal(t, 999, pricing.LongContextInputThreshold)
require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12)
require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12)
}

View File

@@ -0,0 +1,282 @@
//go:build unit
package service
import (
"context"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func newTestValidator() *ClaudeCodeValidator {
return NewClaudeCodeValidator()
}
// validClaudeCodeBody 构造一个完整有效的 Claude Code 请求体
func validClaudeCodeBody() map[string]any {
return map[string]any{
"model": "claude-sonnet-4-20250514",
"system": []any{
map[string]any{
"type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
},
},
"metadata": map[string]any{
"user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_" + "12345678-1234-1234-1234-123456789abc",
},
}
}
func TestValidate_ClaudeCLIUserAgent(t *testing.T) {
v := newTestValidator()
tests := []struct {
name string
ua string
want bool
}{
{"标准版本号", "claude-cli/1.0.0", true},
{"多位版本号", "claude-cli/12.34.56", true},
{"大写开头", "Claude-CLI/1.0.0", true},
{"非 claude-cli", "curl/7.64.1", false},
{"空 User-Agent", "", false},
{"部分匹配", "not-claude-cli/1.0.0", false},
{"缺少版本号", "claude-cli/", false},
{"版本格式不对", "claude-cli/1.0", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
require.Equal(t, tt.want, v.ValidateUserAgent(tt.ua), "UA: %q", tt.ua)
})
}
}
func TestValidate_NonMessagesPath_UAOnly(t *testing.T) {
v := newTestValidator()
// 非 messages 路径只检查 UA
req := httptest.NewRequest("GET", "/v1/models", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
result := v.Validate(req, nil)
require.True(t, result, "非 messages 路径只需 UA 匹配")
}
func TestValidate_NonMessagesPath_InvalidUA(t *testing.T) {
v := newTestValidator()
req := httptest.NewRequest("GET", "/v1/models", nil)
req.Header.Set("User-Agent", "curl/7.64.1")
result := v.Validate(req, nil)
require.False(t, result, "UA 不匹配时应返回 false")
}
func TestValidate_MessagesPath_FullValid(t *testing.T) {
v := newTestValidator()
req := httptest.NewRequest("POST", "/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
req.Header.Set("X-App", "claude-code")
req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15")
req.Header.Set("anthropic-version", "2023-06-01")
result := v.Validate(req, validClaudeCodeBody())
require.True(t, result, "完整有效请求应通过")
}
func TestValidate_MessagesPath_MissingHeaders(t *testing.T) {
v := newTestValidator()
body := validClaudeCodeBody()
tests := []struct {
name string
missingHeader string
}{
{"缺少 X-App", "X-App"},
{"缺少 anthropic-beta", "anthropic-beta"},
{"缺少 anthropic-version", "anthropic-version"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
req.Header.Set("X-App", "claude-code")
req.Header.Set("anthropic-beta", "beta")
req.Header.Set("anthropic-version", "2023-06-01")
req.Header.Del(tt.missingHeader)
result := v.Validate(req, body)
require.False(t, result, "缺少 %s 应返回 false", tt.missingHeader)
})
}
}
func TestValidate_MessagesPath_InvalidMetadataUserID(t *testing.T) {
v := newTestValidator()
tests := []struct {
name string
metadata map[string]any
}{
{"缺少 metadata", nil},
{"缺少 user_id", map[string]any{"other": "value"}},
{"空 user_id", map[string]any{"user_id": ""}},
{"格式错误", map[string]any{"user_id": "invalid-format"}},
{"hex 长度不足", map[string]any{"user_id": "user_abc_account__session_uuid"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
req.Header.Set("X-App", "claude-code")
req.Header.Set("anthropic-beta", "beta")
req.Header.Set("anthropic-version", "2023-06-01")
body := map[string]any{
"model": "claude-sonnet-4",
"system": []any{
map[string]any{
"type": "text",
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
},
},
}
if tt.metadata != nil {
body["metadata"] = tt.metadata
}
result := v.Validate(req, body)
require.False(t, result, "metadata.user_id: %v", tt.metadata)
})
}
}
func TestValidate_MessagesPath_InvalidSystemPrompt(t *testing.T) {
v := newTestValidator()
req := httptest.NewRequest("POST", "/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
req.Header.Set("X-App", "claude-code")
req.Header.Set("anthropic-beta", "beta")
req.Header.Set("anthropic-version", "2023-06-01")
body := map[string]any{
"model": "claude-sonnet-4",
"system": []any{
map[string]any{
"type": "text",
"text": "Generate JSON data for testing database migrations.",
},
},
"metadata": map[string]any{
"user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_12345678-1234-1234-1234-123456789abc",
},
}
result := v.Validate(req, body)
require.False(t, result, "无关系统提示词应返回 false")
}
func TestValidate_MaxTokensOneHaikuBypass(t *testing.T) {
v := newTestValidator()
req := httptest.NewRequest("POST", "/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
// 不设置 X-App 等头,通过 context 标记为 haiku 探测请求
ctx := context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
req = req.WithContext(ctx)
// 即使 body 不包含 system prompt也应通过
result := v.Validate(req, map[string]any{"model": "claude-3-haiku", "max_tokens": 1})
require.True(t, result, "max_tokens=1+haiku 探测请求应绕过严格验证")
}
func TestSystemPromptSimilarity(t *testing.T) {
v := newTestValidator()
tests := []struct {
name string
prompt string
want bool
}{
{"精确匹配", "You are Claude Code, Anthropic's official CLI for Claude.", true},
{"带多余空格", "You are Claude Code, Anthropic's official CLI for Claude.", true},
{"Agent SDK 模板", "You are a Claude agent, built on Anthropic's Claude Agent SDK.", true},
{"文件搜索专家模板", "You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.", true},
{"对话摘要模板", "You are a helpful AI assistant tasked with summarizing conversations.", true},
{"交互式 CLI 模板", "You are an interactive CLI tool that helps users", true},
{"无关文本", "Write me a poem about cats", false},
{"空文本", "", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
body := map[string]any{
"model": "claude-sonnet-4",
"system": []any{
map[string]any{"type": "text", "text": tt.prompt},
},
}
result := v.IncludesClaudeCodeSystemPrompt(body)
require.Equal(t, tt.want, result, "提示词: %q", tt.prompt)
})
}
}
func TestDiceCoefficient(t *testing.T) {
tests := []struct {
name string
a string
b string
want float64
tol float64
}{
{"相同字符串", "hello", "hello", 1.0, 0.001},
{"完全不同", "abc", "xyz", 0.0, 0.001},
{"空字符串", "", "hello", 0.0, 0.001},
{"单字符", "a", "b", 0.0, 0.001},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := diceCoefficient(tt.a, tt.b)
require.InDelta(t, tt.want, result, tt.tol)
})
}
}
func TestIsClaudeCodeClient_Context(t *testing.T) {
ctx := context.Background()
// 默认应为 false
require.False(t, IsClaudeCodeClient(ctx))
// 设置为 true
ctx = SetClaudeCodeClient(ctx, true)
require.True(t, IsClaudeCodeClient(ctx))
// 设置为 false
ctx = SetClaudeCodeClient(ctx, false)
require.False(t, IsClaudeCodeClient(ctx))
}
func TestValidate_NilBody_MessagesPath(t *testing.T) {
v := newTestValidator()
req := httptest.NewRequest("POST", "/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.0.0")
req.Header.Set("X-App", "claude-code")
req.Header.Set("anthropic-beta", "beta")
req.Header.Set("anthropic-version", "2023-06-01")
result := v.Validate(req, nil)
require.False(t, result, "nil body 的 messages 请求应返回 false")
}

View File

@@ -0,0 +1,328 @@
package service
import (
"context"
"net/http"
"regexp"
"strconv"
"strings"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
)
// ClaudeCodeValidator 验证请求是否来自 Claude Code 客户端
// 完全学习自 claude-relay-service 项目的验证逻辑
type ClaudeCodeValidator struct{}
var (
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI大小写不敏感)
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
// 带捕获组的版本提取正则
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
systemPromptThreshold = 0.5
)
// Claude Code 官方 System Prompt 模板
// 从 claude-relay-service/src/utils/contents.js 提取
var claudeCodeSystemPrompts = []string{
// claudeOtherSystemPrompt1 - Primary
"You are Claude Code, Anthropic's official CLI for Claude.",
// claudeOtherSystemPrompt3 - Agent SDK
"You are a Claude agent, built on Anthropic's Claude Agent SDK.",
// claudeOtherSystemPrompt4 - Compact Agent SDK
"You are Claude Code, Anthropic's official CLI for Claude, running within the Claude Agent SDK.",
// exploreAgentSystemPrompt
"You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.",
// claudeOtherSystemPromptCompact - Compact (用于对话摘要)
"You are a helpful AI assistant tasked with summarizing conversations.",
// claudeOtherSystemPrompt2 - Secondary (长提示词的关键部分)
"You are an interactive CLI tool that helps users",
}
// NewClaudeCodeValidator 创建验证器实例
func NewClaudeCodeValidator() *ClaudeCodeValidator {
return &ClaudeCodeValidator{}
}
// Validate 验证请求是否来自 Claude Code CLI
// 采用与 claude-relay-service 完全一致的验证策略:
//
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过UA 已验证)
// Step 4: 对于 messages 路径,进行严格验证:
// - System prompt 相似度检查
// - X-App header 检查
// - anthropic-beta header 检查
// - anthropic-version header 检查
// - metadata.user_id 格式验证
func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) bool {
// Step 1: User-Agent 检查
ua := r.Header.Get("User-Agent")
if !claudeCodeUAPattern.MatchString(ua) {
return false
}
// Step 2: 非 messages 路径,只要 UA 匹配就通过
path := r.URL.Path
if !strings.Contains(path, "messages") {
return true
}
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
if isMaxTokensOneHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(r.Context()); ok && isMaxTokensOneHaiku {
return true // 绕过 system prompt 检查UA 已在 Step 1 验证
}
// Step 4: messages 路径,进行严格验证
// 4.1 检查 system prompt 相似度
if !v.hasClaudeCodeSystemPrompt(body) {
return false
}
// 4.2 检查必需的 headers值不为空即可
xApp := r.Header.Get("X-App")
if xApp == "" {
return false
}
anthropicBeta := r.Header.Get("anthropic-beta")
if anthropicBeta == "" {
return false
}
anthropicVersion := r.Header.Get("anthropic-version")
if anthropicVersion == "" {
return false
}
// 4.3 验证 metadata.user_id
if body == nil {
return false
}
metadata, ok := body["metadata"].(map[string]any)
if !ok {
return false
}
userID, ok := metadata["user_id"].(string)
if !ok || userID == "" {
return false
}
if !userIDPattern.MatchString(userID) {
return false
}
return true
}
// hasClaudeCodeSystemPrompt 检查请求是否包含 Claude Code 系统提示词
// 使用字符串相似度匹配Dice coefficient
func (v *ClaudeCodeValidator) hasClaudeCodeSystemPrompt(body map[string]any) bool {
if body == nil {
return false
}
// 检查 model 字段
if _, ok := body["model"].(string); !ok {
return false
}
// 获取 system 字段
systemEntries, ok := body["system"].([]any)
if !ok {
return false
}
// 检查每个 system entry
for _, entry := range systemEntries {
entryMap, ok := entry.(map[string]any)
if !ok {
continue
}
text, ok := entryMap["text"].(string)
if !ok || text == "" {
continue
}
// 计算与所有模板的最佳相似度
bestScore := v.bestSimilarityScore(text)
if bestScore >= systemPromptThreshold {
return true
}
}
return false
}
// bestSimilarityScore 计算文本与所有 Claude Code 模板的最佳相似度
func (v *ClaudeCodeValidator) bestSimilarityScore(text string) float64 {
normalizedText := normalizePrompt(text)
bestScore := 0.0
for _, template := range claudeCodeSystemPrompts {
normalizedTemplate := normalizePrompt(template)
score := diceCoefficient(normalizedText, normalizedTemplate)
if score > bestScore {
bestScore = score
}
}
return bestScore
}
// normalizePrompt 标准化提示词文本(去除多余空白)
func normalizePrompt(text string) string {
// 将所有空白字符替换为单个空格,并去除首尾空白
return strings.Join(strings.Fields(text), " ")
}
// diceCoefficient 计算两个字符串的 Dice 系数SørensenDice coefficient
// 这是 string-similarity 库使用的算法
// 公式: 2 * |intersection| / (|bigrams(a)| + |bigrams(b)|)
func diceCoefficient(a, b string) float64 {
if a == b {
return 1.0
}
if len(a) < 2 || len(b) < 2 {
return 0.0
}
// 生成 bigrams
bigramsA := getBigrams(a)
bigramsB := getBigrams(b)
if len(bigramsA) == 0 || len(bigramsB) == 0 {
return 0.0
}
// 计算交集大小
intersection := 0
for bigram, countA := range bigramsA {
if countB, exists := bigramsB[bigram]; exists {
if countA < countB {
intersection += countA
} else {
intersection += countB
}
}
}
// 计算总 bigram 数量
totalA := 0
for _, count := range bigramsA {
totalA += count
}
totalB := 0
for _, count := range bigramsB {
totalB += count
}
return float64(2*intersection) / float64(totalA+totalB)
}
// getBigrams 获取字符串的所有 bigrams相邻字符对
func getBigrams(s string) map[string]int {
bigrams := make(map[string]int)
runes := []rune(strings.ToLower(s))
for i := 0; i < len(runes)-1; i++ {
bigram := string(runes[i : i+2])
bigrams[bigram]++
}
return bigrams
}
// ValidateUserAgent 仅验证 User-Agent用于不需要解析请求体的场景
func (v *ClaudeCodeValidator) ValidateUserAgent(ua string) bool {
return claudeCodeUAPattern.MatchString(ua)
}
// IncludesClaudeCodeSystemPrompt 检查请求体是否包含 Claude Code 系统提示词
// 只要存在匹配的系统提示词就返回 true用于宽松检测
func (v *ClaudeCodeValidator) IncludesClaudeCodeSystemPrompt(body map[string]any) bool {
return v.hasClaudeCodeSystemPrompt(body)
}
// IsClaudeCodeClient 从 context 中获取 Claude Code 客户端标识
func IsClaudeCodeClient(ctx context.Context) bool {
if v, ok := ctx.Value(ctxkey.IsClaudeCodeClient).(bool); ok {
return v
}
return false
}
// SetClaudeCodeClient 将 Claude Code 客户端标识设置到 context 中
func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context {
return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode)
}
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
if len(matches) >= 2 {
return matches[1]
}
return ""
}
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中
func SetClaudeCodeVersion(ctx context.Context, version string) context.Context {
return context.WithValue(ctx, ctxkey.ClaudeCodeVersion, version)
}
// GetClaudeCodeVersion 从 context 中获取 Claude Code 版本号
func GetClaudeCodeVersion(ctx context.Context) string {
if v, ok := ctx.Value(ctxkey.ClaudeCodeVersion).(string); ok {
return v
}
return ""
}
// CompareVersions 比较两个 semver 版本号
// 返回: -1 (a < b), 0 (a == b), 1 (a > b)
func CompareVersions(a, b string) int {
aParts := parseSemver(a)
bParts := parseSemver(b)
for i := 0; i < 3; i++ {
if aParts[i] < bParts[i] {
return -1
}
if aParts[i] > bParts[i] {
return 1
}
}
return 0
}
// parseSemver 解析 semver 版本号为 [major, minor, patch]
func parseSemver(v string) [3]int {
v = strings.TrimPrefix(v, "v")
parts := strings.Split(v, ".")
result := [3]int{0, 0, 0}
for i := 0; i < len(parts) && i < 3; i++ {
if parsed, err := strconv.Atoi(parts[i]); err == nil {
result[i] = parsed
}
}
return result
}

View File

@@ -0,0 +1,106 @@
package service
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
"github.com/stretchr/testify/require"
)
func TestClaudeCodeValidator_ProbeBypass(t *testing.T) {
validator := NewClaudeCodeValidator()
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true))
ok := validator.Validate(req, map[string]any{
"model": "claude-haiku-4-5",
"max_tokens": 1,
})
require.True(t, ok)
}
func TestClaudeCodeValidator_ProbeBypassRequiresUA(t *testing.T) {
validator := NewClaudeCodeValidator()
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
req.Header.Set("User-Agent", "curl/8.0.0")
req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true))
ok := validator.Validate(req, map[string]any{
"model": "claude-haiku-4-5",
"max_tokens": 1,
})
require.False(t, ok)
}
func TestClaudeCodeValidator_MessagesWithoutProbeStillNeedStrictValidation(t *testing.T) {
validator := NewClaudeCodeValidator()
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
ok := validator.Validate(req, map[string]any{
"model": "claude-haiku-4-5",
"max_tokens": 1,
})
require.False(t, ok)
}
func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) {
validator := NewClaudeCodeValidator()
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/models", nil)
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
ok := validator.Validate(req, nil)
require.True(t, ok)
}
func TestExtractVersion(t *testing.T) {
v := NewClaudeCodeValidator()
tests := []struct {
ua string
want string
}{
{"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"},
{"claude-cli/1.0.0", "1.0.0"},
{"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感
{"curl/8.0.0", ""}, // 非 Claude CLI
{"", ""}, // 空字符串
{"claude-cli/", ""}, // 无版本号
{"claude-cli/2.1.22-beta", "2.1.22"}, // 带后缀仍提取主版本号
}
for _, tt := range tests {
got := v.ExtractVersion(tt.ua)
require.Equal(t, tt.want, got, "ExtractVersion(%q)", tt.ua)
}
}
func TestCompareVersions(t *testing.T) {
tests := []struct {
a, b string
want int
}{
{"2.1.0", "2.1.0", 0}, // 相等
{"2.1.1", "2.1.0", 1}, // patch 更大
{"2.0.0", "2.1.0", -1}, // minor 更小
{"3.0.0", "2.99.99", 1}, // major 更大
{"1.0.0", "2.0.0", -1}, // major 更小
{"0.0.1", "0.0.0", 1}, // patch 差异
{"", "1.0.0", -1}, // 空字符串 vs 正常版本
{"v2.1.0", "2.1.0", 0}, // v 前缀处理
}
for _, tt := range tests {
got := CompareVersions(tt.a, tt.b)
require.Equal(t, tt.want, got, "CompareVersions(%q, %q)", tt.a, tt.b)
}
}
func TestSetGetClaudeCodeVersion(t *testing.T) {
ctx := context.Background()
require.Equal(t, "", GetClaudeCodeVersion(ctx), "empty context should return empty string")
ctx = SetClaudeCodeVersion(ctx, "2.1.63")
require.Equal(t, "2.1.63", GetClaudeCodeVersion(ctx))
}

View File

@@ -0,0 +1,159 @@
package service
import (
"context"
"errors"
"log/slog"
"strings"
"time"
)
const (
claudeTokenRefreshSkew = 3 * time.Minute
claudeTokenCacheSkew = 5 * time.Minute
claudeLockWaitTime = 200 * time.Millisecond
)
// ClaudeTokenCache token cache interface.
type ClaudeTokenCache = GeminiTokenCache
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
type ClaudeTokenProvider struct {
accountRepo AccountRepository
tokenCache ClaudeTokenCache
oauthService *OAuthService
refreshAPI *OAuthRefreshAPI
executor OAuthRefreshExecutor
refreshPolicy ProviderRefreshPolicy
}
func NewClaudeTokenProvider(
accountRepo AccountRepository,
tokenCache ClaudeTokenCache,
oauthService *OAuthService,
) *ClaudeTokenProvider {
return &ClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: tokenCache,
oauthService: oauthService,
refreshPolicy: ClaudeProviderRefreshPolicy(),
}
}
// SetRefreshAPI injects unified OAuth refresh API and executor.
func (p *ClaudeTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
p.refreshAPI = api
p.executor = executor
}
// SetRefreshPolicy injects caller-side refresh policy.
func (p *ClaudeTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
p.refreshPolicy = policy
}
// GetAccessToken returns a valid access_token.
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account")
}
cacheKey := ClaudeTokenCacheKey(account)
// 1) Try cache first.
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
return token, nil
} else if err != nil {
slog.Warn("claude_token_cache_get_failed", "account_id", account.ID, "error", err)
}
}
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
// 2) Refresh if needed (pre-expiry skew).
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
refreshFailed := false
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, claudeTokenRefreshSkew)
if err != nil {
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
return "", err
}
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
refreshFailed = true
} else if result.LockHeld {
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
time.Sleep(claudeLockWaitTime)
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
} else {
account = result.Account
expiresAt = account.GetCredentialAsTime("expires_at")
}
} else if needsRefresh && p.tokenCache != nil {
// Backward-compatible test path when refreshAPI is not injected.
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if lockErr == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
} else if lockErr != nil {
slog.Warn("claude_token_lock_failed", "account_id", account.ID, "error", lockErr)
} else {
time.Sleep(claudeLockWaitTime)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
return token, nil
}
}
}
accessToken := account.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found in credentials")
}
// 3) Populate cache with TTL.
if p.tokenCache != nil {
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
if isStale && latestAccount != nil {
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
accessToken = latestAccount.GetCredential("access_token")
if strings.TrimSpace(accessToken) == "" {
return "", errors.New("access_token not found after version check")
}
} else {
ttl := 30 * time.Minute
if refreshFailed {
if p.refreshPolicy.FailureTTL > 0 {
ttl = p.refreshPolicy.FailureTTL
} else {
ttl = time.Minute
}
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
} else if expiresAt != nil {
until := time.Until(*expiresAt)
switch {
case until > claudeTokenCacheSkew:
ttl = until - claudeTokenCacheSkew
case until > 0:
ttl = until
default:
ttl = time.Minute
}
}
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
}
}
}
return accessToken, nil
}

View File

@@ -0,0 +1,939 @@
//go:build unit
package service
import (
"context"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
// claudeTokenCacheStub implements ClaudeTokenCache for testing
type claudeTokenCacheStub struct {
mu sync.Mutex
tokens map[string]string
getErr error
setErr error
deleteErr error
lockAcquired bool
lockErr error
releaseLockErr error
getCalled int32
setCalled int32
lockCalled int32
unlockCalled int32
simulateLockRace bool
}
func newClaudeTokenCacheStub() *claudeTokenCacheStub {
return &claudeTokenCacheStub{
tokens: make(map[string]string),
lockAcquired: true,
}
}
func (s *claudeTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
atomic.AddInt32(&s.getCalled, 1)
if s.getErr != nil {
return "", s.getErr
}
s.mu.Lock()
defer s.mu.Unlock()
return s.tokens[cacheKey], nil
}
func (s *claudeTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
atomic.AddInt32(&s.setCalled, 1)
if s.setErr != nil {
return s.setErr
}
s.mu.Lock()
defer s.mu.Unlock()
s.tokens[cacheKey] = token
return nil
}
func (s *claudeTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
if s.deleteErr != nil {
return s.deleteErr
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.tokens, cacheKey)
return nil
}
func (s *claudeTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
atomic.AddInt32(&s.lockCalled, 1)
if s.lockErr != nil {
return false, s.lockErr
}
if s.simulateLockRace {
return false, nil
}
return s.lockAcquired, nil
}
func (s *claudeTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
atomic.AddInt32(&s.unlockCalled, 1)
return s.releaseLockErr
}
// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider
type claudeAccountRepoStub struct {
account *Account
getErr error
updateErr error
getCalled int32
updateCalled int32
}
func (r *claudeAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
atomic.AddInt32(&r.getCalled, 1)
if r.getErr != nil {
return nil, r.getErr
}
return r.account, nil
}
func (r *claudeAccountRepoStub) Update(ctx context.Context, account *Account) error {
atomic.AddInt32(&r.updateCalled, 1)
if r.updateErr != nil {
return r.updateErr
}
r.account = account
return nil
}
// claudeOAuthServiceStub implements OAuthService methods for testing
type claudeOAuthServiceStub struct {
tokenInfo *TokenInfo
refreshErr error
refreshCalled int32
}
func (s *claudeOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
atomic.AddInt32(&s.refreshCalled, 1)
if s.refreshErr != nil {
return nil, s.refreshErr
}
return s.tokenInfo, nil
}
// testClaudeTokenProvider is a test version that uses the stub OAuth service
type testClaudeTokenProvider struct {
accountRepo *claudeAccountRepoStub
tokenCache *claudeTokenCacheStub
oauthService *claudeOAuthServiceStub
}
func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
if account == nil {
return "", errors.New("account is nil")
}
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
return "", errors.New("not an anthropic oauth account")
}
cacheKey := ClaudeTokenCacheKey(account)
// 1. Check cache
if p.tokenCache != nil {
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
// 2. Check if refresh needed
expiresAt := account.GetCredentialAsTime("expires_at")
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
refreshFailed := false
if needsRefresh && p.tokenCache != nil {
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
if err == nil && locked {
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
// Check cache again after acquiring lock
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
// Get fresh account from DB
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
if err == nil && fresh != nil {
account = fresh
}
expiresAt = account.GetCredentialAsTime("expires_at")
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
if p.oauthService == nil {
refreshFailed = true // 无法刷新,标记失败
} else {
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
if err != nil {
refreshFailed = true // 刷新失败,标记以使用短 TTL
} else {
// Build new credentials
newCredentials := make(map[string]any)
for k, v := range account.Credentials {
newCredentials[k] = v
}
newCredentials["access_token"] = tokenInfo.AccessToken
newCredentials["token_type"] = tokenInfo.TokenType
newCredentials["expires_at"] = time.Now().Add(time.Duration(tokenInfo.ExpiresIn) * time.Second).Format(time.RFC3339)
if tokenInfo.RefreshToken != "" {
newCredentials["refresh_token"] = tokenInfo.RefreshToken
}
account.Credentials = newCredentials
_ = p.accountRepo.Update(ctx, account)
expiresAt = account.GetCredentialAsTime("expires_at")
}
}
}
} else if p.tokenCache.simulateLockRace {
// Wait and retry cache
time.Sleep(10 * time.Millisecond)
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
return token, nil
}
}
}
accessToken := account.GetCredential("access_token")
if accessToken == "" {
return "", errors.New("access_token not found in credentials")
}
// 3. Store in cache
if p.tokenCache != nil {
ttl := 30 * time.Minute
if refreshFailed {
ttl = time.Minute // 刷新失败时使用短 TTL
} else if expiresAt != nil {
until := time.Until(*expiresAt)
if until > claudeTokenCacheSkew {
ttl = until - claudeTokenCacheSkew
} else if until > 0 {
ttl = until
} else {
ttl = time.Minute
}
}
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
}
return accessToken, nil
}
func TestClaudeTokenProvider_CacheHit(t *testing.T) {
cache := newClaudeTokenCacheStub()
account := &Account{
ID: 100,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "db-token",
},
}
cacheKey := ClaudeTokenCacheKey(account)
cache.tokens[cacheKey] = "cached-token"
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "cached-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
}
func TestClaudeTokenProvider_CacheMiss_FromCredentials(t *testing.T) {
cache := newClaudeTokenCacheStub()
// Token expires in far future, no refresh needed
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 101,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "credential-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "credential-token", token)
// Should have stored in cache
cacheKey := ClaudeTokenCacheKey(account)
require.Equal(t, "credential-token", cache.tokens[cacheKey])
}
func TestClaudeTokenProvider_TokenRefresh(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh-token",
TokenType: "Bearer",
ExpiresIn: 3600,
ExpiresAt: time.Now().Add(time.Hour).Unix(),
},
}
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 102,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
}
func TestClaudeTokenProvider_LockRaceCondition(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.simulateLockRace = true
accountRepo := &claudeAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 103,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "race-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
// Simulate another worker already refreshed and cached
cacheKey := ClaudeTokenCacheKey(account)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestClaudeTokenProvider_NilAccount(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), nil)
require.Error(t, err)
require.Contains(t, err.Error(), "account is nil")
require.Empty(t, token)
}
func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
account := &Account{
ID: 104,
Platform: PlatformOpenAI,
Type: AccountTypeOAuth,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Empty(t, token)
}
func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
account := &Account{
ID: 105,
Platform: PlatformAnthropic,
Type: AccountTypeAPIKey,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Empty(t, token)
}
func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
provider := NewClaudeTokenProvider(nil, nil, nil)
account := &Account{
ID: 106,
Platform: PlatformAnthropic,
Type: AccountTypeSetupToken,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "not an anthropic oauth account")
require.Empty(t, token)
}
func TestClaudeTokenProvider_NilCache(t *testing.T) {
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 107,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "nocache-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, nil, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "nocache-token", token)
}
func TestClaudeTokenProvider_CacheGetError(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.getErr = errors.New("redis connection failed")
// Token doesn't need refresh
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 108,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
// Should gracefully degrade and return from credentials
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-token", token)
}
func TestClaudeTokenProvider_CacheSetError(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.setErr = errors.New("redis write failed")
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 109,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "still-works-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
// Should still work even if cache set fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "still-works-token", token)
}
func TestClaudeTokenProvider_MissingAccessToken(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 110,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// missing access_token
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestClaudeTokenProvider_RefreshError(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
refreshErr: errors.New("oauth refresh failed"),
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 111,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Now with fallback behavior, should return existing token even if refresh fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestClaudeTokenProvider_OAuthServiceNotConfigured(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 112,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: nil, // not configured
}
// Now with fallback behavior, should return existing token even if oauth service not configured
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "old-token", token) // Fallback to existing token
}
func TestClaudeTokenProvider_TTLCalculation(t *testing.T) {
tests := []struct {
name string
expiresIn time.Duration
}{
{
name: "far_future_expiry",
expiresIn: 1 * time.Hour,
},
{
name: "medium_expiry",
expiresIn: 10 * time.Minute,
},
{
name: "near_expiry",
expiresIn: 6 * time.Minute,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
account := &Account{
ID: 200,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "test-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
_, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
// Verify token was cached
cacheKey := ClaudeTokenCacheKey(account)
require.Equal(t, "test-token", cache.tokens[cacheKey])
})
}
}
func TestClaudeTokenProvider_AccountRepoGetError(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{
getErr: errors.New("db connection failed"),
}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 113,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh",
"expires_at": expiresAt,
},
}
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Should still work, just using the passed-in account
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
}
func TestClaudeTokenProvider_AccountUpdateError(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{
updateErr: errors.New("db write failed"),
}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 114,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"refresh_token": "old-refresh",
"expires_at": expiresAt,
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
// Should still return token even if update fails
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "refreshed-token", token)
}
func TestClaudeTokenProvider_RefreshPreservesExistingCredentials(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "new-access-token",
RefreshToken: "new-refresh-token",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 115,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-access-token",
"refresh_token": "old-refresh-token",
"expires_at": expiresAt,
"custom_field": "should-be-preserved",
"organization": "test-org",
},
}
accountRepo.account = account
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "new-access-token", token)
// Verify existing fields are preserved
require.Equal(t, "should-be-preserved", accountRepo.account.Credentials["custom_field"])
require.Equal(t, "test-org", accountRepo.account.Credentials["organization"])
// Verify new fields are updated
require.Equal(t, "new-access-token", accountRepo.account.Credentials["access_token"])
require.Equal(t, "new-refresh-token", accountRepo.account.Credentials["refresh_token"])
}
func TestClaudeTokenProvider_DoubleCheckCacheAfterLock(t *testing.T) {
cache := newClaudeTokenCacheStub()
accountRepo := &claudeAccountRepoStub{}
oauthService := &claudeOAuthServiceStub{
tokenInfo: &TokenInfo{
AccessToken: "refreshed-token",
RefreshToken: "new-refresh",
TokenType: "Bearer",
ExpiresIn: 3600,
},
}
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 116,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "old-token",
"expires_at": expiresAt,
},
}
accountRepo.account = account
cacheKey := ClaudeTokenCacheKey(account)
// After lock is acquired, cache should have the token (simulating another worker)
go func() {
time.Sleep(5 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "cached-by-other-worker"
cache.mu.Unlock()
}()
provider := &testClaudeTokenProvider{
accountRepo: accountRepo,
tokenCache: cache,
oauthService: oauthService,
}
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
// Tests for real provider - to increase coverage
func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon (within refresh skew) to trigger lock attempt
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 300,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-token",
"expires_at": expiresAt,
},
}
// Set token in cache after lock wait period (simulate other worker refreshing)
cacheKey := ClaudeTokenCacheKey(account)
go func() {
time.Sleep(100 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "refreshed-by-other"
cache.mu.Unlock()
}()
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Lock acquisition fails
// Token expires soon
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 301,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "original-token",
"expires_at": expiresAt,
},
}
cacheKey := ClaudeTokenCacheKey(account)
// Set token in cache immediately after wait starts
go func() {
time.Sleep(50 * time.Millisecond)
cache.mu.Lock()
cache.tokens[cacheKey] = "winner-token"
cache.mu.Unlock()
}()
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.NotEmpty(t, token)
}
func TestClaudeTokenProvider_Real_NoExpiresAt(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockAcquired = false // Prevent entering refresh logic
// Token with nil expires_at (no expiry set)
account := &Account{
ID: 302,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "no-expiry-token",
},
}
// After lock wait, return token from credentials
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "no-expiry-token", token)
}
func TestClaudeTokenProvider_Real_WhitespaceToken(t *testing.T) {
cache := newClaudeTokenCacheStub()
cacheKey := "claude:account:303"
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 303,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "real-token",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "real-token", token)
}
func TestClaudeTokenProvider_Real_EmptyCredentialToken(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 304,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": " ", // Whitespace only
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}
func TestClaudeTokenProvider_Real_LockError(t *testing.T) {
cache := newClaudeTokenCacheStub()
cache.lockErr = errors.New("redis lock failed")
// Token expires soon (within refresh skew)
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
account := &Account{
ID: 305,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"access_token": "fallback-on-lock-error",
"expires_at": expiresAt,
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.NoError(t, err)
require.Equal(t, "fallback-on-lock-error", token)
}
func TestClaudeTokenProvider_Real_NilCredentials(t *testing.T) {
cache := newClaudeTokenCacheStub()
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
account := &Account{
ID: 306,
Platform: PlatformAnthropic,
Type: AccountTypeOAuth,
Credentials: map[string]any{
"expires_at": expiresAt,
// No access_token
},
}
provider := NewClaudeTokenProvider(nil, cache, nil)
token, err := provider.GetAccessToken(context.Background(), account)
require.Error(t, err)
require.Contains(t, err.Error(), "access_token not found")
require.Empty(t, token)
}

View File

@@ -0,0 +1,360 @@
package service
import (
"context"
"crypto/rand"
"encoding/binary"
"os"
"strconv"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
// ConcurrencyCache 定义并发控制的缓存接口
// 使用有序集合存储槽位,按时间戳清理过期条目
type ConcurrencyCache interface {
// 账号槽位管理
// 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
// 账号等待队列(账号级)
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
// 用户槽位管理
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
// 等待队列计数(只在首次创建时设置 TTL
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
DecrementWaitCount(ctx context.Context, userID int64) error
// 批量负载查询(只读)
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error)
// 清理过期槽位(后台任务)
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
// 启动时清理旧进程遗留槽位与等待计数
CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error
}
var (
requestIDPrefix = initRequestIDPrefix()
requestIDCounter atomic.Uint64
)
func initRequestIDPrefix() string {
b := make([]byte, 8)
if _, err := rand.Read(b); err == nil {
return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36)
}
fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16)
return "r" + strconv.FormatUint(fallback, 36)
}
func RequestIDPrefix() string {
return requestIDPrefix
}
func generateRequestID() string {
seq := requestIDCounter.Add(1)
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
}
func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error {
if s == nil || s.cache == nil {
return nil
}
return s.cache.CleanupStaleProcessSlots(ctx, RequestIDPrefix())
}
const (
// Default extra wait slots beyond concurrency limit
defaultExtraWaitSlots = 20
)
// ConcurrencyService manages concurrent request limiting for accounts and users
type ConcurrencyService struct {
cache ConcurrencyCache
}
// NewConcurrencyService creates a new ConcurrencyService
func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService {
return &ConcurrencyService{cache: cache}
}
// AcquireResult represents the result of acquiring a concurrency slot
type AcquireResult struct {
Acquired bool
ReleaseFunc func() // Must be called when done (typically via defer)
}
type AccountWithConcurrency struct {
ID int64
MaxConcurrency int
}
type UserWithConcurrency struct {
ID int64
MaxConcurrency int
}
type AccountLoadInfo struct {
AccountID int64
CurrentConcurrency int
WaitingCount int
LoadRate int // 0-100+ (percent)
}
type UserLoadInfo struct {
UserID int64
CurrentConcurrency int
WaitingCount int
LoadRate int // 0-100+ (percent)
}
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
// If the account is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
// If maxConcurrency is 0 or negative, no limit
if maxConcurrency <= 0 {
return &AcquireResult{
Acquired: true,
ReleaseFunc: func() {}, // no-op
}, nil
}
// Generate unique request ID for this slot
requestID := generateRequestID()
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency, requestID)
if err != nil {
return nil, err
}
if acquired {
return &AcquireResult{
Acquired: true,
ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil {
logger.LegacyPrintf("service.concurrency", "Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err)
}
},
}, nil
}
return &AcquireResult{
Acquired: false,
ReleaseFunc: nil,
}, nil
}
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
// If the user is at max concurrency, it waits until a slot is available or timeout.
// Returns a release function that MUST be called when the request completes.
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
// If maxConcurrency is 0 or negative, no limit
if maxConcurrency <= 0 {
return &AcquireResult{
Acquired: true,
ReleaseFunc: func() {}, // no-op
}, nil
}
// Generate unique request ID for this slot
requestID := generateRequestID()
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency, requestID)
if err != nil {
return nil, err
}
if acquired {
return &AcquireResult{
Acquired: true,
ReleaseFunc: func() {
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil {
logger.LegacyPrintf("service.concurrency", "Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err)
}
},
}, nil
}
return &AcquireResult{
Acquired: false,
ReleaseFunc: nil,
}, nil
}
// ============================================
// Wait Queue Count Methods
// ============================================
// IncrementWaitCount attempts to increment the wait queue counter for a user.
// Returns true if successful, false if the wait queue is full.
// maxWait should be user.Concurrency + defaultExtraWaitSlots
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
if s.cache == nil {
// Redis not available, allow request
return true, nil
}
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
if err != nil {
// On error, allow the request to proceed (fail open)
logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for user %d: %v", userID, err)
return true, nil
}
return result, nil
}
// DecrementWaitCount decrements the wait queue counter for a user.
// Should be called when a request completes or exits the wait queue.
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
if s.cache == nil {
return
}
// Use background context to ensure decrement even if original context is cancelled
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for user %d: %v", userID, err)
}
}
// IncrementAccountWaitCount increments the wait queue counter for an account.
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
if s.cache == nil {
return true, nil
}
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
if err != nil {
logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for account %d: %v", accountID, err)
return true, nil
}
return result, nil
}
// DecrementAccountWaitCount decrements the wait queue counter for an account.
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
if s.cache == nil {
return
}
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for account %d: %v", accountID, err)
}
}
// GetAccountWaitingCount gets current wait queue count for an account.
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
if s.cache == nil {
return 0, nil
}
return s.cache.GetAccountWaitingCount(ctx, accountID)
}
// CalculateMaxWait calculates the maximum wait queue size for a user
// maxWait = userConcurrency + defaultExtraWaitSlots
func CalculateMaxWait(userConcurrency int) int {
if userConcurrency <= 0 {
userConcurrency = 1
}
return userConcurrency + defaultExtraWaitSlots
}
// GetAccountsLoadBatch returns load info for multiple accounts.
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
if s.cache == nil {
return map[int64]*AccountLoadInfo{}, nil
}
return s.cache.GetAccountsLoadBatch(ctx, accounts)
}
// GetUsersLoadBatch returns load info for multiple users.
func (s *ConcurrencyService) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
if s.cache == nil {
return map[int64]*UserLoadInfo{}, nil
}
return s.cache.GetUsersLoadBatch(ctx, users)
}
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
if s.cache == nil {
return nil
}
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
}
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
return
}
runCleanup := func() {
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
accounts, err := accountRepo.ListSchedulable(listCtx)
cancel()
if err != nil {
logger.LegacyPrintf("service.concurrency", "Warning: list schedulable accounts failed: %v", err)
return
}
for _, account := range accounts {
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
accountCancel()
if err != nil {
logger.LegacyPrintf("service.concurrency", "Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
}
}
}
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
runCleanup()
for range ticker.C {
runCleanup()
}
}()
}
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
// Returns a map of accountID -> current concurrency count
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
if len(accountIDs) == 0 {
return map[int64]int{}, nil
}
if s.cache == nil {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
result[accountID] = 0
}
return result, nil
}
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
}

View File

@@ -0,0 +1,337 @@
//go:build unit
package service
import (
"context"
"errors"
"strconv"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩
type stubConcurrencyCacheForTest struct {
acquireResult bool
acquireErr error
releaseErr error
concurrency int
concurrencyErr error
waitAllowed bool
waitErr error
waitCount int
waitCountErr error
loadBatch map[int64]*AccountLoadInfo
loadBatchErr error
usersLoadBatch map[int64]*UserLoadInfo
usersLoadErr error
cleanupErr error
// 记录调用
releasedAccountIDs []int64
releasedRequestIDs []string
}
var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil)
func (c *stubConcurrencyCacheForTest) AcquireAccountSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) {
return c.acquireResult, c.acquireErr
}
func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, accountID int64, requestID string) error {
c.releasedAccountIDs = append(c.releasedAccountIDs, accountID)
c.releasedRequestIDs = append(c.releasedRequestIDs, requestID)
return c.releaseErr
}
func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) {
return c.concurrency, c.concurrencyErr
}
func (c *stubConcurrencyCacheForTest) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
result := make(map[int64]int, len(accountIDs))
for _, accountID := range accountIDs {
if c.concurrencyErr != nil {
return nil, c.concurrencyErr
}
result[accountID] = c.concurrency
}
return result, nil
}
func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
return c.waitAllowed, c.waitErr
}
func (c *stubConcurrencyCacheForTest) DecrementAccountWaitCount(_ context.Context, _ int64) error {
return nil
}
func (c *stubConcurrencyCacheForTest) GetAccountWaitingCount(_ context.Context, _ int64) (int, error) {
return c.waitCount, c.waitCountErr
}
func (c *stubConcurrencyCacheForTest) AcquireUserSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) {
return c.acquireResult, c.acquireErr
}
func (c *stubConcurrencyCacheForTest) ReleaseUserSlot(_ context.Context, _ int64, _ string) error {
return c.releaseErr
}
func (c *stubConcurrencyCacheForTest) GetUserConcurrency(_ context.Context, _ int64) (int, error) {
return c.concurrency, c.concurrencyErr
}
func (c *stubConcurrencyCacheForTest) IncrementWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
return c.waitAllowed, c.waitErr
}
func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ int64) error {
return nil
}
func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
return c.loadBatch, c.loadBatchErr
}
func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
return c.usersLoadBatch, c.usersLoadErr
}
func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
return c.cleanupErr
}
func (c *stubConcurrencyCacheForTest) CleanupStaleProcessSlots(_ context.Context, _ string) error {
return c.cleanupErr
}
type trackingConcurrencyCache struct {
stubConcurrencyCacheForTest
cleanupPrefix string
}
func (c *trackingConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, prefix string) error {
c.cleanupPrefix = prefix
return c.cleanupErr
}
func TestCleanupStaleProcessSlots_NilCache(t *testing.T) {
svc := &ConcurrencyService{cache: nil}
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
}
func TestCleanupStaleProcessSlots_DelegatesPrefix(t *testing.T) {
cache := &trackingConcurrencyCache{}
svc := NewConcurrencyService(cache)
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
require.Equal(t, RequestIDPrefix(), cache.cleanupPrefix)
}
func TestAcquireAccountSlot_Success(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireResult: true}
svc := NewConcurrencyService(cache)
result, err := svc.AcquireAccountSlot(context.Background(), 1, 5)
require.NoError(t, err)
require.True(t, result.Acquired)
require.NotNil(t, result.ReleaseFunc)
}
func TestAcquireAccountSlot_Failure(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireResult: false}
svc := NewConcurrencyService(cache)
result, err := svc.AcquireAccountSlot(context.Background(), 1, 5)
require.NoError(t, err)
require.False(t, result.Acquired)
require.Nil(t, result.ReleaseFunc)
}
func TestAcquireAccountSlot_UnlimitedConcurrency(t *testing.T) {
svc := NewConcurrencyService(&stubConcurrencyCacheForTest{})
for _, maxConcurrency := range []int{0, -1} {
result, err := svc.AcquireAccountSlot(context.Background(), 1, maxConcurrency)
require.NoError(t, err)
require.True(t, result.Acquired, "maxConcurrency=%d 应无限制通过", maxConcurrency)
require.NotNil(t, result.ReleaseFunc, "ReleaseFunc 应为 no-op 函数")
}
}
func TestAcquireAccountSlot_CacheError(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireErr: errors.New("redis down")}
svc := NewConcurrencyService(cache)
result, err := svc.AcquireAccountSlot(context.Background(), 1, 5)
require.Error(t, err)
require.Nil(t, result)
}
func TestAcquireAccountSlot_ReleaseDecrements(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireResult: true}
svc := NewConcurrencyService(cache)
result, err := svc.AcquireAccountSlot(context.Background(), 42, 5)
require.NoError(t, err)
require.True(t, result.Acquired)
// 调用 ReleaseFunc 应释放槽位
result.ReleaseFunc()
require.Len(t, cache.releasedAccountIDs, 1)
require.Equal(t, int64(42), cache.releasedAccountIDs[0])
require.Len(t, cache.releasedRequestIDs, 1)
require.NotEmpty(t, cache.releasedRequestIDs[0], "requestID 不应为空")
}
func TestAcquireUserSlot_IndependentFromAccount(t *testing.T) {
cache := &stubConcurrencyCacheForTest{acquireResult: true}
svc := NewConcurrencyService(cache)
// 用户槽位获取应独立于账户槽位
result, err := svc.AcquireUserSlot(context.Background(), 100, 3)
require.NoError(t, err)
require.True(t, result.Acquired)
require.NotNil(t, result.ReleaseFunc)
}
func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) {
svc := NewConcurrencyService(&stubConcurrencyCacheForTest{})
result, err := svc.AcquireUserSlot(context.Background(), 1, 0)
require.NoError(t, err)
require.True(t, result.Acquired)
}
func TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter(t *testing.T) {
id1 := generateRequestID()
id2 := generateRequestID()
require.NotEmpty(t, id1)
require.NotEmpty(t, id2)
p1 := strings.Split(id1, "-")
p2 := strings.Split(id2, "-")
require.Len(t, p1, 2)
require.Len(t, p2, 2)
require.Equal(t, p1[0], p2[0], "同一进程前缀应保持一致")
n1, err := strconv.ParseUint(p1[1], 36, 64)
require.NoError(t, err)
n2, err := strconv.ParseUint(p2[1], 36, 64)
require.NoError(t, err)
require.Equal(t, n1+1, n2, "计数器应单调递增")
}
func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) {
expected := map[int64]*AccountLoadInfo{
1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60},
2: {AccountID: 2, CurrentConcurrency: 5, WaitingCount: 2, LoadRate: 100},
}
cache := &stubConcurrencyCacheForTest{loadBatch: expected}
svc := NewConcurrencyService(cache)
accounts := []AccountWithConcurrency{
{ID: 1, MaxConcurrency: 5},
{ID: 2, MaxConcurrency: 5},
}
result, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
require.NoError(t, err)
require.Equal(t, expected, result)
}
func TestGetAccountsLoadBatch_NilCache(t *testing.T) {
svc := &ConcurrencyService{cache: nil}
result, err := svc.GetAccountsLoadBatch(context.Background(), nil)
require.NoError(t, err)
require.Empty(t, result)
}
func TestIncrementWaitCount_Success(t *testing.T) {
cache := &stubConcurrencyCacheForTest{waitAllowed: true}
svc := NewConcurrencyService(cache)
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
require.NoError(t, err)
require.True(t, allowed)
}
func TestIncrementWaitCount_QueueFull(t *testing.T) {
cache := &stubConcurrencyCacheForTest{waitAllowed: false}
svc := NewConcurrencyService(cache)
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
require.NoError(t, err)
require.False(t, allowed)
}
func TestIncrementWaitCount_FailOpen(t *testing.T) {
// Redis 错误时应 fail-open允许请求通过
cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis timeout")}
svc := NewConcurrencyService(cache)
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
require.NoError(t, err, "Redis 错误不应传播")
require.True(t, allowed, "Redis 错误时应 fail-open")
}
func TestIncrementWaitCount_NilCache(t *testing.T) {
svc := &ConcurrencyService{cache: nil}
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
require.NoError(t, err)
require.True(t, allowed, "nil cache 应 fail-open")
}
func TestCalculateMaxWait(t *testing.T) {
tests := []struct {
concurrency int
expected int
}{
{5, 25}, // 5 + 20
{1, 21}, // 1 + 20
{0, 21}, // min(1) + 20
{-1, 21}, // min(1) + 20
{10, 30}, // 10 + 20
}
for _, tt := range tests {
result := CalculateMaxWait(tt.concurrency)
require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency)
}
}
func TestGetAccountWaitingCount(t *testing.T) {
cache := &stubConcurrencyCacheForTest{waitCount: 5}
svc := NewConcurrencyService(cache)
count, err := svc.GetAccountWaitingCount(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, 5, count)
}
func TestGetAccountWaitingCount_NilCache(t *testing.T) {
svc := &ConcurrencyService{cache: nil}
count, err := svc.GetAccountWaitingCount(context.Background(), 1)
require.NoError(t, err)
require.Equal(t, 0, count)
}
func TestGetAccountConcurrencyBatch(t *testing.T) {
cache := &stubConcurrencyCacheForTest{concurrency: 3}
svc := NewConcurrencyService(cache)
result, err := svc.GetAccountConcurrencyBatch(context.Background(), []int64{1, 2, 3})
require.NoError(t, err)
require.Len(t, result, 3)
for _, id := range []int64{1, 2, 3} {
require.Equal(t, 3, result[id])
}
}
func TestIncrementAccountWaitCount_FailOpen(t *testing.T) {
cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis error")}
svc := NewConcurrencyService(cache)
allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10)
require.NoError(t, err, "Redis 错误不应传播")
require.True(t, allowed, "Redis 错误时应 fail-open")
}
func TestIncrementAccountWaitCount_NilCache(t *testing.T) {
svc := &ConcurrencyService{cache: nil}
allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10)
require.NoError(t, err)
require.True(t, allowed)
}

View File

@@ -0,0 +1,112 @@
package service
import (
"testing"
)
func TestBuildSelectedSet(t *testing.T) {
tests := []struct {
name string
ids []string
wantNil bool
wantSize int
}{
{
name: "nil input returns nil (backward compatible: create all)",
ids: nil,
wantNil: true,
},
{
name: "empty slice returns empty map (create none)",
ids: []string{},
wantNil: false,
wantSize: 0,
},
{
name: "single ID",
ids: []string{"abc-123"},
wantNil: false,
wantSize: 1,
},
{
name: "multiple IDs",
ids: []string{"a", "b", "c"},
wantNil: false,
wantSize: 3,
},
{
name: "duplicate IDs are deduplicated",
ids: []string{"a", "a", "b"},
wantNil: false,
wantSize: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := buildSelectedSet(tt.ids)
if tt.wantNil {
if got != nil {
t.Errorf("buildSelectedSet(%v) = %v, want nil", tt.ids, got)
}
return
}
if got == nil {
t.Fatalf("buildSelectedSet(%v) = nil, want non-nil map", tt.ids)
}
if len(got) != tt.wantSize {
t.Errorf("buildSelectedSet(%v) has %d entries, want %d", tt.ids, len(got), tt.wantSize)
}
// Verify all unique IDs are present
for _, id := range tt.ids {
if _, ok := got[id]; !ok {
t.Errorf("buildSelectedSet(%v) missing key %q", tt.ids, id)
}
}
})
}
}
func TestShouldCreateAccount(t *testing.T) {
tests := []struct {
name string
crsID string
selectedSet map[string]struct{}
want bool
}{
{
name: "nil set allows all (backward compatible)",
crsID: "any-id",
selectedSet: nil,
want: true,
},
{
name: "empty set blocks all",
crsID: "any-id",
selectedSet: map[string]struct{}{},
want: false,
},
{
name: "ID in set is allowed",
crsID: "abc-123",
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
want: true,
},
{
name: "ID not in set is blocked",
crsID: "xyz-789",
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := shouldCreateAccount(tt.crsID, tt.selectedSet)
if got != tt.want {
t.Errorf("shouldCreateAccount(%q, %v) = %v, want %v",
tt.crsID, tt.selectedSet, got, tt.want)
}
})
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,322 @@
package service
import (
"context"
"errors"
"log/slog"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
)
const (
defaultDashboardAggregationTimeout = 2 * time.Minute
defaultDashboardAggregationBackfillTimeout = 30 * time.Minute
dashboardAggregationRetentionInterval = 6 * time.Hour
)
var (
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
errDashboardAggregationRunning = errors.New("聚合作业正在运行")
)
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
type DashboardAggregationRepository interface {
AggregateRange(ctx context.Context, start, end time.Time) error
// RecomputeRange 重新计算指定时间范围内的聚合数据(包含活跃用户等派生表)。
// 设计目的:当 usage_logs 被批量删除/回滚后,确保聚合表可恢复一致性。
RecomputeRange(ctx context.Context, start, end time.Time) error
GetAggregationWatermark(ctx context.Context) (time.Time, error)
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
}
// DashboardAggregationService 负责定时聚合与回填。
type DashboardAggregationService struct {
repo DashboardAggregationRepository
timingWheel *TimingWheelService
cfg config.DashboardAggregationConfig
running int32
lastRetentionCleanup atomic.Value // time.Time
}
// NewDashboardAggregationService 创建聚合服务。
func NewDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
var aggCfg config.DashboardAggregationConfig
if cfg != nil {
aggCfg = cfg.DashboardAgg
}
return &DashboardAggregationService{
repo: repo,
timingWheel: timingWheel,
cfg: aggCfg,
}
}
// Start 启动定时聚合作业(重启生效配置)。
func (s *DashboardAggregationService) Start() {
if s == nil || s.repo == nil || s.timingWheel == nil {
return
}
if !s.cfg.Enabled {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业已禁用")
return
}
interval := time.Duration(s.cfg.IntervalSeconds) * time.Second
if interval <= 0 {
interval = time.Minute
}
if s.cfg.RecomputeDays > 0 {
go s.recomputeRecentDays()
}
s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() {
s.runScheduledAggregation()
})
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds)
if !s.cfg.BackfillEnabled {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填")
}
}
// TriggerBackfill 触发回填(异步)。
func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) error {
if s == nil || s.repo == nil {
return errors.New("聚合服务未初始化")
}
if !s.cfg.BackfillEnabled {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填被拒绝: backfill_enabled=false")
return ErrDashboardBackfillDisabled
}
if !end.After(start) {
return errors.New("回填时间范围无效")
}
if s.cfg.BackfillMaxDays > 0 {
maxRange := time.Duration(s.cfg.BackfillMaxDays) * 24 * time.Hour
if end.Sub(start) > maxRange {
return ErrDashboardBackfillTooLarge
}
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
defer cancel()
if err := s.backfillRange(ctx, start, end); err != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填失败: %v", err)
}
}()
return nil
}
// TriggerRecomputeRange 触发指定范围的重新计算(异步)。
// 与 TriggerBackfill 不同:
// - 不依赖 backfill_enabled这是内部一致性修复
// - 不更新 watermark避免影响正常增量聚合游标
func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time) error {
if s == nil || s.repo == nil {
return errors.New("聚合服务未初始化")
}
if !s.cfg.Enabled {
return errors.New("聚合服务已禁用")
}
if !end.After(start) {
return errors.New("重新计算时间范围无效")
}
go func() {
const maxRetries = 3
for i := 0; i < maxRetries; i++ {
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
err := s.recomputeRange(ctx, start, end)
cancel()
if err == nil {
return
}
if !errors.Is(err, errDashboardAggregationRunning) {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算失败: %v", err)
return
}
time.Sleep(5 * time.Second)
}
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算放弃: 聚合作业持续占用")
}()
return nil
}
func (s *DashboardAggregationService) recomputeRecentDays() {
days := s.cfg.RecomputeDays
if days <= 0 {
return
}
now := time.Now().UTC()
start := now.AddDate(0, 0, -days)
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
defer cancel()
if err := s.backfillRange(ctx, start, now); err != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 启动重算失败: %v", err)
return
}
}
func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, end time.Time) error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return errDashboardAggregationRunning
}
defer atomic.StoreInt32(&s.running, 0)
jobStart := time.Now().UTC()
if err := s.repo.RecomputeRange(ctx, start, end); err != nil {
return err
}
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)",
start.UTC().Format(time.RFC3339),
end.UTC().Format(time.RFC3339),
time.Since(jobStart).String(),
)
return nil
}
func (s *DashboardAggregationService) runScheduledAggregation() {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return
}
defer atomic.StoreInt32(&s.running, 0)
jobStart := time.Now().UTC()
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationTimeout)
defer cancel()
now := time.Now().UTC()
last, err := s.repo.GetAggregationWatermark(ctx)
if err != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 读取水位失败: %v", err)
last = time.Unix(0, 0).UTC()
}
lookback := time.Duration(s.cfg.LookbackSeconds) * time.Second
epoch := time.Unix(0, 0).UTC()
start := last.Add(-lookback)
if !last.After(epoch) {
retentionDays := s.cfg.Retention.UsageLogsDays
if retentionDays <= 0 {
retentionDays = 1
}
start = truncateToDayUTC(now.AddDate(0, 0, -retentionDays))
} else if start.After(now) {
start = now.Add(-lookback)
}
if err := s.aggregateRange(ctx, start, now); err != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合失败: %v", err)
return
}
updateErr := s.repo.UpdateAggregationWatermark(ctx, now)
if updateErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr)
}
slog.Debug("[DashboardAggregation] 聚合完成",
"start", start.Format(time.RFC3339),
"end", now.Format(time.RFC3339),
"duration", time.Since(jobStart).String(),
"watermark_updated", updateErr == nil,
)
s.maybeCleanupRetention(ctx, now)
}
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
return errDashboardAggregationRunning
}
defer atomic.StoreInt32(&s.running, 0)
jobStart := time.Now().UTC()
startUTC := start.UTC()
endUTC := end.UTC()
if !endUTC.After(startUTC) {
return errors.New("回填时间范围无效")
}
cursor := truncateToDayUTC(startUTC)
for cursor.Before(endUTC) {
windowEnd := cursor.Add(24 * time.Hour)
if windowEnd.After(endUTC) {
windowEnd = endUTC
}
if err := s.aggregateRange(ctx, cursor, windowEnd); err != nil {
return err
}
cursor = windowEnd
}
updateErr := s.repo.UpdateAggregationWatermark(ctx, endUTC)
if updateErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr)
}
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
startUTC.Format(time.RFC3339),
endUTC.Format(time.RFC3339),
time.Since(jobStart).String(),
updateErr == nil,
)
s.maybeCleanupRetention(ctx, endUTC)
return nil
}
func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, end time.Time) error {
if !end.After(start) {
return nil
}
if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 分区检查失败: %v", err)
}
return s.repo.AggregateRange(ctx, start, end)
}
func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, now time.Time) {
lastAny := s.lastRetentionCleanup.Load()
if lastAny != nil {
if last, ok := lastAny.(time.Time); ok && now.Sub(last) < dashboardAggregationRetentionInterval {
return
}
}
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
if aggErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合保留清理失败: %v", aggErr)
}
usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff)
if usageErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
}
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
if dedupErr != nil {
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
}
if aggErr == nil && usageErr == nil && dedupErr == nil {
s.lastRetentionCleanup.Store(now)
}
}
func truncateToDayUTC(t time.Time) time.Time {
t = t.UTC()
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
}

View File

@@ -0,0 +1,168 @@
package service
import (
"context"
"errors"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/stretchr/testify/require"
)
type dashboardAggregationRepoTestStub struct {
aggregateCalls int
recomputeCalls int
cleanupUsageCalls int
cleanupDedupCalls int
ensurePartitionCalls int
lastStart time.Time
lastEnd time.Time
watermark time.Time
aggregateErr error
cleanupAggregatesErr error
cleanupUsageErr error
cleanupDedupErr error
ensurePartitionErr error
}
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
s.aggregateCalls++
s.lastStart = start
s.lastEnd = end
return s.aggregateErr
}
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
s.recomputeCalls++
return s.AggregateRange(ctx, start, end)
}
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
return s.watermark, nil
}
func (s *dashboardAggregationRepoTestStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
return nil
}
func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
return s.cleanupAggregatesErr
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
s.cleanupUsageCalls++
return s.cleanupUsageErr
}
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
s.cleanupDedupCalls++
return s.cleanupDedupErr
}
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
s.ensurePartitionCalls++
return s.ensurePartitionErr
}
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{watermark: time.Unix(0, 0).UTC()}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Enabled: true,
IntervalSeconds: 60,
LookbackSeconds: 120,
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.runScheduledAggregation()
require.Equal(t, 1, repo.aggregateCalls)
require.False(t, repo.lastEnd.IsZero())
require.Equal(t, truncateToDayUTC(repo.lastEnd.AddDate(0, 0, -1)), repo.lastStart)
}
func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{cleanupAggregatesErr: errors.New("清理失败")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
require.Equal(t, 1, repo.cleanupUsageCalls)
require.Equal(t, 1, repo.cleanupDedupCalls)
}
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
require.Nil(t, svc.lastRetentionCleanup.Load())
require.Equal(t, 1, repo.cleanupDedupCalls)
}
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
Enabled: true,
IntervalSeconds: 60,
LookbackSeconds: 120,
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 1,
UsageBillingDedupDays: 2,
HourlyDays: 1,
DailyDays: 1,
},
},
}
svc.runScheduledAggregation()
require.Equal(t, 1, repo.ensurePartitionCalls)
require.Equal(t, 1, repo.aggregateCalls)
}
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
repo := &dashboardAggregationRepoTestStub{}
svc := &DashboardAggregationService{
repo: repo,
cfg: config.DashboardAggregationConfig{
BackfillEnabled: true,
BackfillMaxDays: 1,
},
}
start := time.Now().AddDate(0, 0, -3)
end := time.Now()
err := svc.TriggerBackfill(start, end)
require.ErrorIs(t, err, ErrDashboardBackfillTooLarge)
require.Equal(t, 0, repo.aggregateCalls)
}

View File

@@ -0,0 +1,352 @@
package service
import (
"context"
"encoding/json"
"errors"
"fmt"
"sync/atomic"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
)
const (
defaultDashboardStatsFreshTTL = 15 * time.Second
defaultDashboardStatsCacheTTL = 30 * time.Second
defaultDashboardStatsRefreshTimeout = 30 * time.Second
)
// ErrDashboardStatsCacheMiss 标记仪表盘缓存未命中。
var ErrDashboardStatsCacheMiss = errors.New("仪表盘缓存未命中")
// DashboardStatsCache 定义仪表盘统计缓存接口。
type DashboardStatsCache interface {
GetDashboardStats(ctx context.Context) (string, error)
SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error
DeleteDashboardStats(ctx context.Context) error
}
type dashboardStatsRangeFetcher interface {
GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error)
}
type dashboardStatsCacheEntry struct {
Stats *usagestats.DashboardStats `json:"stats"`
UpdatedAt int64 `json:"updated_at"`
}
// DashboardService 提供管理员仪表盘统计服务。
type DashboardService struct {
usageRepo UsageLogRepository
aggRepo DashboardAggregationRepository
cache DashboardStatsCache
cacheFreshTTL time.Duration
cacheTTL time.Duration
refreshTimeout time.Duration
refreshing int32
aggEnabled bool
aggInterval time.Duration
aggLookback time.Duration
aggUsageDays int
}
func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregationRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService {
freshTTL := defaultDashboardStatsFreshTTL
cacheTTL := defaultDashboardStatsCacheTTL
refreshTimeout := defaultDashboardStatsRefreshTimeout
aggEnabled := true
aggInterval := time.Minute
aggLookback := 2 * time.Minute
aggUsageDays := 90
if cfg != nil {
if !cfg.Dashboard.Enabled {
cache = nil
}
if cfg.Dashboard.StatsFreshTTLSeconds > 0 {
freshTTL = time.Duration(cfg.Dashboard.StatsFreshTTLSeconds) * time.Second
}
if cfg.Dashboard.StatsTTLSeconds > 0 {
cacheTTL = time.Duration(cfg.Dashboard.StatsTTLSeconds) * time.Second
}
if cfg.Dashboard.StatsRefreshTimeoutSeconds > 0 {
refreshTimeout = time.Duration(cfg.Dashboard.StatsRefreshTimeoutSeconds) * time.Second
}
aggEnabled = cfg.DashboardAgg.Enabled
if cfg.DashboardAgg.IntervalSeconds > 0 {
aggInterval = time.Duration(cfg.DashboardAgg.IntervalSeconds) * time.Second
}
if cfg.DashboardAgg.LookbackSeconds > 0 {
aggLookback = time.Duration(cfg.DashboardAgg.LookbackSeconds) * time.Second
}
if cfg.DashboardAgg.Retention.UsageLogsDays > 0 {
aggUsageDays = cfg.DashboardAgg.Retention.UsageLogsDays
}
}
if aggRepo == nil {
aggEnabled = false
}
return &DashboardService{
usageRepo: usageRepo,
aggRepo: aggRepo,
cache: cache,
cacheFreshTTL: freshTTL,
cacheTTL: cacheTTL,
refreshTimeout: refreshTimeout,
aggEnabled: aggEnabled,
aggInterval: aggInterval,
aggLookback: aggLookback,
aggUsageDays: aggUsageDays,
}
}
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
if s.cache != nil {
cached, fresh, err := s.getCachedDashboardStats(ctx)
if err == nil && cached != nil {
s.refreshAggregationStaleness(cached)
if !fresh {
s.refreshDashboardStatsAsync()
}
return cached, nil
}
if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) {
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存读取失败: %v", err)
}
}
stats, err := s.refreshDashboardStats(ctx)
if err != nil {
return nil, fmt.Errorf("get dashboard stats: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get usage trend with filters: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get model stats with filters: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
if err != nil {
return nil, fmt.Errorf("get group stats with filters: %w", err)
}
return stats, nil
}
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
data, err := s.cache.GetDashboardStats(ctx)
if err != nil {
return nil, false, err
}
var entry dashboardStatsCacheEntry
if err := json.Unmarshal([]byte(data), &entry); err != nil {
s.evictDashboardStatsCache(err)
return nil, false, ErrDashboardStatsCacheMiss
}
if entry.Stats == nil {
s.evictDashboardStatsCache(errors.New("仪表盘缓存缺少统计数据"))
return nil, false, ErrDashboardStatsCacheMiss
}
age := time.Since(time.Unix(entry.UpdatedAt, 0))
return entry.Stats, age <= s.cacheFreshTTL, nil
}
func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
stats, err := s.fetchDashboardStats(ctx)
if err != nil {
return nil, err
}
s.applyAggregationStatus(ctx, stats)
cacheCtx, cancel := s.cacheOperationContext()
defer cancel()
s.saveDashboardStatsCache(cacheCtx, stats)
return stats, nil
}
func (s *DashboardService) refreshDashboardStatsAsync() {
if s.cache == nil {
return
}
if !atomic.CompareAndSwapInt32(&s.refreshing, 0, 1) {
return
}
go func() {
defer atomic.StoreInt32(&s.refreshing, 0)
ctx, cancel := context.WithTimeout(context.Background(), s.refreshTimeout)
defer cancel()
stats, err := s.fetchDashboardStats(ctx)
if err != nil {
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异步刷新失败: %v", err)
return
}
s.applyAggregationStatus(ctx, stats)
cacheCtx, cancel := s.cacheOperationContext()
defer cancel()
s.saveDashboardStatsCache(cacheCtx, stats)
}()
}
func (s *DashboardService) fetchDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
if !s.aggEnabled {
if fetcher, ok := s.usageRepo.(dashboardStatsRangeFetcher); ok {
now := time.Now().UTC()
start := truncateToDayUTC(now.AddDate(0, 0, -s.aggUsageDays))
return fetcher.GetDashboardStatsWithRange(ctx, start, now)
}
}
return s.usageRepo.GetDashboardStats(ctx)
}
func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *usagestats.DashboardStats) {
if s.cache == nil || stats == nil {
return
}
entry := dashboardStatsCacheEntry{
Stats: stats,
UpdatedAt: time.Now().Unix(),
}
data, err := json.Marshal(entry)
if err != nil {
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存序列化失败: %v", err)
return
}
if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil {
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存写入失败: %v", err)
}
}
func (s *DashboardService) evictDashboardStatsCache(reason error) {
if s.cache == nil {
return
}
cacheCtx, cancel := s.cacheOperationContext()
defer cancel()
if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil {
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存清理失败: %v", err)
}
if reason != nil {
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异常,已清理: %v", reason)
}
}
func (s *DashboardService) cacheOperationContext() (context.Context, context.CancelFunc) {
return context.WithTimeout(context.Background(), s.refreshTimeout)
}
func (s *DashboardService) applyAggregationStatus(ctx context.Context, stats *usagestats.DashboardStats) {
if stats == nil {
return
}
updatedAt := s.fetchAggregationUpdatedAt(ctx)
stats.StatsUpdatedAt = updatedAt.UTC().Format(time.RFC3339)
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
}
func (s *DashboardService) refreshAggregationStaleness(stats *usagestats.DashboardStats) {
if stats == nil {
return
}
updatedAt := parseStatsUpdatedAt(stats.StatsUpdatedAt)
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
}
func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.Time {
if s.aggRepo == nil {
return time.Unix(0, 0).UTC()
}
updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx)
if err != nil {
logger.LegacyPrintf("service.dashboard", "[Dashboard] 读取聚合水位失败: %v", err)
return time.Unix(0, 0).UTC()
}
if updatedAt.IsZero() {
return time.Unix(0, 0).UTC()
}
return updatedAt.UTC()
}
func (s *DashboardService) isAggregationStale(updatedAt, now time.Time) bool {
if !s.aggEnabled {
return true
}
epoch := time.Unix(0, 0).UTC()
if !updatedAt.After(epoch) {
return true
}
threshold := s.aggInterval + s.aggLookback
return now.Sub(updatedAt) > threshold
}
func parseStatsUpdatedAt(raw string) time.Time {
if raw == "" {
return time.Unix(0, 0).UTC()
}
parsed, err := time.Parse(time.RFC3339, raw)
if err != nil {
return time.Unix(0, 0).UTC()
}
return parsed.UTC()
}
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {
return nil, fmt.Errorf("get api key usage trend: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
trend, err := s.usageRepo.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
if err != nil {
return nil, fmt.Errorf("get user usage trend: %w", err)
}
return trend, nil
}
func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
ranking, err := s.usageRepo.GetUserSpendingRanking(ctx, startTime, endTime, limit)
if err != nil {
return nil, fmt.Errorf("get user spending ranking: %w", err)
}
return ranking, nil
}
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get batch user usage stats: %w", err)
}
return stats, nil
}
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
if err != nil {
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
}
return stats, nil
}

View File

@@ -0,0 +1,395 @@
package service
import (
"context"
"encoding/json"
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/Wei-Shaw/sub2api/internal/config"
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
"github.com/stretchr/testify/require"
)
type usageRepoStub struct {
UsageLogRepository
stats *usagestats.DashboardStats
rangeStats *usagestats.DashboardStats
err error
rangeErr error
calls int32
rangeCalls int32
rangeStart time.Time
rangeEnd time.Time
onCall chan struct{}
}
func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
atomic.AddInt32(&s.calls, 1)
if s.onCall != nil {
select {
case s.onCall <- struct{}{}:
default:
}
}
if s.err != nil {
return nil, s.err
}
return s.stats, nil
}
func (s *usageRepoStub) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error) {
atomic.AddInt32(&s.rangeCalls, 1)
s.rangeStart = start
s.rangeEnd = end
if s.rangeErr != nil {
return nil, s.rangeErr
}
if s.rangeStats != nil {
return s.rangeStats, nil
}
return s.stats, nil
}
type dashboardCacheStub struct {
get func(ctx context.Context) (string, error)
set func(ctx context.Context, data string, ttl time.Duration) error
del func(ctx context.Context) error
getCalls int32
setCalls int32
delCalls int32
lastSetMu sync.Mutex
lastSet string
}
func (c *dashboardCacheStub) GetDashboardStats(ctx context.Context) (string, error) {
atomic.AddInt32(&c.getCalls, 1)
if c.get != nil {
return c.get(ctx)
}
return "", ErrDashboardStatsCacheMiss
}
func (c *dashboardCacheStub) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
atomic.AddInt32(&c.setCalls, 1)
c.lastSetMu.Lock()
c.lastSet = data
c.lastSetMu.Unlock()
if c.set != nil {
return c.set(ctx, data, ttl)
}
return nil
}
func (c *dashboardCacheStub) DeleteDashboardStats(ctx context.Context) error {
atomic.AddInt32(&c.delCalls, 1)
if c.del != nil {
return c.del(ctx)
}
return nil
}
type dashboardAggregationRepoStub struct {
watermark time.Time
err error
}
func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
if s.err != nil {
return time.Time{}, s.err
}
return s.watermark, nil
}
func (s *dashboardAggregationRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
return nil
}
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
return nil
}
func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntry {
t.Helper()
c.lastSetMu.Lock()
data := c.lastSet
c.lastSetMu.Unlock()
var entry dashboardStatsCacheEntry
err := json.Unmarshal([]byte(data), &entry)
require.NoError(t, err)
return entry
}
func TestDashboardService_CacheHitFresh(t *testing.T) {
stats := &usagestats.DashboardStats{
TotalUsers: 10,
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
StatsStale: true,
}
entry := dashboardStatsCacheEntry{
Stats: stats,
UpdatedAt: time.Now().Unix(),
}
payload, err := json.Marshal(entry)
require.NoError(t, err)
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return string(payload), nil
},
}
repo := &usageRepoStub{
stats: &usagestats.DashboardStats{TotalUsers: 99},
}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, stats, got)
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
}
func TestDashboardService_CacheMiss_StoresCache(t *testing.T) {
stats := &usagestats.DashboardStats{
TotalUsers: 7,
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
StatsStale: true,
}
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return "", ErrDashboardStatsCacheMiss
},
}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, stats, got)
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
require.Equal(t, int32(1), atomic.LoadInt32(&cache.setCalls))
entry := cache.readLastEntry(t)
require.Equal(t, stats, entry.Stats)
require.WithinDuration(t, time.Now(), time.Unix(entry.UpdatedAt, 0), time.Second)
}
func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) {
stats := &usagestats.DashboardStats{
TotalUsers: 3,
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
StatsStale: true,
}
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return "", nil
},
}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: false},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, stats, got)
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalls))
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
}
func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) {
staleStats := &usagestats.DashboardStats{
TotalUsers: 11,
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
StatsStale: true,
}
entry := dashboardStatsCacheEntry{
Stats: staleStats,
UpdatedAt: time.Now().Add(-defaultDashboardStatsFreshTTL * 2).Unix(),
}
payload, err := json.Marshal(entry)
require.NoError(t, err)
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return string(payload), nil
},
}
refreshCh := make(chan struct{}, 1)
repo := &usageRepoStub{
stats: &usagestats.DashboardStats{TotalUsers: 22},
onCall: refreshCh,
}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, staleStats, got)
select {
case <-refreshCh:
case <-time.After(1 * time.Second):
t.Fatal("等待异步刷新超时")
}
require.Eventually(t, func() bool {
return atomic.LoadInt32(&cache.setCalls) >= 1
}, 1*time.Second, 10*time.Millisecond)
}
func TestDashboardService_CacheParseError_EvictsAndRefetches(t *testing.T) {
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return "not-json", nil
},
}
stats := &usagestats.DashboardStats{TotalUsers: 9}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, stats, got)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
}
func TestDashboardService_CacheParseError_RepoFailure(t *testing.T) {
cache := &dashboardCacheStub{
get: func(ctx context.Context) (string, error) {
return "not-json", nil
},
}
repo := &usageRepoStub{err: errors.New("db down")}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: true},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
},
}
svc := NewDashboardService(repo, aggRepo, cache, cfg)
_, err := svc.GetDashboardStats(context.Background())
require.Error(t, err)
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
}
func TestDashboardService_StatsUpdatedAtEpochWhenMissing(t *testing.T) {
stats := &usagestats.DashboardStats{}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}}
svc := NewDashboardService(repo, aggRepo, nil, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, "1970-01-01T00:00:00Z", got.StatsUpdatedAt)
require.True(t, got.StatsStale)
}
func TestDashboardService_StatsStaleFalseWhenFresh(t *testing.T) {
aggNow := time.Now().UTC().Truncate(time.Second)
stats := &usagestats.DashboardStats{}
repo := &usageRepoStub{stats: stats}
aggRepo := &dashboardAggregationRepoStub{watermark: aggNow}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: false},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: true,
IntervalSeconds: 60,
LookbackSeconds: 120,
},
}
svc := NewDashboardService(repo, aggRepo, nil, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, aggNow.Format(time.RFC3339), got.StatsUpdatedAt)
require.False(t, got.StatsStale)
}
func TestDashboardService_AggDisabled_UsesUsageLogsFallback(t *testing.T) {
expected := &usagestats.DashboardStats{TotalUsers: 42}
repo := &usageRepoStub{
rangeStats: expected,
err: errors.New("should not call aggregated stats"),
}
cfg := &config.Config{
Dashboard: config.DashboardCacheConfig{Enabled: false},
DashboardAgg: config.DashboardAggregationConfig{
Enabled: false,
Retention: config.DashboardAggregationRetentionConfig{
UsageLogsDays: 7,
},
},
}
svc := NewDashboardService(repo, nil, nil, cfg)
got, err := svc.GetDashboardStats(context.Background())
require.NoError(t, err)
require.Equal(t, int64(42), got.TotalUsers)
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
require.Equal(t, int32(1), atomic.LoadInt32(&repo.rangeCalls))
require.False(t, repo.rangeEnd.IsZero())
require.Equal(t, truncateToDayUTC(repo.rangeEnd.AddDate(0, 0, -7)), repo.rangeStart)
}

View File

@@ -0,0 +1,252 @@
package service
import "context"
type DataManagementPostgresConfig struct {
Host string `json:"host"`
Port int32 `json:"port"`
User string `json:"user"`
Password string `json:"password,omitempty"`
PasswordConfigured bool `json:"password_configured"`
Database string `json:"database"`
SSLMode string `json:"ssl_mode"`
ContainerName string `json:"container_name"`
}
type DataManagementRedisConfig struct {
Addr string `json:"addr"`
Username string `json:"username"`
Password string `json:"password,omitempty"`
PasswordConfigured bool `json:"password_configured"`
DB int32 `json:"db"`
ContainerName string `json:"container_name"`
}
type DataManagementS3Config struct {
Enabled bool `json:"enabled"`
Endpoint string `json:"endpoint"`
Region string `json:"region"`
Bucket string `json:"bucket"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key,omitempty"`
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
Prefix string `json:"prefix"`
ForcePathStyle bool `json:"force_path_style"`
UseSSL bool `json:"use_ssl"`
}
type DataManagementConfig struct {
SourceMode string `json:"source_mode"`
BackupRoot string `json:"backup_root"`
SQLitePath string `json:"sqlite_path,omitempty"`
RetentionDays int32 `json:"retention_days"`
KeepLast int32 `json:"keep_last"`
ActivePostgresID string `json:"active_postgres_profile_id"`
ActiveRedisID string `json:"active_redis_profile_id"`
Postgres DataManagementPostgresConfig `json:"postgres"`
Redis DataManagementRedisConfig `json:"redis"`
S3 DataManagementS3Config `json:"s3"`
ActiveS3ProfileID string `json:"active_s3_profile_id"`
}
type DataManagementTestS3Result struct {
OK bool `json:"ok"`
Message string `json:"message"`
}
type DataManagementCreateBackupJobInput struct {
BackupType string
UploadToS3 bool
TriggeredBy string
IdempotencyKey string
S3ProfileID string
PostgresID string
RedisID string
}
type DataManagementListBackupJobsInput struct {
PageSize int32
PageToken string
Status string
BackupType string
}
type DataManagementArtifactInfo struct {
LocalPath string `json:"local_path"`
SizeBytes int64 `json:"size_bytes"`
SHA256 string `json:"sha256"`
}
type DataManagementS3ObjectInfo struct {
Bucket string `json:"bucket"`
Key string `json:"key"`
ETag string `json:"etag"`
}
type DataManagementBackupJob struct {
JobID string `json:"job_id"`
BackupType string `json:"backup_type"`
Status string `json:"status"`
TriggeredBy string `json:"triggered_by"`
IdempotencyKey string `json:"idempotency_key,omitempty"`
UploadToS3 bool `json:"upload_to_s3"`
S3ProfileID string `json:"s3_profile_id,omitempty"`
PostgresID string `json:"postgres_profile_id,omitempty"`
RedisID string `json:"redis_profile_id,omitempty"`
StartedAt string `json:"started_at,omitempty"`
FinishedAt string `json:"finished_at,omitempty"`
ErrorMessage string `json:"error_message,omitempty"`
Artifact DataManagementArtifactInfo `json:"artifact"`
S3Object DataManagementS3ObjectInfo `json:"s3"`
}
type DataManagementSourceProfile struct {
SourceType string `json:"source_type"`
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
Config DataManagementSourceConfig `json:"config"`
PasswordConfigured bool `json:"password_configured"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type DataManagementSourceConfig struct {
Host string `json:"host"`
Port int32 `json:"port"`
User string `json:"user"`
Password string `json:"password,omitempty"`
Database string `json:"database"`
SSLMode string `json:"ssl_mode"`
Addr string `json:"addr"`
Username string `json:"username"`
DB int32 `json:"db"`
ContainerName string `json:"container_name"`
}
type DataManagementCreateSourceProfileInput struct {
SourceType string
ProfileID string
Name string
Config DataManagementSourceConfig
SetActive bool
}
type DataManagementUpdateSourceProfileInput struct {
SourceType string
ProfileID string
Name string
Config DataManagementSourceConfig
}
type DataManagementS3Profile struct {
ProfileID string `json:"profile_id"`
Name string `json:"name"`
IsActive bool `json:"is_active"`
S3 DataManagementS3Config `json:"s3"`
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
CreatedAt string `json:"created_at,omitempty"`
UpdatedAt string `json:"updated_at,omitempty"`
}
type DataManagementCreateS3ProfileInput struct {
ProfileID string
Name string
S3 DataManagementS3Config
SetActive bool
}
type DataManagementUpdateS3ProfileInput struct {
ProfileID string
Name string
S3 DataManagementS3Config
}
type DataManagementListBackupJobsResult struct {
Items []DataManagementBackupJob `json:"items"`
NextPageToken string `json:"next_page_token,omitempty"`
}
func (s *DataManagementService) GetConfig(ctx context.Context) (DataManagementConfig, error) {
_ = ctx
return DataManagementConfig{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateConfig(ctx context.Context, cfg DataManagementConfig) (DataManagementConfig, error) {
_, _ = ctx, cfg
return DataManagementConfig{}, s.deprecatedError()
}
func (s *DataManagementService) ListSourceProfiles(ctx context.Context, sourceType string) ([]DataManagementSourceProfile, error) {
_, _ = ctx, sourceType
return nil, s.deprecatedError()
}
func (s *DataManagementService) CreateSourceProfile(ctx context.Context, input DataManagementCreateSourceProfileInput) (DataManagementSourceProfile, error) {
_, _ = ctx, input
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateSourceProfile(ctx context.Context, input DataManagementUpdateSourceProfileInput) (DataManagementSourceProfile, error) {
_, _ = ctx, input
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error {
_, _, _ = ctx, sourceType, profileID
return s.deprecatedError()
}
func (s *DataManagementService) SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (DataManagementSourceProfile, error) {
_, _, _ = ctx, sourceType, profileID
return DataManagementSourceProfile{}, s.deprecatedError()
}
func (s *DataManagementService) ValidateS3(ctx context.Context, cfg DataManagementS3Config) (DataManagementTestS3Result, error) {
_, _ = ctx, cfg
return DataManagementTestS3Result{}, s.deprecatedError()
}
func (s *DataManagementService) ListS3Profiles(ctx context.Context) ([]DataManagementS3Profile, error) {
_ = ctx
return nil, s.deprecatedError()
}
func (s *DataManagementService) CreateS3Profile(ctx context.Context, input DataManagementCreateS3ProfileInput) (DataManagementS3Profile, error) {
_, _ = ctx, input
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) UpdateS3Profile(ctx context.Context, input DataManagementUpdateS3ProfileInput) (DataManagementS3Profile, error) {
_, _ = ctx, input
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) DeleteS3Profile(ctx context.Context, profileID string) error {
_, _ = ctx, profileID
return s.deprecatedError()
}
func (s *DataManagementService) SetActiveS3Profile(ctx context.Context, profileID string) (DataManagementS3Profile, error) {
_, _ = ctx, profileID
return DataManagementS3Profile{}, s.deprecatedError()
}
func (s *DataManagementService) CreateBackupJob(ctx context.Context, input DataManagementCreateBackupJobInput) (DataManagementBackupJob, error) {
_, _ = ctx, input
return DataManagementBackupJob{}, s.deprecatedError()
}
func (s *DataManagementService) ListBackupJobs(ctx context.Context, input DataManagementListBackupJobsInput) (DataManagementListBackupJobsResult, error) {
_, _ = ctx, input
return DataManagementListBackupJobsResult{}, s.deprecatedError()
}
func (s *DataManagementService) GetBackupJob(ctx context.Context, jobID string) (DataManagementBackupJob, error) {
_, _ = ctx, jobID
return DataManagementBackupJob{}, s.deprecatedError()
}
func (s *DataManagementService) deprecatedError() error {
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
}

Some files were not shown because too many files have changed in this diff Show More