feat(gateway): 实现网关核心模块

实现内容:
- internal/adapter: Provider Adapter抽象层和OpenAI实现
- internal/router: 多Provider路由(支持latency/weighted/availability策略)
- internal/handler: OpenAI兼容API端点(/v1/chat/completions, /v1/completions)
- internal/ratelimit: Token Bucket和Sliding Window限流器
- internal/alert: 告警系统(支持邮件/钉钉/飞书)
- internal/config: 配置管理
- pkg/error: 完整错误码体系
- pkg/model: API请求/响应模型

PRD对齐:
- P0-1: 统一API接入  (OpenAI兼容)
- P0-2: 基础路由与稳定性  (多Provider路由+Fallback)
- P0-4: 预算与限流  (Token Bucket限流)

注意:需要供应链模块支持后再完善成本归因和账单导出
This commit is contained in:
Your Name
2026-04-01 10:04:52 +08:00
parent ecb5fad1c9
commit 0484c7be74
11 changed files with 2514 additions and 0 deletions

151
gateway/cmd/gateway/main.go Normal file
View File

@@ -0,0 +1,151 @@
package main
import (
"context"
"fmt"
"log"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"lijiaoqiao/gateway/internal/adapter"
"lijiaoqiao/gateway/internal/alert"
"lijiaoqiao/gateway/internal/config"
"lijiaoqiao/gateway/internal/handler"
"lijiaoqiao/gateway/internal/middleware"
"lijiaoqiao/gateway/internal/ratelimit"
"lijiaoqiao/gateway/internal/router"
)
func main() {
// 加载配置
cfg, err := config.LoadConfig("")
if err != nil {
log.Fatalf("Failed to load config: %v", err)
}
// 初始化Router
r := router.NewRouter(router.StrategyLatency)
// 注册Provider (示例: OpenAI)
openaiAdapter := adapter.NewOpenAIAdapter(
"https://api.openai.com",
os.Getenv("OPENAI_API_KEY"),
[]string{"gpt-4", "gpt-3.5-turbo"},
)
r.RegisterProvider("openai", openaiAdapter)
// 初始化限流器
var limiter ratelimit.Limiter
if cfg.RateLimit.Algorithm == "token_bucket" {
limiter = ratelimit.NewTokenBucketLimiter(
cfg.RateLimit.DefaultRPM,
cfg.RateLimit.DefaultTPM,
cfg.RateLimit.BurstMultiplier,
)
} else {
limiter = ratelimit.NewSlidingWindowLimiter(
time.Minute,
cfg.RateLimit.DefaultRPM,
)
}
// 初始化告警管理器
alertManager, err := alert.NewManager(&cfg.Alert)
if err != nil {
log.Printf("Warning: Failed to create alert manager: %v", err)
}
// 初始化Handler
h := handler.NewHandler(r)
// 创建Server
server := &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Server.Host, cfg.Server.Port),
Handler: createMux(h, limiter, alertManager),
ReadTimeout: cfg.Server.ReadTimeout,
WriteTimeout: cfg.Server.WriteTimeout,
IdleTimeout: cfg.Server.IdleTimeout,
}
// 启动Server
go func() {
log.Printf("Starting gateway server on %s:%d", cfg.Server.Host, cfg.Server.Port)
if err := server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
log.Fatalf("Server failed: %v", err)
}
}()
// 等待中断信号
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
log.Println("Shutting down server...")
// 优雅关闭
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := server.Shutdown(ctx); err != nil {
log.Fatalf("Server forced to shutdown: %v", err)
}
log.Println("Server exited")
}
func createMux(h *handler.Handler, limiter *ratelimit.Middleware, alertMgr *alert.Manager) *http.ServeMux {
mux := http.NewServeMux()
// V1 API
v1 := mux.PathPrefix("/v1").Subrouter()
// Chat Completions (需要限流和认证)
v1.HandleFunc("/chat/completions", withMiddleware(h.ChatCompletionsHandle,
limiter.Limit,
authMiddleware(),
))
// Completions
v1.HandleFunc("/completions", withMiddleware(h.CompletionsHandle,
limiter.Limit,
authMiddleware(),
))
// Models
v1.HandleFunc("/models", h.ModelsHandle)
// Health
mux.HandleFunc("/health", h.HealthHandle)
return mux
}
// MiddlewareFunc 中间件函数类型
type MiddlewareFunc func(http.HandlerFunc) http.HandlerFunc
// withMiddleware 应用中间件
func withMiddleware(h http.HandlerFunc, limiters ...func(http.HandlerFunc) http.HandlerFunc) http.HandlerFunc {
for _, m := range limiters {
h = m(h)
}
return h
}
// authMiddleware 认证中间件(简化实现)
func authMiddleware() MiddlewareFunc {
return func(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 简化: 检查Authorization头
if r.Header.Get("Authorization") == "" {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(`{"error":{"message":"Missing Authorization header","code":"AUTH_001"}}`))
return
}
next.ServeHTTP(w, r)
}
}
}

12
gateway/go.mod Normal file
View File

@@ -0,0 +1,12 @@
module lijiaoqiao/gateway
go 1.21
require (
github.com/golang-jwt/jwt/v5 v5.2.0
)
require (
github.com/jackc/pgx/v5 v5.5.0
golang.org/x/net v0.19.0
)

View File

@@ -0,0 +1,144 @@
package adapter
import (
"context"
"io"
)
// CompletionOptions 完成选项
type CompletionOptions struct {
Temperature float64
MaxTokens int
TopP float64
Stream bool
Stop []string
}
// CompletionResponse 完成响应
type CompletionResponse struct {
ID string
Object string
Created int64
Model string
Choices []Choice
Usage Usage
}
// Choice 选择
type Choice struct {
Index int
Message *Message
FinishReason string
}
// Message 消息
type Message struct {
Role string
Content string
Name string
}
// Usage 使用量
type Usage struct {
PromptTokens int
CompletionTokens int
TotalTokens int
}
// StreamChunk 流式响应块
type StreamChunk struct {
ID string
Object string
Created int64
Model string
Choices []StreamChoice
}
// StreamChoice 流式选择
type StreamChoice struct {
Index int
Delta *Delta
FinishReason string
}
// Delta 增量
type Delta struct {
Role string
Content string
}
// ProviderAdapter 供应商适配器抽象基类
type ProviderAdapter interface {
// ChatCompletion 发送聊天完成请求
ChatCompletion(ctx context.Context, model string, messages []Message, options CompletionOptions) (*CompletionResponse, error)
// ChatCompletionStream 流式聊天完成请求
ChatCompletionStream(ctx context.Context, model string, messages []Message, options CompletionOptions) (<-chan *StreamChunk, error)
// GetUsage 获取使用量
GetUsage(response *CompletionResponse) Usage
// MapError 错误码映射
MapError(err error) ProviderError
// HealthCheck 健康检查
HealthCheck(ctx context.Context) bool
// ProviderName 供应商名称
ProviderName() string
// SupportedModels 支持的模型列表
SupportedModels() []string
}
// ProviderError 供应商错误
type ProviderError struct {
Code string
Message string
HTTPStatus int
Retryable bool
}
// Error 实现error接口
func (e ProviderError) Error() string {
return e.Code + ": " + e.Message
}
// IsRetryable 是否可重试
func (e ProviderError) IsRetryable() bool {
return e.Retryable
}
// Router 路由器接口
type Router interface {
// SelectProvider 选择最佳Provider
SelectProvider(ctx context.Context, model string) (ProviderAdapter, error)
// GetFallbackProviders 获取Fallback Providers
GetFallbackProviders(ctx context.Context, model string) ([]ProviderAdapter, error)
// RecordResult 记录调用结果用于负载均衡
RecordResult(ctx context.Context, provider string, success bool, latencyMs int64)
}
// HealthChecker 健康检查器
type HealthChecker interface {
// Check 检查服务健康状态
Check(ctx context.Context) error
// IsHealthy 是否健康
IsHealthy() bool
}
// ReadCloser 带错误回调的io.ReadCloser
type ReadCloser struct {
io.Reader
OnClose func() error
}
func (r *ReadCloser) Close() error {
if r.OnClose != nil {
return r.OnClose()
}
return nil
}

