chore: prepare repository for publishing
Some checks failed
CI / test (push) Has been cancelled

This commit is contained in:
phamnazage-jpg
2026-05-13 14:42:45 +08:00
parent 55e506b2b5
commit 77e6610fd2
118 changed files with 27373 additions and 1009 deletions

View File

@@ -0,0 +1,129 @@
// internal/collectors/collector.go
// Collector 接口定义:所有数据源采集器的统一抽象
package collectors
import (
"context"
"time"
)
// Result 采集结果
type Result struct {
Models []ModelInfo
Meta CollectionMeta
}
// CollectionMeta 采集元信息
type CollectionMeta struct {
Source string
Count int
Duration time.Duration
Timestamp time.Time
BatchID string
CollectorVersion string
}
// ModelInfo 标准模型信息(与 fetch_openrouter.go 兼容)
type ModelInfo struct {
ID string
Name string
Provider string
ProviderID string
Version string
Modality string
ContextLength int
Capabilities []string
Pricing ModelPricing
Description string
IsFree bool
SourceURL string
}
// ModelPricing 标准定价信息
type ModelPricing struct {
Input float64
Output float64
}
// Collector 采集器接口
type Collector interface {
// Name 返回采集器名称
Name() string
// Collect 执行采集,返回标准模型列表
Collect(ctx context.Context) (Result, error)
// Schedule 返回推荐调度周期(如 "0 8 * * *"
Schedule() string
// Timeout 返回单次采集超时时间
Timeout() time.Duration
// RetryCount 返回最大重试次数
RetryCount() int
}
// BaseCollector 提供默认实现的嵌入类型
type BaseCollector struct {
name string
schedule string
timeout time.Duration
retryCount int
version string
}
func (b *BaseCollector) Name() string { return b.name }
func (b *BaseCollector) Schedule() string { return b.schedule }
func (b *BaseCollector) Timeout() time.Duration { return b.timeout }
func (b *BaseCollector) RetryCount() int { return b.retryCount }
func (b *BaseCollector) Version() string { return b.version }
// NewBaseCollector 创建基础采集器配置
func NewBaseCollector(name, schedule string, timeout time.Duration, retry int, version string) BaseCollector {
return BaseCollector{
name: name,
schedule: schedule,
timeout: timeout,
retryCount: retry,
version: version,
}
}
// CollectorRegistry 采集器注册表
type CollectorRegistry struct {
collectors map[string]Collector
}
// NewRegistry 创建采集器注册表
func NewRegistry() *CollectorRegistry {
return &CollectorRegistry{collectors: make(map[string]Collector)}
}
// Register 注册采集器
func (r *CollectorRegistry) Register(c Collector) {
r.collectors[c.Name()] = c
}
// Get 获取采集器
func (r *CollectorRegistry) Get(name string) (Collector, bool) {
c, ok := r.collectors[name]
return c, ok
}
// All 返回所有已注册采集器
func (r *CollectorRegistry) All() []Collector {
cs := make([]Collector, 0, len(r.collectors))
for _, c := range r.collectors {
cs = append(cs, c)
}
return cs
}
// Names 返回所有已注册采集器名称
func (r *CollectorRegistry) Names() []string {
names := make([]string, 0, len(r.collectors))
for n := range r.collectors {
names = append(names, n)
}
return names
}

View File

