485 lines
12 KiB
Go
485 lines
12 KiB
Go
package service
|
||
|
||
import (
|
||
"bytes"
|
||
"context"
|
||
"crypto/hmac"
|
||
cryptorand "crypto/rand"
|
||
"crypto/sha256"
|
||
"encoding/hex"
|
||
"encoding/json"
|
||
"fmt"
|
||
"log/slog"
|
||
"net"
|
||
"net/http"
|
||
"net/url"
|
||
"strings"
|
||
"sync"
|
||
"time"
|
||
|
||
"github.com/user-management-system/internal/domain"
|
||
"github.com/user-management-system/internal/repository"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
// WebhookService Webhook 服务
|
||
type WebhookService struct {
|
||
db *gorm.DB
|
||
repo *repository.WebhookRepository
|
||
queue chan *deliveryTask
|
||
workers int
|
||
config WebhookServiceConfig
|
||
wg sync.WaitGroup
|
||
once sync.Once
|
||
}
|
||
|
||
type WebhookServiceConfig struct {
|
||
Enabled bool
|
||
SecretHeader string
|
||
TimeoutSec int
|
||
MaxRetries int
|
||
RetryBackoff string
|
||
WorkerCount int
|
||
QueueSize int
|
||
}
|
||
|
||
// deliveryTask 投递任务
|
||
type deliveryTask struct {
|
||
webhook *domain.Webhook
|
||
eventType domain.WebhookEventType
|
||
payload []byte
|
||
attempt int
|
||
}
|
||
|
||
// WebhookEvent 发布的事件结构
|
||
type WebhookEvent struct {
|
||
EventID string `json:"event_id"`
|
||
EventType domain.WebhookEventType `json:"event_type"`
|
||
Timestamp time.Time `json:"timestamp"`
|
||
Data interface{} `json:"data"`
|
||
}
|
||
|
||
// NewWebhookService 创建 Webhook 服务
|
||
func NewWebhookService(db *gorm.DB, cfgs ...WebhookServiceConfig) *WebhookService {
|
||
cfg := defaultWebhookServiceConfig()
|
||
if len(cfgs) > 0 {
|
||
cfg = cfgs[0]
|
||
}
|
||
if cfg.WorkerCount <= 0 {
|
||
cfg.WorkerCount = defaultWebhookServiceConfig().WorkerCount
|
||
}
|
||
if cfg.QueueSize <= 0 {
|
||
cfg.QueueSize = defaultWebhookServiceConfig().QueueSize
|
||
}
|
||
if cfg.SecretHeader == "" {
|
||
cfg.SecretHeader = defaultWebhookServiceConfig().SecretHeader
|
||
}
|
||
if cfg.TimeoutSec <= 0 {
|
||
cfg.TimeoutSec = defaultWebhookServiceConfig().TimeoutSec
|
||
}
|
||
if cfg.MaxRetries <= 0 {
|
||
cfg.MaxRetries = defaultWebhookServiceConfig().MaxRetries
|
||
}
|
||
if cfg.RetryBackoff == "" {
|
||
cfg.RetryBackoff = defaultWebhookServiceConfig().RetryBackoff
|
||
}
|
||
|
||
svc := &WebhookService{
|
||
db: db,
|
||
repo: repository.NewWebhookRepository(db),
|
||
queue: make(chan *deliveryTask, cfg.QueueSize),
|
||
workers: cfg.WorkerCount,
|
||
config: cfg,
|
||
}
|
||
svc.startWorkers()
|
||
return svc
|
||
}
|
||
|
||
func defaultWebhookServiceConfig() WebhookServiceConfig {
|
||
return WebhookServiceConfig{
|
||
Enabled: true,
|
||
SecretHeader: "X-Webhook-Signature",
|
||
TimeoutSec: 10,
|
||
MaxRetries: 3,
|
||
RetryBackoff: "exponential",
|
||
WorkerCount: 4,
|
||
QueueSize: 1000,
|
||
}
|
||
}
|
||
|
||
// startWorkers 启动后台投递 worker
|
||
func (s *WebhookService) startWorkers() {
|
||
s.once.Do(func() {
|
||
for i := 0; i < s.workers; i++ {
|
||
s.wg.Add(1)
|
||
go func() {
|
||
defer s.wg.Done()
|
||
for task := range s.queue {
|
||
s.deliver(task)
|
||
}
|
||
}()
|
||
}
|
||
})
|
||
}
|
||
|
||
// Publish 发布事件:找到订阅该事件的所有 Webhook,异步投递
|
||
func (s *WebhookService) Publish(ctx context.Context, eventType domain.WebhookEventType, data interface{}) {
|
||
if !s.config.Enabled {
|
||
return
|
||
}
|
||
// 查询所有活跃 Webhook
|
||
webhooks, err := s.repo.ListActive(ctx)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
// 构建事件载荷
|
||
eventID, err := generateEventID()
|
||
if err != nil {
|
||
slog.Error("generate event ID failed", "error", err)
|
||
return
|
||
}
|
||
event := &WebhookEvent{
|
||
EventID: eventID,
|
||
EventType: eventType,
|
||
Timestamp: time.Now().UTC(),
|
||
Data: data,
|
||
}
|
||
payloadBytes, err := json.Marshal(event)
|
||
if err != nil {
|
||
return
|
||
}
|
||
|
||
for i := range webhooks {
|
||
wh := webhooks[i]
|
||
// 检查是否订阅了该事件类型
|
||
if !webhookSubscribesTo(wh, eventType) {
|
||
continue
|
||
}
|
||
|
||
task := &deliveryTask{
|
||
webhook: wh,
|
||
eventType: eventType,
|
||
payload: payloadBytes,
|
||
attempt: 1,
|
||
}
|
||
|
||
// 非阻塞投递到队列
|
||
select {
|
||
case s.queue <- task:
|
||
default:
|
||
// 队列满时记录但不阻塞
|
||
}
|
||
}
|
||
}
|
||
|
||
// deliver 执行单次 HTTP 投递
|
||
func (s *WebhookService) deliver(task *deliveryTask) {
|
||
wh := task.webhook
|
||
|
||
// NEW-SEC-01 修复:检查 URL 安全性
|
||
if !isSafeURL(wh.URL) {
|
||
s.recordDelivery(task, 0, "", "webhook URL 不安全: 可能存在 SSRF 风险", false)
|
||
return
|
||
}
|
||
|
||
timeout := time.Duration(wh.TimeoutSec) * time.Second
|
||
if timeout <= 0 {
|
||
timeout = time.Duration(s.config.TimeoutSec) * time.Second
|
||
}
|
||
if timeout <= 0 {
|
||
timeout = 10 * time.Second
|
||
}
|
||
|
||
client := &http.Client{Timeout: timeout}
|
||
|
||
req, err := http.NewRequest("POST", wh.URL, bytes.NewReader(task.payload))
|
||
if err != nil {
|
||
s.recordDelivery(task, 0, "", err.Error(), false)
|
||
return
|
||
}
|
||
|
||
req.Header.Set("Content-Type", "application/json")
|
||
req.Header.Set("User-Agent", "UserManagementSystem-Webhook/1.0")
|
||
req.Header.Set("X-Webhook-Event", string(task.eventType))
|
||
req.Header.Set("X-Webhook-Attempt", fmt.Sprintf("%d", task.attempt))
|
||
|
||
// HMAC 签名
|
||
if wh.Secret != "" {
|
||
sig := computeHMAC(task.payload, wh.Secret)
|
||
req.Header.Set(s.config.SecretHeader, "sha256="+sig)
|
||
}
|
||
|
||
// 使用带超时的 context 避免请求无限等待
|
||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||
defer cancel()
|
||
resp, err := client.Do(req.WithContext(ctx))
|
||
if err != nil {
|
||
s.handleFailure(task, 0, "", err.Error())
|
||
return
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
var respBuf bytes.Buffer
|
||
respBuf.ReadFrom(resp.Body)
|
||
success := resp.StatusCode >= 200 && resp.StatusCode < 300
|
||
|
||
if !success {
|
||
s.handleFailure(task, resp.StatusCode, respBuf.String(), "非 2xx 响应")
|
||
return
|
||
}
|
||
|
||
s.recordDelivery(task, resp.StatusCode, respBuf.String(), "", true)
|
||
}
|
||
|
||
// handleFailure 处理投递失败(重试逻辑)
|
||
func (s *WebhookService) handleFailure(task *deliveryTask, statusCode int, body, errMsg string) {
|
||
s.recordDelivery(task, statusCode, body, errMsg, false)
|
||
|
||
// 指数退避重试
|
||
if task.attempt < task.webhook.MaxRetries {
|
||
backoff := time.Second
|
||
if s.config.RetryBackoff == "fixed" {
|
||
backoff = 2 * time.Second
|
||
} else {
|
||
backoff = time.Duration(1<<uint(task.attempt)) * time.Second
|
||
}
|
||
time.AfterFunc(backoff, func() {
|
||
task.attempt++
|
||
select {
|
||
case s.queue <- task:
|
||
default:
|
||
}
|
||
})
|
||
}
|
||
}
|
||
|
||
// recordDelivery 记录投递日志
|
||
func (s *WebhookService) recordDelivery(task *deliveryTask, statusCode int, body, errMsg string, success bool) {
|
||
now := time.Now()
|
||
delivery := &domain.WebhookDelivery{
|
||
WebhookID: task.webhook.ID,
|
||
EventType: task.eventType,
|
||
Payload: string(task.payload),
|
||
StatusCode: statusCode,
|
||
ResponseBody: body,
|
||
Attempt: task.attempt,
|
||
Success: success,
|
||
Error: errMsg,
|
||
}
|
||
if success {
|
||
delivery.DeliveredAt = &now
|
||
}
|
||
_ = s.repo.CreateDelivery(context.Background(), delivery)
|
||
}
|
||
|
||
// CreateWebhook 创建 Webhook
|
||
func (s *WebhookService) CreateWebhook(ctx context.Context, req *CreateWebhookRequest, createdBy int64) (*domain.Webhook, error) {
|
||
eventsJSON, err := json.Marshal(req.Events)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("序列化事件列表失败")
|
||
}
|
||
|
||
secret := req.Secret
|
||
if secret == "" {
|
||
generatedSecret, err := generateWebhookSecret()
|
||
if err != nil {
|
||
return nil, fmt.Errorf("generate webhook secret failed: %w", err)
|
||
}
|
||
secret = generatedSecret
|
||
}
|
||
|
||
wh := &domain.Webhook{
|
||
Name: req.Name,
|
||
URL: req.URL,
|
||
Secret: secret,
|
||
Events: string(eventsJSON),
|
||
Status: domain.WebhookStatusActive,
|
||
MaxRetries: s.config.MaxRetries,
|
||
TimeoutSec: s.config.TimeoutSec,
|
||
CreatedBy: createdBy,
|
||
}
|
||
if err := s.repo.Create(ctx, wh); err != nil {
|
||
return nil, err
|
||
}
|
||
return wh, nil
|
||
}
|
||
|
||
// UpdateWebhook 更新 Webhook
|
||
func (s *WebhookService) UpdateWebhook(ctx context.Context, id int64, req *UpdateWebhookRequest) error {
|
||
updates := map[string]interface{}{}
|
||
if req.Name != "" {
|
||
updates["name"] = req.Name
|
||
}
|
||
if req.URL != "" {
|
||
updates["url"] = req.URL
|
||
}
|
||
if len(req.Events) > 0 {
|
||
b, _ := json.Marshal(req.Events)
|
||
updates["events"] = string(b)
|
||
}
|
||
if req.Status != nil {
|
||
updates["status"] = *req.Status
|
||
}
|
||
return s.repo.Update(ctx, id, updates)
|
||
}
|
||
|
||
// DeleteWebhook 删除 Webhook
|
||
func (s *WebhookService) DeleteWebhook(ctx context.Context, id int64) error {
|
||
return s.repo.Delete(ctx, id)
|
||
}
|
||
|
||
func (s *WebhookService) GetWebhook(ctx context.Context, id int64) (*domain.Webhook, error) {
|
||
return s.repo.GetByID(ctx, id)
|
||
}
|
||
|
||
// ListWebhooks 获取 Webhook 列表(不分页)
|
||
func (s *WebhookService) ListWebhooks(ctx context.Context, createdBy int64) ([]*domain.Webhook, error) {
|
||
return s.repo.ListByCreator(ctx, createdBy)
|
||
}
|
||
|
||
// ListWebhooksPaginated 获取 Webhook 列表(分页)
|
||
func (s *WebhookService) ListWebhooksPaginated(ctx context.Context, createdBy int64, offset, limit int) ([]*domain.Webhook, int64, error) {
|
||
return s.repo.ListByCreatorPaginated(ctx, createdBy, offset, limit)
|
||
}
|
||
|
||
// GetWebhookDeliveries 获取投递记录
|
||
func (s *WebhookService) GetWebhookDeliveries(ctx context.Context, webhookID int64, limit int) ([]*domain.WebhookDelivery, error) {
|
||
return s.repo.ListDeliveries(ctx, webhookID, limit)
|
||
}
|
||
|
||
// ---- Request/Response 结构 ----
|
||
|
||
// CreateWebhookRequest 创建 Webhook 请求
|
||
type CreateWebhookRequest struct {
|
||
Name string `json:"name" binding:"required"`
|
||
URL string `json:"url" binding:"required,url"`
|
||
Secret string `json:"secret"`
|
||
Events []domain.WebhookEventType `json:"events" binding:"required,min=1"`
|
||
}
|
||
|
||
// UpdateWebhookRequest 更新 Webhook 请求
|
||
type UpdateWebhookRequest struct {
|
||
Name string `json:"name"`
|
||
URL string `json:"url"`
|
||
Events []domain.WebhookEventType `json:"events"`
|
||
Status *domain.WebhookStatus `json:"status"`
|
||
}
|
||
|
||
// ---- 辅助函数 ----
|
||
|
||
// webhookSubscribesTo 检查 Webhook 是否订阅了指定事件类型
|
||
func webhookSubscribesTo(w *domain.Webhook, eventType domain.WebhookEventType) bool {
|
||
var events []domain.WebhookEventType
|
||
if err := json.Unmarshal([]byte(w.Events), &events); err != nil {
|
||
return false
|
||
}
|
||
for _, e := range events {
|
||
if e == eventType || e == "*" {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// SubscribesTo 检查 Webhook 是否订阅了指定事件类型(为 domain.Webhook 添加方法,通过包装实现)
|
||
// 注意:此函数在 domain 包外部无法直接扩展,使用独立函数代替
|
||
|
||
// isSafeURL 检查 URL 是否安全(防止 SSRF 攻击)
|
||
// NEW-SEC-01 修复:添加完整的 URL 安全检查
|
||
func isSafeURL(rawURL string) bool {
|
||
u, err := url.Parse(rawURL)
|
||
if err != nil || u.Scheme == "" {
|
||
return false
|
||
}
|
||
// 只允许 http/https
|
||
if u.Scheme != "http" && u.Scheme != "https" {
|
||
return false
|
||
}
|
||
|
||
host := u.Hostname()
|
||
|
||
// 禁止 localhost
|
||
if host == "localhost" || host == "127.0.0.1" || host == "::1" {
|
||
return false
|
||
}
|
||
|
||
// 检查内网 IP
|
||
if ip := net.ParseIP(host); ip != nil {
|
||
if isPrivateIP(ip) {
|
||
return false
|
||
}
|
||
}
|
||
|
||
// 检查内网域名
|
||
if strings.HasSuffix(host, ".internal") ||
|
||
strings.HasSuffix(host, ".local") ||
|
||
strings.HasSuffix(host, ".corp") ||
|
||
strings.HasSuffix(host, ".lan") ||
|
||
strings.HasSuffix(host, ".intranet") {
|
||
return false
|
||
}
|
||
|
||
// 检查知名内网服务地址
|
||
blockedHosts := []string{
|
||
"metadata.google.internal", // GCP 元数据服务
|
||
"169.254.169.254", // AWS/Azure/GCP 元数据服务
|
||
"metadata.azure.internal", // Azure 元数据服务
|
||
"100.100.100.200", // 阿里云元数据服务
|
||
}
|
||
for _, blocked := range blockedHosts {
|
||
if host == blocked {
|
||
return false
|
||
}
|
||
}
|
||
|
||
return true
|
||
}
|
||
|
||
// isPrivateIP 检查是否为内网 IP
|
||
func isPrivateIP(ip net.IP) bool {
|
||
privateRanges := []string{
|
||
"10.0.0.0/8",
|
||
"172.16.0.0/12",
|
||
"192.168.0.0/16",
|
||
"127.0.0.0/8",
|
||
"::1/128",
|
||
"fc00::/7",
|
||
}
|
||
for _, cidr := range privateRanges {
|
||
_, network, err := net.ParseCIDR(cidr)
|
||
if err != nil {
|
||
continue
|
||
}
|
||
if network.Contains(ip) {
|
||
return true
|
||
}
|
||
}
|
||
return false
|
||
}
|
||
|
||
// computeHMAC 计算 HMAC-SHA256 签名
|
||
func computeHMAC(payload []byte, secret string) string {
|
||
mac := hmac.New(sha256.New, []byte(secret))
|
||
mac.Write(payload)
|
||
return hex.EncodeToString(mac.Sum(nil))
|
||
}
|
||
|
||
// generateEventID 生成随机事件 ID
|
||
func generateEventID() (string, error) {
|
||
b := make([]byte, 8)
|
||
if _, err := cryptorand.Read(b); err != nil {
|
||
return "", fmt.Errorf("generate event ID failed: %w", err)
|
||
}
|
||
return "evt_" + hex.EncodeToString(b), nil
|
||
}
|
||
|
||
// generateWebhookSecret 生成随机 Webhook 签名密钥
|
||
func generateWebhookSecret() (string, error) {
|
||
b := make([]byte, 24)
|
||
if _, err := cryptorand.Read(b); err != nil {
|
||
return "", fmt.Errorf("generate webhook secret failed: %w", err)
|
||
}
|
||
return strings.ToLower(hex.EncodeToString(b)), nil
|
||
}
|