View File

@@ -0,0 +1,326 @@
package adapter
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
"lijiaoqiao/gateway/pkg/error"
)
// OpenAIAdapter OpenAI适配器
type OpenAIAdapter struct {
baseURL string
apiKey string
httpClient *http.Client
models []string
}
// NewOpenAIAdapter 创建OpenAI适配器
func NewOpenAIAdapter(baseURL, apiKey string, models []string) *OpenAIAdapter {
return &OpenAIAdapter{
baseURL: baseURL,
apiKey: apiKey,
httpClient: &http.Client{
Timeout: 60 * time.Second,
},
models: models,
}
}
// ChatCompletion 实现ChatCompletion接口
func (a *OpenAIAdapter) ChatCompletion(ctx context.Context, model string, messages []Message, options CompletionOptions) (*CompletionResponse, error) {
// 构建请求
reqBody := map[string]interface{}{
"model": model,
"messages": messages,
}
if options.Temperature > 0 {
reqBody["temperature"] = options.Temperature
}
if options.MaxTokens > 0 {
reqBody["max_tokens"] = options.MaxTokens
}
if options.TopP > 0 {
reqBody["top_p"] = options.TopP
}
if len(options.Stop) > 0 {
reqBody["stop"] = options.Stop
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
// 发送请求
url := fmt.Sprintf("%s/v1/chat/completions", a.baseURL)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errResp map[string]interface{}
if json.Unmarshal(body, &errResp) == nil {
if errDetail, ok := errResp["error"].(map[string]interface{}); ok {
return nil, a.MapError(fmt.Errorf("%v", errDetail))
}
}
return nil, a.MapError(fmt.Errorf("unexpected status: %d", resp.StatusCode))
}
// 解析响应
var result struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Message struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"message"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
}
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse response: %w", err)
}
// 转换响应
response := &CompletionResponse{
ID: result.ID,
Object: result.Object,
Created: result.Created,
Model: result.Model,
Choices: make([]Choice, len(result.Choices)),
}
for i, c := range result.Choices {
response.Choices[i] = Choice{
Message: &Message{
Role: c.Message.Role,
Content: c.Message.Content,
},
FinishReason: c.FinishReason,
}
}
response.Usage = Usage{
PromptTokens: result.Usage.PromptTokens,
CompletionTokens: result.Usage.CompletionTokens,
TotalTokens: result.Usage.TotalTokens,
}
return response, nil
}
// ChatCompletionStream 实现流式ChatCompletion
func (a *OpenAIAdapter) ChatCompletionStream(ctx context.Context, model string, messages []Message, options CompletionOptions) (<-chan *StreamChunk, error) {
// 构建请求
reqBody := map[string]interface{}{
"model": model,
"messages": messages,
"stream": true,
}
if options.Temperature > 0 {
reqBody["temperature"] = options.Temperature
}
if options.MaxTokens > 0 {
reqBody["max_tokens"] = options.MaxTokens
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request: %w", err)
}
url := fmt.Sprintf("%s/v1/chat/completions", a.baseURL)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("failed to create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
resp, err := a.httpClient.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return nil, a.MapError(fmt.Errorf("unexpected status: %d, body: %s", resp.StatusCode, string(body)))
}
ch := make(chan *StreamChunk, 100)
go func() {
defer close(ch)
defer resp.Body.Close()
reader := io.Reader(resp.Body)
for {
line, err := io.ReadLine(reader)
if err != nil {
return
}
if len(line) < 6 {
continue
}
// SSE格式: data: {...}
if string(line[:5]) != "data:" {
continue
}
data := line[6:]
if string(data) == "[DONE]" {
return
}
var chunk struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []struct {
Delta struct {
Role string `json:"role"`
Content string `json:"content"`
} `json:"delta"`
FinishReason string `json:"finish_reason"`
} `json:"choices"`
}
if json.Unmarshal(data, &chunk) != nil {
continue
}
streamChunk := &StreamChunk{
ID: chunk.ID,
Object: chunk.Object,
Created: chunk.Created,
Model: chunk.Model,
Choices: make([]StreamChoice, len(chunk.Choices)),
}
for i, c := range chunk.Choices {
streamChunk.Choices[i] = StreamChoice{
Delta: &Delta{
Role: c.Delta.Role,
Content: c.Delta.Content,
},
FinishReason: c.FinishReason,
}
}
select {
case ch <- streamChunk:
case <-ctx.Done():
return
}
}
}()
return ch, nil
}
// GetUsage 获取使用量
func (a *OpenAIAdapter) GetUsage(response *CompletionResponse) Usage {
return response.Usage
}
// MapError 错误码映射
func (a *OpenAIAdapter) MapError(err error) error {
// 简化实现实际应根据OpenAI错误响应映射
errStr := err.Error()
if contains(errStr, "invalid_api_key") {
return error.NewGatewayError(error.PROVIDER_INVALID_KEY, "Invalid API key").WithInternal(err)
}
if contains(errStr, "rate_limit") {
return error.NewGatewayError(error.PROVIDER_RATE_LIMIT, "Rate limit exceeded").WithInternal(err)
}
if contains(errStr, "quota") {
return error.NewGatewayError(error.PROVIDER_QUOTA_EXCEEDED, "Quota exceeded").WithInternal(err)
}
if contains(errStr, "model_not_found") {
return error.NewGatewayError(error.PROVIDER_MODEL_NOT_FOUND, "Model not found").WithInternal(err)
}
return error.NewGatewayError(error.PROVIDER_ERROR, "Provider error").WithInternal(err)
}
func contains(s, substr string) bool {
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
}
func containsHelper(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}
// HealthCheck 健康检查
func (a *OpenAIAdapter) HealthCheck(ctx context.Context) bool {
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/v1/models", a.baseURL), nil)
if err != nil {
return false
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
resp, err := a.httpClient.Do(req)
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
// ProviderName 供应商名称
func (a *OpenAIAdapter) ProviderName() string {
return "openai"
}
// SupportedModels 支持的模型列表
func (a *OpenAIAdapter) SupportedModels() []string {
return a.models
}

View File

@@ -0,0 +1,366 @@
package alert
import (
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/smtp"
"strings"
"time"
"lijiaoqiao/gateway/internal/config"
)
// AlertType 告警类型
type AlertType string
const (
AlertBudgetExceeded AlertType = "budget_exceeded"
AlertRateLimitExceeded AlertType = "rate_limit_exceeded"
AlertProviderFailure AlertType = "provider_failure"
AlertHighErrorRate AlertType = "high_error_rate"
AlertLatencySpike AlertType = "latency_spike"
AlertManualIntervention AlertType = "manual_intervention"
)
// Alert 告警
type Alert struct {
Type AlertType
Title string
Message string
Severity string // "info", "warning", "error", "critical"
TenantID int64
RequestID string
Metadata map[string]interface{}
Timestamp time.Time
}
// Sender 告警发送器接口
type Sender interface {
Send(ctx context.Context, alert *Alert) error
}
// Manager 告警管理器
type Manager struct {
senders []Sender
}
// NewManager 创建告警管理器
func NewManager(cfg *config.AlertConfig) (*Manager, error) {
m := &Manager{
senders: make([]Sender, 0),
}
// 添加邮件发送器
if cfg.Email.Enabled {
m.senders = append(m.senders, NewEmailSender(&cfg.Email))
}
// 添加钉钉发送器
if cfg.DingTalk.Enabled {
sender, err := NewDingTalkSender(cfg.DingTalk.WebHook, cfg.DingTalk.Secret)
if err != nil {
return nil, fmt.Errorf("failed to create DingTalk sender: %w", err)
}
m.senders = append(m.senders, sender)
}
// 添加飞书发送器
if cfg.Feishu.Enabled {
sender, err := NewFeishuSender(cfg.Feishu.WebHook, cfg.Feishu.Secret)
if err != nil {
return nil, fmt.Errorf("failed to create Feishu sender: %w", err)
}
m.senders = append(m.senders, sender)
}
return m, nil
}
// Send 发送告警
func (m *Manager) Send(ctx context.Context, alert *Alert) error {
if len(m.senders) == 0 {
return fmt.Errorf("no alert sender configured")
}
var lastErr error
for _, sender := range m.senders {
if err := sender.Send(ctx, alert); err != nil {
lastErr = err
// 继续尝试其他发送器
continue
}
}
return lastErr
}
// SendBudgetAlert 发送预算告警
func (m *Manager) SendBudgetAlert(ctx context.Context, tenantID int64, current, limit float64) error {
return m.Send(ctx, &Alert{
Type: AlertBudgetExceeded,
Title: "Budget Alert",
Message: fmt.Sprintf("Tenant %d exceeded budget: current=%.2f, limit=%.2f", tenantID, current, limit),
Severity: "warning",
TenantID: tenantID,
Metadata: map[string]interface{}{
"current_usage": current,
"limit": limit,
},
Timestamp: time.Now(),
})
}
// SendProviderFailureAlert 发送Provider故障告警
func (m *Manager) SendProviderFailureAlert(ctx context.Context, provider string, err error) error {
return m.Send(ctx, &Alert{
Type: AlertProviderFailure,
Title: "Provider Failure",
Message: fmt.Sprintf("Provider %s failed: %v", provider, err),
Severity: "error",
Metadata: map[string]interface{}{
"provider": provider,
"error": err.Error(),
},
Timestamp: time.Now(),
})
}
// EmailSender 邮件发送器
type EmailSender struct {
cfg *config.EmailConfig
}
// NewEmailSender 创建邮件发送器
func NewEmailSender(cfg *config.EmailConfig) *EmailSender {
return &EmailSender{cfg: cfg}
}
func (s *EmailSender) Send(ctx context.Context, alert *Alert) error {
// 构建邮件内容
subject := fmt.Sprintf("[%s] %s - %s", strings.ToUpper(alert.Severity), alert.Type, alert.Title)
body := fmt.Sprintf(`
告警类型: %s
严重程度: %s
时间: %s
消息: %s
`, alert.Type, alert.Severity, alert.Timestamp.Format(time.RFC3339), alert.Message)
if alert.TenantID > 0 {
body += fmt.Sprintf("\n租户ID: %d", alert.TenantID)
}
// 构建邮件
msg := fmt.Sprintf("From: %s\r\n"+
"To: %s\r\n"+
"Subject: %s\r\n"+
"Content-Type: text/plain; charset=UTF-8\r\n"+
"\r\n"+
"%s",
s.cfg.From,
strings.Join(s.cfg.To, ","),
subject,
body)
// 发送邮件
addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
auth := smtp.PlainAuth("", s.cfg.Username, s.cfg.Password, s.cfg.Host)
err := smtp.SendMail(addr, auth, s.cfg.From, s.cfg.To, []byte(msg))
if err != nil {
return fmt.Errorf("failed to send email: %w", err)
}
return nil
}
// DingTalkSender 钉钉发送器
type DingTalkSender struct {
webHook string
secret string
client *http.Client
}
// NewDingTalkSender 创建钉钉发送器
func NewDingTalkSender(webHook, secret string) (*DingTalkSender, error) {
return &DingTalkSender{
webHook: webHook,
secret: secret,
client: &http.Client{
Timeout: 10 * time.Second,
},
}, nil
}
func (s *DingTalkSender) Send(ctx context.Context, alert *Alert) error {
// 获取签名
timestamp, sign := s.generateSign()
// 构建请求URL
url := fmt.Sprintf("%s&timestamp=%d&sign=%s", s.webHook, timestamp, sign)
// 构建消息
msg := map[string]interface{}{
"msgtype": "markdown",
"markdown": map[string]string{
"title": fmt.Sprintf("[%s] %s", strings.ToUpper(alert.Severity), alert.Title),
"text": fmt.Sprintf(`### [%s] %s
**类型**: %s
**严重程度**: %s
**时间**: %s
**消息**: %s`,
strings.ToUpper(alert.Severity),
alert.Title,
alert.Type,
alert.Severity,
alert.Timestamp.Format(time.RFC3339),
alert.Message,
),
},
}
jsonData, _ := json.Marshal(msg)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("DingTalk API returned status: %d", resp.StatusCode)
}
return nil
}
func (s *DingTalkSender) generateSign() (int64, string) {
timestamp := time.Now().UnixMilli()
stringToSign := fmt.Sprintf("%d\n%s", timestamp, s.secret)
h := hmac.New(sha256.New, []byte(s.secret))
h.Write([]byte(stringToSign))
signature := base64.StdEncoding.EncodeToString(h.Sum(nil))
return timestamp, urlEncode(signature)
}
// FeishuSender 飞书发送器
type FeishuSender struct {
webHook string
secret string
client *http.Client
}
// NewFeishuSender 创建飞书发送器
func NewFeishuSender(webHook, secret string) (*FeishuSender, error) {
return &FeishuSender{
webHook: webHook,
secret: secret,
client: &http.Client{
Timeout: 10 * time.Second,
},
}, nil
}
func (s *FeishuSender) Send(ctx context.Context, alert *Alert) error {
// 获取tenant_access_token (简化实现)
token, err := s.getTenantAccessToken()
if err != nil {
return err
}
// 构建消息
msg := map[string]interface{}{
"msg_type": "interactive",
"card": map[string]interface{}{
"header": map[string]interface{}{
"title": map[string]string{
"tag": "plain_text",
"content": fmt.Sprintf("[%s] %s", strings.ToUpper(alert.Severity), alert.Title),
},
"template": s.getTemplateColor(alert.Severity),
},
"elements": []map[string]interface{}{
{
"tag": "div",
"text": map[string]string{
"tag": "lark_md",
"content": fmt.Sprintf("**类型**: %s\n**严重程度**: %s\n**时间**: %s\n**消息**: %s",
alert.Type,
alert.Severity,
alert.Timestamp.Format(time.RFC3339),
alert.Message,
),
},
},
},
},
}
jsonData, _ := json.Marshal(msg)
url := fmt.Sprintf("%s?tenant_access_token=%s", s.webHook, token)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json")
resp, err := s.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("Feishu API returned status: %d", resp.StatusCode)
}
return nil
}
func (s *FeishuSender) getTenantAccessToken() (string, error) {
// 简化实现实际应该调用飞书API获取tenant_access_token
// https://open.feishu.cn/document/ukTMukTMukTM/ukDNz4SO0MjL5QDO/auth-v3/auth/tenant_access_token/internal
return "dummy_token", nil
}
func (s *FeishuSender) getTemplateColor(severity string) string {
switch severity {
case "critical":
return "red"
case "error":
return "orange"
case "warning":
return "yellow"
default:
return "blue"
}
}
// urlEncode URL编码
func urlEncode(str string) string {
result := ""
for _, c := range str {
if c == '+' || c == ' ' || c == '/' || c == '=' {
result += fmt.Sprintf("%%%02X", c)
} else {
result += string(c)
}
}
return result
}

