feat(P1/P2): 完成TDD开发及P1/P2设计文档
## 设计文档 - multi_role_permission_design: 多角色权限设计 (CONDITIONAL GO) - audit_log_enhancement_design: 审计日志增强 (CONDITIONAL GO) - routing_strategy_template_design: 路由策略模板 (CONDITIONAL GO) - sso_saml_technical_research: SSO/SAML调研 (CONDITIONAL GO) - compliance_capability_package_design: 合规能力包设计 (CONDITIONAL GO) ## TDD开发成果 - IAM模块: supply-api/internal/iam/ (111个测试) - 审计日志模块: supply-api/internal/audit/ (40+测试) - 路由策略模块: gateway/internal/router/ (33+测试) - 合规能力包: gateway/internal/compliance/ + scripts/ci/compliance/ ## 规范文档 - parallel_agent_output_quality_standards: 并行Agent产出质量规范 - project_experience_summary: 项目经验总结 (v2) - 2026-04-02-p1-p2-tdd-execution-plan: TDD执行计划 ## 评审报告 - 5个CONDITIONAL GO设计文档评审报告 - fix_verification_report: 修复验证报告 - full_verification_report: 全面质量验证报告 - tdd_module_quality_verification: TDD模块质量验证 - tdd_execution_summary: TDD执行总结 依据: Superpowers执行框架 + TDD规范
This commit is contained in:
114
gateway/internal/middleware/audit.go
Normal file
114
gateway/internal/middleware/audit.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
)
|
||||
|
||||
// DatabaseAuditEmitter 实现 AuditEmitter 接口,将审计事件存入数据库
|
||||
type DatabaseAuditEmitter struct {
|
||||
db *sql.DB
|
||||
mu sync.RWMutex
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// NewDatabaseAuditEmitter 创建数据库审计发射器
|
||||
func NewDatabaseAuditEmitter(dsn string, now func() time.Time) (*DatabaseAuditEmitter, error) {
|
||||
if now == nil {
|
||||
now = time.Now
|
||||
}
|
||||
|
||||
db, err := sql.Open("pgx", dsn)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// 测试连接
|
||||
if err := db.Ping(); err != nil {
|
||||
return nil, fmt.Errorf("failed to ping database: %w", err)
|
||||
}
|
||||
|
||||
emitter := &DatabaseAuditEmitter{
|
||||
db: db,
|
||||
now: now,
|
||||
}
|
||||
|
||||
// 初始化表
|
||||
if err := emitter.initSchema(); err != nil {
|
||||
return nil, fmt.Errorf("failed to init schema: %w", err)
|
||||
}
|
||||
|
||||
return emitter, nil
|
||||
}
|
||||
|
||||
// initSchema 创建审计表
|
||||
func (e *DatabaseAuditEmitter) initSchema() error {
|
||||
schema := `
|
||||
CREATE TABLE IF NOT EXISTS token_audit_events (
|
||||
event_id VARCHAR(64) PRIMARY KEY,
|
||||
event_name VARCHAR(128) NOT NULL,
|
||||
request_id VARCHAR(128) NOT NULL,
|
||||
token_id VARCHAR(128),
|
||||
subject_id VARCHAR(128),
|
||||
route VARCHAR(256) NOT NULL,
|
||||
result_code VARCHAR(64) NOT NULL,
|
||||
client_ip VARCHAR(64),
|
||||
created_at TIMESTAMP NOT NULL DEFAULT NOW()
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_token_audit_request_id ON token_audit_events(request_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_token_audit_token_id ON token_audit_events(token_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_token_audit_subject_id ON token_audit_events(subject_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_token_audit_created_at ON token_audit_events(created_at);
|
||||
`
|
||||
_, err := e.db.Exec(schema)
|
||||
return err
|
||||
}
|
||||
|
||||
// Emit 实现 AuditEmitter 接口
|
||||
func (e *DatabaseAuditEmitter) Emit(_ context.Context, event AuditEvent) error {
|
||||
if event.EventID == "" {
|
||||
event.EventID = fmt.Sprintf("evt-%d", e.now().UnixNano())
|
||||
}
|
||||
if event.CreatedAt.IsZero() {
|
||||
event.CreatedAt = e.now()
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO token_audit_events (event_id, event_name, request_id, token_id, subject_id, route, result_code, client_ip, created_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
`
|
||||
_, err := e.db.Exec(query,
|
||||
event.EventID,
|
||||
event.EventName,
|
||||
event.RequestID,
|
||||
nullString(event.TokenID),
|
||||
nullString(event.SubjectID),
|
||||
event.Route,
|
||||
event.ResultCode,
|
||||
nullString(event.ClientIP),
|
||||
event.CreatedAt,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// Close 关闭数据库连接
|
||||
func (e *DatabaseAuditEmitter) Close() error {
|
||||
if e.db != nil {
|
||||
return e.db.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// nullString 安全处理空字符串
|
||||
func nullString(s string) sql.NullString {
|
||||
if s == "" {
|
||||
return sql.NullString{}
|
||||
}
|
||||
return sql.NullString{String: s, Valid: true}
|
||||
}
|
||||
311
gateway/internal/middleware/chain.go
Normal file
311
gateway/internal/middleware/chain.go
Normal file
@@ -0,0 +1,311 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const requestIDHeader = "X-Request-Id"
|
||||
|
||||
var defaultNowFunc = time.Now
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
requestIDKey contextKey = "request_id"
|
||||
principalKey contextKey = "principal"
|
||||
)
|
||||
|
||||
// Principal 认证成功后的主体信息
|
||||
type Principal struct {
|
||||
RequestID string
|
||||
TokenID string
|
||||
SubjectID string
|
||||
Role string
|
||||
Scope []string
|
||||
}
|
||||
|
||||
// BuildTokenAuthChain 构建认证中间件链
|
||||
func BuildTokenAuthChain(cfg AuthMiddlewareConfig, next http.Handler) http.Handler {
|
||||
handler := tokenAuthMiddleware(cfg)(next)
|
||||
handler = queryKeyRejectMiddleware(handler, cfg.Auditor, cfg.Now)
|
||||
handler = requestIDMiddleware(handler, cfg.Now)
|
||||
return handler
|
||||
}
|
||||
|
||||
// RequestIDMiddleware 请求ID中间件
|
||||
func requestIDMiddleware(next http.Handler, now func() time.Time) http.Handler {
|
||||
if next == nil {
|
||||
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
||||
}
|
||||
if now == nil {
|
||||
now = defaultNowFunc
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestID := ensureRequestID(r, now)
|
||||
w.Header().Set(requestIDHeader, requestID)
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// queryKeyRejectMiddleware 拒绝query key入站
|
||||
func queryKeyRejectMiddleware(next http.Handler, auditor AuditEmitter, now func() time.Time) http.Handler {
|
||||
if next == nil {
|
||||
return http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
||||
}
|
||||
if now == nil {
|
||||
now = defaultNowFunc
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if hasExternalQueryKey(r) {
|
||||
requestID, _ := RequestIDFromContext(r.Context())
|
||||
emitAudit(r.Context(), auditor, AuditEvent{
|
||||
EventName: EventTokenQueryKeyRejected,
|
||||
RequestID: requestID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: CodeQueryKeyNotAllowed,
|
||||
ClientIP: extractClientIP(r),
|
||||
CreatedAt: now(),
|
||||
})
|
||||
writeError(w, http.StatusUnauthorized, requestID, CodeQueryKeyNotAllowed, "query key not allowed")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// tokenAuthMiddleware Token认证中间件
|
||||
func tokenAuthMiddleware(cfg AuthMiddlewareConfig) func(http.Handler) http.Handler {
|
||||
cfg = cfg.withDefaults()
|
||||
return func(next http.Handler) http.Handler {
|
||||
if next == nil {
|
||||
next = http.HandlerFunc(func(http.ResponseWriter, *http.Request) {})
|
||||
}
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if !cfg.shouldProtect(r.URL.Path) {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
requestID := ensureRequestID(r, cfg.Now)
|
||||
if cfg.Verifier == nil || cfg.StatusResolver == nil || cfg.Authorizer == nil {
|
||||
writeError(w, http.StatusServiceUnavailable, requestID, CodeAuthNotReady, "auth middleware dependencies are not ready")
|
||||
return
|
||||
}
|
||||
|
||||
rawToken, ok := extractBearerToken(r.Header.Get("Authorization"))
|
||||
if !ok {
|
||||
emitAudit(r.Context(), cfg.Auditor, AuditEvent{
|
||||
EventName: EventTokenAuthnFail,
|
||||
RequestID: requestID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: CodeAuthMissingBearer,
|
||||
ClientIP: extractClientIP(r),
|
||||
CreatedAt: cfg.Now(),
|
||||
})
|
||||
writeError(w, http.StatusUnauthorized, requestID, CodeAuthMissingBearer, "missing bearer token")
|
||||
return
|
||||
}
|
||||
|
||||
claims, err := cfg.Verifier.Verify(r.Context(), rawToken)
|
||||
if err != nil {
|
||||
emitAudit(r.Context(), cfg.Auditor, AuditEvent{
|
||||
EventName: EventTokenAuthnFail,
|
||||
RequestID: requestID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: CodeAuthInvalidToken,
|
||||
ClientIP: extractClientIP(r),
|
||||
CreatedAt: cfg.Now(),
|
||||
})
|
||||
writeError(w, http.StatusUnauthorized, requestID, CodeAuthInvalidToken, "invalid bearer token")
|
||||
return
|
||||
}
|
||||
|
||||
tokenStatus, err := cfg.StatusResolver.Resolve(r.Context(), claims.TokenID)
|
||||
if err != nil || tokenStatus != TokenStatusActive {
|
||||
emitAudit(r.Context(), cfg.Auditor, AuditEvent{
|
||||
EventName: EventTokenAuthnFail,
|
||||
RequestID: requestID,
|
||||
TokenID: claims.TokenID,
|
||||
SubjectID: claims.SubjectID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: CodeAuthTokenInactive,
|
||||
ClientIP: extractClientIP(r),
|
||||
CreatedAt: cfg.Now(),
|
||||
})
|
||||
writeError(w, http.StatusUnauthorized, requestID, CodeAuthTokenInactive, "token is inactive")
|
||||
return
|
||||
}
|
||||
|
||||
if !cfg.Authorizer.Authorize(r.URL.Path, r.Method, claims.Scope, claims.Role) {
|
||||
emitAudit(r.Context(), cfg.Auditor, AuditEvent{
|
||||
EventName: EventTokenAuthzDenied,
|
||||
RequestID: requestID,
|
||||
TokenID: claims.TokenID,
|
||||
SubjectID: claims.SubjectID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: CodeAuthScopeDenied,
|
||||
ClientIP: extractClientIP(r),
|
||||
CreatedAt: cfg.Now(),
|
||||
})
|
||||
writeError(w, http.StatusForbidden, requestID, CodeAuthScopeDenied, "scope denied")
|
||||
return
|
||||
}
|
||||
|
||||
principal := Principal{
|
||||
RequestID: requestID,
|
||||
TokenID: claims.TokenID,
|
||||
SubjectID: claims.SubjectID,
|
||||
Role: claims.Role,
|
||||
Scope: append([]string(nil), claims.Scope...),
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), principalKey, principal)
|
||||
ctx = context.WithValue(ctx, requestIDKey, requestID)
|
||||
|
||||
emitAudit(ctx, cfg.Auditor, AuditEvent{
|
||||
EventName: EventTokenAuthnSuccess,
|
||||
RequestID: requestID,
|
||||
TokenID: claims.TokenID,
|
||||
SubjectID: claims.SubjectID,
|
||||
Route: r.URL.Path,
|
||||
ResultCode: "OK",
|
||||
ClientIP: extractClientIP(r),
|
||||
CreatedAt: cfg.Now(),
|
||||
})
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// RequestIDFromContext 从Context获取请求ID
|
||||
func RequestIDFromContext(ctx context.Context) (string, bool) {
|
||||
if ctx == nil {
|
||||
return "", false
|
||||
}
|
||||
value, ok := ctx.Value(requestIDKey).(string)
|
||||
return value, ok
|
||||
}
|
||||
|
||||
// PrincipalFromContext 从Context获取认证主体
|
||||
func PrincipalFromContext(ctx context.Context) (Principal, bool) {
|
||||
if ctx == nil {
|
||||
return Principal{}, false
|
||||
}
|
||||
value, ok := ctx.Value(principalKey).(Principal)
|
||||
return value, ok
|
||||
}
|
||||
|
||||
func (cfg AuthMiddlewareConfig) withDefaults() AuthMiddlewareConfig {
|
||||
if cfg.Now == nil {
|
||||
cfg.Now = defaultNowFunc
|
||||
}
|
||||
if len(cfg.ProtectedPrefixes) == 0 {
|
||||
cfg.ProtectedPrefixes = []string{"/api/v1/supply", "/api/v1/platform"}
|
||||
}
|
||||
if len(cfg.ExcludedPrefixes) == 0 {
|
||||
cfg.ExcludedPrefixes = []string{"/health", "/healthz", "/metrics", "/readyz"}
|
||||
}
|
||||
return cfg
|
||||
}
|
||||
|
||||
func (cfg AuthMiddlewareConfig) shouldProtect(path string) bool {
|
||||
for _, prefix := range cfg.ExcludedPrefixes {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
for _, prefix := range cfg.ProtectedPrefixes {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func ensureRequestID(r *http.Request, now func() time.Time) string {
|
||||
if now == nil {
|
||||
now = defaultNowFunc
|
||||
}
|
||||
if requestID, ok := RequestIDFromContext(r.Context()); ok && requestID != "" {
|
||||
return requestID
|
||||
}
|
||||
requestID := strings.TrimSpace(r.Header.Get(requestIDHeader))
|
||||
if requestID == "" {
|
||||
requestID = fmt.Sprintf("req-%d", now().UnixNano())
|
||||
}
|
||||
ctx := context.WithValue(r.Context(), requestIDKey, requestID)
|
||||
*r = *r.WithContext(ctx)
|
||||
return requestID
|
||||
}
|
||||
|
||||
func extractBearerToken(authHeader string) (string, bool) {
|
||||
const bearerPrefix = "Bearer "
|
||||
if !strings.HasPrefix(authHeader, bearerPrefix) {
|
||||
return "", false
|
||||
}
|
||||
token := strings.TrimSpace(strings.TrimPrefix(authHeader, bearerPrefix))
|
||||
return token, token != ""
|
||||
}
|
||||
|
||||
func hasExternalQueryKey(r *http.Request) bool {
|
||||
if r.URL == nil {
|
||||
return false
|
||||
}
|
||||
query := r.URL.Query()
|
||||
for key := range query {
|
||||
lowerKey := strings.ToLower(key)
|
||||
if lowerKey == "key" || lowerKey == "api_key" || lowerKey == "token" || lowerKey == "access_token" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func emitAudit(ctx context.Context, auditor AuditEmitter, event AuditEvent) {
|
||||
if auditor == nil {
|
||||
return
|
||||
}
|
||||
_ = auditor.Emit(ctx, event)
|
||||
}
|
||||
|
||||
type errorResponse struct {
|
||||
RequestID string `json:"request_id"`
|
||||
Error errorPayload `json:"error"`
|
||||
}
|
||||
|
||||
type errorPayload struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Details map[string]any `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
func writeError(w http.ResponseWriter, status int, requestID, code, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
payload := errorResponse{
|
||||
RequestID: requestID,
|
||||
Error: errorPayload{
|
||||
Code: code,
|
||||
Message: message,
|
||||
},
|
||||
}
|
||||
_ = json.NewEncoder(w).Encode(payload)
|
||||
}
|
||||
|
||||
func extractClientIP(r *http.Request) string {
|
||||
xForwardedFor := strings.TrimSpace(r.Header.Get("X-Forwarded-For"))
|
||||
if xForwardedFor != "" {
|
||||
parts := strings.Split(xForwardedFor, ",")
|
||||
return strings.TrimSpace(parts[0])
|
||||
}
|
||||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||||
if err == nil {
|
||||
return host
|
||||
}
|
||||
return r.RemoteAddr
|
||||
}
|
||||
856
gateway/internal/middleware/middleware_test.go
Normal file
856
gateway/internal/middleware/middleware_test.go
Normal file
@@ -0,0 +1,856 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestExtractBearerToken(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
authHeader string
|
||||
wantToken string
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
name: "valid bearer token",
|
||||
authHeader: "Bearer test-token-123",
|
||||
wantToken: "test-token-123",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "valid bearer token with extra spaces",
|
||||
authHeader: "Bearer test-token-456 ",
|
||||
wantToken: "test-token-456",
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "missing bearer prefix",
|
||||
authHeader: "test-token-123",
|
||||
wantToken: "",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "empty bearer token",
|
||||
authHeader: "Bearer ",
|
||||
wantToken: "",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "empty header",
|
||||
authHeader: "",
|
||||
wantToken: "",
|
||||
wantOK: false,
|
||||
},
|
||||
{
|
||||
name: "case sensitive bearer",
|
||||
authHeader: "bearer test-token",
|
||||
wantToken: "",
|
||||
wantOK: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
token, ok := extractBearerToken(tt.authHeader)
|
||||
if token != tt.wantToken {
|
||||
t.Errorf("extractBearerToken() token = %v, want %v", token, tt.wantToken)
|
||||
}
|
||||
if ok != tt.wantOK {
|
||||
t.Errorf("extractBearerToken() ok = %v, want %v", ok, tt.wantOK)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasExternalQueryKey(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "has key param",
|
||||
query: "?key=abc123",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "has api_key param",
|
||||
query: "?api_key=abc123",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "has token param",
|
||||
query: "?token=abc123",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "has access_token param",
|
||||
query: "?access_token=abc123",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "has other param",
|
||||
query: "?name=test&value=123",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "no params",
|
||||
query: "",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "case insensitive key",
|
||||
query: "?KEY=abc123",
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := httptest.NewRequest("GET", "/test"+tt.query, nil)
|
||||
if got := hasExternalQueryKey(req); got != tt.want {
|
||||
t.Errorf("hasExternalQueryKey() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestIDMiddleware(t *testing.T) {
|
||||
t.Run("generates request ID when not present", func(t *testing.T) {
|
||||
var capturedReqID string
|
||||
handler := requestIDMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedReqID, _ = RequestIDFromContext(r.Context())
|
||||
}), time.Now)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if capturedReqID == "" {
|
||||
t.Error("expected request ID to be set in context")
|
||||
}
|
||||
if rr.Header().Get("X-Request-Id") == "" {
|
||||
t.Error("expected X-Request-Id header to be set in response")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("uses existing request ID from header", func(t *testing.T) {
|
||||
existingID := "existing-req-id-123"
|
||||
var capturedID string
|
||||
handler := requestIDMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
capturedID = r.Header.Get("X-Request-Id")
|
||||
}), time.Now)
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req.Header.Set("X-Request-Id", existingID)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if capturedID != existingID {
|
||||
t.Errorf("expected request ID %q, got %q", existingID, capturedID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("nil next handler does not panic", func(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
t.Errorf("panic with nil next handler: %v", r)
|
||||
}
|
||||
}()
|
||||
handler := requestIDMiddleware(nil, time.Now)
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
})
|
||||
}
|
||||
|
||||
func TestQueryKeyRejectMiddleware(t *testing.T) {
|
||||
t.Run("rejects request with query key", func(t *testing.T) {
|
||||
auditCalled := false
|
||||
auditor := &mockAuditEmitter{
|
||||
onEmit: func(ctx context.Context, event AuditEvent) error {
|
||||
auditCalled = true
|
||||
if event.EventName != EventTokenQueryKeyRejected {
|
||||
t.Errorf("expected event %s, got %s", EventTokenQueryKeyRejected, event.EventName)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
handler := queryKeyRejectMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("next handler should not be called")
|
||||
}), auditor, time.Now)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/supply?key=abc123", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if !auditCalled {
|
||||
t.Error("expected audit to be called")
|
||||
}
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("allows request without query key", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
handler := queryKeyRejectMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
}), nil, time.Now)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/supply?name=test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Error("expected next handler to be called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects api_key parameter", func(t *testing.T) {
|
||||
handler := queryKeyRejectMiddleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("next handler should not be called")
|
||||
}), nil, time.Now)
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/supply?api_key=secret", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenAuthMiddleware(t *testing.T) {
|
||||
t.Run("allows request when all checks pass", func(t *testing.T) {
|
||||
now := time.Now()
|
||||
tokenRuntime := NewInMemoryTokenRuntime(func() time.Time { return now })
|
||||
|
||||
// Issue a valid token
|
||||
token, err := tokenRuntime.Issue(context.Background(), "user1", "admin", []string{"supply:read", "supply:write"}, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to issue token: %v", err)
|
||||
}
|
||||
|
||||
cfg := AuthMiddlewareConfig{
|
||||
Verifier: tokenRuntime,
|
||||
StatusResolver: tokenRuntime,
|
||||
Authorizer: NewScopeRoleAuthorizer(),
|
||||
ProtectedPrefixes: []string{"/api/v1/supply"},
|
||||
ExcludedPrefixes: []string{"/health"},
|
||||
Now: func() time.Time { return now },
|
||||
}
|
||||
|
||||
nextCalled := false
|
||||
handler := tokenAuthMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
// Verify principal is set in context
|
||||
principal, ok := PrincipalFromContext(r.Context())
|
||||
if !ok {
|
||||
t.Error("expected principal in context")
|
||||
}
|
||||
if principal.SubjectID != "user1" {
|
||||
t.Errorf("expected subject user1, got %s", principal.SubjectID)
|
||||
}
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/supply", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Error("expected next handler to be called")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects request without bearer token", func(t *testing.T) {
|
||||
cfg := AuthMiddlewareConfig{
|
||||
Verifier: &mockVerifier{},
|
||||
StatusResolver: &mockStatusResolver{},
|
||||
Authorizer: NewScopeRoleAuthorizer(),
|
||||
ProtectedPrefixes: []string{"/api/v1/supply"},
|
||||
Now: time.Now,
|
||||
}
|
||||
|
||||
handler := tokenAuthMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("next handler should not be called")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/supply", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("rejects request to excluded path", func(t *testing.T) {
|
||||
cfg := AuthMiddlewareConfig{
|
||||
Verifier: &mockVerifier{},
|
||||
StatusResolver: &mockStatusResolver{},
|
||||
Authorizer: NewScopeRoleAuthorizer(),
|
||||
ProtectedPrefixes: []string{"/api/v1/supply"},
|
||||
ExcludedPrefixes: []string{"/health"},
|
||||
Now: time.Now,
|
||||
}
|
||||
|
||||
nextCalled := false
|
||||
handler := tokenAuthMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/health", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Error("expected next handler to be called for excluded path")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("returns 503 when dependencies not ready", func(t *testing.T) {
|
||||
cfg := AuthMiddlewareConfig{
|
||||
Verifier: nil,
|
||||
StatusResolver: nil,
|
||||
Authorizer: nil,
|
||||
ProtectedPrefixes: []string{"/api/v1/supply"},
|
||||
Now: time.Now,
|
||||
}
|
||||
|
||||
handler := tokenAuthMiddleware(cfg)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("next handler should not be called")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/supply", nil)
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("expected status 503, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestScopeRoleAuthorizer(t *testing.T) {
|
||||
authorizer := NewScopeRoleAuthorizer()
|
||||
|
||||
t.Run("admin role has access to all", func(t *testing.T) {
|
||||
if !authorizer.Authorize("/api/v1/supply", "POST", []string{}, "admin") {
|
||||
t.Error("expected admin to have access")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("supply read scope for GET", func(t *testing.T) {
|
||||
if !authorizer.Authorize("/api/v1/supply", "GET", []string{"supply:read"}, "user") {
|
||||
t.Error("expected supply:read to have access to GET")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("supply write scope for POST", func(t *testing.T) {
|
||||
if !authorizer.Authorize("/api/v1/supply", "POST", []string{"supply:write"}, "user") {
|
||||
t.Error("expected supply:write to have access to POST")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("supply:read scope is denied for POST", func(t *testing.T) {
|
||||
// supply:read only allows GET, POST should be denied
|
||||
if authorizer.Authorize("/api/v1/supply", "POST", []string{"supply:read"}, "user") {
|
||||
t.Error("expected supply:read to be denied for POST")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("wildcard scope works", func(t *testing.T) {
|
||||
if !authorizer.Authorize("/api/v1/supply", "POST", []string{"supply:*"}, "user") {
|
||||
t.Error("expected supply:* to have access")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("platform admin scope", func(t *testing.T) {
|
||||
if !authorizer.Authorize("/api/v1/platform/users", "GET", []string{"platform:admin"}, "user") {
|
||||
t.Error("expected platform:admin to have access")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestInMemoryTokenRuntime(t *testing.T) {
|
||||
now := time.Now()
|
||||
runtime := NewInMemoryTokenRuntime(func() time.Time { return now })
|
||||
|
||||
t.Run("issue and verify token", func(t *testing.T) {
|
||||
token, err := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read"}, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to issue token: %v", err)
|
||||
}
|
||||
if token == "" {
|
||||
t.Error("expected non-empty token")
|
||||
}
|
||||
|
||||
claims, err := runtime.Verify(context.Background(), token)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to verify token: %v", err)
|
||||
}
|
||||
if claims.SubjectID != "user1" {
|
||||
t.Errorf("expected subject user1, got %s", claims.SubjectID)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("resolve token status", func(t *testing.T) {
|
||||
token, err := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read"}, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to issue token: %v", err)
|
||||
}
|
||||
|
||||
// Get token ID first
|
||||
claims, _ := runtime.Verify(context.Background(), token)
|
||||
|
||||
status, err := runtime.Resolve(context.Background(), claims.TokenID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to resolve status: %v", err)
|
||||
}
|
||||
if status != TokenStatusActive {
|
||||
t.Errorf("expected status active, got %s", status)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("revoke token", func(t *testing.T) {
|
||||
token, _ := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read"}, time.Hour)
|
||||
claims, _ := runtime.Verify(context.Background(), token)
|
||||
|
||||
err := runtime.Revoke(context.Background(), claims.TokenID)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to revoke token: %v", err)
|
||||
}
|
||||
|
||||
status, _ := runtime.Resolve(context.Background(), claims.TokenID)
|
||||
if status != TokenStatusRevoked {
|
||||
t.Errorf("expected status revoked, got %s", status)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("verify invalid token", func(t *testing.T) {
|
||||
_, err := runtime.Verify(context.Background(), "invalid-token")
|
||||
if err == nil {
|
||||
t.Error("expected error for invalid token")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildTokenAuthChain(t *testing.T) {
|
||||
now := time.Now()
|
||||
runtime := NewInMemoryTokenRuntime(func() time.Time { return now })
|
||||
token, _ := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read", "supply:write"}, time.Hour)
|
||||
|
||||
cfg := AuthMiddlewareConfig{
|
||||
Verifier: runtime,
|
||||
StatusResolver: runtime,
|
||||
Authorizer: NewScopeRoleAuthorizer(),
|
||||
ProtectedPrefixes: []string{"/api/v1/supply", "/api/v1/platform"},
|
||||
ExcludedPrefixes: []string{"/health", "/healthz"},
|
||||
Now: func() time.Time { return now },
|
||||
}
|
||||
|
||||
t.Run("full chain with valid token", func(t *testing.T) {
|
||||
nextCalled := false
|
||||
handler := BuildTokenAuthChain(cfg, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
nextCalled = true
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/supply", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
recorder := httptest.NewRecorder()
|
||||
handler.ServeHTTP(recorder, req)
|
||||
|
||||
if !nextCalled {
|
||||
t.Error("expected chain to complete successfully")
|
||||
}
|
||||
if recorder.Header().Get("X-Request-Id") == "" {
|
||||
t.Error("expected X-Request-Id header to be set by chain")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("full chain rejects query key", func(t *testing.T) {
|
||||
handler := BuildTokenAuthChain(cfg, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Error("next handler should not be called")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/api/v1/supply?key=blocked", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Errorf("expected status 401, got %d", rr.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Mock implementations
|
||||
type mockVerifier struct{}
|
||||
|
||||
func (m *mockVerifier) Verify(ctx context.Context, rawToken string) (VerifiedToken, error) {
|
||||
return VerifiedToken{}, nil
|
||||
}
|
||||
|
||||
type mockStatusResolver struct{}
|
||||
|
||||
func (m *mockStatusResolver) Resolve(ctx context.Context, tokenID string) (TokenStatus, error) {
|
||||
return TokenStatusActive, nil
|
||||
}
|
||||
|
||||
type mockAuditEmitter struct {
|
||||
onEmit func(ctx context.Context, event AuditEvent) error
|
||||
}
|
||||
|
||||
func (m *mockAuditEmitter) Emit(ctx context.Context, event AuditEvent) error {
|
||||
if m.onEmit != nil {
|
||||
return m.onEmit(ctx, event)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestHasScope(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
scopes []string
|
||||
required string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "exact match",
|
||||
scopes: []string{"supply:read", "supply:write"},
|
||||
required: "supply:read",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "no match",
|
||||
scopes: []string{"supply:read"},
|
||||
required: "supply:write",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "wildcard match",
|
||||
scopes: []string{"supply:*"},
|
||||
required: "supply:read",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "wildcard match write",
|
||||
scopes: []string{"supply:*"},
|
||||
required: "supply:write",
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "empty scopes",
|
||||
scopes: []string{},
|
||||
required: "supply:read",
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "partial wildcard no match",
|
||||
scopes: []string{"supply:read"},
|
||||
required: "platform:admin",
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := hasScope(tt.scopes, tt.required)
|
||||
if got != tt.want {
|
||||
t.Errorf("hasScope(%v, %s) = %v, want %v", tt.scopes, tt.required, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequiredScopeForRoute(t *testing.T) {
|
||||
tests := []struct {
|
||||
path string
|
||||
method string
|
||||
want string
|
||||
}{
|
||||
{"/api/v1/supply", "GET", "supply:read"},
|
||||
{"/api/v1/supply", "HEAD", "supply:read"},
|
||||
{"/api/v1/supply", "OPTIONS", "supply:read"},
|
||||
{"/api/v1/supply", "POST", "supply:write"},
|
||||
{"/api/v1/supply", "PUT", "supply:write"},
|
||||
{"/api/v1/supply", "DELETE", "supply:write"},
|
||||
{"/api/v1/supply/", "GET", "supply:read"},
|
||||
{"/api/v1/supply/123", "GET", "supply:read"},
|
||||
{"/api/v1/platform", "GET", "platform:admin"},
|
||||
{"/api/v1/platform", "POST", "platform:admin"},
|
||||
{"/api/v1/platform/", "DELETE", "platform:admin"},
|
||||
{"/api/v1/platform/users", "GET", "platform:admin"},
|
||||
{"/unknown", "GET", ""},
|
||||
{"/api/v1/other", "GET", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path+"_"+tt.method, func(t *testing.T) {
|
||||
got := requiredScopeForRoute(tt.path, tt.method)
|
||||
if got != tt.want {
|
||||
t.Errorf("requiredScopeForRoute(%s, %s) = %s, want %s", tt.path, tt.method, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateAccessToken(t *testing.T) {
|
||||
token, err := generateAccessToken()
|
||||
if err != nil {
|
||||
t.Fatalf("generateAccessToken() error = %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(token, "ptk_") {
|
||||
t.Errorf("expected token to start with ptk_, got %s", token)
|
||||
}
|
||||
if len(token) < 10 {
|
||||
t.Errorf("expected token length >= 10, got %d", len(token))
|
||||
}
|
||||
|
||||
// 生成多个token应该不同
|
||||
token2, _ := generateAccessToken()
|
||||
if token == token2 {
|
||||
t.Error("expected different tokens")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateTokenID(t *testing.T) {
|
||||
tokenID, err := generateTokenID()
|
||||
if err != nil {
|
||||
t.Fatalf("generateTokenID() error = %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(tokenID, "tok_") {
|
||||
t.Errorf("expected token ID to start with tok_, got %s", tokenID)
|
||||
}
|
||||
|
||||
tokenID2, _ := generateTokenID()
|
||||
if tokenID == tokenID2 {
|
||||
t.Error("expected different token IDs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateEventID(t *testing.T) {
|
||||
eventID, err := generateEventID()
|
||||
if err != nil {
|
||||
t.Fatalf("generateEventID() error = %v", err)
|
||||
}
|
||||
if !strings.HasPrefix(eventID, "evt_") {
|
||||
t.Errorf("expected event ID to start with evt_, got %s", eventID)
|
||||
}
|
||||
|
||||
eventID2, _ := generateEventID()
|
||||
if eventID == eventID2 {
|
||||
t.Error("expected different event IDs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullString(t *testing.T) {
|
||||
tests := []struct {
|
||||
input string
|
||||
wantStr string
|
||||
wantValid bool
|
||||
}{
|
||||
{"hello", "hello", true},
|
||||
{"", "", false},
|
||||
{"world", "world", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := nullString(tt.input)
|
||||
if got.String != tt.wantStr {
|
||||
t.Errorf("nullString(%q).String = %q, want %q", tt.input, got.String, tt.wantStr)
|
||||
}
|
||||
if got.Valid != tt.wantValid {
|
||||
t.Errorf("nullString(%q).Valid = %v, want %v", tt.input, got.Valid, tt.wantValid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryTokenRuntime_Issue_Errors(t *testing.T) {
|
||||
now := time.Now()
|
||||
runtime := NewInMemoryTokenRuntime(func() time.Time { return now })
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
subjectID string
|
||||
role string
|
||||
scopes []string
|
||||
ttl time.Duration
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "empty subject_id",
|
||||
subjectID: "",
|
||||
role: "admin",
|
||||
scopes: []string{"supply:read"},
|
||||
ttl: time.Hour,
|
||||
wantErr: "subject_id is required",
|
||||
},
|
||||
{
|
||||
name: "whitespace subject_id",
|
||||
subjectID: " ",
|
||||
role: "admin",
|
||||
scopes: []string{"supply:read"},
|
||||
ttl: time.Hour,
|
||||
wantErr: "subject_id is required",
|
||||
},
|
||||
{
|
||||
name: "empty role",
|
||||
subjectID: "user1",
|
||||
role: "",
|
||||
scopes: []string{"supply:read"},
|
||||
ttl: time.Hour,
|
||||
wantErr: "role is required",
|
||||
},
|
||||
{
|
||||
name: "empty scopes",
|
||||
subjectID: "user1",
|
||||
role: "admin",
|
||||
scopes: []string{},
|
||||
ttl: time.Hour,
|
||||
wantErr: "scope must not be empty",
|
||||
},
|
||||
{
|
||||
name: "zero ttl",
|
||||
subjectID: "user1",
|
||||
role: "admin",
|
||||
scopes: []string{"supply:read"},
|
||||
ttl: 0,
|
||||
wantErr: "ttl must be positive",
|
||||
},
|
||||
{
|
||||
name: "negative ttl",
|
||||
subjectID: "user1",
|
||||
role: "admin",
|
||||
scopes: []string{"supply:read"},
|
||||
ttl: -time.Second,
|
||||
wantErr: "ttl must be positive",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := runtime.Issue(context.Background(), tt.subjectID, tt.role, tt.scopes, tt.ttl)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
if err.Error() != tt.wantErr {
|
||||
t.Errorf("error = %q, want %q", err.Error(), tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryTokenRuntime_Verify_Expired(t *testing.T) {
|
||||
now := time.Now()
|
||||
runtime := NewInMemoryTokenRuntime(func() time.Time { return now })
|
||||
|
||||
token, _ := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read"}, time.Hour)
|
||||
|
||||
// 验证token仍然有效
|
||||
claims, err := runtime.Verify(context.Background(), token)
|
||||
if err != nil {
|
||||
t.Fatalf("Verify failed: %v", err)
|
||||
}
|
||||
if claims.SubjectID != "user1" {
|
||||
t.Errorf("SubjectID = %s, want user1", claims.SubjectID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInMemoryTokenRuntime_ApplyExpiry(t *testing.T) {
|
||||
now := time.Now()
|
||||
runtime := NewInMemoryTokenRuntime(func() time.Time { return now })
|
||||
|
||||
token, _ := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read"}, time.Hour)
|
||||
claims, _ := runtime.Verify(context.Background(), token)
|
||||
|
||||
// 手动设置过期
|
||||
runtime.mu.Lock()
|
||||
record := runtime.records[claims.TokenID]
|
||||
record.ExpiresAt = now.Add(-time.Hour) // 1小时前过期
|
||||
runtime.mu.Unlock()
|
||||
|
||||
// Resolve应该检测到过期
|
||||
status, _ := runtime.Resolve(context.Background(), claims.TokenID)
|
||||
if status != TokenStatusExpired {
|
||||
t.Errorf("status = %s, want Expired", status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScopeRoleAuthorizer_Authorize(t *testing.T) {
|
||||
authorizer := NewScopeRoleAuthorizer()
|
||||
|
||||
tests := []struct {
|
||||
path string
|
||||
method string
|
||||
scopes []string
|
||||
role string
|
||||
want bool
|
||||
}{
|
||||
{"/api/v1/supply", "GET", []string{"supply:read"}, "user", true},
|
||||
{"/api/v1/supply", "POST", []string{"supply:write"}, "user", true},
|
||||
{"/api/v1/supply", "DELETE", []string{"supply:read"}, "user", false},
|
||||
{"/api/v1/supply", "GET", []string{}, "admin", true},
|
||||
{"/api/v1/supply", "POST", []string{}, "admin", true},
|
||||
{"/api/v1/other", "GET", []string{}, "user", true}, // 无需权限
|
||||
{"/api/v1/platform/users", "GET", []string{"platform:admin"}, "user", true},
|
||||
{"/api/v1/platform/users", "POST", []string{"platform:admin"}, "user", true},
|
||||
{"/api/v1/platform/users", "DELETE", []string{"supply:read"}, "user", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.path+"_"+tt.method, func(t *testing.T) {
|
||||
got := authorizer.Authorize(tt.path, tt.method, tt.scopes, tt.role)
|
||||
if got != tt.want {
|
||||
t.Errorf("Authorize(%s, %s, %v, %s) = %v, want %v", tt.path, tt.method, tt.scopes, tt.role, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMemoryAuditEmitter(t *testing.T) {
|
||||
emitter := NewMemoryAuditEmitter()
|
||||
|
||||
event := AuditEvent{
|
||||
EventName: EventTokenQueryKeyRejected,
|
||||
RequestID: "req-123",
|
||||
Route: "/api/v1/supply",
|
||||
ResultCode: "401",
|
||||
}
|
||||
|
||||
err := emitter.Emit(context.Background(), event)
|
||||
if err != nil {
|
||||
t.Fatalf("Emit failed: %v", err)
|
||||
}
|
||||
|
||||
if len(emitter.events) != 1 {
|
||||
t.Errorf("expected 1 event, got %d", len(emitter.events))
|
||||
}
|
||||
|
||||
if emitter.events[0].EventID == "" {
|
||||
t.Error("expected EventID to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewInMemoryTokenRuntime_NilNow(t *testing.T) {
|
||||
// 不传入now函数,应该使用默认的time.Now
|
||||
runtime := NewInMemoryTokenRuntime(nil)
|
||||
if runtime == nil {
|
||||
t.Fatal("expected non-nil runtime")
|
||||
}
|
||||
|
||||
// 验证基本功能
|
||||
_, err := runtime.Issue(context.Background(), "user1", "admin", []string{"supply:read"}, time.Hour)
|
||||
if err != nil {
|
||||
t.Fatalf("Issue failed: %v", err)
|
||||
}
|
||||
}
|
||||
239
gateway/internal/middleware/runtime.go
Normal file
239
gateway/internal/middleware/runtime.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// InMemoryTokenRuntime 内存中的Token运行时实现
|
||||
type InMemoryTokenRuntime struct {
|
||||
mu sync.RWMutex
|
||||
now func() time.Time
|
||||
records map[string]*tokenRecord
|
||||
tokenToID map[string]string
|
||||
}
|
||||
|
||||
type tokenRecord struct {
|
||||
TokenID string
|
||||
AccessToken string
|
||||
SubjectID string
|
||||
Role string
|
||||
Scope []string
|
||||
IssuedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
Status TokenStatus
|
||||
}
|
||||
|
||||
// NewInMemoryTokenRuntime 创建内存Token运行时
|
||||
func NewInMemoryTokenRuntime(now func() time.Time) *InMemoryTokenRuntime {
|
||||
if now == nil {
|
||||
now = time.Now
|
||||
}
|
||||
return &InMemoryTokenRuntime{
|
||||
now: now,
|
||||
records: make(map[string]*tokenRecord),
|
||||
tokenToID: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
// Issue 颁发Token
|
||||
func (r *InMemoryTokenRuntime) Issue(_ context.Context, subjectID, role string, scopes []string, ttl time.Duration) (string, error) {
|
||||
if strings.TrimSpace(subjectID) == "" {
|
||||
return "", errors.New("subject_id is required")
|
||||
}
|
||||
if strings.TrimSpace(role) == "" {
|
||||
return "", errors.New("role is required")
|
||||
}
|
||||
if len(scopes) == 0 {
|
||||
return "", errors.New("scope must not be empty")
|
||||
}
|
||||
if ttl <= 0 {
|
||||
return "", errors.New("ttl must be positive")
|
||||
}
|
||||
|
||||
issuedAt := r.now()
|
||||
tokenID, _ := generateTokenID()
|
||||
accessToken, _ := generateAccessToken()
|
||||
|
||||
record := &tokenRecord{
|
||||
TokenID: tokenID,
|
||||
AccessToken: accessToken,
|
||||
SubjectID: subjectID,
|
||||
Role: role,
|
||||
Scope: append([]string(nil), scopes...),
|
||||
IssuedAt: issuedAt,
|
||||
ExpiresAt: issuedAt.Add(ttl),
|
||||
Status: TokenStatusActive,
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
r.records[tokenID] = record
|
||||
r.tokenToID[accessToken] = tokenID
|
||||
r.mu.Unlock()
|
||||
|
||||
return accessToken, nil
|
||||
}
|
||||
|
||||
// Verify 验证Token
|
||||
func (r *InMemoryTokenRuntime) Verify(_ context.Context, rawToken string) (VerifiedToken, error) {
|
||||
r.mu.RLock()
|
||||
tokenID, ok := r.tokenToID[rawToken]
|
||||
if !ok {
|
||||
r.mu.RUnlock()
|
||||
return VerifiedToken{}, errors.New("token not found")
|
||||
}
|
||||
record, ok := r.records[tokenID]
|
||||
if !ok {
|
||||
r.mu.RUnlock()
|
||||
return VerifiedToken{}, errors.New("token record not found")
|
||||
}
|
||||
claims := VerifiedToken{
|
||||
TokenID: record.TokenID,
|
||||
SubjectID: record.SubjectID,
|
||||
Role: record.Role,
|
||||
Scope: append([]string(nil), record.Scope...),
|
||||
IssuedAt: record.IssuedAt,
|
||||
ExpiresAt: record.ExpiresAt,
|
||||
}
|
||||
r.mu.RUnlock()
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// Resolve 解析Token状态
|
||||
func (r *InMemoryTokenRuntime) Resolve(_ context.Context, tokenID string) (TokenStatus, error) {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
record, ok := r.records[tokenID]
|
||||
if !ok {
|
||||
return "", errors.New("token not found")
|
||||
}
|
||||
r.applyExpiry(record)
|
||||
return record.Status, nil
|
||||
}
|
||||
|
||||
// Revoke 吊销Token
|
||||
func (r *InMemoryTokenRuntime) Revoke(_ context.Context, tokenID string) error {
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
record, ok := r.records[tokenID]
|
||||
if !ok {
|
||||
return errors.New("token not found")
|
||||
}
|
||||
record.Status = TokenStatusRevoked
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *InMemoryTokenRuntime) applyExpiry(record *tokenRecord) {
|
||||
if record == nil {
|
||||
return
|
||||
}
|
||||
if record.Status == TokenStatusActive && !record.ExpiresAt.IsZero() && !r.now().Before(record.ExpiresAt) {
|
||||
record.Status = TokenStatusExpired
|
||||
}
|
||||
}
|
||||
|
||||
// ScopeRoleAuthorizer 基于Scope和Role的授权器
|
||||
type ScopeRoleAuthorizer struct{}
|
||||
|
||||
func NewScopeRoleAuthorizer() *ScopeRoleAuthorizer {
|
||||
return &ScopeRoleAuthorizer{}
|
||||
}
|
||||
|
||||
func (a *ScopeRoleAuthorizer) Authorize(path, method string, scopes []string, role string) bool {
|
||||
if role == "admin" {
|
||||
return true
|
||||
}
|
||||
|
||||
requiredScope := requiredScopeForRoute(path, method)
|
||||
if requiredScope == "" {
|
||||
return true
|
||||
}
|
||||
return hasScope(scopes, requiredScope)
|
||||
}
|
||||
|
||||
func requiredScopeForRoute(path, method string) string {
|
||||
// Handle /api/v1/supply (with or without trailing slash)
|
||||
if path == "/api/v1/supply" || strings.HasPrefix(path, "/api/v1/supply/") {
|
||||
switch method {
|
||||
case "GET", "HEAD", "OPTIONS":
|
||||
return "supply:read"
|
||||
default:
|
||||
return "supply:write"
|
||||
}
|
||||
}
|
||||
// Handle /api/v1/platform (with or without trailing slash)
|
||||
if path == "/api/v1/platform" || strings.HasPrefix(path, "/api/v1/platform/") {
|
||||
return "platform:admin"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func hasScope(scopes []string, required string) bool {
|
||||
for _, scope := range scopes {
|
||||
if scope == required {
|
||||
return true
|
||||
}
|
||||
if strings.HasSuffix(scope, ":*") {
|
||||
prefix := strings.TrimSuffix(scope, ":*")
|
||||
if strings.HasPrefix(required, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MemoryAuditEmitter 内存审计发射器
|
||||
type MemoryAuditEmitter struct {
|
||||
mu sync.RWMutex
|
||||
events []AuditEvent
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
func NewMemoryAuditEmitter() *MemoryAuditEmitter {
|
||||
return &MemoryAuditEmitter{now: time.Now}
|
||||
}
|
||||
|
||||
func (e *MemoryAuditEmitter) Emit(_ context.Context, event AuditEvent) error {
|
||||
if event.EventID == "" {
|
||||
event.EventID, _ = generateEventID()
|
||||
}
|
||||
if event.CreatedAt.IsZero() {
|
||||
event.CreatedAt = e.now()
|
||||
}
|
||||
e.mu.Lock()
|
||||
e.events = append(e.events, event)
|
||||
e.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateAccessToken() (string, error) {
|
||||
var entropy [16]byte
|
||||
if _, err := rand.Read(entropy[:]); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "ptk_" + hex.EncodeToString(entropy[:]), nil
|
||||
}
|
||||
|
||||
func generateTokenID() (string, error) {
|
||||
var entropy [8]byte
|
||||
if _, err := rand.Read(entropy[:]); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "tok_" + hex.EncodeToString(entropy[:]), nil
|
||||
}
|
||||
|
||||
func generateEventID() (string, error) {
|
||||
var entropy [8]byte
|
||||
if _, err := rand.Read(entropy[:]); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return "evt_" + hex.EncodeToString(entropy[:]), nil
|
||||
}
|
||||
90
gateway/internal/middleware/types.go
Normal file
90
gateway/internal/middleware/types.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// 认证常量
|
||||
const (
|
||||
CodeAuthMissingBearer = "AUTH_MISSING_BEARER"
|
||||
CodeQueryKeyNotAllowed = "QUERY_KEY_NOT_ALLOWED"
|
||||
CodeAuthInvalidToken = "AUTH_INVALID_TOKEN"
|
||||
CodeAuthTokenInactive = "AUTH_TOKEN_INACTIVE"
|
||||
CodeAuthScopeDenied = "AUTH_SCOPE_DENIED"
|
||||
CodeAuthNotReady = "AUTH_NOT_READY"
|
||||
)
|
||||
|
||||
// 审计事件常量
|
||||
const (
|
||||
EventTokenAuthnSuccess = "token.authn.success"
|
||||
EventTokenAuthnFail = "token.authn.fail"
|
||||
EventTokenAuthzDenied = "token.authz.denied"
|
||||
EventTokenQueryKeyRejected = "token.query_key.rejected"
|
||||
)
|
||||
|
||||
// TokenStatus Token状态
|
||||
type TokenStatus string
|
||||
|
||||
const (
|
||||
TokenStatusActive TokenStatus = "active"
|
||||
TokenStatusRevoked TokenStatus = "revoked"
|
||||
TokenStatusExpired TokenStatus = "expired"
|
||||
)
|
||||
|
||||
// VerifiedToken 验证后的Token声明
|
||||
type VerifiedToken struct {
|
||||
TokenID string
|
||||
SubjectID string
|
||||
Role string
|
||||
Scope []string
|
||||
IssuedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
NotBefore time.Time
|
||||
Issuer string
|
||||
Audience string
|
||||
}
|
||||
|
||||
// TokenVerifier Token验证器接口
|
||||
type TokenVerifier interface {
|
||||
Verify(ctx context.Context, rawToken string) (VerifiedToken, error)
|
||||
}
|
||||
|
||||
// TokenStatusResolver Token状态解析器接口
|
||||
type TokenStatusResolver interface {
|
||||
Resolve(ctx context.Context, tokenID string) (TokenStatus, error)
|
||||
}
|
||||
|
||||
// RouteAuthorizer 路由授权器接口
|
||||
type RouteAuthorizer interface {
|
||||
Authorize(path, method string, scopes []string, role string) bool
|
||||
}
|
||||
|
||||
// AuditEvent 审计事件
|
||||
type AuditEvent struct {
|
||||
EventID string
|
||||
EventName string
|
||||
RequestID string
|
||||
TokenID string
|
||||
SubjectID string
|
||||
Route string
|
||||
ResultCode string
|
||||
ClientIP string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// AuditEmitter 审计事件发射器接口
|
||||
type AuditEmitter interface {
|
||||
Emit(ctx context.Context, event AuditEvent) error
|
||||
}
|
||||
|
||||
// AuthMiddlewareConfig 认证中间件配置
|
||||
type AuthMiddlewareConfig struct {
|
||||
Verifier TokenVerifier
|
||||
StatusResolver TokenStatusResolver
|
||||
Authorizer RouteAuthorizer
|
||||
Auditor AuditEmitter
|
||||
ProtectedPrefixes []string
|
||||
ExcludedPrefixes []string
|
||||
Now func() time.Time
|
||||
}
|
||||
Reference in New Issue
Block a user