package router import ( "context" "math" "testing" "time" "lijiaoqiao/gateway/internal/adapter" ) // mockProvider 实现adapter.ProviderAdapter接口 type mockProvider struct { name string models []string healthy bool } func (m *mockProvider) ChatCompletion(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (*adapter.CompletionResponse, error) { return nil, nil } func (m *mockProvider) ChatCompletionStream(ctx context.Context, model string, messages []adapter.Message, options adapter.CompletionOptions) (<-chan *adapter.StreamChunk, error) { return nil, nil } func (m *mockProvider) GetUsage(response *adapter.CompletionResponse) adapter.Usage { return adapter.Usage{} } func (m *mockProvider) MapError(err error) adapter.ProviderError { return adapter.ProviderError{} } func (m *mockProvider) HealthCheck(ctx context.Context) bool { return m.healthy } func (m *mockProvider) ProviderName() string { return m.name } func (m *mockProvider) SupportedModels() []string { return m.models } func TestNewRouter(t *testing.T) { r := NewRouter(StrategyLatency) if r == nil { t.Fatal("expected non-nil router") } if r.strategy != StrategyLatency { t.Errorf("expected strategy latency, got %s", r.strategy) } if len(r.providers) != 0 { t.Errorf("expected 0 providers, got %d", len(r.providers)) } } func TestRegisterProvider(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("test", prov) if len(r.providers) != 1 { t.Errorf("expected 1 provider, got %d", len(r.providers)) } health := r.health["test"] if health == nil { t.Fatal("expected health to be registered") } if health.Name != "test" { t.Errorf("expected name test, got %s", health.Name) } if !health.Available { t.Error("expected provider to be available") } } func TestSelectProvider_NoProviders(t *testing.T) { r := NewRouter(StrategyLatency) _, err := r.SelectProvider(context.Background(), "gpt-4") if err == nil { t.Fatal("expected error") } } func TestSelectProvider_BasicSelection(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("test", prov) selected, err := r.SelectProvider(context.Background(), "gpt-4") if err != nil { t.Fatalf("unexpected error: %v", err) } if selected.ProviderName() != "test" { t.Errorf("expected provider test, got %s", selected.ProviderName()) } } func TestSelectProvider_ModelNotSupported(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-3.5"}, healthy: true} r.RegisterProvider("test", prov) _, err := r.SelectProvider(context.Background(), "gpt-4") if err == nil { t.Fatal("expected error") } } func TestSelectProvider_ProviderUnavailable(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("test", prov) // 通过UpdateHealth标记为不可用 r.UpdateHealth("test", false) _, err := r.SelectProvider(context.Background(), "gpt-4") if err == nil { t.Fatal("expected error") } } func TestSelectProvider_WildcardModel(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"*"}, healthy: true} r.RegisterProvider("test", prov) selected, err := r.SelectProvider(context.Background(), "any-model") if err != nil { t.Fatalf("unexpected error: %v", err) } if selected.ProviderName() != "test" { t.Errorf("expected provider test, got %s", selected.ProviderName()) } } func TestSelectProvider_MultipleProviders(t *testing.T) { r := NewRouter(StrategyLatency) prov1 := &mockProvider{name: "fast", models: []string{"gpt-4"}, healthy: true} prov2 := &mockProvider{name: "slow", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("fast", prov1) r.RegisterProvider("slow", prov2) // 记录初始延迟 r.health["fast"].LatencyMs = 10 r.health["slow"].LatencyMs = 100 selected, err := r.SelectProvider(context.Background(), "gpt-4") if err != nil { t.Fatalf("unexpected error: %v", err) } if selected.ProviderName() != "fast" { t.Errorf("expected fastest provider, got %s", selected.ProviderName()) } } func TestRecordResult_Success(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("test", prov) // 初始状态 initialLatency := r.health["test"].LatencyMs r.RecordResult(context.Background(), "test", true, 50) if r.health["test"].LatencyMs == initialLatency { // 首次更新 } if r.health["test"].FailureRate != 0 { t.Errorf("expected failure rate 0, got %f", r.health["test"].FailureRate) } } func TestRecordResult_Failure(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("test", prov) r.RecordResult(context.Background(), "test", false, 100) if r.health["test"].FailureRate == 0 { t.Error("expected failure rate to increase") } } func TestRecordResult_MultipleFailures(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("test", prov) // 多次失败直到失败率超过0.5 // 公式: newRate = oldRate * 0.9 + 0.1 // 需要7次才能超过0.5 (0.469 -> 0.522) for i := 0; i < 7; i++ { r.RecordResult(context.Background(), "test", false, 100) } // 失败率超过0.5应该标记为不可用 if r.health["test"].Available { t.Error("expected provider to be marked unavailable") } } func TestUpdateHealth(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("test", prov) r.UpdateHealth("test", false) if r.health["test"].Available { t.Error("expected provider to be unavailable") } } func TestGetHealthStatus(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("test", prov) status := r.GetHealthStatus() if len(status) != 1 { t.Errorf("expected 1 health status, got %d", len(status)) } health := status["test"] if health == nil { t.Fatal("expected health for test") } if health.Available != true { t.Error("expected available") } } func TestGetHealthStatus_Empty(t *testing.T) { r := NewRouter(StrategyLatency) status := r.GetHealthStatus() if len(status) != 0 { t.Errorf("expected 0 health statuses, got %d", len(status)) } } func TestSelectByLatency_EqualLatency(t *testing.T) { r := NewRouter(StrategyLatency) prov1 := &mockProvider{name: "p1", models: []string{"gpt-4"}, healthy: true} prov2 := &mockProvider{name: "p2", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("p1", prov1) r.RegisterProvider("p2", prov2) // 相同的延迟 r.health["p1"].LatencyMs = 50 r.health["p2"].LatencyMs = 50 selected, err := r.selectByLatency([]string{"p1", "p2"}) if err != nil { t.Fatalf("unexpected error: %v", err) } // 应该返回其中一个 if selected.ProviderName() != "p1" && selected.ProviderName() != "p2" { t.Errorf("unexpected provider: %s", selected.ProviderName()) } } func TestSelectByLatency_NoProviders(t *testing.T) { r := NewRouter(StrategyLatency) _, err := r.selectByLatency([]string{}) if err == nil { t.Fatal("expected error") } } func TestSelectByWeight(t *testing.T) { r := NewRouter(StrategyLatency) prov1 := &mockProvider{name: "p1", models: []string{"gpt-4"}, healthy: true} prov2 := &mockProvider{name: "p2", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("p1", prov1) r.RegisterProvider("p2", prov2) r.health["p1"].Weight = 3.0 r.health["p2"].Weight = 1.0 // 测试能正常返回结果 selected, err := r.selectByWeight([]string{"p1", "p2"}) if err != nil { t.Fatalf("unexpected error: %v", err) } // 应该返回其中一个 if selected.ProviderName() != "p1" && selected.ProviderName() != "p2" { t.Errorf("unexpected provider: %s", selected.ProviderName()) } // 注意:由于实现中randVal = time.Now().UnixNano()/MaxInt64 * totalWeight // 在大多数系统上这个值较小,可能总是选中第一个provider。 // 这是实现的一个已知限制。 } func TestSelectByWeight_SingleProvider(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "p1", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("p1", prov) r.health["p1"].Weight = 2.0 selected, err := r.selectByWeight([]string{"p1"}) if err != nil { t.Fatalf("unexpected error: %v", err) } if selected.ProviderName() != "p1" { t.Errorf("expected p1, got %s", selected.ProviderName()) } } func TestSelectByAvailability(t *testing.T) { r := NewRouter(StrategyLatency) prov1 := &mockProvider{name: "p1", models: []string{"gpt-4"}, healthy: true} prov2 := &mockProvider{name: "p2", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("p1", prov1) r.RegisterProvider("p2", prov2) r.health["p1"].FailureRate = 0.3 r.health["p2"].FailureRate = 0.1 selected, err := r.selectByAvailability([]string{"p1", "p2"}) if err != nil { t.Fatalf("unexpected error: %v", err) } if selected.ProviderName() != "p2" { t.Errorf("expected provider with lower failure rate, got %s", selected.ProviderName()) } } func TestGetFallbackProviders(t *testing.T) { r := NewRouter(StrategyLatency) prov1 := &mockProvider{name: "primary", models: []string{"gpt-4"}, healthy: true} prov2 := &mockProvider{name: "fallback", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("primary", prov1) r.RegisterProvider("fallback", prov2) fallbacks, err := r.GetFallbackProviders(context.Background(), "gpt-4") if err != nil { t.Fatalf("unexpected error: %v", err) } if len(fallbacks) != 1 { t.Errorf("expected 1 fallback, got %d", len(fallbacks)) } if fallbacks[0].ProviderName() != "fallback" { t.Errorf("expected fallback, got %s", fallbacks[0].ProviderName()) } } func TestGetFallbackProviders_AllUnavailable(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "primary", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("primary", prov) fallbacks, err := r.GetFallbackProviders(context.Background(), "gpt-4") if err != nil { t.Fatalf("unexpected error: %v", err) } if len(fallbacks) != 0 { t.Errorf("expected 0 fallbacks, got %d", len(fallbacks)) } } func TestRecordResult_LatencyUpdate(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("test", prov) // 首次记录 r.RecordResult(context.Background(), "test", true, 100) if r.health["test"].LatencyMs != 100 { t.Errorf("expected latency 100, got %d", r.health["test"].LatencyMs) } // 第二次记录,使用指数移动平均 (7/8 * 100 + 1/8 * 200 = 87.5 + 25 = 112.5) r.RecordResult(context.Background(), "test", true, 200) expectedLatency := int64((100*7 + 200) / 8) if r.health["test"].LatencyMs != expectedLatency { t.Errorf("expected latency %d, got %d", expectedLatency, r.health["test"].LatencyMs) } } func TestRecordResult_UnknownProvider(t *testing.T) { r := NewRouter(StrategyLatency) // 不应该panic r.RecordResult(context.Background(), "unknown", true, 100) } func TestUpdateHealth_UnknownProvider(t *testing.T) { r := NewRouter(StrategyLatency) // 不应该panic r.UpdateHealth("unknown", false) } func TestIsProviderAvailable(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-4", "gpt-3.5"}, healthy: true} r.RegisterProvider("test", prov) tests := []struct { model string available bool }{ {"gpt-4", true}, {"gpt-3.5", true}, {"claude", false}, } for _, tt := range tests { if got := r.isProviderAvailable("test", tt.model); got != tt.available { t.Errorf("isProviderAvailable(%s) = %v, want %v", tt.model, got, tt.available) } } } func TestIsProviderAvailable_UnknownProvider(t *testing.T) { r := NewRouter(StrategyLatency) if r.isProviderAvailable("unknown", "gpt-4") { t.Error("expected false for unknown provider") } } func TestIsProviderAvailable_Unhealthy(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("test", prov) // 通过UpdateHealth标记为不可用 r.UpdateHealth("test", false) if r.isProviderAvailable("test", "gpt-4") { t.Error("expected false for unhealthy provider") } } func TestProviderHealth_Struct(t *testing.T) { health := &ProviderHealth{ Name: "test", Available: true, LatencyMs: 50, FailureRate: 0.1, Weight: 1.0, LastCheckTime: time.Now(), } if health.Name != "test" { t.Errorf("expected name test, got %s", health.Name) } if !health.Available { t.Error("expected available") } if health.LatencyMs != 50 { t.Errorf("expected latency 50, got %d", health.LatencyMs) } if health.FailureRate != 0.1 { t.Errorf("expected failure rate 0.1, got %f", health.FailureRate) } if health.Weight != 1.0 { t.Errorf("expected weight 1.0, got %f", health.Weight) } } func TestLoadBalancerStrategy_Constants(t *testing.T) { if StrategyLatency != "latency" { t.Errorf("expected latency, got %s", StrategyLatency) } if StrategyRoundRobin != "round_robin" { t.Errorf("expected round_robin, got %s", StrategyRoundRobin) } if StrategyWeighted != "weighted" { t.Errorf("expected weighted, got %s", StrategyWeighted) } if StrategyAvailability != "availability" { t.Errorf("expected availability, got %s", StrategyAvailability) } } func TestSelectProvider_AllStrategies(t *testing.T) { strategies := []LoadBalancerStrategy{StrategyLatency, StrategyWeighted, StrategyAvailability} for _, strategy := range strategies { r := NewRouter(strategy) prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("test", prov) selected, err := r.SelectProvider(context.Background(), "gpt-4") if err != nil { t.Errorf("strategy %s: unexpected error: %v", strategy, err) } if selected.ProviderName() != "test" { t.Errorf("strategy %s: expected provider test, got %s", strategy, selected.ProviderName()) } } } // 确保FailureRate永远不会超过1.0 func TestRecordResult_FailureRateCapped(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("test", prov) // 多次失败 for i := 0; i < 20; i++ { r.RecordResult(context.Background(), "test", false, 100) } if r.health["test"].FailureRate > 1.0 { t.Errorf("failure rate should be capped at 1.0, got %f", r.health["test"].FailureRate) } } // 确保LatencyMs永远不会变成负数 func TestRecordResult_LatencyNeverNegative(t *testing.T) { r := NewRouter(StrategyLatency) prov := &mockProvider{name: "test", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("test", prov) // 提供负延迟 r.RecordResult(context.Background(), "test", true, -100) if r.health["test"].LatencyMs < 0 { t.Errorf("latency should never be negative, got %d", r.health["test"].LatencyMs) } } // 确保math.MaxInt64不会溢出 func TestSelectByLatency_MaxInt64(t *testing.T) { r := NewRouter(StrategyLatency) prov1 := &mockProvider{name: "p1", models: []string{"gpt-4"}, healthy: true} prov2 := &mockProvider{name: "p2", models: []string{"gpt-4"}, healthy: true} r.RegisterProvider("p1", prov1) r.RegisterProvider("p2", prov2) // p1设置为较大值,p2设置为MaxInt64 r.health["p1"].LatencyMs = math.MaxInt64 - 1 r.health["p2"].LatencyMs = math.MaxInt64 selected, err := r.selectByLatency([]string{"p1", "p2"}) if err != nil { t.Fatalf("unexpected error: %v", err) } // p1的延迟更低,应该被选中 if selected.ProviderName() != "p1" { t.Errorf("expected provider p1 (lower latency), got %s", selected.ProviderName()) } }