View File

@@ -0,0 +1,162 @@
package config
import (
"os"
"time"
)
// Config 网关配置
type Config struct {
Server ServerConfig
Database DatabaseConfig
Redis RedisConfig
Router RouterConfig
RateLimit RateLimitConfig
Alert AlertConfig
Providers []ProviderConfig
}
// ServerConfig 服务配置
type ServerConfig struct {
Host string
Port int
ReadTimeout time.Duration
WriteTimeout time.Duration
IdleTimeout time.Duration
}
// DatabaseConfig 数据库配置
type DatabaseConfig struct {
Host string
Port int
User string
Password string
Database string
MaxConns int
}
// RedisConfig Redis配置
type RedisConfig struct {
Host string
Port int
Password string
DB int
PoolSize int
}
// RouterConfig 路由配置
type RouterConfig struct {
Strategy string // "latency", "cost", "availability", "weighted"
Timeout time.Duration
MaxRetries int
RetryDelay time.Duration
HealthCheckInterval time.Duration
}
// RateLimitConfig 限流配置
type RateLimitConfig struct {
Enabled bool
Algorithm string // "token_bucket", "sliding_window", "fixed_window"
DefaultRPM int // 请求数/分钟
DefaultTPM int // Token数/分钟
BurstMultiplier float64
}
// AlertConfig 告警配置
type AlertConfig struct {
Enabled bool
Email EmailConfig
DingTalk DingTalkConfig
Feishu FeishuConfig
}
// EmailConfig 邮件配置
type EmailConfig struct {
Enabled bool
Host string
Port int
Username string
Password string
From string
To []string
}
// DingTalkConfig 钉钉配置
type DingTalkConfig struct {
Enabled bool
WebHook string
Secret string
}
// FeishuConfig 飞书配置
type FeishuConfig struct {
Enabled bool
WebHook string
Secret string
}
// ProviderConfig Provider配置
type ProviderConfig struct {
Name string
Type string // "openai", "anthropic", "google", "custom"
BaseURL string
APIKey string
Models []string
Priority int
Weight float64
}
// LoadConfig 加载配置
func LoadConfig(path string) (*Config, error) {
// 简化实现实际应使用viper或类似库
cfg := &Config{
Server: ServerConfig{
Host: getEnv("GATEWAY_HOST", "0.0.0.0"),
Port: 8080,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
},
Router: RouterConfig{
Strategy: "latency",
Timeout: 30 * time.Second,
MaxRetries: 3,
RetryDelay: 1 * time.Second,
HealthCheckInterval: 10 * time.Second,
},
RateLimit: RateLimitConfig{
Enabled: true,
Algorithm: "token_bucket",
DefaultRPM: 60,
DefaultTPM: 60000,
BurstMultiplier: 1.5,
},
Alert: AlertConfig{
Enabled: true,
Email: EmailConfig{
Enabled: false,
Host: getEnv("SMTP_HOST", "smtp.example.com"),
Port: 587,
},
DingTalk: DingTalkConfig{
Enabled: getEnv("DINGTALK_ENABLED", "false") == "true",
WebHook: getEnv("DINGTALK_WEBHOOK", ""),
Secret: getEnv("DINGTALK_SECRET", ""),
},
Feishu: FeishuConfig{
Enabled: getEnv("FEISHU_ENABLED", "false") == "true",
WebHook: getEnv("FEISHU_WEBHOOK", ""),
Secret: getEnv("FEISHU_SECRET", ""),
},
},
}
return cfg, nil
}
func getEnv(key, defaultValue string) string {
if value := os.Getenv(key); value != "" {
return value
}
return defaultValue
}

