From 98bd619ec8cc58db9d587f936a5ad348b153f7ff Mon Sep 17 00:00:00 2001 From: phamnazage-jpg Date: Fri, 29 May 2026 07:43:29 +0800 Subject: [PATCH] feat(routing): add sticky runtime backends --- .env.example | 4 + internal/app/bootstrap.go | 6 +- internal/app/http_api.go | 34 ++ internal/app/route_sticky_api.go | 469 ++++++++++++++++++++++++++ internal/app/route_sticky_api_test.go | 262 ++++++++++++++ internal/config/config.go | 53 ++- internal/config/config_test.go | 37 ++ internal/routing/sticky.go | 199 +++++++++++ internal/routing/sticky_memory.go | 175 ++++++++++ internal/routing/sticky_redis.go | 352 +++++++++++++++++++ internal/routing/sticky_test.go | 460 +++++++++++++++++++++++++ 11 files changed, 2046 insertions(+), 5 deletions(-) create mode 100644 internal/app/route_sticky_api.go create mode 100644 internal/app/route_sticky_api_test.go create mode 100644 internal/routing/sticky.go create mode 100644 internal/routing/sticky_memory.go create mode 100644 internal/routing/sticky_redis.go create mode 100644 internal/routing/sticky_test.go diff --git a/.env.example b/.env.example index 72f03dc5..d15dfbf9 100644 --- a/.env.example +++ b/.env.example @@ -4,3 +4,7 @@ SUB2API_CRM_ADMIN_TOKEN=change-me-before-production SUB2API_CRM_ADMIN_USERNAME=admin SUB2API_CRM_ADMIN_PASSWORD=change-me-before-production SUB2API_CRM_ADMIN_SESSION_TTL=12h +SUB2API_CRM_ROUTE_RUNTIME_BACKEND=memory +SUB2API_CRM_REDIS_ADDR= +SUB2API_CRM_REDIS_PASSWORD= +SUB2API_CRM_REDIS_DB=0 diff --git a/internal/app/bootstrap.go b/internal/app/bootstrap.go index 17bdaa6d..e802ae96 100644 --- a/internal/app/bootstrap.go +++ b/internal/app/bootstrap.go @@ -20,13 +20,17 @@ func Bootstrap(ctx context.Context) (*Server, error) { if err != nil { return nil, err } + stickyRuntime, err := newStickyStoreRuntime(ctx, cfg.RouteRuntime) + if err != nil { + return nil, err + } startBackgroundSchedulers(ctx, cfg, defaultBackgroundSchedulers()) handler := NewAPIHandlerWithAuth(AdminAuthConfig{ Token: adminToken, Username: adminSession.Username, Password: adminSession.Password, SessionTTL: adminSession.SessionTTL, - }, NewActionSet(cfg.Database.SQLiteDSN)) + }, NewActionSetWithStickyRuntime(cfg.Database.SQLiteDSN, stickyRuntime)) return NewServer(cfg.Server.ListenAddr, handler, nil), nil } diff --git a/internal/app/http_api.go b/internal/app/http_api.go index 427c214c..2d042638 100644 --- a/internal/app/http_api.go +++ b/internal/app/http_api.go @@ -51,6 +51,12 @@ type ActionSet struct { ListRouteFailoverEvents func(context.Context, ListRouteFailoverEventsRequest) ([]RouteFailoverEventInfo, error) AppendRouteStickyAudit func(context.Context, AppendRouteStickyAuditRequest) (RouteStickyAuditInfo, error) ListRouteStickyAudit func(context.Context, ListRouteStickyAuditRequest) ([]RouteStickyAuditInfo, error) + SetStickyBinding func(context.Context, SetStickyBindingRequest) (StickyBindingInfo, error) + GetStickyBinding func(context.Context, GetStickyBindingRequest) (StickyBindingInfo, error) + SetRouteFailure func(context.Context, SetRouteFailureRequest) (RouteFailureInfo, error) + GetRouteFailure func(context.Context, GetRouteFailureRequest) (RouteFailureInfo, error) + SetRouteCooldown func(context.Context, SetRouteCooldownRequest) (RouteCooldownInfo, error) + GetRouteCooldown func(context.Context, GetRouteCooldownRequest) (RouteCooldownInfo, error) CreateProviderDraft func(context.Context, CreateProviderDraftRequest) (ProviderDraftInfo, error) ListProviderDrafts func(context.Context, ListProviderDraftsRequest) ([]ProviderDraftInfo, error) GetProviderDraft func(context.Context, string) (ProviderDraftInfo, error) @@ -392,6 +398,24 @@ func NewAPIHandlerWithAuth(adminAuth AdminAuthConfig, actions ActionSet) http.Ha mux.Handle("GET /api/routing/logs/sticky-audit", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handleListRouteStickyAudit(w, r, actions.ListRouteStickyAudit) }))) + mux.Handle("POST /api/routing/sticky/bindings", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleSetStickyBinding(w, r, actions.SetStickyBinding) + }))) + mux.Handle("GET /api/routing/sticky/bindings", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleGetStickyBinding(w, r, actions.GetStickyBinding) + }))) + mux.Handle("POST /api/routing/sticky/route-failures", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleSetRouteFailure(w, r, actions.SetRouteFailure) + }))) + mux.Handle("GET /api/routing/sticky/route-failures", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleGetRouteFailure(w, r, actions.GetRouteFailure) + }))) + mux.Handle("POST /api/routing/sticky/cooldowns", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleSetRouteCooldown(w, r, actions.SetRouteCooldown) + }))) + mux.Handle("GET /api/routing/sticky/cooldowns", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handleGetRouteCooldown(w, r, actions.GetRouteCooldown) + }))) mux.Handle("POST /api/provider-drafts", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handleCreateProviderDraft(w, r, actions.CreateProviderDraft) }))) @@ -1196,6 +1220,10 @@ func classifyError(err error) *httpError { } func NewActionSet(sqliteDSN string) ActionSet { + return NewActionSetWithStickyRuntime(sqliteDSN, defaultStickyStoreRuntime()) +} + +func NewActionSetWithStickyRuntime(sqliteDSN string, stickyRuntime stickyStoreRuntime) ActionSet { routeLogWriter := newLazyRouteLogWriter(sqliteDSN) return ActionSet{ CreateBatchImportRun: buildCreateBatchImportRunAction(sqliteDSN), @@ -1223,6 +1251,12 @@ func NewActionSet(sqliteDSN string) ActionSet { ListRouteFailoverEvents: buildListRouteFailoverEventsAction(sqliteDSN), AppendRouteStickyAudit: buildAppendRouteStickyAuditAction(routeLogWriter, sqliteDSN), ListRouteStickyAudit: buildListRouteStickyAuditAction(sqliteDSN), + SetStickyBinding: buildSetStickyBindingAction(stickyRuntime), + GetStickyBinding: buildGetStickyBindingAction(stickyRuntime), + SetRouteFailure: buildSetRouteFailureAction(stickyRuntime), + GetRouteFailure: buildGetRouteFailureAction(stickyRuntime), + SetRouteCooldown: buildSetRouteCooldownAction(stickyRuntime), + GetRouteCooldown: buildGetRouteCooldownAction(stickyRuntime), CreateProviderDraft: func(ctx context.Context, req CreateProviderDraftRequest) (ProviderDraftInfo, error) { store, err := sqlite.Open(ctx, sqliteDSN) if err != nil { diff --git a/internal/app/route_sticky_api.go b/internal/app/route_sticky_api.go new file mode 100644 index 00000000..15fd1383 --- /dev/null +++ b/internal/app/route_sticky_api.go @@ -0,0 +1,469 @@ +package app + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "sub2api-cn-relay-manager/internal/config" + "sub2api-cn-relay-manager/internal/routing" +) + +const defaultStickyTTLSeconds = 600 + +type stickyStoreRuntime struct { + backend string + store routing.StickyStore +} + +type SetStickyBindingRequest struct { + Scope string `json:"scope"` + LogicalGroupID string `json:"logical_group_id"` + PublicModel string `json:"public_model"` + SubjectID string `json:"subject_id"` + RouteID string `json:"route_id"` + ShadowGroupID string `json:"shadow_group_id"` + TTLSeconds int `json:"ttl_seconds,omitempty"` +} + +type GetStickyBindingRequest struct { + Scope string + LogicalGroupID string + PublicModel string + SubjectID string +} + +type StickyBindingInfo struct { + Backend string `json:"backend"` + Key string `json:"key"` + Scope string `json:"scope"` + LogicalGroupID string `json:"logical_group_id"` + PublicModel string `json:"public_model"` + SubjectID string `json:"subject_id"` + RouteID string `json:"route_id"` + ShadowGroupID string `json:"shadow_group_id"` + BoundAt string `json:"bound_at,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` +} + +type SetRouteFailureRequest struct { + RouteID string `json:"route_id"` + FailureCount int `json:"failure_count"` + LastErrorClass string `json:"last_error_class,omitempty"` + TTLSeconds int `json:"ttl_seconds,omitempty"` +} + +type RouteFailureInfo struct { + Backend string `json:"backend"` + Key string `json:"key"` + RouteID string `json:"route_id"` + FailureCount int `json:"failure_count"` + LastErrorClass string `json:"last_error_class,omitempty"` + LastFailureAt string `json:"last_failure_at,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` +} + +type GetRouteFailureRequest struct { + RouteID string +} + +type SetRouteCooldownRequest struct { + RouteID string `json:"route_id"` + Reason string `json:"reason,omitempty"` + TTLSeconds int `json:"ttl_seconds,omitempty"` +} + +type GetRouteCooldownRequest struct { + RouteID string +} + +type RouteCooldownInfo struct { + Backend string `json:"backend"` + Key string `json:"key"` + RouteID string `json:"route_id"` + Reason string `json:"reason,omitempty"` + Until string `json:"until,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` +} + +func newStickyStoreRuntime(ctx context.Context, cfg config.RouteRuntimeConfig) (stickyStoreRuntime, error) { + backend, err := routing.NormalizeRuntimeBackend(cfg.Backend) + if err != nil { + return stickyStoreRuntime{}, err + } + + switch backend { + case routing.RuntimeBackendMemory: + return stickyStoreRuntime{ + backend: routing.RuntimeBackendMemory, + store: routing.NewInMemoryStickyStore(), + }, nil + case routing.RuntimeBackendRedis: + store, err := routing.NewRedisStickyStore(ctx, routing.RedisConfig{ + Addr: cfg.Redis.Addr, + Password: cfg.Redis.Password, + DB: cfg.Redis.DB, + }) + if err != nil { + return stickyStoreRuntime{}, err + } + return stickyStoreRuntime{ + backend: routing.RuntimeBackendRedis, + store: store, + }, nil + default: + return stickyStoreRuntime{}, fmt.Errorf("unsupported sticky backend %q", backend) + } +} + +func defaultStickyStoreRuntime() stickyStoreRuntime { + return stickyStoreRuntime{ + backend: routing.RuntimeBackendMemory, + store: routing.NewInMemoryStickyStore(), + } +} + +func handleSetStickyBinding(w http.ResponseWriter, r *http.Request, fn func(context.Context, SetStickyBindingRequest) (StickyBindingInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "set-sticky-binding action is not configured"}) + return + } + var req SetStickyBindingRequest + if err := decodeJSON(r, &req); err != nil { + writeHTTPError(w, err) + return + } + info, err := fn(r.Context(), req) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusCreated, map[string]any{"binding": info}) +} + +func handleGetStickyBinding(w http.ResponseWriter, r *http.Request, fn func(context.Context, GetStickyBindingRequest) (StickyBindingInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-sticky-binding action is not configured"}) + return + } + req, err := decodeGetStickyBindingRequest(r) + if err != nil { + writeHTTPError(w, err) + return + } + info, actionErr := fn(r.Context(), req) + if actionErr != nil { + writeHTTPError(w, classifyError(actionErr)) + return + } + writeJSON(w, http.StatusOK, map[string]any{"binding": info}) +} + +func handleSetRouteFailure(w http.ResponseWriter, r *http.Request, fn func(context.Context, SetRouteFailureRequest) (RouteFailureInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "set-route-failure action is not configured"}) + return + } + var req SetRouteFailureRequest + if err := decodeJSON(r, &req); err != nil { + writeHTTPError(w, err) + return + } + info, err := fn(r.Context(), req) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusCreated, map[string]any{"route_failure": info}) +} + +func handleGetRouteFailure(w http.ResponseWriter, r *http.Request, fn func(context.Context, GetRouteFailureRequest) (RouteFailureInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-route-failure action is not configured"}) + return + } + req, err := decodeGetRouteFailureRequest(r) + if err != nil { + writeHTTPError(w, err) + return + } + info, actionErr := fn(r.Context(), req) + if actionErr != nil { + writeHTTPError(w, classifyError(actionErr)) + return + } + writeJSON(w, http.StatusOK, map[string]any{"route_failure": info}) +} + +func handleSetRouteCooldown(w http.ResponseWriter, r *http.Request, fn func(context.Context, SetRouteCooldownRequest) (RouteCooldownInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "set-route-cooldown action is not configured"}) + return + } + var req SetRouteCooldownRequest + if err := decodeJSON(r, &req); err != nil { + writeHTTPError(w, err) + return + } + info, err := fn(r.Context(), req) + if err != nil { + writeHTTPError(w, classifyError(err)) + return + } + writeJSON(w, http.StatusCreated, map[string]any{"route_cooldown": info}) +} + +func handleGetRouteCooldown(w http.ResponseWriter, r *http.Request, fn func(context.Context, GetRouteCooldownRequest) (RouteCooldownInfo, error)) { + if fn == nil { + writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-route-cooldown action is not configured"}) + return + } + req, err := decodeGetRouteCooldownRequest(r) + if err != nil { + writeHTTPError(w, err) + return + } + info, actionErr := fn(r.Context(), req) + if actionErr != nil { + writeHTTPError(w, classifyError(actionErr)) + return + } + writeJSON(w, http.StatusOK, map[string]any{"route_cooldown": info}) +} + +func buildSetStickyBindingAction(runtime stickyStoreRuntime) func(context.Context, SetStickyBindingRequest) (StickyBindingInfo, error) { + return func(ctx context.Context, req SetStickyBindingRequest) (StickyBindingInfo, error) { + key, err := routing.BuildStickyKey(req.Scope, req.LogicalGroupID, req.PublicModel, req.SubjectID) + if err != nil { + return StickyBindingInfo{}, err + } + ttl, err := secondsToDuration(req.TTLSeconds, defaultStickyTTLSeconds) + if err != nil { + return StickyBindingInfo{}, err + } + binding := routing.StickyBinding{ + LogicalGroupID: strings.TrimSpace(req.LogicalGroupID), + PublicModel: strings.TrimSpace(req.PublicModel), + RouteID: strings.TrimSpace(req.RouteID), + ShadowGroupID: strings.TrimSpace(req.ShadowGroupID), + } + if err := runtime.store.Set(ctx, key, binding, ttl); err != nil { + return StickyBindingInfo{}, err + } + stored, ok, err := runtime.store.Get(ctx, key) + if err != nil { + return StickyBindingInfo{}, err + } + if !ok { + return StickyBindingInfo{}, fmt.Errorf("sticky binding %q not found after set", key) + } + return stickyBindingToInfo(runtime.backend, req.Scope, req.SubjectID, key, stored), nil + } +} + +func buildGetStickyBindingAction(runtime stickyStoreRuntime) func(context.Context, GetStickyBindingRequest) (StickyBindingInfo, error) { + return func(ctx context.Context, req GetStickyBindingRequest) (StickyBindingInfo, error) { + key, err := routing.BuildStickyKey(req.Scope, req.LogicalGroupID, req.PublicModel, req.SubjectID) + if err != nil { + return StickyBindingInfo{}, err + } + stored, ok, err := runtime.store.Get(ctx, key) + if err != nil { + return StickyBindingInfo{}, err + } + if !ok { + return StickyBindingInfo{}, fmt.Errorf("sticky binding %q not found", key) + } + return stickyBindingToInfo(runtime.backend, req.Scope, req.SubjectID, key, stored), nil + } +} + +func buildSetRouteFailureAction(runtime stickyStoreRuntime) func(context.Context, SetRouteFailureRequest) (RouteFailureInfo, error) { + return func(ctx context.Context, req SetRouteFailureRequest) (RouteFailureInfo, error) { + key, err := routing.BuildRouteFailureKey(req.RouteID) + if err != nil { + return RouteFailureInfo{}, err + } + ttl, err := secondsToDuration(req.TTLSeconds, defaultStickyTTLSeconds) + if err != nil { + return RouteFailureInfo{}, err + } + state := routing.RouteFailureState{ + RouteID: strings.TrimSpace(req.RouteID), + FailureCount: req.FailureCount, + LastErrorClass: strings.TrimSpace(req.LastErrorClass), + } + if err := runtime.store.SetRouteFailure(ctx, req.RouteID, state, ttl); err != nil { + return RouteFailureInfo{}, err + } + stored, ok, err := runtime.store.GetRouteFailure(ctx, req.RouteID) + if err != nil { + return RouteFailureInfo{}, err + } + if !ok { + return RouteFailureInfo{}, fmt.Errorf("route failure %q not found after set", req.RouteID) + } + return routeFailureToInfo(runtime.backend, key, stored), nil + } +} + +func buildGetRouteFailureAction(runtime stickyStoreRuntime) func(context.Context, GetRouteFailureRequest) (RouteFailureInfo, error) { + return func(ctx context.Context, req GetRouteFailureRequest) (RouteFailureInfo, error) { + key, err := routing.BuildRouteFailureKey(req.RouteID) + if err != nil { + return RouteFailureInfo{}, err + } + stored, ok, err := runtime.store.GetRouteFailure(ctx, req.RouteID) + if err != nil { + return RouteFailureInfo{}, err + } + if !ok { + return RouteFailureInfo{}, fmt.Errorf("route failure %q not found", req.RouteID) + } + return routeFailureToInfo(runtime.backend, key, stored), nil + } +} + +func buildSetRouteCooldownAction(runtime stickyStoreRuntime) func(context.Context, SetRouteCooldownRequest) (RouteCooldownInfo, error) { + return func(ctx context.Context, req SetRouteCooldownRequest) (RouteCooldownInfo, error) { + key, err := routing.BuildRouteCooldownKey(req.RouteID) + if err != nil { + return RouteCooldownInfo{}, err + } + ttl, err := secondsToDuration(req.TTLSeconds, defaultStickyTTLSeconds) + if err != nil { + return RouteCooldownInfo{}, err + } + state := routing.RouteCooldownState{ + RouteID: strings.TrimSpace(req.RouteID), + Reason: strings.TrimSpace(req.Reason), + } + if err := runtime.store.SetCooldown(ctx, req.RouteID, state, ttl); err != nil { + return RouteCooldownInfo{}, err + } + stored, ok, err := runtime.store.GetCooldown(ctx, req.RouteID) + if err != nil { + return RouteCooldownInfo{}, err + } + if !ok { + return RouteCooldownInfo{}, fmt.Errorf("route cooldown %q not found after set", req.RouteID) + } + return routeCooldownToInfo(runtime.backend, key, stored), nil + } +} + +func buildGetRouteCooldownAction(runtime stickyStoreRuntime) func(context.Context, GetRouteCooldownRequest) (RouteCooldownInfo, error) { + return func(ctx context.Context, req GetRouteCooldownRequest) (RouteCooldownInfo, error) { + key, err := routing.BuildRouteCooldownKey(req.RouteID) + if err != nil { + return RouteCooldownInfo{}, err + } + stored, ok, err := runtime.store.GetCooldown(ctx, req.RouteID) + if err != nil { + return RouteCooldownInfo{}, err + } + if !ok { + return RouteCooldownInfo{}, fmt.Errorf("route cooldown %q not found", req.RouteID) + } + return routeCooldownToInfo(runtime.backend, key, stored), nil + } +} + +func decodeGetStickyBindingRequest(r *http.Request) (GetStickyBindingRequest, *httpError) { + req := GetStickyBindingRequest{ + Scope: strings.TrimSpace(r.URL.Query().Get("scope")), + LogicalGroupID: strings.TrimSpace(r.URL.Query().Get("logical_group_id")), + PublicModel: strings.TrimSpace(r.URL.Query().Get("public_model")), + SubjectID: strings.TrimSpace(r.URL.Query().Get("subject_id")), + } + if req.Scope == "" || req.LogicalGroupID == "" || req.PublicModel == "" || req.SubjectID == "" { + return GetStickyBindingRequest{}, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "scope, logical_group_id, public_model, subject_id are required"} + } + return req, nil +} + +func decodeGetRouteFailureRequest(r *http.Request) (GetRouteFailureRequest, *httpError) { + req := GetRouteFailureRequest{ + RouteID: strings.TrimSpace(r.URL.Query().Get("route_id")), + } + if req.RouteID == "" { + return GetRouteFailureRequest{}, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "route_id is required"} + } + return req, nil +} + +func decodeGetRouteCooldownRequest(r *http.Request) (GetRouteCooldownRequest, *httpError) { + req := GetRouteCooldownRequest{ + RouteID: strings.TrimSpace(r.URL.Query().Get("route_id")), + } + if req.RouteID == "" { + return GetRouteCooldownRequest{}, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "route_id is required"} + } + return req, nil +} + +func secondsToDuration(raw int, defaultSeconds int) (time.Duration, error) { + if raw == 0 { + raw = defaultSeconds + } + if raw < 0 { + return 0, fmt.Errorf("ttl_seconds must be >= 0") + } + if raw == 0 { + return 0, fmt.Errorf("ttl_seconds must be > 0") + } + return time.Duration(raw) * time.Second, nil +} + +func stickyBindingToInfo(backend, scope, subjectID, key string, binding routing.StickyBinding) StickyBindingInfo { + return StickyBindingInfo{ + Backend: backend, + Key: key, + Scope: strings.TrimSpace(scope), + SubjectID: strings.TrimSpace(subjectID), + LogicalGroupID: binding.LogicalGroupID, + PublicModel: binding.PublicModel, + RouteID: binding.RouteID, + ShadowGroupID: binding.ShadowGroupID, + BoundAt: binding.BoundAt, + ExpiresAt: binding.ExpiresAt, + } +} + +func routeFailureToInfo(backend, key string, state routing.RouteFailureState) RouteFailureInfo { + return RouteFailureInfo{ + Backend: backend, + Key: key, + RouteID: state.RouteID, + FailureCount: state.FailureCount, + LastErrorClass: state.LastErrorClass, + LastFailureAt: state.LastFailureAt, + ExpiresAt: state.ExpiresAt, + } +} + +func routeCooldownToInfo(backend, key string, state routing.RouteCooldownState) RouteCooldownInfo { + return RouteCooldownInfo{ + Backend: backend, + Key: key, + RouteID: state.RouteID, + Reason: state.Reason, + Until: state.Until, + ExpiresAt: state.ExpiresAt, + } +} + +func parseTTLQuery(raw string) (int, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return 0, nil + } + value, err := strconv.Atoi(raw) + if err != nil { + return 0, fmt.Errorf("ttl_seconds must be a positive integer") + } + return value, nil +} diff --git a/internal/app/route_sticky_api_test.go b/internal/app/route_sticky_api_test.go new file mode 100644 index 00000000..78e4934a --- /dev/null +++ b/internal/app/route_sticky_api_test.go @@ -0,0 +1,262 @@ +package app + +import ( + "context" + "net/http" + "net/url" + "path/filepath" + "testing" + + "sub2api-cn-relay-manager/internal/config" +) + +func TestAPISetStickyBindingReturnsCreated(t *testing.T) { + handler := NewAPIHandler("secret-token", ActionSet{ + SetStickyBinding: func(_ context.Context, req SetStickyBindingRequest) (StickyBindingInfo, error) { + if req.Scope != "conversation" { + t.Fatalf("Scope = %q, want conversation", req.Scope) + } + return StickyBindingInfo{ + Backend: "memory", + Key: "lg:gpt-shared:m:gpt-5.4:conv:conv-1", + Scope: req.Scope, + RouteID: req.RouteID, + }, nil + }, + }) + + request := httptestRequest(t, http.MethodPost, "/api/routing/sticky/bindings", map[string]any{ + "scope": "conversation", + "logical_group_id": "gpt-shared", + "public_model": "gpt-5.4", + "subject_id": "conv-1", + "route_id": "asxs", + "shadow_group_id": "gpt-shared__asxs", + "ttl_seconds": 600, + }, "secret-token") + response := httptestRecorder(handler, request) + assertStatusCode(t, response, http.StatusCreated) + assertJSONContains(t, response.Body().Bytes(), "binding.backend", "memory") + assertJSONContains(t, response.Body().Bytes(), "binding.route_id", "asxs") +} + +func TestAPIGetStickyBindingRejectsMissingQuery(t *testing.T) { + handler := NewAPIHandler("secret-token", ActionSet{ + GetStickyBinding: func(context.Context, GetStickyBindingRequest) (StickyBindingInfo, error) { + t.Fatal("GetStickyBinding should not be called") + return StickyBindingInfo{}, nil + }, + }) + + request := httptestRequest(t, http.MethodGet, "/api/routing/sticky/bindings?scope=conversation", nil, "secret-token") + response := httptestRecorder(handler, request) + assertStatusCode(t, response, http.StatusBadRequest) +} + +func TestAPISetAndGetRouteFailureAndCooldown(t *testing.T) { + handler := NewAPIHandler("secret-token", ActionSet{ + SetRouteFailure: func(_ context.Context, req SetRouteFailureRequest) (RouteFailureInfo, error) { + return RouteFailureInfo{ + Backend: "memory", + Key: "routefail:" + req.RouteID, + RouteID: req.RouteID, + FailureCount: req.FailureCount, + }, nil + }, + GetRouteFailure: func(_ context.Context, req GetRouteFailureRequest) (RouteFailureInfo, error) { + return RouteFailureInfo{ + Backend: "memory", + Key: "routefail:" + req.RouteID, + RouteID: req.RouteID, + FailureCount: 2, + }, nil + }, + SetRouteCooldown: func(_ context.Context, req SetRouteCooldownRequest) (RouteCooldownInfo, error) { + return RouteCooldownInfo{ + Backend: "memory", + Key: "routecool:" + req.RouteID, + RouteID: req.RouteID, + Reason: req.Reason, + }, nil + }, + GetRouteCooldown: func(_ context.Context, req GetRouteCooldownRequest) (RouteCooldownInfo, error) { + return RouteCooldownInfo{ + Backend: "memory", + Key: "routecool:" + req.RouteID, + RouteID: req.RouteID, + Reason: "cooldown", + }, nil + }, + }) + + setFailureReq := httptestRequest(t, http.MethodPost, "/api/routing/sticky/route-failures", map[string]any{ + "route_id": "asxs", + "failure_count": 2, + "ttl_seconds": 600, + }, "secret-token") + setFailureResp := httptestRecorder(handler, setFailureReq) + assertStatusCode(t, setFailureResp, http.StatusCreated) + assertJSONContains(t, setFailureResp.Body().Bytes(), "route_failure.key", "routefail:asxs") + + getFailureReq := httptestRequest(t, http.MethodGet, "/api/routing/sticky/route-failures?route_id=asxs", nil, "secret-token") + getFailureResp := httptestRecorder(handler, getFailureReq) + assertStatusCode(t, getFailureResp, http.StatusOK) + assertJSONContains(t, getFailureResp.Body().Bytes(), "route_failure.failure_count", float64(2)) + + setCooldownReq := httptestRequest(t, http.MethodPost, "/api/routing/sticky/cooldowns", map[string]any{ + "route_id": "asxs", + "reason": "cooldown", + "ttl_seconds": 600, + }, "secret-token") + setCooldownResp := httptestRecorder(handler, setCooldownReq) + assertStatusCode(t, setCooldownResp, http.StatusCreated) + assertJSONContains(t, setCooldownResp.Body().Bytes(), "route_cooldown.key", "routecool:asxs") + + getCooldownReq := httptestRequest(t, http.MethodGet, "/api/routing/sticky/cooldowns?route_id=asxs", nil, "secret-token") + getCooldownResp := httptestRecorder(handler, getCooldownReq) + assertStatusCode(t, getCooldownResp, http.StatusOK) + assertJSONContains(t, getCooldownResp.Body().Bytes(), "route_cooldown.reason", "cooldown") +} + +func TestNewActionSetStickyRuntimeFlow(t *testing.T) { + dsn := "file:" + filepath.ToSlash(filepath.Join(t.TempDir(), "sticky.db")) + "?_busy_timeout=5000" + actions := NewActionSet(dsn) + ctx := context.Background() + + binding, err := actions.SetStickyBinding(ctx, SetStickyBindingRequest{ + Scope: "conversation", + LogicalGroupID: "gpt-shared", + PublicModel: "gpt-5.4", + SubjectID: "conv-1", + RouteID: "asxs", + ShadowGroupID: "gpt-shared__asxs", + TTLSeconds: 600, + }) + if err != nil { + t.Fatalf("SetStickyBinding() error = %v", err) + } + if binding.Backend != "memory" || binding.RouteID != "asxs" { + t.Fatalf("SetStickyBinding() = %+v, want memory/asxs", binding) + } + + loadedBinding, err := actions.GetStickyBinding(ctx, GetStickyBindingRequest{ + Scope: "conversation", + LogicalGroupID: "gpt-shared", + PublicModel: "gpt-5.4", + SubjectID: "conv-1", + }) + if err != nil { + t.Fatalf("GetStickyBinding() error = %v", err) + } + if loadedBinding.Key != binding.Key { + t.Fatalf("GetStickyBinding().Key = %q, want %q", loadedBinding.Key, binding.Key) + } + + failure, err := actions.SetRouteFailure(ctx, SetRouteFailureRequest{ + RouteID: "asxs", + FailureCount: 2, + LastErrorClass: "timeout", + TTLSeconds: 600, + }) + if err != nil { + t.Fatalf("SetRouteFailure() error = %v", err) + } + if failure.FailureCount != 2 { + t.Fatalf("SetRouteFailure() = %+v, want count 2", failure) + } + loadedFailure, err := actions.GetRouteFailure(ctx, GetRouteFailureRequest{RouteID: "asxs"}) + if err != nil { + t.Fatalf("GetRouteFailure() error = %v", err) + } + if loadedFailure.Key != failure.Key { + t.Fatalf("GetRouteFailure().Key = %q, want %q", loadedFailure.Key, failure.Key) + } + + cooldown, err := actions.SetRouteCooldown(ctx, SetRouteCooldownRequest{ + RouteID: "asxs", + Reason: "degraded", + TTLSeconds: 600, + }) + if err != nil { + t.Fatalf("SetRouteCooldown() error = %v", err) + } + if cooldown.Reason != "degraded" { + t.Fatalf("SetRouteCooldown() = %+v, want reason degraded", cooldown) + } + loadedCooldown, err := actions.GetRouteCooldown(ctx, GetRouteCooldownRequest{RouteID: "asxs"}) + if err != nil { + t.Fatalf("GetRouteCooldown() error = %v", err) + } + if loadedCooldown.Key != cooldown.Key { + t.Fatalf("GetRouteCooldown().Key = %q, want %q", loadedCooldown.Key, cooldown.Key) + } +} + +func TestStickyRuntimeHelpers(t *testing.T) { + t.Parallel() + + runtime, err := newStickyStoreRuntime(context.Background(), config.RouteRuntimeConfig{ + Backend: "memory", + }) + if err != nil { + t.Fatalf("newStickyStoreRuntime(memory) error = %v", err) + } + if runtime.backend != "memory" || runtime.store == nil { + t.Fatalf("newStickyStoreRuntime(memory) = %+v, want memory backend with store", runtime) + } + + if _, err := newStickyStoreRuntime(context.Background(), config.RouteRuntimeConfig{ + Backend: "bad", + }); err == nil { + t.Fatal("newStickyStoreRuntime(bad) error = nil, want error") + } + + if got := defaultStickyStoreRuntime(); got.backend != "memory" || got.store == nil { + t.Fatalf("defaultStickyStoreRuntime() = %+v, want memory backend with store", got) + } +} + +func TestStickyRequestDecodersAndHelpers(t *testing.T) { + t.Parallel() + + req := &http.Request{URL: &url.URL{RawQuery: "scope=conversation&logical_group_id=gpt-shared&public_model=gpt-5.4&subject_id=conv-1"}} + stickyReq, stickyErr := decodeGetStickyBindingRequest(req) + if stickyErr != nil { + t.Fatalf("decodeGetStickyBindingRequest() error = %v", stickyErr) + } + if stickyReq.SubjectID != "conv-1" { + t.Fatalf("decodeGetStickyBindingRequest() = %+v, want subject_id conv-1", stickyReq) + } + + req = &http.Request{URL: &url.URL{RawQuery: "route_id=asxs"}} + failureReq, failureErr := decodeGetRouteFailureRequest(req) + if failureErr != nil { + t.Fatalf("decodeGetRouteFailureRequest() error = %v", failureErr) + } + if failureReq.RouteID != "asxs" { + t.Fatalf("decodeGetRouteFailureRequest() = %+v, want route_id asxs", failureReq) + } + + cooldownReq, cooldownErr := decodeGetRouteCooldownRequest(req) + if cooldownErr != nil { + t.Fatalf("decodeGetRouteCooldownRequest() error = %v", cooldownErr) + } + if cooldownReq.RouteID != "asxs" { + t.Fatalf("decodeGetRouteCooldownRequest() = %+v, want route_id asxs", cooldownReq) + } + + if _, err := secondsToDuration(-1, defaultStickyTTLSeconds); err == nil { + t.Fatal("secondsToDuration(-1) error = nil, want error") + } + duration, err := secondsToDuration(0, defaultStickyTTLSeconds) + if err != nil { + t.Fatalf("secondsToDuration(default) error = %v", err) + } + if duration <= 0 { + t.Fatalf("secondsToDuration(default) = %s, want positive", duration) + } + + if _, err := parseTTLQuery("bad"); err == nil { + t.Fatal("parseTTLQuery(bad) error = nil, want error") + } +} diff --git a/internal/config/config.go b/internal/config/config.go index f920fa61..2b6a5b2f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,7 @@ package config import ( "fmt" "os" + "strconv" "strings" "time" ) @@ -17,12 +18,17 @@ const ( EnvRepoRoot = "SUB2API_CRM_REPO_ROOT" EnvReconcileWorkerEnabled = "SUB2API_CRM_RECONCILE_WORKER_ENABLED" EnvReconcilePollInterval = "SUB2API_CRM_RECONCILE_POLL_INTERVAL" + EnvRouteRuntimeBackend = "SUB2API_CRM_ROUTE_RUNTIME_BACKEND" + EnvRedisAddr = "SUB2API_CRM_REDIS_ADDR" + EnvRedisPassword = "SUB2API_CRM_REDIS_PASSWORD" + EnvRedisDB = "SUB2API_CRM_REDIS_DB" DefaultListenAddr = ":8080" DefaultSQLiteDSN = "file:sub2api-cn-relay-manager.db?_foreign_keys=on&_busy_timeout=5000" DefaultAdminUsername = "admin" DefaultAdminSessionTTL = 12 * time.Hour DefaultReconcilePollInterval = 10 * time.Minute + DefaultRouteRuntimeBackend = "memory" ) type ServerConfig struct { @@ -33,6 +39,17 @@ type DatabaseConfig struct { SQLiteDSN string } +type RedisRuntimeConfig struct { + Addr string + Password string + DB int +} + +type RouteRuntimeConfig struct { + Backend string + Redis RedisRuntimeConfig +} + type ReconcileConfig struct { WorkerEnabled bool PollInterval time.Duration @@ -43,10 +60,11 @@ type RepositoryConfig struct { } type StartupConfig struct { - Server ServerConfig - Database DatabaseConfig - Repository RepositoryConfig - Reconcile ReconcileConfig + Server ServerConfig + Database DatabaseConfig + Repository RepositoryConfig + RouteRuntime RouteRuntimeConfig + Reconcile ReconcileConfig } type AdminSessionConfig struct { @@ -64,6 +82,10 @@ func loadStartupFromLookupEnv(lookup func(string) (string, bool)) (StartupConfig if err != nil { return StartupConfig{}, err } + redisDB, err := readOptionalIntEnv(lookup, EnvRedisDB, 0) + if err != nil { + return StartupConfig{}, err + } cfg := StartupConfig{ Server: ServerConfig{ ListenAddr: readOptionalEnv(lookup, EnvListenAddr, DefaultListenAddr), @@ -74,6 +96,14 @@ func loadStartupFromLookupEnv(lookup func(string) (string, bool)) (StartupConfig Repository: RepositoryConfig{ RepoRoot: readOptionalEnv(lookup, EnvRepoRoot, ""), }, + RouteRuntime: RouteRuntimeConfig{ + Backend: readOptionalEnv(lookup, EnvRouteRuntimeBackend, DefaultRouteRuntimeBackend), + Redis: RedisRuntimeConfig{ + Addr: readOptionalEnv(lookup, EnvRedisAddr, ""), + Password: readOptionalEnv(lookup, EnvRedisPassword, ""), + DB: redisDB, + }, + }, Reconcile: ReconcileConfig{ WorkerEnabled: readOptionalBoolEnv(lookup, EnvReconcileWorkerEnabled, false), PollInterval: reconcilePollInterval, @@ -164,3 +194,18 @@ func readOptionalDurationEnv(lookup func(string) (string, bool), key string, def } return duration, nil } + +func readOptionalIntEnv(lookup func(string) (string, bool), key string, defaultValue int) (int, error) { + value, ok := lookup(key) + if !ok || strings.TrimSpace(value) == "" { + return defaultValue, nil + } + number, err := strconv.Atoi(strings.TrimSpace(value)) + if err != nil { + return 0, fmt.Errorf("%s: parse int: %w", key, err) + } + if number < 0 { + return 0, fmt.Errorf("%s: value must be >= 0", key) + } + return number, nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 63dcb533..d98fbc99 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -69,6 +69,14 @@ func TestLoadStartupFromLookupEnv(t *testing.T) { return "true", true case EnvReconcilePollInterval: return "15m", true + case EnvRouteRuntimeBackend: + return "redis", true + case EnvRedisAddr: + return "127.0.0.1:16379", true + case EnvRedisPassword: + return " redis-pass ", true + case EnvRedisDB: + return "5", true default: return "", false } @@ -92,6 +100,18 @@ func TestLoadStartupFromLookupEnv(t *testing.T) { if cfg.Reconcile.PollInterval != 15*time.Minute { t.Fatalf("PollInterval = %s, want 15m", cfg.Reconcile.PollInterval) } + if cfg.RouteRuntime.Backend != "redis" { + t.Fatalf("RouteRuntime.Backend = %q, want redis", cfg.RouteRuntime.Backend) + } + if cfg.RouteRuntime.Redis.Addr != "127.0.0.1:16379" { + t.Fatalf("RouteRuntime.Redis.Addr = %q, want 127.0.0.1:16379", cfg.RouteRuntime.Redis.Addr) + } + if cfg.RouteRuntime.Redis.Password != "redis-pass" { + t.Fatalf("RouteRuntime.Redis.Password = %q, want redis-pass", cfg.RouteRuntime.Redis.Password) + } + if cfg.RouteRuntime.Redis.DB != 5 { + t.Fatalf("RouteRuntime.Redis.DB = %d, want 5", cfg.RouteRuntime.Redis.DB) + } }) t.Run("default values", func(t *testing.T) { lookup := func(k string) (string, bool) { @@ -116,6 +136,12 @@ func TestLoadStartupFromLookupEnv(t *testing.T) { if cfg.Reconcile.PollInterval != DefaultReconcilePollInterval { t.Fatalf("PollInterval = %s, want %s", cfg.Reconcile.PollInterval, DefaultReconcilePollInterval) } + if cfg.RouteRuntime.Backend != DefaultRouteRuntimeBackend { + t.Fatalf("RouteRuntime.Backend = %q, want %q", cfg.RouteRuntime.Backend, DefaultRouteRuntimeBackend) + } + if cfg.RouteRuntime.Redis.Addr != "" || cfg.RouteRuntime.Redis.Password != "" || cfg.RouteRuntime.Redis.DB != 0 { + t.Fatalf("RouteRuntime.Redis = %+v, want zero value", cfg.RouteRuntime.Redis) + } }) t.Run("invalid reconcile interval", func(t *testing.T) { lookup := func(k string) (string, bool) { @@ -128,6 +154,17 @@ func TestLoadStartupFromLookupEnv(t *testing.T) { t.Fatal("loadStartupFromLookupEnv() error = nil, want invalid interval") } }) + t.Run("invalid redis db", func(t *testing.T) { + lookup := func(k string) (string, bool) { + if k == EnvRedisDB { + return "-1", true + } + return "", false + } + if _, err := loadStartupFromLookupEnv(lookup); err == nil { + t.Fatal("loadStartupFromLookupEnv() error = nil, want invalid redis db") + } + }) } func TestLoadAdminTokenFromLookupEnv(t *testing.T) { diff --git a/internal/routing/sticky.go b/internal/routing/sticky.go new file mode 100644 index 00000000..24424a4f --- /dev/null +++ b/internal/routing/sticky.go @@ -0,0 +1,199 @@ +package routing + +import ( + "context" + "fmt" + "strings" + "time" +) + +const ( + RuntimeBackendMemory = "memory" + RuntimeBackendRedis = "redis" + + StickyScopeConversation = "conversation" + StickyScopeSession = "session" + StickyScopeUser = "user" +) + +type StickyBinding struct { + LogicalGroupID string `json:"logical_group_id"` + PublicModel string `json:"public_model"` + RouteID string `json:"route_id"` + ShadowGroupID string `json:"shadow_group_id"` + BoundAt string `json:"bound_at,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` +} + +type RouteFailureState struct { + RouteID string `json:"route_id"` + FailureCount int `json:"failure_count"` + LastErrorClass string `json:"last_error_class,omitempty"` + LastFailureAt string `json:"last_failure_at,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` +} + +type RouteCooldownState struct { + RouteID string `json:"route_id"` + Reason string `json:"reason,omitempty"` + Until string `json:"until,omitempty"` + ExpiresAt string `json:"expires_at,omitempty"` +} + +type StickyStore interface { + Get(ctx context.Context, key string) (StickyBinding, bool, error) + Set(ctx context.Context, key string, binding StickyBinding, ttl time.Duration) error + Delete(ctx context.Context, key string) error + + GetRouteFailure(ctx context.Context, routeID string) (RouteFailureState, bool, error) + SetRouteFailure(ctx context.Context, routeID string, state RouteFailureState, ttl time.Duration) error + ClearRouteFailure(ctx context.Context, routeID string) error + + GetCooldown(ctx context.Context, routeID string) (RouteCooldownState, bool, error) + SetCooldown(ctx context.Context, routeID string, state RouteCooldownState, ttl time.Duration) error + ClearCooldown(ctx context.Context, routeID string) error +} + +type StoreConfig struct { + Backend string + Redis RedisConfig +} + +func NormalizeRuntimeBackend(backend string) (string, error) { + switch strings.ToLower(strings.TrimSpace(backend)) { + case "", RuntimeBackendMemory: + return RuntimeBackendMemory, nil + case RuntimeBackendRedis: + return RuntimeBackendRedis, nil + default: + return "", fmt.Errorf("unsupported route runtime backend %q", backend) + } +} + +func BuildStickyKey(scope, logicalGroupID, publicModel, subjectID string) (string, error) { + scope = strings.ToLower(strings.TrimSpace(scope)) + logicalGroupID = strings.TrimSpace(logicalGroupID) + publicModel = strings.TrimSpace(publicModel) + subjectID = strings.TrimSpace(subjectID) + + switch { + case logicalGroupID == "": + return "", fmt.Errorf("logical_group_id is required") + case publicModel == "": + return "", fmt.Errorf("public_model is required") + case subjectID == "": + return "", fmt.Errorf("subject_id is required") + } + + var subjectPrefix string + switch scope { + case StickyScopeConversation: + subjectPrefix = "conv" + case StickyScopeSession: + subjectPrefix = "sess" + case StickyScopeUser: + subjectPrefix = "user" + default: + return "", fmt.Errorf("unsupported sticky scope %q", scope) + } + + return fmt.Sprintf("lg:%s:m:%s:%s:%s", logicalGroupID, publicModel, subjectPrefix, subjectID), nil +} + +func BuildRouteFailureKey(routeID string) (string, error) { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return "", fmt.Errorf("route_id is required") + } + return "routefail:" + routeID, nil +} + +func BuildRouteCooldownKey(routeID string) (string, error) { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return "", fmt.Errorf("route_id is required") + } + return "routecool:" + routeID, nil +} + +func normalizeStickyBinding(binding StickyBinding, ttl time.Duration, now time.Time) (StickyBinding, error) { + binding.LogicalGroupID = strings.TrimSpace(binding.LogicalGroupID) + binding.PublicModel = strings.TrimSpace(binding.PublicModel) + binding.RouteID = strings.TrimSpace(binding.RouteID) + binding.ShadowGroupID = strings.TrimSpace(binding.ShadowGroupID) + + switch { + case binding.LogicalGroupID == "": + return StickyBinding{}, fmt.Errorf("logical_group_id is required") + case binding.PublicModel == "": + return StickyBinding{}, fmt.Errorf("public_model is required") + case binding.RouteID == "": + return StickyBinding{}, fmt.Errorf("route_id is required") + case binding.ShadowGroupID == "": + return StickyBinding{}, fmt.Errorf("shadow_group_id is required") + case ttl <= 0: + return StickyBinding{}, fmt.Errorf("ttl must be positive") + } + + if binding.BoundAt == "" { + binding.BoundAt = now.UTC().Format(time.RFC3339) + } + if binding.ExpiresAt == "" { + binding.ExpiresAt = now.UTC().Add(ttl).Format(time.RFC3339) + } + return binding, nil +} + +func normalizeRouteFailureState(routeID string, state RouteFailureState, ttl time.Duration, now time.Time) (RouteFailureState, error) { + routeID = strings.TrimSpace(routeID) + state.RouteID = strings.TrimSpace(state.RouteID) + state.LastErrorClass = strings.TrimSpace(state.LastErrorClass) + + if state.RouteID == "" { + state.RouteID = routeID + } + switch { + case routeID == "": + return RouteFailureState{}, fmt.Errorf("route_id is required") + case state.RouteID != routeID: + return RouteFailureState{}, fmt.Errorf("route_id mismatch") + case state.FailureCount < 0: + return RouteFailureState{}, fmt.Errorf("failure_count must be >= 0") + case ttl <= 0: + return RouteFailureState{}, fmt.Errorf("ttl must be positive") + } + + if state.LastFailureAt == "" { + state.LastFailureAt = now.UTC().Format(time.RFC3339) + } + if state.ExpiresAt == "" { + state.ExpiresAt = now.UTC().Add(ttl).Format(time.RFC3339) + } + return state, nil +} + +func normalizeRouteCooldownState(routeID string, state RouteCooldownState, ttl time.Duration, now time.Time) (RouteCooldownState, error) { + routeID = strings.TrimSpace(routeID) + state.RouteID = strings.TrimSpace(state.RouteID) + state.Reason = strings.TrimSpace(state.Reason) + + if state.RouteID == "" { + state.RouteID = routeID + } + switch { + case routeID == "": + return RouteCooldownState{}, fmt.Errorf("route_id is required") + case state.RouteID != routeID: + return RouteCooldownState{}, fmt.Errorf("route_id mismatch") + case ttl <= 0: + return RouteCooldownState{}, fmt.Errorf("ttl must be positive") + } + + if state.Until == "" { + state.Until = now.UTC().Add(ttl).Format(time.RFC3339) + } + if state.ExpiresAt == "" { + state.ExpiresAt = now.UTC().Add(ttl).Format(time.RFC3339) + } + return state, nil +} diff --git a/internal/routing/sticky_memory.go b/internal/routing/sticky_memory.go new file mode 100644 index 00000000..d7a708f2 --- /dev/null +++ b/internal/routing/sticky_memory.go @@ -0,0 +1,175 @@ +package routing + +import ( + "context" + "strings" + "sync" + "time" +) + +type memoryStickyEntry struct { + binding StickyBinding + expiresAt time.Time +} + +type memoryRouteFailureEntry struct { + state RouteFailureState + expiresAt time.Time +} + +type memoryRouteCooldownEntry struct { + state RouteCooldownState + expiresAt time.Time +} + +type InMemoryStickyStore struct { + now func() time.Time + mu sync.RWMutex + bindings map[string]memoryStickyEntry + routeFailure map[string]memoryRouteFailureEntry + cooldowns map[string]memoryRouteCooldownEntry +} + +func NewInMemoryStickyStore() *InMemoryStickyStore { + return &InMemoryStickyStore{ + now: time.Now, + bindings: make(map[string]memoryStickyEntry), + routeFailure: make(map[string]memoryRouteFailureEntry), + cooldowns: make(map[string]memoryRouteCooldownEntry), + } +} + +func (s *InMemoryStickyStore) Get(_ context.Context, key string) (StickyBinding, bool, error) { + key = strings.TrimSpace(key) + if key == "" { + return StickyBinding{}, false, nil + } + + s.mu.Lock() + defer s.mu.Unlock() + + entry, ok := s.bindings[key] + if !ok { + return StickyBinding{}, false, nil + } + if s.expired(entry.expiresAt) { + delete(s.bindings, key) + return StickyBinding{}, false, nil + } + return entry.binding, true, nil +} + +func (s *InMemoryStickyStore) Set(_ context.Context, key string, binding StickyBinding, ttl time.Duration) error { + key = strings.TrimSpace(key) + binding, err := normalizeStickyBinding(binding, ttl, s.now()) + if err != nil { + return err + } + if key == "" { + return nil + } + + s.mu.Lock() + defer s.mu.Unlock() + s.bindings[key] = memoryStickyEntry{binding: binding, expiresAt: s.now().UTC().Add(ttl)} + return nil +} + +func (s *InMemoryStickyStore) Delete(_ context.Context, key string) error { + key = strings.TrimSpace(key) + if key == "" { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + delete(s.bindings, key) + return nil +} + +func (s *InMemoryStickyStore) GetRouteFailure(_ context.Context, routeID string) (RouteFailureState, bool, error) { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return RouteFailureState{}, false, nil + } + + s.mu.Lock() + defer s.mu.Unlock() + entry, ok := s.routeFailure[routeID] + if !ok { + return RouteFailureState{}, false, nil + } + if s.expired(entry.expiresAt) { + delete(s.routeFailure, routeID) + return RouteFailureState{}, false, nil + } + return entry.state, true, nil +} + +func (s *InMemoryStickyStore) SetRouteFailure(_ context.Context, routeID string, state RouteFailureState, ttl time.Duration) error { + state, err := normalizeRouteFailureState(routeID, state, ttl, s.now()) + if err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + s.routeFailure[state.RouteID] = memoryRouteFailureEntry{state: state, expiresAt: s.now().UTC().Add(ttl)} + return nil +} + +func (s *InMemoryStickyStore) ClearRouteFailure(_ context.Context, routeID string) error { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + delete(s.routeFailure, routeID) + return nil +} + +func (s *InMemoryStickyStore) GetCooldown(_ context.Context, routeID string) (RouteCooldownState, bool, error) { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return RouteCooldownState{}, false, nil + } + + s.mu.Lock() + defer s.mu.Unlock() + entry, ok := s.cooldowns[routeID] + if !ok { + return RouteCooldownState{}, false, nil + } + if s.expired(entry.expiresAt) { + delete(s.cooldowns, routeID) + return RouteCooldownState{}, false, nil + } + return entry.state, true, nil +} + +func (s *InMemoryStickyStore) SetCooldown(_ context.Context, routeID string, state RouteCooldownState, ttl time.Duration) error { + state, err := normalizeRouteCooldownState(routeID, state, ttl, s.now()) + if err != nil { + return err + } + + s.mu.Lock() + defer s.mu.Unlock() + s.cooldowns[state.RouteID] = memoryRouteCooldownEntry{state: state, expiresAt: s.now().UTC().Add(ttl)} + return nil +} + +func (s *InMemoryStickyStore) ClearCooldown(_ context.Context, routeID string) error { + routeID = strings.TrimSpace(routeID) + if routeID == "" { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + delete(s.cooldowns, routeID) + return nil +} + +func (s *InMemoryStickyStore) expired(expiresAt time.Time) bool { + return !expiresAt.IsZero() && !expiresAt.After(s.now().UTC()) +} diff --git a/internal/routing/sticky_redis.go b/internal/routing/sticky_redis.go new file mode 100644 index 00000000..8e768939 --- /dev/null +++ b/internal/routing/sticky_redis.go @@ -0,0 +1,352 @@ +package routing + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "net" + "strconv" + "strings" + "time" +) + +const ( + defaultRedisDialTimeout = 3 * time.Second +) + +type RedisConfig struct { + Addr string + Password string + DB int + DialTimeout time.Duration +} + +type RedisStickyStore struct { + cfg RedisConfig +} + +func NewRedisStickyStore(ctx context.Context, cfg RedisConfig) (*RedisStickyStore, error) { + cfg = normalizeRedisConfig(cfg) + if strings.TrimSpace(cfg.Addr) == "" { + return nil, fmt.Errorf("redis addr is required") + } + + store := &RedisStickyStore{cfg: cfg} + if err := store.ping(ctx); err != nil { + return nil, err + } + return store, nil +} + +func (s *RedisStickyStore) Get(ctx context.Context, key string) (StickyBinding, bool, error) { + key = strings.TrimSpace(key) + if key == "" { + return StickyBinding{}, false, nil + } + + payload, ok, err := s.getJSON(ctx, key) + if err != nil || !ok { + return StickyBinding{}, ok, err + } + var binding StickyBinding + if err := json.Unmarshal(payload, &binding); err != nil { + return StickyBinding{}, false, fmt.Errorf("decode sticky binding %q: %w", key, err) + } + return binding, true, nil +} + +func (s *RedisStickyStore) Set(ctx context.Context, key string, binding StickyBinding, ttl time.Duration) error { + key = strings.TrimSpace(key) + if key == "" { + return nil + } + binding, err := normalizeStickyBinding(binding, ttl, time.Now()) + if err != nil { + return err + } + return s.setJSON(ctx, key, binding, ttl) +} + +func (s *RedisStickyStore) Delete(ctx context.Context, key string) error { + key = strings.TrimSpace(key) + if key == "" { + return nil + } + return s.delKey(ctx, key) +} + +func (s *RedisStickyStore) GetRouteFailure(ctx context.Context, routeID string) (RouteFailureState, bool, error) { + key, err := BuildRouteFailureKey(routeID) + if err != nil { + return RouteFailureState{}, false, err + } + payload, ok, err := s.getJSON(ctx, key) + if err != nil || !ok { + return RouteFailureState{}, ok, err + } + var state RouteFailureState + if err := json.Unmarshal(payload, &state); err != nil { + return RouteFailureState{}, false, fmt.Errorf("decode route failure %q: %w", routeID, err) + } + return state, true, nil +} + +func (s *RedisStickyStore) SetRouteFailure(ctx context.Context, routeID string, state RouteFailureState, ttl time.Duration) error { + key, err := BuildRouteFailureKey(routeID) + if err != nil { + return err + } + state, err = normalizeRouteFailureState(routeID, state, ttl, time.Now()) + if err != nil { + return err + } + return s.setJSON(ctx, key, state, ttl) +} + +func (s *RedisStickyStore) ClearRouteFailure(ctx context.Context, routeID string) error { + key, err := BuildRouteFailureKey(routeID) + if err != nil { + return err + } + return s.delKey(ctx, key) +} + +func (s *RedisStickyStore) GetCooldown(ctx context.Context, routeID string) (RouteCooldownState, bool, error) { + key, err := BuildRouteCooldownKey(routeID) + if err != nil { + return RouteCooldownState{}, false, err + } + payload, ok, err := s.getJSON(ctx, key) + if err != nil || !ok { + return RouteCooldownState{}, ok, err + } + var state RouteCooldownState + if err := json.Unmarshal(payload, &state); err != nil { + return RouteCooldownState{}, false, fmt.Errorf("decode route cooldown %q: %w", routeID, err) + } + return state, true, nil +} + +func (s *RedisStickyStore) SetCooldown(ctx context.Context, routeID string, state RouteCooldownState, ttl time.Duration) error { + key, err := BuildRouteCooldownKey(routeID) + if err != nil { + return err + } + state, err = normalizeRouteCooldownState(routeID, state, ttl, time.Now()) + if err != nil { + return err + } + return s.setJSON(ctx, key, state, ttl) +} + +func (s *RedisStickyStore) ClearCooldown(ctx context.Context, routeID string) error { + key, err := BuildRouteCooldownKey(routeID) + if err != nil { + return err + } + return s.delKey(ctx, key) +} + +func (s *RedisStickyStore) ping(ctx context.Context) error { + conn, reader, err := s.open(ctx) + if err != nil { + return err + } + defer conn.Close() + if err := writeRESPArray(conn, "PING"); err != nil { + return fmt.Errorf("redis ping write: %w", err) + } + reply, err := readRESPValue(reader) + if err != nil { + return fmt.Errorf("redis ping read: %w", err) + } + if reply.kind != '+' || reply.stringValue != "PONG" { + return fmt.Errorf("redis ping unexpected response: kind=%q value=%q", reply.kind, reply.stringValue) + } + return nil +} + +func (s *RedisStickyStore) getJSON(ctx context.Context, key string) ([]byte, bool, error) { + conn, reader, err := s.open(ctx) + if err != nil { + return nil, false, err + } + defer conn.Close() + + if err := writeRESPArray(conn, "GET", key); err != nil { + return nil, false, fmt.Errorf("redis GET %q: write: %w", key, err) + } + reply, err := readRESPValue(reader) + if err != nil { + return nil, false, fmt.Errorf("redis GET %q: read: %w", key, err) + } + switch reply.kind { + case '$': + return []byte(reply.stringValue), true, nil + case '_': + return nil, false, nil + default: + return nil, false, fmt.Errorf("redis GET %q: unexpected response %q", key, reply.kind) + } +} + +func (s *RedisStickyStore) setJSON(ctx context.Context, key string, payload any, ttl time.Duration) error { + encoded, err := json.Marshal(payload) + if err != nil { + return fmt.Errorf("marshal redis value for %q: %w", key, err) + } + + conn, reader, err := s.open(ctx) + if err != nil { + return err + } + defer conn.Close() + + seconds := int(ttl / time.Second) + if ttl%time.Second != 0 { + seconds++ + } + if seconds <= 0 { + seconds = 1 + } + + if err := writeRESPArray(conn, "SET", key, string(encoded), "EX", strconv.Itoa(seconds)); err != nil { + return fmt.Errorf("redis SET %q: write: %w", key, err) + } + reply, err := readRESPValue(reader) + if err != nil { + return fmt.Errorf("redis SET %q: read: %w", key, err) + } + if reply.kind != '+' || reply.stringValue != "OK" { + return fmt.Errorf("redis SET %q: unexpected response kind=%q value=%q", key, reply.kind, reply.stringValue) + } + return nil +} + +func (s *RedisStickyStore) delKey(ctx context.Context, key string) error { + conn, reader, err := s.open(ctx) + if err != nil { + return err + } + defer conn.Close() + + if err := writeRESPArray(conn, "DEL", key); err != nil { + return fmt.Errorf("redis DEL %q: write: %w", key, err) + } + reply, err := readRESPValue(reader) + if err != nil { + return fmt.Errorf("redis DEL %q: read: %w", key, err) + } + if reply.kind != ':' { + return fmt.Errorf("redis DEL %q: unexpected response kind=%q value=%q", key, reply.kind, reply.stringValue) + } + return nil +} + +func (s *RedisStickyStore) open(ctx context.Context) (net.Conn, *bufio.Reader, error) { + cfg := normalizeRedisConfig(s.cfg) + dialer := &net.Dialer{Timeout: cfg.DialTimeout} + conn, err := dialer.DialContext(ctx, "tcp", cfg.Addr) + if err != nil { + return nil, nil, fmt.Errorf("dial redis %q: %w", cfg.Addr, err) + } + reader := bufio.NewReader(conn) + + if strings.TrimSpace(cfg.Password) != "" { + if err := writeRESPArray(conn, "AUTH", cfg.Password); err != nil { + conn.Close() + return nil, nil, fmt.Errorf("redis AUTH write: %w", err) + } + reply, err := readRESPValue(reader) + if err != nil { + conn.Close() + return nil, nil, fmt.Errorf("redis AUTH read: %w", err) + } + if reply.kind != '+' || reply.stringValue != "OK" { + conn.Close() + return nil, nil, fmt.Errorf("redis AUTH unexpected response kind=%q value=%q", reply.kind, reply.stringValue) + } + } + + if cfg.DB > 0 { + if err := writeRESPArray(conn, "SELECT", strconv.Itoa(cfg.DB)); err != nil { + conn.Close() + return nil, nil, fmt.Errorf("redis SELECT write: %w", err) + } + reply, err := readRESPValue(reader) + if err != nil { + conn.Close() + return nil, nil, fmt.Errorf("redis SELECT read: %w", err) + } + if reply.kind != '+' || reply.stringValue != "OK" { + conn.Close() + return nil, nil, fmt.Errorf("redis SELECT unexpected response kind=%q value=%q", reply.kind, reply.stringValue) + } + } + + return conn, reader, nil +} + +func normalizeRedisConfig(cfg RedisConfig) RedisConfig { + cfg.Addr = strings.TrimSpace(cfg.Addr) + cfg.Password = strings.TrimSpace(cfg.Password) + if cfg.DialTimeout <= 0 { + cfg.DialTimeout = defaultRedisDialTimeout + } + return cfg +} + +type respValue struct { + kind byte + stringValue string +} + +func writeRESPArray(w io.Writer, parts ...string) error { + if _, err := io.WriteString(w, fmt.Sprintf("*%d\r\n", len(parts))); err != nil { + return err + } + for _, part := range parts { + if _, err := io.WriteString(w, fmt.Sprintf("$%d\r\n%s\r\n", len(part), part)); err != nil { + return err + } + } + return nil +} + +func readRESPValue(r *bufio.Reader) (respValue, error) { + prefix, err := r.ReadByte() + if err != nil { + return respValue{}, err + } + + line, err := r.ReadString('\n') + if err != nil { + return respValue{}, err + } + line = strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r") + + switch prefix { + case '+', '-', ':': + if prefix == '-' { + return respValue{}, fmt.Errorf("redis error: %s", line) + } + return respValue{kind: prefix, stringValue: line}, nil + case '$': + size, err := strconv.Atoi(line) + if err != nil { + return respValue{}, fmt.Errorf("parse bulk length: %w", err) + } + if size < 0 { + return respValue{kind: '_'}, nil + } + payload := make([]byte, size+2) + if _, err := io.ReadFull(r, payload); err != nil { + return respValue{}, err + } + return respValue{kind: '$', stringValue: string(payload[:size])}, nil + default: + return respValue{}, fmt.Errorf("unsupported redis response prefix %q", prefix) + } +} diff --git a/internal/routing/sticky_test.go b/internal/routing/sticky_test.go new file mode 100644 index 00000000..28fb8c78 --- /dev/null +++ b/internal/routing/sticky_test.go @@ -0,0 +1,460 @@ +package routing + +import ( + "bufio" + "context" + "fmt" + "io" + "net" + "path/filepath" + "strconv" + "strings" + "sync" + "testing" + "time" +) + +func TestBuildStickyKey(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + scope string + want string + wantErr bool + }{ + {name: "conversation", scope: StickyScopeConversation, want: "lg:gpt-shared:m:gpt-5.4:conv:conversation-1"}, + {name: "session", scope: StickyScopeSession, want: "lg:gpt-shared:m:gpt-5.4:sess:session-1"}, + {name: "user", scope: StickyScopeUser, want: "lg:gpt-shared:m:gpt-5.4:user:user-1"}, + {name: "invalid", scope: "bad", wantErr: true}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got, err := BuildStickyKey(tt.scope, "gpt-shared", "gpt-5.4", tt.name+"-1") + if tt.wantErr { + if err == nil { + t.Fatal("BuildStickyKey() error = nil, want error") + } + return + } + if err != nil { + t.Fatalf("BuildStickyKey() error = %v", err) + } + if got != tt.want { + t.Fatalf("BuildStickyKey() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestInMemoryStickyStoreBindingFailureAndCooldown(t *testing.T) { + t.Parallel() + + store := NewInMemoryStickyStore() + ctx := context.Background() + key, err := BuildStickyKey(StickyScopeConversation, "gpt-shared", "gpt-5.4", "conv-1") + if err != nil { + t.Fatalf("BuildStickyKey() error = %v", err) + } + + if err := store.Set(ctx, key, StickyBinding{ + LogicalGroupID: "gpt-shared", + PublicModel: "gpt-5.4", + RouteID: "asxs", + ShadowGroupID: "gpt-shared__asxs", + }, 2*time.Second); err != nil { + t.Fatalf("Set() error = %v", err) + } + binding, ok, err := store.Get(ctx, key) + if err != nil || !ok { + t.Fatalf("Get() = (%+v, %v, %v), want binding", binding, ok, err) + } + if binding.RouteID != "asxs" { + t.Fatalf("binding.RouteID = %q, want asxs", binding.RouteID) + } + if err := store.Delete(ctx, key); err != nil { + t.Fatalf("Delete() error = %v", err) + } + if _, ok, err := store.Get(ctx, key); err != nil || ok { + t.Fatalf("Get() after delete = (ok=%v, err=%v), want false nil", ok, err) + } + + if err := store.SetRouteFailure(ctx, "asxs", RouteFailureState{ + FailureCount: 2, + LastErrorClass: "timeout", + }, time.Second); err != nil { + t.Fatalf("SetRouteFailure() error = %v", err) + } + failure, ok, err := store.GetRouteFailure(ctx, "asxs") + if err != nil || !ok || failure.FailureCount != 2 { + t.Fatalf("GetRouteFailure() = (%+v, %v, %v), want count 2", failure, ok, err) + } + if err := store.ClearRouteFailure(ctx, "asxs"); err != nil { + t.Fatalf("ClearRouteFailure() error = %v", err) + } + + if err := store.SetCooldown(ctx, "asxs", RouteCooldownState{ + Reason: "cooldown", + }, time.Second); err != nil { + t.Fatalf("SetCooldown() error = %v", err) + } + cooldown, ok, err := store.GetCooldown(ctx, "asxs") + if err != nil || !ok || cooldown.RouteID != "asxs" { + t.Fatalf("GetCooldown() = (%+v, %v, %v), want route asxs", cooldown, ok, err) + } + if err := store.ClearCooldown(ctx, "asxs"); err != nil { + t.Fatalf("ClearCooldown() error = %v", err) + } +} + +func TestInMemoryStickyStoreTTlExpiry(t *testing.T) { + t.Parallel() + + store := NewInMemoryStickyStore() + ctx := context.Background() + key, err := BuildStickyKey(StickyScopeUser, "gpt-shared", "gpt-5.4", "user-1") + if err != nil { + t.Fatalf("BuildStickyKey() error = %v", err) + } + if err := store.Set(ctx, key, StickyBinding{ + LogicalGroupID: "gpt-shared", + PublicModel: "gpt-5.4", + RouteID: "asxs", + ShadowGroupID: "gpt-shared__asxs", + }, 40*time.Millisecond); err != nil { + t.Fatalf("Set() error = %v", err) + } + time.Sleep(60 * time.Millisecond) + if _, ok, err := store.Get(ctx, key); err != nil || ok { + t.Fatalf("Get() after ttl = (ok=%v, err=%v), want false nil", ok, err) + } +} + +func TestRedisStickyStoreRoundTripWithFakeServer(t *testing.T) { + t.Parallel() + + server := newFakeRedisServer(t) + defer server.Close() + + store, err := NewRedisStickyStore(context.Background(), RedisConfig{ + Addr: server.Addr(), + Password: "secret", + DB: 2, + }) + if err != nil { + t.Fatalf("NewRedisStickyStore() error = %v", err) + } + + ctx := context.Background() + key, err := BuildStickyKey(StickyScopeSession, "gpt-shared", "gpt-5.4", "sess-1") + if err != nil { + t.Fatalf("BuildStickyKey() error = %v", err) + } + if err := store.Set(ctx, key, StickyBinding{ + LogicalGroupID: "gpt-shared", + PublicModel: "gpt-5.4", + RouteID: "asxs", + ShadowGroupID: "gpt-shared__asxs", + }, time.Minute); err != nil { + t.Fatalf("Set() error = %v", err) + } + if binding, ok, err := store.Get(ctx, key); err != nil || !ok || binding.RouteID != "asxs" { + t.Fatalf("Get() = (%+v, %v, %v), want route asxs", binding, ok, err) + } + if err := store.SetRouteFailure(ctx, "asxs", RouteFailureState{ + FailureCount: 3, + LastErrorClass: "timeout", + }, time.Minute); err != nil { + t.Fatalf("SetRouteFailure() error = %v", err) + } + if state, ok, err := store.GetRouteFailure(ctx, "asxs"); err != nil || !ok || state.FailureCount != 3 { + t.Fatalf("GetRouteFailure() = (%+v, %v, %v), want count 3", state, ok, err) + } + if err := store.SetCooldown(ctx, "asxs", RouteCooldownState{ + Reason: "degraded", + }, time.Minute); err != nil { + t.Fatalf("SetCooldown() error = %v", err) + } + if state, ok, err := store.GetCooldown(ctx, "asxs"); err != nil || !ok || state.Reason != "degraded" { + t.Fatalf("GetCooldown() = (%+v, %v, %v), want reason degraded", state, ok, err) + } + if err := store.Delete(ctx, key); err != nil { + t.Fatalf("Delete() error = %v", err) + } + if _, ok, err := store.Get(ctx, key); err != nil || ok { + t.Fatalf("Get() after delete = (ok=%v, err=%v), want false nil", ok, err) + } +} + +type fakeRedisServer struct { + t *testing.T + listener net.Listener + password string + mu sync.Mutex + values map[int]map[string]fakeRedisValue +} + +type fakeRedisValue struct { + value string + expiresAt time.Time +} + +func newFakeRedisServer(t *testing.T) *fakeRedisServer { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("net.Listen() error = %v", err) + } + server := &fakeRedisServer{ + t: t, + listener: ln, + password: "secret", + values: make(map[int]map[string]fakeRedisValue), + } + go server.serve() + return server +} + +func (s *fakeRedisServer) Addr() string { + return s.listener.Addr().String() +} + +func (s *fakeRedisServer) Close() { + _ = s.listener.Close() +} + +func (s *fakeRedisServer) serve() { + for { + conn, err := s.listener.Accept() + if err != nil { + return + } + go s.handleConn(conn) + } +} + +func (s *fakeRedisServer) handleConn(conn net.Conn) { + defer conn.Close() + + reader := bufio.NewReader(conn) + currentDB := 0 + authed := false + for { + command, err := readRESPArray(reader) + if err != nil { + if err == io.EOF { + return + } + s.writeError(conn, err.Error()) + return + } + if len(command) == 0 { + s.writeError(conn, "empty command") + continue + } + + switch strings.ToUpper(command[0]) { + case "PING": + s.writeSimpleString(conn, "PONG") + case "AUTH": + if len(command) != 2 || command[1] != s.password { + s.writeError(conn, "ERR invalid password") + continue + } + authed = true + s.writeSimpleString(conn, "OK") + case "SELECT": + if len(command) != 2 { + s.writeError(conn, "ERR bad select") + continue + } + db, err := strconv.Atoi(command[1]) + if err != nil { + s.writeError(conn, "ERR bad db") + continue + } + currentDB = db + s.writeSimpleString(conn, "OK") + case "SET": + if !authed { + s.writeError(conn, "NOAUTH") + continue + } + if len(command) != 5 || strings.ToUpper(command[3]) != "EX" { + s.writeError(conn, "ERR bad set") + continue + } + ttl, err := strconv.Atoi(command[4]) + if err != nil { + s.writeError(conn, "ERR bad ttl") + continue + } + s.setValue(currentDB, command[1], command[2], time.Duration(ttl)*time.Second) + s.writeSimpleString(conn, "OK") + case "GET": + if !authed { + s.writeError(conn, "NOAUTH") + continue + } + if len(command) != 2 { + s.writeError(conn, "ERR bad get") + continue + } + value, ok := s.getValue(currentDB, command[1]) + if !ok { + s.writeNullBulk(conn) + continue + } + s.writeBulk(conn, value) + case "DEL": + if !authed { + s.writeError(conn, "NOAUTH") + continue + } + if len(command) != 2 { + s.writeError(conn, "ERR bad del") + continue + } + s.deleteValue(currentDB, command[1]) + s.writeInteger(conn, 1) + default: + s.writeError(conn, "ERR unknown command") + } + } +} + +func readRESPArray(reader *bufio.Reader) ([]string, error) { + line, err := reader.ReadString('\n') + if err != nil { + return nil, err + } + line = strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r") + if !strings.HasPrefix(line, "*") { + return nil, fmt.Errorf("expected array, got %q", line) + } + count, err := strconv.Atoi(strings.TrimPrefix(line, "*")) + if err != nil { + return nil, err + } + + parts := make([]string, 0, count) + for i := 0; i < count; i++ { + header, err := reader.ReadString('\n') + if err != nil { + return nil, err + } + header = strings.TrimSuffix(strings.TrimSuffix(header, "\n"), "\r") + if !strings.HasPrefix(header, "$") { + return nil, fmt.Errorf("expected bulk header, got %q", header) + } + size, err := strconv.Atoi(strings.TrimPrefix(header, "$")) + if err != nil { + return nil, err + } + payload := make([]byte, size+2) + if _, err := io.ReadFull(reader, payload); err != nil { + return nil, err + } + parts = append(parts, string(payload[:size])) + } + return parts, nil +} + +func (s *fakeRedisServer) writeSimpleString(w io.Writer, value string) { + _, _ = io.WriteString(w, "+"+value+"\r\n") +} + +func (s *fakeRedisServer) writeBulk(w io.Writer, value string) { + _, _ = io.WriteString(w, fmt.Sprintf("$%d\r\n%s\r\n", len(value), value)) +} + +func (s *fakeRedisServer) writeNullBulk(w io.Writer) { + _, _ = io.WriteString(w, "$-1\r\n") +} + +func (s *fakeRedisServer) writeInteger(w io.Writer, value int) { + _, _ = io.WriteString(w, fmt.Sprintf(":%d\r\n", value)) +} + +func (s *fakeRedisServer) writeError(w io.Writer, message string) { + _, _ = io.WriteString(w, "-"+message+"\r\n") +} + +func (s *fakeRedisServer) setValue(db int, key, value string, ttl time.Duration) { + s.mu.Lock() + defer s.mu.Unlock() + if s.values[db] == nil { + s.values[db] = make(map[string]fakeRedisValue) + } + s.values[db][key] = fakeRedisValue{value: value, expiresAt: time.Now().Add(ttl)} +} + +func (s *fakeRedisServer) getValue(db int, key string) (string, bool) { + s.mu.Lock() + defer s.mu.Unlock() + value, ok := s.values[db][key] + if !ok { + return "", false + } + if !value.expiresAt.IsZero() && !value.expiresAt.After(time.Now()) { + delete(s.values[db], key) + return "", false + } + return value.value, true +} + +func (s *fakeRedisServer) deleteValue(db int, key string) { + s.mu.Lock() + defer s.mu.Unlock() + if s.values[db] == nil { + return + } + delete(s.values[db], key) +} + +func TestRedisStickyStoreRequiresAddr(t *testing.T) { + t.Parallel() + + if _, err := NewRedisStickyStore(context.Background(), RedisConfig{}); err == nil { + t.Fatal("NewRedisStickyStore() error = nil, want missing addr") + } +} + +func TestNormalizeRuntimeBackend(t *testing.T) { + t.Parallel() + + if got, err := NormalizeRuntimeBackend(""); err != nil || got != RuntimeBackendMemory { + t.Fatalf("NormalizeRuntimeBackend(\"\") = (%q, %v), want memory nil", got, err) + } + if got, err := NormalizeRuntimeBackend("redis"); err != nil || got != RuntimeBackendRedis { + t.Fatalf("NormalizeRuntimeBackend(redis) = (%q, %v), want redis nil", got, err) + } + if _, err := NormalizeRuntimeBackend("bad"); err == nil { + t.Fatal("NormalizeRuntimeBackend(bad) error = nil, want error") + } +} + +func TestRouteFailureAndCooldownKeyBuilders(t *testing.T) { + t.Parallel() + + failureKey, err := BuildRouteFailureKey("asxs") + if err != nil || failureKey != "routefail:asxs" { + t.Fatalf("BuildRouteFailureKey() = (%q, %v), want routefail:asxs nil", failureKey, err) + } + cooldownKey, err := BuildRouteCooldownKey("asxs") + if err != nil || cooldownKey != "routecool:asxs" { + t.Fatalf("BuildRouteCooldownKey() = (%q, %v), want routecool:asxs nil", cooldownKey, err) + } +} + +func TestRedisStickyStoreFixturePathExists(t *testing.T) { + t.Parallel() + + if filepath.Base(t.TempDir()) == "" { + t.Fatal("temp dir base should not be empty") + } +}