@@ -0,0 +1,127 @@
// internal/collectors/collector_test.go
package collectors
import (
"context"
"errors"
"testing"
"time"
)
// mockCollector 用于测试的模拟采集器
type mockCollector struct {
BaseCollector
collectFunc func(ctx context.Context) (Result, error)
}
func (m *mockCollector) Collect(ctx context.Context) (Result, error) {
return m.collectFunc(ctx)
}
func TestCollectorInterface(t *testing.T) {
c := &mockCollector{
BaseCollector: NewBaseCollector("test", "0 8 * * *", 30*time.Second, 3, "v1.0"),
collectFunc: func(ctx context.Context) (Result, error) {
return Result{
Models: []ModelInfo{{ID: "test/model-1", Name: "Test Model"}},
Meta: CollectionMeta{Source: "test", Count: 1},
}, nil
},
}
// 测试接口方法
if c.Name() != "test" {
t.Errorf("Name() = %q, want %q", c.Name(), "test")
}
if c.Schedule() != "0 8 * * *" {
t.Errorf("Schedule() = %q, want %q", c.Schedule(), "0 8 * * *")
}
if c.Timeout() != 30*time.Second {
t.Errorf("Timeout() = %v, want %v", c.Timeout(), 30*time.Second)
}
if c.RetryCount() != 3 {
t.Errorf("RetryCount() = %d, want %d", c.RetryCount(), 3)
}
// 测试 Collect
ctx := context.Background()
result, err := c.Collect(ctx)
if err != nil {
t.Fatalf("Collect() error = %v", err)
}
if len(result.Models) != 1 {
t.Errorf("len(Models) = %d, want 1", len(result.Models))
}
if result.Meta.Count != 1 {
t.Errorf("Meta.Count = %d, want 1", result.Meta.Count)
}
}
func TestCollectorRegistry(t *testing.T) {
reg := NewRegistry()
c1 := &mockCollector{
BaseCollector: NewBaseCollector("openrouter", "0 8 * * *", 30*time.Second, 3, "v1.0"),
collectFunc: func(ctx context.Context) (Result, error) { return Result{}, nil },
}
c2 := &mockCollector{
BaseCollector: NewBaseCollector("siliconflow", "0 9 * * *", 30*time.Second, 3, "v1.0"),
collectFunc: func(ctx context.Context) (Result, error) { return Result{}, nil },
}
reg.Register(c1)
reg.Register(c2)
// 测试 Get
got, ok := reg.Get("openrouter")
if !ok {
t.Fatal("Get(openrouter) not found")
}
if got.Name() != "openrouter" {
t.Errorf("Get() Name = %q, want %q", got.Name(), "openrouter")
}
// 测试 Names
names := reg.Names()
if len(names) != 2 {
t.Errorf("Names() len = %d, want 2", len(names))
}
// 测试 All
all := reg.All()
if len(all) != 2 {
t.Errorf("All() len = %d, want 2", len(all))
}
// 测试不存在的采集器
_, ok = reg.Get("nonexistent")
if ok {
t.Error("Get(nonexistent) should return false")
}
}
func TestCollectorTimeout(t *testing.T) {
c := &mockCollector{
BaseCollector: NewBaseCollector("slow", "0 8 * * *", 100*time.Millisecond, 0, "v1.0"),
collectFunc: func(ctx context.Context) (Result, error) {
// 模拟耗时操作
select {
case <-time.After(200 * time.Millisecond):
return Result{}, nil
case <-ctx.Done():
return Result{}, ctx.Err()
}
},
}
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err := c.Collect(ctx)
if err == nil {
t.Error("Expected timeout error, got nil")
}
if !errors.Is(err, context.DeadlineExceeded) {
t.Errorf("Expected DeadlineExceeded, got %v", err)
}
}

View File