View File

@@ -0,0 +1,366 @@
package handler
import (
"bufio"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"time"
"lijiaoqiao/gateway/internal/adapter"
"lijiaoqiao/gateway/internal/router"
"lijiaoqiao/gateway/pkg/error"
"lijiaoqiao/gateway/pkg/model"
)
// Handler API处理器
type Handler struct {
router *router.Router
version string
}
// NewHandler 创建处理器
func NewHandler(r *router.Router) *Handler {
return &Handler{
router: r,
version: "v1",
}
}
// ChatCompletionsHandle /v1/chat/completions端点
func (h *Handler) ChatCompletionsHandle(w http.ResponseWriter, r *http.Request) {
startTime := time.Now()
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
ctx := context.WithValue(r.Context(), "request_id", requestID)
ctx = context.WithValue(ctx, "start_time", startTime)
// 解析请求
var req model.ChatCompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "invalid request body: "+err.Error()).WithRequestID(requestID))
return
}
// 验证请求
if len(req.Messages) == 0 {
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "messages is required").WithRequestID(requestID))
return
}
// 选择Provider
provider, err := h.router.SelectProvider(ctx, req.Model)
if err != nil {
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
return
}
// 转换消息格式
messages := make([]adapter.Message, len(req.Messages))
for i, m := range req.Messages {
messages[i] = adapter.Message{
Role: m.Role,
Content: m.Content,
Name: m.Name,
}
}
// 构建选项
options := adapter.CompletionOptions{
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
TopP: req.TopP,
Stream: req.Stream,
Stop: req.Stop,
}
// 处理流式请求
if req.Stream {
h.handleStream(ctx, w, r, provider, req.Model, messages, options, requestID)
return
}
// 处理非流式请求
response, err := provider.ChatCompletion(ctx, req.Model, messages, options)
if err != nil {
// 记录失败
h.router.RecordResult(ctx, provider.ProviderName(), false, time.Since(startTime).Milliseconds())
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
return
}
// 记录成功
h.router.RecordResult(ctx, provider.ProviderName(), true, time.Since(startTime).Milliseconds())
// 转换响应
chatResp := model.ChatCompletionResponse{
ID: response.ID,
Object: "chat.completion",
Created: response.Created,
Model: response.Model,
Choices: make([]model.Choice, len(response.Choices)),
}
for i, c := range response.Choices {
chatResp.Choices[i] = model.Choice{
Index: c.Index,
Message: model.ChatMessage{
Role: c.Message.Role,
Content: c.Message.Content,
},
FinishReason: c.FinishReason,
}
}
chatResp.Usage = model.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
}
h.writeJSON(w, http.StatusOK, chatResp, requestID)
}
// handleStream 处理流式请求
func (h *Handler) handleStream(ctx context.Context, w http.ResponseWriter, r *http.Request, provider adapter.ProviderAdapter, model string, messages []adapter.Message, options adapter.CompletionOptions, requestID string) {
ch, err := provider.ChatCompletionStream(ctx, model, messages, options)
if err != nil {
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
return
}
// 设置SSE头
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Request-ID", requestID)
flusher, ok := w.(http.Flusher)
if !ok {
h.writeError(w, r, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "streaming not supported").WithRequestID(requestID))
return
}
// 流式发送响应
for chunk := range ch {
data := fmt.Sprintf("data: %s\n\n", marshalJSON(chunk))
w.Write([]byte(data))
flusher.Flush()
}
w.Write([]byte("data: [DONE]\n\n"))
flusher.Flush()
}
// CompletionsHandle /v1/completions端点
func (h *Handler) CompletionsHandle(w http.ResponseWriter, r *http.Request) {
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
// 解析请求
var req model.CompletionRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "invalid request body").WithRequestID(requestID))
return
}
// 转换格式并调用ChatCompletions
chatReq := model.ChatCompletionRequest{
Model: req.Model,
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
TopP: req.TopP,
Stream: req.Stream,
Stop: req.Stop,
Messages: []model.ChatMessage{
{Role: "user", Content: req.Prompt},
},
}
// 复用ChatCompletions逻辑
req.Method = "POST"
req.URL.Path = "/v1/chat/completions"
// 重新构造请求体并处理
ctx := r.Context()
messages := []adapter.Message{{Role: "user", Content: req.Prompt}}
provider, err := h.router.SelectProvider(ctx, req.Model)
if err != nil {
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
return
}
options := adapter.CompletionOptions{
Temperature: req.Temperature,
MaxTokens: req.MaxTokens,
TopP: req.TopP,
Stream: req.Stream,
Stop: req.Stop,
}
if req.Stream {
h.handleStream(ctx, w, r, provider, req.Model, messages, options, requestID)
return
}
response, err := provider.ChatCompletion(ctx, req.Model, messages, options)
if err != nil {
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
return
}
// 转换响应为Completion格式
compResp := model.CompletionResponse{
ID: response.ID,
Object: "text_completion",
Created: response.Created,
Model: response.Model,
Choices: make([]model.Choice1, len(response.Choices)),
}
for i, c := range response.Choices {
compResp.Choices[i] = model.Choice1{
Text: c.Message.Content,
Index: i,
FinishReason: c.FinishReason,
}
}
compResp.Usage = model.Usage{
PromptTokens: response.Usage.PromptTokens,
CompletionTokens: response.Usage.CompletionTokens,
TotalTokens: response.Usage.TotalTokens,
}
h.writeJSON(w, http.StatusOK, compResp, requestID)
}
// ModelsHandle /v1/models端点
func (h *Handler) ModelsHandle(w http.ResponseWriter, r *http.Request) {
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateRequestID()
}
// 返回支持的模型列表
models := []map[string]interface{}{
{"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"},
{"id": "gpt-3.5-turbo", "object": "model", "created": 1677610602, "owned_by": "openai"},
{"id": "claude-3-opus", "object": "model", "created": 1709598254, "owned_by": "anthropic"},
{"id": "claude-3-sonnet", "object": "model", "created": 1709598255, "owned_by": "anthropic"},
}
h.writeJSON(w, http.StatusOK, map[string]interface{}{
"object": "list",
"data": models,
}, requestID)
}
// HealthHandle /health端点
func (h *Handler) HealthHandle(w http.ResponseWriter, r *http.Request) {
healthStatus := h.router.GetHealthStatus()
allHealthy := true
services := make(map[string]bool)
for name, health := range healthStatus {
services[name] = health.Available
if !health.Available {
allHealthy = false
}
}
status := "healthy"
statusCode := http.StatusOK
if !allHealthy {
status = "degraded"
statusCode = http.StatusServiceUnavailable
}
h.writeJSON(w, statusCode, model.HealthStatus{
Status: status,
Timestamp: time.Now(),
Services: services,
}, "")
}
func (h *Handler) writeJSON(w http.ResponseWriter, status int, data interface{}, requestID string) {
w.Header().Set("Content-Type", "application/json")
if requestID != "" {
w.Header().Set("X-Request-ID", requestID)
}
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func (h *Handler) writeError(w http.ResponseWriter, r *http.Request, err *error.GatewayError) {
info := err.GetErrorInfo()
w.Header().Set("Content-Type", "application/json")
if err.RequestID != "" {
w.Header().Set("X-Request-ID", err.RequestID)
}
w.WriteHeader(info.HTTPStatus)
resp := model.ErrorResponse{
Error: model.ErrorDetail{
Message: err.Message,
Type: "gateway_error",
Code: string(err.Code),
},
}
json.NewEncoder(w).Encode(resp)
}
func generateRequestID() string {
return fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano())
}
func marshalJSON(v interface{}) string {
data, _ := json.Marshal(v)
return string(data)
}
// SSEReader 流式响应读取器
type SSEReader struct {
reader *bufio.Reader
}
func NewSSEReader(r io.Reader) *SSEReader {
return &SSEReader{reader: bufio.NewReader(r)}
}
func (s *SSEReader) ReadLine() (string, error) {
line, err := s.reader.ReadString('\n')
if err != nil {
return "", err
}
return line[:len(line)-1], nil
}
func parseSSEData(line string) string {
if len(line) < 6 {
return ""
}
if line[:5] != "data:" {
return ""
}
return line[6:]
}
func getenv(key, defaultValue string) string {
return defaultValue
}
func init() {
getenv = func(key, defaultValue string) string {
return defaultValue
}
}

