Files
sub2api-cn-relay-manager/internal/provision/import_service.go
phamnazage-jpg 9134afed9f fix(provision): stabilize kimi a7m import closure
Downgrade the first third-party account test 403 to an advisory warning when models are already present, and retry transient gateway completion 503 responses during access closure.

Add regression coverage for the probe race and completion retry paths, update the execution board, and store the final v0.1.129 Kimi A7M fresh-host acceptance artifact that now reaches succeeded/active/subscription_ready.
2026-05-22 12:33:12 +08:00

542 lines
17 KiB
Go

package provision
import (
"context"
"errors"
"fmt"
"strings"
"sub2api-cn-relay-manager/internal/access"
"sub2api-cn-relay-manager/internal/host/sub2api"
"sub2api-cn-relay-manager/internal/pack"
)
const (
ImportModeStrict = "strict"
ImportModePartial = "partial"
AccessModeSubscription = "subscription"
AccessModeSelfService = "self_service"
BatchStatusSucceeded = "succeeded"
BatchStatusPartial = "partially_succeeded"
BatchStatusFailed = "failed"
BatchStatusRolledBack = "rolled_back"
ProviderStatusActive = "active"
ProviderStatusDegraded = "degraded"
ProviderStatusFailed = "failed"
AccessStatusSubscriptionReady = "subscription_ready"
AccessStatusSelfServiceReady = "self_service_ready"
AccessStatusFullyReady = "fully_ready"
AccessStatusBroken = "broken"
AccountStatusPassed = "passed"
AccountStatusWarning = "warning"
AccountStatusFailed = "failed"
)
type AccessRequest struct {
Mode string
ProbeAPIKey string
Subscriptions []SubscriptionTarget
}
type SubscriptionTarget struct {
UserID string
DurationDays int
}
type ImportRequest struct {
Provider pack.ProviderManifest
Mode string
Access AccessRequest
Keys []string
}
type ImportReport struct {
BatchStatus string
ProviderStatus string
AccessStatus string
AcceptedKeys []string
Group sub2api.GroupRef
Channel sub2api.ChannelRef
Plan *sub2api.PlanRef
Accounts []AccountImportResult
Gateway sub2api.GatewayAccessResult
}
type AccountImportResult struct {
Ref sub2api.AccountRef
Probe sub2api.ProbeResult
Models []sub2api.AccountModel
SmokeModelSeen bool
}
func (r AccountImportResult) ValidationStatus() string {
if !r.SmokeModelSeen {
return AccountStatusFailed
}
if r.Probe.OK {
return AccountStatusPassed
}
if isAdvisoryAccountProbeFailure(r.Probe) {
return AccountStatusWarning
}
return AccountStatusFailed
}
func (r AccountImportResult) HasBlockingFailure() bool {
return r.ValidationStatus() == AccountStatusFailed
}
func (r AccountImportResult) HasAdvisoryWarning() bool {
return r.ValidationStatus() == AccountStatusWarning
}
type hostAdapter interface {
sub2api.HostAdapter
CheckGatewayAccess(ctx context.Context, req sub2api.GatewayAccessCheckRequest) (sub2api.GatewayAccessResult, error)
}
func GatewayAccessReady(result sub2api.GatewayAccessResult) bool {
return result.OK && result.HasExpectedModel && result.CompletionOK
}
func isAdvisoryAccountProbeFailure(probe sub2api.ProbeResult) bool {
if probe.OK {
return false
}
message := strings.ToLower(strings.TrimSpace(probe.Message))
if message == "" {
return false
}
if isTransientAccountProbeFailure(message) {
return true
}
// OpenAI-compatible third-party upstreams such as Kimi/DeepSeek may
// create accounts and expose /models immediately, but the host's
// asynchronous /responses capability probe can complete slightly later.
// During that race window, the first /accounts/:id/test still takes the
// default /responses path and returns a plain 403 Forbidden even though
// the actual chat/completions route already works.
if strings.Contains(message, "api returned 403: forbidden") {
return true
}
if !strings.Contains(message, "responses api") {
return false
}
return strings.Contains(message, "当前测试接口仅支持") ||
strings.Contains(message, "账号本身可正常使用") ||
strings.Contains(message, "please directly") ||
strings.Contains(message, "actual api")
}
func isTransientAccountProbeFailure(message string) bool {
if !(strings.Contains(message, "429") ||
strings.Contains(message, "rate limit") ||
strings.Contains(message, "too many requests") ||
strings.Contains(message, "502") ||
strings.Contains(message, "503") ||
strings.Contains(message, "504") ||
strings.Contains(message, "bad gateway") ||
strings.Contains(message, "service unavailable") ||
strings.Contains(message, "timeout")) {
return false
}
return strings.Contains(message, "api returned") ||
strings.Contains(message, "rate_limit") ||
strings.Contains(message, "upstream") ||
strings.Contains(message, "temporar") ||
strings.Contains(message, "retry")
}
type resolvedManagedResources struct {
Group sub2api.GroupRef
Channel sub2api.ChannelRef
Plan *sub2api.PlanRef
Accounts []sub2api.NamedResource
CreatedGroup bool
CreatedChannel bool
CreatedPlan bool
}
type ImportService struct {
host hostAdapter
}
func NewImportService(host hostAdapter) *ImportService {
return &ImportService{host: host}
}
func (s *ImportService) Import(ctx context.Context, req ImportRequest) (report ImportReport, err error) {
normalizedKeys, err := normalizeKeys(req.Keys)
if err != nil {
return ImportReport{}, err
}
if err := validateMode(req.Mode); err != nil {
return ImportReport{}, err
}
if err := access.Validate(access.ClosureRequest{
Mode: req.Access.Mode,
ProbeAPIKey: req.Access.ProbeAPIKey,
Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions),
}); err != nil {
return ImportReport{}, err
}
report = ImportReport{AcceptedKeys: normalizedKeys}
rollback := newManagedResourceRollback(s.host)
defer func() {
if err == nil || req.Mode != ImportModeStrict {
return
}
if rollbackErr := rollback.Run(ctx); rollbackErr != nil {
err = errors.Join(err, fmt.Errorf("rollback managed resources: %w", rollbackErr))
}
}()
resources, err := s.ensureManagedResources(ctx, req.Provider, req.Access.Mode)
if err != nil {
return report, err
}
report.Group = resources.Group
report.Channel = resources.Channel
report.Plan = resources.Plan
if resources.CreatedGroup {
rollback.AddGroup(resources.Group.ID)
}
if resources.CreatedChannel {
rollback.AddChannel(resources.Channel.ID)
}
if resources.CreatedPlan && resources.Plan != nil {
rollback.AddPlan(resources.Plan.ID)
}
accounts, err := s.host.BatchCreateAccounts(ctx, buildBatchAccountsRequest(req.Provider, resources.Group.ID, normalizedKeys))
if err != nil {
return report, fmt.Errorf("batch create accounts: %w", err)
}
rollback.AddAccounts(accounts)
for _, account := range accounts {
probe, err := s.host.TestAccount(ctx, account.ID, req.Provider.SmokeTestModel)
if err != nil {
return failOrDegrade(report, req.Mode, fmt.Errorf("test account %s: %w", account.ID, err))
}
models, err := s.host.GetAccountModels(ctx, account.ID)
if err != nil {
return failOrDegrade(report, req.Mode, fmt.Errorf("get account models %s: %w", account.ID, err))
}
result := AccountImportResult{Ref: account, Probe: probe, Models: models, SmokeModelSeen: hasModel(models, req.Provider.SmokeTestModel)}
report.Accounts = append(report.Accounts, result)
}
failedAccounts := 0
for _, account := range report.Accounts {
if account.HasBlockingFailure() {
failedAccounts++
}
}
if failedAccounts == 0 {
if err := deleteNamedAccounts(ctx, s.host, resources.Accounts); err != nil {
return failOrDegrade(report, req.Mode, fmt.Errorf("cleanup existing accounts: %w", err))
}
}
if failedAccounts > 0 && req.Mode == ImportModeStrict {
report.BatchStatus = BatchStatusFailed
report.ProviderStatus = ProviderStatusFailed
report.AccessStatus = AccessStatusBroken
return report, fmt.Errorf("strict import failed: %d account(s) did not pass smoke validation", failedAccounts)
}
closureService := access.NewService(s.host)
gateway, err := closureService.Close(ctx, access.ClosureRequest{
Mode: req.Access.Mode,
ProbeAPIKey: req.Access.ProbeAPIKey,
Subscriptions: toAccessSubscriptionTargets(req.Access.Subscriptions),
GroupID: resources.Group.ID,
ExpectedModel: req.Provider.SmokeTestModel,
})
if err != nil {
return failOrDegrade(report, req.Mode, err)
}
report.Gateway = gateway
report.BatchStatus = BatchStatusSucceeded
report.ProviderStatus = ProviderStatusActive
if failedAccounts > 0 || !GatewayAccessReady(gateway) {
report.BatchStatus = BatchStatusPartial
report.ProviderStatus = ProviderStatusDegraded
}
switch req.Access.Mode {
case AccessModeSubscription:
report.AccessStatus = AccessStatusSubscriptionReady
case AccessModeSelfService:
report.AccessStatus = AccessStatusSelfServiceReady
}
if !GatewayAccessReady(gateway) {
report.AccessStatus = AccessStatusBroken
}
return report, nil
}
func (s *ImportService) ensureManagedResources(ctx context.Context, provider pack.ProviderManifest, accessMode string) (resolvedManagedResources, error) {
names := SuggestResourceNamesForMode(provider, accessMode)
snapshot, err := s.host.ListManagedResources(ctx, sub2api.ListManagedResourcesRequest{
GroupName: names.Group,
ChannelName: names.Channel,
PlanName: names.Plan,
AccountNamePrefix: SuggestAccountNamePrefix(provider),
})
if err != nil {
return resolvedManagedResources{}, fmt.Errorf("list managed resources: %w", err)
}
result := resolvedManagedResources{Accounts: append([]sub2api.NamedResource(nil), snapshot.Accounts...)}
group, created, err := ensureGroup(ctx, s.host, snapshot.Groups, provider, accessMode, names.Group)
if err != nil {
return resolvedManagedResources{}, fmt.Errorf("ensure group: %w", err)
}
result.Group = group
result.CreatedGroup = created
channel, created, err := ensureChannel(ctx, s.host, snapshot.Channels, provider, group.ID, names.Channel)
if err != nil {
return resolvedManagedResources{}, fmt.Errorf("ensure channel: %w", err)
}
result.Channel = channel
result.CreatedChannel = created
if accessMode == AccessModeSubscription {
plan, created, err := ensurePlan(ctx, s.host, snapshot.Plans, provider, group.ID, names.Plan)
if err != nil {
return resolvedManagedResources{}, fmt.Errorf("ensure plan: %w", err)
}
result.Plan = &plan
result.CreatedPlan = created
}
return result, nil
}
func ensureGroup(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, accessMode, groupName string) (sub2api.GroupRef, bool, error) {
switch len(existing) {
case 0:
groupReq := sub2api.CreateGroupRequest{Name: groupName, Platform: provider.Platform, RateMultiplier: provider.GroupTemplate.RateMultiplier}
if accessMode == AccessModeSubscription {
groupReq.SubscriptionType = "subscription"
}
group, err := host.CreateGroup(ctx, groupReq)
return group, true, err
case 1:
return sub2api.GroupRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil
default:
return sub2api.GroupRef{}, false, fmt.Errorf("multiple groups already exist for %q", groupName)
}
}
func ensureChannel(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID, channelName string) (sub2api.ChannelRef, bool, error) {
channelReq := buildChannelRequest(provider, groupID, channelName)
switch len(existing) {
case 0:
channel, err := host.CreateChannel(ctx, channelReq)
return channel, true, err
case 1:
if err := host.UpdateChannel(ctx, existing[0].ID, channelReq); err != nil {
return sub2api.ChannelRef{}, false, err
}
return sub2api.ChannelRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil
default:
return sub2api.ChannelRef{}, false, fmt.Errorf("multiple channels already exist for %q", channelName)
}
}
func buildChannelRequest(provider pack.ProviderManifest, groupID, channelName string) sub2api.CreateChannelRequest {
return sub2api.CreateChannelRequest{
Name: channelName,
GroupIDs: []string{groupID},
ModelMapping: provider.ChannelTemplate.ModelMapping,
ModelPricing: []sub2api.ChannelModelPricing{{
Platform: provider.Platform,
Models: append([]string(nil), provider.DefaultModels...),
BillingMode: "token",
Intervals: []sub2api.ChannelPricingTier{},
}},
Platform: provider.Platform,
RestrictModels: true,
BillingModelSource: "channel_mapped",
}
}
func ensurePlan(ctx context.Context, host hostAdapter, existing []sub2api.NamedResource, provider pack.ProviderManifest, groupID, planName string) (sub2api.PlanRef, bool, error) {
switch len(existing) {
case 0:
plan, err := host.CreatePlan(ctx, sub2api.CreatePlanRequest{GroupID: groupID, Name: planName, Price: provider.PlanTemplate.Price, ValidityDays: provider.PlanTemplate.ValidityDays, ValidityUnit: provider.PlanTemplate.ValidityUnit})
return plan, true, err
case 1:
return sub2api.PlanRef{ID: existing[0].ID, Name: existing[0].Name}, false, nil
default:
return sub2api.PlanRef{}, false, fmt.Errorf("multiple plans already exist for %q", planName)
}
}
func validateMode(mode string) error {
switch strings.TrimSpace(mode) {
case ImportModeStrict, ImportModePartial:
return nil
default:
return fmt.Errorf("unsupported import mode %q", mode)
}
}
func toAccessSubscriptionTargets(targets []SubscriptionTarget) []access.SubscriptionTarget {
result := make([]access.SubscriptionTarget, 0, len(targets))
for _, target := range targets {
result = append(result, access.SubscriptionTarget{UserID: target.UserID, DurationDays: target.DurationDays})
}
return result
}
func normalizeKeys(keys []string) ([]string, error) {
seen := map[string]struct{}{}
result := make([]string, 0, len(keys))
for _, key := range keys {
normalized := strings.TrimSpace(strings.TrimPrefix(key, "\ufeff"))
if normalized == "" {
continue
}
if _, ok := seen[normalized]; ok {
continue
}
seen[normalized] = struct{}{}
result = append(result, normalized)
}
if len(result) == 0 {
return nil, fmt.Errorf("at least one api key is required")
}
return result, nil
}
func buildBatchAccountsRequest(provider pack.ProviderManifest, groupID string, keys []string) sub2api.BatchCreateAccountsRequest {
accounts := make([]sub2api.CreateAccountRequest, 0, len(keys))
for index, key := range keys {
accounts = append(accounts, sub2api.CreateAccountRequest{
Name: fmt.Sprintf("%s-%02d", provider.ProviderID, index+1),
Platform: provider.Platform,
Type: provider.AccountType,
GroupIDs: []string{groupID},
Credentials: map[string]any{
"base_url": provider.BaseURL,
"api_key": key,
"model_mapping": provider.ChannelTemplate.ModelMapping,
},
})
}
return sub2api.BatchCreateAccountsRequest{Accounts: accounts}
}
func hasModel(models []sub2api.AccountModel, target string) bool {
for _, model := range models {
if strings.TrimSpace(model.ID) == strings.TrimSpace(target) {
return true
}
}
return false
}
func deleteNamedAccounts(ctx context.Context, host hostAdapter, accounts []sub2api.NamedResource) error {
var errs []error
for index := len(accounts) - 1; index >= 0; index-- {
accountID := strings.TrimSpace(accounts[index].ID)
if accountID == "" {
continue
}
if err := host.DeleteAccount(ctx, accountID); err != nil {
errs = append(errs, fmt.Errorf("delete stale account %s: %w", accountID, err))
}
}
return errors.Join(errs...)
}
type managedResourceRollback struct {
host hostAdapter
groupID string
channelID string
planID string
accountIDs []string
}
func newManagedResourceRollback(host hostAdapter) *managedResourceRollback {
return &managedResourceRollback{host: host}
}
func (r *managedResourceRollback) AddGroup(groupID string) {
r.groupID = strings.TrimSpace(groupID)
}
func (r *managedResourceRollback) AddChannel(channelID string) {
r.channelID = strings.TrimSpace(channelID)
}
func (r *managedResourceRollback) AddPlan(planID string) {
r.planID = strings.TrimSpace(planID)
}
func (r *managedResourceRollback) AddAccounts(accounts []sub2api.AccountRef) {
for _, account := range accounts {
accountID := strings.TrimSpace(account.ID)
if accountID == "" {
continue
}
r.accountIDs = append(r.accountIDs, accountID)
}
}
func (r *managedResourceRollback) Run(ctx context.Context) error {
if r == nil || r.host == nil {
return nil
}
var errs []error
for index := len(r.accountIDs) - 1; index >= 0; index-- {
if err := r.host.DeleteAccount(ctx, r.accountIDs[index]); err != nil {
errs = append(errs, fmt.Errorf("delete account %s: %w", r.accountIDs[index], err))
}
}
if r.planID != "" {
if err := r.host.DeletePlan(ctx, r.planID); err != nil {
errs = append(errs, fmt.Errorf("delete plan %s: %w", r.planID, err))
}
}
if r.channelID != "" {
if err := r.host.DeleteChannel(ctx, r.channelID); err != nil {
errs = append(errs, fmt.Errorf("delete channel %s: %w", r.channelID, err))
}
}
if r.groupID != "" {
if err := r.host.DeleteGroup(ctx, r.groupID); err != nil {
errs = append(errs, fmt.Errorf("delete group %s: %w", r.groupID, err))
}
}
return errors.Join(errs...)
}
func failOrDegrade(report ImportReport, mode string, err error) (ImportReport, error) {
if mode == ImportModeStrict {
report.BatchStatus = BatchStatusFailed
report.ProviderStatus = ProviderStatusFailed
report.AccessStatus = AccessStatusBroken
return report, err
}
report.BatchStatus = BatchStatusPartial
report.ProviderStatus = ProviderStatusDegraded
report.AccessStatus = AccessStatusBroken
return report, err
}