feat(routing): add sticky runtime backends

This commit is contained in:
phamnazage-jpg
2026-05-29 07:43:29 +08:00
parent 9d92360401
commit 98bd619ec8
11 changed files with 2046 additions and 5 deletions

View File

@@ -4,3 +4,7 @@ SUB2API_CRM_ADMIN_TOKEN=change-me-before-production
SUB2API_CRM_ADMIN_USERNAME=admin SUB2API_CRM_ADMIN_USERNAME=admin
SUB2API_CRM_ADMIN_PASSWORD=change-me-before-production SUB2API_CRM_ADMIN_PASSWORD=change-me-before-production
SUB2API_CRM_ADMIN_SESSION_TTL=12h SUB2API_CRM_ADMIN_SESSION_TTL=12h
SUB2API_CRM_ROUTE_RUNTIME_BACKEND=memory
SUB2API_CRM_REDIS_ADDR=
SUB2API_CRM_REDIS_PASSWORD=
SUB2API_CRM_REDIS_DB=0

View File

@@ -20,13 +20,17 @@ func Bootstrap(ctx context.Context) (*Server, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
stickyRuntime, err := newStickyStoreRuntime(ctx, cfg.RouteRuntime)
if err != nil {
return nil, err
}
startBackgroundSchedulers(ctx, cfg, defaultBackgroundSchedulers()) startBackgroundSchedulers(ctx, cfg, defaultBackgroundSchedulers())
handler := NewAPIHandlerWithAuth(AdminAuthConfig{ handler := NewAPIHandlerWithAuth(AdminAuthConfig{
Token: adminToken, Token: adminToken,
Username: adminSession.Username, Username: adminSession.Username,
Password: adminSession.Password, Password: adminSession.Password,
SessionTTL: adminSession.SessionTTL, SessionTTL: adminSession.SessionTTL,
}, NewActionSet(cfg.Database.SQLiteDSN)) }, NewActionSetWithStickyRuntime(cfg.Database.SQLiteDSN, stickyRuntime))
return NewServer(cfg.Server.ListenAddr, handler, nil), nil return NewServer(cfg.Server.ListenAddr, handler, nil), nil
} }

View File

