chore: initial snapshot for gitea/github upload
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,62 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccount_IsAnthropicAPIKeyPassthroughEnabled(t *testing.T) {
|
||||
t.Run("Anthropic API Key 开启", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"anthropic_passthrough": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsAnthropicAPIKeyPassthroughEnabled())
|
||||
})
|
||||
|
||||
t.Run("Anthropic API Key 关闭", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"anthropic_passthrough": false,
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled())
|
||||
})
|
||||
|
||||
t.Run("字段类型非法默认关闭", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"anthropic_passthrough": "true",
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsAnthropicAPIKeyPassthroughEnabled())
|
||||
})
|
||||
|
||||
t.Run("非 Anthropic API Key 账号始终关闭", func(t *testing.T) {
|
||||
oauth := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"anthropic_passthrough": true,
|
||||
},
|
||||
}
|
||||
require.False(t, oauth.IsAnthropicAPIKeyPassthroughEnabled())
|
||||
|
||||
openai := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"anthropic_passthrough": true,
|
||||
},
|
||||
}
|
||||
require.False(t, openai.IsAnthropicAPIKeyPassthroughEnabled())
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,160 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetBaseURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account Account
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "non-apikey type returns empty",
|
||||
account: Account{
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAnthropic,
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "apikey without base_url returns default anthropic",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAnthropic,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: "https://api.anthropic.com",
|
||||
},
|
||||
{
|
||||
name: "apikey with custom base_url",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAnthropic,
|
||||
Credentials: map[string]any{"base_url": "https://custom.example.com"},
|
||||
},
|
||||
expected: "https://custom.example.com",
|
||||
},
|
||||
{
|
||||
name: "antigravity apikey auto-appends /antigravity",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
|
||||
},
|
||||
expected: "https://upstream.example.com/antigravity",
|
||||
},
|
||||
{
|
||||
name: "antigravity apikey trims trailing slash before appending",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{"base_url": "https://upstream.example.com/"},
|
||||
},
|
||||
expected: "https://upstream.example.com/antigravity",
|
||||
},
|
||||
{
|
||||
name: "antigravity non-apikey returns empty",
|
||||
account: Account{
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.account.GetBaseURL()
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetBaseURL() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetGeminiBaseURL(t *testing.T) {
|
||||
const defaultGeminiURL = "https://generativelanguage.googleapis.com"
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
account Account
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "apikey without base_url returns default",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: defaultGeminiURL,
|
||||
},
|
||||
{
|
||||
name: "apikey with custom base_url",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
Credentials: map[string]any{"base_url": "https://custom-gemini.example.com"},
|
||||
},
|
||||
expected: "https://custom-gemini.example.com",
|
||||
},
|
||||
{
|
||||
name: "antigravity apikey auto-appends /antigravity",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
|
||||
},
|
||||
expected: "https://upstream.example.com/antigravity",
|
||||
},
|
||||
{
|
||||
name: "antigravity apikey trims trailing slash",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{"base_url": "https://upstream.example.com/"},
|
||||
},
|
||||
expected: "https://upstream.example.com/antigravity",
|
||||
},
|
||||
{
|
||||
name: "antigravity oauth does NOT append /antigravity",
|
||||
account: Account{
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{"base_url": "https://upstream.example.com"},
|
||||
},
|
||||
expected: "https://upstream.example.com",
|
||||
},
|
||||
{
|
||||
name: "oauth without base_url returns default",
|
||||
account: Account{
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: defaultGeminiURL,
|
||||
},
|
||||
{
|
||||
name: "nil credentials returns default",
|
||||
account: Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformGemini,
|
||||
},
|
||||
expected: defaultGeminiURL,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := tt.account.GetGeminiBaseURL(defaultGeminiURL)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetGeminiBaseURL() = %q, want %q", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccount_BillingRateMultiplier_DefaultsToOneWhenNil(t *testing.T) {
|
||||
var a Account
|
||||
require.NoError(t, json.Unmarshal([]byte(`{"id":1,"name":"acc","status":"active"}`), &a))
|
||||
require.Nil(t, a.RateMultiplier)
|
||||
require.Equal(t, 1.0, a.BillingRateMultiplier())
|
||||
}
|
||||
|
||||
func TestAccount_BillingRateMultiplier_AllowsZero(t *testing.T) {
|
||||
v := 0.0
|
||||
a := Account{RateMultiplier: &v}
|
||||
require.Equal(t, 0.0, a.BillingRateMultiplier())
|
||||
}
|
||||
|
||||
func TestAccount_BillingRateMultiplier_NegativeFallsBackToOne(t *testing.T) {
|
||||
v := -1.0
|
||||
a := Account{RateMultiplier: &v}
|
||||
require.Equal(t, 1.0, a.BillingRateMultiplier())
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// AccountExpiryService periodically pauses expired accounts when auto-pause is enabled.
|
||||
type AccountExpiryService struct {
|
||||
accountRepo AccountRepository
|
||||
interval time.Duration
|
||||
stopCh chan struct{}
|
||||
stopOnce sync.Once
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func NewAccountExpiryService(accountRepo AccountRepository, interval time.Duration) *AccountExpiryService {
|
||||
return &AccountExpiryService{
|
||||
accountRepo: accountRepo,
|
||||
interval: interval,
|
||||
stopCh: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountExpiryService) Start() {
|
||||
if s == nil || s.accountRepo == nil || s.interval <= 0 {
|
||||
return
|
||||
}
|
||||
s.wg.Add(1)
|
||||
go func() {
|
||||
defer s.wg.Done()
|
||||
ticker := time.NewTicker(s.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
s.runOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
s.runOnce()
|
||||
case <-s.stopCh:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *AccountExpiryService) Stop() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.stopOnce.Do(func() {
|
||||
close(s.stopCh)
|
||||
})
|
||||
s.wg.Wait()
|
||||
}
|
||||
|
||||
func (s *AccountExpiryService) runOnce() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
updated, err := s.accountRepo.AutoPauseExpiredAccounts(ctx, time.Now())
|
||||
if err != nil {
|
||||
log.Printf("[AccountExpiry] Auto pause expired accounts failed: %v", err)
|
||||
return
|
||||
}
|
||||
if updated > 0 {
|
||||
log.Printf("[AccountExpiry] Auto paused %d expired accounts", updated)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
type AccountGroup struct {
|
||||
AccountID int64
|
||||
GroupID int64
|
||||
Priority int
|
||||
CreatedAt time.Time
|
||||
|
||||
Account *Account
|
||||
Group *Group
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccount_IsInterceptWarmupEnabled(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
credentials map[string]any
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil credentials",
|
||||
credentials: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "empty map",
|
||||
credentials: map[string]any{},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "field not present",
|
||||
credentials: map[string]any{"access_token": "tok"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "field is true",
|
||||
credentials: map[string]any{"intercept_warmup_requests": true},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "field is false",
|
||||
credentials: map[string]any{"intercept_warmup_requests": false},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "field is string true",
|
||||
credentials: map[string]any{"intercept_warmup_requests": "true"},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "field is int 1",
|
||||
credentials: map[string]any{"intercept_warmup_requests": 1},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "field is nil",
|
||||
credentials: map[string]any{"intercept_warmup_requests": nil},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &Account{Credentials: tt.credentials}
|
||||
result := a.IsInterceptWarmupEnabled()
|
||||
require.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func intPtrHelper(v int) *int { return &v }
|
||||
|
||||
func TestEffectiveLoadFactor_NilAccount(t *testing.T) {
|
||||
var a *Account
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NilLoadFactor_PositiveConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 5}
|
||||
require.Equal(t, 5, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NilLoadFactor_ZeroConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 0}
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_PositiveLoadFactor(t *testing.T) {
|
||||
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(20)}
|
||||
require.Equal(t, 20, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_ZeroLoadFactor_FallbackToConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 5, LoadFactor: intPtrHelper(0)}
|
||||
require.Equal(t, 5, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_NegativeLoadFactor_FallbackToConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 3, LoadFactor: intPtrHelper(-1)}
|
||||
require.Equal(t, 3, a.EffectiveLoadFactor())
|
||||
}
|
||||
|
||||
func TestEffectiveLoadFactor_ZeroLoadFactor_ZeroConcurrency(t *testing.T) {
|
||||
a := &Account{Concurrency: 0, LoadFactor: intPtrHelper(0)}
|
||||
require.Equal(t, 1, a.EffectiveLoadFactor())
|
||||
}
|
||||
@@ -0,0 +1,315 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAccount_IsOpenAIPassthroughEnabled(t *testing.T) {
|
||||
t.Run("新字段开启", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"openai_passthrough": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIPassthroughEnabled())
|
||||
})
|
||||
|
||||
t.Run("兼容旧字段", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_passthrough": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIPassthroughEnabled())
|
||||
})
|
||||
|
||||
t.Run("非OpenAI账号始终关闭", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_passthrough": true,
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsOpenAIPassthroughEnabled())
|
||||
})
|
||||
|
||||
t.Run("空额外配置默认关闭", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
require.False(t, account.IsOpenAIPassthroughEnabled())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_IsOpenAIOAuthPassthroughEnabled(t *testing.T) {
|
||||
t.Run("仅OAuth类型允许返回开启", func(t *testing.T) {
|
||||
oauthAccount := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_passthrough": true,
|
||||
},
|
||||
}
|
||||
require.True(t, oauthAccount.IsOpenAIOAuthPassthroughEnabled())
|
||||
|
||||
apiKeyAccount := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"openai_passthrough": true,
|
||||
},
|
||||
}
|
||||
require.False(t, apiKeyAccount.IsOpenAIOAuthPassthroughEnabled())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_IsCodexCLIOnlyEnabled(t *testing.T) {
|
||||
t.Run("OpenAI OAuth 开启", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"codex_cli_only": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsCodexCLIOnlyEnabled())
|
||||
})
|
||||
|
||||
t.Run("OpenAI OAuth 关闭", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"codex_cli_only": false,
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsCodexCLIOnlyEnabled())
|
||||
})
|
||||
|
||||
t.Run("字段缺失默认关闭", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.False(t, account.IsCodexCLIOnlyEnabled())
|
||||
})
|
||||
|
||||
t.Run("类型非法默认关闭", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"codex_cli_only": "true",
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsCodexCLIOnlyEnabled())
|
||||
})
|
||||
|
||||
t.Run("非 OAuth 账号始终关闭", func(t *testing.T) {
|
||||
apiKeyAccount := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"codex_cli_only": true,
|
||||
},
|
||||
}
|
||||
require.False(t, apiKeyAccount.IsCodexCLIOnlyEnabled())
|
||||
|
||||
otherPlatform := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"codex_cli_only": true,
|
||||
},
|
||||
}
|
||||
require.False(t, otherPlatform.IsCodexCLIOnlyEnabled())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_IsOpenAIResponsesWebSocketV2Enabled(t *testing.T) {
|
||||
t.Run("OAuth使用OAuth专用开关", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("API Key使用API Key专用开关", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("OAuth账号不会读取API Key专用开关", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("分类型新键优先于兼容键", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||
"responses_websockets_v2_enabled": true,
|
||||
"openai_ws_enabled": true,
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("分类型键缺失时回退兼容键", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
|
||||
t.Run("非OpenAI账号默认关闭", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.False(t, account.IsOpenAIResponsesWebSocketV2Enabled())
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_ResolveOpenAIResponsesWebSocketV2Mode(t *testing.T) {
|
||||
t.Run("default fallback to ctx_pool", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(""))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode("invalid"))
|
||||
})
|
||||
|
||||
t.Run("oauth mode field has highest priority", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModePassthrough,
|
||||
"openai_oauth_responses_websockets_v2_enabled": false,
|
||||
"responses_websockets_v2_enabled": false,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModePassthrough, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||
})
|
||||
|
||||
t.Run("legacy enabled maps to ctx_pool", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
})
|
||||
|
||||
t.Run("shared/dedicated mode strings are compatible with ctx_pool", func(t *testing.T) {
|
||||
shared := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeShared,
|
||||
},
|
||||
}
|
||||
dedicated := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeShared, shared.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeDedicated, dedicated.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeOff))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeShared))
|
||||
require.Equal(t, OpenAIWSIngressModeCtxPool, normalizeOpenAIWSIngressDefaultMode(OpenAIWSIngressModeDedicated))
|
||||
})
|
||||
|
||||
t.Run("legacy disabled maps to off", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeAPIKey,
|
||||
Extra: map[string]any{
|
||||
"openai_apikey_responses_websockets_v2_enabled": false,
|
||||
"responses_websockets_v2_enabled": true,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeCtxPool))
|
||||
})
|
||||
|
||||
t.Run("non openai always off", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_mode": OpenAIWSIngressModeDedicated,
|
||||
},
|
||||
}
|
||||
require.Equal(t, OpenAIWSIngressModeOff, account.ResolveOpenAIResponsesWebSocketV2Mode(OpenAIWSIngressModeDedicated))
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccount_OpenAIWSExtraFlags(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_ws_force_http": true,
|
||||
"openai_ws_allow_store_recovery": true,
|
||||
},
|
||||
}
|
||||
require.True(t, account.IsOpenAIWSForceHTTPEnabled())
|
||||
require.True(t, account.IsOpenAIWSAllowStoreRecoveryEnabled())
|
||||
|
||||
off := &Account{Platform: PlatformOpenAI, Type: AccountTypeOAuth, Extra: map[string]any{}}
|
||||
require.False(t, off.IsOpenAIWSForceHTTPEnabled())
|
||||
require.False(t, off.IsOpenAIWSAllowStoreRecoveryEnabled())
|
||||
|
||||
var nilAccount *Account
|
||||
require.False(t, nilAccount.IsOpenAIWSAllowStoreRecoveryEnabled())
|
||||
|
||||
nonOpenAI := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_ws_allow_store_recovery": true,
|
||||
},
|
||||
}
|
||||
require.False(t, nonOpenAI.IsOpenAIWSAllowStoreRecoveryEnabled())
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetPoolModeRetryCount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
account *Account
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "default_when_not_pool_mode",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "default_when_missing_retry_count",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "supports_float64_from_json_credentials",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": float64(5),
|
||||
},
|
||||
},
|
||||
expected: 5,
|
||||
},
|
||||
{
|
||||
name: "supports_json_number",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": json.Number("4"),
|
||||
},
|
||||
},
|
||||
expected: 4,
|
||||
},
|
||||
{
|
||||
name: "supports_string_value",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": "2",
|
||||
},
|
||||
},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "negative_value_is_clamped_to_zero",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": -1,
|
||||
},
|
||||
},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "oversized_value_is_clamped_to_max",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": 99,
|
||||
},
|
||||
},
|
||||
expected: maxPoolModeRetryCount,
|
||||
},
|
||||
{
|
||||
name: "invalid_value_falls_back_to_default",
|
||||
account: &Account{
|
||||
Type: AccountTypeAPIKey,
|
||||
Platform: PlatformOpenAI,
|
||||
Credentials: map[string]any{
|
||||
"pool_mode": true,
|
||||
"pool_mode_retry_count": "oops",
|
||||
},
|
||||
},
|
||||
expected: defaultPoolModeRetryCount,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.expected, tt.account.GetPoolModeRetryCount())
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,516 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// nextFixedDailyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNextFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 06:00 UTC, reset hour = 9
|
||||
after := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_AtResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// Exactly at reset hour → should return tomorrow
|
||||
after := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// After reset hour → should return tomorrow
|
||||
after := time.Date(2026, 3, 14, 15, 30, 0, 0, tz)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_MidnightReset(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// Reset at hour 0 (midnight), currently 23:59
|
||||
after := time.Date(2026, 3, 14, 23, 59, 0, 0, tz)
|
||||
got := nextFixedDailyReset(0, tz, after)
|
||||
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedDailyReset_NonUTCTimezone(t *testing.T) {
|
||||
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2026-03-14 07:00 UTC = 2026-03-14 15:00 CST, reset hour = 9 (CST)
|
||||
after := time.Date(2026, 3, 14, 7, 0, 0, 0, time.UTC)
|
||||
got := nextFixedDailyReset(9, tz, after)
|
||||
// Already past 9:00 CST today → tomorrow 9:00 CST = 2026-03-15 01:00 UTC
|
||||
want := time.Date(2026, 3, 15, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// lastFixedDailyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLastFixedDailyReset_BeforeResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 6, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// Before today's 9:00 → yesterday 9:00
|
||||
want := time.Date(2026, 3, 13, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedDailyReset_AtResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// At exactly 9:00 → today 9:00
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedDailyReset_AfterResetHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
now := time.Date(2026, 3, 14, 15, 0, 0, 0, tz)
|
||||
got := lastFixedDailyReset(9, tz, now)
|
||||
// After 9:00 → today 9:00
|
||||
want := time.Date(2026, 3, 14, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// nextFixedWeeklyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayAhead(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 is Saturday (day=6), target = Monday (day=1), hour = 9
|
||||
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday = 2026-03-16
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_BeforeHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, before 9:00
|
||||
after := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Today at 9:00
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_AtHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, exactly at 9:00
|
||||
after := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday at 9:00
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayToday_AfterHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, after 9:00
|
||||
after := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday at 9:00
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_TargetDayPast(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||
after := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(1, 9, tz, after)
|
||||
// Next Monday = 2026-03-23
|
||||
want := time.Date(2026, 3, 23, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestNextFixedWeeklyReset_Sunday(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-14 is Saturday (day=6), target = Sunday (day=0)
|
||||
after := time.Date(2026, 3, 14, 10, 0, 0, 0, tz)
|
||||
got := nextFixedWeeklyReset(0, 0, tz, after)
|
||||
// Next Sunday = 2026-03-15
|
||||
want := time.Date(2026, 3, 15, 0, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// lastFixedWeeklyReset
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestLastFixedWeeklyReset_SameDay_AfterHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday (day=1), target = Monday, hour = 9, now = 15:00
|
||||
now := time.Date(2026, 3, 16, 15, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Today at 9:00
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedWeeklyReset_SameDay_BeforeHour(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-16 is Monday, target = Monday, hour = 9, now = 06:00
|
||||
now := time.Date(2026, 3, 16, 6, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Last Monday at 9:00 = 2026-03-09
|
||||
want := time.Date(2026, 3, 9, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
func TestLastFixedWeeklyReset_DifferentDay(t *testing.T) {
|
||||
tz := time.UTC
|
||||
// 2026-03-18 is Wednesday (day=3), target = Monday (day=1)
|
||||
now := time.Date(2026, 3, 18, 10, 0, 0, 0, tz)
|
||||
got := lastFixedWeeklyReset(1, 9, tz, now)
|
||||
// Last Monday = 2026-03-16
|
||||
want := time.Date(2026, 3, 16, 9, 0, 0, 0, tz)
|
||||
assert.Equal(t, want, got)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isFixedDailyPeriodExpired
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(time.Time{}))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_NotExpired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started after the most recent reset → not expired
|
||||
// (This test uses a time very close to "now", which is after the last reset)
|
||||
periodStart := time.Now().Add(-1 * time.Minute)
|
||||
assert.False(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_Expired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 3 days ago → definitely expired
|
||||
periodStart := time.Now().Add(-72 * time.Hour)
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedDailyPeriodExpired_InvalidTimezone(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "Invalid/Timezone",
|
||||
}}
|
||||
// Invalid timezone falls back to UTC
|
||||
periodStart := time.Now().Add(-72 * time.Hour)
|
||||
assert.True(t, a.isFixedDailyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// isFixedWeeklyPeriodExpired
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_ZeroPeriodStart(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
assert.True(t, a.isFixedWeeklyPeriodExpired(time.Time{}))
|
||||
}
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_NotExpired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 1 minute ago → not expired
|
||||
periodStart := time.Now().Add(-1 * time.Minute)
|
||||
assert.False(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
func TestIsFixedWeeklyPeriodExpired_Expired(t *testing.T) {
|
||||
a := &Account{Extra: map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}}
|
||||
// Period started 10 days ago → definitely expired
|
||||
periodStart := time.Now().Add(-240 * time.Hour)
|
||||
assert.True(t, a.isFixedWeeklyPeriodExpired(periodStart))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ValidateQuotaResetConfig
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestValidateQuotaResetConfig_NilExtra(t *testing.T) {
|
||||
assert.NoError(t, ValidateQuotaResetConfig(nil))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_EmptyExtra(t *testing.T) {
|
||||
assert.NoError(t, ValidateQuotaResetConfig(map[string]any{}))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_ValidFixed(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1),
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "Asia/Shanghai",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_ValidRolling(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidTimezone(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_reset_timezone": "Not/A/Timezone",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_reset_timezone")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyMode(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "invalid",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_mode")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyHour_TooHigh(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(24),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidDailyHour_Negative(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(-1),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_daily_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyMode(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_mode": "unknown",
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_mode")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyDay_TooHigh(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_day": float64(7),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyDay_Negative(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_day": float64(-1),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_day")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_InvalidWeeklyHour(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_hour": float64(25),
|
||||
}
|
||||
err := ValidateQuotaResetConfig(extra)
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "quota_weekly_reset_hour")
|
||||
}
|
||||
|
||||
func TestValidateQuotaResetConfig_BoundaryValues(t *testing.T) {
|
||||
// All boundary values should be valid
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_hour": float64(23),
|
||||
"quota_weekly_reset_day": float64(0), // Sunday
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra))
|
||||
|
||||
extra2 := map[string]any{
|
||||
"quota_daily_reset_hour": float64(0),
|
||||
"quota_weekly_reset_day": float64(6), // Saturday
|
||||
"quota_weekly_reset_hour": float64(23),
|
||||
}
|
||||
assert.NoError(t, ValidateQuotaResetConfig(extra2))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// ComputeQuotaResetAt
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestComputeQuotaResetAt_RollingMode_NoResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||
assert.False(t, hasDailyResetAt, "rolling mode should not set quota_daily_reset_at")
|
||||
assert.False(t, hasWeeklyResetAt, "rolling mode should not set quota_weekly_reset_at")
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_RollingMode_ClearsExistingResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "rolling",
|
||||
"quota_weekly_reset_mode": "rolling",
|
||||
"quota_daily_reset_at": "2026-03-14T09:00:00Z",
|
||||
"quota_weekly_reset_at": "2026-03-16T09:00:00Z",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
_, hasDailyResetAt := extra["quota_daily_reset_at"]
|
||||
_, hasWeeklyResetAt := extra["quota_weekly_reset_at"]
|
||||
assert.False(t, hasDailyResetAt, "rolling mode should remove quota_daily_reset_at")
|
||||
assert.False(t, hasWeeklyResetAt, "rolling mode should remove quota_weekly_reset_at")
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedDaily_SetsResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok, "quota_daily_reset_at should be set")
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Reset time should be in the future
|
||||
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||
// Reset hour should be 9 UTC
|
||||
assert.Equal(t, 9, resetAt.UTC().Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedWeekly_SetsResetAt(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_weekly_reset_mode": "fixed",
|
||||
"quota_weekly_reset_day": float64(1), // Monday
|
||||
"quota_weekly_reset_hour": float64(0),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_weekly_reset_at"].(string)
|
||||
require.True(t, ok, "quota_weekly_reset_at should be set")
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Reset time should be in the future
|
||||
assert.True(t, resetAt.After(time.Now()), "reset_at should be in the future")
|
||||
// Reset day should be Monday
|
||||
assert.Equal(t, time.Monday, resetAt.UTC().Weekday())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_FixedDaily_WithTimezone(t *testing.T) {
|
||||
tz, err := time.LoadLocation("Asia/Shanghai")
|
||||
require.NoError(t, err)
|
||||
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(9),
|
||||
"quota_reset_timezone": "Asia/Shanghai",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// In Shanghai timezone, the hour should be 9
|
||||
assert.Equal(t, 9, resetAt.In(tz).Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_DefaultTimezone(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(12),
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Default timezone is UTC
|
||||
assert.Equal(t, 12, resetAt.UTC().Hour())
|
||||
}
|
||||
|
||||
func TestComputeQuotaResetAt_InvalidHour_ClampedToZero(t *testing.T) {
|
||||
extra := map[string]any{
|
||||
"quota_daily_reset_mode": "fixed",
|
||||
"quota_daily_reset_hour": float64(99),
|
||||
"quota_reset_timezone": "UTC",
|
||||
}
|
||||
ComputeQuotaResetAt(extra)
|
||||
resetAtStr, ok := extra["quota_daily_reset_at"].(string)
|
||||
require.True(t, ok)
|
||||
|
||||
resetAt, err := time.Parse(time.RFC3339, resetAtStr)
|
||||
require.NoError(t, err)
|
||||
// Invalid hour → clamped to 0
|
||||
assert.Equal(t, 0, resetAt.UTC().Hour())
|
||||
}
|
||||
@@ -0,0 +1,120 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetBaseRPM(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
expected int
|
||||
}{
|
||||
{"nil extra", nil, 0},
|
||||
{"no key", map[string]any{}, 0},
|
||||
{"zero", map[string]any{"base_rpm": 0}, 0},
|
||||
{"int value", map[string]any{"base_rpm": 15}, 15},
|
||||
{"float value", map[string]any{"base_rpm": 15.0}, 15},
|
||||
{"string value", map[string]any{"base_rpm": "15"}, 15},
|
||||
{"negative value", map[string]any{"base_rpm": -5}, 0},
|
||||
{"int64 value", map[string]any{"base_rpm": int64(20)}, 20},
|
||||
{"json.Number value", map[string]any{"base_rpm": json.Number("25")}, 25},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &Account{Extra: tt.extra}
|
||||
if got := a.GetBaseRPM(); got != tt.expected {
|
||||
t.Errorf("GetBaseRPM() = %d, want %d", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRPMStrategy(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
expected string
|
||||
}{
|
||||
{"nil extra", nil, "tiered"},
|
||||
{"no key", map[string]any{}, "tiered"},
|
||||
{"tiered", map[string]any{"rpm_strategy": "tiered"}, "tiered"},
|
||||
{"sticky_exempt", map[string]any{"rpm_strategy": "sticky_exempt"}, "sticky_exempt"},
|
||||
{"invalid", map[string]any{"rpm_strategy": "foobar"}, "tiered"},
|
||||
{"empty string fallback", map[string]any{"rpm_strategy": ""}, "tiered"},
|
||||
{"numeric value fallback", map[string]any{"rpm_strategy": 123}, "tiered"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &Account{Extra: tt.extra}
|
||||
if got := a.GetRPMStrategy(); got != tt.expected {
|
||||
t.Errorf("GetRPMStrategy() = %q, want %q", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckRPMSchedulability(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
currentRPM int
|
||||
expected WindowCostSchedulability
|
||||
}{
|
||||
{"disabled", map[string]any{}, 100, WindowCostSchedulable},
|
||||
{"green zone", map[string]any{"base_rpm": 15}, 10, WindowCostSchedulable},
|
||||
{"yellow zone tiered", map[string]any{"base_rpm": 15}, 15, WindowCostStickyOnly},
|
||||
{"red zone tiered", map[string]any{"base_rpm": 15}, 18, WindowCostNotSchedulable},
|
||||
{"sticky_exempt at limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 15, WindowCostStickyOnly},
|
||||
{"sticky_exempt over limit", map[string]any{"base_rpm": 15, "rpm_strategy": "sticky_exempt"}, 100, WindowCostStickyOnly},
|
||||
{"custom buffer", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 14, WindowCostStickyOnly},
|
||||
{"custom buffer red", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 15, WindowCostNotSchedulable},
|
||||
{"base_rpm=1 green", map[string]any{"base_rpm": 1}, 0, WindowCostSchedulable},
|
||||
{"base_rpm=1 yellow (at limit)", map[string]any{"base_rpm": 1}, 1, WindowCostStickyOnly},
|
||||
{"base_rpm=1 red (at limit+buffer)", map[string]any{"base_rpm": 1}, 2, WindowCostNotSchedulable},
|
||||
{"negative currentRPM", map[string]any{"base_rpm": 15}, -1, WindowCostSchedulable},
|
||||
{"base_rpm negative disabled", map[string]any{"base_rpm": -5}, 10, WindowCostSchedulable},
|
||||
{"very high currentRPM", map[string]any{"base_rpm": 10}, 9999, WindowCostNotSchedulable},
|
||||
{"sticky_exempt very high currentRPM", map[string]any{"base_rpm": 10, "rpm_strategy": "sticky_exempt"}, 9999, WindowCostStickyOnly},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &Account{Extra: tt.extra}
|
||||
if got := a.CheckRPMSchedulability(tt.currentRPM); got != tt.expected {
|
||||
t.Errorf("CheckRPMSchedulability(%d) = %d, want %d", tt.currentRPM, got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRPMStickyBuffer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
extra map[string]any
|
||||
expected int
|
||||
}{
|
||||
{"nil extra", nil, 0},
|
||||
{"no keys", map[string]any{}, 0},
|
||||
{"base_rpm=0", map[string]any{"base_rpm": 0}, 0},
|
||||
{"base_rpm=1 min buffer 1", map[string]any{"base_rpm": 1}, 1},
|
||||
{"base_rpm=4 min buffer 1", map[string]any{"base_rpm": 4}, 1},
|
||||
{"base_rpm=5 buffer 1", map[string]any{"base_rpm": 5}, 1},
|
||||
{"base_rpm=10 buffer 2", map[string]any{"base_rpm": 10}, 2},
|
||||
{"base_rpm=15 buffer 3", map[string]any{"base_rpm": 15}, 3},
|
||||
{"base_rpm=100 buffer 20", map[string]any{"base_rpm": 100}, 20},
|
||||
{"custom buffer=5", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 5}, 5},
|
||||
{"custom buffer=0 fallback to default", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": 0}, 2},
|
||||
{"custom buffer negative fallback", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": -1}, 2},
|
||||
{"custom buffer with float", map[string]any{"base_rpm": 10, "rpm_sticky_buffer": float64(7)}, 7},
|
||||
{"json.Number base_rpm", map[string]any{"base_rpm": json.Number("10")}, 2},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
a := &Account{Extra: tt.extra}
|
||||
if got := a.GetRPMStickyBuffer(); got != tt.expected {
|
||||
t.Errorf("GetRPMStickyBuffer() = %d, want %d", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,397 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAccountNotFound = infraerrors.NotFound("ACCOUNT_NOT_FOUND", "account not found")
|
||||
ErrAccountNilInput = infraerrors.BadRequest("ACCOUNT_NIL_INPUT", "account input cannot be nil")
|
||||
)
|
||||
|
||||
type AccountRepository interface {
|
||||
Create(ctx context.Context, account *Account) error
|
||||
GetByID(ctx context.Context, id int64) (*Account, error)
|
||||
// GetByIDs fetches accounts by IDs in a single query.
|
||||
// It should return all accounts found (missing IDs are ignored).
|
||||
GetByIDs(ctx context.Context, ids []int64) ([]*Account, error)
|
||||
// ExistsByID 检查账号是否存在,仅返回布尔值,用于删除前的轻量级存在性检查
|
||||
ExistsByID(ctx context.Context, id int64) (bool, error)
|
||||
// GetByCRSAccountID finds an account previously synced from CRS.
|
||||
// Returns (nil, nil) if not found.
|
||||
GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error)
|
||||
// FindByExtraField 根据 extra 字段中的键值对查找账号(限定 platform='sora')
|
||||
// 用于查找通过 linked_openai_account_id 关联的 Sora 账号
|
||||
FindByExtraField(ctx context.Context, key string, value any) ([]Account, error)
|
||||
// ListCRSAccountIDs returns a map of crs_account_id -> local account ID
|
||||
// for all accounts that have been synced from CRS.
|
||||
ListCRSAccountIDs(ctx context.Context) (map[string]int64, error)
|
||||
Update(ctx context.Context, account *Account) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error)
|
||||
ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error)
|
||||
ListByGroup(ctx context.Context, groupID int64) ([]Account, error)
|
||||
ListActive(ctx context.Context) ([]Account, error)
|
||||
ListByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
|
||||
UpdateLastUsed(ctx context.Context, id int64) error
|
||||
BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error
|
||||
SetError(ctx context.Context, id int64, errorMsg string) error
|
||||
ClearError(ctx context.Context, id int64) error
|
||||
SetSchedulable(ctx context.Context, id int64, schedulable bool) error
|
||||
AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error)
|
||||
BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error
|
||||
|
||||
ListSchedulable(ctx context.Context) ([]Account, error)
|
||||
ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error)
|
||||
ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error)
|
||||
ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
|
||||
ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error)
|
||||
ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error)
|
||||
ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error)
|
||||
|
||||
SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error
|
||||
SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error
|
||||
SetOverloaded(ctx context.Context, id int64, until time.Time) error
|
||||
SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error
|
||||
ClearTempUnschedulable(ctx context.Context, id int64) error
|
||||
ClearRateLimit(ctx context.Context, id int64) error
|
||||
ClearAntigravityQuotaScopes(ctx context.Context, id int64) error
|
||||
ClearModelRateLimits(ctx context.Context, id int64) error
|
||||
UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error
|
||||
UpdateExtra(ctx context.Context, id int64, updates map[string]any) error
|
||||
BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error)
|
||||
// IncrementQuotaUsed 原子递增 API Key 账号的配额用量(总/日/周)
|
||||
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error
|
||||
// ResetQuotaUsed 重置 API Key 账号所有维度的配额用量为 0
|
||||
ResetQuotaUsed(ctx context.Context, id int64) error
|
||||
}
|
||||
|
||||
// AccountBulkUpdate describes the fields that can be updated in a bulk operation.
|
||||
// Nil pointers mean "do not change".
|
||||
type AccountBulkUpdate struct {
|
||||
Name *string
|
||||
ProxyID *int64
|
||||
Concurrency *int
|
||||
Priority *int
|
||||
RateMultiplier *float64
|
||||
LoadFactor *int
|
||||
Status *string
|
||||
Schedulable *bool
|
||||
Credentials map[string]any
|
||||
Extra map[string]any
|
||||
}
|
||||
|
||||
// CreateAccountRequest 创建账号请求
|
||||
type CreateAccountRequest struct {
|
||||
Name string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Platform string `json:"platform"`
|
||||
Type string `json:"type"`
|
||||
Credentials map[string]any `json:"credentials"`
|
||||
Extra map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
Priority int `json:"priority"`
|
||||
GroupIDs []int64 `json:"group_ids"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
}
|
||||
|
||||
// UpdateAccountRequest 更新账号请求
|
||||
type UpdateAccountRequest struct {
|
||||
Name *string `json:"name"`
|
||||
Notes *string `json:"notes"`
|
||||
Credentials *map[string]any `json:"credentials"`
|
||||
Extra *map[string]any `json:"extra"`
|
||||
ProxyID *int64 `json:"proxy_id"`
|
||||
Concurrency *int `json:"concurrency"`
|
||||
Priority *int `json:"priority"`
|
||||
Status *string `json:"status"`
|
||||
GroupIDs *[]int64 `json:"group_ids"`
|
||||
ExpiresAt *time.Time `json:"expires_at"`
|
||||
AutoPauseOnExpired *bool `json:"auto_pause_on_expired"`
|
||||
}
|
||||
|
||||
// AccountService 账号管理服务
|
||||
type AccountService struct {
|
||||
accountRepo AccountRepository
|
||||
groupRepo GroupRepository
|
||||
}
|
||||
|
||||
type groupExistenceBatchChecker interface {
|
||||
ExistsByIDs(ctx context.Context, ids []int64) (map[int64]bool, error)
|
||||
}
|
||||
|
||||
// NewAccountService 创建账号服务实例
|
||||
func NewAccountService(accountRepo AccountRepository, groupRepo GroupRepository) *AccountService {
|
||||
return &AccountService{
|
||||
accountRepo: accountRepo,
|
||||
groupRepo: groupRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// Create 创建账号
|
||||
func (s *AccountService) Create(ctx context.Context, req CreateAccountRequest) (*Account, error) {
|
||||
// 验证分组是否存在(如果指定了分组)
|
||||
if len(req.GroupIDs) > 0 {
|
||||
if err := s.validateGroupIDsExist(ctx, req.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 创建账号
|
||||
account := &Account{
|
||||
Name: req.Name,
|
||||
Notes: normalizeAccountNotes(req.Notes),
|
||||
Platform: req.Platform,
|
||||
Type: req.Type,
|
||||
Credentials: req.Credentials,
|
||||
Extra: req.Extra,
|
||||
ProxyID: req.ProxyID,
|
||||
Concurrency: req.Concurrency,
|
||||
Priority: req.Priority,
|
||||
Status: StatusActive,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
}
|
||||
if req.AutoPauseOnExpired != nil {
|
||||
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
|
||||
} else {
|
||||
account.AutoPauseOnExpired = true
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Create(ctx, account); err != nil {
|
||||
return nil, fmt.Errorf("create account: %w", err)
|
||||
}
|
||||
|
||||
// 绑定分组
|
||||
if len(req.GroupIDs) > 0 {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, req.GroupIDs); err != nil {
|
||||
return nil, fmt.Errorf("bind groups: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取账号
|
||||
func (s *AccountService) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// List 获取账号列表
|
||||
func (s *AccountService) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
accounts, pagination, err := s.accountRepo.List(ctx, params)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list accounts: %w", err)
|
||||
}
|
||||
return accounts, pagination, nil
|
||||
}
|
||||
|
||||
// ListByPlatform 根据平台获取账号列表
|
||||
func (s *AccountService) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
accounts, err := s.accountRepo.ListByPlatform(ctx, platform)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list accounts by platform: %w", err)
|
||||
}
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
// ListByGroup 根据分组获取账号列表
|
||||
func (s *AccountService) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
accounts, err := s.accountRepo.ListByGroup(ctx, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list accounts by group: %w", err)
|
||||
}
|
||||
return accounts, nil
|
||||
}
|
||||
|
||||
// Update 更新账号
|
||||
func (s *AccountService) Update(ctx context.Context, id int64, req UpdateAccountRequest) (*Account, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != nil {
|
||||
account.Name = *req.Name
|
||||
}
|
||||
if req.Notes != nil {
|
||||
account.Notes = normalizeAccountNotes(req.Notes)
|
||||
}
|
||||
|
||||
if req.Credentials != nil {
|
||||
account.Credentials = *req.Credentials
|
||||
}
|
||||
|
||||
if req.Extra != nil {
|
||||
account.Extra = *req.Extra
|
||||
}
|
||||
|
||||
if req.ProxyID != nil {
|
||||
account.ProxyID = req.ProxyID
|
||||
}
|
||||
|
||||
if req.Concurrency != nil {
|
||||
account.Concurrency = *req.Concurrency
|
||||
}
|
||||
|
||||
if req.Priority != nil {
|
||||
account.Priority = *req.Priority
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
account.Status = *req.Status
|
||||
}
|
||||
if req.ExpiresAt != nil {
|
||||
account.ExpiresAt = req.ExpiresAt
|
||||
}
|
||||
if req.AutoPauseOnExpired != nil {
|
||||
account.AutoPauseOnExpired = *req.AutoPauseOnExpired
|
||||
}
|
||||
|
||||
// 先验证分组是否存在(在任何写操作之前)
|
||||
if req.GroupIDs != nil {
|
||||
if err := s.validateGroupIDsExist(ctx, *req.GroupIDs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 执行更新
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return nil, fmt.Errorf("update account: %w", err)
|
||||
}
|
||||
|
||||
// 绑定分组
|
||||
if req.GroupIDs != nil {
|
||||
if err := s.accountRepo.BindGroups(ctx, account.ID, *req.GroupIDs); err != nil {
|
||||
return nil, fmt.Errorf("bind groups: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return account, nil
|
||||
}
|
||||
|
||||
// Delete 删除账号
|
||||
// 优化:使用 ExistsByID 替代 GetByID 进行存在性检查,
|
||||
// 避免加载完整账号对象及其关联数据,提升删除操作的性能
|
||||
func (s *AccountService) Delete(ctx context.Context, id int64) error {
|
||||
// 使用轻量级的存在性检查,而非加载完整账号对象
|
||||
exists, err := s.accountRepo.ExistsByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check account: %w", err)
|
||||
}
|
||||
// 明确返回账号不存在错误,便于调用方区分错误类型
|
||||
if !exists {
|
||||
return ErrAccountNotFound
|
||||
}
|
||||
|
||||
if err := s.accountRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete account: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountService) validateGroupIDsExist(ctx context.Context, groupIDs []int64) error {
|
||||
if len(groupIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
if s.groupRepo == nil {
|
||||
return fmt.Errorf("group repository not configured")
|
||||
}
|
||||
|
||||
if batchChecker, ok := s.groupRepo.(groupExistenceBatchChecker); ok {
|
||||
existsByID, err := batchChecker.ExistsByIDs(ctx, groupIDs)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check groups exists: %w", err)
|
||||
}
|
||||
for _, groupID := range groupIDs {
|
||||
if groupID <= 0 {
|
||||
return fmt.Errorf("get group: %w", ErrGroupNotFound)
|
||||
}
|
||||
if !existsByID[groupID] {
|
||||
return fmt.Errorf("get group: %w", ErrGroupNotFound)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for _, groupID := range groupIDs {
|
||||
_, err := s.groupRepo.GetByID(ctx, groupID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateStatus 更新账号状态
|
||||
func (s *AccountService) UpdateStatus(ctx context.Context, id int64, status string, errorMessage string) error {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
account.Status = status
|
||||
account.ErrorMessage = errorMessage
|
||||
|
||||
if err := s.accountRepo.Update(ctx, account); err != nil {
|
||||
return fmt.Errorf("update account: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateLastUsed 更新最后使用时间
|
||||
func (s *AccountService) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
if err := s.accountRepo.UpdateLastUsed(ctx, id); err != nil {
|
||||
return fmt.Errorf("update last used: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetCredential 获取账号凭证(安全访问)
|
||||
func (s *AccountService) GetCredential(ctx context.Context, id int64, key string) (string, error) {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
return account.GetCredential(key), nil
|
||||
}
|
||||
|
||||
// TestCredentials 测试账号凭证是否有效(需要实现具体平台的测试逻辑)
|
||||
func (s *AccountService) TestCredentials(ctx context.Context, id int64) error {
|
||||
account, err := s.accountRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get account: %w", err)
|
||||
}
|
||||
|
||||
// 根据平台执行不同的测试逻辑
|
||||
switch account.Platform {
|
||||
case PlatformAnthropic:
|
||||
// TODO: 测试Anthropic API凭证
|
||||
return nil
|
||||
case PlatformOpenAI:
|
||||
// TODO: 测试OpenAI API凭证
|
||||
return nil
|
||||
case PlatformGemini:
|
||||
// TODO: 测试Gemini API凭证
|
||||
return nil
|
||||
default:
|
||||
return fmt.Errorf("unsupported platform: %s", account.Platform)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,271 @@
|
||||
//go:build unit
|
||||
|
||||
// 账号服务删除方法的单元测试
|
||||
// 测试 AccountService.Delete 方法在各种场景下的行为
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// accountRepoStub 是 AccountRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 AccountService.Delete 方法,避免依赖真实数据库。
|
||||
//
|
||||
// 设计说明:
|
||||
// - exists: 模拟 ExistsByID 返回的存在性结果
|
||||
// - existsErr: 模拟 ExistsByID 返回的错误
|
||||
// - deleteErr: 模拟 Delete 返回的错误
|
||||
// - deletedIDs: 记录被调用删除的账号 ID,用于断言验证
|
||||
type accountRepoStub struct {
|
||||
exists bool // ExistsByID 的返回值
|
||||
existsErr error // ExistsByID 的错误返回值
|
||||
deleteErr error // Delete 的错误返回值
|
||||
deletedIDs []int64 // 记录已删除的账号 ID 列表
|
||||
}
|
||||
|
||||
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
|
||||
|
||||
func (s *accountRepoStub) Create(ctx context.Context, account *Account) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) GetByIDs(ctx context.Context, ids []int64) ([]*Account, error) {
|
||||
panic("unexpected GetByIDs call")
|
||||
}
|
||||
|
||||
// ExistsByID 返回预设的存在性检查结果。
|
||||
// 这是 Delete 方法调用的第一个仓储方法,用于验证账号是否存在。
|
||||
func (s *accountRepoStub) ExistsByID(ctx context.Context, id int64) (bool, error) {
|
||||
return s.exists, s.existsErr
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) GetByCRSAccountID(ctx context.Context, crsAccountID string) (*Account, error) {
|
||||
panic("unexpected GetByCRSAccountID call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) FindByExtraField(ctx context.Context, key string, value any) ([]Account, error) {
|
||||
panic("unexpected FindByExtraField call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListCRSAccountIDs(ctx context.Context) (map[string]int64, error) {
|
||||
panic("unexpected ListCRSAccountIDs call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
// Delete 记录被删除的账号 ID 并返回预设的错误。
|
||||
// 通过 deletedIDs 可以验证删除操作是否被正确调用。
|
||||
func (s *accountRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
s.deletedIDs = append(s.deletedIDs, id)
|
||||
return s.deleteErr
|
||||
}
|
||||
|
||||
// 以下是接口要求实现但本测试不关心的方法
|
||||
|
||||
func (s *accountRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]Account, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListByGroup(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
panic("unexpected ListByGroup call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListActive(ctx context.Context) ([]Account, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
panic("unexpected ListByPlatform call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) UpdateLastUsed(ctx context.Context, id int64) error {
|
||||
panic("unexpected UpdateLastUsed call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) BatchUpdateLastUsed(ctx context.Context, updates map[int64]time.Time) error {
|
||||
panic("unexpected BatchUpdateLastUsed call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetError(ctx context.Context, id int64, errorMsg string) error {
|
||||
panic("unexpected SetError call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearError(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearError call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetSchedulable(ctx context.Context, id int64, schedulable bool) error {
|
||||
panic("unexpected SetSchedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) AutoPauseExpiredAccounts(ctx context.Context, now time.Time) (int64, error) {
|
||||
panic("unexpected AutoPauseExpiredAccounts call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) BindGroups(ctx context.Context, accountID int64, groupIDs []int64) error {
|
||||
panic("unexpected BindGroups call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulable(ctx context.Context) ([]Account, error) {
|
||||
panic("unexpected ListSchedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableByGroupID(ctx context.Context, groupID int64) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableByGroupID call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableByPlatform call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatform(ctx context.Context, groupID int64, platform string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableByGroupIDAndPlatform call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableByPlatforms call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableByGroupIDAndPlatforms(ctx context.Context, groupID int64, platforms []string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableByGroupIDAndPlatforms call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableUngroupedByPlatform(ctx context.Context, platform string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableUngroupedByPlatform call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ListSchedulableUngroupedByPlatforms(ctx context.Context, platforms []string) ([]Account, error) {
|
||||
panic("unexpected ListSchedulableUngroupedByPlatforms call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetRateLimited(ctx context.Context, id int64, resetAt time.Time) error {
|
||||
panic("unexpected SetRateLimited call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetModelRateLimit(ctx context.Context, id int64, scope string, resetAt time.Time) error {
|
||||
panic("unexpected SetModelRateLimit call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetOverloaded(ctx context.Context, id int64, until time.Time) error {
|
||||
panic("unexpected SetOverloaded call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) SetTempUnschedulable(ctx context.Context, id int64, until time.Time, reason string) error {
|
||||
panic("unexpected SetTempUnschedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearTempUnschedulable(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearTempUnschedulable call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearRateLimit(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearRateLimit call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearAntigravityQuotaScopes(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearAntigravityQuotaScopes call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ClearModelRateLimits(ctx context.Context, id int64) error {
|
||||
panic("unexpected ClearModelRateLimits call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) UpdateSessionWindow(ctx context.Context, id int64, start, end *time.Time, status string) error {
|
||||
panic("unexpected UpdateSessionWindow call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) UpdateExtra(ctx context.Context, id int64, updates map[string]any) error {
|
||||
panic("unexpected UpdateExtra call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) BulkUpdate(ctx context.Context, ids []int64, updates AccountBulkUpdate) (int64, error) {
|
||||
panic("unexpected BulkUpdate call")
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStub) ResetQuotaUsed(ctx context.Context, id int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestAccountService_Delete_NotFound 测试删除不存在的账号时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - ExistsByID 返回 false(账号不存在)
|
||||
// - 返回 ErrAccountNotFound 错误
|
||||
// - Delete 方法不被调用(deletedIDs 为空)
|
||||
func TestAccountService_Delete_NotFound(t *testing.T) {
|
||||
repo := &accountRepoStub{exists: false}
|
||||
svc := &AccountService{accountRepo: repo}
|
||||
|
||||
err := svc.Delete(context.Background(), 55)
|
||||
require.ErrorIs(t, err, ErrAccountNotFound)
|
||||
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
|
||||
}
|
||||
|
||||
// TestAccountService_Delete_CheckError 测试存在性检查失败时的错误处理。
|
||||
// 预期行为:
|
||||
// - ExistsByID 返回数据库错误
|
||||
// - 返回包含 "check account" 的错误信息
|
||||
// - Delete 方法不被调用
|
||||
func TestAccountService_Delete_CheckError(t *testing.T) {
|
||||
repo := &accountRepoStub{existsErr: errors.New("db down")}
|
||||
svc := &AccountService{accountRepo: repo}
|
||||
|
||||
err := svc.Delete(context.Background(), 55)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "check account") // 验证错误信息包含上下文
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
}
|
||||
|
||||
// TestAccountService_Delete_DeleteError 测试删除操作失败时的错误处理。
|
||||
// 预期行为:
|
||||
// - ExistsByID 返回 true(账号存在)
|
||||
// - Delete 被调用但返回错误
|
||||
// - 返回包含 "delete account" 的错误信息
|
||||
// - deletedIDs 记录了尝试删除的 ID
|
||||
func TestAccountService_Delete_DeleteError(t *testing.T) {
|
||||
repo := &accountRepoStub{
|
||||
exists: true,
|
||||
deleteErr: errors.New("delete failed"),
|
||||
}
|
||||
svc := &AccountService{accountRepo: repo}
|
||||
|
||||
err := svc.Delete(context.Background(), 55)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "delete account")
|
||||
require.Equal(t, []int64{55}, repo.deletedIDs) // 验证删除操作被调用
|
||||
}
|
||||
|
||||
// TestAccountService_Delete_Success 测试删除操作成功的场景。
|
||||
// 预期行为:
|
||||
// - ExistsByID 返回 true(账号存在)
|
||||
// - Delete 成功执行
|
||||
// - 返回 nil 错误
|
||||
// - deletedIDs 记录了被删除的 ID
|
||||
func TestAccountService_Delete_Success(t *testing.T) {
|
||||
repo := &accountRepoStub{exists: true}
|
||||
svc := &AccountService{accountRepo: repo}
|
||||
|
||||
err := svc.Delete(context.Background(), 55)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{55}, repo.deletedIDs) // 验证正确的 ID 被删除
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,59 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateGeminiTestPayload_ImageModel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
payload := createGeminiTestPayload("gemini-2.5-flash-image", "draw a tiny robot")
|
||||
|
||||
var parsed struct {
|
||||
Contents []struct {
|
||||
Parts []struct {
|
||||
Text string `json:"text"`
|
||||
} `json:"parts"`
|
||||
} `json:"contents"`
|
||||
GenerationConfig struct {
|
||||
ResponseModalities []string `json:"responseModalities"`
|
||||
ImageConfig struct {
|
||||
AspectRatio string `json:"aspectRatio"`
|
||||
} `json:"imageConfig"`
|
||||
} `json:"generationConfig"`
|
||||
}
|
||||
|
||||
require.NoError(t, json.Unmarshal(payload, &parsed))
|
||||
require.Len(t, parsed.Contents, 1)
|
||||
require.Len(t, parsed.Contents[0].Parts, 1)
|
||||
require.Equal(t, "draw a tiny robot", parsed.Contents[0].Parts[0].Text)
|
||||
require.Equal(t, []string{"TEXT", "IMAGE"}, parsed.GenerationConfig.ResponseModalities)
|
||||
require.Equal(t, "1:1", parsed.GenerationConfig.ImageConfig.AspectRatio)
|
||||
}
|
||||
|
||||
func TestProcessGeminiStream_EmitsImageEvent(t *testing.T) {
|
||||
t.Parallel()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
ctx, recorder := newSoraTestContext()
|
||||
svc := &AccountTestService{}
|
||||
|
||||
stream := strings.NewReader("data: {\"candidates\":[{\"content\":{\"parts\":[{\"text\":\"ok\"},{\"inlineData\":{\"mimeType\":\"image/png\",\"data\":\"QUJD\"}}]}}]}\n\ndata: [DONE]\n\n")
|
||||
|
||||
err := svc.processGeminiStream(ctx, stream)
|
||||
require.NoError(t, err)
|
||||
|
||||
body := recorder.Body.String()
|
||||
require.Contains(t, body, "\"type\":\"content\"")
|
||||
require.Contains(t, body, "\"text\":\"ok\"")
|
||||
require.Contains(t, body, "\"type\":\"image\"")
|
||||
require.Contains(t, body, "\"image_url\":\"data:image/png;base64,QUJD\"")
|
||||
require.Contains(t, body, "\"mime_type\":\"image/png\"")
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type openAIAccountTestRepo struct {
|
||||
mockAccountRepoForGemini
|
||||
updatedExtra map[string]any
|
||||
rateLimitedID int64
|
||||
rateLimitedAt *time.Time
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
r.updatedExtra = updates
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *openAIAccountTestRepo) SetRateLimited(_ context.Context, id int64, resetAt time.Time) error {
|
||||
r.rateLimitedID = id
|
||||
r.rateLimitedAt = &resetAt
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAISuccessPersistsSnapshotFromHeaders(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, recorder := newSoraTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusOK, "")
|
||||
resp.Body = io.NopCloser(strings.NewReader(`data: {"type":"response.completed"}
|
||||
|
||||
`))
|
||||
resp.Header.Set("x-codex-primary-used-percent", "88")
|
||||
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
resp.Header.Set("x-codex-primary-window-minutes", "10080")
|
||||
resp.Header.Set("x-codex-secondary-used-percent", "42")
|
||||
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
resp.Header.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 89,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 42.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
require.Equal(t, 88.0, repo.updatedExtra["codex_7d_used_percent"])
|
||||
require.Contains(t, recorder.Body.String(), "test_complete")
|
||||
}
|
||||
|
||||
func TestAccountTestService_OpenAI429PersistsSnapshotAndRateLimit(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
ctx, _ := newSoraTestContext()
|
||||
|
||||
resp := newJSONResponse(http.StatusTooManyRequests, `{"error":{"type":"usage_limit_reached","message":"limit reached"}}`)
|
||||
resp.Header.Set("x-codex-primary-used-percent", "100")
|
||||
resp.Header.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
resp.Header.Set("x-codex-primary-window-minutes", "10080")
|
||||
resp.Header.Set("x-codex-secondary-used-percent", "100")
|
||||
resp.Header.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
resp.Header.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
repo := &openAIAccountTestRepo{}
|
||||
upstream := &queuedHTTPUpstream{responses: []*http.Response{resp}}
|
||||
svc := &AccountTestService{accountRepo: repo, httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 88,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{"access_token": "test-token"},
|
||||
}
|
||||
|
||||
err := svc.testOpenAIAccountConnection(ctx, account, "gpt-5.4")
|
||||
require.Error(t, err)
|
||||
require.NotEmpty(t, repo.updatedExtra)
|
||||
require.Equal(t, 100.0, repo.updatedExtra["codex_5h_used_percent"])
|
||||
require.Equal(t, int64(88), repo.rateLimitedID)
|
||||
require.NotNil(t, repo.rateLimitedAt)
|
||||
require.NotNil(t, account.RateLimitResetAt)
|
||||
if account.RateLimitResetAt != nil && repo.rateLimitedAt != nil {
|
||||
require.WithinDuration(t, *repo.rateLimitedAt, *account.RateLimitResetAt, time.Second)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,319 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type queuedHTTPUpstream struct {
|
||||
responses []*http.Response
|
||||
requests []*http.Request
|
||||
tlsFlags []bool
|
||||
}
|
||||
|
||||
func (u *queuedHTTPUpstream) Do(_ *http.Request, _ string, _ int64, _ int) (*http.Response, error) {
|
||||
return nil, fmt.Errorf("unexpected Do call")
|
||||
}
|
||||
|
||||
func (u *queuedHTTPUpstream) DoWithTLS(req *http.Request, _ string, _ int64, _ int, enableTLSFingerprint bool) (*http.Response, error) {
|
||||
u.requests = append(u.requests, req)
|
||||
u.tlsFlags = append(u.tlsFlags, enableTLSFingerprint)
|
||||
if len(u.responses) == 0 {
|
||||
return nil, fmt.Errorf("no mocked response")
|
||||
}
|
||||
resp := u.responses[0]
|
||||
u.responses = u.responses[1:]
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func newJSONResponse(status int, body string) *http.Response {
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Header: make(http.Header),
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
}
|
||||
}
|
||||
|
||||
func newJSONResponseWithHeader(status int, body, key, value string) *http.Response {
|
||||
resp := newJSONResponse(status, body)
|
||||
resp.Header.Set(key, value)
|
||||
return resp
|
||||
}
|
||||
|
||||
func newSoraTestContext() (*gin.Context, *httptest.ResponseRecorder) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
rec := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(rec)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/api/v1/admin/accounts/1/test", nil)
|
||||
return c, rec
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_WithSubscription(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
|
||||
newJSONResponse(http.StatusOK, `{"data":[{"plan":{"id":"chatgpt_plus","title":"ChatGPT Plus"},"end_ts":"2026-12-31T00:00:00Z"}]}`),
|
||||
newJSONResponse(http.StatusOK, `{"invite_code":"inv_abc","redeemed_count":3,"total_count":50}`),
|
||||
newJSONResponse(http.StatusOK, `{"rate_limit_and_credit_balance":{"estimated_num_videos_remaining":27,"rate_limit_reached":false,"access_resets_in_seconds":46833}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
cfg: &config.Config{
|
||||
Gateway: config.GatewayConfig{
|
||||
TLSFingerprint: config.TLSFingerprintConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
Sora: config.SoraConfig{
|
||||
Client: config.SoraClientConfig{
|
||||
DisableTLSFingerprint: false,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, upstream.requests, 4)
|
||||
require.Equal(t, soraMeAPIURL, upstream.requests[0].URL.String())
|
||||
require.Equal(t, soraBillingAPIURL, upstream.requests[1].URL.String())
|
||||
require.Equal(t, soraInviteMineURL, upstream.requests[2].URL.String())
|
||||
require.Equal(t, soraRemainingURL, upstream.requests[3].URL.String())
|
||||
require.Equal(t, "Bearer test_token", upstream.requests[0].Header.Get("Authorization"))
|
||||
require.Equal(t, "Bearer test_token", upstream.requests[1].Header.Get("Authorization"))
|
||||
require.Equal(t, []bool{true, true, true, true}, upstream.tlsFlags)
|
||||
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"type":"test_start"`)
|
||||
require.Contains(t, body, "Sora connection OK - Email: demo@example.com")
|
||||
require.Contains(t, body, "Subscription: ChatGPT Plus | chatgpt_plus | end=2026-12-31T00:00:00Z")
|
||||
require.Contains(t, body, "Sora2: supported | invite=inv_abc | used=3/50")
|
||||
require.Contains(t, body, "Sora2 remaining: 27 | reset_in=46833s")
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"status":"success"`)
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_SubscriptionFailedStillSuccess(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
||||
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
|
||||
newJSONResponse(http.StatusUnauthorized, `{"error":{"message":"Unauthorized"}}`),
|
||||
newJSONResponse(http.StatusForbidden, `{"error":{"message":"forbidden"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Len(t, upstream.requests, 4)
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Sora connection OK - User: demo-user")
|
||||
require.Contains(t, body, "Subscription check returned 403")
|
||||
require.Contains(t, body, "Sora2 invite check returned 401")
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"status":"partial_success"`)
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponseWithHeader(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`, "cf-ray", "9cff2d62d83bb98d"),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Cloudflare challenge")
|
||||
require.Contains(t, err.Error(), "cf-ray: 9cff2d62d83bb98d")
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"type":"error"`)
|
||||
require.Contains(t, body, "Cloudflare challenge")
|
||||
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_CloudflareChallenge429WithHeader(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponseWithHeader(http.StatusTooManyRequests, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body></body></html>`, "cf-mitigated", "challenge"),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "Cloudflare challenge")
|
||||
require.Contains(t, err.Error(), "HTTP 429")
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Cloudflare challenge")
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_TokenInvalidated(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusUnauthorized, `{"error":{"code":"token_invalidated","message":"Token invalid"}}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "token_invalidated")
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"status":"failed"`)
|
||||
require.Contains(t, body, "token_invalidated")
|
||||
require.NotContains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_RateLimited(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"email":"demo@example.com"}`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{
|
||||
httpUpstream: upstream,
|
||||
soraTestCooldown: time.Hour,
|
||||
}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c1, _ := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c1, account)
|
||||
require.NoError(t, err)
|
||||
|
||||
c2, rec2 := newSoraTestContext()
|
||||
err = svc.testSoraAccountConnection(c2, account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "测试过于频繁")
|
||||
body := rec2.Body.String()
|
||||
require.Contains(t, body, `"type":"sora_test_result"`)
|
||||
require.Contains(t, body, `"code":"test_rate_limited"`)
|
||||
require.Contains(t, body, `"status":"failed"`)
|
||||
require.NotContains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestAccountTestService_testSoraAccountConnection_SubscriptionCloudflareChallengeWithRay(t *testing.T) {
|
||||
upstream := &queuedHTTPUpstream{
|
||||
responses: []*http.Response{
|
||||
newJSONResponse(http.StatusOK, `{"name":"demo-user"}`),
|
||||
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
|
||||
newJSONResponse(http.StatusForbidden, `<!DOCTYPE html><html><head><title>Just a moment...</title></head><body><script>window._cf_chl_opt={cRay: '9cff2d62d83bb98d'};</script><noscript>Enable JavaScript and cookies to continue</noscript></body></html>`),
|
||||
},
|
||||
}
|
||||
svc := &AccountTestService{httpUpstream: upstream}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformSora,
|
||||
Type: AccountTypeOAuth,
|
||||
Concurrency: 1,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test_token",
|
||||
},
|
||||
}
|
||||
|
||||
c, rec := newSoraTestContext()
|
||||
err := svc.testSoraAccountConnection(c, account)
|
||||
|
||||
require.NoError(t, err)
|
||||
body := rec.Body.String()
|
||||
require.Contains(t, body, "Subscription check blocked by Cloudflare challenge (HTTP 403)")
|
||||
require.Contains(t, body, "Sora2 invite check blocked by Cloudflare challenge (HTTP 403)")
|
||||
require.Contains(t, body, "cf-ray: 9cff2d62d83bb98d")
|
||||
require.Contains(t, body, `"type":"test_complete","success":true`)
|
||||
}
|
||||
|
||||
func TestSanitizeProxyURLForLog(t *testing.T) {
|
||||
require.Equal(t, "http://proxy.example.com:8080", sanitizeProxyURLForLog("http://user:pass@proxy.example.com:8080"))
|
||||
require.Equal(t, "", sanitizeProxyURLForLog(""))
|
||||
require.Equal(t, "<invalid_proxy_url>", sanitizeProxyURLForLog("://invalid"))
|
||||
}
|
||||
|
||||
func TestExtractSoraEgressIPHint(t *testing.T) {
|
||||
h := make(http.Header)
|
||||
h.Set("x-openai-public-ip", "203.0.113.10")
|
||||
require.Equal(t, "203.0.113.10", extractSoraEgressIPHint(h))
|
||||
|
||||
h2 := make(http.Header)
|
||||
h2.Set("x-envoy-external-address", "198.51.100.9")
|
||||
require.Equal(t, "198.51.100.9", extractSoraEgressIPHint(h2))
|
||||
|
||||
require.Equal(t, "unknown", extractSoraEgressIPHint(nil))
|
||||
require.Equal(t, "unknown", extractSoraEgressIPHint(http.Header{}))
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,150 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type accountUsageCodexProbeRepo struct {
|
||||
stubOpenAIAccountRepo
|
||||
updateExtraCh chan map[string]any
|
||||
rateLimitCh chan time.Time
|
||||
}
|
||||
|
||||
func (r *accountUsageCodexProbeRepo) UpdateExtra(_ context.Context, _ int64, updates map[string]any) error {
|
||||
if r.updateExtraCh != nil {
|
||||
copied := make(map[string]any, len(updates))
|
||||
for k, v := range updates {
|
||||
copied[k] = v
|
||||
}
|
||||
r.updateExtraCh <- copied
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *accountUsageCodexProbeRepo) SetRateLimited(_ context.Context, _ int64, resetAt time.Time) error {
|
||||
if r.rateLimitCh != nil {
|
||||
r.rateLimitCh <- resetAt
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestShouldRefreshOpenAICodexSnapshot(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
rateLimitedUntil := time.Now().Add(5 * time.Minute)
|
||||
now := time.Now()
|
||||
usage := &UsageInfo{
|
||||
FiveHour: &UsageProgress{Utilization: 0},
|
||||
SevenDay: &UsageProgress{Utilization: 0},
|
||||
}
|
||||
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{RateLimitResetAt: &rateLimitedUntil}, usage, now) {
|
||||
t.Fatal("expected rate-limited account to force codex snapshot refresh")
|
||||
}
|
||||
|
||||
if shouldRefreshOpenAICodexSnapshot(&Account{}, usage, now) {
|
||||
t.Fatal("expected complete non-rate-limited usage to skip codex snapshot refresh")
|
||||
}
|
||||
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{}, &UsageInfo{FiveHour: nil, SevenDay: &UsageProgress{}}, now) {
|
||||
t.Fatal("expected missing 5h snapshot to require refresh")
|
||||
}
|
||||
|
||||
staleAt := now.Add(-(openAIProbeCacheTTL + time.Minute)).Format(time.RFC3339)
|
||||
if !shouldRefreshOpenAICodexSnapshot(&Account{
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
Extra: map[string]any{
|
||||
"openai_oauth_responses_websockets_v2_enabled": true,
|
||||
"codex_usage_updated_at": staleAt,
|
||||
},
|
||||
}, usage, now) {
|
||||
t.Fatal("expected stale ws snapshot to trigger refresh")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractOpenAICodexProbeUpdatesAccepts429WithCodexHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||
headers.Set("x-codex-secondary-used-percent", "100")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
updates, err := extractOpenAICodexProbeUpdates(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
|
||||
if err != nil {
|
||||
t.Fatalf("extractOpenAICodexProbeUpdates() error = %v", err)
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
t.Fatal("expected codex probe updates from 429 headers")
|
||||
}
|
||||
if got := updates["codex_5h_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_5h_used_percent = %v, want 100", got)
|
||||
}
|
||||
if got := updates["codex_7d_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractOpenAICodexProbeSnapshotAccepts429WithResetAt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := make(http.Header)
|
||||
headers.Set("x-codex-primary-used-percent", "100")
|
||||
headers.Set("x-codex-primary-reset-after-seconds", "604800")
|
||||
headers.Set("x-codex-primary-window-minutes", "10080")
|
||||
headers.Set("x-codex-secondary-used-percent", "100")
|
||||
headers.Set("x-codex-secondary-reset-after-seconds", "18000")
|
||||
headers.Set("x-codex-secondary-window-minutes", "300")
|
||||
|
||||
updates, resetAt, err := extractOpenAICodexProbeSnapshot(&http.Response{StatusCode: http.StatusTooManyRequests, Header: headers})
|
||||
if err != nil {
|
||||
t.Fatalf("extractOpenAICodexProbeSnapshot() error = %v", err)
|
||||
}
|
||||
if len(updates) == 0 {
|
||||
t.Fatal("expected codex probe updates from 429 headers")
|
||||
}
|
||||
if resetAt == nil {
|
||||
t.Fatal("expected resetAt from exhausted codex headers")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountUsageService_PersistOpenAICodexProbeSnapshotSetsRateLimit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
repo := &accountUsageCodexProbeRepo{
|
||||
updateExtraCh: make(chan map[string]any, 1),
|
||||
rateLimitCh: make(chan time.Time, 1),
|
||||
}
|
||||
svc := &AccountUsageService{accountRepo: repo}
|
||||
resetAt := time.Now().Add(2 * time.Hour).UTC().Truncate(time.Second)
|
||||
|
||||
svc.persistOpenAICodexProbeSnapshot(321, map[string]any{
|
||||
"codex_7d_used_percent": 100.0,
|
||||
"codex_7d_reset_at": resetAt.Format(time.RFC3339),
|
||||
}, &resetAt)
|
||||
|
||||
select {
|
||||
case updates := <-repo.updateExtraCh:
|
||||
if got := updates["codex_7d_used_percent"]; got != 100.0 {
|
||||
t.Fatalf("codex_7d_used_percent = %v, want 100", got)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("waiting for codex probe extra persistence timed out")
|
||||
}
|
||||
|
||||
select {
|
||||
case got := <-repo.rateLimitCh:
|
||||
if got.Before(resetAt.Add(-time.Second)) || got.After(resetAt.Add(time.Second)) {
|
||||
t.Fatalf("rate limit resetAt = %v, want around %v", got, resetAt)
|
||||
}
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("waiting for codex probe rate limit persistence timed out")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,455 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestMatchWildcard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pattern string
|
||||
str string
|
||||
expected bool
|
||||
}{
|
||||
// 精确匹配
|
||||
{"exact match", "claude-sonnet-4-5", "claude-sonnet-4-5", true},
|
||||
{"exact mismatch", "claude-sonnet-4-5", "claude-opus-4-5", false},
|
||||
|
||||
// 通配符匹配
|
||||
{"wildcard prefix match", "claude-*", "claude-sonnet-4-5", true},
|
||||
{"wildcard prefix match 2", "claude-*", "claude-opus-4-5-thinking", true},
|
||||
{"wildcard prefix mismatch", "claude-*", "gemini-3-flash", false},
|
||||
{"wildcard partial match", "gemini-3*", "gemini-3-flash", true},
|
||||
{"wildcard partial match 2", "gemini-3*", "gemini-3-pro-image", true},
|
||||
{"wildcard partial mismatch", "gemini-3*", "gemini-2.5-flash", false},
|
||||
|
||||
// 边界情况
|
||||
{"empty pattern exact", "", "", true},
|
||||
{"empty pattern mismatch", "", "claude", false},
|
||||
{"single star", "*", "anything", true},
|
||||
{"star at end only", "abc*", "abcdef", true},
|
||||
{"star at end empty suffix", "abc*", "abc", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := matchWildcard(tt.pattern, tt.str)
|
||||
if result != tt.expected {
|
||||
t.Errorf("matchWildcard(%q, %q) = %v, want %v", tt.pattern, tt.str, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMatchWildcardMappingResult(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mapping map[string]string
|
||||
requestedModel string
|
||||
expected string
|
||||
matched bool
|
||||
}{
|
||||
// 精确匹配优先于通配符
|
||||
{
|
||||
name: "exact match takes precedence",
|
||||
mapping: map[string]string{
|
||||
"claude-sonnet-4-5": "claude-sonnet-4-5-exact",
|
||||
"claude-*": "claude-default",
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5-exact",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 最长通配符优先
|
||||
{
|
||||
name: "longer wildcard takes precedence",
|
||||
mapping: map[string]string{
|
||||
"claude-*": "claude-default",
|
||||
"claude-sonnet-*": "claude-sonnet-default",
|
||||
"claude-sonnet-4*": "claude-sonnet-4-series",
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-series",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 单个通配符
|
||||
{
|
||||
name: "single wildcard",
|
||||
mapping: map[string]string{
|
||||
"claude-*": "claude-mapped",
|
||||
},
|
||||
requestedModel: "claude-opus-4-5",
|
||||
expected: "claude-mapped",
|
||||
matched: true,
|
||||
},
|
||||
|
||||
// 无匹配返回原始模型
|
||||
{
|
||||
name: "no match returns original",
|
||||
mapping: map[string]string{
|
||||
"claude-*": "claude-mapped",
|
||||
},
|
||||
requestedModel: "gemini-3-flash",
|
||||
expected: "gemini-3-flash",
|
||||
matched: false,
|
||||
},
|
||||
|
||||
// 空映射返回原始模型
|
||||
{
|
||||
name: "empty mapping returns original",
|
||||
mapping: map[string]string{},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5",
|
||||
matched: false,
|
||||
},
|
||||
|
||||
// Gemini 模型映射
|
||||
{
|
||||
name: "gemini wildcard mapping",
|
||||
mapping: map[string]string{
|
||||
"gemini-3*": "gemini-3-pro-high",
|
||||
"gemini-2.5*": "gemini-2.5-flash",
|
||||
},
|
||||
requestedModel: "gemini-3-flash-preview",
|
||||
expected: "gemini-3-pro-high",
|
||||
matched: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, matched := matchWildcardMappingResult(tt.mapping, tt.requestedModel)
|
||||
if result != tt.expected || matched != tt.matched {
|
||||
t.Errorf("matchWildcardMappingResult(%v, %q) = (%q, %v), want (%q, %v)", tt.mapping, tt.requestedModel, result, matched, tt.expected, tt.matched)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountIsModelSupported(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
credentials map[string]any
|
||||
requestedModel string
|
||||
expected bool
|
||||
}{
|
||||
// 无映射 = 允许所有
|
||||
{
|
||||
name: "no mapping allows all",
|
||||
credentials: nil,
|
||||
requestedModel: "any-model",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "empty mapping allows all",
|
||||
credentials: map[string]any{},
|
||||
requestedModel: "any-model",
|
||||
expected: true,
|
||||
},
|
||||
|
||||
// 精确匹配
|
||||
{
|
||||
name: "exact match supported",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-5": "target-model",
|
||||
},
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "exact match not supported",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-5": "target-model",
|
||||
},
|
||||
},
|
||||
requestedModel: "claude-opus-4-5",
|
||||
expected: false,
|
||||
},
|
||||
|
||||
// 通配符匹配
|
||||
{
|
||||
name: "wildcard match supported",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-*": "claude-sonnet-4-5",
|
||||
},
|
||||
},
|
||||
requestedModel: "claude-opus-4-5-thinking",
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard match not supported",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-*": "claude-sonnet-4-5",
|
||||
},
|
||||
},
|
||||
requestedModel: "gemini-3-flash",
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: tt.credentials,
|
||||
}
|
||||
result := account.IsModelSupported(tt.requestedModel)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsModelSupported(%q) = %v, want %v", tt.requestedModel, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetMappedModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
credentials map[string]any
|
||||
requestedModel string
|
||||
expected string
|
||||
}{
|
||||
// 无映射 = 返回原始模型
|
||||
{
|
||||
name: "no mapping returns original",
|
||||
credentials: nil,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
|
||||
// 精确匹配
|
||||
{
|
||||
name: "exact match",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-5": "target-model",
|
||||
},
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "target-model",
|
||||
},
|
||||
|
||||
// 通配符匹配(最长优先)
|
||||
{
|
||||
name: "wildcard longest match",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-*": "claude-default",
|
||||
"claude-sonnet-*": "claude-sonnet-mapped",
|
||||
},
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-mapped",
|
||||
},
|
||||
|
||||
// 无匹配返回原始模型
|
||||
{
|
||||
name: "no match returns original",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-*": "gemini-mapped",
|
||||
},
|
||||
},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: tt.credentials,
|
||||
}
|
||||
result := account.GetMappedModel(tt.requestedModel)
|
||||
if result != tt.expected {
|
||||
t.Errorf("GetMappedModel(%q) = %q, want %q", tt.requestedModel, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountResolveMappedModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
credentials map[string]any
|
||||
requestedModel string
|
||||
expectedModel string
|
||||
expectedMatch bool
|
||||
}{
|
||||
{
|
||||
name: "no mapping reports unmatched",
|
||||
credentials: nil,
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
{
|
||||
name: "exact passthrough mapping still counts as matched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.4": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard passthrough mapping still counts as matched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-*": "gpt-5.4",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: true,
|
||||
},
|
||||
{
|
||||
name: "missing mapping reports unmatched",
|
||||
credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gpt-5.2": "gpt-5.2",
|
||||
},
|
||||
},
|
||||
requestedModel: "gpt-5.4",
|
||||
expectedModel: "gpt-5.4",
|
||||
expectedMatch: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: tt.credentials,
|
||||
}
|
||||
mappedModel, matched := account.ResolveMappedModel(tt.requestedModel)
|
||||
if mappedModel != tt.expectedModel || matched != tt.expectedMatch {
|
||||
t.Fatalf("ResolveMappedModel(%q) = (%q, %v), want (%q, %v)", tt.requestedModel, mappedModel, matched, tt.expectedModel, tt.expectedMatch)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_AntigravityEnsuresGeminiDefaultPassthroughs(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-3-pro-high": "gemini-3.1-pro-high",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mapping := account.GetModelMapping()
|
||||
if mapping["gemini-3-flash"] != "gemini-3-flash" {
|
||||
t.Fatalf("expected gemini-3-flash passthrough to be auto-filled, got: %q", mapping["gemini-3-flash"])
|
||||
}
|
||||
if mapping["gemini-3.1-pro-high"] != "gemini-3.1-pro-high" {
|
||||
t.Fatalf("expected gemini-3.1-pro-high passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-high"])
|
||||
}
|
||||
if mapping["gemini-3.1-pro-low"] != "gemini-3.1-pro-low" {
|
||||
t.Fatalf("expected gemini-3.1-pro-low passthrough to be auto-filled, got: %q", mapping["gemini-3.1-pro-low"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_AntigravityRespectsWildcardOverride(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"gemini-3*": "gemini-3.1-pro-high",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mapping := account.GetModelMapping()
|
||||
if _, exists := mapping["gemini-3-flash"]; exists {
|
||||
t.Fatalf("did not expect explicit gemini-3-flash passthrough when wildcard already exists")
|
||||
}
|
||||
if _, exists := mapping["gemini-3.1-pro-high"]; exists {
|
||||
t.Fatalf("did not expect explicit gemini-3.1-pro-high passthrough when wildcard already exists")
|
||||
}
|
||||
if _, exists := mapping["gemini-3.1-pro-low"]; exists {
|
||||
t.Fatalf("did not expect explicit gemini-3.1-pro-low passthrough when wildcard already exists")
|
||||
}
|
||||
if mapped := account.GetMappedModel("gemini-3-flash"); mapped != "gemini-3.1-pro-high" {
|
||||
t.Fatalf("expected wildcard mapping to stay effective, got: %q", mapped)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_CacheInvalidatesOnCredentialsReplace(t *testing.T) {
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-3-5-sonnet": "upstream-a",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
first := account.GetModelMapping()
|
||||
if first["claude-3-5-sonnet"] != "upstream-a" {
|
||||
t.Fatalf("unexpected first mapping: %v", first)
|
||||
}
|
||||
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-3-5-sonnet": "upstream-b",
|
||||
},
|
||||
}
|
||||
second := account.GetModelMapping()
|
||||
if second["claude-3-5-sonnet"] != "upstream-b" {
|
||||
t.Fatalf("expected cache invalidated after credentials replace, got: %v", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_CacheInvalidatesOnMappingLenChange(t *testing.T) {
|
||||
rawMapping := map[string]any{
|
||||
"claude-sonnet": "sonnet-a",
|
||||
}
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": rawMapping,
|
||||
},
|
||||
}
|
||||
|
||||
first := account.GetModelMapping()
|
||||
if len(first) != 1 {
|
||||
t.Fatalf("unexpected first mapping length: %d", len(first))
|
||||
}
|
||||
|
||||
rawMapping["claude-opus"] = "opus-b"
|
||||
second := account.GetModelMapping()
|
||||
if second["claude-opus"] != "opus-b" {
|
||||
t.Fatalf("expected cache invalidated after mapping len change, got: %v", second)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAccountGetModelMapping_CacheInvalidatesOnInPlaceValueChange(t *testing.T) {
|
||||
rawMapping := map[string]any{
|
||||
"claude-sonnet": "sonnet-a",
|
||||
}
|
||||
account := &Account{
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": rawMapping,
|
||||
},
|
||||
}
|
||||
|
||||
first := account.GetModelMapping()
|
||||
if first["claude-sonnet"] != "sonnet-a" {
|
||||
t.Fatalf("unexpected first mapping: %v", first)
|
||||
}
|
||||
|
||||
rawMapping["claude-sonnet"] = "sonnet-b"
|
||||
second := account.GetModelMapping()
|
||||
if second["claude-sonnet"] != "sonnet-b" {
|
||||
t.Fatalf("expected cache invalidated after in-place value change, got: %v", second)
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,503 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Stubs
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// userRepoStubForGroupUpdate implements UserRepository for AdminUpdateAPIKeyGroupID tests.
|
||||
type userRepoStubForGroupUpdate struct {
|
||||
addGroupErr error
|
||||
addGroupCalled bool
|
||||
addedUserID int64
|
||||
addedGroupID int64
|
||||
}
|
||||
|
||||
func (s *userRepoStubForGroupUpdate) AddGroupToAllowedGroups(_ context.Context, userID int64, groupID int64) error {
|
||||
s.addGroupCalled = true
|
||||
s.addedUserID = userID
|
||||
s.addedGroupID = groupID
|
||||
return s.addGroupErr
|
||||
}
|
||||
|
||||
func (s *userRepoStubForGroupUpdate) Create(context.Context, *User) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) GetByID(context.Context, int64) (*User, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) GetByEmail(context.Context, string) (*User, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) GetFirstAdmin(context.Context) (*User, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) Update(context.Context, *User) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) UpdateBalance(context.Context, int64, float64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) DeductBalance(context.Context, int64, float64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) UpdateConcurrency(context.Context, int64, int) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) ExistsByEmail(context.Context, string) (bool, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) RemoveGroupFromAllowedGroups(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) UpdateTotpSecret(context.Context, int64, *string) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *userRepoStubForGroupUpdate) EnableTotp(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *userRepoStubForGroupUpdate) DisableTotp(context.Context, int64) error { panic("unexpected") }
|
||||
|
||||
// apiKeyRepoStubForGroupUpdate implements APIKeyRepository for AdminUpdateAPIKeyGroupID tests.
|
||||
type apiKeyRepoStubForGroupUpdate struct {
|
||||
key *APIKey
|
||||
getErr error
|
||||
updateErr error
|
||||
updated *APIKey // captures what was passed to Update
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStubForGroupUpdate) GetByID(_ context.Context, _ int64) (*APIKey, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
clone := *s.key
|
||||
return &clone, nil
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) Update(_ context.Context, key *APIKey) error {
|
||||
if s.updateErr != nil {
|
||||
return s.updateErr
|
||||
}
|
||||
clone := *key
|
||||
s.updated = &clone
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unused methods – panic on unexpected call.
|
||||
func (s *apiKeyRepoStubForGroupUpdate) Create(context.Context, *APIKey) error { panic("unexpected") }
|
||||
func (s *apiKeyRepoStubForGroupUpdate) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) GetByKey(context.Context, string) (*APIKey, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) CountByUserID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ExistsByKey(context.Context, string) (bool, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) CountByGroupID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ListKeysByUserID(context.Context, int64) ([]string, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ListKeysByGroupID(context.Context, int64) ([]string, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) UpdateLastUsed(context.Context, int64, time.Time) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) IncrementRateLimitUsage(context.Context, int64, float64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) ResetRateLimitWindows(context.Context, int64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *apiKeyRepoStubForGroupUpdate) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
// groupRepoStubForGroupUpdate implements GroupRepository for AdminUpdateAPIKeyGroupID tests.
|
||||
type groupRepoStubForGroupUpdate struct {
|
||||
group *Group
|
||||
getErr error
|
||||
lastGetByIDArg int64
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForGroupUpdate) GetByID(_ context.Context, id int64) (*Group, error) {
|
||||
s.lastGetByIDArg = id
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
clone := *s.group
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
// Unused methods – panic on unexpected call.
|
||||
func (s *groupRepoStubForGroupUpdate) Create(context.Context, *Group) error { panic("unexpected") }
|
||||
func (s *groupRepoStubForGroupUpdate) GetByIDLite(context.Context, int64) (*Group, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) Update(context.Context, *Group) error { panic("unexpected") }
|
||||
func (s *groupRepoStubForGroupUpdate) Delete(context.Context, int64) error { panic("unexpected") }
|
||||
func (s *groupRepoStubForGroupUpdate) DeleteCascade(context.Context, int64) ([]int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) List(context.Context, pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) ListWithFilters(context.Context, pagination.PaginationParams, string, string, string, *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) ListActive(context.Context) ([]Group, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) ListActiveByPlatform(context.Context, string) ([]Group, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) ExistsByName(context.Context, string) (bool, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) GetAccountCount(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) DeleteAccountGroupsByGroupID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) GetAccountIDsByGroupIDs(context.Context, []int64) ([]int64, error) {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) BindAccountsToGroup(context.Context, int64, []int64) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
func (s *groupRepoStubForGroupUpdate) UpdateSortOrders(context.Context, []GroupSortOrderUpdate) error {
|
||||
panic("unexpected")
|
||||
}
|
||||
|
||||
type userSubRepoStubForGroupUpdate struct {
|
||||
userSubRepoNoop
|
||||
getActiveSub *UserSubscription
|
||||
getActiveErr error
|
||||
called bool
|
||||
calledUserID int64
|
||||
calledGroupID int64
|
||||
}
|
||||
|
||||
func (s *userSubRepoStubForGroupUpdate) GetActiveByUserIDAndGroupID(_ context.Context, userID, groupID int64) (*UserSubscription, error) {
|
||||
s.called = true
|
||||
s.calledUserID = userID
|
||||
s.calledGroupID = groupID
|
||||
if s.getActiveErr != nil {
|
||||
return nil, s.getActiveErr
|
||||
}
|
||||
if s.getActiveSub == nil {
|
||||
return nil, ErrSubscriptionNotFound
|
||||
}
|
||||
clone := *s.getActiveSub
|
||||
return &clone, nil
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_KeyNotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStubForGroupUpdate{getErr: ErrAPIKeyNotFound}
|
||||
svc := &adminServiceImpl{apiKeyRepo: repo}
|
||||
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 999, int64Ptr(1))
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_NilGroupID_NoOp(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5)}
|
||||
repo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
svc := &adminServiceImpl{apiKeyRepo: repo}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), got.APIKey.ID)
|
||||
// Update should NOT have been called (updated stays nil)
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(5), Group: &Group{ID: 5, Name: "Old"}}
|
||||
repo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: repo, authCacheInvalidator: cache}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, got.APIKey.GroupID, "group_id should be nil after unbind")
|
||||
require.Nil(t, got.APIKey.Group, "group object should be nil after unbind")
|
||||
require.NotNil(t, repo.updated, "Update should have been called")
|
||||
require.Nil(t, repo.updated.GroupID)
|
||||
require.Equal(t, []string{"sk-test"}, cache.keys, "cache should be invalidated")
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_BindActiveGroup(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
require.Equal(t, int64(10), *got.APIKey.GroupID)
|
||||
require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
|
||||
require.Equal(t, []string{"sk-test"}, cache.keys)
|
||||
// M3: verify correct group ID was passed to repo
|
||||
require.Equal(t, int64(10), groupRepo.lastGetByIDArg)
|
||||
// C1 fix: verify Group object is populated
|
||||
require.NotNil(t, got.APIKey.Group)
|
||||
require.Equal(t, "Pro", got.APIKey.Group.Name)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_SameGroup_Idempotent(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Pro"}}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
require.Equal(t, int64(10), *got.APIKey.GroupID)
|
||||
// Update is still called (current impl doesn't short-circuit on same group)
|
||||
require.NotNil(t, apiKeyRepo.updated)
|
||||
require.Equal(t, []string{"sk-test"}, cache.keys)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotFound(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test"}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{getErr: ErrGroupNotFound}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
|
||||
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(99))
|
||||
require.ErrorIs(t, err, ErrGroupNotFound)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_GroupNotActive(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test"}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 5, Status: StatusDisabled}}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
|
||||
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(5))
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "GROUP_NOT_ACTIVE", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_UpdateFails(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: int64Ptr(3)}
|
||||
repo := &apiKeyRepoStubForGroupUpdate{key: existing, updateErr: errors.New("db write error")}
|
||||
svc := &adminServiceImpl{apiKeyRepo: repo}
|
||||
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "update api key")
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_NegativeGroupID(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test"}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo}
|
||||
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(-5))
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "INVALID_GROUP_ID", infraerrors.Reason(err))
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_PointerIsolation(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Pro", Status: StatusActive}}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, authCacheInvalidator: cache}
|
||||
|
||||
inputGID := int64(10)
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, &inputGID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
// Mutating the input pointer must NOT affect the stored value
|
||||
inputGID = 999
|
||||
require.Equal(t, int64(10), *got.APIKey.GroupID)
|
||||
require.Equal(t, int64(10), *apiKeyRepo.updated.GroupID)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_NilCacheInvalidator(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, Key: "sk-test"}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 7, Status: StatusActive}}
|
||||
// authCacheInvalidator is nil – should not panic
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(7))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
require.Equal(t, int64(7), *got.APIKey.GroupID)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests: AllowedGroup auto-sync
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AddsAllowedGroup(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
require.Equal(t, int64(10), *got.APIKey.GroupID)
|
||||
// 验证 AddGroupToAllowedGroups 被调用,且参数正确
|
||||
require.True(t, userRepo.addGroupCalled)
|
||||
require.Equal(t, int64(42), userRepo.addedUserID)
|
||||
require.Equal(t, int64(10), userRepo.addedGroupID)
|
||||
// 验证 result 标记了自动授权
|
||||
require.True(t, got.AutoGrantedGroupAccess)
|
||||
require.NotNil(t, got.GrantedGroupID)
|
||||
require.Equal(t, int64(10), *got.GrantedGroupID)
|
||||
require.Equal(t, "Exclusive", got.GrantedGroupName)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_NonExclusiveGroup_NoAllowedGroupUpdate(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Public", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeStandard}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, authCacheInvalidator: cache}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
// 非专属分组不触发 AddGroupToAllowedGroups
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
require.False(t, got.AutoGrantedGroupAccess)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_Blocked(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
userSubRepo := &userSubRepoStubForGroupUpdate{getActiveErr: ErrSubscriptionNotFound}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
|
||||
|
||||
// 无有效订阅时应拒绝绑定
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "SUBSCRIPTION_REQUIRED", infraerrors.Reason(err))
|
||||
require.True(t, userSubRepo.called)
|
||||
require.Equal(t, int64(42), userSubRepo.calledUserID)
|
||||
require.Equal(t, int64(10), userSubRepo.calledGroupID)
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_RequiresRepo(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: false, SubscriptionType: SubscriptionTypeSubscription}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
|
||||
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "SUBSCRIPTION_REPOSITORY_UNAVAILABLE", infraerrors.Reason(err))
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_SubscriptionGroup_AllowsActiveSubscription(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Sub", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeSubscription}}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
userSubRepo := &userSubRepoStubForGroupUpdate{
|
||||
getActiveSub: &UserSubscription{ID: 99, UserID: 42, GroupID: 10},
|
||||
}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo, userSubRepo: userSubRepo}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.NoError(t, err)
|
||||
require.True(t, userSubRepo.called)
|
||||
require.NotNil(t, got.APIKey.GroupID)
|
||||
require.Equal(t, int64(10), *got.APIKey.GroupID)
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_ExclusiveGroup_AllowedGroupAddFails_ReturnsError(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: nil}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
groupRepo := &groupRepoStubForGroupUpdate{group: &Group{ID: 10, Name: "Exclusive", Status: StatusActive, IsExclusive: true, SubscriptionType: SubscriptionTypeStandard}}
|
||||
userRepo := &userRepoStubForGroupUpdate{addGroupErr: errors.New("db error")}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, groupRepo: groupRepo, userRepo: userRepo}
|
||||
|
||||
// 严格模式:AddGroupToAllowedGroups 失败时,整体操作报错
|
||||
_, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(10))
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "add group to user allowed groups")
|
||||
require.True(t, userRepo.addGroupCalled)
|
||||
// apiKey 不应被更新
|
||||
require.Nil(t, apiKeyRepo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_AdminUpdateAPIKeyGroupID_Unbind_NoAllowedGroupUpdate(t *testing.T) {
|
||||
existing := &APIKey{ID: 1, UserID: 42, Key: "sk-test", GroupID: int64Ptr(10), Group: &Group{ID: 10, Name: "Exclusive"}}
|
||||
apiKeyRepo := &apiKeyRepoStubForGroupUpdate{key: existing}
|
||||
userRepo := &userRepoStubForGroupUpdate{}
|
||||
cache := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{apiKeyRepo: apiKeyRepo, userRepo: userRepo, authCacheInvalidator: cache}
|
||||
|
||||
got, err := svc.AdminUpdateAPIKeyGroupID(context.Background(), 1, int64Ptr(0))
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, got.APIKey.GroupID)
|
||||
// 解绑时不修改 allowed_groups
|
||||
require.False(t, userRepo.addGroupCalled)
|
||||
require.False(t, got.AutoGrantedGroupAccess)
|
||||
}
|
||||
@@ -0,0 +1,172 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type accountRepoStubForBulkUpdate struct {
|
||||
accountRepoStub
|
||||
bulkUpdateErr error
|
||||
bulkUpdateIDs []int64
|
||||
bindGroupErrByID map[int64]error
|
||||
bindGroupsCalls []int64
|
||||
getByIDsAccounts []*Account
|
||||
getByIDsErr error
|
||||
getByIDsCalled bool
|
||||
getByIDsIDs []int64
|
||||
getByIDAccounts map[int64]*Account
|
||||
getByIDErrByID map[int64]error
|
||||
getByIDCalled []int64
|
||||
listByGroupData map[int64][]Account
|
||||
listByGroupErr map[int64]error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BulkUpdate(_ context.Context, ids []int64, _ AccountBulkUpdate) (int64, error) {
|
||||
s.bulkUpdateIDs = append([]int64{}, ids...)
|
||||
if s.bulkUpdateErr != nil {
|
||||
return 0, s.bulkUpdateErr
|
||||
}
|
||||
return int64(len(ids)), nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) BindGroups(_ context.Context, accountID int64, _ []int64) error {
|
||||
s.bindGroupsCalls = append(s.bindGroupsCalls, accountID)
|
||||
if err, ok := s.bindGroupErrByID[accountID]; ok {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) GetByIDs(_ context.Context, ids []int64) ([]*Account, error) {
|
||||
s.getByIDsCalled = true
|
||||
s.getByIDsIDs = append([]int64{}, ids...)
|
||||
if s.getByIDsErr != nil {
|
||||
return nil, s.getByIDsErr
|
||||
}
|
||||
return s.getByIDsAccounts, nil
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) GetByID(_ context.Context, id int64) (*Account, error) {
|
||||
s.getByIDCalled = append(s.getByIDCalled, id)
|
||||
if err, ok := s.getByIDErrByID[id]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if account, ok := s.getByIDAccounts[id]; ok {
|
||||
return account, nil
|
||||
}
|
||||
return nil, errors.New("account not found")
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForBulkUpdate) ListByGroup(_ context.Context, groupID int64) ([]Account, error) {
|
||||
if err, ok := s.listByGroupErr[groupID]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if rows, ok := s.listByGroupData[groupID]; ok {
|
||||
return rows, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_AllSuccessIDs 验证批量更新成功时返回 success_ids/failed_ids。
|
||||
func TestAdminService_BulkUpdateAccounts_AllSuccessIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
schedulable := true
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
Schedulable: &schedulable,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, result.Success)
|
||||
require.Equal(t, 0, result.Failed)
|
||||
require.ElementsMatch(t, []int64{1, 2, 3}, result.SuccessIDs)
|
||||
require.Empty(t, result.FailedIDs)
|
||||
require.Len(t, result.Results, 3)
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_PartialFailureIDs 验证部分失败时 success_ids/failed_ids 正确。
|
||||
func TestAdminService_BulkUpdateAccounts_PartialFailureIDs(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{
|
||||
bindGroupErrByID: map[int64]error{
|
||||
2: errors.New("bind failed"),
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{
|
||||
accountRepo: repo,
|
||||
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "g10"}},
|
||||
}
|
||||
|
||||
groupIDs := []int64{10}
|
||||
schedulable := false
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1, 2, 3},
|
||||
GroupIDs: &groupIDs,
|
||||
Schedulable: &schedulable,
|
||||
SkipMixedChannelCheck: true,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2, result.Success)
|
||||
require.Equal(t, 1, result.Failed)
|
||||
require.ElementsMatch(t, []int64{1, 3}, result.SuccessIDs)
|
||||
require.ElementsMatch(t, []int64{2}, result.FailedIDs)
|
||||
require.Len(t, result.Results, 3)
|
||||
}
|
||||
|
||||
func TestAdminService_BulkUpdateAccounts_NilGroupRepoReturnsError(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
groupIDs := []int64{10}
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1},
|
||||
GroupIDs: &groupIDs,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "group repository not configured")
|
||||
}
|
||||
|
||||
// TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict verifies
|
||||
// that the global pre-check detects a conflict with existing group members and returns an
|
||||
// error before any DB write is performed.
|
||||
func TestAdminService_BulkUpdateAccounts_MixedChannelPreCheckBlocksOnExistingConflict(t *testing.T) {
|
||||
repo := &accountRepoStubForBulkUpdate{
|
||||
getByIDsAccounts: []*Account{
|
||||
{ID: 1, Platform: PlatformAntigravity},
|
||||
},
|
||||
// Group 10 already contains an Anthropic account.
|
||||
listByGroupData: map[int64][]Account{
|
||||
10: {{ID: 99, Platform: PlatformAnthropic}},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{
|
||||
accountRepo: repo,
|
||||
groupRepo: &groupRepoStubForAdmin{getByID: &Group{ID: 10, Name: "target-group"}},
|
||||
}
|
||||
|
||||
groupIDs := []int64{10}
|
||||
input := &BulkUpdateAccountsInput{
|
||||
AccountIDs: []int64{1},
|
||||
GroupIDs: &groupIDs,
|
||||
}
|
||||
|
||||
result, err := svc.BulkUpdateAccounts(context.Background(), input)
|
||||
require.Nil(t, result)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "mixed channel")
|
||||
// No BindGroups should have been called since the check runs before any write.
|
||||
require.Empty(t, repo.bindGroupsCalls)
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAdminService_CreateUser_Success(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 10}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
input := &CreateUserInput{
|
||||
Email: "user@test.com",
|
||||
Password: "strong-pass",
|
||||
Username: "tester",
|
||||
Notes: "note",
|
||||
Balance: 12.5,
|
||||
Concurrency: 7,
|
||||
AllowedGroups: []int64{3, 5},
|
||||
}
|
||||
|
||||
user, err := svc.CreateUser(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.Equal(t, int64(10), user.ID)
|
||||
require.Equal(t, input.Email, user.Email)
|
||||
require.Equal(t, input.Username, user.Username)
|
||||
require.Equal(t, input.Notes, user.Notes)
|
||||
require.Equal(t, input.Balance, user.Balance)
|
||||
require.Equal(t, input.Concurrency, user.Concurrency)
|
||||
require.Equal(t, input.AllowedGroups, user.AllowedGroups)
|
||||
require.Equal(t, RoleUser, user.Role)
|
||||
require.Equal(t, StatusActive, user.Status)
|
||||
require.True(t, user.CheckPassword(input.Password))
|
||||
require.Len(t, repo.created, 1)
|
||||
require.Equal(t, user, repo.created[0])
|
||||
}
|
||||
|
||||
func TestAdminService_CreateUser_EmailExists(t *testing.T) {
|
||||
repo := &userRepoStub{createErr: ErrEmailExists}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
_, err := svc.CreateUser(context.Background(), &CreateUserInput{
|
||||
Email: "dup@test.com",
|
||||
Password: "password",
|
||||
})
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
require.Empty(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateUser_CreateError(t *testing.T) {
|
||||
createErr := errors.New("db down")
|
||||
repo := &userRepoStub{createErr: createErr}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
_, err := svc.CreateUser(context.Background(), &CreateUserInput{
|
||||
Email: "user@test.com",
|
||||
Password: "password",
|
||||
})
|
||||
require.ErrorIs(t, err, createErr)
|
||||
require.Empty(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateUser_AssignsDefaultSubscriptions(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 21}
|
||||
assigner := &defaultSubscriptionAssignerStub{}
|
||||
cfg := &config.Config{
|
||||
Default: config.DefaultConfig{
|
||||
UserBalance: 0,
|
||||
UserConcurrency: 1,
|
||||
},
|
||||
}
|
||||
settingService := NewSettingService(&settingRepoStub{values: map[string]string{
|
||||
SettingKeyDefaultSubscriptions: `[{"group_id":5,"validity_days":30}]`,
|
||||
}}, cfg)
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
settingService: settingService,
|
||||
defaultSubAssigner: assigner,
|
||||
}
|
||||
|
||||
_, err := svc.CreateUser(context.Background(), &CreateUserInput{
|
||||
Email: "new-user@test.com",
|
||||
Password: "password",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, assigner.calls, 1)
|
||||
require.Equal(t, int64(21), assigner.calls[0].UserID)
|
||||
require.Equal(t, int64(5), assigner.calls[0].GroupID)
|
||||
require.Equal(t, 30, assigner.calls[0].ValidityDays)
|
||||
}
|
||||
@@ -0,0 +1,542 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userRepoStub struct {
|
||||
user *User
|
||||
getErr error
|
||||
createErr error
|
||||
deleteErr error
|
||||
exists bool
|
||||
existsErr error
|
||||
nextID int64
|
||||
created []*User
|
||||
deletedIDs []int64
|
||||
}
|
||||
|
||||
func (s *userRepoStub) Create(ctx context.Context, user *User) error {
|
||||
if s.createErr != nil {
|
||||
return s.createErr
|
||||
}
|
||||
if s.nextID != 0 && user.ID == 0 {
|
||||
user.ID = s.nextID
|
||||
}
|
||||
s.created = append(s.created, user)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *userRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
if s.user == nil {
|
||||
return nil, ErrUserNotFound
|
||||
}
|
||||
return s.user, nil
|
||||
}
|
||||
|
||||
func (s *userRepoStub) GetByEmail(ctx context.Context, email string) (*User, error) {
|
||||
panic("unexpected GetByEmail call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) GetFirstAdmin(ctx context.Context) (*User, error) {
|
||||
panic("unexpected GetFirstAdmin call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) Update(ctx context.Context, user *User) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
s.deletedIDs = append(s.deletedIDs, id)
|
||||
return s.deleteErr
|
||||
}
|
||||
|
||||
func (s *userRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, filters UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) UpdateBalance(ctx context.Context, id int64, amount float64) error {
|
||||
panic("unexpected UpdateBalance call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) DeductBalance(ctx context.Context, id int64, amount float64) error {
|
||||
panic("unexpected DeductBalance call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) UpdateConcurrency(ctx context.Context, id int64, amount int) error {
|
||||
panic("unexpected UpdateConcurrency call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) ExistsByEmail(ctx context.Context, email string) (bool, error) {
|
||||
if s.existsErr != nil {
|
||||
return false, s.existsErr
|
||||
}
|
||||
return s.exists, nil
|
||||
}
|
||||
|
||||
func (s *userRepoStub) RemoveGroupFromAllowedGroups(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected RemoveGroupFromAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) AddGroupToAllowedGroups(ctx context.Context, userID int64, groupID int64) error {
|
||||
panic("unexpected AddGroupToAllowedGroups call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) UpdateTotpSecret(ctx context.Context, userID int64, encryptedSecret *string) error {
|
||||
panic("unexpected UpdateTotpSecret call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) EnableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected EnableTotp call")
|
||||
}
|
||||
|
||||
func (s *userRepoStub) DisableTotp(ctx context.Context, userID int64) error {
|
||||
panic("unexpected DisableTotp call")
|
||||
}
|
||||
|
||||
type groupRepoStub struct {
|
||||
affectedUserIDs []int64
|
||||
deleteErr error
|
||||
deleteCalls []int64
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) Create(ctx context.Context, group *Group) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetByIDLite(ctx context.Context, id int64) (*Group, error) {
|
||||
panic("unexpected GetByIDLite call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) Update(ctx context.Context, group *Group) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) DeleteCascade(ctx context.Context, id int64) ([]int64, error) {
|
||||
s.deleteCalls = append(s.deleteCalls, id)
|
||||
return s.affectedUserIDs, s.deleteErr
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) ListActive(ctx context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) ListActiveByPlatform(ctx context.Context, platform string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) ExistsByName(ctx context.Context, name string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetAccountCount(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) DeleteAccountGroupsByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) BindAccountsToGroup(ctx context.Context, groupID int64, accountIDs []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) GetAccountIDsByGroupIDs(ctx context.Context, groupIDs []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStub) UpdateSortOrders(ctx context.Context, updates []GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type proxyRepoStub struct {
|
||||
deleteErr error
|
||||
countErr error
|
||||
accountCount int64
|
||||
deletedIDs []int64
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) Create(ctx context.Context, proxy *Proxy) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) GetByID(ctx context.Context, id int64) (*Proxy, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListByIDs(ctx context.Context, ids []int64) ([]Proxy, error) {
|
||||
panic("unexpected ListByIDs call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) Update(ctx context.Context, proxy *Proxy) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
s.deletedIDs = append(s.deletedIDs, id)
|
||||
return s.deleteErr
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]Proxy, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListActive(ctx context.Context) ([]Proxy, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListActiveWithAccountCount(ctx context.Context) ([]ProxyWithAccountCount, error) {
|
||||
panic("unexpected ListActiveWithAccountCount call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListWithFiltersAndAccountCount(ctx context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFiltersAndAccountCount call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ExistsByHostPortAuth(ctx context.Context, host string, port int, username, password string) (bool, error) {
|
||||
panic("unexpected ExistsByHostPortAuth call")
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) CountAccountsByProxyID(ctx context.Context, proxyID int64) (int64, error) {
|
||||
if s.countErr != nil {
|
||||
return 0, s.countErr
|
||||
}
|
||||
return s.accountCount, nil
|
||||
}
|
||||
|
||||
func (s *proxyRepoStub) ListAccountSummariesByProxyID(ctx context.Context, proxyID int64) ([]ProxyAccountSummary, error) {
|
||||
panic("unexpected ListAccountSummariesByProxyID call")
|
||||
}
|
||||
|
||||
type redeemRepoStub struct {
|
||||
deleteErrByID map[int64]error
|
||||
deletedIDs []int64
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) CreateBatch(ctx context.Context, codes []RedeemCode) error {
|
||||
panic("unexpected CreateBatch call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) GetByID(ctx context.Context, id int64) (*RedeemCode, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) GetByCode(ctx context.Context, code string) (*RedeemCode, error) {
|
||||
panic("unexpected GetByCode call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) Update(ctx context.Context, code *RedeemCode) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
s.deletedIDs = append(s.deletedIDs, id)
|
||||
if s.deleteErrByID != nil {
|
||||
if err, ok := s.deleteErrByID[id]; ok {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) Use(ctx context.Context, id, userID int64) error {
|
||||
panic("unexpected Use call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) List(ctx context.Context, params pagination.PaginationParams) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) ListWithFilters(ctx context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) ListByUser(ctx context.Context, userID int64, limit int) ([]RedeemCode, error) {
|
||||
panic("unexpected ListByUser call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) ListByUserPaginated(ctx context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserPaginated call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStub) SumPositiveBalanceByUser(ctx context.Context, userID int64) (float64, error) {
|
||||
panic("unexpected SumPositiveBalanceByUser call")
|
||||
}
|
||||
|
||||
type subscriptionInvalidateCall struct {
|
||||
userID int64
|
||||
groupID int64
|
||||
}
|
||||
|
||||
type billingCacheStub struct {
|
||||
invalidations chan subscriptionInvalidateCall
|
||||
}
|
||||
|
||||
func newBillingCacheStub(buffer int) *billingCacheStub {
|
||||
return &billingCacheStub{invalidations: make(chan subscriptionInvalidateCall, buffer)}
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
panic("unexpected GetUserBalance call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
panic("unexpected SetUserBalance call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
panic("unexpected DeductUserBalance call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
panic("unexpected InvalidateUserBalance call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
|
||||
panic("unexpected GetSubscriptionCache call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
|
||||
panic("unexpected SetSubscriptionCache call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
panic("unexpected UpdateSubscriptionUsage call")
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
s.invalidations <- subscriptionInvalidateCall{userID: userID, groupID: groupID}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
|
||||
panic("unexpected GetAPIKeyRateLimit call")
|
||||
}
|
||||
func (s *billingCacheStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
|
||||
panic("unexpected SetAPIKeyRateLimit call")
|
||||
}
|
||||
func (s *billingCacheStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
panic("unexpected UpdateAPIKeyRateLimitUsage call")
|
||||
}
|
||||
func (s *billingCacheStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
panic("unexpected InvalidateAPIKeyRateLimit call")
|
||||
}
|
||||
|
||||
func waitForInvalidations(t *testing.T, ch <-chan subscriptionInvalidateCall, expected int) []subscriptionInvalidateCall {
|
||||
t.Helper()
|
||||
calls := make([]subscriptionInvalidateCall, 0, expected)
|
||||
timeout := time.After(2 * time.Second)
|
||||
for len(calls) < expected {
|
||||
select {
|
||||
case call := <-ch:
|
||||
calls = append(calls, call)
|
||||
case <-timeout:
|
||||
t.Fatalf("timeout waiting for %d invalidations, got %d", expected, len(calls))
|
||||
}
|
||||
}
|
||||
return calls
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteUser_Success(t *testing.T) {
|
||||
repo := &userRepoStub{user: &User{ID: 7, Role: RoleUser}}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
err := svc.DeleteUser(context.Background(), 7)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{7}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteUser_NotFound(t *testing.T) {
|
||||
repo := &userRepoStub{getErr: ErrUserNotFound}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
err := svc.DeleteUser(context.Background(), 404)
|
||||
require.ErrorIs(t, err, ErrUserNotFound)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteUser_AdminGuard(t *testing.T) {
|
||||
repo := &userRepoStub{user: &User{ID: 1, Role: RoleAdmin}}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
err := svc.DeleteUser(context.Background(), 1)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "cannot delete admin user")
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteUser_DeleteError(t *testing.T) {
|
||||
deleteErr := errors.New("delete failed")
|
||||
repo := &userRepoStub{
|
||||
user: &User{ID: 9, Role: RoleUser},
|
||||
deleteErr: deleteErr,
|
||||
}
|
||||
svc := &adminServiceImpl{userRepo: repo}
|
||||
|
||||
err := svc.DeleteUser(context.Background(), 9)
|
||||
require.ErrorIs(t, err, deleteErr)
|
||||
require.Equal(t, []int64{9}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteGroup_Success_WithCacheInvalidation(t *testing.T) {
|
||||
cache := newBillingCacheStub(2)
|
||||
repo := &groupRepoStub{affectedUserIDs: []int64{11, 12}}
|
||||
svc := &adminServiceImpl{
|
||||
groupRepo: repo,
|
||||
billingCacheService: &BillingCacheService{cache: cache},
|
||||
}
|
||||
|
||||
err := svc.DeleteGroup(context.Background(), 5)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{5}, repo.deleteCalls)
|
||||
|
||||
calls := waitForInvalidations(t, cache.invalidations, 2)
|
||||
require.ElementsMatch(t, []subscriptionInvalidateCall{
|
||||
{userID: 11, groupID: 5},
|
||||
{userID: 12, groupID: 5},
|
||||
}, calls)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteGroup_NotFound(t *testing.T) {
|
||||
repo := &groupRepoStub{deleteErr: ErrGroupNotFound}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
err := svc.DeleteGroup(context.Background(), 99)
|
||||
require.ErrorIs(t, err, ErrGroupNotFound)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteGroup_Error(t *testing.T) {
|
||||
deleteErr := errors.New("delete failed")
|
||||
repo := &groupRepoStub{deleteErr: deleteErr}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
err := svc.DeleteGroup(context.Background(), 42)
|
||||
require.ErrorIs(t, err, deleteErr)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteProxy_Success(t *testing.T) {
|
||||
repo := &proxyRepoStub{}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
err := svc.DeleteProxy(context.Background(), 7)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{7}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteProxy_Idempotent(t *testing.T) {
|
||||
repo := &proxyRepoStub{}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
err := svc.DeleteProxy(context.Background(), 404)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{404}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteProxy_InUse(t *testing.T) {
|
||||
repo := &proxyRepoStub{accountCount: 2}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
err := svc.DeleteProxy(context.Background(), 77)
|
||||
require.ErrorIs(t, err, ErrProxyInUse)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteProxy_Error(t *testing.T) {
|
||||
deleteErr := errors.New("delete failed")
|
||||
repo := &proxyRepoStub{deleteErr: deleteErr}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
err := svc.DeleteProxy(context.Background(), 33)
|
||||
require.ErrorIs(t, err, deleteErr)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteRedeemCode_Success(t *testing.T) {
|
||||
repo := &redeemRepoStub{}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
err := svc.DeleteRedeemCode(context.Background(), 10)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{10}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteRedeemCode_Idempotent(t *testing.T) {
|
||||
repo := &redeemRepoStub{}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
err := svc.DeleteRedeemCode(context.Background(), 999)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{999}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_DeleteRedeemCode_Error(t *testing.T) {
|
||||
deleteErr := errors.New("delete failed")
|
||||
repo := &redeemRepoStub{deleteErrByID: map[int64]error{1: deleteErr}}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
err := svc.DeleteRedeemCode(context.Background(), 1)
|
||||
require.ErrorIs(t, err, deleteErr)
|
||||
require.Equal(t, []int64{1}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_BatchDeleteRedeemCodes_Success(t *testing.T) {
|
||||
repo := &redeemRepoStub{}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
deleted, err := svc.BatchDeleteRedeemCodes(context.Background(), []int64{1, 2, 3})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3), deleted)
|
||||
require.Equal(t, []int64{1, 2, 3}, repo.deletedIDs)
|
||||
}
|
||||
|
||||
func TestAdminService_BatchDeleteRedeemCodes_PartialFailures(t *testing.T) {
|
||||
repo := &redeemRepoStub{
|
||||
deleteErrByID: map[int64]error{
|
||||
2: errors.New("db error"),
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
deleted, err := svc.BatchDeleteRedeemCodes(context.Background(), []int64{1, 2, 3})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), deleted)
|
||||
require.Equal(t, []int64{1, 2, 3}, repo.deletedIDs)
|
||||
}
|
||||
@@ -0,0 +1,176 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// userGroupRateRepoStubForGroupRate implements UserGroupRateRepository for group rate tests.
|
||||
type userGroupRateRepoStubForGroupRate struct {
|
||||
getByGroupIDData map[int64][]UserGroupRateEntry
|
||||
getByGroupIDErr error
|
||||
|
||||
deletedGroupIDs []int64
|
||||
deleteByGroupErr error
|
||||
|
||||
syncedGroupID int64
|
||||
syncedEntries []GroupRateMultiplierInput
|
||||
syncGroupErr error
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetByUserID(_ context.Context, _ int64) (map[int64]float64, error) {
|
||||
panic("unexpected GetByUserID call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetByUserAndGroup(_ context.Context, _, _ int64) (*float64, error) {
|
||||
panic("unexpected GetByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) GetByGroupID(_ context.Context, groupID int64) ([]UserGroupRateEntry, error) {
|
||||
if s.getByGroupIDErr != nil {
|
||||
return nil, s.getByGroupIDErr
|
||||
}
|
||||
return s.getByGroupIDData[groupID], nil
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) SyncUserGroupRates(_ context.Context, _ int64, _ map[int64]*float64) error {
|
||||
panic("unexpected SyncUserGroupRates call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) SyncGroupRateMultipliers(_ context.Context, groupID int64, entries []GroupRateMultiplierInput) error {
|
||||
s.syncedGroupID = groupID
|
||||
s.syncedEntries = entries
|
||||
return s.syncGroupErr
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) DeleteByGroupID(_ context.Context, groupID int64) error {
|
||||
s.deletedGroupIDs = append(s.deletedGroupIDs, groupID)
|
||||
return s.deleteByGroupErr
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForGroupRate) DeleteByUserID(_ context.Context, _ int64) error {
|
||||
panic("unexpected DeleteByUserID call")
|
||||
}
|
||||
|
||||
func TestAdminService_GetGroupRateMultipliers(t *testing.T) {
|
||||
t.Run("returns entries for group", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
getByGroupIDData: map[int64][]UserGroupRateEntry{
|
||||
10: {
|
||||
{UserID: 1, UserName: "alice", UserEmail: "alice@test.com", RateMultiplier: 1.5},
|
||||
{UserID: 2, UserName: "bob", UserEmail: "bob@test.com", RateMultiplier: 0.8},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, entries, 2)
|
||||
require.Equal(t, int64(1), entries[0].UserID)
|
||||
require.Equal(t, "alice", entries[0].UserName)
|
||||
require.Equal(t, 1.5, entries[0].RateMultiplier)
|
||||
require.Equal(t, int64(2), entries[1].UserID)
|
||||
require.Equal(t, 0.8, entries[1].RateMultiplier)
|
||||
})
|
||||
|
||||
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||
|
||||
entries, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, entries)
|
||||
})
|
||||
|
||||
t.Run("returns empty slice for group with no entries", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
getByGroupIDData: map[int64][]UserGroupRateEntry{},
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
entries, err := svc.GetGroupRateMultipliers(context.Background(), 99)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, entries)
|
||||
})
|
||||
|
||||
t.Run("propagates repo error", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
getByGroupIDErr: errors.New("db error"),
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
_, err := svc.GetGroupRateMultipliers(context.Background(), 10)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "db error")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ClearGroupRateMultipliers(t *testing.T) {
|
||||
t.Run("deletes by group ID", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{42}, repo.deletedGroupIDs)
|
||||
})
|
||||
|
||||
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||
|
||||
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("propagates repo error", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
deleteByGroupErr: errors.New("delete failed"),
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
err := svc.ClearGroupRateMultipliers(context.Background(), 42)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "delete failed")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_BatchSetGroupRateMultipliers(t *testing.T) {
|
||||
t.Run("syncs entries to repo", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
entries := []GroupRateMultiplierInput{
|
||||
{UserID: 1, RateMultiplier: 1.5},
|
||||
{UserID: 2, RateMultiplier: 0.8},
|
||||
}
|
||||
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, entries)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), repo.syncedGroupID)
|
||||
require.Equal(t, entries, repo.syncedEntries)
|
||||
})
|
||||
|
||||
t.Run("returns nil when repo is nil", func(t *testing.T) {
|
||||
svc := &adminServiceImpl{userGroupRateRepo: nil}
|
||||
|
||||
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, nil)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("propagates repo error", func(t *testing.T) {
|
||||
repo := &userGroupRateRepoStubForGroupRate{
|
||||
syncGroupErr: errors.New("sync failed"),
|
||||
}
|
||||
svc := &adminServiceImpl{userGroupRateRepo: repo}
|
||||
|
||||
err := svc.BatchSetGroupRateMultipliers(context.Background(), 10, []GroupRateMultiplierInput{
|
||||
{UserID: 1, RateMultiplier: 1.0},
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "sync failed")
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,787 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// groupRepoStubForAdmin 用于测试 AdminService 的 GroupRepository Stub
|
||||
type groupRepoStubForAdmin struct {
|
||||
created *Group // 记录 Create 调用的参数
|
||||
updated *Group // 记录 Update 调用的参数
|
||||
getByID *Group // GetByID 返回值
|
||||
getErr error // GetByID 返回的错误
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersPlatform string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersIsExclusive *bool
|
||||
listWithFiltersGroups []Group
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Create(_ context.Context, g *Group) error {
|
||||
s.created = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Update(_ context.Context, g *Group) error {
|
||||
s.updated = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetByID(_ context.Context, _ int64) (*Group, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetByIDLite(_ context.Context, _ int64) (*Group, error) {
|
||||
if s.getErr != nil {
|
||||
return nil, s.getErr
|
||||
}
|
||||
return s.getByID, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, status, search string, isExclusive *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
s.listWithFiltersIsExclusive = isExclusive
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersGroups)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersGroups, result, nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForAdmin) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestAdminService_CreateGroup_WithImagePricing 测试创建分组时 ImagePrice 字段正确传递
|
||||
func TestAdminService_CreateGroup_WithImagePricing(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
price1K := 0.10
|
||||
price2K := 0.15
|
||||
price4K := 0.30
|
||||
|
||||
input := &CreateGroupInput{
|
||||
Name: "test-group",
|
||||
Description: "Test group",
|
||||
Platform: PlatformAntigravity,
|
||||
RateMultiplier: 1.0,
|
||||
ImagePrice1K: &price1K,
|
||||
ImagePrice2K: &price2K,
|
||||
ImagePrice4K: &price4K,
|
||||
}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证 repo 收到了正确的字段
|
||||
require.NotNil(t, repo.created)
|
||||
require.NotNil(t, repo.created.ImagePrice1K)
|
||||
require.NotNil(t, repo.created.ImagePrice2K)
|
||||
require.NotNil(t, repo.created.ImagePrice4K)
|
||||
require.InDelta(t, 0.10, *repo.created.ImagePrice1K, 0.0001)
|
||||
require.InDelta(t, 0.15, *repo.created.ImagePrice2K, 0.0001)
|
||||
require.InDelta(t, 0.30, *repo.created.ImagePrice4K, 0.0001)
|
||||
}
|
||||
|
||||
// TestAdminService_CreateGroup_NilImagePricing 测试 ImagePrice 为 nil 时正常创建
|
||||
func TestAdminService_CreateGroup_NilImagePricing(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
input := &CreateGroupInput{
|
||||
Name: "test-group",
|
||||
Description: "Test group",
|
||||
Platform: PlatformAntigravity,
|
||||
RateMultiplier: 1.0,
|
||||
// ImagePrice 字段全部为 nil
|
||||
}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证 ImagePrice 字段为 nil
|
||||
require.NotNil(t, repo.created)
|
||||
require.Nil(t, repo.created.ImagePrice1K)
|
||||
require.Nil(t, repo.created.ImagePrice2K)
|
||||
require.Nil(t, repo.created.ImagePrice4K)
|
||||
}
|
||||
|
||||
// TestAdminService_UpdateGroup_WithImagePricing 测试更新分组时 ImagePrice 字段正确更新
|
||||
func TestAdminService_UpdateGroup_WithImagePricing(t *testing.T) {
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
price1K := 0.12
|
||||
price2K := 0.18
|
||||
price4K := 0.36
|
||||
|
||||
input := &UpdateGroupInput{
|
||||
ImagePrice1K: &price1K,
|
||||
ImagePrice2K: &price2K,
|
||||
ImagePrice4K: &price4K,
|
||||
}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证 repo 收到了更新后的字段
|
||||
require.NotNil(t, repo.updated)
|
||||
require.NotNil(t, repo.updated.ImagePrice1K)
|
||||
require.NotNil(t, repo.updated.ImagePrice2K)
|
||||
require.NotNil(t, repo.updated.ImagePrice4K)
|
||||
require.InDelta(t, 0.12, *repo.updated.ImagePrice1K, 0.0001)
|
||||
require.InDelta(t, 0.18, *repo.updated.ImagePrice2K, 0.0001)
|
||||
require.InDelta(t, 0.36, *repo.updated.ImagePrice4K, 0.0001)
|
||||
}
|
||||
|
||||
// TestAdminService_UpdateGroup_PartialImagePricing 测试仅更新部分 ImagePrice 字段
|
||||
func TestAdminService_UpdateGroup_PartialImagePricing(t *testing.T) {
|
||||
oldPrice2K := 0.15
|
||||
existingGroup := &Group{
|
||||
ID: 1,
|
||||
Name: "existing-group",
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
ImagePrice2K: &oldPrice2K, // 已有 2K 价格
|
||||
}
|
||||
repo := &groupRepoStubForAdmin{getByID: existingGroup}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
// 只更新 1K 价格
|
||||
price1K := 0.10
|
||||
input := &UpdateGroupInput{
|
||||
ImagePrice1K: &price1K,
|
||||
// ImagePrice2K 和 ImagePrice4K 为 nil,不更新
|
||||
}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), 1, input)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
|
||||
// 验证:1K 被更新,2K 保持原值,4K 仍为 nil
|
||||
require.NotNil(t, repo.updated)
|
||||
require.NotNil(t, repo.updated.ImagePrice1K)
|
||||
require.InDelta(t, 0.10, *repo.updated.ImagePrice1K, 0.0001)
|
||||
require.NotNil(t, repo.updated.ImagePrice2K)
|
||||
require.InDelta(t, 0.15, *repo.updated.ImagePrice2K, 0.0001) // 原值保持
|
||||
require.Nil(t, repo.updated.ImagePrice4K)
|
||||
}
|
||||
|
||||
func TestAdminService_ListGroups_WithSearch(t *testing.T) {
|
||||
// 测试:
|
||||
// 1. search 参数正常传递到 repository 层
|
||||
// 2. search 为空字符串时的行为
|
||||
// 3. search 与其他过滤条件组合使用
|
||||
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{{ID: 1, Name: "alpha"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 1},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 1, 20, "", "", "alpha", nil)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), total)
|
||||
require.Equal(t, []Group{{ID: 1, Name: "alpha"}}, groups)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "alpha", repo.listWithFiltersSearch)
|
||||
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
|
||||
t.Run("search 为空字符串时传递空字符串", func(t *testing.T) {
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 0},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 2, 10, "", "", "", nil)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, groups)
|
||||
require.Equal(t, int64(0), total)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "", repo.listWithFiltersSearch)
|
||||
require.Nil(t, repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
|
||||
t.Run("search 与其他过滤条件组合使用", func(t *testing.T) {
|
||||
isExclusive := true
|
||||
repo := &groupRepoStubForAdmin{
|
||||
listWithFiltersGroups: []Group{{ID: 2, Name: "beta"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 42},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
groups, total, err := svc.ListGroups(context.Background(), 3, 50, PlatformAntigravity, StatusActive, "beta", &isExclusive)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), total)
|
||||
require.Equal(t, []Group{{ID: 2, Name: "beta"}}, groups)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
|
||||
require.Equal(t, PlatformAntigravity, repo.listWithFiltersPlatform)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "beta", repo.listWithFiltersSearch)
|
||||
require.NotNil(t, repo.listWithFiltersIsExclusive)
|
||||
require.True(t, *repo.listWithFiltersIsExclusive)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ValidateFallbackGroup_DetectsCycle(t *testing.T) {
|
||||
groupID := int64(1)
|
||||
fallbackID := int64(2)
|
||||
repo := &groupRepoStubForFallbackCycle{
|
||||
groups: map[int64]*Group{
|
||||
groupID: {
|
||||
ID: groupID,
|
||||
FallbackGroupID: &fallbackID,
|
||||
},
|
||||
fallbackID: {
|
||||
ID: fallbackID,
|
||||
FallbackGroupID: &groupID,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
err := svc.validateFallbackGroup(context.Background(), groupID, fallbackID)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group cycle")
|
||||
}
|
||||
|
||||
type groupRepoStubForFallbackCycle struct {
|
||||
groups map[int64]*Group
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Create(_ context.Context, _ *Group) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Update(_ context.Context, _ *Group) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
return s.GetByIDLite(ctx, id)
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||
if g, ok := s.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForFallbackCycle) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type groupRepoStubForInvalidRequestFallback struct {
|
||||
groups map[int64]*Group
|
||||
created *Group
|
||||
updated *Group
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) Create(_ context.Context, g *Group) error {
|
||||
s.created = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) Update(_ context.Context, g *Group) error {
|
||||
s.updated = g
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetByID(ctx context.Context, id int64) (*Group, error) {
|
||||
return s.GetByIDLite(ctx, id)
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetByIDLite(_ context.Context, id int64) (*Group, error) {
|
||||
if g, ok := s.groups[id]; ok {
|
||||
return g, nil
|
||||
}
|
||||
return nil, ErrGroupNotFound
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) Delete(_ context.Context, _ int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) DeleteCascade(_ context.Context, _ int64) ([]int64, error) {
|
||||
panic("unexpected DeleteCascade call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) List(_ context.Context, _ pagination.PaginationParams) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected List call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ListWithFilters(_ context.Context, _ pagination.PaginationParams, _, _, _ string, _ *bool) ([]Group, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListWithFilters call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ListActive(_ context.Context) ([]Group, error) {
|
||||
panic("unexpected ListActive call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ListActiveByPlatform(_ context.Context, _ string) ([]Group, error) {
|
||||
panic("unexpected ListActiveByPlatform call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) ExistsByName(_ context.Context, _ string) (bool, error) {
|
||||
panic("unexpected ExistsByName call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetAccountCount(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected GetAccountCount call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) DeleteAccountGroupsByGroupID(_ context.Context, _ int64) (int64, error) {
|
||||
panic("unexpected DeleteAccountGroupsByGroupID call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) GetAccountIDsByGroupIDs(_ context.Context, _ []int64) ([]int64, error) {
|
||||
panic("unexpected GetAccountIDsByGroupIDs call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) BindAccountsToGroup(_ context.Context, _ int64, _ []int64) error {
|
||||
panic("unexpected BindAccountsToGroup call")
|
||||
}
|
||||
|
||||
func (s *groupRepoStubForInvalidRequestFallback) UpdateSortOrders(_ context.Context, _ []GroupSortOrderUpdate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsUnsupportedPlatform(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformOpenAI,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsSubscription(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeSubscription,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
fallback *Group
|
||||
wantMessage string
|
||||
}{
|
||||
{
|
||||
name: "openai_target",
|
||||
fallback: &Group{ID: 10, Platform: PlatformOpenAI, SubscriptionType: SubscriptionTypeStandard},
|
||||
wantMessage: "fallback group must be anthropic platform",
|
||||
},
|
||||
{
|
||||
name: "antigravity_target",
|
||||
fallback: &Group{ID: 10, Platform: PlatformAntigravity, SubscriptionType: SubscriptionTypeStandard},
|
||||
wantMessage: "fallback group must be anthropic platform",
|
||||
},
|
||||
{
|
||||
name: "subscription_group",
|
||||
fallback: &Group{ID: 10, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
|
||||
wantMessage: "fallback group cannot be subscription type",
|
||||
},
|
||||
{
|
||||
name: "nested_fallback",
|
||||
fallback: &Group{
|
||||
ID: 10,
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: func() *int64 { v := int64(99); return &v }(),
|
||||
},
|
||||
wantMessage: "fallback group cannot have invalid request fallback configured",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
fallbackID := tc.fallback.ID
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: tc.fallback,
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tc.wantMessage)
|
||||
require.Nil(t, repo.created)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackNotFound(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group not found")
|
||||
require.Nil(t, repo.created)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAntigravity,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.created)
|
||||
require.Equal(t, fallbackID, *repo.created.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_CreateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
|
||||
zero := int64(0)
|
||||
repo := &groupRepoStubForInvalidRequestFallback{}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.CreateGroup(context.Background(), &CreateGroupInput{
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
FallbackGroupIDOnInvalidRequest: &zero,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.created)
|
||||
require.Nil(t, repo.created.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackPlatformMismatch(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
Platform: PlatformOpenAI,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "invalid request fallback only supported for anthropic or antigravity groups")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackSubscriptionMismatch(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
SubscriptionType: SubscriptionTypeSubscription,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "subscription groups cannot set invalid request fallback")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackClearsOnZero(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
clear := int64(0)
|
||||
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
Platform: PlatformOpenAI,
|
||||
FallbackGroupIDOnInvalidRequest: &clear,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Nil(t, repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackRejectsFallbackGroup(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeSubscription},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
_, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "fallback group cannot be subscription type")
|
||||
require.Nil(t, repo.updated)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackSetSuccess(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAnthropic,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateGroup_InvalidRequestFallbackAllowsAntigravity(t *testing.T) {
|
||||
fallbackID := int64(10)
|
||||
existing := &Group{
|
||||
ID: 1,
|
||||
Name: "g1",
|
||||
Platform: PlatformAntigravity,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
Status: StatusActive,
|
||||
}
|
||||
repo := &groupRepoStubForInvalidRequestFallback{
|
||||
groups: map[int64]*Group{
|
||||
existing.ID: existing,
|
||||
fallbackID: {ID: fallbackID, Platform: PlatformAnthropic, SubscriptionType: SubscriptionTypeStandard},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{groupRepo: repo}
|
||||
|
||||
group, err := svc.UpdateGroup(context.Background(), existing.ID, &UpdateGroupInput{
|
||||
FallbackGroupIDOnInvalidRequest: &fallbackID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, group)
|
||||
require.NotNil(t, repo.updated)
|
||||
require.Equal(t, fallbackID, *repo.updated.FallbackGroupIDOnInvalidRequest)
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type userRepoStubForListUsers struct {
|
||||
userRepoStub
|
||||
users []User
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *userRepoStubForListUsers) ListWithFilters(_ context.Context, params pagination.PaginationParams, _ UserListFilters) ([]User, *pagination.PaginationResult, error) {
|
||||
if s.err != nil {
|
||||
return nil, nil, s.err
|
||||
}
|
||||
out := make([]User, len(s.users))
|
||||
copy(out, s.users)
|
||||
return out, &pagination.PaginationResult{
|
||||
Total: int64(len(out)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type userGroupRateRepoStubForListUsers struct {
|
||||
batchCalls int
|
||||
singleCall []int64
|
||||
|
||||
batchErr error
|
||||
batchData map[int64]map[int64]float64
|
||||
|
||||
singleErr map[int64]error
|
||||
singleData map[int64]map[int64]float64
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) GetByUserIDs(_ context.Context, _ []int64) (map[int64]map[int64]float64, error) {
|
||||
s.batchCalls++
|
||||
if s.batchErr != nil {
|
||||
return nil, s.batchErr
|
||||
}
|
||||
return s.batchData, nil
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) GetByUserID(_ context.Context, userID int64) (map[int64]float64, error) {
|
||||
s.singleCall = append(s.singleCall, userID)
|
||||
if err, ok := s.singleErr[userID]; ok {
|
||||
return nil, err
|
||||
}
|
||||
if rates, ok := s.singleData[userID]; ok {
|
||||
return rates, nil
|
||||
}
|
||||
return map[int64]float64{}, nil
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) GetByUserAndGroup(_ context.Context, userID, groupID int64) (*float64, error) {
|
||||
panic("unexpected GetByUserAndGroup call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) SyncUserGroupRates(_ context.Context, userID int64, rates map[int64]*float64) error {
|
||||
panic("unexpected SyncUserGroupRates call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) GetByGroupID(_ context.Context, _ int64) ([]UserGroupRateEntry, error) {
|
||||
panic("unexpected GetByGroupID call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) SyncGroupRateMultipliers(_ context.Context, _ int64, _ []GroupRateMultiplierInput) error {
|
||||
panic("unexpected SyncGroupRateMultipliers call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) DeleteByGroupID(_ context.Context, _ int64) error {
|
||||
panic("unexpected DeleteByGroupID call")
|
||||
}
|
||||
|
||||
func (s *userGroupRateRepoStubForListUsers) DeleteByUserID(_ context.Context, userID int64) error {
|
||||
panic("unexpected DeleteByUserID call")
|
||||
}
|
||||
|
||||
func TestAdminService_ListUsers_BatchRateFallbackToSingle(t *testing.T) {
|
||||
userRepo := &userRepoStubForListUsers{
|
||||
users: []User{
|
||||
{ID: 101, Username: "u1"},
|
||||
{ID: 202, Username: "u2"},
|
||||
},
|
||||
}
|
||||
rateRepo := &userGroupRateRepoStubForListUsers{
|
||||
batchErr: errors.New("batch unavailable"),
|
||||
singleData: map[int64]map[int64]float64{
|
||||
101: {11: 1.1},
|
||||
202: {22: 2.2},
|
||||
},
|
||||
}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: userRepo,
|
||||
userGroupRateRepo: rateRepo,
|
||||
}
|
||||
|
||||
users, total, err := svc.ListUsers(context.Background(), 1, 20, UserListFilters{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), total)
|
||||
require.Len(t, users, 2)
|
||||
require.Equal(t, 1, rateRepo.batchCalls)
|
||||
require.ElementsMatch(t, []int64{101, 202}, rateRepo.singleCall)
|
||||
require.Equal(t, 1.1, users[0].GroupRates[11])
|
||||
require.Equal(t, 2.2, users[1].GroupRates[22])
|
||||
}
|
||||
@@ -0,0 +1,123 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type updateAccountOveragesRepoStub struct {
|
||||
mockAccountRepoForGemini
|
||||
account *Account
|
||||
updateCalls int
|
||||
}
|
||||
|
||||
func (r *updateAccountOveragesRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
func (r *updateAccountOveragesRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
r.updateCalls++
|
||||
r.account = account
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestUpdateAccount_DisableOveragesClearsAICreditsKey(t *testing.T) {
|
||||
accountID := int64(101)
|
||||
repo := &updateAccountOveragesRepoStub{
|
||||
account: &Account{
|
||||
ID: accountID,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
"mixed_scheduling": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||
Extra: map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.False(t, updated.IsOveragesEnabled())
|
||||
|
||||
// 关闭 overages 后,AICredits key 应被清除
|
||||
rawLimits, ok := repo.account.Extra[modelRateLimitsKey].(map[string]any)
|
||||
if ok {
|
||||
_, exists := rawLimits[creditsExhaustedKey]
|
||||
require.False(t, exists, "关闭 overages 时应清除 AICredits 限流 key")
|
||||
}
|
||||
// 普通模型限流应保留
|
||||
require.True(t, ok)
|
||||
_, exists := rawLimits["claude-sonnet-4-5"]
|
||||
require.True(t, exists, "普通模型限流应保留")
|
||||
}
|
||||
|
||||
func TestUpdateAccount_EnableOveragesClearsModelRateLimitsBeforePersist(t *testing.T) {
|
||||
accountID := int64(102)
|
||||
repo := &updateAccountOveragesRepoStub{
|
||||
account: &Account{
|
||||
ID: accountID,
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeOAuth,
|
||||
Status: StatusActive,
|
||||
Extra: map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
updated, err := svc.UpdateAccount(context.Background(), accountID, &UpdateAccountInput{
|
||||
Extra: map[string]any{
|
||||
"mixed_scheduling": true,
|
||||
"allow_overages": true,
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated)
|
||||
require.Equal(t, 1, repo.updateCalls)
|
||||
require.True(t, updated.IsOveragesEnabled())
|
||||
|
||||
_, exists := repo.account.Extra[modelRateLimitsKey]
|
||||
require.False(t, exists, "开启 overages 时应在持久化前清掉旧模型限流")
|
||||
}
|
||||
@@ -0,0 +1,95 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestFinalizeProxyQualityResult_ScoreAndGrade(t *testing.T) {
|
||||
result := &ProxyQualityCheckResult{
|
||||
PassedCount: 2,
|
||||
WarnCount: 1,
|
||||
FailedCount: 1,
|
||||
ChallengeCount: 1,
|
||||
}
|
||||
|
||||
finalizeProxyQualityResult(result)
|
||||
|
||||
require.Equal(t, 38, result.Score)
|
||||
require.Equal(t, "F", result.Grade)
|
||||
require.Contains(t, result.Summary, "通过 2 项")
|
||||
require.Contains(t, result.Summary, "告警 1 项")
|
||||
require.Contains(t, result.Summary, "失败 1 项")
|
||||
require.Contains(t, result.Summary, "挑战 1 项")
|
||||
}
|
||||
|
||||
func TestRunProxyQualityTarget_SoraChallenge(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/html")
|
||||
w.Header().Set("cf-ray", "test-ray-123")
|
||||
w.WriteHeader(http.StatusForbidden)
|
||||
_, _ = w.Write([]byte("<!DOCTYPE html><title>Just a moment...</title><script>window._cf_chl_opt={};</script>"))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
target := proxyQualityTarget{
|
||||
Target: "sora",
|
||||
URL: server.URL,
|
||||
Method: http.MethodGet,
|
||||
AllowedStatuses: map[int]struct{}{
|
||||
http.StatusUnauthorized: {},
|
||||
},
|
||||
}
|
||||
|
||||
item := runProxyQualityTarget(context.Background(), server.Client(), target)
|
||||
require.Equal(t, "challenge", item.Status)
|
||||
require.Equal(t, http.StatusForbidden, item.HTTPStatus)
|
||||
require.Equal(t, "test-ray-123", item.CFRay)
|
||||
}
|
||||
|
||||
func TestRunProxyQualityTarget_AllowedStatusPass(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte(`{"models":[]}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
target := proxyQualityTarget{
|
||||
Target: "gemini",
|
||||
URL: server.URL,
|
||||
Method: http.MethodGet,
|
||||
AllowedStatuses: map[int]struct{}{
|
||||
http.StatusOK: {},
|
||||
},
|
||||
}
|
||||
|
||||
item := runProxyQualityTarget(context.Background(), server.Client(), target)
|
||||
require.Equal(t, "pass", item.Status)
|
||||
require.Equal(t, http.StatusOK, item.HTTPStatus)
|
||||
}
|
||||
|
||||
func TestRunProxyQualityTarget_AllowedStatusWarnForUnauthorized(t *testing.T) {
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_, _ = w.Write([]byte(`{"error":"unauthorized"}`))
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
target := proxyQualityTarget{
|
||||
Target: "openai",
|
||||
URL: server.URL,
|
||||
Method: http.MethodGet,
|
||||
AllowedStatuses: map[int]struct{}{
|
||||
http.StatusUnauthorized: {},
|
||||
},
|
||||
}
|
||||
|
||||
item := runProxyQualityTarget(context.Background(), server.Client(), target)
|
||||
require.Equal(t, "warn", item.Status)
|
||||
require.Equal(t, http.StatusUnauthorized, item.HTTPStatus)
|
||||
require.Contains(t, item.Message, "目标可达")
|
||||
}
|
||||
@@ -0,0 +1,246 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type accountRepoStubForAdminList struct {
|
||||
accountRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersPlatform string
|
||||
listWithFiltersType string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersAccounts []Account
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *accountRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, platform, accountType, status, search string, groupID int64) ([]Account, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersPlatform = platform
|
||||
s.listWithFiltersType = accountType
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersAccounts)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersAccounts, result, nil
|
||||
}
|
||||
|
||||
type proxyRepoStubForAdminList struct {
|
||||
proxyRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersProtocol string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersProxies []Proxy
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
|
||||
listWithFiltersAndAccountCountCalls int
|
||||
listWithFiltersAndAccountCountParams pagination.PaginationParams
|
||||
listWithFiltersAndAccountCountProtocol string
|
||||
listWithFiltersAndAccountCountStatus string
|
||||
listWithFiltersAndAccountCountSearch string
|
||||
listWithFiltersAndAccountCountProxies []ProxyWithAccountCount
|
||||
listWithFiltersAndAccountCountResult *pagination.PaginationResult
|
||||
listWithFiltersAndAccountCountErr error
|
||||
}
|
||||
|
||||
func (s *proxyRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]Proxy, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersProtocol = protocol
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersProxies)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersProxies, result, nil
|
||||
}
|
||||
|
||||
func (s *proxyRepoStubForAdminList) ListWithFiltersAndAccountCount(_ context.Context, params pagination.PaginationParams, protocol, status, search string) ([]ProxyWithAccountCount, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersAndAccountCountCalls++
|
||||
s.listWithFiltersAndAccountCountParams = params
|
||||
s.listWithFiltersAndAccountCountProtocol = protocol
|
||||
s.listWithFiltersAndAccountCountStatus = status
|
||||
s.listWithFiltersAndAccountCountSearch = search
|
||||
|
||||
if s.listWithFiltersAndAccountCountErr != nil {
|
||||
return nil, nil, s.listWithFiltersAndAccountCountErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersAndAccountCountResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersAndAccountCountProxies)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersAndAccountCountProxies, result, nil
|
||||
}
|
||||
|
||||
type redeemRepoStubForAdminList struct {
|
||||
redeemRepoStub
|
||||
|
||||
listWithFiltersCalls int
|
||||
listWithFiltersParams pagination.PaginationParams
|
||||
listWithFiltersType string
|
||||
listWithFiltersStatus string
|
||||
listWithFiltersSearch string
|
||||
listWithFiltersCodes []RedeemCode
|
||||
listWithFiltersResult *pagination.PaginationResult
|
||||
listWithFiltersErr error
|
||||
}
|
||||
|
||||
func (s *redeemRepoStubForAdminList) ListWithFilters(_ context.Context, params pagination.PaginationParams, codeType, status, search string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
s.listWithFiltersCalls++
|
||||
s.listWithFiltersParams = params
|
||||
s.listWithFiltersType = codeType
|
||||
s.listWithFiltersStatus = status
|
||||
s.listWithFiltersSearch = search
|
||||
|
||||
if s.listWithFiltersErr != nil {
|
||||
return nil, nil, s.listWithFiltersErr
|
||||
}
|
||||
|
||||
result := s.listWithFiltersResult
|
||||
if result == nil {
|
||||
result = &pagination.PaginationResult{
|
||||
Total: int64(len(s.listWithFiltersCodes)),
|
||||
Page: params.Page,
|
||||
PageSize: params.PageSize,
|
||||
}
|
||||
}
|
||||
|
||||
return s.listWithFiltersCodes, result, nil
|
||||
}
|
||||
|
||||
func (s *redeemRepoStubForAdminList) ListByUserPaginated(_ context.Context, userID int64, params pagination.PaginationParams, codeType string) ([]RedeemCode, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserPaginated call")
|
||||
}
|
||||
|
||||
func (s *redeemRepoStubForAdminList) SumPositiveBalanceByUser(_ context.Context, userID int64) (float64, error) {
|
||||
panic("unexpected SumPositiveBalanceByUser call")
|
||||
}
|
||||
|
||||
func TestAdminService_ListAccounts_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &accountRepoStubForAdminList{
|
||||
listWithFiltersAccounts: []Account{{ID: 1, Name: "acc"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 10},
|
||||
}
|
||||
svc := &adminServiceImpl{accountRepo: repo}
|
||||
|
||||
accounts, total, err := svc.ListAccounts(context.Background(), 1, 20, PlatformGemini, AccountTypeOAuth, StatusActive, "acc", 0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(10), total)
|
||||
require.Equal(t, []Account{{ID: 1, Name: "acc"}}, accounts)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, PlatformGemini, repo.listWithFiltersPlatform)
|
||||
require.Equal(t, AccountTypeOAuth, repo.listWithFiltersType)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "acc", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListProxies_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &proxyRepoStubForAdminList{
|
||||
listWithFiltersProxies: []Proxy{{ID: 2, Name: "p1"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 7},
|
||||
}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
proxies, total, err := svc.ListProxies(context.Background(), 3, 50, "http", StatusActive, "p1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(7), total)
|
||||
require.Equal(t, []Proxy{{ID: 2, Name: "p1"}}, proxies)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 3, PageSize: 50}, repo.listWithFiltersParams)
|
||||
require.Equal(t, "http", repo.listWithFiltersProtocol)
|
||||
require.Equal(t, StatusActive, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "p1", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListProxiesWithAccountCount_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &proxyRepoStubForAdminList{
|
||||
listWithFiltersAndAccountCountProxies: []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}},
|
||||
listWithFiltersAndAccountCountResult: &pagination.PaginationResult{Total: 9},
|
||||
}
|
||||
svc := &adminServiceImpl{proxyRepo: repo}
|
||||
|
||||
proxies, total, err := svc.ListProxiesWithAccountCount(context.Background(), 2, 10, "socks5", StatusDisabled, "p2")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(9), total)
|
||||
require.Equal(t, []ProxyWithAccountCount{{Proxy: Proxy{ID: 3, Name: "p2"}, AccountCount: 5}}, proxies)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersAndAccountCountCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 2, PageSize: 10}, repo.listWithFiltersAndAccountCountParams)
|
||||
require.Equal(t, "socks5", repo.listWithFiltersAndAccountCountProtocol)
|
||||
require.Equal(t, StatusDisabled, repo.listWithFiltersAndAccountCountStatus)
|
||||
require.Equal(t, "p2", repo.listWithFiltersAndAccountCountSearch)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdminService_ListRedeemCodes_WithSearch(t *testing.T) {
|
||||
t.Run("search 参数正常传递到 repository 层", func(t *testing.T) {
|
||||
repo := &redeemRepoStubForAdminList{
|
||||
listWithFiltersCodes: []RedeemCode{{ID: 4, Code: "ABC"}},
|
||||
listWithFiltersResult: &pagination.PaginationResult{Total: 3},
|
||||
}
|
||||
svc := &adminServiceImpl{redeemCodeRepo: repo}
|
||||
|
||||
codes, total, err := svc.ListRedeemCodes(context.Background(), 1, 20, RedeemTypeBalance, StatusUnused, "ABC")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(3), total)
|
||||
require.Equal(t, []RedeemCode{{ID: 4, Code: "ABC"}}, codes)
|
||||
|
||||
require.Equal(t, 1, repo.listWithFiltersCalls)
|
||||
require.Equal(t, pagination.PaginationParams{Page: 1, PageSize: 20}, repo.listWithFiltersParams)
|
||||
require.Equal(t, RedeemTypeBalance, repo.listWithFiltersType)
|
||||
require.Equal(t, StatusUnused, repo.listWithFiltersStatus)
|
||||
require.Equal(t, "ABC", repo.listWithFiltersSearch)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type balanceUserRepoStub struct {
|
||||
*userRepoStub
|
||||
updateErr error
|
||||
updated []*User
|
||||
}
|
||||
|
||||
func (s *balanceUserRepoStub) Update(ctx context.Context, user *User) error {
|
||||
if s.updateErr != nil {
|
||||
return s.updateErr
|
||||
}
|
||||
if user == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *user
|
||||
s.updated = append(s.updated, &clone)
|
||||
if s.userRepoStub != nil {
|
||||
s.userRepoStub.user = &clone
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type balanceRedeemRepoStub struct {
|
||||
*redeemRepoStub
|
||||
created []*RedeemCode
|
||||
}
|
||||
|
||||
func (s *balanceRedeemRepoStub) Create(ctx context.Context, code *RedeemCode) error {
|
||||
if code == nil {
|
||||
return nil
|
||||
}
|
||||
clone := *code
|
||||
s.created = append(s.created, &clone)
|
||||
return nil
|
||||
}
|
||||
|
||||
type authCacheInvalidatorStub struct {
|
||||
userIDs []int64
|
||||
groupIDs []int64
|
||||
keys []string
|
||||
}
|
||||
|
||||
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||
s.keys = append(s.keys, key)
|
||||
}
|
||||
|
||||
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||
s.userIDs = append(s.userIDs, userID)
|
||||
}
|
||||
|
||||
func (s *authCacheInvalidatorStub) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||
s.groupIDs = append(s.groupIDs, groupID)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateUserBalance_InvalidatesAuthCache(t *testing.T) {
|
||||
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
redeemCodeRepo: redeemRepo,
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
_, err := svc.UpdateUserBalance(context.Background(), 7, 5, "add", "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{7}, invalidator.userIDs)
|
||||
require.Len(t, redeemRepo.created, 1)
|
||||
}
|
||||
|
||||
func TestAdminService_UpdateUserBalance_NoChangeNoInvalidate(t *testing.T) {
|
||||
baseRepo := &userRepoStub{user: &User{ID: 7, Balance: 10}}
|
||||
repo := &balanceUserRepoStub{userRepoStub: baseRepo}
|
||||
redeemRepo := &balanceRedeemRepoStub{redeemRepoStub: &redeemRepoStub{}}
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &adminServiceImpl{
|
||||
userRepo: repo,
|
||||
redeemCodeRepo: redeemRepo,
|
||||
authCacheInvalidator: invalidator,
|
||||
}
|
||||
|
||||
_, err := svc.UpdateUserBalance(context.Background(), 7, 10, "set", "")
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, invalidator.userIDs)
|
||||
require.Empty(t, redeemRepo.created)
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementStatusDraft = domain.AnnouncementStatusDraft
|
||||
AnnouncementStatusActive = domain.AnnouncementStatusActive
|
||||
AnnouncementStatusArchived = domain.AnnouncementStatusArchived
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementNotifyModeSilent = domain.AnnouncementNotifyModeSilent
|
||||
AnnouncementNotifyModePopup = domain.AnnouncementNotifyModePopup
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementConditionTypeSubscription = domain.AnnouncementConditionTypeSubscription
|
||||
AnnouncementConditionTypeBalance = domain.AnnouncementConditionTypeBalance
|
||||
)
|
||||
|
||||
const (
|
||||
AnnouncementOperatorIn = domain.AnnouncementOperatorIn
|
||||
AnnouncementOperatorGT = domain.AnnouncementOperatorGT
|
||||
AnnouncementOperatorGTE = domain.AnnouncementOperatorGTE
|
||||
AnnouncementOperatorLT = domain.AnnouncementOperatorLT
|
||||
AnnouncementOperatorLTE = domain.AnnouncementOperatorLTE
|
||||
AnnouncementOperatorEQ = domain.AnnouncementOperatorEQ
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAnnouncementNotFound = domain.ErrAnnouncementNotFound
|
||||
ErrAnnouncementInvalidTarget = domain.ErrAnnouncementInvalidTarget
|
||||
)
|
||||
|
||||
type AnnouncementTargeting = domain.AnnouncementTargeting
|
||||
|
||||
type AnnouncementConditionGroup = domain.AnnouncementConditionGroup
|
||||
|
||||
type AnnouncementCondition = domain.AnnouncementCondition
|
||||
|
||||
type Announcement = domain.Announcement
|
||||
|
||||
type AnnouncementListFilters struct {
|
||||
Status string
|
||||
Search string
|
||||
}
|
||||
|
||||
type AnnouncementRepository interface {
|
||||
Create(ctx context.Context, a *Announcement) error
|
||||
GetByID(ctx context.Context, id int64) (*Announcement, error)
|
||||
Update(ctx context.Context, a *Announcement) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error)
|
||||
ListActive(ctx context.Context, now time.Time) ([]Announcement, error)
|
||||
}
|
||||
|
||||
type AnnouncementReadRepository interface {
|
||||
MarkRead(ctx context.Context, announcementID, userID int64, readAt time.Time) error
|
||||
GetReadMapByUser(ctx context.Context, userID int64, announcementIDs []int64) (map[int64]time.Time, error)
|
||||
GetReadMapByUsers(ctx context.Context, announcementID int64, userIDs []int64) (map[int64]time.Time, error)
|
||||
CountByAnnouncementID(ctx context.Context, announcementID int64) (int64, error)
|
||||
}
|
||||
@@ -0,0 +1,406 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
type AnnouncementService struct {
|
||||
announcementRepo AnnouncementRepository
|
||||
readRepo AnnouncementReadRepository
|
||||
userRepo UserRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
}
|
||||
|
||||
func NewAnnouncementService(
|
||||
announcementRepo AnnouncementRepository,
|
||||
readRepo AnnouncementReadRepository,
|
||||
userRepo UserRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
) *AnnouncementService {
|
||||
return &AnnouncementService{
|
||||
announcementRepo: announcementRepo,
|
||||
readRepo: readRepo,
|
||||
userRepo: userRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
}
|
||||
}
|
||||
|
||||
type CreateAnnouncementInput struct {
|
||||
Title string
|
||||
Content string
|
||||
Status string
|
||||
NotifyMode string
|
||||
Targeting AnnouncementTargeting
|
||||
StartsAt *time.Time
|
||||
EndsAt *time.Time
|
||||
ActorID *int64 // 管理员用户ID
|
||||
}
|
||||
|
||||
type UpdateAnnouncementInput struct {
|
||||
Title *string
|
||||
Content *string
|
||||
Status *string
|
||||
NotifyMode *string
|
||||
Targeting *AnnouncementTargeting
|
||||
StartsAt **time.Time
|
||||
EndsAt **time.Time
|
||||
ActorID *int64 // 管理员用户ID
|
||||
}
|
||||
|
||||
type UserAnnouncement struct {
|
||||
Announcement Announcement
|
||||
ReadAt *time.Time
|
||||
}
|
||||
|
||||
type AnnouncementUserReadStatus struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Email string `json:"email"`
|
||||
Username string `json:"username"`
|
||||
Balance float64 `json:"balance"`
|
||||
Eligible bool `json:"eligible"`
|
||||
ReadAt *time.Time `json:"read_at,omitempty"`
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) Create(ctx context.Context, input *CreateAnnouncementInput) (*Announcement, error) {
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("create announcement: nil input")
|
||||
}
|
||||
|
||||
title := strings.TrimSpace(input.Title)
|
||||
content := strings.TrimSpace(input.Content)
|
||||
if title == "" || len(title) > 200 {
|
||||
return nil, fmt.Errorf("create announcement: invalid title")
|
||||
}
|
||||
if content == "" {
|
||||
return nil, fmt.Errorf("create announcement: content is required")
|
||||
}
|
||||
|
||||
status := strings.TrimSpace(input.Status)
|
||||
if status == "" {
|
||||
status = AnnouncementStatusDraft
|
||||
}
|
||||
if !isValidAnnouncementStatus(status) {
|
||||
return nil, fmt.Errorf("create announcement: invalid status")
|
||||
}
|
||||
|
||||
targeting, err := domain.AnnouncementTargeting(input.Targeting).NormalizeAndValidate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
notifyMode := strings.TrimSpace(input.NotifyMode)
|
||||
if notifyMode == "" {
|
||||
notifyMode = AnnouncementNotifyModeSilent
|
||||
}
|
||||
if !isValidAnnouncementNotifyMode(notifyMode) {
|
||||
return nil, fmt.Errorf("create announcement: invalid notify_mode")
|
||||
}
|
||||
|
||||
if input.StartsAt != nil && input.EndsAt != nil {
|
||||
if !input.StartsAt.Before(*input.EndsAt) {
|
||||
return nil, fmt.Errorf("create announcement: starts_at must be before ends_at")
|
||||
}
|
||||
}
|
||||
|
||||
a := &Announcement{
|
||||
Title: title,
|
||||
Content: content,
|
||||
Status: status,
|
||||
NotifyMode: notifyMode,
|
||||
Targeting: targeting,
|
||||
StartsAt: input.StartsAt,
|
||||
EndsAt: input.EndsAt,
|
||||
}
|
||||
if input.ActorID != nil && *input.ActorID > 0 {
|
||||
a.CreatedBy = input.ActorID
|
||||
a.UpdatedBy = input.ActorID
|
||||
}
|
||||
|
||||
if err := s.announcementRepo.Create(ctx, a); err != nil {
|
||||
return nil, fmt.Errorf("create announcement: %w", err)
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) Update(ctx context.Context, id int64, input *UpdateAnnouncementInput) (*Announcement, error) {
|
||||
if input == nil {
|
||||
return nil, fmt.Errorf("update announcement: nil input")
|
||||
}
|
||||
|
||||
a, err := s.announcementRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if input.Title != nil {
|
||||
title := strings.TrimSpace(*input.Title)
|
||||
if title == "" || len(title) > 200 {
|
||||
return nil, fmt.Errorf("update announcement: invalid title")
|
||||
}
|
||||
a.Title = title
|
||||
}
|
||||
if input.Content != nil {
|
||||
content := strings.TrimSpace(*input.Content)
|
||||
if content == "" {
|
||||
return nil, fmt.Errorf("update announcement: content is required")
|
||||
}
|
||||
a.Content = content
|
||||
}
|
||||
if input.Status != nil {
|
||||
status := strings.TrimSpace(*input.Status)
|
||||
if !isValidAnnouncementStatus(status) {
|
||||
return nil, fmt.Errorf("update announcement: invalid status")
|
||||
}
|
||||
a.Status = status
|
||||
}
|
||||
|
||||
if input.NotifyMode != nil {
|
||||
notifyMode := strings.TrimSpace(*input.NotifyMode)
|
||||
if !isValidAnnouncementNotifyMode(notifyMode) {
|
||||
return nil, fmt.Errorf("update announcement: invalid notify_mode")
|
||||
}
|
||||
a.NotifyMode = notifyMode
|
||||
}
|
||||
|
||||
if input.Targeting != nil {
|
||||
targeting, err := domain.AnnouncementTargeting(*input.Targeting).NormalizeAndValidate()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.Targeting = targeting
|
||||
}
|
||||
|
||||
if input.StartsAt != nil {
|
||||
a.StartsAt = *input.StartsAt
|
||||
}
|
||||
if input.EndsAt != nil {
|
||||
a.EndsAt = *input.EndsAt
|
||||
}
|
||||
|
||||
if a.StartsAt != nil && a.EndsAt != nil {
|
||||
if !a.StartsAt.Before(*a.EndsAt) {
|
||||
return nil, fmt.Errorf("update announcement: starts_at must be before ends_at")
|
||||
}
|
||||
}
|
||||
|
||||
if input.ActorID != nil && *input.ActorID > 0 {
|
||||
a.UpdatedBy = input.ActorID
|
||||
}
|
||||
|
||||
if err := s.announcementRepo.Update(ctx, a); err != nil {
|
||||
return nil, fmt.Errorf("update announcement: %w", err)
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) Delete(ctx context.Context, id int64) error {
|
||||
if err := s.announcementRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete announcement: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) GetByID(ctx context.Context, id int64) (*Announcement, error) {
|
||||
return s.announcementRepo.GetByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) List(ctx context.Context, params pagination.PaginationParams, filters AnnouncementListFilters) ([]Announcement, *pagination.PaginationResult, error) {
|
||||
return s.announcementRepo.List(ctx, params, filters)
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) ListForUser(ctx context.Context, userID int64, unreadOnly bool) ([]UserAnnouncement, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active subscriptions: %w", err)
|
||||
}
|
||||
activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
|
||||
for i := range activeSubs {
|
||||
activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
anns, err := s.announcementRepo.ListActive(ctx, now)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active announcements: %w", err)
|
||||
}
|
||||
|
||||
visible := make([]Announcement, 0, len(anns))
|
||||
ids := make([]int64, 0, len(anns))
|
||||
for i := range anns {
|
||||
a := anns[i]
|
||||
if !a.IsActiveAt(now) {
|
||||
continue
|
||||
}
|
||||
if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
|
||||
continue
|
||||
}
|
||||
visible = append(visible, a)
|
||||
ids = append(ids, a.ID)
|
||||
}
|
||||
|
||||
if len(visible) == 0 {
|
||||
return []UserAnnouncement{}, nil
|
||||
}
|
||||
|
||||
readMap, err := s.readRepo.GetReadMapByUser(ctx, userID, ids)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get read map: %w", err)
|
||||
}
|
||||
|
||||
out := make([]UserAnnouncement, 0, len(visible))
|
||||
for i := range visible {
|
||||
a := visible[i]
|
||||
readAt, ok := readMap[a.ID]
|
||||
if unreadOnly && ok {
|
||||
continue
|
||||
}
|
||||
var ptr *time.Time
|
||||
if ok {
|
||||
t := readAt
|
||||
ptr = &t
|
||||
}
|
||||
out = append(out, UserAnnouncement{
|
||||
Announcement: a,
|
||||
ReadAt: ptr,
|
||||
})
|
||||
}
|
||||
|
||||
// 未读优先、同状态按创建时间倒序
|
||||
sort.Slice(out, func(i, j int) bool {
|
||||
ai, aj := out[i], out[j]
|
||||
if (ai.ReadAt == nil) != (aj.ReadAt == nil) {
|
||||
return ai.ReadAt == nil
|
||||
}
|
||||
return ai.Announcement.ID > aj.Announcement.ID
|
||||
})
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) MarkRead(ctx context.Context, userID, announcementID int64) error {
|
||||
// 安全:仅允许标记当前用户“可见”的公告
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
a, err := s.announcementRepo.GetByID(ctx, announcementID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if !a.IsActiveAt(now) {
|
||||
return ErrAnnouncementNotFound
|
||||
}
|
||||
|
||||
activeSubs, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("list active subscriptions: %w", err)
|
||||
}
|
||||
activeGroupIDs := make(map[int64]struct{}, len(activeSubs))
|
||||
for i := range activeSubs {
|
||||
activeGroupIDs[activeSubs[i].GroupID] = struct{}{}
|
||||
}
|
||||
|
||||
if !a.Targeting.Matches(user.Balance, activeGroupIDs) {
|
||||
return ErrAnnouncementNotFound
|
||||
}
|
||||
|
||||
if err := s.readRepo.MarkRead(ctx, announcementID, userID, now); err != nil {
|
||||
return fmt.Errorf("mark read: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AnnouncementService) ListUserReadStatus(
|
||||
ctx context.Context,
|
||||
announcementID int64,
|
||||
params pagination.PaginationParams,
|
||||
search string,
|
||||
) ([]AnnouncementUserReadStatus, *pagination.PaginationResult, error) {
|
||||
ann, err := s.announcementRepo.GetByID(ctx, announcementID)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
filters := UserListFilters{
|
||||
Search: strings.TrimSpace(search),
|
||||
}
|
||||
|
||||
users, page, err := s.userRepo.ListWithFilters(ctx, params, filters)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list users: %w", err)
|
||||
}
|
||||
|
||||
userIDs := make([]int64, 0, len(users))
|
||||
for i := range users {
|
||||
userIDs = append(userIDs, users[i].ID)
|
||||
}
|
||||
|
||||
readMap, err := s.readRepo.GetReadMapByUsers(ctx, announcementID, userIDs)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get read map: %w", err)
|
||||
}
|
||||
|
||||
out := make([]AnnouncementUserReadStatus, 0, len(users))
|
||||
for i := range users {
|
||||
u := users[i]
|
||||
subs, err := s.userSubRepo.ListActiveByUserID(ctx, u.ID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list active subscriptions: %w", err)
|
||||
}
|
||||
activeGroupIDs := make(map[int64]struct{}, len(subs))
|
||||
for j := range subs {
|
||||
activeGroupIDs[subs[j].GroupID] = struct{}{}
|
||||
}
|
||||
|
||||
readAt, ok := readMap[u.ID]
|
||||
var ptr *time.Time
|
||||
if ok {
|
||||
t := readAt
|
||||
ptr = &t
|
||||
}
|
||||
|
||||
out = append(out, AnnouncementUserReadStatus{
|
||||
UserID: u.ID,
|
||||
Email: u.Email,
|
||||
Username: u.Username,
|
||||
Balance: u.Balance,
|
||||
Eligible: domain.AnnouncementTargeting(ann.Targeting).Matches(u.Balance, activeGroupIDs),
|
||||
ReadAt: ptr,
|
||||
})
|
||||
}
|
||||
|
||||
return out, page, nil
|
||||
}
|
||||
|
||||
func isValidAnnouncementStatus(status string) bool {
|
||||
switch status {
|
||||
case AnnouncementStatusDraft, AnnouncementStatusActive, AnnouncementStatusArchived:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isValidAnnouncementNotifyMode(mode string) bool {
|
||||
switch mode {
|
||||
case AnnouncementNotifyModeSilent, AnnouncementNotifyModePopup:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAnnouncementTargeting_Matches_EmptyMatchesAll(t *testing.T) {
|
||||
var targeting AnnouncementTargeting
|
||||
require.True(t, targeting.Matches(0, nil))
|
||||
require.True(t, targeting.Matches(123.45, map[int64]struct{}{1: {}}))
|
||||
}
|
||||
|
||||
func TestAnnouncementTargeting_NormalizeAndValidate_RejectsEmptyGroup(t *testing.T) {
|
||||
targeting := AnnouncementTargeting{
|
||||
AnyOf: []AnnouncementConditionGroup{
|
||||
{AllOf: nil},
|
||||
},
|
||||
}
|
||||
_, err := targeting.NormalizeAndValidate()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
|
||||
}
|
||||
|
||||
func TestAnnouncementTargeting_NormalizeAndValidate_RejectsInvalidCondition(t *testing.T) {
|
||||
targeting := AnnouncementTargeting{
|
||||
AnyOf: []AnnouncementConditionGroup{
|
||||
{
|
||||
AllOf: []AnnouncementCondition{
|
||||
{Type: "balance", Operator: "between", Value: 10},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
_, err := targeting.NormalizeAndValidate()
|
||||
require.Error(t, err)
|
||||
require.ErrorIs(t, err, ErrAnnouncementInvalidTarget)
|
||||
}
|
||||
|
||||
func TestAnnouncementTargeting_Matches_AndOrSemantics(t *testing.T) {
|
||||
targeting := AnnouncementTargeting{
|
||||
AnyOf: []AnnouncementConditionGroup{
|
||||
{
|
||||
AllOf: []AnnouncementCondition{
|
||||
{Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorGTE, Value: 100},
|
||||
{Type: AnnouncementConditionTypeSubscription, Operator: AnnouncementOperatorIn, GroupIDs: []int64{10}},
|
||||
},
|
||||
},
|
||||
{
|
||||
AllOf: []AnnouncementCondition{
|
||||
{Type: AnnouncementConditionTypeBalance, Operator: AnnouncementOperatorLT, Value: 5},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// 命中第 2 组(balance < 5)
|
||||
require.True(t, targeting.Matches(4.99, nil))
|
||||
require.False(t, targeting.Matches(5, nil))
|
||||
|
||||
// 命中第 1 组(balance >= 100 AND 订阅 in [10])
|
||||
require.False(t, targeting.Matches(100, map[int64]struct{}{}))
|
||||
require.False(t, targeting.Matches(99.9, map[int64]struct{}{10: {}}))
|
||||
require.True(t, targeting.Matches(100, map[int64]struct{}{10: {}}))
|
||||
}
|
||||
@@ -0,0 +1,79 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Anthropic 会话 Fallback 相关常量
|
||||
const (
|
||||
// anthropicSessionTTLSeconds Anthropic 会话缓存 TTL(5 分钟)
|
||||
anthropicSessionTTLSeconds = 300
|
||||
|
||||
// anthropicDigestSessionKeyPrefix Anthropic 摘要 fallback 会话 key 前缀
|
||||
anthropicDigestSessionKeyPrefix = "anthropic:digest:"
|
||||
)
|
||||
|
||||
// AnthropicSessionTTL 返回 Anthropic 会话缓存 TTL
|
||||
func AnthropicSessionTTL() time.Duration {
|
||||
return anthropicSessionTTLSeconds * time.Second
|
||||
}
|
||||
|
||||
// BuildAnthropicDigestChain 根据 Anthropic 请求生成摘要链
|
||||
// 格式: s:<hash>-u:<hash>-a:<hash>-u:<hash>-...
|
||||
// s = system, u = user, a = assistant
|
||||
func BuildAnthropicDigestChain(parsed *ParsedRequest) string {
|
||||
if parsed == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var parts []string
|
||||
|
||||
// 1. system prompt
|
||||
if parsed.System != nil {
|
||||
systemData, _ := json.Marshal(parsed.System)
|
||||
if len(systemData) > 0 && string(systemData) != "null" {
|
||||
parts = append(parts, "s:"+shortHash(systemData))
|
||||
}
|
||||
}
|
||||
|
||||
// 2. messages
|
||||
for _, msg := range parsed.Messages {
|
||||
msgMap, ok := msg.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
role, _ := msgMap["role"].(string)
|
||||
prefix := rolePrefix(role)
|
||||
content := msgMap["content"]
|
||||
contentData, _ := json.Marshal(content)
|
||||
parts = append(parts, prefix+":"+shortHash(contentData))
|
||||
}
|
||||
|
||||
return strings.Join(parts, "-")
|
||||
}
|
||||
|
||||
// rolePrefix 将 Anthropic 的 role 映射为单字符前缀
|
||||
func rolePrefix(role string) string {
|
||||
switch role {
|
||||
case "assistant":
|
||||
return "a"
|
||||
default:
|
||||
return "u"
|
||||
}
|
||||
}
|
||||
|
||||
// GenerateAnthropicDigestSessionKey 生成 Anthropic 摘要 fallback 的 sessionKey
|
||||
// 组合 prefixHash 前 8 位 + uuid 前 8 位,确保不同会话产生不同的 sessionKey
|
||||
func GenerateAnthropicDigestSessionKey(prefixHash, uuid string) string {
|
||||
prefix := prefixHash
|
||||
if len(prefixHash) >= 8 {
|
||||
prefix = prefixHash[:8]
|
||||
}
|
||||
uuidPart := uuid
|
||||
if len(uuid) >= 8 {
|
||||
uuidPart = uuid[:8]
|
||||
}
|
||||
return anthropicDigestSessionKeyPrefix + prefix + ":" + uuidPart
|
||||
}
|
||||
@@ -0,0 +1,320 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildAnthropicDigestChain_NilRequest(t *testing.T) {
|
||||
result := BuildAnthropicDigestChain(nil)
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string for nil request, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_EmptyMessages(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
Messages: []any{},
|
||||
}
|
||||
result := BuildAnthropicDigestChain(parsed)
|
||||
if result != "" {
|
||||
t.Errorf("expected empty string for empty messages, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_SingleUserMessage(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
result := BuildAnthropicDigestChain(parsed)
|
||||
parts := splitChain(result)
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
|
||||
}
|
||||
if !strings.HasPrefix(parts[0], "u:") {
|
||||
t.Errorf("expected prefix 'u:', got: %s", parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_UserAndAssistant(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
map[string]any{"role": "assistant", "content": "hi there"},
|
||||
},
|
||||
}
|
||||
result := BuildAnthropicDigestChain(parsed)
|
||||
parts := splitChain(result)
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("expected 2 parts, got %d: %s", len(parts), result)
|
||||
}
|
||||
if !strings.HasPrefix(parts[0], "u:") {
|
||||
t.Errorf("part[0] expected prefix 'u:', got: %s", parts[0])
|
||||
}
|
||||
if !strings.HasPrefix(parts[1], "a:") {
|
||||
t.Errorf("part[1] expected prefix 'a:', got: %s", parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_WithSystemString(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
System: "You are a helpful assistant",
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
result := BuildAnthropicDigestChain(parsed)
|
||||
parts := splitChain(result)
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
|
||||
}
|
||||
if !strings.HasPrefix(parts[0], "s:") {
|
||||
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
|
||||
}
|
||||
if !strings.HasPrefix(parts[1], "u:") {
|
||||
t.Errorf("part[1] expected prefix 'u:', got: %s", parts[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_WithSystemContentBlocks(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
System: []any{
|
||||
map[string]any{"type": "text", "text": "You are a helpful assistant"},
|
||||
},
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
result := BuildAnthropicDigestChain(parsed)
|
||||
parts := splitChain(result)
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("expected 2 parts (s + u), got %d: %s", len(parts), result)
|
||||
}
|
||||
if !strings.HasPrefix(parts[0], "s:") {
|
||||
t.Errorf("part[0] expected prefix 's:', got: %s", parts[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_ConversationPrefixRelationship(t *testing.T) {
|
||||
// 核心测试:验证对话增长时链的前缀关系
|
||||
// 上一轮的完整链一定是下一轮链的前缀
|
||||
system := "You are a helpful assistant"
|
||||
|
||||
// 第 1 轮: system + user
|
||||
round1 := &ParsedRequest{
|
||||
System: system,
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
chain1 := BuildAnthropicDigestChain(round1)
|
||||
|
||||
// 第 2 轮: system + user + assistant + user
|
||||
round2 := &ParsedRequest{
|
||||
System: system,
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
map[string]any{"role": "assistant", "content": "hi there"},
|
||||
map[string]any{"role": "user", "content": "how are you?"},
|
||||
},
|
||||
}
|
||||
chain2 := BuildAnthropicDigestChain(round2)
|
||||
|
||||
// 第 3 轮: system + user + assistant + user + assistant + user
|
||||
round3 := &ParsedRequest{
|
||||
System: system,
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
map[string]any{"role": "assistant", "content": "hi there"},
|
||||
map[string]any{"role": "user", "content": "how are you?"},
|
||||
map[string]any{"role": "assistant", "content": "I'm doing well"},
|
||||
map[string]any{"role": "user", "content": "great"},
|
||||
},
|
||||
}
|
||||
chain3 := BuildAnthropicDigestChain(round3)
|
||||
|
||||
t.Logf("Chain1: %s", chain1)
|
||||
t.Logf("Chain2: %s", chain2)
|
||||
t.Logf("Chain3: %s", chain3)
|
||||
|
||||
// chain1 是 chain2 的前缀
|
||||
if !strings.HasPrefix(chain2, chain1) {
|
||||
t.Errorf("chain1 should be prefix of chain2:\n chain1: %s\n chain2: %s", chain1, chain2)
|
||||
}
|
||||
|
||||
// chain2 是 chain3 的前缀
|
||||
if !strings.HasPrefix(chain3, chain2) {
|
||||
t.Errorf("chain2 should be prefix of chain3:\n chain2: %s\n chain3: %s", chain2, chain3)
|
||||
}
|
||||
|
||||
// chain1 也是 chain3 的前缀(传递性)
|
||||
if !strings.HasPrefix(chain3, chain1) {
|
||||
t.Errorf("chain1 should be prefix of chain3:\n chain1: %s\n chain3: %s", chain1, chain3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_DifferentSystemProducesDifferentChain(t *testing.T) {
|
||||
parsed1 := &ParsedRequest{
|
||||
System: "System A",
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
parsed2 := &ParsedRequest{
|
||||
System: "System B",
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
},
|
||||
}
|
||||
|
||||
chain1 := BuildAnthropicDigestChain(parsed1)
|
||||
chain2 := BuildAnthropicDigestChain(parsed2)
|
||||
|
||||
if chain1 == chain2 {
|
||||
t.Error("Different system prompts should produce different chains")
|
||||
}
|
||||
|
||||
// 但 user 部分的 hash 应该相同
|
||||
parts1 := splitChain(chain1)
|
||||
parts2 := splitChain(chain2)
|
||||
if parts1[1] != parts2[1] {
|
||||
t.Error("Same user message should produce same hash regardless of system")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_DifferentContentProducesDifferentChain(t *testing.T) {
|
||||
parsed1 := &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
map[string]any{"role": "assistant", "content": "ORIGINAL reply"},
|
||||
map[string]any{"role": "user", "content": "next"},
|
||||
},
|
||||
}
|
||||
parsed2 := &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
map[string]any{"role": "assistant", "content": "TAMPERED reply"},
|
||||
map[string]any{"role": "user", "content": "next"},
|
||||
},
|
||||
}
|
||||
|
||||
chain1 := BuildAnthropicDigestChain(parsed1)
|
||||
chain2 := BuildAnthropicDigestChain(parsed2)
|
||||
|
||||
if chain1 == chain2 {
|
||||
t.Error("Different content should produce different chains")
|
||||
}
|
||||
|
||||
parts1 := splitChain(chain1)
|
||||
parts2 := splitChain(chain2)
|
||||
// 第一个 user message hash 应该相同
|
||||
if parts1[0] != parts2[0] {
|
||||
t.Error("First user message hash should be the same")
|
||||
}
|
||||
// assistant reply hash 应该不同
|
||||
if parts1[1] == parts2[1] {
|
||||
t.Error("Assistant reply hash should differ")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_Deterministic(t *testing.T) {
|
||||
parsed := &ParsedRequest{
|
||||
System: "test system",
|
||||
Messages: []any{
|
||||
map[string]any{"role": "user", "content": "hello"},
|
||||
map[string]any{"role": "assistant", "content": "hi"},
|
||||
},
|
||||
}
|
||||
|
||||
chain1 := BuildAnthropicDigestChain(parsed)
|
||||
chain2 := BuildAnthropicDigestChain(parsed)
|
||||
|
||||
if chain1 != chain2 {
|
||||
t.Errorf("BuildAnthropicDigestChain not deterministic: %s vs %s", chain1, chain2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAnthropicDigestSessionKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prefixHash string
|
||||
uuid string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal 16 char hash with uuid",
|
||||
prefixHash: "abcdefgh12345678",
|
||||
uuid: "550e8400-e29b-41d4-a716-446655440000",
|
||||
want: "anthropic:digest:abcdefgh:550e8400",
|
||||
},
|
||||
{
|
||||
name: "exactly 8 chars",
|
||||
prefixHash: "12345678",
|
||||
uuid: "abcdefgh",
|
||||
want: "anthropic:digest:12345678:abcdefgh",
|
||||
},
|
||||
{
|
||||
name: "short values",
|
||||
prefixHash: "abc",
|
||||
uuid: "xyz",
|
||||
want: "anthropic:digest:abc:xyz",
|
||||
},
|
||||
{
|
||||
name: "empty values",
|
||||
prefixHash: "",
|
||||
uuid: "",
|
||||
want: "anthropic:digest::",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GenerateAnthropicDigestSessionKey(tt.prefixHash, tt.uuid)
|
||||
if got != tt.want {
|
||||
t.Errorf("GenerateAnthropicDigestSessionKey(%q, %q) = %q, want %q", tt.prefixHash, tt.uuid, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// 验证不同 uuid 产生不同 sessionKey
|
||||
t.Run("different uuid different key", func(t *testing.T) {
|
||||
hash := "sameprefix123456"
|
||||
result1 := GenerateAnthropicDigestSessionKey(hash, "uuid0001-session-a")
|
||||
result2 := GenerateAnthropicDigestSessionKey(hash, "uuid0002-session-b")
|
||||
if result1 == result2 {
|
||||
t.Errorf("Different UUIDs should produce different session keys: %s vs %s", result1, result2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAnthropicSessionTTL(t *testing.T) {
|
||||
ttl := AnthropicSessionTTL()
|
||||
if ttl.Seconds() != 300 {
|
||||
t.Errorf("expected 300 seconds, got: %v", ttl.Seconds())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildAnthropicDigestChain_ContentBlocks(t *testing.T) {
|
||||
// 测试 content 为 content blocks 数组的情况
|
||||
parsed := &ParsedRequest{
|
||||
Messages: []any{
|
||||
map[string]any{
|
||||
"role": "user",
|
||||
"content": []any{
|
||||
map[string]any{"type": "text", "text": "describe this image"},
|
||||
map[string]any{"type": "image", "source": map[string]any{"type": "base64"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
result := BuildAnthropicDigestChain(parsed)
|
||||
parts := splitChain(result)
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("expected 1 part, got %d: %s", len(parts), result)
|
||||
}
|
||||
if !strings.HasPrefix(parts[0], "u:") {
|
||||
t.Errorf("expected prefix 'u:', got: %s", parts[0])
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,234 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
// creditsExhaustedKey 是 model_rate_limits 中标记积分耗尽的特殊 key。
|
||||
// 与普通模型限流完全同构:通过 SetModelRateLimit / isRateLimitActiveForKey 读写。
|
||||
creditsExhaustedKey = "AICredits"
|
||||
creditsExhaustedDuration = 5 * time.Hour
|
||||
)
|
||||
|
||||
type antigravity429Category string
|
||||
|
||||
const (
|
||||
antigravity429Unknown antigravity429Category = "unknown"
|
||||
antigravity429RateLimited antigravity429Category = "rate_limited"
|
||||
antigravity429QuotaExhausted antigravity429Category = "quota_exhausted"
|
||||
)
|
||||
|
||||
var (
|
||||
antigravityQuotaExhaustedKeywords = []string{
|
||||
"quota_exhausted",
|
||||
"quota exhausted",
|
||||
}
|
||||
|
||||
creditsExhaustedKeywords = []string{
|
||||
"google_one_ai",
|
||||
"insufficient credit",
|
||||
"insufficient credits",
|
||||
"not enough credit",
|
||||
"not enough credits",
|
||||
"credit exhausted",
|
||||
"credits exhausted",
|
||||
"credit balance",
|
||||
"minimumcreditamountforusage",
|
||||
"minimum credit amount for usage",
|
||||
"minimum credit",
|
||||
}
|
||||
)
|
||||
|
||||
// isCreditsExhausted 检查账号的 AICredits 限流 key 是否生效(积分是否耗尽)。
|
||||
func (a *Account) isCreditsExhausted() bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
return a.isRateLimitActiveForKey(creditsExhaustedKey)
|
||||
}
|
||||
|
||||
// setCreditsExhausted 标记账号积分耗尽:写入 model_rate_limits["AICredits"] + 更新缓存。
|
||||
func (s *AntigravityGatewayService) setCreditsExhausted(ctx context.Context, account *Account) {
|
||||
if account == nil || account.ID == 0 {
|
||||
return
|
||||
}
|
||||
resetAt := time.Now().Add(creditsExhaustedDuration)
|
||||
if err := s.accountRepo.SetModelRateLimit(ctx, account.ID, creditsExhaustedKey, resetAt); err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "set credits exhausted failed: account=%d err=%v", account.ID, err)
|
||||
return
|
||||
}
|
||||
s.updateAccountModelRateLimitInCache(ctx, account, creditsExhaustedKey, resetAt)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "credits_exhausted_marked account=%d reset_at=%s",
|
||||
account.ID, resetAt.UTC().Format(time.RFC3339))
|
||||
}
|
||||
|
||||
// clearCreditsExhausted 清除账号的 AICredits 限流 key。
|
||||
func (s *AntigravityGatewayService) clearCreditsExhausted(ctx context.Context, account *Account) {
|
||||
if account == nil || account.ID == 0 || account.Extra == nil {
|
||||
return
|
||||
}
|
||||
rawLimits, ok := account.Extra[modelRateLimitsKey].(map[string]any)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if _, exists := rawLimits[creditsExhaustedKey]; !exists {
|
||||
return
|
||||
}
|
||||
delete(rawLimits, creditsExhaustedKey)
|
||||
account.Extra[modelRateLimitsKey] = rawLimits
|
||||
if err := s.accountRepo.UpdateExtra(ctx, account.ID, map[string]any{
|
||||
modelRateLimitsKey: rawLimits,
|
||||
}); err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "clear credits exhausted failed: account=%d err=%v", account.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// classifyAntigravity429 将 Antigravity 的 429 响应归类为配额耗尽、限流或未知。
|
||||
func classifyAntigravity429(body []byte) antigravity429Category {
|
||||
if len(body) == 0 {
|
||||
return antigravity429Unknown
|
||||
}
|
||||
lowerBody := strings.ToLower(string(body))
|
||||
for _, keyword := range antigravityQuotaExhaustedKeywords {
|
||||
if strings.Contains(lowerBody, keyword) {
|
||||
return antigravity429QuotaExhausted
|
||||
}
|
||||
}
|
||||
if info := parseAntigravitySmartRetryInfo(body); info != nil && !info.IsModelCapacityExhausted {
|
||||
return antigravity429RateLimited
|
||||
}
|
||||
return antigravity429Unknown
|
||||
}
|
||||
|
||||
// injectEnabledCreditTypes 在已序列化的 v1internal JSON body 中注入 AI Credits 类型。
|
||||
func injectEnabledCreditTypes(body []byte) []byte {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return nil
|
||||
}
|
||||
payload["enabledCreditTypes"] = []string{"GOOGLE_ONE_AI"}
|
||||
result, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// resolveCreditsOveragesModelKey 解析当前请求对应的 overages 状态模型 key。
|
||||
func resolveCreditsOveragesModelKey(ctx context.Context, account *Account, upstreamModelName, requestedModel string) string {
|
||||
modelKey := strings.TrimSpace(upstreamModelName)
|
||||
if modelKey != "" {
|
||||
return modelKey
|
||||
}
|
||||
if account == nil {
|
||||
return ""
|
||||
}
|
||||
modelKey = resolveFinalAntigravityModelKey(ctx, account, requestedModel)
|
||||
if strings.TrimSpace(modelKey) != "" {
|
||||
return modelKey
|
||||
}
|
||||
return resolveAntigravityModelKey(requestedModel)
|
||||
}
|
||||
|
||||
// shouldMarkCreditsExhausted 判断一次 credits 请求失败是否应标记为 credits 耗尽。
|
||||
func shouldMarkCreditsExhausted(resp *http.Response, respBody []byte, reqErr error) bool {
|
||||
if reqErr != nil || resp == nil {
|
||||
return false
|
||||
}
|
||||
if resp.StatusCode >= 500 || resp.StatusCode == http.StatusRequestTimeout {
|
||||
return false
|
||||
}
|
||||
if isURLLevelRateLimit(respBody) {
|
||||
return false
|
||||
}
|
||||
if info := parseAntigravitySmartRetryInfo(respBody); info != nil {
|
||||
return false
|
||||
}
|
||||
bodyLower := strings.ToLower(string(respBody))
|
||||
for _, keyword := range creditsExhaustedKeywords {
|
||||
if strings.Contains(bodyLower, keyword) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
type creditsOveragesRetryResult struct {
|
||||
handled bool
|
||||
resp *http.Response
|
||||
}
|
||||
|
||||
// attemptCreditsOveragesRetry 在确认免费配额耗尽后,尝试注入 AI Credits 继续请求。
|
||||
func (s *AntigravityGatewayService) attemptCreditsOveragesRetry(
|
||||
p antigravityRetryLoopParams,
|
||||
baseURL string,
|
||||
modelName string,
|
||||
waitDuration time.Duration,
|
||||
originalStatusCode int,
|
||||
respBody []byte,
|
||||
) *creditsOveragesRetryResult {
|
||||
creditsBody := injectEnabledCreditTypes(p.body)
|
||||
if creditsBody == nil {
|
||||
return &creditsOveragesRetryResult{handled: false}
|
||||
}
|
||||
modelKey := resolveCreditsOveragesModelKey(p.ctx, p.account, modelName, p.requestedModel)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=429 credit_overages_retry model=%s account=%d (injecting enabledCreditTypes)",
|
||||
p.prefix, modelKey, p.account.ID)
|
||||
|
||||
creditsReq, err := antigravity.NewAPIRequestWithURL(p.ctx, baseURL, p.action, p.accessToken, creditsBody)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d build_request_err=%v",
|
||||
p.prefix, modelKey, p.account.ID, err)
|
||||
return &creditsOveragesRetryResult{handled: true}
|
||||
}
|
||||
|
||||
creditsResp, err := p.httpUpstream.Do(creditsReq, p.proxyURL, p.account.ID, p.account.Concurrency)
|
||||
if err == nil && creditsResp != nil && creditsResp.StatusCode < 400 {
|
||||
s.clearCreditsExhausted(p.ctx, p.account)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s status=%d credit_overages_success model=%s account=%d",
|
||||
p.prefix, creditsResp.StatusCode, modelKey, p.account.ID)
|
||||
return &creditsOveragesRetryResult{handled: true, resp: creditsResp}
|
||||
}
|
||||
|
||||
s.handleCreditsRetryFailure(p.ctx, p.prefix, modelKey, p.account, creditsResp, err)
|
||||
return &creditsOveragesRetryResult{handled: true}
|
||||
}
|
||||
|
||||
func (s *AntigravityGatewayService) handleCreditsRetryFailure(
|
||||
ctx context.Context,
|
||||
prefix string,
|
||||
modelKey string,
|
||||
account *Account,
|
||||
creditsResp *http.Response,
|
||||
reqErr error,
|
||||
) {
|
||||
var creditsRespBody []byte
|
||||
creditsStatusCode := 0
|
||||
if creditsResp != nil {
|
||||
creditsStatusCode = creditsResp.StatusCode
|
||||
if creditsResp.Body != nil {
|
||||
creditsRespBody, _ = io.ReadAll(io.LimitReader(creditsResp.Body, 64<<10))
|
||||
_ = creditsResp.Body.Close()
|
||||
}
|
||||
}
|
||||
|
||||
if shouldMarkCreditsExhausted(creditsResp, creditsRespBody, reqErr) && account != nil {
|
||||
s.setCreditsExhausted(ctx, account)
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=true status=%d body=%s",
|
||||
prefix, modelKey, account.ID, creditsStatusCode, truncateForLog(creditsRespBody, 200))
|
||||
return
|
||||
}
|
||||
if account != nil {
|
||||
logger.LegacyPrintf("service.antigravity_gateway", "%s credit_overages_failed model=%s account=%d marked_exhausted=false status=%d err=%v body=%s",
|
||||
prefix, modelKey, account.ID, creditsStatusCode, reqErr, truncateForLog(creditsRespBody, 200))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,538 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClassifyAntigravity429(t *testing.T) {
|
||||
t.Run("明确配额耗尽", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
|
||||
require.Equal(t, antigravity429QuotaExhausted, classifyAntigravity429(body))
|
||||
})
|
||||
|
||||
t.Run("结构化限流", func(t *testing.T) {
|
||||
body := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.5s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
require.Equal(t, antigravity429RateLimited, classifyAntigravity429(body))
|
||||
})
|
||||
|
||||
t.Run("未知429", func(t *testing.T) {
|
||||
body := []byte(`{"error":{"message":"too many requests"}}`)
|
||||
require.Equal(t, antigravity429Unknown, classifyAntigravity429(body))
|
||||
})
|
||||
}
|
||||
|
||||
func TestIsCreditsExhausted_UsesAICreditsKey(t *testing.T) {
|
||||
t.Run("无 AICredits key 则积分可用", func(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
},
|
||||
}
|
||||
require.False(t, account.isCreditsExhausted())
|
||||
})
|
||||
|
||||
t.Run("AICredits key 生效则积分耗尽", func(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.True(t, account.isCreditsExhausted())
|
||||
})
|
||||
|
||||
t.Run("AICredits key 过期则积分可用", func(t *testing.T) {
|
||||
account := &Account{
|
||||
ID: 3,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": time.Now().Add(-6 * time.Hour).UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(-1 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
require.False(t, account.isCreditsExhausted())
|
||||
})
|
||||
}
|
||||
|
||||
func TestHandleSmartRetry_QuotaExhausted_UsesCreditsAndStoresIndependentState(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Name: "acc-101",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
},
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": map[string]any{
|
||||
"claude-opus-4-6": "claude-sonnet-4-5",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
respBody := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","message":"QUOTA_EXHAUSTED"}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-opus-4-6","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
requestedModel: "claude-opus-4-6",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp)
|
||||
require.Nil(t, result.switchError)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||
require.Empty(t, repo.modelRateLimitCalls, "overages 成功后不应写入普通 model_rate_limits")
|
||||
}
|
||||
|
||||
func TestHandleSmartRetry_RateLimited_DoesNotUseCredits(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Name: "acc-102",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
},
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, []string{"https://ag-1.test"})
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
require.NotContains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
require.Empty(t, repo.modelRateLimitCalls)
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_ModelRateLimited_InjectsCredits(t *testing.T) {
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvailability
|
||||
}()
|
||||
|
||||
antigravity.BaseURLs = []string{"https://ag-1.test"}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"ok":true}`)),
|
||||
},
|
||||
},
|
||||
errors: []error{nil},
|
||||
}
|
||||
// 模型已限流 + overages 启用 + 无 AICredits key → 应直接注入积分
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Name: "acc-103",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
require.Len(t, upstream.requestBodies, 1)
|
||||
require.Contains(t, string(upstream.requestBodies[0]), "enabledCreditTypes")
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_CreditsExhausted_DoesNotInject(t *testing.T) {
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvailability
|
||||
}()
|
||||
|
||||
antigravity.BaseURLs = []string{"https://ag-1.test"}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
|
||||
// 模型限流 + overages 启用 + AICredits key 生效 → 不应注入积分,应切号
|
||||
account := &Account{
|
||||
ID: 104,
|
||||
Name: "acc-104",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
|
||||
},
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
_, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
// 模型限流 + 积分耗尽 → 应触发切号错误
|
||||
require.Error(t, err)
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, err, &switchErr)
|
||||
}
|
||||
|
||||
func TestAntigravityRetryLoop_CreditErrorMarksExhausted(t *testing.T) {
|
||||
oldBaseURLs := append([]string(nil), antigravity.BaseURLs...)
|
||||
oldAvailability := antigravity.DefaultURLAvailability
|
||||
defer func() {
|
||||
antigravity.BaseURLs = oldBaseURLs
|
||||
antigravity.DefaultURLAvailability = oldAvailability
|
||||
}()
|
||||
|
||||
antigravity.BaseURLs = []string{"https://ag-1.test"}
|
||||
antigravity.DefaultURLAvailability = antigravity.NewURLAvailability(time.Minute)
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
upstream := &queuedHTTPUpstreamStub{
|
||||
responses: []*http.Response{
|
||||
{
|
||||
StatusCode: http.StatusForbidden,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"error":{"message":"Insufficient GOOGLE_ONE_AI credits"}}`)),
|
||||
},
|
||||
},
|
||||
errors: []error{nil},
|
||||
}
|
||||
// 模型限流 + overages 启用 + 积分可用 → 注入积分但上游返回积分不足
|
||||
account := &Account{
|
||||
ID: 105,
|
||||
Name: "acc-105",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Status: StatusActive,
|
||||
Schedulable: true,
|
||||
Extra: map[string]any{
|
||||
"allow_overages": true,
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": time.Now().UTC().Format(time.RFC3339),
|
||||
"rate_limit_reset_at": time.Now().Add(30 * time.Minute).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: context.Background(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"model":"claude-sonnet-4-5","request":{}}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, result)
|
||||
// 验证 AICredits key 已通过 SetModelRateLimit 写入数据库
|
||||
require.Len(t, repo.modelRateLimitCalls, 1, "应通过 SetModelRateLimit 写入 AICredits key")
|
||||
require.Equal(t, creditsExhaustedKey, repo.modelRateLimitCalls[0].modelKey)
|
||||
}
|
||||
|
||||
func TestShouldMarkCreditsExhausted(t *testing.T) {
|
||||
t.Run("reqErr 不为 nil 时不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusForbidden}
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), io.ErrUnexpectedEOF))
|
||||
})
|
||||
|
||||
t.Run("resp 为 nil 时不标记", func(t *testing.T) {
|
||||
require.False(t, shouldMarkCreditsExhausted(nil, []byte(`{"error":"Insufficient credits"}`), nil))
|
||||
})
|
||||
|
||||
t.Run("5xx 响应不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusInternalServerError}
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
|
||||
})
|
||||
|
||||
t.Run("408 RequestTimeout 不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusRequestTimeout}
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, []byte(`{"error":"Insufficient credits"}`), nil))
|
||||
})
|
||||
|
||||
t.Run("URL 级限流不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
|
||||
body := []byte(`{"error":{"message":"Resource has been exhausted"}}`)
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||
})
|
||||
|
||||
t.Run("结构化限流不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusTooManyRequests}
|
||||
body := []byte(`{"error":{"status":"RESOURCE_EXHAUSTED","details":[{"@type":"type.googleapis.com/google.rpc.ErrorInfo","reason":"RATE_LIMIT_EXCEEDED"},{"@type":"type.googleapis.com/google.rpc.RetryInfo","retryDelay":"0.5s"}]}}`)
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||
})
|
||||
|
||||
t.Run("含 credits 关键词时标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusForbidden}
|
||||
for _, keyword := range []string{
|
||||
"Insufficient GOOGLE_ONE_AI credits",
|
||||
"insufficient credit balance",
|
||||
"not enough credits for this request",
|
||||
"Credits exhausted",
|
||||
"minimumCreditAmountForUsage requirement not met",
|
||||
} {
|
||||
body := []byte(`{"error":{"message":"` + keyword + `"}}`)
|
||||
require.True(t, shouldMarkCreditsExhausted(resp, body, nil), "should mark for keyword: %s", keyword)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("无 credits 关键词时不标记", func(t *testing.T) {
|
||||
resp := &http.Response{StatusCode: http.StatusForbidden}
|
||||
body := []byte(`{"error":{"message":"permission denied"}}`)
|
||||
require.False(t, shouldMarkCreditsExhausted(resp, body, nil))
|
||||
})
|
||||
}
|
||||
|
||||
func TestInjectEnabledCreditTypes(t *testing.T) {
|
||||
t.Run("正常 JSON 注入成功", func(t *testing.T) {
|
||||
body := []byte(`{"model":"claude-sonnet-4-5","request":{}}`)
|
||||
result := injectEnabledCreditTypes(body)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, string(result), `"enabledCreditTypes"`)
|
||||
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
|
||||
})
|
||||
|
||||
t.Run("非法 JSON 返回 nil", func(t *testing.T) {
|
||||
require.Nil(t, injectEnabledCreditTypes([]byte(`not json`)))
|
||||
})
|
||||
|
||||
t.Run("空 body 返回 nil", func(t *testing.T) {
|
||||
require.Nil(t, injectEnabledCreditTypes([]byte{}))
|
||||
})
|
||||
|
||||
t.Run("已有 enabledCreditTypes 会被覆盖", func(t *testing.T) {
|
||||
body := []byte(`{"enabledCreditTypes":["OLD"],"model":"test"}`)
|
||||
result := injectEnabledCreditTypes(body)
|
||||
require.NotNil(t, result)
|
||||
require.Contains(t, string(result), `GOOGLE_ONE_AI`)
|
||||
require.NotContains(t, string(result), `OLD`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestClearCreditsExhausted(t *testing.T) {
|
||||
t.Run("account 为 nil 不操作", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
svc.clearCreditsExhausted(context.Background(), nil)
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
})
|
||||
|
||||
t.Run("Extra 为 nil 不操作", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
svc.clearCreditsExhausted(context.Background(), &Account{ID: 1})
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
})
|
||||
|
||||
t.Run("无 modelRateLimitsKey 不操作", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
svc.clearCreditsExhausted(context.Background(), &Account{
|
||||
ID: 1,
|
||||
Extra: map[string]any{"some_key": "value"},
|
||||
})
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
})
|
||||
|
||||
t.Run("无 AICredits key 不操作", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
svc.clearCreditsExhausted(context.Background(), &Account{
|
||||
ID: 1,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.Empty(t, repo.extraUpdateCalls)
|
||||
})
|
||||
|
||||
t.Run("有 AICredits key 时删除并调用 UpdateExtra", func(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
svc := &AntigravityGatewayService{accountRepo: repo}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": "2099-03-15T00:00:00Z",
|
||||
},
|
||||
creditsExhaustedKey: map[string]any{
|
||||
"rate_limited_at": "2026-03-15T00:00:00Z",
|
||||
"rate_limit_reset_at": time.Now().Add(5 * time.Hour).UTC().Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
svc.clearCreditsExhausted(context.Background(), account)
|
||||
require.Len(t, repo.extraUpdateCalls, 1)
|
||||
// AICredits key 应被删除
|
||||
rawLimits := account.Extra[modelRateLimitsKey].(map[string]any)
|
||||
_, exists := rawLimits[creditsExhaustedKey]
|
||||
require.False(t, exists, "AICredits key 应被删除")
|
||||
// 普通模型限流应保留
|
||||
_, exists = rawLimits["claude-sonnet-4-5"]
|
||||
require.True(t, exists, "普通模型限流应保留")
|
||||
})
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,123 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIsImageGenerationModel_GeminiProImage 测试 gemini-3-pro-image 识别
|
||||
func TestIsImageGenerationModel_GeminiProImage(t *testing.T) {
|
||||
require.True(t, isImageGenerationModel("gemini-3-pro-image"))
|
||||
require.True(t, isImageGenerationModel("gemini-3-pro-image-preview"))
|
||||
require.True(t, isImageGenerationModel("models/gemini-3-pro-image"))
|
||||
}
|
||||
|
||||
// TestIsImageGenerationModel_GeminiFlashImage 测试 gemini-2.5-flash-image 识别
|
||||
func TestIsImageGenerationModel_GeminiFlashImage(t *testing.T) {
|
||||
require.True(t, isImageGenerationModel("gemini-2.5-flash-image"))
|
||||
require.True(t, isImageGenerationModel("gemini-2.5-flash-image-preview"))
|
||||
}
|
||||
|
||||
// TestIsImageGenerationModel_RegularModel 测试普通模型不被识别为图片模型
|
||||
func TestIsImageGenerationModel_RegularModel(t *testing.T) {
|
||||
require.False(t, isImageGenerationModel("claude-3-opus"))
|
||||
require.False(t, isImageGenerationModel("claude-sonnet-4-20250514"))
|
||||
require.False(t, isImageGenerationModel("gpt-4o"))
|
||||
require.False(t, isImageGenerationModel("gemini-2.5-pro")) // 非图片模型
|
||||
require.False(t, isImageGenerationModel("gemini-2.5-flash"))
|
||||
// 验证不会误匹配包含关键词的自定义模型名
|
||||
require.False(t, isImageGenerationModel("my-gemini-3-pro-image-test"))
|
||||
require.False(t, isImageGenerationModel("custom-gemini-2.5-flash-image-wrapper"))
|
||||
}
|
||||
|
||||
// TestIsImageGenerationModel_CaseInsensitive 测试大小写不敏感
|
||||
func TestIsImageGenerationModel_CaseInsensitive(t *testing.T) {
|
||||
require.True(t, isImageGenerationModel("GEMINI-3-PRO-IMAGE"))
|
||||
require.True(t, isImageGenerationModel("Gemini-3-Pro-Image"))
|
||||
require.True(t, isImageGenerationModel("GEMINI-2.5-FLASH-IMAGE"))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_ValidSizes 测试有效尺寸解析
|
||||
func TestExtractImageSize_ValidSizes(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
// 1K
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1K"}}}`)
|
||||
require.Equal(t, "1K", svc.extractImageSize(body))
|
||||
|
||||
// 2K
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"2K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 4K
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4K"}}}`)
|
||||
require.Equal(t, "4K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_CaseInsensitive 测试大小写不敏感
|
||||
func TestExtractImageSize_CaseInsensitive(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"1k"}}}`)
|
||||
require.Equal(t, "1K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"4k"}}}`)
|
||||
require.Equal(t, "4K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_Default 测试无 imageConfig 返回默认 2K
|
||||
func TestExtractImageSize_Default(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
// 无 generationConfig
|
||||
body := []byte(`{"contents":[]}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 有 generationConfig 但无 imageConfig
|
||||
body = []byte(`{"generationConfig":{"temperature":0.7}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 有 imageConfig 但无 imageSize
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_InvalidJSON 测试非法 JSON 返回默认 2K
|
||||
func TestExtractImageSize_InvalidJSON(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`not valid json`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"broken":`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_EmptySize 测试空 imageSize 返回默认 2K
|
||||
func TestExtractImageSize_EmptySize(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":""}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
// 空格
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":" "}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
|
||||
// TestExtractImageSize_InvalidSize 测试无效尺寸返回默认 2K
|
||||
func TestExtractImageSize_InvalidSize(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
body := []byte(`{"generationConfig":{"imageConfig":{"imageSize":"3K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"8K"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
|
||||
body = []byte(`{"generationConfig":{"imageConfig":{"imageSize":"invalid"}}}`)
|
||||
require.Equal(t, "2K", svc.extractImageSize(body))
|
||||
}
|
||||
@@ -0,0 +1,285 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAntigravityGatewayService_GetMappedModel(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
accountMapping map[string]string
|
||||
expected string
|
||||
}{
|
||||
// 1. 账户级映射优先
|
||||
{
|
||||
name: "账户映射优先",
|
||||
requestedModel: "claude-3-5-sonnet-20241022",
|
||||
accountMapping: map[string]string{"claude-3-5-sonnet-20241022": "custom-model"},
|
||||
expected: "custom-model",
|
||||
},
|
||||
{
|
||||
name: "账户映射 - 可覆盖默认映射的模型",
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
accountMapping: map[string]string{"claude-sonnet-4-5": "my-custom-sonnet"},
|
||||
expected: "my-custom-sonnet",
|
||||
},
|
||||
{
|
||||
name: "账户映射 - 可覆盖未知模型",
|
||||
requestedModel: "claude-opus-4",
|
||||
accountMapping: map[string]string{"claude-opus-4": "my-opus"},
|
||||
expected: "my-opus",
|
||||
},
|
||||
|
||||
// 2. 默认映射(DefaultAntigravityModelMapping)
|
||||
{
|
||||
name: "默认映射 - claude-opus-4-6 → claude-opus-4-6-thinking",
|
||||
requestedModel: "claude-opus-4-6",
|
||||
accountMapping: nil,
|
||||
expected: "claude-opus-4-6-thinking",
|
||||
},
|
||||
{
|
||||
name: "默认映射 - claude-opus-4-5-20251101 → claude-opus-4-6-thinking",
|
||||
requestedModel: "claude-opus-4-5-20251101",
|
||||
accountMapping: nil,
|
||||
expected: "claude-opus-4-6-thinking",
|
||||
},
|
||||
{
|
||||
name: "默认映射 - claude-opus-4-5-thinking → claude-opus-4-6-thinking",
|
||||
requestedModel: "claude-opus-4-5-thinking",
|
||||
accountMapping: nil,
|
||||
expected: "claude-opus-4-6-thinking",
|
||||
},
|
||||
{
|
||||
name: "默认映射 - claude-haiku-4-5 → claude-sonnet-4-5",
|
||||
requestedModel: "claude-haiku-4-5",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "默认映射 - claude-haiku-4-5-20251001 → claude-sonnet-4-5",
|
||||
requestedModel: "claude-haiku-4-5-20251001",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "默认映射 - claude-sonnet-4-5-20250929 → claude-sonnet-4-5",
|
||||
requestedModel: "claude-sonnet-4-5-20250929",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
|
||||
// 3. 默认映射中的透传(映射到自己)
|
||||
{
|
||||
name: "默认映射透传 - claude-sonnet-4-6",
|
||||
requestedModel: "claude-sonnet-4-6",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-6",
|
||||
},
|
||||
{
|
||||
name: "默认映射透传 - claude-sonnet-4-5",
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "默认映射透传 - claude-opus-4-6-thinking",
|
||||
requestedModel: "claude-opus-4-6-thinking",
|
||||
accountMapping: nil,
|
||||
expected: "claude-opus-4-6-thinking",
|
||||
},
|
||||
{
|
||||
name: "默认映射透传 - claude-sonnet-4-5-thinking",
|
||||
requestedModel: "claude-sonnet-4-5-thinking",
|
||||
accountMapping: nil,
|
||||
expected: "claude-sonnet-4-5-thinking",
|
||||
},
|
||||
{
|
||||
name: "默认映射透传 - gemini-2.5-flash",
|
||||
requestedModel: "gemini-2.5-flash",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-2.5-flash",
|
||||
},
|
||||
{
|
||||
name: "默认映射透传 - gemini-2.5-pro",
|
||||
requestedModel: "gemini-2.5-pro",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-2.5-pro",
|
||||
},
|
||||
{
|
||||
name: "默认映射透传 - gemini-3-flash",
|
||||
requestedModel: "gemini-3-flash",
|
||||
accountMapping: nil,
|
||||
expected: "gemini-3-flash",
|
||||
},
|
||||
|
||||
// 4. 未在默认映射中的模型返回空字符串(不支持)
|
||||
{
|
||||
name: "未知模型 - claude-unknown 返回空",
|
||||
requestedModel: "claude-unknown",
|
||||
accountMapping: nil,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "未知模型 - claude-3-5-sonnet-20241022 返回空(未在默认映射)",
|
||||
requestedModel: "claude-3-5-sonnet-20241022",
|
||||
accountMapping: nil,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "未知模型 - claude-3-opus-20240229 返回空",
|
||||
requestedModel: "claude-3-opus-20240229",
|
||||
accountMapping: nil,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "未知模型 - claude-opus-4 返回空",
|
||||
requestedModel: "claude-opus-4",
|
||||
accountMapping: nil,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "未知模型 - gemini-future-model 返回空",
|
||||
requestedModel: "gemini-future-model",
|
||||
accountMapping: nil,
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
if tt.accountMapping != nil {
|
||||
// GetModelMapping 期望 model_mapping 是 map[string]any 格式
|
||||
mappingAny := make(map[string]any)
|
||||
for k, v := range tt.accountMapping {
|
||||
mappingAny[k] = v
|
||||
}
|
||||
account.Credentials = map[string]any{
|
||||
"model_mapping": mappingAny,
|
||||
}
|
||||
}
|
||||
|
||||
got := svc.getMappedModel(account, tt.requestedModel)
|
||||
require.Equal(t, tt.expected, got, "model: %s", tt.requestedModel)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_GetMappedModel_EdgeCases(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
requestedModel string
|
||||
expected string
|
||||
}{
|
||||
// 空字符串和非 claude/gemini 前缀返回空字符串
|
||||
{"空字符串", "", ""},
|
||||
{"非claude/gemini前缀 - gpt", "gpt-4", ""},
|
||||
{"非claude/gemini前缀 - llama", "llama-3", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{Platform: PlatformAntigravity}
|
||||
got := svc.getMappedModel(account, tt.requestedModel)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAntigravityGatewayService_IsModelSupported(t *testing.T) {
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
// 直接支持
|
||||
{"直接支持 - claude-sonnet-4-5", "claude-sonnet-4-5", true},
|
||||
{"直接支持 - gemini-3-flash", "gemini-3-flash", true},
|
||||
|
||||
// 可映射(有明确前缀映射)
|
||||
{"可映射 - claude-opus-4-6", "claude-opus-4-6", true},
|
||||
|
||||
// 前缀透传(claude 和 gemini 前缀)
|
||||
{"Gemini前缀", "gemini-unknown", true},
|
||||
{"Claude前缀", "claude-unknown", true},
|
||||
|
||||
// 不支持
|
||||
{"不支持 - gpt-4", "gpt-4", false},
|
||||
{"不支持 - 空字符串", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := svc.IsModelSupported(tt.model)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMapAntigravityModel_WildcardTargetEqualsRequest 测试通配符映射目标恰好等于请求模型名的 edge case
|
||||
// 例如 {"claude-*": "claude-sonnet-4-5"},请求 "claude-sonnet-4-5" 时应该通过
|
||||
func TestMapAntigravityModel_WildcardTargetEqualsRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelMapping map[string]any
|
||||
requestedModel string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "wildcard target equals request model",
|
||||
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "wildcard target differs from request model",
|
||||
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"},
|
||||
requestedModel: "claude-opus-4-6",
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "wildcard no match",
|
||||
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5"},
|
||||
requestedModel: "gpt-4o",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "explicit passthrough same name",
|
||||
modelMapping: map[string]any{"claude-sonnet-4-5": "claude-sonnet-4-5"},
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "multiple wildcards target equals one request",
|
||||
modelMapping: map[string]any{"claude-*": "claude-sonnet-4-5", "gemini-*": "gemini-2.5-flash"},
|
||||
requestedModel: "gemini-2.5-flash",
|
||||
expected: "gemini-2.5-flash",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Credentials: map[string]any{
|
||||
"model_mapping": tt.modelMapping,
|
||||
},
|
||||
}
|
||||
got := mapAntigravityModel(account, tt.requestedModel)
|
||||
require.Equal(t, tt.expected, got, "mapAntigravityModel(%q) = %q, want %q", tt.requestedModel, got, tt.expected)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,436 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
type AntigravityOAuthService struct {
|
||||
sessionStore *antigravity.SessionStore
|
||||
proxyRepo ProxyRepository
|
||||
}
|
||||
|
||||
func NewAntigravityOAuthService(proxyRepo ProxyRepository) *AntigravityOAuthService {
|
||||
return &AntigravityOAuthService{
|
||||
sessionStore: antigravity.NewSessionStore(),
|
||||
proxyRepo: proxyRepo,
|
||||
}
|
||||
}
|
||||
|
||||
// AntigravityAuthURLResult is the result of generating an authorization URL
|
||||
type AntigravityAuthURLResult struct {
|
||||
AuthURL string `json:"auth_url"`
|
||||
SessionID string `json:"session_id"`
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
// GenerateAuthURL 生成 Google OAuth 授权链接
|
||||
func (s *AntigravityOAuthService) GenerateAuthURL(ctx context.Context, proxyID *int64) (*AntigravityAuthURLResult, error) {
|
||||
state, err := antigravity.GenerateState()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成 state 失败: %w", err)
|
||||
}
|
||||
|
||||
codeVerifier, err := antigravity.GenerateCodeVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成 code_verifier 失败: %w", err)
|
||||
}
|
||||
|
||||
sessionID, err := antigravity.GenerateSessionID()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("生成 session_id 失败: %w", err)
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
session := &antigravity.OAuthSession{
|
||||
State: state,
|
||||
CodeVerifier: codeVerifier,
|
||||
ProxyURL: proxyURL,
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
s.sessionStore.Set(sessionID, session)
|
||||
|
||||
codeChallenge := antigravity.GenerateCodeChallenge(codeVerifier)
|
||||
authURL := antigravity.BuildAuthorizationURL(state, codeChallenge)
|
||||
|
||||
return &AntigravityAuthURLResult{
|
||||
AuthURL: authURL,
|
||||
SessionID: sessionID,
|
||||
State: state,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AntigravityExchangeCodeInput 交换 code 的输入
|
||||
type AntigravityExchangeCodeInput struct {
|
||||
SessionID string
|
||||
State string
|
||||
Code string
|
||||
ProxyID *int64
|
||||
}
|
||||
|
||||
// AntigravityTokenInfo token 信息
|
||||
type AntigravityTokenInfo struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
ExpiresIn int64 `json:"expires_in"`
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
TokenType string `json:"token_type"`
|
||||
Email string `json:"email,omitempty"`
|
||||
ProjectID string `json:"project_id,omitempty"`
|
||||
ProjectIDMissing bool `json:"-"` // LoadCodeAssist 未返回 project_id
|
||||
}
|
||||
|
||||
// ExchangeCode 用 authorization code 交换 token
|
||||
func (s *AntigravityOAuthService) ExchangeCode(ctx context.Context, input *AntigravityExchangeCodeInput) (*AntigravityTokenInfo, error) {
|
||||
session, ok := s.sessionStore.Get(input.SessionID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session 不存在或已过期")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(input.State) == "" || input.State != session.State {
|
||||
return nil, fmt.Errorf("state 无效")
|
||||
}
|
||||
|
||||
// 确定代理 URL
|
||||
proxyURL := session.ProxyURL
|
||||
if input.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *input.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
client, err := antigravity.NewClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create antigravity client failed: %w", err)
|
||||
}
|
||||
|
||||
// 交换 token
|
||||
tokenResp, err := client.ExchangeCode(ctx, input.Code, session.CodeVerifier)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token 交换失败: %w", err)
|
||||
}
|
||||
|
||||
// 删除 session
|
||||
s.sessionStore.Delete(input.SessionID)
|
||||
|
||||
// 计算过期时间(减去 5 分钟安全窗口)
|
||||
expiresAt := time.Now().Unix() + tokenResp.ExpiresIn - 300
|
||||
|
||||
result := &AntigravityTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
TokenType: tokenResp.TokenType,
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
userInfo, err := client.GetUserInfo(ctx, tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
|
||||
} else {
|
||||
result.Email = userInfo.Email
|
||||
}
|
||||
|
||||
// 获取 project_id(部分账户类型可能没有),失败时重试
|
||||
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenResp.AccessToken, proxyURL, 3)
|
||||
if loadErr != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
|
||||
result.ProjectIDMissing = true
|
||||
} else {
|
||||
result.ProjectID = projectID
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// RefreshToken 刷新 token
|
||||
func (s *AntigravityOAuthService) RefreshToken(ctx context.Context, refreshToken, proxyURL string) (*AntigravityTokenInfo, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= 3; attempt++ {
|
||||
if attempt > 0 {
|
||||
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
|
||||
if backoff > 30*time.Second {
|
||||
backoff = 30 * time.Second
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
client, err := antigravity.NewClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create antigravity client failed: %w", err)
|
||||
}
|
||||
tokenResp, err := client.RefreshToken(ctx, refreshToken)
|
||||
if err == nil {
|
||||
now := time.Now()
|
||||
expiresAt := now.Unix() + tokenResp.ExpiresIn - 300
|
||||
fmt.Printf("[AntigravityOAuth] Token refreshed: expires_in=%d, expires_at=%d (%s)\n",
|
||||
tokenResp.ExpiresIn, expiresAt, time.Unix(expiresAt, 0).Format("2006-01-02 15:04:05"))
|
||||
return &AntigravityTokenInfo{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
ExpiresIn: tokenResp.ExpiresIn,
|
||||
ExpiresAt: expiresAt,
|
||||
TokenType: tokenResp.TokenType,
|
||||
}, nil
|
||||
}
|
||||
|
||||
if isNonRetryableAntigravityOAuthError(err) {
|
||||
return nil, err
|
||||
}
|
||||
lastErr = err
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("token 刷新失败 (重试后): %w", lastErr)
|
||||
}
|
||||
|
||||
// ValidateRefreshToken 用 refresh token 验证并获取完整的 token 信息(含 email 和 project_id)
|
||||
func (s *AntigravityOAuthService) ValidateRefreshToken(ctx context.Context, refreshToken string, proxyID *int64) (*AntigravityTokenInfo, error) {
|
||||
var proxyURL string
|
||||
if proxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *proxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
// 刷新 token
|
||||
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 获取用户信息(email)
|
||||
client, err := antigravity.NewClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create antigravity client failed: %w", err)
|
||||
}
|
||||
userInfo, err := client.GetUserInfo(ctx, tokenInfo.AccessToken)
|
||||
if err != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取用户信息失败: %v\n", err)
|
||||
} else {
|
||||
tokenInfo.Email = userInfo.Email
|
||||
}
|
||||
|
||||
// 获取 project_id(容错,失败不阻塞)
|
||||
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
|
||||
if loadErr != nil {
|
||||
fmt.Printf("[AntigravityOAuth] 警告: 获取 project_id 失败(重试后): %v\n", loadErr)
|
||||
tokenInfo.ProjectIDMissing = true
|
||||
} else {
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
func isNonRetryableAntigravityOAuthError(err error) bool {
|
||||
msg := err.Error()
|
||||
nonRetryable := []string{
|
||||
"invalid_grant",
|
||||
"invalid_client",
|
||||
"unauthorized_client",
|
||||
"access_denied",
|
||||
}
|
||||
for _, needle := range nonRetryable {
|
||||
if strings.Contains(msg, needle) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RefreshAccountToken 刷新账户的 token
|
||||
func (s *AntigravityOAuthService) RefreshAccountToken(ctx context.Context, account *Account) (*AntigravityTokenInfo, error) {
|
||||
if account.Platform != PlatformAntigravity || account.Type != AccountTypeOAuth {
|
||||
return nil, fmt.Errorf("非 Antigravity OAuth 账户")
|
||||
}
|
||||
|
||||
refreshToken := account.GetCredential("refresh_token")
|
||||
if strings.TrimSpace(refreshToken) == "" {
|
||||
return nil, fmt.Errorf("无可用的 refresh_token")
|
||||
}
|
||||
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
|
||||
tokenInfo, err := s.RefreshToken(ctx, refreshToken, proxyURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 保留原有的 email
|
||||
existingEmail := strings.TrimSpace(account.GetCredential("email"))
|
||||
if existingEmail != "" {
|
||||
tokenInfo.Email = existingEmail
|
||||
}
|
||||
|
||||
// 每次刷新都调用 LoadCodeAssist 获取 project_id,失败时重试
|
||||
existingProjectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
projectID, loadErr := s.loadProjectIDWithRetry(ctx, tokenInfo.AccessToken, proxyURL, 3)
|
||||
|
||||
if loadErr != nil {
|
||||
// LoadCodeAssist 失败,保留原有 project_id
|
||||
tokenInfo.ProjectID = existingProjectID
|
||||
// 只有从未获取过 project_id 且本次也获取失败时,才标记为真正缺失
|
||||
// 如果之前有 project_id,本次只是临时故障,不应标记为错误
|
||||
if existingProjectID == "" {
|
||||
tokenInfo.ProjectIDMissing = true
|
||||
}
|
||||
} else {
|
||||
tokenInfo.ProjectID = projectID
|
||||
}
|
||||
|
||||
return tokenInfo, nil
|
||||
}
|
||||
|
||||
// loadProjectIDWithRetry 带重试机制获取 project_id
|
||||
// 返回 project_id 和错误,失败时会重试指定次数
|
||||
func (s *AntigravityOAuthService) loadProjectIDWithRetry(ctx context.Context, accessToken, proxyURL string, maxRetries int) (string, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt <= maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// 指数退避:1s, 2s, 4s
|
||||
backoff := time.Duration(1<<uint(attempt-1)) * time.Second
|
||||
if backoff > 8*time.Second {
|
||||
backoff = 8 * time.Second
|
||||
}
|
||||
time.Sleep(backoff)
|
||||
}
|
||||
|
||||
client, err := antigravity.NewClient(proxyURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create antigravity client failed: %w", err)
|
||||
}
|
||||
loadResp, loadRaw, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
|
||||
if err == nil && loadResp != nil && loadResp.CloudAICompanionProject != "" {
|
||||
return loadResp.CloudAICompanionProject, nil
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
if projectID, onboardErr := tryOnboardProjectID(ctx, client, accessToken, loadRaw); onboardErr == nil && projectID != "" {
|
||||
return projectID, nil
|
||||
} else if onboardErr != nil {
|
||||
lastErr = onboardErr
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 记录错误
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
} else if loadResp == nil {
|
||||
lastErr = fmt.Errorf("LoadCodeAssist 返回空响应")
|
||||
} else {
|
||||
lastErr = fmt.Errorf("LoadCodeAssist 返回空 project_id")
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("获取 project_id 失败 (重试 %d 次后): %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func tryOnboardProjectID(ctx context.Context, client *antigravity.Client, accessToken string, loadRaw map[string]any) (string, error) {
|
||||
tierID := resolveDefaultTierID(loadRaw)
|
||||
if tierID == "" {
|
||||
return "", fmt.Errorf("loadCodeAssist 未返回可用的默认 tier")
|
||||
}
|
||||
|
||||
projectID, err := client.OnboardUser(ctx, accessToken, tierID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("onboardUser 失败 (tier=%s): %w", tierID, err)
|
||||
}
|
||||
return projectID, nil
|
||||
}
|
||||
|
||||
func resolveDefaultTierID(loadRaw map[string]any) string {
|
||||
if len(loadRaw) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
rawTiers, ok := loadRaw["allowedTiers"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
tiers, ok := rawTiers.([]any)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, rawTier := range tiers {
|
||||
tier, ok := rawTier.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isDefault, _ := tier["isDefault"].(bool); !isDefault {
|
||||
continue
|
||||
}
|
||||
if id, ok := tier["id"].(string); ok {
|
||||
id = strings.TrimSpace(id)
|
||||
if id != "" {
|
||||
return id
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// FillProjectID 仅获取 project_id,不刷新 OAuth token
|
||||
func (s *AntigravityOAuthService) FillProjectID(ctx context.Context, account *Account, accessToken string) (string, error) {
|
||||
var proxyURL string
|
||||
if account.ProxyID != nil {
|
||||
proxy, err := s.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err == nil && proxy != nil {
|
||||
proxyURL = proxy.URL()
|
||||
}
|
||||
}
|
||||
return s.loadProjectIDWithRetry(ctx, accessToken, proxyURL, 3)
|
||||
}
|
||||
|
||||
// BuildAccountCredentials 构建账户凭证
|
||||
func (s *AntigravityOAuthService) BuildAccountCredentials(tokenInfo *AntigravityTokenInfo) map[string]any {
|
||||
creds := map[string]any{
|
||||
"access_token": tokenInfo.AccessToken,
|
||||
"expires_at": strconv.FormatInt(tokenInfo.ExpiresAt, 10),
|
||||
}
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
creds["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
if tokenInfo.TokenType != "" {
|
||||
creds["token_type"] = tokenInfo.TokenType
|
||||
}
|
||||
if tokenInfo.Email != "" {
|
||||
creds["email"] = tokenInfo.Email
|
||||
}
|
||||
if tokenInfo.ProjectID != "" {
|
||||
creds["project_id"] = tokenInfo.ProjectID
|
||||
}
|
||||
return creds
|
||||
}
|
||||
|
||||
// Stop 停止服务
|
||||
func (s *AntigravityOAuthService) Stop() {
|
||||
s.sessionStore.Stop()
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestResolveDefaultTierID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
loadRaw map[string]any
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "nil loadRaw",
|
||||
loadRaw: nil,
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "missing allowedTiers",
|
||||
loadRaw: map[string]any{
|
||||
"paidTier": map[string]any{"id": "g1-pro-tier"},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty allowedTiers",
|
||||
loadRaw: map[string]any{"allowedTiers": []any{}},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "tier missing id field",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"isDefault": true},
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "allowedTiers but no default",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": "free-tier", "isDefault": false},
|
||||
map[string]any{"id": "standard-tier", "isDefault": false},
|
||||
},
|
||||
},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "default tier found",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": "free-tier", "isDefault": true},
|
||||
map[string]any{"id": "standard-tier", "isDefault": false},
|
||||
},
|
||||
},
|
||||
want: "free-tier",
|
||||
},
|
||||
{
|
||||
name: "default tier id with spaces",
|
||||
loadRaw: map[string]any{
|
||||
"allowedTiers": []any{
|
||||
map[string]any{"id": " standard-tier ", "isDefault": true},
|
||||
},
|
||||
},
|
||||
want: "standard-tier",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := resolveDefaultTierID(tc.loadRaw)
|
||||
if got != tc.want {
|
||||
t.Fatalf("resolveDefaultTierID() = %q, want %q", got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,272 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
const (
|
||||
forbiddenTypeValidation = "validation"
|
||||
forbiddenTypeViolation = "violation"
|
||||
forbiddenTypeForbidden = "forbidden"
|
||||
|
||||
// 机器可读的错误码
|
||||
errorCodeForbidden = "forbidden"
|
||||
errorCodeUnauthenticated = "unauthenticated"
|
||||
errorCodeRateLimited = "rate_limited"
|
||||
errorCodeNetworkError = "network_error"
|
||||
)
|
||||
|
||||
// AntigravityQuotaFetcher 从 Antigravity API 获取额度
|
||||
type AntigravityQuotaFetcher struct {
|
||||
proxyRepo ProxyRepository
|
||||
}
|
||||
|
||||
// NewAntigravityQuotaFetcher 创建 AntigravityQuotaFetcher
|
||||
func NewAntigravityQuotaFetcher(proxyRepo ProxyRepository) *AntigravityQuotaFetcher {
|
||||
return &AntigravityQuotaFetcher{proxyRepo: proxyRepo}
|
||||
}
|
||||
|
||||
// CanFetch 检查是否可以获取此账户的额度
|
||||
func (f *AntigravityQuotaFetcher) CanFetch(account *Account) bool {
|
||||
if account.Platform != PlatformAntigravity {
|
||||
return false
|
||||
}
|
||||
accessToken := account.GetCredential("access_token")
|
||||
return accessToken != ""
|
||||
}
|
||||
|
||||
// FetchQuota 获取 Antigravity 账户额度信息
|
||||
func (f *AntigravityQuotaFetcher) FetchQuota(ctx context.Context, account *Account, proxyURL string) (*QuotaResult, error) {
|
||||
accessToken := account.GetCredential("access_token")
|
||||
projectID := account.GetCredential("project_id")
|
||||
|
||||
client, err := antigravity.NewClient(proxyURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create antigravity client failed: %w", err)
|
||||
}
|
||||
|
||||
// 调用 API 获取配额
|
||||
modelsResp, modelsRaw, err := client.FetchAvailableModels(ctx, accessToken, projectID)
|
||||
if err != nil {
|
||||
// 403 Forbidden: 不报错,返回 is_forbidden 标记
|
||||
var forbiddenErr *antigravity.ForbiddenError
|
||||
if errors.As(err, &forbiddenErr) {
|
||||
now := time.Now()
|
||||
fbType := classifyForbiddenType(forbiddenErr.Body)
|
||||
return &QuotaResult{
|
||||
UsageInfo: &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
IsForbidden: true,
|
||||
ForbiddenReason: forbiddenErr.Body,
|
||||
ForbiddenType: fbType,
|
||||
ValidationURL: extractValidationURL(forbiddenErr.Body),
|
||||
NeedsVerify: fbType == forbiddenTypeValidation,
|
||||
IsBanned: fbType == forbiddenTypeViolation,
|
||||
ErrorCode: errorCodeForbidden,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 调用 LoadCodeAssist 获取订阅等级和 AI Credits 余额(非关键路径,失败不影响主流程)
|
||||
tierRaw, tierNormalized, loadResp := f.fetchSubscriptionTier(ctx, client, accessToken)
|
||||
|
||||
// 转换为 UsageInfo
|
||||
usageInfo := f.buildUsageInfo(modelsResp, tierRaw, tierNormalized, loadResp)
|
||||
|
||||
return &QuotaResult{
|
||||
UsageInfo: usageInfo,
|
||||
Raw: modelsRaw,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// fetchSubscriptionTier 获取账号订阅等级,失败返回空字符串。
|
||||
// 同时返回 LoadCodeAssistResponse,以便提取 AI Credits 余额。
|
||||
func (f *AntigravityQuotaFetcher) fetchSubscriptionTier(ctx context.Context, client *antigravity.Client, accessToken string) (raw, normalized string, loadResp *antigravity.LoadCodeAssistResponse) {
|
||||
loadResp, _, err := client.LoadCodeAssist(ctx, accessToken)
|
||||
if err != nil {
|
||||
slog.Warn("failed to fetch subscription tier", "error", err)
|
||||
return "", "", nil
|
||||
}
|
||||
if loadResp == nil {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
raw = loadResp.GetTier() // 已有方法:paidTier > currentTier
|
||||
normalized = normalizeTier(raw)
|
||||
return raw, normalized, loadResp
|
||||
}
|
||||
|
||||
// normalizeTier 将原始 tier 字符串归一化为 FREE/PRO/ULTRA/UNKNOWN
|
||||
func normalizeTier(raw string) string {
|
||||
if raw == "" {
|
||||
return ""
|
||||
}
|
||||
lower := strings.ToLower(raw)
|
||||
switch {
|
||||
case strings.Contains(lower, "ultra"):
|
||||
return "ULTRA"
|
||||
case strings.Contains(lower, "pro"):
|
||||
return "PRO"
|
||||
case strings.Contains(lower, "free"):
|
||||
return "FREE"
|
||||
default:
|
||||
return "UNKNOWN"
|
||||
}
|
||||
}
|
||||
|
||||
// buildUsageInfo 将 API 响应转换为 UsageInfo。
|
||||
func (f *AntigravityQuotaFetcher) buildUsageInfo(modelsResp *antigravity.FetchAvailableModelsResponse, tierRaw, tierNormalized string, loadResp *antigravity.LoadCodeAssistResponse) *UsageInfo {
|
||||
now := time.Now()
|
||||
info := &UsageInfo{
|
||||
UpdatedAt: &now,
|
||||
AntigravityQuota: make(map[string]*AntigravityModelQuota),
|
||||
AntigravityQuotaDetails: make(map[string]*AntigravityModelDetail),
|
||||
SubscriptionTier: tierNormalized,
|
||||
SubscriptionTierRaw: tierRaw,
|
||||
}
|
||||
|
||||
// 遍历所有模型,填充 AntigravityQuota 和 AntigravityQuotaDetails
|
||||
for modelName, modelInfo := range modelsResp.Models {
|
||||
if modelInfo.QuotaInfo == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// remainingFraction 是剩余比例 (0.0-1.0),转换为使用率百分比
|
||||
utilization := int((1.0 - modelInfo.QuotaInfo.RemainingFraction) * 100)
|
||||
|
||||
info.AntigravityQuota[modelName] = &AntigravityModelQuota{
|
||||
Utilization: utilization,
|
||||
ResetTime: modelInfo.QuotaInfo.ResetTime,
|
||||
}
|
||||
|
||||
// 填充模型详细能力信息
|
||||
detail := &AntigravityModelDetail{
|
||||
DisplayName: modelInfo.DisplayName,
|
||||
SupportsImages: modelInfo.SupportsImages,
|
||||
SupportsThinking: modelInfo.SupportsThinking,
|
||||
ThinkingBudget: modelInfo.ThinkingBudget,
|
||||
Recommended: modelInfo.Recommended,
|
||||
MaxTokens: modelInfo.MaxTokens,
|
||||
MaxOutputTokens: modelInfo.MaxOutputTokens,
|
||||
SupportedMimeTypes: modelInfo.SupportedMimeTypes,
|
||||
}
|
||||
info.AntigravityQuotaDetails[modelName] = detail
|
||||
}
|
||||
|
||||
// 废弃模型转发规则
|
||||
if len(modelsResp.DeprecatedModelIDs) > 0 {
|
||||
info.ModelForwardingRules = make(map[string]string, len(modelsResp.DeprecatedModelIDs))
|
||||
for oldID, deprecated := range modelsResp.DeprecatedModelIDs {
|
||||
info.ModelForwardingRules[oldID] = deprecated.NewModelID
|
||||
}
|
||||
}
|
||||
|
||||
// 同时设置 FiveHour 用于兼容展示(取主要模型)
|
||||
priorityModels := []string{"claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"}
|
||||
for _, modelName := range priorityModels {
|
||||
if modelInfo, ok := modelsResp.Models[modelName]; ok && modelInfo.QuotaInfo != nil {
|
||||
utilization := (1.0 - modelInfo.QuotaInfo.RemainingFraction) * 100
|
||||
progress := &UsageProgress{
|
||||
Utilization: utilization,
|
||||
}
|
||||
if modelInfo.QuotaInfo.ResetTime != "" {
|
||||
if resetTime, err := time.Parse(time.RFC3339, modelInfo.QuotaInfo.ResetTime); err == nil {
|
||||
progress.ResetsAt = &resetTime
|
||||
progress.RemainingSeconds = int(time.Until(resetTime).Seconds())
|
||||
}
|
||||
}
|
||||
info.FiveHour = progress
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if loadResp != nil {
|
||||
for _, credit := range loadResp.GetAvailableCredits() {
|
||||
info.AICredits = append(info.AICredits, AICredit{
|
||||
CreditType: credit.CreditType,
|
||||
Amount: credit.GetAmount(),
|
||||
MinimumBalance: credit.GetMinimumAmount(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return info
|
||||
}
|
||||
|
||||
// GetProxyURL 获取账户的代理 URL
|
||||
func (f *AntigravityQuotaFetcher) GetProxyURL(ctx context.Context, account *Account) string {
|
||||
if account.ProxyID == nil || f.proxyRepo == nil {
|
||||
return ""
|
||||
}
|
||||
proxy, err := f.proxyRepo.GetByID(ctx, *account.ProxyID)
|
||||
if err != nil || proxy == nil {
|
||||
return ""
|
||||
}
|
||||
return proxy.URL()
|
||||
}
|
||||
|
||||
// classifyForbiddenType 根据 403 响应体判断禁止类型
|
||||
func classifyForbiddenType(body string) string {
|
||||
lower := strings.ToLower(body)
|
||||
switch {
|
||||
case strings.Contains(lower, "validation_required") ||
|
||||
strings.Contains(lower, "verify your account") ||
|
||||
strings.Contains(lower, "validation_url"):
|
||||
return forbiddenTypeValidation
|
||||
case strings.Contains(lower, "terms of service") ||
|
||||
strings.Contains(lower, "violation"):
|
||||
return forbiddenTypeViolation
|
||||
default:
|
||||
return forbiddenTypeForbidden
|
||||
}
|
||||
}
|
||||
|
||||
// urlPattern 用于从 403 响应体中提取 URL(降级方案)
|
||||
var urlPattern = regexp.MustCompile(`https://[^\s"'\\]+`)
|
||||
|
||||
// extractValidationURL 从 403 响应 JSON 中提取验证/申诉链接
|
||||
func extractValidationURL(body string) string {
|
||||
// 1. 尝试结构化 JSON 提取: /error/details[*]/metadata/validation_url 或 appeal_url
|
||||
var parsed struct {
|
||||
Error struct {
|
||||
Details []struct {
|
||||
Metadata map[string]string `json:"metadata"`
|
||||
} `json:"details"`
|
||||
} `json:"error"`
|
||||
}
|
||||
if json.Unmarshal([]byte(body), &parsed) == nil {
|
||||
for _, detail := range parsed.Error.Details {
|
||||
if u := detail.Metadata["validation_url"]; u != "" {
|
||||
return u
|
||||
}
|
||||
if u := detail.Metadata["appeal_url"]; u != "" {
|
||||
return u
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 降级:正则匹配 URL
|
||||
lower := strings.ToLower(body)
|
||||
if !strings.Contains(lower, "validation") &&
|
||||
!strings.Contains(lower, "verify") &&
|
||||
!strings.Contains(lower, "appeal") {
|
||||
return ""
|
||||
}
|
||||
// 先解码常见转义再匹配
|
||||
normalized := strings.ReplaceAll(body, `\u0026`, "&")
|
||||
if m := urlPattern.FindString(normalized); m != "" {
|
||||
return m
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -0,0 +1,522 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/antigravity"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// normalizeTier
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestNormalizeTier(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
expected string
|
||||
}{
|
||||
{name: "empty string", raw: "", expected: ""},
|
||||
{name: "free-tier", raw: "free-tier", expected: "FREE"},
|
||||
{name: "g1-pro-tier", raw: "g1-pro-tier", expected: "PRO"},
|
||||
{name: "g1-ultra-tier", raw: "g1-ultra-tier", expected: "ULTRA"},
|
||||
{name: "unknown-something", raw: "unknown-something", expected: "UNKNOWN"},
|
||||
{name: "Google AI Pro contains pro keyword", raw: "Google AI Pro", expected: "PRO"},
|
||||
{name: "case insensitive FREE", raw: "FREE-TIER", expected: "FREE"},
|
||||
{name: "case insensitive Ultra", raw: "Ultra Plan", expected: "ULTRA"},
|
||||
{name: "arbitrary unrecognized string", raw: "enterprise-custom", expected: "UNKNOWN"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := normalizeTier(tt.raw)
|
||||
require.Equal(t, tt.expected, got, "normalizeTier(%q)", tt.raw)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// buildUsageInfo
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func aqfBoolPtr(v bool) *bool { return &v }
|
||||
func aqfIntPtr(v int) *int { return &v }
|
||||
|
||||
func TestBuildUsageInfo_BasicModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.75,
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
DisplayName: "Claude Sonnet 4",
|
||||
SupportsImages: aqfBoolPtr(true),
|
||||
SupportsThinking: aqfBoolPtr(false),
|
||||
ThinkingBudget: aqfIntPtr(0),
|
||||
Recommended: aqfBoolPtr(true),
|
||||
MaxTokens: aqfIntPtr(200000),
|
||||
MaxOutputTokens: aqfIntPtr(16384),
|
||||
SupportedMimeTypes: map[string]bool{
|
||||
"image/png": true,
|
||||
"image/jpeg": true,
|
||||
},
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
ResetTime: "2026-03-08T15:00:00Z",
|
||||
},
|
||||
DisplayName: "Gemini 2.5 Pro",
|
||||
MaxTokens: aqfIntPtr(1000000),
|
||||
MaxOutputTokens: aqfIntPtr(65536),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", nil)
|
||||
|
||||
// 基本字段
|
||||
require.NotNil(t, info.UpdatedAt, "UpdatedAt should be set")
|
||||
require.Equal(t, "PRO", info.SubscriptionTier)
|
||||
require.Equal(t, "g1-pro-tier", info.SubscriptionTierRaw)
|
||||
|
||||
// AntigravityQuota
|
||||
require.Len(t, info.AntigravityQuota, 2)
|
||||
|
||||
sonnetQuota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, sonnetQuota)
|
||||
require.Equal(t, 25, sonnetQuota.Utilization) // (1 - 0.75) * 100 = 25
|
||||
require.Equal(t, "2026-03-08T12:00:00Z", sonnetQuota.ResetTime)
|
||||
|
||||
geminiQuota := info.AntigravityQuota["gemini-2.5-pro"]
|
||||
require.NotNil(t, geminiQuota)
|
||||
require.Equal(t, 50, geminiQuota.Utilization) // (1 - 0.50) * 100 = 50
|
||||
require.Equal(t, "2026-03-08T15:00:00Z", geminiQuota.ResetTime)
|
||||
|
||||
// AntigravityQuotaDetails
|
||||
require.Len(t, info.AntigravityQuotaDetails, 2)
|
||||
|
||||
sonnetDetail := info.AntigravityQuotaDetails["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, sonnetDetail)
|
||||
require.Equal(t, "Claude Sonnet 4", sonnetDetail.DisplayName)
|
||||
require.Equal(t, aqfBoolPtr(true), sonnetDetail.SupportsImages)
|
||||
require.Equal(t, aqfBoolPtr(false), sonnetDetail.SupportsThinking)
|
||||
require.Equal(t, aqfIntPtr(0), sonnetDetail.ThinkingBudget)
|
||||
require.Equal(t, aqfBoolPtr(true), sonnetDetail.Recommended)
|
||||
require.Equal(t, aqfIntPtr(200000), sonnetDetail.MaxTokens)
|
||||
require.Equal(t, aqfIntPtr(16384), sonnetDetail.MaxOutputTokens)
|
||||
require.Equal(t, map[string]bool{"image/png": true, "image/jpeg": true}, sonnetDetail.SupportedMimeTypes)
|
||||
|
||||
geminiDetail := info.AntigravityQuotaDetails["gemini-2.5-pro"]
|
||||
require.NotNil(t, geminiDetail)
|
||||
require.Equal(t, "Gemini 2.5 Pro", geminiDetail.DisplayName)
|
||||
require.Nil(t, geminiDetail.SupportsImages)
|
||||
require.Nil(t, geminiDetail.SupportsThinking)
|
||||
require.Equal(t, aqfIntPtr(1000000), geminiDetail.MaxTokens)
|
||||
require.Equal(t, aqfIntPtr(65536), geminiDetail.MaxOutputTokens)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_DeprecatedModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 1.0,
|
||||
},
|
||||
},
|
||||
},
|
||||
DeprecatedModelIDs: map[string]antigravity.DeprecatedModelInfo{
|
||||
"claude-3-sonnet-20240229": {NewModelID: "claude-sonnet-4-20250514"},
|
||||
"claude-3-haiku-20240307": {NewModelID: "claude-haiku-3.5-latest"},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.Len(t, info.ModelForwardingRules, 2)
|
||||
require.Equal(t, "claude-sonnet-4-20250514", info.ModelForwardingRules["claude-3-sonnet-20240229"])
|
||||
require.Equal(t, "claude-haiku-3.5-latest", info.ModelForwardingRules["claude-3-haiku-20240307"])
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_NoDeprecatedModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"some-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{RemainingFraction: 0.9},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.Nil(t, info.ModelForwardingRules, "ModelForwardingRules should be nil when no deprecated models")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_EmptyModels(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.NotNil(t, info.AntigravityQuota)
|
||||
require.Empty(t, info.AntigravityQuota)
|
||||
require.NotNil(t, info.AntigravityQuotaDetails)
|
||||
require.Empty(t, info.AntigravityQuotaDetails)
|
||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_ModelWithNilQuotaInfo(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"model-without-quota": {
|
||||
DisplayName: "No Quota Model",
|
||||
// QuotaInfo is nil
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info)
|
||||
require.Empty(t, info.AntigravityQuota, "models with nil QuotaInfo should be skipped")
|
||||
require.Empty(t, info.AntigravityQuotaDetails, "models with nil QuotaInfo should be skipped from details too")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourPriorityOrder(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// priorityModels = ["claude-sonnet-4-20250514", "claude-sonnet-4", "gemini-2.5-pro"]
|
||||
// When the first priority model exists, it should be used for FiveHour
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.40,
|
||||
ResetTime: "2026-03-08T18:00:00Z",
|
||||
},
|
||||
},
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.80,
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info.FiveHour, "FiveHour should be set when a priority model exists")
|
||||
// claude-sonnet-4-20250514 is first in priority list, so it should be used
|
||||
expectedUtilization := (1.0 - 0.80) * 100 // 20
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
require.NotNil(t, info.FiveHour.ResetsAt, "ResetsAt should be parsed from ResetTime")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourFallbackToClaude4(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// Only claude-sonnet-4 exists (second in priority list), not claude-sonnet-4-20250514
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.60,
|
||||
ResetTime: "2026-03-08T14:00:00Z",
|
||||
},
|
||||
},
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.30,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.60) * 100 // 40
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourFallbackToGemini(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// Only gemini-2.5-pro exists (third in priority list)
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"gemini-2.5-pro": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.30,
|
||||
},
|
||||
},
|
||||
"other-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.90,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
expectedUtilization := (1.0 - 0.30) * 100 // 70
|
||||
require.InDelta(t, expectedUtilization, info.FiveHour.Utilization, 0.01)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourNoPriorityModel(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
// None of the priority models exist
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"some-other-model": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.Nil(t, info.FiveHour, "FiveHour should be nil when no priority model exists")
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FiveHourWithEmptyResetTime(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.50,
|
||||
ResetTime: "", // empty reset time
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
require.NotNil(t, info.FiveHour)
|
||||
require.Nil(t, info.FiveHour.ResetsAt, "ResetsAt should be nil when ResetTime is empty")
|
||||
require.Equal(t, 0, info.FiveHour.RemainingSeconds)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_FullUtilization(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 0.0, // fully used
|
||||
ResetTime: "2026-03-08T12:00:00Z",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
require.Equal(t, 100, quota.Utilization)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_ZeroUtilization(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{
|
||||
"claude-sonnet-4-20250514": {
|
||||
QuotaInfo: &antigravity.ModelQuotaInfo{
|
||||
RemainingFraction: 1.0, // fully available
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "", "", nil)
|
||||
quota := info.AntigravityQuota["claude-sonnet-4-20250514"]
|
||||
require.NotNil(t, quota)
|
||||
require.Equal(t, 0, quota.Utilization)
|
||||
}
|
||||
|
||||
func TestBuildUsageInfo_AICredits(t *testing.T) {
|
||||
fetcher := &AntigravityQuotaFetcher{}
|
||||
modelsResp := &antigravity.FetchAvailableModelsResponse{
|
||||
Models: map[string]antigravity.ModelInfo{},
|
||||
}
|
||||
loadResp := &antigravity.LoadCodeAssistResponse{
|
||||
PaidTier: &antigravity.PaidTierInfo{
|
||||
ID: "g1-pro-tier",
|
||||
AvailableCredits: []antigravity.AvailableCredit{
|
||||
{
|
||||
CreditType: "GOOGLE_ONE_AI",
|
||||
CreditAmount: "25",
|
||||
MinimumCreditAmountForUsage: "5",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
info := fetcher.buildUsageInfo(modelsResp, "g1-pro-tier", "PRO", loadResp)
|
||||
|
||||
require.Len(t, info.AICredits, 1)
|
||||
require.Equal(t, "GOOGLE_ONE_AI", info.AICredits[0].CreditType)
|
||||
require.Equal(t, 25.0, info.AICredits[0].Amount)
|
||||
require.Equal(t, 5.0, info.AICredits[0].MinimumBalance)
|
||||
}
|
||||
|
||||
func TestFetchQuota_ForbiddenReturnsIsForbidden(t *testing.T) {
|
||||
// 模拟 FetchQuota 遇到 403 时的行为:
|
||||
// FetchAvailableModels 返回 ForbiddenError → FetchQuota 应返回 is_forbidden=true
|
||||
forbiddenErr := &antigravity.ForbiddenError{
|
||||
StatusCode: 403,
|
||||
Body: "Access denied",
|
||||
}
|
||||
|
||||
// 验证 ForbiddenError 满足 errors.As
|
||||
var target *antigravity.ForbiddenError
|
||||
require.True(t, errors.As(forbiddenErr, &target))
|
||||
require.Equal(t, 403, target.StatusCode)
|
||||
require.Equal(t, "Access denied", target.Body)
|
||||
require.Contains(t, forbiddenErr.Error(), "403")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// classifyForbiddenType
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestClassifyForbiddenType(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "VALIDATION_REQUIRED keyword",
|
||||
body: `{"error":{"message":"VALIDATION_REQUIRED"}}`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "verify your account",
|
||||
body: `Please verify your account to continue`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "contains validation_url field",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://..."}}]}}`,
|
||||
expected: "validation",
|
||||
},
|
||||
{
|
||||
name: "terms of service violation",
|
||||
body: `Your account has been suspended for Terms of Service violation`,
|
||||
expected: "violation",
|
||||
},
|
||||
{
|
||||
name: "violation keyword",
|
||||
body: `Account suspended due to policy violation`,
|
||||
expected: "violation",
|
||||
},
|
||||
{
|
||||
name: "generic 403",
|
||||
body: `Access denied`,
|
||||
expected: "forbidden",
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
expected: "forbidden",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := classifyForbiddenType(tt.body)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// extractValidationURL
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestExtractValidationURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "structured validation_url",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://accounts.google.com/verify?token=abc"}}]}}`,
|
||||
expected: "https://accounts.google.com/verify?token=abc",
|
||||
},
|
||||
{
|
||||
name: "structured appeal_url",
|
||||
body: `{"error":{"details":[{"metadata":{"appeal_url":"https://support.google.com/appeal/123"}}]}}`,
|
||||
expected: "https://support.google.com/appeal/123",
|
||||
},
|
||||
{
|
||||
name: "validation_url takes priority over appeal_url",
|
||||
body: `{"error":{"details":[{"metadata":{"validation_url":"https://v.com","appeal_url":"https://a.com"}}]}}`,
|
||||
expected: "https://v.com",
|
||||
},
|
||||
{
|
||||
name: "fallback regex with verify keyword",
|
||||
body: `Please verify your account at https://accounts.google.com/verify`,
|
||||
expected: "https://accounts.google.com/verify",
|
||||
},
|
||||
{
|
||||
name: "no URL in generic forbidden",
|
||||
body: `Access denied`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "empty body",
|
||||
body: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "URL present but no validation keywords",
|
||||
body: `Error at https://example.com/something`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "unicode escaped ampersand",
|
||||
body: `validation required: https://accounts.google.com/verify?a=1\u0026b=2`,
|
||||
expected: "https://accounts.google.com/verify?a=1&b=2",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := extractValidationURL(tt.body)
|
||||
require.Equal(t, tt.expected, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,57 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
func normalizeAntigravityModelName(model string) string {
|
||||
normalized := strings.ToLower(strings.TrimSpace(model))
|
||||
normalized = strings.TrimPrefix(normalized, "models/")
|
||||
return normalized
|
||||
}
|
||||
|
||||
// resolveAntigravityModelKey 根据请求的模型名解析限流 key
|
||||
// 返回空字符串表示无法解析
|
||||
func resolveAntigravityModelKey(requestedModel string) string {
|
||||
return normalizeAntigravityModelName(requestedModel)
|
||||
}
|
||||
|
||||
// IsSchedulableForModel 结合模型级限流判断是否可调度。
|
||||
// 保持旧签名以兼容既有调用方;默认使用 context.Background()。
|
||||
func (a *Account) IsSchedulableForModel(requestedModel string) bool {
|
||||
return a.IsSchedulableForModelWithContext(context.Background(), requestedModel)
|
||||
}
|
||||
|
||||
func (a *Account) IsSchedulableForModelWithContext(ctx context.Context, requestedModel string) bool {
|
||||
if a == nil {
|
||||
return false
|
||||
}
|
||||
if !a.IsSchedulable() {
|
||||
return false
|
||||
}
|
||||
if a.isModelRateLimitedWithContext(ctx, requestedModel) {
|
||||
// Antigravity + overages 启用 + 积分未耗尽 → 放行(有积分可用)
|
||||
if a.Platform == PlatformAntigravity && a.IsOveragesEnabled() && !a.isCreditsExhausted() {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// GetRateLimitRemainingTime 获取限流剩余时间(模型级限流)
|
||||
// 返回 0 表示未限流或已过期
|
||||
func (a *Account) GetRateLimitRemainingTime(requestedModel string) time.Duration {
|
||||
return a.GetRateLimitRemainingTimeWithContext(context.Background(), requestedModel)
|
||||
}
|
||||
|
||||
// GetRateLimitRemainingTimeWithContext 获取限流剩余时间(模型级限流)
|
||||
// 返回 0 表示未限流或已过期
|
||||
func (a *Account) GetRateLimitRemainingTimeWithContext(ctx context.Context, requestedModel string) time.Duration {
|
||||
if a == nil {
|
||||
return 0
|
||||
}
|
||||
return a.GetModelRateLimitRemainingTimeWithContext(ctx, requestedModel)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,904 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 辅助函数:构造带 SingleAccountRetry 标记的 context
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func ctxWithSingleAccountRetry() context.Context {
|
||||
return context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 1. isSingleAccountRetry 测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestIsSingleAccountRetry_True(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, true)
|
||||
require.True(t, isSingleAccountRetry(ctx))
|
||||
}
|
||||
|
||||
func TestIsSingleAccountRetry_False_NoValue(t *testing.T) {
|
||||
require.False(t, isSingleAccountRetry(context.Background()))
|
||||
}
|
||||
|
||||
func TestIsSingleAccountRetry_False_ExplicitFalse(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, false)
|
||||
require.False(t, isSingleAccountRetry(ctx))
|
||||
}
|
||||
|
||||
func TestIsSingleAccountRetry_False_WrongType(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), ctxkey.SingleAccountRetry, "true")
|
||||
require.False(t, isSingleAccountRetry(ctx))
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 2. 常量验证
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func TestSingleAccountRetryConstants(t *testing.T) {
|
||||
require.Equal(t, 3, antigravitySingleAccountSmartRetryMaxAttempts,
|
||||
"单账号原地重试最多 3 次")
|
||||
require.Equal(t, 15*time.Second, antigravitySingleAccountSmartRetryMaxWait,
|
||||
"单次最大等待 15s")
|
||||
require.Equal(t, 30*time.Second, antigravitySingleAccountSmartRetryTotalMaxWait,
|
||||
"总累计等待不超过 30s")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 3. handleSmartRetry + 503 + SingleAccountRetry → 走 handleSingleAccountRetryInPlace
|
||||
// (而非设模型限流 + 切换账号)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace
|
||||
// 核心场景:503 + retryDelay >= 7s + SingleAccountRetry 标记
|
||||
// → 不设模型限流、不切换账号,改为原地重试
|
||||
func TestHandleSmartRetry_503_LongDelay_SingleAccountRetry_RetryInPlace(t *testing.T) {
|
||||
// 原地重试成功
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 1,
|
||||
Name: "acc-single",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
// 503 + 39s >= 7s 阈值 + MODEL_CAPACITY_EXHAUSTED
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
|
||||
],
|
||||
"message": "No capacity available for model gemini-3-pro-high on the server"
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(), // 关键:设置单账号标记
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
// 关键断言:返回 resp(原地重试成功),而非 switchError(切换账号)
|
||||
require.NotNil(t, result.resp, "should return successful response from in-place retry")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.Nil(t, result.switchError, "should NOT return switchError in single account mode")
|
||||
require.Nil(t, result.err)
|
||||
|
||||
// 验证未设模型限流(单账号模式不应设限流)
|
||||
require.Len(t, repo.modelRateLimitCalls, 0,
|
||||
"should NOT set model rate limit in single account retry mode")
|
||||
|
||||
// 验证确实调用了 upstream(原地重试)
|
||||
require.GreaterOrEqual(t, len(upstream.calls), 1, "should have made at least one retry call")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches
|
||||
// 对照组:503 + retryDelay >= 7s + 无 SingleAccountRetry 标记
|
||||
// → 照常设模型限流 + 切换账号
|
||||
func TestHandleSmartRetry_503_LongDelay_NoSingleAccountRetry_StillSwitches(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 2,
|
||||
Name: "acc-multi",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
// 503 + 39s >= 7s 阈值(使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED,
|
||||
// 因为 MODEL_CAPACITY_EXHAUSTED 走独立的重试路径,不触发 shouldRateLimitModel)
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro-high"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "39s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(), // 关键:无单账号标记
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
// 对照:多账号模式返回 switchError
|
||||
require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503")
|
||||
require.Nil(t, result.resp, "should not return resp when switchError is set")
|
||||
|
||||
// 对照:多账号模式应设模型限流
|
||||
require.Len(t, repo.modelRateLimitCalls, 1,
|
||||
"multi-account mode SHOULD set model rate limit")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches
|
||||
// 边界情况:429(非 503)+ SingleAccountRetry 标记
|
||||
// → 单账号原地重试仅针对 503,429 依然走切换账号逻辑
|
||||
func TestHandleSmartRetry_429_LongDelay_SingleAccountRetry_StillSwitches(t *testing.T) {
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 3,
|
||||
Name: "acc-429",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
// 429 + 15s >= 7s 阈值
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "claude-sonnet-4-5"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "15s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusTooManyRequests, // 429,不是 503
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(), // 有单账号标记
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
// 429 即使有单账号标记,也应走切换账号
|
||||
require.NotNil(t, result.switchError, "429 should still return switchError even with SingleAccountRetry")
|
||||
require.Len(t, repo.modelRateLimitCalls, 1,
|
||||
"429 should still set model rate limit even with SingleAccountRetry")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 4. handleSmartRetry + 503 + 短延迟 + SingleAccountRetry → 智能重试耗尽后不设限流
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit
|
||||
// 503 + retryDelay < 7s + SingleAccountRetry → 智能重试耗尽后直接返回 503,不设限流
|
||||
func TestHandleSmartRetry_503_ShortDelay_SingleAccountRetry_NoRateLimit(t *testing.T) {
|
||||
// 智能重试也返回 503
|
||||
failRespBody := `{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`
|
||||
failResp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{failResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 4,
|
||||
Name: "acc-short-503",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
// 0.1s < 7s 阈值
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
// 关键断言:单账号 503 模式下,智能重试耗尽后直接返回 503 响应,不切换
|
||||
require.NotNil(t, result.resp, "should return 503 response directly for single account mode")
|
||||
require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode)
|
||||
require.Nil(t, result.switchError, "should NOT switch account in single account mode")
|
||||
|
||||
// 关键断言:不设模型限流
|
||||
require.Len(t, repo.modelRateLimitCalls, 0,
|
||||
"should NOT set model rate limit for 503 in single account mode")
|
||||
}
|
||||
|
||||
// TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit
|
||||
// 对照组:503 + retryDelay < 7s + 无 SingleAccountRetry → 智能重试耗尽后照常设限流
|
||||
// 使用 RATE_LIMIT_EXCEEDED 而非 MODEL_CAPACITY_EXHAUSTED,因为后者走独立的 60 次重试路径
|
||||
func TestHandleSmartRetry_503_ShortDelay_NoSingleAccountRetry_SetsRateLimit(t *testing.T) {
|
||||
failRespBody := `{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`
|
||||
failResp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(failRespBody)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{failResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 5,
|
||||
Name: "acc-multi-503",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
}
|
||||
|
||||
respBody := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "RESOURCE_EXHAUSTED",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-flash"}, "reason": "RATE_LIMIT_EXCEEDED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(respBody)),
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: context.Background(), // 无单账号标记
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
availableURLs := []string{"https://ag-1.test"}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSmartRetry(params, resp, respBody, "https://ag-1.test", 0, availableURLs)
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
// 对照:多账号模式应返回 switchError
|
||||
require.NotNil(t, result.switchError, "multi-account mode should return switchError for 503")
|
||||
// 对照:多账号模式应设模型限流
|
||||
require.Len(t, repo.modelRateLimitCalls, 1,
|
||||
"multi-account mode should set model rate limit")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 5. handleSingleAccountRetryInPlace 直接测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestHandleSingleAccountRetryInPlace_Success 原地重试成功
|
||||
func TestHandleSingleAccountRetryInPlace_Success(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 10,
|
||||
Name: "acc-inplace-ok",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp, "should return successful response")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.Nil(t, result.switchError, "should not switch account on success")
|
||||
require.Nil(t, result.err)
|
||||
}
|
||||
|
||||
// TestHandleSingleAccountRetryInPlace_AllRetriesFail 所有重试都失败,返回 503(不设限流)
|
||||
func TestHandleSingleAccountRetryInPlace_AllRetriesFail(t *testing.T) {
|
||||
// 构造 3 个 503 响应(对应 3 次原地重试)
|
||||
var responses []*http.Response
|
||||
var errors []error
|
||||
for i := 0; i < antigravitySingleAccountSmartRetryMaxAttempts; i++ {
|
||||
responses = append(responses, &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`)),
|
||||
})
|
||||
errors = append(errors, nil)
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: responses,
|
||||
errors: errors,
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 11,
|
||||
Name: "acc-inplace-fail",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
origBody := []byte(`{"error":{"code":503,"status":"UNAVAILABLE"}}`)
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{"X-Test": {"original"}},
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSingleAccountRetryInPlace(params, resp, origBody, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
// 关键:返回 503 resp,不返回 switchError
|
||||
require.NotNil(t, result.resp, "should return 503 response directly")
|
||||
require.Equal(t, http.StatusServiceUnavailable, result.resp.StatusCode)
|
||||
require.Nil(t, result.switchError, "should NOT return switchError - let Handler handle it")
|
||||
require.Nil(t, result.err)
|
||||
|
||||
// 验证确实重试了指定次数
|
||||
require.Len(t, upstream.calls, antigravitySingleAccountSmartRetryMaxAttempts,
|
||||
"should have made exactly maxAttempts retry calls")
|
||||
}
|
||||
|
||||
// TestHandleSingleAccountRetryInPlace_WaitDurationClamped 等待时间被限制在 [min, max] 范围
|
||||
func TestHandleSingleAccountRetryInPlace_WaitDurationClamped(t *testing.T) {
|
||||
// 用短延迟的成功响应,只验证不 panic
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{successResp},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 12,
|
||||
Name: "acc-clamp",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
|
||||
// 等待时间过大应被 clamp 到 antigravitySingleAccountSmartRetryMaxWait
|
||||
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 999*time.Second, "gemini-3-pro")
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp)
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
}
|
||||
|
||||
// TestHandleSingleAccountRetryInPlace_ContextCanceled context 取消时立即返回
|
||||
func TestHandleSingleAccountRetryInPlace_ContextCanceled(t *testing.T) {
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{nil},
|
||||
errors: []error{nil},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 13,
|
||||
Name: "acc-cancel",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ctx = context.WithValue(ctx, ctxkey.SingleAccountRetry, true)
|
||||
cancel() // 立即取消
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctx,
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.Error(t, result.err, "should return context error")
|
||||
// 不应调用 upstream(因为在等待阶段就被取消了)
|
||||
require.Len(t, upstream.calls, 0, "should not call upstream when context is canceled")
|
||||
}
|
||||
|
||||
// TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry 网络错误时继续重试
|
||||
func TestHandleSingleAccountRetryInPlace_NetworkError_ContinuesRetry(t *testing.T) {
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
// 第1次网络错误(nil resp),第2次成功
|
||||
responses: []*http.Response{nil, successResp},
|
||||
errors: []error{nil, nil},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 14,
|
||||
Name: "acc-net-retry",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp, "should return successful response after network error recovery")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.Len(t, upstream.calls, 2, "first call fails (network error), second succeeds")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 6. antigravityRetryLoop 预检查:单账号模式跳过限流
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit
|
||||
// 预检查中,如果有 SingleAccountRetry 标记,即使账号已限流也跳过直接发请求
|
||||
func TestAntigravityRetryLoop_PreCheck_SingleAccountRetry_SkipsRateLimit(t *testing.T) {
|
||||
// 创建一个已设模型限流的账号
|
||||
upstream := &recordingOKUpstream{}
|
||||
account := &Account{
|
||||
ID: 20,
|
||||
Name: "acc-rate-limited",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err, "should not return error")
|
||||
require.NotNil(t, result, "should return result")
|
||||
require.NotNil(t, result.resp, "should have response")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
// 关键:尽管限流了,有 SingleAccountRetry 标记时仍然到达了 upstream
|
||||
require.Equal(t, 1, upstream.calls, "should have reached upstream despite rate limit")
|
||||
}
|
||||
|
||||
// TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit
|
||||
// 对照组:无 SingleAccountRetry + 已限流 → 预检查返回 switchError
|
||||
func TestAntigravityRetryLoop_PreCheck_NoSingleAccountRetry_SwitchesOnRateLimit(t *testing.T) {
|
||||
upstream := &recordingOKUpstream{}
|
||||
account := &Account{
|
||||
ID: 21,
|
||||
Name: "acc-rate-limited-multi",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
Extra: map[string]any{
|
||||
modelRateLimitsKey: map[string]any{
|
||||
"claude-sonnet-4-5": map[string]any{
|
||||
"rate_limit_reset_at": time.Now().Add(30 * time.Second).Format(time.RFC3339),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: context.Background(), // 无单账号标记
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
requestedModel: "claude-sonnet-4-5",
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.Nil(t, result, "should not return result on rate limit switch")
|
||||
require.NotNil(t, err, "should return error")
|
||||
|
||||
var switchErr *AntigravityAccountSwitchError
|
||||
require.ErrorAs(t, err, &switchErr, "should return AntigravityAccountSwitchError")
|
||||
require.Equal(t, account.ID, switchErr.OriginalAccountID)
|
||||
require.Equal(t, "claude-sonnet-4-5", switchErr.RateLimitedModel)
|
||||
|
||||
// upstream 不应被调用(预检查就短路了)
|
||||
require.Equal(t, 0, upstream.calls, "upstream should NOT be called when pre-check blocks")
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// 7. 端到端集成场景测试
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
// TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E
|
||||
// 端到端场景:503 + 单账号 + 原地重试第2次成功
|
||||
func TestHandleSmartRetry_503_SingleAccount_RetryInPlace_ThenSuccess_E2E(t *testing.T) {
|
||||
// 第1次原地重试仍返回 503,第2次成功
|
||||
fail503Body := `{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "0.1s"}
|
||||
]
|
||||
}
|
||||
}`
|
||||
resp503 := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(fail503Body)),
|
||||
}
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
responses: []*http.Response{resp503, successResp},
|
||||
errors: []error{nil, nil},
|
||||
}
|
||||
|
||||
account := &Account{
|
||||
ID: 30,
|
||||
Name: "acc-e2e",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
}
|
||||
|
||||
params := antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result := svc.handleSingleAccountRetryInPlace(params, resp, nil, "https://ag-1.test", 1*time.Second, "gemini-3-pro")
|
||||
|
||||
require.NotNil(t, result)
|
||||
require.Equal(t, smartRetryActionBreakWithResp, result.action)
|
||||
require.NotNil(t, result.resp, "should return successful response after 2nd attempt")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
require.Nil(t, result.switchError)
|
||||
require.Len(t, upstream.calls, 2, "first 503, second OK")
|
||||
}
|
||||
|
||||
// TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E
|
||||
// 通过 antigravityRetryLoop → handleSmartRetry → handleSingleAccountRetryInPlace 完整链路
|
||||
func TestAntigravityRetryLoop_503_SingleAccount_InPlaceRetryUsed_E2E(t *testing.T) {
|
||||
// 初始请求返回 503 + 长延迟
|
||||
initial503Body := []byte(`{
|
||||
"error": {
|
||||
"code": 503,
|
||||
"status": "UNAVAILABLE",
|
||||
"details": [
|
||||
{"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"model": "gemini-3-pro"}, "reason": "MODEL_CAPACITY_EXHAUSTED"},
|
||||
{"@type": "type.googleapis.com/google.rpc.RetryInfo", "retryDelay": "10s"}
|
||||
],
|
||||
"message": "No capacity available"
|
||||
}
|
||||
}`)
|
||||
initial503Resp := &http.Response{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(initial503Body)),
|
||||
}
|
||||
|
||||
// 原地重试成功
|
||||
successResp := &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader(`{"result":"ok"}`)),
|
||||
}
|
||||
|
||||
upstream := &mockSmartRetryUpstream{
|
||||
// 第1次调用(retryLoop 主循环)返回 503
|
||||
// 第2次调用(handleSingleAccountRetryInPlace 原地重试)返回 200
|
||||
responses: []*http.Response{initial503Resp, successResp},
|
||||
errors: []error{nil, nil},
|
||||
}
|
||||
|
||||
repo := &stubAntigravityAccountRepo{}
|
||||
account := &Account{
|
||||
ID: 31,
|
||||
Name: "acc-e2e-loop",
|
||||
Type: AccountTypeOAuth,
|
||||
Platform: PlatformAntigravity,
|
||||
Schedulable: true,
|
||||
Status: StatusActive,
|
||||
Concurrency: 1,
|
||||
}
|
||||
|
||||
svc := &AntigravityGatewayService{}
|
||||
result, err := svc.antigravityRetryLoop(antigravityRetryLoopParams{
|
||||
ctx: ctxWithSingleAccountRetry(),
|
||||
prefix: "[test]",
|
||||
account: account,
|
||||
accessToken: "token",
|
||||
action: "generateContent",
|
||||
body: []byte(`{"input":"test"}`),
|
||||
httpUpstream: upstream,
|
||||
accountRepo: repo,
|
||||
handleError: func(ctx context.Context, prefix string, account *Account, statusCode int, headers http.Header, body []byte, requestedModel string, groupID int64, sessionHash string, isStickySession bool) *handleModelRateLimitResult {
|
||||
return nil
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err, "should not return error on successful retry")
|
||||
require.NotNil(t, result, "should return result")
|
||||
require.NotNil(t, result.resp, "should return response")
|
||||
require.Equal(t, http.StatusOK, result.resp.StatusCode)
|
||||
|
||||
// 验证未设模型限流
|
||||
require.Len(t, repo.modelRateLimitCalls, 0,
|
||||
"should NOT set model rate limit in single account retry mode")
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,68 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestApplyThinkingModelSuffix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
mappedModel string
|
||||
thinkingEnabled bool
|
||||
expected string
|
||||
}{
|
||||
// Thinking 未开启:保持原样
|
||||
{
|
||||
name: "thinking disabled - claude-sonnet-4-5 unchanged",
|
||||
mappedModel: "claude-sonnet-4-5",
|
||||
thinkingEnabled: false,
|
||||
expected: "claude-sonnet-4-5",
|
||||
},
|
||||
{
|
||||
name: "thinking disabled - other model unchanged",
|
||||
mappedModel: "claude-opus-4-6-thinking",
|
||||
thinkingEnabled: false,
|
||||
expected: "claude-opus-4-6-thinking",
|
||||
},
|
||||
|
||||
// Thinking 开启 + claude-sonnet-4-5:自动添加后缀
|
||||
{
|
||||
name: "thinking enabled - claude-sonnet-4-5 becomes thinking version",
|
||||
mappedModel: "claude-sonnet-4-5",
|
||||
thinkingEnabled: true,
|
||||
expected: "claude-sonnet-4-5-thinking",
|
||||
},
|
||||
|
||||
// Thinking 开启 + 其他模型:保持原样
|
||||
{
|
||||
name: "thinking enabled - claude-sonnet-4-5-thinking unchanged",
|
||||
mappedModel: "claude-sonnet-4-5-thinking",
|
||||
thinkingEnabled: true,
|
||||
expected: "claude-sonnet-4-5-thinking",
|
||||
},
|
||||
{
|
||||
name: "thinking enabled - claude-opus-4-6-thinking unchanged",
|
||||
mappedModel: "claude-opus-4-6-thinking",
|
||||
thinkingEnabled: true,
|
||||
expected: "claude-opus-4-6-thinking",
|
||||
},
|
||||
{
|
||||
name: "thinking enabled - gemini model unchanged",
|
||||
mappedModel: "gemini-3-flash",
|
||||
thinkingEnabled: true,
|
||||
expected: "gemini-3-flash",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := applyThinkingModelSuffix(tt.mappedModel, tt.thinkingEnabled)
|
||||
if result != tt.expected {
|
||||
t.Errorf("applyThinkingModelSuffix(%q, %v) = %q, want %q",
|
||||
tt.mappedModel, tt.thinkingEnabled, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityTokenRefreshSkew = 3 * time.Minute
|
||||
antigravityTokenCacheSkew = 5 * time.Minute
|
||||
antigravityBackfillCooldown = 5 * time.Minute
|
||||
)
|
||||
|
||||
// AntigravityTokenCache token cache interface.
|
||||
type AntigravityTokenCache = GeminiTokenCache
|
||||
|
||||
// AntigravityTokenProvider manages access_token for antigravity accounts.
|
||||
type AntigravityTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache AntigravityTokenCache
|
||||
antigravityOAuthService *AntigravityOAuthService
|
||||
backfillCooldown sync.Map // key: accountID -> last attempt time
|
||||
refreshAPI *OAuthRefreshAPI
|
||||
executor OAuthRefreshExecutor
|
||||
refreshPolicy ProviderRefreshPolicy
|
||||
}
|
||||
|
||||
func NewAntigravityTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache AntigravityTokenCache,
|
||||
antigravityOAuthService *AntigravityOAuthService,
|
||||
) *AntigravityTokenProvider {
|
||||
return &AntigravityTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
antigravityOAuthService: antigravityOAuthService,
|
||||
refreshPolicy: AntigravityProviderRefreshPolicy(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||
func (p *AntigravityTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||
p.refreshAPI = api
|
||||
p.executor = executor
|
||||
}
|
||||
|
||||
// SetRefreshPolicy injects caller-side refresh policy.
|
||||
func (p *AntigravityTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||
p.refreshPolicy = policy
|
||||
}
|
||||
|
||||
// GetAccessToken returns a valid access_token.
|
||||
func (p *AntigravityTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAntigravity {
|
||||
return "", errors.New("not an antigravity account")
|
||||
}
|
||||
|
||||
// upstream accounts use static api_key and never refresh oauth token.
|
||||
if account.Type == AccountTypeUpstream {
|
||||
apiKey := account.GetCredential("api_key")
|
||||
if apiKey == "" {
|
||||
return "", errors.New("upstream account missing api_key in credentials")
|
||||
}
|
||||
return apiKey, nil
|
||||
}
|
||||
if account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an antigravity oauth account")
|
||||
}
|
||||
|
||||
cacheKey := AntigravityTokenCacheKey(account)
|
||||
|
||||
// 1) Try cache first.
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Refresh if needed (pre-expiry skew).
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= antigravityTokenRefreshSkew
|
||||
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, antigravityTokenRefreshSkew)
|
||||
if err != nil {
|
||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||
return "", err
|
||||
}
|
||||
} else if result.LockHeld {
|
||||
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
// default policy: continue with existing token.
|
||||
} else {
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
} else if needsRefresh && p.tokenCache != nil {
|
||||
// Backward-compatible test path when refreshAPI is not injected.
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// Backfill project_id online when missing, with cooldown to avoid hammering.
|
||||
if strings.TrimSpace(account.GetCredential("project_id")) == "" && p.antigravityOAuthService != nil {
|
||||
if p.shouldAttemptBackfill(account.ID) {
|
||||
p.markBackfillAttempted(account.ID)
|
||||
if projectID, err := p.antigravityOAuthService.FillProjectID(ctx, account, accessToken); err == nil && projectID != "" {
|
||||
account.Credentials["project_id"] = projectID
|
||||
if updateErr := p.accountRepo.Update(ctx, account); updateErr != nil {
|
||||
slog.Warn("antigravity_project_id_backfill_persist_failed",
|
||||
"account_id", account.ID,
|
||||
"error", updateErr,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3) Populate cache with TTL.
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
slog.Debug("antigravity_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > antigravityTokenCacheSkew:
|
||||
ttl = until - antigravityTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// shouldAttemptBackfill checks backfill cooldown.
|
||||
func (p *AntigravityTokenProvider) shouldAttemptBackfill(accountID int64) bool {
|
||||
if v, ok := p.backfillCooldown.Load(accountID); ok {
|
||||
if lastAttempt, ok := v.(time.Time); ok {
|
||||
return time.Since(lastAttempt) > antigravityBackfillCooldown
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *AntigravityTokenProvider) markBackfillAttempted(accountID int64) {
|
||||
p.backfillCooldown.Store(accountID, time.Now())
|
||||
}
|
||||
|
||||
func AntigravityTokenCacheKey(account *Account) string {
|
||||
projectID := strings.TrimSpace(account.GetCredential("project_id"))
|
||||
if projectID != "" {
|
||||
return "ag:" + projectID
|
||||
}
|
||||
return "ag:account:" + strconv.FormatInt(account.ID, 10)
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAntigravityTokenProvider_GetAccessToken_Upstream(t *testing.T) {
|
||||
provider := &AntigravityTokenProvider{}
|
||||
|
||||
t.Run("upstream account with valid api_key", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeUpstream,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "sk-test-key-12345",
|
||||
},
|
||||
}
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "sk-test-key-12345", token)
|
||||
})
|
||||
|
||||
t.Run("upstream account missing api_key", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeUpstream,
|
||||
Credentials: map[string]any{},
|
||||
}
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "upstream account missing api_key")
|
||||
require.Empty(t, token)
|
||||
})
|
||||
|
||||
t.Run("upstream account with empty api_key", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeUpstream,
|
||||
Credentials: map[string]any{
|
||||
"api_key": "",
|
||||
},
|
||||
}
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "upstream account missing api_key")
|
||||
require.Empty(t, token)
|
||||
})
|
||||
|
||||
t.Run("upstream account with nil credentials", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeUpstream,
|
||||
}
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "upstream account missing api_key")
|
||||
require.Empty(t, token)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAntigravityTokenProvider_GetAccessToken_Guards(t *testing.T) {
|
||||
provider := &AntigravityTokenProvider{}
|
||||
|
||||
t.Run("nil account", func(t *testing.T) {
|
||||
token, err := provider.GetAccessToken(context.Background(), nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "account is nil")
|
||||
require.Empty(t, token)
|
||||
})
|
||||
|
||||
t.Run("non-antigravity platform", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an antigravity account")
|
||||
require.Empty(t, token)
|
||||
})
|
||||
|
||||
t.Run("unsupported account type", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAntigravity,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an antigravity oauth account")
|
||||
require.Empty(t, token)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,90 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// antigravityRefreshWindow Antigravity token 提前刷新窗口:15分钟
|
||||
// Google OAuth token 有效期55分钟,提前15分钟刷新
|
||||
antigravityRefreshWindow = 15 * time.Minute
|
||||
)
|
||||
|
||||
// AntigravityTokenRefresher 实现 TokenRefresher 接口
|
||||
type AntigravityTokenRefresher struct {
|
||||
antigravityOAuthService *AntigravityOAuthService
|
||||
}
|
||||
|
||||
func NewAntigravityTokenRefresher(antigravityOAuthService *AntigravityOAuthService) *AntigravityTokenRefresher {
|
||||
return &AntigravityTokenRefresher{
|
||||
antigravityOAuthService: antigravityOAuthService,
|
||||
}
|
||||
}
|
||||
|
||||
// CacheKey 返回用于分布式锁的缓存键
|
||||
func (r *AntigravityTokenRefresher) CacheKey(account *Account) string {
|
||||
return AntigravityTokenCacheKey(account)
|
||||
}
|
||||
|
||||
// CanRefresh 检查是否可以刷新此账户
|
||||
func (r *AntigravityTokenRefresher) CanRefresh(account *Account) bool {
|
||||
return account.Platform == PlatformAntigravity && account.Type == AccountTypeOAuth
|
||||
}
|
||||
|
||||
// NeedsRefresh 检查账户是否需要刷新
|
||||
// Antigravity 使用固定的15分钟刷新窗口,忽略全局配置
|
||||
func (r *AntigravityTokenRefresher) NeedsRefresh(account *Account, _ time.Duration) bool {
|
||||
if !r.CanRefresh(account) {
|
||||
return false
|
||||
}
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil {
|
||||
return false
|
||||
}
|
||||
timeUntilExpiry := time.Until(*expiresAt)
|
||||
needsRefresh := timeUntilExpiry < antigravityRefreshWindow
|
||||
if needsRefresh {
|
||||
fmt.Printf("[AntigravityTokenRefresher] Account %d needs refresh: expires_at=%s, time_until_expiry=%v, window=%v\n",
|
||||
account.ID, expiresAt.Format("2006-01-02 15:04:05"), timeUntilExpiry, antigravityRefreshWindow)
|
||||
}
|
||||
return needsRefresh
|
||||
}
|
||||
|
||||
// Refresh 执行 token 刷新
|
||||
func (r *AntigravityTokenRefresher) Refresh(ctx context.Context, account *Account) (map[string]any, error) {
|
||||
tokenInfo, err := r.antigravityOAuthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
newCredentials := r.antigravityOAuthService.BuildAccountCredentials(tokenInfo)
|
||||
// 合并旧的 credentials,保留新 credentials 中不存在的字段
|
||||
newCredentials = MergeCredentials(account.Credentials, newCredentials)
|
||||
|
||||
// 特殊处理 project_id:如果新值为空但旧值非空,保留旧值
|
||||
// 这确保了即使 LoadCodeAssist 失败,project_id 也不会丢失
|
||||
if newProjectID, _ := newCredentials["project_id"].(string); newProjectID == "" {
|
||||
if oldProjectID := strings.TrimSpace(account.GetCredential("project_id")); oldProjectID != "" {
|
||||
newCredentials["project_id"] = oldProjectID
|
||||
}
|
||||
}
|
||||
|
||||
// 如果 project_id 获取失败,只记录警告,不返回错误
|
||||
// LoadCodeAssist 失败可能是临时网络问题,应该允许重试而不是立即标记为不可重试错误
|
||||
// Token 刷新本身是成功的(access_token 和 refresh_token 已更新)
|
||||
if tokenInfo.ProjectIDMissing {
|
||||
if tokenInfo.ProjectID != "" {
|
||||
// 有旧的 project_id,本次获取失败,保留旧值
|
||||
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 临时失败,保留旧 project_id", account.ID)
|
||||
} else {
|
||||
// 从未获取过 project_id,本次也失败,但不返回错误以允许下次重试
|
||||
log.Printf("[AntigravityTokenRefresher] Account %d: LoadCodeAssist 失败,project_id 缺失,但 token 已更新,将在下次刷新时重试", account.ID)
|
||||
}
|
||||
}
|
||||
|
||||
return newCredentials, nil
|
||||
}
|
||||
@@ -0,0 +1,143 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
)
|
||||
|
||||
// API Key status constants
|
||||
const (
|
||||
StatusAPIKeyActive = "active"
|
||||
StatusAPIKeyDisabled = "disabled"
|
||||
StatusAPIKeyQuotaExhausted = "quota_exhausted"
|
||||
StatusAPIKeyExpired = "expired"
|
||||
)
|
||||
|
||||
// Rate limit window durations
|
||||
const (
|
||||
RateLimitWindow5h = 5 * time.Hour
|
||||
RateLimitWindow1d = 24 * time.Hour
|
||||
RateLimitWindow7d = 7 * 24 * time.Hour
|
||||
)
|
||||
|
||||
// IsWindowExpired returns true if the window starting at windowStart has exceeded the given duration.
|
||||
// A nil windowStart is treated as expired — no initialized window means any accumulated usage is stale.
|
||||
func IsWindowExpired(windowStart *time.Time, duration time.Duration) bool {
|
||||
return windowStart == nil || time.Since(*windowStart) >= duration
|
||||
}
|
||||
|
||||
type APIKey struct {
|
||||
ID int64
|
||||
UserID int64
|
||||
Key string
|
||||
Name string
|
||||
GroupID *int64
|
||||
Status string
|
||||
IPWhitelist []string
|
||||
IPBlacklist []string
|
||||
// 预编译的 IP 规则,用于认证热路径避免重复 ParseIP/ParseCIDR。
|
||||
CompiledIPWhitelist *ip.CompiledIPRules `json:"-"`
|
||||
CompiledIPBlacklist *ip.CompiledIPRules `json:"-"`
|
||||
LastUsedAt *time.Time
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
User *User
|
||||
Group *Group
|
||||
|
||||
// Quota fields
|
||||
Quota float64 // Quota limit in USD (0 = unlimited)
|
||||
QuotaUsed float64 // Used quota amount
|
||||
ExpiresAt *time.Time // Expiration time (nil = never expires)
|
||||
|
||||
// Rate limit fields
|
||||
RateLimit5h float64 // Rate limit in USD per 5h (0 = unlimited)
|
||||
RateLimit1d float64 // Rate limit in USD per 1d (0 = unlimited)
|
||||
RateLimit7d float64 // Rate limit in USD per 7d (0 = unlimited)
|
||||
Usage5h float64 // Used amount in current 5h window
|
||||
Usage1d float64 // Used amount in current 1d window
|
||||
Usage7d float64 // Used amount in current 7d window
|
||||
Window5hStart *time.Time // Start of current 5h window
|
||||
Window1dStart *time.Time // Start of current 1d window
|
||||
Window7dStart *time.Time // Start of current 7d window
|
||||
}
|
||||
|
||||
func (k *APIKey) IsActive() bool {
|
||||
return k.Status == StatusActive
|
||||
}
|
||||
|
||||
// HasRateLimits returns true if any rate limit window is configured
|
||||
func (k *APIKey) HasRateLimits() bool {
|
||||
return k.RateLimit5h > 0 || k.RateLimit1d > 0 || k.RateLimit7d > 0
|
||||
}
|
||||
|
||||
// IsExpired checks if the API key has expired
|
||||
func (k *APIKey) IsExpired() bool {
|
||||
if k.ExpiresAt == nil {
|
||||
return false
|
||||
}
|
||||
return time.Now().After(*k.ExpiresAt)
|
||||
}
|
||||
|
||||
// IsQuotaExhausted checks if the API key quota is exhausted
|
||||
func (k *APIKey) IsQuotaExhausted() bool {
|
||||
if k.Quota <= 0 {
|
||||
return false // unlimited
|
||||
}
|
||||
return k.QuotaUsed >= k.Quota
|
||||
}
|
||||
|
||||
// GetQuotaRemaining returns remaining quota (-1 for unlimited)
|
||||
func (k *APIKey) GetQuotaRemaining() float64 {
|
||||
if k.Quota <= 0 {
|
||||
return -1 // unlimited
|
||||
}
|
||||
remaining := k.Quota - k.QuotaUsed
|
||||
if remaining < 0 {
|
||||
return 0
|
||||
}
|
||||
return remaining
|
||||
}
|
||||
|
||||
// GetDaysUntilExpiry returns days until expiry (-1 for never expires)
|
||||
func (k *APIKey) GetDaysUntilExpiry() int {
|
||||
if k.ExpiresAt == nil {
|
||||
return -1 // never expires
|
||||
}
|
||||
duration := time.Until(*k.ExpiresAt)
|
||||
if duration < 0 {
|
||||
return 0
|
||||
}
|
||||
return int(duration.Hours() / 24)
|
||||
}
|
||||
|
||||
// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired.
|
||||
func (k *APIKey) EffectiveUsage5h() float64 {
|
||||
if IsWindowExpired(k.Window5hStart, RateLimitWindow5h) {
|
||||
return 0
|
||||
}
|
||||
return k.Usage5h
|
||||
}
|
||||
|
||||
// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired.
|
||||
func (k *APIKey) EffectiveUsage1d() float64 {
|
||||
if IsWindowExpired(k.Window1dStart, RateLimitWindow1d) {
|
||||
return 0
|
||||
}
|
||||
return k.Usage1d
|
||||
}
|
||||
|
||||
// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired.
|
||||
func (k *APIKey) EffectiveUsage7d() float64 {
|
||||
if IsWindowExpired(k.Window7dStart, RateLimitWindow7d) {
|
||||
return 0
|
||||
}
|
||||
return k.Usage7d
|
||||
}
|
||||
|
||||
// APIKeyListFilters holds optional filtering parameters for listing API keys.
|
||||
type APIKeyListFilters struct {
|
||||
Search string
|
||||
Status string
|
||||
GroupID *int64 // nil=不筛选, 0=无分组, >0=指定分组
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
package service
|
||||
|
||||
import "time"
|
||||
|
||||
// APIKeyAuthSnapshot API Key 认证缓存快照(仅包含认证所需字段)
|
||||
type APIKeyAuthSnapshot struct {
|
||||
APIKeyID int64 `json:"api_key_id"`
|
||||
UserID int64 `json:"user_id"`
|
||||
GroupID *int64 `json:"group_id,omitempty"`
|
||||
Status string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist,omitempty"`
|
||||
IPBlacklist []string `json:"ip_blacklist,omitempty"`
|
||||
User APIKeyAuthUserSnapshot `json:"user"`
|
||||
Group *APIKeyAuthGroupSnapshot `json:"group,omitempty"`
|
||||
|
||||
// Quota fields for API Key independent quota feature
|
||||
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
|
||||
QuotaUsed float64 `json:"quota_used"` // Used quota amount
|
||||
|
||||
// Expiration field for API Key expiration feature
|
||||
ExpiresAt *time.Time `json:"expires_at,omitempty"` // Expiration time (nil = never expires)
|
||||
|
||||
// Rate limit configuration (only limits, not usage - usage read from Redis at check time)
|
||||
RateLimit5h float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d float64 `json:"rate_limit_7d"`
|
||||
}
|
||||
|
||||
// APIKeyAuthUserSnapshot 用户快照
|
||||
type APIKeyAuthUserSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
Status string `json:"status"`
|
||||
Role string `json:"role"`
|
||||
Balance float64 `json:"balance"`
|
||||
Concurrency int `json:"concurrency"`
|
||||
}
|
||||
|
||||
// APIKeyAuthGroupSnapshot 分组快照
|
||||
type APIKeyAuthGroupSnapshot struct {
|
||||
ID int64 `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Platform string `json:"platform"`
|
||||
Status string `json:"status"`
|
||||
SubscriptionType string `json:"subscription_type"`
|
||||
RateMultiplier float64 `json:"rate_multiplier"`
|
||||
DailyLimitUSD *float64 `json:"daily_limit_usd,omitempty"`
|
||||
WeeklyLimitUSD *float64 `json:"weekly_limit_usd,omitempty"`
|
||||
MonthlyLimitUSD *float64 `json:"monthly_limit_usd,omitempty"`
|
||||
ImagePrice1K *float64 `json:"image_price_1k,omitempty"`
|
||||
ImagePrice2K *float64 `json:"image_price_2k,omitempty"`
|
||||
ImagePrice4K *float64 `json:"image_price_4k,omitempty"`
|
||||
SoraImagePrice360 *float64 `json:"sora_image_price_360,omitempty"`
|
||||
SoraImagePrice540 *float64 `json:"sora_image_price_540,omitempty"`
|
||||
SoraVideoPricePerRequest *float64 `json:"sora_video_price_per_request,omitempty"`
|
||||
SoraVideoPricePerRequestHD *float64 `json:"sora_video_price_per_request_hd,omitempty"`
|
||||
ClaudeCodeOnly bool `json:"claude_code_only"`
|
||||
FallbackGroupID *int64 `json:"fallback_group_id,omitempty"`
|
||||
FallbackGroupIDOnInvalidRequest *int64 `json:"fallback_group_id_on_invalid_request,omitempty"`
|
||||
|
||||
// Model routing is used by gateway account selection, so it must be part of auth cache snapshot.
|
||||
// Only anthropic groups use these fields; others may leave them empty.
|
||||
ModelRouting map[string][]int64 `json:"model_routing,omitempty"`
|
||||
ModelRoutingEnabled bool `json:"model_routing_enabled"`
|
||||
MCPXMLInject bool `json:"mcp_xml_inject"`
|
||||
|
||||
// 支持的模型系列(仅 antigravity 平台使用)
|
||||
SupportedModelScopes []string `json:"supported_model_scopes,omitempty"`
|
||||
|
||||
// OpenAI Messages 调度配置(仅 openai 平台使用)
|
||||
AllowMessagesDispatch bool `json:"allow_messages_dispatch"`
|
||||
DefaultMappedModel string `json:"default_mapped_model,omitempty"`
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheEntry 缓存条目,支持负缓存
|
||||
type APIKeyAuthCacheEntry struct {
|
||||
NotFound bool `json:"not_found"`
|
||||
Snapshot *APIKeyAuthSnapshot `json:"snapshot,omitempty"`
|
||||
}
|
||||
@@ -0,0 +1,313 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand/v2"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/dgraph-io/ristretto"
|
||||
)
|
||||
|
||||
type apiKeyAuthCacheConfig struct {
|
||||
l1Size int
|
||||
l1TTL time.Duration
|
||||
l2TTL time.Duration
|
||||
negativeTTL time.Duration
|
||||
jitterPercent int
|
||||
singleflight bool
|
||||
}
|
||||
|
||||
func newAPIKeyAuthCacheConfig(cfg *config.Config) apiKeyAuthCacheConfig {
|
||||
if cfg == nil {
|
||||
return apiKeyAuthCacheConfig{}
|
||||
}
|
||||
auth := cfg.APIKeyAuth
|
||||
return apiKeyAuthCacheConfig{
|
||||
l1Size: auth.L1Size,
|
||||
l1TTL: time.Duration(auth.L1TTLSeconds) * time.Second,
|
||||
l2TTL: time.Duration(auth.L2TTLSeconds) * time.Second,
|
||||
negativeTTL: time.Duration(auth.NegativeTTLSeconds) * time.Second,
|
||||
jitterPercent: auth.JitterPercent,
|
||||
singleflight: auth.Singleflight,
|
||||
}
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) l1Enabled() bool {
|
||||
return c.l1Size > 0 && c.l1TTL > 0
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) l2Enabled() bool {
|
||||
return c.l2TTL > 0
|
||||
}
|
||||
|
||||
func (c apiKeyAuthCacheConfig) negativeEnabled() bool {
|
||||
return c.negativeTTL > 0
|
||||
}
|
||||
|
||||
// jitterTTL 为缓存 TTL 添加抖动,避免多个请求在同一时刻同时过期触发集中回源。
|
||||
// 这里直接使用 rand/v2 的顶层函数:并发安全,无需全局互斥锁。
|
||||
func (c apiKeyAuthCacheConfig) jitterTTL(ttl time.Duration) time.Duration {
|
||||
if ttl <= 0 {
|
||||
return ttl
|
||||
}
|
||||
if c.jitterPercent <= 0 {
|
||||
return ttl
|
||||
}
|
||||
percent := c.jitterPercent
|
||||
if percent > 100 {
|
||||
percent = 100
|
||||
}
|
||||
delta := float64(percent) / 100
|
||||
randVal := rand.Float64()
|
||||
factor := 1 - delta + randVal*(2*delta)
|
||||
if factor <= 0 {
|
||||
return ttl
|
||||
}
|
||||
return time.Duration(float64(ttl) * factor)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) initAuthCache(cfg *config.Config) {
|
||||
s.authCfg = newAPIKeyAuthCacheConfig(cfg)
|
||||
if !s.authCfg.l1Enabled() {
|
||||
return
|
||||
}
|
||||
cache, err := ristretto.NewCache(&ristretto.Config{
|
||||
NumCounters: int64(s.authCfg.l1Size) * 10,
|
||||
MaxCost: int64(s.authCfg.l1Size),
|
||||
BufferItems: 64,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.authCacheL1 = cache
|
||||
}
|
||||
|
||||
// StartAuthCacheInvalidationSubscriber starts the Pub/Sub subscriber for L1 cache invalidation.
|
||||
// This should be called after the service is fully initialized.
|
||||
func (s *APIKeyService) StartAuthCacheInvalidationSubscriber(ctx context.Context) {
|
||||
if s.cache == nil || s.authCacheL1 == nil {
|
||||
return
|
||||
}
|
||||
if err := s.cache.SubscribeAuthCacheInvalidation(ctx, func(cacheKey string) {
|
||||
s.authCacheL1.Del(cacheKey)
|
||||
}); err != nil {
|
||||
// Log but don't fail - L1 cache will still work, just without cross-instance invalidation
|
||||
println("[Service] Warning: failed to start auth cache invalidation subscriber:", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *APIKeyService) authCacheKey(key string) string {
|
||||
sum := sha256.Sum256([]byte(key))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func (s *APIKeyService) getAuthCacheEntry(ctx context.Context, cacheKey string) (*APIKeyAuthCacheEntry, bool) {
|
||||
if s.authCacheL1 != nil {
|
||||
if val, ok := s.authCacheL1.Get(cacheKey); ok {
|
||||
if entry, ok := val.(*APIKeyAuthCacheEntry); ok {
|
||||
return entry, true
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||
return nil, false
|
||||
}
|
||||
entry, err := s.cache.GetAuthCache(ctx, cacheKey)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
s.setAuthCacheL1(cacheKey, entry)
|
||||
return entry, true
|
||||
}
|
||||
|
||||
func (s *APIKeyService) setAuthCacheL1(cacheKey string, entry *APIKeyAuthCacheEntry) {
|
||||
if s.authCacheL1 == nil || entry == nil {
|
||||
return
|
||||
}
|
||||
ttl := s.authCfg.l1TTL
|
||||
if entry.NotFound && s.authCfg.negativeTTL > 0 && s.authCfg.negativeTTL < ttl {
|
||||
ttl = s.authCfg.negativeTTL
|
||||
}
|
||||
ttl = s.authCfg.jitterTTL(ttl)
|
||||
_ = s.authCacheL1.SetWithTTL(cacheKey, entry, 1, ttl)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) setAuthCacheEntry(ctx context.Context, cacheKey string, entry *APIKeyAuthCacheEntry, ttl time.Duration) {
|
||||
if entry == nil {
|
||||
return
|
||||
}
|
||||
s.setAuthCacheL1(cacheKey, entry)
|
||||
if s.cache == nil || !s.authCfg.l2Enabled() {
|
||||
return
|
||||
}
|
||||
_ = s.cache.SetAuthCache(ctx, cacheKey, entry, s.authCfg.jitterTTL(ttl))
|
||||
}
|
||||
|
||||
func (s *APIKeyService) deleteAuthCache(ctx context.Context, cacheKey string) {
|
||||
if s.authCacheL1 != nil {
|
||||
s.authCacheL1.Del(cacheKey)
|
||||
}
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
_ = s.cache.DeleteAuthCache(ctx, cacheKey)
|
||||
// Publish invalidation message to other instances
|
||||
_ = s.cache.PublishAuthCacheInvalidation(ctx, cacheKey)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) loadAuthCacheEntry(ctx context.Context, key, cacheKey string) (*APIKeyAuthCacheEntry, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrAPIKeyNotFound) {
|
||||
entry := &APIKeyAuthCacheEntry{NotFound: true}
|
||||
if s.authCfg.negativeEnabled() {
|
||||
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.negativeTTL)
|
||||
}
|
||||
return entry, nil
|
||||
}
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
apiKey.Key = key
|
||||
snapshot := s.snapshotFromAPIKey(apiKey)
|
||||
if snapshot == nil {
|
||||
return nil, fmt.Errorf("get api key: %w", ErrAPIKeyNotFound)
|
||||
}
|
||||
entry := &APIKeyAuthCacheEntry{Snapshot: snapshot}
|
||||
s.setAuthCacheEntry(ctx, cacheKey, entry, s.authCfg.l2TTL)
|
||||
return entry, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) applyAuthCacheEntry(key string, entry *APIKeyAuthCacheEntry) (*APIKey, bool, error) {
|
||||
if entry == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
if entry.NotFound {
|
||||
return nil, true, ErrAPIKeyNotFound
|
||||
}
|
||||
if entry.Snapshot == nil {
|
||||
return nil, false, nil
|
||||
}
|
||||
return s.snapshotToAPIKey(key, entry.Snapshot), true, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) snapshotFromAPIKey(apiKey *APIKey) *APIKeyAuthSnapshot {
|
||||
if apiKey == nil || apiKey.User == nil {
|
||||
return nil
|
||||
}
|
||||
snapshot := &APIKeyAuthSnapshot{
|
||||
APIKeyID: apiKey.ID,
|
||||
UserID: apiKey.UserID,
|
||||
GroupID: apiKey.GroupID,
|
||||
Status: apiKey.Status,
|
||||
IPWhitelist: apiKey.IPWhitelist,
|
||||
IPBlacklist: apiKey.IPBlacklist,
|
||||
Quota: apiKey.Quota,
|
||||
QuotaUsed: apiKey.QuotaUsed,
|
||||
ExpiresAt: apiKey.ExpiresAt,
|
||||
RateLimit5h: apiKey.RateLimit5h,
|
||||
RateLimit1d: apiKey.RateLimit1d,
|
||||
RateLimit7d: apiKey.RateLimit7d,
|
||||
User: APIKeyAuthUserSnapshot{
|
||||
ID: apiKey.User.ID,
|
||||
Status: apiKey.User.Status,
|
||||
Role: apiKey.User.Role,
|
||||
Balance: apiKey.User.Balance,
|
||||
Concurrency: apiKey.User.Concurrency,
|
||||
},
|
||||
}
|
||||
if apiKey.Group != nil {
|
||||
snapshot.Group = &APIKeyAuthGroupSnapshot{
|
||||
ID: apiKey.Group.ID,
|
||||
Name: apiKey.Group.Name,
|
||||
Platform: apiKey.Group.Platform,
|
||||
Status: apiKey.Group.Status,
|
||||
SubscriptionType: apiKey.Group.SubscriptionType,
|
||||
RateMultiplier: apiKey.Group.RateMultiplier,
|
||||
DailyLimitUSD: apiKey.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: apiKey.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: apiKey.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: apiKey.Group.ImagePrice1K,
|
||||
ImagePrice2K: apiKey.Group.ImagePrice2K,
|
||||
ImagePrice4K: apiKey.Group.ImagePrice4K,
|
||||
SoraImagePrice360: apiKey.Group.SoraImagePrice360,
|
||||
SoraImagePrice540: apiKey.Group.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: apiKey.Group.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: apiKey.Group.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: apiKey.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: apiKey.Group.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: apiKey.Group.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: apiKey.Group.ModelRouting,
|
||||
ModelRoutingEnabled: apiKey.Group.ModelRoutingEnabled,
|
||||
MCPXMLInject: apiKey.Group.MCPXMLInject,
|
||||
SupportedModelScopes: apiKey.Group.SupportedModelScopes,
|
||||
AllowMessagesDispatch: apiKey.Group.AllowMessagesDispatch,
|
||||
DefaultMappedModel: apiKey.Group.DefaultMappedModel,
|
||||
}
|
||||
}
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (s *APIKeyService) snapshotToAPIKey(key string, snapshot *APIKeyAuthSnapshot) *APIKey {
|
||||
if snapshot == nil {
|
||||
return nil
|
||||
}
|
||||
apiKey := &APIKey{
|
||||
ID: snapshot.APIKeyID,
|
||||
UserID: snapshot.UserID,
|
||||
GroupID: snapshot.GroupID,
|
||||
Key: key,
|
||||
Status: snapshot.Status,
|
||||
IPWhitelist: snapshot.IPWhitelist,
|
||||
IPBlacklist: snapshot.IPBlacklist,
|
||||
Quota: snapshot.Quota,
|
||||
QuotaUsed: snapshot.QuotaUsed,
|
||||
ExpiresAt: snapshot.ExpiresAt,
|
||||
RateLimit5h: snapshot.RateLimit5h,
|
||||
RateLimit1d: snapshot.RateLimit1d,
|
||||
RateLimit7d: snapshot.RateLimit7d,
|
||||
User: &User{
|
||||
ID: snapshot.User.ID,
|
||||
Status: snapshot.User.Status,
|
||||
Role: snapshot.User.Role,
|
||||
Balance: snapshot.User.Balance,
|
||||
Concurrency: snapshot.User.Concurrency,
|
||||
},
|
||||
}
|
||||
if snapshot.Group != nil {
|
||||
apiKey.Group = &Group{
|
||||
ID: snapshot.Group.ID,
|
||||
Name: snapshot.Group.Name,
|
||||
Platform: snapshot.Group.Platform,
|
||||
Status: snapshot.Group.Status,
|
||||
Hydrated: true,
|
||||
SubscriptionType: snapshot.Group.SubscriptionType,
|
||||
RateMultiplier: snapshot.Group.RateMultiplier,
|
||||
DailyLimitUSD: snapshot.Group.DailyLimitUSD,
|
||||
WeeklyLimitUSD: snapshot.Group.WeeklyLimitUSD,
|
||||
MonthlyLimitUSD: snapshot.Group.MonthlyLimitUSD,
|
||||
ImagePrice1K: snapshot.Group.ImagePrice1K,
|
||||
ImagePrice2K: snapshot.Group.ImagePrice2K,
|
||||
ImagePrice4K: snapshot.Group.ImagePrice4K,
|
||||
SoraImagePrice360: snapshot.Group.SoraImagePrice360,
|
||||
SoraImagePrice540: snapshot.Group.SoraImagePrice540,
|
||||
SoraVideoPricePerRequest: snapshot.Group.SoraVideoPricePerRequest,
|
||||
SoraVideoPricePerRequestHD: snapshot.Group.SoraVideoPricePerRequestHD,
|
||||
ClaudeCodeOnly: snapshot.Group.ClaudeCodeOnly,
|
||||
FallbackGroupID: snapshot.Group.FallbackGroupID,
|
||||
FallbackGroupIDOnInvalidRequest: snapshot.Group.FallbackGroupIDOnInvalidRequest,
|
||||
ModelRouting: snapshot.Group.ModelRouting,
|
||||
ModelRoutingEnabled: snapshot.Group.ModelRoutingEnabled,
|
||||
MCPXMLInject: snapshot.Group.MCPXMLInject,
|
||||
SupportedModelScopes: snapshot.Group.SupportedModelScopes,
|
||||
AllowMessagesDispatch: snapshot.Group.AllowMessagesDispatch,
|
||||
DefaultMappedModel: snapshot.Group.DefaultMappedModel,
|
||||
}
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey
|
||||
}
|
||||
@@ -0,0 +1,48 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
// InvalidateAuthCacheByKey 清除指定 API Key 的认证缓存
|
||||
func (s *APIKeyService) InvalidateAuthCacheByKey(ctx context.Context, key string) {
|
||||
if key == "" {
|
||||
return
|
||||
}
|
||||
cacheKey := s.authCacheKey(key)
|
||||
s.deleteAuthCache(ctx, cacheKey)
|
||||
}
|
||||
|
||||
// InvalidateAuthCacheByUserID 清除用户相关的 API Key 认证缓存
|
||||
func (s *APIKeyService) InvalidateAuthCacheByUserID(ctx context.Context, userID int64) {
|
||||
if userID <= 0 {
|
||||
return
|
||||
}
|
||||
keys, err := s.apiKeyRepo.ListKeysByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.deleteAuthCacheByKeys(ctx, keys)
|
||||
}
|
||||
|
||||
// InvalidateAuthCacheByGroupID 清除分组相关的 API Key 认证缓存
|
||||
func (s *APIKeyService) InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64) {
|
||||
if groupID <= 0 {
|
||||
return
|
||||
}
|
||||
keys, err := s.apiKeyRepo.ListKeysByGroupID(ctx, groupID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
s.deleteAuthCacheByKeys(ctx, keys)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) deleteAuthCacheByKeys(ctx context.Context, keys []string) {
|
||||
if len(keys) == 0 {
|
||||
return
|
||||
}
|
||||
for _, key := range keys {
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
s.deleteAuthCache(ctx, s.authCacheKey(key))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,245 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestIsWindowExpired(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
start *time.Time
|
||||
duration time.Duration
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil window start (treated as expired)",
|
||||
start: nil,
|
||||
duration: RateLimitWindow5h,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "active window (started 1h ago, 5h window)",
|
||||
start: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
duration: RateLimitWindow5h,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "expired window (started 6h ago, 5h window)",
|
||||
start: rateLimitTimePtr(now.Add(-6 * time.Hour)),
|
||||
duration: RateLimitWindow5h,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "exactly at boundary (started 5h ago, 5h window)",
|
||||
start: rateLimitTimePtr(now.Add(-5 * time.Hour)),
|
||||
duration: RateLimitWindow5h,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "active 1d window (started 12h ago)",
|
||||
start: rateLimitTimePtr(now.Add(-12 * time.Hour)),
|
||||
duration: RateLimitWindow1d,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "expired 1d window (started 25h ago)",
|
||||
start: rateLimitTimePtr(now.Add(-25 * time.Hour)),
|
||||
duration: RateLimitWindow1d,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "active 7d window (started 3d ago)",
|
||||
start: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)),
|
||||
duration: RateLimitWindow7d,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "expired 7d window (started 8d ago)",
|
||||
start: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)),
|
||||
duration: RateLimitWindow7d,
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := IsWindowExpired(tt.start, tt.duration)
|
||||
if got != tt.want {
|
||||
t.Errorf("IsWindowExpired() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKey_EffectiveUsage(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
key APIKey
|
||||
want5h float64
|
||||
want1d float64
|
||||
want7d float64
|
||||
}{
|
||||
{
|
||||
name: "all windows active",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
Usage7d: 50.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-3 * 24 * time.Hour)),
|
||||
},
|
||||
want5h: 5.0,
|
||||
want1d: 10.0,
|
||||
want7d: 50.0,
|
||||
},
|
||||
{
|
||||
name: "all windows expired",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
Usage7d: 50.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-25 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-8 * 24 * time.Hour)),
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "nil window starts return 0 (stale usage reset)",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
Usage7d: 50.0,
|
||||
Window5hStart: nil,
|
||||
Window1dStart: nil,
|
||||
Window7dStart: nil,
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "mixed: 5h expired, 1d active, 7d nil",
|
||||
key: APIKey{
|
||||
Usage5h: 5.0,
|
||||
Usage1d: 10.0,
|
||||
Usage7d: 50.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-6 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-12 * time.Hour)),
|
||||
Window7dStart: nil,
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 10.0,
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "zero usage with active windows",
|
||||
key: APIKey{
|
||||
Usage5h: 0,
|
||||
Usage1d: 0,
|
||||
Usage7d: 0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-1 * time.Hour)),
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.key.EffectiveUsage5h(); got != tt.want5h {
|
||||
t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h)
|
||||
}
|
||||
if got := tt.key.EffectiveUsage1d(); got != tt.want1d {
|
||||
t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d)
|
||||
}
|
||||
if got := tt.key.EffectiveUsage7d(); got != tt.want7d {
|
||||
t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyRateLimitData_EffectiveUsage(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
data APIKeyRateLimitData
|
||||
want5h float64
|
||||
want1d float64
|
||||
want7d float64
|
||||
}{
|
||||
{
|
||||
name: "all windows active",
|
||||
data: APIKeyRateLimitData{
|
||||
Usage5h: 3.0,
|
||||
Usage1d: 8.0,
|
||||
Usage7d: 40.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-2 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-10 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-2 * 24 * time.Hour)),
|
||||
},
|
||||
want5h: 3.0,
|
||||
want1d: 8.0,
|
||||
want7d: 40.0,
|
||||
},
|
||||
{
|
||||
name: "all windows expired",
|
||||
data: APIKeyRateLimitData{
|
||||
Usage5h: 3.0,
|
||||
Usage1d: 8.0,
|
||||
Usage7d: 40.0,
|
||||
Window5hStart: rateLimitTimePtr(now.Add(-10 * time.Hour)),
|
||||
Window1dStart: rateLimitTimePtr(now.Add(-48 * time.Hour)),
|
||||
Window7dStart: rateLimitTimePtr(now.Add(-10 * 24 * time.Hour)),
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
{
|
||||
name: "nil window starts return 0 (stale usage reset)",
|
||||
data: APIKeyRateLimitData{
|
||||
Usage5h: 3.0,
|
||||
Usage1d: 8.0,
|
||||
Usage7d: 40.0,
|
||||
Window5hStart: nil,
|
||||
Window1dStart: nil,
|
||||
Window7dStart: nil,
|
||||
},
|
||||
want5h: 0,
|
||||
want1d: 0,
|
||||
want7d: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.data.EffectiveUsage5h(); got != tt.want5h {
|
||||
t.Errorf("EffectiveUsage5h() = %v, want %v", got, tt.want5h)
|
||||
}
|
||||
if got := tt.data.EffectiveUsage1d(); got != tt.want1d {
|
||||
t.Errorf("EffectiveUsage1d() = %v, want %v", got, tt.want1d)
|
||||
}
|
||||
if got := tt.data.EffectiveUsage7d(); got != tt.want7d {
|
||||
t.Errorf("EffectiveUsage7d() = %v, want %v", got, tt.want7d)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func rateLimitTimePtr(t time.Time) *time.Time {
|
||||
return &t
|
||||
}
|
||||
@@ -0,0 +1,881 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/timezone"
|
||||
"github.com/dgraph-io/ristretto"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrAPIKeyNotFound = infraerrors.NotFound("API_KEY_NOT_FOUND", "api key not found")
|
||||
ErrGroupNotAllowed = infraerrors.Forbidden("GROUP_NOT_ALLOWED", "user is not allowed to bind this group")
|
||||
ErrAPIKeyExists = infraerrors.Conflict("API_KEY_EXISTS", "api key already exists")
|
||||
ErrAPIKeyTooShort = infraerrors.BadRequest("API_KEY_TOO_SHORT", "api key must be at least 16 characters")
|
||||
ErrAPIKeyInvalidChars = infraerrors.BadRequest("API_KEY_INVALID_CHARS", "api key can only contain letters, numbers, underscores, and hyphens")
|
||||
ErrAPIKeyRateLimited = infraerrors.TooManyRequests("API_KEY_RATE_LIMITED", "too many failed attempts, please try again later")
|
||||
ErrInvalidIPPattern = infraerrors.BadRequest("INVALID_IP_PATTERN", "invalid IP or CIDR pattern")
|
||||
// ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key has expired")
|
||||
ErrAPIKeyExpired = infraerrors.Forbidden("API_KEY_EXPIRED", "api key 已过期")
|
||||
// ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key quota exhausted")
|
||||
ErrAPIKeyQuotaExhausted = infraerrors.TooManyRequests("API_KEY_QUOTA_EXHAUSTED", "api key 额度已用完")
|
||||
|
||||
// Rate limit errors
|
||||
ErrAPIKeyRateLimit5hExceeded = infraerrors.TooManyRequests("API_KEY_RATE_5H_EXCEEDED", "api key 5小时限额已用完")
|
||||
ErrAPIKeyRateLimit1dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_1D_EXCEEDED", "api key 日限额已用完")
|
||||
ErrAPIKeyRateLimit7dExceeded = infraerrors.TooManyRequests("API_KEY_RATE_7D_EXCEEDED", "api key 7天限额已用完")
|
||||
)
|
||||
|
||||
const (
|
||||
apiKeyMaxErrorsPerHour = 20
|
||||
apiKeyLastUsedMinTouch = 30 * time.Second
|
||||
// DB 写失败后的短退避,避免请求路径持续同步重试造成写风暴与高延迟。
|
||||
apiKeyLastUsedFailBackoff = 5 * time.Second
|
||||
)
|
||||
|
||||
type APIKeyRepository interface {
|
||||
Create(ctx context.Context, key *APIKey) error
|
||||
GetByID(ctx context.Context, id int64) (*APIKey, error)
|
||||
// GetKeyAndOwnerID 仅获取 API Key 的 key 与所有者 ID,用于删除等轻量场景
|
||||
GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error)
|
||||
GetByKey(ctx context.Context, key string) (*APIKey, error)
|
||||
// GetByKeyForAuth 认证专用查询,返回最小字段集
|
||||
GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error)
|
||||
Update(ctx context.Context, key *APIKey) error
|
||||
Delete(ctx context.Context, id int64) error
|
||||
|
||||
ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error)
|
||||
VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error)
|
||||
CountByUserID(ctx context.Context, userID int64) (int64, error)
|
||||
ExistsByKey(ctx context.Context, key string) (bool, error)
|
||||
ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error)
|
||||
SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error)
|
||||
ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
CountByGroupID(ctx context.Context, groupID int64) (int64, error)
|
||||
ListKeysByUserID(ctx context.Context, userID int64) ([]string, error)
|
||||
ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error)
|
||||
|
||||
// Quota methods
|
||||
IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error)
|
||||
UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error
|
||||
|
||||
// Rate limit methods
|
||||
IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error
|
||||
ResetRateLimitWindows(ctx context.Context, id int64) error
|
||||
GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error)
|
||||
}
|
||||
|
||||
// APIKeyRateLimitData holds rate limit usage and window state for an API key.
|
||||
type APIKeyRateLimitData struct {
|
||||
Usage5h float64
|
||||
Usage1d float64
|
||||
Usage7d float64
|
||||
Window5hStart *time.Time
|
||||
Window1dStart *time.Time
|
||||
Window7dStart *time.Time
|
||||
}
|
||||
|
||||
// EffectiveUsage5h returns the 5h window usage, or 0 if the window has expired.
|
||||
func (d *APIKeyRateLimitData) EffectiveUsage5h() float64 {
|
||||
if IsWindowExpired(d.Window5hStart, RateLimitWindow5h) {
|
||||
return 0
|
||||
}
|
||||
return d.Usage5h
|
||||
}
|
||||
|
||||
// EffectiveUsage1d returns the 1d window usage, or 0 if the window has expired.
|
||||
func (d *APIKeyRateLimitData) EffectiveUsage1d() float64 {
|
||||
if IsWindowExpired(d.Window1dStart, RateLimitWindow1d) {
|
||||
return 0
|
||||
}
|
||||
return d.Usage1d
|
||||
}
|
||||
|
||||
// EffectiveUsage7d returns the 7d window usage, or 0 if the window has expired.
|
||||
func (d *APIKeyRateLimitData) EffectiveUsage7d() float64 {
|
||||
if IsWindowExpired(d.Window7dStart, RateLimitWindow7d) {
|
||||
return 0
|
||||
}
|
||||
return d.Usage7d
|
||||
}
|
||||
|
||||
// APIKeyQuotaUsageState captures the latest quota fields after an atomic quota update.
|
||||
// It is intentionally small so repositories can return it from a single SQL statement.
|
||||
type APIKeyQuotaUsageState struct {
|
||||
QuotaUsed float64
|
||||
Quota float64
|
||||
Key string
|
||||
Status string
|
||||
}
|
||||
|
||||
// APIKeyCache defines cache operations for API key service
|
||||
type APIKeyCache interface {
|
||||
GetCreateAttemptCount(ctx context.Context, userID int64) (int, error)
|
||||
IncrementCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
DeleteCreateAttemptCount(ctx context.Context, userID int64) error
|
||||
|
||||
IncrementDailyUsage(ctx context.Context, apiKey string) error
|
||||
SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error
|
||||
|
||||
GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error
|
||||
DeleteAuthCache(ctx context.Context, key string) error
|
||||
|
||||
// Pub/Sub for L1 cache invalidation across instances
|
||||
PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error
|
||||
SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error
|
||||
}
|
||||
|
||||
// APIKeyAuthCacheInvalidator 提供认证缓存失效能力
|
||||
type APIKeyAuthCacheInvalidator interface {
|
||||
InvalidateAuthCacheByKey(ctx context.Context, key string)
|
||||
InvalidateAuthCacheByUserID(ctx context.Context, userID int64)
|
||||
InvalidateAuthCacheByGroupID(ctx context.Context, groupID int64)
|
||||
}
|
||||
|
||||
// CreateAPIKeyRequest 创建API Key请求
|
||||
type CreateAPIKeyRequest struct {
|
||||
Name string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
CustomKey *string `json:"custom_key"` // 可选的自定义key
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单
|
||||
|
||||
// Quota fields
|
||||
Quota float64 `json:"quota"` // Quota limit in USD (0 = unlimited)
|
||||
ExpiresInDays *int `json:"expires_in_days"` // Days until expiry (nil = never expires)
|
||||
|
||||
// Rate limit fields (0 = unlimited)
|
||||
RateLimit5h float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d float64 `json:"rate_limit_7d"`
|
||||
}
|
||||
|
||||
// UpdateAPIKeyRequest 更新API Key请求
|
||||
type UpdateAPIKeyRequest struct {
|
||||
Name *string `json:"name"`
|
||||
GroupID *int64 `json:"group_id"`
|
||||
Status *string `json:"status"`
|
||||
IPWhitelist []string `json:"ip_whitelist"` // IP 白名单(空数组清空)
|
||||
IPBlacklist []string `json:"ip_blacklist"` // IP 黑名单(空数组清空)
|
||||
|
||||
// Quota fields
|
||||
Quota *float64 `json:"quota"` // Quota limit in USD (nil = no change, 0 = unlimited)
|
||||
ExpiresAt *time.Time `json:"expires_at"` // Expiration time (nil = no change)
|
||||
ClearExpiration bool `json:"-"` // Clear expiration (internal use)
|
||||
ResetQuota *bool `json:"reset_quota"` // Reset quota_used to 0
|
||||
|
||||
// Rate limit fields (nil = no change, 0 = unlimited)
|
||||
RateLimit5h *float64 `json:"rate_limit_5h"`
|
||||
RateLimit1d *float64 `json:"rate_limit_1d"`
|
||||
RateLimit7d *float64 `json:"rate_limit_7d"`
|
||||
ResetRateLimitUsage *bool `json:"reset_rate_limit_usage"` // Reset all usage counters to 0
|
||||
}
|
||||
|
||||
// APIKeyService API Key服务
|
||||
// RateLimitCacheInvalidator invalidates rate limit cache entries on manual reset.
|
||||
type RateLimitCacheInvalidator interface {
|
||||
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
|
||||
}
|
||||
|
||||
type APIKeyService struct {
|
||||
apiKeyRepo APIKeyRepository
|
||||
userRepo UserRepository
|
||||
groupRepo GroupRepository
|
||||
userSubRepo UserSubscriptionRepository
|
||||
userGroupRateRepo UserGroupRateRepository
|
||||
cache APIKeyCache
|
||||
rateLimitCacheInvalid RateLimitCacheInvalidator // optional: invalidate Redis rate limit cache
|
||||
cfg *config.Config
|
||||
authCacheL1 *ristretto.Cache
|
||||
authCfg apiKeyAuthCacheConfig
|
||||
authGroup singleflight.Group
|
||||
lastUsedTouchL1 sync.Map // keyID -> nextAllowedAt(time.Time)
|
||||
lastUsedTouchSF singleflight.Group
|
||||
}
|
||||
|
||||
// NewAPIKeyService 创建API Key服务实例
|
||||
func NewAPIKeyService(
|
||||
apiKeyRepo APIKeyRepository,
|
||||
userRepo UserRepository,
|
||||
groupRepo GroupRepository,
|
||||
userSubRepo UserSubscriptionRepository,
|
||||
userGroupRateRepo UserGroupRateRepository,
|
||||
cache APIKeyCache,
|
||||
cfg *config.Config,
|
||||
) *APIKeyService {
|
||||
svc := &APIKeyService{
|
||||
apiKeyRepo: apiKeyRepo,
|
||||
userRepo: userRepo,
|
||||
groupRepo: groupRepo,
|
||||
userSubRepo: userSubRepo,
|
||||
userGroupRateRepo: userGroupRateRepo,
|
||||
cache: cache,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.initAuthCache(cfg)
|
||||
return svc
|
||||
}
|
||||
|
||||
// SetRateLimitCacheInvalidator sets the optional rate limit cache invalidator.
|
||||
// Called after construction (e.g. in wire) to avoid circular dependencies.
|
||||
func (s *APIKeyService) SetRateLimitCacheInvalidator(inv RateLimitCacheInvalidator) {
|
||||
s.rateLimitCacheInvalid = inv
|
||||
}
|
||||
|
||||
func (s *APIKeyService) compileAPIKeyIPRules(apiKey *APIKey) {
|
||||
if apiKey == nil {
|
||||
return
|
||||
}
|
||||
apiKey.CompiledIPWhitelist = ip.CompileIPRules(apiKey.IPWhitelist)
|
||||
apiKey.CompiledIPBlacklist = ip.CompileIPRules(apiKey.IPBlacklist)
|
||||
}
|
||||
|
||||
// GenerateKey 生成随机API Key
|
||||
func (s *APIKeyService) GenerateKey() (string, error) {
|
||||
// 生成32字节随机数据
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("generate random bytes: %w", err)
|
||||
}
|
||||
|
||||
// 转换为十六进制字符串并添加前缀
|
||||
prefix := s.cfg.Default.APIKeyPrefix
|
||||
if prefix == "" {
|
||||
prefix = "sk-"
|
||||
}
|
||||
|
||||
key := prefix + hex.EncodeToString(bytes)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// ValidateCustomKey 验证自定义API Key格式
|
||||
func (s *APIKeyService) ValidateCustomKey(key string) error {
|
||||
// 检查长度
|
||||
if len(key) < 16 {
|
||||
return ErrAPIKeyTooShort
|
||||
}
|
||||
|
||||
// 检查字符:只允许字母、数字、下划线、连字符
|
||||
for _, c := range key {
|
||||
if (c >= 'a' && c <= 'z') ||
|
||||
(c >= 'A' && c <= 'Z') ||
|
||||
(c >= '0' && c <= '9') ||
|
||||
c == '_' || c == '-' {
|
||||
continue
|
||||
}
|
||||
return ErrAPIKeyInvalidChars
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkAPIKeyRateLimit 检查用户创建自定义Key的错误次数是否超限
|
||||
func (s *APIKeyService) checkAPIKeyRateLimit(ctx context.Context, userID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
count, err := s.cache.GetCreateAttemptCount(ctx, userID)
|
||||
if err != nil {
|
||||
// Redis 出错时不阻止用户操作
|
||||
return nil
|
||||
}
|
||||
|
||||
if count >= apiKeyMaxErrorsPerHour {
|
||||
return ErrAPIKeyRateLimited
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// incrementAPIKeyErrorCount 增加用户创建自定义Key的错误计数
|
||||
func (s *APIKeyService) incrementAPIKeyErrorCount(ctx context.Context, userID int64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
_ = s.cache.IncrementCreateAttemptCount(ctx, userID)
|
||||
}
|
||||
|
||||
// canUserBindGroup 检查用户是否可以绑定指定分组
|
||||
// 对于订阅类型分组:检查用户是否有有效订阅
|
||||
// 对于标准类型分组:使用原有的 AllowedGroups 和 IsExclusive 逻辑
|
||||
func (s *APIKeyService) canUserBindGroup(ctx context.Context, user *User, group *Group) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
_, err := s.userSubRepo.GetActiveByUserIDAndGroupID(ctx, user.ID, group.ID)
|
||||
return err == nil // 有有效订阅则允许
|
||||
}
|
||||
// 标准类型分组:使用原有逻辑
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
|
||||
// Create 创建API Key
|
||||
func (s *APIKeyService) Create(ctx context.Context, userID int64, req CreateAPIKeyRequest) (*APIKey, error) {
|
||||
// 验证用户存在
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 验证 IP 白名单格式
|
||||
if len(req.IPWhitelist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 IP 黑名单格式
|
||||
if len(req.IPBlacklist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证分组权限(如果指定了分组)
|
||||
if req.GroupID != nil {
|
||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
// 检查用户是否可以绑定该分组
|
||||
if !s.canUserBindGroup(ctx, user, group) {
|
||||
return nil, ErrGroupNotAllowed
|
||||
}
|
||||
}
|
||||
|
||||
var key string
|
||||
|
||||
// 判断是否使用自定义Key
|
||||
if req.CustomKey != nil && *req.CustomKey != "" {
|
||||
// 检查限流(仅对自定义key进行限流)
|
||||
if err := s.checkAPIKeyRateLimit(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证自定义Key格式
|
||||
if err := s.ValidateCustomKey(*req.CustomKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 检查Key是否已存在
|
||||
exists, err := s.apiKeyRepo.ExistsByKey(ctx, *req.CustomKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check key exists: %w", err)
|
||||
}
|
||||
if exists {
|
||||
// Key已存在,增加错误计数
|
||||
s.incrementAPIKeyErrorCount(ctx, userID)
|
||||
return nil, ErrAPIKeyExists
|
||||
}
|
||||
|
||||
key = *req.CustomKey
|
||||
} else {
|
||||
// 生成随机API Key
|
||||
var err error
|
||||
key, err = s.GenerateKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 创建API Key记录
|
||||
apiKey := &APIKey{
|
||||
UserID: userID,
|
||||
Key: key,
|
||||
Name: req.Name,
|
||||
GroupID: req.GroupID,
|
||||
Status: StatusActive,
|
||||
IPWhitelist: req.IPWhitelist,
|
||||
IPBlacklist: req.IPBlacklist,
|
||||
Quota: req.Quota,
|
||||
QuotaUsed: 0,
|
||||
RateLimit5h: req.RateLimit5h,
|
||||
RateLimit1d: req.RateLimit1d,
|
||||
RateLimit7d: req.RateLimit7d,
|
||||
}
|
||||
|
||||
// Set expiration time if specified
|
||||
if req.ExpiresInDays != nil && *req.ExpiresInDays > 0 {
|
||||
expiresAt := time.Now().AddDate(0, 0, *req.ExpiresInDays)
|
||||
apiKey.ExpiresAt = &expiresAt
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Create(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("create api key: %w", err)
|
||||
}
|
||||
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// List 获取用户的API Key列表
|
||||
func (s *APIKeyService) List(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
keys, pagination, err := s.apiKeyRepo.ListByUserID(ctx, userID, params, filters)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("list api keys: %w", err)
|
||||
}
|
||||
return keys, pagination, nil
|
||||
}
|
||||
|
||||
func (s *APIKeyService) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
if len(apiKeyIDs) == 0 {
|
||||
return []int64{}, nil
|
||||
}
|
||||
|
||||
validIDs, err := s.apiKeyRepo.VerifyOwnership(ctx, userID, apiKeyIDs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("verify api key ownership: %w", err)
|
||||
}
|
||||
return validIDs, nil
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取API Key
|
||||
func (s *APIKeyService) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// GetByKey 根据Key字符串获取API Key(用于认证)
|
||||
func (s *APIKeyService) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
cacheKey := s.authCacheKey(key)
|
||||
|
||||
if entry, ok := s.getAuthCacheEntry(ctx, cacheKey); ok {
|
||||
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
if s.authCfg.singleflight {
|
||||
value, err, _ := s.authGroup.Do(cacheKey, func() (any, error) {
|
||||
return s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entry, _ := value.(*APIKeyAuthCacheEntry)
|
||||
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
} else {
|
||||
entry, err := s.loadAuthCacheEntry(ctx, key, cacheKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if apiKey, used, err := s.applyAuthCacheEntry(key, entry); used {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
}
|
||||
|
||||
apiKey, err := s.apiKeyRepo.GetByKeyForAuth(ctx, key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
apiKey.Key = key
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// Update 更新API Key
|
||||
func (s *APIKeyService) Update(ctx context.Context, id int64, userID int64, req UpdateAPIKeyRequest) (*APIKey, error) {
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
// 验证所有权
|
||||
if apiKey.UserID != userID {
|
||||
return nil, ErrInsufficientPerms
|
||||
}
|
||||
|
||||
// 验证 IP 白名单格式
|
||||
if len(req.IPWhitelist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPWhitelist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证 IP 黑名单格式
|
||||
if len(req.IPBlacklist) > 0 {
|
||||
if invalid := ip.ValidateIPPatterns(req.IPBlacklist); len(invalid) > 0 {
|
||||
return nil, fmt.Errorf("%w: %v", ErrInvalidIPPattern, invalid)
|
||||
}
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != nil {
|
||||
apiKey.Name = *req.Name
|
||||
}
|
||||
|
||||
if req.GroupID != nil {
|
||||
// 验证分组权限
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
group, err := s.groupRepo.GetByID(ctx, *req.GroupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group: %w", err)
|
||||
}
|
||||
|
||||
if !s.canUserBindGroup(ctx, user, group) {
|
||||
return nil, ErrGroupNotAllowed
|
||||
}
|
||||
|
||||
apiKey.GroupID = req.GroupID
|
||||
}
|
||||
|
||||
if req.Status != nil {
|
||||
apiKey.Status = *req.Status
|
||||
// 如果状态改变,清除Redis缓存
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, apiKey.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
// Update quota fields
|
||||
if req.Quota != nil {
|
||||
apiKey.Quota = *req.Quota
|
||||
// If quota is increased and status was quota_exhausted, reactivate
|
||||
if apiKey.Status == StatusAPIKeyQuotaExhausted && *req.Quota > apiKey.QuotaUsed {
|
||||
apiKey.Status = StatusActive
|
||||
}
|
||||
}
|
||||
if req.ResetQuota != nil && *req.ResetQuota {
|
||||
apiKey.QuotaUsed = 0
|
||||
// If resetting quota and status was quota_exhausted, reactivate
|
||||
if apiKey.Status == StatusAPIKeyQuotaExhausted {
|
||||
apiKey.Status = StatusActive
|
||||
}
|
||||
}
|
||||
if req.ClearExpiration {
|
||||
apiKey.ExpiresAt = nil
|
||||
// If clearing expiry and status was expired, reactivate
|
||||
if apiKey.Status == StatusAPIKeyExpired {
|
||||
apiKey.Status = StatusActive
|
||||
}
|
||||
} else if req.ExpiresAt != nil {
|
||||
apiKey.ExpiresAt = req.ExpiresAt
|
||||
// If extending expiry and status was expired, reactivate
|
||||
if apiKey.Status == StatusAPIKeyExpired && time.Now().Before(*req.ExpiresAt) {
|
||||
apiKey.Status = StatusActive
|
||||
}
|
||||
}
|
||||
|
||||
// 更新 IP 限制(空数组会清空设置)
|
||||
apiKey.IPWhitelist = req.IPWhitelist
|
||||
apiKey.IPBlacklist = req.IPBlacklist
|
||||
|
||||
// Update rate limit configuration
|
||||
if req.RateLimit5h != nil {
|
||||
apiKey.RateLimit5h = *req.RateLimit5h
|
||||
}
|
||||
if req.RateLimit1d != nil {
|
||||
apiKey.RateLimit1d = *req.RateLimit1d
|
||||
}
|
||||
if req.RateLimit7d != nil {
|
||||
apiKey.RateLimit7d = *req.RateLimit7d
|
||||
}
|
||||
resetRateLimit := req.ResetRateLimitUsage != nil && *req.ResetRateLimitUsage
|
||||
if resetRateLimit {
|
||||
apiKey.Usage5h = 0
|
||||
apiKey.Usage1d = 0
|
||||
apiKey.Usage7d = 0
|
||||
apiKey.Window5hStart = nil
|
||||
apiKey.Window1dStart = nil
|
||||
apiKey.Window7dStart = nil
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil, fmt.Errorf("update api key: %w", err)
|
||||
}
|
||||
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
s.compileAPIKeyIPRules(apiKey)
|
||||
|
||||
// Invalidate Redis rate limit cache so reset takes effect immediately
|
||||
if resetRateLimit && s.rateLimitCacheInvalid != nil {
|
||||
_ = s.rateLimitCacheInvalid.InvalidateAPIKeyRateLimit(ctx, apiKey.ID)
|
||||
}
|
||||
|
||||
return apiKey, nil
|
||||
}
|
||||
|
||||
// Delete 删除API Key
|
||||
func (s *APIKeyService) Delete(ctx context.Context, id int64, userID int64) error {
|
||||
key, ownerID, err := s.apiKeyRepo.GetKeyAndOwnerID(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get api key: %w", err)
|
||||
}
|
||||
|
||||
// 验证当前用户是否为该 API Key 的所有者
|
||||
if ownerID != userID {
|
||||
return ErrInsufficientPerms
|
||||
}
|
||||
|
||||
// 清除Redis缓存(使用 userID 而非 apiKey.UserID)
|
||||
if s.cache != nil {
|
||||
_ = s.cache.DeleteCreateAttemptCount(ctx, userID)
|
||||
}
|
||||
s.InvalidateAuthCacheByKey(ctx, key)
|
||||
|
||||
if err := s.apiKeyRepo.Delete(ctx, id); err != nil {
|
||||
return fmt.Errorf("delete api key: %w", err)
|
||||
}
|
||||
s.lastUsedTouchL1.Delete(id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ValidateKey 验证API Key是否有效(用于认证中间件)
|
||||
func (s *APIKeyService) ValidateKey(ctx context.Context, key string) (*APIKey, *User, error) {
|
||||
// 获取API Key
|
||||
apiKey, err := s.GetByKey(ctx, key)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
// 检查API Key状态
|
||||
if !apiKey.IsActive() {
|
||||
return nil, nil, infraerrors.Unauthorized("API_KEY_INACTIVE", "api key is not active")
|
||||
}
|
||||
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, apiKey.UserID)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 检查用户状态
|
||||
if !user.IsActive() {
|
||||
return nil, nil, ErrUserNotActive
|
||||
}
|
||||
|
||||
return apiKey, user, nil
|
||||
}
|
||||
|
||||
// TouchLastUsed 通过防抖更新 api_keys.last_used_at,减少高频写放大。
|
||||
// 该操作为尽力而为,不应阻塞主请求链路。
|
||||
func (s *APIKeyService) TouchLastUsed(ctx context.Context, keyID int64) error {
|
||||
if keyID <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
|
||||
if nextAllowedAt, ok := v.(time.Time); ok && now.Before(nextAllowedAt) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
_, err, _ := s.lastUsedTouchSF.Do(strconv.FormatInt(keyID, 10), func() (any, error) {
|
||||
latest := time.Now()
|
||||
if v, ok := s.lastUsedTouchL1.Load(keyID); ok {
|
||||
if nextAllowedAt, ok := v.(time.Time); ok && latest.Before(nextAllowedAt) {
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
if err := s.apiKeyRepo.UpdateLastUsed(ctx, keyID, latest); err != nil {
|
||||
s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedFailBackoff))
|
||||
return nil, fmt.Errorf("touch api key last used: %w", err)
|
||||
}
|
||||
s.lastUsedTouchL1.Store(keyID, latest.Add(apiKeyLastUsedMinTouch))
|
||||
return nil, nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// IncrementUsage 增加API Key使用次数(可选:用于统计)
|
||||
func (s *APIKeyService) IncrementUsage(ctx context.Context, keyID int64) error {
|
||||
// 使用Redis计数器
|
||||
if s.cache != nil {
|
||||
cacheKey := fmt.Sprintf("apikey:usage:%d:%s", keyID, timezone.Now().Format("2006-01-02"))
|
||||
if err := s.cache.IncrementDailyUsage(ctx, cacheKey); err != nil {
|
||||
return fmt.Errorf("increment usage: %w", err)
|
||||
}
|
||||
// 设置24小时过期
|
||||
_ = s.cache.SetDailyUsageExpiry(ctx, cacheKey, 24*time.Hour)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetAvailableGroups 获取用户有权限绑定的分组列表
|
||||
// 返回用户可以选择的分组:
|
||||
// - 标准类型分组:公开的(非专属)或用户被明确允许的
|
||||
// - 订阅类型分组:用户有有效订阅的
|
||||
func (s *APIKeyService) GetAvailableGroups(ctx context.Context, userID int64) ([]Group, error) {
|
||||
// 获取用户信息
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user: %w", err)
|
||||
}
|
||||
|
||||
// 获取所有活跃分组
|
||||
allGroups, err := s.groupRepo.ListActive(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active groups: %w", err)
|
||||
}
|
||||
|
||||
// 获取用户的所有有效订阅
|
||||
activeSubscriptions, err := s.userSubRepo.ListActiveByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list active subscriptions: %w", err)
|
||||
}
|
||||
|
||||
// 构建订阅分组 ID 集合
|
||||
subscribedGroupIDs := make(map[int64]bool)
|
||||
for _, sub := range activeSubscriptions {
|
||||
subscribedGroupIDs[sub.GroupID] = true
|
||||
}
|
||||
|
||||
// 过滤出用户有权限的分组
|
||||
availableGroups := make([]Group, 0)
|
||||
for _, group := range allGroups {
|
||||
if s.canUserBindGroupInternal(user, &group, subscribedGroupIDs) {
|
||||
availableGroups = append(availableGroups, group)
|
||||
}
|
||||
}
|
||||
|
||||
return availableGroups, nil
|
||||
}
|
||||
|
||||
// canUserBindGroupInternal 内部方法,检查用户是否可以绑定分组(使用预加载的订阅数据)
|
||||
func (s *APIKeyService) canUserBindGroupInternal(user *User, group *Group, subscribedGroupIDs map[int64]bool) bool {
|
||||
// 订阅类型分组:需要有效订阅
|
||||
if group.IsSubscriptionType() {
|
||||
return subscribedGroupIDs[group.ID]
|
||||
}
|
||||
// 标准类型分组:使用原有逻辑
|
||||
return user.CanBindGroup(group.ID, group.IsExclusive)
|
||||
}
|
||||
|
||||
func (s *APIKeyService) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
keys, err := s.apiKeyRepo.SearchAPIKeys(ctx, userID, keyword, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("search api keys: %w", err)
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// GetUserGroupRates 获取用户的专属分组倍率配置
|
||||
// 返回 map[groupID]rateMultiplier
|
||||
func (s *APIKeyService) GetUserGroupRates(ctx context.Context, userID int64) (map[int64]float64, error) {
|
||||
if s.userGroupRateRepo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
rates, err := s.userGroupRateRepo.GetByUserID(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user group rates: %w", err)
|
||||
}
|
||||
return rates, nil
|
||||
}
|
||||
|
||||
// CheckAPIKeyQuotaAndExpiry checks if the API key is valid for use (not expired, quota not exhausted)
|
||||
// Returns nil if valid, error if invalid
|
||||
func (s *APIKeyService) CheckAPIKeyQuotaAndExpiry(apiKey *APIKey) error {
|
||||
// Check expiration
|
||||
if apiKey.IsExpired() {
|
||||
return ErrAPIKeyExpired
|
||||
}
|
||||
|
||||
// Check quota
|
||||
if apiKey.IsQuotaExhausted() {
|
||||
return ErrAPIKeyQuotaExhausted
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UpdateQuotaUsed updates the quota_used field after a request
|
||||
// Also checks if quota is exhausted and updates status accordingly
|
||||
func (s *APIKeyService) UpdateQuotaUsed(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
if cost <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
type quotaStateReader interface {
|
||||
IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error)
|
||||
}
|
||||
|
||||
if repo, ok := s.apiKeyRepo.(quotaStateReader); ok {
|
||||
state, err := repo.IncrementQuotaUsedAndGetState(ctx, apiKeyID, cost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("increment quota used: %w", err)
|
||||
}
|
||||
if state != nil && state.Status == StatusAPIKeyQuotaExhausted && strings.TrimSpace(state.Key) != "" {
|
||||
s.InvalidateAuthCacheByKey(ctx, state.Key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Use repository to atomically increment quota_used
|
||||
newQuotaUsed, err := s.apiKeyRepo.IncrementQuotaUsed(ctx, apiKeyID, cost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("increment quota used: %w", err)
|
||||
}
|
||||
|
||||
// Check if quota is now exhausted and update status if needed
|
||||
apiKey, err := s.apiKeyRepo.GetByID(ctx, apiKeyID)
|
||||
if err != nil {
|
||||
return nil // Don't fail the request, just log
|
||||
}
|
||||
|
||||
// If quota is set and now exhausted, update status
|
||||
if apiKey.Quota > 0 && newQuotaUsed >= apiKey.Quota {
|
||||
apiKey.Status = StatusAPIKeyQuotaExhausted
|
||||
if err := s.apiKeyRepo.Update(ctx, apiKey); err != nil {
|
||||
return nil // Don't fail the request
|
||||
}
|
||||
// Invalidate cache so next request sees the new status
|
||||
s.InvalidateAuthCacheByKey(ctx, apiKey.Key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRateLimitData returns rate limit usage and window state for an API key.
|
||||
func (s *APIKeyService) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
|
||||
return s.apiKeyRepo.GetRateLimitData(ctx, id)
|
||||
}
|
||||
|
||||
// UpdateRateLimitUsage atomically increments rate limit usage counters in the DB.
|
||||
func (s *APIKeyService) UpdateRateLimitUsage(ctx context.Context, apiKeyID int64, cost float64) error {
|
||||
if cost <= 0 {
|
||||
return nil
|
||||
}
|
||||
return s.apiKeyRepo.IncrementRateLimitUsage(ctx, apiKeyID, cost)
|
||||
}
|
||||
@@ -0,0 +1,448 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/redis/go-redis/v9"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type authRepoStub struct {
|
||||
getByKeyForAuth func(ctx context.Context, key string) (*APIKey, error)
|
||||
listKeysByUserID func(ctx context.Context, userID int64) ([]string, error)
|
||||
listKeysByGroupID func(ctx context.Context, groupID int64) ([]string, error)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
panic("unexpected GetKeyAndOwnerID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||
if s.getByKeyForAuth == nil {
|
||||
panic("unexpected GetByKeyForAuth call")
|
||||
}
|
||||
return s.getByKeyForAuth(ctx, key)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
panic("unexpected VerifyOwnership call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
panic("unexpected CountByUserID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
panic("unexpected ExistsByKey call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
panic("unexpected SearchAPIKeys call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected ClearGroupIDByGroupID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
if s.listKeysByUserID == nil {
|
||||
panic("unexpected ListKeysByUserID call")
|
||||
}
|
||||
return s.listKeysByUserID(ctx, userID)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
if s.listKeysByGroupID == nil {
|
||||
panic("unexpected ListKeysByGroupID call")
|
||||
}
|
||||
return s.listKeysByGroupID(ctx, groupID)
|
||||
}
|
||||
|
||||
func (s *authRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||
panic("unexpected IncrementQuotaUsed call")
|
||||
}
|
||||
|
||||
func (s *authRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
panic("unexpected UpdateLastUsed call")
|
||||
}
|
||||
func (s *authRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
panic("unexpected IncrementRateLimitUsage call")
|
||||
}
|
||||
func (s *authRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
panic("unexpected ResetRateLimitWindows call")
|
||||
}
|
||||
func (s *authRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
|
||||
panic("unexpected GetRateLimitData call")
|
||||
}
|
||||
|
||||
type authCacheStub struct {
|
||||
getAuthCache func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error)
|
||||
setAuthKeys []string
|
||||
deleteAuthKeys []string
|
||||
}
|
||||
|
||||
func (s *authCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
if s.getAuthCache == nil {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
return s.getAuthCache(ctx, key)
|
||||
}
|
||||
|
||||
func (s *authCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||
s.setAuthKeys = append(s.setAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *authCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_UsesL2Cache(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return nil, errors.New("unexpected repo call")
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
groupID := int64(9)
|
||||
cacheEntry := &APIKeyAuthCacheEntry{
|
||||
Snapshot: &APIKeyAuthSnapshot{
|
||||
APIKeyID: 1,
|
||||
UserID: 2,
|
||||
GroupID: &groupID,
|
||||
Status: StatusActive,
|
||||
User: APIKeyAuthUserSnapshot{
|
||||
ID: 2,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 10,
|
||||
Concurrency: 3,
|
||||
},
|
||||
Group: &APIKeyAuthGroupSnapshot{
|
||||
ID: groupID,
|
||||
Name: "g",
|
||||
Platform: PlatformAnthropic,
|
||||
Status: StatusActive,
|
||||
SubscriptionType: SubscriptionTypeStandard,
|
||||
RateMultiplier: 1,
|
||||
ModelRoutingEnabled: true,
|
||||
ModelRouting: map[string][]int64{
|
||||
"claude-opus-*": {1, 2},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return cacheEntry, nil
|
||||
}
|
||||
|
||||
apiKey, err := svc.GetByKey(context.Background(), "k1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), apiKey.ID)
|
||||
require.Equal(t, int64(2), apiKey.User.ID)
|
||||
require.Equal(t, groupID, apiKey.Group.ID)
|
||||
require.True(t, apiKey.Group.ModelRoutingEnabled)
|
||||
require.Equal(t, map[string][]int64{"claude-opus-*": {1, 2}}, apiKey.Group.ModelRouting)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_NegativeCache(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return nil, errors.New("unexpected repo call")
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return &APIKeyAuthCacheEntry{NotFound: true}, nil
|
||||
}
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "missing")
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_CacheMissStoresL2(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return &APIKey{
|
||||
ID: 5,
|
||||
UserID: 7,
|
||||
Status: StatusActive,
|
||||
User: &User{
|
||||
ID: 7,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 12,
|
||||
Concurrency: 2,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
|
||||
apiKey, err := svc.GetByKey(context.Background(), "k2")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(5), apiKey.ID)
|
||||
require.Len(t, cache.setAuthKeys, 1)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_UsesL1Cache(t *testing.T) {
|
||||
var calls int32
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
return &APIKey{
|
||||
ID: 21,
|
||||
UserID: 3,
|
||||
Status: StatusActive,
|
||||
User: &User{
|
||||
ID: 3,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 5,
|
||||
Concurrency: 2,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L1Size: 1000,
|
||||
L1TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
require.NotNil(t, svc.authCacheL1)
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "k-l1")
|
||||
require.NoError(t, err)
|
||||
svc.authCacheL1.Wait()
|
||||
cacheKey := svc.authCacheKey("k-l1")
|
||||
_, ok := svc.authCacheL1.Get(cacheKey)
|
||||
require.True(t, ok)
|
||||
_, err = svc.GetByKey(context.Background(), "k-l1")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||
}
|
||||
|
||||
func TestAPIKeyService_InvalidateAuthCacheByUserID(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||
return []string{"k1", "k2"}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByUserID(context.Background(), 7)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_InvalidateAuthCacheByGroupID(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
listKeysByGroupID: func(ctx context.Context, groupID int64) ([]string, error) {
|
||||
return []string{"k1", "k2"}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByGroupID(context.Background(), 9)
|
||||
require.Len(t, cache.deleteAuthKeys, 2)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_InvalidateAuthCacheByKey(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
listKeysByUserID: func(ctx context.Context, userID int64) ([]string, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
svc.InvalidateAuthCacheByKey(context.Background(), "k1")
|
||||
require.Len(t, cache.deleteAuthKeys, 1)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_CachesNegativeOnRepoMiss(t *testing.T) {
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
return nil, ErrAPIKeyNotFound
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
L2TTLSeconds: 60,
|
||||
NegativeTTLSeconds: 30,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
cache.getAuthCache = func(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, redis.Nil
|
||||
}
|
||||
|
||||
_, err := svc.GetByKey(context.Background(), "missing")
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
require.Len(t, cache.setAuthKeys, 1)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_GetByKey_SingleflightCollapses(t *testing.T) {
|
||||
var calls int32
|
||||
cache := &authCacheStub{}
|
||||
repo := &authRepoStub{
|
||||
getByKeyForAuth: func(ctx context.Context, key string) (*APIKey, error) {
|
||||
atomic.AddInt32(&calls, 1)
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
return &APIKey{
|
||||
ID: 11,
|
||||
UserID: 2,
|
||||
Status: StatusActive,
|
||||
User: &User{
|
||||
ID: 2,
|
||||
Status: StatusActive,
|
||||
Role: RoleUser,
|
||||
Balance: 1,
|
||||
Concurrency: 1,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
cfg := &config.Config{
|
||||
APIKeyAuth: config.APIKeyAuthCacheConfig{
|
||||
Singleflight: true,
|
||||
},
|
||||
}
|
||||
svc := NewAPIKeyService(repo, nil, nil, nil, nil, cache, cfg)
|
||||
|
||||
start := make(chan struct{})
|
||||
wg := sync.WaitGroup{}
|
||||
errs := make([]error, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
_, err := svc.GetByKey(context.Background(), "k1")
|
||||
errs[idx] = err
|
||||
}(i)
|
||||
}
|
||||
close(start)
|
||||
wg.Wait()
|
||||
|
||||
for _, err := range errs {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&calls))
|
||||
}
|
||||
@@ -0,0 +1,291 @@
|
||||
//go:build unit
|
||||
|
||||
// API Key 服务删除方法的单元测试
|
||||
// 测试 APIKeyService.Delete 方法在各种场景下的行为,
|
||||
// 包括权限验证、缓存清理和错误处理
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// apiKeyRepoStub 是 APIKeyRepository 接口的测试桩实现。
|
||||
// 用于隔离测试 APIKeyService.Delete 方法,避免依赖真实数据库。
|
||||
//
|
||||
// 设计说明:
|
||||
// - apiKey/getByIDErr: 模拟 GetKeyAndOwnerID 返回的记录与错误
|
||||
// - deleteErr: 模拟 Delete 返回的错误
|
||||
// - deletedIDs: 记录被调用删除的 API Key ID,用于断言验证
|
||||
type apiKeyRepoStub struct {
|
||||
apiKey *APIKey // GetKeyAndOwnerID 的返回值
|
||||
getByIDErr error // GetKeyAndOwnerID 的错误返回值
|
||||
deleteErr error // Delete 的错误返回值
|
||||
deletedIDs []int64 // 记录已删除的 API Key ID 列表
|
||||
updateLastUsed func(ctx context.Context, id int64, usedAt time.Time) error
|
||||
touchedIDs []int64
|
||||
touchedUsedAts []time.Time
|
||||
}
|
||||
|
||||
// 以下方法在本测试中不应被调用,使用 panic 确保测试失败时能快速定位问题
|
||||
|
||||
func (s *apiKeyRepoStub) Create(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByID(ctx context.Context, id int64) (*APIKey, error) {
|
||||
if s.getByIDErr != nil {
|
||||
return nil, s.getByIDErr
|
||||
}
|
||||
if s.apiKey != nil {
|
||||
clone := *s.apiKey
|
||||
return &clone, nil
|
||||
}
|
||||
panic("unexpected GetByID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetKeyAndOwnerID(ctx context.Context, id int64) (string, int64, error) {
|
||||
if s.getByIDErr != nil {
|
||||
return "", 0, s.getByIDErr
|
||||
}
|
||||
if s.apiKey != nil {
|
||||
return s.apiKey.Key, s.apiKey.UserID, nil
|
||||
}
|
||||
return "", 0, ErrAPIKeyNotFound
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByKey(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetByKeyForAuth(ctx context.Context, key string) (*APIKey, error) {
|
||||
panic("unexpected GetByKeyForAuth call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) Update(ctx context.Context, key *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
|
||||
// Delete 记录被删除的 API Key ID 并返回预设的错误。
|
||||
// 通过 deletedIDs 可以验证删除操作是否被正确调用。
|
||||
func (s *apiKeyRepoStub) Delete(ctx context.Context, id int64) error {
|
||||
s.deletedIDs = append(s.deletedIDs, id)
|
||||
return s.deleteErr
|
||||
}
|
||||
|
||||
// 以下是接口要求实现但本测试不关心的方法
|
||||
|
||||
func (s *apiKeyRepoStub) ListByUserID(ctx context.Context, userID int64, params pagination.PaginationParams, filters APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) VerifyOwnership(ctx context.Context, userID int64, apiKeyIDs []int64) ([]int64, error) {
|
||||
panic("unexpected VerifyOwnership call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) CountByUserID(ctx context.Context, userID int64) (int64, error) {
|
||||
panic("unexpected CountByUserID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ExistsByKey(ctx context.Context, key string) (bool, error) {
|
||||
panic("unexpected ExistsByKey call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListByGroupID(ctx context.Context, groupID int64, params pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) SearchAPIKeys(ctx context.Context, userID int64, keyword string, limit int) ([]APIKey, error) {
|
||||
panic("unexpected SearchAPIKeys call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ClearGroupIDByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected ClearGroupIDByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) CountByGroupID(ctx context.Context, groupID int64) (int64, error) {
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListKeysByUserID(ctx context.Context, userID int64) ([]string, error) {
|
||||
panic("unexpected ListKeysByUserID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ListKeysByGroupID(ctx context.Context, groupID int64) ([]string, error) {
|
||||
panic("unexpected ListKeysByGroupID call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) IncrementQuotaUsed(ctx context.Context, id int64, amount float64) (float64, error) {
|
||||
panic("unexpected IncrementQuotaUsed call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
s.touchedIDs = append(s.touchedIDs, id)
|
||||
s.touchedUsedAts = append(s.touchedUsedAts, usedAt)
|
||||
if s.updateLastUsed != nil {
|
||||
return s.updateLastUsed(ctx, id, usedAt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) IncrementRateLimitUsage(ctx context.Context, id int64, cost float64) error {
|
||||
panic("unexpected IncrementRateLimitUsage call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) ResetRateLimitWindows(ctx context.Context, id int64) error {
|
||||
panic("unexpected ResetRateLimitWindows call")
|
||||
}
|
||||
|
||||
func (s *apiKeyRepoStub) GetRateLimitData(ctx context.Context, id int64) (*APIKeyRateLimitData, error) {
|
||||
panic("unexpected GetRateLimitData call")
|
||||
}
|
||||
|
||||
// apiKeyCacheStub 是 APIKeyCache 接口的测试桩实现。
|
||||
// 用于验证删除操作时缓存清理逻辑是否被正确调用。
|
||||
//
|
||||
// 设计说明:
|
||||
// - invalidated: 记录被清除缓存的用户 ID 列表
|
||||
type apiKeyCacheStub struct {
|
||||
invalidated []int64 // 记录调用 DeleteCreateAttemptCount 时传入的用户 ID
|
||||
deleteAuthKeys []string // 记录调用 DeleteAuthCache 时传入的缓存 key
|
||||
}
|
||||
|
||||
// GetCreateAttemptCount 返回 0,表示用户未超过创建次数限制
|
||||
func (s *apiKeyCacheStub) GetCreateAttemptCount(ctx context.Context, userID int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// IncrementCreateAttemptCount 空实现,本测试不验证此行为
|
||||
func (s *apiKeyCacheStub) IncrementCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteCreateAttemptCount 记录被清除缓存的用户 ID。
|
||||
// 删除 API Key 时会调用此方法清除用户的创建尝试计数缓存。
|
||||
func (s *apiKeyCacheStub) DeleteCreateAttemptCount(ctx context.Context, userID int64) error {
|
||||
s.invalidated = append(s.invalidated, userID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// IncrementDailyUsage 空实现,本测试不验证此行为
|
||||
func (s *apiKeyCacheStub) IncrementDailyUsage(ctx context.Context, apiKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetDailyUsageExpiry 空实现,本测试不验证此行为
|
||||
func (s *apiKeyCacheStub) SetDailyUsageExpiry(ctx context.Context, apiKey string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) GetAuthCache(ctx context.Context, key string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) SetAuthCache(ctx context.Context, key string, entry *APIKeyAuthCacheEntry, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) DeleteAuthCache(ctx context.Context, key string) error {
|
||||
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) PublishAuthCacheInvalidation(ctx context.Context, cacheKey string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *apiKeyCacheStub) SubscribeAuthCacheInvalidation(ctx context.Context, handler func(cacheKey string)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_OwnerMismatch 测试非所有者尝试删除时返回权限错误。
|
||||
// 预期行为:
|
||||
// - GetKeyAndOwnerID 返回所有者 ID 为 1
|
||||
// - 调用者 userID 为 2(不匹配)
|
||||
// - 返回 ErrInsufficientPerms 错误
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_OwnerMismatch(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
apiKey: &APIKey{ID: 10, UserID: 1, Key: "k"},
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 10, 2) // API Key ID=10, 调用者 userID=2
|
||||
require.ErrorIs(t, err, ErrInsufficientPerms)
|
||||
require.Empty(t, repo.deletedIDs) // 验证删除操作未被调用
|
||||
require.Empty(t, cache.invalidated) // 验证缓存未被清除
|
||||
require.Empty(t, cache.deleteAuthKeys)
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_Success 测试所有者成功删除 API Key 的场景。
|
||||
// 预期行为:
|
||||
// - GetKeyAndOwnerID 返回所有者 ID 为 7
|
||||
// - 调用者 userID 为 7(匹配)
|
||||
// - Delete 成功执行
|
||||
// - 缓存被正确清除(使用 ownerID)
|
||||
// - 返回 nil 错误
|
||||
func TestApiKeyService_Delete_Success(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
apiKey: &APIKey{ID: 42, UserID: 7, Key: "k"},
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
svc.lastUsedTouchL1.Store(int64(42), time.Now())
|
||||
|
||||
err := svc.Delete(context.Background(), 42, 7) // API Key ID=42, 调用者 userID=7
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{42}, repo.deletedIDs) // 验证正确的 API Key 被删除
|
||||
require.Equal(t, []int64{7}, cache.invalidated) // 验证所有者的缓存被清除
|
||||
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||
_, exists := svc.lastUsedTouchL1.Load(int64(42))
|
||||
require.False(t, exists, "delete should clear touch debounce cache")
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_NotFound 测试删除不存在的 API Key 时返回正确的错误。
|
||||
// 预期行为:
|
||||
// - GetKeyAndOwnerID 返回 ErrAPIKeyNotFound 错误
|
||||
// - 返回 ErrAPIKeyNotFound 错误(被 fmt.Errorf 包装)
|
||||
// - Delete 方法不被调用
|
||||
// - 缓存不被清除
|
||||
func TestApiKeyService_Delete_NotFound(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{getByIDErr: ErrAPIKeyNotFound}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 99, 1)
|
||||
require.ErrorIs(t, err, ErrAPIKeyNotFound)
|
||||
require.Empty(t, repo.deletedIDs)
|
||||
require.Empty(t, cache.invalidated)
|
||||
require.Empty(t, cache.deleteAuthKeys)
|
||||
}
|
||||
|
||||
// TestApiKeyService_Delete_DeleteFails 测试删除操作失败时的错误处理。
|
||||
// 预期行为:
|
||||
// - GetKeyAndOwnerID 返回正确的所有者 ID
|
||||
// - 所有权验证通过
|
||||
// - 缓存被清除(在删除之前)
|
||||
// - Delete 被调用但返回错误
|
||||
// - 返回包含 "delete api key" 的错误信息
|
||||
func TestApiKeyService_Delete_DeleteFails(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
apiKey: &APIKey{ID: 42, UserID: 3, Key: "k"},
|
||||
deleteErr: errors.New("delete failed"),
|
||||
}
|
||||
cache := &apiKeyCacheStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo, cache: cache}
|
||||
|
||||
err := svc.Delete(context.Background(), 3, 3) // API Key ID=3, 调用者 userID=3
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "delete api key")
|
||||
require.Equal(t, []int64{3}, repo.deletedIDs) // 验证删除操作被调用
|
||||
require.Equal(t, []int64{3}, cache.invalidated) // 验证缓存已被清除(即使删除失败)
|
||||
require.Equal(t, []string{svc.authCacheKey("k")}, cache.deleteAuthKeys)
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/pagination"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type quotaStateRepoStub struct {
|
||||
quotaBaseAPIKeyRepoStub
|
||||
stateCalls int
|
||||
state *APIKeyQuotaUsageState
|
||||
stateErr error
|
||||
}
|
||||
|
||||
func (s *quotaStateRepoStub) IncrementQuotaUsedAndGetState(ctx context.Context, id int64, amount float64) (*APIKeyQuotaUsageState, error) {
|
||||
s.stateCalls++
|
||||
if s.stateErr != nil {
|
||||
return nil, s.stateErr
|
||||
}
|
||||
if s.state == nil {
|
||||
return nil, nil
|
||||
}
|
||||
out := *s.state
|
||||
return &out, nil
|
||||
}
|
||||
|
||||
type quotaStateCacheStub struct {
|
||||
deleteAuthKeys []string
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) GetCreateAttemptCount(context.Context, int64) (int, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) IncrementCreateAttemptCount(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) DeleteCreateAttemptCount(context.Context, int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) IncrementDailyUsage(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) SetDailyUsageExpiry(context.Context, string, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) GetAuthCache(context.Context, string) (*APIKeyAuthCacheEntry, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) SetAuthCache(context.Context, string, *APIKeyAuthCacheEntry, time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) DeleteAuthCache(_ context.Context, key string) error {
|
||||
s.deleteAuthKeys = append(s.deleteAuthKeys, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) PublishAuthCacheInvalidation(context.Context, string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *quotaStateCacheStub) SubscribeAuthCacheInvalidation(context.Context, func(string)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type quotaBaseAPIKeyRepoStub struct {
|
||||
getByIDCalls int
|
||||
}
|
||||
|
||||
func (s *quotaBaseAPIKeyRepoStub) Create(context.Context, *APIKey) error {
|
||||
panic("unexpected Create call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) GetByID(context.Context, int64) (*APIKey, error) {
|
||||
s.getByIDCalls++
|
||||
return nil, nil
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) GetKeyAndOwnerID(context.Context, int64) (string, int64, error) {
|
||||
panic("unexpected GetKeyAndOwnerID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) GetByKey(context.Context, string) (*APIKey, error) {
|
||||
panic("unexpected GetByKey call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) GetByKeyForAuth(context.Context, string) (*APIKey, error) {
|
||||
panic("unexpected GetByKeyForAuth call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) Update(context.Context, *APIKey) error {
|
||||
panic("unexpected Update call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) Delete(context.Context, int64) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ListByUserID(context.Context, int64, pagination.PaginationParams, APIKeyListFilters) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByUserID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) VerifyOwnership(context.Context, int64, []int64) ([]int64, error) {
|
||||
panic("unexpected VerifyOwnership call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) CountByUserID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected CountByUserID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ExistsByKey(context.Context, string) (bool, error) {
|
||||
panic("unexpected ExistsByKey call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ListByGroupID(context.Context, int64, pagination.PaginationParams) ([]APIKey, *pagination.PaginationResult, error) {
|
||||
panic("unexpected ListByGroupID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) SearchAPIKeys(context.Context, int64, string, int) ([]APIKey, error) {
|
||||
panic("unexpected SearchAPIKeys call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ClearGroupIDByGroupID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected ClearGroupIDByGroupID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) CountByGroupID(context.Context, int64) (int64, error) {
|
||||
panic("unexpected CountByGroupID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ListKeysByUserID(context.Context, int64) ([]string, error) {
|
||||
panic("unexpected ListKeysByUserID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ListKeysByGroupID(context.Context, int64) ([]string, error) {
|
||||
panic("unexpected ListKeysByGroupID call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) IncrementQuotaUsed(context.Context, int64, float64) (float64, error) {
|
||||
panic("unexpected IncrementQuotaUsed call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) UpdateLastUsed(context.Context, int64, time.Time) error {
|
||||
panic("unexpected UpdateLastUsed call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) IncrementRateLimitUsage(context.Context, int64, float64) error {
|
||||
panic("unexpected IncrementRateLimitUsage call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) ResetRateLimitWindows(context.Context, int64) error {
|
||||
panic("unexpected ResetRateLimitWindows call")
|
||||
}
|
||||
func (s *quotaBaseAPIKeyRepoStub) GetRateLimitData(context.Context, int64) (*APIKeyRateLimitData, error) {
|
||||
panic("unexpected GetRateLimitData call")
|
||||
}
|
||||
|
||||
func TestAPIKeyService_UpdateQuotaUsed_UsesAtomicStatePath(t *testing.T) {
|
||||
repo := "aStateRepoStub{
|
||||
state: &APIKeyQuotaUsageState{
|
||||
QuotaUsed: 12,
|
||||
Quota: 10,
|
||||
Key: "sk-test-quota",
|
||||
Status: StatusAPIKeyQuotaExhausted,
|
||||
},
|
||||
}
|
||||
cache := "aStateCacheStub{}
|
||||
svc := &APIKeyService{
|
||||
apiKeyRepo: repo,
|
||||
cache: cache,
|
||||
}
|
||||
|
||||
err := svc.UpdateQuotaUsed(context.Background(), 101, 2)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, repo.stateCalls)
|
||||
require.Equal(t, 0, repo.getByIDCalls, "fast path should not re-read API key by id")
|
||||
require.Equal(t, []string{svc.authCacheKey("sk-test-quota")}, cache.deleteAuthKeys)
|
||||
}
|
||||
@@ -0,0 +1,160 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_InvalidKeyID(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
return errors.New("should not be called")
|
||||
},
|
||||
}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
require.NoError(t, svc.TouchLastUsed(context.Background(), 0))
|
||||
require.NoError(t, svc.TouchLastUsed(context.Background(), -1))
|
||||
require.Empty(t, repo.touchedIDs)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_FirstTouchSucceeds(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
err := svc.TouchLastUsed(context.Background(), 123)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []int64{123}, repo.touchedIDs)
|
||||
require.Len(t, repo.touchedUsedAts, 1)
|
||||
require.False(t, repo.touchedUsedAts[0].IsZero())
|
||||
|
||||
cached, ok := svc.lastUsedTouchL1.Load(int64(123))
|
||||
require.True(t, ok, "successful touch should update debounce cache")
|
||||
_, isTime := cached.(time.Time)
|
||||
require.True(t, isTime)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_DebouncedWithinWindow(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
require.NoError(t, svc.TouchLastUsed(context.Background(), 123))
|
||||
require.NoError(t, svc.TouchLastUsed(context.Background(), 123))
|
||||
|
||||
require.Equal(t, []int64{123}, repo.touchedIDs, "second touch within debounce window should not hit repository")
|
||||
}
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_ExpiredDebounceTouchesAgain(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
require.NoError(t, svc.TouchLastUsed(context.Background(), 123))
|
||||
|
||||
// 强制将 debounce 时间回拨到窗口之外,触发第二次写库。
|
||||
svc.lastUsedTouchL1.Store(int64(123), time.Now().Add(-apiKeyLastUsedMinTouch-time.Second))
|
||||
|
||||
require.NoError(t, svc.TouchLastUsed(context.Background(), 123))
|
||||
require.Len(t, repo.touchedIDs, 2)
|
||||
require.Equal(t, int64(123), repo.touchedIDs[0])
|
||||
require.Equal(t, int64(123), repo.touchedIDs[1])
|
||||
}
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_RepoError(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
return errors.New("db write failed")
|
||||
},
|
||||
}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
err := svc.TouchLastUsed(context.Background(), 123)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "touch api key last used")
|
||||
require.Equal(t, []int64{123}, repo.touchedIDs)
|
||||
|
||||
cached, ok := svc.lastUsedTouchL1.Load(int64(123))
|
||||
require.True(t, ok, "failed touch should still update retry debounce cache")
|
||||
_, isTime := cached.(time.Time)
|
||||
require.True(t, isTime)
|
||||
}
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_RepoErrorDebounced(t *testing.T) {
|
||||
repo := &apiKeyRepoStub{
|
||||
updateLastUsed: func(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
return errors.New("db write failed")
|
||||
},
|
||||
}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
firstErr := svc.TouchLastUsed(context.Background(), 456)
|
||||
require.Error(t, firstErr)
|
||||
require.ErrorContains(t, firstErr, "touch api key last used")
|
||||
|
||||
secondErr := svc.TouchLastUsed(context.Background(), 456)
|
||||
require.NoError(t, secondErr, "failed touch should be debounced and skip immediate retry")
|
||||
require.Equal(t, []int64{456}, repo.touchedIDs, "debounced retry should not hit repository again")
|
||||
}
|
||||
|
||||
type touchSingleflightRepo struct {
|
||||
*apiKeyRepoStub
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
blockCh chan struct{}
|
||||
}
|
||||
|
||||
func (r *touchSingleflightRepo) UpdateLastUsed(ctx context.Context, id int64, usedAt time.Time) error {
|
||||
r.mu.Lock()
|
||||
r.calls++
|
||||
r.mu.Unlock()
|
||||
<-r.blockCh
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAPIKeyService_TouchLastUsed_ConcurrentFirstTouchDeduplicated(t *testing.T) {
|
||||
repo := &touchSingleflightRepo{
|
||||
apiKeyRepoStub: &apiKeyRepoStub{},
|
||||
blockCh: make(chan struct{}),
|
||||
}
|
||||
svc := &APIKeyService{apiKeyRepo: repo}
|
||||
|
||||
const workers = 20
|
||||
startCh := make(chan struct{})
|
||||
errCh := make(chan error, workers)
|
||||
var wg sync.WaitGroup
|
||||
|
||||
for i := 0; i < workers; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-startCh
|
||||
errCh <- svc.TouchLastUsed(context.Background(), 321)
|
||||
}()
|
||||
}
|
||||
|
||||
close(startCh)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
return repo.calls >= 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
close(repo.blockCh)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
repo.mu.Lock()
|
||||
defer repo.mu.Unlock()
|
||||
require.Equal(t, 1, repo.calls, "并发首次 touch 只应写库一次")
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUsageService_InvalidateUsageCaches(t *testing.T) {
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &UsageService{authCacheInvalidator: invalidator}
|
||||
|
||||
svc.invalidateUsageCaches(context.Background(), 7, false)
|
||||
require.Empty(t, invalidator.userIDs)
|
||||
|
||||
svc.invalidateUsageCaches(context.Background(), 7, true)
|
||||
require.Equal(t, []int64{7}, invalidator.userIDs)
|
||||
}
|
||||
|
||||
func TestRedeemService_InvalidateRedeemCaches_AuthCache(t *testing.T) {
|
||||
invalidator := &authCacheInvalidatorStub{}
|
||||
svc := &RedeemService{authCacheInvalidator: invalidator}
|
||||
|
||||
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeBalance})
|
||||
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeConcurrency})
|
||||
groupID := int64(3)
|
||||
svc.invalidateRedeemCaches(context.Background(), 11, &RedeemCode{Type: RedeemTypeSubscription, GroupID: &groupID})
|
||||
|
||||
require.Equal(t, []int64{11, 11, 11}, invalidator.userIDs)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,146 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newAuthServiceForPendingOAuthTest() *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret-pending-oauth",
|
||||
ExpireHour: 1,
|
||||
},
|
||||
}
|
||||
return NewAuthService(nil, nil, nil, nil, cfg, nil, nil, nil, nil, nil, nil)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_ValidToken 验证正常签发的 pending token 可以被成功解析。
|
||||
func TestVerifyPendingOAuthToken_ValidToken(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
token, err := svc.CreatePendingOAuthToken("user@example.com", "alice")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
|
||||
email, username, err := svc.VerifyPendingOAuthToken(token)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "user@example.com", email)
|
||||
require.Equal(t, "alice", username)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_RegularJWTRejected 用普通 access token 尝试验证,应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_RegularJWTRejected(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
// 签发一个普通 access token(JWTClaims,无 Purpose 字段)
|
||||
accessToken, err := svc.GenerateToken(&User{
|
||||
ID: 1,
|
||||
Email: "user@example.com",
|
||||
Role: RoleUser,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(accessToken)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_WrongPurpose 手动构造 purpose 字段不匹配的 JWT,应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_WrongPurpose(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: "user@example.com",
|
||||
Username: "alice",
|
||||
Purpose: "some_other_purpose",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_MissingPurpose 手动构造无 purpose 字段的 JWT(模拟旧 token),应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_MissingPurpose(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
now := time.Now()
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: "user@example.com",
|
||||
Username: "alice",
|
||||
Purpose: "", // 旧 token 无此字段,反序列化后为零值
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(10 * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_ExpiredToken 过期 token 应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_ExpiredToken(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
|
||||
past := time.Now().Add(-1 * time.Hour)
|
||||
claims := &pendingOAuthClaims{
|
||||
Email: "user@example.com",
|
||||
Username: "alice",
|
||||
Purpose: pendingOAuthPurpose,
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(past),
|
||||
IssuedAt: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
|
||||
NotBefore: jwt.NewNumericDate(past.Add(-10 * time.Minute)),
|
||||
},
|
||||
}
|
||||
tok := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
|
||||
tokenStr, err := tok.SignedString([]byte(svc.cfg.JWT.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
_, _, err = svc.VerifyPendingOAuthToken(tokenStr)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_WrongSecret 不同密钥签发的 token 应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_WrongSecret(t *testing.T) {
|
||||
other := NewAuthService(nil, nil, nil, nil, &config.Config{
|
||||
JWT: config.JWTConfig{Secret: "other-secret"},
|
||||
}, nil, nil, nil, nil, nil, nil)
|
||||
|
||||
token, err := other.CreatePendingOAuthToken("user@example.com", "alice")
|
||||
require.NoError(t, err)
|
||||
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
_, _, err = svc.VerifyPendingOAuthToken(token)
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
|
||||
// TestVerifyPendingOAuthToken_TooLong 超长 token 应返回 ErrInvalidToken。
|
||||
func TestVerifyPendingOAuthToken_TooLong(t *testing.T) {
|
||||
svc := newAuthServiceForPendingOAuthTest()
|
||||
giant := make([]byte, maxTokenLength+1)
|
||||
for i := range giant {
|
||||
giant[i] = 'a'
|
||||
}
|
||||
_, _, err := svc.VerifyPendingOAuthToken(string(giant))
|
||||
require.ErrorIs(t, err, ErrInvalidToken)
|
||||
}
|
||||
@@ -0,0 +1,466 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type settingRepoStub struct {
|
||||
values map[string]string
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) Get(ctx context.Context, key string) (*Setting, error) {
|
||||
panic("unexpected Get call")
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) GetValue(ctx context.Context, key string) (string, error) {
|
||||
if s.err != nil {
|
||||
return "", s.err
|
||||
}
|
||||
if v, ok := s.values[key]; ok {
|
||||
return v, nil
|
||||
}
|
||||
return "", ErrSettingNotFound
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) Set(ctx context.Context, key, value string) error {
|
||||
panic("unexpected Set call")
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) GetMultiple(ctx context.Context, keys []string) (map[string]string, error) {
|
||||
panic("unexpected GetMultiple call")
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) SetMultiple(ctx context.Context, settings map[string]string) error {
|
||||
panic("unexpected SetMultiple call")
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) GetAll(ctx context.Context) (map[string]string, error) {
|
||||
panic("unexpected GetAll call")
|
||||
}
|
||||
|
||||
func (s *settingRepoStub) Delete(ctx context.Context, key string) error {
|
||||
panic("unexpected Delete call")
|
||||
}
|
||||
|
||||
type emailCacheStub struct {
|
||||
data *VerificationCodeData
|
||||
err error
|
||||
}
|
||||
|
||||
type defaultSubscriptionAssignerStub struct {
|
||||
calls []AssignSubscriptionInput
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *defaultSubscriptionAssignerStub) AssignOrExtendSubscription(_ context.Context, input *AssignSubscriptionInput) (*UserSubscription, bool, error) {
|
||||
if input != nil {
|
||||
s.calls = append(s.calls, *input)
|
||||
}
|
||||
if s.err != nil {
|
||||
return nil, false, s.err
|
||||
}
|
||||
return &UserSubscription{UserID: input.UserID, GroupID: input.GroupID}, false, nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) GetVerificationCode(ctx context.Context, email string) (*VerificationCodeData, error) {
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.data, nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) SetVerificationCode(ctx context.Context, email string, data *VerificationCodeData, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) DeleteVerificationCode(ctx context.Context, email string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) GetPasswordResetToken(ctx context.Context, email string) (*PasswordResetTokenData, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) SetPasswordResetToken(ctx context.Context, email string, data *PasswordResetTokenData, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) DeletePasswordResetToken(ctx context.Context, email string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) IsPasswordResetEmailInCooldown(ctx context.Context, email string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *emailCacheStub) SetPasswordResetEmailCooldown(ctx context.Context, email string, ttl time.Duration) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newAuthService(repo *userRepoStub, settings map[string]string, emailCache EmailCache) *AuthService {
|
||||
cfg := &config.Config{
|
||||
JWT: config.JWTConfig{
|
||||
Secret: "test-secret",
|
||||
ExpireHour: 1,
|
||||
},
|
||||
Default: config.DefaultConfig{
|
||||
UserBalance: 3.5,
|
||||
UserConcurrency: 2,
|
||||
},
|
||||
}
|
||||
|
||||
var settingService *SettingService
|
||||
if settings != nil {
|
||||
settingService = NewSettingService(&settingRepoStub{values: settings}, cfg)
|
||||
}
|
||||
|
||||
var emailService *EmailService
|
||||
if emailCache != nil {
|
||||
emailService = NewEmailService(&settingRepoStub{values: settings}, emailCache)
|
||||
}
|
||||
|
||||
return NewAuthService(
|
||||
nil, // entClient
|
||||
repo,
|
||||
nil, // redeemRepo
|
||||
nil, // refreshTokenCache
|
||||
cfg,
|
||||
settingService,
|
||||
emailService,
|
||||
nil,
|
||||
nil,
|
||||
nil, // promoService
|
||||
nil, // defaultSubAssigner
|
||||
)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_Disabled(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "false",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_DisabledByDefault(t *testing.T) {
|
||||
// 当 settings 为 nil(设置项不存在)时,注册应该默认关闭
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrRegDisabled)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyEnabledButServiceNotConfigured(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
// 邮件验证开启但 emailCache 为 nil(emailService 未配置)
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
// 应返回服务不可用错误,而不是允许绕过验证
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "any-code", "", "")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyRequired(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
cache := &emailCacheStub{} // 配置 emailService
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "", "", "")
|
||||
require.ErrorIs(t, err, ErrEmailVerifyRequired)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailVerifyInvalid(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
cache := &emailCacheStub{
|
||||
data: &VerificationCodeData{Code: "expected", Attempts: 0},
|
||||
}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
}, cache)
|
||||
|
||||
_, _, err := service.RegisterWithVerification(context.Background(), "user@test.com", "password", "wrong", "", "")
|
||||
require.ErrorIs(t, err, ErrInvalidVerifyCode)
|
||||
require.ErrorContains(t, err, "verify code")
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailExists(t *testing.T) {
|
||||
repo := &userRepoStub{exists: true}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CheckEmailError(t *testing.T) {
|
||||
repo := &userRepoStub{existsErr: errors.New("db down")}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_ReservedEmail(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "linuxdo-123@linuxdo-connect.invalid", "password")
|
||||
require.ErrorIs(t, err, ErrEmailReserved)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailSuffixNotAllowed(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@other.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
|
||||
appErr := infraerrors.FromError(err)
|
||||
require.Contains(t, appErr.Message, "@example.com")
|
||||
require.Contains(t, appErr.Message, "@company.com")
|
||||
require.Equal(t, "EMAIL_SUFFIX_NOT_ALLOWED", appErr.Reason)
|
||||
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
|
||||
require.Equal(t, "@example.com,@company.com", appErr.Metadata["allowed_suffixes"])
|
||||
}
|
||||
|
||||
func TestAuthService_Register_EmailSuffixAllowed(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 8}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["example.com"]`,
|
||||
}, nil)
|
||||
|
||||
_, user, err := service.Register(context.Background(), "user@example.com", "password")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.Equal(t, int64(8), user.ID)
|
||||
}
|
||||
|
||||
func TestAuthService_SendVerifyCode_EmailSuffixNotAllowed(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyRegistrationEmailSuffixWhitelist: `["@example.com","@company.com"]`,
|
||||
}, nil)
|
||||
|
||||
err := service.SendVerifyCode(context.Background(), "user@other.com")
|
||||
require.ErrorIs(t, err, ErrEmailSuffixNotAllowed)
|
||||
appErr := infraerrors.FromError(err)
|
||||
require.Contains(t, appErr.Message, "@example.com")
|
||||
require.Contains(t, appErr.Message, "@company.com")
|
||||
require.Equal(t, "2", appErr.Metadata["allowed_suffix_count"])
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CreateError(t *testing.T) {
|
||||
repo := &userRepoStub{createErr: errors.New("create failed")}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrServiceUnavailable)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_CreateEmailExistsRace(t *testing.T) {
|
||||
// 模拟竞态条件:ExistsByEmail 返回 false,但 Create 时因唯一约束失败
|
||||
repo := &userRepoStub{createErr: ErrEmailExists}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
_, _, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.ErrorIs(t, err, ErrEmailExists)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_Success(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 5}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, nil)
|
||||
|
||||
token, user, err := service.Register(context.Background(), "user@test.com", "password")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
require.NotNil(t, user)
|
||||
require.Equal(t, int64(5), user.ID)
|
||||
require.Equal(t, "user@test.com", user.Email)
|
||||
require.Equal(t, RoleUser, user.Role)
|
||||
require.Equal(t, StatusActive, user.Status)
|
||||
require.Equal(t, 3.5, user.Balance)
|
||||
require.Equal(t, 2, user.Concurrency)
|
||||
require.Len(t, repo.created, 1)
|
||||
require.True(t, user.CheckPassword("password"))
|
||||
}
|
||||
|
||||
func TestAuthService_ValidateToken_ExpiredReturnsClaimsWithError(t *testing.T) {
|
||||
repo := &userRepoStub{}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
// 创建用户并生成 token
|
||||
user := &User{
|
||||
ID: 1,
|
||||
Email: "test@test.com",
|
||||
Role: RoleUser,
|
||||
Status: StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
token, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证有效 token
|
||||
claims, err := service.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, claims)
|
||||
require.Equal(t, int64(1), claims.UserID)
|
||||
|
||||
// 模拟过期 token(通过创建一个过期很久的 token)
|
||||
service.cfg.JWT.ExpireHour = -1 // 设置为负数使 token 立即过期
|
||||
expiredToken, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
service.cfg.JWT.ExpireHour = 1 // 恢复
|
||||
|
||||
// 验证过期 token 应返回 claims 和 ErrTokenExpired
|
||||
claims, err = service.ValidateToken(expiredToken)
|
||||
require.ErrorIs(t, err, ErrTokenExpired)
|
||||
require.NotNil(t, claims, "claims should not be nil when token is expired")
|
||||
require.Equal(t, int64(1), claims.UserID)
|
||||
require.Equal(t, "test@test.com", claims.Email)
|
||||
}
|
||||
|
||||
func TestAuthService_RefreshToken_ExpiredTokenNoPanic(t *testing.T) {
|
||||
user := &User{
|
||||
ID: 1,
|
||||
Email: "test@test.com",
|
||||
Role: RoleUser,
|
||||
Status: StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
repo := &userRepoStub{user: user}
|
||||
service := newAuthService(repo, nil, nil)
|
||||
|
||||
// 创建过期 token
|
||||
service.cfg.JWT.ExpireHour = -1
|
||||
expiredToken, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
service.cfg.JWT.ExpireHour = 1
|
||||
|
||||
// RefreshToken 使用过期 token 不应 panic
|
||||
require.NotPanics(t, func() {
|
||||
newToken, err := service.RefreshToken(context.Background(), expiredToken)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, newToken)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthService_GetAccessTokenExpiresIn_FallbackToExpireHour(t *testing.T) {
|
||||
service := newAuthService(&userRepoStub{}, nil, nil)
|
||||
service.cfg.JWT.ExpireHour = 24
|
||||
service.cfg.JWT.AccessTokenExpireMinutes = 0
|
||||
|
||||
require.Equal(t, 24*3600, service.GetAccessTokenExpiresIn())
|
||||
}
|
||||
|
||||
func TestAuthService_GetAccessTokenExpiresIn_MinutesHasPriority(t *testing.T) {
|
||||
service := newAuthService(&userRepoStub{}, nil, nil)
|
||||
service.cfg.JWT.ExpireHour = 24
|
||||
service.cfg.JWT.AccessTokenExpireMinutes = 90
|
||||
|
||||
require.Equal(t, 90*60, service.GetAccessTokenExpiresIn())
|
||||
}
|
||||
|
||||
func TestAuthService_GenerateToken_UsesExpireHourWhenMinutesZero(t *testing.T) {
|
||||
service := newAuthService(&userRepoStub{}, nil, nil)
|
||||
service.cfg.JWT.ExpireHour = 24
|
||||
service.cfg.JWT.AccessTokenExpireMinutes = 0
|
||||
|
||||
user := &User{
|
||||
ID: 1,
|
||||
Email: "test@test.com",
|
||||
Role: RoleUser,
|
||||
Status: StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
|
||||
token, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := service.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, claims)
|
||||
require.NotNil(t, claims.IssuedAt)
|
||||
require.NotNil(t, claims.ExpiresAt)
|
||||
|
||||
require.WithinDuration(t, claims.IssuedAt.Time.Add(24*time.Hour), claims.ExpiresAt.Time, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestAuthService_GenerateToken_UsesMinutesWhenConfigured(t *testing.T) {
|
||||
service := newAuthService(&userRepoStub{}, nil, nil)
|
||||
service.cfg.JWT.ExpireHour = 24
|
||||
service.cfg.JWT.AccessTokenExpireMinutes = 90
|
||||
|
||||
user := &User{
|
||||
ID: 2,
|
||||
Email: "test2@test.com",
|
||||
Role: RoleUser,
|
||||
Status: StatusActive,
|
||||
TokenVersion: 1,
|
||||
}
|
||||
|
||||
token, err := service.GenerateToken(user)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := service.ValidateToken(token)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, claims)
|
||||
require.NotNil(t, claims.IssuedAt)
|
||||
require.NotNil(t, claims.ExpiresAt)
|
||||
|
||||
require.WithinDuration(t, claims.IssuedAt.Time.Add(90*time.Minute), claims.ExpiresAt.Time, 2*time.Second)
|
||||
}
|
||||
|
||||
func TestAuthService_Register_AssignsDefaultSubscriptions(t *testing.T) {
|
||||
repo := &userRepoStub{nextID: 42}
|
||||
assigner := &defaultSubscriptionAssignerStub{}
|
||||
service := newAuthService(repo, map[string]string{
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
SettingKeyDefaultSubscriptions: `[{"group_id":11,"validity_days":30},{"group_id":12,"validity_days":7}]`,
|
||||
}, nil)
|
||||
service.defaultSubAssigner = assigner
|
||||
|
||||
_, user, err := service.Register(context.Background(), "default-sub@test.com", "password")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, user)
|
||||
require.Len(t, assigner.calls, 2)
|
||||
require.Equal(t, int64(42), assigner.calls[0].UserID)
|
||||
require.Equal(t, int64(11), assigner.calls[0].GroupID)
|
||||
require.Equal(t, 30, assigner.calls[0].ValidityDays)
|
||||
require.Equal(t, int64(12), assigner.calls[1].GroupID)
|
||||
require.Equal(t, 7, assigner.calls[1].ValidityDays)
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type turnstileVerifierSpy struct {
|
||||
called int
|
||||
lastToken string
|
||||
result *TurnstileVerifyResponse
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *turnstileVerifierSpy) VerifyToken(_ context.Context, _ string, token, _ string) (*TurnstileVerifyResponse, error) {
|
||||
s.called++
|
||||
s.lastToken = token
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
if s.result != nil {
|
||||
return s.result, nil
|
||||
}
|
||||
return &TurnstileVerifyResponse{Success: true}, nil
|
||||
}
|
||||
|
||||
func newAuthServiceForRegisterTurnstileTest(settings map[string]string, verifier TurnstileVerifier) *AuthService {
|
||||
cfg := &config.Config{
|
||||
Server: config.ServerConfig{
|
||||
Mode: "release",
|
||||
},
|
||||
Turnstile: config.TurnstileConfig{
|
||||
Required: true,
|
||||
},
|
||||
}
|
||||
|
||||
settingService := NewSettingService(&settingRepoStub{values: settings}, cfg)
|
||||
turnstileService := NewTurnstileService(settingService, verifier)
|
||||
|
||||
return NewAuthService(
|
||||
nil, // entClient
|
||||
&userRepoStub{},
|
||||
nil, // redeemRepo
|
||||
nil, // refreshTokenCache
|
||||
cfg,
|
||||
settingService,
|
||||
nil, // emailService
|
||||
turnstileService,
|
||||
nil, // emailQueueService
|
||||
nil, // promoService
|
||||
nil, // defaultSubAssigner
|
||||
)
|
||||
}
|
||||
|
||||
func TestAuthService_VerifyTurnstileForRegister_SkipWhenEmailVerifyCodeProvided(t *testing.T) {
|
||||
verifier := &turnstileVerifierSpy{}
|
||||
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
SettingKeyTurnstileEnabled: "true",
|
||||
SettingKeyTurnstileSecretKey: "secret",
|
||||
SettingKeyRegistrationEnabled: "true",
|
||||
}, verifier)
|
||||
|
||||
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "123456")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, verifier.called)
|
||||
}
|
||||
|
||||
func TestAuthService_VerifyTurnstileForRegister_RequireWhenVerifyCodeMissing(t *testing.T) {
|
||||
verifier := &turnstileVerifierSpy{}
|
||||
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
|
||||
SettingKeyEmailVerifyEnabled: "true",
|
||||
SettingKeyTurnstileEnabled: "true",
|
||||
SettingKeyTurnstileSecretKey: "secret",
|
||||
}, verifier)
|
||||
|
||||
err := service.VerifyTurnstileForRegister(context.Background(), "", "127.0.0.1", "")
|
||||
require.ErrorIs(t, err, ErrTurnstileVerificationFailed)
|
||||
}
|
||||
|
||||
func TestAuthService_VerifyTurnstileForRegister_NoSkipWhenEmailVerifyDisabled(t *testing.T) {
|
||||
verifier := &turnstileVerifierSpy{}
|
||||
service := newAuthServiceForRegisterTurnstileTest(map[string]string{
|
||||
SettingKeyEmailVerifyEnabled: "false",
|
||||
SettingKeyTurnstileEnabled: "true",
|
||||
SettingKeyTurnstileSecretKey: "secret",
|
||||
}, verifier)
|
||||
|
||||
err := service.VerifyTurnstileForRegister(context.Background(), "turnstile-token", "127.0.0.1", "123456")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, verifier.called)
|
||||
require.Equal(t, "turnstile-token", verifier.lastToken)
|
||||
}
|
||||
@@ -0,0 +1,770 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/robfig/cron/v3"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
settingKeyBackupS3Config = "backup_s3_config"
|
||||
settingKeyBackupSchedule = "backup_schedule"
|
||||
settingKeyBackupRecords = "backup_records"
|
||||
|
||||
maxBackupRecords = 100
|
||||
)
|
||||
|
||||
var (
|
||||
ErrBackupS3NotConfigured = infraerrors.BadRequest("BACKUP_S3_NOT_CONFIGURED", "backup S3 storage is not configured")
|
||||
ErrBackupNotFound = infraerrors.NotFound("BACKUP_NOT_FOUND", "backup record not found")
|
||||
ErrBackupInProgress = infraerrors.Conflict("BACKUP_IN_PROGRESS", "a backup is already in progress")
|
||||
ErrRestoreInProgress = infraerrors.Conflict("RESTORE_IN_PROGRESS", "a restore is already in progress")
|
||||
ErrBackupRecordsCorrupt = infraerrors.InternalServer("BACKUP_RECORDS_CORRUPT", "backup records data is corrupted")
|
||||
ErrBackupS3ConfigCorrupt = infraerrors.InternalServer("BACKUP_S3_CONFIG_CORRUPT", "backup S3 config data is corrupted")
|
||||
)
|
||||
|
||||
// ─── 接口定义 ───
|
||||
|
||||
// DBDumper abstracts database dump/restore operations
|
||||
type DBDumper interface {
|
||||
Dump(ctx context.Context) (io.ReadCloser, error)
|
||||
Restore(ctx context.Context, data io.Reader) error
|
||||
}
|
||||
|
||||
// BackupObjectStore abstracts object storage for backup files
|
||||
type BackupObjectStore interface {
|
||||
Upload(ctx context.Context, key string, body io.Reader, contentType string) (sizeBytes int64, err error)
|
||||
Download(ctx context.Context, key string) (io.ReadCloser, error)
|
||||
Delete(ctx context.Context, key string) error
|
||||
PresignURL(ctx context.Context, key string, expiry time.Duration) (string, error)
|
||||
HeadBucket(ctx context.Context) error
|
||||
}
|
||||
|
||||
// BackupObjectStoreFactory creates an object store from S3 config
|
||||
type BackupObjectStoreFactory func(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error)
|
||||
|
||||
// ─── 数据模型 ───
|
||||
|
||||
// BackupS3Config S3 兼容存储配置(支持 Cloudflare R2)
|
||||
type BackupS3Config struct {
|
||||
Endpoint string `json:"endpoint"` // e.g. https://<account_id>.r2.cloudflarestorage.com
|
||||
Region string `json:"region"` // R2 用 "auto"
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key,omitempty"` //nolint:revive // field name follows AWS convention
|
||||
Prefix string `json:"prefix"` // S3 key 前缀,如 "backups/"
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
}
|
||||
|
||||
// IsConfigured 检查必要字段是否已配置
|
||||
func (c *BackupS3Config) IsConfigured() bool {
|
||||
return c.Bucket != "" && c.AccessKeyID != "" && c.SecretAccessKey != ""
|
||||
}
|
||||
|
||||
// BackupScheduleConfig 定时备份配置
|
||||
type BackupScheduleConfig struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
CronExpr string `json:"cron_expr"` // cron 表达式,如 "0 2 * * *" 每天凌晨2点
|
||||
RetainDays int `json:"retain_days"` // 备份文件过期天数,默认14,0=不自动清理
|
||||
RetainCount int `json:"retain_count"` // 最多保留份数,0=不限制
|
||||
}
|
||||
|
||||
// BackupRecord 备份记录
|
||||
type BackupRecord struct {
|
||||
ID string `json:"id"`
|
||||
Status string `json:"status"` // pending, running, completed, failed
|
||||
BackupType string `json:"backup_type"` // postgres
|
||||
FileName string `json:"file_name"`
|
||||
S3Key string `json:"s3_key"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
TriggeredBy string `json:"triggered_by"` // manual, scheduled
|
||||
ErrorMsg string `json:"error_message,omitempty"`
|
||||
StartedAt string `json:"started_at"`
|
||||
FinishedAt string `json:"finished_at,omitempty"`
|
||||
ExpiresAt string `json:"expires_at,omitempty"` // 过期时间
|
||||
}
|
||||
|
||||
// BackupService 数据库备份恢复服务
|
||||
type BackupService struct {
|
||||
settingRepo SettingRepository
|
||||
dbCfg *config.DatabaseConfig
|
||||
encryptor SecretEncryptor
|
||||
storeFactory BackupObjectStoreFactory
|
||||
dumper DBDumper
|
||||
|
||||
mu sync.Mutex
|
||||
store BackupObjectStore
|
||||
s3Cfg *BackupS3Config
|
||||
backingUp bool
|
||||
restoring bool
|
||||
|
||||
recordsMu sync.Mutex // 保护 records 的 load/save 操作
|
||||
|
||||
cronMu sync.Mutex
|
||||
cronSched *cron.Cron
|
||||
cronEntryID cron.EntryID
|
||||
}
|
||||
|
||||
func NewBackupService(
|
||||
settingRepo SettingRepository,
|
||||
cfg *config.Config,
|
||||
encryptor SecretEncryptor,
|
||||
storeFactory BackupObjectStoreFactory,
|
||||
dumper DBDumper,
|
||||
) *BackupService {
|
||||
return &BackupService{
|
||||
settingRepo: settingRepo,
|
||||
dbCfg: &cfg.Database,
|
||||
encryptor: encryptor,
|
||||
storeFactory: storeFactory,
|
||||
dumper: dumper,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动定时备份调度器
|
||||
func (s *BackupService) Start() {
|
||||
s.cronSched = cron.New()
|
||||
s.cronSched.Start()
|
||||
|
||||
// 加载已有的定时配置
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
schedule, err := s.GetSchedule(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 加载定时备份配置失败: %v", err)
|
||||
return
|
||||
}
|
||||
if schedule.Enabled && schedule.CronExpr != "" {
|
||||
if err := s.applyCronSchedule(schedule); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 应用定时备份配置失败: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stop 停止定时备份
|
||||
func (s *BackupService) Stop() {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
if s.cronSched != nil {
|
||||
s.cronSched.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
// ─── S3 配置管理 ───
|
||||
|
||||
func (s *BackupService) GetS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||
cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cfg == nil {
|
||||
return &BackupS3Config{}, nil
|
||||
}
|
||||
// 脱敏返回
|
||||
cfg.SecretAccessKey = ""
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) UpdateS3Config(ctx context.Context, cfg BackupS3Config) (*BackupS3Config, error) {
|
||||
// 如果没提供 secret,保留原有值
|
||||
if cfg.SecretAccessKey == "" {
|
||||
old, _ := s.loadS3Config(ctx)
|
||||
if old != nil {
|
||||
cfg.SecretAccessKey = old.SecretAccessKey
|
||||
}
|
||||
} else {
|
||||
// 加密 SecretAccessKey
|
||||
encrypted, err := s.encryptor.Encrypt(cfg.SecretAccessKey)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encrypt secret: %w", err)
|
||||
}
|
||||
cfg.SecretAccessKey = encrypted
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal s3 config: %w", err)
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, settingKeyBackupS3Config, string(data)); err != nil {
|
||||
return nil, fmt.Errorf("save s3 config: %w", err)
|
||||
}
|
||||
|
||||
// 清除缓存的 S3 客户端
|
||||
s.mu.Lock()
|
||||
s.store = nil
|
||||
s.s3Cfg = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
cfg.SecretAccessKey = ""
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) TestS3Connection(ctx context.Context, cfg BackupS3Config) error {
|
||||
// 如果没提供 secret,用已保存的
|
||||
if cfg.SecretAccessKey == "" {
|
||||
old, _ := s.loadS3Config(ctx)
|
||||
if old != nil {
|
||||
cfg.SecretAccessKey = old.SecretAccessKey
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.Bucket == "" || cfg.AccessKeyID == "" || cfg.SecretAccessKey == "" {
|
||||
return fmt.Errorf("incomplete S3 config: bucket, access_key_id, secret_access_key are required")
|
||||
}
|
||||
|
||||
store, err := s.storeFactory(ctx, &cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return store.HeadBucket(ctx)
|
||||
}
|
||||
|
||||
// ─── 定时备份管理 ───
|
||||
|
||||
func (s *BackupService) GetSchedule(ctx context.Context) (*BackupScheduleConfig, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupSchedule)
|
||||
if err != nil || raw == "" {
|
||||
return &BackupScheduleConfig{}, nil
|
||||
}
|
||||
var cfg BackupScheduleConfig
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
return &BackupScheduleConfig{}, nil
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) UpdateSchedule(ctx context.Context, cfg BackupScheduleConfig) (*BackupScheduleConfig, error) {
|
||||
if cfg.Enabled && cfg.CronExpr == "" {
|
||||
return nil, infraerrors.BadRequest("INVALID_CRON", "cron expression is required when schedule is enabled")
|
||||
}
|
||||
// 验证 cron 表达式
|
||||
if cfg.CronExpr != "" {
|
||||
parser := cron.NewParser(cron.Minute | cron.Hour | cron.Dom | cron.Month | cron.Dow)
|
||||
if _, err := parser.Parse(cfg.CronExpr); err != nil {
|
||||
return nil, infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("invalid cron expression: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshal schedule config: %w", err)
|
||||
}
|
||||
if err := s.settingRepo.Set(ctx, settingKeyBackupSchedule, string(data)); err != nil {
|
||||
return nil, fmt.Errorf("save schedule config: %w", err)
|
||||
}
|
||||
|
||||
// 应用或停止定时任务
|
||||
if cfg.Enabled {
|
||||
if err := s.applyCronSchedule(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
s.removeCronSchedule()
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) applyCronSchedule(cfg *BackupScheduleConfig) error {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
|
||||
if s.cronSched == nil {
|
||||
return fmt.Errorf("cron scheduler not initialized")
|
||||
}
|
||||
|
||||
// 移除旧任务
|
||||
if s.cronEntryID != 0 {
|
||||
s.cronSched.Remove(s.cronEntryID)
|
||||
s.cronEntryID = 0
|
||||
}
|
||||
|
||||
entryID, err := s.cronSched.AddFunc(cfg.CronExpr, func() {
|
||||
s.runScheduledBackup()
|
||||
})
|
||||
if err != nil {
|
||||
return infraerrors.BadRequest("INVALID_CRON", fmt.Sprintf("failed to schedule: %v", err))
|
||||
}
|
||||
s.cronEntryID = entryID
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已启用: %s", cfg.CronExpr)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) removeCronSchedule() {
|
||||
s.cronMu.Lock()
|
||||
defer s.cronMu.Unlock()
|
||||
if s.cronSched != nil && s.cronEntryID != 0 {
|
||||
s.cronSched.Remove(s.cronEntryID)
|
||||
s.cronEntryID = 0
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份已停用")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BackupService) runScheduledBackup() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
// 读取定时备份配置中的过期天数
|
||||
schedule, _ := s.GetSchedule(ctx)
|
||||
expireDays := 14 // 默认14天过期
|
||||
if schedule != nil && schedule.RetainDays > 0 {
|
||||
expireDays = schedule.RetainDays
|
||||
}
|
||||
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 开始执行定时备份, 过期天数: %d", expireDays)
|
||||
record, err := s.CreateBackup(ctx, "scheduled", expireDays)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份失败: %v", err)
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 定时备份完成: id=%s size=%d", record.ID, record.SizeBytes)
|
||||
|
||||
// 清理过期备份(复用已加载的 schedule)
|
||||
if schedule == nil {
|
||||
return
|
||||
}
|
||||
if err := s.cleanupOldBackups(ctx, schedule); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 清理过期备份失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ─── 备份/恢复核心 ───
|
||||
|
||||
// CreateBackup 创建全量数据库备份并上传到 S3(流式处理)
|
||||
// expireDays: 备份过期天数,0=永不过期,默认14天
|
||||
func (s *BackupService) CreateBackup(ctx context.Context, triggeredBy string, expireDays int) (*BackupRecord, error) {
|
||||
s.mu.Lock()
|
||||
if s.backingUp {
|
||||
s.mu.Unlock()
|
||||
return nil, ErrBackupInProgress
|
||||
}
|
||||
s.backingUp = true
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.backingUp = false
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s3Cfg == nil || !s3Cfg.IsConfigured() {
|
||||
return nil, ErrBackupS3NotConfigured
|
||||
}
|
||||
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("init object store: %w", err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
backupID := uuid.New().String()[:8]
|
||||
fileName := fmt.Sprintf("%s_%s.sql.gz", s.dbCfg.DBName, now.Format("20060102_150405"))
|
||||
s3Key := s.buildS3Key(s3Cfg, fileName)
|
||||
|
||||
var expiresAt string
|
||||
if expireDays > 0 {
|
||||
expiresAt = now.AddDate(0, 0, expireDays).Format(time.RFC3339)
|
||||
}
|
||||
|
||||
record := &BackupRecord{
|
||||
ID: backupID,
|
||||
Status: "running",
|
||||
BackupType: "postgres",
|
||||
FileName: fileName,
|
||||
S3Key: s3Key,
|
||||
TriggeredBy: triggeredBy,
|
||||
StartedAt: now.Format(time.RFC3339),
|
||||
ExpiresAt: expiresAt,
|
||||
}
|
||||
|
||||
// 流式执行: pg_dump -> gzip -> S3 upload
|
||||
dumpReader, err := s.dumper.Dump(ctx)
|
||||
if err != nil {
|
||||
record.Status = "failed"
|
||||
record.ErrorMsg = fmt.Sprintf("pg_dump failed: %v", err)
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
_ = s.saveRecord(ctx, record)
|
||||
return record, fmt.Errorf("pg_dump: %w", err)
|
||||
}
|
||||
|
||||
// 使用 io.Pipe 将 gzip 压缩数据流式传递给 S3 上传
|
||||
pr, pw := io.Pipe()
|
||||
var gzipErr error
|
||||
go func() {
|
||||
gzWriter := gzip.NewWriter(pw)
|
||||
_, gzipErr = io.Copy(gzWriter, dumpReader)
|
||||
if closeErr := gzWriter.Close(); closeErr != nil && gzipErr == nil {
|
||||
gzipErr = closeErr
|
||||
}
|
||||
if closeErr := dumpReader.Close(); closeErr != nil && gzipErr == nil {
|
||||
gzipErr = closeErr
|
||||
}
|
||||
if gzipErr != nil {
|
||||
_ = pw.CloseWithError(gzipErr)
|
||||
} else {
|
||||
_ = pw.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
contentType := "application/gzip"
|
||||
sizeBytes, err := objectStore.Upload(ctx, s3Key, pr, contentType)
|
||||
if err != nil {
|
||||
record.Status = "failed"
|
||||
errMsg := fmt.Sprintf("S3 upload failed: %v", err)
|
||||
if gzipErr != nil {
|
||||
errMsg = fmt.Sprintf("gzip/dump failed: %v", gzipErr)
|
||||
}
|
||||
record.ErrorMsg = errMsg
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
_ = s.saveRecord(ctx, record)
|
||||
return record, fmt.Errorf("backup upload: %w", err)
|
||||
}
|
||||
|
||||
record.SizeBytes = sizeBytes
|
||||
record.Status = "completed"
|
||||
record.FinishedAt = time.Now().Format(time.RFC3339)
|
||||
if err := s.saveRecord(ctx, record); err != nil {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 保存备份记录失败: %v", err)
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// RestoreBackup 从 S3 下载备份并流式恢复到数据库
|
||||
func (s *BackupService) RestoreBackup(ctx context.Context, backupID string) error {
|
||||
s.mu.Lock()
|
||||
if s.restoring {
|
||||
s.mu.Unlock()
|
||||
return ErrRestoreInProgress
|
||||
}
|
||||
s.restoring = true
|
||||
s.mu.Unlock()
|
||||
defer func() {
|
||||
s.mu.Lock()
|
||||
s.restoring = false
|
||||
s.mu.Unlock()
|
||||
}()
|
||||
|
||||
record, err := s.GetBackupRecord(ctx, backupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if record.Status != "completed" {
|
||||
return infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "can only restore from a completed backup")
|
||||
}
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init object store: %w", err)
|
||||
}
|
||||
|
||||
// 从 S3 流式下载
|
||||
body, err := objectStore.Download(ctx, record.S3Key)
|
||||
if err != nil {
|
||||
return fmt.Errorf("S3 download failed: %w", err)
|
||||
}
|
||||
defer func() { _ = body.Close() }()
|
||||
|
||||
// 流式解压 gzip -> psql(不将全部数据加载到内存)
|
||||
gzReader, err := gzip.NewReader(body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("gzip reader: %w", err)
|
||||
}
|
||||
defer func() { _ = gzReader.Close() }()
|
||||
|
||||
// 流式恢复
|
||||
if err := s.dumper.Restore(ctx, gzReader); err != nil {
|
||||
return fmt.Errorf("pg restore: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ─── 备份记录管理 ───
|
||||
|
||||
func (s *BackupService) ListBackups(ctx context.Context) ([]BackupRecord, error) {
|
||||
records, err := s.loadRecords(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 倒序返回(最新在前)
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
return records[i].StartedAt > records[j].StartedAt
|
||||
})
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) GetBackupRecord(ctx context.Context, backupID string) (*BackupRecord, error) {
|
||||
records, err := s.loadRecords(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range records {
|
||||
if records[i].ID == backupID {
|
||||
return &records[i], nil
|
||||
}
|
||||
}
|
||||
return nil, ErrBackupNotFound
|
||||
}
|
||||
|
||||
func (s *BackupService) DeleteBackup(ctx context.Context, backupID string) error {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, err := s.loadRecordsLocked(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var found *BackupRecord
|
||||
var remaining []BackupRecord
|
||||
for i := range records {
|
||||
if records[i].ID == backupID {
|
||||
found = &records[i]
|
||||
} else {
|
||||
remaining = append(remaining, records[i])
|
||||
}
|
||||
}
|
||||
if found == nil {
|
||||
return ErrBackupNotFound
|
||||
}
|
||||
|
||||
// 从 S3 删除
|
||||
if found.S3Key != "" && found.Status == "completed" {
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err == nil && s3Cfg != nil && s3Cfg.IsConfigured() {
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err == nil {
|
||||
_ = objectStore.Delete(ctx, found.S3Key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return s.saveRecordsLocked(ctx, remaining)
|
||||
}
|
||||
|
||||
// GetBackupDownloadURL 获取备份文件预签名下载 URL
|
||||
func (s *BackupService) GetBackupDownloadURL(ctx context.Context, backupID string) (string, error) {
|
||||
record, err := s.GetBackupRecord(ctx, backupID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if record.Status != "completed" {
|
||||
return "", infraerrors.BadRequest("BACKUP_NOT_COMPLETED", "backup is not completed")
|
||||
}
|
||||
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
url, err := objectStore.PresignURL(ctx, record.S3Key, 1*time.Hour)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("presign url: %w", err)
|
||||
}
|
||||
return url, nil
|
||||
}
|
||||
|
||||
// ─── 内部方法 ───
|
||||
|
||||
func (s *BackupService) loadS3Config(ctx context.Context) (*BackupS3Config, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupS3Config)
|
||||
if err != nil || raw == "" {
|
||||
return nil, nil //nolint:nilnil // no config is a valid state
|
||||
}
|
||||
var cfg BackupS3Config
|
||||
if err := json.Unmarshal([]byte(raw), &cfg); err != nil {
|
||||
return nil, ErrBackupS3ConfigCorrupt
|
||||
}
|
||||
// 解密 SecretAccessKey
|
||||
if cfg.SecretAccessKey != "" {
|
||||
decrypted, err := s.encryptor.Decrypt(cfg.SecretAccessKey)
|
||||
if err != nil {
|
||||
// 兼容未加密的旧数据:如果解密失败,保持原值
|
||||
logger.LegacyPrintf("service.backup", "[Backup] S3 SecretAccessKey 解密失败(可能是旧的未加密数据): %v", err)
|
||||
} else {
|
||||
cfg.SecretAccessKey = decrypted
|
||||
}
|
||||
}
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) getOrCreateStore(ctx context.Context, cfg *BackupS3Config) (BackupObjectStore, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.store != nil && s.s3Cfg != nil {
|
||||
return s.store, nil
|
||||
}
|
||||
|
||||
if cfg == nil {
|
||||
return nil, ErrBackupS3NotConfigured
|
||||
}
|
||||
|
||||
store, err := s.storeFactory(ctx, cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.store = store
|
||||
s.s3Cfg = cfg
|
||||
return store, nil
|
||||
}
|
||||
|
||||
func (s *BackupService) buildS3Key(cfg *BackupS3Config, fileName string) string {
|
||||
prefix := strings.TrimRight(cfg.Prefix, "/")
|
||||
if prefix == "" {
|
||||
prefix = "backups"
|
||||
}
|
||||
return fmt.Sprintf("%s/%s/%s", prefix, time.Now().Format("2006/01/02"), fileName)
|
||||
}
|
||||
|
||||
// loadRecords 加载备份记录,区分"无数据"和"数据损坏"
|
||||
func (s *BackupService) loadRecords(ctx context.Context) ([]BackupRecord, error) {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
return s.loadRecordsLocked(ctx)
|
||||
}
|
||||
|
||||
// loadRecordsLocked 在已持有 recordsMu 锁的情况下加载记录
|
||||
func (s *BackupService) loadRecordsLocked(ctx context.Context) ([]BackupRecord, error) {
|
||||
raw, err := s.settingRepo.GetValue(ctx, settingKeyBackupRecords)
|
||||
if err != nil || raw == "" {
|
||||
return nil, nil //nolint:nilnil // no records is a valid state
|
||||
}
|
||||
var records []BackupRecord
|
||||
if err := json.Unmarshal([]byte(raw), &records); err != nil {
|
||||
return nil, ErrBackupRecordsCorrupt
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
// saveRecordsLocked 在已持有 recordsMu 锁的情况下保存记录
|
||||
func (s *BackupService) saveRecordsLocked(ctx context.Context, records []BackupRecord) error {
|
||||
data, err := json.Marshal(records)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return s.settingRepo.Set(ctx, settingKeyBackupRecords, string(data))
|
||||
}
|
||||
|
||||
// saveRecord 保存单条记录(带互斥锁保护)
|
||||
func (s *BackupService) saveRecord(ctx context.Context, record *BackupRecord) error {
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, _ := s.loadRecordsLocked(ctx)
|
||||
|
||||
// 更新已有记录或追加
|
||||
found := false
|
||||
for i := range records {
|
||||
if records[i].ID == record.ID {
|
||||
records[i] = *record
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
records = append(records, *record)
|
||||
}
|
||||
|
||||
// 限制记录数量
|
||||
if len(records) > maxBackupRecords {
|
||||
records = records[len(records)-maxBackupRecords:]
|
||||
}
|
||||
|
||||
return s.saveRecordsLocked(ctx, records)
|
||||
}
|
||||
|
||||
func (s *BackupService) cleanupOldBackups(ctx context.Context, schedule *BackupScheduleConfig) error {
|
||||
if schedule == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.recordsMu.Lock()
|
||||
defer s.recordsMu.Unlock()
|
||||
|
||||
records, err := s.loadRecordsLocked(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 按时间倒序
|
||||
sort.Slice(records, func(i, j int) bool {
|
||||
return records[i].StartedAt > records[j].StartedAt
|
||||
})
|
||||
|
||||
var toDelete []BackupRecord
|
||||
var toKeep []BackupRecord
|
||||
|
||||
for i, r := range records {
|
||||
shouldDelete := false
|
||||
|
||||
// 按保留份数清理
|
||||
if schedule.RetainCount > 0 && i >= schedule.RetainCount {
|
||||
shouldDelete = true
|
||||
}
|
||||
|
||||
// 按保留天数清理
|
||||
if schedule.RetainDays > 0 && r.StartedAt != "" {
|
||||
startedAt, err := time.Parse(time.RFC3339, r.StartedAt)
|
||||
if err == nil && time.Since(startedAt) > time.Duration(schedule.RetainDays)*24*time.Hour {
|
||||
shouldDelete = true
|
||||
}
|
||||
}
|
||||
|
||||
if shouldDelete && r.Status == "completed" {
|
||||
toDelete = append(toDelete, r)
|
||||
} else {
|
||||
toKeep = append(toKeep, r)
|
||||
}
|
||||
}
|
||||
|
||||
// 删除 S3 上的文件
|
||||
for _, r := range toDelete {
|
||||
if r.S3Key != "" {
|
||||
_ = s.deleteS3Object(ctx, r.S3Key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(toDelete) > 0 {
|
||||
logger.LegacyPrintf("service.backup", "[Backup] 自动清理了 %d 个过期备份", len(toDelete))
|
||||
return s.saveRecordsLocked(ctx, toKeep)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *BackupService) deleteS3Object(ctx context.Context, key string) error {
|
||||
s3Cfg, err := s.loadS3Config(ctx)
|
||||
if err != nil || s3Cfg == nil {
|
||||
return nil
|
||||
}
|
||||
objectStore, err := s.getOrCreateStore(ctx, s3Cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return objectStore.Delete(ctx, key)
|
||||
}
|
||||
@@ -0,0 +1,528 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// ─── Mocks ───
|
||||
|
||||
type mockSettingRepo struct {
|
||||
mu sync.Mutex
|
||||
data map[string]string
|
||||
}
|
||||
|
||||
func newMockSettingRepo() *mockSettingRepo {
|
||||
return &mockSettingRepo{data: make(map[string]string)}
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Get(_ context.Context, key string) (*Setting, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
v, ok := m.data[key]
|
||||
if !ok {
|
||||
return nil, ErrSettingNotFound
|
||||
}
|
||||
return &Setting{Key: key, Value: v}, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetValue(_ context.Context, key string) (string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
v, ok := m.data[key]
|
||||
if !ok {
|
||||
return "", nil
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Set(_ context.Context, key, value string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetMultiple(_ context.Context, keys []string) (map[string]string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
result := make(map[string]string)
|
||||
for _, k := range keys {
|
||||
if v, ok := m.data[k]; ok {
|
||||
result[k] = v
|
||||
}
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) SetMultiple(_ context.Context, settings map[string]string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
for k, v := range settings {
|
||||
m.data[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) GetAll(_ context.Context) (map[string]string, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
result := make(map[string]string, len(m.data))
|
||||
for k, v := range m.data {
|
||||
result[k] = v
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (m *mockSettingRepo) Delete(_ context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.data, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
// plainEncryptor 仅做 base64-like 包装,用于测试
|
||||
type plainEncryptor struct{}
|
||||
|
||||
func (e *plainEncryptor) Encrypt(plaintext string) (string, error) {
|
||||
return "ENC:" + plaintext, nil
|
||||
}
|
||||
|
||||
func (e *plainEncryptor) Decrypt(ciphertext string) (string, error) {
|
||||
if strings.HasPrefix(ciphertext, "ENC:") {
|
||||
return strings.TrimPrefix(ciphertext, "ENC:"), nil
|
||||
}
|
||||
return ciphertext, fmt.Errorf("not encrypted")
|
||||
}
|
||||
|
||||
type mockDumper struct {
|
||||
dumpData []byte
|
||||
dumpErr error
|
||||
restored []byte
|
||||
restErr error
|
||||
}
|
||||
|
||||
func (m *mockDumper) Dump(_ context.Context) (io.ReadCloser, error) {
|
||||
if m.dumpErr != nil {
|
||||
return nil, m.dumpErr
|
||||
}
|
||||
return io.NopCloser(bytes.NewReader(m.dumpData)), nil
|
||||
}
|
||||
|
||||
func (m *mockDumper) Restore(_ context.Context, data io.Reader) error {
|
||||
if m.restErr != nil {
|
||||
return m.restErr
|
||||
}
|
||||
d, err := io.ReadAll(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.restored = d
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockObjectStore struct {
|
||||
objects map[string][]byte
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func newMockObjectStore() *mockObjectStore {
|
||||
return &mockObjectStore{objects: make(map[string][]byte)}
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Upload(_ context.Context, key string, body io.Reader, _ string) (int64, error) {
|
||||
data, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.objects[key] = data
|
||||
m.mu.Unlock()
|
||||
return int64(len(data)), nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Download(_ context.Context, key string) (io.ReadCloser, error) {
|
||||
m.mu.Lock()
|
||||
data, ok := m.objects[key]
|
||||
m.mu.Unlock()
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("not found: %s", key)
|
||||
}
|
||||
return io.NopCloser(bytes.NewReader(data)), nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) Delete(_ context.Context, key string) error {
|
||||
m.mu.Lock()
|
||||
delete(m.objects, key)
|
||||
m.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) PresignURL(_ context.Context, key string, _ time.Duration) (string, error) {
|
||||
return "https://presigned.example.com/" + key, nil
|
||||
}
|
||||
|
||||
func (m *mockObjectStore) HeadBucket(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func newTestBackupService(repo *mockSettingRepo, dumper *mockDumper, store *mockObjectStore) *BackupService {
|
||||
cfg := &config.Config{
|
||||
Database: config.DatabaseConfig{
|
||||
Host: "localhost",
|
||||
Port: 5432,
|
||||
User: "test",
|
||||
DBName: "testdb",
|
||||
},
|
||||
}
|
||||
factory := func(_ context.Context, _ *BackupS3Config) (BackupObjectStore, error) {
|
||||
return store, nil
|
||||
}
|
||||
return NewBackupService(repo, cfg, &plainEncryptor{}, factory, dumper)
|
||||
}
|
||||
|
||||
func seedS3Config(t *testing.T, repo *mockSettingRepo) {
|
||||
t.Helper()
|
||||
cfg := BackupS3Config{
|
||||
Bucket: "test-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "ENC:secret123",
|
||||
Prefix: "backups",
|
||||
}
|
||||
data, _ := json.Marshal(cfg)
|
||||
require.NoError(t, repo.Set(context.Background(), settingKeyBackupS3Config, string(data)))
|
||||
}
|
||||
|
||||
// ─── Tests ───
|
||||
|
||||
func TestBackupService_S3ConfigEncryption(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 保存配置 -> SecretAccessKey 应被加密
|
||||
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "my-secret",
|
||||
Prefix: "backups",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 直接读取数据库中存储的值,应该是加密后的
|
||||
raw, _ := repo.GetValue(context.Background(), settingKeyBackupS3Config)
|
||||
var stored BackupS3Config
|
||||
require.NoError(t, json.Unmarshal([]byte(raw), &stored))
|
||||
require.Equal(t, "ENC:my-secret", stored.SecretAccessKey)
|
||||
|
||||
// 通过 GetS3Config 获取应该脱敏
|
||||
cfg, err := svc.GetS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, cfg.SecretAccessKey)
|
||||
require.Equal(t, "my-bucket", cfg.Bucket)
|
||||
|
||||
// loadS3Config 内部应解密
|
||||
internal, err := svc.loadS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "my-secret", internal.SecretAccessKey)
|
||||
}
|
||||
|
||||
func TestBackupService_S3ConfigKeepExistingSecret(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 先保存一个有 secret 的配置
|
||||
_, err := svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID",
|
||||
SecretAccessKey: "original-secret",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 再更新时不提供 secret,应保留原值
|
||||
_, err = svc.UpdateS3Config(context.Background(), BackupS3Config{
|
||||
Bucket: "my-bucket",
|
||||
AccessKeyID: "AKID-NEW",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
internal, err := svc.loadS3Config(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "original-secret", internal.SecretAccessKey)
|
||||
require.Equal(t, "AKID-NEW", internal.AccessKeyID)
|
||||
}
|
||||
|
||||
func TestBackupService_SaveRecordConcurrency(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
var wg sync.WaitGroup
|
||||
n := 20
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
record := &BackupRecord{
|
||||
ID: fmt.Sprintf("rec-%d", idx),
|
||||
Status: "completed",
|
||||
StartedAt: time.Now().Format(time.RFC3339),
|
||||
}
|
||||
_ = svc.saveRecord(context.Background(), record)
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, records, n)
|
||||
}
|
||||
|
||||
func TestBackupService_LoadRecords_Empty(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, records) // 无数据时返回 nil
|
||||
}
|
||||
|
||||
func TestBackupService_LoadRecords_Corrupted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
_ = repo.Set(context.Background(), settingKeyBackupRecords, "not valid json{{{")
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
records, err := svc.loadRecords(context.Background())
|
||||
require.Error(t, err) // 损坏数据应返回错误
|
||||
require.Nil(t, records)
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_Streaming(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "completed", record.Status)
|
||||
require.Greater(t, record.SizeBytes, int64(0))
|
||||
require.NotEmpty(t, record.S3Key)
|
||||
|
||||
// 验证 S3 上确实有文件
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 1)
|
||||
store.mu.Unlock()
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_DumpFailure(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumper := &mockDumper{dumpErr: fmt.Errorf("pg_dump failed")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, "failed", record.Status)
|
||||
require.Contains(t, record.ErrorMsg, "pg_dump")
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_NoS3Config(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.ErrorIs(t, err, ErrBackupS3NotConfigured)
|
||||
}
|
||||
|
||||
func TestBackupService_CreateBackup_ConcurrentBlocked(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
// 使用一个慢速 dumper 来模拟正在进行的备份
|
||||
dumper := &mockDumper{dumpData: []byte("data")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
// 手动设置 backingUp 标志
|
||||
svc.mu.Lock()
|
||||
svc.backingUp = true
|
||||
svc.mu.Unlock()
|
||||
|
||||
_, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.ErrorIs(t, err, ErrBackupInProgress)
|
||||
}
|
||||
|
||||
func TestBackupService_RestoreBackup_Streaming(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "-- PostgreSQL dump\nCREATE TABLE test (id int);\n"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
// 先创建一个备份
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 恢复
|
||||
err = svc.RestoreBackup(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 验证 psql 收到的数据是否与原始 dump 内容一致
|
||||
require.Equal(t, dumpContent, string(dumper.restored))
|
||||
}
|
||||
|
||||
func TestBackupService_RestoreBackup_NotCompleted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
// 手动插入一条 failed 记录
|
||||
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||
ID: "fail-1",
|
||||
Status: "failed",
|
||||
})
|
||||
|
||||
err := svc.RestoreBackup(context.Background(), "fail-1")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_DeleteBackup(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumpContent := "data"
|
||||
dumper := &mockDumper{dumpData: []byte(dumpContent)}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S3 中应有文件
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 1)
|
||||
store.mu.Unlock()
|
||||
|
||||
// 删除
|
||||
err = svc.DeleteBackup(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// S3 中文件应被删除
|
||||
store.mu.Lock()
|
||||
require.Len(t, store.objects, 0)
|
||||
store.mu.Unlock()
|
||||
|
||||
// 记录应不存在
|
||||
_, err = svc.GetBackupRecord(context.Background(), record.ID)
|
||||
require.ErrorIs(t, err, ErrBackupNotFound)
|
||||
}
|
||||
|
||||
func TestBackupService_GetDownloadURL(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
seedS3Config(t, repo)
|
||||
|
||||
dumper := &mockDumper{dumpData: []byte("data")}
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, dumper, store)
|
||||
|
||||
record, err := svc.CreateBackup(context.Background(), "manual", 14)
|
||||
require.NoError(t, err)
|
||||
|
||||
url, err := svc.GetBackupDownloadURL(context.Background(), record.ID)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, url, "https://presigned.example.com/")
|
||||
}
|
||||
|
||||
func TestBackupService_ListBackups_Sorted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
now := time.Now()
|
||||
for i := 0; i < 3; i++ {
|
||||
_ = svc.saveRecord(context.Background(), &BackupRecord{
|
||||
ID: fmt.Sprintf("rec-%d", i),
|
||||
Status: "completed",
|
||||
StartedAt: now.Add(time.Duration(i) * time.Hour).Format(time.RFC3339),
|
||||
})
|
||||
}
|
||||
|
||||
records, err := svc.ListBackups(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Len(t, records, 3)
|
||||
// 最新在前
|
||||
require.Equal(t, "rec-2", records[0].ID)
|
||||
require.Equal(t, "rec-0", records[2].ID)
|
||||
}
|
||||
|
||||
func TestBackupService_TestS3Connection(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
store := newMockObjectStore()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, store)
|
||||
|
||||
err := svc.TestS3Connection(context.Background(), BackupS3Config{
|
||||
Bucket: "test",
|
||||
AccessKeyID: "ak",
|
||||
SecretAccessKey: "sk",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_TestS3Connection_Incomplete(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
err := svc.TestS3Connection(context.Background(), BackupS3Config{
|
||||
Bucket: "test",
|
||||
})
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "incomplete")
|
||||
}
|
||||
|
||||
func TestBackupService_Schedule_CronValidation(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
svc.cronSched = nil // 未初始化 cron
|
||||
|
||||
// 启用但 cron 为空
|
||||
_, err := svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
|
||||
Enabled: true,
|
||||
CronExpr: "",
|
||||
})
|
||||
require.Error(t, err)
|
||||
|
||||
// 无效的 cron 表达式
|
||||
_, err = svc.UpdateSchedule(context.Background(), BackupScheduleConfig{
|
||||
Enabled: true,
|
||||
CronExpr: "invalid",
|
||||
})
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestBackupService_LoadS3Config_Corrupted(t *testing.T) {
|
||||
repo := newMockSettingRepo()
|
||||
_ = repo.Set(context.Background(), settingKeyBackupS3Config, "not json!!!!")
|
||||
svc := newTestBackupService(repo, &mockDumper{}, newMockObjectStore())
|
||||
|
||||
cfg, err := svc.loadS3Config(context.Background())
|
||||
require.Error(t, err)
|
||||
require.Nil(t, cfg)
|
||||
}
|
||||
@@ -0,0 +1,607 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/domain"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const defaultBedrockRegion = "us-east-1"
|
||||
|
||||
var bedrockCrossRegionPrefixes = []string{"us.", "eu.", "apac.", "jp.", "au.", "us-gov.", "global."}
|
||||
|
||||
// BedrockCrossRegionPrefix 根据 AWS Region 返回 Bedrock 跨区域推理的模型 ID 前缀
|
||||
// 参考: https://docs.aws.amazon.com/bedrock/latest/userguide/inference-profiles-support.html
|
||||
func BedrockCrossRegionPrefix(region string) string {
|
||||
switch {
|
||||
case strings.HasPrefix(region, "us-gov"):
|
||||
return "us-gov" // GovCloud 使用独立的 us-gov 前缀
|
||||
case strings.HasPrefix(region, "us-"):
|
||||
return "us"
|
||||
case strings.HasPrefix(region, "eu-"):
|
||||
return "eu"
|
||||
case region == "ap-northeast-1":
|
||||
return "jp" // 日本区域使用独立的 jp 前缀(AWS 官方定义)
|
||||
case region == "ap-southeast-2":
|
||||
return "au" // 澳大利亚区域使用独立的 au 前缀(AWS 官方定义)
|
||||
case strings.HasPrefix(region, "ap-"):
|
||||
return "apac" // 其余亚太区域使用通用 apac 前缀
|
||||
case strings.HasPrefix(region, "ca-"):
|
||||
return "us" // 加拿大区域使用 us 前缀的跨区域推理
|
||||
case strings.HasPrefix(region, "sa-"):
|
||||
return "us" // 南美区域使用 us 前缀的跨区域推理
|
||||
default:
|
||||
return "us"
|
||||
}
|
||||
}
|
||||
|
||||
// AdjustBedrockModelRegionPrefix 将模型 ID 的区域前缀替换为与当前 AWS Region 匹配的前缀
|
||||
// 例如 region=eu-west-1 时,"us.anthropic.claude-opus-4-6-v1" → "eu.anthropic.claude-opus-4-6-v1"
|
||||
// 特殊值 region="global" 强制使用 global. 前缀
|
||||
func AdjustBedrockModelRegionPrefix(modelID, region string) string {
|
||||
var targetPrefix string
|
||||
if region == "global" {
|
||||
targetPrefix = "global"
|
||||
} else {
|
||||
targetPrefix = BedrockCrossRegionPrefix(region)
|
||||
}
|
||||
|
||||
for _, p := range bedrockCrossRegionPrefixes {
|
||||
if strings.HasPrefix(modelID, p) {
|
||||
if p == targetPrefix+"." {
|
||||
return modelID // 前缀已匹配,无需替换
|
||||
}
|
||||
return targetPrefix + "." + modelID[len(p):]
|
||||
}
|
||||
}
|
||||
|
||||
// 模型 ID 没有已知区域前缀(如 "anthropic.claude-..."),不做修改
|
||||
return modelID
|
||||
}
|
||||
|
||||
func bedrockRuntimeRegion(account *Account) string {
|
||||
if account == nil {
|
||||
return defaultBedrockRegion
|
||||
}
|
||||
if region := account.GetCredential("aws_region"); region != "" {
|
||||
return region
|
||||
}
|
||||
return defaultBedrockRegion
|
||||
}
|
||||
|
||||
func shouldForceBedrockGlobal(account *Account) bool {
|
||||
return account != nil && account.GetCredential("aws_force_global") == "true"
|
||||
}
|
||||
|
||||
func isRegionalBedrockModelID(modelID string) bool {
|
||||
for _, prefix := range bedrockCrossRegionPrefixes {
|
||||
if strings.HasPrefix(modelID, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isLikelyBedrockModelID(modelID string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(modelID))
|
||||
if lower == "" {
|
||||
return false
|
||||
}
|
||||
if strings.HasPrefix(lower, "arn:") {
|
||||
return true
|
||||
}
|
||||
for _, prefix := range []string{
|
||||
"anthropic.",
|
||||
"amazon.",
|
||||
"meta.",
|
||||
"mistral.",
|
||||
"cohere.",
|
||||
"ai21.",
|
||||
"deepseek.",
|
||||
"stability.",
|
||||
"writer.",
|
||||
"nova.",
|
||||
} {
|
||||
if strings.HasPrefix(lower, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return isRegionalBedrockModelID(lower)
|
||||
}
|
||||
|
||||
func normalizeBedrockModelID(modelID string) (normalized string, shouldAdjustRegion bool, ok bool) {
|
||||
modelID = strings.TrimSpace(modelID)
|
||||
if modelID == "" {
|
||||
return "", false, false
|
||||
}
|
||||
if mapped, exists := domain.DefaultBedrockModelMapping[modelID]; exists {
|
||||
return mapped, true, true
|
||||
}
|
||||
if isRegionalBedrockModelID(modelID) {
|
||||
return modelID, true, true
|
||||
}
|
||||
if isLikelyBedrockModelID(modelID) {
|
||||
return modelID, false, true
|
||||
}
|
||||
return "", false, false
|
||||
}
|
||||
|
||||
// ResolveBedrockModelID resolves a requested Claude model into a Bedrock model ID.
|
||||
// It applies account model_mapping first, then default Bedrock aliases, and finally
|
||||
// adjusts Anthropic cross-region prefixes to match the account region.
|
||||
func ResolveBedrockModelID(account *Account, requestedModel string) (string, bool) {
|
||||
if account == nil {
|
||||
return "", false
|
||||
}
|
||||
|
||||
mappedModel := account.GetMappedModel(requestedModel)
|
||||
modelID, shouldAdjustRegion, ok := normalizeBedrockModelID(mappedModel)
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if shouldAdjustRegion {
|
||||
targetRegion := bedrockRuntimeRegion(account)
|
||||
if shouldForceBedrockGlobal(account) {
|
||||
targetRegion = "global"
|
||||
}
|
||||
modelID = AdjustBedrockModelRegionPrefix(modelID, targetRegion)
|
||||
}
|
||||
return modelID, true
|
||||
}
|
||||
|
||||
// BuildBedrockURL 构建 Bedrock InvokeModel 的 URL
|
||||
// stream=true 时使用 invoke-with-response-stream 端点
|
||||
// modelID 中的特殊字符会被 URL 编码(与 litellm 的 urllib.parse.quote(safe="") 对齐)
|
||||
func BuildBedrockURL(region, modelID string, stream bool) string {
|
||||
if region == "" {
|
||||
region = defaultBedrockRegion
|
||||
}
|
||||
encodedModelID := url.PathEscape(modelID)
|
||||
// url.PathEscape 不编码冒号(RFC 允许 path 中出现 ":"),
|
||||
// 但 AWS Bedrock 期望模型 ID 中的冒号被编码为 %3A
|
||||
encodedModelID = strings.ReplaceAll(encodedModelID, ":", "%3A")
|
||||
if stream {
|
||||
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke-with-response-stream", region, encodedModelID)
|
||||
}
|
||||
return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/invoke", region, encodedModelID)
|
||||
}
|
||||
|
||||
// PrepareBedrockRequestBody 处理请求体以适配 Bedrock API
|
||||
// 1. 注入 anthropic_version
|
||||
// 2. 注入 anthropic_beta(从客户端 anthropic-beta 头解析)
|
||||
// 3. 移除 Bedrock 不支持的字段(model, stream, output_format, output_config)
|
||||
// 4. 移除工具定义中的 custom 字段(Claude Code 会发送 custom: {defer_loading: true})
|
||||
// 5. 清理 cache_control 中 Bedrock 不支持的字段(scope, ttl)
|
||||
func PrepareBedrockRequestBody(body []byte, modelID string, betaHeader string) ([]byte, error) {
|
||||
betaTokens := ResolveBedrockBetaTokens(betaHeader, body, modelID)
|
||||
return PrepareBedrockRequestBodyWithTokens(body, modelID, betaTokens)
|
||||
}
|
||||
|
||||
// PrepareBedrockRequestBodyWithTokens prepares a Bedrock request using pre-resolved beta tokens.
|
||||
func PrepareBedrockRequestBodyWithTokens(body []byte, modelID string, betaTokens []string) ([]byte, error) {
|
||||
var err error
|
||||
|
||||
// 注入 anthropic_version(Bedrock 要求)
|
||||
body, err = sjson.SetBytes(body, "anthropic_version", "bedrock-2023-05-31")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("inject anthropic_version: %w", err)
|
||||
}
|
||||
|
||||
// 注入 anthropic_beta(Bedrock Invoke 通过请求体传递 beta 头,而非 HTTP 头)
|
||||
// 1. 从客户端 anthropic-beta header 解析
|
||||
// 2. 根据请求体内容自动补齐必要的 beta token
|
||||
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() + _get_tool_search_beta_header_for_bedrock()
|
||||
if len(betaTokens) > 0 {
|
||||
body, err = sjson.SetBytes(body, "anthropic_beta", betaTokens)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("inject anthropic_beta: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 移除 model 字段(Bedrock 通过 URL 指定模型)
|
||||
body, err = sjson.DeleteBytes(body, "model")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remove model field: %w", err)
|
||||
}
|
||||
|
||||
// 移除 stream 字段(Bedrock 通过不同端点控制流式,不接受请求体中的 stream 字段)
|
||||
body, err = sjson.DeleteBytes(body, "stream")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remove stream field: %w", err)
|
||||
}
|
||||
|
||||
// 转换 output_format(Bedrock Invoke 不支持此字段,但可将 schema 内联到最后一条 user message)
|
||||
// 参考 litellm: _convert_output_format_to_inline_schema()
|
||||
body = convertOutputFormatToInlineSchema(body)
|
||||
|
||||
// 移除 output_config 字段(Bedrock Invoke 不支持)
|
||||
body, err = sjson.DeleteBytes(body, "output_config")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("remove output_config field: %w", err)
|
||||
}
|
||||
|
||||
// 移除工具定义中的 custom 字段
|
||||
// Claude Code (v2.1.69+) 在 tool 定义中发送 custom: {defer_loading: true},
|
||||
// Anthropic API 接受但 Bedrock 会拒绝并报 "Extra inputs are not permitted"
|
||||
body = removeCustomFieldFromTools(body)
|
||||
|
||||
// 清理 cache_control 中 Bedrock 不支持的字段
|
||||
body = sanitizeBedrockCacheControl(body, modelID)
|
||||
|
||||
return body, nil
|
||||
}
|
||||
|
||||
// ResolveBedrockBetaTokens computes the final Bedrock beta token list before policy filtering.
|
||||
func ResolveBedrockBetaTokens(betaHeader string, body []byte, modelID string) []string {
|
||||
betaTokens := parseAnthropicBetaHeader(betaHeader)
|
||||
betaTokens = autoInjectBedrockBetaTokens(betaTokens, body, modelID)
|
||||
return filterBedrockBetaTokens(betaTokens)
|
||||
}
|
||||
|
||||
// convertOutputFormatToInlineSchema 将 output_format 中的 JSON schema 内联到最后一条 user message
|
||||
// Bedrock Invoke 不支持 output_format 参数,litellm 的做法是将 schema 追加到用户消息中
|
||||
// 参考: litellm AmazonAnthropicClaudeMessagesConfig._convert_output_format_to_inline_schema()
|
||||
func convertOutputFormatToInlineSchema(body []byte) []byte {
|
||||
outputFormat := gjson.GetBytes(body, "output_format")
|
||||
if !outputFormat.Exists() || !outputFormat.IsObject() {
|
||||
return body
|
||||
}
|
||||
|
||||
// 先从请求体中移除 output_format
|
||||
body, _ = sjson.DeleteBytes(body, "output_format")
|
||||
|
||||
schema := outputFormat.Get("schema")
|
||||
if !schema.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// 找到最后一条 user message
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
msgArr := messages.Array()
|
||||
lastUserIdx := -1
|
||||
for i := len(msgArr) - 1; i >= 0; i-- {
|
||||
if msgArr[i].Get("role").String() == "user" {
|
||||
lastUserIdx = i
|
||||
break
|
||||
}
|
||||
}
|
||||
if lastUserIdx < 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
// 将 schema 序列化为 JSON 文本追加到该 message 的 content 数组
|
||||
schemaJSON, err := json.Marshal(json.RawMessage(schema.Raw))
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
content := msgArr[lastUserIdx].Get("content")
|
||||
basePath := fmt.Sprintf("messages.%d.content", lastUserIdx)
|
||||
|
||||
if content.IsArray() {
|
||||
// 追加一个 text block 到 content 数组末尾
|
||||
idx := len(content.Array())
|
||||
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.type", basePath, idx), "text")
|
||||
body, _ = sjson.SetBytes(body, fmt.Sprintf("%s.%d.text", basePath, idx), string(schemaJSON))
|
||||
} else if content.Type == gjson.String {
|
||||
// content 是纯字符串,转换为数组格式
|
||||
originalText := content.String()
|
||||
body, _ = sjson.SetBytes(body, basePath, []map[string]string{
|
||||
{"type": "text", "text": originalText},
|
||||
{"type": "text", "text": string(schemaJSON)},
|
||||
})
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// removeCustomFieldFromTools 移除 tools 数组中每个工具定义的 custom 字段
|
||||
func removeCustomFieldFromTools(body []byte) []byte {
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if !tools.Exists() || !tools.IsArray() {
|
||||
return body
|
||||
}
|
||||
var err error
|
||||
for i := range tools.Array() {
|
||||
body, err = sjson.DeleteBytes(body, fmt.Sprintf("tools.%d.custom", i))
|
||||
if err != nil {
|
||||
// 删除失败不影响整体流程,跳过
|
||||
continue
|
||||
}
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// claudeVersionRe 匹配 Claude 模型 ID 中的版本号部分
|
||||
// 支持 claude-{tier}-{major}-{minor} 和 claude-{tier}-{major}.{minor} 格式
|
||||
var claudeVersionRe = regexp.MustCompile(`claude-(?:haiku|sonnet|opus)-(\d+)[-.](\d+)`)
|
||||
|
||||
// isBedrockClaude45OrNewer 判断 Bedrock 模型 ID 是否为 Claude 4.5 或更新版本
|
||||
// Claude 4.5+ 支持 cache_control 中的 ttl 字段("5m" 和 "1h")
|
||||
func isBedrockClaude45OrNewer(modelID string) bool {
|
||||
lower := strings.ToLower(modelID)
|
||||
matches := claudeVersionRe.FindStringSubmatch(lower)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
major, _ := strconv.Atoi(matches[1])
|
||||
minor, _ := strconv.Atoi(matches[2])
|
||||
return major > 4 || (major == 4 && minor >= 5)
|
||||
}
|
||||
|
||||
// sanitizeBedrockCacheControl 清理 system 和 messages 中 cache_control 里
|
||||
// Bedrock 不支持的字段:
|
||||
// - scope:Bedrock 不支持(如 "global" 跨请求缓存)
|
||||
// - ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",旧模型需要移除
|
||||
func sanitizeBedrockCacheControl(body []byte, modelID string) []byte {
|
||||
isClaude45 := isBedrockClaude45OrNewer(modelID)
|
||||
|
||||
// 清理 system 数组中的 cache_control
|
||||
systemArr := gjson.GetBytes(body, "system")
|
||||
if systemArr.Exists() && systemArr.IsArray() {
|
||||
for i, item := range systemArr.Array() {
|
||||
if !item.IsObject() {
|
||||
continue
|
||||
}
|
||||
cc := item.Get("cache_control")
|
||||
if !cc.Exists() || !cc.IsObject() {
|
||||
continue
|
||||
}
|
||||
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("system.%d.cache_control", i), cc, isClaude45)
|
||||
}
|
||||
}
|
||||
|
||||
// 清理 messages 中的 cache_control
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.Exists() || !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
for mi, msg := range messages.Array() {
|
||||
if !msg.IsObject() {
|
||||
continue
|
||||
}
|
||||
content := msg.Get("content")
|
||||
if !content.Exists() || !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
for ci, block := range content.Array() {
|
||||
if !block.IsObject() {
|
||||
continue
|
||||
}
|
||||
cc := block.Get("cache_control")
|
||||
if !cc.Exists() || !cc.IsObject() {
|
||||
continue
|
||||
}
|
||||
body = deleteCacheControlUnsupportedFields(body, fmt.Sprintf("messages.%d.content.%d.cache_control", mi, ci), cc, isClaude45)
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// deleteCacheControlUnsupportedFields 删除给定 cache_control 路径下 Bedrock 不支持的字段
|
||||
func deleteCacheControlUnsupportedFields(body []byte, basePath string, cc gjson.Result, isClaude45 bool) []byte {
|
||||
// Bedrock 不支持 scope(如 "global")
|
||||
if cc.Get("scope").Exists() {
|
||||
body, _ = sjson.DeleteBytes(body, basePath+".scope")
|
||||
}
|
||||
|
||||
// ttl:仅 Claude 4.5+ 支持 "5m" 和 "1h",其余情况移除
|
||||
ttl := cc.Get("ttl")
|
||||
if ttl.Exists() {
|
||||
shouldRemove := true
|
||||
if isClaude45 {
|
||||
v := ttl.String()
|
||||
if v == "5m" || v == "1h" {
|
||||
shouldRemove = false
|
||||
}
|
||||
}
|
||||
if shouldRemove {
|
||||
body, _ = sjson.DeleteBytes(body, basePath+".ttl")
|
||||
}
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// parseAnthropicBetaHeader 解析 anthropic-beta 头的逗号分隔字符串为 token 列表
|
||||
func parseAnthropicBetaHeader(header string) []string {
|
||||
header = strings.TrimSpace(header)
|
||||
if header == "" {
|
||||
return nil
|
||||
}
|
||||
if strings.HasPrefix(header, "[") && strings.HasSuffix(header, "]") {
|
||||
var parsed []any
|
||||
if err := json.Unmarshal([]byte(header), &parsed); err == nil {
|
||||
tokens := make([]string, 0, len(parsed))
|
||||
for _, item := range parsed {
|
||||
token := strings.TrimSpace(fmt.Sprint(item))
|
||||
if token != "" {
|
||||
tokens = append(tokens, token)
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
}
|
||||
var tokens []string
|
||||
for _, part := range strings.Split(header, ",") {
|
||||
t := strings.TrimSpace(part)
|
||||
if t != "" {
|
||||
tokens = append(tokens, t)
|
||||
}
|
||||
}
|
||||
return tokens
|
||||
}
|
||||
|
||||
// bedrockSupportedBetaTokens 是 Bedrock Invoke 支持的 beta 头白名单
|
||||
// 参考: litellm/litellm/llms/bedrock/common_utils.py (anthropic_beta_headers_config.json)
|
||||
// 更新策略: 当 AWS Bedrock 新增支持的 beta token 时需同步更新此白名单
|
||||
var bedrockSupportedBetaTokens = map[string]bool{
|
||||
"computer-use-2025-01-24": true,
|
||||
"computer-use-2025-11-24": true,
|
||||
"context-1m-2025-08-07": true,
|
||||
"context-management-2025-06-27": true,
|
||||
"compact-2026-01-12": true,
|
||||
"interleaved-thinking-2025-05-14": true,
|
||||
"tool-search-tool-2025-10-19": true,
|
||||
"tool-examples-2025-10-29": true,
|
||||
}
|
||||
|
||||
// bedrockBetaTokenTransforms 定义 Bedrock Invoke 特有的 beta 头转换规则
|
||||
// Anthropic 直接 API 使用通用头,Bedrock Invoke 需要特定的替代头
|
||||
var bedrockBetaTokenTransforms = map[string]string{
|
||||
"advanced-tool-use-2025-11-20": "tool-search-tool-2025-10-19",
|
||||
}
|
||||
|
||||
// autoInjectBedrockBetaTokens 根据请求体内容自动补齐必要的 beta token
|
||||
// 参考 litellm: AnthropicModelInfo.get_anthropic_beta_list() 和
|
||||
// AmazonAnthropicClaudeMessagesConfig._get_tool_search_beta_header_for_bedrock()
|
||||
//
|
||||
// 客户端(特别是非 Claude Code 客户端)可能只在 body 中启用了功能而不在 header 中带对应 beta token,
|
||||
// 这里通过检测请求体特征自动补齐,确保 Bedrock Invoke 不会因缺少必要 beta 头而 400。
|
||||
func autoInjectBedrockBetaTokens(tokens []string, body []byte, modelID string) []string {
|
||||
seen := make(map[string]bool, len(tokens))
|
||||
for _, t := range tokens {
|
||||
seen[t] = true
|
||||
}
|
||||
|
||||
inject := func(token string) {
|
||||
if !seen[token] {
|
||||
tokens = append(tokens, token)
|
||||
seen[token] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 检测 thinking / interleaved thinking
|
||||
// 请求体中有 "thinking" 字段 → 需要 interleaved-thinking beta
|
||||
if gjson.GetBytes(body, "thinking").Exists() {
|
||||
inject("interleaved-thinking-2025-05-14")
|
||||
}
|
||||
|
||||
// 检测 computer_use 工具
|
||||
// tools 中有 type="computer_20xxxxxx" 的工具 → 需要 computer-use beta
|
||||
tools := gjson.GetBytes(body, "tools")
|
||||
if tools.Exists() && tools.IsArray() {
|
||||
toolSearchUsed := false
|
||||
programmaticToolCallingUsed := false
|
||||
inputExamplesUsed := false
|
||||
for _, tool := range tools.Array() {
|
||||
toolType := tool.Get("type").String()
|
||||
if strings.HasPrefix(toolType, "computer_20") {
|
||||
inject("computer-use-2025-11-24")
|
||||
}
|
||||
if isBedrockToolSearchType(toolType) {
|
||||
toolSearchUsed = true
|
||||
}
|
||||
if hasCodeExecutionAllowedCallers(tool) {
|
||||
programmaticToolCallingUsed = true
|
||||
}
|
||||
if hasInputExamples(tool) {
|
||||
inputExamplesUsed = true
|
||||
}
|
||||
}
|
||||
if programmaticToolCallingUsed || inputExamplesUsed {
|
||||
// programmatic tool calling 和 input examples 需要 advanced-tool-use,
|
||||
// 后续 filterBedrockBetaTokens 会将其转换为 Bedrock 特定的 tool-search-tool
|
||||
inject("advanced-tool-use-2025-11-20")
|
||||
}
|
||||
if toolSearchUsed && bedrockModelSupportsToolSearch(modelID) {
|
||||
// 纯 tool search(无 programmatic/inputExamples)时直接注入 Bedrock 特定头,
|
||||
// 跳过 advanced-tool-use → tool-search-tool 的转换步骤(与 litellm 对齐)
|
||||
if !programmaticToolCallingUsed && !inputExamplesUsed {
|
||||
inject("tool-search-tool-2025-10-19")
|
||||
} else {
|
||||
inject("advanced-tool-use-2025-11-20")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
func isBedrockToolSearchType(toolType string) bool {
|
||||
return toolType == "tool_search_tool_regex_20251119" || toolType == "tool_search_tool_bm25_20251119"
|
||||
}
|
||||
|
||||
func hasCodeExecutionAllowedCallers(tool gjson.Result) bool {
|
||||
allowedCallers := tool.Get("allowed_callers")
|
||||
if containsStringInJSONArray(allowedCallers, "code_execution_20250825") {
|
||||
return true
|
||||
}
|
||||
return containsStringInJSONArray(tool.Get("function.allowed_callers"), "code_execution_20250825")
|
||||
}
|
||||
|
||||
func hasInputExamples(tool gjson.Result) bool {
|
||||
if arr := tool.Get("input_examples"); arr.Exists() && arr.IsArray() && len(arr.Array()) > 0 {
|
||||
return true
|
||||
}
|
||||
arr := tool.Get("function.input_examples")
|
||||
return arr.Exists() && arr.IsArray() && len(arr.Array()) > 0
|
||||
}
|
||||
|
||||
func containsStringInJSONArray(result gjson.Result, target string) bool {
|
||||
if !result.Exists() || !result.IsArray() {
|
||||
return false
|
||||
}
|
||||
for _, item := range result.Array() {
|
||||
if item.String() == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// bedrockModelSupportsToolSearch 判断 Bedrock 模型是否支持 tool search
|
||||
// 目前仅 Claude Opus/Sonnet 4.5+ 支持,Haiku 不支持
|
||||
func bedrockModelSupportsToolSearch(modelID string) bool {
|
||||
lower := strings.ToLower(modelID)
|
||||
matches := claudeVersionRe.FindStringSubmatch(lower)
|
||||
if matches == nil {
|
||||
return false
|
||||
}
|
||||
// Haiku 不支持 tool search
|
||||
if strings.Contains(lower, "haiku") {
|
||||
return false
|
||||
}
|
||||
major, _ := strconv.Atoi(matches[1])
|
||||
minor, _ := strconv.Atoi(matches[2])
|
||||
return major > 4 || (major == 4 && minor >= 5)
|
||||
}
|
||||
|
||||
// filterBedrockBetaTokens 过滤并转换 beta token 列表,仅保留 Bedrock Invoke 支持的 token
|
||||
// 1. 应用转换规则(如 advanced-tool-use → tool-search-tool)
|
||||
// 2. 过滤掉 Bedrock 不支持的 token(如 output-128k, files-api, structured-outputs 等)
|
||||
// 3. 自动关联 tool-examples(当 tool-search-tool 存在时)
|
||||
func filterBedrockBetaTokens(tokens []string) []string {
|
||||
seen := make(map[string]bool, len(tokens))
|
||||
var result []string
|
||||
|
||||
for _, t := range tokens {
|
||||
// 应用转换规则
|
||||
if replacement, ok := bedrockBetaTokenTransforms[t]; ok {
|
||||
t = replacement
|
||||
}
|
||||
// 只保留白名单中的 token,且去重
|
||||
if bedrockSupportedBetaTokens[t] && !seen[t] {
|
||||
result = append(result, t)
|
||||
seen[t] = true
|
||||
}
|
||||
}
|
||||
|
||||
// 自动关联: tool-search-tool 存在时,确保 tool-examples 也存在
|
||||
if seen["tool-search-tool-2025-10-19"] && !seen["tool-examples-2025-10-29"] {
|
||||
result = append(result, "tool-examples-2025-10-29")
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,659 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestPrepareBedrockRequestBody_BasicFields(t *testing.T) {
|
||||
input := `{"model":"claude-opus-4-6","stream":true,"max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
// anthropic_version 应被注入
|
||||
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||
// model 和 stream 应被移除
|
||||
assert.False(t, gjson.GetBytes(result, "model").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "stream").Exists())
|
||||
// max_tokens 应保留
|
||||
assert.Equal(t, int64(1024), gjson.GetBytes(result, "max_tokens").Int())
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_OutputFormatInlineSchema(t *testing.T) {
|
||||
t.Run("schema inlined into last user message array content", func(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}},"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
// schema 应内联到最后一条 user message 的 content 数组末尾
|
||||
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||
require.Len(t, contentArr, 2)
|
||||
assert.Equal(t, "text", contentArr[1].Get("type").String())
|
||||
assert.Contains(t, contentArr[1].Get("text").String(), `"name":"string"`)
|
||||
})
|
||||
|
||||
t.Run("schema inlined into string content", func(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"result":"number"}},"messages":[{"role":"user","content":"compute this"}]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||
require.Len(t, contentArr, 2)
|
||||
assert.Equal(t, "compute this", contentArr[0].Get("text").String())
|
||||
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"number"`)
|
||||
})
|
||||
|
||||
t.Run("no schema field just removes output_format", func(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json"},"messages":[{"role":"user","content":"hi"}]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
})
|
||||
|
||||
t.Run("no messages just removes output_format", func(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_format":{"type":"json","schema":{"name":"string"}}}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_RemoveOutputConfig(t *testing.T) {
|
||||
input := `{"model":"claude-sonnet-4-5","output_config":{"max_tokens":100},"messages":[]}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-sonnet-4-5-v1", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
|
||||
}
|
||||
|
||||
func TestRemoveCustomFieldFromTools(t *testing.T) {
|
||||
input := `{
|
||||
"tools": [
|
||||
{"name":"tool1","custom":{"defer_loading":true},"description":"desc1"},
|
||||
{"name":"tool2","description":"desc2"},
|
||||
{"name":"tool3","custom":{"defer_loading":true,"other":123},"description":"desc3"}
|
||||
]
|
||||
}`
|
||||
result := removeCustomFieldFromTools([]byte(input))
|
||||
|
||||
tools := gjson.GetBytes(result, "tools").Array()
|
||||
require.Len(t, tools, 3)
|
||||
// custom 应被移除
|
||||
assert.False(t, tools[0].Get("custom").Exists())
|
||||
// name/description 应保留
|
||||
assert.Equal(t, "tool1", tools[0].Get("name").String())
|
||||
assert.Equal(t, "desc1", tools[0].Get("description").String())
|
||||
// 没有 custom 的工具不受影响
|
||||
assert.Equal(t, "tool2", tools[1].Get("name").String())
|
||||
// 第三个工具的 custom 也应被移除
|
||||
assert.False(t, tools[2].Get("custom").Exists())
|
||||
assert.Equal(t, "tool3", tools[2].Get("name").String())
|
||||
}
|
||||
|
||||
func TestRemoveCustomFieldFromTools_NoTools(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}]}`
|
||||
result := removeCustomFieldFromTools([]byte(input))
|
||||
// 无 tools 时不改变原始数据
|
||||
assert.JSONEq(t, input, string(result))
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_RemoveScope(t *testing.T) {
|
||||
input := `{
|
||||
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","scope":"global"}}],
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","scope":"global"}}]}]
|
||||
}`
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||
|
||||
// scope 应被移除
|
||||
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.scope").Exists())
|
||||
// type 应保留
|
||||
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "messages.0.content.0.cache_control.type").String())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_TTL_OldModel(t *testing.T) {
|
||||
input := `{
|
||||
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
|
||||
}`
|
||||
// 旧模型(Claude 3.5)不支持 ttl
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_TTL_Claude45_Supported(t *testing.T) {
|
||||
input := `{
|
||||
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"5m"}}]
|
||||
}`
|
||||
// Claude 4.5+ 支持 "5m" 和 "1h"
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
|
||||
|
||||
assert.True(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_TTL_Claude45_UnsupportedValue(t *testing.T) {
|
||||
input := `{
|
||||
"system": [{"type":"text","text":"sys","cache_control":{"type":"ephemeral","ttl":"10m"}}]
|
||||
}`
|
||||
// Claude 4.5 不支持 "10m"
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-sonnet-4-5-20250929-v1:0")
|
||||
|
||||
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.ttl").Exists())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_TTL_Claude46(t *testing.T) {
|
||||
input := `{
|
||||
"messages": [{"role":"user","content":[{"type":"text","text":"hi","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
|
||||
}`
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||
|
||||
assert.True(t, gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").Exists())
|
||||
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestSanitizeBedrockCacheControl_NoCacheControl(t *testing.T) {
|
||||
input := `{"system":[{"type":"text","text":"sys"}],"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`
|
||||
result := sanitizeBedrockCacheControl([]byte(input), "us.anthropic.claude-opus-4-6-v1")
|
||||
// 无 cache_control 时不改变原始数据
|
||||
assert.JSONEq(t, input, string(result))
|
||||
}
|
||||
|
||||
func TestIsBedrockClaude45OrNewer(t *testing.T) {
|
||||
tests := []struct {
|
||||
modelID string
|
||||
expect bool
|
||||
}{
|
||||
{"us.anthropic.claude-opus-4-6-v1", true},
|
||||
{"us.anthropic.claude-sonnet-4-6", true},
|
||||
{"us.anthropic.claude-sonnet-4-5-20250929-v1:0", true},
|
||||
{"us.anthropic.claude-opus-4-5-20251101-v1:0", true},
|
||||
{"us.anthropic.claude-haiku-4-5-20251001-v1:0", true},
|
||||
{"anthropic.claude-3-5-sonnet-20241022-v2:0", false},
|
||||
{"anthropic.claude-3-opus-20240229-v1:0", false},
|
||||
{"anthropic.claude-3-haiku-20240307-v1:0", false},
|
||||
// 未来版本应自动支持
|
||||
{"us.anthropic.claude-sonnet-5-0-v1", true},
|
||||
{"us.anthropic.claude-opus-4-7-v1", true},
|
||||
// 旧版本
|
||||
{"anthropic.claude-opus-4-1-v1", false},
|
||||
{"anthropic.claude-sonnet-4-0-v1", false},
|
||||
// 非 Claude 模型
|
||||
{"amazon.nova-pro-v1", false},
|
||||
{"meta.llama3-70b", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.modelID, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expect, isBedrockClaude45OrNewer(tt.modelID))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_FullIntegration(t *testing.T) {
|
||||
// 模拟一个完整的 Claude Code 请求
|
||||
input := `{
|
||||
"model": "claude-opus-4-6",
|
||||
"stream": true,
|
||||
"max_tokens": 16384,
|
||||
"output_format": {"type": "json", "schema": {"result": "string"}},
|
||||
"output_config": {"max_tokens": 100},
|
||||
"system": [{"type": "text", "text": "You are helpful", "cache_control": {"type": "ephemeral", "scope": "global", "ttl": "5m"}}],
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "hello", "cache_control": {"type": "ephemeral", "ttl": "1h"}}]}
|
||||
],
|
||||
"tools": [
|
||||
{"name": "bash", "description": "Run bash", "custom": {"defer_loading": true}, "input_schema": {"type": "object"}},
|
||||
{"name": "read", "description": "Read file", "input_schema": {"type": "object"}}
|
||||
]
|
||||
}`
|
||||
|
||||
betaHeader := "interleaved-thinking-2025-05-14, context-1m-2025-08-07, compact-2026-01-12"
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", betaHeader)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 基本字段
|
||||
assert.Equal(t, "bedrock-2023-05-31", gjson.GetBytes(result, "anthropic_version").String())
|
||||
assert.False(t, gjson.GetBytes(result, "model").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "stream").Exists())
|
||||
assert.Equal(t, int64(16384), gjson.GetBytes(result, "max_tokens").Int())
|
||||
|
||||
// anthropic_beta 应包含所有 beta tokens
|
||||
betaArr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, betaArr, 3)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", betaArr[0].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", betaArr[1].String())
|
||||
assert.Equal(t, "compact-2026-01-12", betaArr[2].String())
|
||||
|
||||
// output_format 应被移除,schema 内联到最后一条 user message
|
||||
assert.False(t, gjson.GetBytes(result, "output_format").Exists())
|
||||
assert.False(t, gjson.GetBytes(result, "output_config").Exists())
|
||||
// content 数组:原始 text block + 内联 schema block
|
||||
contentArr := gjson.GetBytes(result, "messages.0.content").Array()
|
||||
require.Len(t, contentArr, 2)
|
||||
assert.Equal(t, "hello", contentArr[0].Get("text").String())
|
||||
assert.Contains(t, contentArr[1].Get("text").String(), `"result":"string"`)
|
||||
|
||||
// tools 中的 custom 应被移除
|
||||
assert.False(t, gjson.GetBytes(result, "tools.0.custom").Exists())
|
||||
assert.Equal(t, "bash", gjson.GetBytes(result, "tools.0.name").String())
|
||||
assert.Equal(t, "read", gjson.GetBytes(result, "tools.1.name").String())
|
||||
|
||||
// cache_control: scope 应被移除,ttl 在 Claude 4.6 上保留合法值
|
||||
assert.False(t, gjson.GetBytes(result, "system.0.cache_control.scope").Exists())
|
||||
assert.Equal(t, "ephemeral", gjson.GetBytes(result, "system.0.cache_control.type").String())
|
||||
assert.Equal(t, "5m", gjson.GetBytes(result, "system.0.cache_control.ttl").String())
|
||||
assert.Equal(t, "1h", gjson.GetBytes(result, "messages.0.content.0.cache_control.ttl").String())
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_BetaHeader(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
|
||||
|
||||
t.Run("empty beta header", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||
require.NoError(t, err)
|
||||
assert.False(t, gjson.GetBytes(result, "anthropic_beta").Exists())
|
||||
})
|
||||
|
||||
t.Run("single beta token", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 1)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
})
|
||||
|
||||
t.Run("multiple beta tokens with spaces", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "interleaved-thinking-2025-05-14 , context-1m-2025-08-07 ")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 2)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
|
||||
})
|
||||
|
||||
t.Run("json array beta header", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", `["interleaved-thinking-2025-05-14","context-1m-2025-08-07"]`)
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 2)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
assert.Equal(t, "context-1m-2025-08-07", arr[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseAnthropicBetaHeader(t *testing.T) {
|
||||
assert.Nil(t, parseAnthropicBetaHeader(""))
|
||||
assert.Equal(t, []string{"a"}, parseAnthropicBetaHeader("a"))
|
||||
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a,b"))
|
||||
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader("a , b "))
|
||||
assert.Equal(t, []string{"a", "b", "c"}, parseAnthropicBetaHeader("a,b,c"))
|
||||
assert.Equal(t, []string{"a", "b"}, parseAnthropicBetaHeader(`["a","b"]`))
|
||||
}
|
||||
|
||||
func TestFilterBedrockBetaTokens(t *testing.T) {
|
||||
t.Run("supported tokens pass through", func(t *testing.T) {
|
||||
tokens := []string{"interleaved-thinking-2025-05-14", "context-1m-2025-08-07", "compact-2026-01-12"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Equal(t, tokens, result)
|
||||
})
|
||||
|
||||
t.Run("unsupported tokens are filtered out", func(t *testing.T) {
|
||||
tokens := []string{"interleaved-thinking-2025-05-14", "output-128k-2025-02-19", "files-api-2025-04-14", "structured-outputs-2025-11-13"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
|
||||
})
|
||||
|
||||
t.Run("advanced-tool-use transforms to tool-search-tool", func(t *testing.T) {
|
||||
tokens := []string{"advanced-tool-use-2025-11-20"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||
// tool-examples 自动关联
|
||||
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||
})
|
||||
|
||||
t.Run("tool-search-tool auto-associates tool-examples", func(t *testing.T) {
|
||||
tokens := []string{"tool-search-tool-2025-10-19"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||
})
|
||||
|
||||
t.Run("no duplication when tool-examples already present", func(t *testing.T) {
|
||||
tokens := []string{"tool-search-tool-2025-10-19", "tool-examples-2025-10-29"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
count := 0
|
||||
for _, t := range result {
|
||||
if t == "tool-examples-2025-10-29" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
|
||||
t.Run("empty input returns nil", func(t *testing.T) {
|
||||
result := filterBedrockBetaTokens(nil)
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("all unsupported returns nil", func(t *testing.T) {
|
||||
result := filterBedrockBetaTokens([]string{"output-128k-2025-02-19", "effort-2025-11-24"})
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("duplicate tokens are deduplicated", func(t *testing.T) {
|
||||
tokens := []string{"context-1m-2025-08-07", "context-1m-2025-08-07"}
|
||||
result := filterBedrockBetaTokens(tokens)
|
||||
assert.Equal(t, []string{"context-1m-2025-08-07"}, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_BetaFiltering(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`
|
||||
|
||||
t.Run("unsupported beta tokens are filtered", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
|
||||
"interleaved-thinking-2025-05-14, output-128k-2025-02-19, files-api-2025-04-14")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 1)
|
||||
assert.Equal(t, "interleaved-thinking-2025-05-14", arr[0].String())
|
||||
})
|
||||
|
||||
t.Run("advanced-tool-use transformed in full pipeline", func(t *testing.T) {
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1",
|
||||
"advanced-tool-use-2025-11-20")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
require.Len(t, arr, 2)
|
||||
assert.Equal(t, "tool-search-tool-2025-10-19", arr[0].String())
|
||||
assert.Equal(t, "tool-examples-2025-10-29", arr[1].String())
|
||||
})
|
||||
}
|
||||
|
||||
func TestBedrockCrossRegionPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
region string
|
||||
expect string
|
||||
}{
|
||||
// US regions
|
||||
{"us-east-1", "us"},
|
||||
{"us-east-2", "us"},
|
||||
{"us-west-1", "us"},
|
||||
{"us-west-2", "us"},
|
||||
// GovCloud
|
||||
{"us-gov-east-1", "us-gov"},
|
||||
{"us-gov-west-1", "us-gov"},
|
||||
// EU regions
|
||||
{"eu-west-1", "eu"},
|
||||
{"eu-west-2", "eu"},
|
||||
{"eu-west-3", "eu"},
|
||||
{"eu-central-1", "eu"},
|
||||
{"eu-central-2", "eu"},
|
||||
{"eu-north-1", "eu"},
|
||||
{"eu-south-1", "eu"},
|
||||
// APAC regions
|
||||
{"ap-northeast-1", "jp"},
|
||||
{"ap-northeast-2", "apac"},
|
||||
{"ap-southeast-1", "apac"},
|
||||
{"ap-southeast-2", "au"},
|
||||
{"ap-south-1", "apac"},
|
||||
// Canada / South America fallback to us
|
||||
{"ca-central-1", "us"},
|
||||
{"sa-east-1", "us"},
|
||||
// Unknown defaults to us
|
||||
{"me-south-1", "us"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.region, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expect, BedrockCrossRegionPrefix(tt.region))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveBedrockModelID(t *testing.T) {
|
||||
t.Run("default alias resolves and adjusts region", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "eu-west-1",
|
||||
},
|
||||
}
|
||||
|
||||
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-5")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", modelID)
|
||||
})
|
||||
|
||||
t.Run("custom alias mapping reuses default bedrock mapping", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "ap-southeast-2",
|
||||
"model_mapping": map[string]any{
|
||||
"claude-*": "claude-opus-4-6",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modelID, ok := ResolveBedrockModelID(account, "claude-opus-4-6-thinking")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "au.anthropic.claude-opus-4-6-v1", modelID)
|
||||
})
|
||||
|
||||
t.Run("force global rewrites anthropic regional model id", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "us-east-1",
|
||||
"aws_force_global": "true",
|
||||
"model_mapping": map[string]any{
|
||||
"claude-sonnet-4-6": "us.anthropic.claude-sonnet-4-6",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
modelID, ok := ResolveBedrockModelID(account, "claude-sonnet-4-6")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "global.anthropic.claude-sonnet-4-6", modelID)
|
||||
})
|
||||
|
||||
t.Run("direct bedrock model id passes through", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "us-east-1",
|
||||
},
|
||||
}
|
||||
|
||||
modelID, ok := ResolveBedrockModelID(account, "anthropic.claude-haiku-4-5-20251001-v1:0")
|
||||
require.True(t, ok)
|
||||
assert.Equal(t, "anthropic.claude-haiku-4-5-20251001-v1:0", modelID)
|
||||
})
|
||||
|
||||
t.Run("unsupported alias returns false", func(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_region": "us-east-1",
|
||||
},
|
||||
}
|
||||
|
||||
_, ok := ResolveBedrockModelID(account, "claude-3-5-sonnet-20241022")
|
||||
assert.False(t, ok)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAutoInjectBedrockBetaTokens(t *testing.T) {
|
||||
t.Run("inject interleaved-thinking when thinking present", func(t *testing.T) {
|
||||
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
|
||||
})
|
||||
|
||||
t.Run("no duplicate when already present", func(t *testing.T) {
|
||||
body := []byte(`{"thinking":{"type":"enabled","budget_tokens":10000},"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens([]string{"interleaved-thinking-2025-05-14"}, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
count := 0
|
||||
for _, t := range result {
|
||||
if t == "interleaved-thinking-2025-05-14" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
assert.Equal(t, 1, count)
|
||||
})
|
||||
|
||||
t.Run("inject computer-use when computer tool present", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"type":"computer_20250124","name":"computer","display_width_px":1024}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "computer-use-2025-11-24")
|
||||
})
|
||||
|
||||
t.Run("inject advanced-tool-use for programmatic tool calling", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||
})
|
||||
|
||||
t.Run("inject advanced-tool-use for input examples", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"bash","input_examples":[{"cmd":"ls"}]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||
})
|
||||
|
||||
t.Run("inject tool-search-tool directly for pure tool search (no programmatic/inputExamples)", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
|
||||
// 纯 tool search 场景直接注入 Bedrock 特定头,不走 advanced-tool-use 转换
|
||||
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
|
||||
})
|
||||
|
||||
t.Run("inject advanced-tool-use when tool search combined with programmatic calling", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"},{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-sonnet-4-6")
|
||||
// 混合场景使用 advanced-tool-use(后续由 filter 转换为 tool-search-tool)
|
||||
assert.Contains(t, result, "advanced-tool-use-2025-11-20")
|
||||
})
|
||||
|
||||
t.Run("do not inject tool-search beta for unsupported models", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"type":"tool_search_tool_regex_20251119","name":"search"}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
assert.NotContains(t, result, "advanced-tool-use-2025-11-20")
|
||||
assert.NotContains(t, result, "tool-search-tool-2025-10-19")
|
||||
})
|
||||
|
||||
t.Run("no injection for regular tools", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"bash","description":"run bash","input_schema":{"type":"object"}}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("no injection when no features detected", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hi"}],"max_tokens":100}`)
|
||||
result := autoInjectBedrockBetaTokens(nil, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Empty(t, result)
|
||||
})
|
||||
|
||||
t.Run("preserves existing tokens", func(t *testing.T) {
|
||||
body := []byte(`{"thinking":{"type":"enabled"},"messages":[{"role":"user","content":"hi"}]}`)
|
||||
existing := []string{"context-1m-2025-08-07", "compact-2026-01-12"}
|
||||
result := autoInjectBedrockBetaTokens(existing, body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "context-1m-2025-08-07")
|
||||
assert.Contains(t, result, "compact-2026-01-12")
|
||||
assert.Contains(t, result, "interleaved-thinking-2025-05-14")
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveBedrockBetaTokens(t *testing.T) {
|
||||
t.Run("body-only tool features resolve to final bedrock tokens", func(t *testing.T) {
|
||||
body := []byte(`{"tools":[{"name":"bash","allowed_callers":["code_execution_20250825"]}],"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := ResolveBedrockBetaTokens("", body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Contains(t, result, "tool-search-tool-2025-10-19")
|
||||
assert.Contains(t, result, "tool-examples-2025-10-29")
|
||||
})
|
||||
|
||||
t.Run("unsupported client beta tokens are filtered out", func(t *testing.T) {
|
||||
body := []byte(`{"messages":[{"role":"user","content":"hi"}]}`)
|
||||
result := ResolveBedrockBetaTokens("interleaved-thinking-2025-05-14,files-api-2025-04-14", body, "us.anthropic.claude-opus-4-6-v1")
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestPrepareBedrockRequestBody_AutoBetaInjection(t *testing.T) {
|
||||
t.Run("thinking in body auto-injects beta without header", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
found := false
|
||||
for _, v := range arr {
|
||||
if v.String() == "interleaved-thinking-2025-05-14" {
|
||||
found = true
|
||||
}
|
||||
}
|
||||
assert.True(t, found, "interleaved-thinking should be auto-injected")
|
||||
})
|
||||
|
||||
t.Run("header tokens merged with auto-injected tokens", func(t *testing.T) {
|
||||
input := `{"messages":[{"role":"user","content":"hi"}],"max_tokens":100,"thinking":{"type":"enabled","budget_tokens":10000}}`
|
||||
result, err := PrepareBedrockRequestBody([]byte(input), "us.anthropic.claude-opus-4-6-v1", "context-1m-2025-08-07")
|
||||
require.NoError(t, err)
|
||||
arr := gjson.GetBytes(result, "anthropic_beta").Array()
|
||||
names := make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
names[i] = v.String()
|
||||
}
|
||||
assert.Contains(t, names, "context-1m-2025-08-07")
|
||||
assert.Contains(t, names, "interleaved-thinking-2025-05-14")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAdjustBedrockModelRegionPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
modelID string
|
||||
region string
|
||||
expect string
|
||||
}{
|
||||
// US region — no change needed
|
||||
{"us region keeps us prefix", "us.anthropic.claude-opus-4-6-v1", "us-east-1", "us.anthropic.claude-opus-4-6-v1"},
|
||||
// EU region — replace us → eu
|
||||
{"eu region replaces prefix", "us.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
|
||||
{"eu region sonnet", "us.anthropic.claude-sonnet-4-6", "eu-central-1", "eu.anthropic.claude-sonnet-4-6"},
|
||||
// APAC region — jp and au have dedicated prefixes per AWS docs
|
||||
{"jp region (ap-northeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-northeast-1", "jp.anthropic.claude-sonnet-4-5-20250929-v1:0"},
|
||||
{"au region (ap-southeast-2)", "us.anthropic.claude-haiku-4-5-20251001-v1:0", "ap-southeast-2", "au.anthropic.claude-haiku-4-5-20251001-v1:0"},
|
||||
{"apac region (ap-southeast-1)", "us.anthropic.claude-sonnet-4-5-20250929-v1:0", "ap-southeast-1", "apac.anthropic.claude-sonnet-4-5-20250929-v1:0"},
|
||||
// eu → us (user manually set eu prefix, moved to us region)
|
||||
{"eu to us", "eu.anthropic.claude-opus-4-6-v1", "us-west-2", "us.anthropic.claude-opus-4-6-v1"},
|
||||
// global prefix — replace to match region
|
||||
{"global to eu", "global.anthropic.claude-opus-4-6-v1", "eu-west-1", "eu.anthropic.claude-opus-4-6-v1"},
|
||||
// No known prefix — leave unchanged
|
||||
{"no prefix unchanged", "anthropic.claude-3-5-sonnet-20241022-v2:0", "eu-west-1", "anthropic.claude-3-5-sonnet-20241022-v2:0"},
|
||||
// GovCloud — uses independent us-gov prefix
|
||||
{"govcloud from us", "us.anthropic.claude-opus-4-6-v1", "us-gov-east-1", "us-gov.anthropic.claude-opus-4-6-v1"},
|
||||
{"govcloud already correct", "us-gov.anthropic.claude-opus-4-6-v1", "us-gov-west-1", "us-gov.anthropic.claude-opus-4-6-v1"},
|
||||
// Force global (special region value)
|
||||
{"force global from us", "us.anthropic.claude-opus-4-6-v1", "global", "global.anthropic.claude-opus-4-6-v1"},
|
||||
{"force global from eu", "eu.anthropic.claude-sonnet-4-6", "global", "global.anthropic.claude-sonnet-4-6"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.expect, AdjustBedrockModelRegionPrefix(tt.modelID, tt.region))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4"
|
||||
)
|
||||
|
||||
// BedrockSigner 使用 AWS SigV4 对 Bedrock 请求签名
|
||||
type BedrockSigner struct {
|
||||
credentials aws.Credentials
|
||||
region string
|
||||
signer *v4.Signer
|
||||
}
|
||||
|
||||
// NewBedrockSigner 创建 BedrockSigner
|
||||
func NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region string) *BedrockSigner {
|
||||
return &BedrockSigner{
|
||||
credentials: aws.Credentials{
|
||||
AccessKeyID: accessKeyID,
|
||||
SecretAccessKey: secretAccessKey,
|
||||
SessionToken: sessionToken,
|
||||
},
|
||||
region: region,
|
||||
signer: v4.NewSigner(),
|
||||
}
|
||||
}
|
||||
|
||||
// NewBedrockSignerFromAccount 从 Account 凭证创建 BedrockSigner
|
||||
func NewBedrockSignerFromAccount(account *Account) (*BedrockSigner, error) {
|
||||
accessKeyID := account.GetCredential("aws_access_key_id")
|
||||
if accessKeyID == "" {
|
||||
return nil, fmt.Errorf("aws_access_key_id not found in credentials")
|
||||
}
|
||||
secretAccessKey := account.GetCredential("aws_secret_access_key")
|
||||
if secretAccessKey == "" {
|
||||
return nil, fmt.Errorf("aws_secret_access_key not found in credentials")
|
||||
}
|
||||
region := account.GetCredential("aws_region")
|
||||
if region == "" {
|
||||
region = defaultBedrockRegion
|
||||
}
|
||||
sessionToken := account.GetCredential("aws_session_token") // 可选
|
||||
|
||||
return NewBedrockSigner(accessKeyID, secretAccessKey, sessionToken, region), nil
|
||||
}
|
||||
|
||||
// SignRequest 对 HTTP 请求进行 SigV4 签名
|
||||
// 重要约束:调用此方法前,req 应只包含 AWS 相关的 header(如 Content-Type、Accept)。
|
||||
// 非 AWS header(如 anthropic-beta)会参与签名计算,如果 Bedrock 服务端不识别这些 header,
|
||||
// 签名验证可能失败。litellm 通过 _filter_headers_for_aws_signature 实现头过滤,
|
||||
// 当前实现中 buildUpstreamRequestBedrock 仅设置了 Content-Type 和 Accept,因此是安全的。
|
||||
func (s *BedrockSigner) SignRequest(ctx context.Context, req *http.Request, body []byte) error {
|
||||
payloadHash := sha256Hash(body)
|
||||
return s.signer.SignHTTP(ctx, s.credentials, req, payloadHash, "bedrock", s.region, time.Now())
|
||||
}
|
||||
|
||||
func sha256Hash(data []byte) string {
|
||||
h := sha256.Sum256(data)
|
||||
return hex.EncodeToString(h[:])
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewBedrockSignerFromAccount_DefaultRegion(t *testing.T) {
|
||||
account := &Account{
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeBedrock,
|
||||
Credentials: map[string]any{
|
||||
"aws_access_key_id": "test-akid",
|
||||
"aws_secret_access_key": "test-secret",
|
||||
},
|
||||
}
|
||||
|
||||
signer, err := NewBedrockSignerFromAccount(account)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, signer)
|
||||
assert.Equal(t, defaultBedrockRegion, signer.region)
|
||||
}
|
||||
|
||||
func TestFilterBetaTokens(t *testing.T) {
|
||||
tokens := []string{"interleaved-thinking-2025-05-14", "tool-search-tool-2025-10-19"}
|
||||
filterSet := map[string]struct{}{
|
||||
"tool-search-tool-2025-10-19": {},
|
||||
}
|
||||
|
||||
assert.Equal(t, []string{"interleaved-thinking-2025-05-14"}, filterBetaTokens(tokens, filterSet))
|
||||
assert.Equal(t, tokens, filterBetaTokens(tokens, nil))
|
||||
assert.Nil(t, filterBetaTokens(nil, filterSet))
|
||||
}
|
||||
@@ -0,0 +1,414 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// handleBedrockStreamingResponse 处理 Bedrock InvokeModelWithResponseStream 的 EventStream 响应
|
||||
// Bedrock 返回 AWS EventStream 二进制格式,每个事件的 payload 中 chunk.bytes 是 base64 编码的
|
||||
// Claude SSE 事件 JSON。本方法解码后转换为标准 SSE 格式写入客户端。
|
||||
func (s *GatewayService) handleBedrockStreamingResponse(
|
||||
ctx context.Context,
|
||||
resp *http.Response,
|
||||
c *gin.Context,
|
||||
account *Account,
|
||||
startTime time.Time,
|
||||
model string,
|
||||
) (*streamingResult, error) {
|
||||
w := c.Writer
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
return nil, errors.New("streaming not supported")
|
||||
}
|
||||
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("X-Accel-Buffering", "no")
|
||||
if v := resp.Header.Get("x-amzn-requestid"); v != "" {
|
||||
c.Header("x-request-id", v)
|
||||
}
|
||||
|
||||
usage := &ClaudeUsage{}
|
||||
var firstTokenMs *int
|
||||
clientDisconnected := false
|
||||
|
||||
// Bedrock EventStream 使用 application/vnd.amazon.eventstream 二进制格式。
|
||||
// 每个帧结构:total_length(4) + headers_length(4) + prelude_crc(4) + headers + payload + message_crc(4)
|
||||
// 但更实用的方式是使用行扫描找 JSON chunks,因为 Bedrock 的响应在二进制帧中。
|
||||
// 我们使用 EventStream decoder 来正确解析。
|
||||
decoder := newBedrockEventStreamDecoder(resp.Body)
|
||||
|
||||
type decodeEvent struct {
|
||||
payload []byte
|
||||
err error
|
||||
}
|
||||
events := make(chan decodeEvent, 16)
|
||||
done := make(chan struct{})
|
||||
sendEvent := func(ev decodeEvent) bool {
|
||||
select {
|
||||
case events <- ev:
|
||||
return true
|
||||
case <-done:
|
||||
return false
|
||||
}
|
||||
}
|
||||
var lastReadAt atomic.Int64
|
||||
lastReadAt.Store(time.Now().UnixNano())
|
||||
|
||||
go func() {
|
||||
defer close(events)
|
||||
for {
|
||||
payload, err := decoder.Decode()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
_ = sendEvent(decodeEvent{err: err})
|
||||
return
|
||||
}
|
||||
lastReadAt.Store(time.Now().UnixNano())
|
||||
if !sendEvent(decodeEvent{payload: payload}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
defer close(done)
|
||||
|
||||
streamInterval := time.Duration(0)
|
||||
if s.cfg != nil && s.cfg.Gateway.StreamDataIntervalTimeout > 0 {
|
||||
streamInterval = time.Duration(s.cfg.Gateway.StreamDataIntervalTimeout) * time.Second
|
||||
}
|
||||
var intervalTicker *time.Ticker
|
||||
if streamInterval > 0 {
|
||||
intervalTicker = time.NewTicker(streamInterval)
|
||||
defer intervalTicker.Stop()
|
||||
}
|
||||
var intervalCh <-chan time.Time
|
||||
if intervalTicker != nil {
|
||||
intervalCh = intervalTicker.C
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case ev, ok := <-events:
|
||||
if !ok {
|
||||
if !clientDisconnected {
|
||||
flusher.Flush()
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: clientDisconnected}, nil
|
||||
}
|
||||
if ev.err != nil {
|
||||
if clientDisconnected {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
if errors.Is(ev.err, context.Canceled) || errors.Is(ev.err, context.DeadlineExceeded) {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("bedrock stream read error: %w", ev.err)
|
||||
}
|
||||
|
||||
// payload 是 JSON,提取 chunk.bytes(base64 编码的 Claude SSE 事件数据)
|
||||
sseData := extractBedrockChunkData(ev.payload)
|
||||
if sseData == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if firstTokenMs == nil {
|
||||
ms := int(time.Since(startTime).Milliseconds())
|
||||
firstTokenMs = &ms
|
||||
}
|
||||
|
||||
// 转换 Bedrock 特有的 amazon-bedrock-invocationMetrics 为标准 Anthropic usage 格式
|
||||
// 同时移除该字段避免透传给客户端
|
||||
sseData = transformBedrockInvocationMetrics(sseData)
|
||||
|
||||
// 解析 SSE 事件数据提取 usage
|
||||
s.parseSSEUsagePassthrough(string(sseData), usage)
|
||||
|
||||
// 确定 SSE event type
|
||||
eventType := gjson.GetBytes(sseData, "type").String()
|
||||
|
||||
// 写入标准 SSE 格式
|
||||
if !clientDisconnected {
|
||||
var writeErr error
|
||||
if eventType != "" {
|
||||
_, writeErr = fmt.Fprintf(w, "event: %s\ndata: %s\n\n", eventType, sseData)
|
||||
} else {
|
||||
_, writeErr = fmt.Fprintf(w, "data: %s\n\n", sseData)
|
||||
}
|
||||
if writeErr != nil {
|
||||
clientDisconnected = true
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock] Client disconnected during streaming, continue draining for usage: account=%d", account.ID)
|
||||
} else {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
case <-intervalCh:
|
||||
lastRead := time.Unix(0, lastReadAt.Load())
|
||||
if time.Since(lastRead) < streamInterval {
|
||||
continue
|
||||
}
|
||||
if clientDisconnected {
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs, clientDisconnect: true}, nil
|
||||
}
|
||||
logger.LegacyPrintf("service.gateway", "[Bedrock] Stream data interval timeout: account=%d model=%s interval=%s", account.ID, model, streamInterval)
|
||||
if s.rateLimitService != nil {
|
||||
s.rateLimitService.HandleStreamTimeout(ctx, account, model)
|
||||
}
|
||||
return &streamingResult{usage: usage, firstTokenMs: firstTokenMs}, fmt.Errorf("stream data interval timeout")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// extractBedrockChunkData 从 Bedrock EventStream payload 中提取 Claude SSE 事件数据
|
||||
// Bedrock payload 格式:{"bytes":"<base64-encoded-json>"}
|
||||
func extractBedrockChunkData(payload []byte) []byte {
|
||||
b64 := gjson.GetBytes(payload, "bytes").String()
|
||||
if b64 == "" {
|
||||
return nil
|
||||
}
|
||||
decoded, err := base64.StdEncoding.DecodeString(b64)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return decoded
|
||||
}
|
||||
|
||||
// transformBedrockInvocationMetrics 将 Bedrock 特有的 amazon-bedrock-invocationMetrics
|
||||
// 转换为标准 Anthropic usage 格式,并从 SSE 数据中移除该字段。
|
||||
//
|
||||
// Bedrock Invoke 返回的 message_delta 事件可能包含:
|
||||
//
|
||||
// {"type":"message_delta","delta":{...},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}
|
||||
//
|
||||
// 转换为:
|
||||
//
|
||||
// {"type":"message_delta","delta":{...},"usage":{"input_tokens":150,"output_tokens":42}}
|
||||
func transformBedrockInvocationMetrics(data []byte) []byte {
|
||||
metrics := gjson.GetBytes(data, "amazon-bedrock-invocationMetrics")
|
||||
if !metrics.Exists() || !metrics.IsObject() {
|
||||
return data
|
||||
}
|
||||
|
||||
// 移除 Bedrock 特有字段
|
||||
data, _ = sjson.DeleteBytes(data, "amazon-bedrock-invocationMetrics")
|
||||
|
||||
// 如果已有标准 usage 字段,不覆盖
|
||||
if gjson.GetBytes(data, "usage").Exists() {
|
||||
return data
|
||||
}
|
||||
|
||||
// 转换 camelCase → snake_case 写入 usage
|
||||
inputTokens := metrics.Get("inputTokenCount")
|
||||
outputTokens := metrics.Get("outputTokenCount")
|
||||
if inputTokens.Exists() {
|
||||
data, _ = sjson.SetBytes(data, "usage.input_tokens", inputTokens.Int())
|
||||
}
|
||||
if outputTokens.Exists() {
|
||||
data, _ = sjson.SetBytes(data, "usage.output_tokens", outputTokens.Int())
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// bedrockEventStreamDecoder 解码 AWS EventStream 二进制帧
|
||||
// EventStream 帧格式:
|
||||
//
|
||||
// [total_byte_length: 4 bytes]
|
||||
// [headers_byte_length: 4 bytes]
|
||||
// [prelude_crc: 4 bytes]
|
||||
// [headers: variable]
|
||||
// [payload: variable]
|
||||
// [message_crc: 4 bytes]
|
||||
type bedrockEventStreamDecoder struct {
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func newBedrockEventStreamDecoder(r io.Reader) *bedrockEventStreamDecoder {
|
||||
return &bedrockEventStreamDecoder{
|
||||
reader: bufio.NewReaderSize(r, 64*1024),
|
||||
}
|
||||
}
|
||||
|
||||
// Decode 读取下一个 EventStream 帧并返回 chunk 类型事件的 payload
|
||||
func (d *bedrockEventStreamDecoder) Decode() ([]byte, error) {
|
||||
for {
|
||||
// 读取 prelude: total_length(4) + headers_length(4) + prelude_crc(4) = 12 bytes
|
||||
prelude := make([]byte, 12)
|
||||
if _, err := io.ReadFull(d.reader, prelude); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证 prelude CRC(AWS EventStream 使用标准 CRC32 / IEEE)
|
||||
preludeCRC := bedrockReadUint32(prelude[8:12])
|
||||
if crc32.Checksum(prelude[0:8], crc32IEEETable) != preludeCRC {
|
||||
return nil, fmt.Errorf("eventstream prelude CRC mismatch")
|
||||
}
|
||||
|
||||
totalLength := bedrockReadUint32(prelude[0:4])
|
||||
headersLength := bedrockReadUint32(prelude[4:8])
|
||||
|
||||
if totalLength < 16 { // minimum: 12 prelude + 4 message_crc
|
||||
return nil, fmt.Errorf("invalid eventstream frame: total_length=%d", totalLength)
|
||||
}
|
||||
|
||||
// 读取 headers + payload + message_crc
|
||||
remaining := int(totalLength) - 12
|
||||
if remaining <= 0 {
|
||||
continue
|
||||
}
|
||||
data := make([]byte, remaining)
|
||||
if _, err := io.ReadFull(d.reader, data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 验证 message CRC(覆盖 prelude + headers + payload)
|
||||
messageCRC := bedrockReadUint32(data[len(data)-4:])
|
||||
h := crc32.New(crc32IEEETable)
|
||||
_, _ = h.Write(prelude)
|
||||
_, _ = h.Write(data[:len(data)-4])
|
||||
if h.Sum32() != messageCRC {
|
||||
return nil, fmt.Errorf("eventstream message CRC mismatch")
|
||||
}
|
||||
|
||||
// 解析 headers
|
||||
headers := data[:headersLength]
|
||||
payload := data[headersLength : len(data)-4] // 去掉 message_crc
|
||||
|
||||
// 从 headers 中提取 :event-type
|
||||
eventType := extractEventStreamHeaderValue(headers, ":event-type")
|
||||
|
||||
// 只处理 chunk 事件
|
||||
if eventType == "chunk" {
|
||||
// payload 是完整的 JSON,包含 bytes 字段
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
// 检查异常事件
|
||||
exceptionType := extractEventStreamHeaderValue(headers, ":exception-type")
|
||||
if exceptionType != "" {
|
||||
return nil, fmt.Errorf("bedrock exception: %s: %s", exceptionType, string(payload))
|
||||
}
|
||||
|
||||
messageType := extractEventStreamHeaderValue(headers, ":message-type")
|
||||
if messageType == "exception" || messageType == "error" {
|
||||
return nil, fmt.Errorf("bedrock error: %s", string(payload))
|
||||
}
|
||||
|
||||
// 跳过其他事件类型(如 initial-response)
|
||||
}
|
||||
}
|
||||
|
||||
// extractEventStreamHeaderValue 从 EventStream headers 二进制数据中提取指定 header 的字符串值
|
||||
// EventStream header 格式:
|
||||
//
|
||||
// [name_length: 1 byte][name: variable][value_type: 1 byte][value: variable]
|
||||
//
|
||||
// value_type = 7 表示 string 类型,前 2 bytes 为长度
|
||||
func extractEventStreamHeaderValue(headers []byte, targetName string) string {
|
||||
pos := 0
|
||||
for pos < len(headers) {
|
||||
if pos >= len(headers) {
|
||||
break
|
||||
}
|
||||
nameLen := int(headers[pos])
|
||||
pos++
|
||||
if pos+nameLen > len(headers) {
|
||||
break
|
||||
}
|
||||
name := string(headers[pos : pos+nameLen])
|
||||
pos += nameLen
|
||||
|
||||
if pos >= len(headers) {
|
||||
break
|
||||
}
|
||||
valueType := headers[pos]
|
||||
pos++
|
||||
|
||||
switch valueType {
|
||||
case 7: // string
|
||||
if pos+2 > len(headers) {
|
||||
return ""
|
||||
}
|
||||
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
|
||||
pos += 2
|
||||
if pos+valueLen > len(headers) {
|
||||
return ""
|
||||
}
|
||||
value := string(headers[pos : pos+valueLen])
|
||||
pos += valueLen
|
||||
if name == targetName {
|
||||
return value
|
||||
}
|
||||
case 0: // bool true
|
||||
if name == targetName {
|
||||
return "true"
|
||||
}
|
||||
case 1: // bool false
|
||||
if name == targetName {
|
||||
return "false"
|
||||
}
|
||||
case 2: // byte
|
||||
pos++
|
||||
if name == targetName {
|
||||
return ""
|
||||
}
|
||||
case 3: // short
|
||||
pos += 2
|
||||
if name == targetName {
|
||||
return ""
|
||||
}
|
||||
case 4: // int
|
||||
pos += 4
|
||||
if name == targetName {
|
||||
return ""
|
||||
}
|
||||
case 5: // long
|
||||
pos += 8
|
||||
if name == targetName {
|
||||
return ""
|
||||
}
|
||||
case 6: // bytes
|
||||
if pos+2 > len(headers) {
|
||||
return ""
|
||||
}
|
||||
valueLen := int(bedrockReadUint16(headers[pos : pos+2]))
|
||||
pos += 2 + valueLen
|
||||
case 8: // timestamp
|
||||
pos += 8
|
||||
case 9: // uuid
|
||||
pos += 16
|
||||
default:
|
||||
return "" // 未知类型,无法继续解析
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// crc32IEEETable is the CRC32 / IEEE table used by AWS EventStream.
|
||||
var crc32IEEETable = crc32.MakeTable(crc32.IEEE)
|
||||
|
||||
func bedrockReadUint32(b []byte) uint32 {
|
||||
return uint32(b[0])<<24 | uint32(b[1])<<16 | uint32(b[2])<<8 | uint32(b[3])
|
||||
}
|
||||
|
||||
func bedrockReadUint16(b []byte) uint16 {
|
||||
return uint16(b[0])<<8 | uint16(b[1])
|
||||
}
|
||||
@@ -0,0 +1,261 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestExtractBedrockChunkData(t *testing.T) {
|
||||
t.Run("valid base64 payload", func(t *testing.T) {
|
||||
original := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`
|
||||
b64 := base64.StdEncoding.EncodeToString([]byte(original))
|
||||
payload := []byte(`{"bytes":"` + b64 + `"}`)
|
||||
|
||||
result := extractBedrockChunkData(payload)
|
||||
require.NotNil(t, result)
|
||||
assert.JSONEq(t, original, string(result))
|
||||
})
|
||||
|
||||
t.Run("empty bytes field", func(t *testing.T) {
|
||||
result := extractBedrockChunkData([]byte(`{"bytes":""}`))
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("no bytes field", func(t *testing.T) {
|
||||
result := extractBedrockChunkData([]byte(`{"other":"value"}`))
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
|
||||
t.Run("invalid base64", func(t *testing.T) {
|
||||
result := extractBedrockChunkData([]byte(`{"bytes":"not-valid-base64!!!"}`))
|
||||
assert.Nil(t, result)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTransformBedrockInvocationMetrics(t *testing.T) {
|
||||
t.Run("converts metrics to usage", func(t *testing.T) {
|
||||
input := `{"type":"message_delta","delta":{"stop_reason":"end_turn"},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
|
||||
result := transformBedrockInvocationMetrics([]byte(input))
|
||||
|
||||
// amazon-bedrock-invocationMetrics should be removed
|
||||
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
|
||||
// usage should be set
|
||||
assert.Equal(t, int64(150), gjson.GetBytes(result, "usage.input_tokens").Int())
|
||||
assert.Equal(t, int64(42), gjson.GetBytes(result, "usage.output_tokens").Int())
|
||||
// original fields preserved
|
||||
assert.Equal(t, "message_delta", gjson.GetBytes(result, "type").String())
|
||||
assert.Equal(t, "end_turn", gjson.GetBytes(result, "delta.stop_reason").String())
|
||||
})
|
||||
|
||||
t.Run("no metrics present", func(t *testing.T) {
|
||||
input := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}`
|
||||
result := transformBedrockInvocationMetrics([]byte(input))
|
||||
assert.JSONEq(t, input, string(result))
|
||||
})
|
||||
|
||||
t.Run("does not overwrite existing usage", func(t *testing.T) {
|
||||
input := `{"type":"message_delta","usage":{"output_tokens":100},"amazon-bedrock-invocationMetrics":{"inputTokenCount":150,"outputTokenCount":42}}`
|
||||
result := transformBedrockInvocationMetrics([]byte(input))
|
||||
|
||||
// metrics removed but existing usage preserved
|
||||
assert.False(t, gjson.GetBytes(result, "amazon-bedrock-invocationMetrics").Exists())
|
||||
assert.Equal(t, int64(100), gjson.GetBytes(result, "usage.output_tokens").Int())
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractEventStreamHeaderValue(t *testing.T) {
|
||||
// Build a header with :event-type = "chunk" (string type = 7)
|
||||
buildStringHeader := func(name, value string) []byte {
|
||||
var buf bytes.Buffer
|
||||
// name length (1 byte)
|
||||
_ = buf.WriteByte(byte(len(name)))
|
||||
// name
|
||||
_, _ = buf.WriteString(name)
|
||||
// value type (7 = string)
|
||||
_ = buf.WriteByte(7)
|
||||
// value length (2 bytes, big-endian)
|
||||
_ = binary.Write(&buf, binary.BigEndian, uint16(len(value)))
|
||||
// value
|
||||
_, _ = buf.WriteString(value)
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
t.Run("find string header", func(t *testing.T) {
|
||||
headers := buildStringHeader(":event-type", "chunk")
|
||||
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
|
||||
})
|
||||
|
||||
t.Run("header not found", func(t *testing.T) {
|
||||
headers := buildStringHeader(":event-type", "chunk")
|
||||
assert.Equal(t, "", extractEventStreamHeaderValue(headers, ":message-type"))
|
||||
})
|
||||
|
||||
t.Run("multiple headers", func(t *testing.T) {
|
||||
var buf bytes.Buffer
|
||||
_, _ = buf.Write(buildStringHeader(":content-type", "application/json"))
|
||||
_, _ = buf.Write(buildStringHeader(":event-type", "chunk"))
|
||||
_, _ = buf.Write(buildStringHeader(":message-type", "event"))
|
||||
|
||||
headers := buf.Bytes()
|
||||
assert.Equal(t, "chunk", extractEventStreamHeaderValue(headers, ":event-type"))
|
||||
assert.Equal(t, "application/json", extractEventStreamHeaderValue(headers, ":content-type"))
|
||||
assert.Equal(t, "event", extractEventStreamHeaderValue(headers, ":message-type"))
|
||||
})
|
||||
|
||||
t.Run("empty headers", func(t *testing.T) {
|
||||
assert.Equal(t, "", extractEventStreamHeaderValue([]byte{}, ":event-type"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestBedrockEventStreamDecoder(t *testing.T) {
|
||||
crc32IeeeTab := crc32.MakeTable(crc32.IEEE)
|
||||
|
||||
// Build a valid EventStream frame with correct CRC32/IEEE checksums.
|
||||
buildFrame := func(eventType string, payload []byte) []byte {
|
||||
// Build headers
|
||||
var headersBuf bytes.Buffer
|
||||
// :event-type header
|
||||
_ = headersBuf.WriteByte(byte(len(":event-type")))
|
||||
_, _ = headersBuf.WriteString(":event-type")
|
||||
_ = headersBuf.WriteByte(7) // string type
|
||||
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len(eventType)))
|
||||
_, _ = headersBuf.WriteString(eventType)
|
||||
// :message-type header
|
||||
_ = headersBuf.WriteByte(byte(len(":message-type")))
|
||||
_, _ = headersBuf.WriteString(":message-type")
|
||||
_ = headersBuf.WriteByte(7)
|
||||
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("event")))
|
||||
_, _ = headersBuf.WriteString("event")
|
||||
|
||||
headers := headersBuf.Bytes()
|
||||
headersLen := uint32(len(headers))
|
||||
// total = 12 (prelude) + headers + payload + 4 (message_crc)
|
||||
totalLen := uint32(12 + len(headers) + len(payload) + 4)
|
||||
|
||||
// Prelude: total_length(4) + headers_length(4)
|
||||
var preludeBuf bytes.Buffer
|
||||
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
|
||||
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
|
||||
preludeBytes := preludeBuf.Bytes()
|
||||
preludeCRC := crc32.Checksum(preludeBytes, crc32IeeeTab)
|
||||
|
||||
// Build frame: prelude + prelude_crc + headers + payload
|
||||
var frame bytes.Buffer
|
||||
_, _ = frame.Write(preludeBytes)
|
||||
_ = binary.Write(&frame, binary.BigEndian, preludeCRC)
|
||||
_, _ = frame.Write(headers)
|
||||
_, _ = frame.Write(payload)
|
||||
|
||||
// Message CRC covers everything before itself
|
||||
messageCRC := crc32.Checksum(frame.Bytes(), crc32IeeeTab)
|
||||
_ = binary.Write(&frame, binary.BigEndian, messageCRC)
|
||||
return frame.Bytes()
|
||||
}
|
||||
|
||||
t.Run("decode chunk event", func(t *testing.T) {
|
||||
payload := []byte(`{"bytes":"dGVzdA=="}`) // base64("test")
|
||||
frame := buildFrame("chunk", payload)
|
||||
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||
result, err := decoder.Decode()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, payload, result)
|
||||
})
|
||||
|
||||
t.Run("skip non-chunk events", func(t *testing.T) {
|
||||
// Write initial-response followed by chunk
|
||||
var buf bytes.Buffer
|
||||
_, _ = buf.Write(buildFrame("initial-response", []byte(`{}`)))
|
||||
chunkPayload := []byte(`{"bytes":"aGVsbG8="}`)
|
||||
_, _ = buf.Write(buildFrame("chunk", chunkPayload))
|
||||
|
||||
decoder := newBedrockEventStreamDecoder(&buf)
|
||||
result, err := decoder.Decode()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, chunkPayload, result)
|
||||
})
|
||||
|
||||
t.Run("EOF on empty input", func(t *testing.T) {
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(nil))
|
||||
_, err := decoder.Decode()
|
||||
assert.Equal(t, io.EOF, err)
|
||||
})
|
||||
|
||||
t.Run("corrupted prelude CRC", func(t *testing.T) {
|
||||
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
|
||||
// Corrupt the prelude CRC (bytes 8-11)
|
||||
frame[8] ^= 0xFF
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||
_, err := decoder.Decode()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "prelude CRC mismatch")
|
||||
})
|
||||
|
||||
t.Run("corrupted message CRC", func(t *testing.T) {
|
||||
frame := buildFrame("chunk", []byte(`{"bytes":"dGVzdA=="}`))
|
||||
// Corrupt the message CRC (last 4 bytes)
|
||||
frame[len(frame)-1] ^= 0xFF
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame))
|
||||
_, err := decoder.Decode()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "message CRC mismatch")
|
||||
})
|
||||
|
||||
t.Run("castagnoli encoded frame is rejected", func(t *testing.T) {
|
||||
castagnoliTab := crc32.MakeTable(crc32.Castagnoli)
|
||||
payload := []byte(`{"bytes":"dGVzdA=="}`)
|
||||
|
||||
var headersBuf bytes.Buffer
|
||||
_ = headersBuf.WriteByte(byte(len(":event-type")))
|
||||
_, _ = headersBuf.WriteString(":event-type")
|
||||
_ = headersBuf.WriteByte(7)
|
||||
_ = binary.Write(&headersBuf, binary.BigEndian, uint16(len("chunk")))
|
||||
_, _ = headersBuf.WriteString("chunk")
|
||||
|
||||
headers := headersBuf.Bytes()
|
||||
headersLen := uint32(len(headers))
|
||||
totalLen := uint32(12 + len(headers) + len(payload) + 4)
|
||||
|
||||
var preludeBuf bytes.Buffer
|
||||
_ = binary.Write(&preludeBuf, binary.BigEndian, totalLen)
|
||||
_ = binary.Write(&preludeBuf, binary.BigEndian, headersLen)
|
||||
preludeBytes := preludeBuf.Bytes()
|
||||
|
||||
var frame bytes.Buffer
|
||||
_, _ = frame.Write(preludeBytes)
|
||||
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(preludeBytes, castagnoliTab))
|
||||
_, _ = frame.Write(headers)
|
||||
_, _ = frame.Write(payload)
|
||||
_ = binary.Write(&frame, binary.BigEndian, crc32.Checksum(frame.Bytes(), castagnoliTab))
|
||||
|
||||
decoder := newBedrockEventStreamDecoder(bytes.NewReader(frame.Bytes()))
|
||||
_, err := decoder.Decode()
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "prelude CRC mismatch")
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildBedrockURL(t *testing.T) {
|
||||
t.Run("stream URL with colon in model ID", func(t *testing.T) {
|
||||
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-opus-4-5-20251101-v1:0", true)
|
||||
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-opus-4-5-20251101-v1%3A0/invoke-with-response-stream", url)
|
||||
})
|
||||
|
||||
t.Run("non-stream URL with colon in model ID", func(t *testing.T) {
|
||||
url := BuildBedrockURL("eu-west-1", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0", false)
|
||||
assert.Equal(t, "https://bedrock-runtime.eu-west-1.amazonaws.com/model/eu.anthropic.claude-sonnet-4-5-20250929-v1%3A0/invoke", url)
|
||||
})
|
||||
|
||||
t.Run("model ID without colon", func(t *testing.T) {
|
||||
url := BuildBedrockURL("us-east-1", "us.anthropic.claude-sonnet-4-6", true)
|
||||
assert.Equal(t, "https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-sonnet-4-6/invoke-with-response-stream", url)
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// SubscriptionCacheData represents cached subscription data
|
||||
type SubscriptionCacheData struct {
|
||||
Status string
|
||||
ExpiresAt time.Time
|
||||
DailyUsage float64
|
||||
WeeklyUsage float64
|
||||
MonthlyUsage float64
|
||||
Version int64
|
||||
}
|
||||
@@ -0,0 +1,861 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
infraerrors "github.com/Wei-Shaw/sub2api/internal/pkg/errors"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
// 错误定义
|
||||
// 注:ErrInsufficientBalance在redeem_service.go中定义
|
||||
// 注:ErrDailyLimitExceeded/ErrWeeklyLimitExceeded/ErrMonthlyLimitExceeded在subscription_service.go中定义
|
||||
var (
|
||||
ErrSubscriptionInvalid = infraerrors.Forbidden("SUBSCRIPTION_INVALID", "subscription is invalid or expired")
|
||||
ErrBillingServiceUnavailable = infraerrors.ServiceUnavailable("BILLING_SERVICE_ERROR", "Billing service temporarily unavailable. Please retry later.")
|
||||
)
|
||||
|
||||
// subscriptionCacheData 订阅缓存数据结构(内部使用)
|
||||
type subscriptionCacheData struct {
|
||||
Status string
|
||||
ExpiresAt time.Time
|
||||
DailyUsage float64
|
||||
WeeklyUsage float64
|
||||
MonthlyUsage float64
|
||||
Version int64
|
||||
}
|
||||
|
||||
// 缓存写入任务类型
|
||||
type cacheWriteKind int
|
||||
|
||||
const (
|
||||
cacheWriteSetBalance cacheWriteKind = iota
|
||||
cacheWriteSetSubscription
|
||||
cacheWriteUpdateSubscriptionUsage
|
||||
cacheWriteDeductBalance
|
||||
cacheWriteUpdateRateLimitUsage
|
||||
)
|
||||
|
||||
// 异步缓存写入工作池配置
|
||||
//
|
||||
// 性能优化说明:
|
||||
// 原实现在请求热路径中使用 goroutine 异步更新缓存,存在以下问题:
|
||||
// 1. 每次请求创建新 goroutine,高并发下产生大量短生命周期 goroutine
|
||||
// 2. 无法控制并发数量,可能导致 Redis 连接耗尽
|
||||
// 3. goroutine 创建/销毁带来额外开销
|
||||
//
|
||||
// 新实现使用固定大小的工作池:
|
||||
// 1. 预创建 10 个 worker goroutine,避免频繁创建销毁
|
||||
// 2. 使用带缓冲的 channel(1000)作为任务队列,平滑写入峰值
|
||||
// 3. 非阻塞写入,队列满时关键任务同步回退,非关键任务丢弃并告警
|
||||
// 4. 统一超时控制,避免慢操作阻塞工作池
|
||||
const (
|
||||
cacheWriteWorkerCount = 10 // 工作协程数量
|
||||
cacheWriteBufferSize = 1000 // 任务队列缓冲大小
|
||||
cacheWriteTimeout = 2 * time.Second // 单个写入操作超时
|
||||
cacheWriteDropLogInterval = 5 * time.Second // 丢弃日志节流间隔
|
||||
balanceLoadTimeout = 3 * time.Second
|
||||
)
|
||||
|
||||
// cacheWriteTask 缓存写入任务
|
||||
type cacheWriteTask struct {
|
||||
kind cacheWriteKind
|
||||
userID int64
|
||||
groupID int64
|
||||
apiKeyID int64
|
||||
balance float64
|
||||
amount float64
|
||||
subscriptionData *subscriptionCacheData
|
||||
}
|
||||
|
||||
// apiKeyRateLimitLoader defines the interface for loading rate limit data from DB.
|
||||
type apiKeyRateLimitLoader interface {
|
||||
GetRateLimitData(ctx context.Context, keyID int64) (*APIKeyRateLimitData, error)
|
||||
}
|
||||
|
||||
// BillingCacheService 计费缓存服务
|
||||
// 负责余额和订阅数据的缓存管理,提供高性能的计费资格检查
|
||||
type BillingCacheService struct {
|
||||
cache BillingCache
|
||||
userRepo UserRepository
|
||||
subRepo UserSubscriptionRepository
|
||||
apiKeyRateLimitLoader apiKeyRateLimitLoader
|
||||
cfg *config.Config
|
||||
circuitBreaker *billingCircuitBreaker
|
||||
|
||||
cacheWriteChan chan cacheWriteTask
|
||||
cacheWriteWg sync.WaitGroup
|
||||
cacheWriteStopOnce sync.Once
|
||||
cacheWriteMu sync.RWMutex
|
||||
stopped atomic.Bool
|
||||
balanceLoadSF singleflight.Group
|
||||
// 丢弃日志节流计数器(减少高负载下日志噪音)
|
||||
cacheWriteDropFullCount uint64
|
||||
cacheWriteDropFullLastLog int64
|
||||
cacheWriteDropClosedCount uint64
|
||||
cacheWriteDropClosedLastLog int64
|
||||
}
|
||||
|
||||
// NewBillingCacheService 创建计费缓存服务
|
||||
func NewBillingCacheService(cache BillingCache, userRepo UserRepository, subRepo UserSubscriptionRepository, apiKeyRepo APIKeyRepository, cfg *config.Config) *BillingCacheService {
|
||||
svc := &BillingCacheService{
|
||||
cache: cache,
|
||||
userRepo: userRepo,
|
||||
subRepo: subRepo,
|
||||
apiKeyRateLimitLoader: apiKeyRepo,
|
||||
cfg: cfg,
|
||||
}
|
||||
svc.circuitBreaker = newBillingCircuitBreaker(cfg.Billing.CircuitBreaker)
|
||||
svc.startCacheWriteWorkers()
|
||||
return svc
|
||||
}
|
||||
|
||||
// Stop 关闭缓存写入工作池
|
||||
func (s *BillingCacheService) Stop() {
|
||||
s.cacheWriteStopOnce.Do(func() {
|
||||
s.stopped.Store(true)
|
||||
|
||||
s.cacheWriteMu.Lock()
|
||||
ch := s.cacheWriteChan
|
||||
if ch != nil {
|
||||
close(ch)
|
||||
}
|
||||
s.cacheWriteMu.Unlock()
|
||||
|
||||
if ch == nil {
|
||||
return
|
||||
}
|
||||
s.cacheWriteWg.Wait()
|
||||
|
||||
s.cacheWriteMu.Lock()
|
||||
if s.cacheWriteChan == ch {
|
||||
s.cacheWriteChan = nil
|
||||
}
|
||||
s.cacheWriteMu.Unlock()
|
||||
})
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) startCacheWriteWorkers() {
|
||||
ch := make(chan cacheWriteTask, cacheWriteBufferSize)
|
||||
s.cacheWriteChan = ch
|
||||
for i := 0; i < cacheWriteWorkerCount; i++ {
|
||||
s.cacheWriteWg.Add(1)
|
||||
go s.cacheWriteWorker(ch)
|
||||
}
|
||||
}
|
||||
|
||||
// enqueueCacheWrite 尝试将任务入队,队列满时返回 false(并记录告警)。
|
||||
func (s *BillingCacheService) enqueueCacheWrite(task cacheWriteTask) (enqueued bool) {
|
||||
if s.stopped.Load() {
|
||||
s.logCacheWriteDrop(task, "closed")
|
||||
return false
|
||||
}
|
||||
|
||||
s.cacheWriteMu.RLock()
|
||||
defer s.cacheWriteMu.RUnlock()
|
||||
|
||||
if s.cacheWriteChan == nil {
|
||||
s.logCacheWriteDrop(task, "closed")
|
||||
return false
|
||||
}
|
||||
|
||||
select {
|
||||
case s.cacheWriteChan <- task:
|
||||
return true
|
||||
default:
|
||||
// 队列满时不阻塞主流程,交由调用方决定是否同步回退。
|
||||
s.logCacheWriteDrop(task, "full")
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) cacheWriteWorker(ch <-chan cacheWriteTask) {
|
||||
defer s.cacheWriteWg.Done()
|
||||
for task := range ch {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
switch task.kind {
|
||||
case cacheWriteSetBalance:
|
||||
s.setBalanceCache(ctx, task.userID, task.balance)
|
||||
case cacheWriteSetSubscription:
|
||||
s.setSubscriptionCache(ctx, task.userID, task.groupID, task.subscriptionData)
|
||||
case cacheWriteUpdateSubscriptionUsage:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.UpdateSubscriptionUsage(ctx, task.userID, task.groupID, task.amount); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache failed for user %d group %d: %v", task.userID, task.groupID, err)
|
||||
}
|
||||
}
|
||||
case cacheWriteDeductBalance:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.DeductUserBalance(ctx, task.userID, task.amount); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache failed for user %d: %v", task.userID, err)
|
||||
}
|
||||
}
|
||||
case cacheWriteUpdateRateLimitUsage:
|
||||
if s.cache != nil {
|
||||
if err := s.cache.UpdateAPIKeyRateLimitUsage(ctx, task.apiKeyID, task.amount); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: update rate limit usage cache failed for api key %d: %v", task.apiKeyID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
cancel()
|
||||
}
|
||||
}
|
||||
|
||||
// cacheWriteKindName 用于日志中的任务类型标识,便于排查丢弃原因。
|
||||
func cacheWriteKindName(kind cacheWriteKind) string {
|
||||
switch kind {
|
||||
case cacheWriteSetBalance:
|
||||
return "set_balance"
|
||||
case cacheWriteSetSubscription:
|
||||
return "set_subscription"
|
||||
case cacheWriteUpdateSubscriptionUsage:
|
||||
return "update_subscription_usage"
|
||||
case cacheWriteDeductBalance:
|
||||
return "deduct_balance"
|
||||
case cacheWriteUpdateRateLimitUsage:
|
||||
return "update_rate_limit_usage"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
// logCacheWriteDrop 使用节流方式记录丢弃情况,并汇总丢弃数量。
|
||||
func (s *BillingCacheService) logCacheWriteDrop(task cacheWriteTask, reason string) {
|
||||
var (
|
||||
countPtr *uint64
|
||||
lastPtr *int64
|
||||
)
|
||||
switch reason {
|
||||
case "full":
|
||||
countPtr = &s.cacheWriteDropFullCount
|
||||
lastPtr = &s.cacheWriteDropFullLastLog
|
||||
case "closed":
|
||||
countPtr = &s.cacheWriteDropClosedCount
|
||||
lastPtr = &s.cacheWriteDropClosedLastLog
|
||||
default:
|
||||
return
|
||||
}
|
||||
|
||||
atomic.AddUint64(countPtr, 1)
|
||||
now := time.Now().UnixNano()
|
||||
last := atomic.LoadInt64(lastPtr)
|
||||
if now-last < int64(cacheWriteDropLogInterval) {
|
||||
return
|
||||
}
|
||||
if !atomic.CompareAndSwapInt64(lastPtr, last, now) {
|
||||
return
|
||||
}
|
||||
dropped := atomic.SwapUint64(countPtr, 0)
|
||||
if dropped == 0 {
|
||||
return
|
||||
}
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: cache write queue %s, dropped %d tasks in last %s (latest kind=%s user %d group %d)",
|
||||
reason,
|
||||
dropped,
|
||||
cacheWriteDropLogInterval,
|
||||
cacheWriteKindName(task.kind),
|
||||
task.userID,
|
||||
task.groupID,
|
||||
)
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 余额缓存方法
|
||||
// ============================================
|
||||
|
||||
// GetUserBalance 获取用户余额(优先从缓存读取)
|
||||
func (s *BillingCacheService) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
if s.cache == nil {
|
||||
// Redis不可用,直接查询数据库
|
||||
return s.getUserBalanceFromDB(ctx, userID)
|
||||
}
|
||||
|
||||
// 尝试从缓存读取
|
||||
balance, err := s.cache.GetUserBalance(ctx, userID)
|
||||
if err == nil {
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
// 缓存未命中:singleflight 合并同一 userID 的并发回源请求。
|
||||
value, err, _ := s.balanceLoadSF.Do(strconv.FormatInt(userID, 10), func() (any, error) {
|
||||
loadCtx, cancel := context.WithTimeout(context.Background(), balanceLoadTimeout)
|
||||
defer cancel()
|
||||
|
||||
balance, err := s.getUserBalanceFromDB(loadCtx, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetBalance,
|
||||
userID: userID,
|
||||
balance: balance,
|
||||
})
|
||||
return balance, nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
balance, ok := value.(float64)
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("unexpected balance type: %T", value)
|
||||
}
|
||||
return balance, nil
|
||||
}
|
||||
|
||||
// getUserBalanceFromDB 从数据库获取用户余额
|
||||
func (s *BillingCacheService) getUserBalanceFromDB(ctx context.Context, userID int64) (float64, error) {
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("get user balance: %w", err)
|
||||
}
|
||||
return user.Balance, nil
|
||||
}
|
||||
|
||||
// setBalanceCache 设置余额缓存
|
||||
func (s *BillingCacheService) setBalanceCache(ctx context.Context, userID int64, balance float64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
if err := s.cache.SetUserBalance(ctx, userID, balance); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: set balance cache failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// DeductBalanceCache 扣减余额缓存(同步调用)
|
||||
func (s *BillingCacheService) DeductBalanceCache(ctx context.Context, userID int64, amount float64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.DeductUserBalance(ctx, userID, amount)
|
||||
}
|
||||
|
||||
// QueueDeductBalance 异步扣减余额缓存
|
||||
func (s *BillingCacheService) QueueDeductBalance(userID int64, amount float64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
// 队列满时同步回退,避免关键扣减被静默丢弃。
|
||||
if s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteDeductBalance,
|
||||
userID: userID,
|
||||
amount: amount,
|
||||
}) {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
if err := s.DeductBalanceCache(ctx, userID, amount); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: deduct balance cache fallback failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateUserBalance 失效用户余额缓存
|
||||
func (s *BillingCacheService) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
if err := s.cache.InvalidateUserBalance(ctx, userID); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate balance cache failed for user %d: %v", userID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 订阅缓存方法
|
||||
// ============================================
|
||||
|
||||
// GetSubscriptionStatus 获取订阅状态(优先从缓存读取)
|
||||
func (s *BillingCacheService) GetSubscriptionStatus(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||
if s.cache == nil {
|
||||
return s.getSubscriptionFromDB(ctx, userID, groupID)
|
||||
}
|
||||
|
||||
// 尝试从缓存读取
|
||||
cacheData, err := s.cache.GetSubscriptionCache(ctx, userID, groupID)
|
||||
if err == nil && cacheData != nil {
|
||||
return s.convertFromPortsData(cacheData), nil
|
||||
}
|
||||
|
||||
// 缓存未命中,从数据库读取
|
||||
data, err := s.getSubscriptionFromDB(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 异步建立缓存
|
||||
_ = s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteSetSubscription,
|
||||
userID: userID,
|
||||
groupID: groupID,
|
||||
subscriptionData: data,
|
||||
})
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) convertFromPortsData(data *SubscriptionCacheData) *subscriptionCacheData {
|
||||
return &subscriptionCacheData{
|
||||
Status: data.Status,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
DailyUsage: data.DailyUsage,
|
||||
WeeklyUsage: data.WeeklyUsage,
|
||||
MonthlyUsage: data.MonthlyUsage,
|
||||
Version: data.Version,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *BillingCacheService) convertToPortsData(data *subscriptionCacheData) *SubscriptionCacheData {
|
||||
return &SubscriptionCacheData{
|
||||
Status: data.Status,
|
||||
ExpiresAt: data.ExpiresAt,
|
||||
DailyUsage: data.DailyUsage,
|
||||
WeeklyUsage: data.WeeklyUsage,
|
||||
MonthlyUsage: data.MonthlyUsage,
|
||||
Version: data.Version,
|
||||
}
|
||||
}
|
||||
|
||||
// getSubscriptionFromDB 从数据库获取订阅数据
|
||||
func (s *BillingCacheService) getSubscriptionFromDB(ctx context.Context, userID, groupID int64) (*subscriptionCacheData, error) {
|
||||
sub, err := s.subRepo.GetActiveByUserIDAndGroupID(ctx, userID, groupID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get subscription: %w", err)
|
||||
}
|
||||
|
||||
return &subscriptionCacheData{
|
||||
Status: sub.Status,
|
||||
ExpiresAt: sub.ExpiresAt,
|
||||
DailyUsage: sub.DailyUsageUSD,
|
||||
WeeklyUsage: sub.WeeklyUsageUSD,
|
||||
MonthlyUsage: sub.MonthlyUsageUSD,
|
||||
Version: sub.UpdatedAt.Unix(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// setSubscriptionCache 设置订阅缓存
|
||||
func (s *BillingCacheService) setSubscriptionCache(ctx context.Context, userID, groupID int64, data *subscriptionCacheData) {
|
||||
if s.cache == nil || data == nil {
|
||||
return
|
||||
}
|
||||
if err := s.cache.SetSubscriptionCache(ctx, userID, groupID, s.convertToPortsData(data)); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: set subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateSubscriptionUsage 更新订阅用量缓存(同步调用)
|
||||
func (s *BillingCacheService) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, costUSD float64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD)
|
||||
}
|
||||
|
||||
// QueueUpdateSubscriptionUsage 异步更新订阅用量缓存
|
||||
func (s *BillingCacheService) QueueUpdateSubscriptionUsage(userID, groupID int64, costUSD float64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
// 队列满时同步回退,确保订阅用量及时更新。
|
||||
if s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteUpdateSubscriptionUsage,
|
||||
userID: userID,
|
||||
groupID: groupID,
|
||||
amount: costUSD,
|
||||
}) {
|
||||
return
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
if err := s.UpdateSubscriptionUsage(ctx, userID, groupID, costUSD); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: update subscription cache fallback failed for user %d group %d: %v", userID, groupID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// InvalidateSubscription 失效指定订阅缓存
|
||||
func (s *BillingCacheService) InvalidateSubscription(ctx context.Context, userID, groupID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
if err := s.cache.InvalidateSubscriptionCache(ctx, userID, groupID); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate subscription cache failed for user %d group %d: %v", userID, groupID, err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// API Key 限速缓存方法
|
||||
// ============================================
|
||||
|
||||
// checkAPIKeyRateLimits checks rate limit windows for an API key.
|
||||
// It loads usage from Redis cache (falling back to DB on cache miss),
|
||||
// resets expired windows in-memory and triggers async DB reset,
|
||||
// and returns an error if any window limit is exceeded.
|
||||
func (s *BillingCacheService) checkAPIKeyRateLimits(ctx context.Context, apiKey *APIKey) error {
|
||||
if s.cache == nil {
|
||||
// No cache: fall back to reading from DB directly
|
||||
if s.apiKeyRateLimitLoader == nil {
|
||||
return nil
|
||||
}
|
||||
data, err := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
|
||||
if err != nil {
|
||||
return nil // Don't block requests on DB errors
|
||||
}
|
||||
return s.evaluateRateLimits(ctx, apiKey, data.Usage5h, data.Usage1d, data.Usage7d,
|
||||
data.Window5hStart, data.Window1dStart, data.Window7dStart)
|
||||
}
|
||||
|
||||
cacheData, err := s.cache.GetAPIKeyRateLimit(ctx, apiKey.ID)
|
||||
if err != nil {
|
||||
// Cache miss: load from DB and populate cache
|
||||
if s.apiKeyRateLimitLoader == nil {
|
||||
return nil
|
||||
}
|
||||
dbData, dbErr := s.apiKeyRateLimitLoader.GetRateLimitData(ctx, apiKey.ID)
|
||||
if dbErr != nil {
|
||||
return nil // Don't block requests on DB errors
|
||||
}
|
||||
// Build cache entry from DB data
|
||||
cacheEntry := &APIKeyRateLimitCacheData{
|
||||
Usage5h: dbData.Usage5h,
|
||||
Usage1d: dbData.Usage1d,
|
||||
Usage7d: dbData.Usage7d,
|
||||
}
|
||||
if dbData.Window5hStart != nil {
|
||||
cacheEntry.Window5h = dbData.Window5hStart.Unix()
|
||||
}
|
||||
if dbData.Window1dStart != nil {
|
||||
cacheEntry.Window1d = dbData.Window1dStart.Unix()
|
||||
}
|
||||
if dbData.Window7dStart != nil {
|
||||
cacheEntry.Window7d = dbData.Window7dStart.Unix()
|
||||
}
|
||||
_ = s.cache.SetAPIKeyRateLimit(ctx, apiKey.ID, cacheEntry)
|
||||
cacheData = cacheEntry
|
||||
}
|
||||
|
||||
var w5h, w1d, w7d *time.Time
|
||||
if cacheData.Window5h > 0 {
|
||||
t := time.Unix(cacheData.Window5h, 0)
|
||||
w5h = &t
|
||||
}
|
||||
if cacheData.Window1d > 0 {
|
||||
t := time.Unix(cacheData.Window1d, 0)
|
||||
w1d = &t
|
||||
}
|
||||
if cacheData.Window7d > 0 {
|
||||
t := time.Unix(cacheData.Window7d, 0)
|
||||
w7d = &t
|
||||
}
|
||||
return s.evaluateRateLimits(ctx, apiKey, cacheData.Usage5h, cacheData.Usage1d, cacheData.Usage7d, w5h, w1d, w7d)
|
||||
}
|
||||
|
||||
// evaluateRateLimits checks usage against limits, triggering async resets for expired windows.
|
||||
func (s *BillingCacheService) evaluateRateLimits(ctx context.Context, apiKey *APIKey, usage5h, usage1d, usage7d float64, w5h, w1d, w7d *time.Time) error {
|
||||
needsReset := false
|
||||
|
||||
// Reset expired windows in-memory for check purposes
|
||||
if IsWindowExpired(w5h, RateLimitWindow5h) {
|
||||
usage5h = 0
|
||||
needsReset = true
|
||||
}
|
||||
if IsWindowExpired(w1d, RateLimitWindow1d) {
|
||||
usage1d = 0
|
||||
needsReset = true
|
||||
}
|
||||
if IsWindowExpired(w7d, RateLimitWindow7d) {
|
||||
usage7d = 0
|
||||
needsReset = true
|
||||
}
|
||||
|
||||
// Trigger async DB reset if any window expired
|
||||
if needsReset {
|
||||
keyID := apiKey.ID
|
||||
go func() {
|
||||
resetCtx, cancel := context.WithTimeout(context.Background(), cacheWriteTimeout)
|
||||
defer cancel()
|
||||
if s.apiKeyRateLimitLoader != nil {
|
||||
// Use the repo directly - reset then reload cache
|
||||
if loader, ok := s.apiKeyRateLimitLoader.(interface {
|
||||
ResetRateLimitWindows(ctx context.Context, id int64) error
|
||||
}); ok {
|
||||
if err := loader.ResetRateLimitWindows(resetCtx, keyID); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: reset rate limit windows failed for api key %d: %v", keyID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Invalidate cache so next request loads fresh data
|
||||
if s.cache != nil {
|
||||
if err := s.cache.InvalidateAPIKeyRateLimit(resetCtx, keyID); err != nil {
|
||||
logger.LegacyPrintf("service.billing_cache", "Warning: invalidate rate limit cache failed for api key %d: %v", keyID, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Check limits
|
||||
if apiKey.RateLimit5h > 0 && usage5h >= apiKey.RateLimit5h {
|
||||
return ErrAPIKeyRateLimit5hExceeded
|
||||
}
|
||||
if apiKey.RateLimit1d > 0 && usage1d >= apiKey.RateLimit1d {
|
||||
return ErrAPIKeyRateLimit1dExceeded
|
||||
}
|
||||
if apiKey.RateLimit7d > 0 && usage7d >= apiKey.RateLimit7d {
|
||||
return ErrAPIKeyRateLimit7dExceeded
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// QueueUpdateAPIKeyRateLimitUsage asynchronously updates rate limit usage in the cache.
|
||||
func (s *BillingCacheService) QueueUpdateAPIKeyRateLimitUsage(apiKeyID int64, cost float64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
s.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteUpdateRateLimitUsage,
|
||||
apiKeyID: apiKeyID,
|
||||
amount: cost,
|
||||
})
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// 统一检查方法
|
||||
// ============================================
|
||||
|
||||
// CheckBillingEligibility 检查用户是否有资格发起请求
|
||||
// 余额模式:检查缓存余额 > 0
|
||||
// 订阅模式:检查缓存用量未超过限额(Group限额从参数传入)
|
||||
func (s *BillingCacheService) CheckBillingEligibility(ctx context.Context, user *User, apiKey *APIKey, group *Group, subscription *UserSubscription) error {
|
||||
// 简易模式:跳过所有计费检查
|
||||
if s.cfg.RunMode == config.RunModeSimple {
|
||||
return nil
|
||||
}
|
||||
if s.circuitBreaker != nil && !s.circuitBreaker.Allow() {
|
||||
return ErrBillingServiceUnavailable
|
||||
}
|
||||
|
||||
// 判断计费模式
|
||||
isSubscriptionMode := group != nil && group.IsSubscriptionType() && subscription != nil
|
||||
|
||||
if isSubscriptionMode {
|
||||
if err := s.checkSubscriptionEligibility(ctx, user.ID, group, subscription); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := s.checkBalanceEligibility(ctx, user.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Check API Key rate limits (applies to both billing modes)
|
||||
if apiKey != nil && apiKey.HasRateLimits() {
|
||||
if err := s.checkAPIKeyRateLimits(ctx, apiKey); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkBalanceEligibility 检查余额模式资格
|
||||
func (s *BillingCacheService) checkBalanceEligibility(ctx context.Context, userID int64) error {
|
||||
balance, err := s.GetUserBalance(ctx, userID)
|
||||
if err != nil {
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnFailure(err)
|
||||
}
|
||||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing balance check failed for user %d: %v", userID, err)
|
||||
return ErrBillingServiceUnavailable.WithCause(err)
|
||||
}
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnSuccess()
|
||||
}
|
||||
|
||||
if balance <= 0 {
|
||||
return ErrInsufficientBalance
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkSubscriptionEligibility 检查订阅模式资格
|
||||
func (s *BillingCacheService) checkSubscriptionEligibility(ctx context.Context, userID int64, group *Group, subscription *UserSubscription) error {
|
||||
// 获取订阅缓存数据
|
||||
subData, err := s.GetSubscriptionStatus(ctx, userID, group.ID)
|
||||
if err != nil {
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnFailure(err)
|
||||
}
|
||||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing subscription check failed for user %d group %d: %v", userID, group.ID, err)
|
||||
return ErrBillingServiceUnavailable.WithCause(err)
|
||||
}
|
||||
if s.circuitBreaker != nil {
|
||||
s.circuitBreaker.OnSuccess()
|
||||
}
|
||||
|
||||
// 检查订阅状态
|
||||
if subData.Status != SubscriptionStatusActive {
|
||||
return ErrSubscriptionInvalid
|
||||
}
|
||||
|
||||
// 检查是否过期
|
||||
if time.Now().After(subData.ExpiresAt) {
|
||||
return ErrSubscriptionInvalid
|
||||
}
|
||||
|
||||
// 检查限额(使用传入的Group限额配置)
|
||||
if group.HasDailyLimit() && subData.DailyUsage >= *group.DailyLimitUSD {
|
||||
return ErrDailyLimitExceeded
|
||||
}
|
||||
|
||||
if group.HasWeeklyLimit() && subData.WeeklyUsage >= *group.WeeklyLimitUSD {
|
||||
return ErrWeeklyLimitExceeded
|
||||
}
|
||||
|
||||
if group.HasMonthlyLimit() && subData.MonthlyUsage >= *group.MonthlyLimitUSD {
|
||||
return ErrMonthlyLimitExceeded
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type billingCircuitBreakerState int
|
||||
|
||||
const (
|
||||
billingCircuitClosed billingCircuitBreakerState = iota
|
||||
billingCircuitOpen
|
||||
billingCircuitHalfOpen
|
||||
)
|
||||
|
||||
type billingCircuitBreaker struct {
|
||||
mu sync.Mutex
|
||||
state billingCircuitBreakerState
|
||||
failures int
|
||||
openedAt time.Time
|
||||
failureThreshold int
|
||||
resetTimeout time.Duration
|
||||
halfOpenRequests int
|
||||
halfOpenRemaining int
|
||||
}
|
||||
|
||||
func newBillingCircuitBreaker(cfg config.CircuitBreakerConfig) *billingCircuitBreaker {
|
||||
if !cfg.Enabled {
|
||||
return nil
|
||||
}
|
||||
resetTimeout := time.Duration(cfg.ResetTimeoutSeconds) * time.Second
|
||||
if resetTimeout <= 0 {
|
||||
resetTimeout = 30 * time.Second
|
||||
}
|
||||
halfOpen := cfg.HalfOpenRequests
|
||||
if halfOpen <= 0 {
|
||||
halfOpen = 1
|
||||
}
|
||||
threshold := cfg.FailureThreshold
|
||||
if threshold <= 0 {
|
||||
threshold = 5
|
||||
}
|
||||
return &billingCircuitBreaker{
|
||||
state: billingCircuitClosed,
|
||||
failureThreshold: threshold,
|
||||
resetTimeout: resetTimeout,
|
||||
halfOpenRequests: halfOpen,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) Allow() bool {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
switch b.state {
|
||||
case billingCircuitClosed:
|
||||
return true
|
||||
case billingCircuitOpen:
|
||||
if time.Since(b.openedAt) < b.resetTimeout {
|
||||
return false
|
||||
}
|
||||
b.state = billingCircuitHalfOpen
|
||||
b.halfOpenRemaining = b.halfOpenRequests
|
||||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker entering half-open state")
|
||||
fallthrough
|
||||
case billingCircuitHalfOpen:
|
||||
if b.halfOpenRemaining <= 0 {
|
||||
return false
|
||||
}
|
||||
b.halfOpenRemaining--
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) OnFailure(err error) {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
switch b.state {
|
||||
case billingCircuitOpen:
|
||||
return
|
||||
case billingCircuitHalfOpen:
|
||||
b.state = billingCircuitOpen
|
||||
b.openedAt = time.Now()
|
||||
b.halfOpenRemaining = 0
|
||||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after half-open failure: %v", err)
|
||||
return
|
||||
default:
|
||||
b.failures++
|
||||
if b.failures >= b.failureThreshold {
|
||||
b.state = billingCircuitOpen
|
||||
b.openedAt = time.Now()
|
||||
b.halfOpenRemaining = 0
|
||||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker opened after %d failures: %v", b.failures, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *billingCircuitBreaker) OnSuccess() {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
previousState := b.state
|
||||
previousFailures := b.failures
|
||||
|
||||
b.state = billingCircuitClosed
|
||||
b.failures = 0
|
||||
b.halfOpenRemaining = 0
|
||||
|
||||
// 只有状态真正发生变化时才记录日志
|
||||
if previousState != billingCircuitClosed {
|
||||
logger.LegacyPrintf("service.billing_cache", "ALERT: billing circuit breaker closed (was %s)", circuitStateString(previousState))
|
||||
} else if previousFailures > 0 {
|
||||
logger.LegacyPrintf("service.billing_cache", "INFO: billing circuit breaker failures reset from %d", previousFailures)
|
||||
}
|
||||
}
|
||||
|
||||
func circuitStateString(state billingCircuitBreakerState) string {
|
||||
switch state {
|
||||
case billingCircuitClosed:
|
||||
return "closed"
|
||||
case billingCircuitOpen:
|
||||
return "open"
|
||||
case billingCircuitHalfOpen:
|
||||
return "half-open"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,131 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type billingCacheMissStub struct {
|
||||
setBalanceCalls atomic.Int64
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
return 0, errors.New("cache miss")
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
s.setBalanceCalls.Add(1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
|
||||
return nil, errors.New("cache miss")
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
|
||||
return nil, errors.New("cache miss")
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *billingCacheMissStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type balanceLoadUserRepoStub struct {
|
||||
mockUserRepo
|
||||
calls atomic.Int64
|
||||
delay time.Duration
|
||||
balance float64
|
||||
}
|
||||
|
||||
func (s *balanceLoadUserRepoStub) GetByID(ctx context.Context, id int64) (*User, error) {
|
||||
s.calls.Add(1)
|
||||
if s.delay > 0 {
|
||||
select {
|
||||
case <-time.After(s.delay):
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
return &User{ID: id, Balance: s.balance}, nil
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceGetUserBalance_Singleflight(t *testing.T) {
|
||||
cache := &billingCacheMissStub{}
|
||||
userRepo := &balanceLoadUserRepoStub{
|
||||
delay: 80 * time.Millisecond,
|
||||
balance: 12.34,
|
||||
}
|
||||
svc := NewBillingCacheService(cache, userRepo, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
const goroutines = 16
|
||||
start := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
errCh := make(chan error, goroutines)
|
||||
balCh := make(chan float64, goroutines)
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-start
|
||||
bal, err := svc.GetUserBalance(context.Background(), 99)
|
||||
errCh <- err
|
||||
balCh <- bal
|
||||
}()
|
||||
}
|
||||
|
||||
close(start)
|
||||
wg.Wait()
|
||||
close(errCh)
|
||||
close(balCh)
|
||||
|
||||
for err := range errCh {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
for bal := range balCh {
|
||||
require.Equal(t, 12.34, bal)
|
||||
}
|
||||
|
||||
require.Equal(t, int64(1), userRepo.calls.Load(), "并发穿透应被 singleflight 合并")
|
||||
require.Eventually(t, func() bool {
|
||||
return cache.setBalanceCalls.Load() >= 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type billingCacheWorkerStub struct {
|
||||
balanceUpdates int64
|
||||
subscriptionUpdates int64
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetUserBalance(ctx context.Context, userID int64) (float64, error) {
|
||||
return 0, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetUserBalance(ctx context.Context, userID int64, balance float64) error {
|
||||
atomic.AddInt64(&b.balanceUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) DeductUserBalance(ctx context.Context, userID int64, amount float64) error {
|
||||
atomic.AddInt64(&b.balanceUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) InvalidateUserBalance(ctx context.Context, userID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error {
|
||||
atomic.AddInt64(&b.subscriptionUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error {
|
||||
atomic.AddInt64(&b.subscriptionUpdates, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *billingCacheWorkerStub) InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceQueueHighLoad(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
|
||||
t.Cleanup(svc.Stop)
|
||||
|
||||
start := time.Now()
|
||||
for i := 0; i < cacheWriteBufferSize*2; i++ {
|
||||
svc.QueueDeductBalance(1, 1)
|
||||
}
|
||||
require.Less(t, time.Since(start), 2*time.Second)
|
||||
|
||||
svc.QueueUpdateSubscriptionUsage(1, 2, 1.5)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt64(&cache.balanceUpdates) > 0
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt64(&cache.subscriptionUpdates) > 0
|
||||
}, 2*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestBillingCacheServiceEnqueueAfterStopReturnsFalse(t *testing.T) {
|
||||
cache := &billingCacheWorkerStub{}
|
||||
svc := NewBillingCacheService(cache, nil, nil, nil, &config.Config{})
|
||||
svc.Stop()
|
||||
|
||||
enqueued := svc.enqueueCacheWrite(cacheWriteTask{
|
||||
kind: cacheWriteDeductBalance,
|
||||
userID: 1,
|
||||
amount: 1,
|
||||
})
|
||||
require.False(t, enqueued)
|
||||
}
|
||||
@@ -0,0 +1,763 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
)
|
||||
|
||||
// APIKeyRateLimitCacheData holds rate limit usage data cached in Redis.
|
||||
type APIKeyRateLimitCacheData struct {
|
||||
Usage5h float64 `json:"usage_5h"`
|
||||
Usage1d float64 `json:"usage_1d"`
|
||||
Usage7d float64 `json:"usage_7d"`
|
||||
Window5h int64 `json:"window_5h"` // unix timestamp, 0 = not started
|
||||
Window1d int64 `json:"window_1d"`
|
||||
Window7d int64 `json:"window_7d"`
|
||||
}
|
||||
|
||||
// BillingCache defines cache operations for billing service
|
||||
type BillingCache interface {
|
||||
// Balance operations
|
||||
GetUserBalance(ctx context.Context, userID int64) (float64, error)
|
||||
SetUserBalance(ctx context.Context, userID int64, balance float64) error
|
||||
DeductUserBalance(ctx context.Context, userID int64, amount float64) error
|
||||
InvalidateUserBalance(ctx context.Context, userID int64) error
|
||||
|
||||
// Subscription operations
|
||||
GetSubscriptionCache(ctx context.Context, userID, groupID int64) (*SubscriptionCacheData, error)
|
||||
SetSubscriptionCache(ctx context.Context, userID, groupID int64, data *SubscriptionCacheData) error
|
||||
UpdateSubscriptionUsage(ctx context.Context, userID, groupID int64, cost float64) error
|
||||
InvalidateSubscriptionCache(ctx context.Context, userID, groupID int64) error
|
||||
|
||||
// API Key rate limit operations
|
||||
GetAPIKeyRateLimit(ctx context.Context, keyID int64) (*APIKeyRateLimitCacheData, error)
|
||||
SetAPIKeyRateLimit(ctx context.Context, keyID int64, data *APIKeyRateLimitCacheData) error
|
||||
UpdateAPIKeyRateLimitUsage(ctx context.Context, keyID int64, cost float64) error
|
||||
InvalidateAPIKeyRateLimit(ctx context.Context, keyID int64) error
|
||||
}
|
||||
|
||||
// ModelPricing 模型价格配置(per-token价格,与LiteLLM格式一致)
|
||||
type ModelPricing struct {
|
||||
InputPricePerToken float64 // 每token输入价格 (USD)
|
||||
InputPricePerTokenPriority float64 // priority service tier 下每token输入价格 (USD)
|
||||
OutputPricePerToken float64 // 每token输出价格 (USD)
|
||||
OutputPricePerTokenPriority float64 // priority service tier 下每token输出价格 (USD)
|
||||
CacheCreationPricePerToken float64 // 缓存创建每token价格 (USD)
|
||||
CacheReadPricePerToken float64 // 缓存读取每token价格 (USD)
|
||||
CacheReadPricePerTokenPriority float64 // priority service tier 下缓存读取每token价格 (USD)
|
||||
CacheCreation5mPrice float64 // 5分钟缓存创建每token价格 (USD)
|
||||
CacheCreation1hPrice float64 // 1小时缓存创建每token价格 (USD)
|
||||
SupportsCacheBreakdown bool // 是否支持详细的缓存分类
|
||||
LongContextInputThreshold int // 超过阈值后按整次会话提升输入价格
|
||||
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
|
||||
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
|
||||
}
|
||||
|
||||
const (
|
||||
openAIGPT54LongContextInputThreshold = 272000
|
||||
openAIGPT54LongContextInputMultiplier = 2.0
|
||||
openAIGPT54LongContextOutputMultiplier = 1.5
|
||||
)
|
||||
|
||||
func normalizeBillingServiceTier(serviceTier string) string {
|
||||
return strings.ToLower(strings.TrimSpace(serviceTier))
|
||||
}
|
||||
|
||||
func usePriorityServiceTierPricing(serviceTier string, pricing *ModelPricing) bool {
|
||||
if pricing == nil || normalizeBillingServiceTier(serviceTier) != "priority" {
|
||||
return false
|
||||
}
|
||||
return pricing.InputPricePerTokenPriority > 0 || pricing.OutputPricePerTokenPriority > 0 || pricing.CacheReadPricePerTokenPriority > 0
|
||||
}
|
||||
|
||||
func serviceTierCostMultiplier(serviceTier string) float64 {
|
||||
switch normalizeBillingServiceTier(serviceTier) {
|
||||
case "priority":
|
||||
return 2.0
|
||||
case "flex":
|
||||
return 0.5
|
||||
default:
|
||||
return 1.0
|
||||
}
|
||||
}
|
||||
|
||||
// UsageTokens 使用的token数量
|
||||
type UsageTokens struct {
|
||||
InputTokens int
|
||||
OutputTokens int
|
||||
CacheCreationTokens int
|
||||
CacheReadTokens int
|
||||
CacheCreation5mTokens int
|
||||
CacheCreation1hTokens int
|
||||
}
|
||||
|
||||
// CostBreakdown 费用明细
|
||||
type CostBreakdown struct {
|
||||
InputCost float64
|
||||
OutputCost float64
|
||||
CacheCreationCost float64
|
||||
CacheReadCost float64
|
||||
TotalCost float64
|
||||
ActualCost float64 // 应用倍率后的实际费用
|
||||
}
|
||||
|
||||
// BillingService 计费服务
|
||||
type BillingService struct {
|
||||
cfg *config.Config
|
||||
pricingService *PricingService
|
||||
fallbackPrices map[string]*ModelPricing // 硬编码回退价格
|
||||
}
|
||||
|
||||
// NewBillingService 创建计费服务实例
|
||||
func NewBillingService(cfg *config.Config, pricingService *PricingService) *BillingService {
|
||||
s := &BillingService{
|
||||
cfg: cfg,
|
||||
pricingService: pricingService,
|
||||
fallbackPrices: make(map[string]*ModelPricing),
|
||||
}
|
||||
|
||||
// 初始化硬编码回退价格(当动态价格不可用时使用)
|
||||
s.initFallbackPricing()
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// initFallbackPricing 初始化硬编码回退价格(当动态价格不可用时使用)
|
||||
// 价格单位:USD per token(与LiteLLM格式一致)
|
||||
func (s *BillingService) initFallbackPricing() {
|
||||
// Claude 4.5 Opus
|
||||
s.fallbackPrices["claude-opus-4.5"] = &ModelPricing{
|
||||
InputPricePerToken: 5e-6, // $5 per MTok
|
||||
OutputPricePerToken: 25e-6, // $25 per MTok
|
||||
CacheCreationPricePerToken: 6.25e-6, // $6.25 per MTok
|
||||
CacheReadPricePerToken: 0.5e-6, // $0.50 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 4 Sonnet
|
||||
s.fallbackPrices["claude-sonnet-4"] = &ModelPricing{
|
||||
InputPricePerToken: 3e-6, // $3 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
|
||||
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 3.5 Sonnet
|
||||
s.fallbackPrices["claude-3-5-sonnet"] = &ModelPricing{
|
||||
InputPricePerToken: 3e-6, // $3 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
CacheCreationPricePerToken: 3.75e-6, // $3.75 per MTok
|
||||
CacheReadPricePerToken: 0.3e-6, // $0.30 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 3.5 Haiku
|
||||
s.fallbackPrices["claude-3-5-haiku"] = &ModelPricing{
|
||||
InputPricePerToken: 1e-6, // $1 per MTok
|
||||
OutputPricePerToken: 5e-6, // $5 per MTok
|
||||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheReadPricePerToken: 0.1e-6, // $0.10 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 3 Opus
|
||||
s.fallbackPrices["claude-3-opus"] = &ModelPricing{
|
||||
InputPricePerToken: 15e-6, // $15 per MTok
|
||||
OutputPricePerToken: 75e-6, // $75 per MTok
|
||||
CacheCreationPricePerToken: 18.75e-6, // $18.75 per MTok
|
||||
CacheReadPricePerToken: 1.5e-6, // $1.50 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 3 Haiku
|
||||
s.fallbackPrices["claude-3-haiku"] = &ModelPricing{
|
||||
InputPricePerToken: 0.25e-6, // $0.25 per MTok
|
||||
OutputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheCreationPricePerToken: 0.3e-6, // $0.30 per MTok
|
||||
CacheReadPricePerToken: 0.03e-6, // $0.03 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// Claude 4.6 Opus (与4.5同价)
|
||||
s.fallbackPrices["claude-opus-4.6"] = s.fallbackPrices["claude-opus-4.5"]
|
||||
|
||||
// Gemini 3.1 Pro
|
||||
s.fallbackPrices["gemini-3.1-pro"] = &ModelPricing{
|
||||
InputPricePerToken: 2e-6, // $2 per MTok
|
||||
OutputPricePerToken: 12e-6, // $12 per MTok
|
||||
CacheCreationPricePerToken: 2e-6, // $2 per MTok
|
||||
CacheReadPricePerToken: 0.2e-6, // $0.20 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
|
||||
// OpenAI GPT-5.1(本地兜底,防止动态定价不可用时拒绝计费)
|
||||
s.fallbackPrices["gpt-5.1"] = &ModelPricing{
|
||||
InputPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
InputPricePerTokenPriority: 2.5e-6, // $2.5 per MTok
|
||||
OutputPricePerToken: 10e-6, // $10 per MTok
|
||||
OutputPricePerTokenPriority: 20e-6, // $20 per MTok
|
||||
CacheCreationPricePerToken: 1.25e-6, // $1.25 per MTok
|
||||
CacheReadPricePerToken: 0.125e-6,
|
||||
CacheReadPricePerTokenPriority: 0.25e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// OpenAI GPT-5.4(业务指定价格)
|
||||
s.fallbackPrices["gpt-5.4"] = &ModelPricing{
|
||||
InputPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
InputPricePerTokenPriority: 5e-6, // $5 per MTok
|
||||
OutputPricePerToken: 15e-6, // $15 per MTok
|
||||
OutputPricePerTokenPriority: 30e-6, // $30 per MTok
|
||||
CacheCreationPricePerToken: 2.5e-6, // $2.5 per MTok
|
||||
CacheReadPricePerToken: 0.25e-6, // $0.25 per MTok
|
||||
CacheReadPricePerTokenPriority: 0.5e-6, // $0.5 per MTok
|
||||
SupportsCacheBreakdown: false,
|
||||
LongContextInputThreshold: openAIGPT54LongContextInputThreshold,
|
||||
LongContextInputMultiplier: openAIGPT54LongContextInputMultiplier,
|
||||
LongContextOutputMultiplier: openAIGPT54LongContextOutputMultiplier,
|
||||
}
|
||||
// OpenAI GPT-5.2(本地兜底)
|
||||
s.fallbackPrices["gpt-5.2"] = &ModelPricing{
|
||||
InputPricePerToken: 1.75e-6,
|
||||
InputPricePerTokenPriority: 3.5e-6,
|
||||
OutputPricePerToken: 14e-6,
|
||||
OutputPricePerTokenPriority: 28e-6,
|
||||
CacheCreationPricePerToken: 1.75e-6,
|
||||
CacheReadPricePerToken: 0.175e-6,
|
||||
CacheReadPricePerTokenPriority: 0.35e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
// Codex 族兜底统一按 GPT-5.1 Codex 价格计费
|
||||
s.fallbackPrices["gpt-5.1-codex"] = &ModelPricing{
|
||||
InputPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
InputPricePerTokenPriority: 3e-6, // $3 per MTok
|
||||
OutputPricePerToken: 12e-6, // $12 per MTok
|
||||
OutputPricePerTokenPriority: 24e-6, // $24 per MTok
|
||||
CacheCreationPricePerToken: 1.5e-6, // $1.5 per MTok
|
||||
CacheReadPricePerToken: 0.15e-6,
|
||||
CacheReadPricePerTokenPriority: 0.3e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.2-codex"] = &ModelPricing{
|
||||
InputPricePerToken: 1.75e-6,
|
||||
InputPricePerTokenPriority: 3.5e-6,
|
||||
OutputPricePerToken: 14e-6,
|
||||
OutputPricePerTokenPriority: 28e-6,
|
||||
CacheCreationPricePerToken: 1.75e-6,
|
||||
CacheReadPricePerToken: 0.175e-6,
|
||||
CacheReadPricePerTokenPriority: 0.35e-6,
|
||||
SupportsCacheBreakdown: false,
|
||||
}
|
||||
s.fallbackPrices["gpt-5.3-codex"] = s.fallbackPrices["gpt-5.1-codex"]
|
||||
}
|
||||
|
||||
// getFallbackPricing 根据模型系列获取回退价格
|
||||
func (s *BillingService) getFallbackPricing(model string) *ModelPricing {
|
||||
modelLower := strings.ToLower(model)
|
||||
|
||||
// 按模型系列匹配
|
||||
if strings.Contains(modelLower, "opus") {
|
||||
if strings.Contains(modelLower, "4.6") || strings.Contains(modelLower, "4-6") {
|
||||
return s.fallbackPrices["claude-opus-4.6"]
|
||||
}
|
||||
if strings.Contains(modelLower, "4.5") || strings.Contains(modelLower, "4-5") {
|
||||
return s.fallbackPrices["claude-opus-4.5"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-opus"]
|
||||
}
|
||||
if strings.Contains(modelLower, "sonnet") {
|
||||
if strings.Contains(modelLower, "4") && !strings.Contains(modelLower, "3") {
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-5-sonnet"]
|
||||
}
|
||||
if strings.Contains(modelLower, "haiku") {
|
||||
if strings.Contains(modelLower, "3-5") || strings.Contains(modelLower, "3.5") {
|
||||
return s.fallbackPrices["claude-3-5-haiku"]
|
||||
}
|
||||
return s.fallbackPrices["claude-3-haiku"]
|
||||
}
|
||||
// Claude 未知型号统一回退到 Sonnet,避免计费中断。
|
||||
if strings.Contains(modelLower, "claude") {
|
||||
return s.fallbackPrices["claude-sonnet-4"]
|
||||
}
|
||||
if strings.Contains(modelLower, "gemini-3.1-pro") || strings.Contains(modelLower, "gemini-3-1-pro") {
|
||||
return s.fallbackPrices["gemini-3.1-pro"]
|
||||
}
|
||||
|
||||
// OpenAI 仅匹配已知 GPT-5/Codex 族,避免未知 OpenAI 型号误计价。
|
||||
if strings.Contains(modelLower, "gpt-5") || strings.Contains(modelLower, "codex") {
|
||||
normalized := normalizeCodexModel(modelLower)
|
||||
switch normalized {
|
||||
case "gpt-5.4":
|
||||
return s.fallbackPrices["gpt-5.4"]
|
||||
case "gpt-5.2":
|
||||
return s.fallbackPrices["gpt-5.2"]
|
||||
case "gpt-5.2-codex":
|
||||
return s.fallbackPrices["gpt-5.2-codex"]
|
||||
case "gpt-5.3-codex":
|
||||
return s.fallbackPrices["gpt-5.3-codex"]
|
||||
case "gpt-5.1-codex", "gpt-5.1-codex-max", "gpt-5.1-codex-mini", "codex-mini-latest":
|
||||
return s.fallbackPrices["gpt-5.1-codex"]
|
||||
case "gpt-5.1":
|
||||
return s.fallbackPrices["gpt-5.1"]
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetModelPricing 获取模型价格配置
|
||||
func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
|
||||
// 标准化模型名称(转小写)
|
||||
model = strings.ToLower(model)
|
||||
|
||||
// 1. 优先从动态价格服务获取
|
||||
if s.pricingService != nil {
|
||||
litellmPricing := s.pricingService.GetModelPricing(model)
|
||||
if litellmPricing != nil {
|
||||
// 启用 5m/1h 分类计费的条件:
|
||||
// 1. 存在 1h 价格
|
||||
// 2. 1h 价格 > 5m 价格(防止 LiteLLM 数据错误导致少收费)
|
||||
price5m := litellmPricing.CacheCreationInputTokenCost
|
||||
price1h := litellmPricing.CacheCreationInputTokenCostAbove1hr
|
||||
enableBreakdown := price1h > 0 && price1h > price5m
|
||||
return s.applyModelSpecificPricingPolicy(model, &ModelPricing{
|
||||
InputPricePerToken: litellmPricing.InputCostPerToken,
|
||||
InputPricePerTokenPriority: litellmPricing.InputCostPerTokenPriority,
|
||||
OutputPricePerToken: litellmPricing.OutputCostPerToken,
|
||||
OutputPricePerTokenPriority: litellmPricing.OutputCostPerTokenPriority,
|
||||
CacheCreationPricePerToken: litellmPricing.CacheCreationInputTokenCost,
|
||||
CacheReadPricePerToken: litellmPricing.CacheReadInputTokenCost,
|
||||
CacheReadPricePerTokenPriority: litellmPricing.CacheReadInputTokenCostPriority,
|
||||
CacheCreation5mPrice: price5m,
|
||||
CacheCreation1hPrice: price1h,
|
||||
SupportsCacheBreakdown: enableBreakdown,
|
||||
LongContextInputThreshold: litellmPricing.LongContextInputTokenThreshold,
|
||||
LongContextInputMultiplier: litellmPricing.LongContextInputCostMultiplier,
|
||||
LongContextOutputMultiplier: litellmPricing.LongContextOutputCostMultiplier,
|
||||
}), nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. 使用硬编码回退价格
|
||||
fallback := s.getFallbackPricing(model)
|
||||
if fallback != nil {
|
||||
log.Printf("[Billing] Using fallback pricing for model: %s", model)
|
||||
return s.applyModelSpecificPricingPolicy(model, fallback), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("pricing not found for model: %s", model)
|
||||
}
|
||||
|
||||
// CalculateCost 计算使用费用
|
||||
func (s *BillingService) CalculateCost(model string, tokens UsageTokens, rateMultiplier float64) (*CostBreakdown, error) {
|
||||
return s.CalculateCostWithServiceTier(model, tokens, rateMultiplier, "")
|
||||
}
|
||||
|
||||
func (s *BillingService) CalculateCostWithServiceTier(model string, tokens UsageTokens, rateMultiplier float64, serviceTier string) (*CostBreakdown, error) {
|
||||
pricing, err := s.GetModelPricing(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
breakdown := &CostBreakdown{}
|
||||
inputPricePerToken := pricing.InputPricePerToken
|
||||
outputPricePerToken := pricing.OutputPricePerToken
|
||||
cacheReadPricePerToken := pricing.CacheReadPricePerToken
|
||||
tierMultiplier := 1.0
|
||||
if usePriorityServiceTierPricing(serviceTier, pricing) {
|
||||
if pricing.InputPricePerTokenPriority > 0 {
|
||||
inputPricePerToken = pricing.InputPricePerTokenPriority
|
||||
}
|
||||
if pricing.OutputPricePerTokenPriority > 0 {
|
||||
outputPricePerToken = pricing.OutputPricePerTokenPriority
|
||||
}
|
||||
if pricing.CacheReadPricePerTokenPriority > 0 {
|
||||
cacheReadPricePerToken = pricing.CacheReadPricePerTokenPriority
|
||||
}
|
||||
} else {
|
||||
tierMultiplier = serviceTierCostMultiplier(serviceTier)
|
||||
}
|
||||
if s.shouldApplySessionLongContextPricing(tokens, pricing) {
|
||||
inputPricePerToken *= pricing.LongContextInputMultiplier
|
||||
outputPricePerToken *= pricing.LongContextOutputMultiplier
|
||||
}
|
||||
|
||||
// 计算输入token费用(使用per-token价格)
|
||||
breakdown.InputCost = float64(tokens.InputTokens) * inputPricePerToken
|
||||
|
||||
// 计算输出token费用
|
||||
breakdown.OutputCost = float64(tokens.OutputTokens) * outputPricePerToken
|
||||
|
||||
// 计算缓存费用
|
||||
if pricing.SupportsCacheBreakdown && (pricing.CacheCreation5mPrice > 0 || pricing.CacheCreation1hPrice > 0) {
|
||||
// 支持详细缓存分类的模型(5分钟/1小时缓存,价格为 per-token)
|
||||
if tokens.CacheCreation5mTokens == 0 && tokens.CacheCreation1hTokens == 0 && tokens.CacheCreationTokens > 0 {
|
||||
// API 未返回 ephemeral 明细,回退到全部按 5m 单价计费
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreation5mPrice
|
||||
} else {
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreation5mTokens)*pricing.CacheCreation5mPrice +
|
||||
float64(tokens.CacheCreation1hTokens)*pricing.CacheCreation1hPrice
|
||||
}
|
||||
} else {
|
||||
// 标准缓存创建价格(per-token)
|
||||
breakdown.CacheCreationCost = float64(tokens.CacheCreationTokens) * pricing.CacheCreationPricePerToken
|
||||
}
|
||||
|
||||
breakdown.CacheReadCost = float64(tokens.CacheReadTokens) * cacheReadPricePerToken
|
||||
|
||||
if tierMultiplier != 1.0 {
|
||||
breakdown.InputCost *= tierMultiplier
|
||||
breakdown.OutputCost *= tierMultiplier
|
||||
breakdown.CacheCreationCost *= tierMultiplier
|
||||
breakdown.CacheReadCost *= tierMultiplier
|
||||
}
|
||||
|
||||
// 计算总费用
|
||||
breakdown.TotalCost = breakdown.InputCost + breakdown.OutputCost +
|
||||
breakdown.CacheCreationCost + breakdown.CacheReadCost
|
||||
|
||||
// 应用倍率计算实际费用
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
breakdown.ActualCost = breakdown.TotalCost * rateMultiplier
|
||||
|
||||
return breakdown, nil
|
||||
}
|
||||
|
||||
func (s *BillingService) applyModelSpecificPricingPolicy(model string, pricing *ModelPricing) *ModelPricing {
|
||||
if pricing == nil {
|
||||
return nil
|
||||
}
|
||||
if !isOpenAIGPT54Model(model) {
|
||||
return pricing
|
||||
}
|
||||
if pricing.LongContextInputThreshold > 0 && pricing.LongContextInputMultiplier > 0 && pricing.LongContextOutputMultiplier > 0 {
|
||||
return pricing
|
||||
}
|
||||
cloned := *pricing
|
||||
if cloned.LongContextInputThreshold <= 0 {
|
||||
cloned.LongContextInputThreshold = openAIGPT54LongContextInputThreshold
|
||||
}
|
||||
if cloned.LongContextInputMultiplier <= 0 {
|
||||
cloned.LongContextInputMultiplier = openAIGPT54LongContextInputMultiplier
|
||||
}
|
||||
if cloned.LongContextOutputMultiplier <= 0 {
|
||||
cloned.LongContextOutputMultiplier = openAIGPT54LongContextOutputMultiplier
|
||||
}
|
||||
return &cloned
|
||||
}
|
||||
|
||||
func (s *BillingService) shouldApplySessionLongContextPricing(tokens UsageTokens, pricing *ModelPricing) bool {
|
||||
if pricing == nil || pricing.LongContextInputThreshold <= 0 {
|
||||
return false
|
||||
}
|
||||
if pricing.LongContextInputMultiplier <= 1 && pricing.LongContextOutputMultiplier <= 1 {
|
||||
return false
|
||||
}
|
||||
totalInputTokens := tokens.InputTokens + tokens.CacheReadTokens
|
||||
return totalInputTokens > pricing.LongContextInputThreshold
|
||||
}
|
||||
|
||||
func isOpenAIGPT54Model(model string) bool {
|
||||
normalized := normalizeCodexModel(strings.TrimSpace(strings.ToLower(model)))
|
||||
return normalized == "gpt-5.4"
|
||||
}
|
||||
|
||||
// CalculateCostWithConfig 使用配置中的默认倍率计算费用
|
||||
func (s *BillingService) CalculateCostWithConfig(model string, tokens UsageTokens) (*CostBreakdown, error) {
|
||||
multiplier := s.cfg.Default.RateMultiplier
|
||||
if multiplier <= 0 {
|
||||
multiplier = 1.0
|
||||
}
|
||||
return s.CalculateCost(model, tokens, multiplier)
|
||||
}
|
||||
|
||||
// CalculateCostWithLongContext 计算费用,支持长上下文双倍计费
|
||||
// threshold: 阈值(如 200000),超过此值的部分按 extraMultiplier 倍计费
|
||||
// extraMultiplier: 超出部分的倍率(如 2.0 表示双倍)
|
||||
//
|
||||
// 示例:缓存 210k + 输入 10k = 220k,阈值 200k,倍率 2.0
|
||||
// 拆分为:范围内 (200k, 0) + 范围外 (10k, 10k)
|
||||
// 范围内正常计费,范围外 × 2 计费
|
||||
func (s *BillingService) CalculateCostWithLongContext(model string, tokens UsageTokens, rateMultiplier float64, threshold int, extraMultiplier float64) (*CostBreakdown, error) {
|
||||
// 未启用长上下文计费,直接走正常计费
|
||||
if threshold <= 0 || extraMultiplier <= 1 {
|
||||
return s.CalculateCost(model, tokens, rateMultiplier)
|
||||
}
|
||||
|
||||
// 计算总输入 token(缓存读取 + 新输入)
|
||||
total := tokens.CacheReadTokens + tokens.InputTokens
|
||||
if total <= threshold {
|
||||
return s.CalculateCost(model, tokens, rateMultiplier)
|
||||
}
|
||||
|
||||
// 拆分成范围内和范围外
|
||||
var inRangeCacheTokens, inRangeInputTokens int
|
||||
var outRangeCacheTokens, outRangeInputTokens int
|
||||
|
||||
if tokens.CacheReadTokens >= threshold {
|
||||
// 缓存已超过阈值:范围内只有缓存,范围外是超出的缓存+全部输入
|
||||
inRangeCacheTokens = threshold
|
||||
inRangeInputTokens = 0
|
||||
outRangeCacheTokens = tokens.CacheReadTokens - threshold
|
||||
outRangeInputTokens = tokens.InputTokens
|
||||
} else {
|
||||
// 缓存未超过阈值:范围内是全部缓存+部分输入,范围外是剩余输入
|
||||
inRangeCacheTokens = tokens.CacheReadTokens
|
||||
inRangeInputTokens = threshold - tokens.CacheReadTokens
|
||||
outRangeCacheTokens = 0
|
||||
outRangeInputTokens = tokens.InputTokens - inRangeInputTokens
|
||||
}
|
||||
|
||||
// 范围内部分:正常计费
|
||||
inRangeTokens := UsageTokens{
|
||||
InputTokens: inRangeInputTokens,
|
||||
OutputTokens: tokens.OutputTokens, // 输出只算一次
|
||||
CacheCreationTokens: tokens.CacheCreationTokens,
|
||||
CacheReadTokens: inRangeCacheTokens,
|
||||
CacheCreation5mTokens: tokens.CacheCreation5mTokens,
|
||||
CacheCreation1hTokens: tokens.CacheCreation1hTokens,
|
||||
}
|
||||
inRangeCost, err := s.CalculateCost(model, inRangeTokens, rateMultiplier)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 范围外部分:× extraMultiplier 计费
|
||||
outRangeTokens := UsageTokens{
|
||||
InputTokens: outRangeInputTokens,
|
||||
CacheReadTokens: outRangeCacheTokens,
|
||||
}
|
||||
outRangeCost, err := s.CalculateCost(model, outRangeTokens, rateMultiplier*extraMultiplier)
|
||||
if err != nil {
|
||||
return inRangeCost, fmt.Errorf("out-range cost: %w", err)
|
||||
}
|
||||
|
||||
// 合并成本
|
||||
return &CostBreakdown{
|
||||
InputCost: inRangeCost.InputCost + outRangeCost.InputCost,
|
||||
OutputCost: inRangeCost.OutputCost,
|
||||
CacheCreationCost: inRangeCost.CacheCreationCost,
|
||||
CacheReadCost: inRangeCost.CacheReadCost + outRangeCost.CacheReadCost,
|
||||
TotalCost: inRangeCost.TotalCost + outRangeCost.TotalCost,
|
||||
ActualCost: inRangeCost.ActualCost + outRangeCost.ActualCost,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ListSupportedModels 列出所有支持的模型(现在总是返回true,因为有模糊匹配)
|
||||
func (s *BillingService) ListSupportedModels() []string {
|
||||
models := make([]string, 0)
|
||||
// 返回回退价格支持的模型系列
|
||||
for model := range s.fallbackPrices {
|
||||
models = append(models, model)
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
// IsModelSupported 检查模型是否支持(现在总是返回true,因为有模糊匹配回退)
|
||||
func (s *BillingService) IsModelSupported(model string) bool {
|
||||
// 所有Claude模型都有回退价格支持
|
||||
modelLower := strings.ToLower(model)
|
||||
return strings.Contains(modelLower, "claude") ||
|
||||
strings.Contains(modelLower, "opus") ||
|
||||
strings.Contains(modelLower, "sonnet") ||
|
||||
strings.Contains(modelLower, "haiku")
|
||||
}
|
||||
|
||||
// GetEstimatedCost 估算费用(用于前端展示)
|
||||
func (s *BillingService) GetEstimatedCost(model string, estimatedInputTokens, estimatedOutputTokens int) (float64, error) {
|
||||
tokens := UsageTokens{
|
||||
InputTokens: estimatedInputTokens,
|
||||
OutputTokens: estimatedOutputTokens,
|
||||
}
|
||||
|
||||
breakdown, err := s.CalculateCostWithConfig(model, tokens)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return breakdown.ActualCost, nil
|
||||
}
|
||||
|
||||
// GetPricingServiceStatus 获取价格服务状态
|
||||
func (s *BillingService) GetPricingServiceStatus() map[string]any {
|
||||
if s.pricingService != nil {
|
||||
return s.pricingService.GetStatus()
|
||||
}
|
||||
return map[string]any{
|
||||
"model_count": len(s.fallbackPrices),
|
||||
"last_updated": "using fallback",
|
||||
"local_hash": "N/A",
|
||||
}
|
||||
}
|
||||
|
||||
// ForceUpdatePricing 强制更新价格数据
|
||||
func (s *BillingService) ForceUpdatePricing() error {
|
||||
if s.pricingService != nil {
|
||||
return s.pricingService.ForceUpdate()
|
||||
}
|
||||
return fmt.Errorf("pricing service not initialized")
|
||||
}
|
||||
|
||||
// ImagePriceConfig 图片计费配置
|
||||
type ImagePriceConfig struct {
|
||||
Price1K *float64 // 1K 尺寸价格(nil 表示使用默认值)
|
||||
Price2K *float64 // 2K 尺寸价格(nil 表示使用默认值)
|
||||
Price4K *float64 // 4K 尺寸价格(nil 表示使用默认值)
|
||||
}
|
||||
|
||||
// SoraPriceConfig Sora 按次计费配置
|
||||
type SoraPriceConfig struct {
|
||||
ImagePrice360 *float64
|
||||
ImagePrice540 *float64
|
||||
VideoPricePerRequest *float64
|
||||
VideoPricePerRequestHD *float64
|
||||
}
|
||||
|
||||
// CalculateImageCost 计算图片生成费用
|
||||
// model: 请求的模型名称(用于获取 LiteLLM 默认价格)
|
||||
// imageSize: 图片尺寸 "1K", "2K", "4K"
|
||||
// imageCount: 生成的图片数量
|
||||
// groupConfig: 分组配置的价格(可能为 nil,表示使用默认值)
|
||||
// rateMultiplier: 费率倍数
|
||||
func (s *BillingService) CalculateImageCost(model string, imageSize string, imageCount int, groupConfig *ImagePriceConfig, rateMultiplier float64) *CostBreakdown {
|
||||
if imageCount <= 0 {
|
||||
return &CostBreakdown{}
|
||||
}
|
||||
|
||||
// 获取单价
|
||||
unitPrice := s.getImageUnitPrice(model, imageSize, groupConfig)
|
||||
|
||||
// 计算总费用
|
||||
totalCost := unitPrice * float64(imageCount)
|
||||
|
||||
// 应用倍率
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
actualCost := totalCost * rateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateSoraImageCost 计算 Sora 图片按次费用
|
||||
func (s *BillingService) CalculateSoraImageCost(imageSize string, imageCount int, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
|
||||
if imageCount <= 0 {
|
||||
return &CostBreakdown{}
|
||||
}
|
||||
|
||||
unitPrice := 0.0
|
||||
if groupConfig != nil {
|
||||
switch imageSize {
|
||||
case "540":
|
||||
if groupConfig.ImagePrice540 != nil {
|
||||
unitPrice = *groupConfig.ImagePrice540
|
||||
}
|
||||
default:
|
||||
if groupConfig.ImagePrice360 != nil {
|
||||
unitPrice = *groupConfig.ImagePrice360
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
totalCost := unitPrice * float64(imageCount)
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
actualCost := totalCost * rateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}
|
||||
}
|
||||
|
||||
// CalculateSoraVideoCost 计算 Sora 视频按次费用
|
||||
func (s *BillingService) CalculateSoraVideoCost(model string, groupConfig *SoraPriceConfig, rateMultiplier float64) *CostBreakdown {
|
||||
unitPrice := 0.0
|
||||
if groupConfig != nil {
|
||||
modelLower := strings.ToLower(model)
|
||||
if strings.Contains(modelLower, "sora2pro-hd") {
|
||||
if groupConfig.VideoPricePerRequestHD != nil {
|
||||
unitPrice = *groupConfig.VideoPricePerRequestHD
|
||||
}
|
||||
}
|
||||
if unitPrice <= 0 && groupConfig.VideoPricePerRequest != nil {
|
||||
unitPrice = *groupConfig.VideoPricePerRequest
|
||||
}
|
||||
}
|
||||
|
||||
totalCost := unitPrice
|
||||
if rateMultiplier <= 0 {
|
||||
rateMultiplier = 1.0
|
||||
}
|
||||
actualCost := totalCost * rateMultiplier
|
||||
|
||||
return &CostBreakdown{
|
||||
TotalCost: totalCost,
|
||||
ActualCost: actualCost,
|
||||
}
|
||||
}
|
||||
|
||||
// getImageUnitPrice 获取图片单价
|
||||
func (s *BillingService) getImageUnitPrice(model string, imageSize string, groupConfig *ImagePriceConfig) float64 {
|
||||
// 优先使用分组配置的价格
|
||||
if groupConfig != nil {
|
||||
switch imageSize {
|
||||
case "1K":
|
||||
if groupConfig.Price1K != nil {
|
||||
return *groupConfig.Price1K
|
||||
}
|
||||
case "2K":
|
||||
if groupConfig.Price2K != nil {
|
||||
return *groupConfig.Price2K
|
||||
}
|
||||
case "4K":
|
||||
if groupConfig.Price4K != nil {
|
||||
return *groupConfig.Price4K
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 回退到 LiteLLM 默认价格
|
||||
return s.getDefaultImagePrice(model, imageSize)
|
||||
}
|
||||
|
||||
// getDefaultImagePrice 获取 LiteLLM 默认图片价格
|
||||
func (s *BillingService) getDefaultImagePrice(model string, imageSize string) float64 {
|
||||
basePrice := 0.0
|
||||
|
||||
// 从 PricingService 获取 output_cost_per_image
|
||||
if s.pricingService != nil {
|
||||
pricing := s.pricingService.GetModelPricing(model)
|
||||
if pricing != nil && pricing.OutputCostPerImage > 0 {
|
||||
basePrice = pricing.OutputCostPerImage
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有找到价格,使用硬编码默认值($0.134,来自 gemini-3-pro-image-preview)
|
||||
if basePrice <= 0 {
|
||||
basePrice = 0.134
|
||||
}
|
||||
|
||||
// 2K 尺寸 1.5 倍,4K 尺寸翻倍
|
||||
if imageSize == "2K" {
|
||||
return basePrice * 1.5
|
||||
}
|
||||
if imageSize == "4K" {
|
||||
return basePrice * 2
|
||||
}
|
||||
|
||||
return basePrice
|
||||
}
|
||||
@@ -0,0 +1,149 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestCalculateImageCost_DefaultPricing 测试无分组配置时使用默认价格
|
||||
func TestCalculateImageCost_DefaultPricing(t *testing.T) {
|
||||
svc := &BillingService{} // pricingService 为 nil,使用硬编码默认值
|
||||
|
||||
// 2K 尺寸,默认价格 $0.134 * 1.5 = $0.201
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.201, cost.ActualCost, 0.0001)
|
||||
|
||||
// 多张图片
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 3, nil, 1.0)
|
||||
require.InDelta(t, 0.603, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_GroupCustomPricing 测试分组自定义价格
|
||||
func TestCalculateImageCost_GroupCustomPricing(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
price1K := 0.10
|
||||
price2K := 0.15
|
||||
price4K := 0.30
|
||||
groupConfig := &ImagePriceConfig{
|
||||
Price1K: &price1K,
|
||||
Price2K: &price2K,
|
||||
Price4K: &price4K,
|
||||
}
|
||||
|
||||
// 1K 使用分组价格
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 2, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.20, cost.TotalCost, 0.0001)
|
||||
|
||||
// 2K 使用分组价格
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.15, cost.TotalCost, 0.0001)
|
||||
|
||||
// 4K 使用分组价格
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.30, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_4KDoublePrice 测试 4K 默认价格翻倍
|
||||
func TestCalculateImageCost_4KDoublePrice(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
// 4K 尺寸,默认价格翻倍 $0.134 * 2 = $0.268
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_RateMultiplier 测试费率倍数
|
||||
func TestCalculateImageCost_RateMultiplier(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
// 费率倍数 1.5x
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.5)
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001) // TotalCost = 0.134 * 1.5
|
||||
require.InDelta(t, 0.3015, cost.ActualCost, 0.0001) // ActualCost = 0.201 * 1.5
|
||||
|
||||
// 费率倍数 2.0x
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 2, nil, 2.0)
|
||||
require.InDelta(t, 0.402, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.804, cost.ActualCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_ZeroCount 测试 imageCount=0
|
||||
func TestCalculateImageCost_ZeroCount(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 0, nil, 1.0)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
require.Equal(t, 0.0, cost.ActualCost)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_NegativeCount 测试 imageCount=-1
|
||||
func TestCalculateImageCost_NegativeCount(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", -1, nil, 1.0)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
require.Equal(t, 0.0, cost.ActualCost)
|
||||
}
|
||||
|
||||
// TestCalculateImageCost_ZeroRateMultiplier 测试费率倍数为 0 时默认使用 1.0
|
||||
func TestCalculateImageCost_ZeroRateMultiplier(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 0)
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
|
||||
require.InDelta(t, 0.201, cost.ActualCost, 0.0001) // 0 倍率当作 1.0 处理
|
||||
}
|
||||
|
||||
// TestGetImageUnitPrice_GroupPriorityOverDefault 测试分组价格优先于默认价格
|
||||
func TestGetImageUnitPrice_GroupPriorityOverDefault(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
price2K := 0.20
|
||||
groupConfig := &ImagePriceConfig{
|
||||
Price2K: &price2K,
|
||||
}
|
||||
|
||||
// 分组配置了 2K 价格,应该使用分组价格而不是默认的 $0.134
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.20, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestGetImageUnitPrice_PartialGroupConfig 测试分组部分配置时回退默认
|
||||
func TestGetImageUnitPrice_PartialGroupConfig(t *testing.T) {
|
||||
svc := &BillingService{}
|
||||
|
||||
// 只配置 1K 价格
|
||||
price1K := 0.10
|
||||
groupConfig := &ImagePriceConfig{
|
||||
Price1K: &price1K,
|
||||
}
|
||||
|
||||
// 1K 使用分组价格
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.10, cost.TotalCost, 0.0001)
|
||||
|
||||
// 2K 回退默认价格 $0.201 (1.5倍)
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
|
||||
|
||||
// 4K 回退默认价格 $0.268 (翻倍)
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "4K", 1, groupConfig, 1.0)
|
||||
require.InDelta(t, 0.268, cost.TotalCost, 0.0001)
|
||||
}
|
||||
|
||||
// TestGetDefaultImagePrice_FallbackHardcoded 测试 PricingService 无数据时使用硬编码默认值
|
||||
func TestGetDefaultImagePrice_FallbackHardcoded(t *testing.T) {
|
||||
svc := &BillingService{} // pricingService 为 nil
|
||||
|
||||
// 1K 默认价格 $0.134,2K 默认价格 $0.201 (1.5倍)
|
||||
cost := svc.CalculateImageCost("gemini-3-pro-image", "1K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.134, cost.TotalCost, 0.0001)
|
||||
|
||||
cost = svc.CalculateImageCost("gemini-3-pro-image", "2K", 1, nil, 1.0)
|
||||
require.InDelta(t, 0.201, cost.TotalCost, 0.0001)
|
||||
}
|
||||
@@ -0,0 +1,710 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestBillingService() *BillingService {
|
||||
return NewBillingService(&config.Config{}, nil)
|
||||
}
|
||||
|
||||
func TestCalculateCost_BasicComputation(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
// 使用 claude-sonnet-4 的回退价格:Input $3/MTok, Output $15/MTok
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 500,
|
||||
}
|
||||
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 1000 * 3e-6 = 0.003, 500 * 15e-6 = 0.0075
|
||||
expectedInput := 1000 * 3e-6
|
||||
expectedOutput := 500 * 15e-6
|
||||
require.InDelta(t, expectedInput, cost.InputCost, 1e-10)
|
||||
require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10)
|
||||
require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10)
|
||||
require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCost_WithCacheTokens(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 500,
|
||||
CacheCreationTokens: 2000,
|
||||
CacheReadTokens: 3000,
|
||||
}
|
||||
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedCacheCreation := 2000 * 3.75e-6
|
||||
expectedCacheRead := 3000 * 0.3e-6
|
||||
require.InDelta(t, expectedCacheCreation, cost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, expectedCacheRead, cost.CacheReadCost, 1e-10)
|
||||
|
||||
expectedTotal := cost.InputCost + cost.OutputCost + expectedCacheCreation + expectedCacheRead
|
||||
require.InDelta(t, expectedTotal, cost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCost_RateMultiplier(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
||||
|
||||
cost1x, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
cost2x, err := svc.CalculateCost("claude-sonnet-4", tokens, 2.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// TotalCost 不受倍率影响,ActualCost 翻倍
|
||||
require.InDelta(t, cost1x.TotalCost, cost2x.TotalCost, 1e-10)
|
||||
require.InDelta(t, cost1x.ActualCost*2, cost2x.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCost_ZeroMultiplierDefaultsToOne(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000}
|
||||
|
||||
costZero, err := svc.CalculateCost("claude-sonnet-4", tokens, 0)
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costZero.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCost_NegativeMultiplierDefaultsToOne(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000}
|
||||
|
||||
costNeg, err := svc.CalculateCost("claude-sonnet-4", tokens, -1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
costOne, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, costOne.ActualCost, costNeg.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_FallbackMatchesByFamily(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tests := []struct {
|
||||
model string
|
||||
expectedInput float64
|
||||
}{
|
||||
{"claude-opus-4.5-20250101", 5e-6},
|
||||
{"claude-3-opus-20240229", 15e-6},
|
||||
{"claude-sonnet-4-20250514", 3e-6},
|
||||
{"claude-3-5-sonnet-20241022", 3e-6},
|
||||
{"claude-3-5-haiku-20241022", 1e-6},
|
||||
{"claude-3-haiku-20240307", 0.25e-6},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
pricing, err := svc.GetModelPricing(tt.model)
|
||||
require.NoError(t, err, "模型 %s", tt.model)
|
||||
require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12, "模型 %s 输入价格", tt.model)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelPricing_CaseInsensitive(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
p1, err := svc.GetModelPricing("Claude-Sonnet-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
p2, err := svc.GetModelPricing("claude-sonnet-4")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, p1.InputPricePerToken, p2.InputPricePerToken)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_UnknownClaudeModelFallsBackToSonnet(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
// 不包含 opus/sonnet/haiku 关键词的 Claude 模型会走默认 Sonnet 价格
|
||||
pricing, err := svc.GetModelPricing("claude-unknown-model")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 3e-6, pricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_UnknownOpenAIModelReturnsError(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-unknown-model")
|
||||
require.Error(t, err)
|
||||
require.Nil(t, pricing)
|
||||
require.Contains(t, err.Error(), "pricing not found")
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGPT51Fallback(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-5.1")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 1.25e-6, pricing.InputPricePerToken, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGPT54Fallback(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.Equal(t, 272000, pricing.LongContextInputThreshold)
|
||||
require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateCost_OpenAIGPT54LongContextAppliesWholeSessionMultipliers(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 300000,
|
||||
OutputTokens: 4000,
|
||||
}
|
||||
|
||||
cost, err := svc.CalculateCost("gpt-5.4-2026-03-05", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedInput := float64(tokens.InputTokens) * 2.5e-6 * 2.0
|
||||
expectedOutput := float64(tokens.OutputTokens) * 15e-6 * 1.5
|
||||
require.InDelta(t, expectedInput, cost.InputCost, 1e-10)
|
||||
require.InDelta(t, expectedOutput, cost.OutputCost, 1e-10)
|
||||
require.InDelta(t, expectedInput+expectedOutput, cost.TotalCost, 1e-10)
|
||||
require.InDelta(t, expectedInput+expectedOutput, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestGetFallbackPricing_FamilyMatching(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expectedInput float64
|
||||
expectNilPricing bool
|
||||
}{
|
||||
{name: "empty model", model: " ", expectNilPricing: true},
|
||||
{name: "claude opus 4.6", model: "claude-opus-4.6-20260201", expectedInput: 5e-6},
|
||||
{name: "claude opus 4.5 alt separator", model: "claude-opus-4-5-20260101", expectedInput: 5e-6},
|
||||
{name: "claude generic model fallback sonnet", model: "claude-foo-bar", expectedInput: 3e-6},
|
||||
{name: "gemini explicit fallback", model: "gemini-3-1-pro", expectedInput: 2e-6},
|
||||
{name: "gemini unknown no fallback", model: "gemini-2.0-pro", expectNilPricing: true},
|
||||
{name: "openai gpt5.1", model: "gpt-5.1", expectedInput: 1.25e-6},
|
||||
{name: "openai gpt5.4", model: "gpt-5.4", expectedInput: 2.5e-6},
|
||||
{name: "openai gpt5.3 codex", model: "gpt-5.3-codex", expectedInput: 1.5e-6},
|
||||
{name: "openai gpt5.1 codex max alias", model: "gpt-5.1-codex-max", expectedInput: 1.5e-6},
|
||||
{name: "openai codex mini latest alias", model: "codex-mini-latest", expectedInput: 1.5e-6},
|
||||
{name: "openai unknown no fallback", model: "gpt-unknown-model", expectNilPricing: true},
|
||||
{name: "non supported family", model: "qwen-max", expectNilPricing: true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pricing := svc.getFallbackPricing(tt.model)
|
||||
if tt.expectNilPricing {
|
||||
require.Nil(t, pricing)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, pricing)
|
||||
require.InDelta(t, tt.expectedInput, pricing.InputPricePerToken, 1e-12)
|
||||
})
|
||||
}
|
||||
}
|
||||
func TestCalculateCostWithLongContext_BelowThreshold(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 50000,
|
||||
OutputTokens: 1000,
|
||||
CacheReadTokens: 100000,
|
||||
}
|
||||
// 总输入 150k < 200k 阈值,应走正常计费
|
||||
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithLongContext_AboveThreshold_CacheExceedsThreshold(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
// 缓存 210k + 输入 10k = 220k > 200k 阈值
|
||||
// 缓存已超阈值:范围内 200k 缓存,范围外 10k 缓存 + 10k 输入
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 10000,
|
||||
OutputTokens: 1000,
|
||||
CacheReadTokens: 210000,
|
||||
}
|
||||
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 范围内:200k cache + 0 input + 1k output
|
||||
inRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{
|
||||
InputTokens: 0,
|
||||
OutputTokens: 1000,
|
||||
CacheReadTokens: 200000,
|
||||
}, 1.0)
|
||||
|
||||
// 范围外:10k cache + 10k input,倍率 2.0
|
||||
outRange, _ := svc.CalculateCost("claude-sonnet-4", UsageTokens{
|
||||
InputTokens: 10000,
|
||||
CacheReadTokens: 10000,
|
||||
}, 2.0)
|
||||
|
||||
require.InDelta(t, inRange.ActualCost+outRange.ActualCost, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithLongContext_AboveThreshold_CacheBelowThreshold(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
// 缓存 100k + 输入 150k = 250k > 200k 阈值
|
||||
// 缓存未超阈值:范围内 100k 缓存 + 100k 输入,范围外 50k 输入
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 150000,
|
||||
OutputTokens: 1000,
|
||||
CacheReadTokens: 100000,
|
||||
}
|
||||
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 2.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.True(t, cost.ActualCost > 0, "费用应大于 0")
|
||||
|
||||
// 正常费用不含长上下文
|
||||
normalCost, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.True(t, cost.ActualCost > normalCost.ActualCost, "长上下文费用应高于正常费用")
|
||||
}
|
||||
|
||||
func TestCalculateCostWithLongContext_DisabledThreshold(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0}
|
||||
|
||||
// threshold <= 0 应禁用长上下文计费
|
||||
cost1, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 0, 2.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
cost2, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, cost2.ActualCost, cost1.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithLongContext_ExtraMultiplierLessEqualOne(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{InputTokens: 300000}
|
||||
|
||||
// extraMultiplier <= 1 应禁用长上下文计费
|
||||
cost, err := svc.CalculateCostWithLongContext("claude-sonnet-4", tokens, 1.0, 200000, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
normalCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, normalCost.ActualCost, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateImageCost(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
price := 0.134
|
||||
cfg := &ImagePriceConfig{Price1K: &price}
|
||||
cost := svc.CalculateImageCost("gpt-image-1", "1K", 3, cfg, 1.0)
|
||||
|
||||
require.InDelta(t, 0.134*3, cost.TotalCost, 1e-10)
|
||||
require.InDelta(t, 0.134*3, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateSoraVideoCost(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
price := 0.5
|
||||
cfg := &SoraPriceConfig{VideoPricePerRequest: &price}
|
||||
cost := svc.CalculateSoraVideoCost("sora-video", cfg, 1.0)
|
||||
|
||||
require.InDelta(t, 0.5, cost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateSoraVideoCost_HDModel(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
hdPrice := 1.0
|
||||
normalPrice := 0.5
|
||||
cfg := &SoraPriceConfig{
|
||||
VideoPricePerRequest: &normalPrice,
|
||||
VideoPricePerRequestHD: &hdPrice,
|
||||
}
|
||||
cost := svc.CalculateSoraVideoCost("sora2pro-hd", cfg, 1.0)
|
||||
require.InDelta(t, 1.0, cost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestIsModelSupported(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
require.True(t, svc.IsModelSupported("claude-sonnet-4"))
|
||||
require.True(t, svc.IsModelSupported("Claude-Opus-4.5"))
|
||||
require.True(t, svc.IsModelSupported("claude-3-haiku"))
|
||||
require.False(t, svc.IsModelSupported("gpt-4o"))
|
||||
require.False(t, svc.IsModelSupported("gemini-pro"))
|
||||
}
|
||||
|
||||
func TestCalculateCost_ZeroTokens(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
cost, err := svc.CalculateCost("claude-sonnet-4", UsageTokens{}, 1.0)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
require.Equal(t, 0.0, cost.ActualCost)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithConfig(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Default.RateMultiplier = 1.5
|
||||
svc := NewBillingService(cfg, nil)
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000, OutputTokens: 500}
|
||||
cost, err := svc.CalculateCostWithConfig("claude-sonnet-4", tokens)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.5)
|
||||
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithConfig_ZeroMultiplier(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.Default.RateMultiplier = 0
|
||||
svc := NewBillingService(cfg, nil)
|
||||
|
||||
tokens := UsageTokens{InputTokens: 1000}
|
||||
cost, err := svc.CalculateCostWithConfig("claude-sonnet-4", tokens)
|
||||
require.NoError(t, err)
|
||||
|
||||
// 倍率 <=0 时默认 1.0
|
||||
expected, _ := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.InDelta(t, expected.ActualCost, cost.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestGetEstimatedCost(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
est, err := svc.GetEstimatedCost("claude-sonnet-4", 1000, 500)
|
||||
require.NoError(t, err)
|
||||
require.True(t, est > 0)
|
||||
}
|
||||
|
||||
func TestListSupportedModels(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
models := svc.ListSupportedModels()
|
||||
require.NotEmpty(t, models)
|
||||
require.GreaterOrEqual(t, len(models), 6)
|
||||
}
|
||||
|
||||
func TestGetPricingServiceStatus_NilService(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
status := svc.GetPricingServiceStatus()
|
||||
require.NotNil(t, status)
|
||||
require.Equal(t, "using fallback", status["last_updated"])
|
||||
}
|
||||
|
||||
func TestForceUpdatePricing_NilService(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
err := svc.ForceUpdatePricing()
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not initialized")
|
||||
}
|
||||
|
||||
func TestCalculateSoraImageCost(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
price360 := 0.05
|
||||
price540 := 0.08
|
||||
cfg := &SoraPriceConfig{ImagePrice360: &price360, ImagePrice540: &price540}
|
||||
|
||||
cost := svc.CalculateSoraImageCost("360", 2, cfg, 1.0)
|
||||
require.InDelta(t, 0.10, cost.TotalCost, 1e-10)
|
||||
|
||||
cost540 := svc.CalculateSoraImageCost("540", 1, cfg, 2.0)
|
||||
require.InDelta(t, 0.08, cost540.TotalCost, 1e-10)
|
||||
require.InDelta(t, 0.16, cost540.ActualCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateSoraImageCost_ZeroCount(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
cost := svc.CalculateSoraImageCost("360", 0, nil, 1.0)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
}
|
||||
|
||||
func TestCalculateSoraVideoCost_NilConfig(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
cost := svc.CalculateSoraVideoCost("sora-video", nil, 1.0)
|
||||
require.Equal(t, 0.0, cost.TotalCost)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithLongContext_PropagatesError(t *testing.T) {
|
||||
// 使用空的 fallback prices 让 GetModelPricing 失败
|
||||
svc := &BillingService{
|
||||
cfg: &config.Config{},
|
||||
fallbackPrices: make(map[string]*ModelPricing),
|
||||
}
|
||||
|
||||
tokens := UsageTokens{InputTokens: 300000, CacheReadTokens: 0}
|
||||
_, err := svc.CalculateCostWithLongContext("unknown-model", tokens, 1.0, 200000, 2.0)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "pricing not found")
|
||||
}
|
||||
|
||||
func TestCalculateCost_SupportsCacheBreakdown(t *testing.T) {
|
||||
svc := &BillingService{
|
||||
cfg: &config.Config{},
|
||||
fallbackPrices: map[string]*ModelPricing{
|
||||
"claude-sonnet-4": {
|
||||
InputPricePerToken: 3e-6,
|
||||
OutputPricePerToken: 15e-6,
|
||||
SupportsCacheBreakdown: true,
|
||||
CacheCreation5mPrice: 4e-6, // per token
|
||||
CacheCreation1hPrice: 5e-6, // per token
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 1000,
|
||||
OutputTokens: 500,
|
||||
CacheCreation5mTokens: 100000,
|
||||
CacheCreation1hTokens: 50000,
|
||||
}
|
||||
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected5m := float64(tokens.CacheCreation5mTokens) * 4e-6
|
||||
expected1h := float64(tokens.CacheCreation1hTokens) * 5e-6
|
||||
require.InDelta(t, expected5m+expected1h, cost.CacheCreationCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCost_LargeTokenCount(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
tokens := UsageTokens{
|
||||
InputTokens: 1_000_000,
|
||||
OutputTokens: 1_000_000,
|
||||
}
|
||||
cost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Input: 1M * 3e-6 = $3, Output: 1M * 15e-6 = $15
|
||||
require.InDelta(t, 3.0, cost.InputCost, 1e-6)
|
||||
require.InDelta(t, 15.0, cost.OutputCost, 1e-6)
|
||||
require.False(t, math.IsNaN(cost.TotalCost))
|
||||
require.False(t, math.IsInf(cost.TotalCost, 0))
|
||||
}
|
||||
|
||||
func TestServiceTierCostMultiplier(t *testing.T) {
|
||||
require.InDelta(t, 2.0, serviceTierCostMultiplier("priority"), 1e-12)
|
||||
require.InDelta(t, 2.0, serviceTierCostMultiplier(" Priority "), 1e-12)
|
||||
require.InDelta(t, 0.5, serviceTierCostMultiplier("flex"), 1e-12)
|
||||
require.InDelta(t, 1.0, serviceTierCostMultiplier(""), 1e-12)
|
||||
require.InDelta(t, 1.0, serviceTierCostMultiplier("default"), 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_OpenAIPriorityUsesPriorityPricing(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheReadTokens: 20}
|
||||
|
||||
baseCost, err := svc.CalculateCost("gpt-5.1-codex", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
priorityCost, err := svc.CalculateCostWithServiceTier("gpt-5.1-codex", tokens, 1.0, "priority")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_FlexAppliesHalfMultiplier(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20}
|
||||
|
||||
baseCost, err := svc.CalculateCost("gpt-5.4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
flexCost, err := svc.CalculateCostWithServiceTier("gpt-5.4", tokens, 1.0, "flex")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*0.5, flexCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*0.5, flexCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheCreationCost*0.5, flexCost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*0.5, flexCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*0.5, flexCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWithoutExplicitPriorityPrice(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
tokens := UsageTokens{InputTokens: 120, OutputTokens: 30, CacheCreationTokens: 12, CacheReadTokens: 8}
|
||||
|
||||
baseCost, err := svc.CalculateCost("claude-sonnet-4", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
priorityCost, err := svc.CalculateCostWithServiceTier("claude-sonnet-4", tokens, 1.0, "priority")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestBillingServiceGetModelPricing_UsesDynamicPriorityFields(t *testing.T) {
|
||||
pricingSvc := &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"gpt-5.4": {
|
||||
InputCostPerToken: 2.5e-6,
|
||||
InputCostPerTokenPriority: 5e-6,
|
||||
OutputCostPerToken: 15e-6,
|
||||
OutputCostPerTokenPriority: 30e-6,
|
||||
CacheCreationInputTokenCost: 2.5e-6,
|
||||
CacheReadInputTokenCost: 0.25e-6,
|
||||
CacheReadInputTokenCostPriority: 0.5e-6,
|
||||
LongContextInputTokenThreshold: 272000,
|
||||
LongContextInputCostMultiplier: 2.0,
|
||||
LongContextOutputCostMultiplier: 1.5,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewBillingService(&config.Config{}, pricingSvc)
|
||||
|
||||
pricing, err := svc.GetModelPricing("gpt-5.4")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 2.5e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 15e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 30e-6, pricing.OutputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 0.25e-6, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.InDelta(t, 0.5e-6, pricing.CacheReadPricePerTokenPriority, 1e-12)
|
||||
require.Equal(t, 272000, pricing.LongContextInputThreshold)
|
||||
require.InDelta(t, 2.0, pricing.LongContextInputMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.5, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
|
||||
func TestBillingServiceGetModelPricing_OpenAIFallbackGpt52Variants(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
gpt52, err := svc.GetModelPricing("gpt-5.2")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gpt52)
|
||||
require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12)
|
||||
|
||||
gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex")
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, gpt52Codex)
|
||||
require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12)
|
||||
}
|
||||
|
||||
func TestCalculateCostWithServiceTier_PriorityFallsBackToTierMultiplierWhenExplicitPriceMissing(t *testing.T) {
|
||||
svc := NewBillingService(&config.Config{}, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"custom-no-priority": {
|
||||
InputCostPerToken: 1e-6,
|
||||
OutputCostPerToken: 2e-6,
|
||||
CacheCreationInputTokenCost: 0.5e-6,
|
||||
CacheReadInputTokenCost: 0.25e-6,
|
||||
},
|
||||
},
|
||||
})
|
||||
tokens := UsageTokens{InputTokens: 100, OutputTokens: 50, CacheCreationTokens: 40, CacheReadTokens: 20}
|
||||
|
||||
baseCost, err := svc.CalculateCost("custom-no-priority", tokens, 1.0)
|
||||
require.NoError(t, err)
|
||||
|
||||
priorityCost, err := svc.CalculateCostWithServiceTier("custom-no-priority", tokens, 1.0, "priority")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.InDelta(t, baseCost.InputCost*2, priorityCost.InputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.OutputCost*2, priorityCost.OutputCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheCreationCost*2, priorityCost.CacheCreationCost, 1e-10)
|
||||
require.InDelta(t, baseCost.CacheReadCost*2, priorityCost.CacheReadCost, 1e-10)
|
||||
require.InDelta(t, baseCost.TotalCost*2, priorityCost.TotalCost, 1e-10)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_OpenAIGpt52FallbacksExposePriorityPrices(t *testing.T) {
|
||||
svc := newTestBillingService()
|
||||
|
||||
gpt52, err := svc.GetModelPricing("gpt-5.2")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1.75e-6, gpt52.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 14e-6, gpt52.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 28e-6, gpt52.OutputPricePerTokenPriority, 1e-12)
|
||||
|
||||
gpt52Codex, err := svc.GetModelPricing("gpt-5.2-codex")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1.75e-6, gpt52Codex.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 3.5e-6, gpt52Codex.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 14e-6, gpt52Codex.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 28e-6, gpt52Codex.OutputPricePerTokenPriority, 1e-12)
|
||||
}
|
||||
|
||||
func TestGetModelPricing_MapsDynamicPriorityFieldsIntoBillingPricing(t *testing.T) {
|
||||
svc := NewBillingService(&config.Config{}, &PricingService{
|
||||
pricingData: map[string]*LiteLLMModelPricing{
|
||||
"dynamic-tier-model": {
|
||||
InputCostPerToken: 1e-6,
|
||||
InputCostPerTokenPriority: 2e-6,
|
||||
OutputCostPerToken: 3e-6,
|
||||
OutputCostPerTokenPriority: 6e-6,
|
||||
CacheCreationInputTokenCost: 4e-6,
|
||||
CacheCreationInputTokenCostAbove1hr: 5e-6,
|
||||
CacheReadInputTokenCost: 7e-7,
|
||||
CacheReadInputTokenCostPriority: 8e-7,
|
||||
LongContextInputTokenThreshold: 999,
|
||||
LongContextInputCostMultiplier: 1.5,
|
||||
LongContextOutputCostMultiplier: 1.25,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
pricing, err := svc.GetModelPricing("dynamic-tier-model")
|
||||
require.NoError(t, err)
|
||||
require.InDelta(t, 1e-6, pricing.InputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 2e-6, pricing.InputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 3e-6, pricing.OutputPricePerToken, 1e-12)
|
||||
require.InDelta(t, 6e-6, pricing.OutputPricePerTokenPriority, 1e-12)
|
||||
require.InDelta(t, 4e-6, pricing.CacheCreation5mPrice, 1e-12)
|
||||
require.InDelta(t, 5e-6, pricing.CacheCreation1hPrice, 1e-12)
|
||||
require.True(t, pricing.SupportsCacheBreakdown)
|
||||
require.InDelta(t, 7e-7, pricing.CacheReadPricePerToken, 1e-12)
|
||||
require.InDelta(t, 8e-7, pricing.CacheReadPricePerTokenPriority, 1e-12)
|
||||
require.Equal(t, 999, pricing.LongContextInputThreshold)
|
||||
require.InDelta(t, 1.5, pricing.LongContextInputMultiplier, 1e-12)
|
||||
require.InDelta(t, 1.25, pricing.LongContextOutputMultiplier, 1e-12)
|
||||
}
|
||||
@@ -0,0 +1,282 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newTestValidator() *ClaudeCodeValidator {
|
||||
return NewClaudeCodeValidator()
|
||||
}
|
||||
|
||||
// validClaudeCodeBody 构造一个完整有效的 Claude Code 请求体
|
||||
func validClaudeCodeBody() map[string]any {
|
||||
return map[string]any{
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"system": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
},
|
||||
},
|
||||
"metadata": map[string]any{
|
||||
"user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_" + "12345678-1234-1234-1234-123456789abc",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_ClaudeCLIUserAgent(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ua string
|
||||
want bool
|
||||
}{
|
||||
{"标准版本号", "claude-cli/1.0.0", true},
|
||||
{"多位版本号", "claude-cli/12.34.56", true},
|
||||
{"大写开头", "Claude-CLI/1.0.0", true},
|
||||
{"非 claude-cli", "curl/7.64.1", false},
|
||||
{"空 User-Agent", "", false},
|
||||
{"部分匹配", "not-claude-cli/1.0.0", false},
|
||||
{"缺少版本号", "claude-cli/", false},
|
||||
{"版本格式不对", "claude-cli/1.0", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
require.Equal(t, tt.want, v.ValidateUserAgent(tt.ua), "UA: %q", tt.ua)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_NonMessagesPath_UAOnly(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
// 非 messages 路径只检查 UA
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
|
||||
result := v.Validate(req, nil)
|
||||
require.True(t, result, "非 messages 路径只需 UA 匹配")
|
||||
}
|
||||
|
||||
func TestValidate_NonMessagesPath_InvalidUA(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
req := httptest.NewRequest("GET", "/v1/models", nil)
|
||||
req.Header.Set("User-Agent", "curl/7.64.1")
|
||||
|
||||
result := v.Validate(req, nil)
|
||||
require.False(t, result, "UA 不匹配时应返回 false")
|
||||
}
|
||||
|
||||
func TestValidate_MessagesPath_FullValid(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
req.Header.Set("X-App", "claude-code")
|
||||
req.Header.Set("anthropic-beta", "max-tokens-3-5-sonnet-2024-07-15")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
result := v.Validate(req, validClaudeCodeBody())
|
||||
require.True(t, result, "完整有效请求应通过")
|
||||
}
|
||||
|
||||
func TestValidate_MessagesPath_MissingHeaders(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
body := validClaudeCodeBody()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
missingHeader string
|
||||
}{
|
||||
{"缺少 X-App", "X-App"},
|
||||
{"缺少 anthropic-beta", "anthropic-beta"},
|
||||
{"缺少 anthropic-version", "anthropic-version"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
req.Header.Set("X-App", "claude-code")
|
||||
req.Header.Set("anthropic-beta", "beta")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
req.Header.Del(tt.missingHeader)
|
||||
|
||||
result := v.Validate(req, body)
|
||||
require.False(t, result, "缺少 %s 应返回 false", tt.missingHeader)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_MessagesPath_InvalidMetadataUserID(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
metadata map[string]any
|
||||
}{
|
||||
{"缺少 metadata", nil},
|
||||
{"缺少 user_id", map[string]any{"other": "value"}},
|
||||
{"空 user_id", map[string]any{"user_id": ""}},
|
||||
{"格式错误", map[string]any{"user_id": "invalid-format"}},
|
||||
{"hex 长度不足", map[string]any{"user_id": "user_abc_account__session_uuid"}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
req.Header.Set("X-App", "claude-code")
|
||||
req.Header.Set("anthropic-beta", "beta")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
body := map[string]any{
|
||||
"model": "claude-sonnet-4",
|
||||
"system": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
},
|
||||
},
|
||||
}
|
||||
if tt.metadata != nil {
|
||||
body["metadata"] = tt.metadata
|
||||
}
|
||||
|
||||
result := v.Validate(req, body)
|
||||
require.False(t, result, "metadata.user_id: %v", tt.metadata)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidate_MessagesPath_InvalidSystemPrompt(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
req.Header.Set("X-App", "claude-code")
|
||||
req.Header.Set("anthropic-beta", "beta")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
body := map[string]any{
|
||||
"model": "claude-sonnet-4",
|
||||
"system": []any{
|
||||
map[string]any{
|
||||
"type": "text",
|
||||
"text": "Generate JSON data for testing database migrations.",
|
||||
},
|
||||
},
|
||||
"metadata": map[string]any{
|
||||
"user_id": "user_" + "a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2c3d4e5f6a1b2" + "_account__session_12345678-1234-1234-1234-123456789abc",
|
||||
},
|
||||
}
|
||||
|
||||
result := v.Validate(req, body)
|
||||
require.False(t, result, "无关系统提示词应返回 false")
|
||||
}
|
||||
|
||||
func TestValidate_MaxTokensOneHaikuBypass(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
// 不设置 X-App 等头,通过 context 标记为 haiku 探测请求
|
||||
ctx := context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true)
|
||||
req = req.WithContext(ctx)
|
||||
|
||||
// 即使 body 不包含 system prompt,也应通过
|
||||
result := v.Validate(req, map[string]any{"model": "claude-3-haiku", "max_tokens": 1})
|
||||
require.True(t, result, "max_tokens=1+haiku 探测请求应绕过严格验证")
|
||||
}
|
||||
|
||||
func TestSystemPromptSimilarity(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
prompt string
|
||||
want bool
|
||||
}{
|
||||
{"精确匹配", "You are Claude Code, Anthropic's official CLI for Claude.", true},
|
||||
{"带多余空格", "You are Claude Code, Anthropic's official CLI for Claude.", true},
|
||||
{"Agent SDK 模板", "You are a Claude agent, built on Anthropic's Claude Agent SDK.", true},
|
||||
{"文件搜索专家模板", "You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.", true},
|
||||
{"对话摘要模板", "You are a helpful AI assistant tasked with summarizing conversations.", true},
|
||||
{"交互式 CLI 模板", "You are an interactive CLI tool that helps users", true},
|
||||
{"无关文本", "Write me a poem about cats", false},
|
||||
{"空文本", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
body := map[string]any{
|
||||
"model": "claude-sonnet-4",
|
||||
"system": []any{
|
||||
map[string]any{"type": "text", "text": tt.prompt},
|
||||
},
|
||||
}
|
||||
result := v.IncludesClaudeCodeSystemPrompt(body)
|
||||
require.Equal(t, tt.want, result, "提示词: %q", tt.prompt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiceCoefficient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
a string
|
||||
b string
|
||||
want float64
|
||||
tol float64
|
||||
}{
|
||||
{"相同字符串", "hello", "hello", 1.0, 0.001},
|
||||
{"完全不同", "abc", "xyz", 0.0, 0.001},
|
||||
{"空字符串", "", "hello", 0.0, 0.001},
|
||||
{"单字符", "a", "b", 0.0, 0.001},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := diceCoefficient(tt.a, tt.b)
|
||||
require.InDelta(t, tt.want, result, tt.tol)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsClaudeCodeClient_Context(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// 默认应为 false
|
||||
require.False(t, IsClaudeCodeClient(ctx))
|
||||
|
||||
// 设置为 true
|
||||
ctx = SetClaudeCodeClient(ctx, true)
|
||||
require.True(t, IsClaudeCodeClient(ctx))
|
||||
|
||||
// 设置为 false
|
||||
ctx = SetClaudeCodeClient(ctx, false)
|
||||
require.False(t, IsClaudeCodeClient(ctx))
|
||||
}
|
||||
|
||||
func TestValidate_NilBody_MessagesPath(t *testing.T) {
|
||||
v := newTestValidator()
|
||||
|
||||
req := httptest.NewRequest("POST", "/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.0.0")
|
||||
req.Header.Set("X-App", "claude-code")
|
||||
req.Header.Set("anthropic-beta", "beta")
|
||||
req.Header.Set("anthropic-version", "2023-06-01")
|
||||
|
||||
result := v.Validate(req, nil)
|
||||
require.False(t, result, "nil body 的 messages 请求应返回 false")
|
||||
}
|
||||
@@ -0,0 +1,328 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
)
|
||||
|
||||
// ClaudeCodeValidator 验证请求是否来自 Claude Code 客户端
|
||||
// 完全学习自 claude-relay-service 项目的验证逻辑
|
||||
type ClaudeCodeValidator struct{}
|
||||
|
||||
var (
|
||||
// User-Agent 匹配: claude-cli/x.x.x (仅支持官方 CLI,大小写不敏感)
|
||||
claudeCodeUAPattern = regexp.MustCompile(`(?i)^claude-cli/\d+\.\d+\.\d+`)
|
||||
|
||||
// 带捕获组的版本提取正则
|
||||
claudeCodeUAVersionPattern = regexp.MustCompile(`(?i)^claude-cli/(\d+\.\d+\.\d+)`)
|
||||
|
||||
// metadata.user_id 格式: user_{64位hex}_account__session_{uuid}
|
||||
userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account__session_[\w-]+$`)
|
||||
|
||||
// System prompt 相似度阈值(默认 0.5,和 claude-relay-service 一致)
|
||||
systemPromptThreshold = 0.5
|
||||
)
|
||||
|
||||
// Claude Code 官方 System Prompt 模板
|
||||
// 从 claude-relay-service/src/utils/contents.js 提取
|
||||
var claudeCodeSystemPrompts = []string{
|
||||
// claudeOtherSystemPrompt1 - Primary
|
||||
"You are Claude Code, Anthropic's official CLI for Claude.",
|
||||
|
||||
// claudeOtherSystemPrompt3 - Agent SDK
|
||||
"You are a Claude agent, built on Anthropic's Claude Agent SDK.",
|
||||
|
||||
// claudeOtherSystemPrompt4 - Compact Agent SDK
|
||||
"You are Claude Code, Anthropic's official CLI for Claude, running within the Claude Agent SDK.",
|
||||
|
||||
// exploreAgentSystemPrompt
|
||||
"You are a file search specialist for Claude Code, Anthropic's official CLI for Claude.",
|
||||
|
||||
// claudeOtherSystemPromptCompact - Compact (用于对话摘要)
|
||||
"You are a helpful AI assistant tasked with summarizing conversations.",
|
||||
|
||||
// claudeOtherSystemPrompt2 - Secondary (长提示词的关键部分)
|
||||
"You are an interactive CLI tool that helps users",
|
||||
}
|
||||
|
||||
// NewClaudeCodeValidator 创建验证器实例
|
||||
func NewClaudeCodeValidator() *ClaudeCodeValidator {
|
||||
return &ClaudeCodeValidator{}
|
||||
}
|
||||
|
||||
// Validate 验证请求是否来自 Claude Code CLI
|
||||
// 采用与 claude-relay-service 完全一致的验证策略:
|
||||
//
|
||||
// Step 1: User-Agent 检查 (必需) - 必须是 claude-cli/x.x.x
|
||||
// Step 2: 对于非 messages 路径,只要 UA 匹配就通过
|
||||
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过(UA 已验证)
|
||||
// Step 4: 对于 messages 路径,进行严格验证:
|
||||
// - System prompt 相似度检查
|
||||
// - X-App header 检查
|
||||
// - anthropic-beta header 检查
|
||||
// - anthropic-version header 检查
|
||||
// - metadata.user_id 格式验证
|
||||
func (v *ClaudeCodeValidator) Validate(r *http.Request, body map[string]any) bool {
|
||||
// Step 1: User-Agent 检查
|
||||
ua := r.Header.Get("User-Agent")
|
||||
if !claudeCodeUAPattern.MatchString(ua) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Step 2: 非 messages 路径,只要 UA 匹配就通过
|
||||
path := r.URL.Path
|
||||
if !strings.Contains(path, "messages") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Step 3: 检查 max_tokens=1 + haiku 探测请求绕过
|
||||
// 这类请求用于 Claude Code 验证 API 连通性,不携带 system prompt
|
||||
if isMaxTokensOneHaiku, ok := IsMaxTokensOneHaikuRequestFromContext(r.Context()); ok && isMaxTokensOneHaiku {
|
||||
return true // 绕过 system prompt 检查,UA 已在 Step 1 验证
|
||||
}
|
||||
|
||||
// Step 4: messages 路径,进行严格验证
|
||||
|
||||
// 4.1 检查 system prompt 相似度
|
||||
if !v.hasClaudeCodeSystemPrompt(body) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 4.2 检查必需的 headers(值不为空即可)
|
||||
xApp := r.Header.Get("X-App")
|
||||
if xApp == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
anthropicBeta := r.Header.Get("anthropic-beta")
|
||||
if anthropicBeta == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
anthropicVersion := r.Header.Get("anthropic-version")
|
||||
if anthropicVersion == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// 4.3 验证 metadata.user_id
|
||||
if body == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
metadata, ok := body["metadata"].(map[string]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
userID, ok := metadata["user_id"].(string)
|
||||
if !ok || userID == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
if !userIDPattern.MatchString(userID) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// hasClaudeCodeSystemPrompt 检查请求是否包含 Claude Code 系统提示词
|
||||
// 使用字符串相似度匹配(Dice coefficient)
|
||||
func (v *ClaudeCodeValidator) hasClaudeCodeSystemPrompt(body map[string]any) bool {
|
||||
if body == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查 model 字段
|
||||
if _, ok := body["model"].(string); !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取 system 字段
|
||||
systemEntries, ok := body["system"].([]any)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查每个 system entry
|
||||
for _, entry := range systemEntries {
|
||||
entryMap, ok := entry.(map[string]any)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
text, ok := entryMap["text"].(string)
|
||||
if !ok || text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// 计算与所有模板的最佳相似度
|
||||
bestScore := v.bestSimilarityScore(text)
|
||||
if bestScore >= systemPromptThreshold {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// bestSimilarityScore 计算文本与所有 Claude Code 模板的最佳相似度
|
||||
func (v *ClaudeCodeValidator) bestSimilarityScore(text string) float64 {
|
||||
normalizedText := normalizePrompt(text)
|
||||
bestScore := 0.0
|
||||
|
||||
for _, template := range claudeCodeSystemPrompts {
|
||||
normalizedTemplate := normalizePrompt(template)
|
||||
score := diceCoefficient(normalizedText, normalizedTemplate)
|
||||
if score > bestScore {
|
||||
bestScore = score
|
||||
}
|
||||
}
|
||||
|
||||
return bestScore
|
||||
}
|
||||
|
||||
// normalizePrompt 标准化提示词文本(去除多余空白)
|
||||
func normalizePrompt(text string) string {
|
||||
// 将所有空白字符替换为单个空格,并去除首尾空白
|
||||
return strings.Join(strings.Fields(text), " ")
|
||||
}
|
||||
|
||||
// diceCoefficient 计算两个字符串的 Dice 系数(Sørensen–Dice coefficient)
|
||||
// 这是 string-similarity 库使用的算法
|
||||
// 公式: 2 * |intersection| / (|bigrams(a)| + |bigrams(b)|)
|
||||
func diceCoefficient(a, b string) float64 {
|
||||
if a == b {
|
||||
return 1.0
|
||||
}
|
||||
|
||||
if len(a) < 2 || len(b) < 2 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 生成 bigrams
|
||||
bigramsA := getBigrams(a)
|
||||
bigramsB := getBigrams(b)
|
||||
|
||||
if len(bigramsA) == 0 || len(bigramsB) == 0 {
|
||||
return 0.0
|
||||
}
|
||||
|
||||
// 计算交集大小
|
||||
intersection := 0
|
||||
for bigram, countA := range bigramsA {
|
||||
if countB, exists := bigramsB[bigram]; exists {
|
||||
if countA < countB {
|
||||
intersection += countA
|
||||
} else {
|
||||
intersection += countB
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 计算总 bigram 数量
|
||||
totalA := 0
|
||||
for _, count := range bigramsA {
|
||||
totalA += count
|
||||
}
|
||||
totalB := 0
|
||||
for _, count := range bigramsB {
|
||||
totalB += count
|
||||
}
|
||||
|
||||
return float64(2*intersection) / float64(totalA+totalB)
|
||||
}
|
||||
|
||||
// getBigrams 获取字符串的所有 bigrams(相邻字符对)
|
||||
func getBigrams(s string) map[string]int {
|
||||
bigrams := make(map[string]int)
|
||||
runes := []rune(strings.ToLower(s))
|
||||
|
||||
for i := 0; i < len(runes)-1; i++ {
|
||||
bigram := string(runes[i : i+2])
|
||||
bigrams[bigram]++
|
||||
}
|
||||
|
||||
return bigrams
|
||||
}
|
||||
|
||||
// ValidateUserAgent 仅验证 User-Agent(用于不需要解析请求体的场景)
|
||||
func (v *ClaudeCodeValidator) ValidateUserAgent(ua string) bool {
|
||||
return claudeCodeUAPattern.MatchString(ua)
|
||||
}
|
||||
|
||||
// IncludesClaudeCodeSystemPrompt 检查请求体是否包含 Claude Code 系统提示词
|
||||
// 只要存在匹配的系统提示词就返回 true(用于宽松检测)
|
||||
func (v *ClaudeCodeValidator) IncludesClaudeCodeSystemPrompt(body map[string]any) bool {
|
||||
return v.hasClaudeCodeSystemPrompt(body)
|
||||
}
|
||||
|
||||
// IsClaudeCodeClient 从 context 中获取 Claude Code 客户端标识
|
||||
func IsClaudeCodeClient(ctx context.Context) bool {
|
||||
if v, ok := ctx.Value(ctxkey.IsClaudeCodeClient).(bool); ok {
|
||||
return v
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// SetClaudeCodeClient 将 Claude Code 客户端标识设置到 context 中
|
||||
func SetClaudeCodeClient(ctx context.Context, isClaudeCode bool) context.Context {
|
||||
return context.WithValue(ctx, ctxkey.IsClaudeCodeClient, isClaudeCode)
|
||||
}
|
||||
|
||||
// ExtractVersion 从 User-Agent 中提取 Claude Code 版本号
|
||||
// 返回 "2.1.22" 形式的版本号,如果不匹配返回空字符串
|
||||
func (v *ClaudeCodeValidator) ExtractVersion(ua string) string {
|
||||
matches := claudeCodeUAVersionPattern.FindStringSubmatch(ua)
|
||||
if len(matches) >= 2 {
|
||||
return matches[1]
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetClaudeCodeVersion 将 Claude Code 版本号设置到 context 中
|
||||
func SetClaudeCodeVersion(ctx context.Context, version string) context.Context {
|
||||
return context.WithValue(ctx, ctxkey.ClaudeCodeVersion, version)
|
||||
}
|
||||
|
||||
// GetClaudeCodeVersion 从 context 中获取 Claude Code 版本号
|
||||
func GetClaudeCodeVersion(ctx context.Context) string {
|
||||
if v, ok := ctx.Value(ctxkey.ClaudeCodeVersion).(string); ok {
|
||||
return v
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// CompareVersions 比较两个 semver 版本号
|
||||
// 返回: -1 (a < b), 0 (a == b), 1 (a > b)
|
||||
func CompareVersions(a, b string) int {
|
||||
aParts := parseSemver(a)
|
||||
bParts := parseSemver(b)
|
||||
for i := 0; i < 3; i++ {
|
||||
if aParts[i] < bParts[i] {
|
||||
return -1
|
||||
}
|
||||
if aParts[i] > bParts[i] {
|
||||
return 1
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// parseSemver 解析 semver 版本号为 [major, minor, patch]
|
||||
func parseSemver(v string) [3]int {
|
||||
v = strings.TrimPrefix(v, "v")
|
||||
parts := strings.Split(v, ".")
|
||||
result := [3]int{0, 0, 0}
|
||||
for i := 0; i < len(parts) && i < 3; i++ {
|
||||
if parsed, err := strconv.Atoi(parts[i]); err == nil {
|
||||
result[i] = parsed
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
@@ -0,0 +1,106 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ctxkey"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestClaudeCodeValidator_ProbeBypass(t *testing.T) {
|
||||
validator := NewClaudeCodeValidator()
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true))
|
||||
|
||||
ok := validator.Validate(req, map[string]any{
|
||||
"model": "claude-haiku-4-5",
|
||||
"max_tokens": 1,
|
||||
})
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestClaudeCodeValidator_ProbeBypassRequiresUA(t *testing.T) {
|
||||
validator := NewClaudeCodeValidator()
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "curl/8.0.0")
|
||||
req = req.WithContext(context.WithValue(req.Context(), ctxkey.IsMaxTokensOneHaikuRequest, true))
|
||||
|
||||
ok := validator.Validate(req, map[string]any{
|
||||
"model": "claude-haiku-4-5",
|
||||
"max_tokens": 1,
|
||||
})
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestClaudeCodeValidator_MessagesWithoutProbeStillNeedStrictValidation(t *testing.T) {
|
||||
validator := NewClaudeCodeValidator()
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/messages", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
|
||||
|
||||
ok := validator.Validate(req, map[string]any{
|
||||
"model": "claude-haiku-4-5",
|
||||
"max_tokens": 1,
|
||||
})
|
||||
require.False(t, ok)
|
||||
}
|
||||
|
||||
func TestClaudeCodeValidator_NonMessagesPathUAOnly(t *testing.T) {
|
||||
validator := NewClaudeCodeValidator()
|
||||
req := httptest.NewRequest(http.MethodPost, "http://example.com/v1/models", nil)
|
||||
req.Header.Set("User-Agent", "claude-cli/1.2.3 (darwin; arm64)")
|
||||
|
||||
ok := validator.Validate(req, nil)
|
||||
require.True(t, ok)
|
||||
}
|
||||
|
||||
func TestExtractVersion(t *testing.T) {
|
||||
v := NewClaudeCodeValidator()
|
||||
tests := []struct {
|
||||
ua string
|
||||
want string
|
||||
}{
|
||||
{"claude-cli/2.1.22 (darwin; arm64)", "2.1.22"},
|
||||
{"claude-cli/1.0.0", "1.0.0"},
|
||||
{"Claude-CLI/3.10.5 (linux; x86_64)", "3.10.5"}, // 大小写不敏感
|
||||
{"curl/8.0.0", ""}, // 非 Claude CLI
|
||||
{"", ""}, // 空字符串
|
||||
{"claude-cli/", ""}, // 无版本号
|
||||
{"claude-cli/2.1.22-beta", "2.1.22"}, // 带后缀仍提取主版本号
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := v.ExtractVersion(tt.ua)
|
||||
require.Equal(t, tt.want, got, "ExtractVersion(%q)", tt.ua)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompareVersions(t *testing.T) {
|
||||
tests := []struct {
|
||||
a, b string
|
||||
want int
|
||||
}{
|
||||
{"2.1.0", "2.1.0", 0}, // 相等
|
||||
{"2.1.1", "2.1.0", 1}, // patch 更大
|
||||
{"2.0.0", "2.1.0", -1}, // minor 更小
|
||||
{"3.0.0", "2.99.99", 1}, // major 更大
|
||||
{"1.0.0", "2.0.0", -1}, // major 更小
|
||||
{"0.0.1", "0.0.0", 1}, // patch 差异
|
||||
{"", "1.0.0", -1}, // 空字符串 vs 正常版本
|
||||
{"v2.1.0", "2.1.0", 0}, // v 前缀处理
|
||||
}
|
||||
for _, tt := range tests {
|
||||
got := CompareVersions(tt.a, tt.b)
|
||||
require.Equal(t, tt.want, got, "CompareVersions(%q, %q)", tt.a, tt.b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetGetClaudeCodeVersion(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
require.Equal(t, "", GetClaudeCodeVersion(ctx), "empty context should return empty string")
|
||||
|
||||
ctx = SetClaudeCodeVersion(ctx, "2.1.63")
|
||||
require.Equal(t, "2.1.63", GetClaudeCodeVersion(ctx))
|
||||
}
|
||||
@@ -0,0 +1,159 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
claudeTokenRefreshSkew = 3 * time.Minute
|
||||
claudeTokenCacheSkew = 5 * time.Minute
|
||||
claudeLockWaitTime = 200 * time.Millisecond
|
||||
)
|
||||
|
||||
// ClaudeTokenCache token cache interface.
|
||||
type ClaudeTokenCache = GeminiTokenCache
|
||||
|
||||
// ClaudeTokenProvider manages access_token for Claude OAuth accounts.
|
||||
type ClaudeTokenProvider struct {
|
||||
accountRepo AccountRepository
|
||||
tokenCache ClaudeTokenCache
|
||||
oauthService *OAuthService
|
||||
refreshAPI *OAuthRefreshAPI
|
||||
executor OAuthRefreshExecutor
|
||||
refreshPolicy ProviderRefreshPolicy
|
||||
}
|
||||
|
||||
func NewClaudeTokenProvider(
|
||||
accountRepo AccountRepository,
|
||||
tokenCache ClaudeTokenCache,
|
||||
oauthService *OAuthService,
|
||||
) *ClaudeTokenProvider {
|
||||
return &ClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: tokenCache,
|
||||
oauthService: oauthService,
|
||||
refreshPolicy: ClaudeProviderRefreshPolicy(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetRefreshAPI injects unified OAuth refresh API and executor.
|
||||
func (p *ClaudeTokenProvider) SetRefreshAPI(api *OAuthRefreshAPI, executor OAuthRefreshExecutor) {
|
||||
p.refreshAPI = api
|
||||
p.executor = executor
|
||||
}
|
||||
|
||||
// SetRefreshPolicy injects caller-side refresh policy.
|
||||
func (p *ClaudeTokenProvider) SetRefreshPolicy(policy ProviderRefreshPolicy) {
|
||||
p.refreshPolicy = policy
|
||||
}
|
||||
|
||||
// GetAccessToken returns a valid access_token.
|
||||
func (p *ClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an anthropic oauth account")
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// 1) Try cache first.
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit", "account_id", account.ID)
|
||||
return token, nil
|
||||
} else if err != nil {
|
||||
slog.Warn("claude_token_cache_get_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("claude_token_cache_miss", "account_id", account.ID)
|
||||
|
||||
// 2) Refresh if needed (pre-expiry skew).
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||
refreshFailed := false
|
||||
|
||||
if needsRefresh && p.refreshAPI != nil && p.executor != nil {
|
||||
result, err := p.refreshAPI.RefreshIfNeeded(ctx, account, p.executor, claudeTokenRefreshSkew)
|
||||
if err != nil {
|
||||
if p.refreshPolicy.OnRefreshError == ProviderRefreshErrorReturn {
|
||||
return "", err
|
||||
}
|
||||
slog.Warn("claude_token_refresh_failed", "account_id", account.ID, "error", err)
|
||||
refreshFailed = true
|
||||
} else if result.LockHeld {
|
||||
if p.refreshPolicy.OnLockHeld == ProviderLockHeldWaitForCache && p.tokenCache != nil {
|
||||
time.Sleep(claudeLockWaitTime)
|
||||
if token, cacheErr := p.tokenCache.GetAccessToken(ctx, cacheKey); cacheErr == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
} else {
|
||||
account = result.Account
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
} else if needsRefresh && p.tokenCache != nil {
|
||||
// Backward-compatible test path when refreshAPI is not injected.
|
||||
locked, lockErr := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if lockErr == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
} else if lockErr != nil {
|
||||
slog.Warn("claude_token_lock_failed", "account_id", account.ID, "error", lockErr)
|
||||
} else {
|
||||
time.Sleep(claudeLockWaitTime)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && strings.TrimSpace(token) != "" {
|
||||
slog.Debug("claude_token_cache_hit_after_wait", "account_id", account.ID)
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3) Populate cache with TTL.
|
||||
if p.tokenCache != nil {
|
||||
latestAccount, isStale := CheckTokenVersion(ctx, account, p.accountRepo)
|
||||
if isStale && latestAccount != nil {
|
||||
slog.Debug("claude_token_version_stale_use_latest", "account_id", account.ID)
|
||||
accessToken = latestAccount.GetCredential("access_token")
|
||||
if strings.TrimSpace(accessToken) == "" {
|
||||
return "", errors.New("access_token not found after version check")
|
||||
}
|
||||
} else {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
if p.refreshPolicy.FailureTTL > 0 {
|
||||
ttl = p.refreshPolicy.FailureTTL
|
||||
} else {
|
||||
ttl = time.Minute
|
||||
}
|
||||
slog.Debug("claude_token_cache_short_ttl", "account_id", account.ID, "reason", "refresh_failed")
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
switch {
|
||||
case until > claudeTokenCacheSkew:
|
||||
ttl = until - claudeTokenCacheSkew
|
||||
case until > 0:
|
||||
ttl = until
|
||||
default:
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
if err := p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl); err != nil {
|
||||
slog.Warn("claude_token_cache_set_failed", "account_id", account.ID, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
@@ -0,0 +1,939 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// claudeTokenCacheStub implements ClaudeTokenCache for testing
|
||||
type claudeTokenCacheStub struct {
|
||||
mu sync.Mutex
|
||||
tokens map[string]string
|
||||
getErr error
|
||||
setErr error
|
||||
deleteErr error
|
||||
lockAcquired bool
|
||||
lockErr error
|
||||
releaseLockErr error
|
||||
getCalled int32
|
||||
setCalled int32
|
||||
lockCalled int32
|
||||
unlockCalled int32
|
||||
simulateLockRace bool
|
||||
}
|
||||
|
||||
func newClaudeTokenCacheStub() *claudeTokenCacheStub {
|
||||
return &claudeTokenCacheStub{
|
||||
tokens: make(map[string]string),
|
||||
lockAcquired: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) GetAccessToken(ctx context.Context, cacheKey string) (string, error) {
|
||||
atomic.AddInt32(&s.getCalled, 1)
|
||||
if s.getErr != nil {
|
||||
return "", s.getErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
return s.tokens[cacheKey], nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) SetAccessToken(ctx context.Context, cacheKey string, token string, ttl time.Duration) error {
|
||||
atomic.AddInt32(&s.setCalled, 1)
|
||||
if s.setErr != nil {
|
||||
return s.setErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.tokens[cacheKey] = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) DeleteAccessToken(ctx context.Context, cacheKey string) error {
|
||||
if s.deleteErr != nil {
|
||||
return s.deleteErr
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.tokens, cacheKey)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) AcquireRefreshLock(ctx context.Context, cacheKey string, ttl time.Duration) (bool, error) {
|
||||
atomic.AddInt32(&s.lockCalled, 1)
|
||||
if s.lockErr != nil {
|
||||
return false, s.lockErr
|
||||
}
|
||||
if s.simulateLockRace {
|
||||
return false, nil
|
||||
}
|
||||
return s.lockAcquired, nil
|
||||
}
|
||||
|
||||
func (s *claudeTokenCacheStub) ReleaseRefreshLock(ctx context.Context, cacheKey string) error {
|
||||
atomic.AddInt32(&s.unlockCalled, 1)
|
||||
return s.releaseLockErr
|
||||
}
|
||||
|
||||
// claudeAccountRepoStub is a minimal stub implementing only the methods used by ClaudeTokenProvider
|
||||
type claudeAccountRepoStub struct {
|
||||
account *Account
|
||||
getErr error
|
||||
updateErr error
|
||||
getCalled int32
|
||||
updateCalled int32
|
||||
}
|
||||
|
||||
func (r *claudeAccountRepoStub) GetByID(ctx context.Context, id int64) (*Account, error) {
|
||||
atomic.AddInt32(&r.getCalled, 1)
|
||||
if r.getErr != nil {
|
||||
return nil, r.getErr
|
||||
}
|
||||
return r.account, nil
|
||||
}
|
||||
|
||||
func (r *claudeAccountRepoStub) Update(ctx context.Context, account *Account) error {
|
||||
atomic.AddInt32(&r.updateCalled, 1)
|
||||
if r.updateErr != nil {
|
||||
return r.updateErr
|
||||
}
|
||||
r.account = account
|
||||
return nil
|
||||
}
|
||||
|
||||
// claudeOAuthServiceStub implements OAuthService methods for testing
|
||||
type claudeOAuthServiceStub struct {
|
||||
tokenInfo *TokenInfo
|
||||
refreshErr error
|
||||
refreshCalled int32
|
||||
}
|
||||
|
||||
func (s *claudeOAuthServiceStub) RefreshAccountToken(ctx context.Context, account *Account) (*TokenInfo, error) {
|
||||
atomic.AddInt32(&s.refreshCalled, 1)
|
||||
if s.refreshErr != nil {
|
||||
return nil, s.refreshErr
|
||||
}
|
||||
return s.tokenInfo, nil
|
||||
}
|
||||
|
||||
// testClaudeTokenProvider is a test version that uses the stub OAuth service
|
||||
type testClaudeTokenProvider struct {
|
||||
accountRepo *claudeAccountRepoStub
|
||||
tokenCache *claudeTokenCacheStub
|
||||
oauthService *claudeOAuthServiceStub
|
||||
}
|
||||
|
||||
func (p *testClaudeTokenProvider) GetAccessToken(ctx context.Context, account *Account) (string, error) {
|
||||
if account == nil {
|
||||
return "", errors.New("account is nil")
|
||||
}
|
||||
if account.Platform != PlatformAnthropic || account.Type != AccountTypeOAuth {
|
||||
return "", errors.New("not an anthropic oauth account")
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// 1. Check cache
|
||||
if p.tokenCache != nil {
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check if refresh needed
|
||||
expiresAt := account.GetCredentialAsTime("expires_at")
|
||||
needsRefresh := expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew
|
||||
refreshFailed := false
|
||||
if needsRefresh && p.tokenCache != nil {
|
||||
locked, err := p.tokenCache.AcquireRefreshLock(ctx, cacheKey, 30*time.Second)
|
||||
if err == nil && locked {
|
||||
defer func() { _ = p.tokenCache.ReleaseRefreshLock(ctx, cacheKey) }()
|
||||
|
||||
// Check cache again after acquiring lock
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// Get fresh account from DB
|
||||
fresh, err := p.accountRepo.GetByID(ctx, account.ID)
|
||||
if err == nil && fresh != nil {
|
||||
account = fresh
|
||||
}
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
if expiresAt == nil || time.Until(*expiresAt) <= claudeTokenRefreshSkew {
|
||||
if p.oauthService == nil {
|
||||
refreshFailed = true // 无法刷新,标记失败
|
||||
} else {
|
||||
tokenInfo, err := p.oauthService.RefreshAccountToken(ctx, account)
|
||||
if err != nil {
|
||||
refreshFailed = true // 刷新失败,标记以使用短 TTL
|
||||
} else {
|
||||
// Build new credentials
|
||||
newCredentials := make(map[string]any)
|
||||
for k, v := range account.Credentials {
|
||||
newCredentials[k] = v
|
||||
}
|
||||
newCredentials["access_token"] = tokenInfo.AccessToken
|
||||
newCredentials["token_type"] = tokenInfo.TokenType
|
||||
newCredentials["expires_at"] = time.Now().Add(time.Duration(tokenInfo.ExpiresIn) * time.Second).Format(time.RFC3339)
|
||||
if tokenInfo.RefreshToken != "" {
|
||||
newCredentials["refresh_token"] = tokenInfo.RefreshToken
|
||||
}
|
||||
account.Credentials = newCredentials
|
||||
_ = p.accountRepo.Update(ctx, account)
|
||||
expiresAt = account.GetCredentialAsTime("expires_at")
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if p.tokenCache.simulateLockRace {
|
||||
// Wait and retry cache
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
if token, err := p.tokenCache.GetAccessToken(ctx, cacheKey); err == nil && token != "" {
|
||||
return token, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
accessToken := account.GetCredential("access_token")
|
||||
if accessToken == "" {
|
||||
return "", errors.New("access_token not found in credentials")
|
||||
}
|
||||
|
||||
// 3. Store in cache
|
||||
if p.tokenCache != nil {
|
||||
ttl := 30 * time.Minute
|
||||
if refreshFailed {
|
||||
ttl = time.Minute // 刷新失败时使用短 TTL
|
||||
} else if expiresAt != nil {
|
||||
until := time.Until(*expiresAt)
|
||||
if until > claudeTokenCacheSkew {
|
||||
ttl = until - claudeTokenCacheSkew
|
||||
} else if until > 0 {
|
||||
ttl = until
|
||||
} else {
|
||||
ttl = time.Minute
|
||||
}
|
||||
}
|
||||
_ = p.tokenCache.SetAccessToken(ctx, cacheKey, accessToken, ttl)
|
||||
}
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheHit(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
account := &Account{
|
||||
ID: 100,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "db-token",
|
||||
},
|
||||
}
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
cache.tokens[cacheKey] = "cached-token"
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "cached-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalled))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalled))
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheMiss_FromCredentials(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
// Token expires in far future, no refresh needed
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 101,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "credential-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "credential-token", token)
|
||||
|
||||
// Should have stored in cache
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
require.Equal(t, "credential-token", cache.tokens[cacheKey])
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_TokenRefresh(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
ExpiresAt: time.Now().Add(time.Hour).Unix(),
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 102,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&oauthService.refreshCalled))
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_LockRaceCondition(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.simulateLockRace = true
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 103,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "race-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
// Simulate another worker already refreshed and cached
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_NilAccount(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), nil)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "account is nil")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_WrongPlatform(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 104,
|
||||
Platform: PlatformOpenAI,
|
||||
Type: AccountTypeOAuth,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_WrongAccountType(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 105,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeAPIKey,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_SetupTokenType(t *testing.T) {
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
account := &Account{
|
||||
ID: 106,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeSetupToken,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "not an anthropic oauth account")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_NilCache(t *testing.T) {
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 107,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "nocache-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, nil, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "nocache-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheGetError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.getErr = errors.New("redis connection failed")
|
||||
|
||||
// Token doesn't need refresh
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 108,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
// Should gracefully degrade and return from credentials
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_CacheSetError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.setErr = errors.New("redis write failed")
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 109,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "still-works-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
// Should still work even if cache set fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "still-works-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_MissingAccessToken(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 110,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// missing access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_RefreshError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
refreshErr: errors.New("oauth refresh failed"),
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 111,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if refresh fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_OAuthServiceNotConfigured(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 112,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: nil, // not configured
|
||||
}
|
||||
|
||||
// Now with fallback behavior, should return existing token even if oauth service not configured
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "old-token", token) // Fallback to existing token
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_TTLCalculation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
expiresIn time.Duration
|
||||
}{
|
||||
{
|
||||
name: "far_future_expiry",
|
||||
expiresIn: 1 * time.Hour,
|
||||
},
|
||||
{
|
||||
name: "medium_expiry",
|
||||
expiresIn: 10 * time.Minute,
|
||||
},
|
||||
{
|
||||
name: "near_expiry",
|
||||
expiresIn: 6 * time.Minute,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
expiresAt := time.Now().Add(tt.expiresIn).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 200,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "test-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
|
||||
_, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify token was cached
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
require.Equal(t, "test-token", cache.tokens[cacheKey])
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_AccountRepoGetError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{
|
||||
getErr: errors.New("db connection failed"),
|
||||
}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 113,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Should still work, just using the passed-in account
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_AccountUpdateError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{
|
||||
updateErr: errors.New("db write failed"),
|
||||
}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 114,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"refresh_token": "old-refresh",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
// Should still return token even if update fails
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "refreshed-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_RefreshPreservesExistingCredentials(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "new-access-token",
|
||||
RefreshToken: "new-refresh-token",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 115,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-access-token",
|
||||
"refresh_token": "old-refresh-token",
|
||||
"expires_at": expiresAt,
|
||||
"custom_field": "should-be-preserved",
|
||||
"organization": "test-org",
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "new-access-token", token)
|
||||
|
||||
// Verify existing fields are preserved
|
||||
require.Equal(t, "should-be-preserved", accountRepo.account.Credentials["custom_field"])
|
||||
require.Equal(t, "test-org", accountRepo.account.Credentials["organization"])
|
||||
// Verify new fields are updated
|
||||
require.Equal(t, "new-access-token", accountRepo.account.Credentials["access_token"])
|
||||
require.Equal(t, "new-refresh-token", accountRepo.account.Credentials["refresh_token"])
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_DoubleCheckCacheAfterLock(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
accountRepo := &claudeAccountRepoStub{}
|
||||
oauthService := &claudeOAuthServiceStub{
|
||||
tokenInfo: &TokenInfo{
|
||||
AccessToken: "refreshed-token",
|
||||
RefreshToken: "new-refresh",
|
||||
TokenType: "Bearer",
|
||||
ExpiresIn: 3600,
|
||||
},
|
||||
}
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 116,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "old-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
accountRepo.account = account
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
|
||||
// After lock is acquired, cache should have the token (simulating another worker)
|
||||
go func() {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "cached-by-other-worker"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := &testClaudeTokenProvider{
|
||||
accountRepo: accountRepo,
|
||||
tokenCache: cache,
|
||||
oauthService: oauthService,
|
||||
}
|
||||
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
// Tests for real provider - to increase coverage
|
||||
func TestClaudeTokenProvider_Real_LockFailedWait(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon (within refresh skew) to trigger lock attempt
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 300,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
// Set token in cache after lock wait period (simulate other worker refreshing)
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
go func() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "refreshed-by-other"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_CacheHitAfterWait(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockAcquired = false // Lock acquisition fails
|
||||
|
||||
// Token expires soon
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 301,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "original-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
cacheKey := ClaudeTokenCacheKey(account)
|
||||
// Set token in cache immediately after wait starts
|
||||
go func() {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cache.mu.Lock()
|
||||
cache.tokens[cacheKey] = "winner-token"
|
||||
cache.mu.Unlock()
|
||||
}()
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_NoExpiresAt(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockAcquired = false // Prevent entering refresh logic
|
||||
|
||||
// Token with nil expires_at (no expiry set)
|
||||
account := &Account{
|
||||
ID: 302,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "no-expiry-token",
|
||||
},
|
||||
}
|
||||
|
||||
// After lock wait, return token from credentials
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "no-expiry-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_WhitespaceToken(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cacheKey := "claude:account:303"
|
||||
cache.tokens[cacheKey] = " " // Whitespace only - should be treated as empty
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 303,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "real-token",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "real-token", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_EmptyCredentialToken(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 304,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": " ", // Whitespace only
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_LockError(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
cache.lockErr = errors.New("redis lock failed")
|
||||
|
||||
// Token expires soon (within refresh skew)
|
||||
expiresAt := time.Now().Add(1 * time.Minute).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 305,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"access_token": "fallback-on-lock-error",
|
||||
"expires_at": expiresAt,
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "fallback-on-lock-error", token)
|
||||
}
|
||||
|
||||
func TestClaudeTokenProvider_Real_NilCredentials(t *testing.T) {
|
||||
cache := newClaudeTokenCacheStub()
|
||||
|
||||
expiresAt := time.Now().Add(1 * time.Hour).Format(time.RFC3339)
|
||||
account := &Account{
|
||||
ID: 306,
|
||||
Platform: PlatformAnthropic,
|
||||
Type: AccountTypeOAuth,
|
||||
Credentials: map[string]any{
|
||||
"expires_at": expiresAt,
|
||||
// No access_token
|
||||
},
|
||||
}
|
||||
|
||||
provider := NewClaudeTokenProvider(nil, cache, nil)
|
||||
token, err := provider.GetAccessToken(context.Background(), account)
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), "access_token not found")
|
||||
require.Empty(t, token)
|
||||
}
|
||||
@@ -0,0 +1,360 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
// ConcurrencyCache 定义并发控制的缓存接口
|
||||
// 使用有序集合存储槽位,按时间戳清理过期条目
|
||||
type ConcurrencyCache interface {
|
||||
// 账号槽位管理
|
||||
// 键格式: concurrency:account:{accountID}(有序集合,成员为 requestID)
|
||||
AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseAccountSlot(ctx context.Context, accountID int64, requestID string) error
|
||||
GetAccountConcurrency(ctx context.Context, accountID int64) (int, error)
|
||||
GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error)
|
||||
|
||||
// 账号等待队列(账号级)
|
||||
IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error)
|
||||
DecrementAccountWaitCount(ctx context.Context, accountID int64) error
|
||||
GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error)
|
||||
|
||||
// 用户槽位管理
|
||||
// 键格式: concurrency:user:{userID}(有序集合,成员为 requestID)
|
||||
AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int, requestID string) (bool, error)
|
||||
ReleaseUserSlot(ctx context.Context, userID int64, requestID string) error
|
||||
GetUserConcurrency(ctx context.Context, userID int64) (int, error)
|
||||
|
||||
// 等待队列计数(只在首次创建时设置 TTL)
|
||||
IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error)
|
||||
DecrementWaitCount(ctx context.Context, userID int64) error
|
||||
|
||||
// 批量负载查询(只读)
|
||||
GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error)
|
||||
GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error)
|
||||
|
||||
// 清理过期槽位(后台任务)
|
||||
CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error
|
||||
|
||||
// 启动时清理旧进程遗留槽位与等待计数
|
||||
CleanupStaleProcessSlots(ctx context.Context, activeRequestPrefix string) error
|
||||
}
|
||||
|
||||
var (
|
||||
requestIDPrefix = initRequestIDPrefix()
|
||||
requestIDCounter atomic.Uint64
|
||||
)
|
||||
|
||||
func initRequestIDPrefix() string {
|
||||
b := make([]byte, 8)
|
||||
if _, err := rand.Read(b); err == nil {
|
||||
return "r" + strconv.FormatUint(binary.BigEndian.Uint64(b), 36)
|
||||
}
|
||||
fallback := uint64(time.Now().UnixNano()) ^ (uint64(os.Getpid()) << 16)
|
||||
return "r" + strconv.FormatUint(fallback, 36)
|
||||
}
|
||||
|
||||
func RequestIDPrefix() string {
|
||||
return requestIDPrefix
|
||||
}
|
||||
|
||||
func generateRequestID() string {
|
||||
seq := requestIDCounter.Add(1)
|
||||
return requestIDPrefix + "-" + strconv.FormatUint(seq, 36)
|
||||
}
|
||||
|
||||
func (s *ConcurrencyService) CleanupStaleProcessSlots(ctx context.Context) error {
|
||||
if s == nil || s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.CleanupStaleProcessSlots(ctx, RequestIDPrefix())
|
||||
}
|
||||
|
||||
const (
|
||||
// Default extra wait slots beyond concurrency limit
|
||||
defaultExtraWaitSlots = 20
|
||||
)
|
||||
|
||||
// ConcurrencyService manages concurrent request limiting for accounts and users
|
||||
type ConcurrencyService struct {
|
||||
cache ConcurrencyCache
|
||||
}
|
||||
|
||||
// NewConcurrencyService creates a new ConcurrencyService
|
||||
func NewConcurrencyService(cache ConcurrencyCache) *ConcurrencyService {
|
||||
return &ConcurrencyService{cache: cache}
|
||||
}
|
||||
|
||||
// AcquireResult represents the result of acquiring a concurrency slot
|
||||
type AcquireResult struct {
|
||||
Acquired bool
|
||||
ReleaseFunc func() // Must be called when done (typically via defer)
|
||||
}
|
||||
|
||||
type AccountWithConcurrency struct {
|
||||
ID int64
|
||||
MaxConcurrency int
|
||||
}
|
||||
|
||||
type UserWithConcurrency struct {
|
||||
ID int64
|
||||
MaxConcurrency int
|
||||
}
|
||||
|
||||
type AccountLoadInfo struct {
|
||||
AccountID int64
|
||||
CurrentConcurrency int
|
||||
WaitingCount int
|
||||
LoadRate int // 0-100+ (percent)
|
||||
}
|
||||
|
||||
type UserLoadInfo struct {
|
||||
UserID int64
|
||||
CurrentConcurrency int
|
||||
WaitingCount int
|
||||
LoadRate int // 0-100+ (percent)
|
||||
}
|
||||
|
||||
// AcquireAccountSlot attempts to acquire a concurrency slot for an account.
|
||||
// If the account is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireAccountSlot(ctx context.Context, accountID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
// If maxConcurrency is 0 or negative, no limit
|
||||
if maxConcurrency <= 0 {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {}, // no-op
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Generate unique request ID for this slot
|
||||
requestID := generateRequestID()
|
||||
|
||||
acquired, err := s.cache.AcquireAccountSlot(ctx, accountID, maxConcurrency, requestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseAccountSlot(bgCtx, accountID, requestID); err != nil {
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: failed to release account slot for %d (req=%s): %v", accountID, requestID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &AcquireResult{
|
||||
Acquired: false,
|
||||
ReleaseFunc: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// AcquireUserSlot attempts to acquire a concurrency slot for a user.
|
||||
// If the user is at max concurrency, it waits until a slot is available or timeout.
|
||||
// Returns a release function that MUST be called when the request completes.
|
||||
func (s *ConcurrencyService) AcquireUserSlot(ctx context.Context, userID int64, maxConcurrency int) (*AcquireResult, error) {
|
||||
// If maxConcurrency is 0 or negative, no limit
|
||||
if maxConcurrency <= 0 {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {}, // no-op
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Generate unique request ID for this slot
|
||||
requestID := generateRequestID()
|
||||
|
||||
acquired, err := s.cache.AcquireUserSlot(ctx, userID, maxConcurrency, requestID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if acquired {
|
||||
return &AcquireResult{
|
||||
Acquired: true,
|
||||
ReleaseFunc: func() {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
if err := s.cache.ReleaseUserSlot(bgCtx, userID, requestID); err != nil {
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: failed to release user slot for %d (req=%s): %v", userID, requestID, err)
|
||||
}
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
return &AcquireResult{
|
||||
Acquired: false,
|
||||
ReleaseFunc: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ============================================
|
||||
// Wait Queue Count Methods
|
||||
// ============================================
|
||||
|
||||
// IncrementWaitCount attempts to increment the wait queue counter for a user.
|
||||
// Returns true if successful, false if the wait queue is full.
|
||||
// maxWait should be user.Concurrency + defaultExtraWaitSlots
|
||||
func (s *ConcurrencyService) IncrementWaitCount(ctx context.Context, userID int64, maxWait int) (bool, error) {
|
||||
if s.cache == nil {
|
||||
// Redis not available, allow request
|
||||
return true, nil
|
||||
}
|
||||
|
||||
result, err := s.cache.IncrementWaitCount(ctx, userID, maxWait)
|
||||
if err != nil {
|
||||
// On error, allow the request to proceed (fail open)
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for user %d: %v", userID, err)
|
||||
return true, nil
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecrementWaitCount decrements the wait queue counter for a user.
|
||||
// Should be called when a request completes or exits the wait queue.
|
||||
func (s *ConcurrencyService) DecrementWaitCount(ctx context.Context, userID int64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Use background context to ensure decrement even if original context is cancelled
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.cache.DecrementWaitCount(bgCtx, userID); err != nil {
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for user %d: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// IncrementAccountWaitCount increments the wait queue counter for an account.
|
||||
func (s *ConcurrencyService) IncrementAccountWaitCount(ctx context.Context, accountID int64, maxWait int) (bool, error) {
|
||||
if s.cache == nil {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
result, err := s.cache.IncrementAccountWaitCount(ctx, accountID, maxWait)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: increment wait count failed for account %d: %v", accountID, err)
|
||||
return true, nil
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecrementAccountWaitCount decrements the wait queue counter for an account.
|
||||
func (s *ConcurrencyService) DecrementAccountWaitCount(ctx context.Context, accountID int64) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := s.cache.DecrementAccountWaitCount(bgCtx, accountID); err != nil {
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: decrement wait count failed for account %d: %v", accountID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAccountWaitingCount gets current wait queue count for an account.
|
||||
func (s *ConcurrencyService) GetAccountWaitingCount(ctx context.Context, accountID int64) (int, error) {
|
||||
if s.cache == nil {
|
||||
return 0, nil
|
||||
}
|
||||
return s.cache.GetAccountWaitingCount(ctx, accountID)
|
||||
}
|
||||
|
||||
// CalculateMaxWait calculates the maximum wait queue size for a user
|
||||
// maxWait = userConcurrency + defaultExtraWaitSlots
|
||||
func CalculateMaxWait(userConcurrency int) int {
|
||||
if userConcurrency <= 0 {
|
||||
userConcurrency = 1
|
||||
}
|
||||
return userConcurrency + defaultExtraWaitSlots
|
||||
}
|
||||
|
||||
// GetAccountsLoadBatch returns load info for multiple accounts.
|
||||
func (s *ConcurrencyService) GetAccountsLoadBatch(ctx context.Context, accounts []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
if s.cache == nil {
|
||||
return map[int64]*AccountLoadInfo{}, nil
|
||||
}
|
||||
return s.cache.GetAccountsLoadBatch(ctx, accounts)
|
||||
}
|
||||
|
||||
// GetUsersLoadBatch returns load info for multiple users.
|
||||
func (s *ConcurrencyService) GetUsersLoadBatch(ctx context.Context, users []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
||||
if s.cache == nil {
|
||||
return map[int64]*UserLoadInfo{}, nil
|
||||
}
|
||||
return s.cache.GetUsersLoadBatch(ctx, users)
|
||||
}
|
||||
|
||||
// CleanupExpiredAccountSlots removes expired slots for one account (background task).
|
||||
func (s *ConcurrencyService) CleanupExpiredAccountSlots(ctx context.Context, accountID int64) error {
|
||||
if s.cache == nil {
|
||||
return nil
|
||||
}
|
||||
return s.cache.CleanupExpiredAccountSlots(ctx, accountID)
|
||||
}
|
||||
|
||||
// StartSlotCleanupWorker starts a background cleanup worker for expired account slots.
|
||||
func (s *ConcurrencyService) StartSlotCleanupWorker(accountRepo AccountRepository, interval time.Duration) {
|
||||
if s == nil || s.cache == nil || accountRepo == nil || interval <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
runCleanup := func() {
|
||||
listCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
accounts, err := accountRepo.ListSchedulable(listCtx)
|
||||
cancel()
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: list schedulable accounts failed: %v", err)
|
||||
return
|
||||
}
|
||||
for _, account := range accounts {
|
||||
accountCtx, accountCancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||
err := s.cache.CleanupExpiredAccountSlots(accountCtx, account.ID)
|
||||
accountCancel()
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.concurrency", "Warning: cleanup expired slots failed for account %d: %v", account.ID, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
runCleanup()
|
||||
for range ticker.C {
|
||||
runCleanup()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// GetAccountConcurrencyBatch gets current concurrency counts for multiple accounts
|
||||
// Returns a map of accountID -> current concurrency count
|
||||
func (s *ConcurrencyService) GetAccountConcurrencyBatch(ctx context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
if len(accountIDs) == 0 {
|
||||
return map[int64]int{}, nil
|
||||
}
|
||||
if s.cache == nil {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
result[accountID] = 0
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
return s.cache.GetAccountConcurrencyBatch(ctx, accountIDs)
|
||||
}
|
||||
@@ -0,0 +1,337 @@
|
||||
//go:build unit
|
||||
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// stubConcurrencyCacheForTest 用于并发服务单元测试的缓存桩
|
||||
type stubConcurrencyCacheForTest struct {
|
||||
acquireResult bool
|
||||
acquireErr error
|
||||
releaseErr error
|
||||
concurrency int
|
||||
concurrencyErr error
|
||||
waitAllowed bool
|
||||
waitErr error
|
||||
waitCount int
|
||||
waitCountErr error
|
||||
loadBatch map[int64]*AccountLoadInfo
|
||||
loadBatchErr error
|
||||
usersLoadBatch map[int64]*UserLoadInfo
|
||||
usersLoadErr error
|
||||
cleanupErr error
|
||||
|
||||
// 记录调用
|
||||
releasedAccountIDs []int64
|
||||
releasedRequestIDs []string
|
||||
}
|
||||
|
||||
var _ ConcurrencyCache = (*stubConcurrencyCacheForTest)(nil)
|
||||
|
||||
func (c *stubConcurrencyCacheForTest) AcquireAccountSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) {
|
||||
return c.acquireResult, c.acquireErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) ReleaseAccountSlot(_ context.Context, accountID int64, requestID string) error {
|
||||
c.releasedAccountIDs = append(c.releasedAccountIDs, accountID)
|
||||
c.releasedRequestIDs = append(c.releasedRequestIDs, requestID)
|
||||
return c.releaseErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetAccountConcurrency(_ context.Context, _ int64) (int, error) {
|
||||
return c.concurrency, c.concurrencyErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetAccountConcurrencyBatch(_ context.Context, accountIDs []int64) (map[int64]int, error) {
|
||||
result := make(map[int64]int, len(accountIDs))
|
||||
for _, accountID := range accountIDs {
|
||||
if c.concurrencyErr != nil {
|
||||
return nil, c.concurrencyErr
|
||||
}
|
||||
result[accountID] = c.concurrency
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) IncrementAccountWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
|
||||
return c.waitAllowed, c.waitErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) DecrementAccountWaitCount(_ context.Context, _ int64) error {
|
||||
return nil
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetAccountWaitingCount(_ context.Context, _ int64) (int, error) {
|
||||
return c.waitCount, c.waitCountErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) AcquireUserSlot(_ context.Context, _ int64, _ int, _ string) (bool, error) {
|
||||
return c.acquireResult, c.acquireErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) ReleaseUserSlot(_ context.Context, _ int64, _ string) error {
|
||||
return c.releaseErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetUserConcurrency(_ context.Context, _ int64) (int, error) {
|
||||
return c.concurrency, c.concurrencyErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) IncrementWaitCount(_ context.Context, _ int64, _ int) (bool, error) {
|
||||
return c.waitAllowed, c.waitErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) DecrementWaitCount(_ context.Context, _ int64) error {
|
||||
return nil
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetAccountsLoadBatch(_ context.Context, _ []AccountWithConcurrency) (map[int64]*AccountLoadInfo, error) {
|
||||
return c.loadBatch, c.loadBatchErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) GetUsersLoadBatch(_ context.Context, _ []UserWithConcurrency) (map[int64]*UserLoadInfo, error) {
|
||||
return c.usersLoadBatch, c.usersLoadErr
|
||||
}
|
||||
func (c *stubConcurrencyCacheForTest) CleanupExpiredAccountSlots(_ context.Context, _ int64) error {
|
||||
return c.cleanupErr
|
||||
}
|
||||
|
||||
func (c *stubConcurrencyCacheForTest) CleanupStaleProcessSlots(_ context.Context, _ string) error {
|
||||
return c.cleanupErr
|
||||
}
|
||||
|
||||
type trackingConcurrencyCache struct {
|
||||
stubConcurrencyCacheForTest
|
||||
cleanupPrefix string
|
||||
}
|
||||
|
||||
func (c *trackingConcurrencyCache) CleanupStaleProcessSlots(_ context.Context, prefix string) error {
|
||||
c.cleanupPrefix = prefix
|
||||
return c.cleanupErr
|
||||
}
|
||||
|
||||
func TestCleanupStaleProcessSlots_NilCache(t *testing.T) {
|
||||
svc := &ConcurrencyService{cache: nil}
|
||||
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
|
||||
}
|
||||
|
||||
func TestCleanupStaleProcessSlots_DelegatesPrefix(t *testing.T) {
|
||||
cache := &trackingConcurrencyCache{}
|
||||
svc := NewConcurrencyService(cache)
|
||||
require.NoError(t, svc.CleanupStaleProcessSlots(context.Background()))
|
||||
require.Equal(t, RequestIDPrefix(), cache.cleanupPrefix)
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlot_Success(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{acquireResult: true}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
result, err := svc.AcquireAccountSlot(context.Background(), 1, 5)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Acquired)
|
||||
require.NotNil(t, result.ReleaseFunc)
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlot_Failure(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{acquireResult: false}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
result, err := svc.AcquireAccountSlot(context.Background(), 1, 5)
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.Acquired)
|
||||
require.Nil(t, result.ReleaseFunc)
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlot_UnlimitedConcurrency(t *testing.T) {
|
||||
svc := NewConcurrencyService(&stubConcurrencyCacheForTest{})
|
||||
|
||||
for _, maxConcurrency := range []int{0, -1} {
|
||||
result, err := svc.AcquireAccountSlot(context.Background(), 1, maxConcurrency)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Acquired, "maxConcurrency=%d 应无限制通过", maxConcurrency)
|
||||
require.NotNil(t, result.ReleaseFunc, "ReleaseFunc 应为 no-op 函数")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlot_CacheError(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{acquireErr: errors.New("redis down")}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
result, err := svc.AcquireAccountSlot(context.Background(), 1, 5)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
}
|
||||
|
||||
func TestAcquireAccountSlot_ReleaseDecrements(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{acquireResult: true}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
result, err := svc.AcquireAccountSlot(context.Background(), 42, 5)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Acquired)
|
||||
|
||||
// 调用 ReleaseFunc 应释放槽位
|
||||
result.ReleaseFunc()
|
||||
|
||||
require.Len(t, cache.releasedAccountIDs, 1)
|
||||
require.Equal(t, int64(42), cache.releasedAccountIDs[0])
|
||||
require.Len(t, cache.releasedRequestIDs, 1)
|
||||
require.NotEmpty(t, cache.releasedRequestIDs[0], "requestID 不应为空")
|
||||
}
|
||||
|
||||
func TestAcquireUserSlot_IndependentFromAccount(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{acquireResult: true}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
// 用户槽位获取应独立于账户槽位
|
||||
result, err := svc.AcquireUserSlot(context.Background(), 100, 3)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Acquired)
|
||||
require.NotNil(t, result.ReleaseFunc)
|
||||
}
|
||||
|
||||
func TestAcquireUserSlot_UnlimitedConcurrency(t *testing.T) {
|
||||
svc := NewConcurrencyService(&stubConcurrencyCacheForTest{})
|
||||
|
||||
result, err := svc.AcquireUserSlot(context.Background(), 1, 0)
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.Acquired)
|
||||
}
|
||||
|
||||
func TestGenerateRequestID_UsesStablePrefixAndMonotonicCounter(t *testing.T) {
|
||||
id1 := generateRequestID()
|
||||
id2 := generateRequestID()
|
||||
require.NotEmpty(t, id1)
|
||||
require.NotEmpty(t, id2)
|
||||
|
||||
p1 := strings.Split(id1, "-")
|
||||
p2 := strings.Split(id2, "-")
|
||||
require.Len(t, p1, 2)
|
||||
require.Len(t, p2, 2)
|
||||
require.Equal(t, p1[0], p2[0], "同一进程前缀应保持一致")
|
||||
|
||||
n1, err := strconv.ParseUint(p1[1], 36, 64)
|
||||
require.NoError(t, err)
|
||||
n2, err := strconv.ParseUint(p2[1], 36, 64)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, n1+1, n2, "计数器应单调递增")
|
||||
}
|
||||
|
||||
func TestGetAccountsLoadBatch_ReturnsCorrectData(t *testing.T) {
|
||||
expected := map[int64]*AccountLoadInfo{
|
||||
1: {AccountID: 1, CurrentConcurrency: 3, WaitingCount: 0, LoadRate: 60},
|
||||
2: {AccountID: 2, CurrentConcurrency: 5, WaitingCount: 2, LoadRate: 100},
|
||||
}
|
||||
cache := &stubConcurrencyCacheForTest{loadBatch: expected}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
accounts := []AccountWithConcurrency{
|
||||
{ID: 1, MaxConcurrency: 5},
|
||||
{ID: 2, MaxConcurrency: 5},
|
||||
}
|
||||
result, err := svc.GetAccountsLoadBatch(context.Background(), accounts)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
func TestGetAccountsLoadBatch_NilCache(t *testing.T) {
|
||||
svc := &ConcurrencyService{cache: nil}
|
||||
|
||||
result, err := svc.GetAccountsLoadBatch(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, result)
|
||||
}
|
||||
|
||||
func TestIncrementWaitCount_Success(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{waitAllowed: true}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
|
||||
require.NoError(t, err)
|
||||
require.True(t, allowed)
|
||||
}
|
||||
|
||||
func TestIncrementWaitCount_QueueFull(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{waitAllowed: false}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
|
||||
require.NoError(t, err)
|
||||
require.False(t, allowed)
|
||||
}
|
||||
|
||||
func TestIncrementWaitCount_FailOpen(t *testing.T) {
|
||||
// Redis 错误时应 fail-open(允许请求通过)
|
||||
cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis timeout")}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
|
||||
require.NoError(t, err, "Redis 错误不应传播")
|
||||
require.True(t, allowed, "Redis 错误时应 fail-open")
|
||||
}
|
||||
|
||||
func TestIncrementWaitCount_NilCache(t *testing.T) {
|
||||
svc := &ConcurrencyService{cache: nil}
|
||||
|
||||
allowed, err := svc.IncrementWaitCount(context.Background(), 1, 25)
|
||||
require.NoError(t, err)
|
||||
require.True(t, allowed, "nil cache 应 fail-open")
|
||||
}
|
||||
|
||||
func TestCalculateMaxWait(t *testing.T) {
|
||||
tests := []struct {
|
||||
concurrency int
|
||||
expected int
|
||||
}{
|
||||
{5, 25}, // 5 + 20
|
||||
{1, 21}, // 1 + 20
|
||||
{0, 21}, // min(1) + 20
|
||||
{-1, 21}, // min(1) + 20
|
||||
{10, 30}, // 10 + 20
|
||||
}
|
||||
for _, tt := range tests {
|
||||
result := CalculateMaxWait(tt.concurrency)
|
||||
require.Equal(t, tt.expected, result, "CalculateMaxWait(%d)", tt.concurrency)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountWaitingCount(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{waitCount: 5}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
count, err := svc.GetAccountWaitingCount(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 5, count)
|
||||
}
|
||||
|
||||
func TestGetAccountWaitingCount_NilCache(t *testing.T) {
|
||||
svc := &ConcurrencyService{cache: nil}
|
||||
|
||||
count, err := svc.GetAccountWaitingCount(context.Background(), 1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, count)
|
||||
}
|
||||
|
||||
func TestGetAccountConcurrencyBatch(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{concurrency: 3}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
result, err := svc.GetAccountConcurrencyBatch(context.Background(), []int64{1, 2, 3})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, result, 3)
|
||||
for _, id := range []int64{1, 2, 3} {
|
||||
require.Equal(t, 3, result[id])
|
||||
}
|
||||
}
|
||||
|
||||
func TestIncrementAccountWaitCount_FailOpen(t *testing.T) {
|
||||
cache := &stubConcurrencyCacheForTest{waitErr: errors.New("redis error")}
|
||||
svc := NewConcurrencyService(cache)
|
||||
|
||||
allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10)
|
||||
require.NoError(t, err, "Redis 错误不应传播")
|
||||
require.True(t, allowed, "Redis 错误时应 fail-open")
|
||||
}
|
||||
|
||||
func TestIncrementAccountWaitCount_NilCache(t *testing.T) {
|
||||
svc := &ConcurrencyService{cache: nil}
|
||||
|
||||
allowed, err := svc.IncrementAccountWaitCount(context.Background(), 1, 10)
|
||||
require.NoError(t, err)
|
||||
require.True(t, allowed)
|
||||
}
|
||||
@@ -0,0 +1,112 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildSelectedSet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ids []string
|
||||
wantNil bool
|
||||
wantSize int
|
||||
}{
|
||||
{
|
||||
name: "nil input returns nil (backward compatible: create all)",
|
||||
ids: nil,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "empty slice returns empty map (create none)",
|
||||
ids: []string{},
|
||||
wantNil: false,
|
||||
wantSize: 0,
|
||||
},
|
||||
{
|
||||
name: "single ID",
|
||||
ids: []string{"abc-123"},
|
||||
wantNil: false,
|
||||
wantSize: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple IDs",
|
||||
ids: []string{"a", "b", "c"},
|
||||
wantNil: false,
|
||||
wantSize: 3,
|
||||
},
|
||||
{
|
||||
name: "duplicate IDs are deduplicated",
|
||||
ids: []string{"a", "a", "b"},
|
||||
wantNil: false,
|
||||
wantSize: 2,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := buildSelectedSet(tt.ids)
|
||||
if tt.wantNil {
|
||||
if got != nil {
|
||||
t.Errorf("buildSelectedSet(%v) = %v, want nil", tt.ids, got)
|
||||
}
|
||||
return
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("buildSelectedSet(%v) = nil, want non-nil map", tt.ids)
|
||||
}
|
||||
if len(got) != tt.wantSize {
|
||||
t.Errorf("buildSelectedSet(%v) has %d entries, want %d", tt.ids, len(got), tt.wantSize)
|
||||
}
|
||||
// Verify all unique IDs are present
|
||||
for _, id := range tt.ids {
|
||||
if _, ok := got[id]; !ok {
|
||||
t.Errorf("buildSelectedSet(%v) missing key %q", tt.ids, id)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldCreateAccount(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
crsID string
|
||||
selectedSet map[string]struct{}
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil set allows all (backward compatible)",
|
||||
crsID: "any-id",
|
||||
selectedSet: nil,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty set blocks all",
|
||||
crsID: "any-id",
|
||||
selectedSet: map[string]struct{}{},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "ID in set is allowed",
|
||||
crsID: "abc-123",
|
||||
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "ID not in set is blocked",
|
||||
crsID: "xyz-789",
|
||||
selectedSet: map[string]struct{}{"abc-123": {}, "def-456": {}},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := shouldCreateAccount(tt.crsID, tt.selectedSet)
|
||||
if got != tt.want {
|
||||
t.Errorf("shouldCreateAccount(%q, %v) = %v, want %v",
|
||||
tt.crsID, tt.selectedSet, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,322 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDashboardAggregationTimeout = 2 * time.Minute
|
||||
defaultDashboardAggregationBackfillTimeout = 30 * time.Minute
|
||||
dashboardAggregationRetentionInterval = 6 * time.Hour
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrDashboardBackfillDisabled 当配置禁用回填时返回。
|
||||
ErrDashboardBackfillDisabled = errors.New("仪表盘聚合回填已禁用")
|
||||
// ErrDashboardBackfillTooLarge 当回填跨度超过限制时返回。
|
||||
ErrDashboardBackfillTooLarge = errors.New("回填时间跨度过大")
|
||||
errDashboardAggregationRunning = errors.New("聚合作业正在运行")
|
||||
)
|
||||
|
||||
// DashboardAggregationRepository 定义仪表盘预聚合仓储接口。
|
||||
type DashboardAggregationRepository interface {
|
||||
AggregateRange(ctx context.Context, start, end time.Time) error
|
||||
// RecomputeRange 重新计算指定时间范围内的聚合数据(包含活跃用户等派生表)。
|
||||
// 设计目的:当 usage_logs 被批量删除/回滚后,确保聚合表可恢复一致性。
|
||||
RecomputeRange(ctx context.Context, start, end time.Time) error
|
||||
GetAggregationWatermark(ctx context.Context) (time.Time, error)
|
||||
UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error
|
||||
CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error
|
||||
CleanupUsageLogs(ctx context.Context, cutoff time.Time) error
|
||||
CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error
|
||||
EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error
|
||||
}
|
||||
|
||||
// DashboardAggregationService 负责定时聚合与回填。
|
||||
type DashboardAggregationService struct {
|
||||
repo DashboardAggregationRepository
|
||||
timingWheel *TimingWheelService
|
||||
cfg config.DashboardAggregationConfig
|
||||
running int32
|
||||
lastRetentionCleanup atomic.Value // time.Time
|
||||
}
|
||||
|
||||
// NewDashboardAggregationService 创建聚合服务。
|
||||
func NewDashboardAggregationService(repo DashboardAggregationRepository, timingWheel *TimingWheelService, cfg *config.Config) *DashboardAggregationService {
|
||||
var aggCfg config.DashboardAggregationConfig
|
||||
if cfg != nil {
|
||||
aggCfg = cfg.DashboardAgg
|
||||
}
|
||||
return &DashboardAggregationService{
|
||||
repo: repo,
|
||||
timingWheel: timingWheel,
|
||||
cfg: aggCfg,
|
||||
}
|
||||
}
|
||||
|
||||
// Start 启动定时聚合作业(重启生效配置)。
|
||||
func (s *DashboardAggregationService) Start() {
|
||||
if s == nil || s.repo == nil || s.timingWheel == nil {
|
||||
return
|
||||
}
|
||||
if !s.cfg.Enabled {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业已禁用")
|
||||
return
|
||||
}
|
||||
|
||||
interval := time.Duration(s.cfg.IntervalSeconds) * time.Second
|
||||
if interval <= 0 {
|
||||
interval = time.Minute
|
||||
}
|
||||
|
||||
if s.cfg.RecomputeDays > 0 {
|
||||
go s.recomputeRecentDays()
|
||||
}
|
||||
|
||||
s.timingWheel.ScheduleRecurring("dashboard:aggregation", interval, func() {
|
||||
s.runScheduledAggregation()
|
||||
})
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合作业启动 (interval=%v, lookback=%ds)", interval, s.cfg.LookbackSeconds)
|
||||
if !s.cfg.BackfillEnabled {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填已禁用,如需补齐保留窗口以外历史数据请手动回填")
|
||||
}
|
||||
}
|
||||
|
||||
// TriggerBackfill 触发回填(异步)。
|
||||
func (s *DashboardAggregationService) TriggerBackfill(start, end time.Time) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return errors.New("聚合服务未初始化")
|
||||
}
|
||||
if !s.cfg.BackfillEnabled {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填被拒绝: backfill_enabled=false")
|
||||
return ErrDashboardBackfillDisabled
|
||||
}
|
||||
if !end.After(start) {
|
||||
return errors.New("回填时间范围无效")
|
||||
}
|
||||
if s.cfg.BackfillMaxDays > 0 {
|
||||
maxRange := time.Duration(s.cfg.BackfillMaxDays) * 24 * time.Hour
|
||||
if end.Sub(start) > maxRange {
|
||||
return ErrDashboardBackfillTooLarge
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||
defer cancel()
|
||||
if err := s.backfillRange(ctx, start, end); err != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填失败: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
// TriggerRecomputeRange 触发指定范围的重新计算(异步)。
|
||||
// 与 TriggerBackfill 不同:
|
||||
// - 不依赖 backfill_enabled(这是内部一致性修复)
|
||||
// - 不更新 watermark(避免影响正常增量聚合游标)
|
||||
func (s *DashboardAggregationService) TriggerRecomputeRange(start, end time.Time) error {
|
||||
if s == nil || s.repo == nil {
|
||||
return errors.New("聚合服务未初始化")
|
||||
}
|
||||
if !s.cfg.Enabled {
|
||||
return errors.New("聚合服务已禁用")
|
||||
}
|
||||
if !end.After(start) {
|
||||
return errors.New("重新计算时间范围无效")
|
||||
}
|
||||
|
||||
go func() {
|
||||
const maxRetries = 3
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||
err := s.recomputeRange(ctx, start, end)
|
||||
cancel()
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, errDashboardAggregationRunning) {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算失败: %v", err)
|
||||
return
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算放弃: 聚合作业持续占用")
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) recomputeRecentDays() {
|
||||
days := s.cfg.RecomputeDays
|
||||
if days <= 0 {
|
||||
return
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
start := now.AddDate(0, 0, -days)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationBackfillTimeout)
|
||||
defer cancel()
|
||||
if err := s.backfillRange(ctx, start, now); err != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 启动重算失败: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) recomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return errDashboardAggregationRunning
|
||||
}
|
||||
defer atomic.StoreInt32(&s.running, 0)
|
||||
|
||||
jobStart := time.Now().UTC()
|
||||
if err := s.repo.RecomputeRange(ctx, start, end); err != nil {
|
||||
return err
|
||||
}
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 重新计算完成 (start=%s end=%s duration=%s)",
|
||||
start.UTC().Format(time.RFC3339),
|
||||
end.UTC().Format(time.RFC3339),
|
||||
time.Since(jobStart).String(),
|
||||
)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) runScheduledAggregation() {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return
|
||||
}
|
||||
defer atomic.StoreInt32(&s.running, 0)
|
||||
|
||||
jobStart := time.Now().UTC()
|
||||
ctx, cancel := context.WithTimeout(context.Background(), defaultDashboardAggregationTimeout)
|
||||
defer cancel()
|
||||
|
||||
now := time.Now().UTC()
|
||||
last, err := s.repo.GetAggregationWatermark(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 读取水位失败: %v", err)
|
||||
last = time.Unix(0, 0).UTC()
|
||||
}
|
||||
|
||||
lookback := time.Duration(s.cfg.LookbackSeconds) * time.Second
|
||||
epoch := time.Unix(0, 0).UTC()
|
||||
start := last.Add(-lookback)
|
||||
if !last.After(epoch) {
|
||||
retentionDays := s.cfg.Retention.UsageLogsDays
|
||||
if retentionDays <= 0 {
|
||||
retentionDays = 1
|
||||
}
|
||||
start = truncateToDayUTC(now.AddDate(0, 0, -retentionDays))
|
||||
} else if start.After(now) {
|
||||
start = now.Add(-lookback)
|
||||
}
|
||||
|
||||
if err := s.aggregateRange(ctx, start, now); err != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
updateErr := s.repo.UpdateAggregationWatermark(ctx, now)
|
||||
if updateErr != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr)
|
||||
}
|
||||
slog.Debug("[DashboardAggregation] 聚合完成",
|
||||
"start", start.Format(time.RFC3339),
|
||||
"end", now.Format(time.RFC3339),
|
||||
"duration", time.Since(jobStart).String(),
|
||||
"watermark_updated", updateErr == nil,
|
||||
)
|
||||
|
||||
s.maybeCleanupRetention(ctx, now)
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) backfillRange(ctx context.Context, start, end time.Time) error {
|
||||
if !atomic.CompareAndSwapInt32(&s.running, 0, 1) {
|
||||
return errDashboardAggregationRunning
|
||||
}
|
||||
defer atomic.StoreInt32(&s.running, 0)
|
||||
|
||||
jobStart := time.Now().UTC()
|
||||
startUTC := start.UTC()
|
||||
endUTC := end.UTC()
|
||||
if !endUTC.After(startUTC) {
|
||||
return errors.New("回填时间范围无效")
|
||||
}
|
||||
|
||||
cursor := truncateToDayUTC(startUTC)
|
||||
for cursor.Before(endUTC) {
|
||||
windowEnd := cursor.Add(24 * time.Hour)
|
||||
if windowEnd.After(endUTC) {
|
||||
windowEnd = endUTC
|
||||
}
|
||||
if err := s.aggregateRange(ctx, cursor, windowEnd); err != nil {
|
||||
return err
|
||||
}
|
||||
cursor = windowEnd
|
||||
}
|
||||
|
||||
updateErr := s.repo.UpdateAggregationWatermark(ctx, endUTC)
|
||||
if updateErr != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 更新水位失败: %v", updateErr)
|
||||
}
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 回填聚合完成 (start=%s end=%s duration=%s watermark_updated=%t)",
|
||||
startUTC.Format(time.RFC3339),
|
||||
endUTC.Format(time.RFC3339),
|
||||
time.Since(jobStart).String(),
|
||||
updateErr == nil,
|
||||
)
|
||||
|
||||
s.maybeCleanupRetention(ctx, endUTC)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) aggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
if !end.After(start) {
|
||||
return nil
|
||||
}
|
||||
if err := s.repo.EnsureUsageLogsPartitions(ctx, end); err != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 分区检查失败: %v", err)
|
||||
}
|
||||
return s.repo.AggregateRange(ctx, start, end)
|
||||
}
|
||||
|
||||
func (s *DashboardAggregationService) maybeCleanupRetention(ctx context.Context, now time.Time) {
|
||||
lastAny := s.lastRetentionCleanup.Load()
|
||||
if lastAny != nil {
|
||||
if last, ok := lastAny.(time.Time); ok && now.Sub(last) < dashboardAggregationRetentionInterval {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
hourlyCutoff := now.AddDate(0, 0, -s.cfg.Retention.HourlyDays)
|
||||
dailyCutoff := now.AddDate(0, 0, -s.cfg.Retention.DailyDays)
|
||||
usageCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageLogsDays)
|
||||
dedupCutoff := now.AddDate(0, 0, -s.cfg.Retention.UsageBillingDedupDays)
|
||||
|
||||
aggErr := s.repo.CleanupAggregates(ctx, hourlyCutoff, dailyCutoff)
|
||||
if aggErr != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] 聚合保留清理失败: %v", aggErr)
|
||||
}
|
||||
usageErr := s.repo.CleanupUsageLogs(ctx, usageCutoff)
|
||||
if usageErr != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_logs 保留清理失败: %v", usageErr)
|
||||
}
|
||||
dedupErr := s.repo.CleanupUsageBillingDedup(ctx, dedupCutoff)
|
||||
if dedupErr != nil {
|
||||
logger.LegacyPrintf("service.dashboard_aggregation", "[DashboardAggregation] usage_billing_dedup 保留清理失败: %v", dedupErr)
|
||||
}
|
||||
if aggErr == nil && usageErr == nil && dedupErr == nil {
|
||||
s.lastRetentionCleanup.Store(now)
|
||||
}
|
||||
}
|
||||
|
||||
func truncateToDayUTC(t time.Time) time.Time {
|
||||
t = t.UTC()
|
||||
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC)
|
||||
}
|
||||
@@ -0,0 +1,168 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type dashboardAggregationRepoTestStub struct {
|
||||
aggregateCalls int
|
||||
recomputeCalls int
|
||||
cleanupUsageCalls int
|
||||
cleanupDedupCalls int
|
||||
ensurePartitionCalls int
|
||||
lastStart time.Time
|
||||
lastEnd time.Time
|
||||
watermark time.Time
|
||||
aggregateErr error
|
||||
cleanupAggregatesErr error
|
||||
cleanupUsageErr error
|
||||
cleanupDedupErr error
|
||||
ensurePartitionErr error
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
s.aggregateCalls++
|
||||
s.lastStart = start
|
||||
s.lastEnd = end
|
||||
return s.aggregateErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
s.recomputeCalls++
|
||||
return s.AggregateRange(ctx, start, end)
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
return s.watermark, nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
return s.cleanupAggregatesErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
s.cleanupUsageCalls++
|
||||
return s.cleanupUsageErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
s.cleanupDedupCalls++
|
||||
return s.cleanupDedupErr
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoTestStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
s.ensurePartitionCalls++
|
||||
return s.ensurePartitionErr
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_RunScheduledAggregation_EpochUsesRetentionStart(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{watermark: time.Unix(0, 0).UTC()}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
IntervalSeconds: 60,
|
||||
LookbackSeconds: 120,
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.runScheduledAggregation()
|
||||
|
||||
require.Equal(t, 1, repo.aggregateCalls)
|
||||
require.False(t, repo.lastEnd.IsZero())
|
||||
require.Equal(t, truncateToDayUTC(repo.lastEnd.AddDate(0, 0, -1)), repo.lastStart)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_CleanupRetentionFailure_DoesNotRecord(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{cleanupAggregatesErr: errors.New("清理失败")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||
|
||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||
require.Equal(t, 1, repo.cleanupUsageCalls)
|
||||
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_CleanupDedupFailure_DoesNotRecord(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{cleanupDedupErr: errors.New("dedup cleanup failed")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.maybeCleanupRetention(context.Background(), time.Now().UTC())
|
||||
|
||||
require.Nil(t, svc.lastRetentionCleanup.Load())
|
||||
require.Equal(t, 1, repo.cleanupDedupCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_PartitionFailure_DoesNotAggregate(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{ensurePartitionErr: errors.New("partition failed")}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
IntervalSeconds: 60,
|
||||
LookbackSeconds: 120,
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 1,
|
||||
UsageBillingDedupDays: 2,
|
||||
HourlyDays: 1,
|
||||
DailyDays: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
svc.runScheduledAggregation()
|
||||
|
||||
require.Equal(t, 1, repo.ensurePartitionCalls)
|
||||
require.Equal(t, 1, repo.aggregateCalls)
|
||||
}
|
||||
|
||||
func TestDashboardAggregationService_TriggerBackfill_TooLarge(t *testing.T) {
|
||||
repo := &dashboardAggregationRepoTestStub{}
|
||||
svc := &DashboardAggregationService{
|
||||
repo: repo,
|
||||
cfg: config.DashboardAggregationConfig{
|
||||
BackfillEnabled: true,
|
||||
BackfillMaxDays: 1,
|
||||
},
|
||||
}
|
||||
|
||||
start := time.Now().AddDate(0, 0, -3)
|
||||
end := time.Now()
|
||||
err := svc.TriggerBackfill(start, end)
|
||||
require.ErrorIs(t, err, ErrDashboardBackfillTooLarge)
|
||||
require.Equal(t, 0, repo.aggregateCalls)
|
||||
}
|
||||
@@ -0,0 +1,352 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultDashboardStatsFreshTTL = 15 * time.Second
|
||||
defaultDashboardStatsCacheTTL = 30 * time.Second
|
||||
defaultDashboardStatsRefreshTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// ErrDashboardStatsCacheMiss 标记仪表盘缓存未命中。
|
||||
var ErrDashboardStatsCacheMiss = errors.New("仪表盘缓存未命中")
|
||||
|
||||
// DashboardStatsCache 定义仪表盘统计缓存接口。
|
||||
type DashboardStatsCache interface {
|
||||
GetDashboardStats(ctx context.Context) (string, error)
|
||||
SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error
|
||||
DeleteDashboardStats(ctx context.Context) error
|
||||
}
|
||||
|
||||
type dashboardStatsRangeFetcher interface {
|
||||
GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error)
|
||||
}
|
||||
|
||||
type dashboardStatsCacheEntry struct {
|
||||
Stats *usagestats.DashboardStats `json:"stats"`
|
||||
UpdatedAt int64 `json:"updated_at"`
|
||||
}
|
||||
|
||||
// DashboardService 提供管理员仪表盘统计服务。
|
||||
type DashboardService struct {
|
||||
usageRepo UsageLogRepository
|
||||
aggRepo DashboardAggregationRepository
|
||||
cache DashboardStatsCache
|
||||
cacheFreshTTL time.Duration
|
||||
cacheTTL time.Duration
|
||||
refreshTimeout time.Duration
|
||||
refreshing int32
|
||||
aggEnabled bool
|
||||
aggInterval time.Duration
|
||||
aggLookback time.Duration
|
||||
aggUsageDays int
|
||||
}
|
||||
|
||||
func NewDashboardService(usageRepo UsageLogRepository, aggRepo DashboardAggregationRepository, cache DashboardStatsCache, cfg *config.Config) *DashboardService {
|
||||
freshTTL := defaultDashboardStatsFreshTTL
|
||||
cacheTTL := defaultDashboardStatsCacheTTL
|
||||
refreshTimeout := defaultDashboardStatsRefreshTimeout
|
||||
aggEnabled := true
|
||||
aggInterval := time.Minute
|
||||
aggLookback := 2 * time.Minute
|
||||
aggUsageDays := 90
|
||||
if cfg != nil {
|
||||
if !cfg.Dashboard.Enabled {
|
||||
cache = nil
|
||||
}
|
||||
if cfg.Dashboard.StatsFreshTTLSeconds > 0 {
|
||||
freshTTL = time.Duration(cfg.Dashboard.StatsFreshTTLSeconds) * time.Second
|
||||
}
|
||||
if cfg.Dashboard.StatsTTLSeconds > 0 {
|
||||
cacheTTL = time.Duration(cfg.Dashboard.StatsTTLSeconds) * time.Second
|
||||
}
|
||||
if cfg.Dashboard.StatsRefreshTimeoutSeconds > 0 {
|
||||
refreshTimeout = time.Duration(cfg.Dashboard.StatsRefreshTimeoutSeconds) * time.Second
|
||||
}
|
||||
aggEnabled = cfg.DashboardAgg.Enabled
|
||||
if cfg.DashboardAgg.IntervalSeconds > 0 {
|
||||
aggInterval = time.Duration(cfg.DashboardAgg.IntervalSeconds) * time.Second
|
||||
}
|
||||
if cfg.DashboardAgg.LookbackSeconds > 0 {
|
||||
aggLookback = time.Duration(cfg.DashboardAgg.LookbackSeconds) * time.Second
|
||||
}
|
||||
if cfg.DashboardAgg.Retention.UsageLogsDays > 0 {
|
||||
aggUsageDays = cfg.DashboardAgg.Retention.UsageLogsDays
|
||||
}
|
||||
}
|
||||
if aggRepo == nil {
|
||||
aggEnabled = false
|
||||
}
|
||||
return &DashboardService{
|
||||
usageRepo: usageRepo,
|
||||
aggRepo: aggRepo,
|
||||
cache: cache,
|
||||
cacheFreshTTL: freshTTL,
|
||||
cacheTTL: cacheTTL,
|
||||
refreshTimeout: refreshTimeout,
|
||||
aggEnabled: aggEnabled,
|
||||
aggInterval: aggInterval,
|
||||
aggLookback: aggLookback,
|
||||
aggUsageDays: aggUsageDays,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
if s.cache != nil {
|
||||
cached, fresh, err := s.getCachedDashboardStats(ctx)
|
||||
if err == nil && cached != nil {
|
||||
s.refreshAggregationStaleness(cached)
|
||||
if !fresh {
|
||||
s.refreshDashboardStatsAsync()
|
||||
}
|
||||
return cached, nil
|
||||
}
|
||||
if err != nil && !errors.Is(err, ErrDashboardStatsCacheMiss) {
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存读取失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
stats, err := s.refreshDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get dashboard stats: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUsageTrendWithFilters(ctx context.Context, startTime, endTime time.Time, granularity string, userID, apiKeyID, accountID, groupID int64, model string, requestType *int16, stream *bool, billingType *int8) ([]usagestats.TrendDataPoint, error) {
|
||||
trend, err := s.usageRepo.GetUsageTrendWithFilters(ctx, startTime, endTime, granularity, userID, apiKeyID, accountID, groupID, model, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get usage trend with filters: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetModelStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.ModelStat, error) {
|
||||
stats, err := s.usageRepo.GetModelStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get model stats with filters: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetGroupStatsWithFilters(ctx context.Context, startTime, endTime time.Time, userID, apiKeyID, accountID, groupID int64, requestType *int16, stream *bool, billingType *int8) ([]usagestats.GroupStat, error) {
|
||||
stats, err := s.usageRepo.GetGroupStatsWithFilters(ctx, startTime, endTime, userID, apiKeyID, accountID, groupID, requestType, stream, billingType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group stats with filters: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) getCachedDashboardStats(ctx context.Context) (*usagestats.DashboardStats, bool, error) {
|
||||
data, err := s.cache.GetDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
var entry dashboardStatsCacheEntry
|
||||
if err := json.Unmarshal([]byte(data), &entry); err != nil {
|
||||
s.evictDashboardStatsCache(err)
|
||||
return nil, false, ErrDashboardStatsCacheMiss
|
||||
}
|
||||
if entry.Stats == nil {
|
||||
s.evictDashboardStatsCache(errors.New("仪表盘缓存缺少统计数据"))
|
||||
return nil, false, ErrDashboardStatsCacheMiss
|
||||
}
|
||||
|
||||
age := time.Since(time.Unix(entry.UpdatedAt, 0))
|
||||
return entry.Stats, age <= s.cacheFreshTTL, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) refreshDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
stats, err := s.fetchDashboardStats(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.applyAggregationStatus(ctx, stats)
|
||||
cacheCtx, cancel := s.cacheOperationContext()
|
||||
defer cancel()
|
||||
s.saveDashboardStatsCache(cacheCtx, stats)
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) refreshDashboardStatsAsync() {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
if !atomic.CompareAndSwapInt32(&s.refreshing, 0, 1) {
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer atomic.StoreInt32(&s.refreshing, 0)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.refreshTimeout)
|
||||
defer cancel()
|
||||
|
||||
stats, err := s.fetchDashboardStats(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异步刷新失败: %v", err)
|
||||
return
|
||||
}
|
||||
s.applyAggregationStatus(ctx, stats)
|
||||
cacheCtx, cancel := s.cacheOperationContext()
|
||||
defer cancel()
|
||||
s.saveDashboardStatsCache(cacheCtx, stats)
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *DashboardService) fetchDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
if !s.aggEnabled {
|
||||
if fetcher, ok := s.usageRepo.(dashboardStatsRangeFetcher); ok {
|
||||
now := time.Now().UTC()
|
||||
start := truncateToDayUTC(now.AddDate(0, 0, -s.aggUsageDays))
|
||||
return fetcher.GetDashboardStatsWithRange(ctx, start, now)
|
||||
}
|
||||
}
|
||||
return s.usageRepo.GetDashboardStats(ctx)
|
||||
}
|
||||
|
||||
func (s *DashboardService) saveDashboardStatsCache(ctx context.Context, stats *usagestats.DashboardStats) {
|
||||
if s.cache == nil || stats == nil {
|
||||
return
|
||||
}
|
||||
|
||||
entry := dashboardStatsCacheEntry{
|
||||
Stats: stats,
|
||||
UpdatedAt: time.Now().Unix(),
|
||||
}
|
||||
data, err := json.Marshal(entry)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存序列化失败: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.cache.SetDashboardStats(ctx, string(data), s.cacheTTL); err != nil {
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存写入失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardService) evictDashboardStatsCache(reason error) {
|
||||
if s.cache == nil {
|
||||
return
|
||||
}
|
||||
cacheCtx, cancel := s.cacheOperationContext()
|
||||
defer cancel()
|
||||
|
||||
if err := s.cache.DeleteDashboardStats(cacheCtx); err != nil {
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存清理失败: %v", err)
|
||||
}
|
||||
if reason != nil {
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 仪表盘缓存异常,已清理: %v", reason)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DashboardService) cacheOperationContext() (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), s.refreshTimeout)
|
||||
}
|
||||
|
||||
func (s *DashboardService) applyAggregationStatus(ctx context.Context, stats *usagestats.DashboardStats) {
|
||||
if stats == nil {
|
||||
return
|
||||
}
|
||||
updatedAt := s.fetchAggregationUpdatedAt(ctx)
|
||||
stats.StatsUpdatedAt = updatedAt.UTC().Format(time.RFC3339)
|
||||
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
|
||||
}
|
||||
|
||||
func (s *DashboardService) refreshAggregationStaleness(stats *usagestats.DashboardStats) {
|
||||
if stats == nil {
|
||||
return
|
||||
}
|
||||
updatedAt := parseStatsUpdatedAt(stats.StatsUpdatedAt)
|
||||
stats.StatsStale = s.isAggregationStale(updatedAt, time.Now().UTC())
|
||||
}
|
||||
|
||||
func (s *DashboardService) fetchAggregationUpdatedAt(ctx context.Context) time.Time {
|
||||
if s.aggRepo == nil {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
updatedAt, err := s.aggRepo.GetAggregationWatermark(ctx)
|
||||
if err != nil {
|
||||
logger.LegacyPrintf("service.dashboard", "[Dashboard] 读取聚合水位失败: %v", err)
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
if updatedAt.IsZero() {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
return updatedAt.UTC()
|
||||
}
|
||||
|
||||
func (s *DashboardService) isAggregationStale(updatedAt, now time.Time) bool {
|
||||
if !s.aggEnabled {
|
||||
return true
|
||||
}
|
||||
epoch := time.Unix(0, 0).UTC()
|
||||
if !updatedAt.After(epoch) {
|
||||
return true
|
||||
}
|
||||
threshold := s.aggInterval + s.aggLookback
|
||||
return now.Sub(updatedAt) > threshold
|
||||
}
|
||||
|
||||
func parseStatsUpdatedAt(raw string) time.Time {
|
||||
if raw == "" {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
parsed, err := time.Parse(time.RFC3339, raw)
|
||||
if err != nil {
|
||||
return time.Unix(0, 0).UTC()
|
||||
}
|
||||
return parsed.UTC()
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetAPIKeyUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.APIKeyUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetAPIKeyUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get api key usage trend: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUserUsageTrend(ctx context.Context, startTime, endTime time.Time, granularity string, limit int) ([]usagestats.UserUsageTrendPoint, error) {
|
||||
trend, err := s.usageRepo.GetUserUsageTrend(ctx, startTime, endTime, granularity, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user usage trend: %w", err)
|
||||
}
|
||||
return trend, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetUserSpendingRanking(ctx context.Context, startTime, endTime time.Time, limit int) (*usagestats.UserSpendingRankingResponse, error) {
|
||||
ranking, err := s.usageRepo.GetUserSpendingRanking(ctx, startTime, endTime, limit)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user spending ranking: %w", err)
|
||||
}
|
||||
return ranking, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchUserUsageStats(ctx context.Context, userIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchUserUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchUserUsageStats(ctx, userIDs, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch user usage stats: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (s *DashboardService) GetBatchAPIKeyUsageStats(ctx context.Context, apiKeyIDs []int64, startTime, endTime time.Time) (map[int64]*usagestats.BatchAPIKeyUsageStats, error) {
|
||||
stats, err := s.usageRepo.GetBatchAPIKeyUsageStats(ctx, apiKeyIDs, startTime, endTime)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get batch api key usage stats: %w", err)
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
@@ -0,0 +1,395 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/Wei-Shaw/sub2api/internal/config"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/usagestats"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type usageRepoStub struct {
|
||||
UsageLogRepository
|
||||
stats *usagestats.DashboardStats
|
||||
rangeStats *usagestats.DashboardStats
|
||||
err error
|
||||
rangeErr error
|
||||
calls int32
|
||||
rangeCalls int32
|
||||
rangeStart time.Time
|
||||
rangeEnd time.Time
|
||||
onCall chan struct{}
|
||||
}
|
||||
|
||||
func (s *usageRepoStub) GetDashboardStats(ctx context.Context) (*usagestats.DashboardStats, error) {
|
||||
atomic.AddInt32(&s.calls, 1)
|
||||
if s.onCall != nil {
|
||||
select {
|
||||
case s.onCall <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
if s.err != nil {
|
||||
return nil, s.err
|
||||
}
|
||||
return s.stats, nil
|
||||
}
|
||||
|
||||
func (s *usageRepoStub) GetDashboardStatsWithRange(ctx context.Context, start, end time.Time) (*usagestats.DashboardStats, error) {
|
||||
atomic.AddInt32(&s.rangeCalls, 1)
|
||||
s.rangeStart = start
|
||||
s.rangeEnd = end
|
||||
if s.rangeErr != nil {
|
||||
return nil, s.rangeErr
|
||||
}
|
||||
if s.rangeStats != nil {
|
||||
return s.rangeStats, nil
|
||||
}
|
||||
return s.stats, nil
|
||||
}
|
||||
|
||||
type dashboardCacheStub struct {
|
||||
get func(ctx context.Context) (string, error)
|
||||
set func(ctx context.Context, data string, ttl time.Duration) error
|
||||
del func(ctx context.Context) error
|
||||
getCalls int32
|
||||
setCalls int32
|
||||
delCalls int32
|
||||
lastSetMu sync.Mutex
|
||||
lastSet string
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) GetDashboardStats(ctx context.Context) (string, error) {
|
||||
atomic.AddInt32(&c.getCalls, 1)
|
||||
if c.get != nil {
|
||||
return c.get(ctx)
|
||||
}
|
||||
return "", ErrDashboardStatsCacheMiss
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) SetDashboardStats(ctx context.Context, data string, ttl time.Duration) error {
|
||||
atomic.AddInt32(&c.setCalls, 1)
|
||||
c.lastSetMu.Lock()
|
||||
c.lastSet = data
|
||||
c.lastSetMu.Unlock()
|
||||
if c.set != nil {
|
||||
return c.set(ctx, data, ttl)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) DeleteDashboardStats(ctx context.Context) error {
|
||||
atomic.AddInt32(&c.delCalls, 1)
|
||||
if c.del != nil {
|
||||
return c.del(ctx)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dashboardAggregationRepoStub struct {
|
||||
watermark time.Time
|
||||
err error
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) AggregateRange(ctx context.Context, start, end time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) RecomputeRange(ctx context.Context, start, end time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) GetAggregationWatermark(ctx context.Context) (time.Time, error) {
|
||||
if s.err != nil {
|
||||
return time.Time{}, s.err
|
||||
}
|
||||
return s.watermark, nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) UpdateAggregationWatermark(ctx context.Context, aggregatedAt time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) CleanupAggregates(ctx context.Context, hourlyCutoff, dailyCutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) CleanupUsageLogs(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) CleanupUsageBillingDedup(ctx context.Context, cutoff time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *dashboardAggregationRepoStub) EnsureUsageLogsPartitions(ctx context.Context, now time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *dashboardCacheStub) readLastEntry(t *testing.T) dashboardStatsCacheEntry {
|
||||
t.Helper()
|
||||
c.lastSetMu.Lock()
|
||||
data := c.lastSet
|
||||
c.lastSetMu.Unlock()
|
||||
|
||||
var entry dashboardStatsCacheEntry
|
||||
err := json.Unmarshal([]byte(data), &entry)
|
||||
require.NoError(t, err)
|
||||
return entry
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheHitFresh(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{
|
||||
TotalUsers: 10,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
entry := dashboardStatsCacheEntry{
|
||||
Stats: stats,
|
||||
UpdatedAt: time.Now().Unix(),
|
||||
}
|
||||
payload, err := json.Marshal(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return string(payload), nil
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{
|
||||
stats: &usagestats.DashboardStats{TotalUsers: 99},
|
||||
}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheMiss_StoresCache(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{
|
||||
TotalUsers: 7,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "", ErrDashboardStatsCacheMiss
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.getCalls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.setCalls))
|
||||
entry := cache.readLastEntry(t)
|
||||
require.Equal(t, stats, entry.Stats)
|
||||
require.WithinDuration(t, time.Now(), time.Unix(entry.UpdatedAt, 0), time.Second)
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheDisabled_SkipsCache(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{
|
||||
TotalUsers: 3,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "", nil
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.getCalls))
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&cache.setCalls))
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheHitStale_TriggersAsyncRefresh(t *testing.T) {
|
||||
staleStats := &usagestats.DashboardStats{
|
||||
TotalUsers: 11,
|
||||
StatsUpdatedAt: time.Unix(0, 0).UTC().Format(time.RFC3339),
|
||||
StatsStale: true,
|
||||
}
|
||||
entry := dashboardStatsCacheEntry{
|
||||
Stats: staleStats,
|
||||
UpdatedAt: time.Now().Add(-defaultDashboardStatsFreshTTL * 2).Unix(),
|
||||
}
|
||||
payload, err := json.Marshal(entry)
|
||||
require.NoError(t, err)
|
||||
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return string(payload), nil
|
||||
},
|
||||
}
|
||||
refreshCh := make(chan struct{}, 1)
|
||||
repo := &usageRepoStub{
|
||||
stats: &usagestats.DashboardStats{TotalUsers: 22},
|
||||
onCall: refreshCh,
|
||||
}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, staleStats, got)
|
||||
|
||||
select {
|
||||
case <-refreshCh:
|
||||
case <-time.After(1 * time.Second):
|
||||
t.Fatal("等待异步刷新超时")
|
||||
}
|
||||
require.Eventually(t, func() bool {
|
||||
return atomic.LoadInt32(&cache.setCalls) >= 1
|
||||
}, 1*time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheParseError_EvictsAndRefetches(t *testing.T) {
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "not-json", nil
|
||||
},
|
||||
}
|
||||
stats := &usagestats.DashboardStats{TotalUsers: 9}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, stats, got)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.calls))
|
||||
}
|
||||
|
||||
func TestDashboardService_CacheParseError_RepoFailure(t *testing.T) {
|
||||
cache := &dashboardCacheStub{
|
||||
get: func(ctx context.Context) (string, error) {
|
||||
return "not-json", nil
|
||||
},
|
||||
}
|
||||
repo := &usageRepoStub{err: errors.New("db down")}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: true},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, cache, cfg)
|
||||
|
||||
_, err := svc.GetDashboardStats(context.Background())
|
||||
require.Error(t, err)
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&cache.delCalls))
|
||||
}
|
||||
|
||||
func TestDashboardService_StatsUpdatedAtEpochWhenMissing(t *testing.T) {
|
||||
stats := &usagestats.DashboardStats{}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: time.Unix(0, 0).UTC()}
|
||||
cfg := &config.Config{Dashboard: config.DashboardCacheConfig{Enabled: false}}
|
||||
svc := NewDashboardService(repo, aggRepo, nil, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "1970-01-01T00:00:00Z", got.StatsUpdatedAt)
|
||||
require.True(t, got.StatsStale)
|
||||
}
|
||||
|
||||
func TestDashboardService_StatsStaleFalseWhenFresh(t *testing.T) {
|
||||
aggNow := time.Now().UTC().Truncate(time.Second)
|
||||
stats := &usagestats.DashboardStats{}
|
||||
repo := &usageRepoStub{stats: stats}
|
||||
aggRepo := &dashboardAggregationRepoStub{watermark: aggNow}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: true,
|
||||
IntervalSeconds: 60,
|
||||
LookbackSeconds: 120,
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, aggRepo, nil, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, aggNow.Format(time.RFC3339), got.StatsUpdatedAt)
|
||||
require.False(t, got.StatsStale)
|
||||
}
|
||||
|
||||
func TestDashboardService_AggDisabled_UsesUsageLogsFallback(t *testing.T) {
|
||||
expected := &usagestats.DashboardStats{TotalUsers: 42}
|
||||
repo := &usageRepoStub{
|
||||
rangeStats: expected,
|
||||
err: errors.New("should not call aggregated stats"),
|
||||
}
|
||||
cfg := &config.Config{
|
||||
Dashboard: config.DashboardCacheConfig{Enabled: false},
|
||||
DashboardAgg: config.DashboardAggregationConfig{
|
||||
Enabled: false,
|
||||
Retention: config.DashboardAggregationRetentionConfig{
|
||||
UsageLogsDays: 7,
|
||||
},
|
||||
},
|
||||
}
|
||||
svc := NewDashboardService(repo, nil, nil, cfg)
|
||||
|
||||
got, err := svc.GetDashboardStats(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(42), got.TotalUsers)
|
||||
require.Equal(t, int32(0), atomic.LoadInt32(&repo.calls))
|
||||
require.Equal(t, int32(1), atomic.LoadInt32(&repo.rangeCalls))
|
||||
require.False(t, repo.rangeEnd.IsZero())
|
||||
require.Equal(t, truncateToDayUTC(repo.rangeEnd.AddDate(0, 0, -7)), repo.rangeStart)
|
||||
}
|
||||
@@ -0,0 +1,252 @@
|
||||
package service
|
||||
|
||||
import "context"
|
||||
|
||||
type DataManagementPostgresConfig struct {
|
||||
Host string `json:"host"`
|
||||
Port int32 `json:"port"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password,omitempty"`
|
||||
PasswordConfigured bool `json:"password_configured"`
|
||||
Database string `json:"database"`
|
||||
SSLMode string `json:"ssl_mode"`
|
||||
ContainerName string `json:"container_name"`
|
||||
}
|
||||
|
||||
type DataManagementRedisConfig struct {
|
||||
Addr string `json:"addr"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password,omitempty"`
|
||||
PasswordConfigured bool `json:"password_configured"`
|
||||
DB int32 `json:"db"`
|
||||
ContainerName string `json:"container_name"`
|
||||
}
|
||||
|
||||
type DataManagementS3Config struct {
|
||||
Enabled bool `json:"enabled"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key,omitempty"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
Prefix string `json:"prefix"`
|
||||
ForcePathStyle bool `json:"force_path_style"`
|
||||
UseSSL bool `json:"use_ssl"`
|
||||
}
|
||||
|
||||
type DataManagementConfig struct {
|
||||
SourceMode string `json:"source_mode"`
|
||||
BackupRoot string `json:"backup_root"`
|
||||
SQLitePath string `json:"sqlite_path,omitempty"`
|
||||
RetentionDays int32 `json:"retention_days"`
|
||||
KeepLast int32 `json:"keep_last"`
|
||||
ActivePostgresID string `json:"active_postgres_profile_id"`
|
||||
ActiveRedisID string `json:"active_redis_profile_id"`
|
||||
Postgres DataManagementPostgresConfig `json:"postgres"`
|
||||
Redis DataManagementRedisConfig `json:"redis"`
|
||||
S3 DataManagementS3Config `json:"s3"`
|
||||
ActiveS3ProfileID string `json:"active_s3_profile_id"`
|
||||
}
|
||||
|
||||
type DataManagementTestS3Result struct {
|
||||
OK bool `json:"ok"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
type DataManagementCreateBackupJobInput struct {
|
||||
BackupType string
|
||||
UploadToS3 bool
|
||||
TriggeredBy string
|
||||
IdempotencyKey string
|
||||
S3ProfileID string
|
||||
PostgresID string
|
||||
RedisID string
|
||||
}
|
||||
|
||||
type DataManagementListBackupJobsInput struct {
|
||||
PageSize int32
|
||||
PageToken string
|
||||
Status string
|
||||
BackupType string
|
||||
}
|
||||
|
||||
type DataManagementArtifactInfo struct {
|
||||
LocalPath string `json:"local_path"`
|
||||
SizeBytes int64 `json:"size_bytes"`
|
||||
SHA256 string `json:"sha256"`
|
||||
}
|
||||
|
||||
type DataManagementS3ObjectInfo struct {
|
||||
Bucket string `json:"bucket"`
|
||||
Key string `json:"key"`
|
||||
ETag string `json:"etag"`
|
||||
}
|
||||
|
||||
type DataManagementBackupJob struct {
|
||||
JobID string `json:"job_id"`
|
||||
BackupType string `json:"backup_type"`
|
||||
Status string `json:"status"`
|
||||
TriggeredBy string `json:"triggered_by"`
|
||||
IdempotencyKey string `json:"idempotency_key,omitempty"`
|
||||
UploadToS3 bool `json:"upload_to_s3"`
|
||||
S3ProfileID string `json:"s3_profile_id,omitempty"`
|
||||
PostgresID string `json:"postgres_profile_id,omitempty"`
|
||||
RedisID string `json:"redis_profile_id,omitempty"`
|
||||
StartedAt string `json:"started_at,omitempty"`
|
||||
FinishedAt string `json:"finished_at,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
Artifact DataManagementArtifactInfo `json:"artifact"`
|
||||
S3Object DataManagementS3ObjectInfo `json:"s3"`
|
||||
}
|
||||
|
||||
type DataManagementSourceProfile struct {
|
||||
SourceType string `json:"source_type"`
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
Config DataManagementSourceConfig `json:"config"`
|
||||
PasswordConfigured bool `json:"password_configured"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
type DataManagementSourceConfig struct {
|
||||
Host string `json:"host"`
|
||||
Port int32 `json:"port"`
|
||||
User string `json:"user"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Database string `json:"database"`
|
||||
SSLMode string `json:"ssl_mode"`
|
||||
Addr string `json:"addr"`
|
||||
Username string `json:"username"`
|
||||
DB int32 `json:"db"`
|
||||
ContainerName string `json:"container_name"`
|
||||
}
|
||||
|
||||
type DataManagementCreateSourceProfileInput struct {
|
||||
SourceType string
|
||||
ProfileID string
|
||||
Name string
|
||||
Config DataManagementSourceConfig
|
||||
SetActive bool
|
||||
}
|
||||
|
||||
type DataManagementUpdateSourceProfileInput struct {
|
||||
SourceType string
|
||||
ProfileID string
|
||||
Name string
|
||||
Config DataManagementSourceConfig
|
||||
}
|
||||
|
||||
type DataManagementS3Profile struct {
|
||||
ProfileID string `json:"profile_id"`
|
||||
Name string `json:"name"`
|
||||
IsActive bool `json:"is_active"`
|
||||
S3 DataManagementS3Config `json:"s3"`
|
||||
SecretAccessKeyConfigured bool `json:"secret_access_key_configured"`
|
||||
CreatedAt string `json:"created_at,omitempty"`
|
||||
UpdatedAt string `json:"updated_at,omitempty"`
|
||||
}
|
||||
|
||||
type DataManagementCreateS3ProfileInput struct {
|
||||
ProfileID string
|
||||
Name string
|
||||
S3 DataManagementS3Config
|
||||
SetActive bool
|
||||
}
|
||||
|
||||
type DataManagementUpdateS3ProfileInput struct {
|
||||
ProfileID string
|
||||
Name string
|
||||
S3 DataManagementS3Config
|
||||
}
|
||||
|
||||
type DataManagementListBackupJobsResult struct {
|
||||
Items []DataManagementBackupJob `json:"items"`
|
||||
NextPageToken string `json:"next_page_token,omitempty"`
|
||||
}
|
||||
|
||||
func (s *DataManagementService) GetConfig(ctx context.Context) (DataManagementConfig, error) {
|
||||
_ = ctx
|
||||
return DataManagementConfig{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) UpdateConfig(ctx context.Context, cfg DataManagementConfig) (DataManagementConfig, error) {
|
||||
_, _ = ctx, cfg
|
||||
return DataManagementConfig{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) ListSourceProfiles(ctx context.Context, sourceType string) ([]DataManagementSourceProfile, error) {
|
||||
_, _ = ctx, sourceType
|
||||
return nil, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) CreateSourceProfile(ctx context.Context, input DataManagementCreateSourceProfileInput) (DataManagementSourceProfile, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementSourceProfile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) UpdateSourceProfile(ctx context.Context, input DataManagementUpdateSourceProfileInput) (DataManagementSourceProfile, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementSourceProfile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) DeleteSourceProfile(ctx context.Context, sourceType, profileID string) error {
|
||||
_, _, _ = ctx, sourceType, profileID
|
||||
return s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) SetActiveSourceProfile(ctx context.Context, sourceType, profileID string) (DataManagementSourceProfile, error) {
|
||||
_, _, _ = ctx, sourceType, profileID
|
||||
return DataManagementSourceProfile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) ValidateS3(ctx context.Context, cfg DataManagementS3Config) (DataManagementTestS3Result, error) {
|
||||
_, _ = ctx, cfg
|
||||
return DataManagementTestS3Result{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) ListS3Profiles(ctx context.Context) ([]DataManagementS3Profile, error) {
|
||||
_ = ctx
|
||||
return nil, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) CreateS3Profile(ctx context.Context, input DataManagementCreateS3ProfileInput) (DataManagementS3Profile, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementS3Profile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) UpdateS3Profile(ctx context.Context, input DataManagementUpdateS3ProfileInput) (DataManagementS3Profile, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementS3Profile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) DeleteS3Profile(ctx context.Context, profileID string) error {
|
||||
_, _ = ctx, profileID
|
||||
return s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) SetActiveS3Profile(ctx context.Context, profileID string) (DataManagementS3Profile, error) {
|
||||
_, _ = ctx, profileID
|
||||
return DataManagementS3Profile{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) CreateBackupJob(ctx context.Context, input DataManagementCreateBackupJobInput) (DataManagementBackupJob, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementBackupJob{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) ListBackupJobs(ctx context.Context, input DataManagementListBackupJobsInput) (DataManagementListBackupJobsResult, error) {
|
||||
_, _ = ctx, input
|
||||
return DataManagementListBackupJobsResult{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) GetBackupJob(ctx context.Context, jobID string) (DataManagementBackupJob, error) {
|
||||
_, _ = ctx, jobID
|
||||
return DataManagementBackupJob{}, s.deprecatedError()
|
||||
}
|
||||
|
||||
func (s *DataManagementService) deprecatedError() error {
|
||||
return ErrDataManagementDeprecated.WithMetadata(map[string]string{"socket_path": s.SocketPath()})
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user