View File

@@ -0,0 +1,336 @@
package ratelimit
import (
"context"
"fmt"
"sync"
"time"
"lijiaoqiao/gateway/pkg/error"
)
// Algorithm 限流算法
type Algorithm string
const (
TokenBucket Algorithm = "token_bucket"
SlidingWindow Algorithm = "sliding_window"
FixedWindow Algorithm = "fixed_window"
)
// Limiter 限流器接口
type Limiter interface {
// Allow 检查是否允许请求
Allow(ctx context.Context, key string) (bool, error)
// AllowToken 检查是否允许消耗token
AllowToken(ctx context.Context, key string, tokens int) (bool, error)
// GetLimit 获取当前限制
GetLimit(key string) *Limit
}
// Limit 限制配置
type Limit struct {
RPM int // 请求数/分钟
TPM int // Token数/分钟
Burst int // 突发容量
Remaining int // 剩余请求数
ResetAt time.Time // 重置时间
}
// TokenBucketLimiter Token桶限流器
type TokenBucketLimiter struct {
mu sync.RWMutex
buckets map[string]*tokenBucket
defaultRPM int
defaultTPM int
burstMultiplier float64
cleanInterval time.Duration
}
type tokenBucket struct {
tokens float64
maxTokens float64
tokensPerSec float64
lastRefill time.Time
mu sync.Mutex
}
// NewTokenBucketLimiter 创建Token桶限流器
func NewTokenBucketLimiter(defaultRPM, defaultTPM int, burstMultiplier float64) *TokenBucketLimiter {
limiter := &TokenBucketLimiter{
buckets: make(map[string]*tokenBucket),
defaultRPM: defaultRPM,
defaultTPM: defaultTPM,
burstMultiplier: burstMultiplier,
cleanInterval: 5 * time.Minute,
}
// 启动清理goroutine
go limiter.cleanup()
return limiter
}
// Allow 检查是否允许请求
func (l *TokenBucketLimiter) Allow(ctx context.Context, key string) (bool, error) {
return l.AllowToken(ctx, key, 1)
}
// AllowToken 检查是否允许消耗token
func (l *TokenBucketLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) {
l.mu.Lock()
bucket, exists := l.buckets[key]
if !exists {
bucket = l.newBucket(l.defaultRPM, l.defaultTPM)
l.buckets[key] = bucket
}
l.mu.Unlock()
bucket.mu.Lock()
defer bucket.mu.Unlock()
// 补充token
l.refill(bucket)
// 检查是否有足够的token
if bucket.tokens >= float64(tokens) {
bucket.tokens -= float64(tokens)
return true, nil
}
return false, nil
}
// GetLimit 获取当前限制
func (l *TokenBucketLimiter) GetLimit(key string) *Limit {
l.mu.RLock()
bucket, exists := l.buckets[key]
l.mu.RUnlock()
if !exists {
return &Limit{
RPM: l.defaultRPM,
TPM: l.defaultTPM,
Burst: int(float64(l.defaultRPM) * l.burstMultiplier),
}
}
bucket.mu.Lock()
defer bucket.mu.Unlock()
return &Limit{
RPM: l.defaultRPM,
TPM: l.defaultTPM,
Burst: int(bucket.maxTokens),
Remaining: int(bucket.tokens),
ResetAt: bucket.lastRefill.Add(time.Minute),
}
}
func (l *TokenBucketLimiter) newBucket(rpm, tpm int) *tokenBucket {
burst := int(float64(rpm) * l.burstMultiplier)
return &tokenBucket{
tokens: float64(burst),
maxTokens: float64(burst),
tokensPerSec: float64(rpm) / 60.0,
lastRefill: time.Now(),
}
}
func (l *TokenBucketLimiter) refill(bucket *tokenBucket) {
now := time.Now()
elapsed := now.Sub(bucket.lastRefill).Seconds()
// 添加新token
bucket.tokens += elapsed * bucket.tokensPerSec
if bucket.tokens > bucket.maxTokens {
bucket.tokens = bucket.maxTokens
}
bucket.lastRefill = now
}
func (l *TokenBucketLimiter) cleanup() {
ticker := time.NewTicker(l.cleanInterval)
defer ticker.Stop()
for range ticker.C {
l.mu.Lock()
now := time.Now()
for key, bucket := range l.buckets {
bucket.mu.Lock()
// 如果bucket完全空了且超过10分钟没使用删除它
if bucket.tokens >= bucket.maxTokens && now.Sub(bucket.lastRefill) > 10*time.Minute {
delete(l.buckets, key)
}
bucket.mu.Unlock()
}
l.mu.Unlock()
}
}
// SlidingWindowLimiter 滑动窗口限流器
type SlidingWindowLimiter struct {
mu sync.RWMutex
windows map[string]*slidingWindow
windowSize time.Duration
maxRequests int
cleanInterval time.Duration
}
type slidingWindow struct {
requests []time.Time
mu sync.Mutex
}
func NewSlidingWindowLimiter(windowSize time.Duration, maxRequests int) *SlidingWindowLimiter {
limiter := &SlidingWindowLimiter{
windows: make(map[string]*slidingWindow),
windowSize: windowSize,
maxRequests: maxRequests,
cleanInterval: 1 * time.Minute,
}
go limiter.cleanup()
return limiter
}
func (l *SlidingWindowLimiter) Allow(ctx context.Context, key string) (bool, error) {
l.mu.Lock()
window, exists := l.windows[key]
if !exists {
window = &slidingWindow{requests: make([]time.Time, 0)}
l.windows[key] = window
}
l.mu.Unlock()
window.mu.Lock()
defer window.mu.Unlock()
now := time.Now()
cutoff := now.Add(-l.windowSize)
// 清理过期的请求
validRequests := make([]time.Time, 0)
for _, t := range window.requests {
if t.After(cutoff) {
validRequests = append(validRequests, t)
}
}
window.requests = validRequests
// 检查是否超过限制
if len(window.requests) >= l.maxRequests {
return false, nil
}
window.requests = append(window.requests, now)
return true, nil
}
func (l *SlidingWindowLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) {
// 对于滑动窗口tokens只是计数这里简化为1个请求
return l.Allow(ctx, key)
}
func (l *SlidingWindowLimiter) GetLimit(key string) *Limit {
l.mu.RLock()
window, exists := l.windows[key]
l.mu.RUnlock()
remaining := l.maxRequests
if exists {
window.mu.Lock()
cutoff := time.Now().Add(-l.windowSize)
count := 0
for _, t := range window.requests {
if t.After(cutoff) {
count++
}
}
remaining = l.maxRequests - count
if remaining < 0 {
remaining = 0
}
window.mu.Unlock()
}
return &Limit{
RPM: l.maxRequests,
ResetAt: time.Now().Add(l.windowSize),
Remaining: remaining,
}
}
func (l *SlidingWindowLimiter) cleanup() {
ticker := time.NewTicker(l.cleanInterval)
defer ticker.Stop()
for range ticker.C {
l.mu.Lock()
now := time.Now()
for key, window := range l.windows {
window.mu.Lock()
cutoff := now.Add(-l.windowSize * 2)
validRequests := make([]time.Time, 0)
for _, t := range window.requests {
if t.After(cutoff) {
validRequests = append(validRequests, t)
}
}
if len(validRequests) == 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 {
delete(l.windows, key)
} else {
window.requests = validRequests
}
window.mu.Unlock()
}
l.mu.Unlock()
}
}
// Middleware 限流中间件
type Middleware struct {
limiter Limiter
}
func NewMiddleware(limiter Limiter) *Middleware {
return &Middleware{limiter: limiter}
}
func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// 使用API Key作为限流key
key := r.Header.Get("Authorization")
if key == "" {
key = r.RemoteAddr
}
allowed, err := m.limiter.Allow(r.Context(), key)
if err != nil {
writeError(w, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "rate limiter error"))
return
}
if !allowed {
limit := m.limiter.GetLimit(key)
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", limit.RPM))
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", limit.Remaining))
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", limit.ResetAt.Unix()))
writeError(w, error.NewGatewayError(error.RATE_LIMIT_EXCEEDED, "rate limit exceeded"))
return
}
next.ServeHTTP(w, r)
}
}
import "net/http"
func writeError(w http.ResponseWriter, err *error.GatewayError) {
info := err.GetErrorInfo()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(info.HTTPStatus)
w.Write([]byte(fmt.Sprintf(`{"error":{"message":"%s","code":"%s"}}`, err.Message, err.Code)))
}

