Files
llm-intelligence/scripts/import_baichuan_pricing.go

189 lines
5.8 KiB
Go

//go:build llm_script
package main
import (
"database/sql"
"flag"
"fmt"
"io"
"net/http"
"os"
"regexp"
"sort"
"strings"
"time"
)
const defaultBaichuanPricingURL = "https://platform.baichuan-ai.com/prices"
type baichuanPricingImportConfig struct {
URL string
Fixture string
DryRun bool
Timeout time.Duration
}
type baichuanPricingRow struct {
Index int
ModelName string
ContextLength int
InputPrice float64
OutputPrice float64
}
var baichuanModelContextPattern = regexp.MustCompile(`模型调用\s+(Baichuan[-A-Za-z0-9]+)\s+([0-9]+k)`)
var baichuanPairPricePattern = regexp.MustCompile(`输入:([0-9.]+)元/千tokens\s+输出:([0-9.]+)元/千tokens`)
var baichuanFlatPricePattern = regexp.MustCompile(`(?:00:00\s*~\s*24:00|00:00\s*~\s*8:00)\s+([0-9.]+)元/千tokens`)
func main() {
loadSubscriptionImportEnv()
var url string
var fixture string
var dryRun bool
var timeoutSeconds int
flag.StringVar(&url, "url", defaultBaichuanPricingURL, "百川官方价格页")
flag.StringVar(&fixture, "fixture", "", "百川价格样例文件")
flag.BoolVar(&dryRun, "dry-run", false, "仅解析并打印摘要,不写入数据库")
flag.IntVar(&timeoutSeconds, "timeout", 20, "请求超时(秒)")
flag.Parse()
cfg := baichuanPricingImportConfig{URL: url, Fixture: fixture, DryRun: dryRun, Timeout: time.Duration(timeoutSeconds) * time.Second}
var db *sql.DB
var err error
if !cfg.DryRun {
db, err = subscriptionImportDB()
if err != nil {
fmt.Fprintf(os.Stderr, "open db: %v\n", err)
os.Exit(1)
}
defer db.Close()
}
if err := runBaichuanPricingImport(cfg, db, os.Stdout); err != nil {
fmt.Fprintf(os.Stderr, "import_baichuan_pricing: %v\n", err)
os.Exit(1)
}
}
func runBaichuanPricingImport(cfg baichuanPricingImportConfig, db *sql.DB, out io.Writer) error {
client := &http.Client{Timeout: cfg.Timeout}
raw, err := fetchRawPricingPage(cfg.URL, cfg.Fixture, client)
if err != nil {
return err
}
records, err := parseBaichuanPricingCatalog(raw)
if err != nil {
return err
}
records = dedupeOfficialPricingRecords(records)
if cfg.DryRun {
_, err = fmt.Fprintf(out, "source=baichuan-pricing-import models=%d operator=%s dry_run=true\n", len(records), records[0].OperatorName)
return err
}
if db == nil {
return fmt.Errorf("db is required when dry-run=false")
}
if err := upsertOfficialPricingRecords(db, records, "baichuan-pricing-import"); err != nil {
return err
}
var tableRows int
if err := db.QueryRow(`SELECT COUNT(*) FROM region_pricing`).Scan(&tableRows); err != nil {
return fmt.Errorf("count region_pricing: %w", err)
}
_, err = fmt.Fprintf(out, "source=baichuan-pricing-import models=%d operator=%s table_rows=%d dry_run=false\n", len(records), records[0].OperatorName, tableRows)
return err
}
func parseBaichuanPricingCatalog(raw string) ([]officialPricingRecord, error) {
text := cleanHTMLText(raw)
text = strings.ReplaceAll(text, "\n", " ")
text = regexp.MustCompile(`\s+`).ReplaceAllString(text, " ")
text = strings.TrimSpace(text)
sectionStart := strings.Index(text, "通用大模型")
if sectionStart == -1 {
return nil, fmt.Errorf("unexpected baichuan pricing content: missing 通用大模型")
}
text = text[sectionStart:]
sectionEnd := strings.Index(text, "搜索增强服务")
if sectionEnd == -1 {
return nil, fmt.Errorf("unexpected baichuan pricing content: missing 搜索增强服务")
}
section := text[:sectionEnd]
chunks := strings.Split(section, "模型调用 ")
rows := make([]baichuanPricingRow, 0, len(chunks))
for idx, chunk := range chunks {
chunk = strings.TrimSpace(chunk)
if chunk == "" {
continue
}
chunk = "模型调用 " + chunk
if strings.Contains(chunk, "Baichuan-Text-Embedding") {
continue
}
meta := baichuanModelContextPattern.FindStringSubmatch(chunk)
if len(meta) != 3 {
continue
}
modelName := strings.TrimSpace(meta[1])
contextLength := parseContextLengthCommon(meta[2])
if contextLength == 0 {
continue
}
row := baichuanPricingRow{Index: idx, ModelName: modelName, ContextLength: contextLength}
if pair := baichuanPairPricePattern.FindStringSubmatch(chunk); len(pair) == 3 {
row.InputPrice = baichuanPerKTokenToPerMToken(pair[1])
row.OutputPrice = baichuanPerKTokenToPerMToken(pair[2])
} else if flat := baichuanFlatPricePattern.FindStringSubmatch(chunk); len(flat) == 2 {
price := baichuanPerKTokenToPerMToken(flat[1])
row.InputPrice = price
row.OutputPrice = price
} else {
continue
}
rows = append(rows, row)
}
if len(rows) == 0 {
return nil, fmt.Errorf("unexpected baichuan pricing content: no model rows parsed")
}
sort.Slice(rows, func(i, j int) bool { return rows[i].Index < rows[j].Index })
providerNameCn, providerCountry, providerWebsite := providerMetadata("Baichuan")
records := make([]officialPricingRecord, 0, len(rows))
for _, row := range rows {
records = append(records, officialPricingRecord{
ModelID: normalizeExternalID("baichuan", row.ModelName),
ModelName: row.ModelName,
ProviderName: "Baichuan",
ProviderNameCn: providerNameCn,
ProviderCountry: providerCountry,
ProviderWebsite: providerWebsite,
OperatorName: "Baichuan API",
OperatorNameCn: "百川开放平台",
OperatorCountry: "CN",
OperatorWebsite: "https://platform.baichuan-ai.com/docs",
OperatorType: "official",
Region: "CN",
Currency: "CNY",
InputPrice: row.InputPrice,
OutputPrice: row.OutputPrice,
ContextLength: row.ContextLength,
SourceURL: defaultBaichuanPricingURL,
ModelSourceURL: defaultBaichuanPricingURL,
DateConfidence: "unknown",
DateSourceKind: "official_pricing",
Modality: detectModality(row.ModelName),
})
}
return records, nil
}
func baichuanPerKTokenToPerMToken(raw string) float64 {
return mustParseSubscriptionPrice(raw) * 1000
}