@@ -51,6 +51,12 @@ type ActionSet struct {
ListRouteFailoverEvents func(context.Context, ListRouteFailoverEventsRequest) ([]RouteFailoverEventInfo, error) ListRouteFailoverEvents func(context.Context, ListRouteFailoverEventsRequest) ([]RouteFailoverEventInfo, error)
AppendRouteStickyAudit func(context.Context, AppendRouteStickyAuditRequest) (RouteStickyAuditInfo, error) AppendRouteStickyAudit func(context.Context, AppendRouteStickyAuditRequest) (RouteStickyAuditInfo, error)
ListRouteStickyAudit func(context.Context, ListRouteStickyAuditRequest) ([]RouteStickyAuditInfo, error) ListRouteStickyAudit func(context.Context, ListRouteStickyAuditRequest) ([]RouteStickyAuditInfo, error)
SetStickyBinding func(context.Context, SetStickyBindingRequest) (StickyBindingInfo, error)
GetStickyBinding func(context.Context, GetStickyBindingRequest) (StickyBindingInfo, error)
SetRouteFailure func(context.Context, SetRouteFailureRequest) (RouteFailureInfo, error)
GetRouteFailure func(context.Context, GetRouteFailureRequest) (RouteFailureInfo, error)
SetRouteCooldown func(context.Context, SetRouteCooldownRequest) (RouteCooldownInfo, error)
GetRouteCooldown func(context.Context, GetRouteCooldownRequest) (RouteCooldownInfo, error)
CreateProviderDraft func(context.Context, CreateProviderDraftRequest) (ProviderDraftInfo, error) CreateProviderDraft func(context.Context, CreateProviderDraftRequest) (ProviderDraftInfo, error)
ListProviderDrafts func(context.Context, ListProviderDraftsRequest) ([]ProviderDraftInfo, error) ListProviderDrafts func(context.Context, ListProviderDraftsRequest) ([]ProviderDraftInfo, error)
GetProviderDraft func(context.Context, string) (ProviderDraftInfo, error) GetProviderDraft func(context.Context, string) (ProviderDraftInfo, error)
@@ -392,6 +398,24 @@ func NewAPIHandlerWithAuth(adminAuth AdminAuthConfig, actions ActionSet) http.Ha
mux.Handle("GET /api/routing/logs/sticky-audit", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mux.Handle("GET /api/routing/logs/sticky-audit", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleListRouteStickyAudit(w, r, actions.ListRouteStickyAudit) handleListRouteStickyAudit(w, r, actions.ListRouteStickyAudit)
}))) })))
mux.Handle("POST /api/routing/sticky/bindings", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleSetStickyBinding(w, r, actions.SetStickyBinding)
})))
mux.Handle("GET /api/routing/sticky/bindings", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetStickyBinding(w, r, actions.GetStickyBinding)
})))
mux.Handle("POST /api/routing/sticky/route-failures", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleSetRouteFailure(w, r, actions.SetRouteFailure)
})))
mux.Handle("GET /api/routing/sticky/route-failures", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetRouteFailure(w, r, actions.GetRouteFailure)
})))
mux.Handle("POST /api/routing/sticky/cooldowns", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleSetRouteCooldown(w, r, actions.SetRouteCooldown)
})))
mux.Handle("GET /api/routing/sticky/cooldowns", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleGetRouteCooldown(w, r, actions.GetRouteCooldown)
})))
mux.Handle("POST /api/provider-drafts", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mux.Handle("POST /api/provider-drafts", requireAdminAccess(adminAuth, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleCreateProviderDraft(w, r, actions.CreateProviderDraft) handleCreateProviderDraft(w, r, actions.CreateProviderDraft)
}))) })))
@@ -1196,6 +1220,10 @@ func classifyError(err error) *httpError {
} }
func NewActionSet(sqliteDSN string) ActionSet { func NewActionSet(sqliteDSN string) ActionSet {
return NewActionSetWithStickyRuntime(sqliteDSN, defaultStickyStoreRuntime())
}
func NewActionSetWithStickyRuntime(sqliteDSN string, stickyRuntime stickyStoreRuntime) ActionSet {
routeLogWriter := newLazyRouteLogWriter(sqliteDSN) routeLogWriter := newLazyRouteLogWriter(sqliteDSN)
return ActionSet{ return ActionSet{
CreateBatchImportRun: buildCreateBatchImportRunAction(sqliteDSN), CreateBatchImportRun: buildCreateBatchImportRunAction(sqliteDSN),
@@ -1223,6 +1251,12 @@ func NewActionSet(sqliteDSN string) ActionSet {
ListRouteFailoverEvents: buildListRouteFailoverEventsAction(sqliteDSN), ListRouteFailoverEvents: buildListRouteFailoverEventsAction(sqliteDSN),
AppendRouteStickyAudit: buildAppendRouteStickyAuditAction(routeLogWriter, sqliteDSN), AppendRouteStickyAudit: buildAppendRouteStickyAuditAction(routeLogWriter, sqliteDSN),
ListRouteStickyAudit: buildListRouteStickyAuditAction(sqliteDSN), ListRouteStickyAudit: buildListRouteStickyAuditAction(sqliteDSN),
SetStickyBinding: buildSetStickyBindingAction(stickyRuntime),
GetStickyBinding: buildGetStickyBindingAction(stickyRuntime),
SetRouteFailure: buildSetRouteFailureAction(stickyRuntime),
GetRouteFailure: buildGetRouteFailureAction(stickyRuntime),
SetRouteCooldown: buildSetRouteCooldownAction(stickyRuntime),
GetRouteCooldown: buildGetRouteCooldownAction(stickyRuntime),
CreateProviderDraft: func(ctx context.Context, req CreateProviderDraftRequest) (ProviderDraftInfo, error) { CreateProviderDraft: func(ctx context.Context, req CreateProviderDraftRequest) (ProviderDraftInfo, error) {
store, err := sqlite.Open(ctx, sqliteDSN) store, err := sqlite.Open(ctx, sqliteDSN)
if err != nil { if err != nil {

View File

@@ -0,0 +1,469 @@
package app
import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"sub2api-cn-relay-manager/internal/config"
"sub2api-cn-relay-manager/internal/routing"
)
const defaultStickyTTLSeconds = 600
type stickyStoreRuntime struct {
backend string
store routing.StickyStore
}
type SetStickyBindingRequest struct {
Scope string `json:"scope"`
LogicalGroupID string `json:"logical_group_id"`
PublicModel string `json:"public_model"`
SubjectID string `json:"subject_id"`
RouteID string `json:"route_id"`
ShadowGroupID string `json:"shadow_group_id"`
TTLSeconds int `json:"ttl_seconds,omitempty"`
}
type GetStickyBindingRequest struct {
Scope string
LogicalGroupID string
PublicModel string
SubjectID string
}
type StickyBindingInfo struct {
Backend string `json:"backend"`
Key string `json:"key"`
Scope string `json:"scope"`
LogicalGroupID string `json:"logical_group_id"`
PublicModel string `json:"public_model"`
SubjectID string `json:"subject_id"`
RouteID string `json:"route_id"`
ShadowGroupID string `json:"shadow_group_id"`
BoundAt string `json:"bound_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
}
type SetRouteFailureRequest struct {
RouteID string `json:"route_id"`
FailureCount int `json:"failure_count"`
LastErrorClass string `json:"last_error_class,omitempty"`
TTLSeconds int `json:"ttl_seconds,omitempty"`
}
type RouteFailureInfo struct {
Backend string `json:"backend"`
Key string `json:"key"`
RouteID string `json:"route_id"`
FailureCount int `json:"failure_count"`
LastErrorClass string `json:"last_error_class,omitempty"`
LastFailureAt string `json:"last_failure_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
}
type GetRouteFailureRequest struct {
RouteID string
}
type SetRouteCooldownRequest struct {
RouteID string `json:"route_id"`
Reason string `json:"reason,omitempty"`
TTLSeconds int `json:"ttl_seconds,omitempty"`
}
type GetRouteCooldownRequest struct {
RouteID string
}
type RouteCooldownInfo struct {
Backend string `json:"backend"`
Key string `json:"key"`
RouteID string `json:"route_id"`
Reason string `json:"reason,omitempty"`
Until string `json:"until,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
}
func newStickyStoreRuntime(ctx context.Context, cfg config.RouteRuntimeConfig) (stickyStoreRuntime, error) {
backend, err := routing.NormalizeRuntimeBackend(cfg.Backend)
if err != nil {
return stickyStoreRuntime{}, err
}
switch backend {
case routing.RuntimeBackendMemory:
return stickyStoreRuntime{
backend: routing.RuntimeBackendMemory,
store: routing.NewInMemoryStickyStore(),
}, nil
case routing.RuntimeBackendRedis:
store, err := routing.NewRedisStickyStore(ctx, routing.RedisConfig{
Addr: cfg.Redis.Addr,
Password: cfg.Redis.Password,
DB: cfg.Redis.DB,
})
if err != nil {
return stickyStoreRuntime{}, err
}
return stickyStoreRuntime{
backend: routing.RuntimeBackendRedis,
store: store,
}, nil
default:
return stickyStoreRuntime{}, fmt.Errorf("unsupported sticky backend %q", backend)
}
}
func defaultStickyStoreRuntime() stickyStoreRuntime {
return stickyStoreRuntime{
backend: routing.RuntimeBackendMemory,
store: routing.NewInMemoryStickyStore(),
}
}
func handleSetStickyBinding(w http.ResponseWriter, r *http.Request, fn func(context.Context, SetStickyBindingRequest) (StickyBindingInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "set-sticky-binding action is not configured"})
return
}
var req SetStickyBindingRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
info, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusCreated, map[string]any{"binding": info})
}
func handleGetStickyBinding(w http.ResponseWriter, r *http.Request, fn func(context.Context, GetStickyBindingRequest) (StickyBindingInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-sticky-binding action is not configured"})
return
}
req, err := decodeGetStickyBindingRequest(r)
if err != nil {
writeHTTPError(w, err)
return
}
info, actionErr := fn(r.Context(), req)
if actionErr != nil {
writeHTTPError(w, classifyError(actionErr))
return
}
writeJSON(w, http.StatusOK, map[string]any{"binding": info})
}
func handleSetRouteFailure(w http.ResponseWriter, r *http.Request, fn func(context.Context, SetRouteFailureRequest) (RouteFailureInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "set-route-failure action is not configured"})
return
}
var req SetRouteFailureRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
info, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusCreated, map[string]any{"route_failure": info})
}
func handleGetRouteFailure(w http.ResponseWriter, r *http.Request, fn func(context.Context, GetRouteFailureRequest) (RouteFailureInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-route-failure action is not configured"})
return
}
req, err := decodeGetRouteFailureRequest(r)
if err != nil {
writeHTTPError(w, err)
return
}
info, actionErr := fn(r.Context(), req)
if actionErr != nil {
writeHTTPError(w, classifyError(actionErr))
return
}
writeJSON(w, http.StatusOK, map[string]any{"route_failure": info})
}
func handleSetRouteCooldown(w http.ResponseWriter, r *http.Request, fn func(context.Context, SetRouteCooldownRequest) (RouteCooldownInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "set-route-cooldown action is not configured"})
return
}
var req SetRouteCooldownRequest
if err := decodeJSON(r, &req); err != nil {
writeHTTPError(w, err)
return
}
info, err := fn(r.Context(), req)
if err != nil {
writeHTTPError(w, classifyError(err))
return
}
writeJSON(w, http.StatusCreated, map[string]any{"route_cooldown": info})
}
func handleGetRouteCooldown(w http.ResponseWriter, r *http.Request, fn func(context.Context, GetRouteCooldownRequest) (RouteCooldownInfo, error)) {
if fn == nil {
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-route-cooldown action is not configured"})
return
}
req, err := decodeGetRouteCooldownRequest(r)
if err != nil {
writeHTTPError(w, err)
return
}
info, actionErr := fn(r.Context(), req)
if actionErr != nil {
writeHTTPError(w, classifyError(actionErr))
return
}
writeJSON(w, http.StatusOK, map[string]any{"route_cooldown": info})
}
func buildSetStickyBindingAction(runtime stickyStoreRuntime) func(context.Context, SetStickyBindingRequest) (StickyBindingInfo, error) {
return func(ctx context.Context, req SetStickyBindingRequest) (StickyBindingInfo, error) {
key, err := routing.BuildStickyKey(req.Scope, req.LogicalGroupID, req.PublicModel, req.SubjectID)
if err != nil {
return StickyBindingInfo{}, err
}
ttl, err := secondsToDuration(req.TTLSeconds, defaultStickyTTLSeconds)
if err != nil {
return StickyBindingInfo{}, err
}
binding := routing.StickyBinding{
LogicalGroupID: strings.TrimSpace(req.LogicalGroupID),
PublicModel: strings.TrimSpace(req.PublicModel),
RouteID: strings.TrimSpace(req.RouteID),
ShadowGroupID: strings.TrimSpace(req.ShadowGroupID),
}
if err := runtime.store.Set(ctx, key, binding, ttl); err != nil {
return StickyBindingInfo{}, err
}
stored, ok, err := runtime.store.Get(ctx, key)
if err != nil {
return StickyBindingInfo{}, err
}
if !ok {
return StickyBindingInfo{}, fmt.Errorf("sticky binding %q not found after set", key)
}
return stickyBindingToInfo(runtime.backend, req.Scope, req.SubjectID, key, stored), nil
}
}
func buildGetStickyBindingAction(runtime stickyStoreRuntime) func(context.Context, GetStickyBindingRequest) (StickyBindingInfo, error) {
return func(ctx context.Context, req GetStickyBindingRequest) (StickyBindingInfo, error) {
key, err := routing.BuildStickyKey(req.Scope, req.LogicalGroupID, req.PublicModel, req.SubjectID)
if err != nil {
return StickyBindingInfo{}, err
}
stored, ok, err := runtime.store.Get(ctx, key)
if err != nil {
return StickyBindingInfo{}, err
}
if !ok {
return StickyBindingInfo{}, fmt.Errorf("sticky binding %q not found", key)
}
return stickyBindingToInfo(runtime.backend, req.Scope, req.SubjectID, key, stored), nil
}
}
func buildSetRouteFailureAction(runtime stickyStoreRuntime) func(context.Context, SetRouteFailureRequest) (RouteFailureInfo, error) {
return func(ctx context.Context, req SetRouteFailureRequest) (RouteFailureInfo, error) {
key, err := routing.BuildRouteFailureKey(req.RouteID)
if err != nil {
return RouteFailureInfo{}, err
}
ttl, err := secondsToDuration(req.TTLSeconds, defaultStickyTTLSeconds)
if err != nil {
return RouteFailureInfo{}, err
}
state := routing.RouteFailureState{
RouteID: strings.TrimSpace(req.RouteID),
FailureCount: req.FailureCount,
LastErrorClass: strings.TrimSpace(req.LastErrorClass),
}
if err := runtime.store.SetRouteFailure(ctx, req.RouteID, state, ttl); err != nil {
return RouteFailureInfo{}, err
}
stored, ok, err := runtime.store.GetRouteFailure(ctx, req.RouteID)
if err != nil {
return RouteFailureInfo{}, err
}
if !ok {
return RouteFailureInfo{}, fmt.Errorf("route failure %q not found after set", req.RouteID)
}
return routeFailureToInfo(runtime.backend, key, stored), nil
}
}
func buildGetRouteFailureAction(runtime stickyStoreRuntime) func(context.Context, GetRouteFailureRequest) (RouteFailureInfo, error) {
return func(ctx context.Context, req GetRouteFailureRequest) (RouteFailureInfo, error) {
key, err := routing.BuildRouteFailureKey(req.RouteID)
if err != nil {
return RouteFailureInfo{}, err
}
stored, ok, err := runtime.store.GetRouteFailure(ctx, req.RouteID)
if err != nil {
return RouteFailureInfo{}, err
}
if !ok {
return RouteFailureInfo{}, fmt.Errorf("route failure %q not found", req.RouteID)
}
return routeFailureToInfo(runtime.backend, key, stored), nil
}
}
func buildSetRouteCooldownAction(runtime stickyStoreRuntime) func(context.Context, SetRouteCooldownRequest) (RouteCooldownInfo, error) {
return func(ctx context.Context, req SetRouteCooldownRequest) (RouteCooldownInfo, error) {
key, err := routing.BuildRouteCooldownKey(req.RouteID)
if err != nil {
return RouteCooldownInfo{}, err
}
ttl, err := secondsToDuration(req.TTLSeconds, defaultStickyTTLSeconds)
if err != nil {
return RouteCooldownInfo{}, err
}
state := routing.RouteCooldownState{
RouteID: strings.TrimSpace(req.RouteID),
Reason: strings.TrimSpace(req.Reason),
}
if err := runtime.store.SetCooldown(ctx, req.RouteID, state, ttl); err != nil {
return RouteCooldownInfo{}, err
}
stored, ok, err := runtime.store.GetCooldown(ctx, req.RouteID)
if err != nil {
return RouteCooldownInfo{}, err
}
if !ok {
return RouteCooldownInfo{}, fmt.Errorf("route cooldown %q not found after set", req.RouteID)
}
return routeCooldownToInfo(runtime.backend, key, stored), nil
}
}
func buildGetRouteCooldownAction(runtime stickyStoreRuntime) func(context.Context, GetRouteCooldownRequest) (RouteCooldownInfo, error) {
return func(ctx context.Context, req GetRouteCooldownRequest) (RouteCooldownInfo, error) {
key, err := routing.BuildRouteCooldownKey(req.RouteID)
if err != nil {
return RouteCooldownInfo{}, err
}
stored, ok, err := runtime.store.GetCooldown(ctx, req.RouteID)
if err != nil {
return RouteCooldownInfo{}, err
}
if !ok {
return RouteCooldownInfo{}, fmt.Errorf("route cooldown %q not found", req.RouteID)
}
return routeCooldownToInfo(runtime.backend, key, stored), nil
}
}
func decodeGetStickyBindingRequest(r *http.Request) (GetStickyBindingRequest, *httpError) {
req := GetStickyBindingRequest{
Scope: strings.TrimSpace(r.URL.Query().Get("scope")),
LogicalGroupID: strings.TrimSpace(r.URL.Query().Get("logical_group_id")),
PublicModel: strings.TrimSpace(r.URL.Query().Get("public_model")),
SubjectID: strings.TrimSpace(r.URL.Query().Get("subject_id")),
}
if req.Scope == "" || req.LogicalGroupID == "" || req.PublicModel == "" || req.SubjectID == "" {
return GetStickyBindingRequest{}, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "scope, logical_group_id, public_model, subject_id are required"}
}
return req, nil
}
func decodeGetRouteFailureRequest(r *http.Request) (GetRouteFailureRequest, *httpError) {
req := GetRouteFailureRequest{
RouteID: strings.TrimSpace(r.URL.Query().Get("route_id")),
}
if req.RouteID == "" {
return GetRouteFailureRequest{}, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "route_id is required"}
}
return req, nil
}
func decodeGetRouteCooldownRequest(r *http.Request) (GetRouteCooldownRequest, *httpError) {
req := GetRouteCooldownRequest{
RouteID: strings.TrimSpace(r.URL.Query().Get("route_id")),
}
if req.RouteID == "" {
return GetRouteCooldownRequest{}, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "route_id is required"}
}
return req, nil
}
func secondsToDuration(raw int, defaultSeconds int) (time.Duration, error) {
if raw == 0 {
raw = defaultSeconds
}
if raw < 0 {
return 0, fmt.Errorf("ttl_seconds must be >= 0")
}
if raw == 0 {
return 0, fmt.Errorf("ttl_seconds must be > 0")
}
return time.Duration(raw) * time.Second, nil
}
func stickyBindingToInfo(backend, scope, subjectID, key string, binding routing.StickyBinding) StickyBindingInfo {
return StickyBindingInfo{
Backend: backend,
Key: key,
Scope: strings.TrimSpace(scope),
SubjectID: strings.TrimSpace(subjectID),
LogicalGroupID: binding.LogicalGroupID,
PublicModel: binding.PublicModel,
RouteID: binding.RouteID,
ShadowGroupID: binding.ShadowGroupID,
BoundAt: binding.BoundAt,
ExpiresAt: binding.ExpiresAt,
}
}
func routeFailureToInfo(backend, key string, state routing.RouteFailureState) RouteFailureInfo {
return RouteFailureInfo{
Backend: backend,
Key: key,
RouteID: state.RouteID,
FailureCount: state.FailureCount,
LastErrorClass: state.LastErrorClass,
LastFailureAt: state.LastFailureAt,
ExpiresAt: state.ExpiresAt,
}
}
func routeCooldownToInfo(backend, key string, state routing.RouteCooldownState) RouteCooldownInfo {
return RouteCooldownInfo{
Backend: backend,
Key: key,
RouteID: state.RouteID,
Reason: state.Reason,
Until: state.Until,
ExpiresAt: state.ExpiresAt,
}
}
func parseTTLQuery(raw string) (int, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return 0, nil
}
value, err := strconv.Atoi(raw)
if err != nil {
return 0, fmt.Errorf("ttl_seconds must be a positive integer")
}
return value, nil
}

View File

@@ -0,0 +1,262 @@
package app
import (
"context"
"net/http"
"net/url"
"path/filepath"
"testing"
"sub2api-cn-relay-manager/internal/config"
)
func TestAPISetStickyBindingReturnsCreated(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
SetStickyBinding: func(_ context.Context, req SetStickyBindingRequest) (StickyBindingInfo, error) {
if req.Scope != "conversation" {
t.Fatalf("Scope = %q, want conversation", req.Scope)
}
return StickyBindingInfo{
Backend: "memory",
Key: "lg:gpt-shared:m:gpt-5.4:conv:conv-1",
Scope: req.Scope,
RouteID: req.RouteID,
}, nil
},
})
request := httptestRequest(t, http.MethodPost, "/api/routing/sticky/bindings", map[string]any{
"scope": "conversation",
"logical_group_id": "gpt-shared",
"public_model": "gpt-5.4",
"subject_id": "conv-1",
"route_id": "asxs",
"shadow_group_id": "gpt-shared__asxs",
"ttl_seconds": 600,
}, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusCreated)
assertJSONContains(t, response.Body().Bytes(), "binding.backend", "memory")
assertJSONContains(t, response.Body().Bytes(), "binding.route_id", "asxs")
}
func TestAPIGetStickyBindingRejectsMissingQuery(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
GetStickyBinding: func(context.Context, GetStickyBindingRequest) (StickyBindingInfo, error) {
t.Fatal("GetStickyBinding should not be called")
return StickyBindingInfo{}, nil
},
})
request := httptestRequest(t, http.MethodGet, "/api/routing/sticky/bindings?scope=conversation", nil, "secret-token")
response := httptestRecorder(handler, request)
assertStatusCode(t, response, http.StatusBadRequest)
}
func TestAPISetAndGetRouteFailureAndCooldown(t *testing.T) {
handler := NewAPIHandler("secret-token", ActionSet{
SetRouteFailure: func(_ context.Context, req SetRouteFailureRequest) (RouteFailureInfo, error) {
return RouteFailureInfo{
Backend: "memory",
Key: "routefail:" + req.RouteID,
RouteID: req.RouteID,
FailureCount: req.FailureCount,
}, nil
},
GetRouteFailure: func(_ context.Context, req GetRouteFailureRequest) (RouteFailureInfo, error) {
return RouteFailureInfo{
Backend: "memory",
Key: "routefail:" + req.RouteID,
RouteID: req.RouteID,
FailureCount: 2,
}, nil
},
SetRouteCooldown: func(_ context.Context, req SetRouteCooldownRequest) (RouteCooldownInfo, error) {
return RouteCooldownInfo{
Backend: "memory",
Key: "routecool:" + req.RouteID,
RouteID: req.RouteID,
Reason: req.Reason,
}, nil
},
GetRouteCooldown: func(_ context.Context, req GetRouteCooldownRequest) (RouteCooldownInfo, error) {
return RouteCooldownInfo{
Backend: "memory",
Key: "routecool:" + req.RouteID,
RouteID: req.RouteID,
Reason: "cooldown",
}, nil
},
})
setFailureReq := httptestRequest(t, http.MethodPost, "/api/routing/sticky/route-failures", map[string]any{
"route_id": "asxs",
"failure_count": 2,
"ttl_seconds": 600,
}, "secret-token")
setFailureResp := httptestRecorder(handler, setFailureReq)
assertStatusCode(t, setFailureResp, http.StatusCreated)
assertJSONContains(t, setFailureResp.Body().Bytes(), "route_failure.key", "routefail:asxs")
getFailureReq := httptestRequest(t, http.MethodGet, "/api/routing/sticky/route-failures?route_id=asxs", nil, "secret-token")
getFailureResp := httptestRecorder(handler, getFailureReq)
assertStatusCode(t, getFailureResp, http.StatusOK)
assertJSONContains(t, getFailureResp.Body().Bytes(), "route_failure.failure_count", float64(2))
setCooldownReq := httptestRequest(t, http.MethodPost, "/api/routing/sticky/cooldowns", map[string]any{
"route_id": "asxs",
"reason": "cooldown",
"ttl_seconds": 600,
}, "secret-token")
setCooldownResp := httptestRecorder(handler, setCooldownReq)
assertStatusCode(t, setCooldownResp, http.StatusCreated)
assertJSONContains(t, setCooldownResp.Body().Bytes(), "route_cooldown.key", "routecool:asxs")
getCooldownReq := httptestRequest(t, http.MethodGet, "/api/routing/sticky/cooldowns?route_id=asxs", nil, "secret-token")
getCooldownResp := httptestRecorder(handler, getCooldownReq)
assertStatusCode(t, getCooldownResp, http.StatusOK)
assertJSONContains(t, getCooldownResp.Body().Bytes(), "route_cooldown.reason", "cooldown")
}
func TestNewActionSetStickyRuntimeFlow(t *testing.T) {
dsn := "file:" + filepath.ToSlash(filepath.Join(t.TempDir(), "sticky.db")) + "?_busy_timeout=5000"
actions := NewActionSet(dsn)
ctx := context.Background()
binding, err := actions.SetStickyBinding(ctx, SetStickyBindingRequest{
Scope: "conversation",
LogicalGroupID: "gpt-shared",
PublicModel: "gpt-5.4",
SubjectID: "conv-1",
RouteID: "asxs",
ShadowGroupID: "gpt-shared__asxs",
TTLSeconds: 600,
})
if err != nil {
t.Fatalf("SetStickyBinding() error = %v", err)
}
if binding.Backend != "memory" || binding.RouteID != "asxs" {
t.Fatalf("SetStickyBinding() = %+v, want memory/asxs", binding)
}
loadedBinding, err := actions.GetStickyBinding(ctx, GetStickyBindingRequest{
Scope: "conversation",
LogicalGroupID: "gpt-shared",
PublicModel: "gpt-5.4",
SubjectID: "conv-1",
})
if err != nil {
t.Fatalf("GetStickyBinding() error = %v", err)
}
if loadedBinding.Key != binding.Key {
t.Fatalf("GetStickyBinding().Key = %q, want %q", loadedBinding.Key, binding.Key)
}
failure, err := actions.SetRouteFailure(ctx, SetRouteFailureRequest{
RouteID: "asxs",
FailureCount: 2,
LastErrorClass: "timeout",
TTLSeconds: 600,
})
if err != nil {
t.Fatalf("SetRouteFailure() error = %v", err)
}
if failure.FailureCount != 2 {
t.Fatalf("SetRouteFailure() = %+v, want count 2", failure)
}
loadedFailure, err := actions.GetRouteFailure(ctx, GetRouteFailureRequest{RouteID: "asxs"})
if err != nil {
t.Fatalf("GetRouteFailure() error = %v", err)
}
if loadedFailure.Key != failure.Key {
t.Fatalf("GetRouteFailure().Key = %q, want %q", loadedFailure.Key, failure.Key)
}
cooldown, err := actions.SetRouteCooldown(ctx, SetRouteCooldownRequest{
RouteID: "asxs",
Reason: "degraded",
TTLSeconds: 600,
})
if err != nil {
t.Fatalf("SetRouteCooldown() error = %v", err)
}
if cooldown.Reason != "degraded" {
t.Fatalf("SetRouteCooldown() = %+v, want reason degraded", cooldown)
}
loadedCooldown, err := actions.GetRouteCooldown(ctx, GetRouteCooldownRequest{RouteID: "asxs"})
if err != nil {
t.Fatalf("GetRouteCooldown() error = %v", err)
}
if loadedCooldown.Key != cooldown.Key {
t.Fatalf("GetRouteCooldown().Key = %q, want %q", loadedCooldown.Key, cooldown.Key)
}
}
func TestStickyRuntimeHelpers(t *testing.T) {
t.Parallel()
runtime, err := newStickyStoreRuntime(context.Background(), config.RouteRuntimeConfig{
Backend: "memory",
})
if err != nil {
t.Fatalf("newStickyStoreRuntime(memory) error = %v", err)
}
if runtime.backend != "memory" || runtime.store == nil {
t.Fatalf("newStickyStoreRuntime(memory) = %+v, want memory backend with store", runtime)
}
if _, err := newStickyStoreRuntime(context.Background(), config.RouteRuntimeConfig{
Backend: "bad",
}); err == nil {
t.Fatal("newStickyStoreRuntime(bad) error = nil, want error")
}
if got := defaultStickyStoreRuntime(); got.backend != "memory" || got.store == nil {
t.Fatalf("defaultStickyStoreRuntime() = %+v, want memory backend with store", got)
}
}
func TestStickyRequestDecodersAndHelpers(t *testing.T) {
t.Parallel()
req := &http.Request{URL: &url.URL{RawQuery: "scope=conversation&logical_group_id=gpt-shared&public_model=gpt-5.4&subject_id=conv-1"}}
stickyReq, stickyErr := decodeGetStickyBindingRequest(req)
if stickyErr != nil {
t.Fatalf("decodeGetStickyBindingRequest() error = %v", stickyErr)
}
if stickyReq.SubjectID != "conv-1" {
t.Fatalf("decodeGetStickyBindingRequest() = %+v, want subject_id conv-1", stickyReq)
}
req = &http.Request{URL: &url.URL{RawQuery: "route_id=asxs"}}
failureReq, failureErr := decodeGetRouteFailureRequest(req)
if failureErr != nil {
t.Fatalf("decodeGetRouteFailureRequest() error = %v", failureErr)
}
if failureReq.RouteID != "asxs" {
t.Fatalf("decodeGetRouteFailureRequest() = %+v, want route_id asxs", failureReq)
}
cooldownReq, cooldownErr := decodeGetRouteCooldownRequest(req)
if cooldownErr != nil {
t.Fatalf("decodeGetRouteCooldownRequest() error = %v", cooldownErr)
}
if cooldownReq.RouteID != "asxs" {
t.Fatalf("decodeGetRouteCooldownRequest() = %+v, want route_id asxs", cooldownReq)
}
if _, err := secondsToDuration(-1, defaultStickyTTLSeconds); err == nil {
t.Fatal("secondsToDuration(-1) error = nil, want error")
}
duration, err := secondsToDuration(0, defaultStickyTTLSeconds)
if err != nil {
t.Fatalf("secondsToDuration(default) error = %v", err)
}
if duration <= 0 {
t.Fatalf("secondsToDuration(default) = %s, want positive", duration)
}
if _, err := parseTTLQuery("bad"); err == nil {
t.Fatal("parseTTLQuery(bad) error = nil, want error")
}
}

View File

@@ -3,6 +3,7 @@ package config
import ( import (
"fmt" "fmt"
"os" "os"
"strconv"
"strings" "strings"
"time" "time"
) )
@@ -17,12 +18,17 @@ const (
EnvRepoRoot = "SUB2API_CRM_REPO_ROOT" EnvRepoRoot = "SUB2API_CRM_REPO_ROOT"
EnvReconcileWorkerEnabled = "SUB2API_CRM_RECONCILE_WORKER_ENABLED" EnvReconcileWorkerEnabled = "SUB2API_CRM_RECONCILE_WORKER_ENABLED"
EnvReconcilePollInterval = "SUB2API_CRM_RECONCILE_POLL_INTERVAL" EnvReconcilePollInterval = "SUB2API_CRM_RECONCILE_POLL_INTERVAL"
EnvRouteRuntimeBackend = "SUB2API_CRM_ROUTE_RUNTIME_BACKEND"
EnvRedisAddr = "SUB2API_CRM_REDIS_ADDR"
EnvRedisPassword = "SUB2API_CRM_REDIS_PASSWORD"
EnvRedisDB = "SUB2API_CRM_REDIS_DB"
DefaultListenAddr = ":8080" DefaultListenAddr = ":8080"
DefaultSQLiteDSN = "file:sub2api-cn-relay-manager.db?_foreign_keys=on&_busy_timeout=5000" DefaultSQLiteDSN = "file:sub2api-cn-relay-manager.db?_foreign_keys=on&_busy_timeout=5000"
DefaultAdminUsername = "admin" DefaultAdminUsername = "admin"
DefaultAdminSessionTTL = 12 * time.Hour DefaultAdminSessionTTL = 12 * time.Hour
DefaultReconcilePollInterval = 10 * time.Minute DefaultReconcilePollInterval = 10 * time.Minute
DefaultRouteRuntimeBackend = "memory"
) )
type ServerConfig struct { type ServerConfig struct {
@@ -33,6 +39,17 @@ type DatabaseConfig struct {
SQLiteDSN string SQLiteDSN string
} }
type RedisRuntimeConfig struct {
Addr string
Password string
DB int
}
type RouteRuntimeConfig struct {
Backend string
Redis RedisRuntimeConfig
}
type ReconcileConfig struct { type ReconcileConfig struct {
WorkerEnabled bool WorkerEnabled bool
PollInterval time.Duration PollInterval time.Duration
@@ -43,10 +60,11 @@ type RepositoryConfig struct {
} }
type StartupConfig struct { type StartupConfig struct {
Server ServerConfig Server ServerConfig
Database DatabaseConfig Database DatabaseConfig
Repository RepositoryConfig Repository RepositoryConfig
Reconcile ReconcileConfig RouteRuntime RouteRuntimeConfig
Reconcile ReconcileConfig
} }
type AdminSessionConfig struct { type AdminSessionConfig struct {
@@ -64,6 +82,10 @@ func loadStartupFromLookupEnv(lookup func(string) (string, bool)) (StartupConfig
if err != nil { if err != nil {
return StartupConfig{}, err return StartupConfig{}, err
} }
redisDB, err := readOptionalIntEnv(lookup, EnvRedisDB, 0)
if err != nil {
return StartupConfig{}, err
}
cfg := StartupConfig{ cfg := StartupConfig{
Server: ServerConfig{ Server: ServerConfig{
ListenAddr: readOptionalEnv(lookup, EnvListenAddr, DefaultListenAddr), ListenAddr: readOptionalEnv(lookup, EnvListenAddr, DefaultListenAddr),
@@ -74,6 +96,14 @@ func loadStartupFromLookupEnv(lookup func(string) (string, bool)) (StartupConfig
Repository: RepositoryConfig{ Repository: RepositoryConfig{
RepoRoot: readOptionalEnv(lookup, EnvRepoRoot, ""), RepoRoot: readOptionalEnv(lookup, EnvRepoRoot, ""),
}, },
RouteRuntime: RouteRuntimeConfig{
Backend: readOptionalEnv(lookup, EnvRouteRuntimeBackend, DefaultRouteRuntimeBackend),
Redis: RedisRuntimeConfig{
Addr: readOptionalEnv(lookup, EnvRedisAddr, ""),
Password: readOptionalEnv(lookup, EnvRedisPassword, ""),
DB: redisDB,
},
},
Reconcile: ReconcileConfig{ Reconcile: ReconcileConfig{
WorkerEnabled: readOptionalBoolEnv(lookup, EnvReconcileWorkerEnabled, false), WorkerEnabled: readOptionalBoolEnv(lookup, EnvReconcileWorkerEnabled, false),
PollInterval: reconcilePollInterval, PollInterval: reconcilePollInterval,
@@ -164,3 +194,18 @@ func readOptionalDurationEnv(lookup func(string) (string, bool), key string, def
} }
return duration, nil return duration, nil
} }
func readOptionalIntEnv(lookup func(string) (string, bool), key string, defaultValue int) (int, error) {
value, ok := lookup(key)
if !ok || strings.TrimSpace(value) == "" {
return defaultValue, nil
}
number, err := strconv.Atoi(strings.TrimSpace(value))
if err != nil {
return 0, fmt.Errorf("%s: parse int: %w", key, err)
}
if number < 0 {
return 0, fmt.Errorf("%s: value must be >= 0", key)
}
return number, nil
}

View File

@@ -69,6 +69,14 @@ func TestLoadStartupFromLookupEnv(t *testing.T) {
return "true", true return "true", true
case EnvReconcilePollInterval: case EnvReconcilePollInterval:
return "15m", true return "15m", true
case EnvRouteRuntimeBackend:
return "redis", true
case EnvRedisAddr:
return "127.0.0.1:16379", true
case EnvRedisPassword:
return " redis-pass ", true
case EnvRedisDB:
return "5", true
default: default:
return "", false return "", false
} }
@@ -92,6 +100,18 @@ func TestLoadStartupFromLookupEnv(t *testing.T) {
if cfg.Reconcile.PollInterval != 15*time.Minute { if cfg.Reconcile.PollInterval != 15*time.Minute {
t.Fatalf("PollInterval = %s, want 15m", cfg.Reconcile.PollInterval) t.Fatalf("PollInterval = %s, want 15m", cfg.Reconcile.PollInterval)
} }
if cfg.RouteRuntime.Backend != "redis" {
t.Fatalf("RouteRuntime.Backend = %q, want redis", cfg.RouteRuntime.Backend)
}
if cfg.RouteRuntime.Redis.Addr != "127.0.0.1:16379" {
t.Fatalf("RouteRuntime.Redis.Addr = %q, want 127.0.0.1:16379", cfg.RouteRuntime.Redis.Addr)
}
if cfg.RouteRuntime.Redis.Password != "redis-pass" {
t.Fatalf("RouteRuntime.Redis.Password = %q, want redis-pass", cfg.RouteRuntime.Redis.Password)
}
if cfg.RouteRuntime.Redis.DB != 5 {
t.Fatalf("RouteRuntime.Redis.DB = %d, want 5", cfg.RouteRuntime.Redis.DB)
}
}) })
t.Run("default values", func(t *testing.T) { t.Run("default values", func(t *testing.T) {
lookup := func(k string) (string, bool) { lookup := func(k string) (string, bool) {
@@ -116,6 +136,12 @@ func TestLoadStartupFromLookupEnv(t *testing.T) {
if cfg.Reconcile.PollInterval != DefaultReconcilePollInterval { if cfg.Reconcile.PollInterval != DefaultReconcilePollInterval {
t.Fatalf("PollInterval = %s, want %s", cfg.Reconcile.PollInterval, DefaultReconcilePollInterval) t.Fatalf("PollInterval = %s, want %s", cfg.Reconcile.PollInterval, DefaultReconcilePollInterval)
} }
if cfg.RouteRuntime.Backend != DefaultRouteRuntimeBackend {
t.Fatalf("RouteRuntime.Backend = %q, want %q", cfg.RouteRuntime.Backend, DefaultRouteRuntimeBackend)
}
if cfg.RouteRuntime.Redis.Addr != "" || cfg.RouteRuntime.Redis.Password != "" || cfg.RouteRuntime.Redis.DB != 0 {
t.Fatalf("RouteRuntime.Redis = %+v, want zero value", cfg.RouteRuntime.Redis)
}
}) })
t.Run("invalid reconcile interval", func(t *testing.T) { t.Run("invalid reconcile interval", func(t *testing.T) {
lookup := func(k string) (string, bool) { lookup := func(k string) (string, bool) {
@@ -128,6 +154,17 @@ func TestLoadStartupFromLookupEnv(t *testing.T) {
t.Fatal("loadStartupFromLookupEnv() error = nil, want invalid interval") t.Fatal("loadStartupFromLookupEnv() error = nil, want invalid interval")
} }
}) })
t.Run("invalid redis db", func(t *testing.T) {
lookup := func(k string) (string, bool) {
if k == EnvRedisDB {
return "-1", true
}
return "", false
}
if _, err := loadStartupFromLookupEnv(lookup); err == nil {
t.Fatal("loadStartupFromLookupEnv() error = nil, want invalid redis db")
}
})
} }
func TestLoadAdminTokenFromLookupEnv(t *testing.T) { func TestLoadAdminTokenFromLookupEnv(t *testing.T) {

199
internal/routing/sticky.go Normal file
View File

@@ -0,0 +1,199 @@
package routing
import (
"context"
"fmt"
"strings"
"time"
)
const (
RuntimeBackendMemory = "memory"
RuntimeBackendRedis = "redis"
StickyScopeConversation = "conversation"
StickyScopeSession = "session"
StickyScopeUser = "user"
)
type StickyBinding struct {
LogicalGroupID string `json:"logical_group_id"`
PublicModel string `json:"public_model"`
RouteID string `json:"route_id"`
ShadowGroupID string `json:"shadow_group_id"`
BoundAt string `json:"bound_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
}
type RouteFailureState struct {
RouteID string `json:"route_id"`
FailureCount int `json:"failure_count"`
LastErrorClass string `json:"last_error_class,omitempty"`
LastFailureAt string `json:"last_failure_at,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
}
type RouteCooldownState struct {
RouteID string `json:"route_id"`
Reason string `json:"reason,omitempty"`
Until string `json:"until,omitempty"`
ExpiresAt string `json:"expires_at,omitempty"`
}
type StickyStore interface {
Get(ctx context.Context, key string) (StickyBinding, bool, error)
Set(ctx context.Context, key string, binding StickyBinding, ttl time.Duration) error
Delete(ctx context.Context, key string) error
GetRouteFailure(ctx context.Context, routeID string) (RouteFailureState, bool, error)
SetRouteFailure(ctx context.Context, routeID string, state RouteFailureState, ttl time.Duration) error
ClearRouteFailure(ctx context.Context, routeID string) error
GetCooldown(ctx context.Context, routeID string) (RouteCooldownState, bool, error)
SetCooldown(ctx context.Context, routeID string, state RouteCooldownState, ttl time.Duration) error
ClearCooldown(ctx context.Context, routeID string) error
}
type StoreConfig struct {
Backend string
Redis RedisConfig
}
func NormalizeRuntimeBackend(backend string) (string, error) {
switch strings.ToLower(strings.TrimSpace(backend)) {
case "", RuntimeBackendMemory:
return RuntimeBackendMemory, nil
case RuntimeBackendRedis:
return RuntimeBackendRedis, nil
default:
return "", fmt.Errorf("unsupported route runtime backend %q", backend)
}
}
func BuildStickyKey(scope, logicalGroupID, publicModel, subjectID string) (string, error) {
scope = strings.ToLower(strings.TrimSpace(scope))
logicalGroupID = strings.TrimSpace(logicalGroupID)
publicModel = strings.TrimSpace(publicModel)
subjectID = strings.TrimSpace(subjectID)
switch {
case logicalGroupID == "":
return "", fmt.Errorf("logical_group_id is required")
case publicModel == "":
return "", fmt.Errorf("public_model is required")
case subjectID == "":
return "", fmt.Errorf("subject_id is required")
}
var subjectPrefix string
switch scope {
case StickyScopeConversation:
subjectPrefix = "conv"
case StickyScopeSession:
subjectPrefix = "sess"
case StickyScopeUser:
subjectPrefix = "user"
default:
return "", fmt.Errorf("unsupported sticky scope %q", scope)
}
return fmt.Sprintf("lg:%s:m:%s:%s:%s", logicalGroupID, publicModel, subjectPrefix, subjectID), nil
}
func BuildRouteFailureKey(routeID string) (string, error) {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return "", fmt.Errorf("route_id is required")
}
return "routefail:" + routeID, nil
}
func BuildRouteCooldownKey(routeID string) (string, error) {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return "", fmt.Errorf("route_id is required")
}
return "routecool:" + routeID, nil
}
func normalizeStickyBinding(binding StickyBinding, ttl time.Duration, now time.Time) (StickyBinding, error) {
binding.LogicalGroupID = strings.TrimSpace(binding.LogicalGroupID)
binding.PublicModel = strings.TrimSpace(binding.PublicModel)
binding.RouteID = strings.TrimSpace(binding.RouteID)
binding.ShadowGroupID = strings.TrimSpace(binding.ShadowGroupID)
switch {
case binding.LogicalGroupID == "":
return StickyBinding{}, fmt.Errorf("logical_group_id is required")
case binding.PublicModel == "":
return StickyBinding{}, fmt.Errorf("public_model is required")
case binding.RouteID == "":
return StickyBinding{}, fmt.Errorf("route_id is required")
case binding.ShadowGroupID == "":
return StickyBinding{}, fmt.Errorf("shadow_group_id is required")
case ttl <= 0:
return StickyBinding{}, fmt.Errorf("ttl must be positive")
}
if binding.BoundAt == "" {
binding.BoundAt = now.UTC().Format(time.RFC3339)
}
if binding.ExpiresAt == "" {
binding.ExpiresAt = now.UTC().Add(ttl).Format(time.RFC3339)
}
return binding, nil
}
func normalizeRouteFailureState(routeID string, state RouteFailureState, ttl time.Duration, now time.Time) (RouteFailureState, error) {
routeID = strings.TrimSpace(routeID)
state.RouteID = strings.TrimSpace(state.RouteID)
state.LastErrorClass = strings.TrimSpace(state.LastErrorClass)
if state.RouteID == "" {
state.RouteID = routeID
}
switch {
case routeID == "":
return RouteFailureState{}, fmt.Errorf("route_id is required")
case state.RouteID != routeID:
return RouteFailureState{}, fmt.Errorf("route_id mismatch")
case state.FailureCount < 0:
return RouteFailureState{}, fmt.Errorf("failure_count must be >= 0")
case ttl <= 0:
return RouteFailureState{}, fmt.Errorf("ttl must be positive")
}
if state.LastFailureAt == "" {
state.LastFailureAt = now.UTC().Format(time.RFC3339)
}
if state.ExpiresAt == "" {
state.ExpiresAt = now.UTC().Add(ttl).Format(time.RFC3339)
}
return state, nil
}
func normalizeRouteCooldownState(routeID string, state RouteCooldownState, ttl time.Duration, now time.Time) (RouteCooldownState, error) {
routeID = strings.TrimSpace(routeID)
state.RouteID = strings.TrimSpace(state.RouteID)
state.Reason = strings.TrimSpace(state.Reason)
if state.RouteID == "" {
state.RouteID = routeID
}
switch {
case routeID == "":
return RouteCooldownState{}, fmt.Errorf("route_id is required")
case state.RouteID != routeID:
return RouteCooldownState{}, fmt.Errorf("route_id mismatch")
case ttl <= 0:
return RouteCooldownState{}, fmt.Errorf("ttl must be positive")
}
if state.Until == "" {
state.Until = now.UTC().Add(ttl).Format(time.RFC3339)
}
if state.ExpiresAt == "" {
state.ExpiresAt = now.UTC().Add(ttl).Format(time.RFC3339)
}
return state, nil
}

View File

@@ -0,0 +1,175 @@
package routing
import (
"context"
"strings"
"sync"
"time"
)
type memoryStickyEntry struct {
binding StickyBinding
expiresAt time.Time
}
type memoryRouteFailureEntry struct {
state RouteFailureState
expiresAt time.Time
}
type memoryRouteCooldownEntry struct {
state RouteCooldownState
expiresAt time.Time
}
type InMemoryStickyStore struct {
now func() time.Time
mu sync.RWMutex
bindings map[string]memoryStickyEntry
routeFailure map[string]memoryRouteFailureEntry
cooldowns map[string]memoryRouteCooldownEntry
}
func NewInMemoryStickyStore() *InMemoryStickyStore {
return &InMemoryStickyStore{
now: time.Now,
bindings: make(map[string]memoryStickyEntry),
routeFailure: make(map[string]memoryRouteFailureEntry),
cooldowns: make(map[string]memoryRouteCooldownEntry),
}
}
func (s *InMemoryStickyStore) Get(_ context.Context, key string) (StickyBinding, bool, error) {
key = strings.TrimSpace(key)
if key == "" {
return StickyBinding{}, false, nil
}
s.mu.Lock()
defer s.mu.Unlock()
entry, ok := s.bindings[key]
if !ok {
return StickyBinding{}, false, nil
}
if s.expired(entry.expiresAt) {
delete(s.bindings, key)
return StickyBinding{}, false, nil
}
return entry.binding, true, nil
}
func (s *InMemoryStickyStore) Set(_ context.Context, key string, binding StickyBinding, ttl time.Duration) error {
key = strings.TrimSpace(key)
binding, err := normalizeStickyBinding(binding, ttl, s.now())
if err != nil {
return err
}
if key == "" {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
s.bindings[key] = memoryStickyEntry{binding: binding, expiresAt: s.now().UTC().Add(ttl)}
return nil
}
func (s *InMemoryStickyStore) Delete(_ context.Context, key string) error {
key = strings.TrimSpace(key)
if key == "" {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.bindings, key)
return nil
}
func (s *InMemoryStickyStore) GetRouteFailure(_ context.Context, routeID string) (RouteFailureState, bool, error) {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return RouteFailureState{}, false, nil
}
s.mu.Lock()
defer s.mu.Unlock()
entry, ok := s.routeFailure[routeID]
if !ok {
return RouteFailureState{}, false, nil
}
if s.expired(entry.expiresAt) {
delete(s.routeFailure, routeID)
return RouteFailureState{}, false, nil
}
return entry.state, true, nil
}
func (s *InMemoryStickyStore) SetRouteFailure(_ context.Context, routeID string, state RouteFailureState, ttl time.Duration) error {
state, err := normalizeRouteFailureState(routeID, state, ttl, s.now())
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
s.routeFailure[state.RouteID] = memoryRouteFailureEntry{state: state, expiresAt: s.now().UTC().Add(ttl)}
return nil
}
func (s *InMemoryStickyStore) ClearRouteFailure(_ context.Context, routeID string) error {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.routeFailure, routeID)
return nil
}
func (s *InMemoryStickyStore) GetCooldown(_ context.Context, routeID string) (RouteCooldownState, bool, error) {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return RouteCooldownState{}, false, nil
}
s.mu.Lock()
defer s.mu.Unlock()
entry, ok := s.cooldowns[routeID]
if !ok {
return RouteCooldownState{}, false, nil
}
if s.expired(entry.expiresAt) {
delete(s.cooldowns, routeID)
return RouteCooldownState{}, false, nil
}
return entry.state, true, nil
}
func (s *InMemoryStickyStore) SetCooldown(_ context.Context, routeID string, state RouteCooldownState, ttl time.Duration) error {
state, err := normalizeRouteCooldownState(routeID, state, ttl, s.now())
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
s.cooldowns[state.RouteID] = memoryRouteCooldownEntry{state: state, expiresAt: s.now().UTC().Add(ttl)}
return nil
}
func (s *InMemoryStickyStore) ClearCooldown(_ context.Context, routeID string) error {
routeID = strings.TrimSpace(routeID)
if routeID == "" {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
delete(s.cooldowns, routeID)
return nil
}
func (s *InMemoryStickyStore) expired(expiresAt time.Time) bool {
return !expiresAt.IsZero() && !expiresAt.After(s.now().UTC())
}

View File

@@ -0,0 +1,352 @@
package routing
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"net"
"strconv"
"strings"
"time"
)
const (
defaultRedisDialTimeout = 3 * time.Second
)
type RedisConfig struct {
Addr string
Password string
DB int
DialTimeout time.Duration
}
type RedisStickyStore struct {
cfg RedisConfig
}
func NewRedisStickyStore(ctx context.Context, cfg RedisConfig) (*RedisStickyStore, error) {
cfg = normalizeRedisConfig(cfg)
if strings.TrimSpace(cfg.Addr) == "" {
return nil, fmt.Errorf("redis addr is required")
}
store := &RedisStickyStore{cfg: cfg}
if err := store.ping(ctx); err != nil {
return nil, err
}
return store, nil
}
func (s *RedisStickyStore) Get(ctx context.Context, key string) (StickyBinding, bool, error) {
key = strings.TrimSpace(key)
if key == "" {
return StickyBinding{}, false, nil
}
payload, ok, err := s.getJSON(ctx, key)
if err != nil || !ok {
return StickyBinding{}, ok, err
}
var binding StickyBinding
if err := json.Unmarshal(payload, &binding); err != nil {
return StickyBinding{}, false, fmt.Errorf("decode sticky binding %q: %w", key, err)
}
return binding, true, nil
}
func (s *RedisStickyStore) Set(ctx context.Context, key string, binding StickyBinding, ttl time.Duration) error {
key = strings.TrimSpace(key)
if key == "" {
return nil
}
binding, err := normalizeStickyBinding(binding, ttl, time.Now())
if err != nil {
return err
}
return s.setJSON(ctx, key, binding, ttl)
}
func (s *RedisStickyStore) Delete(ctx context.Context, key string) error {
key = strings.TrimSpace(key)
if key == "" {
return nil
}
return s.delKey(ctx, key)
}
func (s *RedisStickyStore) GetRouteFailure(ctx context.Context, routeID string) (RouteFailureState, bool, error) {
key, err := BuildRouteFailureKey(routeID)
if err != nil {
return RouteFailureState{}, false, err
}
payload, ok, err := s.getJSON(ctx, key)
if err != nil || !ok {
return RouteFailureState{}, ok, err
}
var state RouteFailureState
if err := json.Unmarshal(payload, &state); err != nil {
return RouteFailureState{}, false, fmt.Errorf("decode route failure %q: %w", routeID, err)
}
return state, true, nil
}
func (s *RedisStickyStore) SetRouteFailure(ctx context.Context, routeID string, state RouteFailureState, ttl time.Duration) error {
key, err := BuildRouteFailureKey(routeID)
if err != nil {
return err
}
state, err = normalizeRouteFailureState(routeID, state, ttl, time.Now())
if err != nil {
return err
}
return s.setJSON(ctx, key, state, ttl)
}
func (s *RedisStickyStore) ClearRouteFailure(ctx context.Context, routeID string) error {
key, err := BuildRouteFailureKey(routeID)
if err != nil {
return err
}
return s.delKey(ctx, key)
}
func (s *RedisStickyStore) GetCooldown(ctx context.Context, routeID string) (RouteCooldownState, bool, error) {
key, err := BuildRouteCooldownKey(routeID)
if err != nil {
return RouteCooldownState{}, false, err
}
payload, ok, err := s.getJSON(ctx, key)
if err != nil || !ok {
return RouteCooldownState{}, ok, err
}
var state RouteCooldownState
if err := json.Unmarshal(payload, &state); err != nil {
return RouteCooldownState{}, false, fmt.Errorf("decode route cooldown %q: %w", routeID, err)
}
return state, true, nil
}
func (s *RedisStickyStore) SetCooldown(ctx context.Context, routeID string, state RouteCooldownState, ttl time.Duration) error {
key, err := BuildRouteCooldownKey(routeID)
if err != nil {
return err
}
state, err = normalizeRouteCooldownState(routeID, state, ttl, time.Now())
if err != nil {
return err
}
return s.setJSON(ctx, key, state, ttl)
}
func (s *RedisStickyStore) ClearCooldown(ctx context.Context, routeID string) error {
key, err := BuildRouteCooldownKey(routeID)
if err != nil {
return err
}
return s.delKey(ctx, key)
}
func (s *RedisStickyStore) ping(ctx context.Context) error {
conn, reader, err := s.open(ctx)
if err != nil {
return err
}
defer conn.Close()
if err := writeRESPArray(conn, "PING"); err != nil {
return fmt.Errorf("redis ping write: %w", err)
}
reply, err := readRESPValue(reader)
if err != nil {
return fmt.Errorf("redis ping read: %w", err)
}
if reply.kind != '+' || reply.stringValue != "PONG" {
return fmt.Errorf("redis ping unexpected response: kind=%q value=%q", reply.kind, reply.stringValue)
}
return nil
}
func (s *RedisStickyStore) getJSON(ctx context.Context, key string) ([]byte, bool, error) {
conn, reader, err := s.open(ctx)
if err != nil {
return nil, false, err
}
defer conn.Close()
if err := writeRESPArray(conn, "GET", key); err != nil {
return nil, false, fmt.Errorf("redis GET %q: write: %w", key, err)
}
reply, err := readRESPValue(reader)
if err != nil {
return nil, false, fmt.Errorf("redis GET %q: read: %w", key, err)
}
switch reply.kind {
case '$':
return []byte(reply.stringValue), true, nil
case '_':
return nil, false, nil
default:
return nil, false, fmt.Errorf("redis GET %q: unexpected response %q", key, reply.kind)
}
}
func (s *RedisStickyStore) setJSON(ctx context.Context, key string, payload any, ttl time.Duration) error {
encoded, err := json.Marshal(payload)
if err != nil {
return fmt.Errorf("marshal redis value for %q: %w", key, err)
}
conn, reader, err := s.open(ctx)
if err != nil {
return err
}
defer conn.Close()
seconds := int(ttl / time.Second)
if ttl%time.Second != 0 {
seconds++
}
if seconds <= 0 {
seconds = 1
}
if err := writeRESPArray(conn, "SET", key, string(encoded), "EX", strconv.Itoa(seconds)); err != nil {
return fmt.Errorf("redis SET %q: write: %w", key, err)
}
reply, err := readRESPValue(reader)
if err != nil {
return fmt.Errorf("redis SET %q: read: %w", key, err)
}
if reply.kind != '+' || reply.stringValue != "OK" {
return fmt.Errorf("redis SET %q: unexpected response kind=%q value=%q", key, reply.kind, reply.stringValue)
}
return nil
}
func (s *RedisStickyStore) delKey(ctx context.Context, key string) error {
conn, reader, err := s.open(ctx)
if err != nil {
return err
}
defer conn.Close()
if err := writeRESPArray(conn, "DEL", key); err != nil {
return fmt.Errorf("redis DEL %q: write: %w", key, err)
}
reply, err := readRESPValue(reader)
if err != nil {
return fmt.Errorf("redis DEL %q: read: %w", key, err)
}
if reply.kind != ':' {
return fmt.Errorf("redis DEL %q: unexpected response kind=%q value=%q", key, reply.kind, reply.stringValue)
}
return nil
}
func (s *RedisStickyStore) open(ctx context.Context) (net.Conn, *bufio.Reader, error) {
cfg := normalizeRedisConfig(s.cfg)
dialer := &net.Dialer{Timeout: cfg.DialTimeout}
conn, err := dialer.DialContext(ctx, "tcp", cfg.Addr)
if err != nil {
return nil, nil, fmt.Errorf("dial redis %q: %w", cfg.Addr, err)
}
reader := bufio.NewReader(conn)
if strings.TrimSpace(cfg.Password) != "" {
if err := writeRESPArray(conn, "AUTH", cfg.Password); err != nil {
conn.Close()
return nil, nil, fmt.Errorf("redis AUTH write: %w", err)
}
reply, err := readRESPValue(reader)
if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("redis AUTH read: %w", err)
}
if reply.kind != '+' || reply.stringValue != "OK" {
conn.Close()
return nil, nil, fmt.Errorf("redis AUTH unexpected response kind=%q value=%q", reply.kind, reply.stringValue)
}
}
if cfg.DB > 0 {
if err := writeRESPArray(conn, "SELECT", strconv.Itoa(cfg.DB)); err != nil {
conn.Close()
return nil, nil, fmt.Errorf("redis SELECT write: %w", err)
}
reply, err := readRESPValue(reader)
if err != nil {
conn.Close()
return nil, nil, fmt.Errorf("redis SELECT read: %w", err)
}
if reply.kind != '+' || reply.stringValue != "OK" {
conn.Close()
return nil, nil, fmt.Errorf("redis SELECT unexpected response kind=%q value=%q", reply.kind, reply.stringValue)
}
}
return conn, reader, nil
}
func normalizeRedisConfig(cfg RedisConfig) RedisConfig {
cfg.Addr = strings.TrimSpace(cfg.Addr)
cfg.Password = strings.TrimSpace(cfg.Password)
if cfg.DialTimeout <= 0 {
cfg.DialTimeout = defaultRedisDialTimeout
}
return cfg
}
type respValue struct {
kind byte
stringValue string
}
func writeRESPArray(w io.Writer, parts ...string) error {
if _, err := io.WriteString(w, fmt.Sprintf("*%d\r\n", len(parts))); err != nil {
return err
}
for _, part := range parts {
if _, err := io.WriteString(w, fmt.Sprintf("$%d\r\n%s\r\n", len(part), part)); err != nil {
return err
}
}
return nil
}
func readRESPValue(r *bufio.Reader) (respValue, error) {
prefix, err := r.ReadByte()
if err != nil {
return respValue{}, err
}
line, err := r.ReadString('\n')
if err != nil {
return respValue{}, err
}
line = strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r")
switch prefix {
case '+', '-', ':':
if prefix == '-' {
return respValue{}, fmt.Errorf("redis error: %s", line)
}
return respValue{kind: prefix, stringValue: line}, nil
case '$':
size, err := strconv.Atoi(line)
if err != nil {
return respValue{}, fmt.Errorf("parse bulk length: %w", err)
}
if size < 0 {
return respValue{kind: '_'}, nil
}
payload := make([]byte, size+2)
if _, err := io.ReadFull(r, payload); err != nil {
return respValue{}, err
}
return respValue{kind: '$', stringValue: string(payload[:size])}, nil
default:
return respValue{}, fmt.Errorf("unsupported redis response prefix %q", prefix)
}
}

View File

@@ -0,0 +1,460 @@
package routing
import (
"bufio"
"context"
"fmt"
"io"
"net"
"path/filepath"
"strconv"
"strings"
"sync"
"testing"
"time"
)
func TestBuildStickyKey(t *testing.T) {
t.Parallel()
tests := []struct {
name string
scope string
want string
wantErr bool
}{
{name: "conversation", scope: StickyScopeConversation, want: "lg:gpt-shared:m:gpt-5.4:conv:conversation-1"},
{name: "session", scope: StickyScopeSession, want: "lg:gpt-shared:m:gpt-5.4:sess:session-1"},
{name: "user", scope: StickyScopeUser, want: "lg:gpt-shared:m:gpt-5.4:user:user-1"},
{name: "invalid", scope: "bad", wantErr: true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := BuildStickyKey(tt.scope, "gpt-shared", "gpt-5.4", tt.name+"-1")
if tt.wantErr {
if err == nil {
t.Fatal("BuildStickyKey() error = nil, want error")
}
return
}
if err != nil {
t.Fatalf("BuildStickyKey() error = %v", err)
}
if got != tt.want {
t.Fatalf("BuildStickyKey() = %q, want %q", got, tt.want)
}
})
}
}
func TestInMemoryStickyStoreBindingFailureAndCooldown(t *testing.T) {
t.Parallel()
store := NewInMemoryStickyStore()
ctx := context.Background()
key, err := BuildStickyKey(StickyScopeConversation, "gpt-shared", "gpt-5.4", "conv-1")
if err != nil {
t.Fatalf("BuildStickyKey() error = %v", err)
}
if err := store.Set(ctx, key, StickyBinding{
LogicalGroupID: "gpt-shared",
PublicModel: "gpt-5.4",
RouteID: "asxs",
ShadowGroupID: "gpt-shared__asxs",
}, 2*time.Second); err != nil {
t.Fatalf("Set() error = %v", err)
}
binding, ok, err := store.Get(ctx, key)
if err != nil || !ok {
t.Fatalf("Get() = (%+v, %v, %v), want binding", binding, ok, err)
}
if binding.RouteID != "asxs" {
t.Fatalf("binding.RouteID = %q, want asxs", binding.RouteID)
}
if err := store.Delete(ctx, key); err != nil {
t.Fatalf("Delete() error = %v", err)
}
if _, ok, err := store.Get(ctx, key); err != nil || ok {
t.Fatalf("Get() after delete = (ok=%v, err=%v), want false nil", ok, err)
}
if err := store.SetRouteFailure(ctx, "asxs", RouteFailureState{
FailureCount: 2,
LastErrorClass: "timeout",
}, time.Second); err != nil {
t.Fatalf("SetRouteFailure() error = %v", err)
}
failure, ok, err := store.GetRouteFailure(ctx, "asxs")
if err != nil || !ok || failure.FailureCount != 2 {
t.Fatalf("GetRouteFailure() = (%+v, %v, %v), want count 2", failure, ok, err)
}
if err := store.ClearRouteFailure(ctx, "asxs"); err != nil {
t.Fatalf("ClearRouteFailure() error = %v", err)
}
if err := store.SetCooldown(ctx, "asxs", RouteCooldownState{
Reason: "cooldown",
}, time.Second); err != nil {
t.Fatalf("SetCooldown() error = %v", err)
}
cooldown, ok, err := store.GetCooldown(ctx, "asxs")
if err != nil || !ok || cooldown.RouteID != "asxs" {
t.Fatalf("GetCooldown() = (%+v, %v, %v), want route asxs", cooldown, ok, err)
}
if err := store.ClearCooldown(ctx, "asxs"); err != nil {
t.Fatalf("ClearCooldown() error = %v", err)
}
}
func TestInMemoryStickyStoreTTlExpiry(t *testing.T) {
t.Parallel()
store := NewInMemoryStickyStore()
ctx := context.Background()
key, err := BuildStickyKey(StickyScopeUser, "gpt-shared", "gpt-5.4", "user-1")
if err != nil {
t.Fatalf("BuildStickyKey() error = %v", err)
}
if err := store.Set(ctx, key, StickyBinding{
LogicalGroupID: "gpt-shared",
PublicModel: "gpt-5.4",
RouteID: "asxs",
ShadowGroupID: "gpt-shared__asxs",
}, 40*time.Millisecond); err != nil {
t.Fatalf("Set() error = %v", err)
}
time.Sleep(60 * time.Millisecond)
if _, ok, err := store.Get(ctx, key); err != nil || ok {
t.Fatalf("Get() after ttl = (ok=%v, err=%v), want false nil", ok, err)
}
}
func TestRedisStickyStoreRoundTripWithFakeServer(t *testing.T) {
t.Parallel()
server := newFakeRedisServer(t)
defer server.Close()
store, err := NewRedisStickyStore(context.Background(), RedisConfig{
Addr: server.Addr(),
Password: "secret",
DB: 2,
})
if err != nil {
t.Fatalf("NewRedisStickyStore() error = %v", err)
}
ctx := context.Background()
key, err := BuildStickyKey(StickyScopeSession, "gpt-shared", "gpt-5.4", "sess-1")
if err != nil {
t.Fatalf("BuildStickyKey() error = %v", err)
}
if err := store.Set(ctx, key, StickyBinding{
LogicalGroupID: "gpt-shared",
PublicModel: "gpt-5.4",
RouteID: "asxs",
ShadowGroupID: "gpt-shared__asxs",
}, time.Minute); err != nil {
t.Fatalf("Set() error = %v", err)
}
if binding, ok, err := store.Get(ctx, key); err != nil || !ok || binding.RouteID != "asxs" {
t.Fatalf("Get() = (%+v, %v, %v), want route asxs", binding, ok, err)
}
if err := store.SetRouteFailure(ctx, "asxs", RouteFailureState{
FailureCount: 3,
LastErrorClass: "timeout",
}, time.Minute); err != nil {
t.Fatalf("SetRouteFailure() error = %v", err)
}
if state, ok, err := store.GetRouteFailure(ctx, "asxs"); err != nil || !ok || state.FailureCount != 3 {
t.Fatalf("GetRouteFailure() = (%+v, %v, %v), want count 3", state, ok, err)
}
if err := store.SetCooldown(ctx, "asxs", RouteCooldownState{
Reason: "degraded",
}, time.Minute); err != nil {
t.Fatalf("SetCooldown() error = %v", err)
}
if state, ok, err := store.GetCooldown(ctx, "asxs"); err != nil || !ok || state.Reason != "degraded" {
t.Fatalf("GetCooldown() = (%+v, %v, %v), want reason degraded", state, ok, err)
}
if err := store.Delete(ctx, key); err != nil {
t.Fatalf("Delete() error = %v", err)
}
if _, ok, err := store.Get(ctx, key); err != nil || ok {
t.Fatalf("Get() after delete = (ok=%v, err=%v), want false nil", ok, err)
}
}
type fakeRedisServer struct {
t *testing.T
listener net.Listener
password string
mu sync.Mutex
values map[int]map[string]fakeRedisValue
}
type fakeRedisValue struct {
value string
expiresAt time.Time
}
func newFakeRedisServer(t *testing.T) *fakeRedisServer {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.Listen() error = %v", err)
}
server := &fakeRedisServer{
t: t,
listener: ln,
password: "secret",
values: make(map[int]map[string]fakeRedisValue),
}
go server.serve()
return server
}
func (s *fakeRedisServer) Addr() string {
return s.listener.Addr().String()
}
func (s *fakeRedisServer) Close() {
_ = s.listener.Close()
}
func (s *fakeRedisServer) serve() {
for {
conn, err := s.listener.Accept()
if err != nil {
return
}
go s.handleConn(conn)
}
}
func (s *fakeRedisServer) handleConn(conn net.Conn) {
defer conn.Close()
reader := bufio.NewReader(conn)
currentDB := 0
authed := false
for {
command, err := readRESPArray(reader)
if err != nil {
if err == io.EOF {
return
}
s.writeError(conn, err.Error())
return
}
if len(command) == 0 {
s.writeError(conn, "empty command")
continue
}
switch strings.ToUpper(command[0]) {
case "PING":
s.writeSimpleString(conn, "PONG")
case "AUTH":
if len(command) != 2 || command[1] != s.password {
s.writeError(conn, "ERR invalid password")
continue
}
authed = true
s.writeSimpleString(conn, "OK")
case "SELECT":
if len(command) != 2 {
s.writeError(conn, "ERR bad select")
continue
}
db, err := strconv.Atoi(command[1])
if err != nil {
s.writeError(conn, "ERR bad db")
continue
}
currentDB = db
s.writeSimpleString(conn, "OK")
case "SET":
if !authed {
s.writeError(conn, "NOAUTH")
continue
}
if len(command) != 5 || strings.ToUpper(command[3]) != "EX" {
s.writeError(conn, "ERR bad set")
continue
}
ttl, err := strconv.Atoi(command[4])
if err != nil {
s.writeError(conn, "ERR bad ttl")
continue
}
s.setValue(currentDB, command[1], command[2], time.Duration(ttl)*time.Second)
s.writeSimpleString(conn, "OK")
case "GET":
if !authed {
s.writeError(conn, "NOAUTH")
continue
}
if len(command) != 2 {
s.writeError(conn, "ERR bad get")
continue
}
value, ok := s.getValue(currentDB, command[1])
if !ok {
s.writeNullBulk(conn)
continue
}
s.writeBulk(conn, value)
case "DEL":
if !authed {
s.writeError(conn, "NOAUTH")
continue
}
if len(command) != 2 {
s.writeError(conn, "ERR bad del")
continue
}
s.deleteValue(currentDB, command[1])
s.writeInteger(conn, 1)
default:
s.writeError(conn, "ERR unknown command")
}
}
}
func readRESPArray(reader *bufio.Reader) ([]string, error) {
line, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
line = strings.TrimSuffix(strings.TrimSuffix(line, "\n"), "\r")
if !strings.HasPrefix(line, "*") {
return nil, fmt.Errorf("expected array, got %q", line)
}
count, err := strconv.Atoi(strings.TrimPrefix(line, "*"))
if err != nil {
return nil, err
}
parts := make([]string, 0, count)
for i := 0; i < count; i++ {
header, err := reader.ReadString('\n')
if err != nil {
return nil, err
}
header = strings.TrimSuffix(strings.TrimSuffix(header, "\n"), "\r")
if !strings.HasPrefix(header, "$") {
return nil, fmt.Errorf("expected bulk header, got %q", header)
}
size, err := strconv.Atoi(strings.TrimPrefix(header, "$"))
if err != nil {
return nil, err
}
payload := make([]byte, size+2)
if _, err := io.ReadFull(reader, payload); err != nil {
return nil, err
}
parts = append(parts, string(payload[:size]))
}
return parts, nil
}
func (s *fakeRedisServer) writeSimpleString(w io.Writer, value string) {
_, _ = io.WriteString(w, "+"+value+"\r\n")
}
func (s *fakeRedisServer) writeBulk(w io.Writer, value string) {
_, _ = io.WriteString(w, fmt.Sprintf("$%d\r\n%s\r\n", len(value), value))
}
func (s *fakeRedisServer) writeNullBulk(w io.Writer) {
_, _ = io.WriteString(w, "$-1\r\n")
}
func (s *fakeRedisServer) writeInteger(w io.Writer, value int) {
_, _ = io.WriteString(w, fmt.Sprintf(":%d\r\n", value))
}
func (s *fakeRedisServer) writeError(w io.Writer, message string) {
_, _ = io.WriteString(w, "-"+message+"\r\n")
}
func (s *fakeRedisServer) setValue(db int, key, value string, ttl time.Duration) {
s.mu.Lock()
defer s.mu.Unlock()
if s.values[db] == nil {
s.values[db] = make(map[string]fakeRedisValue)
}
s.values[db][key] = fakeRedisValue{value: value, expiresAt: time.Now().Add(ttl)}
}
func (s *fakeRedisServer) getValue(db int, key string) (string, bool) {
s.mu.Lock()
defer s.mu.Unlock()
value, ok := s.values[db][key]
if !ok {
return "", false
}
if !value.expiresAt.IsZero() && !value.expiresAt.After(time.Now()) {
delete(s.values[db], key)
return "", false
}
return value.value, true
}
func (s *fakeRedisServer) deleteValue(db int, key string) {
s.mu.Lock()
defer s.mu.Unlock()
if s.values[db] == nil {
return
}
delete(s.values[db], key)
}
func TestRedisStickyStoreRequiresAddr(t *testing.T) {
t.Parallel()
if _, err := NewRedisStickyStore(context.Background(), RedisConfig{}); err == nil {
t.Fatal("NewRedisStickyStore() error = nil, want missing addr")
}
}
func TestNormalizeRuntimeBackend(t *testing.T) {
t.Parallel()
if got, err := NormalizeRuntimeBackend(""); err != nil || got != RuntimeBackendMemory {
t.Fatalf("NormalizeRuntimeBackend(\"\") = (%q, %v), want memory nil", got, err)
}
if got, err := NormalizeRuntimeBackend("redis"); err != nil || got != RuntimeBackendRedis {
t.Fatalf("NormalizeRuntimeBackend(redis) = (%q, %v), want redis nil", got, err)
}
if _, err := NormalizeRuntimeBackend("bad"); err == nil {
t.Fatal("NormalizeRuntimeBackend(bad) error = nil, want error")
}
}
func TestRouteFailureAndCooldownKeyBuilders(t *testing.T) {
t.Parallel()
failureKey, err := BuildRouteFailureKey("asxs")
if err != nil || failureKey != "routefail:asxs" {
t.Fatalf("BuildRouteFailureKey() = (%q, %v), want routefail:asxs nil", failureKey, err)
}
cooldownKey, err := BuildRouteCooldownKey("asxs")
if err != nil || cooldownKey != "routecool:asxs" {
t.Fatalf("BuildRouteCooldownKey() = (%q, %v), want routecool:asxs nil", cooldownKey, err)
}
}
func TestRedisStickyStoreFixturePathExists(t *testing.T) {
t.Parallel()
if filepath.Base(t.TempDir()) == "" {
t.Fatal("temp dir base should not be empty")
}
}