Files
sub2api-cn-relay-manager/internal/store/sqlite/logical_groups_repo.go
2026-05-28 15:57:34 +08:00

250 lines
6.9 KiB
Go

package sqlite
import (
"context"
"fmt"
"strings"
)
const (
defaultLogicalGroupRoutePolicy = "priority"
defaultLogicalGroupStickyMode = "conversation_preferred"
defaultConversationTTLSeconds = 7200
defaultUserModelTTLSeconds = 1800
defaultFailoverThreshold = 2
defaultCooldownSeconds = 600
)
type LogicalGroup struct {
ID int64
LogicalGroupID string
DisplayName string
Status string
Description string
RoutePolicy string
StickyMode string
ConversationTTLSeconds int
UserModelTTLSeconds int
FailoverThreshold int
CooldownSeconds int
CreatedAt string
UpdatedAt string
}
type LogicalGroupsRepo struct {
db execQuerier
}
func newLogicalGroupsRepo(db execQuerier) *LogicalGroupsRepo {
return &LogicalGroupsRepo{db: db}
}
func (r *LogicalGroupsRepo) Create(ctx context.Context, group LogicalGroup) (int64, error) {
group, err := normalizeLogicalGroup(group)
if err != nil {
return 0, err
}
result, err := r.db.ExecContext(
ctx,
`INSERT INTO logical_groups (
logical_group_id,
display_name,
status,
description,
route_policy,
sticky_mode,
conversation_ttl_seconds,
user_model_ttl_seconds,
failover_threshold,
cooldown_seconds
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
group.LogicalGroupID,
group.DisplayName,
group.Status,
group.Description,
group.RoutePolicy,
group.StickyMode,
group.ConversationTTLSeconds,
group.UserModelTTLSeconds,
group.FailoverThreshold,
group.CooldownSeconds,
)
if err != nil {
return 0, fmt.Errorf("insert logical group %q: %w", group.LogicalGroupID, err)
}
id, err := result.LastInsertId()
if err != nil {
return 0, fmt.Errorf("read inserted logical group id for %q: %w", group.LogicalGroupID, err)
}
return id, nil
}
func (r *LogicalGroupsRepo) GetByLogicalGroupID(ctx context.Context, logicalGroupID string) (LogicalGroup, error) {
logicalGroupID = strings.TrimSpace(logicalGroupID)
if logicalGroupID == "" {
return LogicalGroup{}, fmt.Errorf("logical_group_id is required")
}
var group LogicalGroup
if err := r.db.QueryRowContext(
ctx,
`SELECT id, logical_group_id, display_name, status, description, route_policy, sticky_mode, conversation_ttl_seconds, user_model_ttl_seconds, failover_threshold, cooldown_seconds, created_at, updated_at
FROM logical_groups
WHERE logical_group_id = ?`,
logicalGroupID,
).Scan(
&group.ID,
&group.LogicalGroupID,
&group.DisplayName,
&group.Status,
&group.Description,
&group.RoutePolicy,
&group.StickyMode,
&group.ConversationTTLSeconds,
&group.UserModelTTLSeconds,
&group.FailoverThreshold,
&group.CooldownSeconds,
&group.CreatedAt,
&group.UpdatedAt,
); err != nil {
return LogicalGroup{}, err
}
return group, nil
}
func (r *LogicalGroupsRepo) List(ctx context.Context) ([]LogicalGroup, error) {
rows, err := r.db.QueryContext(
ctx,
`SELECT id, logical_group_id, display_name, status, description, route_policy, sticky_mode, conversation_ttl_seconds, user_model_ttl_seconds, failover_threshold, cooldown_seconds, created_at, updated_at
FROM logical_groups
ORDER BY id ASC`,
)
if err != nil {
return nil, fmt.Errorf("list logical groups: %w", err)
}
defer rows.Close()
groups := make([]LogicalGroup, 0)
for rows.Next() {
var group LogicalGroup
if err := rows.Scan(
&group.ID,
&group.LogicalGroupID,
&group.DisplayName,
&group.Status,
&group.Description,
&group.RoutePolicy,
&group.StickyMode,
&group.ConversationTTLSeconds,
&group.UserModelTTLSeconds,
&group.FailoverThreshold,
&group.CooldownSeconds,
&group.CreatedAt,
&group.UpdatedAt,
); err != nil {
return nil, fmt.Errorf("scan logical group: %w", err)
}
groups = append(groups, group)
}
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("iterate logical groups: %w", err)
}
return groups, nil
}
func (r *LogicalGroupsRepo) UpdateByLogicalGroupID(ctx context.Context, group LogicalGroup) error {
group, err := normalizeLogicalGroup(group)
if err != nil {
return err
}
result, err := r.db.ExecContext(
ctx,
`UPDATE logical_groups
SET display_name = ?, status = ?, description = ?, route_policy = ?, sticky_mode = ?, conversation_ttl_seconds = ?, user_model_ttl_seconds = ?, failover_threshold = ?, cooldown_seconds = ?, updated_at = CURRENT_TIMESTAMP
WHERE logical_group_id = ?`,
group.DisplayName,
group.Status,
group.Description,
group.RoutePolicy,
group.StickyMode,
group.ConversationTTLSeconds,
group.UserModelTTLSeconds,
group.FailoverThreshold,
group.CooldownSeconds,
group.LogicalGroupID,
)
if err != nil {
return fmt.Errorf("update logical group %q: %w", group.LogicalGroupID, err)
}
affected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("read updated logical group rows for %q: %w", group.LogicalGroupID, err)
}
if affected == 0 {
return fmt.Errorf("logical group %q not found", group.LogicalGroupID)
}
return nil
}
func (r *LogicalGroupsRepo) DeleteByLogicalGroupID(ctx context.Context, logicalGroupID string) error {
logicalGroupID = strings.TrimSpace(logicalGroupID)
if logicalGroupID == "" {
return fmt.Errorf("logical_group_id is required")
}
result, err := r.db.ExecContext(ctx, `DELETE FROM logical_groups WHERE logical_group_id = ?`, logicalGroupID)
if err != nil {
return fmt.Errorf("delete logical group %q: %w", logicalGroupID, err)
}
affected, err := result.RowsAffected()
if err != nil {
return fmt.Errorf("read deleted logical group rows for %q: %w", logicalGroupID, err)
}
if affected == 0 {
return fmt.Errorf("logical group %q not found", logicalGroupID)
}
return nil
}
func normalizeLogicalGroup(group LogicalGroup) (LogicalGroup, error) {
group.LogicalGroupID = strings.TrimSpace(group.LogicalGroupID)
group.DisplayName = strings.TrimSpace(group.DisplayName)
group.Status = strings.TrimSpace(group.Status)
group.Description = strings.TrimSpace(group.Description)
group.RoutePolicy = strings.TrimSpace(group.RoutePolicy)
group.StickyMode = strings.TrimSpace(group.StickyMode)
switch {
case group.LogicalGroupID == "":
return LogicalGroup{}, fmt.Errorf("logical_group_id is required")
case group.DisplayName == "":
return LogicalGroup{}, fmt.Errorf("display_name is required")
case group.Status == "":
return LogicalGroup{}, fmt.Errorf("status is required")
}
if group.RoutePolicy == "" {
group.RoutePolicy = defaultLogicalGroupRoutePolicy
}
if group.StickyMode == "" {
group.StickyMode = defaultLogicalGroupStickyMode
}
if group.ConversationTTLSeconds <= 0 {
group.ConversationTTLSeconds = defaultConversationTTLSeconds
}
if group.UserModelTTLSeconds <= 0 {
group.UserModelTTLSeconds = defaultUserModelTTLSeconds
}
if group.FailoverThreshold <= 0 {
group.FailoverThreshold = defaultFailoverThreshold
}
if group.CooldownSeconds <= 0 {
group.CooldownSeconds = defaultCooldownSeconds
}
return group, nil
}