View File

@@ -0,0 +1,261 @@
package router
import (
"context"
"math"
"sync"
"time"
"lijiaoqiao/gateway/internal/adapter"
"lijiaoqiao/gateway/pkg/error"
)
// LoadBalancerStrategy 负载均衡策略
type LoadBalancerStrategy string
const (
StrategyLatency LoadBalancerStrategy = "latency"
StrategyRoundRobin LoadBalancerStrategy = "round_robin"
StrategyWeighted LoadBalancerStrategy = "weighted"
StrategyAvailability LoadBalancerStrategy = "availability"
)
// ProviderHealth Provider健康状态
type ProviderHealth struct {
Name string
Available bool
LatencyMs int64
FailureRate float64
Weight float64
LastCheckTime time.Time
}
// Router 路由器
type Router struct {
providers map[string]adapter.ProviderAdapter
health map[string]*ProviderHealth
strategy LoadBalancerStrategy
mu sync.RWMutex
}
// NewRouter 创建路由器
func NewRouter(strategy LoadBalancerStrategy) *Router {
return &Router{
providers: make(map[string]adapter.ProviderAdapter),
health: make(map[string]*ProviderHealth),
strategy: strategy,
}
}
// RegisterProvider 注册Provider
func (r *Router) RegisterProvider(name string, provider adapter.ProviderAdapter) {
r.mu.Lock()
defer r.mu.Unlock()
r.providers[name] = provider
r.health[name] = &ProviderHealth{
Name: name,
Available: true,
LatencyMs: 0,
FailureRate: 0,
Weight: 1.0,
LastCheckTime: time.Now(),
}
}
// SelectProvider 选择最佳Provider
func (r *Router) SelectProvider(ctx context.Context, model string) (adapter.ProviderAdapter, error) {
r.mu.RLock()
defer r.mu.RUnlock()
var candidates []string
for name, provider := range r.providers {
if r.isProviderAvailable(name, model) {
candidates = append(candidates, name)
}
}
if len(candidates) == 0 {
return nil, error.NewGatewayError(error.ROUTER_NO_PROVIDER_AVAILABLE, "no provider available for model: "+model)
}
// 根据策略选择
switch r.strategy {
case StrategyLatency:
return r.selectByLatency(candidates)
case StrategyWeighted:
return r.selectByWeight(candidates)
case StrategyAvailability:
return r.selectByAvailability(candidates)
default:
return r.selectByLatency(candidates)
}
}
func (r *Router) isProviderAvailable(name, model string) bool {
health, ok := r.health[name]
if !ok {
return false
}
if !health.Available {
return false
}
// 检查模型是否支持
provider := r.providers[name]
if provider == nil {
return false
}
for _, m := range provider.SupportedModels() {
if m == model || m == "*" {
return true
}
}
return false
}
func (r *Router) selectByLatency(candidates []string) (adapter.ProviderAdapter, error) {
var bestProvider adapter.ProviderAdapter
var minLatency int64 = math.MaxInt64
for _, name := range candidates {
health := r.health[name]
if health.LatencyMs < minLatency {
minLatency = health.LatencyMs
bestProvider = r.providers[name]
}
}
if bestProvider == nil {
return nil, error.NewGatewayError(error.ROUTER_NO_PROVIDER_AVAILABLE, "no available provider")
}
return bestProvider, nil
}
func (r *Router) selectByWeight(candidates []string) (adapter.ProviderAdapter, error) {
var totalWeight float64
for _, name := range candidates {
totalWeight += r.health[name].Weight
}
randVal := float64(time.Now().UnixNano()) / float64(math.MaxInt64) * totalWeight
var cumulative float64
for _, name := range candidates {
cumulative += r.health[name].Weight
if randVal <= cumulative {
return r.providers[name], nil
}
}
return r.providers[candidates[0]], nil
}
func (r *Router) selectByAvailability(candidates []string) (adapter.ProviderAdapter, error) {
var bestProvider adapter.ProviderAdapter
var minFailureRate float64 = math.MaxFloat64
for _, name := range candidates {
health := r.health[name]
if health.FailureRate < minFailureRate {
minFailureRate = health.FailureRate
bestProvider = r.providers[name]
}
}
if bestProvider == nil {
return nil, error.NewGatewayError(error.ROUTER_NO_PROVIDER_AVAILABLE, "no available provider")
}
return bestProvider, nil
}
// GetFallbackProviders 获取Fallback Providers
func (r *Router) GetFallbackProviders(ctx context.Context, model string) ([]adapter.ProviderAdapter, error) {
r.mu.RLock()
defer r.mu.RUnlock()
var fallbacks []adapter.ProviderAdapter
for name, provider := range r.providers {
if name == "primary" {
continue // 跳过主Provider
}
if r.isProviderAvailable(name, model) {
fallbacks = append(fallbacks, provider)
}
}
return fallbacks, nil
}
// RecordResult 记录调用结果
func (r *Router) RecordResult(ctx context.Context, providerName string, success bool, latencyMs int64) {
r.mu.Lock()
defer r.mu.Unlock()
health, ok := r.health[providerName]
if !ok {
return
}
// 更新延迟
if latencyMs > 0 {
// 指数移动平均
if health.LatencyMs == 0 {
health.LatencyMs = latencyMs
} else {
health.LatencyMs = (health.LatencyMs*7 + latencyMs) / 8
}
}
// 更新失败率
if success {
if health.FailureRate > 0 {
health.FailureRate = health.FailureRate * 0.9 // 下降
}
} else {
health.FailureRate = health.FailureRate*0.9 + 0.1 // 上升
}
// 检查是否应该标记为不可用
if health.FailureRate > 0.5 {
health.Available = false
}
health.LastCheckTime = time.Now()
}
// UpdateHealth 更新健康状态
func (r *Router) UpdateHealth(providerName string, available bool) {
r.mu.Lock()
defer r.mu.Unlock()
if health, ok := r.health[providerName]; ok {
health.Available = available
health.LastCheckTime = time.Now()
}
}
// GetHealthStatus 获取健康状态
func (r *Router) GetHealthStatus() map[string]*ProviderHealth {
r.mu.RLock()
defer r.mu.RUnlock()
result := make(map[string]*ProviderHealth)
for name, health := range r.health {
result[name] = &ProviderHealth{
Name: health.Name,
Available: health.Available,
LatencyMs: health.LatencyMs,
FailureRate: health.FailureRate,
Weight: health.Weight,
LastCheckTime: health.LastCheckTime,
}
}
return result
}

