//go:build llm_script package main import ( "encoding/json" "net/http" "net/http/httptest" "os" "path/filepath" "testing" "llm-intelligence/internal/retry" ) // Test 1: parseModels 正确解析 name、context_length、capabilities、pricing input/prompt 和 output/completion func TestParseModels(t *testing.T) { // 从样例文件读取,而非内联 JSON samplePath := filepath.Join("testdata", "openrouter_models_sample.json") raw, err := os.ReadFile(samplePath) if err != nil { t.Fatalf("读取样例文件失败: %v", err) } models, err := parseModels(raw) if err != nil { t.Fatalf("parseModels 失败: %v", err) } if len(models) != 3 { t.Fatalf("期望 3 条,实际 %d", len(models)) } // 第一条:完整字段 m := models[0] if m.ID != "openai/gpt-4o" { t.Errorf("ID 错误: %s", m.ID) } if m.Name != "GPT-4o" { t.Errorf("Name 错误: %s", m.Name) } if m.ContextLength != 128000 { t.Errorf("ContextLength 错误: %d", m.ContextLength) } if len(m.Capabilities) != 3 { t.Errorf("Capabilities 长度错误: %d", len(m.Capabilities)) } if m.Pricing.Input != 2.5 { t.Errorf("Pricing.Input 错误: %f", m.Pricing.Input) } if m.Pricing.Output != 10.0 { t.Errorf("Pricing.Output 错误: %f", m.Pricing.Output) } if modality := deriveModality(m); modality != "multimodal" { t.Errorf("deriveModality = %q, want %q", modality, "multimodal") } // 第二条:pricing 用 prompt/completion 别名回退 m2 := models[1] if m2.Pricing.Input != 0.1 { t.Errorf("Input 回退 prompt 失败: %f", m2.Pricing.Input) } if m2.Pricing.Output != 0.3 { t.Errorf("Output 回退 completion 失败: %f", m2.Pricing.Output) } // 第三条:空 pricing m3 := models[2] if m3.Pricing.Input != 0 || m3.Pricing.Output != 0 { t.Errorf("空 pricing 未返回 0: input=%f output=%f", m3.Pricing.Input, m3.Pricing.Output) } } func TestDeriveModality(t *testing.T) { tests := []struct { name string capabilities []string want string }{ {name: "vision first", capabilities: []string{"vision", "json_mode"}, want: "multimodal"}, {name: "audio", capabilities: []string{"audio_generation"}, want: "audio"}, {name: "code", capabilities: []string{"code_interpreter"}, want: "code"}, {name: "text fallback", capabilities: []string{"function_calling"}, want: "text"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := deriveModality(ModelInfo{Capabilities: tt.capabilities}); got != tt.want { t.Fatalf("deriveModality() = %q, want %q", got, tt.want) } }) } } func TestDeriveModalityInfersFromModelIdentityWithoutCapabilities(t *testing.T) { tests := []struct { name string model ModelInfo want string }{ { name: "omni id maps to multimodal", model: ModelInfo{ ID: "nvidia/nemotron-3-nano-omni-30b-a3b-reasoning:free", Description: "accepts text, image, video, and audio inputs", }, want: "multimodal", }, { name: "audio id maps to audio", model: ModelInfo{ ID: "openai/gpt-audio", Description: "audio model for natural sounding voices", }, want: "audio", }, { name: "vl id maps to multimodal", model: ModelInfo{ ID: "qwen/qwen3-vl-32b-instruct", Description: "vision-language model for text, images, and video", }, want: "multimodal", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if got := deriveModality(tt.model); got != tt.want { t.Fatalf("deriveModality(%+v) = %q, want %q", tt.model, got, tt.want) } }) } } // Test 2: run 无 API Key 时写入临时文件,JSON 含 total 和 models 字段 func TestRunNoAPIKey(t *testing.T) { tmpDir := t.TempDir() outPath := filepath.Join(tmpDir, "models.json") cfg := Config{OutPath: outPath} err := run(cfg) if err != nil { t.Fatalf("run 失败: %v", err) } data, err := os.ReadFile(outPath) if err != nil { t.Fatalf("读取输出文件失败: %v", err) } var result map[string]any if err := json.Unmarshal(data, &result); err != nil { t.Fatalf("JSON 解析失败: %v", err) } if _, ok := result["total"]; !ok { t.Error("JSON 缺少 total 字段") } if _, ok := result["models"]; !ok { t.Error("JSON 缺少 models 字段") } models, ok := result["models"].([]any) if !ok { t.Fatal("models 字段类型错误") } if len(models) == 0 { t.Error("models 为空") } } func TestFetchModelsFailsInStrictRealModeWithoutAPIKey(t *testing.T) { _, err := fetchModels(Config{StrictReal: true}) if err == nil { t.Fatal("strict real mode should fail without API key") } } func TestFetchModelsDoesNotRetryPermanentHTTPErrors(t *testing.T) { attempts := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ http.Error(w, "forbidden", http.StatusForbidden) })) defer server.Close() _, err := fetchModels(Config{ APIKey: "test-key", APIURL: server.URL, MaxRetries: 3, TimeoutSec: 1, StrictReal: true, }) if err == nil { t.Fatal("expected fetchModels to fail on 403") } if attempts != 1 { t.Fatalf("expected 1 attempt for permanent HTTP error, got %d", attempts) } } func TestFetchModelsRetriesServerErrors(t *testing.T) { attempts := 0 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { attempts++ if attempts < 3 { http.Error(w, "temporary", http.StatusBadGateway) return } w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"data":[{"id":"openai/gpt-4o","name":"GPT-4o","context_length":128000,"pricing":{"input":2.5,"output":10.0}}]}`)) })) defer server.Close() models, err := fetchModels(Config{ APIKey: "test-key", APIURL: server.URL, MaxRetries: 3, TimeoutSec: 1, StrictReal: true, }) if err != nil { t.Fatalf("expected retry success, got %v", err) } if len(models) != 1 { t.Fatalf("expected 1 model, got %d", len(models)) } if attempts != 3 { t.Fatalf("expected 3 attempts for temporary server error, got %d", attempts) } } func TestRunFailsInStrictRealModeWhenDBWriteFails(t *testing.T) { tmpDir := t.TempDir() outPath := filepath.Join(tmpDir, "models.json") server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _, _ = w.Write([]byte(`{"data":[{"id":"openai/gpt-4o","name":"GPT-4o","context_length":128000,"pricing":{"input":2.5,"output":10.0}}]}`)) })) defer server.Close() err := run(Config{ APIKey: "test-key", APIURL: server.URL, OutPath: outPath, DBConn: "postgres://invalid@127.0.0.1:1/invalid?sslmode=disable", BatchSize: 10, TimeoutSec: 1, StrictReal: true, }) if err == nil { t.Fatal("strict real mode should fail when database write fails") } } func TestRetryHTTPStatusErrorClassification(t *testing.T) { if retry.IsRetryable(retry.HTTPStatusError{StatusCode: http.StatusForbidden}) { t.Fatal("403 should not be retryable") } if !retry.IsRetryable(retry.HTTPStatusError{StatusCode: http.StatusBadGateway}) { t.Fatal("502 should be retryable") } }