Files
sub2api-cn-relay-manager/internal/app/batch_runtime_reuse.go
2026-05-27 20:23:42 +08:00

422 lines
12 KiB
Go

package app
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strconv"
"strings"
"sub2api-cn-relay-manager/internal/batch"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/probe"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
type batchImportReuseInspector struct {
store *sqlite.DB
hostRow sqlite.Host
currentRunID string
}
func (i batchImportReuseInspector) Inspect(ctx context.Context, input batch.ReuseLookupInput) (batch.ReuseLookupResult, error) {
if i.store == nil {
return batch.ReuseLookupResult{}, fmt.Errorf("store is required")
}
if i.hostRow.ID <= 0 {
return batch.ReuseLookupResult{}, fmt.Errorf("host row is required")
}
if reuse, ok, err := i.lookupPriorRunItem(ctx, input); err != nil {
return batch.ReuseLookupResult{}, err
} else if ok {
return reuse, nil
}
return i.lookupLegacyImportBatch(ctx, input)
}
func (i batchImportReuseInspector) lookupPriorRunItem(ctx context.Context, input batch.ReuseLookupInput) (batch.ReuseLookupResult, bool, error) {
runs, err := i.store.ImportRuns().List(ctx, 1000)
if err != nil {
return batch.ReuseLookupResult{}, false, err
}
for _, run := range runs {
if strings.TrimSpace(run.HostID) != strings.TrimSpace(i.hostRow.HostID) {
continue
}
if strings.TrimSpace(run.RunID) == strings.TrimSpace(i.currentRunID) {
continue
}
items, err := i.store.ImportRunItems().ListByRunID(ctx, run.RunID)
if err != nil {
return batch.ReuseLookupResult{}, false, err
}
for _, item := range items {
if strings.TrimSpace(item.ProviderID) != strings.TrimSpace(input.ProviderID) {
continue
}
if !apiKeyFingerprintMatches(item.APIKeyFingerprint, input.APIKeyFingerprint) {
continue
}
return i.reuseFromRunItem(ctx, item)
}
}
return batch.ReuseLookupResult{}, false, nil
}
func (i batchImportReuseInspector) reuseFromRunItem(ctx context.Context, item sqlite.ImportRunItem) (batch.ReuseLookupResult, bool, error) {
modelMapping, err := i.loadExistingModelMapping(ctx, item.ProviderID)
if err != nil {
return batch.ReuseLookupResult{}, false, err
}
reusedAccountID := int64(0)
if item.AccountID != nil {
reusedAccountID = *item.AccountID
} else if item.ReusedFromAccountID != nil {
reusedAccountID = *item.ReusedFromAccountID
}
state := strings.TrimSpace(item.MatchedAccountState)
if state == "" {
state = string(batch.MatchedAccountStateNone)
}
return batch.ReuseLookupResult{
ProviderMatched: true,
ExistingProviderID: strings.TrimSpace(item.ProviderID),
ExistingAccessStatus: normalizeRunItemAccessStatus(item.AccessStatus),
ExistingCanonicalFamilys: parseStringArrayJSON(item.CanonicalFamiliesJSON),
MatchedAccountID: reusedAccountID,
MatchedAccountState: batch.MatchedAccountState(state),
ExistingModelMapping: modelMapping,
LegacyBatchID: item.LegacyBatchID,
}, true, nil
}
func (i batchImportReuseInspector) lookupLegacyImportBatch(ctx context.Context, input batch.ReuseLookupInput) (batch.ReuseLookupResult, error) {
providers, err := i.lookupLegacyProviders(ctx, input)
if err != nil {
return batch.ReuseLookupResult{}, err
}
type candidate struct {
provider sqlite.Provider
batch sqlite.ImportBatch
item sqlite.ImportBatchItem
resources []sqlite.ManagedResource
}
var best *candidate
for _, providerRow := range providers {
batches, err := i.store.ImportBatches().ListByProviderIDAndHostID(ctx, providerRow.ID, i.hostRow.ID)
if err != nil {
return batch.ReuseLookupResult{}, err
}
for _, batchRow := range batches {
items, err := i.store.ImportBatchItems().GetByBatchID(ctx, batchRow.ID)
if err != nil {
return batch.ReuseLookupResult{}, err
}
for _, item := range items {
if !apiKeyFingerprintMatches(item.KeyFingerprint, input.APIKeyFingerprint) {
continue
}
resources, err := i.store.ManagedResources().GetByBatchID(ctx, batchRow.ID)
if err != nil {
return batch.ReuseLookupResult{}, err
}
best = &candidate{
provider: providerRow,
batch: batchRow,
item: item,
resources: resources,
}
break
}
if best != nil && best.batch.ID == batchRow.ID {
break
}
}
if best != nil {
break
}
}
if best == nil {
return batch.ReuseLookupResult{}, nil
}
modelMapping, err := providerModelMapping(best.provider)
if err != nil {
return batch.ReuseLookupResult{}, err
}
canonicalFamilies, err := providerCanonicalFamilies(best.provider)
if err != nil {
return batch.ReuseLookupResult{}, err
}
accountHostID, err := accountIDFromProbeSummary(best.item.ProbeSummaryJSON)
if err != nil {
return batch.ReuseLookupResult{}, err
}
return batch.ReuseLookupResult{
ProviderMatched: true,
ExistingProviderID: strings.TrimSpace(best.provider.ProviderID),
ExistingAccessStatus: normalizeLegacyBatchAccessStatus(best.batch.AccessStatus),
ExistingCanonicalFamilys: canonicalFamilies,
MatchedAccountID: resolveManagedAccountNumericID(accountHostID, best.resources),
MatchedAccountState: normalizeLegacyMatchedAccountState(best.item.AccountStatus, best.batch.AccessStatus),
ExistingModelMapping: modelMapping,
LegacyBatchID: int64Ptr(best.batch.ID),
}, nil
}
func (i batchImportReuseInspector) lookupLegacyProviders(ctx context.Context, input batch.ReuseLookupInput) ([]sqlite.Provider, error) {
seen := make(map[int64]struct{})
providers := make([]sqlite.Provider, 0)
appendUnique := func(rows []sqlite.Provider) {
for _, row := range rows {
if _, ok := seen[row.ID]; ok {
continue
}
seen[row.ID] = struct{}{}
providers = append(providers, row)
}
}
if strings.TrimSpace(input.ProviderID) != "" {
rows, err := i.store.Providers().ListByProviderID(ctx, input.ProviderID)
if err != nil {
return nil, err
}
appendUnique(rows)
}
if strings.TrimSpace(input.BaseURL) != "" {
rows, err := i.store.Providers().ListByBaseURL(ctx, strings.TrimSpace(input.BaseURL))
if err != nil {
return nil, err
}
appendUnique(rows)
}
return providers, nil
}
func (i batchImportReuseInspector) loadExistingModelMapping(ctx context.Context, providerID string) (map[string]string, error) {
providers, err := i.store.Providers().ListByProviderID(ctx, providerID)
if err != nil {
return nil, err
}
if len(providers) == 0 {
return nil, nil
}
for idx := len(providers) - 1; idx >= 0; idx-- {
providerRow := providers[idx]
batchRow, err := i.store.ImportBatches().GetLatestByProviderIDAndHostID(ctx, providerRow.ID, i.hostRow.ID)
if err != nil {
if err == sql.ErrNoRows {
continue
}
return nil, err
}
if batchRow.ID <= 0 {
continue
}
return providerModelMapping(providerRow)
}
return nil, nil
}
func providerModelMapping(providerRow sqlite.Provider) (map[string]string, error) {
type channelTemplatePayload struct {
ModelMapping map[string]string `json:"model_mapping"`
}
var manifest pack.ProviderManifest
if strings.TrimSpace(providerRow.ManifestJSON) != "" && strings.TrimSpace(providerRow.ManifestJSON) != "{}" {
if err := json.Unmarshal([]byte(providerRow.ManifestJSON), &manifest); err != nil {
return nil, fmt.Errorf("decode provider manifest for %q: %w", providerRow.ProviderID, err)
}
if len(manifest.ChannelTemplate.ModelMapping) > 0 {
return cloneStringMap(manifest.ChannelTemplate.ModelMapping), nil
}
}
if strings.TrimSpace(providerRow.ChannelTemplateJSON) != "" && strings.TrimSpace(providerRow.ChannelTemplateJSON) != "{}" {
var payload channelTemplatePayload
if err := json.Unmarshal([]byte(providerRow.ChannelTemplateJSON), &payload); err != nil {
return nil, fmt.Errorf("decode provider channel template for %q: %w", providerRow.ProviderID, err)
}
return cloneStringMap(payload.ModelMapping), nil
}
return map[string]string{}, nil
}
func providerCanonicalFamilies(providerRow sqlite.Provider) ([]string, error) {
models := make([]string, 0)
var manifest pack.ProviderManifest
if strings.TrimSpace(providerRow.ManifestJSON) != "" && strings.TrimSpace(providerRow.ManifestJSON) != "{}" {
if err := json.Unmarshal([]byte(providerRow.ManifestJSON), &manifest); err != nil {
return nil, fmt.Errorf("decode provider manifest for %q: %w", providerRow.ProviderID, err)
}
models = append(models, manifest.DefaultModels...)
for _, mapped := range manifest.ChannelTemplate.ModelMapping {
models = append(models, mapped)
}
}
models = append(models, parseStringArrayJSON(providerRow.DefaultModelsJSON)...)
seen := make(map[string]struct{}, len(models))
families := make([]string, 0, len(models))
for _, modelID := range models {
canonical := probe.CanonicalModelFamily(modelID)
if canonical == "" {
continue
}
if _, ok := seen[canonical]; ok {
continue
}
seen[canonical] = struct{}{}
families = append(families, canonical)
}
return families, nil
}
func normalizeRunItemAccessStatus(raw string) batch.AccessStatus {
switch strings.TrimSpace(raw) {
case string(batch.AccessStatusActive):
return batch.AccessStatusActive
case string(batch.AccessStatusDegraded):
return batch.AccessStatusDegraded
case string(batch.AccessStatusBroken):
return batch.AccessStatusBroken
default:
return batch.AccessStatusUnknown
}
}
func normalizeLegacyBatchAccessStatus(raw string) batch.AccessStatus {
switch strings.TrimSpace(raw) {
case "subscription_ready", "self_service_ready", "fully_ready":
return batch.AccessStatusActive
case "degraded":
return batch.AccessStatusDegraded
case "broken":
return batch.AccessStatusBroken
default:
return batch.AccessStatusUnknown
}
}
func normalizeLegacyMatchedAccountState(accountStatus string, batchAccessStatus string) batch.MatchedAccountState {
if normalizeLegacyBatchAccessStatus(batchAccessStatus) == batch.AccessStatusBroken {
return batch.MatchedAccountStateBroken
}
switch strings.TrimSpace(accountStatus) {
case "passed", "warning":
return batch.MatchedAccountStateActive
case "disabled":
return batch.MatchedAccountStateDisabled
case "deprecated":
return batch.MatchedAccountStateDeprecated
case "failed":
return batch.MatchedAccountStateBroken
default:
return batch.MatchedAccountStateNone
}
}
func accountIDFromProbeSummary(summaryJSON string) (string, error) {
if strings.TrimSpace(summaryJSON) == "" {
return "", nil
}
var payload map[string]any
if err := json.Unmarshal([]byte(summaryJSON), &payload); err != nil {
return "", err
}
accountID, _ := payload["account_id"].(string)
return strings.TrimSpace(accountID), nil
}
func resolveManagedAccountNumericID(accountHostID string, resources []sqlite.ManagedResource) int64 {
accountHostID = strings.TrimSpace(accountHostID)
if accountHostID == "" {
return 0
}
if numericID, err := strconv.ParseInt(accountHostID, 10, 64); err == nil && numericID > 0 {
return numericID
}
for _, resource := range resources {
if strings.TrimSpace(resource.ResourceType) != "account" {
continue
}
if strings.TrimSpace(resource.HostResourceID) == accountHostID {
return resource.ID
}
}
return 0
}
func apiKeyFingerprintMatches(stored string, lookup string) bool {
stored = normalizeFingerprint(stored)
lookup = normalizeFingerprint(lookup)
if stored == "" || lookup == "" {
return false
}
if stored == lookup {
return true
}
return strings.HasPrefix(stored, lookup) || strings.HasPrefix(lookup, stored)
}
func normalizeFingerprint(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return ""
}
return strings.TrimPrefix(trimmed, "sha256:")
}
func parseStringArrayJSON(raw string) []string {
values := []string{}
if strings.TrimSpace(raw) == "" {
return values
}
if err := json.Unmarshal([]byte(raw), &values); err != nil {
return []string{}
}
return values
}
func cloneStringMap(input map[string]string) map[string]string {
if len(input) == 0 {
return map[string]string{}
}
cloned := make(map[string]string, len(input))
for key, value := range input {
cloned[key] = value
}
return cloned
}
func int64Ptr(value int64) *int64 {
if value <= 0 {
return nil
}
cloned := value
return &cloned
}