P1-01: 提取重复的角色层级定义为包级常量 - 将 roleHierarchy 提取为 roleHierarchyLevels 包级变量 - 消除重复定义 P1-02: 修复伪随机数用于加权选择 - 使用 math/rand 的线程安全随机数生成器替代时间戳 - 确保加权路由的均匀分布 P1-03: 修复 FailureRate 初始化计算错误 - 将成功时的恢复因子从 0.9 改为 0.5 - 加速失败后的恢复过程 P1-04: 为 DefaultIAMService 添加并发控制 - 添加 sync.RWMutex 保护 map 操作 - 确保所有服务方法的线程安全 P1-05: 修复 IP 伪造漏洞 - 添加 TrustedProxies 配置 - 只在来自可信代理时才使用 X-Forwarded-For P1-06: 修复限流 key 提取逻辑错误 - 从 Authorization header 中提取 Bearer token - 避免使用完整的 header 作为限流 key
327 lines
9.0 KiB
Go
327 lines
9.0 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"net"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const requestIDHeader = "X-Request-Id"
|
|
|
|
var defaultNowFunc = time.Now
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
requestIDKey contextKey = "request_id"
|
|
principalKey contextKey = "principal"
|
|
)
|
|
|
|
// Principal 认证成功后的主体信息
|
|
type Principal struct {
|
|
RequestID string
|
|
TokenID string
|
|
SubjectID string
|
|
Role string
|
|
Scope []string
|
|
}
|
|
|
|
// BuildTokenAuthChain 构建认证中间件链
|
|
func BuildTokenAuthChain(cfg AuthMiddlewareConfig, next http.Handler) http.Handler {
|
|
handler := tokenAuthMiddleware(cfg)(next)
|
|
handler = queryKeyRejectMiddleware(handler, cfg.Auditor, cfg.Now, cfg.TrustedProxies)
|
|
handler = requestIDMiddleware(handler, cfg.Now)
|
|
return handler
|
|
}
|
|
|
|
// RequestIDMiddleware 请求ID中间件
|
|
func requestIDMiddleware(next http.Handler, now func() time.Time) http.Handler {
|
|
if next == nil {
|
|
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
|
}
|
|
if now == nil {
|
|
now = defaultNowFunc
|
|
}
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
requestID := ensureRequestID(r, now)
|
|
w.Header().Set(requestIDHeader, requestID)
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// queryKeyRejectMiddleware 拒绝query key入站
|
|
func queryKeyRejectMiddleware(next http.Handler, auditor AuditEmitter, now func() time.Time, trustedProxies []string) http.Handler {
|
|
if next == nil {
|
|
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
|
}
|
|
if now == nil {
|
|
now = defaultNowFunc
|
|
}
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if hasExternalQueryKey(r) {
|
|
requestID, _ := RequestIDFromContext(r.Context())
|
|
emitAudit(r.Context(), auditor, AuditEvent{
|
|
EventName: EventTokenQueryKeyRejected,
|
|
RequestID: requestID,
|
|
Route: r.URL.Path,
|
|
ResultCode: CodeQueryKeyNotAllowed,
|
|
ClientIP: extractClientIP(r, trustedProxies),
|
|
CreatedAt: now(),
|
|
})
|
|
writeError(w, http.StatusUnauthorized, requestID, CodeQueryKeyNotAllowed, "query key not allowed")
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// tokenAuthMiddleware Token认证中间件
|
|
func tokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handler {
|
|
cfg = cfg.withDefaults()
|
|
return func(next http.Handler) http.Handler {
|
|
if next == nil {
|
|
next = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
|
}
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if !cfg.shouldProtect(r.URL.Path) {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
requestID := ensureRequestID(r, cfg.Now)
|
|
if cfg.Verifier == nil || cfg.StatusResolver == nil || cfg.Authorizer == nil {
|
|
writeError(w, http.StatusServiceUnavailable, requestID, CodeAuthNotReady, "auth middleware dependencies are not ready")
|
|
return
|
|
}
|
|
|
|
rawToken, ok := extractBearerToken(r.Header.Get("Authorization"))
|
|
if !ok {
|
|
emitAudit(r.Context(), cfg.Auditor, AuditEvent{
|
|
EventName: EventTokenAuthnFail,
|
|
RequestID: requestID,
|
|
Route: r.URL.Path,
|
|
ResultCode: CodeAuthMissingBearer,
|
|
ClientIP: extractClientIP(r, cfg.TrustedProxies),
|
|
CreatedAt: cfg.Now(),
|
|
})
|
|
writeError(w, http.StatusUnauthorized, requestID, CodeAuthMissingBearer, "missing bearer token")
|
|
return
|
|
}
|
|
|
|
claims, err := cfg.Verifier.Verify(r.Context(), rawToken)
|
|
if err != nil {
|
|
emitAudit(r.Context(), cfg.Auditor, AuditEvent{
|
|
EventName: EventTokenAuthnFail,
|
|
RequestID: requestID,
|
|
Route: r.URL.Path,
|
|
ResultCode: CodeAuthInvalidToken,
|
|
ClientIP: extractClientIP(r, cfg.TrustedProxies),
|
|
CreatedAt: cfg.Now(),
|
|
})
|
|
writeError(w, http.StatusUnauthorized, requestID, CodeAuthInvalidToken, "invalid bearer token")
|
|
return
|
|
}
|
|
|
|
tokenStatus, err := cfg.StatusResolver.Resolve(r.Context(), claims.TokenID)
|
|
if err != nil || tokenStatus != TokenStatusActive {
|
|
emitAudit(r.Context(), cfg.Auditor, AuditEvent{
|
|
EventName: EventTokenAuthnFail,
|
|
RequestID: requestID,
|
|
TokenID: claims.TokenID,
|
|
SubjectID: claims.SubjectID,
|
|
Route: r.URL.Path,
|
|
ResultCode: CodeAuthTokenInactive,
|
|
ClientIP: extractClientIP(r, cfg.TrustedProxies),
|
|
CreatedAt: cfg.Now(),
|
|
})
|
|
writeError(w, http.StatusUnauthorized, requestID, CodeAuthTokenInactive, "token is inactive")
|
|
return
|
|
}
|
|
|
|
if !cfg.Authorizer.Authorize(r.URL.Path, r.Method, claims.Scope, claims.Role) {
|
|
emitAudit(r.Context(), cfg.Auditor, AuditEvent{
|
|
EventName: EventTokenAuthzDenied,
|
|
RequestID: requestID,
|
|
TokenID: claims.TokenID,
|
|
SubjectID: claims.SubjectID,
|
|
Route: r.URL.Path,
|
|
ResultCode: CodeAuthScopeDenied,
|
|
ClientIP: extractClientIP(r, cfg.TrustedProxies),
|
|
CreatedAt: cfg.Now(),
|
|
})
|
|
writeError(w, http.StatusForbidden, requestID, CodeAuthScopeDenied, "scope denied")
|
|
return
|
|
}
|
|
|
|
principal := Principal{
|
|
RequestID: requestID,
|
|
TokenID: claims.TokenID,
|
|
SubjectID: claims.SubjectID,
|
|
Role: claims.Role,
|
|
Scope: append([]string(nil), claims.Scope...),
|
|
}
|
|
ctx := context.WithValue(r.Context(), principalKey, principal)
|
|
ctx = context.WithValue(ctx, requestIDKey, requestID)
|
|
|
|
emitAudit(ctx, cfg.Auditor, AuditEvent{
|
|
EventName: EventTokenAuthnSuccess,
|
|
RequestID: requestID,
|
|
TokenID: claims.TokenID,
|
|
SubjectID: claims.SubjectID,
|
|
Route: r.URL.Path,
|
|
ResultCode: "OK",
|
|
ClientIP: extractClientIP(r, cfg.TrustedProxies),
|
|
CreatedAt: cfg.Now(),
|
|
})
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
})
|
|
}
|
|
}
|
|
|
|
// RequestIDFromContext 从Context获取请求ID
|
|
func RequestIDFromContext(ctx context.Context) (string, bool) {
|
|
if ctx == nil {
|
|
return "", false
|
|
}
|
|
value, ok := ctx.Value(requestIDKey).(string)
|
|
return value, ok
|
|
}
|
|
|
|
// PrincipalFromContext 从Context获取认证主体
|
|
func PrincipalFromContext(ctx context.Context) (Principal, bool) {
|
|
if ctx == nil {
|
|
return Principal{}, false
|
|
}
|
|
value, ok := ctx.Value(principalKey).(Principal)
|
|
return value, ok
|
|
}
|
|
|
|
func (cfg AuthMiddlewareConfig) withDefaults() AuthMiddlewareConfig {
|
|
if cfg.Now == nil {
|
|
cfg.Now = defaultNowFunc
|
|
}
|
|
if len(cfg.ProtectedPrefixes) == 0 {
|
|
cfg.ProtectedPrefixes = []string{"/api/v1/supply", "/api/v1/platform"}
|
|
}
|
|
if len(cfg.ExcludedPrefixes) == 0 {
|
|
cfg.ExcludedPrefixes = []string{"/health", "/healthz", "/metrics", "/readyz"}
|
|
}
|
|
return cfg
|
|
}
|
|
|
|
func (cfg AuthMiddlewareConfig) shouldProtect(path string) bool {
|
|
for _, prefix := range cfg.ExcludedPrefixes {
|
|
if strings.HasPrefix(path, prefix) {
|
|
return false
|
|
}
|
|
}
|
|
for _, prefix := range cfg.ProtectedPrefixes {
|
|
if strings.HasPrefix(path, prefix) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func ensureRequestID(r *http.Request, now func() time.Time) string {
|
|
if now == nil {
|
|
now = defaultNowFunc
|
|
}
|
|
if requestID, ok := RequestIDFromContext(r.Context()); ok && requestID != "" {
|
|
return requestID
|
|
}
|
|
requestID := strings.TrimSpace(r.Header.Get(requestIDHeader))
|
|
if requestID == "" {
|
|
requestID = fmt.Sprintf("req-%d", now().UnixNano())
|
|
}
|
|
ctx := context.WithValue(r.Context(), requestIDKey, requestID)
|
|
*r = *r.WithContext(ctx)
|
|
return requestID
|
|
}
|
|
|
|
func extractBearerToken(authHeader string) (string, bool) {
|
|
const bearerPrefix = "Bearer "
|
|
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
|
return "", false
|
|
}
|
|
token := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
|
|
return token, token != ""
|
|
}
|
|
|
|
func hasExternalQueryKey(r *http.Request) bool {
|
|
if r.URL == nil {
|
|
return false
|
|
}
|
|
query := r.URL.Query()
|
|
for key := range query {
|
|
lowerKey := strings.ToLower(key)
|
|
if lowerKey == "key" || lowerKey == "api_key" || lowerKey == "token" || lowerKey == "access_token" {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func emitAudit(ctx context.Context, auditor AuditEmitter, event AuditEvent) {
|
|
if auditor == nil {
|
|
return
|
|
}
|
|
_ = auditor.Emit(ctx, event)
|
|
}
|
|
|
|
type errorResponse struct {
|
|
RequestID string `json:"request_id"`
|
|
Error errorPayload `json:"error"`
|
|
}
|
|
|
|
type errorPayload struct {
|
|
Code string `json:"code"`
|
|
Message string `json:"message"`
|
|
Details map[string]any `json:"details,omitempty"`
|
|
}
|
|
|
|
func writeError(w http.ResponseWriter, status int, requestID, code, message string) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
payload := errorResponse{
|
|
RequestID: requestID,
|
|
Error: errorPayload{
|
|
Code: code,
|
|
Message: message,
|
|
},
|
|
}
|
|
_ = json.NewEncoder(w).Encode(payload)
|
|
}
|
|
|
|
func extractClientIP(r *http.Request, trustedProxies []string) string {
|
|
// 检查请求是否来自可信代理
|
|
isFromTrustedProxy := false
|
|
remoteHost, _, err := net.SplitHostPort(r.RemoteAddr)
|
|
if err == nil {
|
|
for _, proxy := range trustedProxies {
|
|
if remoteHost == proxy {
|
|
isFromTrustedProxy = true
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
// 只有来自可信代理的请求才使用X-Forwarded-For
|
|
if isFromTrustedProxy {
|
|
xForwardedFor := strings.TrimSpace(r.Header.Get("X-Forwarded-For"))
|
|
if xForwardedFor != "" {
|
|
parts := strings.Split(xForwardedFor, ",")
|
|
return strings.TrimSpace(parts[0])
|
|
}
|
|
}
|
|
|
|
// 否则使用RemoteAddr
|
|
if err == nil {
|
|
return remoteHost
|
|
}
|
|
return r.RemoteAddr
|
|
} |