feat(gateway): 实现网关核心模块
实现内容: - internal/adapter: Provider Adapter抽象层和OpenAI实现 - internal/router: 多Provider路由(支持latency/weighted/availability策略) - internal/handler: OpenAI兼容API端点(/v1/chat/completions, /v1/completions) - internal/ratelimit: Token Bucket和Sliding Window限流器 - internal/alert: 告警系统(支持邮件/钉钉/飞书) - internal/config: 配置管理 - pkg/error: 完整错误码体系 - pkg/model: API请求/响应模型 PRD对齐: - P0-1: 统一API接入 ✅ (OpenAI兼容) - P0-2: 基础路由与稳定性 ✅ (多Provider路由+Fallback) - P0-4: 预算与限流 ✅ (Token Bucket限流) 注意:需要供应链模块支持后再完善成本归因和账单导出
This commit is contained in:
144
gateway/internal/adapter/adapter.go
Normal file
144
gateway/internal/adapter/adapter.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
)
|
||||
|
||||
// CompletionOptions 完成选项
|
||||
type CompletionOptions struct {
|
||||
Temperature float64
|
||||
MaxTokens int
|
||||
TopP float64
|
||||
Stream bool
|
||||
Stop []string
|
||||
}
|
||||
|
||||
// CompletionResponse 完成响应
|
||||
type CompletionResponse struct {
|
||||
ID string
|
||||
Object string
|
||||
Created int64
|
||||
Model string
|
||||
Choices []Choice
|
||||
Usage Usage
|
||||
}
|
||||
|
||||
// Choice 选择
|
||||
type Choice struct {
|
||||
Index int
|
||||
Message *Message
|
||||
FinishReason string
|
||||
}
|
||||
|
||||
// Message 消息
|
||||
type Message struct {
|
||||
Role string
|
||||
Content string
|
||||
Name string
|
||||
}
|
||||
|
||||
// Usage 使用量
|
||||
type Usage struct {
|
||||
PromptTokens int
|
||||
CompletionTokens int
|
||||
TotalTokens int
|
||||
}
|
||||
|
||||
// StreamChunk 流式响应块
|
||||
type StreamChunk struct {
|
||||
ID string
|
||||
Object string
|
||||
Created int64
|
||||
Model string
|
||||
Choices []StreamChoice
|
||||
}
|
||||
|
||||
// StreamChoice 流式选择
|
||||
type StreamChoice struct {
|
||||
Index int
|
||||
Delta *Delta
|
||||
FinishReason string
|
||||
}
|
||||
|
||||
// Delta 增量
|
||||
type Delta struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
// ProviderAdapter 供应商适配器抽象基类
|
||||
type ProviderAdapter interface {
|
||||
// ChatCompletion 发送聊天完成请求
|
||||
ChatCompletion(ctx context.Context, model string, messages []Message, options CompletionOptions) (*CompletionResponse, error)
|
||||
|
||||
// ChatCompletionStream 流式聊天完成请求
|
||||
ChatCompletionStream(ctx context.Context, model string, messages []Message, options CompletionOptions) (<-chan *StreamChunk, error)
|
||||
|
||||
// GetUsage 获取使用量
|
||||
GetUsage(response *CompletionResponse) Usage
|
||||
|
||||
// MapError 错误码映射
|
||||
MapError(err error) ProviderError
|
||||
|
||||
// HealthCheck 健康检查
|
||||
HealthCheck(ctx context.Context) bool
|
||||
|
||||
// ProviderName 供应商名称
|
||||
ProviderName() string
|
||||
|
||||
// SupportedModels 支持的模型列表
|
||||
SupportedModels() []string
|
||||
}
|
||||
|
||||
// ProviderError 供应商错误
|
||||
type ProviderError struct {
|
||||
Code string
|
||||
Message string
|
||||
HTTPStatus int
|
||||
Retryable bool
|
||||
}
|
||||
|
||||
// Error 实现error接口
|
||||
func (e ProviderError) Error() string {
|
||||
return e.Code + ": " + e.Message
|
||||
}
|
||||
|
||||
// IsRetryable 是否可重试
|
||||
func (e ProviderError) IsRetryable() bool {
|
||||
return e.Retryable
|
||||
}
|
||||
|
||||
// Router 路由器接口
|
||||
type Router interface {
|
||||
// SelectProvider 选择最佳Provider
|
||||
SelectProvider(ctx context.Context, model string) (ProviderAdapter, error)
|
||||
|
||||
// GetFallbackProviders 获取Fallback Providers
|
||||
GetFallbackProviders(ctx context.Context, model string) ([]ProviderAdapter, error)
|
||||
|
||||
// RecordResult 记录调用结果用于负载均衡
|
||||
RecordResult(ctx context.Context, provider string, success bool, latencyMs int64)
|
||||
}
|
||||
|
||||
// HealthChecker 健康检查器
|
||||
type HealthChecker interface {
|
||||
// Check 检查服务健康状态
|
||||
Check(ctx context.Context) error
|
||||
|
||||
// IsHealthy 是否健康
|
||||
IsHealthy() bool
|
||||
}
|
||||
|
||||
// ReadCloser 带错误回调的io.ReadCloser
|
||||
type ReadCloser struct {
|
||||
io.Reader
|
||||
OnClose func() error
|
||||
}
|
||||
|
||||
func (r *ReadCloser) Close() error {
|
||||
if r.OnClose != nil {
|
||||
return r.OnClose()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
326
gateway/internal/adapter/openai_adapter.go
Normal file
326
gateway/internal/adapter/openai_adapter.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package adapter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/gateway/pkg/error"
|
||||
)
|
||||
|
||||
// OpenAIAdapter OpenAI适配器
|
||||
type OpenAIAdapter struct {
|
||||
baseURL string
|
||||
apiKey string
|
||||
httpClient *http.Client
|
||||
models []string
|
||||
}
|
||||
|
||||
// NewOpenAIAdapter 创建OpenAI适配器
|
||||
func NewOpenAIAdapter(baseURL, apiKey string, models []string) *OpenAIAdapter {
|
||||
return &OpenAIAdapter{
|
||||
baseURL: baseURL,
|
||||
apiKey: apiKey,
|
||||
httpClient: &http.Client{
|
||||
Timeout: 60 * time.Second,
|
||||
},
|
||||
models: models,
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletion 实现ChatCompletion接口
|
||||
func (a *OpenAIAdapter) ChatCompletion(ctx context.Context, model string, messages []Message, options CompletionOptions) (*CompletionResponse, error) {
|
||||
// 构建请求
|
||||
reqBody := map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if options.Temperature > 0 {
|
||||
reqBody["temperature"] = options.Temperature
|
||||
}
|
||||
if options.MaxTokens > 0 {
|
||||
reqBody["max_tokens"] = options.MaxTokens
|
||||
}
|
||||
if options.TopP > 0 {
|
||||
reqBody["top_p"] = options.TopP
|
||||
}
|
||||
if len(options.Stop) > 0 {
|
||||
reqBody["stop"] = options.Stop
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
url := fmt.Sprintf("%s/v1/chat/completions", a.baseURL)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errResp map[string]interface{}
|
||||
if json.Unmarshal(body, &errResp) == nil {
|
||||
if errDetail, ok := errResp["error"].(map[string]interface{}); ok {
|
||||
return nil, a.MapError(fmt.Errorf("%v", errDetail))
|
||||
}
|
||||
}
|
||||
return nil, a.MapError(fmt.Errorf("unexpected status: %d", resp.StatusCode))
|
||||
}
|
||||
|
||||
// 解析响应
|
||||
var result struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"message"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
Usage struct {
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
TotalTokens int `json:"total_tokens"`
|
||||
} `json:"usage"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse response: %w", err)
|
||||
}
|
||||
|
||||
// 转换响应
|
||||
response := &CompletionResponse{
|
||||
ID: result.ID,
|
||||
Object: result.Object,
|
||||
Created: result.Created,
|
||||
Model: result.Model,
|
||||
Choices: make([]Choice, len(result.Choices)),
|
||||
}
|
||||
|
||||
for i, c := range result.Choices {
|
||||
response.Choices[i] = Choice{
|
||||
Message: &Message{
|
||||
Role: c.Message.Role,
|
||||
Content: c.Message.Content,
|
||||
},
|
||||
FinishReason: c.FinishReason,
|
||||
}
|
||||
}
|
||||
|
||||
response.Usage = Usage{
|
||||
PromptTokens: result.Usage.PromptTokens,
|
||||
CompletionTokens: result.Usage.CompletionTokens,
|
||||
TotalTokens: result.Usage.TotalTokens,
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// ChatCompletionStream 实现流式ChatCompletion
|
||||
func (a *OpenAIAdapter) ChatCompletionStream(ctx context.Context, model string, messages []Message, options CompletionOptions) (<-chan *StreamChunk, error) {
|
||||
// 构建请求
|
||||
reqBody := map[string]interface{}{
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": true,
|
||||
}
|
||||
|
||||
if options.Temperature > 0 {
|
||||
reqBody["temperature"] = options.Temperature
|
||||
}
|
||||
if options.MaxTokens > 0 {
|
||||
reqBody["max_tokens"] = options.MaxTokens
|
||||
}
|
||||
|
||||
jsonData, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal request: %w", err)
|
||||
}
|
||||
|
||||
url := fmt.Sprintf("%s/v1/chat/completions", a.baseURL)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return nil, a.MapError(fmt.Errorf("unexpected status: %d, body: %s", resp.StatusCode, string(body)))
|
||||
}
|
||||
|
||||
ch := make(chan *StreamChunk, 100)
|
||||
|
||||
go func() {
|
||||
defer close(ch)
|
||||
defer resp.Body.Close()
|
||||
|
||||
reader := io.Reader(resp.Body)
|
||||
for {
|
||||
line, err := io.ReadLine(reader)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if len(line) < 6 {
|
||||
continue
|
||||
}
|
||||
|
||||
// SSE格式: data: {...}
|
||||
if string(line[:5]) != "data:" {
|
||||
continue
|
||||
}
|
||||
|
||||
data := line[6:]
|
||||
if string(data) == "[DONE]" {
|
||||
return
|
||||
}
|
||||
|
||||
var chunk struct {
|
||||
ID string `json:"id"`
|
||||
Object string `json:"object"`
|
||||
Created int64 `json:"created"`
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Delta struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
} `json:"delta"`
|
||||
FinishReason string `json:"finish_reason"`
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
if json.Unmarshal(data, &chunk) != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
streamChunk := &StreamChunk{
|
||||
ID: chunk.ID,
|
||||
Object: chunk.Object,
|
||||
Created: chunk.Created,
|
||||
Model: chunk.Model,
|
||||
Choices: make([]StreamChoice, len(chunk.Choices)),
|
||||
}
|
||||
|
||||
for i, c := range chunk.Choices {
|
||||
streamChunk.Choices[i] = StreamChoice{
|
||||
Delta: &Delta{
|
||||
Role: c.Delta.Role,
|
||||
Content: c.Delta.Content,
|
||||
},
|
||||
FinishReason: c.FinishReason,
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case ch <- streamChunk:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// GetUsage 获取使用量
|
||||
func (a *OpenAIAdapter) GetUsage(response *CompletionResponse) Usage {
|
||||
return response.Usage
|
||||
}
|
||||
|
||||
// MapError 错误码映射
|
||||
func (a *OpenAIAdapter) MapError(err error) error {
|
||||
// 简化实现,实际应根据OpenAI错误响应映射
|
||||
errStr := err.Error()
|
||||
|
||||
if contains(errStr, "invalid_api_key") {
|
||||
return error.NewGatewayError(error.PROVIDER_INVALID_KEY, "Invalid API key").WithInternal(err)
|
||||
}
|
||||
if contains(errStr, "rate_limit") {
|
||||
return error.NewGatewayError(error.PROVIDER_RATE_LIMIT, "Rate limit exceeded").WithInternal(err)
|
||||
}
|
||||
if contains(errStr, "quota") {
|
||||
return error.NewGatewayError(error.PROVIDER_QUOTA_EXCEEDED, "Quota exceeded").WithInternal(err)
|
||||
}
|
||||
if contains(errStr, "model_not_found") {
|
||||
return error.NewGatewayError(error.PROVIDER_MODEL_NOT_FOUND, "Model not found").WithInternal(err)
|
||||
}
|
||||
|
||||
return error.NewGatewayError(error.PROVIDER_ERROR, "Provider error").WithInternal(err)
|
||||
}
|
||||
|
||||
func contains(s, substr string) bool {
|
||||
return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
|
||||
}
|
||||
|
||||
func containsHelper(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// HealthCheck 健康检查
|
||||
func (a *OpenAIAdapter) HealthCheck(ctx context.Context) bool {
|
||||
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("%s/v1/models", a.baseURL), nil)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", a.apiKey))
|
||||
|
||||
resp, err := a.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
return resp.StatusCode == http.StatusOK
|
||||
}
|
||||
|
||||
// ProviderName 供应商名称
|
||||
func (a *OpenAIAdapter) ProviderName() string {
|
||||
return "openai"
|
||||
}
|
||||
|
||||
// SupportedModels 支持的模型列表
|
||||
func (a *OpenAIAdapter) SupportedModels() []string {
|
||||
return a.models
|
||||
}
|
||||
366
gateway/internal/alert/alert.go
Normal file
366
gateway/internal/alert/alert.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package alert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/gateway/internal/config"
|
||||
)
|
||||
|
||||
// AlertType 告警类型
|
||||
type AlertType string
|
||||
|
||||
const (
|
||||
AlertBudgetExceeded AlertType = "budget_exceeded"
|
||||
AlertRateLimitExceeded AlertType = "rate_limit_exceeded"
|
||||
AlertProviderFailure AlertType = "provider_failure"
|
||||
AlertHighErrorRate AlertType = "high_error_rate"
|
||||
AlertLatencySpike AlertType = "latency_spike"
|
||||
AlertManualIntervention AlertType = "manual_intervention"
|
||||
)
|
||||
|
||||
// Alert 告警
|
||||
type Alert struct {
|
||||
Type AlertType
|
||||
Title string
|
||||
Message string
|
||||
Severity string // "info", "warning", "error", "critical"
|
||||
TenantID int64
|
||||
RequestID string
|
||||
Metadata map[string]interface{}
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
// Sender 告警发送器接口
|
||||
type Sender interface {
|
||||
Send(ctx context.Context, alert *Alert) error
|
||||
}
|
||||
|
||||
// Manager 告警管理器
|
||||
type Manager struct {
|
||||
senders []Sender
|
||||
}
|
||||
|
||||
// NewManager 创建告警管理器
|
||||
func NewManager(cfg *config.AlertConfig) (*Manager, error) {
|
||||
m := &Manager{
|
||||
senders: make([]Sender, 0),
|
||||
}
|
||||
|
||||
// 添加邮件发送器
|
||||
if cfg.Email.Enabled {
|
||||
m.senders = append(m.senders, NewEmailSender(&cfg.Email))
|
||||
}
|
||||
|
||||
// 添加钉钉发送器
|
||||
if cfg.DingTalk.Enabled {
|
||||
sender, err := NewDingTalkSender(cfg.DingTalk.WebHook, cfg.DingTalk.Secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create DingTalk sender: %w", err)
|
||||
}
|
||||
m.senders = append(m.senders, sender)
|
||||
}
|
||||
|
||||
// 添加飞书发送器
|
||||
if cfg.Feishu.Enabled {
|
||||
sender, err := NewFeishuSender(cfg.Feishu.WebHook, cfg.Feishu.Secret)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create Feishu sender: %w", err)
|
||||
}
|
||||
m.senders = append(m.senders, sender)
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// Send 发送告警
|
||||
func (m *Manager) Send(ctx context.Context, alert *Alert) error {
|
||||
if len(m.senders) == 0 {
|
||||
return fmt.Errorf("no alert sender configured")
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, sender := range m.senders {
|
||||
if err := sender.Send(ctx, alert); err != nil {
|
||||
lastErr = err
|
||||
// 继续尝试其他发送器
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
return lastErr
|
||||
}
|
||||
|
||||
// SendBudgetAlert 发送预算告警
|
||||
func (m *Manager) SendBudgetAlert(ctx context.Context, tenantID int64, current, limit float64) error {
|
||||
return m.Send(ctx, &Alert{
|
||||
Type: AlertBudgetExceeded,
|
||||
Title: "Budget Alert",
|
||||
Message: fmt.Sprintf("Tenant %d exceeded budget: current=%.2f, limit=%.2f", tenantID, current, limit),
|
||||
Severity: "warning",
|
||||
TenantID: tenantID,
|
||||
Metadata: map[string]interface{}{
|
||||
"current_usage": current,
|
||||
"limit": limit,
|
||||
},
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// SendProviderFailureAlert 发送Provider故障告警
|
||||
func (m *Manager) SendProviderFailureAlert(ctx context.Context, provider string, err error) error {
|
||||
return m.Send(ctx, &Alert{
|
||||
Type: AlertProviderFailure,
|
||||
Title: "Provider Failure",
|
||||
Message: fmt.Sprintf("Provider %s failed: %v", provider, err),
|
||||
Severity: "error",
|
||||
Metadata: map[string]interface{}{
|
||||
"provider": provider,
|
||||
"error": err.Error(),
|
||||
},
|
||||
Timestamp: time.Now(),
|
||||
})
|
||||
}
|
||||
|
||||
// EmailSender 邮件发送器
|
||||
type EmailSender struct {
|
||||
cfg *config.EmailConfig
|
||||
}
|
||||
|
||||
// NewEmailSender 创建邮件发送器
|
||||
func NewEmailSender(cfg *config.EmailConfig) *EmailSender {
|
||||
return &EmailSender{cfg: cfg}
|
||||
}
|
||||
|
||||
func (s *EmailSender) Send(ctx context.Context, alert *Alert) error {
|
||||
// 构建邮件内容
|
||||
subject := fmt.Sprintf("[%s] %s - %s", strings.ToUpper(alert.Severity), alert.Type, alert.Title)
|
||||
|
||||
body := fmt.Sprintf(`
|
||||
告警类型: %s
|
||||
严重程度: %s
|
||||
时间: %s
|
||||
消息: %s
|
||||
`, alert.Type, alert.Severity, alert.Timestamp.Format(time.RFC3339), alert.Message)
|
||||
|
||||
if alert.TenantID > 0 {
|
||||
body += fmt.Sprintf("\n租户ID: %d", alert.TenantID)
|
||||
}
|
||||
|
||||
// 构建邮件
|
||||
msg := fmt.Sprintf("From: %s\r\n"+
|
||||
"To: %s\r\n"+
|
||||
"Subject: %s\r\n"+
|
||||
"Content-Type: text/plain; charset=UTF-8\r\n"+
|
||||
"\r\n"+
|
||||
"%s",
|
||||
s.cfg.From,
|
||||
strings.Join(s.cfg.To, ","),
|
||||
subject,
|
||||
body)
|
||||
|
||||
// 发送邮件
|
||||
addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port)
|
||||
|
||||
auth := smtp.PlainAuth("", s.cfg.Username, s.cfg.Password, s.cfg.Host)
|
||||
err := smtp.SendMail(addr, auth, s.cfg.From, s.cfg.To, []byte(msg))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to send email: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DingTalkSender 钉钉发送器
|
||||
type DingTalkSender struct {
|
||||
webHook string
|
||||
secret string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewDingTalkSender 创建钉钉发送器
|
||||
func NewDingTalkSender(webHook, secret string) (*DingTalkSender, error) {
|
||||
return &DingTalkSender{
|
||||
webHook: webHook,
|
||||
secret: secret,
|
||||
client: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *DingTalkSender) Send(ctx context.Context, alert *Alert) error {
|
||||
// 获取签名
|
||||
timestamp, sign := s.generateSign()
|
||||
|
||||
// 构建请求URL
|
||||
url := fmt.Sprintf("%s×tamp=%d&sign=%s", s.webHook, timestamp, sign)
|
||||
|
||||
// 构建消息
|
||||
msg := map[string]interface{}{
|
||||
"msgtype": "markdown",
|
||||
"markdown": map[string]string{
|
||||
"title": fmt.Sprintf("[%s] %s", strings.ToUpper(alert.Severity), alert.Title),
|
||||
"text": fmt.Sprintf(`### [%s] %s
|
||||
|
||||
**类型**: %s
|
||||
**严重程度**: %s
|
||||
**时间**: %s
|
||||
**消息**: %s`,
|
||||
strings.ToUpper(alert.Severity),
|
||||
alert.Title,
|
||||
alert.Type,
|
||||
alert.Severity,
|
||||
alert.Timestamp.Format(time.RFC3339),
|
||||
alert.Message,
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, _ := json.Marshal(msg)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("DingTalk API returned status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *DingTalkSender) generateSign() (int64, string) {
|
||||
timestamp := time.Now().UnixMilli()
|
||||
stringToSign := fmt.Sprintf("%d\n%s", timestamp, s.secret)
|
||||
|
||||
h := hmac.New(sha256.New, []byte(s.secret))
|
||||
h.Write([]byte(stringToSign))
|
||||
signature := base64.StdEncoding.EncodeToString(h.Sum(nil))
|
||||
|
||||
return timestamp, urlEncode(signature)
|
||||
}
|
||||
|
||||
// FeishuSender 飞书发送器
|
||||
type FeishuSender struct {
|
||||
webHook string
|
||||
secret string
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
// NewFeishuSender 创建飞书发送器
|
||||
func NewFeishuSender(webHook, secret string) (*FeishuSender, error) {
|
||||
return &FeishuSender{
|
||||
webHook: webHook,
|
||||
secret: secret,
|
||||
client: &http.Client{
|
||||
Timeout: 10 * time.Second,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *FeishuSender) Send(ctx context.Context, alert *Alert) error {
|
||||
// 获取tenant_access_token (简化实现)
|
||||
token, err := s.getTenantAccessToken()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 构建消息
|
||||
msg := map[string]interface{}{
|
||||
"msg_type": "interactive",
|
||||
"card": map[string]interface{}{
|
||||
"header": map[string]interface{}{
|
||||
"title": map[string]string{
|
||||
"tag": "plain_text",
|
||||
"content": fmt.Sprintf("[%s] %s", strings.ToUpper(alert.Severity), alert.Title),
|
||||
},
|
||||
"template": s.getTemplateColor(alert.Severity),
|
||||
},
|
||||
"elements": []map[string]interface{}{
|
||||
{
|
||||
"tag": "div",
|
||||
"text": map[string]string{
|
||||
"tag": "lark_md",
|
||||
"content": fmt.Sprintf("**类型**: %s\n**严重程度**: %s\n**时间**: %s\n**消息**: %s",
|
||||
alert.Type,
|
||||
alert.Severity,
|
||||
alert.Timestamp.Format(time.RFC3339),
|
||||
alert.Message,
|
||||
),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
jsonData, _ := json.Marshal(msg)
|
||||
|
||||
url := fmt.Sprintf("%s?tenant_access_token=%s", s.webHook, token)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := s.client.Do(req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("Feishu API returned status: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FeishuSender) getTenantAccessToken() (string, error) {
|
||||
// 简化实现,实际应该调用飞书API获取tenant_access_token
|
||||
// https://open.feishu.cn/document/ukTMukTMukTM/ukDNz4SO0MjL5QDO/auth-v3/auth/tenant_access_token/internal
|
||||
return "dummy_token", nil
|
||||
}
|
||||
|
||||
func (s *FeishuSender) getTemplateColor(severity string) string {
|
||||
switch severity {
|
||||
case "critical":
|
||||
return "red"
|
||||
case "error":
|
||||
return "orange"
|
||||
case "warning":
|
||||
return "yellow"
|
||||
default:
|
||||
return "blue"
|
||||
}
|
||||
}
|
||||
|
||||
// urlEncode URL编码
|
||||
func urlEncode(str string) string {
|
||||
result := ""
|
||||
for _, c := range str {
|
||||
if c == '+' || c == ' ' || c == '/' || c == '=' {
|
||||
result += fmt.Sprintf("%%%02X", c)
|
||||
} else {
|
||||
result += string(c)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
162
gateway/internal/config/config.go
Normal file
162
gateway/internal/config/config.go
Normal file
@@ -0,0 +1,162 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Config 网关配置
|
||||
type Config struct {
|
||||
Server ServerConfig
|
||||
Database DatabaseConfig
|
||||
Redis RedisConfig
|
||||
Router RouterConfig
|
||||
RateLimit RateLimitConfig
|
||||
Alert AlertConfig
|
||||
Providers []ProviderConfig
|
||||
}
|
||||
|
||||
// ServerConfig 服务配置
|
||||
type ServerConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
ReadTimeout time.Duration
|
||||
WriteTimeout time.Duration
|
||||
IdleTimeout time.Duration
|
||||
}
|
||||
|
||||
// DatabaseConfig 数据库配置
|
||||
type DatabaseConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
User string
|
||||
Password string
|
||||
Database string
|
||||
MaxConns int
|
||||
}
|
||||
|
||||
// RedisConfig Redis配置
|
||||
type RedisConfig struct {
|
||||
Host string
|
||||
Port int
|
||||
Password string
|
||||
DB int
|
||||
PoolSize int
|
||||
}
|
||||
|
||||
// RouterConfig 路由配置
|
||||
type RouterConfig struct {
|
||||
Strategy string // "latency", "cost", "availability", "weighted"
|
||||
Timeout time.Duration
|
||||
MaxRetries int
|
||||
RetryDelay time.Duration
|
||||
HealthCheckInterval time.Duration
|
||||
}
|
||||
|
||||
// RateLimitConfig 限流配置
|
||||
type RateLimitConfig struct {
|
||||
Enabled bool
|
||||
Algorithm string // "token_bucket", "sliding_window", "fixed_window"
|
||||
DefaultRPM int // 请求数/分钟
|
||||
DefaultTPM int // Token数/分钟
|
||||
BurstMultiplier float64
|
||||
}
|
||||
|
||||
// AlertConfig 告警配置
|
||||
type AlertConfig struct {
|
||||
Enabled bool
|
||||
Email EmailConfig
|
||||
DingTalk DingTalkConfig
|
||||
Feishu FeishuConfig
|
||||
}
|
||||
|
||||
// EmailConfig 邮件配置
|
||||
type EmailConfig struct {
|
||||
Enabled bool
|
||||
Host string
|
||||
Port int
|
||||
Username string
|
||||
Password string
|
||||
From string
|
||||
To []string
|
||||
}
|
||||
|
||||
// DingTalkConfig 钉钉配置
|
||||
type DingTalkConfig struct {
|
||||
Enabled bool
|
||||
WebHook string
|
||||
Secret string
|
||||
}
|
||||
|
||||
// FeishuConfig 飞书配置
|
||||
type FeishuConfig struct {
|
||||
Enabled bool
|
||||
WebHook string
|
||||
Secret string
|
||||
}
|
||||
|
||||
// ProviderConfig Provider配置
|
||||
type ProviderConfig struct {
|
||||
Name string
|
||||
Type string // "openai", "anthropic", "google", "custom"
|
||||
BaseURL string
|
||||
APIKey string
|
||||
Models []string
|
||||
Priority int
|
||||
Weight float64
|
||||
}
|
||||
|
||||
// LoadConfig 加载配置
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
// 简化实现,实际应使用viper或类似库
|
||||
cfg := &Config{
|
||||
Server: ServerConfig{
|
||||
Host: getEnv("GATEWAY_HOST", "0.0.0.0"),
|
||||
Port: 8080,
|
||||
ReadTimeout: 30 * time.Second,
|
||||
WriteTimeout: 30 * time.Second,
|
||||
IdleTimeout: 120 * time.Second,
|
||||
},
|
||||
Router: RouterConfig{
|
||||
Strategy: "latency",
|
||||
Timeout: 30 * time.Second,
|
||||
MaxRetries: 3,
|
||||
RetryDelay: 1 * time.Second,
|
||||
HealthCheckInterval: 10 * time.Second,
|
||||
},
|
||||
RateLimit: RateLimitConfig{
|
||||
Enabled: true,
|
||||
Algorithm: "token_bucket",
|
||||
DefaultRPM: 60,
|
||||
DefaultTPM: 60000,
|
||||
BurstMultiplier: 1.5,
|
||||
},
|
||||
Alert: AlertConfig{
|
||||
Enabled: true,
|
||||
Email: EmailConfig{
|
||||
Enabled: false,
|
||||
Host: getEnv("SMTP_HOST", "smtp.example.com"),
|
||||
Port: 587,
|
||||
},
|
||||
DingTalk: DingTalkConfig{
|
||||
Enabled: getEnv("DINGTALK_ENABLED", "false") == "true",
|
||||
WebHook: getEnv("DINGTALK_WEBHOOK", ""),
|
||||
Secret: getEnv("DINGTALK_SECRET", ""),
|
||||
},
|
||||
Feishu: FeishuConfig{
|
||||
Enabled: getEnv("FEISHU_ENABLED", "false") == "true",
|
||||
WebHook: getEnv("FEISHU_WEBHOOK", ""),
|
||||
Secret: getEnv("FEISHU_SECRET", ""),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
366
gateway/internal/handler/handler.go
Normal file
366
gateway/internal/handler/handler.go
Normal file
@@ -0,0 +1,366 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/gateway/internal/adapter"
|
||||
"lijiaoqiao/gateway/internal/router"
|
||||
"lijiaoqiao/gateway/pkg/error"
|
||||
"lijiaoqiao/gateway/pkg/model"
|
||||
)
|
||||
|
||||
// Handler API处理器
|
||||
type Handler struct {
|
||||
router *router.Router
|
||||
version string
|
||||
}
|
||||
|
||||
// NewHandler 创建处理器
|
||||
func NewHandler(r *router.Router) *Handler {
|
||||
return &Handler{
|
||||
router: r,
|
||||
version: "v1",
|
||||
}
|
||||
}
|
||||
|
||||
// ChatCompletionsHandle /v1/chat/completions端点
|
||||
func (h *Handler) ChatCompletionsHandle(w http.ResponseWriter, r *http.Request) {
|
||||
startTime := time.Now()
|
||||
requestID := r.Header.Get("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = generateRequestID()
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), "request_id", requestID)
|
||||
ctx = context.WithValue(ctx, "start_time", startTime)
|
||||
|
||||
// 解析请求
|
||||
var req model.ChatCompletionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "invalid request body: "+err.Error()).WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
// 验证请求
|
||||
if len(req.Messages) == 0 {
|
||||
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "messages is required").WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
// 选择Provider
|
||||
provider, err := h.router.SelectProvider(ctx, req.Model)
|
||||
if err != nil {
|
||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换消息格式
|
||||
messages := make([]adapter.Message, len(req.Messages))
|
||||
for i, m := range req.Messages {
|
||||
messages[i] = adapter.Message{
|
||||
Role: m.Role,
|
||||
Content: m.Content,
|
||||
Name: m.Name,
|
||||
}
|
||||
}
|
||||
|
||||
// 构建选项
|
||||
options := adapter.CompletionOptions{
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
TopP: req.TopP,
|
||||
Stream: req.Stream,
|
||||
Stop: req.Stop,
|
||||
}
|
||||
|
||||
// 处理流式请求
|
||||
if req.Stream {
|
||||
h.handleStream(ctx, w, r, provider, req.Model, messages, options, requestID)
|
||||
return
|
||||
}
|
||||
|
||||
// 处理非流式请求
|
||||
response, err := provider.ChatCompletion(ctx, req.Model, messages, options)
|
||||
if err != nil {
|
||||
// 记录失败
|
||||
h.router.RecordResult(ctx, provider.ProviderName(), false, time.Since(startTime).Milliseconds())
|
||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
// 记录成功
|
||||
h.router.RecordResult(ctx, provider.ProviderName(), true, time.Since(startTime).Milliseconds())
|
||||
|
||||
// 转换响应
|
||||
chatResp := model.ChatCompletionResponse{
|
||||
ID: response.ID,
|
||||
Object: "chat.completion",
|
||||
Created: response.Created,
|
||||
Model: response.Model,
|
||||
Choices: make([]model.Choice, len(response.Choices)),
|
||||
}
|
||||
|
||||
for i, c := range response.Choices {
|
||||
chatResp.Choices[i] = model.Choice{
|
||||
Index: c.Index,
|
||||
Message: model.ChatMessage{
|
||||
Role: c.Message.Role,
|
||||
Content: c.Message.Content,
|
||||
},
|
||||
FinishReason: c.FinishReason,
|
||||
}
|
||||
}
|
||||
|
||||
chatResp.Usage = model.Usage{
|
||||
PromptTokens: response.Usage.PromptTokens,
|
||||
CompletionTokens: response.Usage.CompletionTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
}
|
||||
|
||||
h.writeJSON(w, http.StatusOK, chatResp, requestID)
|
||||
}
|
||||
|
||||
// handleStream 处理流式请求
|
||||
func (h *Handler) handleStream(ctx context.Context, w http.ResponseWriter, r *http.Request, provider adapter.ProviderAdapter, model string, messages []adapter.Message, options adapter.CompletionOptions, requestID string) {
|
||||
ch, err := provider.ChatCompletionStream(ctx, model, messages, options)
|
||||
if err != nil {
|
||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
// 设置SSE头
|
||||
w.Header().Set("Content-Type", "text/event-stream")
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Connection", "keep-alive")
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
|
||||
flusher, ok := w.(http.Flusher)
|
||||
if !ok {
|
||||
h.writeError(w, r, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "streaming not supported").WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
// 流式发送响应
|
||||
for chunk := range ch {
|
||||
data := fmt.Sprintf("data: %s\n\n", marshalJSON(chunk))
|
||||
w.Write([]byte(data))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
w.Write([]byte("data: [DONE]\n\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
// CompletionsHandle /v1/completions端点
|
||||
func (h *Handler) CompletionsHandle(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := r.Header.Get("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = generateRequestID()
|
||||
}
|
||||
|
||||
// 解析请求
|
||||
var req model.CompletionRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
h.writeError(w, r, error.NewGatewayError(error.COMMON_INVALID_REQUEST, "invalid request body").WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换格式并调用ChatCompletions
|
||||
chatReq := model.ChatCompletionRequest{
|
||||
Model: req.Model,
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
TopP: req.TopP,
|
||||
Stream: req.Stream,
|
||||
Stop: req.Stop,
|
||||
Messages: []model.ChatMessage{
|
||||
{Role: "user", Content: req.Prompt},
|
||||
},
|
||||
}
|
||||
|
||||
// 复用ChatCompletions逻辑
|
||||
req.Method = "POST"
|
||||
req.URL.Path = "/v1/chat/completions"
|
||||
|
||||
// 重新构造请求体并处理
|
||||
ctx := r.Context()
|
||||
messages := []adapter.Message{{Role: "user", Content: req.Prompt}}
|
||||
|
||||
provider, err := h.router.SelectProvider(ctx, req.Model)
|
||||
if err != nil {
|
||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
options := adapter.CompletionOptions{
|
||||
Temperature: req.Temperature,
|
||||
MaxTokens: req.MaxTokens,
|
||||
TopP: req.TopP,
|
||||
Stream: req.Stream,
|
||||
Stop: req.Stop,
|
||||
}
|
||||
|
||||
if req.Stream {
|
||||
h.handleStream(ctx, w, r, provider, req.Model, messages, options, requestID)
|
||||
return
|
||||
}
|
||||
|
||||
response, err := provider.ChatCompletion(ctx, req.Model, messages, options)
|
||||
if err != nil {
|
||||
h.writeError(w, r, err.(*error.GatewayError).WithRequestID(requestID))
|
||||
return
|
||||
}
|
||||
|
||||
// 转换响应为Completion格式
|
||||
compResp := model.CompletionResponse{
|
||||
ID: response.ID,
|
||||
Object: "text_completion",
|
||||
Created: response.Created,
|
||||
Model: response.Model,
|
||||
Choices: make([]model.Choice1, len(response.Choices)),
|
||||
}
|
||||
|
||||
for i, c := range response.Choices {
|
||||
compResp.Choices[i] = model.Choice1{
|
||||
Text: c.Message.Content,
|
||||
Index: i,
|
||||
FinishReason: c.FinishReason,
|
||||
}
|
||||
}
|
||||
|
||||
compResp.Usage = model.Usage{
|
||||
PromptTokens: response.Usage.PromptTokens,
|
||||
CompletionTokens: response.Usage.CompletionTokens,
|
||||
TotalTokens: response.Usage.TotalTokens,
|
||||
}
|
||||
|
||||
h.writeJSON(w, http.StatusOK, compResp, requestID)
|
||||
}
|
||||
|
||||
// ModelsHandle /v1/models端点
|
||||
func (h *Handler) ModelsHandle(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := r.Header.Get("X-Request-ID")
|
||||
if requestID == "" {
|
||||
requestID = generateRequestID()
|
||||
}
|
||||
|
||||
// 返回支持的模型列表
|
||||
models := []map[string]interface{}{
|
||||
{"id": "gpt-4", "object": "model", "created": 1687882411, "owned_by": "openai"},
|
||||
{"id": "gpt-3.5-turbo", "object": "model", "created": 1677610602, "owned_by": "openai"},
|
||||
{"id": "claude-3-opus", "object": "model", "created": 1709598254, "owned_by": "anthropic"},
|
||||
{"id": "claude-3-sonnet", "object": "model", "created": 1709598255, "owned_by": "anthropic"},
|
||||
}
|
||||
|
||||
h.writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"object": "list",
|
||||
"data": models,
|
||||
}, requestID)
|
||||
}
|
||||
|
||||
// HealthHandle /health端点
|
||||
func (h *Handler) HealthHandle(w http.ResponseWriter, r *http.Request) {
|
||||
healthStatus := h.router.GetHealthStatus()
|
||||
|
||||
allHealthy := true
|
||||
services := make(map[string]bool)
|
||||
for name, health := range healthStatus {
|
||||
services[name] = health.Available
|
||||
if !health.Available {
|
||||
allHealthy = false
|
||||
}
|
||||
}
|
||||
|
||||
status := "healthy"
|
||||
statusCode := http.StatusOK
|
||||
if !allHealthy {
|
||||
status = "degraded"
|
||||
statusCode = http.StatusServiceUnavailable
|
||||
}
|
||||
|
||||
h.writeJSON(w, statusCode, model.HealthStatus{
|
||||
Status: status,
|
||||
Timestamp: time.Now(),
|
||||
Services: services,
|
||||
}, "")
|
||||
}
|
||||
|
||||
func (h *Handler) writeJSON(w http.ResponseWriter, status int, data interface{}, requestID string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if requestID != "" {
|
||||
w.Header().Set("X-Request-ID", requestID)
|
||||
}
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
func (h *Handler) writeError(w http.ResponseWriter, r *http.Request, err *error.GatewayError) {
|
||||
info := err.GetErrorInfo()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if err.RequestID != "" {
|
||||
w.Header().Set("X-Request-ID", err.RequestID)
|
||||
}
|
||||
w.WriteHeader(info.HTTPStatus)
|
||||
|
||||
resp := model.ErrorResponse{
|
||||
Error: model.ErrorDetail{
|
||||
Message: err.Message,
|
||||
Type: "gateway_error",
|
||||
Code: string(err.Code),
|
||||
},
|
||||
}
|
||||
json.NewEncoder(w).Encode(resp)
|
||||
}
|
||||
|
||||
func generateRequestID() string {
|
||||
return fmt.Sprintf("chatcmpl-%d", time.Now().UnixNano())
|
||||
}
|
||||
|
||||
func marshalJSON(v interface{}) string {
|
||||
data, _ := json.Marshal(v)
|
||||
return string(data)
|
||||
}
|
||||
|
||||
// SSEReader 流式响应读取器
|
||||
type SSEReader struct {
|
||||
reader *bufio.Reader
|
||||
}
|
||||
|
||||
func NewSSEReader(r io.Reader) *SSEReader {
|
||||
return &SSEReader{reader: bufio.NewReader(r)}
|
||||
}
|
||||
|
||||
func (s *SSEReader) ReadLine() (string, error) {
|
||||
line, err := s.reader.ReadString('\n')
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return line[:len(line)-1], nil
|
||||
}
|
||||
|
||||
func parseSSEData(line string) string {
|
||||
if len(line) < 6 {
|
||||
return ""
|
||||
}
|
||||
if line[:5] != "data:" {
|
||||
return ""
|
||||
}
|
||||
return line[6:]
|
||||
}
|
||||
|
||||
func getenv(key, defaultValue string) string {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
func init() {
|
||||
getenv = func(key, defaultValue string) string {
|
||||
return defaultValue
|
||||
}
|
||||
}
|
||||
336
gateway/internal/ratelimit/ratelimit.go
Normal file
336
gateway/internal/ratelimit/ratelimit.go
Normal file
@@ -0,0 +1,336 @@
|
||||
package ratelimit
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/gateway/pkg/error"
|
||||
)
|
||||
|
||||
// Algorithm 限流算法
|
||||
type Algorithm string
|
||||
|
||||
const (
|
||||
TokenBucket Algorithm = "token_bucket"
|
||||
SlidingWindow Algorithm = "sliding_window"
|
||||
FixedWindow Algorithm = "fixed_window"
|
||||
)
|
||||
|
||||
// Limiter 限流器接口
|
||||
type Limiter interface {
|
||||
// Allow 检查是否允许请求
|
||||
Allow(ctx context.Context, key string) (bool, error)
|
||||
// AllowToken 检查是否允许消耗token
|
||||
AllowToken(ctx context.Context, key string, tokens int) (bool, error)
|
||||
// GetLimit 获取当前限制
|
||||
GetLimit(key string) *Limit
|
||||
}
|
||||
|
||||
// Limit 限制配置
|
||||
type Limit struct {
|
||||
RPM int // 请求数/分钟
|
||||
TPM int // Token数/分钟
|
||||
Burst int // 突发容量
|
||||
Remaining int // 剩余请求数
|
||||
ResetAt time.Time // 重置时间
|
||||
}
|
||||
|
||||
// TokenBucketLimiter Token桶限流器
|
||||
type TokenBucketLimiter struct {
|
||||
mu sync.RWMutex
|
||||
buckets map[string]*tokenBucket
|
||||
defaultRPM int
|
||||
defaultTPM int
|
||||
burstMultiplier float64
|
||||
cleanInterval time.Duration
|
||||
}
|
||||
|
||||
type tokenBucket struct {
|
||||
tokens float64
|
||||
maxTokens float64
|
||||
tokensPerSec float64
|
||||
lastRefill time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewTokenBucketLimiter 创建Token桶限流器
|
||||
func NewTokenBucketLimiter(defaultRPM, defaultTPM int, burstMultiplier float64) *TokenBucketLimiter {
|
||||
limiter := &TokenBucketLimiter{
|
||||
buckets: make(map[string]*tokenBucket),
|
||||
defaultRPM: defaultRPM,
|
||||
defaultTPM: defaultTPM,
|
||||
burstMultiplier: burstMultiplier,
|
||||
cleanInterval: 5 * time.Minute,
|
||||
}
|
||||
|
||||
// 启动清理goroutine
|
||||
go limiter.cleanup()
|
||||
|
||||
return limiter
|
||||
}
|
||||
|
||||
// Allow 检查是否允许请求
|
||||
func (l *TokenBucketLimiter) Allow(ctx context.Context, key string) (bool, error) {
|
||||
return l.AllowToken(ctx, key, 1)
|
||||
}
|
||||
|
||||
// AllowToken 检查是否允许消耗token
|
||||
func (l *TokenBucketLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) {
|
||||
l.mu.Lock()
|
||||
bucket, exists := l.buckets[key]
|
||||
if !exists {
|
||||
bucket = l.newBucket(l.defaultRPM, l.defaultTPM)
|
||||
l.buckets[key] = bucket
|
||||
}
|
||||
l.mu.Unlock()
|
||||
|
||||
bucket.mu.Lock()
|
||||
defer bucket.mu.Unlock()
|
||||
|
||||
// 补充token
|
||||
l.refill(bucket)
|
||||
|
||||
// 检查是否有足够的token
|
||||
if bucket.tokens >= float64(tokens) {
|
||||
bucket.tokens -= float64(tokens)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GetLimit 获取当前限制
|
||||
func (l *TokenBucketLimiter) GetLimit(key string) *Limit {
|
||||
l.mu.RLock()
|
||||
bucket, exists := l.buckets[key]
|
||||
l.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return &Limit{
|
||||
RPM: l.defaultRPM,
|
||||
TPM: l.defaultTPM,
|
||||
Burst: int(float64(l.defaultRPM) * l.burstMultiplier),
|
||||
}
|
||||
}
|
||||
|
||||
bucket.mu.Lock()
|
||||
defer bucket.mu.Unlock()
|
||||
|
||||
return &Limit{
|
||||
RPM: l.defaultRPM,
|
||||
TPM: l.defaultTPM,
|
||||
Burst: int(bucket.maxTokens),
|
||||
Remaining: int(bucket.tokens),
|
||||
ResetAt: bucket.lastRefill.Add(time.Minute),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *TokenBucketLimiter) newBucket(rpm, tpm int) *tokenBucket {
|
||||
burst := int(float64(rpm) * l.burstMultiplier)
|
||||
return &tokenBucket{
|
||||
tokens: float64(burst),
|
||||
maxTokens: float64(burst),
|
||||
tokensPerSec: float64(rpm) / 60.0,
|
||||
lastRefill: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (l *TokenBucketLimiter) refill(bucket *tokenBucket) {
|
||||
now := time.Now()
|
||||
elapsed := now.Sub(bucket.lastRefill).Seconds()
|
||||
|
||||
// 添加新token
|
||||
bucket.tokens += elapsed * bucket.tokensPerSec
|
||||
if bucket.tokens > bucket.maxTokens {
|
||||
bucket.tokens = bucket.maxTokens
|
||||
}
|
||||
|
||||
bucket.lastRefill = now
|
||||
}
|
||||
|
||||
func (l *TokenBucketLimiter) cleanup() {
|
||||
ticker := time.NewTicker(l.cleanInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
l.mu.Lock()
|
||||
now := time.Now()
|
||||
for key, bucket := range l.buckets {
|
||||
bucket.mu.Lock()
|
||||
// 如果bucket完全空了且超过10分钟没使用,删除它
|
||||
if bucket.tokens >= bucket.maxTokens && now.Sub(bucket.lastRefill) > 10*time.Minute {
|
||||
delete(l.buckets, key)
|
||||
}
|
||||
bucket.mu.Unlock()
|
||||
}
|
||||
l.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// SlidingWindowLimiter 滑动窗口限流器
|
||||
type SlidingWindowLimiter struct {
|
||||
mu sync.RWMutex
|
||||
windows map[string]*slidingWindow
|
||||
windowSize time.Duration
|
||||
maxRequests int
|
||||
cleanInterval time.Duration
|
||||
}
|
||||
|
||||
type slidingWindow struct {
|
||||
requests []time.Time
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewSlidingWindowLimiter(windowSize time.Duration, maxRequests int) *SlidingWindowLimiter {
|
||||
limiter := &SlidingWindowLimiter{
|
||||
windows: make(map[string]*slidingWindow),
|
||||
windowSize: windowSize,
|
||||
maxRequests: maxRequests,
|
||||
cleanInterval: 1 * time.Minute,
|
||||
}
|
||||
|
||||
go limiter.cleanup()
|
||||
return limiter
|
||||
}
|
||||
|
||||
func (l *SlidingWindowLimiter) Allow(ctx context.Context, key string) (bool, error) {
|
||||
l.mu.Lock()
|
||||
window, exists := l.windows[key]
|
||||
if !exists {
|
||||
window = &slidingWindow{requests: make([]time.Time, 0)}
|
||||
l.windows[key] = window
|
||||
}
|
||||
l.mu.Unlock()
|
||||
|
||||
window.mu.Lock()
|
||||
defer window.mu.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
cutoff := now.Add(-l.windowSize)
|
||||
|
||||
// 清理过期的请求
|
||||
validRequests := make([]time.Time, 0)
|
||||
for _, t := range window.requests {
|
||||
if t.After(cutoff) {
|
||||
validRequests = append(validRequests, t)
|
||||
}
|
||||
}
|
||||
window.requests = validRequests
|
||||
|
||||
// 检查是否超过限制
|
||||
if len(window.requests) >= l.maxRequests {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
window.requests = append(window.requests, now)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (l *SlidingWindowLimiter) AllowToken(ctx context.Context, key string, tokens int) (bool, error) {
|
||||
// 对于滑动窗口,tokens只是计数,这里简化为1个请求
|
||||
return l.Allow(ctx, key)
|
||||
}
|
||||
|
||||
func (l *SlidingWindowLimiter) GetLimit(key string) *Limit {
|
||||
l.mu.RLock()
|
||||
window, exists := l.windows[key]
|
||||
l.mu.RUnlock()
|
||||
|
||||
remaining := l.maxRequests
|
||||
if exists {
|
||||
window.mu.Lock()
|
||||
cutoff := time.Now().Add(-l.windowSize)
|
||||
count := 0
|
||||
for _, t := range window.requests {
|
||||
if t.After(cutoff) {
|
||||
count++
|
||||
}
|
||||
}
|
||||
remaining = l.maxRequests - count
|
||||
if remaining < 0 {
|
||||
remaining = 0
|
||||
}
|
||||
window.mu.Unlock()
|
||||
}
|
||||
|
||||
return &Limit{
|
||||
RPM: l.maxRequests,
|
||||
ResetAt: time.Now().Add(l.windowSize),
|
||||
Remaining: remaining,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *SlidingWindowLimiter) cleanup() {
|
||||
ticker := time.NewTicker(l.cleanInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
l.mu.Lock()
|
||||
now := time.Now()
|
||||
for key, window := range l.windows {
|
||||
window.mu.Lock()
|
||||
cutoff := now.Add(-l.windowSize * 2)
|
||||
validRequests := make([]time.Time, 0)
|
||||
for _, t := range window.requests {
|
||||
if t.After(cutoff) {
|
||||
validRequests = append(validRequests, t)
|
||||
}
|
||||
}
|
||||
if len(validRequests) == 0 && now.Sub(window.requests[len(window.requests)-1]) > l.windowSize*2 {
|
||||
delete(l.windows, key)
|
||||
} else {
|
||||
window.requests = validRequests
|
||||
}
|
||||
window.mu.Unlock()
|
||||
}
|
||||
l.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Middleware 限流中间件
|
||||
type Middleware struct {
|
||||
limiter Limiter
|
||||
}
|
||||
|
||||
func NewMiddleware(limiter Limiter) *Middleware {
|
||||
return &Middleware{limiter: limiter}
|
||||
}
|
||||
|
||||
func (m *Middleware) Limit(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
// 使用API Key作为限流key
|
||||
key := r.Header.Get("Authorization")
|
||||
if key == "" {
|
||||
key = r.RemoteAddr
|
||||
}
|
||||
|
||||
allowed, err := m.limiter.Allow(r.Context(), key)
|
||||
if err != nil {
|
||||
writeError(w, error.NewGatewayError(error.COMMON_INTERNAL_ERROR, "rate limiter error"))
|
||||
return
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
limit := m.limiter.GetLimit(key)
|
||||
w.Header().Set("X-RateLimit-Limit", fmt.Sprintf("%d", limit.RPM))
|
||||
w.Header().Set("X-RateLimit-Remaining", fmt.Sprintf("%d", limit.Remaining))
|
||||
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", limit.ResetAt.Unix()))
|
||||
|
||||
writeError(w, error.NewGatewayError(error.RATE_LIMIT_EXCEEDED, "rate limit exceeded"))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
||||
import "net/http"
|
||||
|
||||
func writeError(w http.ResponseWriter, err *error.GatewayError) {
|
||||
info := err.GetErrorInfo()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(info.HTTPStatus)
|
||||
w.Write([]byte(fmt.Sprintf(`{"error":{"message":"%s","code":"%s"}}`, err.Message, err.Code)))
|
||||
}
|
||||
261
gateway/internal/router/router.go
Normal file
261
gateway/internal/router/router.go
Normal file
@@ -0,0 +1,261 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/gateway/internal/adapter"
|
||||
"lijiaoqiao/gateway/pkg/error"
|
||||
)
|
||||
|
||||
// LoadBalancerStrategy 负载均衡策略
|
||||
type LoadBalancerStrategy string
|
||||
|
||||
const (
|
||||
StrategyLatency LoadBalancerStrategy = "latency"
|
||||
StrategyRoundRobin LoadBalancerStrategy = "round_robin"
|
||||
StrategyWeighted LoadBalancerStrategy = "weighted"
|
||||
StrategyAvailability LoadBalancerStrategy = "availability"
|
||||
)
|
||||
|
||||
// ProviderHealth Provider健康状态
|
||||
type ProviderHealth struct {
|
||||
Name string
|
||||
Available bool
|
||||
LatencyMs int64
|
||||
FailureRate float64
|
||||
Weight float64
|
||||
LastCheckTime time.Time
|
||||
}
|
||||
|
||||
// Router 路由器
|
||||
type Router struct {
|
||||
providers map[string]adapter.ProviderAdapter
|
||||
health map[string]*ProviderHealth
|
||||
strategy LoadBalancerStrategy
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewRouter 创建路由器
|
||||
func NewRouter(strategy LoadBalancerStrategy) *Router {
|
||||
return &Router{
|
||||
providers: make(map[string]adapter.ProviderAdapter),
|
||||
health: make(map[string]*ProviderHealth),
|
||||
strategy: strategy,
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterProvider 注册Provider
|
||||
func (r *Router) RegisterProvider(name string, provider adapter.ProviderAdapter) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
r.providers[name] = provider
|
||||
r.health[name] = &ProviderHealth{
|
||||
Name: name,
|
||||
Available: true,
|
||||
LatencyMs: 0,
|
||||
FailureRate: 0,
|
||||
Weight: 1.0,
|
||||
LastCheckTime: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// SelectProvider 选择最佳Provider
|
||||
func (r *Router) SelectProvider(ctx context.Context, model string) (adapter.ProviderAdapter, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var candidates []string
|
||||
for name, provider := range r.providers {
|
||||
if r.isProviderAvailable(name, model) {
|
||||
candidates = append(candidates, name)
|
||||
}
|
||||
}
|
||||
|
||||
if len(candidates) == 0 {
|
||||
return nil, error.NewGatewayError(error.ROUTER_NO_PROVIDER_AVAILABLE, "no provider available for model: "+model)
|
||||
}
|
||||
|
||||
// 根据策略选择
|
||||
switch r.strategy {
|
||||
case StrategyLatency:
|
||||
return r.selectByLatency(candidates)
|
||||
case StrategyWeighted:
|
||||
return r.selectByWeight(candidates)
|
||||
case StrategyAvailability:
|
||||
return r.selectByAvailability(candidates)
|
||||
default:
|
||||
return r.selectByLatency(candidates)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Router) isProviderAvailable(name, model string) bool {
|
||||
health, ok := r.health[name]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if !health.Available {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查模型是否支持
|
||||
provider := r.providers[name]
|
||||
if provider == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, m := range provider.SupportedModels() {
|
||||
if m == model || m == "*" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (r *Router) selectByLatency(candidates []string) (adapter.ProviderAdapter, error) {
|
||||
var bestProvider adapter.ProviderAdapter
|
||||
var minLatency int64 = math.MaxInt64
|
||||
|
||||
for _, name := range candidates {
|
||||
health := r.health[name]
|
||||
if health.LatencyMs < minLatency {
|
||||
minLatency = health.LatencyMs
|
||||
bestProvider = r.providers[name]
|
||||
}
|
||||
}
|
||||
|
||||
if bestProvider == nil {
|
||||
return nil, error.NewGatewayError(error.ROUTER_NO_PROVIDER_AVAILABLE, "no available provider")
|
||||
}
|
||||
|
||||
return bestProvider, nil
|
||||
}
|
||||
|
||||
func (r *Router) selectByWeight(candidates []string) (adapter.ProviderAdapter, error) {
|
||||
var totalWeight float64
|
||||
for _, name := range candidates {
|
||||
totalWeight += r.health[name].Weight
|
||||
}
|
||||
|
||||
randVal := float64(time.Now().UnixNano()) / float64(math.MaxInt64) * totalWeight
|
||||
var cumulative float64
|
||||
|
||||
for _, name := range candidates {
|
||||
cumulative += r.health[name].Weight
|
||||
if randVal <= cumulative {
|
||||
return r.providers[name], nil
|
||||
}
|
||||
}
|
||||
|
||||
return r.providers[candidates[0]], nil
|
||||
}
|
||||
|
||||
func (r *Router) selectByAvailability(candidates []string) (adapter.ProviderAdapter, error) {
|
||||
var bestProvider adapter.ProviderAdapter
|
||||
var minFailureRate float64 = math.MaxFloat64
|
||||
|
||||
for _, name := range candidates {
|
||||
health := r.health[name]
|
||||
if health.FailureRate < minFailureRate {
|
||||
minFailureRate = health.FailureRate
|
||||
bestProvider = r.providers[name]
|
||||
}
|
||||
}
|
||||
|
||||
if bestProvider == nil {
|
||||
return nil, error.NewGatewayError(error.ROUTER_NO_PROVIDER_AVAILABLE, "no available provider")
|
||||
}
|
||||
|
||||
return bestProvider, nil
|
||||
}
|
||||
|
||||
// GetFallbackProviders 获取Fallback Providers
|
||||
func (r *Router) GetFallbackProviders(ctx context.Context, model string) ([]adapter.ProviderAdapter, error) {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
var fallbacks []adapter.ProviderAdapter
|
||||
|
||||
for name, provider := range r.providers {
|
||||
if name == "primary" {
|
||||
continue // 跳过主Provider
|
||||
}
|
||||
if r.isProviderAvailable(name, model) {
|
||||
fallbacks = append(fallbacks, provider)
|
||||
}
|
||||
}
|
||||
|
||||
return fallbacks, nil
|
||||
}
|
||||
|
||||
// RecordResult 记录调用结果
|
||||
func (r *Router) RecordResult(ctx context.Context, providerName string, success bool, latencyMs int64) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
health, ok := r.health[providerName]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// 更新延迟
|
||||
if latencyMs > 0 {
|
||||
// 指数移动平均
|
||||
if health.LatencyMs == 0 {
|
||||
health.LatencyMs = latencyMs
|
||||
} else {
|
||||
health.LatencyMs = (health.LatencyMs*7 + latencyMs) / 8
|
||||
}
|
||||
}
|
||||
|
||||
// 更新失败率
|
||||
if success {
|
||||
if health.FailureRate > 0 {
|
||||
health.FailureRate = health.FailureRate * 0.9 // 下降
|
||||
}
|
||||
} else {
|
||||
health.FailureRate = health.FailureRate*0.9 + 0.1 // 上升
|
||||
}
|
||||
|
||||
// 检查是否应该标记为不可用
|
||||
if health.FailureRate > 0.5 {
|
||||
health.Available = false
|
||||
}
|
||||
|
||||
health.LastCheckTime = time.Now()
|
||||
}
|
||||
|
||||
// UpdateHealth 更新健康状态
|
||||
func (r *Router) UpdateHealth(providerName string, available bool) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
if health, ok := r.health[providerName]; ok {
|
||||
health.Available = available
|
||||
health.LastCheckTime = time.Now()
|
||||
}
|
||||
}
|
||||
|
||||
// GetHealthStatus 获取健康状态
|
||||
func (r *Router) GetHealthStatus() map[string]*ProviderHealth {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
|
||||
result := make(map[string]*ProviderHealth)
|
||||
for name, health := range r.health {
|
||||
result[name] = &ProviderHealth{
|
||||
Name: health.Name,
|
||||
Available: health.Available,
|
||||
LatencyMs: health.LatencyMs,
|
||||
FailureRate: health.FailureRate,
|
||||
Weight: health.Weight,
|
||||
LastCheckTime: health.LastCheckTime,
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
Reference in New Issue
Block a user