@@ -0,0 +1,115 @@
// internal/collectors/provider_mapper.go
// ProviderMapper: 将 OpenRouter 模型 ID 映射为标准厂商/模型名称
package collectors
import (
"fmt"
"strings"
)
// ProviderInfo 标准厂商信息
type ProviderInfo struct {
ID string // 标准ID: "openai", "anthropic", "deepseek"...
Name string // 英文名
NameCN string // 中文名
Country string // "US" / "CN" / "EU"
}
// ModelMapping 模型映射结果
type ModelMapping struct {
Provider ProviderInfo
ModelName string // 纯模型名,不含厂商前缀
RawID string // 原始 OpenRouter ID
IsFree bool // 是否免费版(:free 后缀)
}
// providerNameMap 标准厂商名称映射表
// key 为标准ID也兼容 OpenRouter 原始格式作为别名)
var providerNameMap = map[string]ProviderInfo{
"openai": {ID: "openai", Name: "OpenAI", NameCN: "OpenAI", Country: "US"},
"anthropic": {ID: "anthropic", Name: "Anthropic", NameCN: "Anthropic", Country: "US"},
"google": {ID: "google", Name: "Google", NameCN: "谷歌", Country: "US"},
"meta": {ID: "meta", Name: "Meta", NameCN: "Meta", Country: "US"},
"xai": {ID: "xai", Name: "xAI", NameCN: "xAI", Country: "US"},
"x-ai": {ID: "xai", Name: "xAI", NameCN: "xAI", Country: "US"}, // OpenRouter别名
"deepseek": {ID: "deepseek", Name: "DeepSeek", NameCN: "深度求索", Country: "CN"},
"qwen": {ID: "qwen", Name: "Qwen", NameCN: "通义千问", Country: "CN"},
"alibaba": {ID: "alibaba", Name: "Alibaba", NameCN: "阿里巴巴", Country: "CN"},
"moonshot": {ID: "moonshot", Name: "Moonshot AI", NameCN: "月之暗面", Country: "CN"},
"moonshotai": {ID: "moonshot", Name: "Moonshot AI", NameCN: "月之暗面", Country: "CN"}, // OpenRouter别名
"zhipu": {ID: "zhipu", Name: "Zhipu AI", NameCN: "智谱AI", Country: "CN"},
"zhipuai": {ID: "zhipu", Name: "Zhipu AI", NameCN: "智谱AI", Country: "CN"}, // OpenRouter别名
"bytedance": {ID: "bytedance", Name: "ByteDance", NameCN: "字节跳动", Country: "CN"},
"baidu": {ID: "baidu", Name: "Baidu", NameCN: "百度", Country: "CN"},
"tencent": {ID: "tencent", Name: "Tencent", NameCN: "腾讯", Country: "CN"},
"mistral": {ID: "mistral", Name: "Mistral AI", NameCN: "Mistral", Country: "EU"},
"cohere": {ID: "cohere", Name: "Cohere", NameCN: "Cohere", Country: "US"},
"ai21": {ID: "ai21", Name: "AI21 Labs", NameCN: "AI21", Country: "US"},
"perplexity": {ID: "perplexity", Name: "Perplexity", NameCN: "Perplexity", Country: "US"},
"nvidia": {ID: "nvidia", Name: "NVIDIA", NameCN: "英伟达", Country: "US"},
"microsoft": {ID: "microsoft", Name: "Microsoft", NameCN: "微软", Country: "US"},
"openrouter": {ID: "openrouter", Name: "OpenRouter", NameCN: "OpenRouter", Country: "US"},
}
// MapOpenRouterID 将 OpenRouter 模型 ID 映射为标准信息
// OpenRouter ID 格式: "provider/model-name" 或 "provider/model-name:free"
func MapOpenRouterID(rawID string) (ModelMapping, error) {
if rawID == "" {
return ModelMapping{}, fmt.Errorf("empty model ID")
}
// 检测 :free 后缀
isFree := false
modelPart := rawID
if strings.HasSuffix(rawID, ":free") {
isFree = true
modelPart = rawID[:len(rawID)-5]
}
// 分割 provider / model
parts := strings.SplitN(modelPart, "/", 2)
if len(parts) < 2 {
return ModelMapping{}, fmt.Errorf("invalid model ID format: %s", rawID)
}
providerKey := strings.ToLower(parts[0])
modelName := parts[1]
// 查找厂商信息
provider, ok := providerNameMap[providerKey]
if !ok {
// 未识别厂商,返回通用信息
provider = ProviderInfo{
ID: providerKey,
Name: providerKey,
NameCN: providerKey,
Country: "unknown",
}
}
return ModelMapping{
Provider: provider,
ModelName: modelName,
RawID: rawID,
IsFree: isFree,
}, nil
}
// GetAllProviderNames 返回所有已注册的厂商ID列表用于测试覆盖度检查
func GetAllProviderNames() []string {
names := make([]string, 0, len(providerNameMap))
for k := range providerNameMap {
names = append(names, k)
}
return names
}
// RegisterProvider 动态注册新厂商(用于扩展)
func RegisterProvider(key string, info ProviderInfo) {
providerNameMap[strings.ToLower(key)] = info
}
// ProviderCount 返回已注册厂商数量
func ProviderCount() int {
return len(providerNameMap)
}

