test: 补齐 handler/repository/domain 层单元测试

This commit is contained in:
2026-05-10 12:54:13 +08:00
parent b8e9af001f
commit 28012140cb
21 changed files with 5837 additions and 1 deletions

View File

@@ -0,0 +1,297 @@
package handler
import (
"bytes"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
func TestAuthHandler_SupportFlags(t *testing.T) {
var nilHandler *AuthHandler
if nilHandler.SupportsPasswordReset() {
t.Fatal("nil handler should not support password reset")
}
handler := &AuthHandler{}
if handler.SupportsPasswordReset() {
t.Fatal("password reset should be disabled by default")
}
handler.SetPasswordResetEnabled(true)
if !handler.SupportsPasswordReset() {
t.Fatal("password reset flag should be enabled")
}
}
func TestGetUserIDFromContext(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/userinfo", nil)
if _, ok := getUserIDFromContext(c); ok {
t.Fatal("expected missing user_id to return false")
}
c.Set("user_id", "1")
if _, ok := getUserIDFromContext(c); ok {
t.Fatal("expected non-int64 user_id to return false")
}
c.Set("user_id", int64(42))
if got, ok := getUserIDFromContext(c); !ok || got != 42 {
t.Fatalf("getUserIDFromContext() = (%d, %v), want (42, true)", got, ok)
}
}
func TestRequestUsesHTTPS(t *testing.T) {
gin.SetMode(gin.TestMode)
if requestUsesHTTPS(nil) {
t.Fatal("nil context should not use https")
}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
if requestUsesHTTPS(c) {
t.Fatal("plain http request should not use https")
}
c.Request.Header.Set("X-Forwarded-Proto", "https")
if !requestUsesHTTPS(c) {
t.Fatal("forwarded https request should be detected")
}
}
func TestSessionCookies_SetAndClear(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
setSessionCookies(c, nil, "")
if len(recorder.Header().Values("Set-Cookie")) != 0 {
t.Fatal("empty refresh token should not set cookies")
}
setSessionCookies(c, nil, "refresh-token")
setCookies := recorder.Header().Values("Set-Cookie")
if len(setCookies) < 2 {
t.Fatalf("expected session cookies to be set, got %d", len(setCookies))
}
if !strings.Contains(setCookies[0], refreshTokenCookieName+"=refresh-token") &&
!strings.Contains(setCookies[1], refreshTokenCookieName+"=refresh-token") {
t.Fatalf("expected refresh token cookie, got %#v", setCookies)
}
recorder = httptest.NewRecorder()
c, _ = gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
clearSessionCookies(c)
setCookies = recorder.Header().Values("Set-Cookie")
if len(setCookies) < 2 {
t.Fatalf("expected clearing cookies to emit expired cookies, got %d", len(setCookies))
}
}
func TestClassifyErrorMessage(t *testing.T) {
testCases := []struct {
name string
msg string
want int
}{
{name: "not found", msg: "user not found", want: http.StatusNotFound},
{name: "duplicate", msg: "already exists", want: http.StatusConflict},
{name: "verification code", msg: "验证码错误", want: http.StatusUnauthorized},
{name: "unauthorized", msg: "invalid token", want: http.StatusUnauthorized},
{name: "forbidden", msg: "permission denied", want: http.StatusForbidden},
{name: "bad request", msg: "invalid payload", want: http.StatusBadRequest},
{name: "rate limit", msg: "too many attempts", want: http.StatusTooManyRequests},
{name: "fallback", msg: "unexpected boom", want: http.StatusInternalServerError},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
if got := classifyErrorMessage(tc.msg); got != tc.want {
t.Fatalf("classifyErrorMessage(%q) = %d, want %d", tc.msg, got, tc.want)
}
})
}
}
func TestAuthHandler_OAuthFallbackEndpoints(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
testCases := []struct {
name string
run func(*gin.Context)
}{
{
name: "oauth login",
run: func(c *gin.Context) {
c.Params = gin.Params{{Key: "provider", Value: "github"}}
h.OAuthLogin(c)
},
},
{
name: "oauth callback",
run: func(c *gin.Context) {
c.Params = gin.Params{{Key: "provider", Value: "github"}}
h.OAuthCallback(c)
},
},
{
name: "oauth exchange",
run: func(c *gin.Context) {
c.Params = gin.Params{{Key: "provider", Value: "github"}}
h.OAuthExchange(c)
},
},
{
name: "oauth providers",
run: func(c *gin.Context) {
h.GetEnabledOAuthProviders(c)
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
tc.run(c)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
})
}
}
func TestAuthHandler_RefreshToken_InvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/auth/refresh", bytes.NewBufferString("{"))
c.Request.Header.Set("Content-Type", "application/json")
h.RefreshToken(c)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
}
func TestAuthHandler_ActivateEmail_MissingToken(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/auth/activate-email", bytes.NewBufferString(`{}`))
c.Request.Header.Set("Content-Type", "application/json")
h.ActivateEmail(c)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
}
func TestAuthHandler_ResendActivationEmail_InvalidEmail(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/auth/resend-activation-email", bytes.NewBufferString(`{"email":"bad-email"}`))
c.Request.Header.Set("Content-Type", "application/json")
h.ResendActivationEmail(c)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
}
func TestAuthHandler_SendEmailCode_InvalidEmail(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/auth/send-email-code", bytes.NewBufferString(`{"email":"bad-email"}`))
c.Request.Header.Set("Content-Type", "application/json")
h.SendEmailCode(c)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
}
func TestAuthHandler_LoginByEmailCode_InvalidPayload(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/auth/login-by-email-code", bytes.NewBufferString(`{"email":"bad-email"}`))
c.Request.Header.Set("Content-Type", "application/json")
h.LoginByEmailCode(c)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
}
func TestAuthHandler_BootstrapAdmin_HeaderFailures(t *testing.T) {
gin.SetMode(gin.TestMode)
h := &AuthHandler{}
original := os.Getenv("BOOTSTRAP_SECRET")
if err := os.Setenv("BOOTSTRAP_SECRET", "expected-secret"); err != nil {
t.Fatalf("set env failed: %v", err)
}
t.Cleanup(func() {
_ = os.Setenv("BOOTSTRAP_SECRET", original)
})
testCases := []struct {
name string
secret string
want int
}{
{name: "missing header", secret: "", want: http.StatusUnauthorized},
{name: "wrong header", secret: "wrong-secret", want: http.StatusUnauthorized},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/auth/bootstrap-admin", bytes.NewBufferString(`{"username":"admin","email":"admin@example.com","password":"AdminPass123!"}`))
c.Request.Header.Set("Content-Type", "application/json")
if tc.secret != "" {
c.Request.Header.Set("X-Bootstrap-Secret", tc.secret)
}
h.BootstrapAdmin(c)
if recorder.Code != tc.want {
t.Fatalf("expected %d, got %d", tc.want, recorder.Code)
}
})
}
}

View File

@@ -0,0 +1,151 @@
package handler_test
import (
"bytes"
"io"
"mime/multipart"
"net/http"
"os"
"testing"
)
// minimalPNG is a valid 1x1 PNG image
var minimalPNG = []byte{
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D,
0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01,
0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00,
0x0C, 0x49, 0x44, 0x41, 0x54, 0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00,
0x00, 0x00, 0x03, 0x00, 0x01, 0x00, 0x05, 0xFE, 0xD8, 0x00, 0x00, 0x00,
0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82,
}
func buildAvatarUploadRequest(t *testing.T, url, token string, fileBody []byte, filename string) *http.Request {
t.Helper()
var body bytes.Buffer
writer := multipart.NewWriter(&body)
part, err := writer.CreateFormFile("avatar", filename)
if err != nil {
t.Fatalf("create form file failed: %v", err)
}
if _, err := part.Write(fileBody); err != nil {
t.Fatalf("write file body failed: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("close multipart writer failed: %v", err)
}
req, err := http.NewRequest(http.MethodPost, url, &body)
if err != nil {
t.Fatalf("create request failed: %v", err)
}
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
return req
}
func TestAvatarHandler_UploadAvatar(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "avatar-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "avatar-bootstrap-secret", "avataradmin", "avataradmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
if ok := registerUser(server.URL, "avataruser", "avataruser@test.com", "UserPass123!"); !ok {
t.Fatal("register user failed")
}
userToken := getToken(server.URL, "avataruser", "UserPass123!")
if userToken == "" {
t.Fatal("get user token failed")
}
tests := []struct {
name string
userID string
token string
fileBody []byte
filename string
wantStatus int
}{
{
name: "admin_upload_for_any_user",
userID: "2",
token: adminToken,
fileBody: minimalPNG,
filename: "avatar.png",
wantStatus: http.StatusForbidden,
},
{
name: "user_upload_own_avatar",
userID: "2",
token: userToken,
fileBody: minimalPNG,
filename: "avatar.png",
wantStatus: http.StatusOK,
},
{
name: "unauthorized",
userID: "1",
token: "",
fileBody: minimalPNG,
filename: "avatar.png",
wantStatus: http.StatusUnauthorized,
},
{
name: "forbidden_cross_user",
userID: "1",
token: userToken,
fileBody: minimalPNG,
filename: "avatar.png",
wantStatus: http.StatusForbidden,
},
{
name: "invalid_user_id",
userID: "invalid",
token: adminToken,
fileBody: minimalPNG,
filename: "avatar.png",
wantStatus: http.StatusBadRequest,
},
{
name: "invalid_file_type",
userID: "1",
token: adminToken,
fileBody: []byte("this is not an image"),
filename: "avatar.txt",
wantStatus: http.StatusBadRequest,
},
{
name: "user_not_found",
userID: "99999",
token: adminToken,
fileBody: minimalPNG,
filename: "avatar.png",
wantStatus: http.StatusForbidden,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := buildAvatarUploadRequest(t, server.URL+"/api/v1/users/"+tt.userID+"/avatar", tt.token, tt.fileBody, tt.filename)
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
body, _ := io.ReadAll(resp.Body)
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, string(body))
}
})
}
// Clean up uploaded avatars
_ = os.RemoveAll("./uploads/avatars")
}

View File

