This commit is contained in:
108
scripts/fetch_multi_source_test.go
Normal file
108
scripts/fetch_multi_source_test.go
Normal file
@@ -0,0 +1,108 @@
|
||||
//go:build llm_script
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fakeSource struct {
|
||||
name string
|
||||
prices []ModelPricing
|
||||
err error
|
||||
}
|
||||
|
||||
func (s fakeSource) Name() string { return s.name }
|
||||
|
||||
func (s fakeSource) FetchPricing() ([]ModelPricing, error) { return s.prices, s.err }
|
||||
|
||||
func (s fakeSource) SourceType() string { return "official" }
|
||||
|
||||
func TestBuildSourcesFiltersRequestedNames(t *testing.T) {
|
||||
sources, err := buildSources("", []string{"moonshot", "openai"})
|
||||
if err != nil {
|
||||
t.Fatalf("buildSources returned error: %v", err)
|
||||
}
|
||||
|
||||
if len(sources) != 2 {
|
||||
t.Fatalf("expected 2 sources, got %d", len(sources))
|
||||
}
|
||||
|
||||
if sources[0].Name() != "Moonshot" || sources[1].Name() != "OpenAI" {
|
||||
t.Fatalf("unexpected source order: %s, %s", sources[0].Name(), sources[1].Name())
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSourcesRejectsUnknownNames(t *testing.T) {
|
||||
_, err := buildSources("", []string{"moonshot", "unknown"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unknown source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRunCollectorDryRunSkipsDatabaseWrite(t *testing.T) {
|
||||
cfg := runConfig{DryRun: true}
|
||||
var out bytes.Buffer
|
||||
writeCalled := false
|
||||
|
||||
err := runCollector(
|
||||
cfg,
|
||||
[]DataSource{
|
||||
fakeSource{
|
||||
name: "Moonshot",
|
||||
prices: []ModelPricing{
|
||||
{ModelID: "kimi-k2.6", ProviderCountry: "CN", Currency: "CNY"},
|
||||
{ModelID: "kimi-k2-0905-preview", ProviderCountry: "CN", Currency: "CNY"},
|
||||
},
|
||||
},
|
||||
fakeSource{
|
||||
name: "OpenAI",
|
||||
prices: []ModelPricing{
|
||||
{ModelID: "gpt-5.5", ProviderCountry: "US", Currency: "USD"},
|
||||
},
|
||||
},
|
||||
},
|
||||
func([]ModelPricing) error {
|
||||
writeCalled = true
|
||||
return nil
|
||||
},
|
||||
&out,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("runCollector returned error: %v", err)
|
||||
}
|
||||
|
||||
if writeCalled {
|
||||
t.Fatal("expected dry-run to skip database write")
|
||||
}
|
||||
|
||||
output := out.String()
|
||||
if output == "" {
|
||||
t.Fatal("expected dry-run summary output")
|
||||
}
|
||||
if !bytes.Contains(out.Bytes(), []byte("sources=2")) {
|
||||
t.Fatalf("expected sources summary, got %q", output)
|
||||
}
|
||||
if !bytes.Contains(out.Bytes(), []byte("models=3")) {
|
||||
t.Fatalf("expected model summary, got %q", output)
|
||||
}
|
||||
if !bytes.Contains(out.Bytes(), []byte("currencies=CNY:2,USD:1")) {
|
||||
t.Fatalf("expected currency summary, got %q", output)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPricingMetadataClassifiesSourceType(t *testing.T) {
|
||||
freeTier := pricingMetadata(ModelPricing{OperatorType: "official", IsFree: true})
|
||||
if freeTier.SourceType != "free_tier" {
|
||||
t.Fatalf("expected free_tier, got %q", freeTier.SourceType)
|
||||
}
|
||||
if freeTier.FreeQuota == "" {
|
||||
t.Fatal("expected free tier quota description")
|
||||
}
|
||||
|
||||
reseller := pricingMetadata(ModelPricing{OperatorType: "reseller"})
|
||||
if reseller.SourceType != "reseller" {
|
||||
t.Fatalf("expected reseller, got %q", reseller.SourceType)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user