189 lines
5.8 KiB
Go
189 lines
5.8 KiB
Go
//go:build llm_script
|
|
|
|
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"flag"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"os"
|
|
"regexp"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const defaultBaichuanPricingURL = "https://platform.baichuan-ai.com/prices"
|
|
|
|
type baichuanPricingImportConfig struct {
|
|
URL string
|
|
Fixture string
|
|
DryRun bool
|
|
Timeout time.Duration
|
|
}
|
|
|
|
type baichuanPricingRow struct {
|
|
Index int
|
|
ModelName string
|
|
ContextLength int
|
|
InputPrice float64
|
|
OutputPrice float64
|
|
}
|
|
|
|
var baichuanModelContextPattern = regexp.MustCompile(`模型调用\s+(Baichuan[-A-Za-z0-9]+)\s+([0-9]+k)`)
|
|
var baichuanPairPricePattern = regexp.MustCompile(`输入:([0-9.]+)元/千tokens\s+输出:([0-9.]+)元/千tokens`)
|
|
var baichuanFlatPricePattern = regexp.MustCompile(`(?:00:00\s*~\s*24:00|00:00\s*~\s*8:00)\s+([0-9.]+)元/千tokens`)
|
|
|
|
func main() {
|
|
loadSubscriptionImportEnv()
|
|
|
|
var url string
|
|
var fixture string
|
|
var dryRun bool
|
|
var timeoutSeconds int
|
|
|
|
flag.StringVar(&url, "url", defaultBaichuanPricingURL, "百川官方价格页")
|
|
flag.StringVar(&fixture, "fixture", "", "百川价格样例文件")
|
|
flag.BoolVar(&dryRun, "dry-run", false, "仅解析并打印摘要,不写入数据库")
|
|
flag.IntVar(&timeoutSeconds, "timeout", 20, "请求超时(秒)")
|
|
flag.Parse()
|
|
|
|
cfg := baichuanPricingImportConfig{URL: url, Fixture: fixture, DryRun: dryRun, Timeout: time.Duration(timeoutSeconds) * time.Second}
|
|
|
|
var db *sql.DB
|
|
var err error
|
|
if !cfg.DryRun {
|
|
db, err = subscriptionImportDB()
|
|
if err != nil {
|
|
fmt.Fprintf(os.Stderr, "open db: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
defer db.Close()
|
|
}
|
|
|
|
if err := runBaichuanPricingImport(cfg, db, os.Stdout); err != nil {
|
|
fmt.Fprintf(os.Stderr, "import_baichuan_pricing: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func runBaichuanPricingImport(cfg baichuanPricingImportConfig, db *sql.DB, out io.Writer) error {
|
|
client := &http.Client{Timeout: cfg.Timeout}
|
|
raw, err := fetchRawPricingPage(cfg.URL, cfg.Fixture, client)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
records, err := parseBaichuanPricingCatalog(raw)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
records = dedupeOfficialPricingRecords(records)
|
|
if cfg.DryRun {
|
|
_, err = fmt.Fprintf(out, "source=baichuan-pricing-import models=%d operator=%s dry_run=true\n", len(records), records[0].OperatorName)
|
|
return err
|
|
}
|
|
if db == nil {
|
|
return fmt.Errorf("db is required when dry-run=false")
|
|
}
|
|
if err := upsertOfficialPricingRecords(db, records, "baichuan-pricing-import"); err != nil {
|
|
return err
|
|
}
|
|
var tableRows int
|
|
if err := db.QueryRow(`SELECT COUNT(*) FROM region_pricing`).Scan(&tableRows); err != nil {
|
|
return fmt.Errorf("count region_pricing: %w", err)
|
|
}
|
|
_, err = fmt.Fprintf(out, "source=baichuan-pricing-import models=%d operator=%s table_rows=%d dry_run=false\n", len(records), records[0].OperatorName, tableRows)
|
|
return err
|
|
}
|
|
|
|
func parseBaichuanPricingCatalog(raw string) ([]officialPricingRecord, error) {
|
|
text := cleanHTMLText(raw)
|
|
text = strings.ReplaceAll(text, "\n", " ")
|
|
text = regexp.MustCompile(`\s+`).ReplaceAllString(text, " ")
|
|
text = strings.TrimSpace(text)
|
|
|
|
sectionStart := strings.Index(text, "通用大模型")
|
|
if sectionStart == -1 {
|
|
return nil, fmt.Errorf("unexpected baichuan pricing content: missing 通用大模型")
|
|
}
|
|
text = text[sectionStart:]
|
|
sectionEnd := strings.Index(text, "搜索增强服务")
|
|
if sectionEnd == -1 {
|
|
return nil, fmt.Errorf("unexpected baichuan pricing content: missing 搜索增强服务")
|
|
}
|
|
section := text[:sectionEnd]
|
|
|
|
chunks := strings.Split(section, "模型调用 ")
|
|
rows := make([]baichuanPricingRow, 0, len(chunks))
|
|
for idx, chunk := range chunks {
|
|
chunk = strings.TrimSpace(chunk)
|
|
if chunk == "" {
|
|
continue
|
|
}
|
|
chunk = "模型调用 " + chunk
|
|
if strings.Contains(chunk, "Baichuan-Text-Embedding") {
|
|
continue
|
|
}
|
|
meta := baichuanModelContextPattern.FindStringSubmatch(chunk)
|
|
if len(meta) != 3 {
|
|
continue
|
|
}
|
|
modelName := strings.TrimSpace(meta[1])
|
|
contextLength := parseContextLengthCommon(meta[2])
|
|
if contextLength == 0 {
|
|
continue
|
|
}
|
|
row := baichuanPricingRow{Index: idx, ModelName: modelName, ContextLength: contextLength}
|
|
if pair := baichuanPairPricePattern.FindStringSubmatch(chunk); len(pair) == 3 {
|
|
row.InputPrice = baichuanPerKTokenToPerMToken(pair[1])
|
|
row.OutputPrice = baichuanPerKTokenToPerMToken(pair[2])
|
|
} else if flat := baichuanFlatPricePattern.FindStringSubmatch(chunk); len(flat) == 2 {
|
|
price := baichuanPerKTokenToPerMToken(flat[1])
|
|
row.InputPrice = price
|
|
row.OutputPrice = price
|
|
} else {
|
|
continue
|
|
}
|
|
rows = append(rows, row)
|
|
}
|
|
if len(rows) == 0 {
|
|
return nil, fmt.Errorf("unexpected baichuan pricing content: no model rows parsed")
|
|
}
|
|
sort.Slice(rows, func(i, j int) bool { return rows[i].Index < rows[j].Index })
|
|
|
|
providerNameCn, providerCountry, providerWebsite := providerMetadata("Baichuan")
|
|
records := make([]officialPricingRecord, 0, len(rows))
|
|
for _, row := range rows {
|
|
records = append(records, officialPricingRecord{
|
|
ModelID: normalizeExternalID("baichuan", row.ModelName),
|
|
ModelName: row.ModelName,
|
|
ProviderName: "Baichuan",
|
|
ProviderNameCn: providerNameCn,
|
|
ProviderCountry: providerCountry,
|
|
ProviderWebsite: providerWebsite,
|
|
OperatorName: "Baichuan API",
|
|
OperatorNameCn: "百川开放平台",
|
|
OperatorCountry: "CN",
|
|
OperatorWebsite: "https://platform.baichuan-ai.com/docs",
|
|
OperatorType: "official",
|
|
Region: "CN",
|
|
Currency: "CNY",
|
|
InputPrice: row.InputPrice,
|
|
OutputPrice: row.OutputPrice,
|
|
ContextLength: row.ContextLength,
|
|
SourceURL: defaultBaichuanPricingURL,
|
|
ModelSourceURL: defaultBaichuanPricingURL,
|
|
DateConfidence: "unknown",
|
|
DateSourceKind: "official_pricing",
|
|
Modality: detectModality(row.ModelName),
|
|
})
|
|
}
|
|
return records, nil
|
|
}
|
|
|
|
func baichuanPerKTokenToPerMToken(raw string) float64 {
|
|
return mustParseSubscriptionPrice(raw) * 1000
|
|
}
|