@@ -0,0 +1,545 @@
package handler_test
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/api/handler"
"github.com/user-management-system/internal/api/middleware"
"github.com/user-management-system/internal/api/router"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/config"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
"github.com/user-management-system/internal/service"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
var customFieldDbCounter int64
func setupCustomFieldTestServer(t *testing.T) (*httptest.Server, string, string, func()) {
t.Helper()
gin.SetMode(gin.TestMode)
id := atomic.AddInt64(&customFieldDbCounter, 1)
dsn := fmt.Sprintf("file:cfdb_%d_%s?mode=memory&cache=shared", id, t.Name())
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Skipf("skipping custom field test (SQLite unavailable): %v", err)
return nil, "", "", func() {}
}
if err := db.AutoMigrate(
&domain.User{},
&domain.Role{},
&domain.Permission{},
&domain.UserRole{},
&domain.RolePermission{},
&domain.CustomField{},
&domain.UserCustomFieldValue{},
); err != nil {
t.Fatalf("db migration failed: %v", err)
}
seedHandlerAuthzData(t, db)
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
HS256Secret: "test-cf-secret-key",
AccessTokenExpire: 15 * time.Minute,
RefreshTokenExpire: 7 * 24 * time.Hour,
})
if err != nil {
t.Fatalf("create jwt manager failed: %v", err)
}
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
userRepo := repository.NewUserRepository(db)
roleRepo := repository.NewRoleRepository(db)
userRoleRepo := repository.NewUserRoleRepository(db)
authSvc := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
fieldRepo := repository.NewCustomFieldRepository(db)
valueRepo := repository.NewUserCustomFieldValueRepository(db)
cfSvc := service.NewCustomFieldService(fieldRepo, valueRepo)
cfHandler := handler.NewCustomFieldHandler(cfSvc)
rateLimitCfg := config.RateLimitConfig{}
rateLimitMiddleware := middleware.NewRateLimitMiddleware(rateLimitCfg)
authMiddleware := middleware.NewAuthMiddleware(
jwtManager, userRepo, userRoleRepo, l1Cache,
)
authMiddleware.SetCacheManager(cacheManager)
authHandler := handler.NewAuthHandler(authSvc)
r := router.NewRouter(
authHandler, nil, nil, nil, nil, nil,
authMiddleware, rateLimitMiddleware, nil,
nil, nil, nil, nil,
nil, nil, nil, nil, cfHandler, nil, nil, nil, nil,
)
engine := r.Setup()
server := httptest.NewServer(engine)
// Register a regular user
regBody := map[string]interface{}{
"username": fmt.Sprintf("cfuser_%d", id),
"password": "TestPass123!",
"email": fmt.Sprintf("cf_%d@test.com", id),
}
regBytes, _ := json.Marshal(regBody)
regResp, _ := http.Post(server.URL+"/api/v1/auth/register", "application/json", bytes.NewReader(regBytes))
io.ReadAll(regResp.Body)
regResp.Body.Close()
// Login as regular user
loginBody := map[string]interface{}{
"account": regBody["username"],
"password": regBody["password"],
}
loginBytes, _ := json.Marshal(loginBody)
loginResp, _ := http.Post(server.URL+"/api/v1/auth/login", "application/json", bytes.NewReader(loginBytes))
var loginResult struct {
Data struct {
AccessToken string `json:"access_token"`
} `json:"data"`
}
json.NewDecoder(loginResp.Body).Decode(&loginResult)
loginResp.Body.Close()
userToken := loginResult.Data.AccessToken
// Bootstrap admin
t.Setenv("BOOTSTRAP_SECRET", fmt.Sprintf("cf-bootstrap-%d", id))
adminToken := bootstrapAdmin(server.URL, fmt.Sprintf("cf-bootstrap-%d", id), fmt.Sprintf("cfadmin_%d", id), fmt.Sprintf("cfa_%d@test.com", id), "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
return server, adminToken, userToken, func() {
server.Close()
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}
}
func TestCustomFieldHandler_CreateField(t *testing.T) {
server, adminToken, userToken, cleanup := setupCustomFieldTestServer(t)
defer cleanup()
tests := []struct {
name string
payload map[string]interface{}
token string
wantStatus int
}{
{
name: "success",
payload: map[string]interface{}{
"name": "Test Field",
"field_key": "test_field_create",
"type": 1,
},
token: adminToken,
wantStatus: http.StatusCreated,
},
{
name: "unauthorized",
payload: map[string]interface{}{
"name": "Test Field Unauth",
"field_key": "test_field_unauth",
"type": 1,
},
token: "",
wantStatus: http.StatusUnauthorized,
},
{
name: "forbidden",
payload: map[string]interface{}{
"name": "Test Field Forbidden",
"field_key": "test_field_forbidden",
"type": 1,
},
token: userToken,
wantStatus: http.StatusForbidden,
},
{
name: "missing_required_fields",
payload: map[string]interface{}{"name": "Missing Key"},
token: adminToken,
wantStatus: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doPost(server.URL+"/api/v1/custom-fields", tt.token, tt.payload)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestCustomFieldHandler_ListFields(t *testing.T) {
server, adminToken, userToken, cleanup := setupCustomFieldTestServer(t)
defer cleanup()
tests := []struct {
name string
token string
wantStatus int
}{
{
name: "success_admin",
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "forbidden_regular_user",
token: userToken,
wantStatus: http.StatusForbidden,
},
{
name: "unauthorized",
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doGet(server.URL+"/api/v1/custom-fields", tt.token)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestCustomFieldHandler_GetField(t *testing.T) {
server, adminToken, _, cleanup := setupCustomFieldTestServer(t)
defer cleanup()
// Create a field
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
"name": "Get Field Test",
"field_key": "test_field_get",
"type": 1,
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
}
var createResult map[string]interface{}
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
t.Fatalf("parse create response failed: %v", err)
}
fieldData := createResult["data"].(map[string]interface{})
fieldID := int64(fieldData["id"].(float64))
tests := []struct {
name string
fieldID string
token string
wantStatus int
}{
{
name: "success",
fieldID: fmt.Sprintf("%d", fieldID),
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "not_found",
fieldID: "99999",
token: adminToken,
wantStatus: http.StatusNotFound,
},
{
name: "invalid_id",
fieldID: "invalid",
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "unauthorized",
fieldID: fmt.Sprintf("%d", fieldID),
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doGet(server.URL+"/api/v1/custom-fields/"+tt.fieldID, tt.token)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestCustomFieldHandler_UpdateField(t *testing.T) {
server, adminToken, _, cleanup := setupCustomFieldTestServer(t)
defer cleanup()
// Create a field
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
"name": "Update Field Test",
"field_key": "test_field_update",
"type": 1,
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
}
var createResult map[string]interface{}
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
t.Fatalf("parse create response failed: %v", err)
}
fieldData := createResult["data"].(map[string]interface{})
fieldID := int64(fieldData["id"].(float64))
tests := []struct {
name string
fieldID string
payload map[string]interface{}
token string
wantStatus int
}{
{
name: "success",
fieldID: fmt.Sprintf("%d", fieldID),
payload: map[string]interface{}{
"name": "Updated Field Name",
},
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "invalid_id",
fieldID: "invalid",
payload: map[string]interface{}{
"name": "Updated Field Name",
},
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "unauthorized",
fieldID: fmt.Sprintf("%d", fieldID),
payload: map[string]interface{}{
"name": "Updated Field Name",
},
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doPut(server.URL+"/api/v1/custom-fields/"+tt.fieldID, tt.token, tt.payload)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestCustomFieldHandler_DeleteField(t *testing.T) {
server, adminToken, _, cleanup := setupCustomFieldTestServer(t)
defer cleanup()
// Create a field
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
"name": "Delete Field Test",
"field_key": "test_field_delete",
"type": 1,
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
}
var createResult map[string]interface{}
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
t.Fatalf("parse create response failed: %v", err)
}
fieldData := createResult["data"].(map[string]interface{})
fieldID := int64(fieldData["id"].(float64))
tests := []struct {
name string
fieldID string
token string
wantStatus int
}{
{
name: "success",
fieldID: fmt.Sprintf("%d", fieldID),
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "invalid_id",
fieldID: "invalid",
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "unauthorized",
fieldID: fmt.Sprintf("%d", fieldID),
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doDelete(server.URL+"/api/v1/custom-fields/"+tt.fieldID, tt.token)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestCustomFieldHandler_SetUserFieldValues(t *testing.T) {
server, adminToken, userToken, cleanup := setupCustomFieldTestServer(t)
defer cleanup()
// Create a field for the user to set
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
"name": "User Field Test",
"field_key": "user_field_test",
"type": 1,
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
}
tests := []struct {
name string
payload map[string]interface{}
token string
wantStatus int
}{
{
name: "success",
payload: map[string]interface{}{
"values": map[string]string{
"user_field_test": "123",
},
},
token: userToken,
wantStatus: http.StatusOK,
},
{
name: "unauthorized",
payload: map[string]interface{}{
"values": map[string]string{
"user_field_test": "test_value",
},
},
token: "",
wantStatus: http.StatusUnauthorized,
},
{
name: "missing_values",
payload: map[string]interface{}{},
token: userToken,
wantStatus: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doPut(server.URL+"/api/v1/users/me/custom-fields", tt.token, tt.payload)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestCustomFieldHandler_GetUserFieldValues(t *testing.T) {
server, adminToken, userToken, cleanup := setupCustomFieldTestServer(t)
defer cleanup()
// Create a field
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
"name": "User Field Get Test",
"field_key": "user_field_get_test",
"type": 1,
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
}
// Set a value first
setResp, setBody := doPut(server.URL+"/api/v1/users/me/custom-fields", userToken, map[string]interface{}{
"values": map[string]string{
"user_field_get_test": "456",
},
})
defer setResp.Body.Close()
if setResp.StatusCode != http.StatusOK {
t.Fatalf("set field value failed: %d %s", setResp.StatusCode, setBody)
}
tests := []struct {
name string
token string
wantStatus int
}{
{
name: "success",
token: userToken,
wantStatus: http.StatusOK,
},
{
name: "unauthorized",
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doGet(server.URL+"/api/v1/users/me/custom-fields", tt.token)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}

View File

@@ -0,0 +1,510 @@
package handler_test
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"testing"
)
func TestDeviceHandler_ListDevices(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicelistuser", "devicelist@test.com", "UserPass123!")
token := getToken(server.URL, "devicelistuser", "UserPass123!")
resp, body := doGet(server.URL+"/api/v1/devices", token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["code"] != float64(0) {
t.Errorf("expected code 0, got %v", result["code"])
}
}
func TestDeviceHandler_ListDevices_Unauthorized(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
resp, _ := doGet(server.URL+"/api/v1/devices", "")
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
}
func TestDeviceHandler_CreateDevice(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicecreateuser", "devicecreate@test.com", "UserPass123!")
token := getToken(server.URL, "devicecreateuser", "UserPass123!")
resp, body := doPost(server.URL+"/api/v1/devices", token, map[string]interface{}{
"name": "Test Device",
"device_id": "device-test-001",
"device_type": 3,
"device_os": "Windows 10",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusCreated, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["code"] != float64(0) {
t.Errorf("expected code 0, got %v", result["code"])
}
}
func TestDeviceHandler_CreateDevice_InvalidBody(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicecreatebad", "devicecreatebad@test.com", "UserPass123!")
token := getToken(server.URL, "devicecreatebad", "UserPass123!")
req, _ := http.NewRequest("POST", server.URL+"/api/v1/devices", bytes.NewReader([]byte("not json")))
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d for invalid body, got %d", http.StatusBadRequest, resp.StatusCode)
}
}
func TestDeviceHandler_GetDevice(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicegetuser", "deviceget@test.com", "UserPass123!")
token := getToken(server.URL, "devicegetuser", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-get-001", "Get Device")
resp, body := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["code"] != float64(0) {
t.Errorf("expected code 0, got %v", result["code"])
}
}
func TestDeviceHandler_GetDevice_NotFound(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicegetnf", "devicegetnf@test.com", "UserPass123!")
token := getToken(server.URL, "devicegetnf", "UserPass123!")
resp, body := doGet(server.URL+"/api/v1/devices/99999", token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("expected status %d, got %d, body: %s", http.StatusNotFound, resp.StatusCode, body)
}
}
func TestDeviceHandler_GetDevice_InvalidID(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicegetinv", "devicegetinv@test.com", "UserPass123!")
token := getToken(server.URL, "devicegetinv", "UserPass123!")
resp, body := doGet(server.URL+"/api/v1/devices/invalid", token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestDeviceHandler_UpdateDevice(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "deviceupdateuser", "deviceupdate@test.com", "UserPass123!")
token := getToken(server.URL, "deviceupdateuser", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-update-001", "Original Name")
resp, body := doPut(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), token, map[string]interface{}{
"device_name": "Updated Name",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["code"] != float64(0) {
t.Errorf("expected code 0, got %v", result["code"])
}
}
func TestDeviceHandler_UpdateDevice_NotFound(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "deviceupdatenf", "deviceupdatenf@test.com", "UserPass123!")
token := getToken(server.URL, "deviceupdatenf", "UserPass123!")
resp, body := doPut(server.URL+"/api/v1/devices/99999", token, map[string]interface{}{
"device_name": "Updated Name",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("expected status %d, got %d, body: %s", http.StatusNotFound, resp.StatusCode, body)
}
}
func TestDeviceHandler_DeleteDevice(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicedeluser", "devicedel@test.com", "UserPass123!")
token := getToken(server.URL, "devicedeluser", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-del-001", "Delete Device")
resp, body := doDelete(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
// Verify deletion
getResp, _ := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), token)
defer getResp.Body.Close()
if getResp.StatusCode != http.StatusNotFound {
t.Errorf("expected device to be deleted, got status %d", getResp.StatusCode)
}
}
func TestDeviceHandler_DeleteDevice_NotFound(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicedelnf", "devicedelnf@test.com", "UserPass123!")
token := getToken(server.URL, "devicedelnf", "UserPass123!")
resp, body := doDelete(server.URL+"/api/v1/devices/99999", token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusNotFound {
t.Errorf("expected status %d, got %d, body: %s", http.StatusNotFound, resp.StatusCode, body)
}
}
func TestDeviceHandler_UpdateDeviceStatus(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicestatususer", "devicestatus@test.com", "UserPass123!")
token := getToken(server.URL, "devicestatususer", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-status-001", "Status Device")
resp, body := doPut(fmt.Sprintf("%s/api/v1/devices/%d/status", server.URL, deviceID), token, map[string]interface{}{
"status": "inactive",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
}
func TestDeviceHandler_UpdateDeviceStatus_InvalidStatus(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicestatusinv", "devicestatusinv@test.com", "UserPass123!")
token := getToken(server.URL, "devicestatusinv", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-status-inv-001", "Status Device")
resp, body := doPut(fmt.Sprintf("%s/api/v1/devices/%d/status", server.URL, deviceID), token, map[string]interface{}{
"status": "invalid_status",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestDeviceHandler_TrustDevice(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicetrustuser", "devicetrust@test.com", "UserPass123!")
token := getToken(server.URL, "devicetrustuser", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-trust-001", "Trust Device")
resp, body := doPost(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), token, map[string]interface{}{
"trust_duration": "24h",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
}
func TestDeviceHandler_UntrustDevice(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "deviceuntrustuser", "deviceuntrust@test.com", "UserPass123!")
token := getToken(server.URL, "deviceuntrustuser", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-untrust-001", "Untrust Device")
// First trust the device
trustResp, trustBody := doPost(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), token, map[string]interface{}{
"trust_duration": "24h",
})
defer trustResp.Body.Close()
if trustResp.StatusCode != http.StatusOK {
t.Fatalf("expected trust status %d, got %d, body: %s", http.StatusOK, trustResp.StatusCode, trustBody)
}
// Then untrust
resp, body := doDelete(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
}
func TestDeviceHandler_GetMyTrustedDevices(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicetrusteduser", "devicetrusted@test.com", "UserPass123!")
token := getToken(server.URL, "devicetrusteduser", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-trusted-001", "Trusted Device")
// Trust the device first
trustResp, trustBody := doPost(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), token, map[string]interface{}{
"trust_duration": "24h",
})
defer trustResp.Body.Close()
if trustResp.StatusCode != http.StatusOK {
t.Fatalf("expected trust status %d, got %d, body: %s", http.StatusOK, trustResp.StatusCode, trustBody)
}
resp, body := doGet(server.URL+"/api/v1/devices/me/trusted", token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["code"] != float64(0) {
t.Errorf("expected code 0, got %v", result["code"])
}
}
func TestDeviceHandler_LogoutAllOtherDevices(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicelogoutuser", "devicelogout@test.com", "UserPass123!")
token := getToken(server.URL, "devicelogoutuser", "UserPass123!")
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-logout-001", "Logout Device")
req, _ := http.NewRequest("POST", server.URL+"/api/v1/devices/me/logout-others", nil)
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("X-Device-ID", fmt.Sprintf("%d", deviceID))
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := json.Marshal(resp.Body)
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
}
}
func TestDeviceHandler_LogoutAllOtherDevices_MissingDeviceID(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicelogoutbad", "devicelogoutbad@test.com", "UserPass123!")
token := getToken(server.URL, "devicelogoutbad", "UserPass123!")
resp, body := doPost(server.URL+"/api/v1/devices/me/logout-others", token, nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestDeviceHandler_GetUserDevices_AdminCanViewOthers(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "deviceadmin", "deviceadmin@test.com", "AdminPass123!")
registerUser(server.URL, "deviceuserview", "deviceuserview@test.com", "UserPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin should return access token")
}
resp, body := doGet(server.URL+"/api/v1/devices/users/2", adminToken)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
}
func TestDeviceHandler_GetUserDevices_NonAdminForbidden(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "deviceuser1", "deviceuser1@test.com", "UserPass123!")
registerUser(server.URL, "deviceuser2", "deviceuser2@test.com", "UserPass123!")
token := getToken(server.URL, "deviceuser1", "UserPass123!")
resp, body := doGet(server.URL+"/api/v1/devices/users/2", token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
}
}
func TestDeviceHandler_GetAllDevices_AdminOnly(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "deviceadmin2", "deviceadmin2@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin should return access token")
}
resp, body := doGet(server.URL+"/api/v1/admin/devices", adminToken)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
}
func TestDeviceHandler_GetAllDevices_NonAdminForbidden(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "deviceuser3", "deviceuser3@test.com", "UserPass123!")
token := getToken(server.URL, "deviceuser3", "UserPass123!")
resp, body := doGet(server.URL+"/api/v1/admin/devices", token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusForbidden {
t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
}
}
func TestDeviceHandler_TrustDeviceByDeviceID(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicetrustiduser", "devicetrustid@test.com", "UserPass123!")
token := getToken(server.URL, "devicetrustiduser", "UserPass123!")
// Create device with specific device_id
resp, body := doPost(server.URL+"/api/v1/devices", token, map[string]interface{}{
"name": "Trust By ID Device",
"device_id": "my-unique-device-id",
"device_type": 1,
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusCreated {
t.Fatalf("expected create status %d, got %d, body: %s", http.StatusCreated, resp.StatusCode, body)
}
// Trust by device ID
trustResp, trustBody := doPost(server.URL+"/api/v1/devices/by-device-id/my-unique-device-id/trust", token, map[string]interface{}{
"trust_duration": "24h",
})
defer trustResp.Body.Close()
if trustResp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, trustResp.StatusCode, trustBody)
}
}
func TestDeviceHandler_TrustDeviceByDeviceID_EmptyID(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "devicetrustidbad", "devicetrustidbad@test.com", "UserPass123!")
token := getToken(server.URL, "devicetrustidbad", "UserPass123!")
// The route uses ":deviceId" path param, so empty ID would be a different route or 404
// Actually the route is /by-device-id/:deviceId/trust, so empty deviceId is not matched
// Let's test with a device ID that doesn't exist
resp, body := doPost(server.URL+"/api/v1/devices/by-device-id/nonexistent/trust", token, map[string]interface{}{
"trust_duration": "24h",
})
defer resp.Body.Close()
// Service returns error for non-existent device
if resp.StatusCode != http.StatusNotFound && resp.StatusCode != http.StatusInternalServerError {
t.Errorf("expected status 404 or 500 for non-existent device, got %d, body: %s", resp.StatusCode, body)
}
}

View File

@@ -0,0 +1,319 @@
package handler_test
import (
"bytes"
"encoding/json"
"fmt"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/api/handler"
"github.com/user-management-system/internal/api/middleware"
"github.com/user-management-system/internal/api/router"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/cache"
"github.com/user-management-system/internal/config"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
"github.com/user-management-system/internal/service"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
)
var exportDbCounter int64
func setupExportTestServer(t *testing.T) (*httptest.Server, string, string, func()) {
t.Helper()
gin.SetMode(gin.TestMode)
id := atomic.AddInt64(&exportDbCounter, 1)
dsn := fmt.Sprintf("file:exportdb_%d_%s?mode=memory&cache=shared", id, t.Name())
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: dsn,
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Skipf("skipping export test (SQLite unavailable): %v", err)
return nil, "", "", func() {}
}
if err := db.AutoMigrate(
&domain.User{},
&domain.Role{},
&domain.Permission{},
&domain.UserRole{},
&domain.RolePermission{},
); err != nil {
t.Fatalf("db migration failed: %v", err)
}
seedHandlerAuthzData(t, db)
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
HS256Secret: "test-export-secret-key",
AccessTokenExpire: 15 * time.Minute,
RefreshTokenExpire: 7 * 24 * time.Hour,
})
if err != nil {
t.Fatalf("create jwt manager failed: %v", err)
}
l1Cache := cache.NewL1Cache()
l2Cache := cache.NewRedisCache(false)
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
userRepo := repository.NewUserRepository(db)
roleRepo := repository.NewRoleRepository(db)
userRoleRepo := repository.NewUserRoleRepository(db)
authSvc := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
exportSvc := service.NewExportService(userRepo, nil)
exportHandler := handler.NewExportHandler(exportSvc)
rateLimitCfg := config.RateLimitConfig{}
rateLimitMiddleware := middleware.NewRateLimitMiddleware(rateLimitCfg)
authMiddleware := middleware.NewAuthMiddleware(
jwtManager, userRepo, userRoleRepo, l1Cache,
)
authMiddleware.SetCacheManager(cacheManager)
authHandler := handler.NewAuthHandler(authSvc)
r := router.NewRouter(
authHandler, nil, nil, nil, nil, nil,
authMiddleware, rateLimitMiddleware, nil,
nil, nil, nil, nil,
nil, exportHandler, nil, nil, nil, nil, nil, nil, nil,
)
engine := r.Setup()
server := httptest.NewServer(engine)
// Register a regular user
regBody := map[string]interface{}{
"username": fmt.Sprintf("exportuser_%d", id),
"password": "TestPass123!",
"email": fmt.Sprintf("ex_%d@test.com", id),
}
regBytes, _ := json.Marshal(regBody)
regResp, _ := http.Post(server.URL+"/api/v1/auth/register", "application/json", bytes.NewReader(regBytes))
io.ReadAll(regResp.Body)
regResp.Body.Close()
// Login as regular user
loginBody := map[string]interface{}{
"account": regBody["username"],
"password": regBody["password"],
}
loginBytes, _ := json.Marshal(loginBody)
loginResp, _ := http.Post(server.URL+"/api/v1/auth/login", "application/json", bytes.NewReader(loginBytes))
var loginResult struct {
Data struct {
AccessToken string `json:"access_token"`
} `json:"data"`
}
json.NewDecoder(loginResp.Body).Decode(&loginResult)
loginResp.Body.Close()
userToken := loginResult.Data.AccessToken
// Bootstrap admin
t.Setenv("BOOTSTRAP_SECRET", fmt.Sprintf("export-bootstrap-%d", id))
adminToken := bootstrapAdmin(server.URL, fmt.Sprintf("export-bootstrap-%d", id), fmt.Sprintf("exportadmin_%d", id), fmt.Sprintf("exa_%d@test.com", id), "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
return server, adminToken, userToken, func() {
server.Close()
if sqlDB, err := db.DB(); err == nil {
sqlDB.Close()
}
}
}
func TestExportHandler_ExportUsers(t *testing.T) {
server, adminToken, userToken, cleanup := setupExportTestServer(t)
defer cleanup()
tests := []struct {
name string
query string
token string
wantStatus int
}{
{
name: "success_csv",
query: "format=csv",
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "success_excel",
query: "format=xlsx",
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "forbidden_regular_user",
query: "format=csv",
token: userToken,
wantStatus: http.StatusForbidden,
},
{
name: "unauthorized",
query: "format=csv",
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
url := server.URL + "/api/v1/admin/users/export"
if tt.query != "" {
url = url + "?" + tt.query
}
resp, body := doGet(url, tt.token)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestExportHandler_ImportUsers(t *testing.T) {
server, adminToken, userToken, cleanup := setupExportTestServer(t)
defer cleanup()
csvData := []byte("\xEF\xBB\xBF用户名,密码,邮箱,手机号,昵称,性别,地区,个人简介\nimportuser1,Password123!,import1@test.com,13800138001,Import1,男,北京,简介1\n")
tests := []struct {
name string
fileBody []byte
filename string
token string
wantStatus int
}{
{
name: "success_csv",
fileBody: csvData,
filename: "users.csv",
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "forbidden_regular_user",
fileBody: csvData,
filename: "users.csv",
token: userToken,
wantStatus: http.StatusForbidden,
},
{
name: "unauthorized",
fileBody: csvData,
filename: "users.csv",
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var body bytes.Buffer
writer := multipart.NewWriter(&body)
part, err := writer.CreateFormFile("file", tt.filename)
if err != nil {
t.Fatalf("create form file failed: %v", err)
}
if _, err := part.Write(tt.fileBody); err != nil {
t.Fatalf("write file body failed: %v", err)
}
if err := writer.Close(); err != nil {
t.Fatalf("close multipart writer failed: %v", err)
}
req, err := http.NewRequest(http.MethodPost, server.URL+"/api/v1/admin/users/import", &body)
if err != nil {
t.Fatalf("create request failed: %v", err)
}
if tt.token != "" {
req.Header.Set("Authorization", "Bearer "+tt.token)
}
req.Header.Set("Content-Type", writer.FormDataContentType())
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
respBody, _ := io.ReadAll(resp.Body)
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, string(respBody))
}
})
}
}
func TestExportHandler_GetImportTemplate(t *testing.T) {
server, adminToken, userToken, cleanup := setupExportTestServer(t)
defer cleanup()
tests := []struct {
name string
query string
token string
wantStatus int
}{
{
name: "success_csv",
query: "format=csv",
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "success_excel",
query: "format=xlsx",
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "forbidden_regular_user",
query: "format=csv",
token: userToken,
wantStatus: http.StatusForbidden,
},
{
name: "unauthorized",
query: "format=csv",
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
url := server.URL + "/api/v1/admin/users/import/template"
if tt.query != "" {
url = url + "?" + tt.query
}
resp, body := doGet(url, tt.token)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}

View File

@@ -0,0 +1,308 @@
package handler_test
import (
"bytes"
"encoding/json"
"net/http"
"testing"
)
func TestPasswordResetHandler_ForgotPassword(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "resetuser", "resetuser@test.com", "UserPass123!")
resp, body := doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
"email": "resetuser@test.com",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["code"] != float64(0) {
t.Errorf("expected code 0, got %v", result["code"])
}
}
func TestPasswordResetHandler_ForgotPassword_MissingEmail(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
resp, body := doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{})
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestPasswordResetHandler_ForgotPassword_NonExistentEmail(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
// For non-existent email, the service returns success to prevent user enumeration
resp, body := doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
"email": "nonexistent@test.com",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Errorf("expected status %d for non-existent email, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
}
func TestPasswordResetHandler_ValidateResetToken(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "validatetokenuser", "validatetoken@test.com", "UserPass123!")
// First request a password reset to generate a token
_, _ = doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
"email": "validatetoken@test.com",
})
// We can't easily get the token from email, so test with an invalid token
resp, body := doPost(server.URL+"/api/v1/auth/password/validate", "", map[string]interface{}{
"token": "invalid-token-12345",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["code"] != float64(0) {
t.Errorf("expected code 0, got %v", result["code"])
}
data, ok := result["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected data in response, got %s", body)
}
if data["valid"] != false {
t.Errorf("expected valid=false for invalid token, got %v", data["valid"])
}
}
func TestPasswordResetHandler_ValidateResetToken_MissingToken(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
resp, body := doPost(server.URL+"/api/v1/auth/password/validate", "", map[string]interface{}{})
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestPasswordResetHandler_ResetPassword(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "resetpwuser", "resetpw@test.com", "UserPass123!")
// Request reset to generate token
_, _ = doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
"email": "resetpw@test.com",
})
// Since we can't get the token, test with invalid token
resp, body := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
"token": "invalid-token",
"new_password": "NewPass123!",
})
defer resp.Body.Close()
// Should fail because token is invalid (service returns 404 for "不存在")
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusBadRequest && resp.StatusCode != http.StatusNotFound {
t.Errorf("expected status 401, 400 or 404 for invalid token, got %d, body: %s", resp.StatusCode, body)
}
}
func TestPasswordResetHandler_ResetPassword_MissingToken(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
resp, body := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
"new_password": "NewPass123!",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestPasswordResetHandler_ResetPassword_MissingPassword(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
resp, body := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
"token": "some-token",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestPasswordResetHandler_ResetPassword_WeakPassword(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "resetpwweak", "resetpwweak@test.com", "UserPass123!")
// We need a valid token to test weak password rejection
// Let's manually create one through the cache by using forgot-password
_, _ = doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
"email": "resetpwweak@test.com",
})
// Use invalid token - the validation happens before password strength check
resp, body := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
"token": "invalid-token",
"new_password": "123",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusBadRequest && resp.StatusCode != http.StatusNotFound {
t.Errorf("expected status 401, 400 or 404, got %d, body: %s", resp.StatusCode, body)
}
}
func TestPasswordResetHandler_ForgotPasswordByPhone_ServiceUnavailable(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
// The password reset handler in the test setup does not have SMS service configured
resp, body := doPost(server.URL+"/api/v1/auth/forgot-password/phone", "", map[string]interface{}{
"phone": "13800138000",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusServiceUnavailable {
t.Errorf("expected status %d, got %d, body: %s", http.StatusServiceUnavailable, resp.StatusCode, body)
}
}
func TestPasswordResetHandler_ResetPasswordByPhone_MissingFields(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
resp, body := doPost(server.URL+"/api/v1/auth/reset-password/phone", "", map[string]interface{}{})
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestPasswordResetHandler_ResetPasswordByPhone_InvalidCode(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "resetphoneuser", "resetphone@test.com", "UserPass123!")
resp, body := doPost(server.URL+"/api/v1/auth/reset-password/phone", "", map[string]interface{}{
"phone": "13800138000",
"code": "000000",
"new_password": "NewPass123!",
})
defer resp.Body.Close()
// Should fail because no code was sent
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status 401 or 400 for invalid code, got %d, body: %s", resp.StatusCode, body)
}
}
func TestPasswordResetHandler_ForgotPassword_InvalidJSON(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
req, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/forgot-password", bytes.NewReader([]byte("not json")))
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d for invalid JSON, got %d", http.StatusBadRequest, resp.StatusCode)
}
}
func TestPasswordResetHandler_FullFlow(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "fullflowuser", "fullflow@test.com", "UserPass123!")
// Step 1: Request password reset
forgotResp, forgotBody := doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
"email": "fullflow@test.com",
})
defer forgotResp.Body.Close()
if forgotResp.StatusCode != http.StatusOK {
t.Fatalf("forgot-password failed: status=%d body=%s", forgotResp.StatusCode, forgotBody)
}
// Step 2: Validate token (we don't know the real token, so it will be invalid)
validateResp, validateBody := doPost(server.URL+"/api/v1/auth/password/validate", "", map[string]interface{}{
"token": "unknown-token",
})
defer validateResp.Body.Close()
if validateResp.StatusCode != http.StatusOK {
t.Fatalf("validate token failed: status=%d body=%s", validateResp.StatusCode, validateBody)
}
var validateResult map[string]interface{}
if err := json.Unmarshal([]byte(validateBody), &validateResult); err != nil {
t.Fatalf("failed to parse validate response: %v", err)
}
validateData, ok := validateResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected validate data, got %s", validateBody)
}
if validateData["valid"] != false {
t.Errorf("expected valid=false for unknown token, got %v", validateData["valid"])
}
// Step 3: Try reset with invalid token
resetResp, resetBody := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
"token": "unknown-token",
"new_password": "NewPass123!",
})
defer resetResp.Body.Close()
// Should fail because token is invalid (service returns 404 for "不存在")
if resetResp.StatusCode != http.StatusUnauthorized && resetResp.StatusCode != http.StatusNotFound {
t.Errorf("expected status 401 or 404 for invalid token reset, got %d, body: %s", resetResp.StatusCode, resetBody)
}
// Step 4: Verify old password still works
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "fullflowuser",
"password": "UserPass123!",
})
defer loginResp.Body.Close()
if loginResp.StatusCode != http.StatusOK {
t.Fatalf("old password should still work: status=%d body=%s", loginResp.StatusCode, loginBody)
}
}

View File

@@ -0,0 +1,455 @@
package handler_test
import (
"encoding/json"
"fmt"
"net/http"
"testing"
)
func TestPermissionHandler_CreatePermission(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
if ok := registerUser(server.URL, "permuser", "permuser@test.com", "UserPass123!"); !ok {
t.Fatal("register user failed")
}
userToken := getToken(server.URL, "permuser", "UserPass123!")
if userToken == "" {
t.Fatal("get user token failed")
}
tests := []struct {
name string
payload map[string]interface{}
token string
wantStatus int
}{
{
name: "success",
payload: map[string]interface{}{
"name": "Test Permission",
"code": "test:permission:create",
"type": 2,
},
token: adminToken,
wantStatus: http.StatusCreated,
},
{
name: "unauthorized",
payload: map[string]interface{}{
"name": "Test Permission",
"code": "test:permission:unauth",
"type": 2,
},
token: "",
wantStatus: http.StatusUnauthorized,
},
{
name: "forbidden",
payload: map[string]interface{}{
"name": "Test Permission",
"code": "test:permission:forbid",
"type": 2,
},
token: userToken,
wantStatus: http.StatusForbidden,
},
{
name: "invalid_type",
payload: map[string]interface{}{
"name": "Test Permission",
"code": "test:permission:badtype",
"type": 5,
},
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "missing_required_fields",
payload: map[string]interface{}{"name": "Missing Code"},
token: adminToken,
wantStatus: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doPost(server.URL+"/api/v1/permissions", tt.token, tt.payload)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestPermissionHandler_ListPermissions(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
if ok := registerUser(server.URL, "permuser", "permuser@test.com", "UserPass123!"); !ok {
t.Fatal("register user failed")
}
userToken := getToken(server.URL, "permuser", "UserPass123!")
if userToken == "" {
t.Fatal("get user token failed")
}
tests := []struct {
name string
token string
wantStatus int
}{
{
name: "success_admin",
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "forbidden_regular_user",
token: userToken,
wantStatus: http.StatusForbidden,
},
{
name: "unauthorized",
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doGet(server.URL+"/api/v1/permissions", tt.token)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestPermissionHandler_GetPermission(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
// Create a permission to retrieve
createResp, createBody := doPost(server.URL+"/api/v1/permissions", adminToken, map[string]interface{}{
"name": "Get Permission Test",
"code": "test:permission:get",
"type": 2,
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create permission failed: %d %s", createResp.StatusCode, createBody)
}
var createResult map[string]interface{}
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
t.Fatalf("parse create response failed: %v", err)
}
permData, ok := createResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected data in create response, got %s", createBody)
}
permID := int64(permData["id"].(float64))
tests := []struct {
name string
permID string
token string
wantStatus int
}{
{
name: "success",
permID: fmt.Sprintf("%d", permID),
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "not_found",
permID: "99999",
token: adminToken,
wantStatus: http.StatusNotFound,
},
{
name: "invalid_id",
permID: "invalid",
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "unauthorized",
permID: fmt.Sprintf("%d", permID),
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doGet(server.URL+"/api/v1/permissions/"+tt.permID, tt.token)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestPermissionHandler_UpdatePermission(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
// Create a permission to update
createResp, createBody := doPost(server.URL+"/api/v1/permissions", adminToken, map[string]interface{}{
"name": "Update Permission Test",
"code": "test:permission:update",
"type": 2,
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create permission failed: %d %s", createResp.StatusCode, createBody)
}
var createResult map[string]interface{}
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
t.Fatalf("parse create response failed: %v", err)
}
permData := createResult["data"].(map[string]interface{})
permID := int64(permData["id"].(float64))
tests := []struct {
name string
permID string
payload map[string]interface{}
token string
wantStatus int
}{
{
name: "success",
permID: fmt.Sprintf("%d", permID),
payload: map[string]interface{}{
"name": "Updated Permission Name",
},
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "invalid_id",
permID: "invalid",
payload: map[string]interface{}{
"name": "Updated Permission Name",
},
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "unauthorized",
permID: fmt.Sprintf("%d", permID),
payload: map[string]interface{}{
"name": "Updated Permission Name",
},
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doPut(server.URL+"/api/v1/permissions/"+tt.permID, tt.token, tt.payload)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestPermissionHandler_DeletePermission(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
// Create a permission to delete
createResp, createBody := doPost(server.URL+"/api/v1/permissions", adminToken, map[string]interface{}{
"name": "Delete Permission Test",
"code": "test:permission:delete",
"type": 2,
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create permission failed: %d %s", createResp.StatusCode, createBody)
}
var createResult map[string]interface{}
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
t.Fatalf("parse create response failed: %v", err)
}
permData := createResult["data"].(map[string]interface{})
permID := int64(permData["id"].(float64))
tests := []struct {
name string
permID string
token string
wantStatus int
}{
{
name: "success",
permID: fmt.Sprintf("%d", permID),
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "invalid_id",
permID: "invalid",
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "unauthorized",
permID: fmt.Sprintf("%d", permID),
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doDelete(server.URL+"/api/v1/permissions/"+tt.permID, tt.token)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestPermissionHandler_UpdatePermissionStatus(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
// Create a permission
createResp, createBody := doPost(server.URL+"/api/v1/permissions", adminToken, map[string]interface{}{
"name": "Status Permission Test",
"code": "test:permission:status",
"type": 2,
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create permission failed: %d %s", createResp.StatusCode, createBody)
}
var createResult map[string]interface{}
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
t.Fatalf("parse create response failed: %v", err)
}
permData := createResult["data"].(map[string]interface{})
permID := int64(permData["id"].(float64))
tests := []struct {
name string
permID string
payload map[string]interface{}
token string
wantStatus int
}{
{
name: "success_numeric",
permID: fmt.Sprintf("%d", permID),
payload: map[string]interface{}{
"status": 0,
},
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "invalid_id",
permID: "invalid",
payload: map[string]interface{}{
"status": 0,
},
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "unauthorized",
permID: fmt.Sprintf("%d", permID),
payload: map[string]interface{}{
"status": 0,
},
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doPut(server.URL+"/api/v1/permissions/"+tt.permID+"/status", tt.token, tt.payload)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestPermissionHandler_GetPermissionTree(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
resp, body := doGet(server.URL+"/api/v1/permissions/tree", adminToken)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("parse response failed: %v", err)
}
if result["code"] != float64(0) {
t.Errorf("expected code 0, got %v", result["code"])
}
if result["data"] == nil {
t.Errorf("expected data in response")
}
}

View File

@@ -0,0 +1,527 @@
package handler_test
import (
"encoding/json"
"fmt"
"net/http"
"testing"
)
func TestRoleHandler_CreateRole(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
if ok := registerUser(server.URL, "roleuser", "roleuser@test.com", "UserPass123!"); !ok {
t.Fatal("register user failed")
}
userToken := getToken(server.URL, "roleuser", "UserPass123!")
if userToken == "" {
t.Fatal("get user token failed")
}
tests := []struct {
name string
payload map[string]interface{}
token string
wantStatus int
}{
{
name: "success",
payload: map[string]interface{}{
"name": "Test Role",
"code": "test_role_create",
},
token: adminToken,
wantStatus: http.StatusCreated,
},
{
name: "unauthorized",
payload: map[string]interface{}{
"name": "Test Role Unauth",
"code": "test_role_unauth",
},
token: "",
wantStatus: http.StatusUnauthorized,
},
{
name: "forbidden",
payload: map[string]interface{}{
"name": "Test Role Forbidden",
"code": "test_role_forbidden",
},
token: userToken,
wantStatus: http.StatusForbidden,
},
{
name: "missing_required_fields",
payload: map[string]interface{}{"name": "Missing Code"},
token: adminToken,
wantStatus: http.StatusBadRequest,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doPost(server.URL+"/api/v1/roles", tt.token, tt.payload)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestRoleHandler_ListRoles(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
if ok := registerUser(server.URL, "roleuser", "roleuser@test.com", "UserPass123!"); !ok {
t.Fatal("register user failed")
}
userToken := getToken(server.URL, "roleuser", "UserPass123!")
if userToken == "" {
t.Fatal("get user token failed")
}
tests := []struct {
name string
token string
wantStatus int
}{
{
name: "success_admin",
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "forbidden_regular_user",
token: userToken,
wantStatus: http.StatusForbidden,
},
{
name: "unauthorized",
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doGet(server.URL+"/api/v1/roles", tt.token)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestRoleHandler_GetRole(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
// Create a role to retrieve
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
"name": "Get Role Test",
"code": "test_role_get",
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
}
var createResult map[string]interface{}
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
t.Fatalf("parse create response failed: %v", err)
}
roleData := createResult["data"].(map[string]interface{})
roleID := int64(roleData["id"].(float64))
tests := []struct {
name string
roleID string
token string
wantStatus int
}{
{
name: "success",
roleID: fmt.Sprintf("%d", roleID),
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "not_found",
roleID: "99999",
token: adminToken,
wantStatus: http.StatusNotFound,
},
{
name: "invalid_id",
roleID: "invalid",
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "unauthorized",
roleID: fmt.Sprintf("%d", roleID),
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doGet(server.URL+"/api/v1/roles/"+tt.roleID, tt.token)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestRoleHandler_UpdateRole(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
// Create a role to update
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
"name": "Update Role Test",
"code": "test_role_update",
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
}
var createResult map[string]interface{}
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
t.Fatalf("parse create response failed: %v", err)
}
roleData := createResult["data"].(map[string]interface{})
roleID := int64(roleData["id"].(float64))
tests := []struct {
name string
roleID string
payload map[string]interface{}
token string
wantStatus int
}{
{
name: "success",
roleID: fmt.Sprintf("%d", roleID),
payload: map[string]interface{}{
"name": "Updated Role Name",
},
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "invalid_id",
roleID: "invalid",
payload: map[string]interface{}{
"name": "Updated Role Name",
},
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "unauthorized",
roleID: fmt.Sprintf("%d", roleID),
payload: map[string]interface{}{
"name": "Updated Role Name",
},
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doPut(server.URL+"/api/v1/roles/"+tt.roleID, tt.token, tt.payload)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestRoleHandler_DeleteRole(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
// Create a role to delete
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
"name": "Delete Role Test",
"code": "test_role_delete",
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
}
var createResult map[string]interface{}
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
t.Fatalf("parse create response failed: %v", err)
}
roleData := createResult["data"].(map[string]interface{})
roleID := int64(roleData["id"].(float64))
tests := []struct {
name string
roleID string
token string
wantStatus int
}{
{
name: "success",
roleID: fmt.Sprintf("%d", roleID),
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "invalid_id",
roleID: "invalid",
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "unauthorized",
roleID: fmt.Sprintf("%d", roleID),
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doDelete(server.URL+"/api/v1/roles/"+tt.roleID, tt.token)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestRoleHandler_UpdateRoleStatus(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
// Create a role
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
"name": "Status Role Test",
"code": "test_role_status",
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
}
var createResult map[string]interface{}
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
t.Fatalf("parse create response failed: %v", err)
}
roleData := createResult["data"].(map[string]interface{})
roleID := int64(roleData["id"].(float64))
tests := []struct {
name string
roleID string
payload map[string]interface{}
token string
wantStatus int
}{
{
name: "success_disabled",
roleID: fmt.Sprintf("%d", roleID),
payload: map[string]interface{}{
"status": "disabled",
},
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "success_enabled",
roleID: fmt.Sprintf("%d", roleID),
payload: map[string]interface{}{
"status": "enabled",
},
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "invalid_status",
roleID: fmt.Sprintf("%d", roleID),
payload: map[string]interface{}{
"status": "invalid_status",
},
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "invalid_id",
roleID: "invalid",
payload: map[string]interface{}{
"status": "disabled",
},
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "unauthorized",
roleID: fmt.Sprintf("%d", roleID),
payload: map[string]interface{}{
"status": "disabled",
},
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doPut(server.URL+"/api/v1/roles/"+tt.roleID+"/status", tt.token, tt.payload)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}
func TestRoleHandler_GetRolePermissions(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
// Use the admin role (id=1) for testing
resp, body := doGet(server.URL+"/api/v1/roles/1/permissions", adminToken)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("parse response failed: %v", err)
}
if result["code"] != float64(0) {
t.Errorf("expected code 0, got %v", result["code"])
}
if result["data"] == nil {
t.Errorf("expected data in response")
}
}
func TestRoleHandler_AssignPermissions(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
if adminToken == "" {
t.Fatal("bootstrap admin failed")
}
// Create a role
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
"name": "Assign Perm Role Test",
"code": "test_role_assign_perm",
})
defer createResp.Body.Close()
if createResp.StatusCode != http.StatusCreated {
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
}
var createResult map[string]interface{}
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
t.Fatalf("parse create response failed: %v", err)
}
roleData := createResult["data"].(map[string]interface{})
roleID := int64(roleData["id"].(float64))
tests := []struct {
name string
roleID string
payload map[string]interface{}
token string
wantStatus int
}{
{
name: "success",
roleID: fmt.Sprintf("%d", roleID),
payload: map[string]interface{}{
"permission_ids": []int64{1, 2},
},
token: adminToken,
wantStatus: http.StatusOK,
},
{
name: "invalid_id",
roleID: "invalid",
payload: map[string]interface{}{
"permission_ids": []int64{1},
},
token: adminToken,
wantStatus: http.StatusBadRequest,
},
{
name: "unauthorized",
roleID: fmt.Sprintf("%d", roleID),
payload: map[string]interface{}{
"permission_ids": []int64{1},
},
token: "",
wantStatus: http.StatusUnauthorized,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, body := doPut(server.URL+"/api/v1/roles/"+tt.roleID+"/permissions", tt.token, tt.payload)
defer resp.Body.Close()
if resp.StatusCode != tt.wantStatus {
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
}
})
}
}

View File

@@ -0,0 +1,855 @@
package handler_test
import (
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/api/handler"
"github.com/user-management-system/internal/auth"
)
func doPostForm(targetURL, token string, data url.Values) (*http.Response, string) {
var bodyReader io.Reader
if data != nil {
bodyReader = strings.NewReader(data.Encode())
}
req, _ := http.NewRequest("POST", targetURL, bodyReader)
if token != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{}
resp, _ := client.Do(req)
bodyBytes, _ := io.ReadAll(resp.Body)
resp.Body.Close()
return resp, string(bodyBytes)
}
func setupSSOTestServer(t *testing.T) (*httptest.Server, func()) {
t.Helper()
gin.SetMode(gin.TestMode)
engine := gin.New()
engine.Use(gin.Recovery())
ssoManager := auth.NewSSOManager()
clientsStore := auth.NewDefaultSSOClientsStore()
clientsStore.RegisterClient(&auth.SSOClient{
ClientID: "test-client",
ClientSecret: "test-secret",
Name: "Test Client",
RedirectURIs: []string{"http://localhost:8080/callback"},
})
ssoHandler := handler.NewSSOHandler(ssoManager, clientsStore)
// Simple auth middleware for testing
authMiddleware := func() gin.HandlerFunc {
return func(c *gin.Context) {
token := c.GetHeader("Authorization")
if token == "" || token == "Bearer " {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
return
}
c.Set("user_id", int64(1))
c.Set("username", "testuser")
c.Next()
}
}()
ssoGroup := engine.Group("/api/v1/sso")
ssoGroup.Use(authMiddleware)
{
ssoGroup.GET("/authorize", ssoHandler.Authorize)
ssoGroup.POST("/token", ssoHandler.Token)
ssoGroup.POST("/introspect", ssoHandler.Introspect)
ssoGroup.POST("/revoke", ssoHandler.Revoke)
ssoGroup.GET("/userinfo", ssoHandler.UserInfo)
}
server := httptest.NewServer(engine)
return server, func() {
server.Close()
}
}
func TestSSOHandler_Authorize_MissingParams(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, body := doGet(server.URL+"/api/v1/sso/authorize", "Bearer test-token")
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestSSOHandler_Authorize_UnsupportedResponseType(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, body := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=unsupported", "Bearer test-token")
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestSSOHandler_Authorize_Unauthorized(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "")
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
}
func TestSSOHandler_Authorize_CodeFlow(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code&state=xyz", "Bearer test-token")
defer resp.Body.Close()
if resp.StatusCode != http.StatusFound {
t.Fatalf("expected status %d (redirect), got %d", http.StatusFound, resp.StatusCode)
}
location := resp.Header.Get("Location")
if location == "" {
t.Fatal("expected redirect location")
}
if !strings.Contains(location, "code=") {
t.Errorf("expected redirect with code, got %s", location)
}
if !strings.Contains(location, "state=xyz") {
t.Errorf("expected redirect with state, got %s", location)
}
}
func TestSSOHandler_Authorize_InvalidRedirectURI(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, body := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://evil.com/callback&response_type=code", "Bearer test-token")
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestSSOHandler_Authorize_TokenFlow(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=token&state=abc", "Bearer test-token")
defer resp.Body.Close()
if resp.StatusCode != http.StatusFound {
t.Fatalf("expected status %d (redirect), got %d", http.StatusFound, resp.StatusCode)
}
location := resp.Header.Get("Location")
if location == "" {
t.Fatal("expected redirect location")
}
if !strings.Contains(location, "access_token=") {
t.Errorf("expected redirect with access_token, got %s", location)
}
}
func TestSSOHandler_Token_MissingParams(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", nil)
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestSSOHandler_Token_InvalidGrantType(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
formData := url.Values{}
formData.Set("grant_type", "password")
formData.Set("client_id", "test-client")
formData.Set("client_secret", "test-secret")
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestSSOHandler_Token_InvalidClient(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("code", "some-code")
formData.Set("client_id", "invalid-client")
formData.Set("client_secret", "wrong-secret")
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
}
}
func TestSSOHandler_Token_InvalidCode(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("code", "invalid-code")
formData.Set("client_id", "test-client")
formData.Set("client_secret", "test-secret")
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
}
}
func TestSSOHandler_Token_Success(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
// First authorize to get a code
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
defer authResp.Body.Close()
if authResp.StatusCode != http.StatusFound {
t.Fatalf("expected authorize redirect, got %d", authResp.StatusCode)
}
location := authResp.Header.Get("Location")
parsedURL, err := url.Parse(location)
if err != nil {
t.Fatalf("failed to parse redirect URL: %v", err)
}
code := parsedURL.Query().Get("code")
if code == "" {
t.Fatal("expected authorization code in redirect")
}
// Exchange code for token
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("code", code)
formData.Set("client_id", "test-client")
formData.Set("client_secret", "test-secret")
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var tokenResp handler.TokenResponse
if err := json.Unmarshal([]byte(body), &tokenResp); err != nil {
t.Fatalf("failed to parse token response: %v", err)
}
if tokenResp.AccessToken == "" {
t.Errorf("expected access_token in response")
}
if tokenResp.TokenType != "Bearer" {
t.Errorf("expected token_type Bearer, got %s", tokenResp.TokenType)
}
}
func TestSSOHandler_Introspect_MissingToken(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, body := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{})
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestSSOHandler_Introspect_InvalidToken(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, body := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
"token": "invalid-token",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result handler.IntrospectResponse
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse introspect response: %v", err)
}
if result.Active != false {
t.Errorf("expected active=false for invalid token, got %v", result.Active)
}
}
func TestSSOHandler_Introspect_ValidToken(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
// Authorize and get token
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
defer authResp.Body.Close()
location := authResp.Header.Get("Location")
parsedURL, _ := url.Parse(location)
code := parsedURL.Query().Get("code")
tokenForm := url.Values{}
tokenForm.Set("grant_type", "authorization_code")
tokenForm.Set("code", code)
tokenForm.Set("client_id", "test-client")
tokenForm.Set("client_secret", "test-secret")
tokenResp, tokenBody := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", tokenForm)
defer tokenResp.Body.Close()
if tokenResp.StatusCode != http.StatusOK {
t.Fatalf("token exchange failed: status=%d body=%s", tokenResp.StatusCode, tokenBody)
}
var tokenResult handler.TokenResponse
if err := json.Unmarshal([]byte(tokenBody), &tokenResult); err != nil {
t.Fatalf("failed to parse token response: %v", err)
}
// Introspect the token
resp, body := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
"token": tokenResult.AccessToken,
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result handler.IntrospectResponse
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse introspect response: %v", err)
}
if result.Active != true {
t.Errorf("expected active=true for valid token, got %v", result.Active)
}
if result.UserID != 1 {
t.Errorf("expected user_id=1, got %d", result.UserID)
}
}
func TestSSOHandler_Revoke_MissingToken(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, body := doPost(server.URL+"/api/v1/sso/revoke", "Bearer test-token", map[string]interface{}{})
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestSSOHandler_Revoke_Success(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
// Authorize and get token
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
defer authResp.Body.Close()
location := authResp.Header.Get("Location")
parsedURL, _ := url.Parse(location)
code := parsedURL.Query().Get("code")
tokenForm := url.Values{}
tokenForm.Set("grant_type", "authorization_code")
tokenForm.Set("code", code)
tokenForm.Set("client_id", "test-client")
tokenForm.Set("client_secret", "test-secret")
tokenResp, tokenBody := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", tokenForm)
defer tokenResp.Body.Close()
if tokenResp.StatusCode != http.StatusOK {
t.Fatalf("token exchange failed: status=%d body=%s", tokenResp.StatusCode, tokenBody)
}
var tokenResult handler.TokenResponse
if err := json.Unmarshal([]byte(tokenBody), &tokenResult); err != nil {
t.Fatalf("failed to parse token response: %v", err)
}
// Revoke the token
resp, body := doPost(server.URL+"/api/v1/sso/revoke", "Bearer test-token", map[string]interface{}{
"token": tokenResult.AccessToken,
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
// Verify token is revoked
introspectResp, introspectBody := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
"token": tokenResult.AccessToken,
})
defer introspectResp.Body.Close()
if introspectResp.StatusCode != http.StatusOK {
t.Fatalf("introspect failed: status=%d body=%s", introspectResp.StatusCode, introspectBody)
}
var introspectResult handler.IntrospectResponse
if err := json.Unmarshal([]byte(introspectBody), &introspectResult); err != nil {
t.Fatalf("failed to parse introspect response: %v", err)
}
if introspectResult.Active != false {
t.Errorf("expected active=false after revoke, got %v", introspectResult.Active)
}
}
func TestSSOHandler_UserInfo_Unauthorized(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, _ := doGet(server.URL+"/api/v1/sso/userinfo", "")
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
}
func TestSSOHandler_UserInfo_Success(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, body := doGet(server.URL+"/api/v1/sso/userinfo", "Bearer test-token")
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["code"] != float64(0) {
t.Errorf("expected code 0, got %v", result["code"])
}
data, ok := result["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected data in response, got %s", body)
}
if data["user_id"] != float64(1) {
t.Errorf("expected user_id=1, got %v", data["user_id"])
}
if data["username"] != "testuser" {
t.Errorf("expected username=testuser, got %v", data["username"])
}
}
func TestSSOHandler_Token_InvalidClientSecret(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
// Authorize to get a code
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
defer authResp.Body.Close()
location := authResp.Header.Get("Location")
parsedURL, _ := url.Parse(location)
code := parsedURL.Query().Get("code")
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("code", code)
formData.Set("client_id", "test-client")
formData.Set("client_secret", "wrong-secret")
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
}
}
func TestSSOHandler_Authorize_MissingClientID(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, body := doGet(server.URL+"/api/v1/sso/authorize?redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestSSOHandler_Introspect_FormData(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
// Test that introspect accepts form-encoded data
formData := url.Values{}
formData.Set("token", "some-token")
req, _ := http.NewRequest("POST", server.URL+"/api/v1/sso/introspect", strings.NewReader(formData.Encode()))
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := json.Marshal(resp.Body)
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
}
}
func TestSSOHandler_Token_FormData(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
// Authorize to get a code
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
defer authResp.Body.Close()
location := authResp.Header.Get("Location")
parsedURL, _ := url.Parse(location)
code := parsedURL.Query().Get("code")
// Test that token accepts form-encoded data
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("code", code)
formData.Set("client_id", "test-client")
formData.Set("client_secret", "test-secret")
req, _ := http.NewRequest("POST", server.URL+"/api/v1/sso/token", strings.NewReader(formData.Encode()))
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
bodyBytes, _ := json.Marshal(resp.Body)
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
}
}
func TestSSOHandler_Revoke_FormData(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
formData := url.Values{}
formData.Set("token", "some-token")
req, _ := http.NewRequest("POST", server.URL+"/api/v1/sso/revoke", strings.NewReader(formData.Encode()))
req.Header.Set("Authorization", "Bearer test-token")
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
bodyBytes, _ := json.Marshal(resp.Body)
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
}
}
func TestSSOHandler_Authorize_UnknownClientID(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, body := doGet(server.URL+"/api/v1/sso/authorize?client_id=unknown-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
defer resp.Body.Close()
// When client is unknown, redirect_uri validation fails
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestSSOHandler_Token_WithoutAuth(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("code", "some-code")
formData.Set("client_id", "test-client")
formData.Set("client_secret", "test-secret")
resp, _ := doPostForm(server.URL+"/api/v1/sso/token", "", formData)
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
}
func TestSSOHandler_UserInfo_WithoutAuth(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, _ := doGet(server.URL+"/api/v1/sso/userinfo", "")
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
}
func TestSSOHandler_Introspect_WithoutAuth(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, _ := doPost(server.URL+"/api/v1/sso/introspect", "", map[string]interface{}{
"token": "some-token",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
}
func TestSSOHandler_Revoke_WithoutAuth(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
resp, _ := doPost(server.URL+"/api/v1/sso/revoke", "", map[string]interface{}{
"token": "some-token",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
}
func TestSSOHandler_Authorize_InvalidClientID(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
// Test with valid redirect URI but unknown client
resp, body := doGet(server.URL+"/api/v1/sso/authorize?client_id=unknown&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestSSOHandler_Token_MissingCode(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
formData := url.Values{}
formData.Set("grant_type", "authorization_code")
formData.Set("client_id", "test-client")
formData.Set("client_secret", "test-secret")
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
defer resp.Body.Close()
// Code is empty, so validate should fail
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
}
}
func TestSSOHandler_FullFlow(t *testing.T) {
server, cleanup := setupSSOTestServer(t)
defer cleanup()
// Step 1: Authorize
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code&state=my-state", "Bearer test-token")
defer authResp.Body.Close()
if authResp.StatusCode != http.StatusFound {
t.Fatalf("authorize failed: status=%d", authResp.StatusCode)
}
location := authResp.Header.Get("Location")
parsedURL, _ := url.Parse(location)
code := parsedURL.Query().Get("code")
state := parsedURL.Query().Get("state")
if code == "" {
t.Fatal("expected authorization code")
}
if state != "my-state" {
t.Errorf("expected state=my-state, got %s", state)
}
// Step 2: Exchange code for token
tokenForm := url.Values{}
tokenForm.Set("grant_type", "authorization_code")
tokenForm.Set("code", code)
tokenForm.Set("client_id", "test-client")
tokenForm.Set("client_secret", "test-secret")
tokenResp, tokenBody := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", tokenForm)
defer tokenResp.Body.Close()
if tokenResp.StatusCode != http.StatusOK {
t.Fatalf("token exchange failed: status=%d body=%s", tokenResp.StatusCode, tokenBody)
}
var tokenResult handler.TokenResponse
if err := json.Unmarshal([]byte(tokenBody), &tokenResult); err != nil {
t.Fatalf("failed to parse token response: %v", err)
}
if tokenResult.AccessToken == "" {
t.Fatal("expected access_token")
}
// Step 3: Introspect token
introspectResp, introspectBody := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
"token": tokenResult.AccessToken,
})
defer introspectResp.Body.Close()
if introspectResp.StatusCode != http.StatusOK {
t.Fatalf("introspect failed: status=%d body=%s", introspectResp.StatusCode, introspectBody)
}
var introspectResult handler.IntrospectResponse
if err := json.Unmarshal([]byte(introspectBody), &introspectResult); err != nil {
t.Fatalf("failed to parse introspect response: %v", err)
}
if !introspectResult.Active {
t.Error("expected token to be active")
}
if introspectResult.UserID != 1 {
t.Errorf("expected user_id=1, got %d", introspectResult.UserID)
}
// Step 4: Get userinfo
userinfoResp, userinfoBody := doGet(server.URL+"/api/v1/sso/userinfo", "Bearer test-token")
defer userinfoResp.Body.Close()
if userinfoResp.StatusCode != http.StatusOK {
t.Fatalf("userinfo failed: status=%d body=%s", userinfoResp.StatusCode, userinfoBody)
}
var userinfoResult map[string]interface{}
if err := json.Unmarshal([]byte(userinfoBody), &userinfoResult); err != nil {
t.Fatalf("failed to parse userinfo response: %v", err)
}
userinfoData, ok := userinfoResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected userinfo data, got %s", userinfoBody)
}
if userinfoData["username"] != "testuser" {
t.Errorf("expected username=testuser, got %v", userinfoData["username"])
}
// Step 5: Revoke token
revokeResp, revokeBody := doPost(server.URL+"/api/v1/sso/revoke", "Bearer test-token", map[string]interface{}{
"token": tokenResult.AccessToken,
})
defer revokeResp.Body.Close()
if revokeResp.StatusCode != http.StatusOK {
t.Fatalf("revoke failed: status=%d body=%s", revokeResp.StatusCode, revokeBody)
}
// Step 6: Verify token is revoked
finalIntrospectResp, finalIntrospectBody := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
"token": tokenResult.AccessToken,
})
defer finalIntrospectResp.Body.Close()
if finalIntrospectResp.StatusCode != http.StatusOK {
t.Fatalf("final introspect failed: status=%d body=%s", finalIntrospectResp.StatusCode, finalIntrospectBody)
}
var finalResult handler.IntrospectResponse
if err := json.Unmarshal([]byte(finalIntrospectBody), &finalResult); err != nil {
t.Fatalf("failed to parse final introspect response: %v", err)
}
if finalResult.Active {
t.Error("expected token to be inactive after revoke")
}
}
func TestSSOHandler_Authorize_NoClientStore(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
ssoManager := auth.NewSSOManager()
// Pass nil clientsStore
ssoHandler := handler.NewSSOHandler(ssoManager, nil)
authMiddleware := func() gin.HandlerFunc {
return func(c *gin.Context) {
c.Set("user_id", int64(1))
c.Set("username", "testuser")
c.Next()
}
}()
ssoGroup := engine.Group("/api/v1/sso")
ssoGroup.Use(authMiddleware)
{
ssoGroup.GET("/authorize", ssoHandler.Authorize)
}
server := httptest.NewServer(engine)
defer server.Close()
// Without clients store, any redirect_uri should be accepted (or validation skipped)
resp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=any&redirect_uri=http://any.com/callback&response_type=code", "Bearer test-token")
defer resp.Body.Close()
if resp.StatusCode != http.StatusFound {
t.Errorf("expected redirect when clientsStore is nil, got %d", resp.StatusCode)
}
}

View File

@@ -0,0 +1,685 @@
package handler_test
import (
"bytes"
"encoding/json"
"net/http"
"testing"
"github.com/user-management-system/internal/auth"
)
func TestTOTPHandler_GetTOTPStatus(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "totpstatususer", "totpstatus@test.com", "UserPass123!")
token := getToken(server.URL, "totpstatususer", "UserPass123!")
resp, body := doGet(server.URL+"/api/v1/auth/2fa/status", token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["code"] != float64(0) {
t.Errorf("expected code 0, got %v", result["code"])
}
data, ok := result["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected data in response, got %s", body)
}
if data["enabled"] != false {
t.Errorf("expected enabled=false for new user, got %v", data["enabled"])
}
}
func TestTOTPHandler_GetTOTPStatus_Unauthorized(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
resp, _ := doGet(server.URL+"/api/v1/auth/2fa/status", "")
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
}
func TestTOTPHandler_SetupTOTP(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "totpsetupuser", "totpsetup@test.com", "UserPass123!")
token := getToken(server.URL, "totpsetupuser", "UserPass123!")
resp, body := doGet(server.URL+"/api/v1/auth/2fa/setup", token)
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["code"] != float64(0) {
t.Errorf("expected code 0, got %v", result["code"])
}
data, ok := result["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected data in response, got %s", body)
}
if data["secret"] == nil || data["secret"] == "" {
t.Errorf("expected secret in setup response, got %+v", data)
}
}
func TestTOTPHandler_SetupTOTP_Unauthorized(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
resp, _ := doGet(server.URL+"/api/v1/auth/2fa/setup", "")
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
}
func TestTOTPHandler_EnableTOTP(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpenableuser", "totpenable@test.com", "UserPass123!")
_ = userID
_ = secret
// setupEnabledTOTPUser already enables TOTP, so let's just verify the user can login with TOTP
// Actually, we need a fresh user to test enable
registerUser(server.URL, "totpenableuser2", "totpenable2@test.com", "UserPass123!")
token := getToken(server.URL, "totpenableuser2", "UserPass123!")
// Setup TOTP
setupResp, setupBody := doGet(server.URL+"/api/v1/auth/2fa/setup", token)
defer setupResp.Body.Close()
if setupResp.StatusCode != http.StatusOK {
t.Fatalf("setup failed: status=%d body=%s", setupResp.StatusCode, setupBody)
}
var setupResult map[string]interface{}
if err := json.Unmarshal([]byte(setupBody), &setupResult); err != nil {
t.Fatalf("failed to parse setup response: %v", err)
}
setupData, ok := setupResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected setup data, got %s", setupBody)
}
newSecret, ok := setupData["secret"].(string)
if !ok || newSecret == "" {
t.Fatalf("expected secret in setup response, got %s", setupBody)
}
// Generate valid code
code, err := auth.NewTOTPManager().GenerateCurrentCode(newSecret)
if err != nil {
t.Fatalf("failed to generate TOTP code: %v", err)
}
// Enable TOTP
enableResp, enableBody := doPost(server.URL+"/api/v1/auth/2fa/enable", token, map[string]interface{}{
"code": code,
})
defer enableResp.Body.Close()
if enableResp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, enableResp.StatusCode, enableBody)
}
}
func TestTOTPHandler_EnableTOTP_InvalidCode(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "totpenableinv", "totpenableinv@test.com", "UserPass123!")
token := getToken(server.URL, "totpenableinv", "UserPass123!")
// Setup TOTP first
setupResp, setupBody := doGet(server.URL+"/api/v1/auth/2fa/setup", token)
defer setupResp.Body.Close()
if setupResp.StatusCode != http.StatusOK {
t.Fatalf("setup failed: status=%d body=%s", setupResp.StatusCode, setupBody)
}
// Try enable with invalid code
enableResp, enableBody := doPost(server.URL+"/api/v1/auth/2fa/enable", token, map[string]interface{}{
"code": "000000",
})
defer enableResp.Body.Close()
if enableResp.StatusCode != http.StatusUnauthorized && enableResp.StatusCode != http.StatusInternalServerError {
t.Errorf("expected status 401 or 500 for invalid code, got %d, body: %s", enableResp.StatusCode, enableBody)
}
}
func TestTOTPHandler_EnableTOTP_MissingCode(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "totpenablemiss", "totpenablemiss@test.com", "UserPass123!")
token := getToken(server.URL, "totpenablemiss", "UserPass123!")
resp, body := doPost(server.URL+"/api/v1/auth/2fa/enable", token, map[string]interface{}{})
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestTOTPHandler_DisableTOTP(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpdisableuser", "totpdisable@test.com", "UserPass123!")
// Login again to get a fresh token (since TOTP is enabled, login may require TOTP)
deviceID := "test-device"
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "totpdisableuser",
"password": "UserPass123!",
"device_id": deviceID,
})
defer loginResp.Body.Close()
if loginResp.StatusCode != http.StatusOK {
t.Fatalf("login failed: status=%d body=%s", loginResp.StatusCode, loginBody)
}
var loginResult map[string]interface{}
if err := json.Unmarshal([]byte(loginBody), &loginResult); err != nil {
t.Fatalf("failed to parse login response: %v", err)
}
// If requires_totp, we need to verify TOTP first
loginData, ok := loginResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected login data, got %s", loginBody)
}
var token string
if loginData["requires_totp"] == true {
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
if err != nil {
t.Fatalf("failed to generate TOTP code: %v", err)
}
tempToken, _ := loginData["temp_token"].(string)
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
"user_id": userID,
"code": code,
"device_id": deviceID,
"temp_token": tempToken,
})
defer verifyResp.Body.Close()
if verifyResp.StatusCode != http.StatusOK {
t.Fatalf("totp verify failed: status=%d body=%s", verifyResp.StatusCode, verifyBody)
}
var verifyResult map[string]interface{}
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err != nil {
t.Fatalf("failed to parse verify response: %v", err)
}
verifyData, ok := verifyResult["data"].(map[string]interface{})
if ok && verifyData["access_token"] != nil {
token, _ = verifyData["access_token"].(string)
}
} else {
token, _ = loginData["access_token"].(string)
}
if token == "" {
t.Fatal("failed to get token after login")
}
// Generate valid code for disable
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
if err != nil {
t.Fatalf("failed to generate TOTP code: %v", err)
}
resp, body := doPost(server.URL+"/api/v1/auth/2fa/disable", token, map[string]interface{}{
"code": code,
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
// Verify TOTP is disabled
statusResp, statusBody := doGet(server.URL+"/api/v1/auth/2fa/status", token)
defer statusResp.Body.Close()
if statusResp.StatusCode != http.StatusOK {
t.Fatalf("status check failed: status=%d body=%s", statusResp.StatusCode, statusBody)
}
var statusResult map[string]interface{}
if err := json.Unmarshal([]byte(statusBody), &statusResult); err != nil {
t.Fatalf("failed to parse status response: %v", err)
}
statusData, ok := statusResult["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected status data, got %s", statusBody)
}
if statusData["enabled"] != false {
t.Errorf("expected enabled=false after disable, got %v", statusData["enabled"])
}
}
func TestTOTPHandler_DisableTOTP_InvalidCode(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpdisableinv", "totpdisableinv@test.com", "UserPass123!")
// Get token (might need TOTP verification)
deviceID := "test-device"
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "totpdisableinv",
"password": "UserPass123!",
"device_id": deviceID,
})
defer loginResp.Body.Close()
var token string
var loginResult map[string]interface{}
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
if loginData["requires_totp"] == true {
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
tempToken, _ := loginData["temp_token"].(string)
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
"user_id": userID,
"code": code,
"device_id": deviceID,
"temp_token": tempToken,
})
defer verifyResp.Body.Close()
if verifyResp.StatusCode == http.StatusOK {
var verifyResult map[string]interface{}
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
token, _ = verifyData["access_token"].(string)
}
}
}
} else {
token, _ = loginData["access_token"].(string)
}
}
}
if token == "" {
t.Fatal("failed to get token after login")
}
resp, body := doPost(server.URL+"/api/v1/auth/2fa/disable", token, map[string]interface{}{
"code": "000000",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusInternalServerError {
t.Errorf("expected status 401 or 500 for invalid code, got %d, body: %s", resp.StatusCode, body)
}
}
func TestTOTPHandler_VerifyTOTP(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpverifyuser", "totpverify@test.com", "UserPass123!")
// Get token (might need TOTP verification)
deviceID := "test-device"
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "totpverifyuser",
"password": "UserPass123!",
"device_id": deviceID,
})
defer loginResp.Body.Close()
var token string
var loginResult map[string]interface{}
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
if loginData["requires_totp"] == true {
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
tempToken, _ := loginData["temp_token"].(string)
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
"user_id": userID,
"code": code,
"device_id": deviceID,
"temp_token": tempToken,
})
defer verifyResp.Body.Close()
if verifyResp.StatusCode == http.StatusOK {
var verifyResult map[string]interface{}
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
token, _ = verifyData["access_token"].(string)
}
}
}
} else {
token, _ = loginData["access_token"].(string)
}
}
}
if token == "" {
t.Fatal("failed to get token after login")
}
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
if err != nil {
t.Fatalf("failed to generate TOTP code: %v", err)
}
resp, body := doPost(server.URL+"/api/v1/auth/2fa/verify", token, map[string]interface{}{
"code": code,
"device_id": deviceID,
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
}
var result map[string]interface{}
if err := json.Unmarshal([]byte(body), &result); err != nil {
t.Fatalf("failed to parse response: %v", err)
}
if result["code"] != float64(0) {
t.Errorf("expected code 0, got %v", result["code"])
}
data, ok := result["data"].(map[string]interface{})
if !ok {
t.Fatalf("expected data in response, got %s", body)
}
if data["verified"] != true {
t.Errorf("expected verified=true, got %v", data["verified"])
}
}
func TestTOTPHandler_VerifyTOTP_InvalidCode(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpverifyinv", "totpverifyinv@test.com", "UserPass123!")
// Get token
deviceID := "test-device"
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "totpverifyinv",
"password": "UserPass123!",
"device_id": deviceID,
})
defer loginResp.Body.Close()
var token string
var loginResult map[string]interface{}
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
if loginData["requires_totp"] == true {
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
tempToken, _ := loginData["temp_token"].(string)
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
"user_id": userID,
"code": code,
"device_id": deviceID,
"temp_token": tempToken,
})
defer verifyResp.Body.Close()
if verifyResp.StatusCode == http.StatusOK {
var verifyResult map[string]interface{}
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
token, _ = verifyData["access_token"].(string)
}
}
}
} else {
token, _ = loginData["access_token"].(string)
}
}
}
if token == "" {
t.Fatal("failed to get token after login")
}
resp, body := doPost(server.URL+"/api/v1/auth/2fa/verify", token, map[string]interface{}{
"code": "000000",
"device_id": deviceID,
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusInternalServerError {
t.Errorf("expected status 401 or 500 for invalid code, got %d, body: %s", resp.StatusCode, body)
}
}
func TestTOTPHandler_VerifyTOTP_MissingCode(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "totpverifymiss", "totpverifymiss@test.com", "UserPass123!")
token := getToken(server.URL, "totpverifymiss", "UserPass123!")
resp, body := doPost(server.URL+"/api/v1/auth/2fa/verify", token, map[string]interface{}{})
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestTOTPHandler_VerifyTOTP_Unauthorized(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
resp, _ := doPost(server.URL+"/api/v1/auth/2fa/verify", "", map[string]interface{}{
"code": "123456",
"device_id": "test-device",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
}
func TestTOTPHandler_DisableTOTP_MissingCode(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpdisablemiss", "totpdisablemiss@test.com", "UserPass123!")
// Get token
deviceID := "test-device"
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "totpdisablemiss",
"password": "UserPass123!",
"device_id": deviceID,
})
defer loginResp.Body.Close()
var token string
var loginResult map[string]interface{}
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
if loginData["requires_totp"] == true {
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
tempToken, _ := loginData["temp_token"].(string)
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
"user_id": userID,
"code": code,
"device_id": deviceID,
"temp_token": tempToken,
})
defer verifyResp.Body.Close()
if verifyResp.StatusCode == http.StatusOK {
var verifyResult map[string]interface{}
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
token, _ = verifyData["access_token"].(string)
}
}
}
} else {
token, _ = loginData["access_token"].(string)
}
}
}
if token == "" {
t.Fatal("failed to get token after login")
}
resp, body := doPost(server.URL+"/api/v1/auth/2fa/disable", token, map[string]interface{}{})
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
}
}
func TestTOTPHandler_DisableTOTP_Unauthorized(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
resp, _ := doPost(server.URL+"/api/v1/auth/2fa/disable", "", map[string]interface{}{
"code": "123456",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
}
func TestTOTPHandler_SetupTOTP_AlreadyEnabled(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpsetupenabled", "totpsetupenabled@test.com", "UserPass123!")
_ = secret
// Get token after TOTP login
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
"account": "totpsetupenabled",
"password": "UserPass123!",
"device_id": "test-device",
})
defer loginResp.Body.Close()
var token string
var loginResult map[string]interface{}
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
if loginData["requires_totp"] == true {
tempToken, _ := loginData["temp_token"].(string)
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
"user_id": userID,
"temp_token": tempToken,
"code": code,
"device_id": "test-device",
})
defer verifyResp.Body.Close()
if verifyResp.StatusCode == http.StatusOK {
var verifyResult map[string]interface{}
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
token, _ = verifyData["access_token"].(string)
}
}
}
} else {
token, _ = loginData["access_token"].(string)
}
}
}
if token == "" {
t.Fatal("failed to get token after login")
}
// Try setup again - should still work or return appropriate response
resp, body := doGet(server.URL+"/api/v1/auth/2fa/setup", token)
defer resp.Body.Close()
// Setup may return 200 with new secret or error if already enabled
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusBadRequest {
t.Errorf("unexpected status %d, body: %s", resp.StatusCode, body)
}
}
func TestTOTPHandler_EnableTOTP_Unauthorized(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
resp, _ := doPost(server.URL+"/api/v1/auth/2fa/enable", "", map[string]interface{}{
"code": "123456",
})
defer resp.Body.Close()
if resp.StatusCode != http.StatusUnauthorized {
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
}
}
func TestTOTPHandler_InvalidJSON(t *testing.T) {
server, cleanup := setupHandlerTestServer(t)
defer cleanup()
registerUser(server.URL, "totpjsonuser", "totpjson@test.com", "UserPass123!")
token := getToken(server.URL, "totpjsonuser", "UserPass123!")
tests := []struct {
name string
path string
method string
}{
{"enable_invalid_json", "/api/v1/auth/2fa/enable", "POST"},
{"disable_invalid_json", "/api/v1/auth/2fa/disable", "POST"},
{"verify_invalid_json", "/api/v1/auth/2fa/verify", "POST"},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(tc.method, server.URL+tc.path, bytes.NewReader([]byte("not json")))
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Set("Content-Type", "application/json")
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadRequest {
t.Errorf("expected status %d for invalid JSON, got %d", http.StatusBadRequest, resp.StatusCode)
}
})
}
}

