350 lines
11 KiB
Go
350 lines
11 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
type Provider struct {
|
|
ID int64
|
|
PackID int64
|
|
ProviderID string
|
|
DisplayName string
|
|
BaseURL string
|
|
Platform string
|
|
AccountType string
|
|
DefaultModelsJSON string
|
|
SmokeTestModel string
|
|
GroupTemplateJSON string
|
|
ChannelTemplateJSON string
|
|
PlanTemplateJSON string
|
|
ImportOptionsJSON string
|
|
ManifestJSON string
|
|
}
|
|
|
|
type ProvidersRepo struct {
|
|
db execQuerier
|
|
}
|
|
|
|
func newProvidersRepo(db execQuerier) *ProvidersRepo {
|
|
return &ProvidersRepo{db: db}
|
|
}
|
|
|
|
func (r *ProvidersRepo) ListByPackID(ctx context.Context, packID int64) ([]Provider, error) {
|
|
if packID <= 0 {
|
|
return nil, fmt.Errorf("pack_id is required")
|
|
}
|
|
|
|
rows, err := r.db.QueryContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE pack_id = ? ORDER BY id`, packID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query providers by pack_id %d: %w", packID, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
providers := make([]Provider, 0)
|
|
for rows.Next() {
|
|
var provider Provider
|
|
if err := rows.Scan(
|
|
&provider.ID,
|
|
&provider.PackID,
|
|
&provider.ProviderID,
|
|
&provider.DisplayName,
|
|
&provider.BaseURL,
|
|
&provider.Platform,
|
|
&provider.AccountType,
|
|
&provider.DefaultModelsJSON,
|
|
&provider.SmokeTestModel,
|
|
&provider.GroupTemplateJSON,
|
|
&provider.ChannelTemplateJSON,
|
|
&provider.PlanTemplateJSON,
|
|
&provider.ImportOptionsJSON,
|
|
&provider.ManifestJSON,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("scan provider: %w", err)
|
|
}
|
|
providers = append(providers, provider)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate providers: %w", err)
|
|
}
|
|
return providers, nil
|
|
}
|
|
|
|
func (r *ProvidersRepo) ListByProviderID(ctx context.Context, providerID string) ([]Provider, error) {
|
|
providerID = strings.TrimSpace(providerID)
|
|
if providerID == "" {
|
|
return nil, fmt.Errorf("provider_id is required")
|
|
}
|
|
|
|
rows, err := r.db.QueryContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE provider_id = ? ORDER BY id`, providerID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query providers by provider_id %q: %w", providerID, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
providers := make([]Provider, 0)
|
|
for rows.Next() {
|
|
var provider Provider
|
|
if err := rows.Scan(
|
|
&provider.ID,
|
|
&provider.PackID,
|
|
&provider.ProviderID,
|
|
&provider.DisplayName,
|
|
&provider.BaseURL,
|
|
&provider.Platform,
|
|
&provider.AccountType,
|
|
&provider.DefaultModelsJSON,
|
|
&provider.SmokeTestModel,
|
|
&provider.GroupTemplateJSON,
|
|
&provider.ChannelTemplateJSON,
|
|
&provider.PlanTemplateJSON,
|
|
&provider.ImportOptionsJSON,
|
|
&provider.ManifestJSON,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("scan provider by provider_id %q: %w", providerID, err)
|
|
}
|
|
providers = append(providers, provider)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate providers by provider_id %q: %w", providerID, err)
|
|
}
|
|
return providers, nil
|
|
}
|
|
|
|
func (r *ProvidersRepo) ListByBaseURL(ctx context.Context, baseURL string) ([]Provider, error) {
|
|
baseURL = strings.TrimSpace(baseURL)
|
|
if baseURL == "" {
|
|
return nil, fmt.Errorf("base_url is required")
|
|
}
|
|
|
|
rows, err := r.db.QueryContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE base_url = ? ORDER BY id`, baseURL)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query providers by base_url %q: %w", baseURL, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
providers := make([]Provider, 0)
|
|
for rows.Next() {
|
|
var provider Provider
|
|
if err := rows.Scan(
|
|
&provider.ID,
|
|
&provider.PackID,
|
|
&provider.ProviderID,
|
|
&provider.DisplayName,
|
|
&provider.BaseURL,
|
|
&provider.Platform,
|
|
&provider.AccountType,
|
|
&provider.DefaultModelsJSON,
|
|
&provider.SmokeTestModel,
|
|
&provider.GroupTemplateJSON,
|
|
&provider.ChannelTemplateJSON,
|
|
&provider.PlanTemplateJSON,
|
|
&provider.ImportOptionsJSON,
|
|
&provider.ManifestJSON,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("scan provider by base_url %q: %w", baseURL, err)
|
|
}
|
|
providers = append(providers, provider)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate providers by base_url %q: %w", baseURL, err)
|
|
}
|
|
return providers, nil
|
|
}
|
|
|
|
func (r *ProvidersRepo) GetByPackIDAndProviderID(ctx context.Context, packID int64, providerID string) (Provider, error) {
|
|
if packID <= 0 {
|
|
return Provider{}, fmt.Errorf("pack_id is required")
|
|
}
|
|
providerID = strings.TrimSpace(providerID)
|
|
if providerID == "" {
|
|
return Provider{}, fmt.Errorf("provider_id is required")
|
|
}
|
|
|
|
var provider Provider
|
|
if err := r.db.QueryRowContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE pack_id = ? AND provider_id = ?`, packID, providerID).Scan(
|
|
&provider.ID,
|
|
&provider.PackID,
|
|
&provider.ProviderID,
|
|
&provider.DisplayName,
|
|
&provider.BaseURL,
|
|
&provider.Platform,
|
|
&provider.AccountType,
|
|
&provider.DefaultModelsJSON,
|
|
&provider.SmokeTestModel,
|
|
&provider.GroupTemplateJSON,
|
|
&provider.ChannelTemplateJSON,
|
|
&provider.PlanTemplateJSON,
|
|
&provider.ImportOptionsJSON,
|
|
&provider.ManifestJSON,
|
|
); err != nil {
|
|
return Provider{}, err
|
|
}
|
|
return provider, nil
|
|
}
|
|
|
|
func (r *ProvidersRepo) GetByID(ctx context.Context, id int64) (Provider, error) {
|
|
if id <= 0 {
|
|
return Provider{}, fmt.Errorf("id is required")
|
|
}
|
|
|
|
var provider Provider
|
|
if err := r.db.QueryRowContext(ctx, `SELECT id, pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json FROM providers WHERE id = ?`, id).Scan(
|
|
&provider.ID,
|
|
&provider.PackID,
|
|
&provider.ProviderID,
|
|
&provider.DisplayName,
|
|
&provider.BaseURL,
|
|
&provider.Platform,
|
|
&provider.AccountType,
|
|
&provider.DefaultModelsJSON,
|
|
&provider.SmokeTestModel,
|
|
&provider.GroupTemplateJSON,
|
|
&provider.ChannelTemplateJSON,
|
|
&provider.PlanTemplateJSON,
|
|
&provider.ImportOptionsJSON,
|
|
&provider.ManifestJSON,
|
|
); err != nil {
|
|
return Provider{}, err
|
|
}
|
|
return provider, nil
|
|
}
|
|
|
|
func (r *ProvidersRepo) Create(ctx context.Context, provider Provider) (int64, error) {
|
|
providerID := strings.TrimSpace(provider.ProviderID)
|
|
displayName := strings.TrimSpace(provider.DisplayName)
|
|
baseURL := strings.TrimSpace(provider.BaseURL)
|
|
platform := strings.TrimSpace(provider.Platform)
|
|
manifestJSON := strings.TrimSpace(provider.ManifestJSON)
|
|
if manifestJSON == "" {
|
|
manifestJSON = "{}"
|
|
}
|
|
|
|
switch {
|
|
case provider.PackID <= 0:
|
|
return 0, fmt.Errorf("pack_id is required")
|
|
case providerID == "":
|
|
return 0, fmt.Errorf("provider_id is required")
|
|
case displayName == "":
|
|
return 0, fmt.Errorf("display_name is required")
|
|
case baseURL == "":
|
|
return 0, fmt.Errorf("base_url is required")
|
|
case platform == "":
|
|
return 0, fmt.Errorf("platform is required")
|
|
}
|
|
|
|
result, err := r.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO providers (pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
provider.PackID,
|
|
providerID,
|
|
displayName,
|
|
baseURL,
|
|
platform,
|
|
strings.TrimSpace(provider.AccountType),
|
|
defaultJSONArray(provider.DefaultModelsJSON),
|
|
strings.TrimSpace(provider.SmokeTestModel),
|
|
defaultJSONObject(provider.GroupTemplateJSON),
|
|
defaultJSONObject(provider.ChannelTemplateJSON),
|
|
defaultJSONObject(provider.PlanTemplateJSON),
|
|
defaultJSONObject(provider.ImportOptionsJSON),
|
|
manifestJSON,
|
|
)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("insert provider %q: %w", providerID, err)
|
|
}
|
|
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("read inserted provider id for %q: %w", providerID, err)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (r *ProvidersRepo) Upsert(ctx context.Context, provider Provider) (int64, error) {
|
|
providerID := strings.TrimSpace(provider.ProviderID)
|
|
displayName := strings.TrimSpace(provider.DisplayName)
|
|
baseURL := strings.TrimSpace(provider.BaseURL)
|
|
platform := strings.TrimSpace(provider.Platform)
|
|
manifestJSON := strings.TrimSpace(provider.ManifestJSON)
|
|
if manifestJSON == "" {
|
|
manifestJSON = "{}"
|
|
}
|
|
|
|
switch {
|
|
case provider.PackID <= 0:
|
|
return 0, fmt.Errorf("pack_id is required")
|
|
case providerID == "":
|
|
return 0, fmt.Errorf("provider_id is required")
|
|
case displayName == "":
|
|
return 0, fmt.Errorf("display_name is required")
|
|
case baseURL == "":
|
|
return 0, fmt.Errorf("base_url is required")
|
|
case platform == "":
|
|
return 0, fmt.Errorf("platform is required")
|
|
}
|
|
|
|
result, err := r.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO providers (pack_id, provider_id, display_name, base_url, platform, account_type, default_models_json, smoke_test_model, group_template_json, channel_template_json, plan_template_json, import_options_json, manifest_json)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
ON CONFLICT(pack_id, provider_id) DO UPDATE SET
|
|
display_name = excluded.display_name,
|
|
base_url = excluded.base_url,
|
|
platform = excluded.platform,
|
|
account_type = excluded.account_type,
|
|
default_models_json = excluded.default_models_json,
|
|
smoke_test_model = excluded.smoke_test_model,
|
|
group_template_json = excluded.group_template_json,
|
|
channel_template_json = excluded.channel_template_json,
|
|
plan_template_json = excluded.plan_template_json,
|
|
import_options_json = excluded.import_options_json,
|
|
manifest_json = excluded.manifest_json`,
|
|
provider.PackID,
|
|
providerID,
|
|
displayName,
|
|
baseURL,
|
|
platform,
|
|
strings.TrimSpace(provider.AccountType),
|
|
defaultJSONArray(provider.DefaultModelsJSON),
|
|
strings.TrimSpace(provider.SmokeTestModel),
|
|
defaultJSONObject(provider.GroupTemplateJSON),
|
|
defaultJSONObject(provider.ChannelTemplateJSON),
|
|
defaultJSONObject(provider.PlanTemplateJSON),
|
|
defaultJSONObject(provider.ImportOptionsJSON),
|
|
manifestJSON,
|
|
)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("upsert provider %q: %w", providerID, err)
|
|
}
|
|
|
|
id, err := result.LastInsertId()
|
|
if err == nil && id > 0 {
|
|
return id, nil
|
|
}
|
|
persisted, getErr := r.GetByPackIDAndProviderID(ctx, provider.PackID, providerID)
|
|
if getErr != nil {
|
|
if err != nil {
|
|
return 0, fmt.Errorf("read upserted provider %q: %w", providerID, getErr)
|
|
}
|
|
return 0, getErr
|
|
}
|
|
return persisted.ID, nil
|
|
}
|
|
|
|
func defaultJSONObject(value string) string {
|
|
if strings.TrimSpace(value) == "" {
|
|
return "{}"
|
|
}
|
|
return value
|
|
}
|
|
|
|
func defaultJSONArray(value string) string {
|
|
if strings.TrimSpace(value) == "" {
|
|
return "[]"
|
|
}
|
|
return value
|
|
}
|