View File

@@ -0,0 +1,167 @@
// internal/collectors/provider_mapper_test.go
package collectors
import (
"testing"
)
func TestMapOpenRouterID(t *testing.T) {
tests := []struct {
name string
rawID string
wantErr bool
wantProvID string
wantProvCN string
wantModel string
wantFree bool
wantCountry string
}{
{
name: "OpenAI GPT-4o",
rawID: "openai/gpt-4o",
wantProvID: "openai",
wantProvCN: "OpenAI",
wantModel: "gpt-4o",
wantFree: false,
wantCountry: "US",
},
{
name: "Anthropic Claude free",
rawID: "anthropic/claude-3.5-sonnet:free",
wantProvID: "anthropic",
wantProvCN: "Anthropic",
wantModel: "claude-3.5-sonnet",
wantFree: true,
wantCountry: "US",
},
{
name: "DeepSeek V3",
rawID: "deepseek/deepseek-v3",
wantProvID: "deepseek",
wantProvCN: "深度求索",
wantModel: "deepseek-v3",
wantFree: false,
wantCountry: "CN",
},
{
name: "Moonshot Kimi",
rawID: "moonshotai/kimi-k2",
wantProvID: "moonshot",
wantProvCN: "月之暗面",
wantModel: "kimi-k2",
wantFree: false,
wantCountry: "CN",
},
{
name: "Unknown provider fallback",
rawID: "some-new-ai/model-x",
wantProvID: "some-new-ai",
wantProvCN: "some-new-ai",
wantModel: "model-x",
wantFree: false,
wantCountry: "unknown",
},
{
name: "Empty ID",
rawID: "",
wantErr: true,
},
{
name: "Invalid format no slash",
rawID: "invalid-id",
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := MapOpenRouterID(tt.rawID)
if (err != nil) != tt.wantErr {
t.Errorf("MapOpenRouterID(%q) error = %v, wantErr %v", tt.rawID, err, tt.wantErr)
return
}
if tt.wantErr {
return
}
if got.Provider.ID != tt.wantProvID {
t.Errorf("Provider.ID = %q, want %q", got.Provider.ID, tt.wantProvID)
}
if got.Provider.NameCN != tt.wantProvCN {
t.Errorf("Provider.NameCN = %q, want %q", got.Provider.NameCN, tt.wantProvCN)
}
if got.ModelName != tt.wantModel {
t.Errorf("ModelName = %q, want %q", got.ModelName, tt.wantModel)
}
if got.IsFree != tt.wantFree {
t.Errorf("IsFree = %v, want %v", got.IsFree, tt.wantFree)
}
if got.Provider.Country != tt.wantCountry {
t.Errorf("Country = %q, want %q", got.Provider.Country, tt.wantCountry)
}
})
}
}
func TestProviderMapCompleteness(t *testing.T) {
// 验证所有预定义的厂商映射
requiredProviders := []string{
"openai", "anthropic", "google", "meta", "xai",
"deepseek", "qwen", "moonshot", "zhipu", "bytedance",
"baidu", "tencent", "alibaba", "mistral", "cohere",
"ai21", "perplexity", "nvidia", "microsoft", "openrouter",
}
for _, id := range requiredProviders {
_, ok := providerNameMap[id]
if !ok {
t.Errorf("Required provider %q not found in providerNameMap", id)
}
}
// 验证总数 >= 20
if ProviderCount() < 20 {
t.Errorf("ProviderCount() = %d, want >= 20", ProviderCount())
}
}
func TestRegisterProvider(t *testing.T) {
// 注册新厂商
RegisterProvider("test-corp", ProviderInfo{
ID: "test-corp",
Name: "Test Corp",
NameCN: "测试公司",
Country: "CN",
})
got, err := MapOpenRouterID("test-corp/model-1")
if err != nil {
t.Fatalf("MapOpenRouterID after RegisterProvider failed: %v", err)
}
if got.Provider.NameCN != "测试公司" {
t.Errorf("After RegisterProvider, NameCN = %q, want %q", got.Provider.NameCN, "测试公司")
}
}
func TestGetAllProviderNames(t *testing.T) {
names := GetAllProviderNames()
if len(names) == 0 {
t.Error("GetAllProviderNames() returned empty slice")
}
// 验证包含 openai
found := false
for _, n := range names {
if n == "openai" {
found = true
break
}
}
if !found {
t.Error("GetAllProviderNames() missing 'openai'")
}
}
func BenchmarkMapOpenRouterID(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _ = MapOpenRouterID("openai/gpt-4o")
}
}