View File

@@ -0,0 +1,102 @@
package middleware
import (
"bytes"
"compress/gzip"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
)
func TestGzipMiddleware_CompressesLargeJSONResponses(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(GzipMiddleware())
router.GET("/data", func(c *gin.Context) {
c.Header("Content-Type", "application/json")
c.String(http.StatusOK, strings.Repeat("a", gzipMinLength+128))
})
req := httptest.NewRequest(http.MethodGet, "/data", nil)
req.Header.Set("Accept-Encoding", "gzip")
router.ServeHTTP(recorder, req)
if got := recorder.Header().Get("Content-Encoding"); got != "gzip" {
t.Fatalf("Content-Encoding = %q, want gzip", got)
}
reader, err := gzip.NewReader(bytes.NewReader(recorder.Body.Bytes()))
if err != nil {
t.Fatalf("gzip.NewReader() error = %v", err)
}
defer reader.Close()
payload, err := io.ReadAll(reader)
if err != nil {
t.Fatalf("ReadAll() error = %v", err)
}
if got := string(payload); got != strings.Repeat("a", gzipMinLength+128) {
t.Fatalf("decompressed payload length = %d, want %d", len(got), gzipMinLength+128)
}
}
func TestGzipMiddleware_PassesThroughWhenCompressionNotUseful(t *testing.T) {
gin.SetMode(gin.TestMode)
testCases := []struct {
name string
acceptEncoding string
contentType string
body string
}{
{
name: "client does not accept gzip",
acceptEncoding: "",
contentType: "application/json",
body: strings.Repeat("b", gzipMinLength+64),
},
{
name: "body below threshold",
acceptEncoding: "gzip",
contentType: "application/json",
body: "small-body",
},
{
name: "unsupported content type",
acceptEncoding: "gzip",
contentType: "image/png",
body: strings.Repeat("c", gzipMinLength+64),
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(GzipMiddleware())
router.GET("/data", func(c *gin.Context) {
c.Header("Content-Type", tc.contentType)
c.String(http.StatusOK, tc.body)
})
req := httptest.NewRequest(http.MethodGet, "/data", nil)
if tc.acceptEncoding != "" {
req.Header.Set("Accept-Encoding", tc.acceptEncoding)
}
router.ServeHTTP(recorder, req)
if got := recorder.Header().Get("Content-Encoding"); got != "" {
t.Fatalf("Content-Encoding = %q, want empty", got)
}
if got := recorder.Body.String(); got != tc.body {
t.Fatalf("body length = %d, want %d", len(got), len(tc.body))
}
})
}
}

