diff --git a/gateway/cmd/gateway/main.go b/gateway/cmd/gateway/main.go new file mode 100644 index 0000000..c9230a6 --- /dev/null +++ b/gateway/cmd/gateway/main.go @@ -0,0 +1,151 @@ +package main + +import ( + "context" + "fmt" + "log" + "net/http" + "os" + "os/signal" + "syscall" + "time" + + "lijiaoqiao/gateway/internal/adapter" + "lijiaoqiao/gateway/internal/alert" + "lijiaoqiao/gateway/internal/config" + "lijiaoqiao/gateway/internal/handler" + "lijiaoqiao/gateway/internal/middleware" + "lijiaoqiao/gateway/internal/ratelimit" + "lijiaoqiao/gateway/internal/router" +) + +func main() { + // 加载配置 + cfg, err := config.LoadConfig("") + if err != nil { + log.Fatalf("Failed to load config: %v", err) + } + + // 初始化Router + r := router.NewRouter(router.StrategyLatency) + + // 注册Provider (示例: OpenAI) + openaiAdapter := adapter.NewOpenAIAdapter( + "https://api.openai.com", + os.Getenv("OPENAI_API_KEY"), + []string{"gpt-4", "gpt-3.5-turbo"}, + ) + r.RegisterProvider("openai", openaiAdapter) + + // 初始化限流器 + var limiter ratelimit.Limiter + if cfg.RateLimit.Algorithm == "token_bucket" { + limiter = ratelimit.NewTokenBucketLimiter( + cfg.RateLimit.DefaultRPM, + cfg.RateLimit.DefaultTPM, + cfg.RateLimit.BurstMultiplier, + ) + } else { + limiter = ratelimit.NewSlidingWindowLimiter( + time.Minute, + cfg.RateLimit.DefaultRPM, + ) + } + + // 初始化告警管理器 + alertManager, err := alert.NewManager(&cfg.Alert) + if err != nil { + log.Printf("Warning: Failed to create alert manager: %v", err) + } + + // 初始化Handler + h := handler.NewHandler(r) + + // 创建Server + server := &http.Server{ + Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port), + Handler: createMux(h, limiter, alertManager), + ReadTimeout: cfg.Server.ReadTimeout, + WriteTimeout: cfg.Server.WriteTimeout, + IdleTimeout: cfg.Server.IdleTimeout, + } + + // 启动Server + go func() { + log.Printf("Starting gateway server on %s:%d", cfg.Server.Host, cfg.Server.Port) + if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed { + log.Fatalf("Server failed: %v", err) + } + }() + + // 等待中断信号 + quit := make(chan os.Signal, 1) + signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) + <-quit + + log.Println("Shutting down server...") + + // 优雅关闭 + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if err := server.Shutdown(ctx); err != nil { + log.Fatalf("Server forced to shutdown: %v", err) + } + + log.Println("Server exited") +} + +func createMux(h *handler.Handler, limiter *ratelimit.Middleware, alertMgr *alert.Manager) *http.ServeMux { + mux := http.NewServeMux() + + // V1 API + v1 := mux.PathPrefix("/v1").Subrouter() + + // Chat Completions (需要限流和认证) + v1.HandleFunc("/chat/completions", withMiddleware(h.ChatCompletionsHandle, + limiter.Limit, + authMiddleware(), + )) + + // Completions + v1.HandleFunc("/completions", withMiddleware(h.CompletionsHandle, + limiter.Limit, + authMiddleware(), + )) + + // Models + v1.HandleFunc("/models", h.ModelsHandle) + + // Health + mux.HandleFunc("/health", h.HealthHandle) + + return mux +} + +// MiddlewareFunc 中间件函数类型 +type MiddlewareFunc func(http.HandlerFunc) http.HandlerFunc + +// withMiddleware 应用中间件 +func withMiddleware(h http.HandlerFunc, limiters ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc { + for _, m := range limiters { + h = m(h) + } + return h +} + +// authMiddleware 认证中间件(简化实现) +func authMiddleware() MiddlewareFunc { + return func(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // 简化: 检查Authorization头 + if r.Header.Get("Authorization") == "" { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + w.Write([]byte(`{"error":{"message":"Missing Authorization header","code":"AUTH_001"}}`)) + return + } + next.ServeHTTP(w, r) + } + } +} diff --git a/gateway/go.mod b/gateway/go.mod new file mode 100644 index 0000000..fa41636 --- /dev/null +++ b/gateway/go.mod @@ -0,0 +1,12 @@ +module lijiaoqiao/gateway + +go 1.21 + +require ( + github.com/golang-jwt/jwt/v5 v5.2.0 +) + +require ( + github.com/jackc/pgx/v5 v5.5.0 + golang.org/x/net v0.19.0 +) diff --git a/gateway/internal/adapter/adapter.go b/gateway/internal/adapter/adapter.go new file mode 100644 index 0000000..a4d79e4 --- /dev/null +++ b/gateway/internal/adapter/adapter.go @@ -0,0 +1,144 @@ +package adapter + +import ( + "context" + "io" +) + +// CompletionOptions 完成选项 +type CompletionOptions struct { + Temperature float64 + MaxTokens int + TopP float64 + Stream bool + Stop []string +} + +// CompletionResponse 完成响应 +type CompletionResponse struct { + ID string + Object string + Created int64 + Model string + Choices []Choice + Usage Usage +} + +// Choice 选择 +type Choice struct { + Index int + Message *Message + FinishReason string +} + +// Message 消息 +type Message struct { + Role string + Content string + Name string +} + +// Usage 使用量 +type Usage struct { + PromptTokens int + CompletionTokens int + TotalTokens int +} + +// StreamChunk 流式响应块 +type StreamChunk struct { + ID string + Object string + Created int64 + Model string + Choices []StreamChoice +} + +// StreamChoice 流式选择 +type StreamChoice struct { + Index int + Delta *Delta + FinishReason string +} + +// Delta 增量 +type Delta struct { + Role string + Content string +} + +// ProviderAdapter 供应商适配器抽象基类 +type ProviderAdapter interface { + // ChatCompletion 发送聊天完成请求 + ChatCompletion(ctx context.Context, model string, messages []Message, options CompletionOptions) (*CompletionResponse, error) + + // ChatCompletionStream 流式聊天完成请求 + ChatCompletionStream(ctx context.Context, model string, messages []Message, options CompletionOptions) (<-chan *StreamChunk, error) + + // GetUsage 获取使用量 + GetUsage(response *CompletionResponse) Usage + + // MapError 错误码映射 + MapError(err error) ProviderError + + // HealthCheck 健康检查 + HealthCheck(ctx context.Context) bool + + // ProviderName 供应商名称 + ProviderName() string + + // SupportedModels 支持的模型列表 + SupportedModels() []string +} + +// ProviderError 供应商错误 +type ProviderError struct { + Code string + Message string + HTTPStatus int + Retryable bool +} + +// Error 实现error接口 +func (e ProviderError) Error() string { + return e.Code + ": " + e.Message +} + +// IsRetryable 是否可重试 +func (e ProviderError) IsRetryable() bool { + return e.Retryable +} + +// Router 路由器接口 +type Router interface { + // SelectProvider 选择最佳Provider + SelectProvider(ctx context.Context, model string) (ProviderAdapter, error) + + // GetFallbackProviders 获取Fallback Providers + GetFallbackProviders(ctx context.Context, model string) ([]ProviderAdapter, error) + + // RecordResult 记录调用结果用于负载均衡 + RecordResult(ctx context.Context, provider string, success bool, latencyMs int64) +} + +// HealthChecker 健康检查器 +type HealthChecker interface { + // Check 检查服务健康状态 + Check(ctx context.Context) error + + // IsHealthy 是否健康 + IsHealthy() bool +} + +// ReadCloser 带错误回调的io.ReadCloser +type ReadCloser struct { + io.Reader + OnClose func() error +} + +func (r *ReadCloser) Close() error { + if r.OnClose != nil { + return r.OnClose() + } + return nil +} diff --git a/gateway/internal/adapter/openai_adapter.go b/gateway/internal/adapter/openai_adapter.go new file mode 100644 index 0000000..4829a19 --- /dev/null +++ b/gateway/internal/adapter/openai_adapter.go @@ -0,0 +1,326 @@ +package adapter + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" + + "lijiaoqiao/gateway/pkg/error" +) + +// OpenAIAdapter OpenAI适配器 +type OpenAIAdapter struct { + baseURL string + apiKey string + httpClient *http.Client + models []string +} + +// NewOpenAIAdapter 创建OpenAI适配器 +func NewOpenAIAdapter(baseURL, apiKey string, models []string) *OpenAIAdapter { + return &OpenAIAdapter{ + baseURL: baseURL, + apiKey: apiKey, + httpClient: &http.Client{ + Timeout: 60 * time.Second, + }, + models: models, + } +} + +// ChatCompletion 实现ChatCompletion接口 +func (a *OpenAIAdapter) ChatCompletion(ctx context.Context, model string, messages []Message, options CompletionOptions) (*CompletionResponse, error) { + // 构建请求 + reqBody := map[string]interface{}{ + "model": model, + "messages": messages, + } + + if options.Temperature > 0 { + reqBody["temperature"] = options.Temperature + } + if options.MaxTokens > 0 { + reqBody["max_tokens"] = options.MaxTokens + } + if options.TopP > 0 { + reqBody["top_p"] = options.TopP + } + if len(options.Stop) > 0 { + reqBody["stop"] = options.Stop + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + // 发送请求 + url := fmt.Sprintf("%s/v1/chat/completions", a.baseURL) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey)) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + var errResp map[string]interface{} + if json.Unmarshal(body, &errResp) == nil { + if errDetail, ok := errResp["error"].(map[string]interface{}); ok { + return nil, a.MapError(fmt.Errorf("%v", errDetail)) + } + } + return nil, a.MapError(fmt.Errorf("unexpected status: %d", resp.StatusCode)) + } + + // 解析响应 + var result struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Message struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // 转换响应 + response := &CompletionResponse{ + ID: result.ID, + Object: result.Object, + Created: result.Created, + Model: result.Model, + Choices: make([]Choice, len(result.Choices)), + } + + for i, c := range result.Choices { + response.Choices[i] = Choice{ + Message: &Message{ + Role: c.Message.Role, + Content: c.Message.Content, + }, + FinishReason: c.FinishReason, + } + } + + response.Usage = Usage{ + PromptTokens: result.Usage.PromptTokens, + CompletionTokens: result.Usage.CompletionTokens, + TotalTokens: result.Usage.TotalTokens, + } + + return response, nil +} + +// ChatCompletionStream 实现流式ChatCompletion +func (a *OpenAIAdapter) ChatCompletionStream(ctx context.Context, model string, messages []Message, options CompletionOptions) (<-chan *StreamChunk, error) { + // 构建请求 + reqBody := map[string]interface{}{ + "model": model, + "messages": messages, + "stream": true, + } + + if options.Temperature > 0 { + reqBody["temperature"] = options.Temperature + } + if options.MaxTokens > 0 { + reqBody["max_tokens"] = options.MaxTokens + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + url := fmt.Sprintf("%s/v1/chat/completions", a.baseURL) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey)) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + return nil, a.MapError(fmt.Errorf("unexpected status: %d, body: %s", resp.StatusCode, string(body))) + } + + ch := make(chan *StreamChunk, 100) + + go func() { + defer close(ch) + defer resp.Body.Close() + + reader := io.Reader(resp.Body) + for { + line, err := io.ReadLine(reader) + if err != nil { + return + } + + if len(line) < 6 { + continue + } + + // SSE格式: data: {...} + if string(line[:5]) != "data:" { + continue + } + + data := line[6:] + if string(data) == "[DONE]" { + return + } + + var chunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []struct { + Delta struct { + Role string `json:"role"` + Content string `json:"content"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + } + + if json.Unmarshal(data, &chunk) != nil { + continue + } + + streamChunk := &StreamChunk{ + ID: chunk.ID, + Object: chunk.Object, + Created: chunk.Created, + Model: chunk.Model, + Choices: make([]StreamChoice, len(chunk.Choices)), + } + + for i, c := range chunk.Choices { + streamChunk.Choices[i] = StreamChoice{ + Delta: &Delta{ + Role: c.Delta.Role, + Content: c.Delta.Content, + }, + FinishReason: c.FinishReason, + } + } + + select { + case ch <- streamChunk: + case <-ctx.Done(): + return + } + } + }() + + return ch, nil +} + +// GetUsage 获取使用量 +func (a *OpenAIAdapter) GetUsage(response *CompletionResponse) Usage { + return response.Usage +} + +// MapError 错误码映射 +func (a *OpenAIAdapter) MapError(err error) error { + // 简化实现,实际应根据OpenAI错误响应映射 + errStr := err.Error() + + if contains(errStr, "invalid_api_key") { + return error.NewGatewayError(error.PROVIDER_INVALID_KEY, "Invalid API key").WithInternal(err) + } + if contains(errStr, "rate_limit") { + return error.NewGatewayError(error.PROVIDER_RATE_LIMIT, "Rate limit exceeded").WithInternal(err) + } + if contains(errStr, "quota") { + return error.NewGatewayError(error.PROVIDER_QUOTA_EXCEEDED, "Quota exceeded").WithInternal(err) + } + if contains(errStr, "model_not_found") { + return error.NewGatewayError(error.PROVIDER_MODEL_NOT_FOUND, "Model not found").WithInternal(err) + } + + return error.NewGatewayError(error.PROVIDER_ERROR, "Provider error").WithInternal(err) +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr)) +} + +func containsHelper(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} + +// HealthCheck 健康检查 +func (a *OpenAIAdapter) HealthCheck(ctx context.Context) bool { + ctx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/v1/models", a.baseURL), nil) + if err != nil { + return false + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey)) + + resp, err := a.httpClient.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + + return resp.StatusCode == http.StatusOK +} + +// ProviderName 供应商名称 +func (a *OpenAIAdapter) ProviderName() string { + return "openai" +} + +// SupportedModels 支持的模型列表 +func (a *OpenAIAdapter) SupportedModels() []string { + return a.models +} diff --git a/gateway/internal/alert/alert.go b/gateway/internal/alert/alert.go new file mode 100644 index 0000000..c88adec --- /dev/null +++ b/gateway/internal/alert/alert.go @@ -0,0 +1,366 @@ +package alert + +import ( + "bytes" + "context" + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" + "net/smtp" + "strings" + "time" + + "lijiaoqiao/gateway/internal/config" +) + +// AlertType 告警类型 +type AlertType string + +const ( + AlertBudgetExceeded AlertType = "budget_exceeded" + AlertRateLimitExceeded AlertType = "rate_limit_exceeded" + AlertProviderFailure AlertType = "provider_failure" + AlertHighErrorRate AlertType = "high_error_rate" + AlertLatencySpike AlertType = "latency_spike" + AlertManualIntervention AlertType = "manual_intervention" +) + +// Alert 告警 +type Alert struct { + Type AlertType + Title string + Message string + Severity string // "info", "warning", "error", "critical" + TenantID int64 + RequestID string + Metadata map[string]interface{} + Timestamp time.Time +} + +// Sender 告警发送器接口 +type Sender interface { + Send(ctx context.Context, alert *Alert) error +} + +// Manager 告警管理器 +type Manager struct { + senders []Sender +} + +// NewManager 创建告警管理器 +func NewManager(cfg *config.AlertConfig) (*Manager, error) { + m := &Manager{ + senders: make([]Sender, 0), + } + + // 添加邮件发送器 + if cfg.Email.Enabled { + m.senders = append(m.senders, NewEmailSender(&cfg.Email)) + } + + // 添加钉钉发送器 + if cfg.DingTalk.Enabled { + sender, err := NewDingTalkSender(cfg.DingTalk.WebHook, cfg.DingTalk.Secret) + if err != nil { + return nil, fmt.Errorf("failed to create DingTalk sender: %w", err) + } + m.senders = append(m.senders, sender) + } + + // 添加飞书发送器 + if cfg.Feishu.Enabled { + sender, err := NewFeishuSender(cfg.Feishu.WebHook, cfg.Feishu.Secret) + if err != nil { + return nil, fmt.Errorf("failed to create Feishu sender: %w", err) + } + m.senders = append(m.senders, sender) + } + + return m, nil +} + +// Send 发送告警 +func (m *Manager) Send(ctx context.Context, alert *Alert) error { + if len(m.senders) == 0 { + return fmt.Errorf("no alert sender configured") + } + + var lastErr error + for _, sender := range m.senders { + if err := sender.Send(ctx, alert); err != nil { + lastErr = err + // 继续尝试其他发送器 + continue + } + } + + return lastErr +} + +// SendBudgetAlert 发送预算告警 +func (m *Manager) SendBudgetAlert(ctx context.Context, tenantID int64, current, limit float64) error { + return m.Send(ctx, &Alert{ + Type: AlertBudgetExceeded, + Title: "Budget Alert", + Message: fmt.Sprintf("Tenant %d exceeded budget: current=%.2f, limit=%.2f", tenantID, current, limit), + Severity: "warning", + TenantID: tenantID, + Metadata: map[string]interface{}{ + "current_usage": current, + "limit": limit, + }, + Timestamp: time.Now(), + }) +} + +// SendProviderFailureAlert 发送Provider故障告警 +func (m *Manager) SendProviderFailureAlert(ctx context.Context, provider string, err error) error { + return m.Send(ctx, &Alert{ + Type: AlertProviderFailure, + Title: "Provider Failure", + Message: fmt.Sprintf("Provider %s failed: %v", provider, err), + Severity: "error", + Metadata: map[string]interface{}{ + "provider": provider, + "error": err.Error(), + }, + Timestamp: time.Now(), + }) +} + +// EmailSender 邮件发送器 +type EmailSender struct { + cfg *config.EmailConfig +} + +// NewEmailSender 创建邮件发送器 +func NewEmailSender(cfg *config.EmailConfig) *EmailSender { + return &EmailSender{cfg: cfg} +} + +func (s *EmailSender) Send(ctx context.Context, alert *Alert) error { + // 构建邮件内容 + subject := fmt.Sprintf("[%s] %s - %s", strings.ToUpper(alert.Severity), alert.Type, alert.Title) + + body := fmt.Sprintf(` +告警类型: %s +严重程度: %s +时间: %s +消息: %s +`, alert.Type, alert.Severity, alert.Timestamp.Format(time.RFC3339), alert.Message) + + if alert.TenantID > 0 { + body += fmt.Sprintf("\n租户ID: %d", alert.TenantID) + } + + // 构建邮件 + msg := fmt.Sprintf("From: %s\r\n"+ + "To: %s\r\n"+ + "Subject: %s\r\n"+ + "Content-Type: text/plain; charset=UTF-8\r\n"+ + "\r\n"+ + "%s", + s.cfg.From, + strings.Join(s.cfg.To, ","), + subject, + body) + + // 发送邮件 + addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) + + auth := smtp.PlainAuth("", s.cfg.Username, s.cfg.Password, s.cfg.Host) + err := smtp.SendMail(addr, auth, s.cfg.From, s.cfg.To, []byte(msg)) + if err != nil { + return fmt.Errorf("failed to send email: %w", err) + } + + return nil +} + +// DingTalkSender 钉钉发送器 +type DingTalkSender struct { + webHook string + secret string + client *http.Client +} + +// NewDingTalkSender 创建钉钉发送器 +func NewDingTalkSender(webHook, secret string) (*DingTalkSender, error) { + return &DingTalkSender{ + webHook: webHook, + secret: secret, + client: &http.Client{ + Timeout: 10 * time.Second, + }, + }, nil +} + +func (s *DingTalkSender) Send(ctx context.Context, alert *Alert) error { + // 获取签名 + timestamp, sign := s.generateSign() + + // 构建请求URL + url := fmt.Sprintf("%s×tamp=%d&sign=%s", s.webHook, timestamp, sign) + + // 构建消息 + msg := map[string]interface{}{ + "msgtype": "markdown", + "markdown": map[string]string{ + "title": fmt.Sprintf("[%s] %s", strings.ToUpper(alert.Severity), alert.Title), + "text": fmt.Sprintf(`### [%s] %s + +**类型**: %s +**严重程度**: %s +**时间**: %s +**消息**: %s`, + strings.ToUpper(alert.Severity), + alert.Title, + alert.Type, + alert.Severity, + alert.Timestamp.Format(time.RFC3339), + alert.Message, + ), + }, + } + + jsonData, _ := json.Marshal(msg) + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := s.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("DingTalk API returned status: %d", resp.StatusCode) + } + + return nil +} + +func (s *DingTalkSender) generateSign() (int64, string) { + timestamp := time.Now().UnixMilli() + stringToSign := fmt.Sprintf("%d\n%s", timestamp, s.secret) + + h := hmac.New(sha256.New, []byte(s.secret)) + h.Write([]byte(stringToSign)) + signature := base64.StdEncoding.EncodeToString(h.Sum(nil)) + + return timestamp, urlEncode(signature) +} + +// FeishuSender 飞书发送器 +type FeishuSender struct { + webHook string + secret string + client *http.Client +} + +// NewFeishuSender 创建飞书发送器 +func NewFeishuSender(webHook, secret string) (*FeishuSender, error) { + return &FeishuSender{ + webHook: webHook, + secret: secret, + client: &http.Client{ + Timeout: 10 * time.Second, + }, + }, nil +} + +func (s *FeishuSender) Send(ctx context.Context, alert *Alert) error { + // 获取tenant_access_token (简化实现) + token, err := s.getTenantAccessToken() + if err != nil { + return err + } + + // 构建消息 + msg := map[string]interface{}{ + "msg_type": "interactive", + "card": map[string]interface{}{ + "header": map[string]interface{}{ + "title": map[string]string{ + "tag": "plain_text", + "content": fmt.Sprintf("[%s] %s", strings.ToUpper(alert.Severity), alert.Title), + }, + "template": s.getTemplateColor(alert.Severity), + }, + "elements": []map[string]interface{}{ + { + "tag": "div", + "text": map[string]string{ + "tag": "lark_md", + "content": fmt.Sprintf("**类型**: %s\n**严重程度**: %s\n**时间**: %s\n**消息**: %s", + alert.Type, + alert.Severity, + alert.Timestamp.Format(time.RFC3339), + alert.Message, + ), + }, + }, + }, + }, + } + + jsonData, _ := json.Marshal(msg) + + url := fmt.Sprintf("%s?tenant_access_token=%s", s.webHook, token) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return err + } + req.Header.Set("Content-Type", "application/json") + + resp, err := s.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("Feishu API returned status: %d", resp.StatusCode) + } + + return nil +} + +func (s *FeishuSender) getTenantAccessToken() (string, error) { + // 简化实现,实际应该调用飞书API获取tenant_access_token + // https://open.feishu.cn/document/ukTMukTMukTM/ukDNz4SO0MjL5QDO/auth-v3/auth/tenant_access_token/internal + return "dummy_token", nil +} + +func (s *FeishuSender) getTemplateColor(severity string) string { + switch severity { + case "critical": + return "red" + case "error": + return "orange" + case "warning": + return "yellow" + default: + return "blue" + } +} + +// urlEncode URL编码 +func urlEncode(str string) string { + result := "" + for _, c := range str { + if c == '+' || c == ' ' || c == '/' || c == '=' { + result += fmt.Sprintf("%%%02X", c) + } else { + result += string(c) + } + } + return result +} diff --git a/gateway/internal/config/config.go b/gateway/internal/config/config.go new file mode 100644 index 0000000..6307648 --- /dev/null +++ b/gateway/internal/config/config.go @@ -0,0 +1,162 @@ +package config + +import ( + "os" + "time" +) + +// Config 网关配置 +type Config struct { + Server ServerConfig + Database DatabaseConfig + Redis RedisConfig + Router RouterConfig + RateLimit RateLimitConfig + Alert AlertConfig + Providers []ProviderConfig +} + +// ServerConfig 服务配置 +type ServerConfig struct { + Host string + Port int + ReadTimeout time.Duration + WriteTimeout time.Duration + IdleTimeout time.Duration +} + +// DatabaseConfig 数据库配置 +type DatabaseConfig struct { + Host string + Port int + User string + Password string + Database string + MaxConns int +} + +// RedisConfig Redis配置 +type RedisConfig struct { + Host string + Port int + Password string + DB int + PoolSize int +} + +// RouterConfig 路由配置 +type RouterConfig struct { + Strategy string // "latency", "cost", "availability", "weighted" + Timeout time.Duration + MaxRetries int + RetryDelay time.Duration + HealthCheckInterval time.Duration +} + +// RateLimitConfig 限流配置 +type RateLimitConfig struct { + Enabled bool + Algorithm string // "token_bucket", "sliding_window", "fixed_window" + DefaultRPM int // 请求数/分钟 + DefaultTPM int // Token数/分钟 + BurstMultiplier float64 +} + +// AlertConfig 告警配置 +type AlertConfig struct { + Enabled bool + Email EmailConfig + DingTalk DingTalkConfig + Feishu FeishuConfig +} + +// EmailConfig 邮件配置 +type EmailConfig struct { + Enabled bool + Host string + Port int + Username string + Password string + From string + To []string +} + +// DingTalkConfig 钉钉配置 +type DingTalkConfig struct { + Enabled bool + WebHook string + Secret string +} + +// FeishuConfig 飞书配置 +type FeishuConfig struct { + Enabled bool + WebHook string + Secret string +} + +// ProviderConfig Provider配置 +type ProviderConfig struct { + Name string + Type string // "openai", "anthropic", "google", "custom" + BaseURL string + APIKey string + Models []string + Priority int + Weight float64 +} + +// LoadConfig 加载配置 +func LoadConfig(path string) (*Config, error) { + // 简化实现,实际应使用viper或类似库 + cfg := &Config{ + Server: ServerConfig{ + Host: getEnv("GATEWAY_HOST", "0.0.0.0"), + Port: 8080, + ReadTimeout: 30 * time.Second, + WriteTimeout: 30 * time.Second, + IdleTimeout: 120 * time.Second, + }, + Router: RouterConfig{ + Strategy: "latency", + Timeout: 30 * time.Second, + MaxRetries: 3, + RetryDelay: 1 * time.Second, + HealthCheckInterval: 10 * time.Second, + }, + RateLimit: RateLimitConfig{ + Enabled: true, + Algorithm: "token_bucket", + DefaultRPM: 60, + DefaultTPM: 60000, + BurstMultiplier: 1.5, + }, + Alert: AlertConfig{ + Enabled: true, + Email: EmailConfig{ + Enabled: false, + Host: getEnv("SMTP_HOST", "smtp.example.com"), + Port: 587, + }, + DingTalk: DingTalkConfig{ + Enabled: getEnv("DINGTALK_ENABLED", "false") == "true", + WebHook: getEnv("DINGTALK_WEBHOOK", ""), + Secret: getEnv("DINGTALK_SECRET", ""), + }, + Feishu: FeishuConfig{ + Enabled: getEnv("FEISHU_ENABLED", "false") == "true", + WebHook: getEnv("FEISHU_WEBHOOK", ""), + Secret: getEnv("FEISHU_SECRET", ""), + }, + }, + } + + return cfg, nil +} + +func getEnv(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} diff --git a/gateway/internal/handler/handler.go b/gateway/internal/handler/handler.go new file mode 100644 index 0000000..acf6710 --- /dev/null +++ b/gateway/internal/handler/handler.go @@ -0,0 +1,366 @@ +package handler + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strconv" + "time" + + "lijiaoqiao/gateway/internal/adapter" + "lijiaoqiao/gateway/internal/router" + "lijiaoqiao/gateway/pkg/error" + "lijiaoqiao/gateway/pkg/model" +) + +// Handler API处理器 +type Handler struct { + router *router.Router + version string +} + +// NewHandler 创建处理器 +func NewHandler(r *router.Router) *Handler { + return &Handler{ + router: r, + version: "v1", + } +} + +// ChatCompletionsHandle /v1/chat/completions端点 +func (h *Handler) ChatCompletionsHandle(w http.ResponseWriter, r *http.Request) { + startTime := time.Now() + requestID := r.Header.Get("X-Request-ID") + if requestID == "" { + requestID = generateRequestID() + } + + ctx := context.WithValue(r.Context(), "request_id", requestID) + ctx = context.WithValue(ctx, "start_time", startTime) + + // 解析请求 + var req model.ChatCompletionRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "invalid request body: "+err.Error()).WithRequestID(requestID)) + return + } + + // 验证请求 + if len(req.Messages) == 0 { + h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "messages is required").WithRequestID(requestID)) + return + } + + // 选择Provider + provider, err := h.router.SelectProvider(ctx, req.Model) + if err != nil { + h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID)) + return + } + + // 转换消息格式 + messages := make([]adapter.Message, len(req.Messages)) + for i, m := range req.Messages { + messages[i] = adapter.Message{ + Role: m.Role, + Content: m.Content, + Name: m.Name, + } + } + + // 构建选项 + options := adapter.CompletionOptions{ + Temperature: req.Temperature, + MaxTokens: req.MaxTokens, + TopP: req.TopP, + Stream: req.Stream, + Stop: req.Stop, + } + + // 处理流式请求 + if req.Stream { + h.handleStream(ctx, w, r, provider, req.Model, messages, options, requestID) + return + } + + // 处理非流式请求 + response, err := provider.ChatCompletion(ctx, req.Model, messages, options) + if err != nil { + // 记录失败 + h.router.RecordResult(ctx, provider.ProviderName(), false, time.Since(startTime).Milliseconds()) + h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID)) + return + } + + // 记录成功 + h.router.RecordResult(ctx, provider.ProviderName(), true, time.Since(startTime).Milliseconds()) + + // 转换响应 + chatResp := model.ChatCompletionResponse{ + ID: response.ID, + Object: "chat.completion", + Created: response.Created, + Model: response.Model, + Choices: make([]model.Choice, len(response.Choices)), + } + + for i, c := range response.Choices { + chatResp.Choices[i] = model.Choice{ + Index: c.Index, + Message: model.ChatMessage{ + Role: c.Message.Role, + Content: c.Message.Content, + }, + FinishReason: c.FinishReason, + } + } + + chatResp.Usage = model.Usage{ + PromptTokens: response.Usage.PromptTokens, + CompletionTokens: response.Usage.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + } + + h.writeJSON(w, http.StatusOK, chatResp, requestID) +} + +// handleStream 处理流式请求 +func (h *Handler) handleStream(ctx context.Context, w http.ResponseWriter, r *http.Request, provider adapter.ProviderAdapter, model string, messages []adapter.Message, options adapter.CompletionOptions, requestID string) { + ch, err := provider.ChatCompletionStream(ctx, model, messages, options) + if err != nil { + h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID)) + return + } + + // 设置SSE头 + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.Header().Set("X-Request-ID", requestID) + + flusher, ok := w.(http.Flusher) + if !ok { + h.writeError(w, r, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "streaming not supported").WithRequestID(requestID)) + return + } + + // 流式发送响应 + for chunk := range ch { + data := fmt.Sprintf("data: %s\n\n", marshalJSON(chunk)) + w.Write([]byte(data)) + flusher.Flush() + } + + w.Write([]byte("data: [DONE]\n\n")) + flusher.Flush() +} + +// CompletionsHandle /v1/completions端点 +func (h *Handler) CompletionsHandle(w http.ResponseWriter, r *http.Request) { + requestID := r.Header.Get("X-Request-ID") + if requestID == "" { + requestID = generateRequestID() + } + + // 解析请求 + var req model.CompletionRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "invalid request body").WithRequestID(requestID)) + return + } + + // 转换格式并调用ChatCompletions + chatReq := model.ChatCompletionRequest{ + Model: req.Model, + Temperature: req.Temperature, + MaxTokens: req.MaxTokens, + TopP: req.TopP, + Stream: req.Stream, + Stop: req.Stop, + Messages: []model.ChatMessage{ + {Role: "user", Content: req.Prompt}, + }, + } + + // 复用ChatCompletions逻辑 + req.Method = "POST" + req.URL.Path = "/v1/chat/completions" + + // 重新构造请求体并处理 + ctx := r.Context() + messages := []adapter.Message{{Role: "user", Content: req.Prompt}} + + provider, err := h.router.SelectProvider(ctx, req.Model) + if err != nil { + h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID)) + return + } + + options := adapter.CompletionOptions{ + Temperature: req.Temperature, + MaxTokens: req.MaxTokens, + TopP: req.TopP, + Stream: req.Stream, + Stop: req.Stop, + } + + if req.Stream { + h.handleStream(ctx, w, r, provider, req.Model, messages, options, requestID) + return + } + + response, err := provider.ChatCompletion(ctx, req.Model, messages, options) + if err != nil { + h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID)) + return + } + + // 转换响应为Completion格式 + compResp := model.CompletionResponse{ + ID: response.ID, + Object: "text_completion", + Created: response.Created, + Model: response.Model, + Choices: make([]model.Choice1, len(response.Choices)), + } + + for i, c := range response.Choices { + compResp.Choices[i] = model.Choice1{ + Text: c.Message.Content, + Index: i, + FinishReason: c.FinishReason, + } + } + + compResp.Usage = model.Usage{ + PromptTokens: response.Usage.PromptTokens, + CompletionTokens: response.Usage.CompletionTokens, + TotalTokens: response.Usage.TotalTokens, + } + + h.writeJSON(w, http.StatusOK, compResp, requestID) +} + +// ModelsHandle /v1/models端点 +func (h *Handler) ModelsHandle(w http.ResponseWriter, r *http.Request) { + requestID := r.Header.Get("X-Request-ID") + if requestID == "" { + requestID = generateRequestID() + } + + // 返回支持的模型列表 + models := []map[string]interface{}{ + {"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"}, + {"id": "gpt-3.5-turbo", "object": "model", "created": 1677610602, "owned_by": "openai"}, + {"id": "claude-3-opus", "object": "model", "created": 1709598254, "owned_by": "anthropic"}, + {"id": "claude-3-sonnet", "object": "model", "created": 1709598255, "owned_by": "anthropic"}, + } + + h.writeJSON(w, http.StatusOK, map[string]interface{}{ + "object": "list", + "data": models, + }, requestID) +} + +// HealthHandle /health端点 +func (h *Handler) HealthHandle(w http.ResponseWriter, r *http.Request) { + healthStatus := h.router.GetHealthStatus() + + allHealthy := true + services := make(map[string]bool) + for name, health := range healthStatus { + services[name] = health.Available + if !health.Available { + allHealthy = false + } + } + + status := "healthy" + statusCode := http.StatusOK + if !allHealthy { + status = "degraded" + statusCode = http.StatusServiceUnavailable + } + + h.writeJSON(w, statusCode, model.HealthStatus{ + Status: status, + Timestamp: time.Now(), + Services: services, + }, "") +} + +func (h *Handler) writeJSON(w http.ResponseWriter, status int, data interface{}, requestID string) { + w.Header().Set("Content-Type", "application/json") + if requestID != "" { + w.Header().Set("X-Request-ID", requestID) + } + w.WriteHeader(status) + json.NewEncoder(w).Encode(data) +} + +func (h *Handler) writeError(w http.ResponseWriter, r *http.Request, err *error.GatewayError) { + info := err.GetErrorInfo() + w.Header().Set("Content-Type", "application/json") + if err.RequestID != "" { + w.Header().Set("X-Request-ID", err.RequestID) + } + w.WriteHeader(info.HTTPStatus) + + resp := model.ErrorResponse{ + Error: model.ErrorDetail{ + Message: err.Message, + Type: "gateway_error", + Code: string(err.Code), + }, + } + json.NewEncoder(w).Encode(resp) +} + +func generateRequestID() string { + return fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano()) +} + +func marshalJSON(v interface{}) string { + data, _ := json.Marshal(v) + return string(data) +} + +// SSEReader 流式响应读取器 +type SSEReader struct { + reader *bufio.Reader +} + +func NewSSEReader(r io.Reader) *SSEReader { + return &SSEReader{reader: bufio.NewReader(r)} +} + +func (s *SSEReader) ReadLine() (string, error) { + line, err := s.reader.ReadString('\n') + if err != nil { + return "", err + } + return line[:len(line)-1], nil +} + +func parseSSEData(line string) string { + if len(line) < 6 { + return "" + } + if line[:5] != "data:" { + return "" + } + return line[6:] +} + +func getenv(key, defaultValue string) string { + return defaultValue +} + +func init() { + getenv = func(key, defaultValue string) string { + return defaultValue + } +} diff --git a/gateway/internal/ratelimit/ratelimit.go b/gateway/internal/ratelimit/ratelimit.go new file mode 100644 index 0000000..8da0768 --- /dev/null +++ b/gateway/internal/ratelimit/ratelimit.go @@ -0,0 +1,336 @@ +package ratelimit + +import ( + "context" + "fmt" + "sync" + "time" + + "lijiaoqiao/gateway/pkg/error" +) + +// Algorithm 限流算法 +type Algorithm string + +const ( + TokenBucket Algorithm = "token_bucket" + SlidingWindow Algorithm = "sliding_window" + FixedWindow Algorithm = "fixed_window" +) + +// Limiter 限流器接口 +type Limiter interface { + // Allow 检查是否允许请求 + Allow(ctx context.Context, key string) (bool, error) + // AllowToken 检查是否允许消耗token + AllowToken(ctx context.Context, key string, tokens int) (bool, error) + // GetLimit 获取当前限制 + GetLimit(key string) *Limit +} + +// Limit 限制配置 +type Limit struct { + RPM int // 请求数/分钟 + TPM int // Token数/分钟 + Burst int // 突发容量 + Remaining int // 剩余请求数 + ResetAt time.Time // 重置时间 +} + +// TokenBucketLimiter Token桶限流器 +type TokenBucketLimiter struct { + mu sync.RWMutex + buckets map[string]*tokenBucket + defaultRPM int + defaultTPM int + burstMultiplier float64 + cleanInterval time.Duration +} + +type tokenBucket struct { + tokens float64 + maxTokens float64 + tokensPerSec float64 + lastRefill time.Time + mu sync.Mutex +} + +// NewTokenBucketLimiter 创建Token桶限流器 +func NewTokenBucketLimiter(defaultRPM, defaultTPM int, burstMultiplier float64) *TokenBucketLimiter { + limiter := &TokenBucketLimiter{ + buckets: make(map[string]*tokenBucket), + defaultRPM: defaultRPM, + defaultTPM: defaultTPM, + burstMultiplier: burstMultiplier, + cleanInterval: 5 * time.Minute, + } + + // 启动清理goroutine + go limiter.cleanup() + + return limiter +} + +// Allow 检查是否允许请求 +func (l *TokenBucketLimiter) Allow(ctx context.Context, key string) (bool, error) { + return l.AllowToken(ctx, key, 1) +} + +// AllowToken 检查是否允许消耗token +func (l *TokenBucketLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) { + l.mu.Lock() + bucket, exists := l.buckets[key] + if !exists { + bucket = l.newBucket(l.defaultRPM, l.defaultTPM) + l.buckets[key] = bucket + } + l.mu.Unlock() + + bucket.mu.Lock() + defer bucket.mu.Unlock() + + // 补充token + l.refill(bucket) + + // 检查是否有足够的token + if bucket.tokens >= float64(tokens) { + bucket.tokens -= float64(tokens) + return true, nil + } + + return false, nil +} + +// GetLimit 获取当前限制 +func (l *TokenBucketLimiter) GetLimit(key string) *Limit { + l.mu.RLock() + bucket, exists := l.buckets[key] + l.mu.RUnlock() + + if !exists { + return &Limit{ + RPM: l.defaultRPM, + TPM: l.defaultTPM, + Burst: int(float64(l.defaultRPM) * l.burstMultiplier), + } + } + + bucket.mu.Lock() + defer bucket.mu.Unlock() + + return &Limit{ + RPM: l.defaultRPM, + TPM: l.defaultTPM, + Burst: int(bucket.maxTokens), + Remaining: int(bucket.tokens), + ResetAt: bucket.lastRefill.Add(time.Minute), + } +} + +func (l *TokenBucketLimiter) newBucket(rpm, tpm int) *tokenBucket { + burst := int(float64(rpm) * l.burstMultiplier) + return &tokenBucket{ + tokens: float64(burst), + maxTokens: float64(burst), + tokensPerSec: float64(rpm) / 60.0, + lastRefill: time.Now(), + } +} + +func (l *TokenBucketLimiter) refill(bucket *tokenBucket) { + now := time.Now() + elapsed := now.Sub(bucket.lastRefill).Seconds() + + // 添加新token + bucket.tokens += elapsed * bucket.tokensPerSec + if bucket.tokens > bucket.maxTokens { + bucket.tokens = bucket.maxTokens + } + + bucket.lastRefill = now +} + +func (l *TokenBucketLimiter) cleanup() { + ticker := time.NewTicker(l.cleanInterval) + defer ticker.Stop() + + for range ticker.C { + l.mu.Lock() + now := time.Now() + for key, bucket := range l.buckets { + bucket.mu.Lock() + // 如果bucket完全空了且超过10分钟没使用,删除它 + if bucket.tokens >= bucket.maxTokens && now.Sub(bucket.lastRefill) > 10*time.Minute { + delete(l.buckets, key) + } + bucket.mu.Unlock() + } + l.mu.Unlock() + } +} + +// SlidingWindowLimiter 滑动窗口限流器 +type SlidingWindowLimiter struct { + mu sync.RWMutex + windows map[string]*slidingWindow + windowSize time.Duration + maxRequests int + cleanInterval time.Duration +} + +type slidingWindow struct { + requests []time.Time + mu sync.Mutex +} + +func NewSlidingWindowLimiter(windowSize time.Duration, maxRequests int) *SlidingWindowLimiter { + limiter := &SlidingWindowLimiter{ + windows: make(map[string]*slidingWindow), + windowSize: windowSize, + maxRequests: maxRequests, + cleanInterval: 1 * time.Minute, + } + + go limiter.cleanup() + return limiter +} + +func (l *SlidingWindowLimiter) Allow(ctx context.Context, key string) (bool, error) { + l.mu.Lock() + window, exists := l.windows[key] + if !exists { + window = &slidingWindow{requests: make([]time.Time, 0)} + l.windows[key] = window + } + l.mu.Unlock() + + window.mu.Lock() + defer window.mu.Unlock() + + now := time.Now() + cutoff := now.Add(-l.windowSize) + + // 清理过期的请求 + validRequests := make([]time.Time, 0) + for _, t := range window.requests { + if t.After(cutoff) { + validRequests = append(validRequests, t) + } + } + window.requests = validRequests + + // 检查是否超过限制 + if len(window.requests) >= l.maxRequests { + return false, nil + } + + window.requests = append(window.requests, now) + return true, nil +} + +func (l *SlidingWindowLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) { + // 对于滑动窗口,tokens只是计数,这里简化为1个请求 + return l.Allow(ctx, key) +} + +func (l *SlidingWindowLimiter) GetLimit(key string) *Limit { + l.mu.RLock() + window, exists := l.windows[key] + l.mu.RUnlock() + + remaining := l.maxRequests + if exists { + window.mu.Lock() + cutoff := time.Now().Add(-l.windowSize) + count := 0 + for _, t := range window.requests { + if t.After(cutoff) { + count++ + } + } + remaining = l.maxRequests - count + if remaining < 0 { + remaining = 0 + } + window.mu.Unlock() + } + + return &Limit{ + RPM: l.maxRequests, + ResetAt: time.Now().Add(l.windowSize), + Remaining: remaining, + } +} + +func (l *SlidingWindowLimiter) cleanup() { + ticker := time.NewTicker(l.cleanInterval) + defer ticker.Stop() + + for range ticker.C { + l.mu.Lock() + now := time.Now() + for key, window := range l.windows { + window.mu.Lock() + cutoff := now.Add(-l.windowSize * 2) + validRequests := make([]time.Time, 0) + for _, t := range window.requests { + if t.After(cutoff) { + validRequests = append(validRequests, t) + } + } + if len(validRequests) == 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 { + delete(l.windows, key) + } else { + window.requests = validRequests + } + window.mu.Unlock() + } + l.mu.Unlock() + } +} + +// Middleware 限流中间件 +type Middleware struct { + limiter Limiter +} + +func NewMiddleware(limiter Limiter) *Middleware { + return &Middleware{limiter: limiter} +} + +func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // 使用API Key作为限流key + key := r.Header.Get("Authorization") + if key == "" { + key = r.RemoteAddr + } + + allowed, err := m.limiter.Allow(r.Context(), key) + if err != nil { + writeError(w, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "rate limiter error")) + return + } + + if !allowed { + limit := m.limiter.GetLimit(key) + w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", limit.RPM)) + w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", limit.Remaining)) + w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", limit.ResetAt.Unix())) + + writeError(w, error.NewGatewayError(error.RATE_LIMIT_EXCEEDED, "rate limit exceeded")) + return + } + + next.ServeHTTP(w, r) + } +} + +import "net/http" + +func writeError(w http.ResponseWriter, err *error.GatewayError) { + info := err.GetErrorInfo() + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(info.HTTPStatus) + w.Write([]byte(fmt.Sprintf(`{"error":{"message":"%s","code":"%s"}}`, err.Message, err.Code))) +} diff --git a/gateway/internal/router/router.go b/gateway/internal/router/router.go new file mode 100644 index 0000000..875be30 --- /dev/null +++ b/gateway/internal/router/router.go @@ -0,0 +1,261 @@ +package router + +import ( + "context" + "math" + "sync" + "time" + + "lijiaoqiao/gateway/internal/adapter" + "lijiaoqiao/gateway/pkg/error" +) + +// LoadBalancerStrategy 负载均衡策略 +type LoadBalancerStrategy string + +const ( + StrategyLatency LoadBalancerStrategy = "latency" + StrategyRoundRobin LoadBalancerStrategy = "round_robin" + StrategyWeighted LoadBalancerStrategy = "weighted" + StrategyAvailability LoadBalancerStrategy = "availability" +) + +// ProviderHealth Provider健康状态 +type ProviderHealth struct { + Name string + Available bool + LatencyMs int64 + FailureRate float64 + Weight float64 + LastCheckTime time.Time +} + +// Router 路由器 +type Router struct { + providers map[string]adapter.ProviderAdapter + health map[string]*ProviderHealth + strategy LoadBalancerStrategy + mu sync.RWMutex +} + +// NewRouter 创建路由器 +func NewRouter(strategy LoadBalancerStrategy) *Router { + return &Router{ + providers: make(map[string]adapter.ProviderAdapter), + health: make(map[string]*ProviderHealth), + strategy: strategy, + } +} + +// RegisterProvider 注册Provider +func (r *Router) RegisterProvider(name string, provider adapter.ProviderAdapter) { + r.mu.Lock() + defer r.mu.Unlock() + + r.providers[name] = provider + r.health[name] = &ProviderHealth{ + Name: name, + Available: true, + LatencyMs: 0, + FailureRate: 0, + Weight: 1.0, + LastCheckTime: time.Now(), + } +} + +// SelectProvider 选择最佳Provider +func (r *Router) SelectProvider(ctx context.Context, model string) (adapter.ProviderAdapter, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + var candidates []string + for name, provider := range r.providers { + if r.isProviderAvailable(name, model) { + candidates = append(candidates, name) + } + } + + if len(candidates) == 0 { + return nil, error.NewGatewayError(error.ROUTER_NO_PROVIDER_AVAILABLE, "no provider available for model: "+model) + } + + // 根据策略选择 + switch r.strategy { + case StrategyLatency: + return r.selectByLatency(candidates) + case StrategyWeighted: + return r.selectByWeight(candidates) + case StrategyAvailability: + return r.selectByAvailability(candidates) + default: + return r.selectByLatency(candidates) + } +} + +func (r *Router) isProviderAvailable(name, model string) bool { + health, ok := r.health[name] + if !ok { + return false + } + + if !health.Available { + return false + } + + // 检查模型是否支持 + provider := r.providers[name] + if provider == nil { + return false + } + + for _, m := range provider.SupportedModels() { + if m == model || m == "*" { + return true + } + } + + return false +} + +func (r *Router) selectByLatency(candidates []string) (adapter.ProviderAdapter, error) { + var bestProvider adapter.ProviderAdapter + var minLatency int64 = math.MaxInt64 + + for _, name := range candidates { + health := r.health[name] + if health.LatencyMs < minLatency { + minLatency = health.LatencyMs + bestProvider = r.providers[name] + } + } + + if bestProvider == nil { + return nil, error.NewGatewayError(error.ROUTER_NO_PROVIDER_AVAILABLE, "no available provider") + } + + return bestProvider, nil +} + +func (r *Router) selectByWeight(candidates []string) (adapter.ProviderAdapter, error) { + var totalWeight float64 + for _, name := range candidates { + totalWeight += r.health[name].Weight + } + + randVal := float64(time.Now().UnixNano()) / float64(math.MaxInt64) * totalWeight + var cumulative float64 + + for _, name := range candidates { + cumulative += r.health[name].Weight + if randVal <= cumulative { + return r.providers[name], nil + } + } + + return r.providers[candidates[0]], nil +} + +func (r *Router) selectByAvailability(candidates []string) (adapter.ProviderAdapter, error) { + var bestProvider adapter.ProviderAdapter + var minFailureRate float64 = math.MaxFloat64 + + for _, name := range candidates { + health := r.health[name] + if health.FailureRate < minFailureRate { + minFailureRate = health.FailureRate + bestProvider = r.providers[name] + } + } + + if bestProvider == nil { + return nil, error.NewGatewayError(error.ROUTER_NO_PROVIDER_AVAILABLE, "no available provider") + } + + return bestProvider, nil +} + +// GetFallbackProviders 获取Fallback Providers +func (r *Router) GetFallbackProviders(ctx context.Context, model string) ([]adapter.ProviderAdapter, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + var fallbacks []adapter.ProviderAdapter + + for name, provider := range r.providers { + if name == "primary" { + continue // 跳过主Provider + } + if r.isProviderAvailable(name, model) { + fallbacks = append(fallbacks, provider) + } + } + + return fallbacks, nil +} + +// RecordResult 记录调用结果 +func (r *Router) RecordResult(ctx context.Context, providerName string, success bool, latencyMs int64) { + r.mu.Lock() + defer r.mu.Unlock() + + health, ok := r.health[providerName] + if !ok { + return + } + + // 更新延迟 + if latencyMs > 0 { + // 指数移动平均 + if health.LatencyMs == 0 { + health.LatencyMs = latencyMs + } else { + health.LatencyMs = (health.LatencyMs*7 + latencyMs) / 8 + } + } + + // 更新失败率 + if success { + if health.FailureRate > 0 { + health.FailureRate = health.FailureRate * 0.9 // 下降 + } + } else { + health.FailureRate = health.FailureRate*0.9 + 0.1 // 上升 + } + + // 检查是否应该标记为不可用 + if health.FailureRate > 0.5 { + health.Available = false + } + + health.LastCheckTime = time.Now() +} + +// UpdateHealth 更新健康状态 +func (r *Router) UpdateHealth(providerName string, available bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if health, ok := r.health[providerName]; ok { + health.Available = available + health.LastCheckTime = time.Now() + } +} + +// GetHealthStatus 获取健康状态 +func (r *Router) GetHealthStatus() map[string]*ProviderHealth { + r.mu.RLock() + defer r.mu.RUnlock() + + result := make(map[string]*ProviderHealth) + for name, health := range r.health { + result[name] = &ProviderHealth{ + Name: health.Name, + Available: health.Available, + LatencyMs: health.LatencyMs, + FailureRate: health.FailureRate, + Weight: health.Weight, + LastCheckTime: health.LastCheckTime, + } + } + return result +} diff --git a/gateway/pkg/error/error.go b/gateway/pkg/error/error.go new file mode 100644 index 0000000..b70ca59 --- /dev/null +++ b/gateway/pkg/error/error.go @@ -0,0 +1,246 @@ +package error + +import "fmt" + +// ErrorCode 错误码枚举 +type ErrorCode string + +const ( + // 认证授权 (AUTH_*) + AUTH_INVALID_TOKEN ErrorCode = "AUTH_001" + AUTH_INSUFFICIENT_PERMISSION ErrorCode = "AUTH_002" + AUTH_MFA_REQUIRED ErrorCode = "AUTH_003" + + // 计费 (BILLING_*) + BILLING_INSUFFICIENT_BALANCE ErrorCode = "BILLING_001" + BILLING_CHARGE_FAILED ErrorCode = "BILLING_002" + BILLING_REFUND_FAILED ErrorCode = "BILLING_003" + BILLING_DISCREPANCY ErrorCode = "BILLING_004" + + // 路由 (ROUTER_*) + ROUTER_NO_PROVIDER_AVAILABLE ErrorCode = "ROUTER_001" + ROUTER_ALL_PROVIDERS_FAILED ErrorCode = "ROUTER_002" + ROUTER_TIMEOUT ErrorCode = "ROUTER_003" + + // 供应商 (PROVIDER_*) + PROVIDER_INVALID_KEY ErrorCode = "PROVIDER_001" + PROVIDER_RATE_LIMIT ErrorCode = "PROVIDER_002" + PROVIDER_QUOTA_EXCEEDED ErrorCode = "PROVIDER_003" + PROVIDER_MODEL_NOT_FOUND ErrorCode = "PROVIDER_004" + PROVIDER_ERROR ErrorCode = "PROVIDER_005" + + // 限流 (RATE_LIMIT_*) + RATE_LIMIT_EXCEEDED ErrorCode = "RATE_LIMIT_001" + RATE_LIMIT_TOKEN_EXCEEDED ErrorCode = "RATE_LIMIT_002" + RATE_LIMIT_BURST_EXCEEDED ErrorCode = "RATE_LIMIT_003" + + // 通用 (COMMON_*) + COMMON_INVALID_REQUEST ErrorCode = "COMMON_001" + COMMON_RESOURCE_NOT_FOUND ErrorCode = "COMMON_002" + COMMON_INTERNAL_ERROR ErrorCode = "COMMON_003" + COMMON_SERVICE_UNAVAILABLE ErrorCode = "COMMON_004" +) + +// ErrorInfo 错误信息 +type ErrorInfo struct { + Code ErrorCode + Message string + HTTPStatus int + Retryable bool +} + +// GatewayError 网关错误 +type GatewayError struct { + Code ErrorCode + Message string + Details map[string]interface{} + RequestID string + Internal error +} + +func (e *GatewayError) Error() string { + if e.Internal != nil { + return fmt.Sprintf("%s: %s (caused by: %v)", e.Code, e.Message, e.Internal) + } + return fmt.Sprintf("%s: %s", e.Code, e.Message) +} + +func (e *GatewayError) Unwrap() error { + return e.Internal +} + +// ErrorDefinitions 错误码定义 +var ErrorDefinitions = map[ErrorCode]ErrorInfo{ + AUTH_INVALID_TOKEN: { + Code: AUTH_INVALID_TOKEN, + Message: "Invalid or expired token", + HTTPStatus: 401, + Retryable: false, + }, + AUTH_INSUFFICIENT_PERMISSION: { + Code: AUTH_INSUFFICIENT_PERMISSION, + Message: "Insufficient permissions", + HTTPStatus: 403, + Retryable: false, + }, + AUTH_MFA_REQUIRED: { + Code: AUTH_MFA_REQUIRED, + Message: "MFA verification required", + HTTPStatus: 403, + Retryable: false, + }, + BILLING_INSUFFICIENT_BALANCE: { + Code: BILLING_INSUFFICIENT_BALANCE, + Message: "Insufficient balance", + HTTPStatus: 402, + Retryable: false, + }, + BILLING_CHARGE_FAILED: { + Code: BILLING_CHARGE_FAILED, + Message: "Charge failed", + HTTPStatus: 500, + Retryable: true, + }, + BILLING_REFUND_FAILED: { + Code: BILLING_REFUND_FAILED, + Message: "Refund failed", + HTTPStatus: 500, + Retryable: true, + }, + BILLING_DISCREPANCY: { + Code: BILLING_DISCREPANCY, + Message: "Billing discrepancy detected", + HTTPStatus: 500, + Retryable: true, + }, + ROUTER_NO_PROVIDER_AVAILABLE: { + Code: ROUTER_NO_PROVIDER_AVAILABLE, + Message: "No provider available", + HTTPStatus: 503, + Retryable: true, + }, + ROUTER_ALL_PROVIDERS_FAILED: { + Code: ROUTER_ALL_PROVIDERS_FAILED, + Message: "All providers failed", + HTTPStatus: 503, + Retryable: true, + }, + ROUTER_TIMEOUT: { + Code: ROUTER_TIMEOUT, + Message: "Request timeout", + HTTPStatus: 504, + Retryable: true, + }, + PROVIDER_INVALID_KEY: { + Code: PROVIDER_INVALID_KEY, + Message: "Invalid API key", + HTTPStatus: 401, + Retryable: false, + }, + PROVIDER_RATE_LIMIT: { + Code: PROVIDER_RATE_LIMIT, + Message: "Rate limit exceeded", + HTTPStatus: 429, + Retryable: true, + }, + PROVIDER_QUOTA_EXCEEDED: { + Code: PROVIDER_QUOTA_EXCEEDED, + Message: "Quota exceeded", + HTTPStatus: 402, + Retryable: false, + }, + PROVIDER_MODEL_NOT_FOUND: { + Code: PROVIDER_MODEL_NOT_FOUND, + Message: "Model not found", + HTTPStatus: 404, + Retryable: false, + }, + PROVIDER_ERROR: { + Code: PROVIDER_ERROR, + Message: "Provider error", + HTTPStatus: 502, + Retryable: true, + }, + RATE_LIMIT_EXCEEDED: { + Code: RATE_LIMIT_EXCEEDED, + Message: "Rate limit exceeded", + HTTPStatus: 429, + Retryable: false, + }, + RATE_LIMIT_TOKEN_EXCEEDED: { + Code: RATE_LIMIT_TOKEN_EXCEEDED, + Message: "Token limit exceeded", + HTTPStatus: 429, + Retryable: false, + }, + RATE_LIMIT_BURST_EXCEEDED: { + Code: RATE_LIMIT_BURST_EXCEEDED, + Message: "Burst limit exceeded", + HTTPStatus: 429, + Retryable: false, + }, + COMMON_INVALID_REQUEST: { + Code: COMMON_INVALID_REQUEST, + Message: "Invalid request", + HTTPStatus: 400, + Retryable: false, + }, + COMMON_RESOURCE_NOT_FOUND: { + Code: COMMON_RESOURCE_NOT_FOUND, + Message: "Resource not found", + HTTPStatus: 404, + Retryable: false, + }, + COMMON_INTERNAL_ERROR: { + Code: COMMON_INTERNAL_ERROR, + Message: "Internal error", + HTTPStatus: 500, + Retryable: true, + }, + COMMON_SERVICE_UNAVAILABLE: { + Code: COMMON_SERVICE_UNAVAILABLE, + Message: "Service unavailable", + HTTPStatus: 503, + Retryable: true, + }, +} + +// NewGatewayError 创建网关错误 +func NewGatewayError(code ErrorCode, message string) *GatewayError { + return &GatewayError{ + Code: code, + Message: message, + Details: make(map[string]interface{}), + } +} + +// WithRequestID 设置请求ID +func (e *GatewayError) WithRequestID(requestID string) *GatewayError { + e.RequestID = requestID + return e +} + +// WithDetail 设置详情 +func (e *GatewayError) WithDetail(key string, value interface{}) *GatewayError { + e.Details[key] = value + return e +} + +// WithInternal 设置内部错误 +func (e *GatewayError) WithInternal(err error) *GatewayError { + e.Internal = err + return e +} + +// GetErrorInfo 获取错误信息 +func (e *GatewayError) GetErrorInfo() ErrorInfo { + if info, ok := ErrorDefinitions[e.Code]; ok { + return info + } + return ErrorInfo{ + Code: COMMON_INTERNAL_ERROR, + Message: e.Message, + HTTPStatus: 500, + Retryable: true, + } +} diff --git a/gateway/pkg/model/model.go b/gateway/pkg/model/model.go new file mode 100644 index 0000000..4bbcc32 --- /dev/null +++ b/gateway/pkg/model/model.go @@ -0,0 +1,144 @@ +package model + +import "time" + +// ChatCompletionRequest 聊天完成请求 +type ChatCompletionRequest struct { + Model string `json:"model" binding:"required"` + Messages []ChatMessage `json:"messages" binding:"required"` + Temperature float64 `json:"temperature,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + N int `json:"n,omitempty"` + PresencePenalty float64 `json:"presence_penalty,omitempty"` + FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` + User string `json:"user,omitempty"` +} + +// ChatMessage 聊天消息 +type ChatMessage struct { + Role string `json:"role" binding:"required"` + Content string `json:"content" binding:"required"` + Name string `json:"name,omitempty"` +} + +// ChatCompletionResponse 聊天完成响应 +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []Choice `json:"choices"` + Usage Usage `json:"usage"` +} + +type Choice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type Usage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// CompletionRequest 完成请求 +type CompletionRequest struct { + Model string `json:"model" binding:"required"` + Prompt string `json:"prompt" binding:"required"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + TopP float64 `json:"top_p,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop []string `json:"stop,omitempty"` + N int `json:"n,omitempty"` +} + +// CompletionResponse 完成响应 +type CompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []Choice1 `json:"choices"` + Usage Usage `json:"usage"` +} + +type Choice1 struct { + Text string `json:"text"` + Index int `json:"index"` + FinishReason string `json:"finish_reason"` +} + +// StreamResponse 流式响应 +type StreamResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []Delta `json:"choices"` +} + +type Delta struct { + Delta struct { + Content string `json:"content,omitempty"` + Role string `json:"role,omitempty"` + } `json:"delta"` + Index int `json:"index"` + FinishReason string `json:"finish_reason,omitempty"` +} + +// ErrorResponse 错误响应 +type ErrorResponse struct { + Error ErrorDetail `json:"error"` +} + +type ErrorDetail struct { + Message string `json:"message"` + Type string `json:"type"` + Code string `json:"code,omitempty"` + Param string `json:"param,omitempty"` +} + +// HealthStatus 健康状态 +type HealthStatus struct { + Status string `json:"status"` + Timestamp time.Time `json:"timestamp"` + Services map[string]bool `json:"services"` +} + +// Tenant 租户 +type Tenant struct { + ID int64 `json:"id"` + Name string `json:"name"` + Plan string `json:"plan"` + CreatedAt time.Time `json:"created_at"` +} + +// Budget 预算 +type Budget struct { + TenantID int64 `json:"tenant_id"` + MonthlyLimit float64 `json:"monthly_limit"` + AlertThreshold float64 `json:"alert_threshold"` + CurrentUsage float64 `json:"current_usage"` +} + +// RouteRequest 路由请求 +type RouteRequest struct { + Model string + TenantID int64 + RouteType string // "primary", "fallback" +} + +// RouteResult 路由结果 +type RouteResult struct { + Provider string + Model string + LatencyMs int64 + Success bool + Error error +}