170
internal/retry/retry.go Normal file
View File

@@ -0,0 +1,170 @@
// internal/retry/retry.go
// 指数退避重试机制
package retry
import (
"context"
"fmt"
"math"
"time"
)
// Strategy 重试策略
type Strategy struct {
MaxRetries int // 最大重试次数0=不重试)
BaseDelay time.Duration // 基础延迟
MaxDelay time.Duration // 最大延迟上限
Multiplier float64 // 乘数默认2.0
Jitter bool // 是否添加随机抖动
Retryable func(error) bool // 判断错误是否可重试
}
// DefaultStrategy 返回默认重试策略
func DefaultStrategy() Strategy {
return Strategy{
MaxRetries: 3,
BaseDelay: 1 * time.Second,
MaxDelay: 30 * time.Second,
Multiplier: 2.0,
Jitter: true,
Retryable: IsRetryable,
}
}
// IsRetryable 默认重试判定网络错误、超时、5xx状态码等可重试
func IsRetryable(err error) bool {
if err == nil {
return false
}
// 这里可以扩展更多错误类型判定
return true
}
// Do 执行带重试的操作
func Do(ctx context.Context, strategy Strategy, fn func() error) error {
var lastErr error
for attempt := 0; attempt <= strategy.MaxRetries; attempt++ {
if err := fn(); err != nil {
lastErr = err
// 不判断最后一次是否需要重试
if attempt == strategy.MaxRetries {
break
}
// 检查是否可重试
if strategy.Retryable != nil && !strategy.Retryable(err) {
return fmt.Errorf("non-retryable error on attempt %d: %w", attempt+1, err)
}
// 计算退避延迟
delay := calculateDelay(strategy, attempt)
// 检查上下文是否已取消
select {
case <-ctx.Done():
return fmt.Errorf("context cancelled after attempt %d: %w", attempt+1, ctx.Err())
case <-time.After(delay):
// 继续重试
}
} else {
return nil
}
}
return fmt.Errorf("all %d attempts failed, last error: %w", strategy.MaxRetries+1, lastErr)
}
// calculateDelay 计算指数退避延迟
func calculateDelay(s Strategy, attempt int) time.Duration {
// 指数退避: base * multiplier^attempt
delay := float64(s.BaseDelay) * math.Pow(s.Multiplier, float64(attempt))
// 添加上限
if max := float64(s.MaxDelay); delay > max {
delay = max
}
// 添加抖动±25%
if s.Jitter {
jitter := delay * 0.25
delay = delay - jitter + (jitter * 2 * float64(time.Now().Nanosecond()%1000) / 1000)
}
return time.Duration(delay)
}
// DoWithResult 执行带重试的操作并返回结果
func DoWithResult[T any](ctx context.Context, strategy Strategy, fn func() (T, error)) (T, error) {
var zero T
var lastErr error
for attempt := 0; attempt <= strategy.MaxRetries; attempt++ {
result, err := fn()
if err == nil {
return result, nil
}
lastErr = err
if attempt == strategy.MaxRetries {
break
}
if strategy.Retryable != nil && !strategy.Retryable(err) {
return zero, fmt.Errorf("non-retryable error on attempt %d: %w", attempt+1, err)
}
delay := calculateDelay(strategy, attempt)
select {
case <-ctx.Done():
return zero, fmt.Errorf("context cancelled after attempt %d: %w", attempt+1, ctx.Err())
case <-time.After(delay):
}
}
return zero, fmt.Errorf("all %d attempts failed, last error: %w", strategy.MaxRetries+1, lastErr)
}
// Metrics 重试统计
type Metrics struct {
Attempts int
Success bool
TotalDelay time.Duration
}
// DoWithMetrics 执行带重试并返回统计信息
func DoWithMetrics(ctx context.Context, strategy Strategy, fn func() error) (Metrics, error) {
m := Metrics{}
var lastErr error
start := time.Now()
for attempt := 0; attempt <= strategy.MaxRetries; attempt++ {
m.Attempts = attempt + 1
if err := fn(); err != nil {
lastErr = err
if attempt == strategy.MaxRetries {
break
}
if strategy.Retryable != nil && !strategy.Retryable(err) {
m.TotalDelay = time.Since(start)
return m, fmt.Errorf("non-retryable error on attempt %d: %w", attempt+1, err)
}
delay := calculateDelay(strategy, attempt)
select {
case <-ctx.Done():
m.TotalDelay = time.Since(start)
return m, fmt.Errorf("context cancelled after attempt %d: %w", attempt+1, ctx.Err())
case <-time.After(delay):
}
} else {
m.Success = true
m.TotalDelay = time.Since(start)
return m, nil
}
}
m.TotalDelay = time.Since(start)
return m, fmt.Errorf("all %d attempts failed, last error: %w", strategy.MaxRetries+1, lastErr)
}