View File

@@ -0,0 +1,165 @@
package middleware
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/repository"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/logger"
_ "modernc.org/sqlite"
)
func newOperationLogRepositoryForTest(t *testing.T) *repository.OperationLogRepository {
t.Helper()
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
DriverName: "sqlite",
DSN: "file:operation_log_test?mode=memory&cache=shared",
}), &gorm.Config{
Logger: logger.Default.LogMode(logger.Silent),
})
if err != nil {
t.Fatalf("open sqlite failed: %v", err)
}
if err := db.AutoMigrate(&domain.OperationLog{}); err != nil {
t.Fatalf("migrate failed: %v", err)
}
if err := db.Exec("DELETE FROM operation_logs").Error; err != nil {
t.Fatalf("cleanup operation_logs failed: %v", err)
}
return repository.NewOperationLogRepository(db)
}
func waitForOperationLogs(t *testing.T, repo *repository.OperationLogRepository, want int) []*domain.OperationLog {
t.Helper()
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
logs, _, err := repo.List(context.Background(), 0, 20)
if err != nil {
t.Fatalf("list operation logs failed: %v", err)
}
if len(logs) >= want {
return logs
}
time.Sleep(25 * time.Millisecond)
}
logs, _, err := repo.List(context.Background(), 0, 20)
if err != nil {
t.Fatalf("list operation logs failed: %v", err)
}
t.Fatalf("timed out waiting for %d operation logs, got %d", want, len(logs))
return nil
}
func TestOperationLogMiddleware_SkipsReadOnlyMethods(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newOperationLogRepositoryForTest(t)
router := gin.New()
router.Use(NewOperationLogMiddleware(repo).Record())
router.GET("/logs", func(c *gin.Context) {
c.Status(http.StatusOK)
})
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
time.Sleep(100 * time.Millisecond)
logs, _, err := repo.List(context.Background(), 0, 20)
if err != nil {
t.Fatalf("list operation logs failed: %v", err)
}
if len(logs) != 0 {
t.Fatalf("expected no logs for GET request, got %d", len(logs))
}
}
func TestOperationLogMiddleware_RecordsAdminMutationAndSanitizesParams(t *testing.T) {
gin.SetMode(gin.TestMode)
repo := newOperationLogRepositoryForTest(t)
router := gin.New()
router.Use(func(c *gin.Context) {
c.Set("user_id", int64(42))
c.Set(ContextKeyRoleCodes, []string{"admin"})
c.Next()
})
router.Use(NewOperationLogMiddleware(repo).Record())
router.POST("/users", func(c *gin.Context) {
c.Status(http.StatusCreated)
})
body := `{"username":"alice","password":"super-secret","token":"abc"}`
req := httptest.NewRequest(http.MethodPost, "/users", strings.NewReader(body))
req.RemoteAddr = "203.0.113.10:8080"
req.Header.Set("User-Agent", "middleware-test")
recorder := httptest.NewRecorder()
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusCreated {
t.Fatalf("expected 201, got %d", recorder.Code)
}
logs := waitForOperationLogs(t, repo, 1)
entry := logs[0]
if entry.UserID == nil || *entry.UserID != 42 {
t.Fatalf("user_id = %#v, want 42", entry.UserID)
}
if entry.OperationType != "admin:CREATE" {
t.Fatalf("operation_type = %q, want admin:CREATE", entry.OperationType)
}
if entry.ResponseStatus != http.StatusCreated {
t.Fatalf("response_status = %d, want %d", entry.ResponseStatus, http.StatusCreated)
}
if strings.Contains(entry.RequestParams, "super-secret") || strings.Contains(entry.RequestParams, "abc") {
t.Fatalf("expected sanitized params, got %s", entry.RequestParams)
}
}
func TestOperationLogMiddleware_MethodToTypeAndSanitizeFallbacks(t *testing.T) {
if got := methodToType(http.MethodPatch); got != "UPDATE" {
t.Fatalf("methodToType(PATCH) = %q, want UPDATE", got)
}
if got := methodToType(http.MethodDelete); got != "DELETE" {
t.Fatalf("methodToType(DELETE) = %q, want DELETE", got)
}
if got := methodToType(http.MethodGet); got != "OTHER" {
t.Fatalf("methodToType(GET) = %q, want OTHER", got)
}
raw := []byte(`{"password":"secret","name":"alice"}`)
sanitized := sanitizeParams(raw)
if strings.Contains(sanitized, "secret") {
t.Fatalf("expected password to be masked, got %s", sanitized)
}
plain := sanitizeParams([]byte("not-json"))
if plain != "not-json" {
t.Fatalf("sanitizeParams(non-json) = %q, want not-json", plain)
}
var payload map[string]string
if err := json.Unmarshal([]byte(sanitized), &payload); err != nil {
t.Fatalf("unmarshal sanitized params failed: %v", err)
}
if payload["password"] != "***" {
t.Fatalf("password = %q, want ***", payload["password"])
}
}