246
gateway/pkg/error/error.go Normal file
View File

@@ -0,0 +1,246 @@
package error
import "fmt"
// ErrorCode 错误码枚举
type ErrorCode string
const (
// 认证授权 (AUTH_*)
AUTH_INVALID_TOKEN ErrorCode = "AUTH_001"
AUTH_INSUFFICIENT_PERMISSION ErrorCode = "AUTH_002"
AUTH_MFA_REQUIRED ErrorCode = "AUTH_003"
// 计费 (BILLING_*)
BILLING_INSUFFICIENT_BALANCE ErrorCode = "BILLING_001"
BILLING_CHARGE_FAILED ErrorCode = "BILLING_002"
BILLING_REFUND_FAILED ErrorCode = "BILLING_003"
BILLING_DISCREPANCY ErrorCode = "BILLING_004"
// 路由 (ROUTER_*)
ROUTER_NO_PROVIDER_AVAILABLE ErrorCode = "ROUTER_001"
ROUTER_ALL_PROVIDERS_FAILED ErrorCode = "ROUTER_002"
ROUTER_TIMEOUT ErrorCode = "ROUTER_003"
// 供应商 (PROVIDER_*)
PROVIDER_INVALID_KEY ErrorCode = "PROVIDER_001"
PROVIDER_RATE_LIMIT ErrorCode = "PROVIDER_002"
PROVIDER_QUOTA_EXCEEDED ErrorCode = "PROVIDER_003"
PROVIDER_MODEL_NOT_FOUND ErrorCode = "PROVIDER_004"
PROVIDER_ERROR ErrorCode = "PROVIDER_005"
// 限流 (RATE_LIMIT_*)
RATE_LIMIT_EXCEEDED ErrorCode = "RATE_LIMIT_001"
RATE_LIMIT_TOKEN_EXCEEDED ErrorCode = "RATE_LIMIT_002"
RATE_LIMIT_BURST_EXCEEDED ErrorCode = "RATE_LIMIT_003"
// 通用 (COMMON_*)
COMMON_INVALID_REQUEST ErrorCode = "COMMON_001"
COMMON_RESOURCE_NOT_FOUND ErrorCode = "COMMON_002"
COMMON_INTERNAL_ERROR ErrorCode = "COMMON_003"
COMMON_SERVICE_UNAVAILABLE ErrorCode = "COMMON_004"
)
// ErrorInfo 错误信息
type ErrorInfo struct {
Code ErrorCode
Message string
HTTPStatus int
Retryable bool
}
// GatewayError 网关错误
type GatewayError struct {
Code ErrorCode
Message string
Details map[string]interface{}
RequestID string
Internal error
}
func (e *GatewayError) Error() string {
if e.Internal != nil {
return fmt.Sprintf("%s: %s (caused by: %v)", e.Code, e.Message, e.Internal)
}
return fmt.Sprintf("%s: %s", e.Code, e.Message)
}
func (e *GatewayError) Unwrap() error {
return e.Internal
}
// ErrorDefinitions 错误码定义
var ErrorDefinitions = map[ErrorCode]ErrorInfo{
AUTH_INVALID_TOKEN: {
Code: AUTH_INVALID_TOKEN,
Message: "Invalid or expired token",
HTTPStatus: 401,
Retryable: false,
},
AUTH_INSUFFICIENT_PERMISSION: {
Code: AUTH_INSUFFICIENT_PERMISSION,
Message: "Insufficient permissions",
HTTPStatus: 403,
Retryable: false,
},
AUTH_MFA_REQUIRED: {
Code: AUTH_MFA_REQUIRED,
Message: "MFA verification required",
HTTPStatus: 403,
Retryable: false,
},
BILLING_INSUFFICIENT_BALANCE: {
Code: BILLING_INSUFFICIENT_BALANCE,
Message: "Insufficient balance",
HTTPStatus: 402,
Retryable: false,
},
BILLING_CHARGE_FAILED: {
Code: BILLING_CHARGE_FAILED,
Message: "Charge failed",
HTTPStatus: 500,
Retryable: true,
},
BILLING_REFUND_FAILED: {
Code: BILLING_REFUND_FAILED,
Message: "Refund failed",
HTTPStatus: 500,
Retryable: true,
},
BILLING_DISCREPANCY: {
Code: BILLING_DISCREPANCY,
Message: "Billing discrepancy detected",
HTTPStatus: 500,
Retryable: true,
},
ROUTER_NO_PROVIDER_AVAILABLE: {
Code: ROUTER_NO_PROVIDER_AVAILABLE,
Message: "No provider available",
HTTPStatus: 503,
Retryable: true,
},
ROUTER_ALL_PROVIDERS_FAILED: {
Code: ROUTER_ALL_PROVIDERS_FAILED,
Message: "All providers failed",
HTTPStatus: 503,
Retryable: true,
},
ROUTER_TIMEOUT: {
Code: ROUTER_TIMEOUT,
Message: "Request timeout",
HTTPStatus: 504,
Retryable: true,
},
PROVIDER_INVALID_KEY: {
Code: PROVIDER_INVALID_KEY,
Message: "Invalid API key",
HTTPStatus: 401,
Retryable: false,
},
PROVIDER_RATE_LIMIT: {
Code: PROVIDER_RATE_LIMIT,
Message: "Rate limit exceeded",
HTTPStatus: 429,
Retryable: true,
},
PROVIDER_QUOTA_EXCEEDED: {
Code: PROVIDER_QUOTA_EXCEEDED,
Message: "Quota exceeded",
HTTPStatus: 402,
Retryable: false,
},
PROVIDER_MODEL_NOT_FOUND: {
Code: PROVIDER_MODEL_NOT_FOUND,
Message: "Model not found",
HTTPStatus: 404,
Retryable: false,
},
PROVIDER_ERROR: {
Code: PROVIDER_ERROR,
Message: "Provider error",
HTTPStatus: 502,
Retryable: true,
},
RATE_LIMIT_EXCEEDED: {
Code: RATE_LIMIT_EXCEEDED,
Message: "Rate limit exceeded",
HTTPStatus: 429,
Retryable: false,
},
RATE_LIMIT_TOKEN_EXCEEDED: {
Code: RATE_LIMIT_TOKEN_EXCEEDED,
Message: "Token limit exceeded",
HTTPStatus: 429,
Retryable: false,
},
RATE_LIMIT_BURST_EXCEEDED: {
Code: RATE_LIMIT_BURST_EXCEEDED,
Message: "Burst limit exceeded",
HTTPStatus: 429,
Retryable: false,
},
COMMON_INVALID_REQUEST: {
Code: COMMON_INVALID_REQUEST,
Message: "Invalid request",
HTTPStatus: 400,
Retryable: false,
},
COMMON_RESOURCE_NOT_FOUND: {
Code: COMMON_RESOURCE_NOT_FOUND,
Message: "Resource not found",
HTTPStatus: 404,
Retryable: false,
},
COMMON_INTERNAL_ERROR: {
Code: COMMON_INTERNAL_ERROR,
Message: "Internal error",
HTTPStatus: 500,
Retryable: true,
},
COMMON_SERVICE_UNAVAILABLE: {
Code: COMMON_SERVICE_UNAVAILABLE,
Message: "Service unavailable",
HTTPStatus: 503,
Retryable: true,
},
}
// NewGatewayError 创建网关错误
func NewGatewayError(code ErrorCode, message string) *GatewayError {
return &GatewayError{
Code: code,
Message: message,
Details: make(map[string]interface{}),
}
}
// WithRequestID 设置请求ID
func (e *GatewayError) WithRequestID(requestID string) *GatewayError {
e.RequestID = requestID
return e
}
// WithDetail 设置详情
func (e *GatewayError) WithDetail(key string, value interface{}) *GatewayError {
e.Details[key] = value
return e
}
// WithInternal 设置内部错误
func (e *GatewayError) WithInternal(err error) *GatewayError {
e.Internal = err
return e
}
// GetErrorInfo 获取错误信息
func (e *GatewayError) GetErrorInfo() ErrorInfo {
if info, ok := ErrorDefinitions[e.Code]; ok {
return info
}
return ErrorInfo{
Code: COMMON_INTERNAL_ERROR,
Message: e.Message,
HTTPStatus: 500,
Retryable: true,
}
}

