refactor(token-runtime): abstract runtime and audit stores
This commit is contained in:
@@ -13,12 +13,12 @@ import (
|
||||
type Config struct {
|
||||
Addr string
|
||||
Env string
|
||||
RuntimeStore *service.InMemoryRuntimeStore
|
||||
AuditStore *service.MemoryAuditStore
|
||||
RuntimeStore service.RuntimeStore
|
||||
AuditStore service.AuditStore
|
||||
Now func() time.Time
|
||||
}
|
||||
|
||||
func BuildRuntime(cfg Config) (*service.InMemoryTokenRuntime, *service.MemoryAuditStore, error) {
|
||||
func BuildRuntime(cfg Config) (*service.InMemoryTokenRuntime, service.AuditStore, error) {
|
||||
now := cfg.Now
|
||||
if now == nil {
|
||||
now = time.Now
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/platform-token-runtime/internal/auth/service"
|
||||
)
|
||||
|
||||
func TestBuildRuntime_ProdRequiresConcreteStore(t *testing.T) {
|
||||
@@ -31,6 +33,26 @@ func TestBuildRuntime_DevUsesInMemoryDefaults(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildRuntime_ProdAcceptsStoreContracts(t *testing.T) {
|
||||
runtimeStore := service.RuntimeStore(service.NewInMemoryRuntimeStore())
|
||||
auditStore := service.AuditStore(service.NewMemoryAuditStore())
|
||||
|
||||
runtime, auditor, err := BuildRuntime(Config{
|
||||
Env: "prod",
|
||||
RuntimeStore: runtimeStore,
|
||||
AuditStore: auditStore,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("BuildRuntime returned error: %v", err)
|
||||
}
|
||||
if runtime == nil {
|
||||
t.Fatal("expected runtime")
|
||||
}
|
||||
if auditor == nil {
|
||||
t.Fatal("expected audit store")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildServer_HealthEndpoint(t *testing.T) {
|
||||
srv, err := BuildServer(Config{
|
||||
Env: "dev",
|
||||
|
||||
@@ -40,19 +40,14 @@ type IssueTokenInput struct {
|
||||
type InMemoryTokenRuntime struct {
|
||||
mu sync.RWMutex
|
||||
now func() time.Time
|
||||
store *InMemoryRuntimeStore
|
||||
}
|
||||
|
||||
type idempotencyEntry struct {
|
||||
RequestHash string
|
||||
TokenID string
|
||||
store RuntimeStore
|
||||
}
|
||||
|
||||
func NewInMemoryTokenRuntime(now func() time.Time) *InMemoryTokenRuntime {
|
||||
return NewInMemoryTokenRuntimeWithStore(now, NewInMemoryRuntimeStore())
|
||||
}
|
||||
|
||||
func NewInMemoryTokenRuntimeWithStore(now func() time.Time, store *InMemoryRuntimeStore) *InMemoryTokenRuntime {
|
||||
func NewInMemoryTokenRuntimeWithStore(now func() time.Time, store RuntimeStore) *InMemoryTokenRuntime {
|
||||
if now == nil {
|
||||
now = time.Now
|
||||
}
|
||||
@@ -65,7 +60,7 @@ func NewInMemoryTokenRuntimeWithStore(now func() time.Time, store *InMemoryRunti
|
||||
}
|
||||
}
|
||||
|
||||
func (r *InMemoryTokenRuntime) Issue(_ context.Context, input IssueTokenInput) (TokenRecord, error) {
|
||||
func (r *InMemoryTokenRuntime) Issue(ctx context.Context, input IssueTokenInput) (TokenRecord, error) {
|
||||
if strings.TrimSpace(input.SubjectID) == "" {
|
||||
return TokenRecord{}, errors.New("subject_id is required")
|
||||
}
|
||||
@@ -104,90 +99,156 @@ func (r *InMemoryTokenRuntime) Issue(_ context.Context, input IssueTokenInput) (
|
||||
RevokedReason: "",
|
||||
}
|
||||
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
r.mu.Lock()
|
||||
if idempotencyKey != "" {
|
||||
entry, ok := r.store.LookupIdempotency(idempotencyKey)
|
||||
entry, ok, err := r.store.LookupIdempotency(ctx, idempotencyKey)
|
||||
if err != nil {
|
||||
r.mu.Unlock()
|
||||
return TokenRecord{}, err
|
||||
}
|
||||
if ok {
|
||||
if entry.RequestHash != requestHash {
|
||||
r.mu.Unlock()
|
||||
return TokenRecord{}, errors.New("idempotency key payload mismatch")
|
||||
}
|
||||
existing, exists := r.store.GetByTokenID(entry.TokenID)
|
||||
existing, exists, err := r.store.GetByTokenID(ctx, entry.TokenID)
|
||||
if err != nil {
|
||||
r.mu.Unlock()
|
||||
return TokenRecord{}, err
|
||||
}
|
||||
if exists {
|
||||
r.mu.Unlock()
|
||||
return cloneRecord(*existing), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
r.store.Save(record, idempotencyKey, requestHash)
|
||||
if err := r.store.Save(ctx, record, idempotencyKey, requestHash); err != nil {
|
||||
r.mu.Unlock()
|
||||
return TokenRecord{}, err
|
||||
}
|
||||
r.mu.Unlock()
|
||||
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (r *InMemoryTokenRuntime) Refresh(_ context.Context, tokenID string, ttl time.Duration) (TokenRecord, error) {
|
||||
func (r *InMemoryTokenRuntime) Refresh(ctx context.Context, tokenID string, ttl time.Duration) (TokenRecord, error) {
|
||||
if ttl <= 0 {
|
||||
return TokenRecord{}, errors.New("ttl must be positive")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
record, ok := r.store.GetByTokenID(tokenID)
|
||||
record, ok, err := r.store.GetByTokenID(ctx, tokenID)
|
||||
if err != nil {
|
||||
return TokenRecord{}, err
|
||||
}
|
||||
if !ok {
|
||||
return TokenRecord{}, errors.New("token not found")
|
||||
}
|
||||
r.applyExpiry(record)
|
||||
if r.applyExpiry(record) {
|
||||
if err := r.store.Save(ctx, *record, "", ""); err != nil {
|
||||
return TokenRecord{}, err
|
||||
}
|
||||
}
|
||||
if record.Status != TokenStatusActive {
|
||||
return TokenRecord{}, errors.New("token is not active")
|
||||
}
|
||||
|
||||
record.ExpiresAt = r.now().Add(ttl)
|
||||
r.store.Save(*record, "", "")
|
||||
if err := r.store.Save(ctx, *record, "", ""); err != nil {
|
||||
return TokenRecord{}, err
|
||||
}
|
||||
return cloneRecord(*record), nil
|
||||
}
|
||||
|
||||
func (r *InMemoryTokenRuntime) Revoke(_ context.Context, tokenID, reason string) (TokenRecord, error) {
|
||||
func (r *InMemoryTokenRuntime) Revoke(ctx context.Context, tokenID, reason string) (TokenRecord, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
record, ok := r.store.GetByTokenID(tokenID)
|
||||
record, ok, err := r.store.GetByTokenID(ctx, tokenID)
|
||||
if err != nil {
|
||||
return TokenRecord{}, err
|
||||
}
|
||||
if !ok {
|
||||
return TokenRecord{}, errors.New("token not found")
|
||||
}
|
||||
r.applyExpiry(record)
|
||||
if r.applyExpiry(record) {
|
||||
if err := r.store.Save(ctx, *record, "", ""); err != nil {
|
||||
return TokenRecord{}, err
|
||||
}
|
||||
}
|
||||
record.Status = TokenStatusRevoked
|
||||
record.RevokedReason = strings.TrimSpace(reason)
|
||||
if err := r.store.Save(ctx, *record, "", ""); err != nil {
|
||||
return TokenRecord{}, err
|
||||
}
|
||||
return cloneRecord(*record), nil
|
||||
}
|
||||
|
||||
func (r *InMemoryTokenRuntime) Introspect(_ context.Context, accessToken string) (TokenRecord, error) {
|
||||
func (r *InMemoryTokenRuntime) Introspect(ctx context.Context, accessToken string) (TokenRecord, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
record, ok := r.store.GetByAccessToken(accessToken)
|
||||
record, ok, err := r.store.GetByAccessToken(ctx, accessToken)
|
||||
if err != nil {
|
||||
return TokenRecord{}, err
|
||||
}
|
||||
if !ok {
|
||||
return TokenRecord{}, errors.New("token not found")
|
||||
}
|
||||
r.applyExpiry(record)
|
||||
if r.applyExpiry(record) {
|
||||
if err := r.store.Save(ctx, *record, "", ""); err != nil {
|
||||
return TokenRecord{}, err
|
||||
}
|
||||
}
|
||||
return cloneRecord(*record), nil
|
||||
}
|
||||
|
||||
func (r *InMemoryTokenRuntime) Lookup(_ context.Context, tokenID string) (TokenRecord, error) {
|
||||
func (r *InMemoryTokenRuntime) Lookup(ctx context.Context, tokenID string) (TokenRecord, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
record, ok := r.store.GetByTokenID(tokenID)
|
||||
record, ok, err := r.store.GetByTokenID(ctx, tokenID)
|
||||
if err != nil {
|
||||
return TokenRecord{}, err
|
||||
}
|
||||
if !ok {
|
||||
return TokenRecord{}, errors.New("token not found")
|
||||
}
|
||||
r.applyExpiry(record)
|
||||
if r.applyExpiry(record) {
|
||||
if err := r.store.Save(ctx, *record, "", ""); err != nil {
|
||||
return TokenRecord{}, err
|
||||
}
|
||||
}
|
||||
return cloneRecord(*record), nil
|
||||
}
|
||||
|
||||
func (r *InMemoryTokenRuntime) Verify(_ context.Context, rawToken string) (VerifiedToken, error) {
|
||||
func (r *InMemoryTokenRuntime) Verify(ctx context.Context, rawToken string) (VerifiedToken, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
r.mu.RLock()
|
||||
record, ok := r.store.GetByAccessToken(rawToken)
|
||||
record, ok, err := r.store.GetByAccessToken(ctx, rawToken)
|
||||
if err != nil {
|
||||
r.mu.RUnlock()
|
||||
return VerifiedToken{}, NewAuthError(CodeAuthInvalidToken, err)
|
||||
}
|
||||
if !ok {
|
||||
r.mu.RUnlock()
|
||||
return VerifiedToken{}, NewAuthError(CodeAuthInvalidToken, errors.New("token not found"))
|
||||
@@ -204,22 +265,33 @@ func (r *InMemoryTokenRuntime) Verify(_ context.Context, rawToken string) (Verif
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (r *InMemoryTokenRuntime) Resolve(_ context.Context, tokenID string) (TokenStatus, error) {
|
||||
func (r *InMemoryTokenRuntime) Resolve(ctx context.Context, tokenID string) (TokenStatus, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
r.mu.Lock()
|
||||
defer r.mu.Unlock()
|
||||
|
||||
record, ok := r.store.GetByTokenID(tokenID)
|
||||
record, ok, err := r.store.GetByTokenID(ctx, tokenID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !ok {
|
||||
return "", NewAuthError(CodeAuthInvalidToken, errors.New("token not found"))
|
||||
}
|
||||
r.applyExpiry(record)
|
||||
if r.applyExpiry(record) {
|
||||
if err := r.store.Save(ctx, *record, "", ""); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
return record.Status, nil
|
||||
}
|
||||
|
||||
func (r *InMemoryTokenRuntime) TokenCount() int {
|
||||
r.mu.RLock()
|
||||
defer r.mu.RUnlock()
|
||||
return r.store.TokenCount()
|
||||
if counter, ok := r.store.(interface{ TokenCount() int }); ok {
|
||||
return counter.TokenCount()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (r *InMemoryTokenRuntime) IssueAndAudit(ctx context.Context, input IssueTokenInput, auditor AuditEmitter) (TokenRecord, error) {
|
||||
@@ -269,13 +341,15 @@ func (r *InMemoryTokenRuntime) RevokeAndAudit(ctx context.Context, tokenID, reas
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (r *InMemoryTokenRuntime) applyExpiry(record *TokenRecord) {
|
||||
func (r *InMemoryTokenRuntime) applyExpiry(record *TokenRecord) bool {
|
||||
if record == nil {
|
||||
return
|
||||
return false
|
||||
}
|
||||
if record.Status == TokenStatusActive && !record.ExpiresAt.IsZero() && !r.now().Before(record.ExpiresAt) {
|
||||
record.Status = TokenStatusExpired
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func cloneRecord(record TokenRecord) TokenRecord {
|
||||
|
||||
@@ -1,57 +1,77 @@
|
||||
package service
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type InMemoryRuntimeStore struct {
|
||||
mu sync.RWMutex
|
||||
records map[string]*TokenRecord
|
||||
tokenToID map[string]string
|
||||
idempotencyByKey map[string]idempotencyEntry
|
||||
mu sync.RWMutex
|
||||
records map[string]*TokenRecord
|
||||
tokenToID map[string]string
|
||||
idempotencyByKey map[string]IdempotencyEntry
|
||||
}
|
||||
|
||||
func NewInMemoryRuntimeStore() *InMemoryRuntimeStore {
|
||||
return &InMemoryRuntimeStore{
|
||||
records: make(map[string]*TokenRecord),
|
||||
tokenToID: make(map[string]string),
|
||||
idempotencyByKey: make(map[string]idempotencyEntry),
|
||||
idempotencyByKey: make(map[string]IdempotencyEntry),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *InMemoryRuntimeStore) Save(record TokenRecord, idempotencyKey, requestHash string) {
|
||||
func (s *InMemoryRuntimeStore) Save(_ context.Context, record TokenRecord, idempotencyKey, requestHash string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
recordCopy := cloneRecord(record)
|
||||
s.records[record.TokenID] = &recordCopy
|
||||
s.tokenToID[record.AccessToken] = record.TokenID
|
||||
if record.AccessToken != "" {
|
||||
s.tokenToID[record.AccessToken] = record.TokenID
|
||||
}
|
||||
if idempotencyKey != "" {
|
||||
s.idempotencyByKey[idempotencyKey] = idempotencyEntry{
|
||||
s.idempotencyByKey[idempotencyKey] = IdempotencyEntry{
|
||||
RequestHash: requestHash,
|
||||
TokenID: record.TokenID,
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *InMemoryRuntimeStore) GetByTokenID(tokenID string) (*TokenRecord, bool) {
|
||||
func (s *InMemoryRuntimeStore) GetByTokenID(_ context.Context, tokenID string) (*TokenRecord, bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
record, ok := s.records[tokenID]
|
||||
return record, ok
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
cloned := cloneRecord(*record)
|
||||
return &cloned, true, nil
|
||||
}
|
||||
|
||||
func (s *InMemoryRuntimeStore) GetByAccessToken(accessToken string) (*TokenRecord, bool) {
|
||||
func (s *InMemoryRuntimeStore) GetByAccessToken(_ context.Context, accessToken string) (*TokenRecord, bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
tokenID, ok := s.tokenToID[accessToken]
|
||||
if !ok {
|
||||
return nil, false
|
||||
return nil, false, nil
|
||||
}
|
||||
record, ok := s.records[tokenID]
|
||||
return record, ok
|
||||
if !ok {
|
||||
return nil, false, nil
|
||||
}
|
||||
cloned := cloneRecord(*record)
|
||||
return &cloned, true, nil
|
||||
}
|
||||
|
||||
func (s *InMemoryRuntimeStore) LookupIdempotency(idempotencyKey string) (idempotencyEntry, bool) {
|
||||
func (s *InMemoryRuntimeStore) LookupIdempotency(_ context.Context, idempotencyKey string) (IdempotencyEntry, bool, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
entry, ok := s.idempotencyByKey[idempotencyKey]
|
||||
return entry, ok
|
||||
return entry, ok, nil
|
||||
}
|
||||
|
||||
func (s *InMemoryRuntimeStore) TokenCount() int {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return len(s.records)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestInMemoryRuntimeStore_SaveAndLookup(t *testing.T) {
|
||||
store := NewInMemoryRuntimeStore()
|
||||
@@ -12,9 +15,14 @@ func TestInMemoryRuntimeStore_SaveAndLookup(t *testing.T) {
|
||||
Scope: []string{"supply:*"},
|
||||
}
|
||||
|
||||
store.Save(record, "idem-1", "hash-1")
|
||||
if err := store.Save(context.Background(), record, "idem-1", "hash-1"); err != nil {
|
||||
t.Fatalf("save record: %v", err)
|
||||
}
|
||||
|
||||
byID, ok := store.GetByTokenID("tok_123")
|
||||
byID, ok, err := store.GetByTokenID(context.Background(), "tok_123")
|
||||
if err != nil {
|
||||
t.Fatalf("get by token id: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("expected record by token id")
|
||||
}
|
||||
@@ -22,7 +30,10 @@ func TestInMemoryRuntimeStore_SaveAndLookup(t *testing.T) {
|
||||
t.Fatalf("unexpected token id: %s", byID.TokenID)
|
||||
}
|
||||
|
||||
byToken, ok := store.GetByAccessToken("ptk_123")
|
||||
byToken, ok, err := store.GetByAccessToken(context.Background(), "ptk_123")
|
||||
if err != nil {
|
||||
t.Fatalf("get by access token: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("expected record by access token")
|
||||
}
|
||||
@@ -30,7 +41,10 @@ func TestInMemoryRuntimeStore_SaveAndLookup(t *testing.T) {
|
||||
t.Fatalf("unexpected subject id: %s", byToken.SubjectID)
|
||||
}
|
||||
|
||||
entry, ok := store.LookupIdempotency("idem-1")
|
||||
entry, ok, err := store.LookupIdempotency(context.Background(), "idem-1")
|
||||
if err != nil {
|
||||
t.Fatalf("lookup idempotency: %v", err)
|
||||
}
|
||||
if !ok {
|
||||
t.Fatal("expected idempotency entry")
|
||||
}
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
package service
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestInMemoryStoresImplementContracts(t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
var runtimeStore RuntimeStore = NewInMemoryRuntimeStore()
|
||||
var auditStore AuditStore = NewMemoryAuditStore()
|
||||
|
||||
if runtimeStore == nil {
|
||||
t.Fatal("expected runtime store contract implementation")
|
||||
}
|
||||
if auditStore == nil {
|
||||
t.Fatal("expected audit store contract implementation")
|
||||
}
|
||||
}
|
||||
@@ -79,6 +79,16 @@ type AuditEmitter interface {
|
||||
Emit(ctx context.Context, event AuditEvent) error
|
||||
}
|
||||
|
||||
type AuditStore interface {
|
||||
AuditEmitter
|
||||
AuditEventQuerier
|
||||
}
|
||||
|
||||
type IdempotencyEntry struct {
|
||||
RequestHash string
|
||||
TokenID string
|
||||
}
|
||||
|
||||
type AuditEventFilter struct {
|
||||
RequestID string
|
||||
TokenID string
|
||||
@@ -92,6 +102,13 @@ type AuditEventQuerier interface {
|
||||
QueryEvents(ctx context.Context, filter AuditEventFilter) ([]AuditEvent, error)
|
||||
}
|
||||
|
||||
type RuntimeStore interface {
|
||||
Save(ctx context.Context, record TokenRecord, idempotencyKey, requestHash string) error
|
||||
GetByTokenID(ctx context.Context, tokenID string) (*TokenRecord, bool, error)
|
||||
GetByAccessToken(ctx context.Context, accessToken string) (*TokenRecord, bool, error)
|
||||
LookupIdempotency(ctx context.Context, idempotencyKey string) (IdempotencyEntry, bool, error)
|
||||
}
|
||||
|
||||
type AuthError struct {
|
||||
Code string
|
||||
Cause error
|
||||
|
||||
Reference in New Issue
Block a user