View File

@@ -0,0 +1,114 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func performRBACRequest(t *testing.T, setup func(*gin.Context), middleware gin.HandlerFunc) *httptest.ResponseRecorder {
t.Helper()
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
if setup != nil {
router.Use(setup)
}
router.Use(middleware)
router.GET("/protected", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"code": 0})
})
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
router.ServeHTTP(recorder, req)
return recorder
}
func TestRequirePermissionRejectsMissingPermission(t *testing.T) {
recorder := performRBACRequest(t, func(c *gin.Context) {
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
c.Next()
}, RequirePermission("users:write"))
if recorder.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", recorder.Code)
}
}
func TestRequirePermissionAllowsMatchingPermission(t *testing.T) {
recorder := performRBACRequest(t, func(c *gin.Context) {
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
c.Next()
}, RequirePermission("users:read"))
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
}
func TestRequireAllPermissionsRequiresEveryCode(t *testing.T) {
recorder := performRBACRequest(t, func(c *gin.Context) {
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
c.Next()
}, RequireAllPermissions("users:read", "users:write"))
if recorder.Code != http.StatusForbidden {
t.Fatalf("expected 403, got %d", recorder.Code)
}
}
func TestRequireAnyPermissionIsAliasOfRequirePermission(t *testing.T) {
recorder := performRBACRequest(t, func(c *gin.Context) {
c.Set(ContextKeyPermissionCodes, []string{"users:write"})
c.Next()
}, RequireAnyPermission("users:read", "users:write"))
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
}
func TestRequireRoleAndAdminOnly(t *testing.T) {
roleRecorder := performRBACRequest(t, func(c *gin.Context) {
c.Set(ContextKeyRoleCodes, []string{"auditor"})
c.Next()
}, RequireRole("admin"))
if roleRecorder.Code != http.StatusForbidden {
t.Fatalf("expected role check to return 403, got %d", roleRecorder.Code)
}
adminRecorder := performRBACRequest(t, func(c *gin.Context) {
c.Set(ContextKeyRoleCodes, []string{"admin"})
c.Next()
}, AdminOnly())
if adminRecorder.Code != http.StatusOK {
t.Fatalf("expected admin check to return 200, got %d", adminRecorder.Code)
}
}
func TestRBACHelpersHandleMissingContextValues(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
if got := GetRoleCodes(c); got != nil {
t.Fatalf("GetRoleCodes() = %#v, want nil", got)
}
if got := GetPermissionCodes(c); got != nil {
t.Fatalf("GetPermissionCodes() = %#v, want nil", got)
}
if IsAdmin(c) {
t.Fatal("IsAdmin() = true, want false")
}
c.Set(ContextKeyRoleCodes, []string{"admin"})
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
if !IsAdmin(c) {
t.Fatal("IsAdmin() = false, want true")
}
}

