Files
sub2api-cn-relay-manager/internal/store/sqlite/providers_repo.go
2026-05-27 20:23:42 +08:00

350 lines
11 KiB
Go

package sqlite
import (
"context"
"fmt"
"strings"
)
type Provider struct {
ID int64
PackID int64
ProviderID string
DisplayName string
BaseURL string
Platform string
AccountType string
DefaultModelsJSON string
SmokeTestModel string
GroupTemplateJSON string
ChannelTemplateJSON string
PlanTemplateJSON string
ImportOptionsJSON string
ManifestJSON string
}
type ProvidersRepo struct {
db execQuerier
}
func newProvidersRepo(db execQuerier) *ProvidersRepo {
return &ProvidersRepo{db: db}
}
func (r *ProvidersRepo) ListByPackID(ctx context.Context, packID int64) ([]Provider, error) {
if packID <= 0 {
return nil, fmt.Errorf("pack_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE pack_id = ? ORDER BY id`, packID)
if err != nil {
return nil, fmt.Errorf("query providers by pack_id %d: %w", packID, err)
}
defer rows.Close()
providers := make([]Provider, 0)
for rows.Next() {
var provider Provider
if err := rows.Scan(
&provider.ID,
&provider.PackID,
&provider.ProviderID,
&provider.DisplayName,
&provider.BaseURL,
&provider.Platform,
&provider.AccountType,
&provider.DefaultModelsJSON,
&provider.SmokeTestModel,
&provider.GroupTemplateJSON,
&provider.ChannelTemplateJSON,
&provider.PlanTemplateJSON,
&provider.ImportOptionsJSON,
&provider.ManifestJSON,
); err != nil {
return nil, fmt.Errorf("scan provider: %w", err)
}
providers = append(providers, provider)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate providers: %w", err)
}
return providers, nil
}
func (r *ProvidersRepo) ListByProviderID(ctx context.Context, providerID string) ([]Provider, error) {
providerID = strings.TrimSpace(providerID)
if providerID == "" {
return nil, fmt.Errorf("provider_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE provider_id = ? ORDER BY id`, providerID)
if err != nil {
return nil, fmt.Errorf("query providers by provider_id %q: %w", providerID, err)
}
defer rows.Close()
providers := make([]Provider, 0)
for rows.Next() {
var provider Provider
if err := rows.Scan(
&provider.ID,
&provider.PackID,
&provider.ProviderID,
&provider.DisplayName,
&provider.BaseURL,
&provider.Platform,
&provider.AccountType,
&provider.DefaultModelsJSON,
&provider.SmokeTestModel,
&provider.GroupTemplateJSON,
&provider.ChannelTemplateJSON,
&provider.PlanTemplateJSON,
&provider.ImportOptionsJSON,
&provider.ManifestJSON,
); err != nil {
return nil, fmt.Errorf("scan provider by provider_id %q: %w", providerID, err)
}
providers = append(providers, provider)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate providers by provider_id %q: %w", providerID, err)
}
return providers, nil
}
func (r *ProvidersRepo) ListByBaseURL(ctx context.Context, baseURL string) ([]Provider, error) {
baseURL = strings.TrimSpace(baseURL)
if baseURL == "" {
return nil, fmt.Errorf("base_url is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE base_url = ? ORDER BY id`, baseURL)
if err != nil {
return nil, fmt.Errorf("query providers by base_url %q: %w", baseURL, err)
}
defer rows.Close()
providers := make([]Provider, 0)
for rows.Next() {
var provider Provider
if err := rows.Scan(
&provider.ID,
&provider.PackID,
&provider.ProviderID,
&provider.DisplayName,
&provider.BaseURL,
&provider.Platform,
&provider.AccountType,
&provider.DefaultModelsJSON,
&provider.SmokeTestModel,
&provider.GroupTemplateJSON,
&provider.ChannelTemplateJSON,
&provider.PlanTemplateJSON,
&provider.ImportOptionsJSON,
&provider.ManifestJSON,
); err != nil {
return nil, fmt.Errorf("scan provider by base_url %q: %w", baseURL, err)
}
providers = append(providers, provider)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate providers by base_url %q: %w", baseURL, err)
}
return providers, nil
}
func (r *ProvidersRepo) GetByPackIDAndProviderID(ctx context.Context, packID int64, providerID string) (Provider, error) {
if packID <= 0 {
return Provider{}, fmt.Errorf("pack_id is required")
}
providerID = strings.TrimSpace(providerID)
if providerID == "" {
return Provider{}, fmt.Errorf("provider_id is required")
}
var provider Provider
if err := r.db.QueryRowContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE pack_id = ? AND provider_id = ?`, packID, providerID).Scan(
&provider.ID,
&provider.PackID,
&provider.ProviderID,
&provider.DisplayName,
&provider.BaseURL,
&provider.Platform,
&provider.AccountType,
&provider.DefaultModelsJSON,
&provider.SmokeTestModel,
&provider.GroupTemplateJSON,
&provider.ChannelTemplateJSON,
&provider.PlanTemplateJSON,
&provider.ImportOptionsJSON,
&provider.ManifestJSON,
); err != nil {
return Provider{}, err
}
return provider, nil
}
func (r *ProvidersRepo) GetByID(ctx context.Context, id int64) (Provider, error) {
if id <= 0 {
return Provider{}, fmt.Errorf("id is required")
}
var provider Provider
if err := r.db.QueryRowContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE id = ?`, id).Scan(
&provider.ID,
&provider.PackID,
&provider.ProviderID,
&provider.DisplayName,
&provider.BaseURL,
&provider.Platform,
&provider.AccountType,
&provider.DefaultModelsJSON,
&provider.SmokeTestModel,
&provider.GroupTemplateJSON,
&provider.ChannelTemplateJSON,
&provider.PlanTemplateJSON,
&provider.ImportOptionsJSON,
&provider.ManifestJSON,
); err != nil {
return Provider{}, err
}
return provider, nil
}
func (r *ProvidersRepo) Create(ctx context.Context, provider Provider) (int64, error) {
providerID := strings.TrimSpace(provider.ProviderID)
displayName := strings.TrimSpace(provider.DisplayName)
baseURL := strings.TrimSpace(provider.BaseURL)
platform := strings.TrimSpace(provider.Platform)
manifestJSON := strings.TrimSpace(provider.ManifestJSON)
if manifestJSON == "" {
manifestJSON = "{}"
}
switch {
case provider.PackID <= 0:
return 0, fmt.Errorf("pack_id is required")
case providerID == "":
return 0, fmt.Errorf("provider_id is required")
case displayName == "":
return 0, fmt.Errorf("display_name is required")
case baseURL == "":
return 0, fmt.Errorf("base_url is required")
case platform == "":
return 0, fmt.Errorf("platform is required")
}
result, err := r.db.ExecContext(
ctx,
`INSERT INTO providers (pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
provider.PackID,
providerID,
displayName,
baseURL,
platform,
strings.TrimSpace(provider.AccountType),
defaultJSONArray(provider.DefaultModelsJSON),
strings.TrimSpace(provider.SmokeTestModel),
defaultJSONObject(provider.GroupTemplateJSON),
defaultJSONObject(provider.ChannelTemplateJSON),
defaultJSONObject(provider.PlanTemplateJSON),
defaultJSONObject(provider.ImportOptionsJSON),
manifestJSON,
)
if err != nil {
return 0, fmt.Errorf("insert provider %q: %w", providerID, err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted provider id for %q: %w", providerID, err)
}
return id, nil
}
func (r *ProvidersRepo) Upsert(ctx context.Context, provider Provider) (int64, error) {
providerID := strings.TrimSpace(provider.ProviderID)
displayName := strings.TrimSpace(provider.DisplayName)
baseURL := strings.TrimSpace(provider.BaseURL)
platform := strings.TrimSpace(provider.Platform)
manifestJSON := strings.TrimSpace(provider.ManifestJSON)
if manifestJSON == "" {
manifestJSON = "{}"
}
switch {
case provider.PackID <= 0:
return 0, fmt.Errorf("pack_id is required")
case providerID == "":
return 0, fmt.Errorf("provider_id is required")
case displayName == "":
return 0, fmt.Errorf("display_name is required")
case baseURL == "":
return 0, fmt.Errorf("base_url is required")
case platform == "":
return 0, fmt.Errorf("platform is required")
}
result, err := r.db.ExecContext(
ctx,
`INSERT INTO providers (pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(pack_id, provider_id) DO UPDATE SET
display_name = excluded.display_name,
base_url = excluded.base_url,
platform = excluded.platform,
account_type = excluded.account_type,
default_models_json = excluded.default_models_json,
smoke_test_model = excluded.smoke_test_model,
group_template_json = excluded.group_template_json,
channel_template_json = excluded.channel_template_json,
plan_template_json = excluded.plan_template_json,
import_options_json = excluded.import_options_json,
manifest_json = excluded.manifest_json`,
provider.PackID,
providerID,
displayName,
baseURL,
platform,
strings.TrimSpace(provider.AccountType),
defaultJSONArray(provider.DefaultModelsJSON),
strings.TrimSpace(provider.SmokeTestModel),
defaultJSONObject(provider.GroupTemplateJSON),
defaultJSONObject(provider.ChannelTemplateJSON),
defaultJSONObject(provider.PlanTemplateJSON),
defaultJSONObject(provider.ImportOptionsJSON),
manifestJSON,
)
if err != nil {
return 0, fmt.Errorf("upsert provider %q: %w", providerID, err)
}
id, err := result.LastInsertId()
if err == nil && id > 0 {
return id, nil
}
persisted, getErr := r.GetByPackIDAndProviderID(ctx, provider.PackID, providerID)
if getErr != nil {
if err != nil {
return 0, fmt.Errorf("read upserted provider %q: %w", providerID, getErr)
}
return 0, getErr
}
return persisted.ID, nil
}
func defaultJSONObject(value string) string {
if strings.TrimSpace(value) == "" {
return "{}"
}
return value
}
func defaultJSONArray(value string) string {
if strings.TrimSpace(value) == "" {
return "[]"
}
return value
}