Files
llm-intelligence/scripts/official_pricing_import_common.go

565 lines
17 KiB
Go

//go:build llm_script
package main
import (
"database/sql"
"fmt"
"html"
"io"
"net/http"
"os"
"regexp"
"strings"
"time"
)
const officialPricingFetchMaxAttempts = 3
type officialPricingFetchOptions struct {
AcceptLanguage string
}
type officialPricingRecord struct {
ModelID string
ModelName string
ProviderName string
ProviderNameCn string
ProviderCountry string
ProviderWebsite string
OperatorName string
OperatorNameCn string
OperatorCountry string
OperatorWebsite string
OperatorType string
Region string
Currency string
PricingMode string
PriceUnit string
FlatPrice float64
InputPrice float64
OutputPrice float64
ContextLength int
IsFree bool
SourceURL string
ModelSourceURL string
ReleaseDate string
DateConfidence string
DateSourceKind string
Modality string
}
func upsertOfficialPricingRecords(db *sql.DB, records []officialPricingRecord, batchID string) error {
records = dedupeOfficialPricingRecords(records)
if len(records) == 0 {
return fmt.Errorf("official pricing records are empty")
}
if strings.TrimSpace(batchID) == "" {
batchID = fmt.Sprintf("official-pricing-%s", time.Now().Format("20060102-150405"))
}
for _, record := range records {
providerID, err := ensureOfficialPricingProvider(db, record)
if err != nil {
return err
}
operatorID, err := ensureOfficialPricingOperator(db, record)
if err != nil {
return err
}
modelID, err := ensureOfficialPricingModel(db, record, providerID, batchID)
if err != nil {
return err
}
sourceType := officialPricingSourceType(record.OperatorType, record.IsFree)
freeQuota := ""
freeLimitations := "[]"
rateLimit := "{}"
if record.IsFree {
freeQuota = "See source_url for provider free-tier details"
freeLimitations = `["See source_url for current quota and policy"]`
}
_, err = db.Exec(
`INSERT INTO region_pricing (
model_id, operator_id, region, currency,
pricing_mode, price_unit, flat_price,
input_price_per_mtok, output_price_per_mtok,
is_free, effective_date, source_url, source_type,
free_quota, free_limitations, rate_limit
) VALUES (
$1, $2, $3, $4,
$5, $6, $7,
$8, $9, $10, CURRENT_DATE, $11, $12,
$13, $14, $15
)
ON CONFLICT (model_id, operator_id, region, currency, effective_date)
DO UPDATE SET
pricing_mode = EXCLUDED.pricing_mode,
price_unit = EXCLUDED.price_unit,
flat_price = EXCLUDED.flat_price,
input_price_per_mtok = EXCLUDED.input_price_per_mtok,
output_price_per_mtok = EXCLUDED.output_price_per_mtok,
is_free = EXCLUDED.is_free,
source_url = EXCLUDED.source_url,
source_type = EXCLUDED.source_type,
free_quota = EXCLUDED.free_quota,
free_limitations = EXCLUDED.free_limitations,
rate_limit = EXCLUDED.rate_limit,
updated_at = CURRENT_TIMESTAMP`,
modelID, operatorID, record.Region, record.Currency,
fallbackPricingMode(record.PricingMode), fallbackPriceUnit(record.PriceUnit), nullIfZeroFloat(record.FlatPrice),
record.InputPrice, record.OutputPrice, record.IsFree, record.SourceURL, sourceType,
nullIfBlank(freeQuota), freeLimitations, rateLimit,
)
if err != nil {
return fmt.Errorf("upsert region_pricing %s: %w", record.ModelID, err)
}
}
return nil
}
func ensureOfficialPricingProvider(db *sql.DB, record officialPricingRecord) (int64, error) {
var providerID int64
err := db.QueryRow(`SELECT id FROM model_provider WHERE name = $1`, record.ProviderName).Scan(&providerID)
if err == nil {
_, updateErr := db.Exec(
`UPDATE model_provider
SET name_cn = COALESCE(name_cn, $2),
website = COALESCE(NULLIF(website, ''), $3),
updated_at = CURRENT_TIMESTAMP
WHERE id = $1`,
providerID, nullIfBlank(record.ProviderNameCn), nullIfBlank(record.ProviderWebsite),
)
return providerID, updateErr
}
if err != sql.ErrNoRows {
return 0, err
}
err = db.QueryRow(
`INSERT INTO model_provider (name, name_cn, country, website, status)
VALUES ($1, $2, $3, $4, 'active')
RETURNING id`,
record.ProviderName, nullIfBlank(record.ProviderNameCn), record.ProviderCountry, nullIfBlank(record.ProviderWebsite),
).Scan(&providerID)
return providerID, err
}
func ensureOfficialPricingOperator(db *sql.DB, record officialPricingRecord) (int64, error) {
var operatorID int64
err := db.QueryRow(`SELECT id FROM operator WHERE name = $1`, record.OperatorName).Scan(&operatorID)
if err == nil {
_, updateErr := db.Exec(
`UPDATE operator
SET name_cn = COALESCE(name_cn, $2),
website = COALESCE(NULLIF(website, ''), $3),
type = COALESCE(NULLIF(type, ''), $4),
updated_at = CURRENT_TIMESTAMP
WHERE id = $1`,
operatorID, nullIfBlank(record.OperatorNameCn), nullIfBlank(record.OperatorWebsite), nullIfBlank(record.OperatorType),
)
return operatorID, updateErr
}
if err != sql.ErrNoRows {
return 0, err
}
err = db.QueryRow(
`INSERT INTO operator (name, name_cn, country, website, description, status, type)
VALUES ($1, $2, $3, $4, $5, 'active', $6)
RETURNING id`,
record.OperatorName, nullIfBlank(record.OperatorNameCn), record.OperatorCountry, nullIfBlank(record.OperatorWebsite),
fmt.Sprintf("%s official pricing import", record.OperatorName), record.OperatorType,
).Scan(&operatorID)
return operatorID, err
}
func ensureOfficialPricingModel(db *sql.DB, record officialPricingRecord, providerID int64, batchID string) (int64, error) {
var modelID int64
err := db.QueryRow(`SELECT id FROM models WHERE external_id = $1`, record.ModelID).Scan(&modelID)
if err == sql.ErrNoRows {
err = db.QueryRow(
`INSERT INTO models (
external_id, name, provider_id, modality, context_length,
status, source, batch_id, source_url, release_date,
date_confidence, date_source_kind
) VALUES (
$1, $2, $3, $4, $5,
'active', $6, $7, $8, $9,
$10, $11
) RETURNING id`,
record.ModelID, record.ModelName, providerID, fallbackModality(record.Modality), nullIfZeroIntCommon(record.ContextLength),
record.OperatorName, batchID, firstNonEmptyText(record.ModelSourceURL, record.SourceURL), releaseDateValueCommon(record.ReleaseDate),
fallbackDateConfidence(record.DateConfidence), fallbackDateSourceKind(record.DateSourceKind),
).Scan(&modelID)
if err != nil {
return 0, err
}
return modelID, nil
}
if err != nil {
return 0, err
}
_, err = db.Exec(
`UPDATE models
SET name = $2,
provider_id = $3,
modality = COALESCE($4, modality),
context_length = COALESCE($5, context_length),
source = $6,
batch_id = $7,
source_url = COALESCE(NULLIF(source_url, ''), $8),
release_date = COALESCE(release_date, $9),
date_confidence = COALESCE(NULLIF(date_confidence, ''), $10),
date_source_kind = COALESCE(NULLIF(date_source_kind, ''), $11),
updated_at = CURRENT_TIMESTAMP
WHERE id = $1`,
modelID, record.ModelName, providerID, nullIfBlank(fallbackModality(record.Modality)), nullIfZeroIntCommon(record.ContextLength),
record.OperatorName, batchID, firstNonEmptyText(record.ModelSourceURL, record.SourceURL), releaseDateValueCommon(record.ReleaseDate),
fallbackDateConfidence(record.DateConfidence), fallbackDateSourceKind(record.DateSourceKind),
)
return modelID, err
}
func officialPricingSourceType(operatorType string, isFree bool) string {
if isFree {
return "free_tier"
}
switch strings.ToLower(strings.TrimSpace(operatorType)) {
case "official":
return "official"
default:
return "reseller"
}
}
func releaseDateValueCommon(raw string) any {
if strings.TrimSpace(raw) == "" {
return nil
}
parsed, err := time.Parse("2006-01-02", raw)
if err != nil {
return nil
}
return parsed
}
func fallbackDateConfidence(raw string) string {
if strings.TrimSpace(raw) == "" {
return "unknown"
}
return raw
}
func fallbackDateSourceKind(raw string) string {
if strings.TrimSpace(raw) == "" {
return "official_product_page"
}
return raw
}
func fallbackModality(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
return "text"
}
return value
}
func fallbackPricingMode(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
return "input_output"
}
return value
}
func fallbackPriceUnit(raw string) string {
value := strings.TrimSpace(raw)
if value == "" {
return "million_tokens"
}
return value
}
func nullIfZeroFloat(value float64) any {
if value == 0 {
return nil
}
return value
}
func fetchRawPricingPage(url string, fixture string, client *http.Client) (string, error) {
return fetchRawPricingPageWithOptions(url, fixture, client, officialPricingFetchOptions{
AcceptLanguage: "zh-CN,zh;q=0.9,en;q=0.8",
})
}
func fetchRawPricingPageWithOptions(url string, fixture string, client *http.Client, opts officialPricingFetchOptions) (string, error) {
if fixture != "" {
data, err := os.ReadFile(fixture)
if err != nil {
return "", fmt.Errorf("read fixture %s: %w", fixture, err)
}
return string(data), nil
}
if client == nil {
client = &http.Client{Timeout: 20 * time.Second}
}
var lastErr error
for attempt := 1; attempt <= officialPricingFetchMaxAttempts; attempt++ {
body, retryable, err := fetchRawPricingPageOnce(url, client, opts)
if err == nil {
return body, nil
}
lastErr = err
if !retryable || attempt == officialPricingFetchMaxAttempts {
return "", err
}
time.Sleep(time.Duration(attempt) * 200 * time.Millisecond)
}
return "", lastErr
}
func fetchRawPricingPageOnce(url string, client *http.Client, opts officialPricingFetchOptions) (string, bool, error) {
req, err := http.NewRequest(http.MethodGet, url, nil)
if err != nil {
return "", false, err
}
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/136.0.0.0 Safari/537.36")
req.Header.Set("Accept", "text/html,application/xhtml+xml,application/json,text/plain;q=0.9,*/*;q=0.8")
if strings.TrimSpace(opts.AcceptLanguage) != "" {
req.Header.Set("Accept-Language", opts.AcceptLanguage)
}
resp, err := client.Do(req)
if err != nil {
return "", isRetriablePricingFetchError(err), fmt.Errorf("fetch %s: %w", url, err)
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
retryable := resp.StatusCode == http.StatusTooManyRequests ||
resp.StatusCode == http.StatusBadGateway ||
resp.StatusCode == http.StatusServiceUnavailable ||
resp.StatusCode == http.StatusGatewayTimeout
return "", retryable, fmt.Errorf("fetch %s: unexpected status %d", url, resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", isRetriablePricingFetchError(err), fmt.Errorf("read %s: %w", url, err)
}
return string(body), false, nil
}
func isRetriablePricingFetchError(err error) bool {
if err == nil {
return false
}
lower := strings.ToLower(err.Error())
for _, marker := range []string{
"eof",
"timeout",
"temporarily unavailable",
"transport closed",
"connection reset",
"connection refused",
"tls handshake timeout",
"i/o timeout",
"too many requests",
"no such host",
} {
if strings.Contains(lower, marker) {
return true
}
}
return false
}
func cleanHTMLText(raw string) string {
tagPattern := regexp.MustCompile(`(?is)<[^>]+>`)
spacePattern := regexp.MustCompile(`[ \t]+`)
text := html.UnescapeString(raw)
text = strings.ReplaceAll(text, "\r\n", "\n")
text = strings.ReplaceAll(text, "\r", "\n")
text = strings.ReplaceAll(text, "\u00a0", " ")
text = tagPattern.ReplaceAllString(text, " ")
text = spacePattern.ReplaceAllString(text, " ")
return strings.TrimSpace(text)
}
func firstDollarPrice(raw string) (float64, bool) {
pattern := regexp.MustCompile(`\$? ?([0-9]+(?:\.[0-9]+)?)`)
match := pattern.FindStringSubmatch(raw)
if len(match) != 2 {
return 0, false
}
return mustParseSubscriptionPrice(match[1]), true
}
func normalizeExternalID(parts ...string) string {
joined := strings.ToLower(strings.Join(parts, "-"))
replacer := regexp.MustCompile(`[^a-z0-9]+`)
normalized := replacer.ReplaceAllString(joined, "-")
normalized = strings.Trim(normalized, "-")
normalized = regexp.MustCompile(`-+`).ReplaceAllString(normalized, "-")
return normalized
}
func parseContextLengthCommon(raw string) int {
cleaned := strings.TrimSpace(strings.ToUpper(strings.ReplaceAll(raw, ",", "")))
if cleaned == "" {
return 0
}
switch {
case strings.HasSuffix(cleaned, "M"):
return int(parseDecimalMultiplier(strings.TrimSuffix(cleaned, "M"), 1000000))
case strings.HasSuffix(cleaned, "K"):
return int(parseDecimalMultiplier(strings.TrimSuffix(cleaned, "K"), 1000))
default:
return int(mustParseSubscriptionInt64(cleaned))
}
}
func detectModality(modelName string) string {
lower := strings.ToLower(modelName)
switch {
case strings.Contains(lower, "coder"), strings.Contains(lower, "code"):
return "code"
case strings.Contains(lower, "voice"), strings.Contains(lower, "audio"), strings.Contains(lower, "speech"):
return "audio"
case strings.Contains(lower, "vision"), strings.Contains(lower, "vl"), strings.Contains(lower, "omni"), strings.Contains(lower, "multi"), strings.Contains(lower, "live"):
return "multimodal"
default:
return "text"
}
}
func providerMetadata(providerName string) (string, string, string) {
switch providerName {
case "Alibaba", "Qwen":
return "阿里云", "CN", "https://tongyi.aliyun.com"
case "360":
return "360", "CN", "https://ai.360.com/open/models"
case "Amazon":
return "亚马逊", "US", "https://aws.amazon.com"
case "Anthropic":
return "Anthropic", "US", "https://www.anthropic.com"
case "BAAI":
return "智源", "CN", "https://www.baai.ac.cn"
case "Baidu":
return "百度", "CN", "https://cloud.baidu.com"
case "Baichuan":
return "百川智能", "CN", "https://platform.baichuan-ai.com"
case "ByteDance":
return "字节跳动", "CN", "https://www.volcengine.com"
case "China Mobile":
return "中国移动", "CN", "https://ecloud.10086.cn"
case "Cloudflare":
return "Cloudflare", "US", "https://www.cloudflare.com"
case "Cohere":
return "Cohere", "CA", "https://cohere.com"
case "DeepSeek":
return "深度求索", "CN", "https://www.deepseek.com"
case "Google":
return "谷歌", "US", "https://ai.google.dev"
case "Meta":
return "Meta", "US", "https://about.meta.com"
case "MiniMax":
return "MiniMax", "CN", "https://www.minimax.io"
case "Mistral AI":
return "Mistral AI", "FR", "https://mistral.ai"
case "Moonshot AI":
return "月之暗面", "CN", "https://www.moonshot.cn"
case "NVIDIA":
return "NVIDIA", "US", "https://build.nvidia.com"
case "OpenAI":
return "OpenAI", "US", "https://openai.com"
case "Perplexity":
return "Perplexity", "US", "https://www.perplexity.ai"
case "SenseTime":
return "商汤科技", "CN", "https://www.sensetime.com"
case "Tencent":
return "腾讯", "CN", "https://cloud.tencent.com"
case "Huawei":
return "华为", "CN", "https://www.huaweicloud.com"
case "iFlytek":
return "科大讯飞", "CN", "https://www.xfyun.cn"
case "Yi":
return "零一万物", "CN", "https://platform.lingyiwanwu.com"
case "xAI":
return "xAI", "US", "https://x.ai"
case "Xiaomi":
return "小米", "CN", "https://xiaomi.com"
case "Zhipu AI":
return "智谱", "CN", "https://open.bigmodel.cn"
default:
return "", "unknown", ""
}
}
func providerFromModelPath(modelName string) string {
lower := strings.ToLower(modelName)
switch {
case strings.HasPrefix(lower, "amazon/"):
return "Amazon"
case strings.HasPrefix(lower, "anthropic/"):
return "Anthropic"
case strings.HasPrefix(lower, "cohere/"):
return "Cohere"
case strings.HasPrefix(lower, "qwen/"):
return "Qwen"
case strings.HasPrefix(lower, "deepseek"), strings.HasPrefix(lower, "deepseek-ai/"):
return "DeepSeek"
case strings.HasPrefix(lower, "google/"), strings.HasPrefix(lower, "gemini/"):
return "Google"
case strings.HasPrefix(lower, "meta/"):
return "Meta"
case strings.HasPrefix(lower, "mistral/"), strings.HasPrefix(lower, "mistralai/"):
return "Mistral AI"
case strings.HasPrefix(lower, "moonshotai/"):
return "Moonshot AI"
case strings.HasPrefix(lower, "minimaxai/"):
return "MiniMax"
case strings.HasPrefix(lower, "nvidia/"):
return "NVIDIA"
case strings.HasPrefix(lower, "perplexity/"):
return "Perplexity"
case strings.HasPrefix(lower, "zai-org/"), strings.HasPrefix(lower, "glm/"):
return "Zhipu AI"
case strings.HasPrefix(lower, "openai/"):
return "OpenAI"
case strings.HasPrefix(lower, "xai/"):
return "xAI"
default:
return "unknown"
}
}
func dedupeOfficialPricingRecords(records []officialPricingRecord) []officialPricingRecord {
seen := make(map[string]officialPricingRecord)
order := make([]string, 0, len(records))
for _, record := range records {
key := strings.Join([]string{
record.OperatorName,
record.ModelID,
record.Region,
record.Currency,
}, "|")
if _, exists := seen[key]; !exists {
order = append(order, key)
}
seen[key] = record
}
result := make([]officialPricingRecord, 0, len(order))
for _, key := range order {
result = append(result, seen[key])
}
return result
}