Files
sub2api-cn-relay-manager/internal/store/sqlite/import_batches_repo.go

286 lines
10 KiB
Go

package sqlite
import (
"context"
"fmt"
"strings"
)
type ImportBatch struct {
ID int64
HostID int64
PackID int64
ProviderID int64
Mode string
BatchStatus string
AccessStatus string
}
type ImportBatchItem struct {
ID int64
BatchID int64
KeyFingerprint string
AccountStatus string
ProbeSummaryJSON string
}
type ImportBatchesRepo struct {
db execQuerier
}
type ImportBatchItemsRepo struct {
db execQuerier
}
func newImportBatchesRepo(db execQuerier) *ImportBatchesRepo {
return &ImportBatchesRepo{db: db}
}
func newImportBatchItemsRepo(db execQuerier) *ImportBatchItemsRepo {
return &ImportBatchItemsRepo{db: db}
}
func (r *ImportBatchesRepo) GetByID(ctx context.Context, id int64) (ImportBatch, error) {
if id <= 0 {
return ImportBatch{}, fmt.Errorf("id is required")
}
var batch ImportBatch
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE id = ?`, id).Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
return ImportBatch{}, err
}
return batch, nil
}
func (r *ImportBatchesRepo) Create(ctx context.Context, batch ImportBatch) (int64, error) {
mode := strings.TrimSpace(batch.Mode)
batchStatus := strings.TrimSpace(batch.BatchStatus)
accessStatus := strings.TrimSpace(batch.AccessStatus)
switch {
case batch.HostID <= 0:
return 0, fmt.Errorf("host_id is required")
case batch.PackID <= 0:
return 0, fmt.Errorf("pack_id is required")
case batch.ProviderID <= 0:
return 0, fmt.Errorf("provider_id is required")
case mode == "":
return 0, fmt.Errorf("mode is required")
case batchStatus == "":
return 0, fmt.Errorf("batch_status is required")
case accessStatus == "":
return 0, fmt.Errorf("access_status is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO import_batches (host_id, pack_id, provider_id, mode, batch_status, access_status) VALUES (?, ?, ?, ?, ?, ?)`, batch.HostID, batch.PackID, batch.ProviderID, mode, batchStatus, accessStatus)
if err != nil {
return 0, fmt.Errorf("insert import batch: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted import batch id: %w", err)
}
return id, nil
}
func (r *ImportBatchesRepo) UpdateStatus(ctx context.Context, id int64, batchStatus, accessStatus string) error {
if id <= 0 {
return fmt.Errorf("id is required")
}
batchStatus = strings.TrimSpace(batchStatus)
accessStatus = strings.TrimSpace(accessStatus)
if batchStatus == "" {
return fmt.Errorf("batch_status is required")
}
if accessStatus == "" {
return fmt.Errorf("access_status is required")
}
if _, err := r.db.ExecContext(ctx, `UPDATE import_batches SET batch_status = ?, access_status = ? WHERE id = ?`, batchStatus, accessStatus, id); err != nil {
return fmt.Errorf("update import batch %d: %w", id, err)
}
return nil
}
func (r *ImportBatchesRepo) GetLatestByProviderID(ctx context.Context, providerID int64) (ImportBatch, error) {
if providerID <= 0 {
return ImportBatch{}, fmt.Errorf("provider_id is required")
}
var batch ImportBatch
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE provider_id = ? ORDER BY id DESC LIMIT 1`, providerID).Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
return ImportBatch{}, err
}
return batch, nil
}
func (r *ImportBatchesRepo) GetLatestByProviderIDAndHostID(ctx context.Context, providerID, hostID int64) (ImportBatch, error) {
if providerID <= 0 {
return ImportBatch{}, fmt.Errorf("provider_id is required")
}
if hostID <= 0 {
return ImportBatch{}, fmt.Errorf("host_id is required")
}
var batch ImportBatch
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE provider_id = ? AND host_id = ? ORDER BY id DESC LIMIT 1`, providerID, hostID).Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
return ImportBatch{}, err
}
return batch, nil
}
func (r *ImportBatchesRepo) ListByProviderID(ctx context.Context, providerID int64) ([]ImportBatch, error) {
if providerID <= 0 {
return nil, fmt.Errorf("provider_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE provider_id = ? ORDER BY id DESC`, providerID)
if err != nil {
return nil, fmt.Errorf("query import batches by provider_id %d: %w", providerID, err)
}
defer rows.Close()
batches := make([]ImportBatch, 0)
for rows.Next() {
var batch ImportBatch
if err := rows.Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
return nil, fmt.Errorf("scan import batch by provider_id %d: %w", providerID, err)
}
batches = append(batches, batch)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate import batches by provider_id %d: %w", providerID, err)
}
return batches, nil
}
func (r *ImportBatchesRepo) ListByProviderIDAndHostID(ctx context.Context, providerID, hostID int64) ([]ImportBatch, error) {
if providerID <= 0 {
return nil, fmt.Errorf("provider_id is required")
}
if hostID <= 0 {
return nil, fmt.Errorf("host_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE provider_id = ? AND host_id = ? ORDER BY id DESC`, providerID, hostID)
if err != nil {
return nil, fmt.Errorf("query import batches by provider_id %d and host_id %d: %w", providerID, hostID, err)
}
defer rows.Close()
batches := make([]ImportBatch, 0)
for rows.Next() {
var batch ImportBatch
if err := rows.Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
return nil, fmt.Errorf("scan import batch by provider_id %d and host_id %d: %w", providerID, hostID, err)
}
batches = append(batches, batch)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate import batches by provider_id %d and host_id %d: %w", providerID, hostID, err)
}
return batches, nil
}
func (r *ImportBatchesRepo) ListLatestReconcilable(ctx context.Context) ([]ImportBatch, error) {
rows, err := r.db.QueryContext(ctx, `
SELECT ib.id, ib.host_id, ib.pack_id, ib.provider_id, ib.mode, ib.batch_status, ib.access_status
FROM import_batches ib
INNER JOIN (
SELECT provider_id, host_id, MAX(id) AS latest_id
FROM import_batches
GROUP BY provider_id, host_id
) latest ON latest.latest_id = ib.id
WHERE ib.batch_status IN ('succeeded', 'partially_succeeded')
ORDER BY ib.id DESC`)
if err != nil {
return nil, fmt.Errorf("query latest reconcilable import batches: %w", err)
}
defer rows.Close()
batches := make([]ImportBatch, 0)
for rows.Next() {
var batch ImportBatch
if err := rows.Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
return nil, fmt.Errorf("scan latest reconcilable import batch: %w", err)
}
batches = append(batches, batch)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate latest reconcilable import batches: %w", err)
}
return batches, nil
}
func (r *ImportBatchItemsRepo) GetByBatchID(ctx context.Context, batchID int64) ([]ImportBatchItem, error) {
if batchID <= 0 {
return nil, fmt.Errorf("batch_id is required")
}
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_id, key_fingerprint, account_status, probe_summary_json FROM import_batch_items WHERE batch_id = ? ORDER BY id`, batchID)
if err != nil {
return nil, fmt.Errorf("query import batch items: %w", err)
}
defer rows.Close()
items := make([]ImportBatchItem, 0)
for rows.Next() {
var item ImportBatchItem
if err := rows.Scan(&item.ID, &item.BatchID, &item.KeyFingerprint, &item.AccountStatus, &item.ProbeSummaryJSON); err != nil {
return nil, fmt.Errorf("scan import batch item: %w", err)
}
items = append(items, item)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate import batch items: %w", err)
}
return items, nil
}
func (r *ImportBatchItemsRepo) Create(ctx context.Context, item ImportBatchItem) (int64, error) {
keyFingerprint := strings.TrimSpace(item.KeyFingerprint)
accountStatus := strings.TrimSpace(item.AccountStatus)
probeSummaryJSON := strings.TrimSpace(item.ProbeSummaryJSON)
if probeSummaryJSON == "" {
probeSummaryJSON = "{}"
}
switch {
case item.BatchID <= 0:
return 0, fmt.Errorf("batch_id is required")
case keyFingerprint == "":
return 0, fmt.Errorf("key_fingerprint is required")
case accountStatus == "":
return 0, fmt.Errorf("account_status is required")
}
result, err := r.db.ExecContext(ctx, `INSERT INTO import_batch_items (batch_id, key_fingerprint, account_status, probe_summary_json) VALUES (?, ?, ?, ?)`, item.BatchID, keyFingerprint, accountStatus, probeSummaryJSON)
if err != nil {
return 0, fmt.Errorf("insert import batch item: %w", err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted import batch item id: %w", err)
}
return id, nil
}
func (r *ImportBatchItemsRepo) UpdateResult(ctx context.Context, id int64, accountStatus, probeSummaryJSON string) error {
if id <= 0 {
return fmt.Errorf("id is required")
}
accountStatus = strings.TrimSpace(accountStatus)
probeSummaryJSON = strings.TrimSpace(probeSummaryJSON)
if accountStatus == "" {
return fmt.Errorf("account_status is required")
}
if probeSummaryJSON == "" {
probeSummaryJSON = "{}"
}
if _, err := r.db.ExecContext(ctx, `UPDATE import_batch_items SET account_status = ?, probe_summary_json = ? WHERE id = ?`, accountStatus, probeSummaryJSON, id); err != nil {
return fmt.Errorf("update import batch item %d: %w", id, err)
}
return nil
}