260 lines
7.5 KiB
Go
260 lines
7.5 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
type ProviderDraft struct {
|
|
ID int64
|
|
DraftID string
|
|
PackID string
|
|
ProviderID string
|
|
DisplayName string
|
|
Platform string
|
|
BaseURL string
|
|
SmokeTestModel string
|
|
SupportedModelsJSON string
|
|
ManifestJSON string
|
|
SourceHostID string
|
|
Notes string
|
|
CreatedAt string
|
|
UpdatedAt string
|
|
}
|
|
|
|
type ListProviderDraftsFilter struct {
|
|
PackID string
|
|
ProviderID string
|
|
Query string
|
|
}
|
|
|
|
type ProviderDraftsRepo struct {
|
|
db execQuerier
|
|
}
|
|
|
|
func newProviderDraftsRepo(db execQuerier) *ProviderDraftsRepo {
|
|
return &ProviderDraftsRepo{db: db}
|
|
}
|
|
|
|
func (r *ProviderDraftsRepo) Create(ctx context.Context, draft ProviderDraft) (int64, error) {
|
|
draftID := strings.TrimSpace(draft.DraftID)
|
|
packID := strings.TrimSpace(draft.PackID)
|
|
providerID := strings.TrimSpace(draft.ProviderID)
|
|
displayName := strings.TrimSpace(draft.DisplayName)
|
|
platform := strings.TrimSpace(draft.Platform)
|
|
if draft.ManifestJSON = strings.TrimSpace(draft.ManifestJSON); draft.ManifestJSON == "" {
|
|
draft.ManifestJSON = "{}"
|
|
}
|
|
if draft.SupportedModelsJSON = strings.TrimSpace(draft.SupportedModelsJSON); draft.SupportedModelsJSON == "" {
|
|
draft.SupportedModelsJSON = "[]"
|
|
}
|
|
|
|
switch {
|
|
case draftID == "":
|
|
return 0, fmt.Errorf("draft_id is required")
|
|
case packID == "":
|
|
return 0, fmt.Errorf("pack_id is required")
|
|
case providerID == "":
|
|
return 0, fmt.Errorf("provider_id is required")
|
|
case displayName == "":
|
|
return 0, fmt.Errorf("display_name is required")
|
|
case platform == "":
|
|
return 0, fmt.Errorf("platform is required")
|
|
}
|
|
|
|
result, err := r.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO provider_drafts (draft_id, pack_id, provider_id, display_name, platform, base_url, smoke_test_model, supported_models_json, manifest_json, source_host_id, notes)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
draftID,
|
|
packID,
|
|
providerID,
|
|
displayName,
|
|
platform,
|
|
strings.TrimSpace(draft.BaseURL),
|
|
strings.TrimSpace(draft.SmokeTestModel),
|
|
draft.SupportedModelsJSON,
|
|
draft.ManifestJSON,
|
|
strings.TrimSpace(draft.SourceHostID),
|
|
strings.TrimSpace(draft.Notes),
|
|
)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("insert provider draft %q: %w", draftID, err)
|
|
}
|
|
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("read inserted provider draft id for %q: %w", draftID, err)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (r *ProviderDraftsRepo) UpdateByDraftID(ctx context.Context, draft ProviderDraft) error {
|
|
draftID := strings.TrimSpace(draft.DraftID)
|
|
packID := strings.TrimSpace(draft.PackID)
|
|
providerID := strings.TrimSpace(draft.ProviderID)
|
|
displayName := strings.TrimSpace(draft.DisplayName)
|
|
platform := strings.TrimSpace(draft.Platform)
|
|
if draft.ManifestJSON = strings.TrimSpace(draft.ManifestJSON); draft.ManifestJSON == "" {
|
|
draft.ManifestJSON = "{}"
|
|
}
|
|
if draft.SupportedModelsJSON = strings.TrimSpace(draft.SupportedModelsJSON); draft.SupportedModelsJSON == "" {
|
|
draft.SupportedModelsJSON = "[]"
|
|
}
|
|
|
|
switch {
|
|
case draftID == "":
|
|
return fmt.Errorf("draft_id is required")
|
|
case packID == "":
|
|
return fmt.Errorf("pack_id is required")
|
|
case providerID == "":
|
|
return fmt.Errorf("provider_id is required")
|
|
case displayName == "":
|
|
return fmt.Errorf("display_name is required")
|
|
case platform == "":
|
|
return fmt.Errorf("platform is required")
|
|
}
|
|
|
|
result, err := r.db.ExecContext(
|
|
ctx,
|
|
`UPDATE provider_drafts
|
|
SET pack_id = ?, provider_id = ?, display_name = ?, platform = ?, base_url = ?, smoke_test_model = ?, supported_models_json = ?, manifest_json = ?, source_host_id = ?, notes = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE draft_id = ?`,
|
|
packID,
|
|
providerID,
|
|
displayName,
|
|
platform,
|
|
strings.TrimSpace(draft.BaseURL),
|
|
strings.TrimSpace(draft.SmokeTestModel),
|
|
draft.SupportedModelsJSON,
|
|
draft.ManifestJSON,
|
|
strings.TrimSpace(draft.SourceHostID),
|
|
strings.TrimSpace(draft.Notes),
|
|
draftID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("update provider draft %q: %w", draftID, err)
|
|
}
|
|
affected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("read updated provider draft rows for %q: %w", draftID, err)
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("provider draft %q not found", draftID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *ProviderDraftsRepo) GetByDraftID(ctx context.Context, draftID string) (ProviderDraft, error) {
|
|
draftID = strings.TrimSpace(draftID)
|
|
if draftID == "" {
|
|
return ProviderDraft{}, fmt.Errorf("draft_id is required")
|
|
}
|
|
|
|
var draft ProviderDraft
|
|
if err := r.db.QueryRowContext(
|
|
ctx,
|
|
`SELECT id, draft_id, pack_id, provider_id, display_name, platform, base_url, smoke_test_model, supported_models_json, manifest_json, source_host_id, notes, created_at, updated_at
|
|
FROM provider_drafts WHERE draft_id = ?`,
|
|
draftID,
|
|
).Scan(
|
|
&draft.ID,
|
|
&draft.DraftID,
|
|
&draft.PackID,
|
|
&draft.ProviderID,
|
|
&draft.DisplayName,
|
|
&draft.Platform,
|
|
&draft.BaseURL,
|
|
&draft.SmokeTestModel,
|
|
&draft.SupportedModelsJSON,
|
|
&draft.ManifestJSON,
|
|
&draft.SourceHostID,
|
|
&draft.Notes,
|
|
&draft.CreatedAt,
|
|
&draft.UpdatedAt,
|
|
); err != nil {
|
|
return ProviderDraft{}, err
|
|
}
|
|
return draft, nil
|
|
}
|
|
|
|
func (r *ProviderDraftsRepo) List(ctx context.Context, filter ListProviderDraftsFilter) ([]ProviderDraft, error) {
|
|
query := `SELECT id, draft_id, pack_id, provider_id, display_name, platform, base_url, smoke_test_model, supported_models_json, manifest_json, source_host_id, notes, created_at, updated_at
|
|
FROM provider_drafts`
|
|
where := make([]string, 0, 3)
|
|
args := make([]any, 0, 3)
|
|
|
|
if packID := strings.TrimSpace(filter.PackID); packID != "" {
|
|
where = append(where, "pack_id = ?")
|
|
args = append(args, packID)
|
|
}
|
|
if providerID := strings.TrimSpace(filter.ProviderID); providerID != "" {
|
|
where = append(where, "provider_id = ?")
|
|
args = append(args, providerID)
|
|
}
|
|
if rawQuery := strings.TrimSpace(filter.Query); rawQuery != "" {
|
|
like := "%" + rawQuery + "%"
|
|
where = append(where, "(draft_id LIKE ? OR provider_id LIKE ? OR display_name LIKE ?)")
|
|
args = append(args, like, like, like)
|
|
}
|
|
if len(where) > 0 {
|
|
query += " WHERE " + strings.Join(where, " AND ")
|
|
}
|
|
query += " ORDER BY id DESC"
|
|
|
|
rows, err := r.db.QueryContext(ctx, query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list provider drafts: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
drafts := make([]ProviderDraft, 0)
|
|
for rows.Next() {
|
|
var draft ProviderDraft
|
|
if err := rows.Scan(
|
|
&draft.ID,
|
|
&draft.DraftID,
|
|
&draft.PackID,
|
|
&draft.ProviderID,
|
|
&draft.DisplayName,
|
|
&draft.Platform,
|
|
&draft.BaseURL,
|
|
&draft.SmokeTestModel,
|
|
&draft.SupportedModelsJSON,
|
|
&draft.ManifestJSON,
|
|
&draft.SourceHostID,
|
|
&draft.Notes,
|
|
&draft.CreatedAt,
|
|
&draft.UpdatedAt,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("scan provider draft: %w", err)
|
|
}
|
|
drafts = append(drafts, draft)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate provider drafts: %w", err)
|
|
}
|
|
return drafts, nil
|
|
}
|
|
|
|
func (r *ProviderDraftsRepo) DeleteByDraftID(ctx context.Context, draftID string) error {
|
|
draftID = strings.TrimSpace(draftID)
|
|
if draftID == "" {
|
|
return fmt.Errorf("draft_id is required")
|
|
}
|
|
|
|
result, err := r.db.ExecContext(ctx, `DELETE FROM provider_drafts WHERE draft_id = ?`, draftID)
|
|
if err != nil {
|
|
return fmt.Errorf("delete provider draft %q: %w", draftID, err)
|
|
}
|
|
affected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("read deleted provider draft rows for %q: %w", draftID, err)
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("provider draft %q not found", draftID)
|
|
}
|
|
return nil
|
|
}
|