View File

@@ -0,0 +1,119 @@
package middleware
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestResponseWrapper_WrapsSuccessfulJSONPayload(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(ResponseWrapper())
router.GET("/users", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"id": 1, "name": "alice"})
})
req := httptest.NewRequest(http.MethodGet, "/users", nil)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
want := `{"code":0,"data":{"id":1,"name":"alice"},"message":"success"}`
if got := recorder.Body.String(); got != want {
t.Fatalf("body = %s, want %s", got, want)
}
}
func TestResponseWrapper_PassesThroughMarkedResponses(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(ResponseWrapper())
router.GET("/users", func(c *gin.Context) {
WrapResponse(c)
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "already wrapped"})
})
req := httptest.NewRequest(http.MethodGet, "/users", nil)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
want := `{"code":0,"message":"already wrapped"}`
if got := recorder.Body.String(); got != want {
t.Fatalf("body = %s, want %s", got, want)
}
}
func TestResponseWrapper_PassesThroughNonSuccessStatus(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(ResponseWrapper())
router.GET("/users", func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
})
req := httptest.NewRequest(http.MethodGet, "/users", nil)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", recorder.Code)
}
want := `{"message":"bad request"}`
if got := recorder.Body.String(); got != want {
t.Fatalf("body = %s, want %s", got, want)
}
}
func TestResponseWrapper_PassesThroughInvalidJSON(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(ResponseWrapper())
router.GET("/users", func(c *gin.Context) {
c.Writer.WriteHeader(http.StatusOK)
_, _ = c.Writer.WriteString("plain text")
})
req := httptest.NewRequest(http.MethodGet, "/users", nil)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
if got := recorder.Body.String(); got != "plain text" {
t.Fatalf("body = %q, want plain text", got)
}
}
func TestResponseWrapper_NoWrapperMarksContext(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
router := gin.New()
router.Use(NoWrapper())
router.GET("/users", func(c *gin.Context) {
if _, exists := c.Get("response_wrapped"); !exists {
t.Fatal("expected response_wrapped marker in context")
}
c.JSON(http.StatusOK, gin.H{"ok": true})
})
req := httptest.NewRequest(http.MethodGet, "/users", nil)
router.ServeHTTP(recorder, req)
if recorder.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", recorder.Code)
}
}

