chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,286 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
pkghttputil "github.com/Wei-Shaw/sub2api/internal/pkg/httputil"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/ip"
|
||||
"github.com/Wei-Shaw/sub2api/internal/pkg/logger"
|
||||
middleware2 "github.com/Wei-Shaw/sub2api/internal/server/middleware"
|
||||
"github.com/Wei-Shaw/sub2api/internal/service"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// ChatCompletions handles OpenAI Chat Completions API requests.
|
||||
// POST /v1/chat/completions
|
||||
func (h *OpenAIGatewayHandler) ChatCompletions(c *gin.Context) {
|
||||
streamStarted := false
|
||||
defer h.recoverResponsesPanic(c, &streamStarted)
|
||||
|
||||
requestStart := time.Now()
|
||||
|
||||
apiKey, ok := middleware2.GetAPIKeyFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusUnauthorized, "authentication_error", "Invalid API key")
|
||||
return
|
||||
}
|
||||
|
||||
subject, ok := middleware2.GetAuthSubjectFromContext(c)
|
||||
if !ok {
|
||||
h.errorResponse(c, http.StatusInternalServerError, "api_error", "User context not found")
|
||||
return
|
||||
}
|
||||
reqLog := requestLogger(
|
||||
c,
|
||||
"handler.openai_gateway.chat_completions",
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
)
|
||||
|
||||
if !h.ensureResponsesDependencies(c, reqLog) {
|
||||
return
|
||||
}
|
||||
|
||||
body, err := pkghttputil.ReadRequestBodyWithPrealloc(c.Request)
|
||||
if err != nil {
|
||||
if maxErr, ok := extractMaxBytesError(err); ok {
|
||||
h.errorResponse(c, http.StatusRequestEntityTooLarge, "invalid_request_error", buildBodyTooLargeMessage(maxErr.Limit))
|
||||
return
|
||||
}
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to read request body")
|
||||
return
|
||||
}
|
||||
if len(body) == 0 {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Request body is empty")
|
||||
return
|
||||
}
|
||||
|
||||
if !gjson.ValidBytes(body) {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "Failed to parse request body")
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(body, "model")
|
||||
if !modelResult.Exists() || modelResult.Type != gjson.String || modelResult.String() == "" {
|
||||
h.errorResponse(c, http.StatusBadRequest, "invalid_request_error", "model is required")
|
||||
return
|
||||
}
|
||||
reqModel := modelResult.String()
|
||||
reqStream := gjson.GetBytes(body, "stream").Bool()
|
||||
|
||||
reqLog = reqLog.With(zap.String("model", reqModel), zap.Bool("stream", reqStream))
|
||||
|
||||
setOpsRequestContext(c, reqModel, reqStream, body)
|
||||
|
||||
if h.errorPassthroughService != nil {
|
||||
service.BindErrorPassthroughService(c, h.errorPassthroughService)
|
||||
}
|
||||
|
||||
subscription, _ := middleware2.GetSubscriptionFromContext(c)
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsAuthLatencyMsKey, time.Since(requestStart).Milliseconds())
|
||||
routingStart := time.Now()
|
||||
|
||||
userReleaseFunc, acquired := h.acquireResponsesUserSlot(c, subject.UserID, subject.Concurrency, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
if userReleaseFunc != nil {
|
||||
defer userReleaseFunc()
|
||||
}
|
||||
|
||||
if err := h.billingCacheService.CheckBillingEligibility(c.Request.Context(), apiKey.User, apiKey, apiKey.Group, subscription); err != nil {
|
||||
reqLog.Info("openai_chat_completions.billing_eligibility_check_failed", zap.Error(err))
|
||||
status, code, message := billingErrorDetails(err)
|
||||
h.handleStreamingAwareError(c, status, code, message, streamStarted)
|
||||
return
|
||||
}
|
||||
|
||||
sessionHash := h.gatewayService.GenerateSessionHash(c, body)
|
||||
promptCacheKey := h.gatewayService.ExtractSessionID(c, body)
|
||||
|
||||
maxAccountSwitches := h.maxAccountSwitches
|
||||
switchCount := 0
|
||||
failedAccountIDs := make(map[int64]struct{})
|
||||
sameAccountRetryCount := make(map[int64]int)
|
||||
var lastFailoverErr *service.UpstreamFailoverError
|
||||
|
||||
for {
|
||||
c.Set("openai_chat_completions_fallback_model", "")
|
||||
reqLog.Debug("openai_chat_completions.account_selecting", zap.Int("excluded_account_count", len(failedAccountIDs)))
|
||||
selection, scheduleDecision, err := h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
reqModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err != nil {
|
||||
reqLog.Warn("openai_chat_completions.account_select_failed",
|
||||
zap.Error(err),
|
||||
zap.Int("excluded_account_count", len(failedAccountIDs)),
|
||||
)
|
||||
if len(failedAccountIDs) == 0 {
|
||||
defaultModel := ""
|
||||
if apiKey.Group != nil {
|
||||
defaultModel = apiKey.Group.DefaultMappedModel
|
||||
}
|
||||
if defaultModel != "" && defaultModel != reqModel {
|
||||
reqLog.Info("openai_chat_completions.fallback_to_default_model",
|
||||
zap.String("default_mapped_model", defaultModel),
|
||||
)
|
||||
selection, scheduleDecision, err = h.gatewayService.SelectAccountWithScheduler(
|
||||
c.Request.Context(),
|
||||
apiKey.GroupID,
|
||||
"",
|
||||
sessionHash,
|
||||
defaultModel,
|
||||
failedAccountIDs,
|
||||
service.OpenAIUpstreamTransportAny,
|
||||
)
|
||||
if err == nil && selection != nil {
|
||||
c.Set("openai_chat_completions_fallback_model", defaultModel)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "Service temporarily unavailable", streamStarted)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
if lastFailoverErr != nil {
|
||||
h.handleFailoverExhausted(c, lastFailoverErr, streamStarted)
|
||||
} else {
|
||||
h.handleStreamingAwareError(c, http.StatusBadGateway, "api_error", "Upstream request failed", streamStarted)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
if selection == nil || selection.Account == nil {
|
||||
h.handleStreamingAwareError(c, http.StatusServiceUnavailable, "api_error", "No available accounts", streamStarted)
|
||||
return
|
||||
}
|
||||
account := selection.Account
|
||||
sessionHash = ensureOpenAIPoolModeSessionHash(sessionHash, account)
|
||||
reqLog.Debug("openai_chat_completions.account_selected", zap.Int64("account_id", account.ID), zap.String("account_name", account.Name))
|
||||
_ = scheduleDecision
|
||||
setOpsSelectedAccount(c, account.ID, account.Platform)
|
||||
|
||||
accountReleaseFunc, acquired := h.acquireResponsesAccountSlot(c, apiKey.GroupID, sessionHash, selection, reqStream, &streamStarted, reqLog)
|
||||
if !acquired {
|
||||
return
|
||||
}
|
||||
|
||||
service.SetOpsLatencyMs(c, service.OpsRoutingLatencyMsKey, time.Since(routingStart).Milliseconds())
|
||||
forwardStart := time.Now()
|
||||
|
||||
defaultMappedModel := c.GetString("openai_chat_completions_fallback_model")
|
||||
result, err := h.gatewayService.ForwardAsChatCompletions(c.Request.Context(), c, account, body, promptCacheKey, defaultMappedModel)
|
||||
|
||||
forwardDurationMs := time.Since(forwardStart).Milliseconds()
|
||||
if accountReleaseFunc != nil {
|
||||
accountReleaseFunc()
|
||||
}
|
||||
upstreamLatencyMs, _ := getContextInt64(c, service.OpsUpstreamLatencyMsKey)
|
||||
responseLatencyMs := forwardDurationMs
|
||||
if upstreamLatencyMs > 0 && forwardDurationMs > upstreamLatencyMs {
|
||||
responseLatencyMs = forwardDurationMs - upstreamLatencyMs
|
||||
}
|
||||
service.SetOpsLatencyMs(c, service.OpsResponseLatencyMsKey, responseLatencyMs)
|
||||
if err == nil && result != nil && result.FirstTokenMs != nil {
|
||||
service.SetOpsLatencyMs(c, service.OpsTimeToFirstTokenMsKey, int64(*result.FirstTokenMs))
|
||||
}
|
||||
if err != nil {
|
||||
var failoverErr *service.UpstreamFailoverError
|
||||
if errors.As(err, &failoverErr) {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
// Pool mode: retry on the same account
|
||||
if failoverErr.RetryableOnSameAccount {
|
||||
retryLimit := account.GetPoolModeRetryCount()
|
||||
if sameAccountRetryCount[account.ID] < retryLimit {
|
||||
sameAccountRetryCount[account.ID]++
|
||||
reqLog.Warn("openai_chat_completions.pool_mode_same_account_retry",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("retry_limit", retryLimit),
|
||||
zap.Int("retry_count", sameAccountRetryCount[account.ID]),
|
||||
)
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
return
|
||||
case <-time.After(sameAccountRetryDelay):
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
h.gatewayService.RecordOpenAIAccountSwitch()
|
||||
failedAccountIDs[account.ID] = struct{}{}
|
||||
lastFailoverErr = failoverErr
|
||||
if switchCount >= maxAccountSwitches {
|
||||
h.handleFailoverExhausted(c, failoverErr, streamStarted)
|
||||
return
|
||||
}
|
||||
switchCount++
|
||||
reqLog.Warn("openai_chat_completions.upstream_failover_switching",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("upstream_status", failoverErr.StatusCode),
|
||||
zap.Int("switch_count", switchCount),
|
||||
zap.Int("max_switches", maxAccountSwitches),
|
||||
)
|
||||
continue
|
||||
}
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, false, nil)
|
||||
wroteFallback := h.ensureForwardErrorResponse(c, streamStarted)
|
||||
reqLog.Warn("openai_chat_completions.forward_failed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Bool("fallback_error_response_written", wroteFallback),
|
||||
zap.Error(err),
|
||||
)
|
||||
return
|
||||
}
|
||||
if result != nil {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, result.FirstTokenMs)
|
||||
} else {
|
||||
h.gatewayService.ReportOpenAIAccountScheduleResult(account.ID, true, nil)
|
||||
}
|
||||
|
||||
userAgent := c.GetHeader("User-Agent")
|
||||
clientIP := ip.GetClientIP(c)
|
||||
|
||||
h.submitUsageRecordTask(func(ctx context.Context) {
|
||||
if err := h.gatewayService.RecordUsage(ctx, &service.OpenAIRecordUsageInput{
|
||||
Result: result,
|
||||
APIKey: apiKey,
|
||||
User: apiKey.User,
|
||||
Account: account,
|
||||
Subscription: subscription,
|
||||
InboundEndpoint: GetInboundEndpoint(c),
|
||||
UpstreamEndpoint: GetUpstreamEndpoint(c, account.Platform),
|
||||
UserAgent: userAgent,
|
||||
IPAddress: clientIP,
|
||||
APIKeyService: h.apiKeyService,
|
||||
}); err != nil {
|
||||
logger.L().With(
|
||||
zap.String("component", "handler.openai_gateway.chat_completions"),
|
||||
zap.Int64("user_id", subject.UserID),
|
||||
zap.Int64("api_key_id", apiKey.ID),
|
||||
zap.Any("group_id", apiKey.GroupID),
|
||||
zap.String("model", reqModel),
|
||||
zap.Int64("account_id", account.ID),
|
||||
).Error("openai_chat_completions.record_usage_failed", zap.Error(err))
|
||||
}
|
||||
})
|
||||
reqLog.Debug("openai_chat_completions.request_completed",
|
||||
zap.Int64("account_id", account.ID),
|
||||
zap.Int("switch_count", switchCount),
|
||||
)
|
||||
return
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user