250 lines
6.9 KiB
Go
250 lines
6.9 KiB
Go
package sqlite
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
const (
|
|
defaultLogicalGroupRoutePolicy = "priority"
|
|
defaultLogicalGroupStickyMode = "conversation_preferred"
|
|
defaultConversationTTLSeconds = 7200
|
|
defaultUserModelTTLSeconds = 1800
|
|
defaultFailoverThreshold = 2
|
|
defaultCooldownSeconds = 600
|
|
)
|
|
|
|
type LogicalGroup struct {
|
|
ID int64
|
|
LogicalGroupID string
|
|
DisplayName string
|
|
Status string
|
|
Description string
|
|
RoutePolicy string
|
|
StickyMode string
|
|
ConversationTTLSeconds int
|
|
UserModelTTLSeconds int
|
|
FailoverThreshold int
|
|
CooldownSeconds int
|
|
CreatedAt string
|
|
UpdatedAt string
|
|
}
|
|
|
|
type LogicalGroupsRepo struct {
|
|
db execQuerier
|
|
}
|
|
|
|
func newLogicalGroupsRepo(db execQuerier) *LogicalGroupsRepo {
|
|
return &LogicalGroupsRepo{db: db}
|
|
}
|
|
|
|
func (r *LogicalGroupsRepo) Create(ctx context.Context, group LogicalGroup) (int64, error) {
|
|
group, err := normalizeLogicalGroup(group)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
result, err := r.db.ExecContext(
|
|
ctx,
|
|
`INSERT INTO logical_groups (
|
|
logical_group_id,
|
|
display_name,
|
|
status,
|
|
description,
|
|
route_policy,
|
|
sticky_mode,
|
|
conversation_ttl_seconds,
|
|
user_model_ttl_seconds,
|
|
failover_threshold,
|
|
cooldown_seconds
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
group.LogicalGroupID,
|
|
group.DisplayName,
|
|
group.Status,
|
|
group.Description,
|
|
group.RoutePolicy,
|
|
group.StickyMode,
|
|
group.ConversationTTLSeconds,
|
|
group.UserModelTTLSeconds,
|
|
group.FailoverThreshold,
|
|
group.CooldownSeconds,
|
|
)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("insert logical group %q: %w", group.LogicalGroupID, err)
|
|
}
|
|
|
|
id, err := result.LastInsertId()
|
|
if err != nil {
|
|
return 0, fmt.Errorf("read inserted logical group id for %q: %w", group.LogicalGroupID, err)
|
|
}
|
|
return id, nil
|
|
}
|
|
|
|
func (r *LogicalGroupsRepo) GetByLogicalGroupID(ctx context.Context, logicalGroupID string) (LogicalGroup, error) {
|
|
logicalGroupID = strings.TrimSpace(logicalGroupID)
|
|
if logicalGroupID == "" {
|
|
return LogicalGroup{}, fmt.Errorf("logical_group_id is required")
|
|
}
|
|
|
|
var group LogicalGroup
|
|
if err := r.db.QueryRowContext(
|
|
ctx,
|
|
`SELECT id, logical_group_id, display_name, status, description, route_policy, sticky_mode, conversation_ttl_seconds, user_model_ttl_seconds, failover_threshold, cooldown_seconds, created_at, updated_at
|
|
FROM logical_groups
|
|
WHERE logical_group_id = ?`,
|
|
logicalGroupID,
|
|
).Scan(
|
|
&group.ID,
|
|
&group.LogicalGroupID,
|
|
&group.DisplayName,
|
|
&group.Status,
|
|
&group.Description,
|
|
&group.RoutePolicy,
|
|
&group.StickyMode,
|
|
&group.ConversationTTLSeconds,
|
|
&group.UserModelTTLSeconds,
|
|
&group.FailoverThreshold,
|
|
&group.CooldownSeconds,
|
|
&group.CreatedAt,
|
|
&group.UpdatedAt,
|
|
); err != nil {
|
|
return LogicalGroup{}, err
|
|
}
|
|
return group, nil
|
|
}
|
|
|
|
func (r *LogicalGroupsRepo) List(ctx context.Context) ([]LogicalGroup, error) {
|
|
rows, err := r.db.QueryContext(
|
|
ctx,
|
|
`SELECT id, logical_group_id, display_name, status, description, route_policy, sticky_mode, conversation_ttl_seconds, user_model_ttl_seconds, failover_threshold, cooldown_seconds, created_at, updated_at
|
|
FROM logical_groups
|
|
ORDER BY id ASC`,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list logical groups: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
groups := make([]LogicalGroup, 0)
|
|
for rows.Next() {
|
|
var group LogicalGroup
|
|
if err := rows.Scan(
|
|
&group.ID,
|
|
&group.LogicalGroupID,
|
|
&group.DisplayName,
|
|
&group.Status,
|
|
&group.Description,
|
|
&group.RoutePolicy,
|
|
&group.StickyMode,
|
|
&group.ConversationTTLSeconds,
|
|
&group.UserModelTTLSeconds,
|
|
&group.FailoverThreshold,
|
|
&group.CooldownSeconds,
|
|
&group.CreatedAt,
|
|
&group.UpdatedAt,
|
|
); err != nil {
|
|
return nil, fmt.Errorf("scan logical group: %w", err)
|
|
}
|
|
groups = append(groups, group)
|
|
}
|
|
if err := rows.Err(); err != nil {
|
|
return nil, fmt.Errorf("iterate logical groups: %w", err)
|
|
}
|
|
return groups, nil
|
|
}
|
|
|
|
func (r *LogicalGroupsRepo) UpdateByLogicalGroupID(ctx context.Context, group LogicalGroup) error {
|
|
group, err := normalizeLogicalGroup(group)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
result, err := r.db.ExecContext(
|
|
ctx,
|
|
`UPDATE logical_groups
|
|
SET display_name = ?, status = ?, description = ?, route_policy = ?, sticky_mode = ?, conversation_ttl_seconds = ?, user_model_ttl_seconds = ?, failover_threshold = ?, cooldown_seconds = ?, updated_at = CURRENT_TIMESTAMP
|
|
WHERE logical_group_id = ?`,
|
|
group.DisplayName,
|
|
group.Status,
|
|
group.Description,
|
|
group.RoutePolicy,
|
|
group.StickyMode,
|
|
group.ConversationTTLSeconds,
|
|
group.UserModelTTLSeconds,
|
|
group.FailoverThreshold,
|
|
group.CooldownSeconds,
|
|
group.LogicalGroupID,
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("update logical group %q: %w", group.LogicalGroupID, err)
|
|
}
|
|
affected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("read updated logical group rows for %q: %w", group.LogicalGroupID, err)
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("logical group %q not found", group.LogicalGroupID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (r *LogicalGroupsRepo) DeleteByLogicalGroupID(ctx context.Context, logicalGroupID string) error {
|
|
logicalGroupID = strings.TrimSpace(logicalGroupID)
|
|
if logicalGroupID == "" {
|
|
return fmt.Errorf("logical_group_id is required")
|
|
}
|
|
|
|
result, err := r.db.ExecContext(ctx, `DELETE FROM logical_groups WHERE logical_group_id = ?`, logicalGroupID)
|
|
if err != nil {
|
|
return fmt.Errorf("delete logical group %q: %w", logicalGroupID, err)
|
|
}
|
|
affected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return fmt.Errorf("read deleted logical group rows for %q: %w", logicalGroupID, err)
|
|
}
|
|
if affected == 0 {
|
|
return fmt.Errorf("logical group %q not found", logicalGroupID)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func normalizeLogicalGroup(group LogicalGroup) (LogicalGroup, error) {
|
|
group.LogicalGroupID = strings.TrimSpace(group.LogicalGroupID)
|
|
group.DisplayName = strings.TrimSpace(group.DisplayName)
|
|
group.Status = strings.TrimSpace(group.Status)
|
|
group.Description = strings.TrimSpace(group.Description)
|
|
group.RoutePolicy = strings.TrimSpace(group.RoutePolicy)
|
|
group.StickyMode = strings.TrimSpace(group.StickyMode)
|
|
|
|
switch {
|
|
case group.LogicalGroupID == "":
|
|
return LogicalGroup{}, fmt.Errorf("logical_group_id is required")
|
|
case group.DisplayName == "":
|
|
return LogicalGroup{}, fmt.Errorf("display_name is required")
|
|
case group.Status == "":
|
|
return LogicalGroup{}, fmt.Errorf("status is required")
|
|
}
|
|
|
|
if group.RoutePolicy == "" {
|
|
group.RoutePolicy = defaultLogicalGroupRoutePolicy
|
|
}
|
|
if group.StickyMode == "" {
|
|
group.StickyMode = defaultLogicalGroupStickyMode
|
|
}
|
|
if group.ConversationTTLSeconds <= 0 {
|
|
group.ConversationTTLSeconds = defaultConversationTTLSeconds
|
|
}
|
|
if group.UserModelTTLSeconds <= 0 {
|
|
group.UserModelTTLSeconds = defaultUserModelTTLSeconds
|
|
}
|
|
if group.FailoverThreshold <= 0 {
|
|
group.FailoverThreshold = defaultFailoverThreshold
|
|
}
|
|
if group.CooldownSeconds <= 0 {
|
|
group.CooldownSeconds = defaultCooldownSeconds
|
|
}
|
|
|
|
return group, nil
|
|
}
|