View File

@@ -0,0 +1,136 @@
package domain
import (
"testing"
"time"
)
func TestDeviceType_Constants(t *testing.T) {
tests := []struct {
name string
value DeviceType
expected int
}{
{"Unknown", DeviceTypeUnknown, 0},
{"Web", DeviceTypeWeb, 1},
{"Mobile", DeviceTypeMobile, 2},
{"Desktop", DeviceTypeDesktop, 3},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if int(tc.value) != tc.expected {
t.Errorf("expected %d, got %d", tc.expected, int(tc.value))
}
})
}
}
func TestDeviceStatus_Constants(t *testing.T) {
tests := []struct {
name string
value DeviceStatus
expected int
}{
{"Inactive", DeviceStatusInactive, 0},
{"Active", DeviceStatusActive, 1},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if int(tc.value) != tc.expected {
t.Errorf("expected %d, got %d", tc.expected, int(tc.value))
}
})
}
}
func TestDevice_TableName(t *testing.T) {
var d Device
if got := d.TableName(); got != "devices" {
t.Errorf("expected table name 'devices', got %q", got)
}
}
func TestDevice_StructFields(t *testing.T) {
now := time.Now()
trustExpires := now.Add(24 * time.Hour)
d := Device{
ID: 1,
UserID: 2,
DeviceID: "device-123",
DeviceName: "Test Device",
DeviceType: DeviceTypeWeb,
DeviceOS: "Windows",
DeviceBrowser: "Chrome",
IP: "127.0.0.1",
Location: "Beijing",
IsTrusted: true,
TrustExpiresAt: &trustExpires,
Status: DeviceStatusActive,
LastActiveTime: now,
CreatedAt: now,
UpdatedAt: now,
}
if d.ID != 1 {
t.Errorf("expected ID 1, got %d", d.ID)
}
if d.UserID != 2 {
t.Errorf("expected UserID 2, got %d", d.UserID)
}
if d.DeviceID != "device-123" {
t.Errorf("expected DeviceID 'device-123', got %q", d.DeviceID)
}
if d.DeviceName != "Test Device" {
t.Errorf("expected DeviceName 'Test Device', got %q", d.DeviceName)
}
if d.DeviceType != DeviceTypeWeb {
t.Errorf("expected DeviceTypeWeb, got %d", d.DeviceType)
}
if d.DeviceOS != "Windows" {
t.Errorf("expected DeviceOS 'Windows', got %q", d.DeviceOS)
}
if d.DeviceBrowser != "Chrome" {
t.Errorf("expected DeviceBrowser 'Chrome', got %q", d.DeviceBrowser)
}
if d.IP != "127.0.0.1" {
t.Errorf("expected IP '127.0.0.1', got %q", d.IP)
}
if d.Location != "Beijing" {
t.Errorf("expected Location 'Beijing', got %q", d.Location)
}
if !d.IsTrusted {
t.Error("expected IsTrusted to be true")
}
if d.TrustExpiresAt == nil || !d.TrustExpiresAt.Equal(trustExpires) {
t.Error("expected TrustExpiresAt to match")
}
if d.Status != DeviceStatusActive {
t.Errorf("expected DeviceStatusActive, got %d", d.Status)
}
if d.LastActiveTime.IsZero() {
t.Error("expected LastActiveTime to be set")
}
if d.CreatedAt.IsZero() {
t.Error("expected CreatedAt to be set")
}
if d.UpdatedAt.IsZero() {
t.Error("expected UpdatedAt to be set")
}
}
func TestDevice_DefaultStatus(t *testing.T) {
var d Device
if d.Status != DeviceStatusInactive {
t.Errorf("expected default status Inactive(0), got %d", d.Status)
}
}
func TestDevice_DefaultDeviceType(t *testing.T) {
var d Device
if d.DeviceType != DeviceTypeUnknown {
t.Errorf("expected default device type Unknown(0), got %d", d.DeviceType)
}
}

