package middleware import ( "context" "errors" "net/http" "net/http/httptest" "testing" "time" "github.com/gin-gonic/gin" "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/cache" "github.com/user-management-system/internal/domain" "github.com/user-management-system/internal/repository" "github.com/user-management-system/internal/service" gormsqlite "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/logger" _ "modernc.org/sqlite" ) type authStubUserRepo struct { user *domain.User err error } func (s authStubUserRepo) GetByID(_ context.Context, _ int64) (*domain.User, error) { return s.user, s.err } type authStubUserRoleRepo struct { roles []*domain.Role perms []*domain.Permission err error } func (s authStubUserRoleRepo) GetUserRolesAndPermissions(_ context.Context, _ int64) ([]*domain.Role, []*domain.Permission, error) { return s.roles, s.perms, s.err } func newTestJWT(t *testing.T) *auth.JWT { t.Helper() jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{ HS256Secret: "test-middleware-secret-at-least-32-chars", AccessTokenExpire: 15 * time.Minute, RefreshTokenExpire: 7 * 24 * time.Hour, }) if err != nil { t.Fatalf("create jwt manager failed: %v", err) } return jwtManager } func newAuthMiddlewareForTest(t *testing.T, user *domain.User, roles []*domain.Role, perms []*domain.Permission) (*AuthMiddleware, *auth.JWT, *cache.L1Cache) { t.Helper() jwtManager := newTestJWT(t) l1Cache := cache.NewL1Cache() middleware := NewAuthMiddleware(jwtManager, authStubUserRepo{user: user}, authStubUserRoleRepo{roles: roles, perms: perms}, l1Cache) return middleware, jwtManager, l1Cache } func performMiddlewareRequest(t *testing.T, middleware gin.HandlerFunc, authHeader string) *httptest.ResponseRecorder { t.Helper() gin.SetMode(gin.TestMode) recorder := httptest.NewRecorder() router := gin.New() router.Use(middleware) router.GET("/protected", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"code": 0}) }) req := httptest.NewRequest(http.MethodGet, "/protected", nil) if authHeader != "" { req.Header.Set("Authorization", authHeader) } router.ServeHTTP(recorder, req) return recorder } func TestAuthMiddleware_AcceptsBootstrapAdminTokenImmediately(t *testing.T) { t.Helper() gin.SetMode(gin.TestMode) db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{ DriverName: "sqlite", DSN: "file:middleware_bootstrap_test?mode=memory&cache=shared", }), &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { t.Fatalf("open sqlite failed: %v", err) } if err := db.AutoMigrate(&domain.User{}, &domain.Role{}, &domain.UserRole{}); err != nil { t.Fatalf("migrate failed: %v", err) } if err := db.Create(&domain.Role{ Name: "管理员", Code: "admin", IsSystem: true, Status: domain.RoleStatusEnabled, }).Error; err != nil { t.Fatalf("seed admin role failed: %v", err) } jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{ HS256Secret: "test-bootstrap-token-secret-at-least-32-chars", AccessTokenExpire: 15 * time.Minute, RefreshTokenExpire: 7 * 24 * time.Hour, }) if err != nil { t.Fatalf("create jwt manager failed: %v", err) } l1Cache := cache.NewL1Cache() l2Cache := cache.NewRedisCache(false) cacheManager := cache.NewCacheManager(l1Cache, l2Cache) userRepo := repository.NewUserRepository(db) roleRepo := repository.NewRoleRepository(db) userRoleRepo := repository.NewUserRoleRepository(db) authService := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute) authService.SetRoleRepositories(userRoleRepo, roleRepo) loginResponse, err := authService.BootstrapAdmin(context.Background(), &service.BootstrapAdminRequest{ Username: "bootstrap_admin", Email: "bootstrap_admin@example.com", Password: "AdminPass123!", }, "127.0.0.1") if err != nil { t.Fatalf("bootstrap admin failed: %v", err) } if loginResponse == nil || loginResponse.AccessToken == "" { t.Fatalf("expected bootstrap access token, got %+v", loginResponse) } if _, err := jwtManager.ValidateAccessToken(loginResponse.AccessToken); err != nil { t.Fatalf("bootstrap access token should validate immediately: %v", err) } authMiddleware := NewAuthMiddleware(jwtManager, userRepo, userRoleRepo, l1Cache) authMiddleware.SetCacheManager(cacheManager) recorder := httptest.NewRecorder() ctx, engine := gin.CreateTestContext(recorder) ctx.Request = httptest.NewRequest(http.MethodGet, "/protected", nil) ctx.Request.Header.Set("Authorization", "Bearer "+loginResponse.AccessToken) engine.Use(authMiddleware.Required()) engine.GET("/protected", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"code": 0}) }) engine.ServeHTTP(recorder, ctx.Request) if recorder.Code != http.StatusOK { t.Fatalf("expected bootstrap token to pass auth middleware immediately, got %d body: %s", recorder.Code, recorder.Body.String()) } } func TestAuthMiddleware_RequiredRejectsMissingToken(t *testing.T) { middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil) recorder := performMiddlewareRequest(t, middleware.Required(), "") if recorder.Code != http.StatusUnauthorized { t.Fatalf("expected 401 for missing token, got %d", recorder.Code) } } func TestAuthMiddleware_RequiredRejectsInvalidToken(t *testing.T) { middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil) recorder := performMiddlewareRequest(t, middleware.Required(), "Bearer not-a-jwt") if recorder.Code != http.StatusUnauthorized { t.Fatalf("expected 401 for invalid token, got %d", recorder.Code) } } func TestAuthMiddleware_RequiredRejectsBlacklistedToken(t *testing.T) { user := &domain.User{ID: 7, Username: "alice", Status: domain.UserStatusActive} middleware, jwtManager, l1Cache := newAuthMiddlewareForTest(t, user, nil, nil) token, err := jwtManager.GenerateAccessToken(user.ID, user.Username, 0) if err != nil { t.Fatalf("generate access token failed: %v", err) } claims, err := jwtManager.ValidateAccessToken(token) if err != nil { t.Fatalf("validate access token failed: %v", err) } l1Cache.Set("jwt_blacklist:"+claims.JTI, true, time.Minute) recorder := performMiddlewareRequest(t, middleware.Required(), "Bearer "+token) if recorder.Code != http.StatusUnauthorized { t.Fatalf("expected 401 for blacklisted token, got %d", recorder.Code) } } func TestAuthMiddleware_RequiredRejectsInactiveUser(t *testing.T) { user := &domain.User{ID: 8, Username: "disabled", Status: domain.UserStatusDisabled} middleware, jwtManager, _ := newAuthMiddlewareForTest(t, user, nil, nil) token, err := jwtManager.GenerateAccessToken(user.ID, user.Username, 0) if err != nil { t.Fatalf("generate access token failed: %v", err) } recorder := performMiddlewareRequest(t, middleware.Required(), "Bearer "+token) if recorder.Code != http.StatusUnauthorized { t.Fatalf("expected 401 for inactive user, got %d", recorder.Code) } } func TestAuthMiddleware_RequiredInjectsIdentityAndAuthorizations(t *testing.T) { gin.SetMode(gin.TestMode) user := &domain.User{ID: 9, Username: "admin", Status: domain.UserStatusActive} roles := []*domain.Role{{Code: "admin"}, {Code: "auditor"}} perms := []*domain.Permission{{Code: "users:read"}, {Code: "users:write"}} middleware, jwtManager, _ := newAuthMiddlewareForTest(t, user, roles, perms) token, err := jwtManager.GenerateAccessToken(user.ID, user.Username, 0) if err != nil { t.Fatalf("generate access token failed: %v", err) } recorder := httptest.NewRecorder() router := gin.New() router.Use(middleware.Required()) router.GET("/protected", func(c *gin.Context) { if got := c.GetInt64("user_id"); got != user.ID { t.Fatalf("user_id = %d, want %d", got, user.ID) } if got := c.GetString("username"); got != user.Username { t.Fatalf("username = %q, want %q", got, user.Username) } roleCodes := GetRoleCodes(c) if len(roleCodes) != 2 || roleCodes[0] != "admin" || roleCodes[1] != "auditor" { t.Fatalf("unexpected role codes: %#v", roleCodes) } permCodes := GetPermissionCodes(c) if len(permCodes) != 2 || permCodes[0] != "users:read" || permCodes[1] != "users:write" { t.Fatalf("unexpected permission codes: %#v", permCodes) } c.JSON(http.StatusOK, gin.H{"code": 0}) }) req := httptest.NewRequest(http.MethodGet, "/protected", nil) req.Header.Set("Authorization", "Bearer "+token) router.ServeHTTP(recorder, req) if recorder.Code != http.StatusOK { t.Fatalf("expected 200 for valid token, got %d body: %s", recorder.Code, recorder.Body.String()) } } func TestAuthMiddleware_OptionalAllowsAnonymousRequest(t *testing.T) { middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil) recorder := performMiddlewareRequest(t, middleware.Optional(), "") if recorder.Code != http.StatusOK { t.Fatalf("expected optional middleware to allow anonymous request, got %d", recorder.Code) } } func TestAuthMiddleware_OptionalInjectsIdentityForValidToken(t *testing.T) { gin.SetMode(gin.TestMode) user := &domain.User{ID: 21, Username: "optional-user", Status: domain.UserStatusActive} roles := []*domain.Role{{Code: "viewer"}} perms := []*domain.Permission{{Code: "users:read"}} middleware, jwtManager, _ := newAuthMiddlewareForTest(t, user, roles, perms) token, err := jwtManager.GenerateAccessToken(user.ID, user.Username, 0) if err != nil { t.Fatalf("generate access token failed: %v", err) } recorder := httptest.NewRecorder() router := gin.New() router.Use(middleware.Optional()) router.GET("/optional", func(c *gin.Context) { if got := c.GetInt64("user_id"); got != user.ID { t.Fatalf("user_id = %d, want %d", got, user.ID) } if got := c.GetString("username"); got != user.Username { t.Fatalf("username = %q, want %q", got, user.Username) } if got := GetRoleCodes(c); len(got) != 1 || got[0] != "viewer" { t.Fatalf("role_codes = %#v, want [viewer]", got) } if got := GetPermissionCodes(c); len(got) != 1 || got[0] != "users:read" { t.Fatalf("permission_codes = %#v, want [users:read]", got) } c.Status(http.StatusOK) }) req := httptest.NewRequest(http.MethodGet, "/optional", nil) req.Header.Set("Authorization", "Bearer "+token) router.ServeHTTP(recorder, req) if recorder.Code != http.StatusOK { t.Fatalf("expected valid optional auth request to pass, got %d", recorder.Code) } } func TestAuthMiddleware_ExtractTokenCases(t *testing.T) { gin.SetMode(gin.TestMode) middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil) testCases := []struct { name string header string want string }{ {name: "missing header", header: "", want: ""}, {name: "valid bearer", header: "Bearer abc.def", want: "abc.def"}, {name: "lowercase bearer rejected", header: "bearer abc", want: ""}, {name: "missing token value", header: "Bearer", want: ""}, {name: "wrong scheme", header: "Basic abc", want: ""}, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { recorder := httptest.NewRecorder() c, _ := gin.CreateTestContext(recorder) c.Request = httptest.NewRequest(http.MethodGet, "/protected", nil) if tc.header != "" { c.Request.Header.Set("Authorization", tc.header) } if got := middleware.extractToken(c); got != tc.want { t.Fatalf("extractToken() = %q, want %q", got, tc.want) } }) } } func TestAuthMiddleware_ValidateUserStateAndCacheInvalidation(t *testing.T) { user := &domain.User{ ID: 11, Username: "cached-user", Status: domain.UserStatusActive, PasswordChangedAt: time.Unix(200, 0), } middleware, _, l1Cache := newAuthMiddlewareForTest(t, user, nil, nil) if got := middleware.validateUserState(context.Background(), user.ID, 150); got == "" { t.Fatal("expected password-changed denial for stale token") } if _, ok := l1Cache.Get("user_state:11"); !ok { t.Fatal("expected user state to be cached") } middleware.InvalidateUserStateCache(user.ID) if _, ok := l1Cache.Get("user_state:11"); ok { t.Fatal("expected user state cache to be cleared") } } func TestAuthMiddleware_LoadUserRolesAndPermsCachesAndInvalidates(t *testing.T) { user := &domain.User{ID: 12, Username: "role-user", Status: domain.UserStatusActive} roles := []*domain.Role{{Code: "admin"}} perms := []*domain.Permission{{Code: "users:read"}} middleware, _, l1Cache := newAuthMiddlewareForTest(t, user, roles, perms) roleCodes, permCodes := middleware.loadUserRolesAndPerms(context.Background(), user.ID) if len(roleCodes) != 1 || roleCodes[0] != "admin" { t.Fatalf("unexpected role codes: %#v", roleCodes) } if len(permCodes) != 1 || permCodes[0] != "users:read" { t.Fatalf("unexpected permission codes: %#v", permCodes) } if _, ok := l1Cache.Get("user_perms:12"); !ok { t.Fatal("expected user permissions to be cached") } middleware.InvalidateUserPermCache(user.ID) if _, ok := l1Cache.Get("user_perms:12"); ok { t.Fatal("expected user permission cache to be cleared") } } func TestAuthMiddleware_AddToBlacklistAndUserHelpers(t *testing.T) { activeUser := &domain.User{ID: 13, Username: "active", Status: domain.UserStatusActive} middleware, _, l1Cache := newAuthMiddlewareForTest(t, activeUser, nil, nil) middleware.AddToBlacklist("jti-1", time.Minute) if _, ok := l1Cache.Get("jwt_blacklist:jti-1"); !ok { t.Fatal("expected blacklist entry in cache") } if !middleware.isUserActive(context.Background(), activeUser.ID) { t.Fatal("expected active user to be active") } if middleware.isPasswordChangedSinceTokenIssued(context.Background(), activeUser.ID, 0) { t.Fatal("expected zero token pce to skip password change check") } changedUser := &domain.User{ ID: 14, Username: "changed", Status: domain.UserStatusActive, PasswordChangedAt: time.Unix(300, 0), } changedMiddleware, _, _ := newAuthMiddlewareForTest(t, changedUser, nil, nil) if !changedMiddleware.isPasswordChangedSinceTokenIssued(context.Background(), changedUser.ID, 200) { t.Fatal("expected password-changed helper to return true") } } func TestAuthMiddleware_UserHelpersHandleRepoFailures(t *testing.T) { middleware, _, _ := newAuthMiddlewareForTest(t, nil, nil, nil) middleware.userRepo = authStubUserRepo{err: errors.New("db down")} if middleware.isUserActive(context.Background(), 99) { t.Fatal("expected repo failure to mark user inactive") } if got := middleware.validateUserState(context.Background(), 99, 0); got == "" { t.Fatal("expected validateUserState to deny on repo failure") } }