View File

@@ -0,0 +1,245 @@
// internal/retry/retry_test.go
package retry
import (
"context"
"errors"
"testing"
"time"
)
func TestDo_Success(t *testing.T) {
strategy := DefaultStrategy()
callCount := 0
err := Do(context.Background(), strategy, func() error {
callCount++
return nil
})
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if callCount != 1 {
t.Errorf("expected 1 call, got %d", callCount)
}
}
func TestDo_RetryThenSuccess(t *testing.T) {
strategy := Strategy{
MaxRetries: 3,
BaseDelay: 10 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
Multiplier: 2.0,
Jitter: false,
Retryable: IsRetryable,
}
callCount := 0
err := Do(context.Background(), strategy, func() error {
callCount++
if callCount < 3 {
return errors.New("temporary error")
}
return nil
})
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if callCount != 3 {
t.Errorf("expected 3 calls, got %d", callCount)
}
}
func TestDo_MaxRetriesExceeded(t *testing.T) {
strategy := Strategy{
MaxRetries: 2,
BaseDelay: 5 * time.Millisecond,
MaxDelay: 50 * time.Millisecond,
Multiplier: 2.0,
Jitter: false,
Retryable: IsRetryable,
}
callCount := 0
expectedErr := errors.New("persistent error")
err := Do(context.Background(), strategy, func() error {
callCount++
return expectedErr
})
if err == nil {
t.Fatal("expected error, got nil")
}
if callCount != 3 { // initial + 2 retries
t.Errorf("expected 3 calls, got %d", callCount)
}
}
func TestDo_NonRetryableError(t *testing.T) {
strategy := Strategy{
MaxRetries: 3,
BaseDelay: 10 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
Multiplier: 2.0,
Jitter: false,
Retryable: func(err error) bool { return false }, // 任何错误都不重试
}
callCount := 0
err := Do(context.Background(), strategy, func() error {
callCount++
return errors.New("non-retryable")
})
if err == nil {
t.Fatal("expected error, got nil")
}
if callCount != 1 {
t.Errorf("expected 1 call (no retry), got %d", callCount)
}
}
func TestDo_ContextCancellation(t *testing.T) {
strategy := Strategy{
MaxRetries: 3,
BaseDelay: 1 * time.Second, // 长延迟确保上下文取消优先
MaxDelay: 5 * time.Second,
Multiplier: 2.0,
Jitter: false,
Retryable: IsRetryable,
}
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
callCount := 0
err := Do(ctx, strategy, func() error {
callCount++
return errors.New("error")
})
if err == nil {
t.Fatal("expected error, got nil")
}
if callCount < 1 {
t.Error("expected at least 1 call")
}
if !errors.Is(err, context.DeadlineExceeded) && !errors.Is(err, context.Canceled) {
t.Errorf("expected context error, got %v", err)
}
}
func TestDoWithResult(t *testing.T) {
strategy := Strategy{
MaxRetries: 2,
BaseDelay: 5 * time.Millisecond,
MaxDelay: 50 * time.Millisecond,
Multiplier: 2.0,
Jitter: false,
Retryable: IsRetryable,
}
callCount := 0
result, err := DoWithResult(context.Background(), strategy, func() (string, error) {
callCount++
if callCount < 2 {
return "", errors.New("temp error")
}
return "success", nil
})
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if result != "success" {
t.Errorf("expected 'success', got %q", result)
}
if callCount != 2 {
t.Errorf("expected 2 calls, got %d", callCount)
}
}
func TestDoWithMetrics(t *testing.T) {
strategy := Strategy{
MaxRetries: 2,
BaseDelay: 10 * time.Millisecond,
MaxDelay: 100 * time.Millisecond,
Multiplier: 2.0,
Jitter: false,
Retryable: IsRetryable,
}
// 成功场景
m, err := DoWithMetrics(context.Background(), strategy, func() error {
return nil
})
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if !m.Success {
t.Error("expected Success=true")
}
if m.Attempts != 1 {
t.Errorf("expected 1 attempt, got %d", m.Attempts)
}
// 失败场景
m2, err := DoWithMetrics(context.Background(), strategy, func() error {
return errors.New("always fails")
})
if err == nil {
t.Fatal("expected error, got nil")
}
if m2.Success {
t.Error("expected Success=false")
}
if m2.Attempts != 3 {
t.Errorf("expected 3 attempts, got %d", m2.Attempts)
}
}
func TestCalculateDelay(t *testing.T) {
strategy := Strategy{
BaseDelay: 1 * time.Second,
MaxDelay: 10 * time.Second,
Multiplier: 2.0,
Jitter: false,
}
tests := []struct {
attempt int
min time.Duration
max time.Duration
}{
{0, 1 * time.Second, 1 * time.Second},
{1, 2 * time.Second, 2 * time.Second},
{2, 4 * time.Second, 4 * time.Second},
{3, 8 * time.Second, 8 * time.Second},
{4, 10 * time.Second, 10 * time.Second}, // 达到上限
}
for _, tt := range tests {
delay := calculateDelay(strategy, tt.attempt)
if delay < tt.min || delay > tt.max {
t.Errorf("attempt %d: delay=%v, want [%v, %v]", tt.attempt, delay, tt.min, tt.max)
}
}
}
func BenchmarkDo(b *testing.B) {
strategy := Strategy{
MaxRetries: 0,
BaseDelay: 0,
MaxDelay: 0,
Multiplier: 0,
Jitter: false,
}
for i := 0; i < b.N; i++ {
_ = Do(context.Background(), strategy, func() error {
return nil
})
}
}