286 lines
10 KiB
Go
286 lines
10 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
type ImportBatch struct {
|
|
ID int64
|
|
HostID int64
|
|
PackID int64
|
|
ProviderID int64
|
|
Mode string
|
|
BatchStatus string
|
|
AccessStatus string
|
|
}
|
|
|
|
type ImportBatchItem struct {
|
|
ID int64
|
|
BatchID int64
|
|
KeyFingerprint string
|
|
AccountStatus string
|
|
ProbeSummaryJSON string
|
|
}
|
|
|
|
type ImportBatchesRepo struct {
|
|
db execQuerier
|
|
}
|
|
|
|
type ImportBatchItemsRepo struct {
|
|
db execQuerier
|
|
}
|
|
|
|
func newImportBatchesRepo(db execQuerier) *ImportBatchesRepo {
|
|
return &ImportBatchesRepo{db: db}
|
|
}
|
|
|
|
func newImportBatchItemsRepo(db execQuerier) *ImportBatchItemsRepo {
|
|
return &ImportBatchItemsRepo{db: db}
|
|
}
|
|
|
|
func (r *ImportBatchesRepo) GetByID(ctx context.Context, id int64) (ImportBatch, error) {
|
|
if id <= 0 {
|
|
return ImportBatch{}, fmt.Errorf("id is required")
|
|
}
|
|
|
|
var batch ImportBatch
|
|
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE id = ?`, id).Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
|
|
return ImportBatch{}, err
|
|
}
|
|
return batch, nil
|
|
}
|
|
|
|
func (r *ImportBatchesRepo) Create(ctx context.Context, batch ImportBatch) (int64, error) {
|
|
mode := strings.TrimSpace(batch.Mode)
|
|
batchStatus := strings.TrimSpace(batch.BatchStatus)
|
|
accessStatus := strings.TrimSpace(batch.AccessStatus)
|
|
|
|
switch {
|
|
case batch.HostID <= 0:
|
|
return 0, fmt.Errorf("host_id is required")
|
|
case batch.PackID <= 0:
|
|
return 0, fmt.Errorf("pack_id is required")
|
|
case batch.ProviderID <= 0:
|
|
return 0, fmt.Errorf("provider_id is required")
|
|
case mode == "":
|
|
return 0, fmt.Errorf("mode is required")
|
|
case batchStatus == "":
|
|
return 0, fmt.Errorf("batch_status is required")
|
|
case accessStatus == "":
|
|
return 0, fmt.Errorf("access_status is required")
|
|
}
|
|
|
|
result, err := r.db.ExecContext(ctx, `INSERT INTO import_batches (host_id, pack_id, provider_id, mode, batch_status, access_status) VALUES (?, ?, ?, ?, ?, ?)`, batch.HostID, batch.PackID, batch.ProviderID, mode, batchStatus, accessStatus)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("insert import batch: %w", err)
|
|
}
|
|
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("read inserted import batch id: %w", err)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (r *ImportBatchesRepo) UpdateStatus(ctx context.Context, id int64, batchStatus, accessStatus string) error {
|
|
if id <= 0 {
|
|
return fmt.Errorf("id is required")
|
|
}
|
|
batchStatus = strings.TrimSpace(batchStatus)
|
|
accessStatus = strings.TrimSpace(accessStatus)
|
|
if batchStatus == "" {
|
|
return fmt.Errorf("batch_status is required")
|
|
}
|
|
if accessStatus == "" {
|
|
return fmt.Errorf("access_status is required")
|
|
}
|
|
if _, err := r.db.ExecContext(ctx, `UPDATE import_batches SET batch_status = ?, access_status = ? WHERE id = ?`, batchStatus, accessStatus, id); err != nil {
|
|
return fmt.Errorf("update import batch %d: %w", id, err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *ImportBatchesRepo) GetLatestByProviderID(ctx context.Context, providerID int64) (ImportBatch, error) {
|
|
if providerID <= 0 {
|
|
return ImportBatch{}, fmt.Errorf("provider_id is required")
|
|
}
|
|
|
|
var batch ImportBatch
|
|
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE provider_id = ? ORDER BY id DESC LIMIT 1`, providerID).Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
|
|
return ImportBatch{}, err
|
|
}
|
|
return batch, nil
|
|
}
|
|
|
|
func (r *ImportBatchesRepo) GetLatestByProviderIDAndHostID(ctx context.Context, providerID, hostID int64) (ImportBatch, error) {
|
|
if providerID <= 0 {
|
|
return ImportBatch{}, fmt.Errorf("provider_id is required")
|
|
}
|
|
if hostID <= 0 {
|
|
return ImportBatch{}, fmt.Errorf("host_id is required")
|
|
}
|
|
|
|
var batch ImportBatch
|
|
if err := r.db.QueryRowContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE provider_id = ? AND host_id = ? ORDER BY id DESC LIMIT 1`, providerID, hostID).Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
|
|
return ImportBatch{}, err
|
|
}
|
|
return batch, nil
|
|
}
|
|
|
|
func (r *ImportBatchesRepo) ListByProviderID(ctx context.Context, providerID int64) ([]ImportBatch, error) {
|
|
if providerID <= 0 {
|
|
return nil, fmt.Errorf("provider_id is required")
|
|
}
|
|
|
|
rows, err := r.db.QueryContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE provider_id = ? ORDER BY id DESC`, providerID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query import batches by provider_id %d: %w", providerID, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
batches := make([]ImportBatch, 0)
|
|
for rows.Next() {
|
|
var batch ImportBatch
|
|
if err := rows.Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
|
|
return nil, fmt.Errorf("scan import batch by provider_id %d: %w", providerID, err)
|
|
}
|
|
batches = append(batches, batch)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate import batches by provider_id %d: %w", providerID, err)
|
|
}
|
|
return batches, nil
|
|
}
|
|
|
|
func (r *ImportBatchesRepo) ListByProviderIDAndHostID(ctx context.Context, providerID, hostID int64) ([]ImportBatch, error) {
|
|
if providerID <= 0 {
|
|
return nil, fmt.Errorf("provider_id is required")
|
|
}
|
|
if hostID <= 0 {
|
|
return nil, fmt.Errorf("host_id is required")
|
|
}
|
|
|
|
rows, err := r.db.QueryContext(ctx, `SELECT id, host_id, pack_id, provider_id, mode, batch_status, access_status FROM import_batches WHERE provider_id = ? AND host_id = ? ORDER BY id DESC`, providerID, hostID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query import batches by provider_id %d and host_id %d: %w", providerID, hostID, err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
batches := make([]ImportBatch, 0)
|
|
for rows.Next() {
|
|
var batch ImportBatch
|
|
if err := rows.Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
|
|
return nil, fmt.Errorf("scan import batch by provider_id %d and host_id %d: %w", providerID, hostID, err)
|
|
}
|
|
batches = append(batches, batch)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate import batches by provider_id %d and host_id %d: %w", providerID, hostID, err)
|
|
}
|
|
return batches, nil
|
|
}
|
|
|
|
func (r *ImportBatchesRepo) ListLatestReconcilable(ctx context.Context) ([]ImportBatch, error) {
|
|
rows, err := r.db.QueryContext(ctx, `
|
|
SELECT ib.id, ib.host_id, ib.pack_id, ib.provider_id, ib.mode, ib.batch_status, ib.access_status
|
|
FROM import_batches ib
|
|
INNER JOIN (
|
|
SELECT provider_id, host_id, MAX(id) AS latest_id
|
|
FROM import_batches
|
|
GROUP BY provider_id, host_id
|
|
) latest ON latest.latest_id = ib.id
|
|
WHERE ib.batch_status IN ('succeeded', 'partially_succeeded')
|
|
ORDER BY ib.id DESC`)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query latest reconcilable import batches: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
batches := make([]ImportBatch, 0)
|
|
for rows.Next() {
|
|
var batch ImportBatch
|
|
if err := rows.Scan(&batch.ID, &batch.HostID, &batch.PackID, &batch.ProviderID, &batch.Mode, &batch.BatchStatus, &batch.AccessStatus); err != nil {
|
|
return nil, fmt.Errorf("scan latest reconcilable import batch: %w", err)
|
|
}
|
|
batches = append(batches, batch)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate latest reconcilable import batches: %w", err)
|
|
}
|
|
return batches, nil
|
|
}
|
|
|
|
func (r *ImportBatchItemsRepo) GetByBatchID(ctx context.Context, batchID int64) ([]ImportBatchItem, error) {
|
|
if batchID <= 0 {
|
|
return nil, fmt.Errorf("batch_id is required")
|
|
}
|
|
|
|
rows, err := r.db.QueryContext(ctx, `SELECT id, batch_id, key_fingerprint, account_status, probe_summary_json FROM import_batch_items WHERE batch_id = ? ORDER BY id`, batchID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("query import batch items: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
items := make([]ImportBatchItem, 0)
|
|
for rows.Next() {
|
|
var item ImportBatchItem
|
|
if err := rows.Scan(&item.ID, &item.BatchID, &item.KeyFingerprint, &item.AccountStatus, &item.ProbeSummaryJSON); err != nil {
|
|
return nil, fmt.Errorf("scan import batch item: %w", err)
|
|
}
|
|
items = append(items, item)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate import batch items: %w", err)
|
|
}
|
|
return items, nil
|
|
}
|
|
|
|
func (r *ImportBatchItemsRepo) Create(ctx context.Context, item ImportBatchItem) (int64, error) {
|
|
keyFingerprint := strings.TrimSpace(item.KeyFingerprint)
|
|
accountStatus := strings.TrimSpace(item.AccountStatus)
|
|
probeSummaryJSON := strings.TrimSpace(item.ProbeSummaryJSON)
|
|
if probeSummaryJSON == "" {
|
|
probeSummaryJSON = "{}"
|
|
}
|
|
|
|
switch {
|
|
case item.BatchID <= 0:
|
|
return 0, fmt.Errorf("batch_id is required")
|
|
case keyFingerprint == "":
|
|
return 0, fmt.Errorf("key_fingerprint is required")
|
|
case accountStatus == "":
|
|
return 0, fmt.Errorf("account_status is required")
|
|
}
|
|
|
|
result, err := r.db.ExecContext(ctx, `INSERT INTO import_batch_items (batch_id, key_fingerprint, account_status, probe_summary_json) VALUES (?, ?, ?, ?)`, item.BatchID, keyFingerprint, accountStatus, probeSummaryJSON)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("insert import batch item: %w", err)
|
|
}
|
|
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("read inserted import batch item id: %w", err)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (r *ImportBatchItemsRepo) UpdateResult(ctx context.Context, id int64, accountStatus, probeSummaryJSON string) error {
|
|
if id <= 0 {
|
|
return fmt.Errorf("id is required")
|
|
}
|
|
accountStatus = strings.TrimSpace(accountStatus)
|
|
probeSummaryJSON = strings.TrimSpace(probeSummaryJSON)
|
|
if accountStatus == "" {
|
|
return fmt.Errorf("account_status is required")
|
|
}
|
|
if probeSummaryJSON == "" {
|
|
probeSummaryJSON = "{}"
|
|
}
|
|
if _, err := r.db.ExecContext(ctx, `UPDATE import_batch_items SET account_status = ?, probe_summary_json = ? WHERE id = ?`, accountStatus, probeSummaryJSON, id); err != nil {
|
|
return fmt.Errorf("update import batch item %d: %w", id, err)
|
|
}
|
|
return nil
|
|
}
|