View File

@@ -0,0 +1,35 @@
package domain
import (
"testing"
"time"
)
func TestPasswordHistory_TableName(t *testing.T) {
var h PasswordHistory
if got := h.TableName(); got != "password_histories" {
t.Errorf("expected table name 'password_histories', got %q", got)
}
}
func TestPasswordHistory_StructTags(t *testing.T) {
h := PasswordHistory{
ID: 1,
UserID: 2,
PasswordHash: "hash123",
CreatedAt: time.Now(),
}
if h.ID != 1 {
t.Errorf("expected ID 1, got %d", h.ID)
}
if h.UserID != 2 {
t.Errorf("expected UserID 2, got %d", h.UserID)
}
if h.PasswordHash != "hash123" {
t.Errorf("expected PasswordHash 'hash123', got %q", h.PasswordHash)
}
if h.CreatedAt.IsZero() {
t.Error("expected CreatedAt to be set")
}
}

View File

@@ -0,0 +1,77 @@
package pagination
import (
"testing"
)
func TestDefaultPagination(t *testing.T) {
p := DefaultPagination()
if p.Page != 1 {
t.Errorf("expected default page 1, got %d", p.Page)
}
if p.PageSize != 20 {
t.Errorf("expected default page_size 20, got %d", p.PageSize)
}
}
func TestPaginationParams_Offset(t *testing.T) {
tests := []struct {
name string
page int
pageSize int
wantOffset int
}{
{"page 1", 1, 20, 0},
{"page 2", 2, 20, 20},
{"page 5", 5, 20, 80},
{"zero page", 0, 20, 0},
{"negative page", -1, 20, 0},
{"page 1 size 10", 1, 10, 0},
{"page 3 size 10", 3, 10, 20},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
p := PaginationParams{Page: tc.page, PageSize: tc.pageSize}
if got := p.Offset(); got != tc.wantOffset {
t.Errorf("expected offset %d, got %d", tc.wantOffset, got)
}
})
}
}
func TestPaginationParams_Limit(t *testing.T) {
tests := []struct {
name string
pageSize int
want int
}{
{"default", 20, 20},
{"size 10", 10, 10},
{"size 50", 50, 50},
{"size 100", 100, 100},
{"max cap", 101, 100},
{"zero size", 0, 20},
{"negative size", -1, 20},
{"size 1", 1, 1},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
p := PaginationParams{PageSize: tc.pageSize}
if got := p.Limit(); got != tc.want {
t.Errorf("expected limit %d, got %d", tc.want, got)
}
})
}
}
func TestPaginationParams_OffsetAndLimit(t *testing.T) {
p := PaginationParams{Page: 3, PageSize: 15}
if got := p.Offset(); got != 30 {
t.Errorf("expected offset 30, got %d", got)
}
if got := p.Limit(); got != 15 {
t.Errorf("expected limit 15, got %d", got)
}
}

View File

@@ -0,0 +1,95 @@
package repository
import (
"testing"
"github.com/user-management-system/internal/pkg/pagination"
)
func TestPaginationResultFromTotal(t *testing.T) {
tests := []struct {
name string
total int64
params pagination.PaginationParams
wantPages int
wantTotal int64
wantPage int
wantPageSize int
}{
{
name: "exact division",
total: 100,
params: pagination.PaginationParams{Page: 1, PageSize: 20},
wantPages: 5,
wantTotal: 100,
wantPage: 1,
wantPageSize: 20,
},
{
name: "with remainder",
total: 105,
params: pagination.PaginationParams{Page: 1, PageSize: 20},
wantPages: 6,
wantTotal: 105,
wantPage: 1,
wantPageSize: 20,
},
{
name: "zero total",
total: 0,
params: pagination.PaginationParams{Page: 1, PageSize: 20},
wantPages: 0,
wantTotal: 0,
wantPage: 1,
wantPageSize: 20,
},
{
name: "single page",
total: 5,
params: pagination.PaginationParams{Page: 1, PageSize: 20},
wantPages: 1,
wantTotal: 5,
wantPage: 1,
wantPageSize: 20,
},
{
name: "page 2",
total: 50,
params: pagination.PaginationParams{Page: 2, PageSize: 20},
wantPages: 3,
wantTotal: 50,
wantPage: 2,
wantPageSize: 20,
},
{
name: "small page size",
total: 10,
params: pagination.PaginationParams{Page: 1, PageSize: 3},
wantPages: 4,
wantTotal: 10,
wantPage: 1,
wantPageSize: 3,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
result := paginationResultFromTotal(tc.total, tc.params)
if result == nil {
t.Fatal("expected non-nil result")
}
if result.Total != tc.wantTotal {
t.Errorf("expected total %d, got %d", tc.wantTotal, result.Total)
}
if result.Page != tc.wantPage {
t.Errorf("expected page %d, got %d", tc.wantPage, result.Page)
}
if result.PageSize != tc.wantPageSize {
t.Errorf("expected page_size %d, got %d", tc.wantPageSize, result.PageSize)
}
if result.Pages != tc.wantPages {
t.Errorf("expected pages %d, got %d", tc.wantPages, result.Pages)
}
})
}
}

View File

@@ -0,0 +1,224 @@
package repository
import (
"context"
"testing"
"time"
"github.com/user-management-system/internal/domain"
)
func TestPasswordHistoryRepository_Create(t *testing.T) {
db := openTestDB(t)
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
t.Fatalf("migrate password_history failed: %v", err)
}
repo := NewPasswordHistoryRepository(db)
ctx := context.Background()
history := &domain.PasswordHistory{
UserID: 1,
PasswordHash: "hash1",
CreatedAt: time.Now(),
}
if err := repo.Create(ctx, history); err != nil {
t.Fatalf("create failed: %v", err)
}
if history.ID == 0 {
t.Error("expected ID to be set after create")
}
}
func TestPasswordHistoryRepository_GetByUserID(t *testing.T) {
db := openTestDB(t)
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
t.Fatalf("migrate password_history failed: %v", err)
}
repo := NewPasswordHistoryRepository(db)
ctx := context.Background()
// Create multiple records for user 1
for i := 0; i < 5; i++ {
h := &domain.PasswordHistory{
UserID: 1,
PasswordHash: "hash",
CreatedAt: time.Now().Add(time.Duration(i) * time.Second),
}
if err := repo.Create(ctx, h); err != nil {
t.Fatalf("create failed: %v", err)
}
}
// Create record for user 2
if err := repo.Create(ctx, &domain.PasswordHistory{UserID: 2, PasswordHash: "hash", CreatedAt: time.Now()}); err != nil {
t.Fatalf("create failed: %v", err)
}
tests := []struct {
name string
userID int64
limit int
wantLen int
wantUser int64
}{
{"get all for user 1", 1, 10, 5, 1},
{"limit 3 for user 1", 1, 3, 3, 1},
{"get for user 2", 2, 10, 1, 2},
{"get for nonexistent user", 999, 10, 0, 999},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
histories, err := repo.GetByUserID(ctx, tc.userID, tc.limit)
if err != nil {
t.Fatalf("get failed: %v", err)
}
if len(histories) != tc.wantLen {
t.Errorf("expected %d histories, got %d", tc.wantLen, len(histories))
}
for _, h := range histories {
if h.UserID != tc.wantUser {
t.Errorf("expected user_id %d, got %d", tc.wantUser, h.UserID)
}
}
})
}
}
func TestPasswordHistoryRepository_GetByUserID_Order(t *testing.T) {
db := openTestDB(t)
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
t.Fatalf("migrate password_history failed: %v", err)
}
repo := NewPasswordHistoryRepository(db)
ctx := context.Background()
// Create records with different timestamps
now := time.Now()
for i := 0; i < 3; i++ {
h := &domain.PasswordHistory{
UserID: 1,
PasswordHash: "hash",
CreatedAt: now.Add(time.Duration(i) * time.Hour),
}
if err := repo.Create(ctx, h); err != nil {
t.Fatalf("create failed: %v", err)
}
}
histories, err := repo.GetByUserID(ctx, 1, 10)
if err != nil {
t.Fatalf("get failed: %v", err)
}
if len(histories) != 3 {
t.Fatalf("expected 3 histories, got %d", len(histories))
}
// Should be ordered by created_at DESC (newest first)
for i := 0; i < len(histories)-1; i++ {
if !histories[i].CreatedAt.After(histories[i+1].CreatedAt) && !histories[i].CreatedAt.Equal(histories[i+1].CreatedAt) {
t.Errorf("expected descending order, got %v before %v", histories[i].CreatedAt, histories[i+1].CreatedAt)
}
}
}
func TestPasswordHistoryRepository_DeleteOldRecords(t *testing.T) {
db := openTestDB(t)
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
t.Fatalf("migrate password_history failed: %v", err)
}
repo := NewPasswordHistoryRepository(db)
ctx := context.Background()
// Create 5 records for user 1
now := time.Now()
for i := 0; i < 5; i++ {
h := &domain.PasswordHistory{
UserID: 1,
PasswordHash: "hash",
CreatedAt: now.Add(time.Duration(i) * time.Hour),
}
if err := repo.Create(ctx, h); err != nil {
t.Fatalf("create failed: %v", err)
}
}
// Delete old records, keep only 3
if err := repo.DeleteOldRecords(ctx, 1, 3); err != nil {
t.Fatalf("delete old records failed: %v", err)
}
histories, err := repo.GetByUserID(ctx, 1, 10)
if err != nil {
t.Fatalf("get failed: %v", err)
}
if len(histories) != 3 {
t.Errorf("expected 3 histories after cleanup, got %d", len(histories))
}
}
func TestPasswordHistoryRepository_DeleteOldRecords_NoRecords(t *testing.T) {
db := openTestDB(t)
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
t.Fatalf("migrate password_history failed: %v", err)
}
repo := NewPasswordHistoryRepository(db)
ctx := context.Background()
// Should not error when no records exist
if err := repo.DeleteOldRecords(ctx, 999, 3); err != nil {
t.Fatalf("delete old records on empty table should not error: %v", err)
}
}
func TestPasswordHistoryRepository_KeepsNewestRecords(t *testing.T) {
db := openTestDB(t)
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
t.Fatalf("migrate password_history failed: %v", err)
}
repo := NewPasswordHistoryRepository(db)
ctx := context.Background()
// Create 5 records with different timestamps
now := time.Now()
var createdIDs []int64
for i := 0; i < 5; i++ {
h := &domain.PasswordHistory{
UserID: 1,
PasswordHash: "hash",
CreatedAt: now.Add(time.Duration(i) * time.Hour),
}
if err := repo.Create(ctx, h); err != nil {
t.Fatalf("create failed: %v", err)
}
createdIDs = append(createdIDs, h.ID)
}
// Delete old records, keep only 2
if err := repo.DeleteOldRecords(ctx, 1, 2); err != nil {
t.Fatalf("delete old records failed: %v", err)
}
histories, err := repo.GetByUserID(ctx, 1, 10)
if err != nil {
t.Fatalf("get failed: %v", err)
}
if len(histories) != 2 {
t.Fatalf("expected 2 histories after cleanup, got %d", len(histories))
}
// The remaining records should be the newest (last 2 created)
expectedIDs := map[int64]bool{createdIDs[3]: true, createdIDs[4]: true}
for _, h := range histories {
if !expectedIDs[h.ID] {
t.Errorf("expected remaining IDs to be %v, got %d", expectedIDs, h.ID)
}
}
}

View File

@@ -0,0 +1,117 @@
package repository
import (
"context"
"database/sql"
"errors"
"testing"
)
// mockQueryer implements sqlQueryer for testing
type mockQueryer struct {
rows *sql.Rows
err error
}
func (m *mockQueryer) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
return m.rows, m.err
}
func TestScanSingleRow_QueryError(t *testing.T) {
ctx := context.Background()
mockErr := errors.New("query failed")
q := &mockQueryer{err: mockErr}
var dest int
err := scanSingleRow(ctx, q, "SELECT 1", nil, &dest)
if err == nil {
t.Fatal("expected error, got nil")
}
if !errors.Is(err, mockErr) {
t.Errorf("expected query error, got %v", err)
}
}
func TestScanSingleRow_NoRows(t *testing.T) {
// This test requires a real database connection to create sql.Rows.
// scanSingleRow is designed to work with any sqlQueryer, but creating
// a mock sql.Rows without a real driver is complex.
// We test the behavior through integration with the test database.
db := openTestDB(t)
ctx := context.Background()
// Use the raw sql.DB from gorm
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("get sql.DB failed: %v", err)
}
var dest int
err = scanSingleRow(ctx, sqlDB, "SELECT 1 WHERE 1=0", nil, &dest)
if err == nil {
t.Fatal("expected error for no rows, got nil")
}
if !errors.Is(err, sql.ErrNoRows) {
t.Errorf("expected sql.ErrNoRows, got %v", err)
}
}
func TestScanSingleRow_Success(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("get sql.DB failed: %v", err)
}
var dest int
err = scanSingleRow(ctx, sqlDB, "SELECT 42", nil, &dest)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if dest != 42 {
t.Errorf("expected 42, got %d", dest)
}
}
func TestScanSingleRow_MultipleColumns(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("get sql.DB failed: %v", err)
}
var a, b int
err = scanSingleRow(ctx, sqlDB, "SELECT 1, 2", nil, &a, &b)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if a != 1 {
t.Errorf("expected a=1, got %d", a)
}
if b != 2 {
t.Errorf("expected b=2, got %d", b)
}
}
func TestScanSingleRow_StringResult(t *testing.T) {
db := openTestDB(t)
ctx := context.Background()
sqlDB, err := db.DB()
if err != nil {
t.Fatalf("get sql.DB failed: %v", err)
}
var dest string
err = scanSingleRow(ctx, sqlDB, "SELECT 'hello'", nil, &dest)
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if dest != "hello" {
t.Errorf("expected 'hello', got %q", dest)
}
}