feat(routing): add minimal chat proxy bridge
This commit is contained in:
@@ -52,6 +52,7 @@ type ActionSet struct {
|
||||
AppendRouteStickyAudit func(context.Context, AppendRouteStickyAuditRequest) (RouteStickyAuditInfo, error)
|
||||
ListRouteStickyAudit func(context.Context, ListRouteStickyAuditRequest) ([]RouteStickyAuditInfo, error)
|
||||
ResolveRoute func(context.Context, ResolveRouteRequest) (ResolveRouteInfo, error)
|
||||
ProxyRouteChatCompletions func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error)
|
||||
SetStickyBinding func(context.Context, SetStickyBindingRequest) (StickyBindingInfo, error)
|
||||
GetStickyBinding func(context.Context, GetStickyBindingRequest) (StickyBindingInfo, error)
|
||||
SetRouteFailure func(context.Context, SetRouteFailureRequest) (RouteFailureInfo, error)
|
||||
@@ -402,6 +403,9 @@ func NewAPIHandlerWithAuth(adminAuth AdminAuthConfig, actions ActionSet) http.Ha
|
||||
mux.Handle("POST /api/routing/resolve", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleResolveRoute(w, r, actions.ResolveRoute)
|
||||
})))
|
||||
mux.Handle("POST /api/routing/proxy/chat/completions", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleProxyRouteChatCompletions(w, r, actions.ProxyRouteChatCompletions)
|
||||
})))
|
||||
mux.Handle("POST /api/routing/sticky/bindings", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleSetStickyBinding(w, r, actions.SetStickyBinding)
|
||||
})))
|
||||
@@ -1230,6 +1234,7 @@ func NewActionSet(sqliteDSN string) ActionSet {
|
||||
func NewActionSetWithStickyRuntime(sqliteDSN string, stickyRuntime stickyStoreRuntime) ActionSet {
|
||||
routeLogWriter := newLazyRouteLogWriter(sqliteDSN)
|
||||
resolveRoute := buildResolveRouteAction(sqliteDSN, stickyRuntime, routeLogWriter)
|
||||
proxyRouteChatCompletions := buildProxyRouteChatCompletionsAction(sqliteDSN, resolveRoute, routeLogWriter)
|
||||
return ActionSet{
|
||||
CreateBatchImportRun: buildCreateBatchImportRunAction(sqliteDSN),
|
||||
ListBatchImportRuns: buildListBatchImportRunsAction(sqliteDSN),
|
||||
@@ -1257,6 +1262,7 @@ func NewActionSetWithStickyRuntime(sqliteDSN string, stickyRuntime stickyStoreRu
|
||||
AppendRouteStickyAudit: buildAppendRouteStickyAuditAction(routeLogWriter, sqliteDSN),
|
||||
ListRouteStickyAudit: buildListRouteStickyAuditAction(sqliteDSN),
|
||||
ResolveRoute: resolveRoute,
|
||||
ProxyRouteChatCompletions: proxyRouteChatCompletions,
|
||||
SetStickyBinding: buildSetStickyBindingAction(stickyRuntime),
|
||||
GetStickyBinding: buildGetStickyBindingAction(stickyRuntime),
|
||||
SetRouteFailure: buildSetRouteFailureAction(stickyRuntime),
|
||||
|
||||
338
internal/app/route_proxy_api.go
Normal file
338
internal/app/route_proxy_api.go
Normal file
@@ -0,0 +1,338 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/routing"
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
const routeChatCompletionsPath = "/v1/chat/completions"
|
||||
|
||||
type ProxyRouteChatCompletionsRequest struct {
|
||||
RequestID string `json:"request_id,omitempty"`
|
||||
LogicalGroupID string `json:"logical_group_id"`
|
||||
PublicModel string `json:"public_model"`
|
||||
Scope string `json:"scope"`
|
||||
SubjectID string `json:"subject_id"`
|
||||
UserKey string `json:"user_key,omitempty"`
|
||||
ConversationKey string `json:"conversation_key,omitempty"`
|
||||
GatewayAPIKey string `json:"gateway_api_key"`
|
||||
Messages []ChatCompletionMessage `json:"messages,omitempty"`
|
||||
MaxTokens int `json:"max_tokens,omitempty"`
|
||||
Temperature *float64 `json:"temperature,omitempty"`
|
||||
Sync bool `json:"sync,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionMessage struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type ProxyRouteChatCompletionsResult struct {
|
||||
Resolve ResolveRouteInfo `json:"resolve"`
|
||||
Forward RouteChatCompletionsForwardInfo `json:"forward"`
|
||||
}
|
||||
|
||||
type RouteChatCompletionsForwardInfo struct {
|
||||
OK bool `json:"ok"`
|
||||
HostID string `json:"host_id"`
|
||||
HostBaseURL string `json:"host_base_url"`
|
||||
ShadowGroupID string `json:"shadow_group_id"`
|
||||
ShadowModel string `json:"shadow_model"`
|
||||
UpstreamPath string `json:"upstream_path"`
|
||||
UpstreamStatus int `json:"upstream_status"`
|
||||
LatencyMS int64 `json:"latency_ms"`
|
||||
ContentType string `json:"content_type,omitempty"`
|
||||
ErrorClass string `json:"error_class,omitempty"`
|
||||
ErrorMessage string `json:"error_message,omitempty"`
|
||||
Response any `json:"response,omitempty"`
|
||||
}
|
||||
|
||||
func handleProxyRouteChatCompletions(w http.ResponseWriter, r *http.Request, fn func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error)) {
|
||||
if fn == nil {
|
||||
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "proxy-route-chat-completions action is not configured"})
|
||||
return
|
||||
}
|
||||
var req ProxyRouteChatCompletionsRequest
|
||||
if err := decodeJSON(r, &req); err != nil {
|
||||
writeHTTPError(w, err)
|
||||
return
|
||||
}
|
||||
result, err := fn(r.Context(), req)
|
||||
if err != nil {
|
||||
writeHTTPError(w, classifyError(err))
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, result)
|
||||
}
|
||||
|
||||
func buildProxyRouteChatCompletionsAction(
|
||||
sqliteDSN string,
|
||||
resolveRoute func(context.Context, ResolveRouteRequest) (ResolveRouteInfo, error),
|
||||
writerSource *lazyRouteLogWriter,
|
||||
) func(context.Context, ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
|
||||
return func(ctx context.Context, req ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
|
||||
req.GatewayAPIKey = strings.TrimSpace(req.GatewayAPIKey)
|
||||
if req.GatewayAPIKey == "" {
|
||||
return ProxyRouteChatCompletionsResult{}, fmt.Errorf("gateway_api_key is required")
|
||||
}
|
||||
|
||||
resolveInfo, err := resolveRoute(ctx, ResolveRouteRequest{
|
||||
RequestID: req.RequestID,
|
||||
LogicalGroupID: req.LogicalGroupID,
|
||||
PublicModel: req.PublicModel,
|
||||
Scope: req.Scope,
|
||||
SubjectID: req.SubjectID,
|
||||
UserKey: req.UserKey,
|
||||
ConversationKey: req.ConversationKey,
|
||||
Sync: req.Sync,
|
||||
})
|
||||
if err != nil {
|
||||
return ProxyRouteChatCompletionsResult{}, err
|
||||
}
|
||||
|
||||
store, err := sqlite.Open(ctx, sqliteDSN)
|
||||
if err != nil {
|
||||
return ProxyRouteChatCompletionsResult{}, err
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
hostRow, err := store.Hosts().GetByHostID(ctx, strings.TrimSpace(resolveInfo.ShadowHostID))
|
||||
if err != nil {
|
||||
return ProxyRouteChatCompletionsResult{}, fmt.Errorf("get shadow host %q: %w", resolveInfo.ShadowHostID, err)
|
||||
}
|
||||
|
||||
shadowModel := strings.TrimSpace(resolveInfo.ShadowModel)
|
||||
if shadowModel == "" {
|
||||
shadowModel = strings.TrimSpace(resolveInfo.PublicModel)
|
||||
}
|
||||
|
||||
forward := proxyChatCompletionToShadowHost(ctx, hostRow.BaseURL, req.GatewayAPIKey, shadowModel, req.Messages, req.MaxTokens, req.Temperature)
|
||||
forward.HostID = strings.TrimSpace(hostRow.HostID)
|
||||
forward.HostBaseURL = strings.TrimSpace(hostRow.BaseURL)
|
||||
forward.ShadowGroupID = strings.TrimSpace(resolveInfo.ShadowGroupID)
|
||||
forward.ShadowModel = shadowModel
|
||||
|
||||
if err := appendProxyRouteDecisionLog(ctx, writerSource, req, resolveInfo, forward); err != nil {
|
||||
return ProxyRouteChatCompletionsResult{}, err
|
||||
}
|
||||
if req.Sync {
|
||||
writer, err := writerSource.get(ctx)
|
||||
if err != nil {
|
||||
return ProxyRouteChatCompletionsResult{}, err
|
||||
}
|
||||
if err := writer.Flush(ctx); err != nil {
|
||||
return ProxyRouteChatCompletionsResult{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return ProxyRouteChatCompletionsResult{
|
||||
Resolve: resolveInfo,
|
||||
Forward: forward,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func appendProxyRouteDecisionLog(
|
||||
ctx context.Context,
|
||||
writerSource *lazyRouteLogWriter,
|
||||
req ProxyRouteChatCompletionsRequest,
|
||||
resolveInfo ResolveRouteInfo,
|
||||
forward RouteChatCompletionsForwardInfo,
|
||||
) error {
|
||||
writer, err := writerSource.get(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return writer.AppendDecision(ctx, routing.RouteDecisionEvent{
|
||||
RequestID: strings.TrimSpace(resolveInfo.RequestID),
|
||||
LogicalGroupID: strings.TrimSpace(resolveInfo.LogicalGroupID),
|
||||
PublicModel: strings.TrimSpace(resolveInfo.PublicModel),
|
||||
UserKey: resolveProxyUserKey(req),
|
||||
ConversationKey: resolveProxyConversationKey(req),
|
||||
StickyKey: strings.TrimSpace(resolveInfo.StickyKey),
|
||||
StickyKeyType: strings.TrimSpace(resolveInfo.Scope),
|
||||
StickyHit: resolveInfo.StickyHit,
|
||||
SelectedRouteID: strings.TrimSpace(resolveInfo.RouteID),
|
||||
SelectedShadowGroupID: strings.TrimSpace(resolveInfo.ShadowGroupID),
|
||||
ErrorClass: strings.TrimSpace(forward.ErrorClass),
|
||||
UpstreamStatus: forward.UpstreamStatus,
|
||||
LatencyMS: int(forward.LatencyMS),
|
||||
})
|
||||
}
|
||||
|
||||
func proxyChatCompletionToShadowHost(
|
||||
ctx context.Context,
|
||||
baseURL, gatewayAPIKey, shadowModel string,
|
||||
messages []ChatCompletionMessage,
|
||||
maxTokens int,
|
||||
temperature *float64,
|
||||
) RouteChatCompletionsForwardInfo {
|
||||
info := RouteChatCompletionsForwardInfo{
|
||||
UpstreamPath: routeChatCompletionsPath,
|
||||
}
|
||||
|
||||
requestURL, err := joinRouteProxyPath(baseURL, routeChatCompletionsPath)
|
||||
if err != nil {
|
||||
info.ErrorClass = "invalid_host_base_url"
|
||||
info.ErrorMessage = err.Error()
|
||||
return info
|
||||
}
|
||||
|
||||
payload := map[string]any{
|
||||
"model": strings.TrimSpace(shadowModel),
|
||||
"messages": normalizeProxyChatMessages(messages),
|
||||
"max_tokens": normalizeProxyMaxTokens(maxTokens),
|
||||
"temperature": normalizeProxyTemperature(temperature),
|
||||
}
|
||||
|
||||
var body bytes.Buffer
|
||||
if err := json.NewEncoder(&body).Encode(payload); err != nil {
|
||||
info.ErrorClass = "encode_request_failed"
|
||||
info.ErrorMessage = err.Error()
|
||||
return info
|
||||
}
|
||||
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, &body)
|
||||
if err != nil {
|
||||
info.ErrorClass = "build_request_failed"
|
||||
info.ErrorMessage = err.Error()
|
||||
return info
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Accept", "application/json, text/event-stream")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+strings.TrimSpace(gatewayAPIKey))
|
||||
|
||||
startedAt := time.Now()
|
||||
resp, err := (&http.Client{Timeout: 20 * time.Second}).Do(httpReq)
|
||||
if err != nil {
|
||||
info.LatencyMS = time.Since(startedAt).Milliseconds()
|
||||
info.ErrorClass = "transport_error"
|
||||
info.ErrorMessage = err.Error()
|
||||
return info
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
info.LatencyMS = time.Since(startedAt).Milliseconds()
|
||||
info.UpstreamStatus = resp.StatusCode
|
||||
info.ContentType = strings.TrimSpace(resp.Header.Get("Content-Type"))
|
||||
|
||||
responseBody, readErr := io.ReadAll(resp.Body)
|
||||
if readErr != nil {
|
||||
info.ErrorClass = "read_response_failed"
|
||||
info.ErrorMessage = readErr.Error()
|
||||
return info
|
||||
}
|
||||
|
||||
info.OK = resp.StatusCode >= http.StatusOK && resp.StatusCode < http.StatusMultipleChoices
|
||||
info.Response = decodeProxyResponseBody(responseBody)
|
||||
if !info.OK {
|
||||
info.ErrorClass = classifyProxyUpstreamStatus(resp.StatusCode)
|
||||
}
|
||||
return info
|
||||
}
|
||||
|
||||
func normalizeProxyChatMessages(messages []ChatCompletionMessage) []map[string]string {
|
||||
if len(messages) == 0 {
|
||||
return []map[string]string{{"role": "user", "content": "ping"}}
|
||||
}
|
||||
|
||||
normalized := make([]map[string]string, 0, len(messages))
|
||||
for _, message := range messages {
|
||||
role := strings.TrimSpace(message.Role)
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
normalized = append(normalized, map[string]string{
|
||||
"role": role,
|
||||
"content": strings.TrimSpace(message.Content),
|
||||
})
|
||||
}
|
||||
return normalized
|
||||
}
|
||||
|
||||
func normalizeProxyMaxTokens(maxTokens int) int {
|
||||
if maxTokens <= 0 {
|
||||
return 8
|
||||
}
|
||||
return maxTokens
|
||||
}
|
||||
|
||||
func normalizeProxyTemperature(temperature *float64) float64 {
|
||||
if temperature == nil {
|
||||
return 0
|
||||
}
|
||||
return *temperature
|
||||
}
|
||||
|
||||
func joinRouteProxyPath(baseURL, path string) (string, error) {
|
||||
parsedURL, err := url.Parse(strings.TrimSpace(baseURL))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if parsedURL.Scheme == "" || parsedURL.Host == "" {
|
||||
return "", fmt.Errorf("base url must include scheme and host")
|
||||
}
|
||||
|
||||
resolvedPath := strings.TrimSpace(path)
|
||||
if !strings.HasPrefix(resolvedPath, "/") {
|
||||
resolvedPath = "/" + resolvedPath
|
||||
}
|
||||
return parsedURL.ResolveReference(&url.URL{Path: resolvedPath}).String(), nil
|
||||
}
|
||||
|
||||
func decodeProxyResponseBody(body []byte) any {
|
||||
trimmed := bytes.TrimSpace(body)
|
||||
if len(trimmed) == 0 {
|
||||
return nil
|
||||
}
|
||||
var payload any
|
||||
if err := json.Unmarshal(trimmed, &payload); err == nil {
|
||||
return payload
|
||||
}
|
||||
return string(trimmed)
|
||||
}
|
||||
|
||||
func classifyProxyUpstreamStatus(statusCode int) string {
|
||||
switch {
|
||||
case statusCode == http.StatusUnauthorized || statusCode == http.StatusForbidden:
|
||||
return "gateway_auth_error"
|
||||
case statusCode == http.StatusTooManyRequests:
|
||||
return "gateway_rate_limited"
|
||||
case statusCode >= http.StatusBadGateway:
|
||||
return "gateway_5xx"
|
||||
case statusCode >= http.StatusBadRequest:
|
||||
return fmt.Sprintf("gateway_%d", statusCode)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func resolveProxyUserKey(req ProxyRouteChatCompletionsRequest) string {
|
||||
if key := strings.TrimSpace(req.UserKey); key != "" {
|
||||
return key
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(req.Scope), "user") {
|
||||
return strings.TrimSpace(req.SubjectID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func resolveProxyConversationKey(req ProxyRouteChatCompletionsRequest) string {
|
||||
if key := strings.TrimSpace(req.ConversationKey); key != "" {
|
||||
return key
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(req.Scope), "conversation") {
|
||||
return strings.TrimSpace(req.SubjectID)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
276
internal/app/route_proxy_api_test.go
Normal file
276
internal/app/route_proxy_api_test.go
Normal file
@@ -0,0 +1,276 @@
|
||||
package app
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"sub2api-cn-relay-manager/internal/store/sqlite"
|
||||
)
|
||||
|
||||
func TestAPIProxyRouteChatCompletionsReturnsResolveAndForward(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
handler := NewAPIHandler("secret-token", ActionSet{
|
||||
ProxyRouteChatCompletions: func(_ context.Context, req ProxyRouteChatCompletionsRequest) (ProxyRouteChatCompletionsResult, error) {
|
||||
if req.LogicalGroupID != "gpt-shared" {
|
||||
t.Fatalf("LogicalGroupID = %q, want gpt-shared", req.LogicalGroupID)
|
||||
}
|
||||
if req.GatewayAPIKey != "gateway-key" {
|
||||
t.Fatalf("GatewayAPIKey = %q, want gateway-key", req.GatewayAPIKey)
|
||||
}
|
||||
return ProxyRouteChatCompletionsResult{
|
||||
Resolve: ResolveRouteInfo{
|
||||
RequestID: "req-proxy-1",
|
||||
Backend: "memory",
|
||||
LogicalGroupID: req.LogicalGroupID,
|
||||
PublicModel: req.PublicModel,
|
||||
StickyKey: "lg:gpt-shared:m:gpt-5.4:conv:conv-1",
|
||||
StickyHit: false,
|
||||
StickyAction: "bind",
|
||||
RouteID: "asxs",
|
||||
ShadowGroupID: "gpt-shared__asxs",
|
||||
ShadowHostID: "remote43",
|
||||
ShadowModel: "gpt-5.4-asxs",
|
||||
},
|
||||
Forward: RouteChatCompletionsForwardInfo{
|
||||
OK: true,
|
||||
HostID: "remote43",
|
||||
HostBaseURL: "https://sub2api.example.com",
|
||||
ShadowGroupID: "gpt-shared__asxs",
|
||||
ShadowModel: "gpt-5.4-asxs",
|
||||
UpstreamPath: "/v1/chat/completions",
|
||||
UpstreamStatus: 200,
|
||||
LatencyMS: 12,
|
||||
ContentType: "application/json",
|
||||
Response: map[string]any{
|
||||
"id": "chatcmpl_proxy",
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
})
|
||||
|
||||
request := httptestRequest(t, http.MethodPost, "/api/routing/proxy/chat/completions", map[string]any{
|
||||
"logical_group_id": "gpt-shared",
|
||||
"public_model": "gpt-5.4",
|
||||
"scope": "conversation",
|
||||
"subject_id": "conv-1",
|
||||
"gateway_api_key": "gateway-key",
|
||||
"sync": true,
|
||||
}, "secret-token")
|
||||
response := httptestRecorder(handler, request)
|
||||
assertStatusCode(t, response, http.StatusOK)
|
||||
assertJSONContains(t, response.Body().Bytes(), "resolve.route_id", "asxs")
|
||||
assertJSONContains(t, response.Body().Bytes(), "forward.shadow_model", "gpt-5.4-asxs")
|
||||
assertJSONContains(t, response.Body().Bytes(), "forward.upstream_status", float64(200))
|
||||
}
|
||||
|
||||
func TestNewActionSetProxyRouteChatCompletionsFlow(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
gotAuthHeader string
|
||||
gotModel string
|
||||
gotPrompt string
|
||||
)
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.URL.Path != "/v1/chat/completions" {
|
||||
t.Fatalf("URL.Path = %q, want /v1/chat/completions", r.URL.Path)
|
||||
}
|
||||
gotAuthHeader = r.Header.Get("Authorization")
|
||||
|
||||
var payload struct {
|
||||
Model string `json:"model"`
|
||||
Messages []struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"messages"`
|
||||
}
|
||||
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
||||
t.Fatalf("json.Decode() error = %v", err)
|
||||
}
|
||||
gotModel = payload.Model
|
||||
if len(payload.Messages) > 0 {
|
||||
gotPrompt = payload.Messages[0].Content
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]any{
|
||||
"id": "chatcmpl_proxy",
|
||||
"object": "chat.completion",
|
||||
"choices": []map[string]any{
|
||||
{
|
||||
"index": 0,
|
||||
"message": map[string]any{
|
||||
"role": "assistant",
|
||||
"content": "pong",
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
dsn := "file:" + filepath.ToSlash(filepath.Join(t.TempDir(), "route-proxy.db")) + "?_busy_timeout=5000"
|
||||
actions := NewActionSet(dsn)
|
||||
ctx := context.Background()
|
||||
|
||||
store, err := sqlite.Open(ctx, dsn)
|
||||
if err != nil {
|
||||
t.Fatalf("sqlite.Open() error = %v", err)
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
if _, err := store.Hosts().Create(ctx, sqlite.Host{
|
||||
HostID: "remote43",
|
||||
BaseURL: server.URL,
|
||||
HostVersion: "0.1.126",
|
||||
AuthType: "apikey",
|
||||
AuthToken: "host-admin-token",
|
||||
}); err != nil {
|
||||
t.Fatalf("Hosts().Create() error = %v", err)
|
||||
}
|
||||
if _, err := actions.CreateLogicalGroup(ctx, CreateLogicalGroupRequest{
|
||||
LogicalGroupID: "gpt-shared",
|
||||
DisplayName: "GPT Shared",
|
||||
Status: "active",
|
||||
RoutePolicy: "priority",
|
||||
StickyMode: "conversation_preferred",
|
||||
ConversationTTLSeconds: 1200,
|
||||
UserModelTTLSeconds: 600,
|
||||
FailoverThreshold: 2,
|
||||
CooldownSeconds: 300,
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateLogicalGroup() error = %v", err)
|
||||
}
|
||||
if _, err := actions.CreateLogicalGroupModel(ctx, CreateLogicalGroupModelRequest{
|
||||
LogicalGroupID: "gpt-shared",
|
||||
PublicModel: "gpt-5.4",
|
||||
Status: "active",
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateLogicalGroupModel() error = %v", err)
|
||||
}
|
||||
if _, err := actions.CreateLogicalGroupRoute(ctx, CreateLogicalGroupRouteRequest{
|
||||
LogicalGroupID: "gpt-shared",
|
||||
RouteID: "asxs",
|
||||
Name: "ASXS",
|
||||
Status: "active",
|
||||
Priority: 10,
|
||||
ShadowGroupID: "gpt-shared__asxs",
|
||||
ShadowHostID: "remote43",
|
||||
UpstreamBaseURLHint: "https://api.asxs.top/v1",
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateLogicalGroupRoute() error = %v", err)
|
||||
}
|
||||
if _, err := actions.CreateLogicalGroupRouteModel(ctx, CreateLogicalGroupRouteModelRequest{
|
||||
LogicalGroupID: "gpt-shared",
|
||||
RouteID: "asxs",
|
||||
PublicModel: "gpt-5.4",
|
||||
ShadowModel: "gpt-5.4-asxs",
|
||||
Status: "active",
|
||||
}); err != nil {
|
||||
t.Fatalf("CreateLogicalGroupRouteModel() error = %v", err)
|
||||
}
|
||||
|
||||
result, err := actions.ProxyRouteChatCompletions(ctx, ProxyRouteChatCompletionsRequest{
|
||||
RequestID: "req-proxy-1",
|
||||
LogicalGroupID: "gpt-shared",
|
||||
PublicModel: "gpt-5.4",
|
||||
Scope: "conversation",
|
||||
SubjectID: "conv-1",
|
||||
GatewayAPIKey: "gateway-key",
|
||||
Sync: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ProxyRouteChatCompletions() error = %v", err)
|
||||
}
|
||||
if gotAuthHeader != "Bearer gateway-key" {
|
||||
t.Fatalf("Authorization header = %q, want Bearer gateway-key", gotAuthHeader)
|
||||
}
|
||||
if gotModel != "gpt-5.4-asxs" {
|
||||
t.Fatalf("forwarded model = %q, want gpt-5.4-asxs", gotModel)
|
||||
}
|
||||
if gotPrompt != "ping" {
|
||||
t.Fatalf("forwarded prompt = %q, want ping", gotPrompt)
|
||||
}
|
||||
if result.Resolve.RouteID != "asxs" || result.Resolve.ShadowModel != "gpt-5.4-asxs" {
|
||||
t.Fatalf("Resolve = %+v, want selected asxs route with shadow model", result.Resolve)
|
||||
}
|
||||
if !result.Forward.OK || result.Forward.UpstreamStatus != http.StatusOK {
|
||||
t.Fatalf("Forward = %+v, want successful 200 forward", result.Forward)
|
||||
}
|
||||
if result.Forward.HostID != "remote43" || result.Forward.HostBaseURL != server.URL {
|
||||
t.Fatalf("Forward = %+v, want host remote43 and server URL", result.Forward)
|
||||
}
|
||||
|
||||
decisions, err := actions.ListRouteDecisionLogs(ctx, ListRouteDecisionLogsRequest{
|
||||
RequestID: "req-proxy-1",
|
||||
Limit: 10,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ListRouteDecisionLogs() error = %v", err)
|
||||
}
|
||||
if len(decisions) != 2 {
|
||||
t.Fatalf("ListRouteDecisionLogs() len = %d, want 2", len(decisions))
|
||||
}
|
||||
if decisions[0].UpstreamStatus != http.StatusOK || decisions[0].SelectedRouteID != "asxs" {
|
||||
t.Fatalf("latest decision log = %+v, want upstream_status 200 on asxs", decisions[0])
|
||||
}
|
||||
if decisions[1].UpstreamStatus != 0 {
|
||||
t.Fatalf("initial decision log = %+v, want upstream_status 0 before forward", decisions[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyChatCompletionToShadowHostReportsNon2xx(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusTooManyRequests, map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": "rate limited",
|
||||
},
|
||||
})
|
||||
}))
|
||||
defer server.Close()
|
||||
|
||||
info := proxyChatCompletionToShadowHost(context.Background(), server.URL, "gateway-key", "gpt-5.4-asxs", nil, 0, nil)
|
||||
if info.OK {
|
||||
t.Fatalf("proxyChatCompletionToShadowHost() = %+v, want non-ok result", info)
|
||||
}
|
||||
if info.UpstreamStatus != http.StatusTooManyRequests || info.ErrorClass != "gateway_rate_limited" {
|
||||
t.Fatalf("proxyChatCompletionToShadowHost() = %+v, want 429 gateway_rate_limited", info)
|
||||
}
|
||||
response, ok := info.Response.(map[string]any)
|
||||
if !ok || response["error"] == nil {
|
||||
t.Fatalf("proxyChatCompletionToShadowHost() response = %#v, want decoded json body", info.Response)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRouteProxyHelpers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
if got := normalizeProxyMaxTokens(0); got != 8 {
|
||||
t.Fatalf("normalizeProxyMaxTokens(0) = %d, want 8", got)
|
||||
}
|
||||
if got := normalizeProxyTemperature(nil); got != 0 {
|
||||
t.Fatalf("normalizeProxyTemperature(nil) = %v, want 0", got)
|
||||
}
|
||||
if got := normalizeProxyChatMessages(nil); len(got) != 1 || got[0]["content"] != "ping" {
|
||||
t.Fatalf("normalizeProxyChatMessages(nil) = %#v, want default ping message", got)
|
||||
}
|
||||
if got := classifyProxyUpstreamStatus(http.StatusForbidden); got != "gateway_auth_error" {
|
||||
t.Fatalf("classifyProxyUpstreamStatus(403) = %q, want gateway_auth_error", got)
|
||||
}
|
||||
if _, err := joinRouteProxyPath("://bad-url", routeChatCompletionsPath); err == nil {
|
||||
t.Fatal("joinRouteProxyPath(invalid) error = nil, want error")
|
||||
}
|
||||
if got := resolveProxyUserKey(ProxyRouteChatCompletionsRequest{Scope: "user", SubjectID: "user-1"}); got != "user-1" {
|
||||
t.Fatalf("resolveProxyUserKey(user) = %q, want user-1", got)
|
||||
}
|
||||
if got := resolveProxyConversationKey(ProxyRouteChatCompletionsRequest{Scope: "conversation", SubjectID: "conv-1"}); got != "conv-1" {
|
||||
t.Fatalf("resolveProxyConversationKey(conversation) = %q, want conv-1", got)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user