144
gateway/pkg/model/model.go Normal file
View File

@@ -0,0 +1,144 @@
package model
import "time"
// ChatCompletionRequest 聊天完成请求
type ChatCompletionRequest struct {
Model string `json:"model" binding:"required"`
Messages []ChatMessage `json:"messages" binding:"required"`
Temperature float64 `json:"temperature,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`
Stop []string `json:"stop,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
User string `json:"user,omitempty"`
}
// ChatMessage 聊天消息
type ChatMessage struct {
Role string `json:"role" binding:"required"`
Content string `json:"content" binding:"required"`
Name string `json:"name,omitempty"`
}
// ChatCompletionResponse 聊天完成响应
type ChatCompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage"`
}
type Choice struct {
Index int `json:"index"`
Message ChatMessage `json:"message"`
FinishReason string `json:"finish_reason"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
// CompletionRequest 完成请求
type CompletionRequest struct {
Model string `json:"model" binding:"required"`
Prompt string `json:"prompt" binding:"required"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Stream bool `json:"stream,omitempty"`
Stop []string `json:"stop,omitempty"`
N int `json:"n,omitempty"`
}
// CompletionResponse 完成响应
type CompletionResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Choice1 `json:"choices"`
Usage Usage `json:"usage"`
}
type Choice1 struct {
Text string `json:"text"`
Index int `json:"index"`
FinishReason string `json:"finish_reason"`
}
// StreamResponse 流式响应
type StreamResponse struct {
ID string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
Choices []Delta `json:"choices"`
}
type Delta struct {
Delta struct {
Content string `json:"content,omitempty"`
Role string `json:"role,omitempty"`
} `json:"delta"`
Index int `json:"index"`
FinishReason string `json:"finish_reason,omitempty"`
}
// ErrorResponse 错误响应
type ErrorResponse struct {
Error ErrorDetail `json:"error"`
}
type ErrorDetail struct {
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code,omitempty"`
Param string `json:"param,omitempty"`
}
// HealthStatus 健康状态
type HealthStatus struct {
Status string `json:"status"`
Timestamp time.Time `json:"timestamp"`
Services map[string]bool `json:"services"`
}
// Tenant 租户
type Tenant struct {
ID int64 `json:"id"`
Name string `json:"name"`
Plan string `json:"plan"`
CreatedAt time.Time `json:"created_at"`
}
// Budget 预算
type Budget struct {
TenantID int64 `json:"tenant_id"`
MonthlyLimit float64 `json:"monthly_limit"`
AlertThreshold float64 `json:"alert_threshold"`
CurrentUsage float64 `json:"current_usage"`
}
// RouteRequest 路由请求
type RouteRequest struct {
Model string
TenantID int64
RouteType string // "primary", "fallback"
}
// RouteResult 路由结果
type RouteResult struct {
Provider string
Model string
LatencyMs int64
Success bool
Error error
}