470 lines
15 KiB
Go
470 lines
15 KiB
Go
package app
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"sub2api-cn-relay-manager/internal/config"
|
|
"sub2api-cn-relay-manager/internal/routing"
|
|
)
|
|
|
|
const defaultStickyTTLSeconds = 600
|
|
|
|
type stickyStoreRuntime struct {
|
|
backend string
|
|
store routing.StickyStore
|
|
}
|
|
|
|
type SetStickyBindingRequest struct {
|
|
Scope string `json:"scope"`
|
|
LogicalGroupID string `json:"logical_group_id"`
|
|
PublicModel string `json:"public_model"`
|
|
SubjectID string `json:"subject_id"`
|
|
RouteID string `json:"route_id"`
|
|
ShadowGroupID string `json:"shadow_group_id"`
|
|
TTLSeconds int `json:"ttl_seconds,omitempty"`
|
|
}
|
|
|
|
type GetStickyBindingRequest struct {
|
|
Scope string
|
|
LogicalGroupID string
|
|
PublicModel string
|
|
SubjectID string
|
|
}
|
|
|
|
type StickyBindingInfo struct {
|
|
Backend string `json:"backend"`
|
|
Key string `json:"key"`
|
|
Scope string `json:"scope"`
|
|
LogicalGroupID string `json:"logical_group_id"`
|
|
PublicModel string `json:"public_model"`
|
|
SubjectID string `json:"subject_id"`
|
|
RouteID string `json:"route_id"`
|
|
ShadowGroupID string `json:"shadow_group_id"`
|
|
BoundAt string `json:"bound_at,omitempty"`
|
|
ExpiresAt string `json:"expires_at,omitempty"`
|
|
}
|
|
|
|
type SetRouteFailureRequest struct {
|
|
RouteID string `json:"route_id"`
|
|
FailureCount int `json:"failure_count"`
|
|
LastErrorClass string `json:"last_error_class,omitempty"`
|
|
TTLSeconds int `json:"ttl_seconds,omitempty"`
|
|
}
|
|
|
|
type RouteFailureInfo struct {
|
|
Backend string `json:"backend"`
|
|
Key string `json:"key"`
|
|
RouteID string `json:"route_id"`
|
|
FailureCount int `json:"failure_count"`
|
|
LastErrorClass string `json:"last_error_class,omitempty"`
|
|
LastFailureAt string `json:"last_failure_at,omitempty"`
|
|
ExpiresAt string `json:"expires_at,omitempty"`
|
|
}
|
|
|
|
type GetRouteFailureRequest struct {
|
|
RouteID string
|
|
}
|
|
|
|
type SetRouteCooldownRequest struct {
|
|
RouteID string `json:"route_id"`
|
|
Reason string `json:"reason,omitempty"`
|
|
TTLSeconds int `json:"ttl_seconds,omitempty"`
|
|
}
|
|
|
|
type GetRouteCooldownRequest struct {
|
|
RouteID string
|
|
}
|
|
|
|
type RouteCooldownInfo struct {
|
|
Backend string `json:"backend"`
|
|
Key string `json:"key"`
|
|
RouteID string `json:"route_id"`
|
|
Reason string `json:"reason,omitempty"`
|
|
Until string `json:"until,omitempty"`
|
|
ExpiresAt string `json:"expires_at,omitempty"`
|
|
}
|
|
|
|
func newStickyStoreRuntime(ctx context.Context, cfg config.RouteRuntimeConfig) (stickyStoreRuntime, error) {
|
|
backend, err := routing.NormalizeRuntimeBackend(cfg.Backend)
|
|
if err != nil {
|
|
return stickyStoreRuntime{}, err
|
|
}
|
|
|
|
switch backend {
|
|
case routing.RuntimeBackendMemory:
|
|
return stickyStoreRuntime{
|
|
backend: routing.RuntimeBackendMemory,
|
|
store: routing.NewInMemoryStickyStore(),
|
|
}, nil
|
|
case routing.RuntimeBackendRedis:
|
|
store, err := routing.NewRedisStickyStore(ctx, routing.RedisConfig{
|
|
Addr: cfg.Redis.Addr,
|
|
Password: cfg.Redis.Password,
|
|
DB: cfg.Redis.DB,
|
|
})
|
|
if err != nil {
|
|
return stickyStoreRuntime{}, err
|
|
}
|
|
return stickyStoreRuntime{
|
|
backend: routing.RuntimeBackendRedis,
|
|
store: store,
|
|
}, nil
|
|
default:
|
|
return stickyStoreRuntime{}, fmt.Errorf("unsupported sticky backend %q", backend)
|
|
}
|
|
}
|
|
|
|
func defaultStickyStoreRuntime() stickyStoreRuntime {
|
|
return stickyStoreRuntime{
|
|
backend: routing.RuntimeBackendMemory,
|
|
store: routing.NewInMemoryStickyStore(),
|
|
}
|
|
}
|
|
|
|
func handleSetStickyBinding(w http.ResponseWriter, r *http.Request, fn func(context.Context, SetStickyBindingRequest) (StickyBindingInfo, error)) {
|
|
if fn == nil {
|
|
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "set-sticky-binding action is not configured"})
|
|
return
|
|
}
|
|
var req SetStickyBindingRequest
|
|
if err := decodeJSON(r, &req); err != nil {
|
|
writeHTTPError(w, err)
|
|
return
|
|
}
|
|
info, err := fn(r.Context(), req)
|
|
if err != nil {
|
|
writeHTTPError(w, classifyError(err))
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusCreated, map[string]any{"binding": info})
|
|
}
|
|
|
|
func handleGetStickyBinding(w http.ResponseWriter, r *http.Request, fn func(context.Context, GetStickyBindingRequest) (StickyBindingInfo, error)) {
|
|
if fn == nil {
|
|
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-sticky-binding action is not configured"})
|
|
return
|
|
}
|
|
req, err := decodeGetStickyBindingRequest(r)
|
|
if err != nil {
|
|
writeHTTPError(w, err)
|
|
return
|
|
}
|
|
info, actionErr := fn(r.Context(), req)
|
|
if actionErr != nil {
|
|
writeHTTPError(w, classifyError(actionErr))
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{"binding": info})
|
|
}
|
|
|
|
func handleSetRouteFailure(w http.ResponseWriter, r *http.Request, fn func(context.Context, SetRouteFailureRequest) (RouteFailureInfo, error)) {
|
|
if fn == nil {
|
|
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "set-route-failure action is not configured"})
|
|
return
|
|
}
|
|
var req SetRouteFailureRequest
|
|
if err := decodeJSON(r, &req); err != nil {
|
|
writeHTTPError(w, err)
|
|
return
|
|
}
|
|
info, err := fn(r.Context(), req)
|
|
if err != nil {
|
|
writeHTTPError(w, classifyError(err))
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusCreated, map[string]any{"route_failure": info})
|
|
}
|
|
|
|
func handleGetRouteFailure(w http.ResponseWriter, r *http.Request, fn func(context.Context, GetRouteFailureRequest) (RouteFailureInfo, error)) {
|
|
if fn == nil {
|
|
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-route-failure action is not configured"})
|
|
return
|
|
}
|
|
req, err := decodeGetRouteFailureRequest(r)
|
|
if err != nil {
|
|
writeHTTPError(w, err)
|
|
return
|
|
}
|
|
info, actionErr := fn(r.Context(), req)
|
|
if actionErr != nil {
|
|
writeHTTPError(w, classifyError(actionErr))
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{"route_failure": info})
|
|
}
|
|
|
|
func handleSetRouteCooldown(w http.ResponseWriter, r *http.Request, fn func(context.Context, SetRouteCooldownRequest) (RouteCooldownInfo, error)) {
|
|
if fn == nil {
|
|
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "set-route-cooldown action is not configured"})
|
|
return
|
|
}
|
|
var req SetRouteCooldownRequest
|
|
if err := decodeJSON(r, &req); err != nil {
|
|
writeHTTPError(w, err)
|
|
return
|
|
}
|
|
info, err := fn(r.Context(), req)
|
|
if err != nil {
|
|
writeHTTPError(w, classifyError(err))
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusCreated, map[string]any{"route_cooldown": info})
|
|
}
|
|
|
|
func handleGetRouteCooldown(w http.ResponseWriter, r *http.Request, fn func(context.Context, GetRouteCooldownRequest) (RouteCooldownInfo, error)) {
|
|
if fn == nil {
|
|
writeHTTPError(w, &httpError{StatusCode: http.StatusInternalServerError, Code: "server_misconfigured", Message: "get-route-cooldown action is not configured"})
|
|
return
|
|
}
|
|
req, err := decodeGetRouteCooldownRequest(r)
|
|
if err != nil {
|
|
writeHTTPError(w, err)
|
|
return
|
|
}
|
|
info, actionErr := fn(r.Context(), req)
|
|
if actionErr != nil {
|
|
writeHTTPError(w, classifyError(actionErr))
|
|
return
|
|
}
|
|
writeJSON(w, http.StatusOK, map[string]any{"route_cooldown": info})
|
|
}
|
|
|
|
func buildSetStickyBindingAction(runtime stickyStoreRuntime) func(context.Context, SetStickyBindingRequest) (StickyBindingInfo, error) {
|
|
return func(ctx context.Context, req SetStickyBindingRequest) (StickyBindingInfo, error) {
|
|
key, err := routing.BuildStickyKey(req.Scope, req.LogicalGroupID, req.PublicModel, req.SubjectID)
|
|
if err != nil {
|
|
return StickyBindingInfo{}, err
|
|
}
|
|
ttl, err := secondsToDuration(req.TTLSeconds, defaultStickyTTLSeconds)
|
|
if err != nil {
|
|
return StickyBindingInfo{}, err
|
|
}
|
|
binding := routing.StickyBinding{
|
|
LogicalGroupID: strings.TrimSpace(req.LogicalGroupID),
|
|
PublicModel: strings.TrimSpace(req.PublicModel),
|
|
RouteID: strings.TrimSpace(req.RouteID),
|
|
ShadowGroupID: strings.TrimSpace(req.ShadowGroupID),
|
|
}
|
|
if err := runtime.store.Set(ctx, key, binding, ttl); err != nil {
|
|
return StickyBindingInfo{}, err
|
|
}
|
|
stored, ok, err := runtime.store.Get(ctx, key)
|
|
if err != nil {
|
|
return StickyBindingInfo{}, err
|
|
}
|
|
if !ok {
|
|
return StickyBindingInfo{}, fmt.Errorf("sticky binding %q not found after set", key)
|
|
}
|
|
return stickyBindingToInfo(runtime.backend, req.Scope, req.SubjectID, key, stored), nil
|
|
}
|
|
}
|
|
|
|
func buildGetStickyBindingAction(runtime stickyStoreRuntime) func(context.Context, GetStickyBindingRequest) (StickyBindingInfo, error) {
|
|
return func(ctx context.Context, req GetStickyBindingRequest) (StickyBindingInfo, error) {
|
|
key, err := routing.BuildStickyKey(req.Scope, req.LogicalGroupID, req.PublicModel, req.SubjectID)
|
|
if err != nil {
|
|
return StickyBindingInfo{}, err
|
|
}
|
|
stored, ok, err := runtime.store.Get(ctx, key)
|
|
if err != nil {
|
|
return StickyBindingInfo{}, err
|
|
}
|
|
if !ok {
|
|
return StickyBindingInfo{}, fmt.Errorf("sticky binding %q not found", key)
|
|
}
|
|
return stickyBindingToInfo(runtime.backend, req.Scope, req.SubjectID, key, stored), nil
|
|
}
|
|
}
|
|
|
|
func buildSetRouteFailureAction(runtime stickyStoreRuntime) func(context.Context, SetRouteFailureRequest) (RouteFailureInfo, error) {
|
|
return func(ctx context.Context, req SetRouteFailureRequest) (RouteFailureInfo, error) {
|
|
key, err := routing.BuildRouteFailureKey(req.RouteID)
|
|
if err != nil {
|
|
return RouteFailureInfo{}, err
|
|
}
|
|
ttl, err := secondsToDuration(req.TTLSeconds, defaultStickyTTLSeconds)
|
|
if err != nil {
|
|
return RouteFailureInfo{}, err
|
|
}
|
|
state := routing.RouteFailureState{
|
|
RouteID: strings.TrimSpace(req.RouteID),
|
|
FailureCount: req.FailureCount,
|
|
LastErrorClass: strings.TrimSpace(req.LastErrorClass),
|
|
}
|
|
if err := runtime.store.SetRouteFailure(ctx, req.RouteID, state, ttl); err != nil {
|
|
return RouteFailureInfo{}, err
|
|
}
|
|
stored, ok, err := runtime.store.GetRouteFailure(ctx, req.RouteID)
|
|
if err != nil {
|
|
return RouteFailureInfo{}, err
|
|
}
|
|
if !ok {
|
|
return RouteFailureInfo{}, fmt.Errorf("route failure %q not found after set", req.RouteID)
|
|
}
|
|
return routeFailureToInfo(runtime.backend, key, stored), nil
|
|
}
|
|
}
|
|
|
|
func buildGetRouteFailureAction(runtime stickyStoreRuntime) func(context.Context, GetRouteFailureRequest) (RouteFailureInfo, error) {
|
|
return func(ctx context.Context, req GetRouteFailureRequest) (RouteFailureInfo, error) {
|
|
key, err := routing.BuildRouteFailureKey(req.RouteID)
|
|
if err != nil {
|
|
return RouteFailureInfo{}, err
|
|
}
|
|
stored, ok, err := runtime.store.GetRouteFailure(ctx, req.RouteID)
|
|
if err != nil {
|
|
return RouteFailureInfo{}, err
|
|
}
|
|
if !ok {
|
|
return RouteFailureInfo{}, fmt.Errorf("route failure %q not found", req.RouteID)
|
|
}
|
|
return routeFailureToInfo(runtime.backend, key, stored), nil
|
|
}
|
|
}
|
|
|
|
func buildSetRouteCooldownAction(runtime stickyStoreRuntime) func(context.Context, SetRouteCooldownRequest) (RouteCooldownInfo, error) {
|
|
return func(ctx context.Context, req SetRouteCooldownRequest) (RouteCooldownInfo, error) {
|
|
key, err := routing.BuildRouteCooldownKey(req.RouteID)
|
|
if err != nil {
|
|
return RouteCooldownInfo{}, err
|
|
}
|
|
ttl, err := secondsToDuration(req.TTLSeconds, defaultStickyTTLSeconds)
|
|
if err != nil {
|
|
return RouteCooldownInfo{}, err
|
|
}
|
|
state := routing.RouteCooldownState{
|
|
RouteID: strings.TrimSpace(req.RouteID),
|
|
Reason: strings.TrimSpace(req.Reason),
|
|
}
|
|
if err := runtime.store.SetCooldown(ctx, req.RouteID, state, ttl); err != nil {
|
|
return RouteCooldownInfo{}, err
|
|
}
|
|
stored, ok, err := runtime.store.GetCooldown(ctx, req.RouteID)
|
|
if err != nil {
|
|
return RouteCooldownInfo{}, err
|
|
}
|
|
if !ok {
|
|
return RouteCooldownInfo{}, fmt.Errorf("route cooldown %q not found after set", req.RouteID)
|
|
}
|
|
return routeCooldownToInfo(runtime.backend, key, stored), nil
|
|
}
|
|
}
|
|
|
|
func buildGetRouteCooldownAction(runtime stickyStoreRuntime) func(context.Context, GetRouteCooldownRequest) (RouteCooldownInfo, error) {
|
|
return func(ctx context.Context, req GetRouteCooldownRequest) (RouteCooldownInfo, error) {
|
|
key, err := routing.BuildRouteCooldownKey(req.RouteID)
|
|
if err != nil {
|
|
return RouteCooldownInfo{}, err
|
|
}
|
|
stored, ok, err := runtime.store.GetCooldown(ctx, req.RouteID)
|
|
if err != nil {
|
|
return RouteCooldownInfo{}, err
|
|
}
|
|
if !ok {
|
|
return RouteCooldownInfo{}, fmt.Errorf("route cooldown %q not found", req.RouteID)
|
|
}
|
|
return routeCooldownToInfo(runtime.backend, key, stored), nil
|
|
}
|
|
}
|
|
|
|
func decodeGetStickyBindingRequest(r *http.Request) (GetStickyBindingRequest, *httpError) {
|
|
req := GetStickyBindingRequest{
|
|
Scope: strings.TrimSpace(r.URL.Query().Get("scope")),
|
|
LogicalGroupID: strings.TrimSpace(r.URL.Query().Get("logical_group_id")),
|
|
PublicModel: strings.TrimSpace(r.URL.Query().Get("public_model")),
|
|
SubjectID: strings.TrimSpace(r.URL.Query().Get("subject_id")),
|
|
}
|
|
if req.Scope == "" || req.LogicalGroupID == "" || req.PublicModel == "" || req.SubjectID == "" {
|
|
return GetStickyBindingRequest{}, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "scope, logical_group_id, public_model, subject_id are required"}
|
|
}
|
|
return req, nil
|
|
}
|
|
|
|
func decodeGetRouteFailureRequest(r *http.Request) (GetRouteFailureRequest, *httpError) {
|
|
req := GetRouteFailureRequest{
|
|
RouteID: strings.TrimSpace(r.URL.Query().Get("route_id")),
|
|
}
|
|
if req.RouteID == "" {
|
|
return GetRouteFailureRequest{}, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "route_id is required"}
|
|
}
|
|
return req, nil
|
|
}
|
|
|
|
func decodeGetRouteCooldownRequest(r *http.Request) (GetRouteCooldownRequest, *httpError) {
|
|
req := GetRouteCooldownRequest{
|
|
RouteID: strings.TrimSpace(r.URL.Query().Get("route_id")),
|
|
}
|
|
if req.RouteID == "" {
|
|
return GetRouteCooldownRequest{}, &httpError{StatusCode: http.StatusBadRequest, Code: "bad_request", Message: "route_id is required"}
|
|
}
|
|
return req, nil
|
|
}
|
|
|
|
func secondsToDuration(raw int, defaultSeconds int) (time.Duration, error) {
|
|
if raw == 0 {
|
|
raw = defaultSeconds
|
|
}
|
|
if raw < 0 {
|
|
return 0, fmt.Errorf("ttl_seconds must be >= 0")
|
|
}
|
|
if raw == 0 {
|
|
return 0, fmt.Errorf("ttl_seconds must be > 0")
|
|
}
|
|
return time.Duration(raw) * time.Second, nil
|
|
}
|
|
|
|
func stickyBindingToInfo(backend, scope, subjectID, key string, binding routing.StickyBinding) StickyBindingInfo {
|
|
return StickyBindingInfo{
|
|
Backend: backend,
|
|
Key: key,
|
|
Scope: strings.TrimSpace(scope),
|
|
SubjectID: strings.TrimSpace(subjectID),
|
|
LogicalGroupID: binding.LogicalGroupID,
|
|
PublicModel: binding.PublicModel,
|
|
RouteID: binding.RouteID,
|
|
ShadowGroupID: binding.ShadowGroupID,
|
|
BoundAt: binding.BoundAt,
|
|
ExpiresAt: binding.ExpiresAt,
|
|
}
|
|
}
|
|
|
|
func routeFailureToInfo(backend, key string, state routing.RouteFailureState) RouteFailureInfo {
|
|
return RouteFailureInfo{
|
|
Backend: backend,
|
|
Key: key,
|
|
RouteID: state.RouteID,
|
|
FailureCount: state.FailureCount,
|
|
LastErrorClass: state.LastErrorClass,
|
|
LastFailureAt: state.LastFailureAt,
|
|
ExpiresAt: state.ExpiresAt,
|
|
}
|
|
}
|
|
|
|
func routeCooldownToInfo(backend, key string, state routing.RouteCooldownState) RouteCooldownInfo {
|
|
return RouteCooldownInfo{
|
|
Backend: backend,
|
|
Key: key,
|
|
RouteID: state.RouteID,
|
|
Reason: state.Reason,
|
|
Until: state.Until,
|
|
ExpiresAt: state.ExpiresAt,
|
|
}
|
|
}
|
|
|
|
func parseTTLQuery(raw string) (int, error) {
|
|
raw = strings.TrimSpace(raw)
|
|
if raw == "" {
|
|
return 0, nil
|
|
}
|
|
value, err := strconv.Atoi(raw)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("ttl_seconds must be a positive integer")
|
|
}
|
|
return value, nil
|
|
}
|