Files
sub2api-cn-relay-manager/internal/app/batch_runtime.go
2026-05-22 16:12:52 +08:00

429 lines
13 KiB
Go

package app
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"time"
"sub2api-cn-relay-manager/internal/batch"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
"sub2api-cn-relay-manager/internal/probe"
"sub2api-cn-relay-manager/internal/provision"
"sub2api-cn-relay-manager/internal/store/sqlite"
)
const (
generatedBatchImportPackID = "batch-auto-import-v2-generated"
generatedBatchImportPackVersion = "2026.05.22"
batchImportRetryDelay = 200 * time.Millisecond
)
type batchImportRuntimeRunner struct {
store *sqlite.DB
hostRow sqlite.Host
hostClient *sub2api.Client
request CreateBatchImportRunRequest
}
func (r batchImportRuntimeRunner) execute(ctx context.Context) (BatchImportRunCreateResponse, error) {
runID := fmt.Sprintf("run_%d", time.Now().UnixNano())
service := batch.BatchImportService{
RunStore: r.store.ImportRuns(),
ItemStore: r.store.ImportRunItems(),
ProbeModels: probe.ProviderModels,
ProbeCapabilities: probe.ProbeCapabilities,
Provisioner: batchImportProvisioner{
store: r.store,
hostRow: r.hostRow,
hostClient: r.hostClient,
request: r.request,
},
}
entries := make([]batch.BatchImportEntry, 0, len(r.request.Entries))
for _, entry := range r.request.Entries {
entries = append(entries, batch.BatchImportEntry{
BaseURL: entry.BaseURL,
APIKey: entry.APIKey,
RequestedModels: append([]string(nil), entry.RequestedModels...),
})
}
if _, err := service.StartRun(ctx, batch.BatchImportRunRequest{
RunID: runID,
Mode: r.request.Mode,
AccessMode: r.request.AccessMode,
HostID: r.hostRow.HostID,
Entries: entries,
}); err != nil {
return BatchImportRunCreateResponse{}, err
}
if err := r.advanceRun(ctx, runID); err != nil {
return BatchImportRunCreateResponse{}, err
}
run, err := r.store.ImportRuns().GetByRunID(ctx, runID)
if err != nil {
return BatchImportRunCreateResponse{}, err
}
return BatchImportRunCreateResponse{
RunID: run.RunID,
State: run.State,
ResultPage: "/batch-import/runs/" + run.RunID,
TotalItems: run.TotalItems,
ActiveItems: run.ActiveItems,
DegradedItems: run.DegradedItems,
BrokenItems: run.BrokenItems,
WarningItems: run.WarningItems,
}, nil
}
func (r batchImportRuntimeRunner) advanceRun(ctx context.Context, runID string) error {
timeout := time.Duration(r.request.ConfirmWaitTimeoutSec) * time.Second
if timeout <= 0 {
timeout = time.Second
}
deadline := time.Now().Add(timeout)
worker := batch.ConfirmationWorker{
WorkerID: "batch-import-api",
ItemStore: batchImportRunItemStore{store: r.store, runID: runID},
EventStore: r.store.ImportRunEvents(),
LeaseDuration: time.Minute,
RetryDelay: batchImportRetryDelay,
Confirmer: r.confirmItem,
}
validator := batch.ValidationService{
ItemStore: r.store.ImportRunItems(),
RunStore: r.store.ImportRuns(),
Validator: r.validateItem,
}
for {
now := time.Now()
if err := worker.Tick(ctx, now); err != nil {
return err
}
items, err := r.store.ImportRunItems().ListByRunID(ctx, runID)
if err != nil {
return err
}
pendingWork := false
for _, item := range items {
switch item.CurrentStage {
case string(batch.ItemStageValidate):
if err := validator.ValidateItem(ctx, item); err != nil {
return err
}
case string(batch.ItemStageConfirm):
if item.ConfirmationStatus == string(batch.ConfirmationPending) {
pendingWork = true
}
}
}
run, err := r.store.ImportRuns().GetByRunID(ctx, runID)
if err != nil {
return err
}
if run.TotalItems > 0 && run.CompletedItems >= run.TotalItems {
return nil
}
if !pendingWork || !time.Now().Before(deadline) {
return nil
}
if err := sleepWithContext(ctx, batchImportRetryDelay); err != nil {
return err
}
}
}
func (r batchImportRuntimeRunner) confirmItem(ctx context.Context, item sqlite.ImportRunItem) (batch.ConfirmationResult, error) {
accountID, err := resolveManagedResourceHostID(ctx, r.store, item, "account")
if err != nil {
return batch.ConfirmationResult{}, err
}
probeResult, err := r.hostClient.TestAccount(ctx, accountID, item.ResolvedSmokeModel)
if err != nil {
var httpErr *sub2api.HTTPError
if errors.As(err, &httpErr) {
return batch.ConfirmationResult{StatusCode: httpErr.StatusCode, Message: httpErr.Body}, nil
}
return batch.ConfirmationResult{}, err
}
if probeResult.OK {
return batch.ConfirmationResult{StatusCode: http.StatusOK, Message: probeResult.Message}, nil
}
message := strings.TrimSpace(probeResult.Message)
lowerMessage := strings.ToLower(message)
switch {
case strings.Contains(lowerMessage, "no available accounts"):
return batch.ConfirmationResult{StatusCode: http.StatusServiceUnavailable, Message: message}, nil
case strings.Contains(lowerMessage, "forbidden"):
return batch.ConfirmationResult{StatusCode: http.StatusForbidden, Message: message}, nil
default:
return batch.ConfirmationResult{StatusCode: http.StatusBadRequest, Message: message}, nil
}
}
func (r batchImportRuntimeRunner) validateItem(ctx context.Context, item sqlite.ImportRunItem) (sub2api.GatewayCompletionResult, error) {
apiKey, err := r.resolveValidationAPIKey(ctx, item)
if err != nil {
return sub2api.GatewayCompletionResult{}, err
}
return r.hostClient.CheckGatewayCompletion(ctx, sub2api.GatewayCompletionCheckRequest{
APIKey: apiKey,
Model: item.ResolvedSmokeModel,
Prompt: "ping",
MaxTokens: 8,
})
}
func (r batchImportRuntimeRunner) resolveValidationAPIKey(ctx context.Context, item sqlite.ImportRunItem) (string, error) {
switch strings.TrimSpace(r.request.AccessMode) {
case provision.AccessModeSelfService:
return strings.TrimSpace(r.request.ProbeAPIKey), nil
case provision.AccessModeSubscription:
if len(r.request.SubscriptionUsers) == 0 {
return "", fmt.Errorf("subscription_users is required")
}
groupID, err := resolveManagedResourceHostID(ctx, r.store, item, "group")
if err != nil {
return "", err
}
accessRef, err := r.hostClient.EnsureSubscriptionAccess(ctx, sub2api.EnsureSubscriptionAccessRequest{
UserSelector: r.request.SubscriptionUsers[0],
GroupID: groupID,
})
if err != nil {
return "", err
}
userID := strings.TrimSpace(accessRef.UserID)
if userID == "" {
userID = r.request.SubscriptionUsers[0]
}
if _, err := r.hostClient.AssignSubscription(ctx, sub2api.AssignSubscriptionRequest{
UserID: userID,
GroupID: groupID,
DurationDays: r.request.SubscriptionDays,
}); err != nil {
return "", err
}
if strings.TrimSpace(accessRef.APIKey) == "" {
return "", fmt.Errorf("subscription access api key is empty")
}
return strings.TrimSpace(accessRef.APIKey), nil
default:
return "", fmt.Errorf("unsupported access mode %q", r.request.AccessMode)
}
}
type batchImportProvisioner struct {
store *sqlite.DB
hostRow sqlite.Host
hostClient *sub2api.Client
request CreateBatchImportRunRequest
}
func (p batchImportProvisioner) Provision(ctx context.Context, req batch.ProvisionRequest) (batch.ProvisionResult, error) {
runtimeService := provision.NewRuntimeImportService(p.store, p.hostClient)
providerManifest := generatedBatchImportProviderManifest(req, p.request)
result, err := runtimeService.Import(ctx, provision.RuntimeImportRequest{
HostID: p.hostRow.HostID,
HostBaseURL: p.hostRow.BaseURL,
Pack: generatedBatchImportPack(providerManifest),
Provider: providerManifest,
Mode: firstNonEmptyString(strings.TrimSpace(p.request.Mode), provision.ImportModeStrict),
Keys: []string{strings.TrimSpace(req.Entry.APIKey)},
Access: batchImportAccessRequest(p.request),
})
if err != nil {
return batch.ProvisionResult{}, err
}
legacyBatchID := result.BatchID
return batch.ProvisionResult{
LegacyBatchID: &legacyBatchID,
LegacyProviderID: req.ProviderID,
}, nil
}
func (p batchImportProvisioner) Patch(_ context.Context, _ batch.PatchProvisionRequest) error {
return nil
}
type batchImportRunItemStore struct {
store *sqlite.DB
runID string
}
func (s batchImportRunItemStore) List(ctx context.Context) ([]sqlite.ImportRunItem, error) {
return s.store.ImportRunItems().ListByRunID(ctx, s.runID)
}
func (s batchImportRunItemStore) Upsert(ctx context.Context, item sqlite.ImportRunItem) error {
return s.store.ImportRunItems().Upsert(ctx, item)
}
func generatedBatchImportPack(providerManifest pack.ProviderManifest) pack.LoadedPack {
return pack.LoadedPack{
Manifest: pack.Manifest{
PackID: generatedBatchImportPackID,
Version: generatedBatchImportPackVersion,
Vendor: "sub2api-cn-relay-manager",
TargetHost: "sub2api",
},
Providers: []pack.ProviderManifest{providerManifest},
Checksum: generatedBatchImportPackID + "@" + generatedBatchImportPackVersion,
}
}
func generatedBatchImportProviderManifest(req batch.ProvisionRequest, createReq CreateBatchImportRunRequest) pack.ProviderManifest {
defaultModels := uniqueNonEmptyStrings(capabilityProfileModels(req.CapabilityProfile))
if len(defaultModels) == 0 {
defaultModels = uniqueNonEmptyStrings([]string{req.ResolvedModel})
}
smokeModel := firstNonEmptyString(strings.TrimSpace(req.ResolvedModel))
if smokeModel == "" && len(defaultModels) > 0 {
smokeModel = defaultModels[0]
}
if smokeModel == "" {
smokeModel = "ping"
}
if len(defaultModels) == 0 {
defaultModels = []string{smokeModel}
}
modelMapping := make(map[string]string, len(defaultModels))
for _, modelID := range defaultModels {
modelMapping[modelID] = modelID
}
names := fmt.Sprintf("crm-%s", strings.TrimSpace(req.ProviderID))
validityDays := createReq.SubscriptionDays
if validityDays <= 0 {
validityDays = 30
}
return pack.ProviderManifest{
ProviderID: req.ProviderID,
DisplayName: req.ProviderID,
BaseURL: strings.TrimSpace(req.Entry.BaseURL),
Platform: "openai",
AccountType: "apikey",
DefaultModels: defaultModels,
SmokeTestModel: smokeModel,
GroupTemplate: pack.GroupTemplate{
Name: names + "-group",
RateMultiplier: 1,
},
ChannelTemplate: pack.ChannelTemplate{
Name: names + "-channel",
ModelMapping: modelMapping,
},
PlanTemplate: pack.PlanTemplate{
Name: names + "-plan",
Price: 1,
ValidityDays: validityDays,
ValidityUnit: "day",
},
Import: pack.ImportOptions{
SupportsMultiKey: true,
SupportsStrict: true,
SupportsPartial: true,
},
}
}
func batchImportAccessRequest(req CreateBatchImportRunRequest) provision.AccessRequest {
subscriptions := make([]provision.SubscriptionTarget, 0, len(req.SubscriptionUsers))
for _, userID := range req.SubscriptionUsers {
subscriptions = append(subscriptions, provision.SubscriptionTarget{
UserID: userID,
DurationDays: req.SubscriptionDays,
})
}
return provision.AccessRequest{
Mode: strings.TrimSpace(req.AccessMode),
ProbeAPIKey: strings.TrimSpace(req.ProbeAPIKey),
Subscriptions: subscriptions,
}
}
func resolveManagedResourceHostID(ctx context.Context, store *sqlite.DB, item sqlite.ImportRunItem, resourceType string) (string, error) {
if store == nil {
return "", fmt.Errorf("store is required")
}
if item.LegacyBatchID == nil || *item.LegacyBatchID <= 0 {
return "", fmt.Errorf("legacy_batch_id is required for %s lookup", strings.TrimSpace(resourceType))
}
resources, err := store.ManagedResources().GetByBatchID(ctx, *item.LegacyBatchID)
if err != nil {
return "", err
}
for _, resource := range resources {
if strings.TrimSpace(resource.ResourceType) == strings.TrimSpace(resourceType) {
return strings.TrimSpace(resource.HostResourceID), nil
}
}
return "", fmt.Errorf("%s resource not found for batch %d", resourceType, *item.LegacyBatchID)
}
func capabilityProfileModels(profile *probe.CapabilityProfile) []string {
if profile == nil {
return nil
}
models := make([]string, 0, len(profile.ModelProfiles))
for _, modelProfile := range profile.ModelProfiles {
models = append(models, strings.TrimSpace(modelProfile.RawModelID))
}
return models
}
func uniqueNonEmptyStrings(values []string) []string {
seen := make(map[string]struct{}, len(values))
result := make([]string, 0, len(values))
for _, value := range values {
trimmed := strings.TrimSpace(value)
if trimmed == "" {
continue
}
if _, ok := seen[trimmed]; ok {
continue
}
seen[trimmed] = struct{}{}
result = append(result, trimmed)
}
return result
}
func firstNonEmptyString(values ...string) string {
for _, value := range values {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed
}
}
return ""
}
func sleepWithContext(ctx context.Context, delay time.Duration) error {
timer := time.NewTimer(delay)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}