422 lines
12 KiB
Go
422 lines
12 KiB
Go
package app
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"sub2api-cn-relay-manager/internal/batch"
|
|
"sub2api-cn-relay-manager/internal/pack"
|
|
"sub2api-cn-relay-manager/internal/probe"
|
|
"sub2api-cn-relay-manager/internal/store/sqlite"
|
|
)
|
|
|
|
type batchImportReuseInspector struct {
|
|
store *sqlite.DB
|
|
hostRow sqlite.Host
|
|
currentRunID string
|
|
}
|
|
|
|
func (i batchImportReuseInspector) Inspect(ctx context.Context, input batch.ReuseLookupInput) (batch.ReuseLookupResult, error) {
|
|
if i.store == nil {
|
|
return batch.ReuseLookupResult{}, fmt.Errorf("store is required")
|
|
}
|
|
if i.hostRow.ID <= 0 {
|
|
return batch.ReuseLookupResult{}, fmt.Errorf("host row is required")
|
|
}
|
|
|
|
if reuse, ok, err := i.lookupPriorRunItem(ctx, input); err != nil {
|
|
return batch.ReuseLookupResult{}, err
|
|
} else if ok {
|
|
return reuse, nil
|
|
}
|
|
|
|
return i.lookupLegacyImportBatch(ctx, input)
|
|
}
|
|
|
|
func (i batchImportReuseInspector) lookupPriorRunItem(ctx context.Context, input batch.ReuseLookupInput) (batch.ReuseLookupResult, bool, error) {
|
|
runs, err := i.store.ImportRuns().List(ctx, 1000)
|
|
if err != nil {
|
|
return batch.ReuseLookupResult{}, false, err
|
|
}
|
|
|
|
for _, run := range runs {
|
|
if strings.TrimSpace(run.HostID) != strings.TrimSpace(i.hostRow.HostID) {
|
|
continue
|
|
}
|
|
if strings.TrimSpace(run.RunID) == strings.TrimSpace(i.currentRunID) {
|
|
continue
|
|
}
|
|
items, err := i.store.ImportRunItems().ListByRunID(ctx, run.RunID)
|
|
if err != nil {
|
|
return batch.ReuseLookupResult{}, false, err
|
|
}
|
|
for _, item := range items {
|
|
if strings.TrimSpace(item.ProviderID) != strings.TrimSpace(input.ProviderID) {
|
|
continue
|
|
}
|
|
if !apiKeyFingerprintMatches(item.APIKeyFingerprint, input.APIKeyFingerprint) {
|
|
continue
|
|
}
|
|
return i.reuseFromRunItem(ctx, item)
|
|
}
|
|
}
|
|
|
|
return batch.ReuseLookupResult{}, false, nil
|
|
}
|
|
|
|
func (i batchImportReuseInspector) reuseFromRunItem(ctx context.Context, item sqlite.ImportRunItem) (batch.ReuseLookupResult, bool, error) {
|
|
modelMapping, err := i.loadExistingModelMapping(ctx, item.ProviderID)
|
|
if err != nil {
|
|
return batch.ReuseLookupResult{}, false, err
|
|
}
|
|
|
|
reusedAccountID := int64(0)
|
|
if item.AccountID != nil {
|
|
reusedAccountID = *item.AccountID
|
|
} else if item.ReusedFromAccountID != nil {
|
|
reusedAccountID = *item.ReusedFromAccountID
|
|
}
|
|
|
|
state := strings.TrimSpace(item.MatchedAccountState)
|
|
if state == "" {
|
|
state = string(batch.MatchedAccountStateNone)
|
|
}
|
|
|
|
return batch.ReuseLookupResult{
|
|
ProviderMatched: true,
|
|
ExistingProviderID: strings.TrimSpace(item.ProviderID),
|
|
ExistingAccessStatus: normalizeRunItemAccessStatus(item.AccessStatus),
|
|
ExistingCanonicalFamilys: parseStringArrayJSON(item.CanonicalFamiliesJSON),
|
|
MatchedAccountID: reusedAccountID,
|
|
MatchedAccountState: batch.MatchedAccountState(state),
|
|
ExistingModelMapping: modelMapping,
|
|
LegacyBatchID: item.LegacyBatchID,
|
|
}, true, nil
|
|
}
|
|
|
|
func (i batchImportReuseInspector) lookupLegacyImportBatch(ctx context.Context, input batch.ReuseLookupInput) (batch.ReuseLookupResult, error) {
|
|
providers, err := i.lookupLegacyProviders(ctx, input)
|
|
if err != nil {
|
|
return batch.ReuseLookupResult{}, err
|
|
}
|
|
|
|
type candidate struct {
|
|
provider sqlite.Provider
|
|
batch sqlite.ImportBatch
|
|
item sqlite.ImportBatchItem
|
|
resources []sqlite.ManagedResource
|
|
}
|
|
|
|
var best *candidate
|
|
for _, providerRow := range providers {
|
|
batches, err := i.store.ImportBatches().ListByProviderIDAndHostID(ctx, providerRow.ID, i.hostRow.ID)
|
|
if err != nil {
|
|
return batch.ReuseLookupResult{}, err
|
|
}
|
|
for _, batchRow := range batches {
|
|
items, err := i.store.ImportBatchItems().GetByBatchID(ctx, batchRow.ID)
|
|
if err != nil {
|
|
return batch.ReuseLookupResult{}, err
|
|
}
|
|
for _, item := range items {
|
|
if !apiKeyFingerprintMatches(item.KeyFingerprint, input.APIKeyFingerprint) {
|
|
continue
|
|
}
|
|
resources, err := i.store.ManagedResources().GetByBatchID(ctx, batchRow.ID)
|
|
if err != nil {
|
|
return batch.ReuseLookupResult{}, err
|
|
}
|
|
best = &candidate{
|
|
provider: providerRow,
|
|
batch: batchRow,
|
|
item: item,
|
|
resources: resources,
|
|
}
|
|
break
|
|
}
|
|
if best != nil && best.batch.ID == batchRow.ID {
|
|
break
|
|
}
|
|
}
|
|
if best != nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
if best == nil {
|
|
return batch.ReuseLookupResult{}, nil
|
|
}
|
|
|
|
modelMapping, err := providerModelMapping(best.provider)
|
|
if err != nil {
|
|
return batch.ReuseLookupResult{}, err
|
|
}
|
|
canonicalFamilies, err := providerCanonicalFamilies(best.provider)
|
|
if err != nil {
|
|
return batch.ReuseLookupResult{}, err
|
|
}
|
|
|
|
accountHostID, err := accountIDFromProbeSummary(best.item.ProbeSummaryJSON)
|
|
if err != nil {
|
|
return batch.ReuseLookupResult{}, err
|
|
}
|
|
|
|
return batch.ReuseLookupResult{
|
|
ProviderMatched: true,
|
|
ExistingProviderID: strings.TrimSpace(best.provider.ProviderID),
|
|
ExistingAccessStatus: normalizeLegacyBatchAccessStatus(best.batch.AccessStatus),
|
|
ExistingCanonicalFamilys: canonicalFamilies,
|
|
MatchedAccountID: resolveManagedAccountNumericID(accountHostID, best.resources),
|
|
MatchedAccountState: normalizeLegacyMatchedAccountState(best.item.AccountStatus, best.batch.AccessStatus),
|
|
ExistingModelMapping: modelMapping,
|
|
LegacyBatchID: int64Ptr(best.batch.ID),
|
|
}, nil
|
|
}
|
|
|
|
func (i batchImportReuseInspector) lookupLegacyProviders(ctx context.Context, input batch.ReuseLookupInput) ([]sqlite.Provider, error) {
|
|
seen := make(map[int64]struct{})
|
|
providers := make([]sqlite.Provider, 0)
|
|
appendUnique := func(rows []sqlite.Provider) {
|
|
for _, row := range rows {
|
|
if _, ok := seen[row.ID]; ok {
|
|
continue
|
|
}
|
|
seen[row.ID] = struct{}{}
|
|
providers = append(providers, row)
|
|
}
|
|
}
|
|
|
|
if strings.TrimSpace(input.ProviderID) != "" {
|
|
rows, err := i.store.Providers().ListByProviderID(ctx, input.ProviderID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
appendUnique(rows)
|
|
}
|
|
|
|
if strings.TrimSpace(input.BaseURL) != "" {
|
|
rows, err := i.store.Providers().ListByBaseURL(ctx, strings.TrimSpace(input.BaseURL))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
appendUnique(rows)
|
|
}
|
|
|
|
return providers, nil
|
|
}
|
|
|
|
func (i batchImportReuseInspector) loadExistingModelMapping(ctx context.Context, providerID string) (map[string]string, error) {
|
|
providers, err := i.store.Providers().ListByProviderID(ctx, providerID)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if len(providers) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
for idx := len(providers) - 1; idx >= 0; idx-- {
|
|
providerRow := providers[idx]
|
|
batchRow, err := i.store.ImportBatches().GetLatestByProviderIDAndHostID(ctx, providerRow.ID, i.hostRow.ID)
|
|
if err != nil {
|
|
if err == sql.ErrNoRows {
|
|
continue
|
|
}
|
|
return nil, err
|
|
}
|
|
if batchRow.ID <= 0 {
|
|
continue
|
|
}
|
|
return providerModelMapping(providerRow)
|
|
}
|
|
|
|
return nil, nil
|
|
}
|
|
|
|
func providerModelMapping(providerRow sqlite.Provider) (map[string]string, error) {
|
|
type channelTemplatePayload struct {
|
|
ModelMapping map[string]string `json:"model_mapping"`
|
|
}
|
|
|
|
var manifest pack.ProviderManifest
|
|
if strings.TrimSpace(providerRow.ManifestJSON) != "" && strings.TrimSpace(providerRow.ManifestJSON) != "{}" {
|
|
if err := json.Unmarshal([]byte(providerRow.ManifestJSON), &manifest); err != nil {
|
|
return nil, fmt.Errorf("decode provider manifest for %q: %w", providerRow.ProviderID, err)
|
|
}
|
|
if len(manifest.ChannelTemplate.ModelMapping) > 0 {
|
|
return cloneStringMap(manifest.ChannelTemplate.ModelMapping), nil
|
|
}
|
|
}
|
|
|
|
if strings.TrimSpace(providerRow.ChannelTemplateJSON) != "" && strings.TrimSpace(providerRow.ChannelTemplateJSON) != "{}" {
|
|
var payload channelTemplatePayload
|
|
if err := json.Unmarshal([]byte(providerRow.ChannelTemplateJSON), &payload); err != nil {
|
|
return nil, fmt.Errorf("decode provider channel template for %q: %w", providerRow.ProviderID, err)
|
|
}
|
|
return cloneStringMap(payload.ModelMapping), nil
|
|
}
|
|
|
|
return map[string]string{}, nil
|
|
}
|
|
|
|
func providerCanonicalFamilies(providerRow sqlite.Provider) ([]string, error) {
|
|
models := make([]string, 0)
|
|
|
|
var manifest pack.ProviderManifest
|
|
if strings.TrimSpace(providerRow.ManifestJSON) != "" && strings.TrimSpace(providerRow.ManifestJSON) != "{}" {
|
|
if err := json.Unmarshal([]byte(providerRow.ManifestJSON), &manifest); err != nil {
|
|
return nil, fmt.Errorf("decode provider manifest for %q: %w", providerRow.ProviderID, err)
|
|
}
|
|
models = append(models, manifest.DefaultModels...)
|
|
for _, mapped := range manifest.ChannelTemplate.ModelMapping {
|
|
models = append(models, mapped)
|
|
}
|
|
}
|
|
|
|
models = append(models, parseStringArrayJSON(providerRow.DefaultModelsJSON)...)
|
|
|
|
seen := make(map[string]struct{}, len(models))
|
|
families := make([]string, 0, len(models))
|
|
for _, modelID := range models {
|
|
canonical := probe.CanonicalModelFamily(modelID)
|
|
if canonical == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[canonical]; ok {
|
|
continue
|
|
}
|
|
seen[canonical] = struct{}{}
|
|
families = append(families, canonical)
|
|
}
|
|
return families, nil
|
|
}
|
|
|
|
func normalizeRunItemAccessStatus(raw string) batch.AccessStatus {
|
|
switch strings.TrimSpace(raw) {
|
|
case string(batch.AccessStatusActive):
|
|
return batch.AccessStatusActive
|
|
case string(batch.AccessStatusDegraded):
|
|
return batch.AccessStatusDegraded
|
|
case string(batch.AccessStatusBroken):
|
|
return batch.AccessStatusBroken
|
|
default:
|
|
return batch.AccessStatusUnknown
|
|
}
|
|
}
|
|
|
|
func normalizeLegacyBatchAccessStatus(raw string) batch.AccessStatus {
|
|
switch strings.TrimSpace(raw) {
|
|
case "subscription_ready", "self_service_ready", "fully_ready":
|
|
return batch.AccessStatusActive
|
|
case "degraded":
|
|
return batch.AccessStatusDegraded
|
|
case "broken":
|
|
return batch.AccessStatusBroken
|
|
default:
|
|
return batch.AccessStatusUnknown
|
|
}
|
|
}
|
|
|
|
func normalizeLegacyMatchedAccountState(accountStatus string, batchAccessStatus string) batch.MatchedAccountState {
|
|
if normalizeLegacyBatchAccessStatus(batchAccessStatus) == batch.AccessStatusBroken {
|
|
return batch.MatchedAccountStateBroken
|
|
}
|
|
|
|
switch strings.TrimSpace(accountStatus) {
|
|
case "passed", "warning":
|
|
return batch.MatchedAccountStateActive
|
|
case "disabled":
|
|
return batch.MatchedAccountStateDisabled
|
|
case "deprecated":
|
|
return batch.MatchedAccountStateDeprecated
|
|
case "failed":
|
|
return batch.MatchedAccountStateBroken
|
|
default:
|
|
return batch.MatchedAccountStateNone
|
|
}
|
|
}
|
|
|
|
func accountIDFromProbeSummary(summaryJSON string) (string, error) {
|
|
if strings.TrimSpace(summaryJSON) == "" {
|
|
return "", nil
|
|
}
|
|
var payload map[string]any
|
|
if err := json.Unmarshal([]byte(summaryJSON), &payload); err != nil {
|
|
return "", err
|
|
}
|
|
accountID, _ := payload["account_id"].(string)
|
|
return strings.TrimSpace(accountID), nil
|
|
}
|
|
|
|
func resolveManagedAccountNumericID(accountHostID string, resources []sqlite.ManagedResource) int64 {
|
|
accountHostID = strings.TrimSpace(accountHostID)
|
|
if accountHostID == "" {
|
|
return 0
|
|
}
|
|
if numericID, err := strconv.ParseInt(accountHostID, 10, 64); err == nil && numericID > 0 {
|
|
return numericID
|
|
}
|
|
for _, resource := range resources {
|
|
if strings.TrimSpace(resource.ResourceType) != "account" {
|
|
continue
|
|
}
|
|
if strings.TrimSpace(resource.HostResourceID) == accountHostID {
|
|
return resource.ID
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func apiKeyFingerprintMatches(stored string, lookup string) bool {
|
|
stored = normalizeFingerprint(stored)
|
|
lookup = normalizeFingerprint(lookup)
|
|
if stored == "" || lookup == "" {
|
|
return false
|
|
}
|
|
if stored == lookup {
|
|
return true
|
|
}
|
|
return strings.HasPrefix(stored, lookup) || strings.HasPrefix(lookup, stored)
|
|
}
|
|
|
|
func normalizeFingerprint(raw string) string {
|
|
trimmed := strings.TrimSpace(raw)
|
|
if trimmed == "" {
|
|
return ""
|
|
}
|
|
return strings.TrimPrefix(trimmed, "sha256:")
|
|
}
|
|
|
|
func parseStringArrayJSON(raw string) []string {
|
|
values := []string{}
|
|
if strings.TrimSpace(raw) == "" {
|
|
return values
|
|
}
|
|
if err := json.Unmarshal([]byte(raw), &values); err != nil {
|
|
return []string{}
|
|
}
|
|
return values
|
|
}
|
|
|
|
func cloneStringMap(input map[string]string) map[string]string {
|
|
if len(input) == 0 {
|
|
return map[string]string{}
|
|
}
|
|
cloned := make(map[string]string, len(input))
|
|
for key, value := range input {
|
|
cloned[key] = value
|
|
}
|
|
return cloned
|
|
}
|
|
|
|
func int64Ptr(value int64) *int64 {
|
|
if value <= 0 {
|
|
return nil
|
|
}
|
|
cloned := value
|
|
return &cloned
|
|
}
|