From c06cacff0d949e3fd51aa9f1dda7e14422ec3a48 Mon Sep 17 00:00:00 2001 From: Your Name Date: Fri, 17 Apr 2026 17:56:59 +0800 Subject: [PATCH] refactor(token-runtime): abstract runtime and audit stores --- .../internal/app/bootstrap.go | 6 +- .../internal/app/bootstrap_test.go | 22 +++ .../internal/auth/service/inmemory_runtime.go | 142 +++++++++++++----- .../internal/auth/service/runtime_store.go | 52 +++++-- .../auth/service/runtime_store_test.go | 24 ++- .../auth/service/store_contract_test.go | 17 +++ .../internal/auth/service/token_verifier.go | 17 +++ 7 files changed, 222 insertions(+), 58 deletions(-) create mode 100644 platform-token-runtime/internal/auth/service/store_contract_test.go diff --git a/platform-token-runtime/internal/app/bootstrap.go b/platform-token-runtime/internal/app/bootstrap.go index 59f8f1d5..7f773e43 100644 --- a/platform-token-runtime/internal/app/bootstrap.go +++ b/platform-token-runtime/internal/app/bootstrap.go @@ -13,12 +13,12 @@ import ( type Config struct { Addr string Env string - RuntimeStore *service.InMemoryRuntimeStore - AuditStore *service.MemoryAuditStore + RuntimeStore service.RuntimeStore + AuditStore service.AuditStore Now func() time.Time } -func BuildRuntime(cfg Config) (*service.InMemoryTokenRuntime, *service.MemoryAuditStore, error) { +func BuildRuntime(cfg Config) (*service.InMemoryTokenRuntime, service.AuditStore, error) { now := cfg.Now if now == nil { now = time.Now diff --git a/platform-token-runtime/internal/app/bootstrap_test.go b/platform-token-runtime/internal/app/bootstrap_test.go index 67d3b880..14d480f6 100644 --- a/platform-token-runtime/internal/app/bootstrap_test.go +++ b/platform-token-runtime/internal/app/bootstrap_test.go @@ -6,6 +6,8 @@ import ( "strings" "testing" "time" + + "lijiaoqiao/platform-token-runtime/internal/auth/service" ) func TestBuildRuntime_ProdRequiresConcreteStore(t *testing.T) { @@ -31,6 +33,26 @@ func TestBuildRuntime_DevUsesInMemoryDefaults(t *testing.T) { } } +func TestBuildRuntime_ProdAcceptsStoreContracts(t *testing.T) { + runtimeStore := service.RuntimeStore(service.NewInMemoryRuntimeStore()) + auditStore := service.AuditStore(service.NewMemoryAuditStore()) + + runtime, auditor, err := BuildRuntime(Config{ + Env: "prod", + RuntimeStore: runtimeStore, + AuditStore: auditStore, + }) + if err != nil { + t.Fatalf("BuildRuntime returned error: %v", err) + } + if runtime == nil { + t.Fatal("expected runtime") + } + if auditor == nil { + t.Fatal("expected audit store") + } +} + func TestBuildServer_HealthEndpoint(t *testing.T) { srv, err := BuildServer(Config{ Env: "dev", diff --git a/platform-token-runtime/internal/auth/service/inmemory_runtime.go b/platform-token-runtime/internal/auth/service/inmemory_runtime.go index ee6800fc..276d894f 100644 --- a/platform-token-runtime/internal/auth/service/inmemory_runtime.go +++ b/platform-token-runtime/internal/auth/service/inmemory_runtime.go @@ -40,19 +40,14 @@ type IssueTokenInput struct { type InMemoryTokenRuntime struct { mu sync.RWMutex now func() time.Time - store *InMemoryRuntimeStore -} - -type idempotencyEntry struct { - RequestHash string - TokenID string + store RuntimeStore } func NewInMemoryTokenRuntime(now func() time.Time) *InMemoryTokenRuntime { return NewInMemoryTokenRuntimeWithStore(now, NewInMemoryRuntimeStore()) } -func NewInMemoryTokenRuntimeWithStore(now func() time.Time, store *InMemoryRuntimeStore) *InMemoryTokenRuntime { +func NewInMemoryTokenRuntimeWithStore(now func() time.Time, store RuntimeStore) *InMemoryTokenRuntime { if now == nil { now = time.Now } @@ -65,7 +60,7 @@ func NewInMemoryTokenRuntimeWithStore(now func() time.Time, store *InMemoryRunti } } -func (r *InMemoryTokenRuntime) Issue(_ context.Context, input IssueTokenInput) (TokenRecord, error) { +func (r *InMemoryTokenRuntime) Issue(ctx context.Context, input IssueTokenInput) (TokenRecord, error) { if strings.TrimSpace(input.SubjectID) == "" { return TokenRecord{}, errors.New("subject_id is required") } @@ -104,90 +99,156 @@ func (r *InMemoryTokenRuntime) Issue(_ context.Context, input IssueTokenInput) ( RevokedReason: "", } + if ctx == nil { + ctx = context.Background() + } r.mu.Lock() if idempotencyKey != "" { - entry, ok := r.store.LookupIdempotency(idempotencyKey) + entry, ok, err := r.store.LookupIdempotency(ctx, idempotencyKey) + if err != nil { + r.mu.Unlock() + return TokenRecord{}, err + } if ok { if entry.RequestHash != requestHash { r.mu.Unlock() return TokenRecord{}, errors.New("idempotency key payload mismatch") } - existing, exists := r.store.GetByTokenID(entry.TokenID) + existing, exists, err := r.store.GetByTokenID(ctx, entry.TokenID) + if err != nil { + r.mu.Unlock() + return TokenRecord{}, err + } if exists { r.mu.Unlock() return cloneRecord(*existing), nil } } } - r.store.Save(record, idempotencyKey, requestHash) + if err := r.store.Save(ctx, record, idempotencyKey, requestHash); err != nil { + r.mu.Unlock() + return TokenRecord{}, err + } r.mu.Unlock() return record, nil } -func (r *InMemoryTokenRuntime) Refresh(_ context.Context, tokenID string, ttl time.Duration) (TokenRecord, error) { +func (r *InMemoryTokenRuntime) Refresh(ctx context.Context, tokenID string, ttl time.Duration) (TokenRecord, error) { if ttl <= 0 { return TokenRecord{}, errors.New("ttl must be positive") } + if ctx == nil { + ctx = context.Background() + } r.mu.Lock() defer r.mu.Unlock() - record, ok := r.store.GetByTokenID(tokenID) + record, ok, err := r.store.GetByTokenID(ctx, tokenID) + if err != nil { + return TokenRecord{}, err + } if !ok { return TokenRecord{}, errors.New("token not found") } - r.applyExpiry(record) + if r.applyExpiry(record) { + if err := r.store.Save(ctx, *record, "", ""); err != nil { + return TokenRecord{}, err + } + } if record.Status != TokenStatusActive { return TokenRecord{}, errors.New("token is not active") } record.ExpiresAt = r.now().Add(ttl) - r.store.Save(*record, "", "") + if err := r.store.Save(ctx, *record, "", ""); err != nil { + return TokenRecord{}, err + } return cloneRecord(*record), nil } -func (r *InMemoryTokenRuntime) Revoke(_ context.Context, tokenID, reason string) (TokenRecord, error) { +func (r *InMemoryTokenRuntime) Revoke(ctx context.Context, tokenID, reason string) (TokenRecord, error) { + if ctx == nil { + ctx = context.Background() + } r.mu.Lock() defer r.mu.Unlock() - record, ok := r.store.GetByTokenID(tokenID) + record, ok, err := r.store.GetByTokenID(ctx, tokenID) + if err != nil { + return TokenRecord{}, err + } if !ok { return TokenRecord{}, errors.New("token not found") } - r.applyExpiry(record) + if r.applyExpiry(record) { + if err := r.store.Save(ctx, *record, "", ""); err != nil { + return TokenRecord{}, err + } + } record.Status = TokenStatusRevoked record.RevokedReason = strings.TrimSpace(reason) + if err := r.store.Save(ctx, *record, "", ""); err != nil { + return TokenRecord{}, err + } return cloneRecord(*record), nil } -func (r *InMemoryTokenRuntime) Introspect(_ context.Context, accessToken string) (TokenRecord, error) { +func (r *InMemoryTokenRuntime) Introspect(ctx context.Context, accessToken string) (TokenRecord, error) { + if ctx == nil { + ctx = context.Background() + } r.mu.Lock() defer r.mu.Unlock() - record, ok := r.store.GetByAccessToken(accessToken) + record, ok, err := r.store.GetByAccessToken(ctx, accessToken) + if err != nil { + return TokenRecord{}, err + } if !ok { return TokenRecord{}, errors.New("token not found") } - r.applyExpiry(record) + if r.applyExpiry(record) { + if err := r.store.Save(ctx, *record, "", ""); err != nil { + return TokenRecord{}, err + } + } return cloneRecord(*record), nil } -func (r *InMemoryTokenRuntime) Lookup(_ context.Context, tokenID string) (TokenRecord, error) { +func (r *InMemoryTokenRuntime) Lookup(ctx context.Context, tokenID string) (TokenRecord, error) { + if ctx == nil { + ctx = context.Background() + } r.mu.Lock() defer r.mu.Unlock() - record, ok := r.store.GetByTokenID(tokenID) + record, ok, err := r.store.GetByTokenID(ctx, tokenID) + if err != nil { + return TokenRecord{}, err + } if !ok { return TokenRecord{}, errors.New("token not found") } - r.applyExpiry(record) + if r.applyExpiry(record) { + if err := r.store.Save(ctx, *record, "", ""); err != nil { + return TokenRecord{}, err + } + } return cloneRecord(*record), nil } -func (r *InMemoryTokenRuntime) Verify(_ context.Context, rawToken string) (VerifiedToken, error) { +func (r *InMemoryTokenRuntime) Verify(ctx context.Context, rawToken string) (VerifiedToken, error) { + if ctx == nil { + ctx = context.Background() + } r.mu.RLock() - record, ok := r.store.GetByAccessToken(rawToken) + record, ok, err := r.store.GetByAccessToken(ctx, rawToken) + if err != nil { + r.mu.RUnlock() + return VerifiedToken{}, NewAuthError(CodeAuthInvalidToken, err) + } if !ok { r.mu.RUnlock() return VerifiedToken{}, NewAuthError(CodeAuthInvalidToken, errors.New("token not found")) @@ -204,22 +265,33 @@ func (r *InMemoryTokenRuntime) Verify(_ context.Context, rawToken string) (Verif return claims, nil } -func (r *InMemoryTokenRuntime) Resolve(_ context.Context, tokenID string) (TokenStatus, error) { +func (r *InMemoryTokenRuntime) Resolve(ctx context.Context, tokenID string) (TokenStatus, error) { + if ctx == nil { + ctx = context.Background() + } r.mu.Lock() defer r.mu.Unlock() - record, ok := r.store.GetByTokenID(tokenID) + record, ok, err := r.store.GetByTokenID(ctx, tokenID) + if err != nil { + return "", err + } if !ok { return "", NewAuthError(CodeAuthInvalidToken, errors.New("token not found")) } - r.applyExpiry(record) + if r.applyExpiry(record) { + if err := r.store.Save(ctx, *record, "", ""); err != nil { + return "", err + } + } return record.Status, nil } func (r *InMemoryTokenRuntime) TokenCount() int { - r.mu.RLock() - defer r.mu.RUnlock() - return r.store.TokenCount() + if counter, ok := r.store.(interface{ TokenCount() int }); ok { + return counter.TokenCount() + } + return 0 } func (r *InMemoryTokenRuntime) IssueAndAudit(ctx context.Context, input IssueTokenInput, auditor AuditEmitter) (TokenRecord, error) { @@ -269,13 +341,15 @@ func (r *InMemoryTokenRuntime) RevokeAndAudit(ctx context.Context, tokenID, reas return record, nil } -func (r *InMemoryTokenRuntime) applyExpiry(record *TokenRecord) { +func (r *InMemoryTokenRuntime) applyExpiry(record *TokenRecord) bool { if record == nil { - return + return false } if record.Status == TokenStatusActive && !record.ExpiresAt.IsZero() && !r.now().Before(record.ExpiresAt) { record.Status = TokenStatusExpired + return true } + return false } func cloneRecord(record TokenRecord) TokenRecord { diff --git a/platform-token-runtime/internal/auth/service/runtime_store.go b/platform-token-runtime/internal/auth/service/runtime_store.go index c7847448..0dc0d3cd 100644 --- a/platform-token-runtime/internal/auth/service/runtime_store.go +++ b/platform-token-runtime/internal/auth/service/runtime_store.go @@ -1,57 +1,77 @@ package service -import "sync" +import ( + "context" + "sync" +) type InMemoryRuntimeStore struct { - mu sync.RWMutex - records map[string]*TokenRecord - tokenToID map[string]string - idempotencyByKey map[string]idempotencyEntry + mu sync.RWMutex + records map[string]*TokenRecord + tokenToID map[string]string + idempotencyByKey map[string]IdempotencyEntry } func NewInMemoryRuntimeStore() *InMemoryRuntimeStore { return &InMemoryRuntimeStore{ records: make(map[string]*TokenRecord), tokenToID: make(map[string]string), - idempotencyByKey: make(map[string]idempotencyEntry), + idempotencyByKey: make(map[string]IdempotencyEntry), } } -func (s *InMemoryRuntimeStore) Save(record TokenRecord, idempotencyKey, requestHash string) { +func (s *InMemoryRuntimeStore) Save(_ context.Context, record TokenRecord, idempotencyKey, requestHash string) error { s.mu.Lock() defer s.mu.Unlock() recordCopy := cloneRecord(record) s.records[record.TokenID] = &recordCopy - s.tokenToID[record.AccessToken] = record.TokenID + if record.AccessToken != "" { + s.tokenToID[record.AccessToken] = record.TokenID + } if idempotencyKey != "" { - s.idempotencyByKey[idempotencyKey] = idempotencyEntry{ + s.idempotencyByKey[idempotencyKey] = IdempotencyEntry{ RequestHash: requestHash, TokenID: record.TokenID, } } + return nil } -func (s *InMemoryRuntimeStore) GetByTokenID(tokenID string) (*TokenRecord, bool) { +func (s *InMemoryRuntimeStore) GetByTokenID(_ context.Context, tokenID string) (*TokenRecord, bool, error) { s.mu.RLock() defer s.mu.RUnlock() record, ok := s.records[tokenID] - return record, ok + if !ok { + return nil, false, nil + } + cloned := cloneRecord(*record) + return &cloned, true, nil } -func (s *InMemoryRuntimeStore) GetByAccessToken(accessToken string) (*TokenRecord, bool) { +func (s *InMemoryRuntimeStore) GetByAccessToken(_ context.Context, accessToken string) (*TokenRecord, bool, error) { + s.mu.RLock() + defer s.mu.RUnlock() tokenID, ok := s.tokenToID[accessToken] if !ok { - return nil, false + return nil, false, nil } record, ok := s.records[tokenID] - return record, ok + if !ok { + return nil, false, nil + } + cloned := cloneRecord(*record) + return &cloned, true, nil } -func (s *InMemoryRuntimeStore) LookupIdempotency(idempotencyKey string) (idempotencyEntry, bool) { +func (s *InMemoryRuntimeStore) LookupIdempotency(_ context.Context, idempotencyKey string) (IdempotencyEntry, bool, error) { + s.mu.RLock() + defer s.mu.RUnlock() entry, ok := s.idempotencyByKey[idempotencyKey] - return entry, ok + return entry, ok, nil } func (s *InMemoryRuntimeStore) TokenCount() int { + s.mu.RLock() + defer s.mu.RUnlock() return len(s.records) } diff --git a/platform-token-runtime/internal/auth/service/runtime_store_test.go b/platform-token-runtime/internal/auth/service/runtime_store_test.go index 77e91c25..c0712745 100644 --- a/platform-token-runtime/internal/auth/service/runtime_store_test.go +++ b/platform-token-runtime/internal/auth/service/runtime_store_test.go @@ -1,6 +1,9 @@ package service -import "testing" +import ( + "context" + "testing" +) func TestInMemoryRuntimeStore_SaveAndLookup(t *testing.T) { store := NewInMemoryRuntimeStore() @@ -12,9 +15,14 @@ func TestInMemoryRuntimeStore_SaveAndLookup(t *testing.T) { Scope: []string{"supply:*"}, } - store.Save(record, "idem-1", "hash-1") + if err := store.Save(context.Background(), record, "idem-1", "hash-1"); err != nil { + t.Fatalf("save record: %v", err) + } - byID, ok := store.GetByTokenID("tok_123") + byID, ok, err := store.GetByTokenID(context.Background(), "tok_123") + if err != nil { + t.Fatalf("get by token id: %v", err) + } if !ok { t.Fatal("expected record by token id") } @@ -22,7 +30,10 @@ func TestInMemoryRuntimeStore_SaveAndLookup(t *testing.T) { t.Fatalf("unexpected token id: %s", byID.TokenID) } - byToken, ok := store.GetByAccessToken("ptk_123") + byToken, ok, err := store.GetByAccessToken(context.Background(), "ptk_123") + if err != nil { + t.Fatalf("get by access token: %v", err) + } if !ok { t.Fatal("expected record by access token") } @@ -30,7 +41,10 @@ func TestInMemoryRuntimeStore_SaveAndLookup(t *testing.T) { t.Fatalf("unexpected subject id: %s", byToken.SubjectID) } - entry, ok := store.LookupIdempotency("idem-1") + entry, ok, err := store.LookupIdempotency(context.Background(), "idem-1") + if err != nil { + t.Fatalf("lookup idempotency: %v", err) + } if !ok { t.Fatal("expected idempotency entry") } diff --git a/platform-token-runtime/internal/auth/service/store_contract_test.go b/platform-token-runtime/internal/auth/service/store_contract_test.go new file mode 100644 index 00000000..b6d174bc --- /dev/null +++ b/platform-token-runtime/internal/auth/service/store_contract_test.go @@ -0,0 +1,17 @@ +package service + +import "testing" + +func TestInMemoryStoresImplementContracts(t *testing.T) { + t.Helper() + + var runtimeStore RuntimeStore = NewInMemoryRuntimeStore() + var auditStore AuditStore = NewMemoryAuditStore() + + if runtimeStore == nil { + t.Fatal("expected runtime store contract implementation") + } + if auditStore == nil { + t.Fatal("expected audit store contract implementation") + } +} diff --git a/platform-token-runtime/internal/auth/service/token_verifier.go b/platform-token-runtime/internal/auth/service/token_verifier.go index e6d8b104..ec8a2d51 100644 --- a/platform-token-runtime/internal/auth/service/token_verifier.go +++ b/platform-token-runtime/internal/auth/service/token_verifier.go @@ -79,6 +79,16 @@ type AuditEmitter interface { Emit(ctx context.Context, event AuditEvent) error } +type AuditStore interface { + AuditEmitter + AuditEventQuerier +} + +type IdempotencyEntry struct { + RequestHash string + TokenID string +} + type AuditEventFilter struct { RequestID string TokenID string @@ -92,6 +102,13 @@ type AuditEventQuerier interface { QueryEvents(ctx context.Context, filter AuditEventFilter) ([]AuditEvent, error) } +type RuntimeStore interface { + Save(ctx context.Context, record TokenRecord, idempotencyKey, requestHash string) error + GetByTokenID(ctx context.Context, tokenID string) (*TokenRecord, bool, error) + GetByAccessToken(ctx context.Context, accessToken string) (*TokenRecord, bool, error) + LookupIdempotency(ctx context.Context, idempotencyKey string) (IdempotencyEntry, bool, error) +} + type AuthError struct { Code string Cause error