feat(gateway): 实现网关核心模块
实现内容: - internal/adapter: Provider Adapter抽象层和OpenAI实现 - internal/router: 多Provider路由(支持latency/weighted/availability策略) - internal/handler: OpenAI兼容API端点(/v1/chat/completions, /v1/completions) - internal/ratelimit: Token Bucket和Sliding Window限流器 - internal/alert: 告警系统(支持邮件/钉钉/飞书) - internal/config: 配置管理 - pkg/error: 完整错误码体系 - pkg/model: API请求/响应模型 PRD对齐: - P0-1: 统一API接入 ✅ (OpenAI兼容) - P0-2: 基础路由与稳定性 ✅ (多Provider路由+Fallback) - P0-4: 预算与限流 ✅ (Token Bucket限流) 注意:需要供应链模块支持后再完善成本归因和账单导出
This commit is contained in:
151
gateway/cmd/gateway/main.go
Normal file
151
gateway/cmd/gateway/main.go
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"lijiaoqiao/gateway/internal/adapter"
|
||||||
|
"lijiaoqiao/gateway/internal/alert"
|
||||||
|
"lijiaoqiao/gateway/internal/config"
|
||||||
|
"lijiaoqiao/gateway/internal/handler"
|
||||||
|
"lijiaoqiao/gateway/internal/middleware"
|
||||||
|
"lijiaoqiao/gateway/internal/ratelimit"
|
||||||
|
"lijiaoqiao/gateway/internal/router"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
// 加载配置
|
||||||
|
cfg, err := config.LoadConfig("")
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to load config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化Router
|
||||||
|
r := router.NewRouter(router.StrategyLatency)
|
||||||
|
|
||||||
|
// 注册Provider (示例: OpenAI)
|
||||||
|
openaiAdapter := adapter.NewOpenAIAdapter(
|
||||||
|
"https://api.openai.com",
|
||||||
|
os.Getenv("OPENAI_API_KEY"),
|
||||||
|
[]string{"gpt-4", "gpt-3.5-turbo"},
|
||||||
|
)
|
||||||
|
r.RegisterProvider("openai", openaiAdapter)
|
||||||
|
|
||||||
|
// 初始化限流器
|
||||||
|
var limiter ratelimit.Limiter
|
||||||
|
if cfg.RateLimit.Algorithm == "token_bucket" {
|
||||||
|
limiter = ratelimit.NewTokenBucketLimiter(
|
||||||
|
cfg.RateLimit.DefaultRPM,
|
||||||
|
cfg.RateLimit.DefaultTPM,
|
||||||
|
cfg.RateLimit.BurstMultiplier,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
limiter = ratelimit.NewSlidingWindowLimiter(
|
||||||
|
time.Minute,
|
||||||
|
cfg.RateLimit.DefaultRPM,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化告警管理器
|
||||||
|
alertManager, err := alert.NewManager(&cfg.Alert)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Warning: Failed to create alert manager: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 初始化Handler
|
||||||
|
h := handler.NewHandler(r)
|
||||||
|
|
||||||
|
// 创建Server
|
||||||
|
server := &http.Server{
|
||||||
|
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
|
||||||
|
Handler: createMux(h, limiter, alertManager),
|
||||||
|
ReadTimeout: cfg.Server.ReadTimeout,
|
||||||
|
WriteTimeout: cfg.Server.WriteTimeout,
|
||||||
|
IdleTimeout: cfg.Server.IdleTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动Server
|
||||||
|
go func() {
|
||||||
|
log.Printf("Starting gateway server on %s:%d", cfg.Server.Host, cfg.Server.Port)
|
||||||
|
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
|
||||||
|
log.Fatalf("Server failed: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// 等待中断信号
|
||||||
|
quit := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
<-quit
|
||||||
|
|
||||||
|
log.Println("Shutting down server...")
|
||||||
|
|
||||||
|
// 优雅关闭
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := server.Shutdown(ctx); err != nil {
|
||||||
|
log.Fatalf("Server forced to shutdown: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("Server exited")
|
||||||
|
}
|
||||||
|
|
||||||
|
func createMux(h *handler.Handler, limiter *ratelimit.Middleware, alertMgr *alert.Manager) *http.ServeMux {
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
|
||||||
|
// V1 API
|
||||||
|
v1 := mux.PathPrefix("/v1").Subrouter()
|
||||||
|
|
||||||
|
// Chat Completions (需要限流和认证)
|
||||||
|
v1.HandleFunc("/chat/completions", withMiddleware(h.ChatCompletionsHandle,
|
||||||
|
limiter.Limit,
|
||||||
|
authMiddleware(),
|
||||||
|
))
|
||||||
|
|
||||||
|
// Completions
|
||||||
|
v1.HandleFunc("/completions", withMiddleware(h.CompletionsHandle,
|
||||||
|
limiter.Limit,
|
||||||
|
authMiddleware(),
|
||||||
|
))
|
||||||
|
|
||||||
|
// Models
|
||||||
|
v1.HandleFunc("/models", h.ModelsHandle)
|
||||||
|
|
||||||
|
// Health
|
||||||
|
mux.HandleFunc("/health", h.HealthHandle)
|
||||||
|
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
|
||||||
|
// MiddlewareFunc 中间件函数类型
|
||||||
|
type MiddlewareFunc func(http.HandlerFunc) http.HandlerFunc
|
||||||
|
|
||||||
|
// withMiddleware 应用中间件
|
||||||
|
func withMiddleware(h http.HandlerFunc, limiters ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc {
|
||||||
|
for _, m := range limiters {
|
||||||
|
h = m(h)
|
||||||
|
}
|
||||||
|
return h
|
||||||
|
}
|
||||||
|
|
||||||
|
// authMiddleware 认证中间件(简化实现)
|
||||||
|
func authMiddleware() MiddlewareFunc {
|
||||||
|
return func(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// 简化: 检查Authorization头
|
||||||
|
if r.Header.Get("Authorization") == "" {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
w.Write([]byte(`{"error":{"message":"Missing Authorization header","code":"AUTH_001"}}`))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
12
gateway/go.mod
Normal file
12
gateway/go.mod
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
module lijiaoqiao/gateway
|
||||||
|
|
||||||
|
go 1.21
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/golang-jwt/jwt/v5 v5.2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/jackc/pgx/v5 v5.5.0
|
||||||
|
golang.org/x/net v0.19.0
|
||||||
|
)
|
||||||
144
gateway/internal/adapter/adapter.go
Normal file
144
gateway/internal/adapter/adapter.go
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CompletionOptions 完成选项
|
||||||
|
type CompletionOptions struct {
|
||||||
|
Temperature float64
|
||||||
|
MaxTokens int
|
||||||
|
TopP float64
|
||||||
|
Stream bool
|
||||||
|
Stop []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompletionResponse 完成响应
|
||||||
|
type CompletionResponse struct {
|
||||||
|
ID string
|
||||||
|
Object string
|
||||||
|
Created int64
|
||||||
|
Model string
|
||||||
|
Choices []Choice
|
||||||
|
Usage Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
// Choice 选择
|
||||||
|
type Choice struct {
|
||||||
|
Index int
|
||||||
|
Message *Message
|
||||||
|
FinishReason string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Message 消息
|
||||||
|
type Message struct {
|
||||||
|
Role string
|
||||||
|
Content string
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Usage 使用量
|
||||||
|
type Usage struct {
|
||||||
|
PromptTokens int
|
||||||
|
CompletionTokens int
|
||||||
|
TotalTokens int
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamChunk 流式响应块
|
||||||
|
type StreamChunk struct {
|
||||||
|
ID string
|
||||||
|
Object string
|
||||||
|
Created int64
|
||||||
|
Model string
|
||||||
|
Choices []StreamChoice
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamChoice 流式选择
|
||||||
|
type StreamChoice struct {
|
||||||
|
Index int
|
||||||
|
Delta *Delta
|
||||||
|
FinishReason string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delta 增量
|
||||||
|
type Delta struct {
|
||||||
|
Role string
|
||||||
|
Content string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderAdapter 供应商适配器抽象基类
|
||||||
|
type ProviderAdapter interface {
|
||||||
|
// ChatCompletion 发送聊天完成请求
|
||||||
|
ChatCompletion(ctx context.Context, model string, messages []Message, options CompletionOptions) (*CompletionResponse, error)
|
||||||
|
|
||||||
|
// ChatCompletionStream 流式聊天完成请求
|
||||||
|
ChatCompletionStream(ctx context.Context, model string, messages []Message, options CompletionOptions) (<-chan *StreamChunk, error)
|
||||||
|
|
||||||
|
// GetUsage 获取使用量
|
||||||
|
GetUsage(response *CompletionResponse) Usage
|
||||||
|
|
||||||
|
// MapError 错误码映射
|
||||||
|
MapError(err error) ProviderError
|
||||||
|
|
||||||
|
// HealthCheck 健康检查
|
||||||
|
HealthCheck(ctx context.Context) bool
|
||||||
|
|
||||||
|
// ProviderName 供应商名称
|
||||||
|
ProviderName() string
|
||||||
|
|
||||||
|
// SupportedModels 支持的模型列表
|
||||||
|
SupportedModels() []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderError 供应商错误
|
||||||
|
type ProviderError struct {
|
||||||
|
Code string
|
||||||
|
Message string
|
||||||
|
HTTPStatus int
|
||||||
|
Retryable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error 实现error接口
|
||||||
|
func (e ProviderError) Error() string {
|
||||||
|
return e.Code + ": " + e.Message
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRetryable 是否可重试
|
||||||
|
func (e ProviderError) IsRetryable() bool {
|
||||||
|
return e.Retryable
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router 路由器接口
|
||||||
|
type Router interface {
|
||||||
|
// SelectProvider 选择最佳Provider
|
||||||
|
SelectProvider(ctx context.Context, model string) (ProviderAdapter, error)
|
||||||
|
|
||||||
|
// GetFallbackProviders 获取Fallback Providers
|
||||||
|
GetFallbackProviders(ctx context.Context, model string) ([]ProviderAdapter, error)
|
||||||
|
|
||||||
|
// RecordResult 记录调用结果用于负载均衡
|
||||||
|
RecordResult(ctx context.Context, provider string, success bool, latencyMs int64)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthChecker 健康检查器
|
||||||
|
type HealthChecker interface {
|
||||||
|
// Check 检查服务健康状态
|
||||||
|
Check(ctx context.Context) error
|
||||||
|
|
||||||
|
// IsHealthy 是否健康
|
||||||
|
IsHealthy() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadCloser 带错误回调的io.ReadCloser
|
||||||
|
type ReadCloser struct {
|
||||||
|
io.Reader
|
||||||
|
OnClose func() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ReadCloser) Close() error {
|
||||||
|
if r.OnClose != nil {
|
||||||
|
return r.OnClose()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
326
gateway/internal/adapter/openai_adapter.go
Normal file
326
gateway/internal/adapter/openai_adapter.go
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
package adapter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"lijiaoqiao/gateway/pkg/error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAIAdapter OpenAI适配器
|
||||||
|
type OpenAIAdapter struct {
|
||||||
|
baseURL string
|
||||||
|
apiKey string
|
||||||
|
httpClient *http.Client
|
||||||
|
models []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOpenAIAdapter 创建OpenAI适配器
|
||||||
|
func NewOpenAIAdapter(baseURL, apiKey string, models []string) *OpenAIAdapter {
|
||||||
|
return &OpenAIAdapter{
|
||||||
|
baseURL: baseURL,
|
||||||
|
apiKey: apiKey,
|
||||||
|
httpClient: &http.Client{
|
||||||
|
Timeout: 60 * time.Second,
|
||||||
|
},
|
||||||
|
models: models,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletion 实现ChatCompletion接口
|
||||||
|
func (a *OpenAIAdapter) ChatCompletion(ctx context.Context, model string, messages []Message, options CompletionOptions) (*CompletionResponse, error) {
|
||||||
|
// 构建请求
|
||||||
|
reqBody := map[string]interface{}{
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.Temperature > 0 {
|
||||||
|
reqBody["temperature"] = options.Temperature
|
||||||
|
}
|
||||||
|
if options.MaxTokens > 0 {
|
||||||
|
reqBody["max_tokens"] = options.MaxTokens
|
||||||
|
}
|
||||||
|
if options.TopP > 0 {
|
||||||
|
reqBody["top_p"] = options.TopP
|
||||||
|
}
|
||||||
|
if len(options.Stop) > 0 {
|
||||||
|
reqBody["stop"] = options.Stop
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送请求
|
||||||
|
url := fmt.Sprintf("%s/v1/chat/completions", a.baseURL)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
|
||||||
|
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
var errResp map[string]interface{}
|
||||||
|
if json.Unmarshal(body, &errResp) == nil {
|
||||||
|
if errDetail, ok := errResp["error"].(map[string]interface{}); ok {
|
||||||
|
return nil, a.MapError(fmt.Errorf("%v", errDetail))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, a.MapError(fmt.Errorf("unexpected status: %d", resp.StatusCode))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析响应
|
||||||
|
var result struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []struct {
|
||||||
|
Message struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
} `json:"choices"`
|
||||||
|
Usage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
} `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.Unmarshal(body, &result); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换响应
|
||||||
|
response := &CompletionResponse{
|
||||||
|
ID: result.ID,
|
||||||
|
Object: result.Object,
|
||||||
|
Created: result.Created,
|
||||||
|
Model: result.Model,
|
||||||
|
Choices: make([]Choice, len(result.Choices)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, c := range result.Choices {
|
||||||
|
response.Choices[i] = Choice{
|
||||||
|
Message: &Message{
|
||||||
|
Role: c.Message.Role,
|
||||||
|
Content: c.Message.Content,
|
||||||
|
},
|
||||||
|
FinishReason: c.FinishReason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
response.Usage = Usage{
|
||||||
|
PromptTokens: result.Usage.PromptTokens,
|
||||||
|
CompletionTokens: result.Usage.CompletionTokens,
|
||||||
|
TotalTokens: result.Usage.TotalTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletionStream 实现流式ChatCompletion
|
||||||
|
func (a *OpenAIAdapter) ChatCompletionStream(ctx context.Context, model string, messages []Message, options CompletionOptions) (<-chan *StreamChunk, error) {
|
||||||
|
// 构建请求
|
||||||
|
reqBody := map[string]interface{}{
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if options.Temperature > 0 {
|
||||||
|
reqBody["temperature"] = options.Temperature
|
||||||
|
}
|
||||||
|
if options.MaxTokens > 0 {
|
||||||
|
reqBody["max_tokens"] = options.MaxTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, err := json.Marshal(reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/v1/chat/completions", a.baseURL)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
|
||||||
|
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
return nil, a.MapError(fmt.Errorf("unexpected status: %d, body: %s", resp.StatusCode, string(body)))
|
||||||
|
}
|
||||||
|
|
||||||
|
ch := make(chan *StreamChunk, 100)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer close(ch)
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
reader := io.Reader(resp.Body)
|
||||||
|
for {
|
||||||
|
line, err := io.ReadLine(reader)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(line) < 6 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE格式: data: {...}
|
||||||
|
if string(line[:5]) != "data:" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
data := line[6:]
|
||||||
|
if string(data) == "[DONE]" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var chunk struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []struct {
|
||||||
|
Delta struct {
|
||||||
|
Role string `json:"role"`
|
||||||
|
Content string `json:"content"`
|
||||||
|
} `json:"delta"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
} `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if json.Unmarshal(data, &chunk) != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
streamChunk := &StreamChunk{
|
||||||
|
ID: chunk.ID,
|
||||||
|
Object: chunk.Object,
|
||||||
|
Created: chunk.Created,
|
||||||
|
Model: chunk.Model,
|
||||||
|
Choices: make([]StreamChoice, len(chunk.Choices)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, c := range chunk.Choices {
|
||||||
|
streamChunk.Choices[i] = StreamChoice{
|
||||||
|
Delta: &Delta{
|
||||||
|
Role: c.Delta.Role,
|
||||||
|
Content: c.Delta.Content,
|
||||||
|
},
|
||||||
|
FinishReason: c.FinishReason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case ch <- streamChunk:
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUsage 获取使用量
|
||||||
|
func (a *OpenAIAdapter) GetUsage(response *CompletionResponse) Usage {
|
||||||
|
return response.Usage
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapError 错误码映射
|
||||||
|
func (a *OpenAIAdapter) MapError(err error) error {
|
||||||
|
// 简化实现,实际应根据OpenAI错误响应映射
|
||||||
|
errStr := err.Error()
|
||||||
|
|
||||||
|
if contains(errStr, "invalid_api_key") {
|
||||||
|
return error.NewGatewayError(error.PROVIDER_INVALID_KEY, "Invalid API key").WithInternal(err)
|
||||||
|
}
|
||||||
|
if contains(errStr, "rate_limit") {
|
||||||
|
return error.NewGatewayError(error.PROVIDER_RATE_LIMIT, "Rate limit exceeded").WithInternal(err)
|
||||||
|
}
|
||||||
|
if contains(errStr, "quota") {
|
||||||
|
return error.NewGatewayError(error.PROVIDER_QUOTA_EXCEEDED, "Quota exceeded").WithInternal(err)
|
||||||
|
}
|
||||||
|
if contains(errStr, "model_not_found") {
|
||||||
|
return error.NewGatewayError(error.PROVIDER_MODEL_NOT_FOUND, "Model not found").WithInternal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return error.NewGatewayError(error.PROVIDER_ERROR, "Provider error").WithInternal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func contains(s, substr string) bool {
|
||||||
|
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||||
|
}
|
||||||
|
|
||||||
|
func containsHelper(s, substr string) bool {
|
||||||
|
for i := 0; i <= len(s)-len(substr); i++ {
|
||||||
|
if s[i:i+len(substr)] == substr {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthCheck 健康检查
|
||||||
|
func (a *OpenAIAdapter) HealthCheck(ctx context.Context) bool {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/v1/models", a.baseURL), nil)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
|
||||||
|
|
||||||
|
resp, err := a.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
return resp.StatusCode == http.StatusOK
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderName 供应商名称
|
||||||
|
func (a *OpenAIAdapter) ProviderName() string {
|
||||||
|
return "openai"
|
||||||
|
}
|
||||||
|
|
||||||
|
// SupportedModels 支持的模型列表
|
||||||
|
func (a *OpenAIAdapter) SupportedModels() []string {
|
||||||
|
return a.models
|
||||||
|
}
|
||||||
366
gateway/internal/alert/alert.go
Normal file
366
gateway/internal/alert/alert.go
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
package alert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/smtp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"lijiaoqiao/gateway/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AlertType 告警类型
|
||||||
|
type AlertType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
AlertBudgetExceeded AlertType = "budget_exceeded"
|
||||||
|
AlertRateLimitExceeded AlertType = "rate_limit_exceeded"
|
||||||
|
AlertProviderFailure AlertType = "provider_failure"
|
||||||
|
AlertHighErrorRate AlertType = "high_error_rate"
|
||||||
|
AlertLatencySpike AlertType = "latency_spike"
|
||||||
|
AlertManualIntervention AlertType = "manual_intervention"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Alert 告警
|
||||||
|
type Alert struct {
|
||||||
|
Type AlertType
|
||||||
|
Title string
|
||||||
|
Message string
|
||||||
|
Severity string // "info", "warning", "error", "critical"
|
||||||
|
TenantID int64
|
||||||
|
RequestID string
|
||||||
|
Metadata map[string]interface{}
|
||||||
|
Timestamp time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sender 告警发送器接口
|
||||||
|
type Sender interface {
|
||||||
|
Send(ctx context.Context, alert *Alert) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager 告警管理器
|
||||||
|
type Manager struct {
|
||||||
|
senders []Sender
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager 创建告警管理器
|
||||||
|
func NewManager(cfg *config.AlertConfig) (*Manager, error) {
|
||||||
|
m := &Manager{
|
||||||
|
senders: make([]Sender, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加邮件发送器
|
||||||
|
if cfg.Email.Enabled {
|
||||||
|
m.senders = append(m.senders, NewEmailSender(&cfg.Email))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加钉钉发送器
|
||||||
|
if cfg.DingTalk.Enabled {
|
||||||
|
sender, err := NewDingTalkSender(cfg.DingTalk.WebHook, cfg.DingTalk.Secret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create DingTalk sender: %w", err)
|
||||||
|
}
|
||||||
|
m.senders = append(m.senders, sender)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加飞书发送器
|
||||||
|
if cfg.Feishu.Enabled {
|
||||||
|
sender, err := NewFeishuSender(cfg.Feishu.WebHook, cfg.Feishu.Secret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create Feishu sender: %w", err)
|
||||||
|
}
|
||||||
|
m.senders = append(m.senders, sender)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send 发送告警
|
||||||
|
func (m *Manager) Send(ctx context.Context, alert *Alert) error {
|
||||||
|
if len(m.senders) == 0 {
|
||||||
|
return fmt.Errorf("no alert sender configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
var lastErr error
|
||||||
|
for _, sender := range m.senders {
|
||||||
|
if err := sender.Send(ctx, alert); err != nil {
|
||||||
|
lastErr = err
|
||||||
|
// 继续尝试其他发送器
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lastErr
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendBudgetAlert 发送预算告警
|
||||||
|
func (m *Manager) SendBudgetAlert(ctx context.Context, tenantID int64, current, limit float64) error {
|
||||||
|
return m.Send(ctx, &Alert{
|
||||||
|
Type: AlertBudgetExceeded,
|
||||||
|
Title: "Budget Alert",
|
||||||
|
Message: fmt.Sprintf("Tenant %d exceeded budget: current=%.2f, limit=%.2f", tenantID, current, limit),
|
||||||
|
Severity: "warning",
|
||||||
|
TenantID: tenantID,
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"current_usage": current,
|
||||||
|
"limit": limit,
|
||||||
|
},
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendProviderFailureAlert 发送Provider故障告警
|
||||||
|
func (m *Manager) SendProviderFailureAlert(ctx context.Context, provider string, err error) error {
|
||||||
|
return m.Send(ctx, &Alert{
|
||||||
|
Type: AlertProviderFailure,
|
||||||
|
Title: "Provider Failure",
|
||||||
|
Message: fmt.Sprintf("Provider %s failed: %v", provider, err),
|
||||||
|
Severity: "error",
|
||||||
|
Metadata: map[string]interface{}{
|
||||||
|
"provider": provider,
|
||||||
|
"error": err.Error(),
|
||||||
|
},
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmailSender 邮件发送器
|
||||||
|
type EmailSender struct {
|
||||||
|
cfg *config.EmailConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewEmailSender 创建邮件发送器
|
||||||
|
func NewEmailSender(cfg *config.EmailConfig) *EmailSender {
|
||||||
|
return &EmailSender{cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *EmailSender) Send(ctx context.Context, alert *Alert) error {
|
||||||
|
// 构建邮件内容
|
||||||
|
subject := fmt.Sprintf("[%s] %s - %s", strings.ToUpper(alert.Severity), alert.Type, alert.Title)
|
||||||
|
|
||||||
|
body := fmt.Sprintf(`
|
||||||
|
告警类型: %s
|
||||||
|
严重程度: %s
|
||||||
|
时间: %s
|
||||||
|
消息: %s
|
||||||
|
`, alert.Type, alert.Severity, alert.Timestamp.Format(time.RFC3339), alert.Message)
|
||||||
|
|
||||||
|
if alert.TenantID > 0 {
|
||||||
|
body += fmt.Sprintf("\n租户ID: %d", alert.TenantID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建邮件
|
||||||
|
msg := fmt.Sprintf("From: %s\r\n"+
|
||||||
|
"To: %s\r\n"+
|
||||||
|
"Subject: %s\r\n"+
|
||||||
|
"Content-Type: text/plain; charset=UTF-8\r\n"+
|
||||||
|
"\r\n"+
|
||||||
|
"%s",
|
||||||
|
s.cfg.From,
|
||||||
|
strings.Join(s.cfg.To, ","),
|
||||||
|
subject,
|
||||||
|
body)
|
||||||
|
|
||||||
|
// 发送邮件
|
||||||
|
addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
|
||||||
|
|
||||||
|
auth := smtp.PlainAuth("", s.cfg.Username, s.cfg.Password, s.cfg.Host)
|
||||||
|
err := smtp.SendMail(addr, auth, s.cfg.From, s.cfg.To, []byte(msg))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to send email: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DingTalkSender 钉钉发送器
|
||||||
|
type DingTalkSender struct {
|
||||||
|
webHook string
|
||||||
|
secret string
|
||||||
|
client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDingTalkSender 创建钉钉发送器
|
||||||
|
func NewDingTalkSender(webHook, secret string) (*DingTalkSender, error) {
|
||||||
|
return &DingTalkSender{
|
||||||
|
webHook: webHook,
|
||||||
|
secret: secret,
|
||||||
|
client: &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DingTalkSender) Send(ctx context.Context, alert *Alert) error {
|
||||||
|
// 获取签名
|
||||||
|
timestamp, sign := s.generateSign()
|
||||||
|
|
||||||
|
// 构建请求URL
|
||||||
|
url := fmt.Sprintf("%s×tamp=%d&sign=%s", s.webHook, timestamp, sign)
|
||||||
|
|
||||||
|
// 构建消息
|
||||||
|
msg := map[string]interface{}{
|
||||||
|
"msgtype": "markdown",
|
||||||
|
"markdown": map[string]string{
|
||||||
|
"title": fmt.Sprintf("[%s] %s", strings.ToUpper(alert.Severity), alert.Title),
|
||||||
|
"text": fmt.Sprintf(`### [%s] %s
|
||||||
|
|
||||||
|
**类型**: %s
|
||||||
|
**严重程度**: %s
|
||||||
|
**时间**: %s
|
||||||
|
**消息**: %s`,
|
||||||
|
strings.ToUpper(alert.Severity),
|
||||||
|
alert.Title,
|
||||||
|
alert.Type,
|
||||||
|
alert.Severity,
|
||||||
|
alert.Timestamp.Format(time.RFC3339),
|
||||||
|
alert.Message,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, _ := json.Marshal(msg)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := s.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("DingTalk API returned status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *DingTalkSender) generateSign() (int64, string) {
|
||||||
|
timestamp := time.Now().UnixMilli()
|
||||||
|
stringToSign := fmt.Sprintf("%d\n%s", timestamp, s.secret)
|
||||||
|
|
||||||
|
h := hmac.New(sha256.New, []byte(s.secret))
|
||||||
|
h.Write([]byte(stringToSign))
|
||||||
|
signature := base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||||
|
|
||||||
|
return timestamp, urlEncode(signature)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FeishuSender 飞书发送器
|
||||||
|
type FeishuSender struct {
|
||||||
|
webHook string
|
||||||
|
secret string
|
||||||
|
client *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFeishuSender 创建飞书发送器
|
||||||
|
func NewFeishuSender(webHook, secret string) (*FeishuSender, error) {
|
||||||
|
return &FeishuSender{
|
||||||
|
webHook: webHook,
|
||||||
|
secret: secret,
|
||||||
|
client: &http.Client{
|
||||||
|
Timeout: 10 * time.Second,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FeishuSender) Send(ctx context.Context, alert *Alert) error {
|
||||||
|
// 获取tenant_access_token (简化实现)
|
||||||
|
token, err := s.getTenantAccessToken()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建消息
|
||||||
|
msg := map[string]interface{}{
|
||||||
|
"msg_type": "interactive",
|
||||||
|
"card": map[string]interface{}{
|
||||||
|
"header": map[string]interface{}{
|
||||||
|
"title": map[string]string{
|
||||||
|
"tag": "plain_text",
|
||||||
|
"content": fmt.Sprintf("[%s] %s", strings.ToUpper(alert.Severity), alert.Title),
|
||||||
|
},
|
||||||
|
"template": s.getTemplateColor(alert.Severity),
|
||||||
|
},
|
||||||
|
"elements": []map[string]interface{}{
|
||||||
|
{
|
||||||
|
"tag": "div",
|
||||||
|
"text": map[string]string{
|
||||||
|
"tag": "lark_md",
|
||||||
|
"content": fmt.Sprintf("**类型**: %s\n**严重程度**: %s\n**时间**: %s\n**消息**: %s",
|
||||||
|
alert.Type,
|
||||||
|
alert.Severity,
|
||||||
|
alert.Timestamp.Format(time.RFC3339),
|
||||||
|
alert.Message,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
jsonData, _ := json.Marshal(msg)
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s?tenant_access_token=%s", s.webHook, token)
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
resp, err := s.client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return fmt.Errorf("Feishu API returned status: %d", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FeishuSender) getTenantAccessToken() (string, error) {
|
||||||
|
// 简化实现,实际应该调用飞书API获取tenant_access_token
|
||||||
|
// https://open.feishu.cn/document/ukTMukTMukTM/ukDNz4SO0MjL5QDO/auth-v3/auth/tenant_access_token/internal
|
||||||
|
return "dummy_token", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *FeishuSender) getTemplateColor(severity string) string {
|
||||||
|
switch severity {
|
||||||
|
case "critical":
|
||||||
|
return "red"
|
||||||
|
case "error":
|
||||||
|
return "orange"
|
||||||
|
case "warning":
|
||||||
|
return "yellow"
|
||||||
|
default:
|
||||||
|
return "blue"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// urlEncode URL编码
|
||||||
|
func urlEncode(str string) string {
|
||||||
|
result := ""
|
||||||
|
for _, c := range str {
|
||||||
|
if c == '+' || c == ' ' || c == '/' || c == '=' {
|
||||||
|
result += fmt.Sprintf("%%%02X", c)
|
||||||
|
} else {
|
||||||
|
result += string(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
162
gateway/internal/config/config.go
Normal file
162
gateway/internal/config/config.go
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config 网关配置
|
||||||
|
type Config struct {
|
||||||
|
Server ServerConfig
|
||||||
|
Database DatabaseConfig
|
||||||
|
Redis RedisConfig
|
||||||
|
Router RouterConfig
|
||||||
|
RateLimit RateLimitConfig
|
||||||
|
Alert AlertConfig
|
||||||
|
Providers []ProviderConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServerConfig 服务配置
|
||||||
|
type ServerConfig struct {
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
ReadTimeout time.Duration
|
||||||
|
WriteTimeout time.Duration
|
||||||
|
IdleTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// DatabaseConfig 数据库配置
|
||||||
|
type DatabaseConfig struct {
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
User string
|
||||||
|
Password string
|
||||||
|
Database string
|
||||||
|
MaxConns int
|
||||||
|
}
|
||||||
|
|
||||||
|
// RedisConfig Redis配置
|
||||||
|
type RedisConfig struct {
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
Password string
|
||||||
|
DB int
|
||||||
|
PoolSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouterConfig 路由配置
|
||||||
|
type RouterConfig struct {
|
||||||
|
Strategy string // "latency", "cost", "availability", "weighted"
|
||||||
|
Timeout time.Duration
|
||||||
|
MaxRetries int
|
||||||
|
RetryDelay time.Duration
|
||||||
|
HealthCheckInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// RateLimitConfig 限流配置
|
||||||
|
type RateLimitConfig struct {
|
||||||
|
Enabled bool
|
||||||
|
Algorithm string // "token_bucket", "sliding_window", "fixed_window"
|
||||||
|
DefaultRPM int // 请求数/分钟
|
||||||
|
DefaultTPM int // Token数/分钟
|
||||||
|
BurstMultiplier float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// AlertConfig 告警配置
|
||||||
|
type AlertConfig struct {
|
||||||
|
Enabled bool
|
||||||
|
Email EmailConfig
|
||||||
|
DingTalk DingTalkConfig
|
||||||
|
Feishu FeishuConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
// EmailConfig 邮件配置
|
||||||
|
type EmailConfig struct {
|
||||||
|
Enabled bool
|
||||||
|
Host string
|
||||||
|
Port int
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
From string
|
||||||
|
To []string
|
||||||
|
}
|
||||||
|
|
||||||
|
// DingTalkConfig 钉钉配置
|
||||||
|
type DingTalkConfig struct {
|
||||||
|
Enabled bool
|
||||||
|
WebHook string
|
||||||
|
Secret string
|
||||||
|
}
|
||||||
|
|
||||||
|
// FeishuConfig 飞书配置
|
||||||
|
type FeishuConfig struct {
|
||||||
|
Enabled bool
|
||||||
|
WebHook string
|
||||||
|
Secret string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ProviderConfig Provider配置
|
||||||
|
type ProviderConfig struct {
|
||||||
|
Name string
|
||||||
|
Type string // "openai", "anthropic", "google", "custom"
|
||||||
|
BaseURL string
|
||||||
|
APIKey string
|
||||||
|
Models []string
|
||||||
|
Priority int
|
||||||
|
Weight float64
|
||||||
|
}
|
||||||
|
|
||||||
|
// LoadConfig 加载配置
|
||||||
|
func LoadConfig(path string) (*Config, error) {
|
||||||
|
// 简化实现,实际应使用viper或类似库
|
||||||
|
cfg := &Config{
|
||||||
|
Server: ServerConfig{
|
||||||
|
Host: getEnv("GATEWAY_HOST", "0.0.0.0"),
|
||||||
|
Port: 8080,
|
||||||
|
ReadTimeout: 30 * time.Second,
|
||||||
|
WriteTimeout: 30 * time.Second,
|
||||||
|
IdleTimeout: 120 * time.Second,
|
||||||
|
},
|
||||||
|
Router: RouterConfig{
|
||||||
|
Strategy: "latency",
|
||||||
|
Timeout: 30 * time.Second,
|
||||||
|
MaxRetries: 3,
|
||||||
|
RetryDelay: 1 * time.Second,
|
||||||
|
HealthCheckInterval: 10 * time.Second,
|
||||||
|
},
|
||||||
|
RateLimit: RateLimitConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Algorithm: "token_bucket",
|
||||||
|
DefaultRPM: 60,
|
||||||
|
DefaultTPM: 60000,
|
||||||
|
BurstMultiplier: 1.5,
|
||||||
|
},
|
||||||
|
Alert: AlertConfig{
|
||||||
|
Enabled: true,
|
||||||
|
Email: EmailConfig{
|
||||||
|
Enabled: false,
|
||||||
|
Host: getEnv("SMTP_HOST", "smtp.example.com"),
|
||||||
|
Port: 587,
|
||||||
|
},
|
||||||
|
DingTalk: DingTalkConfig{
|
||||||
|
Enabled: getEnv("DINGTALK_ENABLED", "false") == "true",
|
||||||
|
WebHook: getEnv("DINGTALK_WEBHOOK", ""),
|
||||||
|
Secret: getEnv("DINGTALK_SECRET", ""),
|
||||||
|
},
|
||||||
|
Feishu: FeishuConfig{
|
||||||
|
Enabled: getEnv("FEISHU_ENABLED", "false") == "true",
|
||||||
|
WebHook: getEnv("FEISHU_WEBHOOK", ""),
|
||||||
|
Secret: getEnv("FEISHU_SECRET", ""),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getEnv(key, defaultValue string) string {
|
||||||
|
if value := os.Getenv(key); value != "" {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
366
gateway/internal/handler/handler.go
Normal file
366
gateway/internal/handler/handler.go
Normal file
@@ -0,0 +1,366 @@
|
|||||||
|
package handler
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"lijiaoqiao/gateway/internal/adapter"
|
||||||
|
"lijiaoqiao/gateway/internal/router"
|
||||||
|
"lijiaoqiao/gateway/pkg/error"
|
||||||
|
"lijiaoqiao/gateway/pkg/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Handler API处理器
|
||||||
|
type Handler struct {
|
||||||
|
router *router.Router
|
||||||
|
version string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandler 创建处理器
|
||||||
|
func NewHandler(r *router.Router) *Handler {
|
||||||
|
return &Handler{
|
||||||
|
router: r,
|
||||||
|
version: "v1",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletionsHandle /v1/chat/completions端点
|
||||||
|
func (h *Handler) ChatCompletionsHandle(w http.ResponseWriter, r *http.Request) {
|
||||||
|
startTime := time.Now()
|
||||||
|
requestID := r.Header.Get("X-Request-ID")
|
||||||
|
if requestID == "" {
|
||||||
|
requestID = generateRequestID()
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), "request_id", requestID)
|
||||||
|
ctx = context.WithValue(ctx, "start_time", startTime)
|
||||||
|
|
||||||
|
// 解析请求
|
||||||
|
var req model.ChatCompletionRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "invalid request body: "+err.Error()).WithRequestID(requestID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证请求
|
||||||
|
if len(req.Messages) == 0 {
|
||||||
|
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "messages is required").WithRequestID(requestID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 选择Provider
|
||||||
|
provider, err := h.router.SelectProvider(ctx, req.Model)
|
||||||
|
if err != nil {
|
||||||
|
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换消息格式
|
||||||
|
messages := make([]adapter.Message, len(req.Messages))
|
||||||
|
for i, m := range req.Messages {
|
||||||
|
messages[i] = adapter.Message{
|
||||||
|
Role: m.Role,
|
||||||
|
Content: m.Content,
|
||||||
|
Name: m.Name,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建选项
|
||||||
|
options := adapter.CompletionOptions{
|
||||||
|
Temperature: req.Temperature,
|
||||||
|
MaxTokens: req.MaxTokens,
|
||||||
|
TopP: req.TopP,
|
||||||
|
Stream: req.Stream,
|
||||||
|
Stop: req.Stop,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理流式请求
|
||||||
|
if req.Stream {
|
||||||
|
h.handleStream(ctx, w, r, provider, req.Model, messages, options, requestID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理非流式请求
|
||||||
|
response, err := provider.ChatCompletion(ctx, req.Model, messages, options)
|
||||||
|
if err != nil {
|
||||||
|
// 记录失败
|
||||||
|
h.router.RecordResult(ctx, provider.ProviderName(), false, time.Since(startTime).Milliseconds())
|
||||||
|
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录成功
|
||||||
|
h.router.RecordResult(ctx, provider.ProviderName(), true, time.Since(startTime).Milliseconds())
|
||||||
|
|
||||||
|
// 转换响应
|
||||||
|
chatResp := model.ChatCompletionResponse{
|
||||||
|
ID: response.ID,
|
||||||
|
Object: "chat.completion",
|
||||||
|
Created: response.Created,
|
||||||
|
Model: response.Model,
|
||||||
|
Choices: make([]model.Choice, len(response.Choices)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, c := range response.Choices {
|
||||||
|
chatResp.Choices[i] = model.Choice{
|
||||||
|
Index: c.Index,
|
||||||
|
Message: model.ChatMessage{
|
||||||
|
Role: c.Message.Role,
|
||||||
|
Content: c.Message.Content,
|
||||||
|
},
|
||||||
|
FinishReason: c.FinishReason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
chatResp.Usage = model.Usage{
|
||||||
|
PromptTokens: response.Usage.PromptTokens,
|
||||||
|
CompletionTokens: response.Usage.CompletionTokens,
|
||||||
|
TotalTokens: response.Usage.TotalTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
h.writeJSON(w, http.StatusOK, chatResp, requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleStream 处理流式请求
|
||||||
|
func (h *Handler) handleStream(ctx context.Context, w http.ResponseWriter, r *http.Request, provider adapter.ProviderAdapter, model string, messages []adapter.Message, options adapter.CompletionOptions, requestID string) {
|
||||||
|
ch, err := provider.ChatCompletionStream(ctx, model, messages, options)
|
||||||
|
if err != nil {
|
||||||
|
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置SSE头
|
||||||
|
w.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
w.Header().Set("Cache-Control", "no-cache")
|
||||||
|
w.Header().Set("Connection", "keep-alive")
|
||||||
|
w.Header().Set("X-Request-ID", requestID)
|
||||||
|
|
||||||
|
flusher, ok := w.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
h.writeError(w, r, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "streaming not supported").WithRequestID(requestID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 流式发送响应
|
||||||
|
for chunk := range ch {
|
||||||
|
data := fmt.Sprintf("data: %s\n\n", marshalJSON(chunk))
|
||||||
|
w.Write([]byte(data))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Write([]byte("data: [DONE]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompletionsHandle /v1/completions端点
|
||||||
|
func (h *Handler) CompletionsHandle(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestID := r.Header.Get("X-Request-ID")
|
||||||
|
if requestID == "" {
|
||||||
|
requestID = generateRequestID()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析请求
|
||||||
|
var req model.CompletionRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "invalid request body").WithRequestID(requestID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换格式并调用ChatCompletions
|
||||||
|
chatReq := model.ChatCompletionRequest{
|
||||||
|
Model: req.Model,
|
||||||
|
Temperature: req.Temperature,
|
||||||
|
MaxTokens: req.MaxTokens,
|
||||||
|
TopP: req.TopP,
|
||||||
|
Stream: req.Stream,
|
||||||
|
Stop: req.Stop,
|
||||||
|
Messages: []model.ChatMessage{
|
||||||
|
{Role: "user", Content: req.Prompt},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// 复用ChatCompletions逻辑
|
||||||
|
req.Method = "POST"
|
||||||
|
req.URL.Path = "/v1/chat/completions"
|
||||||
|
|
||||||
|
// 重新构造请求体并处理
|
||||||
|
ctx := r.Context()
|
||||||
|
messages := []adapter.Message{{Role: "user", Content: req.Prompt}}
|
||||||
|
|
||||||
|
provider, err := h.router.SelectProvider(ctx, req.Model)
|
||||||
|
if err != nil {
|
||||||
|
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
options := adapter.CompletionOptions{
|
||||||
|
Temperature: req.Temperature,
|
||||||
|
MaxTokens: req.MaxTokens,
|
||||||
|
TopP: req.TopP,
|
||||||
|
Stream: req.Stream,
|
||||||
|
Stop: req.Stop,
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Stream {
|
||||||
|
h.handleStream(ctx, w, r, provider, req.Model, messages, options, requestID)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response, err := provider.ChatCompletion(ctx, req.Model, messages, options)
|
||||||
|
if err != nil {
|
||||||
|
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换响应为Completion格式
|
||||||
|
compResp := model.CompletionResponse{
|
||||||
|
ID: response.ID,
|
||||||
|
Object: "text_completion",
|
||||||
|
Created: response.Created,
|
||||||
|
Model: response.Model,
|
||||||
|
Choices: make([]model.Choice1, len(response.Choices)),
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, c := range response.Choices {
|
||||||
|
compResp.Choices[i] = model.Choice1{
|
||||||
|
Text: c.Message.Content,
|
||||||
|
Index: i,
|
||||||
|
FinishReason: c.FinishReason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
compResp.Usage = model.Usage{
|
||||||
|
PromptTokens: response.Usage.PromptTokens,
|
||||||
|
CompletionTokens: response.Usage.CompletionTokens,
|
||||||
|
TotalTokens: response.Usage.TotalTokens,
|
||||||
|
}
|
||||||
|
|
||||||
|
h.writeJSON(w, http.StatusOK, compResp, requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelsHandle /v1/models端点
|
||||||
|
func (h *Handler) ModelsHandle(w http.ResponseWriter, r *http.Request) {
|
||||||
|
requestID := r.Header.Get("X-Request-ID")
|
||||||
|
if requestID == "" {
|
||||||
|
requestID = generateRequestID()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回支持的模型列表
|
||||||
|
models := []map[string]interface{}{
|
||||||
|
{"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"},
|
||||||
|
{"id": "gpt-3.5-turbo", "object": "model", "created": 1677610602, "owned_by": "openai"},
|
||||||
|
{"id": "claude-3-opus", "object": "model", "created": 1709598254, "owned_by": "anthropic"},
|
||||||
|
{"id": "claude-3-sonnet", "object": "model", "created": 1709598255, "owned_by": "anthropic"},
|
||||||
|
}
|
||||||
|
|
||||||
|
h.writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||||
|
"object": "list",
|
||||||
|
"data": models,
|
||||||
|
}, requestID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthHandle /health端点
|
||||||
|
func (h *Handler) HealthHandle(w http.ResponseWriter, r *http.Request) {
|
||||||
|
healthStatus := h.router.GetHealthStatus()
|
||||||
|
|
||||||
|
allHealthy := true
|
||||||
|
services := make(map[string]bool)
|
||||||
|
for name, health := range healthStatus {
|
||||||
|
services[name] = health.Available
|
||||||
|
if !health.Available {
|
||||||
|
allHealthy = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
status := "healthy"
|
||||||
|
statusCode := http.StatusOK
|
||||||
|
if !allHealthy {
|
||||||
|
status = "degraded"
|
||||||
|
statusCode = http.StatusServiceUnavailable
|
||||||
|
}
|
||||||
|
|
||||||
|
h.writeJSON(w, statusCode, model.HealthStatus{
|
||||||
|
Status: status,
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Services: services,
|
||||||
|
}, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) writeJSON(w http.ResponseWriter, status int, data interface{}, requestID string) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if requestID != "" {
|
||||||
|
w.Header().Set("X-Request-ID", requestID)
|
||||||
|
}
|
||||||
|
w.WriteHeader(status)
|
||||||
|
json.NewEncoder(w).Encode(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) writeError(w http.ResponseWriter, r *http.Request, err *error.GatewayError) {
|
||||||
|
info := err.GetErrorInfo()
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if err.RequestID != "" {
|
||||||
|
w.Header().Set("X-Request-ID", err.RequestID)
|
||||||
|
}
|
||||||
|
w.WriteHeader(info.HTTPStatus)
|
||||||
|
|
||||||
|
resp := model.ErrorResponse{
|
||||||
|
Error: model.ErrorDetail{
|
||||||
|
Message: err.Message,
|
||||||
|
Type: "gateway_error",
|
||||||
|
Code: string(err.Code),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
json.NewEncoder(w).Encode(resp)
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateRequestID() string {
|
||||||
|
return fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
func marshalJSON(v interface{}) string {
|
||||||
|
data, _ := json.Marshal(v)
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSEReader 流式响应读取器
|
||||||
|
type SSEReader struct {
|
||||||
|
reader *bufio.Reader
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSSEReader(r io.Reader) *SSEReader {
|
||||||
|
return &SSEReader{reader: bufio.NewReader(r)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SSEReader) ReadLine() (string, error) {
|
||||||
|
line, err := s.reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return line[:len(line)-1], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseSSEData(line string) string {
|
||||||
|
if len(line) < 6 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if line[:5] != "data:" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return line[6:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func getenv(key, defaultValue string) string {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
getenv = func(key, defaultValue string) string {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
}
|
||||||
336
gateway/internal/ratelimit/ratelimit.go
Normal file
336
gateway/internal/ratelimit/ratelimit.go
Normal file
@@ -0,0 +1,336 @@
|
|||||||
|
package ratelimit
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"lijiaoqiao/gateway/pkg/error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Algorithm 限流算法
|
||||||
|
type Algorithm string
|
||||||
|
|
||||||
|
const (
|
||||||
|
TokenBucket Algorithm = "token_bucket"
|
||||||
|
SlidingWindow Algorithm = "sliding_window"
|
||||||
|
FixedWindow Algorithm = "fixed_window"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Limiter 限流器接口
|
||||||
|
type Limiter interface {
|
||||||
|
// Allow 检查是否允许请求
|
||||||
|
Allow(ctx context.Context, key string) (bool, error)
|
||||||
|
// AllowToken 检查是否允许消耗token
|
||||||
|
AllowToken(ctx context.Context, key string, tokens int) (bool, error)
|
||||||
|
// GetLimit 获取当前限制
|
||||||
|
GetLimit(key string) *Limit
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit 限制配置
|
||||||
|
type Limit struct {
|
||||||
|
RPM int // 请求数/分钟
|
||||||
|
TPM int // Token数/分钟
|
||||||
|
Burst int // 突发容量
|
||||||
|
Remaining int // 剩余请求数
|
||||||
|
ResetAt time.Time // 重置时间
|
||||||
|
}
|
||||||
|
|
||||||
|
// TokenBucketLimiter Token桶限流器
|
||||||
|
type TokenBucketLimiter struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
buckets map[string]*tokenBucket
|
||||||
|
defaultRPM int
|
||||||
|
defaultTPM int
|
||||||
|
burstMultiplier float64
|
||||||
|
cleanInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type tokenBucket struct {
|
||||||
|
tokens float64
|
||||||
|
maxTokens float64
|
||||||
|
tokensPerSec float64
|
||||||
|
lastRefill time.Time
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTokenBucketLimiter 创建Token桶限流器
|
||||||
|
func NewTokenBucketLimiter(defaultRPM, defaultTPM int, burstMultiplier float64) *TokenBucketLimiter {
|
||||||
|
limiter := &TokenBucketLimiter{
|
||||||
|
buckets: make(map[string]*tokenBucket),
|
||||||
|
defaultRPM: defaultRPM,
|
||||||
|
defaultTPM: defaultTPM,
|
||||||
|
burstMultiplier: burstMultiplier,
|
||||||
|
cleanInterval: 5 * time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动清理goroutine
|
||||||
|
go limiter.cleanup()
|
||||||
|
|
||||||
|
return limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow 检查是否允许请求
|
||||||
|
func (l *TokenBucketLimiter) Allow(ctx context.Context, key string) (bool, error) {
|
||||||
|
return l.AllowToken(ctx, key, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllowToken 检查是否允许消耗token
|
||||||
|
func (l *TokenBucketLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) {
|
||||||
|
l.mu.Lock()
|
||||||
|
bucket, exists := l.buckets[key]
|
||||||
|
if !exists {
|
||||||
|
bucket = l.newBucket(l.defaultRPM, l.defaultTPM)
|
||||||
|
l.buckets[key] = bucket
|
||||||
|
}
|
||||||
|
l.mu.Unlock()
|
||||||
|
|
||||||
|
bucket.mu.Lock()
|
||||||
|
defer bucket.mu.Unlock()
|
||||||
|
|
||||||
|
// 补充token
|
||||||
|
l.refill(bucket)
|
||||||
|
|
||||||
|
// 检查是否有足够的token
|
||||||
|
if bucket.tokens >= float64(tokens) {
|
||||||
|
bucket.tokens -= float64(tokens)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLimit 获取当前限制
|
||||||
|
func (l *TokenBucketLimiter) GetLimit(key string) *Limit {
|
||||||
|
l.mu.RLock()
|
||||||
|
bucket, exists := l.buckets[key]
|
||||||
|
l.mu.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return &Limit{
|
||||||
|
RPM: l.defaultRPM,
|
||||||
|
TPM: l.defaultTPM,
|
||||||
|
Burst: int(float64(l.defaultRPM) * l.burstMultiplier),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bucket.mu.Lock()
|
||||||
|
defer bucket.mu.Unlock()
|
||||||
|
|
||||||
|
return &Limit{
|
||||||
|
RPM: l.defaultRPM,
|
||||||
|
TPM: l.defaultTPM,
|
||||||
|
Burst: int(bucket.maxTokens),
|
||||||
|
Remaining: int(bucket.tokens),
|
||||||
|
ResetAt: bucket.lastRefill.Add(time.Minute),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *TokenBucketLimiter) newBucket(rpm, tpm int) *tokenBucket {
|
||||||
|
burst := int(float64(rpm) * l.burstMultiplier)
|
||||||
|
return &tokenBucket{
|
||||||
|
tokens: float64(burst),
|
||||||
|
maxTokens: float64(burst),
|
||||||
|
tokensPerSec: float64(rpm) / 60.0,
|
||||||
|
lastRefill: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *TokenBucketLimiter) refill(bucket *tokenBucket) {
|
||||||
|
now := time.Now()
|
||||||
|
elapsed := now.Sub(bucket.lastRefill).Seconds()
|
||||||
|
|
||||||
|
// 添加新token
|
||||||
|
bucket.tokens += elapsed * bucket.tokensPerSec
|
||||||
|
if bucket.tokens > bucket.maxTokens {
|
||||||
|
bucket.tokens = bucket.maxTokens
|
||||||
|
}
|
||||||
|
|
||||||
|
bucket.lastRefill = now
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *TokenBucketLimiter) cleanup() {
|
||||||
|
ticker := time.NewTicker(l.cleanInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
l.mu.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
for key, bucket := range l.buckets {
|
||||||
|
bucket.mu.Lock()
|
||||||
|
// 如果bucket完全空了且超过10分钟没使用,删除它
|
||||||
|
if bucket.tokens >= bucket.maxTokens && now.Sub(bucket.lastRefill) > 10*time.Minute {
|
||||||
|
delete(l.buckets, key)
|
||||||
|
}
|
||||||
|
bucket.mu.Unlock()
|
||||||
|
}
|
||||||
|
l.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SlidingWindowLimiter 滑动窗口限流器
|
||||||
|
type SlidingWindowLimiter struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
windows map[string]*slidingWindow
|
||||||
|
windowSize time.Duration
|
||||||
|
maxRequests int
|
||||||
|
cleanInterval time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
type slidingWindow struct {
|
||||||
|
requests []time.Time
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSlidingWindowLimiter(windowSize time.Duration, maxRequests int) *SlidingWindowLimiter {
|
||||||
|
limiter := &SlidingWindowLimiter{
|
||||||
|
windows: make(map[string]*slidingWindow),
|
||||||
|
windowSize: windowSize,
|
||||||
|
maxRequests: maxRequests,
|
||||||
|
cleanInterval: 1 * time.Minute,
|
||||||
|
}
|
||||||
|
|
||||||
|
go limiter.cleanup()
|
||||||
|
return limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *SlidingWindowLimiter) Allow(ctx context.Context, key string) (bool, error) {
|
||||||
|
l.mu.Lock()
|
||||||
|
window, exists := l.windows[key]
|
||||||
|
if !exists {
|
||||||
|
window = &slidingWindow{requests: make([]time.Time, 0)}
|
||||||
|
l.windows[key] = window
|
||||||
|
}
|
||||||
|
l.mu.Unlock()
|
||||||
|
|
||||||
|
window.mu.Lock()
|
||||||
|
defer window.mu.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
cutoff := now.Add(-l.windowSize)
|
||||||
|
|
||||||
|
// 清理过期的请求
|
||||||
|
validRequests := make([]time.Time, 0)
|
||||||
|
for _, t := range window.requests {
|
||||||
|
if t.After(cutoff) {
|
||||||
|
validRequests = append(validRequests, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
window.requests = validRequests
|
||||||
|
|
||||||
|
// 检查是否超过限制
|
||||||
|
if len(window.requests) >= l.maxRequests {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
window.requests = append(window.requests, now)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *SlidingWindowLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) {
|
||||||
|
// 对于滑动窗口,tokens只是计数,这里简化为1个请求
|
||||||
|
return l.Allow(ctx, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *SlidingWindowLimiter) GetLimit(key string) *Limit {
|
||||||
|
l.mu.RLock()
|
||||||
|
window, exists := l.windows[key]
|
||||||
|
l.mu.RUnlock()
|
||||||
|
|
||||||
|
remaining := l.maxRequests
|
||||||
|
if exists {
|
||||||
|
window.mu.Lock()
|
||||||
|
cutoff := time.Now().Add(-l.windowSize)
|
||||||
|
count := 0
|
||||||
|
for _, t := range window.requests {
|
||||||
|
if t.After(cutoff) {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
remaining = l.maxRequests - count
|
||||||
|
if remaining < 0 {
|
||||||
|
remaining = 0
|
||||||
|
}
|
||||||
|
window.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Limit{
|
||||||
|
RPM: l.maxRequests,
|
||||||
|
ResetAt: time.Now().Add(l.windowSize),
|
||||||
|
Remaining: remaining,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *SlidingWindowLimiter) cleanup() {
|
||||||
|
ticker := time.NewTicker(l.cleanInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for range ticker.C {
|
||||||
|
l.mu.Lock()
|
||||||
|
now := time.Now()
|
||||||
|
for key, window := range l.windows {
|
||||||
|
window.mu.Lock()
|
||||||
|
cutoff := now.Add(-l.windowSize * 2)
|
||||||
|
validRequests := make([]time.Time, 0)
|
||||||
|
for _, t := range window.requests {
|
||||||
|
if t.After(cutoff) {
|
||||||
|
validRequests = append(validRequests, t)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(validRequests) == 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 {
|
||||||
|
delete(l.windows, key)
|
||||||
|
} else {
|
||||||
|
window.requests = validRequests
|
||||||
|
}
|
||||||
|
window.mu.Unlock()
|
||||||
|
}
|
||||||
|
l.mu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Middleware 限流中间件
|
||||||
|
type Middleware struct {
|
||||||
|
limiter Limiter
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMiddleware(limiter Limiter) *Middleware {
|
||||||
|
return &Middleware{limiter: limiter}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// 使用API Key作为限流key
|
||||||
|
key := r.Header.Get("Authorization")
|
||||||
|
if key == "" {
|
||||||
|
key = r.RemoteAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
allowed, err := m.limiter.Allow(r.Context(), key)
|
||||||
|
if err != nil {
|
||||||
|
writeError(w, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "rate limiter error"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !allowed {
|
||||||
|
limit := m.limiter.GetLimit(key)
|
||||||
|
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", limit.RPM))
|
||||||
|
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", limit.Remaining))
|
||||||
|
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", limit.ResetAt.Unix()))
|
||||||
|
|
||||||
|
writeError(w, error.NewGatewayError(error.RATE_LIMIT_EXCEEDED, "rate limit exceeded"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
func writeError(w http.ResponseWriter, err *error.GatewayError) {
|
||||||
|
info := err.GetErrorInfo()
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(info.HTTPStatus)
|
||||||
|
w.Write([]byte(fmt.Sprintf(`{"error":{"message":"%s","code":"%s"}}`, err.Message, err.Code)))
|
||||||
|
}
|
||||||
261
gateway/internal/router/router.go
Normal file
261
gateway/internal/router/router.go
Normal file
@@ -0,0 +1,261 @@
|
|||||||
|
package router
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"math"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"lijiaoqiao/gateway/internal/adapter"
|
||||||
|
"lijiaoqiao/gateway/pkg/error"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LoadBalancerStrategy 负载均衡策略
|
||||||
|
type LoadBalancerStrategy string
|
||||||
|
|
||||||
|
const (
|
||||||
|
StrategyLatency LoadBalancerStrategy = "latency"
|
||||||
|
StrategyRoundRobin LoadBalancerStrategy = "round_robin"
|
||||||
|
StrategyWeighted LoadBalancerStrategy = "weighted"
|
||||||
|
StrategyAvailability LoadBalancerStrategy = "availability"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProviderHealth Provider健康状态
|
||||||
|
type ProviderHealth struct {
|
||||||
|
Name string
|
||||||
|
Available bool
|
||||||
|
LatencyMs int64
|
||||||
|
FailureRate float64
|
||||||
|
Weight float64
|
||||||
|
LastCheckTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router 路由器
|
||||||
|
type Router struct {
|
||||||
|
providers map[string]adapter.ProviderAdapter
|
||||||
|
health map[string]*ProviderHealth
|
||||||
|
strategy LoadBalancerStrategy
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRouter 创建路由器
|
||||||
|
func NewRouter(strategy LoadBalancerStrategy) *Router {
|
||||||
|
return &Router{
|
||||||
|
providers: make(map[string]adapter.ProviderAdapter),
|
||||||
|
health: make(map[string]*ProviderHealth),
|
||||||
|
strategy: strategy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterProvider 注册Provider
|
||||||
|
func (r *Router) RegisterProvider(name string, provider adapter.ProviderAdapter) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
r.providers[name] = provider
|
||||||
|
r.health[name] = &ProviderHealth{
|
||||||
|
Name: name,
|
||||||
|
Available: true,
|
||||||
|
LatencyMs: 0,
|
||||||
|
FailureRate: 0,
|
||||||
|
Weight: 1.0,
|
||||||
|
LastCheckTime: time.Now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SelectProvider 选择最佳Provider
|
||||||
|
func (r *Router) SelectProvider(ctx context.Context, model string) (adapter.ProviderAdapter, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
var candidates []string
|
||||||
|
for name, provider := range r.providers {
|
||||||
|
if r.isProviderAvailable(name, model) {
|
||||||
|
candidates = append(candidates, name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
return nil, error.NewGatewayError(error.ROUTER_NO_PROVIDER_AVAILABLE, "no provider available for model: "+model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据策略选择
|
||||||
|
switch r.strategy {
|
||||||
|
case StrategyLatency:
|
||||||
|
return r.selectByLatency(candidates)
|
||||||
|
case StrategyWeighted:
|
||||||
|
return r.selectByWeight(candidates)
|
||||||
|
case StrategyAvailability:
|
||||||
|
return r.selectByAvailability(candidates)
|
||||||
|
default:
|
||||||
|
return r.selectByLatency(candidates)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) isProviderAvailable(name, model string) bool {
|
||||||
|
health, ok := r.health[name]
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !health.Available {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查模型是否支持
|
||||||
|
provider := r.providers[name]
|
||||||
|
if provider == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, m := range provider.SupportedModels() {
|
||||||
|
if m == model || m == "*" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) selectByLatency(candidates []string) (adapter.ProviderAdapter, error) {
|
||||||
|
var bestProvider adapter.ProviderAdapter
|
||||||
|
var minLatency int64 = math.MaxInt64
|
||||||
|
|
||||||
|
for _, name := range candidates {
|
||||||
|
health := r.health[name]
|
||||||
|
if health.LatencyMs < minLatency {
|
||||||
|
minLatency = health.LatencyMs
|
||||||
|
bestProvider = r.providers[name]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if bestProvider == nil {
|
||||||
|
return nil, error.NewGatewayError(error.ROUTER_NO_PROVIDER_AVAILABLE, "no available provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
return bestProvider, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) selectByWeight(candidates []string) (adapter.ProviderAdapter, error) {
|
||||||
|
var totalWeight float64
|
||||||
|
for _, name := range candidates {
|
||||||
|
totalWeight += r.health[name].Weight
|
||||||
|
}
|
||||||
|
|
||||||
|
randVal := float64(time.Now().UnixNano()) / float64(math.MaxInt64) * totalWeight
|
||||||
|
var cumulative float64
|
||||||
|
|
||||||
|
for _, name := range candidates {
|
||||||
|
cumulative += r.health[name].Weight
|
||||||
|
if randVal <= cumulative {
|
||||||
|
return r.providers[name], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return r.providers[candidates[0]], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Router) selectByAvailability(candidates []string) (adapter.ProviderAdapter, error) {
|
||||||
|
var bestProvider adapter.ProviderAdapter
|
||||||
|
var minFailureRate float64 = math.MaxFloat64
|
||||||
|
|
||||||
|
for _, name := range candidates {
|
||||||
|
health := r.health[name]
|
||||||
|
if health.FailureRate < minFailureRate {
|
||||||
|
minFailureRate = health.FailureRate
|
||||||
|
bestProvider = r.providers[name]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if bestProvider == nil {
|
||||||
|
return nil, error.NewGatewayError(error.ROUTER_NO_PROVIDER_AVAILABLE, "no available provider")
|
||||||
|
}
|
||||||
|
|
||||||
|
return bestProvider, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetFallbackProviders 获取Fallback Providers
|
||||||
|
func (r *Router) GetFallbackProviders(ctx context.Context, model string) ([]adapter.ProviderAdapter, error) {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
var fallbacks []adapter.ProviderAdapter
|
||||||
|
|
||||||
|
for name, provider := range r.providers {
|
||||||
|
if name == "primary" {
|
||||||
|
continue // 跳过主Provider
|
||||||
|
}
|
||||||
|
if r.isProviderAvailable(name, model) {
|
||||||
|
fallbacks = append(fallbacks, provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fallbacks, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RecordResult 记录调用结果
|
||||||
|
func (r *Router) RecordResult(ctx context.Context, providerName string, success bool, latencyMs int64) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
health, ok := r.health[providerName]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新延迟
|
||||||
|
if latencyMs > 0 {
|
||||||
|
// 指数移动平均
|
||||||
|
if health.LatencyMs == 0 {
|
||||||
|
health.LatencyMs = latencyMs
|
||||||
|
} else {
|
||||||
|
health.LatencyMs = (health.LatencyMs*7 + latencyMs) / 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新失败率
|
||||||
|
if success {
|
||||||
|
if health.FailureRate > 0 {
|
||||||
|
health.FailureRate = health.FailureRate * 0.9 // 下降
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
health.FailureRate = health.FailureRate*0.9 + 0.1 // 上升
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否应该标记为不可用
|
||||||
|
if health.FailureRate > 0.5 {
|
||||||
|
health.Available = false
|
||||||
|
}
|
||||||
|
|
||||||
|
health.LastCheckTime = time.Now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateHealth 更新健康状态
|
||||||
|
func (r *Router) UpdateHealth(providerName string, available bool) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if health, ok := r.health[providerName]; ok {
|
||||||
|
health.Available = available
|
||||||
|
health.LastCheckTime = time.Now()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetHealthStatus 获取健康状态
|
||||||
|
func (r *Router) GetHealthStatus() map[string]*ProviderHealth {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make(map[string]*ProviderHealth)
|
||||||
|
for name, health := range r.health {
|
||||||
|
result[name] = &ProviderHealth{
|
||||||
|
Name: health.Name,
|
||||||
|
Available: health.Available,
|
||||||
|
LatencyMs: health.LatencyMs,
|
||||||
|
FailureRate: health.FailureRate,
|
||||||
|
Weight: health.Weight,
|
||||||
|
LastCheckTime: health.LastCheckTime,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
246
gateway/pkg/error/error.go
Normal file
246
gateway/pkg/error/error.go
Normal file
@@ -0,0 +1,246 @@
|
|||||||
|
package error
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// ErrorCode 错误码枚举
|
||||||
|
type ErrorCode string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// 认证授权 (AUTH_*)
|
||||||
|
AUTH_INVALID_TOKEN ErrorCode = "AUTH_001"
|
||||||
|
AUTH_INSUFFICIENT_PERMISSION ErrorCode = "AUTH_002"
|
||||||
|
AUTH_MFA_REQUIRED ErrorCode = "AUTH_003"
|
||||||
|
|
||||||
|
// 计费 (BILLING_*)
|
||||||
|
BILLING_INSUFFICIENT_BALANCE ErrorCode = "BILLING_001"
|
||||||
|
BILLING_CHARGE_FAILED ErrorCode = "BILLING_002"
|
||||||
|
BILLING_REFUND_FAILED ErrorCode = "BILLING_003"
|
||||||
|
BILLING_DISCREPANCY ErrorCode = "BILLING_004"
|
||||||
|
|
||||||
|
// 路由 (ROUTER_*)
|
||||||
|
ROUTER_NO_PROVIDER_AVAILABLE ErrorCode = "ROUTER_001"
|
||||||
|
ROUTER_ALL_PROVIDERS_FAILED ErrorCode = "ROUTER_002"
|
||||||
|
ROUTER_TIMEOUT ErrorCode = "ROUTER_003"
|
||||||
|
|
||||||
|
// 供应商 (PROVIDER_*)
|
||||||
|
PROVIDER_INVALID_KEY ErrorCode = "PROVIDER_001"
|
||||||
|
PROVIDER_RATE_LIMIT ErrorCode = "PROVIDER_002"
|
||||||
|
PROVIDER_QUOTA_EXCEEDED ErrorCode = "PROVIDER_003"
|
||||||
|
PROVIDER_MODEL_NOT_FOUND ErrorCode = "PROVIDER_004"
|
||||||
|
PROVIDER_ERROR ErrorCode = "PROVIDER_005"
|
||||||
|
|
||||||
|
// 限流 (RATE_LIMIT_*)
|
||||||
|
RATE_LIMIT_EXCEEDED ErrorCode = "RATE_LIMIT_001"
|
||||||
|
RATE_LIMIT_TOKEN_EXCEEDED ErrorCode = "RATE_LIMIT_002"
|
||||||
|
RATE_LIMIT_BURST_EXCEEDED ErrorCode = "RATE_LIMIT_003"
|
||||||
|
|
||||||
|
// 通用 (COMMON_*)
|
||||||
|
COMMON_INVALID_REQUEST ErrorCode = "COMMON_001"
|
||||||
|
COMMON_RESOURCE_NOT_FOUND ErrorCode = "COMMON_002"
|
||||||
|
COMMON_INTERNAL_ERROR ErrorCode = "COMMON_003"
|
||||||
|
COMMON_SERVICE_UNAVAILABLE ErrorCode = "COMMON_004"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrorInfo 错误信息
|
||||||
|
type ErrorInfo struct {
|
||||||
|
Code ErrorCode
|
||||||
|
Message string
|
||||||
|
HTTPStatus int
|
||||||
|
Retryable bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// GatewayError 网关错误
|
||||||
|
type GatewayError struct {
|
||||||
|
Code ErrorCode
|
||||||
|
Message string
|
||||||
|
Details map[string]interface{}
|
||||||
|
RequestID string
|
||||||
|
Internal error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *GatewayError) Error() string {
|
||||||
|
if e.Internal != nil {
|
||||||
|
return fmt.Sprintf("%s: %s (caused by: %v)", e.Code, e.Message, e.Internal)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s: %s", e.Code, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *GatewayError) Unwrap() error {
|
||||||
|
return e.Internal
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorDefinitions 错误码定义
|
||||||
|
var ErrorDefinitions = map[ErrorCode]ErrorInfo{
|
||||||
|
AUTH_INVALID_TOKEN: {
|
||||||
|
Code: AUTH_INVALID_TOKEN,
|
||||||
|
Message: "Invalid or expired token",
|
||||||
|
HTTPStatus: 401,
|
||||||
|
Retryable: false,
|
||||||
|
},
|
||||||
|
AUTH_INSUFFICIENT_PERMISSION: {
|
||||||
|
Code: AUTH_INSUFFICIENT_PERMISSION,
|
||||||
|
Message: "Insufficient permissions",
|
||||||
|
HTTPStatus: 403,
|
||||||
|
Retryable: false,
|
||||||
|
},
|
||||||
|
AUTH_MFA_REQUIRED: {
|
||||||
|
Code: AUTH_MFA_REQUIRED,
|
||||||
|
Message: "MFA verification required",
|
||||||
|
HTTPStatus: 403,
|
||||||
|
Retryable: false,
|
||||||
|
},
|
||||||
|
BILLING_INSUFFICIENT_BALANCE: {
|
||||||
|
Code: BILLING_INSUFFICIENT_BALANCE,
|
||||||
|
Message: "Insufficient balance",
|
||||||
|
HTTPStatus: 402,
|
||||||
|
Retryable: false,
|
||||||
|
},
|
||||||
|
BILLING_CHARGE_FAILED: {
|
||||||
|
Code: BILLING_CHARGE_FAILED,
|
||||||
|
Message: "Charge failed",
|
||||||
|
HTTPStatus: 500,
|
||||||
|
Retryable: true,
|
||||||
|
},
|
||||||
|
BILLING_REFUND_FAILED: {
|
||||||
|
Code: BILLING_REFUND_FAILED,
|
||||||
|
Message: "Refund failed",
|
||||||
|
HTTPStatus: 500,
|
||||||
|
Retryable: true,
|
||||||
|
},
|
||||||
|
BILLING_DISCREPANCY: {
|
||||||
|
Code: BILLING_DISCREPANCY,
|
||||||
|
Message: "Billing discrepancy detected",
|
||||||
|
HTTPStatus: 500,
|
||||||
|
Retryable: true,
|
||||||
|
},
|
||||||
|
ROUTER_NO_PROVIDER_AVAILABLE: {
|
||||||
|
Code: ROUTER_NO_PROVIDER_AVAILABLE,
|
||||||
|
Message: "No provider available",
|
||||||
|
HTTPStatus: 503,
|
||||||
|
Retryable: true,
|
||||||
|
},
|
||||||
|
ROUTER_ALL_PROVIDERS_FAILED: {
|
||||||
|
Code: ROUTER_ALL_PROVIDERS_FAILED,
|
||||||
|
Message: "All providers failed",
|
||||||
|
HTTPStatus: 503,
|
||||||
|
Retryable: true,
|
||||||
|
},
|
||||||
|
ROUTER_TIMEOUT: {
|
||||||
|
Code: ROUTER_TIMEOUT,
|
||||||
|
Message: "Request timeout",
|
||||||
|
HTTPStatus: 504,
|
||||||
|
Retryable: true,
|
||||||
|
},
|
||||||
|
PROVIDER_INVALID_KEY: {
|
||||||
|
Code: PROVIDER_INVALID_KEY,
|
||||||
|
Message: "Invalid API key",
|
||||||
|
HTTPStatus: 401,
|
||||||
|
Retryable: false,
|
||||||
|
},
|
||||||
|
PROVIDER_RATE_LIMIT: {
|
||||||
|
Code: PROVIDER_RATE_LIMIT,
|
||||||
|
Message: "Rate limit exceeded",
|
||||||
|
HTTPStatus: 429,
|
||||||
|
Retryable: true,
|
||||||
|
},
|
||||||
|
PROVIDER_QUOTA_EXCEEDED: {
|
||||||
|
Code: PROVIDER_QUOTA_EXCEEDED,
|
||||||
|
Message: "Quota exceeded",
|
||||||
|
HTTPStatus: 402,
|
||||||
|
Retryable: false,
|
||||||
|
},
|
||||||
|
PROVIDER_MODEL_NOT_FOUND: {
|
||||||
|
Code: PROVIDER_MODEL_NOT_FOUND,
|
||||||
|
Message: "Model not found",
|
||||||
|
HTTPStatus: 404,
|
||||||
|
Retryable: false,
|
||||||
|
},
|
||||||
|
PROVIDER_ERROR: {
|
||||||
|
Code: PROVIDER_ERROR,
|
||||||
|
Message: "Provider error",
|
||||||
|
HTTPStatus: 502,
|
||||||
|
Retryable: true,
|
||||||
|
},
|
||||||
|
RATE_LIMIT_EXCEEDED: {
|
||||||
|
Code: RATE_LIMIT_EXCEEDED,
|
||||||
|
Message: "Rate limit exceeded",
|
||||||
|
HTTPStatus: 429,
|
||||||
|
Retryable: false,
|
||||||
|
},
|
||||||
|
RATE_LIMIT_TOKEN_EXCEEDED: {
|
||||||
|
Code: RATE_LIMIT_TOKEN_EXCEEDED,
|
||||||
|
Message: "Token limit exceeded",
|
||||||
|
HTTPStatus: 429,
|
||||||
|
Retryable: false,
|
||||||
|
},
|
||||||
|
RATE_LIMIT_BURST_EXCEEDED: {
|
||||||
|
Code: RATE_LIMIT_BURST_EXCEEDED,
|
||||||
|
Message: "Burst limit exceeded",
|
||||||
|
HTTPStatus: 429,
|
||||||
|
Retryable: false,
|
||||||
|
},
|
||||||
|
COMMON_INVALID_REQUEST: {
|
||||||
|
Code: COMMON_INVALID_REQUEST,
|
||||||
|
Message: "Invalid request",
|
||||||
|
HTTPStatus: 400,
|
||||||
|
Retryable: false,
|
||||||
|
},
|
||||||
|
COMMON_RESOURCE_NOT_FOUND: {
|
||||||
|
Code: COMMON_RESOURCE_NOT_FOUND,
|
||||||
|
Message: "Resource not found",
|
||||||
|
HTTPStatus: 404,
|
||||||
|
Retryable: false,
|
||||||
|
},
|
||||||
|
COMMON_INTERNAL_ERROR: {
|
||||||
|
Code: COMMON_INTERNAL_ERROR,
|
||||||
|
Message: "Internal error",
|
||||||
|
HTTPStatus: 500,
|
||||||
|
Retryable: true,
|
||||||
|
},
|
||||||
|
COMMON_SERVICE_UNAVAILABLE: {
|
||||||
|
Code: COMMON_SERVICE_UNAVAILABLE,
|
||||||
|
Message: "Service unavailable",
|
||||||
|
HTTPStatus: 503,
|
||||||
|
Retryable: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGatewayError 创建网关错误
|
||||||
|
func NewGatewayError(code ErrorCode, message string) *GatewayError {
|
||||||
|
return &GatewayError{
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
Details: make(map[string]interface{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithRequestID 设置请求ID
|
||||||
|
func (e *GatewayError) WithRequestID(requestID string) *GatewayError {
|
||||||
|
e.RequestID = requestID
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDetail 设置详情
|
||||||
|
func (e *GatewayError) WithDetail(key string, value interface{}) *GatewayError {
|
||||||
|
e.Details[key] = value
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithInternal 设置内部错误
|
||||||
|
func (e *GatewayError) WithInternal(err error) *GatewayError {
|
||||||
|
e.Internal = err
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetErrorInfo 获取错误信息
|
||||||
|
func (e *GatewayError) GetErrorInfo() ErrorInfo {
|
||||||
|
if info, ok := ErrorDefinitions[e.Code]; ok {
|
||||||
|
return info
|
||||||
|
}
|
||||||
|
return ErrorInfo{
|
||||||
|
Code: COMMON_INTERNAL_ERROR,
|
||||||
|
Message: e.Message,
|
||||||
|
HTTPStatus: 500,
|
||||||
|
Retryable: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
144
gateway/pkg/model/model.go
Normal file
144
gateway/pkg/model/model.go
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// ChatCompletionRequest 聊天完成请求
|
||||||
|
type ChatCompletionRequest struct {
|
||||||
|
Model string `json:"model" binding:"required"`
|
||||||
|
Messages []ChatMessage `json:"messages" binding:"required"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Stop []string `json:"stop,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
PresencePenalty float64 `json:"presence_penalty,omitempty"`
|
||||||
|
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
|
||||||
|
User string `json:"user,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatMessage 聊天消息
|
||||||
|
type ChatMessage struct {
|
||||||
|
Role string `json:"role" binding:"required"`
|
||||||
|
Content string `json:"content" binding:"required"`
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletionResponse 聊天完成响应
|
||||||
|
type ChatCompletionResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []Choice `json:"choices"`
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Choice struct {
|
||||||
|
Index int `json:"index"`
|
||||||
|
Message ChatMessage `json:"message"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Usage struct {
|
||||||
|
PromptTokens int `json:"prompt_tokens"`
|
||||||
|
CompletionTokens int `json:"completion_tokens"`
|
||||||
|
TotalTokens int `json:"total_tokens"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompletionRequest 完成请求
|
||||||
|
type CompletionRequest struct {
|
||||||
|
Model string `json:"model" binding:"required"`
|
||||||
|
Prompt string `json:"prompt" binding:"required"`
|
||||||
|
MaxTokens int `json:"max_tokens,omitempty"`
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
TopP float64 `json:"top_p,omitempty"`
|
||||||
|
Stream bool `json:"stream,omitempty"`
|
||||||
|
Stop []string `json:"stop,omitempty"`
|
||||||
|
N int `json:"n,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompletionResponse 完成响应
|
||||||
|
type CompletionResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []Choice1 `json:"choices"`
|
||||||
|
Usage Usage `json:"usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Choice1 struct {
|
||||||
|
Text string `json:"text"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
FinishReason string `json:"finish_reason"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamResponse 流式响应
|
||||||
|
type StreamResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Object string `json:"object"`
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
Model string `json:"model"`
|
||||||
|
Choices []Delta `json:"choices"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type Delta struct {
|
||||||
|
Delta struct {
|
||||||
|
Content string `json:"content,omitempty"`
|
||||||
|
Role string `json:"role,omitempty"`
|
||||||
|
} `json:"delta"`
|
||||||
|
Index int `json:"index"`
|
||||||
|
FinishReason string `json:"finish_reason,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorResponse 错误响应
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Error ErrorDetail `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ErrorDetail struct {
|
||||||
|
Message string `json:"message"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Code string `json:"code,omitempty"`
|
||||||
|
Param string `json:"param,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// HealthStatus 健康状态
|
||||||
|
type HealthStatus struct {
|
||||||
|
Status string `json:"status"`
|
||||||
|
Timestamp time.Time `json:"timestamp"`
|
||||||
|
Services map[string]bool `json:"services"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tenant 租户
|
||||||
|
type Tenant struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Plan string `json:"plan"`
|
||||||
|
CreatedAt time.Time `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Budget 预算
|
||||||
|
type Budget struct {
|
||||||
|
TenantID int64 `json:"tenant_id"`
|
||||||
|
MonthlyLimit float64 `json:"monthly_limit"`
|
||||||
|
AlertThreshold float64 `json:"alert_threshold"`
|
||||||
|
CurrentUsage float64 `json:"current_usage"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteRequest 路由请求
|
||||||
|
type RouteRequest struct {
|
||||||
|
Model string
|
||||||
|
TenantID int64
|
||||||
|
RouteType string // "primary", "fallback"
|
||||||
|
}
|
||||||
|
|
||||||
|
// RouteResult 路由结果
|
||||||
|
type RouteResult struct {
|
||||||
|
Provider string
|
||||||
|
Model string
|
||||||
|
LatencyMs int64
|
||||||
|
Success bool
|
||||||
|
Error error
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user