//go:build llm_script && !scripts_pkg // fetch_openrouter.go - OpenRouter 模型数据采集器 v2.0 // Sprint 2 增强版:指数退避重试 + 批量插入 + ProviderMapper + audit_log + 价格变动检测 + slog package main import ( "bufio" "context" "database/sql" "encoding/json" "flag" "fmt" "io" "log/slog" "net/http" "os" "strings" "time" "llm-intelligence/internal/collectors" "llm-intelligence/internal/retry" _ "github.com/lib/pq" ) // Config 采集配置 type Config struct { APIKey string APIURL string OutPath string MaxRetries int TimeoutSec int BatchSize int DBConn string StrictReal bool } // ModelInfo 模型信息(与 collectors 包兼容) type ModelInfo struct { ID string `json:"id"` Name string `json:"name,omitempty"` Created int64 `json:"created,omitempty"` Description string `json:"description,omitempty"` ContextLength int `json:"context_length,omitempty"` Capabilities []string `json:"capabilities,omitempty"` Pricing ModelPricing `json:"pricing,omitempty"` } type ModelPricing struct { Input float64 `json:"input,omitempty"` Output float64 `json:"output,omitempty"` } var ( collectorVersion = "v2.0" logger *slog.Logger ) func init() { logger = slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{ Level: slog.LevelInfo, })) } func main() { cfg := parseArgs() start := time.Now() logger.Info("采集器启动", "collector", "openrouter", "version", collectorVersion, "batch_size", cfg.BatchSize) var runErr error if err := run(cfg); err != nil { logger.Error("采集失败", "error", err, "duration", time.Since(start)) runErr = err } duration := time.Since(start) // 写入采集统计 if cfg.DBConn != "" { if err := recordCollectorStats(cfg.DBConn, runErr, duration); err != nil { logger.Warn("采集统计写入失败", "error", err) } } if runErr != nil { os.Exit(1) } logger.Info("采集完成", "collector", "openrouter", "duration_ms", duration.Milliseconds()) } func parseArgs() Config { loadProjectEnv() apiKey := flag.String("api-key", "", "OpenRouter API Key") apiURL := flag.String("api-url", "https://openrouter.ai/api/v1/models", "API 地址") outPath := flag.String("out", "models.json", "输出文件路径") maxRetries := flag.Int("retry", 3, "最大重试次数") timeoutSec := flag.Int("timeout", 30, "请求超时(秒)") batchSize := flag.Int("batch", 100, "批量插入批次大小") dbConn := flag.String("db", os.Getenv("DATABASE_URL"), "PostgreSQL 连接字符串") strictReal := flag.Bool("strict-real", false, "严格真实模式:缺少 API Key 或数据库写入失败时返回错误") flag.Parse() return Config{ APIKey: *apiKey, APIURL: *apiURL, OutPath: *outPath, MaxRetries: *maxRetries, TimeoutSec: *timeoutSec, BatchSize: *batchSize, DBConn: *dbConn, StrictReal: *strictReal, } } func loadProjectEnv() { for _, path := range []string{".env.local", ".env"} { loadEnvFile(path) } } func loadEnvFile(path string) { f, err := os.Open(path) if err != nil { return } defer f.Close() scanner := bufio.NewScanner(f) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if line == "" || strings.HasPrefix(line, "#") { continue } key, value, ok := strings.Cut(line, "=") if !ok { continue } key = strings.TrimSpace(key) value = strings.TrimSpace(value) value = strings.Trim(value, `"'`) if key == "" { continue } if _, exists := os.LookupEnv(key); exists { continue } _ = os.Setenv(key, value) } } func run(cfg Config) error { models, err := fetchModels(cfg) if err != nil { return err } logger.Info("API 数据获取完成", "records", len(models)) if cfg.DBConn != "" { if err := summarizeDB(cfg.DBConn, models, cfg.BatchSize); err != nil { logger.Error("PostgreSQL 写入失败", "error", err) if cfg.StrictReal { return fmt.Errorf("PostgreSQL 写入失败: %w", err) } logger.Warn("降级为仅写入 JSON") } else { logger.Info("PostgreSQL 写入完成", "records", len(models)) } } return summarize(cfg.OutPath, models) } // fetchModels 抓取 OpenRouter 模型列表(集成指数退避重试) func fetchModels(cfg Config) ([]ModelInfo, error) { if cfg.APIKey == "" { if cfg.StrictReal { return nil, fmt.Errorf("严格真实模式下必须提供 API Key") } logger.Warn("未提供 API Key,使用模拟数据") return []ModelInfo{ {ID: "openai/gpt-4o", ContextLength: 128000, Pricing: ModelPricing{Input: 2.5, Output: 10.0}}, {ID: "anthropic/claude-3.5-sonnet:free", ContextLength: 200000, Pricing: ModelPricing{}}, }, nil } strategy := retry.Strategy{ MaxRetries: cfg.MaxRetries, BaseDelay: 1 * time.Second, MaxDelay: 30 * time.Second, Multiplier: 2.0, Jitter: true, Retryable: retry.IsRetryable, } var models []ModelInfo var lastErr error err := retry.Do(context.Background(), strategy, func() error { client := &http.Client{Timeout: time.Duration(cfg.TimeoutSec) * time.Second} req, err := http.NewRequest("GET", cfg.APIURL, nil) if err != nil { return fmt.Errorf("构造请求失败: %w", err) } req.Header.Set("Authorization", "Bearer "+cfg.APIKey) req.Header.Set("Content-Type", "application/json") resp, err := client.Do(req) if err != nil { lastErr = err return fmt.Errorf("请求失败: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, _ := io.ReadAll(resp.Body) lastErr = retry.HTTPStatusError{StatusCode: resp.StatusCode, Body: string(body)} return lastErr } body, err := io.ReadAll(resp.Body) if err != nil { lastErr = err return fmt.Errorf("读取响应失败: %w", err) } models, err = parseModels(body) if err != nil { lastErr = err return fmt.Errorf("JSON 解析失败: %w", err) } return nil }) if err != nil { return nil, fmt.Errorf("采集失败(%d次尝试): %w", strategy.MaxRetries+1, lastErr) } return models, nil } func parseModels(raw []byte) ([]ModelInfo, error) { var wrapper struct { Data json.RawMessage `json:"data"` } if err := json.Unmarshal(raw, &wrapper); err != nil { return nil, fmt.Errorf("解析 data 字段失败: %w", err) } var rawItems []any if err := json.Unmarshal(wrapper.Data, &rawItems); err != nil { return nil, fmt.Errorf("解析模型数组失败: %w", err) } models := make([]ModelInfo, 0, len(rawItems)) for _, item := range rawItems { m, ok := item.(map[string]any) if !ok { continue } model := ModelInfo{ ID: getString(m, "id"), Name: getString(m, "name"), } if model.ID == "" { continue } if p, ok := m["pricing"].(map[string]any); ok { model.Pricing.Input = getPrice(p, "input", "prompt") model.Pricing.Output = getPrice(p, "output", "completion") } model.ContextLength = getInt(m, "context_length") model.Description = getString(m, "description") model.Created = getInt64(m, "created") if caps, ok := m["capabilities"].([]any); ok { for _, c := range caps { if s, ok := c.(string); ok { model.Capabilities = append(model.Capabilities, s) } } } models = append(models, model) } return models, nil } func deriveModality(model ModelInfo) string { for _, capability := range model.Capabilities { normalized := strings.ToLower(capability) switch { case strings.Contains(normalized, "vision"), strings.Contains(normalized, "image"): return "multimodal" case strings.Contains(normalized, "audio"): return "audio" case strings.Contains(normalized, "video"): return "video" case strings.Contains(normalized, "code"): return "code" } } hints := strings.ToLower(strings.Join([]string{model.ID, model.Name, model.Description}, " ")) switch { case strings.Contains(hints, "video") && (strings.Contains(hints, "omni") || strings.Contains(hints, "vision") || strings.Contains(hints, "multimodal")): return "multimodal" case strings.Contains(hints, "vision") || strings.Contains(hints, "image") || strings.Contains(hints, "vl") || strings.Contains(hints, "omni") || strings.Contains(hints, "multimodal"): return "multimodal" case strings.Contains(hints, "audio") || strings.Contains(hints, "speech") || strings.Contains(hints, "voice"): return "audio" case strings.Contains(hints, "video"): return "video" case strings.Contains(hints, "code"): return "code" default: return "text" } } func getString(m map[string]any, key string) string { if v, ok := m[key].(string); ok { return v } return "" } func getInt(m map[string]any, key string) int { if v, ok := m[key].(float64); ok { return int(v) } return 0 } func getInt64(m map[string]any, key string) int64 { if v, ok := m[key].(float64); ok { return int64(v) } return 0 } func getPrice(m map[string]any, keys ...string) float64 { for _, k := range keys { if v, ok := m[k].(float64); ok { return v } } return 0 } func summarize(outPath string, models []ModelInfo) error { return writeJSON(outPath, models) } // summarizeDB 将采集结果写入 PostgreSQL(批量插入 + ProviderMapper + 价格变动检测 + audit_log) func summarizeDB(connStr string, models []ModelInfo, batchSize int) error { db, err := sql.Open("postgres", connStr) if err != nil { return fmt.Errorf("连接数据库失败: %w", err) } defer db.Close() if err := db.Ping(); err != nil { return fmt.Errorf("ping 数据库失败: %w", err) } batchID := fmt.Sprintf("batch-%d", time.Now().Unix()) now := time.Now() effectiveDate := now.Format("2006-01-02") // 获取默认 operator(OpenRouter) var operatorID int64 err = db.QueryRow("SELECT id FROM operator WHERE name = 'OpenRouter' LIMIT 1").Scan(&operatorID) if err != nil { logger.Warn("未找到 OpenRouter operator,使用 NULL", "error", err) operatorID = 0 } // 获取上次价格数据(用于变动检测) lastPrices := make(map[int64]ModelPricing) rows, err := db.Query(` SELECT model_id, input_price_per_mtok, output_price_per_mtok FROM region_pricing WHERE operator_id = $1 AND effective_date = ( SELECT MAX(effective_date) FROM region_pricing WHERE operator_id = $1 ) `, operatorID) if err == nil { for rows.Next() { var mid int64 var p ModelPricing if err := rows.Scan(&mid, &p.Input, &p.Output); err == nil { lastPrices[mid] = p } } rows.Close() } insertedModels := 0 insertedPrices := 0 priceChanges := 0 // 批量处理 for i := 0; i < len(models); i += batchSize { end := i + batchSize if end > len(models) { end = len(models) } batch := models[i:end] tx, err := db.Begin() if err != nil { return fmt.Errorf("开启事务失败: %w", err) } for _, m := range batch { // 使用 ProviderMapper 映射厂商 mapping, err := collectors.MapOpenRouterID(m.ID) if err != nil { logger.Warn("Provider 映射失败", "id", m.ID, "error", err) mapping = collectors.ModelMapping{ Provider: collectors.ProviderInfo{ID: "unknown", Name: "Unknown"}, ModelName: m.Name, RawID: m.ID, IsFree: false, } } // 查找或创建 provider_id var providerID int64 err = tx.QueryRow("SELECT id FROM model_provider WHERE name = $1 LIMIT 1", mapping.Provider.Name).Scan(&providerID) if err != nil { // 未知厂商,插入 err = tx.QueryRow(` INSERT INTO model_provider (name, name_cn, country, status) VALUES ($1, $2, $3, 'active') ON CONFLICT (name) DO UPDATE SET name = EXCLUDED.name RETURNING id `, mapping.Provider.Name, mapping.Provider.NameCN, mapping.Provider.Country).Scan(&providerID) if err != nil { logger.Warn("创建 provider 失败", "name", mapping.Provider.Name, "error", err) providerID = 0 } } isFree := mapping.IsFree || (m.Pricing.Input == 0 && m.Pricing.Output == 0) // upsert models 表(带新字段) var modelID int64 err = tx.QueryRow(` INSERT INTO models ( source, external_id, name, description, context_length, capabilities, created_at_source, is_free, status, raw_payload, provider_id, version, modality, data_confidence, retrieved_at, batch_id, collector_version, source_url, created_at, updated_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $19) ON CONFLICT (external_id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description, context_length = EXCLUDED.context_length, capabilities = EXCLUDED.capabilities, created_at_source = EXCLUDED.created_at_source, is_free = EXCLUDED.is_free, status = EXCLUDED.status, raw_payload = EXCLUDED.raw_payload, provider_id = EXCLUDED.provider_id, data_confidence = 'official', retrieved_at = EXCLUDED.retrieved_at, batch_id = EXCLUDED.batch_id, collector_version = EXCLUDED.collector_version, updated_at = EXCLUDED.updated_at RETURNING id `, "openrouter", m.ID, m.Name, m.Description, m.ContextLength, jsonCapabilities(m.Capabilities), m.Created, isFree, "active", rawPayload(m), providerID, "", deriveModality(m), "official", now, batchID, collectorVersion, "https://openrouter.ai/api/v1/models", now).Scan(&modelID) if err != nil { _ = tx.Rollback() return fmt.Errorf("写入 models 失败 (%s): %w", m.ID, err) } insertedModels++ // 写入 audit_log _, _ = tx.Exec(` INSERT INTO audit_log (table_name, record_id, field_name, old_value, new_value, operation, operator, batch_id, source_url) VALUES ('models', $1, 'external_id', NULL, $2, 'INSERT', 'fetch_openrouter', $3, $4) `, modelID, m.ID, batchID, "https://openrouter.ai/api/v1/models") // upsert region_pricing 表(替代 model_prices) sourceType := "reseller" freeQuota := "" freeLimitations := "[]" rateLimit := "{}" if isFree { sourceType = "free_tier" freeQuota = "Imported free-tier pricing entry" freeLimitations = `["See source_url for current quota and policy"]` } var pricingID int64 err = tx.QueryRow(` INSERT INTO region_pricing ( model_id, operator_id, region, currency, input_price_per_mtok, output_price_per_mtok, is_free, effective_date, source_url, source_type, free_quota, free_limitations, rate_limit, created_at, updated_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $14) ON CONFLICT (model_id, operator_id, region, currency, effective_date) DO UPDATE SET input_price_per_mtok = EXCLUDED.input_price_per_mtok, output_price_per_mtok = EXCLUDED.output_price_per_mtok, is_free = EXCLUDED.is_free, source_type = EXCLUDED.source_type, free_quota = EXCLUDED.free_quota, free_limitations = EXCLUDED.free_limitations, rate_limit = EXCLUDED.rate_limit, updated_at = EXCLUDED.updated_at RETURNING id `, modelID, operatorID, "global", "USD", m.Pricing.Input, m.Pricing.Output, isFree, effectiveDate, "https://openrouter.ai/api/v1/models", sourceType, freeQuota, freeLimitations, rateLimit, now).Scan(&pricingID) if err != nil { _ = tx.Rollback() return fmt.Errorf("写入 region_pricing 失败 (%s): %w", m.ID, err) } insertedPrices++ // 价格变动检测(>5%) if lastPrice, ok := lastPrices[modelID]; ok { inputChange := calcChangePercent(lastPrice.Input, m.Pricing.Input) outputChange := calcChangePercent(lastPrice.Output, m.Pricing.Output) if abs(inputChange) > 5 || abs(outputChange) > 5 { _, _ = tx.Exec(` INSERT INTO pricing_history ( model_id, region, currency, old_input_price, new_input_price, old_output_price, new_output_price, change_percent, changed_at ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) `, modelID, "global", "USD", lastPrice.Input, m.Pricing.Input, lastPrice.Output, m.Pricing.Output, max(abs(inputChange), abs(outputChange)), now) priceChanges++ } } } if err := tx.Commit(); err != nil { return fmt.Errorf("提交事务失败: %w", err) } logger.Info("批次完成", "batch", i/batchSize+1, "records", len(batch)) } logger.Info("PostgreSQL 写入完成", "models", insertedModels, "prices", insertedPrices, "price_changes", priceChanges, "batch_id", batchID) return nil } func calcChangePercent(old, new float64) float64 { if old == 0 { if new == 0 { return 0 } return 100 } return ((new - old) / old) * 100 } func abs(v float64) float64 { if v < 0 { return -v } return v } func max(a, b float64) float64 { if a > b { return a } return b } func jsonCapabilities(caps []string) []byte { if len(caps) == 0 { return []byte("[]") } b, _ := json.Marshal(caps) return b } func rawPayload(m ModelInfo) []byte { b, _ := json.Marshal(m) return b } func writeJSON(outPath string, models []ModelInfo) error { total := len(models) var freeCnt, paidCnt int for _, m := range models { if len(m.ID) > 5 && m.ID[len(m.ID)-5:] == ":free" { freeCnt++ } else if m.Pricing.Input > 0 || m.Pricing.Output > 0 { paidCnt++ } } summary := fmt.Sprintf("采集完成: 共 %d 模型(免费 %d / 付费 %d)\n", total, freeCnt, paidCnt) fmt.Print(summary) out, err := os.Create(outPath) if err != nil { return fmt.Errorf("创建输出文件失败: %w", err) } defer out.Close() enc := json.NewEncoder(out) enc.SetIndent("", " ") if err := enc.Encode(map[string]any{ "generated_at": time.Now().Format(time.RFC3339), "total": total, "free": freeCnt, "paid": paidCnt, "models": models, }); err != nil { return fmt.Errorf("写入 JSON 失败: %w", err) } fmt.Printf("结果已写入: %s\n", outPath) return nil } // recordCollectorStats 记录采集统计到 collector_stats 表 func recordCollectorStats(connStr string, runErr error, duration time.Duration) error { db, err := sql.Open("postgres", connStr) if err != nil { return err } defer db.Close() success := runErr == nil errMsg := "" if runErr != nil { errMsg = runErr.Error() } _, err = db.Exec(` INSERT INTO collector_stats (source, batch_id, success, duration_ms, error_message, created_at) VALUES ('openrouter', $1, $2, $3, $4, $5) `, fmt.Sprintf("batch-%d", time.Now().Unix()), success, int(duration.Milliseconds()), errMsg, time.Now()) return err }