test: 补齐 handler/repository/domain 层单元测试
This commit is contained in:
297
internal/api/handler/auth_handler_unit_test.go
Normal file
297
internal/api/handler/auth_handler_unit_test.go
Normal file
@@ -0,0 +1,297 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestAuthHandler_SupportFlags(t *testing.T) {
|
||||
var nilHandler *AuthHandler
|
||||
if nilHandler.SupportsPasswordReset() {
|
||||
t.Fatal("nil handler should not support password reset")
|
||||
}
|
||||
|
||||
handler := &AuthHandler{}
|
||||
if handler.SupportsPasswordReset() {
|
||||
t.Fatal("password reset should be disabled by default")
|
||||
}
|
||||
|
||||
handler.SetPasswordResetEnabled(true)
|
||||
if !handler.SupportsPasswordReset() {
|
||||
t.Fatal("password reset flag should be enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetUserIDFromContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/userinfo", nil)
|
||||
|
||||
if _, ok := getUserIDFromContext(c); ok {
|
||||
t.Fatal("expected missing user_id to return false")
|
||||
}
|
||||
|
||||
c.Set("user_id", "1")
|
||||
if _, ok := getUserIDFromContext(c); ok {
|
||||
t.Fatal("expected non-int64 user_id to return false")
|
||||
}
|
||||
|
||||
c.Set("user_id", int64(42))
|
||||
if got, ok := getUserIDFromContext(c); !ok || got != 42 {
|
||||
t.Fatalf("getUserIDFromContext() = (%d, %v), want (42, true)", got, ok)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestUsesHTTPS(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
if requestUsesHTTPS(nil) {
|
||||
t.Fatal("nil context should not use https")
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||||
if requestUsesHTTPS(c) {
|
||||
t.Fatal("plain http request should not use https")
|
||||
}
|
||||
|
||||
c.Request.Header.Set("X-Forwarded-Proto", "https")
|
||||
if !requestUsesHTTPS(c) {
|
||||
t.Fatal("forwarded https request should be detected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSessionCookies_SetAndClear(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||||
|
||||
setSessionCookies(c, nil, "")
|
||||
if len(recorder.Header().Values("Set-Cookie")) != 0 {
|
||||
t.Fatal("empty refresh token should not set cookies")
|
||||
}
|
||||
|
||||
setSessionCookies(c, nil, "refresh-token")
|
||||
setCookies := recorder.Header().Values("Set-Cookie")
|
||||
if len(setCookies) < 2 {
|
||||
t.Fatalf("expected session cookies to be set, got %d", len(setCookies))
|
||||
}
|
||||
if !strings.Contains(setCookies[0], refreshTokenCookieName+"=refresh-token") &&
|
||||
!strings.Contains(setCookies[1], refreshTokenCookieName+"=refresh-token") {
|
||||
t.Fatalf("expected refresh token cookie, got %#v", setCookies)
|
||||
}
|
||||
|
||||
recorder = httptest.NewRecorder()
|
||||
c, _ = gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||||
clearSessionCookies(c)
|
||||
setCookies = recorder.Header().Values("Set-Cookie")
|
||||
if len(setCookies) < 2 {
|
||||
t.Fatalf("expected clearing cookies to emit expired cookies, got %d", len(setCookies))
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifyErrorMessage(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
msg string
|
||||
want int
|
||||
}{
|
||||
{name: "not found", msg: "user not found", want: http.StatusNotFound},
|
||||
{name: "duplicate", msg: "already exists", want: http.StatusConflict},
|
||||
{name: "verification code", msg: "验证码错误", want: http.StatusUnauthorized},
|
||||
{name: "unauthorized", msg: "invalid token", want: http.StatusUnauthorized},
|
||||
{name: "forbidden", msg: "permission denied", want: http.StatusForbidden},
|
||||
{name: "bad request", msg: "invalid payload", want: http.StatusBadRequest},
|
||||
{name: "rate limit", msg: "too many attempts", want: http.StatusTooManyRequests},
|
||||
{name: "fallback", msg: "unexpected boom", want: http.StatusInternalServerError},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if got := classifyErrorMessage(tc.msg); got != tc.want {
|
||||
t.Fatalf("classifyErrorMessage(%q) = %d, want %d", tc.msg, got, tc.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_OAuthFallbackEndpoints(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
run func(*gin.Context)
|
||||
}{
|
||||
{
|
||||
name: "oauth login",
|
||||
run: func(c *gin.Context) {
|
||||
c.Params = gin.Params{{Key: "provider", Value: "github"}}
|
||||
h.OAuthLogin(c)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "oauth callback",
|
||||
run: func(c *gin.Context) {
|
||||
c.Params = gin.Params{{Key: "provider", Value: "github"}}
|
||||
h.OAuthCallback(c)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "oauth exchange",
|
||||
run: func(c *gin.Context) {
|
||||
c.Params = gin.Params{{Key: "provider", Value: "github"}}
|
||||
h.OAuthExchange(c)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "oauth providers",
|
||||
run: func(c *gin.Context) {
|
||||
h.GetEnabledOAuthProviders(c)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||||
tc.run(c)
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_RefreshToken_InvalidJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/auth/refresh", bytes.NewBufferString("{"))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.RefreshToken(c)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_ActivateEmail_MissingToken(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/auth/activate-email", bytes.NewBufferString(`{}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.ActivateEmail(c)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_ResendActivationEmail_InvalidEmail(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/auth/resend-activation-email", bytes.NewBufferString(`{"email":"bad-email"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.ResendActivationEmail(c)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_SendEmailCode_InvalidEmail(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/auth/send-email-code", bytes.NewBufferString(`{"email":"bad-email"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.SendEmailCode(c)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_LoginByEmailCode_InvalidPayload(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/auth/login-by-email-code", bytes.NewBufferString(`{"email":"bad-email"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
|
||||
h.LoginByEmailCode(c)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_BootstrapAdmin_HeaderFailures(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
h := &AuthHandler{}
|
||||
|
||||
original := os.Getenv("BOOTSTRAP_SECRET")
|
||||
if err := os.Setenv("BOOTSTRAP_SECRET", "expected-secret"); err != nil {
|
||||
t.Fatalf("set env failed: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
_ = os.Setenv("BOOTSTRAP_SECRET", original)
|
||||
})
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
secret string
|
||||
want int
|
||||
}{
|
||||
{name: "missing header", secret: "", want: http.StatusUnauthorized},
|
||||
{name: "wrong header", secret: "wrong-secret", want: http.StatusUnauthorized},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/auth/bootstrap-admin", bytes.NewBufferString(`{"username":"admin","email":"admin@example.com","password":"AdminPass123!"}`))
|
||||
c.Request.Header.Set("Content-Type", "application/json")
|
||||
if tc.secret != "" {
|
||||
c.Request.Header.Set("X-Bootstrap-Secret", tc.secret)
|
||||
}
|
||||
|
||||
h.BootstrapAdmin(c)
|
||||
|
||||
if recorder.Code != tc.want {
|
||||
t.Fatalf("expected %d, got %d", tc.want, recorder.Code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
151
internal/api/handler/avatar_handler_test.go
Normal file
151
internal/api/handler/avatar_handler_test.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// minimalPNG is a valid 1x1 PNG image
|
||||
var minimalPNG = []byte{
|
||||
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D,
|
||||
0x49, 0x48, 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01,
|
||||
0x08, 0x02, 0x00, 0x00, 0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00,
|
||||
0x0C, 0x49, 0x44, 0x41, 0x54, 0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00,
|
||||
0x00, 0x00, 0x03, 0x00, 0x01, 0x00, 0x05, 0xFE, 0xD8, 0x00, 0x00, 0x00,
|
||||
0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82,
|
||||
}
|
||||
|
||||
func buildAvatarUploadRequest(t *testing.T, url, token string, fileBody []byte, filename string) *http.Request {
|
||||
t.Helper()
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
part, err := writer.CreateFormFile("avatar", filename)
|
||||
if err != nil {
|
||||
t.Fatalf("create form file failed: %v", err)
|
||||
}
|
||||
if _, err := part.Write(fileBody); err != nil {
|
||||
t.Fatalf("write file body failed: %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("close multipart writer failed: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, url, &body)
|
||||
if err != nil {
|
||||
t.Fatalf("create request failed: %v", err)
|
||||
}
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
return req
|
||||
}
|
||||
|
||||
func TestAvatarHandler_UploadAvatar(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "avatar-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "avatar-bootstrap-secret", "avataradmin", "avataradmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
if ok := registerUser(server.URL, "avataruser", "avataruser@test.com", "UserPass123!"); !ok {
|
||||
t.Fatal("register user failed")
|
||||
}
|
||||
userToken := getToken(server.URL, "avataruser", "UserPass123!")
|
||||
if userToken == "" {
|
||||
t.Fatal("get user token failed")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID string
|
||||
token string
|
||||
fileBody []byte
|
||||
filename string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "admin_upload_for_any_user",
|
||||
userID: "2",
|
||||
token: adminToken,
|
||||
fileBody: minimalPNG,
|
||||
filename: "avatar.png",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "user_upload_own_avatar",
|
||||
userID: "2",
|
||||
token: userToken,
|
||||
fileBody: minimalPNG,
|
||||
filename: "avatar.png",
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
userID: "1",
|
||||
token: "",
|
||||
fileBody: minimalPNG,
|
||||
filename: "avatar.png",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "forbidden_cross_user",
|
||||
userID: "1",
|
||||
token: userToken,
|
||||
fileBody: minimalPNG,
|
||||
filename: "avatar.png",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "invalid_user_id",
|
||||
userID: "invalid",
|
||||
token: adminToken,
|
||||
fileBody: minimalPNG,
|
||||
filename: "avatar.png",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invalid_file_type",
|
||||
userID: "1",
|
||||
token: adminToken,
|
||||
fileBody: []byte("this is not an image"),
|
||||
filename: "avatar.txt",
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "user_not_found",
|
||||
userID: "99999",
|
||||
token: adminToken,
|
||||
fileBody: minimalPNG,
|
||||
filename: "avatar.png",
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
req := buildAvatarUploadRequest(t, server.URL+"/api/v1/users/"+tt.userID+"/avatar", tt.token, tt.fileBody, tt.filename)
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
body, _ := io.ReadAll(resp.Body)
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, string(body))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Clean up uploaded avatars
|
||||
_ = os.RemoveAll("./uploads/avatars")
|
||||
}
|
||||
545
internal/api/handler/custom_field_handler_test.go
Normal file
545
internal/api/handler/custom_field_handler_test.go
Normal file
@@ -0,0 +1,545 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/user-management-system/internal/api/handler"
|
||||
"github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/api/router"
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/config"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"github.com/user-management-system/internal/service"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var customFieldDbCounter int64
|
||||
|
||||
func setupCustomFieldTestServer(t *testing.T) (*httptest.Server, string, string, func()) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
id := atomic.AddInt64(&customFieldDbCounter, 1)
|
||||
dsn := fmt.Sprintf("file:cfdb_%d_%s?mode=memory&cache=shared", id, t.Name())
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: dsn,
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Skipf("skipping custom field test (SQLite unavailable): %v", err)
|
||||
return nil, "", "", func() {}
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(
|
||||
&domain.User{},
|
||||
&domain.Role{},
|
||||
&domain.Permission{},
|
||||
&domain.UserRole{},
|
||||
&domain.RolePermission{},
|
||||
&domain.CustomField{},
|
||||
&domain.UserCustomFieldValue{},
|
||||
); err != nil {
|
||||
t.Fatalf("db migration failed: %v", err)
|
||||
}
|
||||
|
||||
seedHandlerAuthzData(t, db)
|
||||
|
||||
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||
HS256Secret: "test-cf-secret-key",
|
||||
AccessTokenExpire: 15 * time.Minute,
|
||||
RefreshTokenExpire: 7 * 24 * time.Hour,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
l1Cache := cache.NewL1Cache()
|
||||
l2Cache := cache.NewRedisCache(false)
|
||||
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
||||
|
||||
userRepo := repository.NewUserRepository(db)
|
||||
roleRepo := repository.NewRoleRepository(db)
|
||||
userRoleRepo := repository.NewUserRoleRepository(db)
|
||||
|
||||
authSvc := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
|
||||
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
|
||||
|
||||
fieldRepo := repository.NewCustomFieldRepository(db)
|
||||
valueRepo := repository.NewUserCustomFieldValueRepository(db)
|
||||
cfSvc := service.NewCustomFieldService(fieldRepo, valueRepo)
|
||||
cfHandler := handler.NewCustomFieldHandler(cfSvc)
|
||||
|
||||
rateLimitCfg := config.RateLimitConfig{}
|
||||
rateLimitMiddleware := middleware.NewRateLimitMiddleware(rateLimitCfg)
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
jwtManager, userRepo, userRoleRepo, l1Cache,
|
||||
)
|
||||
authMiddleware.SetCacheManager(cacheManager)
|
||||
|
||||
authHandler := handler.NewAuthHandler(authSvc)
|
||||
|
||||
r := router.NewRouter(
|
||||
authHandler, nil, nil, nil, nil, nil,
|
||||
authMiddleware, rateLimitMiddleware, nil,
|
||||
nil, nil, nil, nil,
|
||||
nil, nil, nil, nil, cfHandler, nil, nil, nil, nil,
|
||||
)
|
||||
engine := r.Setup()
|
||||
server := httptest.NewServer(engine)
|
||||
|
||||
// Register a regular user
|
||||
regBody := map[string]interface{}{
|
||||
"username": fmt.Sprintf("cfuser_%d", id),
|
||||
"password": "TestPass123!",
|
||||
"email": fmt.Sprintf("cf_%d@test.com", id),
|
||||
}
|
||||
regBytes, _ := json.Marshal(regBody)
|
||||
regResp, _ := http.Post(server.URL+"/api/v1/auth/register", "application/json", bytes.NewReader(regBytes))
|
||||
io.ReadAll(regResp.Body)
|
||||
regResp.Body.Close()
|
||||
|
||||
// Login as regular user
|
||||
loginBody := map[string]interface{}{
|
||||
"account": regBody["username"],
|
||||
"password": regBody["password"],
|
||||
}
|
||||
loginBytes, _ := json.Marshal(loginBody)
|
||||
loginResp, _ := http.Post(server.URL+"/api/v1/auth/login", "application/json", bytes.NewReader(loginBytes))
|
||||
var loginResult struct {
|
||||
Data struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
} `json:"data"`
|
||||
}
|
||||
json.NewDecoder(loginResp.Body).Decode(&loginResult)
|
||||
loginResp.Body.Close()
|
||||
userToken := loginResult.Data.AccessToken
|
||||
|
||||
// Bootstrap admin
|
||||
t.Setenv("BOOTSTRAP_SECRET", fmt.Sprintf("cf-bootstrap-%d", id))
|
||||
adminToken := bootstrapAdmin(server.URL, fmt.Sprintf("cf-bootstrap-%d", id), fmt.Sprintf("cfadmin_%d", id), fmt.Sprintf("cfa_%d@test.com", id), "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
return server, adminToken, userToken, func() {
|
||||
server.Close()
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomFieldHandler_CreateField(t *testing.T) {
|
||||
server, adminToken, userToken, cleanup := setupCustomFieldTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
payload map[string]interface{}
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
payload: map[string]interface{}{
|
||||
"name": "Test Field",
|
||||
"field_key": "test_field_create",
|
||||
"type": 1,
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
payload: map[string]interface{}{
|
||||
"name": "Test Field Unauth",
|
||||
"field_key": "test_field_unauth",
|
||||
"type": 1,
|
||||
},
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "forbidden",
|
||||
payload: map[string]interface{}{
|
||||
"name": "Test Field Forbidden",
|
||||
"field_key": "test_field_forbidden",
|
||||
"type": 1,
|
||||
},
|
||||
token: userToken,
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "missing_required_fields",
|
||||
payload: map[string]interface{}{"name": "Missing Key"},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doPost(server.URL+"/api/v1/custom-fields", tt.token, tt.payload)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomFieldHandler_ListFields(t *testing.T) {
|
||||
server, adminToken, userToken, cleanup := setupCustomFieldTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success_admin",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "forbidden_regular_user",
|
||||
token: userToken,
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doGet(server.URL+"/api/v1/custom-fields", tt.token)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomFieldHandler_GetField(t *testing.T) {
|
||||
server, adminToken, _, cleanup := setupCustomFieldTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a field
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
|
||||
"name": "Get Field Test",
|
||||
"field_key": "test_field_get",
|
||||
"type": 1,
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
var createResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||
t.Fatalf("parse create response failed: %v", err)
|
||||
}
|
||||
fieldData := createResult["data"].(map[string]interface{})
|
||||
fieldID := int64(fieldData["id"].(float64))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fieldID string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
fieldID: fmt.Sprintf("%d", fieldID),
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "not_found",
|
||||
fieldID: "99999",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "invalid_id",
|
||||
fieldID: "invalid",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
fieldID: fmt.Sprintf("%d", fieldID),
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doGet(server.URL+"/api/v1/custom-fields/"+tt.fieldID, tt.token)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomFieldHandler_UpdateField(t *testing.T) {
|
||||
server, adminToken, _, cleanup := setupCustomFieldTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a field
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
|
||||
"name": "Update Field Test",
|
||||
"field_key": "test_field_update",
|
||||
"type": 1,
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
var createResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||
t.Fatalf("parse create response failed: %v", err)
|
||||
}
|
||||
fieldData := createResult["data"].(map[string]interface{})
|
||||
fieldID := int64(fieldData["id"].(float64))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fieldID string
|
||||
payload map[string]interface{}
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
fieldID: fmt.Sprintf("%d", fieldID),
|
||||
payload: map[string]interface{}{
|
||||
"name": "Updated Field Name",
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invalid_id",
|
||||
fieldID: "invalid",
|
||||
payload: map[string]interface{}{
|
||||
"name": "Updated Field Name",
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
fieldID: fmt.Sprintf("%d", fieldID),
|
||||
payload: map[string]interface{}{
|
||||
"name": "Updated Field Name",
|
||||
},
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doPut(server.URL+"/api/v1/custom-fields/"+tt.fieldID, tt.token, tt.payload)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomFieldHandler_DeleteField(t *testing.T) {
|
||||
server, adminToken, _, cleanup := setupCustomFieldTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a field
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
|
||||
"name": "Delete Field Test",
|
||||
"field_key": "test_field_delete",
|
||||
"type": 1,
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
var createResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||
t.Fatalf("parse create response failed: %v", err)
|
||||
}
|
||||
fieldData := createResult["data"].(map[string]interface{})
|
||||
fieldID := int64(fieldData["id"].(float64))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fieldID string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
fieldID: fmt.Sprintf("%d", fieldID),
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invalid_id",
|
||||
fieldID: "invalid",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
fieldID: fmt.Sprintf("%d", fieldID),
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doDelete(server.URL+"/api/v1/custom-fields/"+tt.fieldID, tt.token)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomFieldHandler_SetUserFieldValues(t *testing.T) {
|
||||
server, adminToken, userToken, cleanup := setupCustomFieldTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a field for the user to set
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
|
||||
"name": "User Field Test",
|
||||
"field_key": "user_field_test",
|
||||
"type": 1,
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
payload map[string]interface{}
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
payload: map[string]interface{}{
|
||||
"values": map[string]string{
|
||||
"user_field_test": "123",
|
||||
},
|
||||
},
|
||||
token: userToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
payload: map[string]interface{}{
|
||||
"values": map[string]string{
|
||||
"user_field_test": "test_value",
|
||||
},
|
||||
},
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "missing_values",
|
||||
payload: map[string]interface{}{},
|
||||
token: userToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doPut(server.URL+"/api/v1/users/me/custom-fields", tt.token, tt.payload)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCustomFieldHandler_GetUserFieldValues(t *testing.T) {
|
||||
server, adminToken, userToken, cleanup := setupCustomFieldTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Create a field
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/custom-fields", adminToken, map[string]interface{}{
|
||||
"name": "User Field Get Test",
|
||||
"field_key": "user_field_get_test",
|
||||
"type": 1,
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create field failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
|
||||
// Set a value first
|
||||
setResp, setBody := doPut(server.URL+"/api/v1/users/me/custom-fields", userToken, map[string]interface{}{
|
||||
"values": map[string]string{
|
||||
"user_field_get_test": "456",
|
||||
},
|
||||
})
|
||||
defer setResp.Body.Close()
|
||||
if setResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("set field value failed: %d %s", setResp.StatusCode, setBody)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
token: userToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doGet(server.URL+"/api/v1/users/me/custom-fields", tt.token)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
510
internal/api/handler/device_handler_test.go
Normal file
510
internal/api/handler/device_handler_test.go
Normal file
@@ -0,0 +1,510 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDeviceHandler_ListDevices(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicelistuser", "devicelist@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicelistuser", "UserPass123!")
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/devices", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_ListDevices_Unauthorized(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/devices", "")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_CreateDevice(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicecreateuser", "devicecreate@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicecreateuser", "UserPass123!")
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/devices", token, map[string]interface{}{
|
||||
"name": "Test Device",
|
||||
"device_id": "device-test-001",
|
||||
"device_type": 3,
|
||||
"device_os": "Windows 10",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusCreated, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_CreateDevice_InvalidBody(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicecreatebad", "devicecreatebad@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicecreatebad", "UserPass123!")
|
||||
|
||||
req, _ := http.NewRequest("POST", server.URL+"/api/v1/devices", bytes.NewReader([]byte("not json")))
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d for invalid body, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_GetDevice(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicegetuser", "deviceget@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicegetuser", "UserPass123!")
|
||||
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-get-001", "Get Device")
|
||||
|
||||
resp, body := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_GetDevice_NotFound(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicegetnf", "devicegetnf@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicegetnf", "UserPass123!")
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/devices/99999", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusNotFound, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_GetDevice_InvalidID(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicegetinv", "devicegetinv@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicegetinv", "UserPass123!")
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/devices/invalid", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_UpdateDevice(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "deviceupdateuser", "deviceupdate@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "deviceupdateuser", "UserPass123!")
|
||||
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-update-001", "Original Name")
|
||||
|
||||
resp, body := doPut(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), token, map[string]interface{}{
|
||||
"device_name": "Updated Name",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_UpdateDevice_NotFound(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "deviceupdatenf", "deviceupdatenf@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "deviceupdatenf", "UserPass123!")
|
||||
|
||||
resp, body := doPut(server.URL+"/api/v1/devices/99999", token, map[string]interface{}{
|
||||
"device_name": "Updated Name",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusNotFound, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_DeleteDevice(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicedeluser", "devicedel@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicedeluser", "UserPass123!")
|
||||
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-del-001", "Delete Device")
|
||||
|
||||
resp, body := doDelete(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
// Verify deletion
|
||||
getResp, _ := doGet(fmt.Sprintf("%s/api/v1/devices/%d", server.URL, deviceID), token)
|
||||
defer getResp.Body.Close()
|
||||
if getResp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("expected device to be deleted, got status %d", getResp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_DeleteDevice_NotFound(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicedelnf", "devicedelnf@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicedelnf", "UserPass123!")
|
||||
|
||||
resp, body := doDelete(server.URL+"/api/v1/devices/99999", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusNotFound, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_UpdateDeviceStatus(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicestatususer", "devicestatus@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicestatususer", "UserPass123!")
|
||||
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-status-001", "Status Device")
|
||||
|
||||
resp, body := doPut(fmt.Sprintf("%s/api/v1/devices/%d/status", server.URL, deviceID), token, map[string]interface{}{
|
||||
"status": "inactive",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_UpdateDeviceStatus_InvalidStatus(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicestatusinv", "devicestatusinv@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicestatusinv", "UserPass123!")
|
||||
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-status-inv-001", "Status Device")
|
||||
|
||||
resp, body := doPut(fmt.Sprintf("%s/api/v1/devices/%d/status", server.URL, deviceID), token, map[string]interface{}{
|
||||
"status": "invalid_status",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_TrustDevice(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicetrustuser", "devicetrust@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicetrustuser", "UserPass123!")
|
||||
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-trust-001", "Trust Device")
|
||||
|
||||
resp, body := doPost(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), token, map[string]interface{}{
|
||||
"trust_duration": "24h",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_UntrustDevice(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "deviceuntrustuser", "deviceuntrust@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "deviceuntrustuser", "UserPass123!")
|
||||
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-untrust-001", "Untrust Device")
|
||||
|
||||
// First trust the device
|
||||
trustResp, trustBody := doPost(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), token, map[string]interface{}{
|
||||
"trust_duration": "24h",
|
||||
})
|
||||
defer trustResp.Body.Close()
|
||||
if trustResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected trust status %d, got %d, body: %s", http.StatusOK, trustResp.StatusCode, trustBody)
|
||||
}
|
||||
|
||||
// Then untrust
|
||||
resp, body := doDelete(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_GetMyTrustedDevices(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicetrusteduser", "devicetrusted@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicetrusteduser", "UserPass123!")
|
||||
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-trusted-001", "Trusted Device")
|
||||
|
||||
// Trust the device first
|
||||
trustResp, trustBody := doPost(fmt.Sprintf("%s/api/v1/devices/%d/trust", server.URL, deviceID), token, map[string]interface{}{
|
||||
"trust_duration": "24h",
|
||||
})
|
||||
defer trustResp.Body.Close()
|
||||
if trustResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected trust status %d, got %d, body: %s", http.StatusOK, trustResp.StatusCode, trustBody)
|
||||
}
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/devices/me/trusted", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_LogoutAllOtherDevices(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicelogoutuser", "devicelogout@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicelogoutuser", "UserPass123!")
|
||||
|
||||
deviceID := createDeviceForHandlerTest(t, server.URL, token, "device-logout-001", "Logout Device")
|
||||
|
||||
req, _ := http.NewRequest("POST", server.URL+"/api/v1/devices/me/logout-others", nil)
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("X-Device-ID", fmt.Sprintf("%d", deviceID))
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := json.Marshal(resp.Body)
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_LogoutAllOtherDevices_MissingDeviceID(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicelogoutbad", "devicelogoutbad@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicelogoutbad", "UserPass123!")
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/devices/me/logout-others", token, nil)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_GetUserDevices_AdminCanViewOthers(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "deviceadmin", "deviceadmin@test.com", "AdminPass123!")
|
||||
registerUser(server.URL, "deviceuserview", "deviceuserview@test.com", "UserPass123!")
|
||||
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin should return access token")
|
||||
}
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/devices/users/2", adminToken)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_GetUserDevices_NonAdminForbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "deviceuser1", "deviceuser1@test.com", "UserPass123!")
|
||||
registerUser(server.URL, "deviceuser2", "deviceuser2@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "deviceuser1", "UserPass123!")
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/devices/users/2", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_GetAllDevices_AdminOnly(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "deviceadmin2", "deviceadmin2@test.com", "AdminPass123!")
|
||||
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin should return access token")
|
||||
}
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/admin/devices", adminToken)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_GetAllDevices_NonAdminForbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "deviceuser3", "deviceuser3@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "deviceuser3", "UserPass123!")
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/admin/devices", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_TrustDeviceByDeviceID(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicetrustiduser", "devicetrustid@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicetrustiduser", "UserPass123!")
|
||||
|
||||
// Create device with specific device_id
|
||||
resp, body := doPost(server.URL+"/api/v1/devices", token, map[string]interface{}{
|
||||
"name": "Trust By ID Device",
|
||||
"device_id": "my-unique-device-id",
|
||||
"device_type": 1,
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("expected create status %d, got %d, body: %s", http.StatusCreated, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
// Trust by device ID
|
||||
trustResp, trustBody := doPost(server.URL+"/api/v1/devices/by-device-id/my-unique-device-id/trust", token, map[string]interface{}{
|
||||
"trust_duration": "24h",
|
||||
})
|
||||
defer trustResp.Body.Close()
|
||||
|
||||
if trustResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, trustResp.StatusCode, trustBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceHandler_TrustDeviceByDeviceID_EmptyID(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "devicetrustidbad", "devicetrustidbad@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "devicetrustidbad", "UserPass123!")
|
||||
|
||||
// The route uses ":deviceId" path param, so empty ID would be a different route or 404
|
||||
// Actually the route is /by-device-id/:deviceId/trust, so empty deviceId is not matched
|
||||
// Let's test with a device ID that doesn't exist
|
||||
resp, body := doPost(server.URL+"/api/v1/devices/by-device-id/nonexistent/trust", token, map[string]interface{}{
|
||||
"trust_duration": "24h",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Service returns error for non-existent device
|
||||
if resp.StatusCode != http.StatusNotFound && resp.StatusCode != http.StatusInternalServerError {
|
||||
t.Errorf("expected status 404 or 500 for non-existent device, got %d, body: %s", resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
319
internal/api/handler/export_handler_test.go
Normal file
319
internal/api/handler/export_handler_test.go
Normal file
@@ -0,0 +1,319 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/user-management-system/internal/api/handler"
|
||||
"github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/api/router"
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/config"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"github.com/user-management-system/internal/service"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
var exportDbCounter int64
|
||||
|
||||
func setupExportTestServer(t *testing.T) (*httptest.Server, string, string, func()) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
id := atomic.AddInt64(&exportDbCounter, 1)
|
||||
dsn := fmt.Sprintf("file:exportdb_%d_%s?mode=memory&cache=shared", id, t.Name())
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: dsn,
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Skipf("skipping export test (SQLite unavailable): %v", err)
|
||||
return nil, "", "", func() {}
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(
|
||||
&domain.User{},
|
||||
&domain.Role{},
|
||||
&domain.Permission{},
|
||||
&domain.UserRole{},
|
||||
&domain.RolePermission{},
|
||||
); err != nil {
|
||||
t.Fatalf("db migration failed: %v", err)
|
||||
}
|
||||
|
||||
seedHandlerAuthzData(t, db)
|
||||
|
||||
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||
HS256Secret: "test-export-secret-key",
|
||||
AccessTokenExpire: 15 * time.Minute,
|
||||
RefreshTokenExpire: 7 * 24 * time.Hour,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
l1Cache := cache.NewL1Cache()
|
||||
l2Cache := cache.NewRedisCache(false)
|
||||
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
||||
|
||||
userRepo := repository.NewUserRepository(db)
|
||||
roleRepo := repository.NewRoleRepository(db)
|
||||
userRoleRepo := repository.NewUserRoleRepository(db)
|
||||
|
||||
authSvc := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
|
||||
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
|
||||
|
||||
exportSvc := service.NewExportService(userRepo, nil)
|
||||
exportHandler := handler.NewExportHandler(exportSvc)
|
||||
|
||||
rateLimitCfg := config.RateLimitConfig{}
|
||||
rateLimitMiddleware := middleware.NewRateLimitMiddleware(rateLimitCfg)
|
||||
authMiddleware := middleware.NewAuthMiddleware(
|
||||
jwtManager, userRepo, userRoleRepo, l1Cache,
|
||||
)
|
||||
authMiddleware.SetCacheManager(cacheManager)
|
||||
|
||||
authHandler := handler.NewAuthHandler(authSvc)
|
||||
|
||||
r := router.NewRouter(
|
||||
authHandler, nil, nil, nil, nil, nil,
|
||||
authMiddleware, rateLimitMiddleware, nil,
|
||||
nil, nil, nil, nil,
|
||||
nil, exportHandler, nil, nil, nil, nil, nil, nil, nil,
|
||||
)
|
||||
engine := r.Setup()
|
||||
server := httptest.NewServer(engine)
|
||||
|
||||
// Register a regular user
|
||||
regBody := map[string]interface{}{
|
||||
"username": fmt.Sprintf("exportuser_%d", id),
|
||||
"password": "TestPass123!",
|
||||
"email": fmt.Sprintf("ex_%d@test.com", id),
|
||||
}
|
||||
regBytes, _ := json.Marshal(regBody)
|
||||
regResp, _ := http.Post(server.URL+"/api/v1/auth/register", "application/json", bytes.NewReader(regBytes))
|
||||
io.ReadAll(regResp.Body)
|
||||
regResp.Body.Close()
|
||||
|
||||
// Login as regular user
|
||||
loginBody := map[string]interface{}{
|
||||
"account": regBody["username"],
|
||||
"password": regBody["password"],
|
||||
}
|
||||
loginBytes, _ := json.Marshal(loginBody)
|
||||
loginResp, _ := http.Post(server.URL+"/api/v1/auth/login", "application/json", bytes.NewReader(loginBytes))
|
||||
var loginResult struct {
|
||||
Data struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
} `json:"data"`
|
||||
}
|
||||
json.NewDecoder(loginResp.Body).Decode(&loginResult)
|
||||
loginResp.Body.Close()
|
||||
userToken := loginResult.Data.AccessToken
|
||||
|
||||
// Bootstrap admin
|
||||
t.Setenv("BOOTSTRAP_SECRET", fmt.Sprintf("export-bootstrap-%d", id))
|
||||
adminToken := bootstrapAdmin(server.URL, fmt.Sprintf("export-bootstrap-%d", id), fmt.Sprintf("exportadmin_%d", id), fmt.Sprintf("exa_%d@test.com", id), "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
return server, adminToken, userToken, func() {
|
||||
server.Close()
|
||||
if sqlDB, err := db.DB(); err == nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportHandler_ExportUsers(t *testing.T) {
|
||||
server, adminToken, userToken, cleanup := setupExportTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success_csv",
|
||||
query: "format=csv",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "success_excel",
|
||||
query: "format=xlsx",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "forbidden_regular_user",
|
||||
query: "format=csv",
|
||||
token: userToken,
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
query: "format=csv",
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
url := server.URL + "/api/v1/admin/users/export"
|
||||
if tt.query != "" {
|
||||
url = url + "?" + tt.query
|
||||
}
|
||||
resp, body := doGet(url, tt.token)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportHandler_ImportUsers(t *testing.T) {
|
||||
server, adminToken, userToken, cleanup := setupExportTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
csvData := []byte("\xEF\xBB\xBF用户名,密码,邮箱,手机号,昵称,性别,地区,个人简介\nimportuser1,Password123!,import1@test.com,13800138001,Import1,男,北京,简介1\n")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
fileBody []byte
|
||||
filename string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success_csv",
|
||||
fileBody: csvData,
|
||||
filename: "users.csv",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "forbidden_regular_user",
|
||||
fileBody: csvData,
|
||||
filename: "users.csv",
|
||||
token: userToken,
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
fileBody: csvData,
|
||||
filename: "users.csv",
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var body bytes.Buffer
|
||||
writer := multipart.NewWriter(&body)
|
||||
part, err := writer.CreateFormFile("file", tt.filename)
|
||||
if err != nil {
|
||||
t.Fatalf("create form file failed: %v", err)
|
||||
}
|
||||
if _, err := part.Write(tt.fileBody); err != nil {
|
||||
t.Fatalf("write file body failed: %v", err)
|
||||
}
|
||||
if err := writer.Close(); err != nil {
|
||||
t.Fatalf("close multipart writer failed: %v", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, server.URL+"/api/v1/admin/users/import", &body)
|
||||
if err != nil {
|
||||
t.Fatalf("create request failed: %v", err)
|
||||
}
|
||||
if tt.token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+tt.token)
|
||||
}
|
||||
req.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, string(respBody))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExportHandler_GetImportTemplate(t *testing.T) {
|
||||
server, adminToken, userToken, cleanup := setupExportTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
query string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success_csv",
|
||||
query: "format=csv",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "success_excel",
|
||||
query: "format=xlsx",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "forbidden_regular_user",
|
||||
query: "format=csv",
|
||||
token: userToken,
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
query: "format=csv",
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
url := server.URL + "/api/v1/admin/users/import/template"
|
||||
if tt.query != "" {
|
||||
url = url + "?" + tt.query
|
||||
}
|
||||
resp, body := doGet(url, tt.token)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
308
internal/api/handler/password_reset_handler_test.go
Normal file
308
internal/api/handler/password_reset_handler_test.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPasswordResetHandler_ForgotPassword(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "resetuser", "resetuser@test.com", "UserPass123!")
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
|
||||
"email": "resetuser@test.com",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetHandler_ForgotPassword_MissingEmail(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetHandler_ForgotPassword_NonExistentEmail(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// For non-existent email, the service returns success to prevent user enumeration
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
|
||||
"email": "nonexistent@test.com",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d for non-existent email, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetHandler_ValidateResetToken(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "validatetokenuser", "validatetoken@test.com", "UserPass123!")
|
||||
|
||||
// First request a password reset to generate a token
|
||||
_, _ = doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
|
||||
"email": "validatetoken@test.com",
|
||||
})
|
||||
|
||||
// We can't easily get the token from email, so test with an invalid token
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/password/validate", "", map[string]interface{}{
|
||||
"token": "invalid-token-12345",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
}
|
||||
|
||||
data, ok := result["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected data in response, got %s", body)
|
||||
}
|
||||
if data["valid"] != false {
|
||||
t.Errorf("expected valid=false for invalid token, got %v", data["valid"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetHandler_ValidateResetToken_MissingToken(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/password/validate", "", map[string]interface{}{})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetHandler_ResetPassword(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "resetpwuser", "resetpw@test.com", "UserPass123!")
|
||||
|
||||
// Request reset to generate token
|
||||
_, _ = doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
|
||||
"email": "resetpw@test.com",
|
||||
})
|
||||
|
||||
// Since we can't get the token, test with invalid token
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
|
||||
"token": "invalid-token",
|
||||
"new_password": "NewPass123!",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Should fail because token is invalid (service returns 404 for "不存在")
|
||||
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusBadRequest && resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("expected status 401, 400 or 404 for invalid token, got %d, body: %s", resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetHandler_ResetPassword_MissingToken(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
|
||||
"new_password": "NewPass123!",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetHandler_ResetPassword_MissingPassword(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
|
||||
"token": "some-token",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetHandler_ResetPassword_WeakPassword(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "resetpwweak", "resetpwweak@test.com", "UserPass123!")
|
||||
|
||||
// We need a valid token to test weak password rejection
|
||||
// Let's manually create one through the cache by using forgot-password
|
||||
_, _ = doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
|
||||
"email": "resetpwweak@test.com",
|
||||
})
|
||||
|
||||
// Use invalid token - the validation happens before password strength check
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
|
||||
"token": "invalid-token",
|
||||
"new_password": "123",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusBadRequest && resp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("expected status 401, 400 or 404, got %d, body: %s", resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetHandler_ForgotPasswordByPhone_ServiceUnavailable(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// The password reset handler in the test setup does not have SMS service configured
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/forgot-password/phone", "", map[string]interface{}{
|
||||
"phone": "13800138000",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusServiceUnavailable, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetHandler_ResetPasswordByPhone_MissingFields(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/reset-password/phone", "", map[string]interface{}{})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetHandler_ResetPasswordByPhone_InvalidCode(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "resetphoneuser", "resetphone@test.com", "UserPass123!")
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/reset-password/phone", "", map[string]interface{}{
|
||||
"phone": "13800138000",
|
||||
"code": "000000",
|
||||
"new_password": "NewPass123!",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Should fail because no code was sent
|
||||
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status 401 or 400 for invalid code, got %d, body: %s", resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetHandler_ForgotPassword_InvalidJSON(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
req, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/forgot-password", bytes.NewReader([]byte("not json")))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d for invalid JSON, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordResetHandler_FullFlow(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "fullflowuser", "fullflow@test.com", "UserPass123!")
|
||||
|
||||
// Step 1: Request password reset
|
||||
forgotResp, forgotBody := doPost(server.URL+"/api/v1/auth/forgot-password", "", map[string]interface{}{
|
||||
"email": "fullflow@test.com",
|
||||
})
|
||||
defer forgotResp.Body.Close()
|
||||
if forgotResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("forgot-password failed: status=%d body=%s", forgotResp.StatusCode, forgotBody)
|
||||
}
|
||||
|
||||
// Step 2: Validate token (we don't know the real token, so it will be invalid)
|
||||
validateResp, validateBody := doPost(server.URL+"/api/v1/auth/password/validate", "", map[string]interface{}{
|
||||
"token": "unknown-token",
|
||||
})
|
||||
defer validateResp.Body.Close()
|
||||
if validateResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("validate token failed: status=%d body=%s", validateResp.StatusCode, validateBody)
|
||||
}
|
||||
|
||||
var validateResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(validateBody), &validateResult); err != nil {
|
||||
t.Fatalf("failed to parse validate response: %v", err)
|
||||
}
|
||||
validateData, ok := validateResult["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected validate data, got %s", validateBody)
|
||||
}
|
||||
if validateData["valid"] != false {
|
||||
t.Errorf("expected valid=false for unknown token, got %v", validateData["valid"])
|
||||
}
|
||||
|
||||
// Step 3: Try reset with invalid token
|
||||
resetResp, resetBody := doPost(server.URL+"/api/v1/auth/reset-password", "", map[string]interface{}{
|
||||
"token": "unknown-token",
|
||||
"new_password": "NewPass123!",
|
||||
})
|
||||
defer resetResp.Body.Close()
|
||||
|
||||
// Should fail because token is invalid (service returns 404 for "不存在")
|
||||
if resetResp.StatusCode != http.StatusUnauthorized && resetResp.StatusCode != http.StatusNotFound {
|
||||
t.Errorf("expected status 401 or 404 for invalid token reset, got %d, body: %s", resetResp.StatusCode, resetBody)
|
||||
}
|
||||
|
||||
// Step 4: Verify old password still works
|
||||
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "fullflowuser",
|
||||
"password": "UserPass123!",
|
||||
})
|
||||
defer loginResp.Body.Close()
|
||||
if loginResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("old password should still work: status=%d body=%s", loginResp.StatusCode, loginBody)
|
||||
}
|
||||
}
|
||||
455
internal/api/handler/permission_handler_test.go
Normal file
455
internal/api/handler/permission_handler_test.go
Normal file
@@ -0,0 +1,455 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestPermissionHandler_CreatePermission(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
if ok := registerUser(server.URL, "permuser", "permuser@test.com", "UserPass123!"); !ok {
|
||||
t.Fatal("register user failed")
|
||||
}
|
||||
userToken := getToken(server.URL, "permuser", "UserPass123!")
|
||||
if userToken == "" {
|
||||
t.Fatal("get user token failed")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
payload map[string]interface{}
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
payload: map[string]interface{}{
|
||||
"name": "Test Permission",
|
||||
"code": "test:permission:create",
|
||||
"type": 2,
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
payload: map[string]interface{}{
|
||||
"name": "Test Permission",
|
||||
"code": "test:permission:unauth",
|
||||
"type": 2,
|
||||
},
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "forbidden",
|
||||
payload: map[string]interface{}{
|
||||
"name": "Test Permission",
|
||||
"code": "test:permission:forbid",
|
||||
"type": 2,
|
||||
},
|
||||
token: userToken,
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "invalid_type",
|
||||
payload: map[string]interface{}{
|
||||
"name": "Test Permission",
|
||||
"code": "test:permission:badtype",
|
||||
"type": 5,
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "missing_required_fields",
|
||||
payload: map[string]interface{}{"name": "Missing Code"},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doPost(server.URL+"/api/v1/permissions", tt.token, tt.payload)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionHandler_ListPermissions(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
if ok := registerUser(server.URL, "permuser", "permuser@test.com", "UserPass123!"); !ok {
|
||||
t.Fatal("register user failed")
|
||||
}
|
||||
userToken := getToken(server.URL, "permuser", "UserPass123!")
|
||||
if userToken == "" {
|
||||
t.Fatal("get user token failed")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success_admin",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "forbidden_regular_user",
|
||||
token: userToken,
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doGet(server.URL+"/api/v1/permissions", tt.token)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionHandler_GetPermission(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
// Create a permission to retrieve
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/permissions", adminToken, map[string]interface{}{
|
||||
"name": "Get Permission Test",
|
||||
"code": "test:permission:get",
|
||||
"type": 2,
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create permission failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
var createResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||
t.Fatalf("parse create response failed: %v", err)
|
||||
}
|
||||
permData, ok := createResult["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected data in create response, got %s", createBody)
|
||||
}
|
||||
permID := int64(permData["id"].(float64))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
permID string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
permID: fmt.Sprintf("%d", permID),
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "not_found",
|
||||
permID: "99999",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "invalid_id",
|
||||
permID: "invalid",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
permID: fmt.Sprintf("%d", permID),
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doGet(server.URL+"/api/v1/permissions/"+tt.permID, tt.token)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionHandler_UpdatePermission(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
// Create a permission to update
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/permissions", adminToken, map[string]interface{}{
|
||||
"name": "Update Permission Test",
|
||||
"code": "test:permission:update",
|
||||
"type": 2,
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create permission failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
var createResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||
t.Fatalf("parse create response failed: %v", err)
|
||||
}
|
||||
permData := createResult["data"].(map[string]interface{})
|
||||
permID := int64(permData["id"].(float64))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
permID string
|
||||
payload map[string]interface{}
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
permID: fmt.Sprintf("%d", permID),
|
||||
payload: map[string]interface{}{
|
||||
"name": "Updated Permission Name",
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invalid_id",
|
||||
permID: "invalid",
|
||||
payload: map[string]interface{}{
|
||||
"name": "Updated Permission Name",
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
permID: fmt.Sprintf("%d", permID),
|
||||
payload: map[string]interface{}{
|
||||
"name": "Updated Permission Name",
|
||||
},
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doPut(server.URL+"/api/v1/permissions/"+tt.permID, tt.token, tt.payload)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionHandler_DeletePermission(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
// Create a permission to delete
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/permissions", adminToken, map[string]interface{}{
|
||||
"name": "Delete Permission Test",
|
||||
"code": "test:permission:delete",
|
||||
"type": 2,
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create permission failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
var createResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||
t.Fatalf("parse create response failed: %v", err)
|
||||
}
|
||||
permData := createResult["data"].(map[string]interface{})
|
||||
permID := int64(permData["id"].(float64))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
permID string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
permID: fmt.Sprintf("%d", permID),
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invalid_id",
|
||||
permID: "invalid",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
permID: fmt.Sprintf("%d", permID),
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doDelete(server.URL+"/api/v1/permissions/"+tt.permID, tt.token)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionHandler_UpdatePermissionStatus(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
// Create a permission
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/permissions", adminToken, map[string]interface{}{
|
||||
"name": "Status Permission Test",
|
||||
"code": "test:permission:status",
|
||||
"type": 2,
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create permission failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
var createResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||
t.Fatalf("parse create response failed: %v", err)
|
||||
}
|
||||
permData := createResult["data"].(map[string]interface{})
|
||||
permID := int64(permData["id"].(float64))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
permID string
|
||||
payload map[string]interface{}
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success_numeric",
|
||||
permID: fmt.Sprintf("%d", permID),
|
||||
payload: map[string]interface{}{
|
||||
"status": 0,
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invalid_id",
|
||||
permID: "invalid",
|
||||
payload: map[string]interface{}{
|
||||
"status": 0,
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
permID: fmt.Sprintf("%d", permID),
|
||||
payload: map[string]interface{}{
|
||||
"status": 0,
|
||||
},
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doPut(server.URL+"/api/v1/permissions/"+tt.permID+"/status", tt.token, tt.payload)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPermissionHandler_GetPermissionTree(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "perm-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "perm-bootstrap-secret", "permadmin", "permadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/permissions/tree", adminToken)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("parse response failed: %v", err)
|
||||
}
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
}
|
||||
if result["data"] == nil {
|
||||
t.Errorf("expected data in response")
|
||||
}
|
||||
}
|
||||
527
internal/api/handler/role_handler_test.go
Normal file
527
internal/api/handler/role_handler_test.go
Normal file
@@ -0,0 +1,527 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRoleHandler_CreateRole(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
if ok := registerUser(server.URL, "roleuser", "roleuser@test.com", "UserPass123!"); !ok {
|
||||
t.Fatal("register user failed")
|
||||
}
|
||||
userToken := getToken(server.URL, "roleuser", "UserPass123!")
|
||||
if userToken == "" {
|
||||
t.Fatal("get user token failed")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
payload map[string]interface{}
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
payload: map[string]interface{}{
|
||||
"name": "Test Role",
|
||||
"code": "test_role_create",
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusCreated,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
payload: map[string]interface{}{
|
||||
"name": "Test Role Unauth",
|
||||
"code": "test_role_unauth",
|
||||
},
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
{
|
||||
name: "forbidden",
|
||||
payload: map[string]interface{}{
|
||||
"name": "Test Role Forbidden",
|
||||
"code": "test_role_forbidden",
|
||||
},
|
||||
token: userToken,
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "missing_required_fields",
|
||||
payload: map[string]interface{}{"name": "Missing Code"},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doPost(server.URL+"/api/v1/roles", tt.token, tt.payload)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleHandler_ListRoles(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
if ok := registerUser(server.URL, "roleuser", "roleuser@test.com", "UserPass123!"); !ok {
|
||||
t.Fatal("register user failed")
|
||||
}
|
||||
userToken := getToken(server.URL, "roleuser", "UserPass123!")
|
||||
if userToken == "" {
|
||||
t.Fatal("get user token failed")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success_admin",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "forbidden_regular_user",
|
||||
token: userToken,
|
||||
wantStatus: http.StatusForbidden,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doGet(server.URL+"/api/v1/roles", tt.token)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleHandler_GetRole(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
// Create a role to retrieve
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
|
||||
"name": "Get Role Test",
|
||||
"code": "test_role_get",
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
var createResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||
t.Fatalf("parse create response failed: %v", err)
|
||||
}
|
||||
roleData := createResult["data"].(map[string]interface{})
|
||||
roleID := int64(roleData["id"].(float64))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roleID string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
roleID: fmt.Sprintf("%d", roleID),
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "not_found",
|
||||
roleID: "99999",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusNotFound,
|
||||
},
|
||||
{
|
||||
name: "invalid_id",
|
||||
roleID: "invalid",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
roleID: fmt.Sprintf("%d", roleID),
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doGet(server.URL+"/api/v1/roles/"+tt.roleID, tt.token)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleHandler_UpdateRole(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
// Create a role to update
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
|
||||
"name": "Update Role Test",
|
||||
"code": "test_role_update",
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
var createResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||
t.Fatalf("parse create response failed: %v", err)
|
||||
}
|
||||
roleData := createResult["data"].(map[string]interface{})
|
||||
roleID := int64(roleData["id"].(float64))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roleID string
|
||||
payload map[string]interface{}
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
roleID: fmt.Sprintf("%d", roleID),
|
||||
payload: map[string]interface{}{
|
||||
"name": "Updated Role Name",
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invalid_id",
|
||||
roleID: "invalid",
|
||||
payload: map[string]interface{}{
|
||||
"name": "Updated Role Name",
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
roleID: fmt.Sprintf("%d", roleID),
|
||||
payload: map[string]interface{}{
|
||||
"name": "Updated Role Name",
|
||||
},
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doPut(server.URL+"/api/v1/roles/"+tt.roleID, tt.token, tt.payload)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleHandler_DeleteRole(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
// Create a role to delete
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
|
||||
"name": "Delete Role Test",
|
||||
"code": "test_role_delete",
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
var createResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||
t.Fatalf("parse create response failed: %v", err)
|
||||
}
|
||||
roleData := createResult["data"].(map[string]interface{})
|
||||
roleID := int64(roleData["id"].(float64))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roleID string
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
roleID: fmt.Sprintf("%d", roleID),
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invalid_id",
|
||||
roleID: "invalid",
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
roleID: fmt.Sprintf("%d", roleID),
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doDelete(server.URL+"/api/v1/roles/"+tt.roleID, tt.token)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleHandler_UpdateRoleStatus(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
// Create a role
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
|
||||
"name": "Status Role Test",
|
||||
"code": "test_role_status",
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
var createResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||
t.Fatalf("parse create response failed: %v", err)
|
||||
}
|
||||
roleData := createResult["data"].(map[string]interface{})
|
||||
roleID := int64(roleData["id"].(float64))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roleID string
|
||||
payload map[string]interface{}
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success_disabled",
|
||||
roleID: fmt.Sprintf("%d", roleID),
|
||||
payload: map[string]interface{}{
|
||||
"status": "disabled",
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "success_enabled",
|
||||
roleID: fmt.Sprintf("%d", roleID),
|
||||
payload: map[string]interface{}{
|
||||
"status": "enabled",
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invalid_status",
|
||||
roleID: fmt.Sprintf("%d", roleID),
|
||||
payload: map[string]interface{}{
|
||||
"status": "invalid_status",
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "invalid_id",
|
||||
roleID: "invalid",
|
||||
payload: map[string]interface{}{
|
||||
"status": "disabled",
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
roleID: fmt.Sprintf("%d", roleID),
|
||||
payload: map[string]interface{}{
|
||||
"status": "disabled",
|
||||
},
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doPut(server.URL+"/api/v1/roles/"+tt.roleID+"/status", tt.token, tt.payload)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleHandler_GetRolePermissions(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
// Use the admin role (id=1) for testing
|
||||
resp, body := doGet(server.URL+"/api/v1/roles/1/permissions", adminToken)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("parse response failed: %v", err)
|
||||
}
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
}
|
||||
if result["data"] == nil {
|
||||
t.Errorf("expected data in response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoleHandler_AssignPermissions(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "role-bootstrap-secret")
|
||||
adminToken := bootstrapAdmin(server.URL, "role-bootstrap-secret", "roleadmin", "roleadmin@test.com", "AdminPass123!")
|
||||
if adminToken == "" {
|
||||
t.Fatal("bootstrap admin failed")
|
||||
}
|
||||
|
||||
// Create a role
|
||||
createResp, createBody := doPost(server.URL+"/api/v1/roles", adminToken, map[string]interface{}{
|
||||
"name": "Assign Perm Role Test",
|
||||
"code": "test_role_assign_perm",
|
||||
})
|
||||
defer createResp.Body.Close()
|
||||
if createResp.StatusCode != http.StatusCreated {
|
||||
t.Fatalf("create role failed: %d %s", createResp.StatusCode, createBody)
|
||||
}
|
||||
var createResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(createBody), &createResult); err != nil {
|
||||
t.Fatalf("parse create response failed: %v", err)
|
||||
}
|
||||
roleData := createResult["data"].(map[string]interface{})
|
||||
roleID := int64(roleData["id"].(float64))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
roleID string
|
||||
payload map[string]interface{}
|
||||
token string
|
||||
wantStatus int
|
||||
}{
|
||||
{
|
||||
name: "success",
|
||||
roleID: fmt.Sprintf("%d", roleID),
|
||||
payload: map[string]interface{}{
|
||||
"permission_ids": []int64{1, 2},
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusOK,
|
||||
},
|
||||
{
|
||||
name: "invalid_id",
|
||||
roleID: "invalid",
|
||||
payload: map[string]interface{}{
|
||||
"permission_ids": []int64{1},
|
||||
},
|
||||
token: adminToken,
|
||||
wantStatus: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
name: "unauthorized",
|
||||
roleID: fmt.Sprintf("%d", roleID),
|
||||
payload: map[string]interface{}{
|
||||
"permission_ids": []int64{1},
|
||||
},
|
||||
token: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resp, body := doPut(server.URL+"/api/v1/roles/"+tt.roleID+"/permissions", tt.token, tt.payload)
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != tt.wantStatus {
|
||||
t.Errorf("expected status %d, got %d, body: %s", tt.wantStatus, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
855
internal/api/handler/sso_handler_test.go
Normal file
855
internal/api/handler/sso_handler_test.go
Normal file
@@ -0,0 +1,855 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/user-management-system/internal/api/handler"
|
||||
"github.com/user-management-system/internal/auth"
|
||||
)
|
||||
|
||||
func doPostForm(targetURL, token string, data url.Values) (*http.Response, string) {
|
||||
var bodyReader io.Reader
|
||||
if data != nil {
|
||||
bodyReader = strings.NewReader(data.Encode())
|
||||
}
|
||||
req, _ := http.NewRequest("POST", targetURL, bodyReader)
|
||||
if token != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
client := &http.Client{}
|
||||
resp, _ := client.Do(req)
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
resp.Body.Close()
|
||||
return resp, string(bodyBytes)
|
||||
}
|
||||
|
||||
func setupSSOTestServer(t *testing.T) (*httptest.Server, func()) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
engine := gin.New()
|
||||
engine.Use(gin.Recovery())
|
||||
|
||||
ssoManager := auth.NewSSOManager()
|
||||
clientsStore := auth.NewDefaultSSOClientsStore()
|
||||
clientsStore.RegisterClient(&auth.SSOClient{
|
||||
ClientID: "test-client",
|
||||
ClientSecret: "test-secret",
|
||||
Name: "Test Client",
|
||||
RedirectURIs: []string{"http://localhost:8080/callback"},
|
||||
})
|
||||
|
||||
ssoHandler := handler.NewSSOHandler(ssoManager, clientsStore)
|
||||
|
||||
// Simple auth middleware for testing
|
||||
authMiddleware := func() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
token := c.GetHeader("Authorization")
|
||||
if token == "" || token == "Bearer " {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
|
||||
return
|
||||
}
|
||||
c.Set("user_id", int64(1))
|
||||
c.Set("username", "testuser")
|
||||
c.Next()
|
||||
}
|
||||
}()
|
||||
|
||||
ssoGroup := engine.Group("/api/v1/sso")
|
||||
ssoGroup.Use(authMiddleware)
|
||||
{
|
||||
ssoGroup.GET("/authorize", ssoHandler.Authorize)
|
||||
ssoGroup.POST("/token", ssoHandler.Token)
|
||||
ssoGroup.POST("/introspect", ssoHandler.Introspect)
|
||||
ssoGroup.POST("/revoke", ssoHandler.Revoke)
|
||||
ssoGroup.GET("/userinfo", ssoHandler.UserInfo)
|
||||
}
|
||||
|
||||
server := httptest.NewServer(engine)
|
||||
return server, func() {
|
||||
server.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Authorize_MissingParams(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/sso/authorize", "Bearer test-token")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Authorize_UnsupportedResponseType(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=unsupported", "Bearer test-token")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Authorize_Unauthorized(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Authorize_CodeFlow(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code&state=xyz", "Bearer test-token")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusFound {
|
||||
t.Fatalf("expected status %d (redirect), got %d", http.StatusFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
location := resp.Header.Get("Location")
|
||||
if location == "" {
|
||||
t.Fatal("expected redirect location")
|
||||
}
|
||||
if !strings.Contains(location, "code=") {
|
||||
t.Errorf("expected redirect with code, got %s", location)
|
||||
}
|
||||
if !strings.Contains(location, "state=xyz") {
|
||||
t.Errorf("expected redirect with state, got %s", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Authorize_InvalidRedirectURI(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://evil.com/callback&response_type=code", "Bearer test-token")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Authorize_TokenFlow(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=token&state=abc", "Bearer test-token")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusFound {
|
||||
t.Fatalf("expected status %d (redirect), got %d", http.StatusFound, resp.StatusCode)
|
||||
}
|
||||
|
||||
location := resp.Header.Get("Location")
|
||||
if location == "" {
|
||||
t.Fatal("expected redirect location")
|
||||
}
|
||||
if !strings.Contains(location, "access_token=") {
|
||||
t.Errorf("expected redirect with access_token, got %s", location)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Token_MissingParams(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", nil)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Token_InvalidGrantType(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "password")
|
||||
formData.Set("client_id", "test-client")
|
||||
formData.Set("client_secret", "test-secret")
|
||||
|
||||
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Token_InvalidClient(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("code", "some-code")
|
||||
formData.Set("client_id", "invalid-client")
|
||||
formData.Set("client_secret", "wrong-secret")
|
||||
|
||||
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Token_InvalidCode(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("code", "invalid-code")
|
||||
formData.Set("client_id", "test-client")
|
||||
formData.Set("client_secret", "test-secret")
|
||||
|
||||
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Token_Success(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// First authorize to get a code
|
||||
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||
defer authResp.Body.Close()
|
||||
|
||||
if authResp.StatusCode != http.StatusFound {
|
||||
t.Fatalf("expected authorize redirect, got %d", authResp.StatusCode)
|
||||
}
|
||||
|
||||
location := authResp.Header.Get("Location")
|
||||
parsedURL, err := url.Parse(location)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse redirect URL: %v", err)
|
||||
}
|
||||
code := parsedURL.Query().Get("code")
|
||||
if code == "" {
|
||||
t.Fatal("expected authorization code in redirect")
|
||||
}
|
||||
|
||||
// Exchange code for token
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("code", code)
|
||||
formData.Set("client_id", "test-client")
|
||||
formData.Set("client_secret", "test-secret")
|
||||
|
||||
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var tokenResp handler.TokenResponse
|
||||
if err := json.Unmarshal([]byte(body), &tokenResp); err != nil {
|
||||
t.Fatalf("failed to parse token response: %v", err)
|
||||
}
|
||||
if tokenResp.AccessToken == "" {
|
||||
t.Errorf("expected access_token in response")
|
||||
}
|
||||
if tokenResp.TokenType != "Bearer" {
|
||||
t.Errorf("expected token_type Bearer, got %s", tokenResp.TokenType)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Introspect_MissingToken(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Introspect_InvalidToken(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
|
||||
"token": "invalid-token",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result handler.IntrospectResponse
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse introspect response: %v", err)
|
||||
}
|
||||
if result.Active != false {
|
||||
t.Errorf("expected active=false for invalid token, got %v", result.Active)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Introspect_ValidToken(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Authorize and get token
|
||||
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||
defer authResp.Body.Close()
|
||||
|
||||
location := authResp.Header.Get("Location")
|
||||
parsedURL, _ := url.Parse(location)
|
||||
code := parsedURL.Query().Get("code")
|
||||
|
||||
tokenForm := url.Values{}
|
||||
tokenForm.Set("grant_type", "authorization_code")
|
||||
tokenForm.Set("code", code)
|
||||
tokenForm.Set("client_id", "test-client")
|
||||
tokenForm.Set("client_secret", "test-secret")
|
||||
|
||||
tokenResp, tokenBody := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", tokenForm)
|
||||
defer tokenResp.Body.Close()
|
||||
|
||||
if tokenResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("token exchange failed: status=%d body=%s", tokenResp.StatusCode, tokenBody)
|
||||
}
|
||||
|
||||
var tokenResult handler.TokenResponse
|
||||
if err := json.Unmarshal([]byte(tokenBody), &tokenResult); err != nil {
|
||||
t.Fatalf("failed to parse token response: %v", err)
|
||||
}
|
||||
|
||||
// Introspect the token
|
||||
resp, body := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
|
||||
"token": tokenResult.AccessToken,
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result handler.IntrospectResponse
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse introspect response: %v", err)
|
||||
}
|
||||
if result.Active != true {
|
||||
t.Errorf("expected active=true for valid token, got %v", result.Active)
|
||||
}
|
||||
if result.UserID != 1 {
|
||||
t.Errorf("expected user_id=1, got %d", result.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Revoke_MissingToken(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/sso/revoke", "Bearer test-token", map[string]interface{}{})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Revoke_Success(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Authorize and get token
|
||||
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||
defer authResp.Body.Close()
|
||||
|
||||
location := authResp.Header.Get("Location")
|
||||
parsedURL, _ := url.Parse(location)
|
||||
code := parsedURL.Query().Get("code")
|
||||
|
||||
tokenForm := url.Values{}
|
||||
tokenForm.Set("grant_type", "authorization_code")
|
||||
tokenForm.Set("code", code)
|
||||
tokenForm.Set("client_id", "test-client")
|
||||
tokenForm.Set("client_secret", "test-secret")
|
||||
|
||||
tokenResp, tokenBody := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", tokenForm)
|
||||
defer tokenResp.Body.Close()
|
||||
|
||||
if tokenResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("token exchange failed: status=%d body=%s", tokenResp.StatusCode, tokenBody)
|
||||
}
|
||||
|
||||
var tokenResult handler.TokenResponse
|
||||
if err := json.Unmarshal([]byte(tokenBody), &tokenResult); err != nil {
|
||||
t.Fatalf("failed to parse token response: %v", err)
|
||||
}
|
||||
|
||||
// Revoke the token
|
||||
resp, body := doPost(server.URL+"/api/v1/sso/revoke", "Bearer test-token", map[string]interface{}{
|
||||
"token": tokenResult.AccessToken,
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
// Verify token is revoked
|
||||
introspectResp, introspectBody := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
|
||||
"token": tokenResult.AccessToken,
|
||||
})
|
||||
defer introspectResp.Body.Close()
|
||||
|
||||
if introspectResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("introspect failed: status=%d body=%s", introspectResp.StatusCode, introspectBody)
|
||||
}
|
||||
|
||||
var introspectResult handler.IntrospectResponse
|
||||
if err := json.Unmarshal([]byte(introspectBody), &introspectResult); err != nil {
|
||||
t.Fatalf("failed to parse introspect response: %v", err)
|
||||
}
|
||||
if introspectResult.Active != false {
|
||||
t.Errorf("expected active=false after revoke, got %v", introspectResult.Active)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_UserInfo_Unauthorized(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/sso/userinfo", "")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_UserInfo_Success(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/sso/userinfo", "Bearer test-token")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
}
|
||||
|
||||
data, ok := result["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected data in response, got %s", body)
|
||||
}
|
||||
if data["user_id"] != float64(1) {
|
||||
t.Errorf("expected user_id=1, got %v", data["user_id"])
|
||||
}
|
||||
if data["username"] != "testuser" {
|
||||
t.Errorf("expected username=testuser, got %v", data["username"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Token_InvalidClientSecret(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Authorize to get a code
|
||||
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||
defer authResp.Body.Close()
|
||||
|
||||
location := authResp.Header.Get("Location")
|
||||
parsedURL, _ := url.Parse(location)
|
||||
code := parsedURL.Query().Get("code")
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("code", code)
|
||||
formData.Set("client_id", "test-client")
|
||||
formData.Set("client_secret", "wrong-secret")
|
||||
|
||||
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Authorize_MissingClientID(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/sso/authorize?redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Introspect_FormData(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Test that introspect accepts form-encoded data
|
||||
formData := url.Values{}
|
||||
formData.Set("token", "some-token")
|
||||
|
||||
req, _ := http.NewRequest("POST", server.URL+"/api/v1/sso/introspect", strings.NewReader(formData.Encode()))
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := json.Marshal(resp.Body)
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Token_FormData(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Authorize to get a code
|
||||
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||
defer authResp.Body.Close()
|
||||
|
||||
location := authResp.Header.Get("Location")
|
||||
parsedURL, _ := url.Parse(location)
|
||||
code := parsedURL.Query().Get("code")
|
||||
|
||||
// Test that token accepts form-encoded data
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("code", code)
|
||||
formData.Set("client_id", "test-client")
|
||||
formData.Set("client_secret", "test-secret")
|
||||
|
||||
req, _ := http.NewRequest("POST", server.URL+"/api/v1/sso/token", strings.NewReader(formData.Encode()))
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bodyBytes, _ := json.Marshal(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Revoke_FormData(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("token", "some-token")
|
||||
|
||||
req, _ := http.NewRequest("POST", server.URL+"/api/v1/sso/revoke", strings.NewReader(formData.Encode()))
|
||||
req.Header.Set("Authorization", "Bearer test-token")
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := json.Marshal(resp.Body)
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Authorize_UnknownClientID(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/sso/authorize?client_id=unknown-client&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||
defer resp.Body.Close()
|
||||
|
||||
// When client is unknown, redirect_uri validation fails
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Token_WithoutAuth(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("code", "some-code")
|
||||
formData.Set("client_id", "test-client")
|
||||
formData.Set("client_secret", "test-secret")
|
||||
|
||||
resp, _ := doPostForm(server.URL+"/api/v1/sso/token", "", formData)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_UserInfo_WithoutAuth(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/sso/userinfo", "")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Introspect_WithoutAuth(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, _ := doPost(server.URL+"/api/v1/sso/introspect", "", map[string]interface{}{
|
||||
"token": "some-token",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Revoke_WithoutAuth(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, _ := doPost(server.URL+"/api/v1/sso/revoke", "", map[string]interface{}{
|
||||
"token": "some-token",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Authorize_InvalidClientID(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Test with valid redirect URI but unknown client
|
||||
resp, body := doGet(server.URL+"/api/v1/sso/authorize?client_id=unknown&redirect_uri=http://localhost:8080/callback&response_type=code", "Bearer test-token")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Token_MissingCode(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
formData := url.Values{}
|
||||
formData.Set("grant_type", "authorization_code")
|
||||
formData.Set("client_id", "test-client")
|
||||
formData.Set("client_secret", "test-secret")
|
||||
|
||||
resp, body := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", formData)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Code is empty, so validate should fail
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_FullFlow(t *testing.T) {
|
||||
server, cleanup := setupSSOTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
// Step 1: Authorize
|
||||
authResp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=test-client&redirect_uri=http://localhost:8080/callback&response_type=code&state=my-state", "Bearer test-token")
|
||||
defer authResp.Body.Close()
|
||||
|
||||
if authResp.StatusCode != http.StatusFound {
|
||||
t.Fatalf("authorize failed: status=%d", authResp.StatusCode)
|
||||
}
|
||||
|
||||
location := authResp.Header.Get("Location")
|
||||
parsedURL, _ := url.Parse(location)
|
||||
code := parsedURL.Query().Get("code")
|
||||
state := parsedURL.Query().Get("state")
|
||||
if code == "" {
|
||||
t.Fatal("expected authorization code")
|
||||
}
|
||||
if state != "my-state" {
|
||||
t.Errorf("expected state=my-state, got %s", state)
|
||||
}
|
||||
|
||||
// Step 2: Exchange code for token
|
||||
tokenForm := url.Values{}
|
||||
tokenForm.Set("grant_type", "authorization_code")
|
||||
tokenForm.Set("code", code)
|
||||
tokenForm.Set("client_id", "test-client")
|
||||
tokenForm.Set("client_secret", "test-secret")
|
||||
|
||||
tokenResp, tokenBody := doPostForm(server.URL+"/api/v1/sso/token", "Bearer test-token", tokenForm)
|
||||
defer tokenResp.Body.Close()
|
||||
|
||||
if tokenResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("token exchange failed: status=%d body=%s", tokenResp.StatusCode, tokenBody)
|
||||
}
|
||||
|
||||
var tokenResult handler.TokenResponse
|
||||
if err := json.Unmarshal([]byte(tokenBody), &tokenResult); err != nil {
|
||||
t.Fatalf("failed to parse token response: %v", err)
|
||||
}
|
||||
if tokenResult.AccessToken == "" {
|
||||
t.Fatal("expected access_token")
|
||||
}
|
||||
|
||||
// Step 3: Introspect token
|
||||
introspectResp, introspectBody := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
|
||||
"token": tokenResult.AccessToken,
|
||||
})
|
||||
defer introspectResp.Body.Close()
|
||||
|
||||
if introspectResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("introspect failed: status=%d body=%s", introspectResp.StatusCode, introspectBody)
|
||||
}
|
||||
|
||||
var introspectResult handler.IntrospectResponse
|
||||
if err := json.Unmarshal([]byte(introspectBody), &introspectResult); err != nil {
|
||||
t.Fatalf("failed to parse introspect response: %v", err)
|
||||
}
|
||||
if !introspectResult.Active {
|
||||
t.Error("expected token to be active")
|
||||
}
|
||||
if introspectResult.UserID != 1 {
|
||||
t.Errorf("expected user_id=1, got %d", introspectResult.UserID)
|
||||
}
|
||||
|
||||
// Step 4: Get userinfo
|
||||
userinfoResp, userinfoBody := doGet(server.URL+"/api/v1/sso/userinfo", "Bearer test-token")
|
||||
defer userinfoResp.Body.Close()
|
||||
|
||||
if userinfoResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("userinfo failed: status=%d body=%s", userinfoResp.StatusCode, userinfoBody)
|
||||
}
|
||||
|
||||
var userinfoResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(userinfoBody), &userinfoResult); err != nil {
|
||||
t.Fatalf("failed to parse userinfo response: %v", err)
|
||||
}
|
||||
userinfoData, ok := userinfoResult["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected userinfo data, got %s", userinfoBody)
|
||||
}
|
||||
if userinfoData["username"] != "testuser" {
|
||||
t.Errorf("expected username=testuser, got %v", userinfoData["username"])
|
||||
}
|
||||
|
||||
// Step 5: Revoke token
|
||||
revokeResp, revokeBody := doPost(server.URL+"/api/v1/sso/revoke", "Bearer test-token", map[string]interface{}{
|
||||
"token": tokenResult.AccessToken,
|
||||
})
|
||||
defer revokeResp.Body.Close()
|
||||
|
||||
if revokeResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("revoke failed: status=%d body=%s", revokeResp.StatusCode, revokeBody)
|
||||
}
|
||||
|
||||
// Step 6: Verify token is revoked
|
||||
finalIntrospectResp, finalIntrospectBody := doPost(server.URL+"/api/v1/sso/introspect", "Bearer test-token", map[string]interface{}{
|
||||
"token": tokenResult.AccessToken,
|
||||
})
|
||||
defer finalIntrospectResp.Body.Close()
|
||||
|
||||
if finalIntrospectResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("final introspect failed: status=%d body=%s", finalIntrospectResp.StatusCode, finalIntrospectBody)
|
||||
}
|
||||
|
||||
var finalResult handler.IntrospectResponse
|
||||
if err := json.Unmarshal([]byte(finalIntrospectBody), &finalResult); err != nil {
|
||||
t.Fatalf("failed to parse final introspect response: %v", err)
|
||||
}
|
||||
if finalResult.Active {
|
||||
t.Error("expected token to be inactive after revoke")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSOHandler_Authorize_NoClientStore(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
engine := gin.New()
|
||||
ssoManager := auth.NewSSOManager()
|
||||
// Pass nil clientsStore
|
||||
ssoHandler := handler.NewSSOHandler(ssoManager, nil)
|
||||
|
||||
authMiddleware := func() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Set("user_id", int64(1))
|
||||
c.Set("username", "testuser")
|
||||
c.Next()
|
||||
}
|
||||
}()
|
||||
|
||||
ssoGroup := engine.Group("/api/v1/sso")
|
||||
ssoGroup.Use(authMiddleware)
|
||||
{
|
||||
ssoGroup.GET("/authorize", ssoHandler.Authorize)
|
||||
}
|
||||
|
||||
server := httptest.NewServer(engine)
|
||||
defer server.Close()
|
||||
|
||||
// Without clients store, any redirect_uri should be accepted (or validation skipped)
|
||||
resp, _ := doGet(server.URL+"/api/v1/sso/authorize?client_id=any&redirect_uri=http://any.com/callback&response_type=code", "Bearer test-token")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusFound {
|
||||
t.Errorf("expected redirect when clientsStore is nil, got %d", resp.StatusCode)
|
||||
}
|
||||
}
|
||||
685
internal/api/handler/totp_handler_test.go
Normal file
685
internal/api/handler/totp_handler_test.go
Normal file
@@ -0,0 +1,685 @@
|
||||
package handler_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
)
|
||||
|
||||
func TestTOTPHandler_GetTOTPStatus(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "totpstatususer", "totpstatus@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "totpstatususer", "UserPass123!")
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/auth/2fa/status", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
}
|
||||
|
||||
data, ok := result["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected data in response, got %s", body)
|
||||
}
|
||||
if data["enabled"] != false {
|
||||
t.Errorf("expected enabled=false for new user, got %v", data["enabled"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_GetTOTPStatus_Unauthorized(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/auth/2fa/status", "")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_SetupTOTP(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "totpsetupuser", "totpsetup@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "totpsetupuser", "UserPass123!")
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/auth/2fa/setup", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
}
|
||||
|
||||
data, ok := result["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected data in response, got %s", body)
|
||||
}
|
||||
if data["secret"] == nil || data["secret"] == "" {
|
||||
t.Errorf("expected secret in setup response, got %+v", data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_SetupTOTP_Unauthorized(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/auth/2fa/setup", "")
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_EnableTOTP(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpenableuser", "totpenable@test.com", "UserPass123!")
|
||||
_ = userID
|
||||
_ = secret
|
||||
|
||||
// setupEnabledTOTPUser already enables TOTP, so let's just verify the user can login with TOTP
|
||||
// Actually, we need a fresh user to test enable
|
||||
registerUser(server.URL, "totpenableuser2", "totpenable2@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "totpenableuser2", "UserPass123!")
|
||||
|
||||
// Setup TOTP
|
||||
setupResp, setupBody := doGet(server.URL+"/api/v1/auth/2fa/setup", token)
|
||||
defer setupResp.Body.Close()
|
||||
if setupResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("setup failed: status=%d body=%s", setupResp.StatusCode, setupBody)
|
||||
}
|
||||
|
||||
var setupResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(setupBody), &setupResult); err != nil {
|
||||
t.Fatalf("failed to parse setup response: %v", err)
|
||||
}
|
||||
setupData, ok := setupResult["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected setup data, got %s", setupBody)
|
||||
}
|
||||
newSecret, ok := setupData["secret"].(string)
|
||||
if !ok || newSecret == "" {
|
||||
t.Fatalf("expected secret in setup response, got %s", setupBody)
|
||||
}
|
||||
|
||||
// Generate valid code
|
||||
code, err := auth.NewTOTPManager().GenerateCurrentCode(newSecret)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate TOTP code: %v", err)
|
||||
}
|
||||
|
||||
// Enable TOTP
|
||||
enableResp, enableBody := doPost(server.URL+"/api/v1/auth/2fa/enable", token, map[string]interface{}{
|
||||
"code": code,
|
||||
})
|
||||
defer enableResp.Body.Close()
|
||||
|
||||
if enableResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, enableResp.StatusCode, enableBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_EnableTOTP_InvalidCode(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "totpenableinv", "totpenableinv@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "totpenableinv", "UserPass123!")
|
||||
|
||||
// Setup TOTP first
|
||||
setupResp, setupBody := doGet(server.URL+"/api/v1/auth/2fa/setup", token)
|
||||
defer setupResp.Body.Close()
|
||||
if setupResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("setup failed: status=%d body=%s", setupResp.StatusCode, setupBody)
|
||||
}
|
||||
|
||||
// Try enable with invalid code
|
||||
enableResp, enableBody := doPost(server.URL+"/api/v1/auth/2fa/enable", token, map[string]interface{}{
|
||||
"code": "000000",
|
||||
})
|
||||
defer enableResp.Body.Close()
|
||||
|
||||
if enableResp.StatusCode != http.StatusUnauthorized && enableResp.StatusCode != http.StatusInternalServerError {
|
||||
t.Errorf("expected status 401 or 500 for invalid code, got %d, body: %s", enableResp.StatusCode, enableBody)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_EnableTOTP_MissingCode(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "totpenablemiss", "totpenablemiss@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "totpenablemiss", "UserPass123!")
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/2fa/enable", token, map[string]interface{}{})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_DisableTOTP(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpdisableuser", "totpdisable@test.com", "UserPass123!")
|
||||
|
||||
// Login again to get a fresh token (since TOTP is enabled, login may require TOTP)
|
||||
deviceID := "test-device"
|
||||
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "totpdisableuser",
|
||||
"password": "UserPass123!",
|
||||
"device_id": deviceID,
|
||||
})
|
||||
defer loginResp.Body.Close()
|
||||
|
||||
if loginResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("login failed: status=%d body=%s", loginResp.StatusCode, loginBody)
|
||||
}
|
||||
|
||||
var loginResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(loginBody), &loginResult); err != nil {
|
||||
t.Fatalf("failed to parse login response: %v", err)
|
||||
}
|
||||
|
||||
// If requires_totp, we need to verify TOTP first
|
||||
loginData, ok := loginResult["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected login data, got %s", loginBody)
|
||||
}
|
||||
|
||||
var token string
|
||||
if loginData["requires_totp"] == true {
|
||||
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate TOTP code: %v", err)
|
||||
}
|
||||
|
||||
tempToken, _ := loginData["temp_token"].(string)
|
||||
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"code": code,
|
||||
"device_id": deviceID,
|
||||
"temp_token": tempToken,
|
||||
})
|
||||
defer verifyResp.Body.Close()
|
||||
if verifyResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("totp verify failed: status=%d body=%s", verifyResp.StatusCode, verifyBody)
|
||||
}
|
||||
|
||||
var verifyResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err != nil {
|
||||
t.Fatalf("failed to parse verify response: %v", err)
|
||||
}
|
||||
verifyData, ok := verifyResult["data"].(map[string]interface{})
|
||||
if ok && verifyData["access_token"] != nil {
|
||||
token, _ = verifyData["access_token"].(string)
|
||||
}
|
||||
} else {
|
||||
token, _ = loginData["access_token"].(string)
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Fatal("failed to get token after login")
|
||||
}
|
||||
|
||||
// Generate valid code for disable
|
||||
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate TOTP code: %v", err)
|
||||
}
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/2fa/disable", token, map[string]interface{}{
|
||||
"code": code,
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
// Verify TOTP is disabled
|
||||
statusResp, statusBody := doGet(server.URL+"/api/v1/auth/2fa/status", token)
|
||||
defer statusResp.Body.Close()
|
||||
if statusResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("status check failed: status=%d body=%s", statusResp.StatusCode, statusBody)
|
||||
}
|
||||
|
||||
var statusResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(statusBody), &statusResult); err != nil {
|
||||
t.Fatalf("failed to parse status response: %v", err)
|
||||
}
|
||||
statusData, ok := statusResult["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected status data, got %s", statusBody)
|
||||
}
|
||||
if statusData["enabled"] != false {
|
||||
t.Errorf("expected enabled=false after disable, got %v", statusData["enabled"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_DisableTOTP_InvalidCode(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpdisableinv", "totpdisableinv@test.com", "UserPass123!")
|
||||
|
||||
// Get token (might need TOTP verification)
|
||||
deviceID := "test-device"
|
||||
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "totpdisableinv",
|
||||
"password": "UserPass123!",
|
||||
"device_id": deviceID,
|
||||
})
|
||||
defer loginResp.Body.Close()
|
||||
|
||||
var token string
|
||||
var loginResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
|
||||
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
|
||||
if loginData["requires_totp"] == true {
|
||||
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||
tempToken, _ := loginData["temp_token"].(string)
|
||||
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"code": code,
|
||||
"device_id": deviceID,
|
||||
"temp_token": tempToken,
|
||||
})
|
||||
defer verifyResp.Body.Close()
|
||||
if verifyResp.StatusCode == http.StatusOK {
|
||||
var verifyResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
|
||||
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
|
||||
token, _ = verifyData["access_token"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
token, _ = loginData["access_token"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Fatal("failed to get token after login")
|
||||
}
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/2fa/disable", token, map[string]interface{}{
|
||||
"code": "000000",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusInternalServerError {
|
||||
t.Errorf("expected status 401 or 500 for invalid code, got %d, body: %s", resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_VerifyTOTP(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpverifyuser", "totpverify@test.com", "UserPass123!")
|
||||
|
||||
// Get token (might need TOTP verification)
|
||||
deviceID := "test-device"
|
||||
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "totpverifyuser",
|
||||
"password": "UserPass123!",
|
||||
"device_id": deviceID,
|
||||
})
|
||||
defer loginResp.Body.Close()
|
||||
|
||||
var token string
|
||||
var loginResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
|
||||
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
|
||||
if loginData["requires_totp"] == true {
|
||||
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||
tempToken, _ := loginData["temp_token"].(string)
|
||||
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"code": code,
|
||||
"device_id": deviceID,
|
||||
"temp_token": tempToken,
|
||||
})
|
||||
defer verifyResp.Body.Close()
|
||||
if verifyResp.StatusCode == http.StatusOK {
|
||||
var verifyResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
|
||||
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
|
||||
token, _ = verifyData["access_token"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
token, _ = loginData["access_token"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Fatal("failed to get token after login")
|
||||
}
|
||||
|
||||
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate TOTP code: %v", err)
|
||||
}
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/2fa/verify", token, map[string]interface{}{
|
||||
"code": code,
|
||||
"device_id": deviceID,
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse response: %v", err)
|
||||
}
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
}
|
||||
data, ok := result["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected data in response, got %s", body)
|
||||
}
|
||||
if data["verified"] != true {
|
||||
t.Errorf("expected verified=true, got %v", data["verified"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_VerifyTOTP_InvalidCode(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpverifyinv", "totpverifyinv@test.com", "UserPass123!")
|
||||
|
||||
// Get token
|
||||
deviceID := "test-device"
|
||||
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "totpverifyinv",
|
||||
"password": "UserPass123!",
|
||||
"device_id": deviceID,
|
||||
})
|
||||
defer loginResp.Body.Close()
|
||||
|
||||
var token string
|
||||
var loginResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
|
||||
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
|
||||
if loginData["requires_totp"] == true {
|
||||
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||
tempToken, _ := loginData["temp_token"].(string)
|
||||
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"code": code,
|
||||
"device_id": deviceID,
|
||||
"temp_token": tempToken,
|
||||
})
|
||||
defer verifyResp.Body.Close()
|
||||
if verifyResp.StatusCode == http.StatusOK {
|
||||
var verifyResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
|
||||
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
|
||||
token, _ = verifyData["access_token"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
token, _ = loginData["access_token"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Fatal("failed to get token after login")
|
||||
}
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/2fa/verify", token, map[string]interface{}{
|
||||
"code": "000000",
|
||||
"device_id": deviceID,
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized && resp.StatusCode != http.StatusInternalServerError {
|
||||
t.Errorf("expected status 401 or 500 for invalid code, got %d, body: %s", resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_VerifyTOTP_MissingCode(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "totpverifymiss", "totpverifymiss@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "totpverifymiss", "UserPass123!")
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/2fa/verify", token, map[string]interface{}{})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_VerifyTOTP_Unauthorized(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, _ := doPost(server.URL+"/api/v1/auth/2fa/verify", "", map[string]interface{}{
|
||||
"code": "123456",
|
||||
"device_id": "test-device",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_DisableTOTP_MissingCode(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpdisablemiss", "totpdisablemiss@test.com", "UserPass123!")
|
||||
|
||||
// Get token
|
||||
deviceID := "test-device"
|
||||
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "totpdisablemiss",
|
||||
"password": "UserPass123!",
|
||||
"device_id": deviceID,
|
||||
})
|
||||
defer loginResp.Body.Close()
|
||||
|
||||
var token string
|
||||
var loginResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
|
||||
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
|
||||
if loginData["requires_totp"] == true {
|
||||
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||
tempToken, _ := loginData["temp_token"].(string)
|
||||
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"code": code,
|
||||
"device_id": deviceID,
|
||||
"temp_token": tempToken,
|
||||
})
|
||||
defer verifyResp.Body.Close()
|
||||
if verifyResp.StatusCode == http.StatusOK {
|
||||
var verifyResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
|
||||
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
|
||||
token, _ = verifyData["access_token"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
token, _ = loginData["access_token"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Fatal("failed to get token after login")
|
||||
}
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/2fa/disable", token, map[string]interface{}{})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_DisableTOTP_Unauthorized(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, _ := doPost(server.URL+"/api/v1/auth/2fa/disable", "", map[string]interface{}{
|
||||
"code": "123456",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_SetupTOTP_AlreadyEnabled(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpsetupenabled", "totpsetupenabled@test.com", "UserPass123!")
|
||||
_ = secret
|
||||
|
||||
// Get token after TOTP login
|
||||
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "totpsetupenabled",
|
||||
"password": "UserPass123!",
|
||||
"device_id": "test-device",
|
||||
})
|
||||
defer loginResp.Body.Close()
|
||||
|
||||
var token string
|
||||
var loginResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(loginBody), &loginResult); err == nil {
|
||||
if loginData, ok := loginResult["data"].(map[string]interface{}); ok {
|
||||
if loginData["requires_totp"] == true {
|
||||
tempToken, _ := loginData["temp_token"].(string)
|
||||
code, _ := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||
verifyResp, verifyBody := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"temp_token": tempToken,
|
||||
"code": code,
|
||||
"device_id": "test-device",
|
||||
})
|
||||
defer verifyResp.Body.Close()
|
||||
if verifyResp.StatusCode == http.StatusOK {
|
||||
var verifyResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(verifyBody), &verifyResult); err == nil {
|
||||
if verifyData, ok := verifyResult["data"].(map[string]interface{}); ok {
|
||||
token, _ = verifyData["access_token"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
token, _ = loginData["access_token"].(string)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if token == "" {
|
||||
t.Fatal("failed to get token after login")
|
||||
}
|
||||
|
||||
// Try setup again - should still work or return appropriate response
|
||||
resp, body := doGet(server.URL+"/api/v1/auth/2fa/setup", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Setup may return 200 with new secret or error if already enabled
|
||||
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("unexpected status %d, body: %s", resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_EnableTOTP_Unauthorized(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, _ := doPost(server.URL+"/api/v1/auth/2fa/enable", "", map[string]interface{}{
|
||||
"code": "123456",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPHandler_InvalidJSON(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "totpjsonuser", "totpjson@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "totpjsonuser", "UserPass123!")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
path string
|
||||
method string
|
||||
}{
|
||||
{"enable_invalid_json", "/api/v1/auth/2fa/enable", "POST"},
|
||||
{"disable_invalid_json", "/api/v1/auth/2fa/disable", "POST"},
|
||||
{"verify_invalid_json", "/api/v1/auth/2fa/verify", "POST"},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(tc.method, server.URL+tc.path, bytes.NewReader([]byte("not json")))
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Errorf("expected status %d for invalid JSON, got %d", http.StatusBadRequest, resp.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
102
internal/api/middleware/gzip_test.go
Normal file
102
internal/api/middleware/gzip_test.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestGzipMiddleware_CompressesLargeJSONResponses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(GzipMiddleware())
|
||||
router.GET("/data", func(c *gin.Context) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.String(http.StatusOK, strings.Repeat("a", gzipMinLength+128))
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/data", nil)
|
||||
req.Header.Set("Accept-Encoding", "gzip")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if got := recorder.Header().Get("Content-Encoding"); got != "gzip" {
|
||||
t.Fatalf("Content-Encoding = %q, want gzip", got)
|
||||
}
|
||||
|
||||
reader, err := gzip.NewReader(bytes.NewReader(recorder.Body.Bytes()))
|
||||
if err != nil {
|
||||
t.Fatalf("gzip.NewReader() error = %v", err)
|
||||
}
|
||||
defer reader.Close()
|
||||
|
||||
payload, err := io.ReadAll(reader)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadAll() error = %v", err)
|
||||
}
|
||||
if got := string(payload); got != strings.Repeat("a", gzipMinLength+128) {
|
||||
t.Fatalf("decompressed payload length = %d, want %d", len(got), gzipMinLength+128)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGzipMiddleware_PassesThroughWhenCompressionNotUseful(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
acceptEncoding string
|
||||
contentType string
|
||||
body string
|
||||
}{
|
||||
{
|
||||
name: "client does not accept gzip",
|
||||
acceptEncoding: "",
|
||||
contentType: "application/json",
|
||||
body: strings.Repeat("b", gzipMinLength+64),
|
||||
},
|
||||
{
|
||||
name: "body below threshold",
|
||||
acceptEncoding: "gzip",
|
||||
contentType: "application/json",
|
||||
body: "small-body",
|
||||
},
|
||||
{
|
||||
name: "unsupported content type",
|
||||
acceptEncoding: "gzip",
|
||||
contentType: "image/png",
|
||||
body: strings.Repeat("c", gzipMinLength+64),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(GzipMiddleware())
|
||||
router.GET("/data", func(c *gin.Context) {
|
||||
c.Header("Content-Type", tc.contentType)
|
||||
c.String(http.StatusOK, tc.body)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/data", nil)
|
||||
if tc.acceptEncoding != "" {
|
||||
req.Header.Set("Accept-Encoding", tc.acceptEncoding)
|
||||
}
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if got := recorder.Header().Get("Content-Encoding"); got != "" {
|
||||
t.Fatalf("Content-Encoding = %q, want empty", got)
|
||||
}
|
||||
if got := recorder.Body.String(); got != tc.body {
|
||||
t.Fatalf("body length = %d, want %d", len(got), len(tc.body))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
165
internal/api/middleware/operation_log_test.go
Normal file
165
internal/api/middleware/operation_log_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func newOperationLogRepositoryForTest(t *testing.T) *repository.OperationLogRepository {
|
||||
t.Helper()
|
||||
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: "file:operation_log_test?mode=memory&cache=shared",
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite failed: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&domain.OperationLog{}); err != nil {
|
||||
t.Fatalf("migrate failed: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Exec("DELETE FROM operation_logs").Error; err != nil {
|
||||
t.Fatalf("cleanup operation_logs failed: %v", err)
|
||||
}
|
||||
|
||||
return repository.NewOperationLogRepository(db)
|
||||
}
|
||||
|
||||
func waitForOperationLogs(t *testing.T, repo *repository.OperationLogRepository, want int) []*domain.OperationLog {
|
||||
t.Helper()
|
||||
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
logs, _, err := repo.List(context.Background(), 0, 20)
|
||||
if err != nil {
|
||||
t.Fatalf("list operation logs failed: %v", err)
|
||||
}
|
||||
if len(logs) >= want {
|
||||
return logs
|
||||
}
|
||||
time.Sleep(25 * time.Millisecond)
|
||||
}
|
||||
|
||||
logs, _, err := repo.List(context.Background(), 0, 20)
|
||||
if err != nil {
|
||||
t.Fatalf("list operation logs failed: %v", err)
|
||||
}
|
||||
t.Fatalf("timed out waiting for %d operation logs, got %d", want, len(logs))
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestOperationLogMiddleware_SkipsReadOnlyMethods(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := newOperationLogRepositoryForTest(t)
|
||||
router := gin.New()
|
||||
router.Use(NewOperationLogMiddleware(repo).Record())
|
||||
router.GET("/logs", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/logs", nil)
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
logs, _, err := repo.List(context.Background(), 0, 20)
|
||||
if err != nil {
|
||||
t.Fatalf("list operation logs failed: %v", err)
|
||||
}
|
||||
if len(logs) != 0 {
|
||||
t.Fatalf("expected no logs for GET request, got %d", len(logs))
|
||||
}
|
||||
}
|
||||
|
||||
func TestOperationLogMiddleware_RecordsAdminMutationAndSanitizesParams(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
repo := newOperationLogRepositoryForTest(t)
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
c.Set("user_id", int64(42))
|
||||
c.Set(ContextKeyRoleCodes, []string{"admin"})
|
||||
c.Next()
|
||||
})
|
||||
router.Use(NewOperationLogMiddleware(repo).Record())
|
||||
router.POST("/users", func(c *gin.Context) {
|
||||
c.Status(http.StatusCreated)
|
||||
})
|
||||
|
||||
body := `{"username":"alice","password":"super-secret","token":"abc"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/users", strings.NewReader(body))
|
||||
req.RemoteAddr = "203.0.113.10:8080"
|
||||
req.Header.Set("User-Agent", "middleware-test")
|
||||
recorder := httptest.NewRecorder()
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusCreated {
|
||||
t.Fatalf("expected 201, got %d", recorder.Code)
|
||||
}
|
||||
|
||||
logs := waitForOperationLogs(t, repo, 1)
|
||||
entry := logs[0]
|
||||
if entry.UserID == nil || *entry.UserID != 42 {
|
||||
t.Fatalf("user_id = %#v, want 42", entry.UserID)
|
||||
}
|
||||
if entry.OperationType != "admin:CREATE" {
|
||||
t.Fatalf("operation_type = %q, want admin:CREATE", entry.OperationType)
|
||||
}
|
||||
if entry.ResponseStatus != http.StatusCreated {
|
||||
t.Fatalf("response_status = %d, want %d", entry.ResponseStatus, http.StatusCreated)
|
||||
}
|
||||
if strings.Contains(entry.RequestParams, "super-secret") || strings.Contains(entry.RequestParams, "abc") {
|
||||
t.Fatalf("expected sanitized params, got %s", entry.RequestParams)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOperationLogMiddleware_MethodToTypeAndSanitizeFallbacks(t *testing.T) {
|
||||
if got := methodToType(http.MethodPatch); got != "UPDATE" {
|
||||
t.Fatalf("methodToType(PATCH) = %q, want UPDATE", got)
|
||||
}
|
||||
if got := methodToType(http.MethodDelete); got != "DELETE" {
|
||||
t.Fatalf("methodToType(DELETE) = %q, want DELETE", got)
|
||||
}
|
||||
if got := methodToType(http.MethodGet); got != "OTHER" {
|
||||
t.Fatalf("methodToType(GET) = %q, want OTHER", got)
|
||||
}
|
||||
|
||||
raw := []byte(`{"password":"secret","name":"alice"}`)
|
||||
sanitized := sanitizeParams(raw)
|
||||
if strings.Contains(sanitized, "secret") {
|
||||
t.Fatalf("expected password to be masked, got %s", sanitized)
|
||||
}
|
||||
|
||||
plain := sanitizeParams([]byte("not-json"))
|
||||
if plain != "not-json" {
|
||||
t.Fatalf("sanitizeParams(non-json) = %q, want not-json", plain)
|
||||
}
|
||||
|
||||
var payload map[string]string
|
||||
if err := json.Unmarshal([]byte(sanitized), &payload); err != nil {
|
||||
t.Fatalf("unmarshal sanitized params failed: %v", err)
|
||||
}
|
||||
if payload["password"] != "***" {
|
||||
t.Fatalf("password = %q, want ***", payload["password"])
|
||||
}
|
||||
}
|
||||
114
internal/api/middleware/rbac_test.go
Normal file
114
internal/api/middleware/rbac_test.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func performRBACRequest(t *testing.T, setup func(*gin.Context), middleware gin.HandlerFunc) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
if setup != nil {
|
||||
router.Use(setup)
|
||||
}
|
||||
router.Use(middleware)
|
||||
router.GET("/protected", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
return recorder
|
||||
}
|
||||
|
||||
func TestRequirePermissionRejectsMissingPermission(t *testing.T) {
|
||||
recorder := performRBACRequest(t, func(c *gin.Context) {
|
||||
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
|
||||
c.Next()
|
||||
}, RequirePermission("users:write"))
|
||||
|
||||
if recorder.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequirePermissionAllowsMatchingPermission(t *testing.T) {
|
||||
recorder := performRBACRequest(t, func(c *gin.Context) {
|
||||
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
|
||||
c.Next()
|
||||
}, RequirePermission("users:read"))
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAllPermissionsRequiresEveryCode(t *testing.T) {
|
||||
recorder := performRBACRequest(t, func(c *gin.Context) {
|
||||
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
|
||||
c.Next()
|
||||
}, RequireAllPermissions("users:read", "users:write"))
|
||||
|
||||
if recorder.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected 403, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAnyPermissionIsAliasOfRequirePermission(t *testing.T) {
|
||||
recorder := performRBACRequest(t, func(c *gin.Context) {
|
||||
c.Set(ContextKeyPermissionCodes, []string{"users:write"})
|
||||
c.Next()
|
||||
}, RequireAnyPermission("users:read", "users:write"))
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireRoleAndAdminOnly(t *testing.T) {
|
||||
roleRecorder := performRBACRequest(t, func(c *gin.Context) {
|
||||
c.Set(ContextKeyRoleCodes, []string{"auditor"})
|
||||
c.Next()
|
||||
}, RequireRole("admin"))
|
||||
if roleRecorder.Code != http.StatusForbidden {
|
||||
t.Fatalf("expected role check to return 403, got %d", roleRecorder.Code)
|
||||
}
|
||||
|
||||
adminRecorder := performRBACRequest(t, func(c *gin.Context) {
|
||||
c.Set(ContextKeyRoleCodes, []string{"admin"})
|
||||
c.Next()
|
||||
}, AdminOnly())
|
||||
if adminRecorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected admin check to return 200, got %d", adminRecorder.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRBACHelpersHandleMissingContextValues(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
|
||||
if got := GetRoleCodes(c); got != nil {
|
||||
t.Fatalf("GetRoleCodes() = %#v, want nil", got)
|
||||
}
|
||||
if got := GetPermissionCodes(c); got != nil {
|
||||
t.Fatalf("GetPermissionCodes() = %#v, want nil", got)
|
||||
}
|
||||
if IsAdmin(c) {
|
||||
t.Fatal("IsAdmin() = true, want false")
|
||||
}
|
||||
|
||||
c.Set(ContextKeyRoleCodes, []string{"admin"})
|
||||
c.Set(ContextKeyPermissionCodes, []string{"users:read"})
|
||||
|
||||
if !IsAdmin(c) {
|
||||
t.Fatal("IsAdmin() = false, want true")
|
||||
}
|
||||
}
|
||||
119
internal/api/middleware/response_wrapper_test.go
Normal file
119
internal/api/middleware/response_wrapper_test.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestResponseWrapper_WrapsSuccessfulJSONPayload(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(ResponseWrapper())
|
||||
router.GET("/users", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"id": 1, "name": "alice"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
want := `{"code":0,"data":{"id":1,"name":"alice"},"message":"success"}`
|
||||
if got := recorder.Body.String(); got != want {
|
||||
t.Fatalf("body = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseWrapper_PassesThroughMarkedResponses(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(ResponseWrapper())
|
||||
router.GET("/users", func(c *gin.Context) {
|
||||
WrapResponse(c)
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "already wrapped"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
want := `{"code":0,"message":"already wrapped"}`
|
||||
if got := recorder.Body.String(); got != want {
|
||||
t.Fatalf("body = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseWrapper_PassesThroughNonSuccessStatus(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(ResponseWrapper())
|
||||
router.GET("/users", func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"message": "bad request"})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", recorder.Code)
|
||||
}
|
||||
want := `{"message":"bad request"}`
|
||||
if got := recorder.Body.String(); got != want {
|
||||
t.Fatalf("body = %s, want %s", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseWrapper_PassesThroughInvalidJSON(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(ResponseWrapper())
|
||||
router.GET("/users", func(c *gin.Context) {
|
||||
c.Writer.WriteHeader(http.StatusOK)
|
||||
_, _ = c.Writer.WriteString("plain text")
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
if got := recorder.Body.String(); got != "plain text" {
|
||||
t.Fatalf("body = %q, want plain text", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseWrapper_NoWrapperMarksContext(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
router := gin.New()
|
||||
router.Use(NoWrapper())
|
||||
router.GET("/users", func(c *gin.Context) {
|
||||
if _, exists := c.Get("response_wrapped"); !exists {
|
||||
t.Fatal("expected response_wrapped marker in context")
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"ok": true})
|
||||
})
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/users", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", recorder.Code)
|
||||
}
|
||||
}
|
||||
136
internal/domain/device_test.go
Normal file
136
internal/domain/device_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDeviceType_Constants(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value DeviceType
|
||||
expected int
|
||||
}{
|
||||
{"Unknown", DeviceTypeUnknown, 0},
|
||||
{"Web", DeviceTypeWeb, 1},
|
||||
{"Mobile", DeviceTypeMobile, 2},
|
||||
{"Desktop", DeviceTypeDesktop, 3},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if int(tc.value) != tc.expected {
|
||||
t.Errorf("expected %d, got %d", tc.expected, int(tc.value))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeviceStatus_Constants(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value DeviceStatus
|
||||
expected int
|
||||
}{
|
||||
{"Inactive", DeviceStatusInactive, 0},
|
||||
{"Active", DeviceStatusActive, 1},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
if int(tc.value) != tc.expected {
|
||||
t.Errorf("expected %d, got %d", tc.expected, int(tc.value))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDevice_TableName(t *testing.T) {
|
||||
var d Device
|
||||
if got := d.TableName(); got != "devices" {
|
||||
t.Errorf("expected table name 'devices', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDevice_StructFields(t *testing.T) {
|
||||
now := time.Now()
|
||||
trustExpires := now.Add(24 * time.Hour)
|
||||
|
||||
d := Device{
|
||||
ID: 1,
|
||||
UserID: 2,
|
||||
DeviceID: "device-123",
|
||||
DeviceName: "Test Device",
|
||||
DeviceType: DeviceTypeWeb,
|
||||
DeviceOS: "Windows",
|
||||
DeviceBrowser: "Chrome",
|
||||
IP: "127.0.0.1",
|
||||
Location: "Beijing",
|
||||
IsTrusted: true,
|
||||
TrustExpiresAt: &trustExpires,
|
||||
Status: DeviceStatusActive,
|
||||
LastActiveTime: now,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
|
||||
if d.ID != 1 {
|
||||
t.Errorf("expected ID 1, got %d", d.ID)
|
||||
}
|
||||
if d.UserID != 2 {
|
||||
t.Errorf("expected UserID 2, got %d", d.UserID)
|
||||
}
|
||||
if d.DeviceID != "device-123" {
|
||||
t.Errorf("expected DeviceID 'device-123', got %q", d.DeviceID)
|
||||
}
|
||||
if d.DeviceName != "Test Device" {
|
||||
t.Errorf("expected DeviceName 'Test Device', got %q", d.DeviceName)
|
||||
}
|
||||
if d.DeviceType != DeviceTypeWeb {
|
||||
t.Errorf("expected DeviceTypeWeb, got %d", d.DeviceType)
|
||||
}
|
||||
if d.DeviceOS != "Windows" {
|
||||
t.Errorf("expected DeviceOS 'Windows', got %q", d.DeviceOS)
|
||||
}
|
||||
if d.DeviceBrowser != "Chrome" {
|
||||
t.Errorf("expected DeviceBrowser 'Chrome', got %q", d.DeviceBrowser)
|
||||
}
|
||||
if d.IP != "127.0.0.1" {
|
||||
t.Errorf("expected IP '127.0.0.1', got %q", d.IP)
|
||||
}
|
||||
if d.Location != "Beijing" {
|
||||
t.Errorf("expected Location 'Beijing', got %q", d.Location)
|
||||
}
|
||||
if !d.IsTrusted {
|
||||
t.Error("expected IsTrusted to be true")
|
||||
}
|
||||
if d.TrustExpiresAt == nil || !d.TrustExpiresAt.Equal(trustExpires) {
|
||||
t.Error("expected TrustExpiresAt to match")
|
||||
}
|
||||
if d.Status != DeviceStatusActive {
|
||||
t.Errorf("expected DeviceStatusActive, got %d", d.Status)
|
||||
}
|
||||
if d.LastActiveTime.IsZero() {
|
||||
t.Error("expected LastActiveTime to be set")
|
||||
}
|
||||
if d.CreatedAt.IsZero() {
|
||||
t.Error("expected CreatedAt to be set")
|
||||
}
|
||||
if d.UpdatedAt.IsZero() {
|
||||
t.Error("expected UpdatedAt to be set")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDevice_DefaultStatus(t *testing.T) {
|
||||
var d Device
|
||||
if d.Status != DeviceStatusInactive {
|
||||
t.Errorf("expected default status Inactive(0), got %d", d.Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDevice_DefaultDeviceType(t *testing.T) {
|
||||
var d Device
|
||||
if d.DeviceType != DeviceTypeUnknown {
|
||||
t.Errorf("expected default device type Unknown(0), got %d", d.DeviceType)
|
||||
}
|
||||
}
|
||||
35
internal/domain/password_history_test.go
Normal file
35
internal/domain/password_history_test.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPasswordHistory_TableName(t *testing.T) {
|
||||
var h PasswordHistory
|
||||
if got := h.TableName(); got != "password_histories" {
|
||||
t.Errorf("expected table name 'password_histories', got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordHistory_StructTags(t *testing.T) {
|
||||
h := PasswordHistory{
|
||||
ID: 1,
|
||||
UserID: 2,
|
||||
PasswordHash: "hash123",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if h.ID != 1 {
|
||||
t.Errorf("expected ID 1, got %d", h.ID)
|
||||
}
|
||||
if h.UserID != 2 {
|
||||
t.Errorf("expected UserID 2, got %d", h.UserID)
|
||||
}
|
||||
if h.PasswordHash != "hash123" {
|
||||
t.Errorf("expected PasswordHash 'hash123', got %q", h.PasswordHash)
|
||||
}
|
||||
if h.CreatedAt.IsZero() {
|
||||
t.Error("expected CreatedAt to be set")
|
||||
}
|
||||
}
|
||||
77
internal/pkg/pagination/pagination_test.go
Normal file
77
internal/pkg/pagination/pagination_test.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package pagination
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultPagination(t *testing.T) {
|
||||
p := DefaultPagination()
|
||||
if p.Page != 1 {
|
||||
t.Errorf("expected default page 1, got %d", p.Page)
|
||||
}
|
||||
if p.PageSize != 20 {
|
||||
t.Errorf("expected default page_size 20, got %d", p.PageSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginationParams_Offset(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
page int
|
||||
pageSize int
|
||||
wantOffset int
|
||||
}{
|
||||
{"page 1", 1, 20, 0},
|
||||
{"page 2", 2, 20, 20},
|
||||
{"page 5", 5, 20, 80},
|
||||
{"zero page", 0, 20, 0},
|
||||
{"negative page", -1, 20, 0},
|
||||
{"page 1 size 10", 1, 10, 0},
|
||||
{"page 3 size 10", 3, 10, 20},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := PaginationParams{Page: tc.page, PageSize: tc.pageSize}
|
||||
if got := p.Offset(); got != tc.wantOffset {
|
||||
t.Errorf("expected offset %d, got %d", tc.wantOffset, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginationParams_Limit(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
pageSize int
|
||||
want int
|
||||
}{
|
||||
{"default", 20, 20},
|
||||
{"size 10", 10, 10},
|
||||
{"size 50", 50, 50},
|
||||
{"size 100", 100, 100},
|
||||
{"max cap", 101, 100},
|
||||
{"zero size", 0, 20},
|
||||
{"negative size", -1, 20},
|
||||
{"size 1", 1, 1},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
p := PaginationParams{PageSize: tc.pageSize}
|
||||
if got := p.Limit(); got != tc.want {
|
||||
t.Errorf("expected limit %d, got %d", tc.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPaginationParams_OffsetAndLimit(t *testing.T) {
|
||||
p := PaginationParams{Page: 3, PageSize: 15}
|
||||
if got := p.Offset(); got != 30 {
|
||||
t.Errorf("expected offset 30, got %d", got)
|
||||
}
|
||||
if got := p.Limit(); got != 15 {
|
||||
t.Errorf("expected limit 15, got %d", got)
|
||||
}
|
||||
}
|
||||
95
internal/repository/pagination_test.go
Normal file
95
internal/repository/pagination_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/user-management-system/internal/pkg/pagination"
|
||||
)
|
||||
|
||||
func TestPaginationResultFromTotal(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
total int64
|
||||
params pagination.PaginationParams
|
||||
wantPages int
|
||||
wantTotal int64
|
||||
wantPage int
|
||||
wantPageSize int
|
||||
}{
|
||||
{
|
||||
name: "exact division",
|
||||
total: 100,
|
||||
params: pagination.PaginationParams{Page: 1, PageSize: 20},
|
||||
wantPages: 5,
|
||||
wantTotal: 100,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "with remainder",
|
||||
total: 105,
|
||||
params: pagination.PaginationParams{Page: 1, PageSize: 20},
|
||||
wantPages: 6,
|
||||
wantTotal: 105,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "zero total",
|
||||
total: 0,
|
||||
params: pagination.PaginationParams{Page: 1, PageSize: 20},
|
||||
wantPages: 0,
|
||||
wantTotal: 0,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "single page",
|
||||
total: 5,
|
||||
params: pagination.PaginationParams{Page: 1, PageSize: 20},
|
||||
wantPages: 1,
|
||||
wantTotal: 5,
|
||||
wantPage: 1,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "page 2",
|
||||
total: 50,
|
||||
params: pagination.PaginationParams{Page: 2, PageSize: 20},
|
||||
wantPages: 3,
|
||||
wantTotal: 50,
|
||||
wantPage: 2,
|
||||
wantPageSize: 20,
|
||||
},
|
||||
{
|
||||
name: "small page size",
|
||||
total: 10,
|
||||
params: pagination.PaginationParams{Page: 1, PageSize: 3},
|
||||
wantPages: 4,
|
||||
wantTotal: 10,
|
||||
wantPage: 1,
|
||||
wantPageSize: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
result := paginationResultFromTotal(tc.total, tc.params)
|
||||
if result == nil {
|
||||
t.Fatal("expected non-nil result")
|
||||
}
|
||||
if result.Total != tc.wantTotal {
|
||||
t.Errorf("expected total %d, got %d", tc.wantTotal, result.Total)
|
||||
}
|
||||
if result.Page != tc.wantPage {
|
||||
t.Errorf("expected page %d, got %d", tc.wantPage, result.Page)
|
||||
}
|
||||
if result.PageSize != tc.wantPageSize {
|
||||
t.Errorf("expected page_size %d, got %d", tc.wantPageSize, result.PageSize)
|
||||
}
|
||||
if result.Pages != tc.wantPages {
|
||||
t.Errorf("expected pages %d, got %d", tc.wantPages, result.Pages)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
224
internal/repository/password_history_test.go
Normal file
224
internal/repository/password_history_test.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
func TestPasswordHistoryRepository_Create(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||
t.Fatalf("migrate password_history failed: %v", err)
|
||||
}
|
||||
|
||||
repo := NewPasswordHistoryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
history := &domain.PasswordHistory{
|
||||
UserID: 1,
|
||||
PasswordHash: "hash1",
|
||||
CreatedAt: time.Now(),
|
||||
}
|
||||
|
||||
if err := repo.Create(ctx, history); err != nil {
|
||||
t.Fatalf("create failed: %v", err)
|
||||
}
|
||||
if history.ID == 0 {
|
||||
t.Error("expected ID to be set after create")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordHistoryRepository_GetByUserID(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||
t.Fatalf("migrate password_history failed: %v", err)
|
||||
}
|
||||
|
||||
repo := NewPasswordHistoryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create multiple records for user 1
|
||||
for i := 0; i < 5; i++ {
|
||||
h := &domain.PasswordHistory{
|
||||
UserID: 1,
|
||||
PasswordHash: "hash",
|
||||
CreatedAt: time.Now().Add(time.Duration(i) * time.Second),
|
||||
}
|
||||
if err := repo.Create(ctx, h); err != nil {
|
||||
t.Fatalf("create failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Create record for user 2
|
||||
if err := repo.Create(ctx, &domain.PasswordHistory{UserID: 2, PasswordHash: "hash", CreatedAt: time.Now()}); err != nil {
|
||||
t.Fatalf("create failed: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
userID int64
|
||||
limit int
|
||||
wantLen int
|
||||
wantUser int64
|
||||
}{
|
||||
{"get all for user 1", 1, 10, 5, 1},
|
||||
{"limit 3 for user 1", 1, 3, 3, 1},
|
||||
{"get for user 2", 2, 10, 1, 2},
|
||||
{"get for nonexistent user", 999, 10, 0, 999},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
histories, err := repo.GetByUserID(ctx, tc.userID, tc.limit)
|
||||
if err != nil {
|
||||
t.Fatalf("get failed: %v", err)
|
||||
}
|
||||
if len(histories) != tc.wantLen {
|
||||
t.Errorf("expected %d histories, got %d", tc.wantLen, len(histories))
|
||||
}
|
||||
for _, h := range histories {
|
||||
if h.UserID != tc.wantUser {
|
||||
t.Errorf("expected user_id %d, got %d", tc.wantUser, h.UserID)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordHistoryRepository_GetByUserID_Order(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||
t.Fatalf("migrate password_history failed: %v", err)
|
||||
}
|
||||
|
||||
repo := NewPasswordHistoryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create records with different timestamps
|
||||
now := time.Now()
|
||||
for i := 0; i < 3; i++ {
|
||||
h := &domain.PasswordHistory{
|
||||
UserID: 1,
|
||||
PasswordHash: "hash",
|
||||
CreatedAt: now.Add(time.Duration(i) * time.Hour),
|
||||
}
|
||||
if err := repo.Create(ctx, h); err != nil {
|
||||
t.Fatalf("create failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
histories, err := repo.GetByUserID(ctx, 1, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("get failed: %v", err)
|
||||
}
|
||||
if len(histories) != 3 {
|
||||
t.Fatalf("expected 3 histories, got %d", len(histories))
|
||||
}
|
||||
|
||||
// Should be ordered by created_at DESC (newest first)
|
||||
for i := 0; i < len(histories)-1; i++ {
|
||||
if !histories[i].CreatedAt.After(histories[i+1].CreatedAt) && !histories[i].CreatedAt.Equal(histories[i+1].CreatedAt) {
|
||||
t.Errorf("expected descending order, got %v before %v", histories[i].CreatedAt, histories[i+1].CreatedAt)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordHistoryRepository_DeleteOldRecords(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||
t.Fatalf("migrate password_history failed: %v", err)
|
||||
}
|
||||
|
||||
repo := NewPasswordHistoryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 5 records for user 1
|
||||
now := time.Now()
|
||||
for i := 0; i < 5; i++ {
|
||||
h := &domain.PasswordHistory{
|
||||
UserID: 1,
|
||||
PasswordHash: "hash",
|
||||
CreatedAt: now.Add(time.Duration(i) * time.Hour),
|
||||
}
|
||||
if err := repo.Create(ctx, h); err != nil {
|
||||
t.Fatalf("create failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete old records, keep only 3
|
||||
if err := repo.DeleteOldRecords(ctx, 1, 3); err != nil {
|
||||
t.Fatalf("delete old records failed: %v", err)
|
||||
}
|
||||
|
||||
histories, err := repo.GetByUserID(ctx, 1, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("get failed: %v", err)
|
||||
}
|
||||
if len(histories) != 3 {
|
||||
t.Errorf("expected 3 histories after cleanup, got %d", len(histories))
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordHistoryRepository_DeleteOldRecords_NoRecords(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||
t.Fatalf("migrate password_history failed: %v", err)
|
||||
}
|
||||
|
||||
repo := NewPasswordHistoryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Should not error when no records exist
|
||||
if err := repo.DeleteOldRecords(ctx, 999, 3); err != nil {
|
||||
t.Fatalf("delete old records on empty table should not error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPasswordHistoryRepository_KeepsNewestRecords(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
if err := db.AutoMigrate(&domain.PasswordHistory{}); err != nil {
|
||||
t.Fatalf("migrate password_history failed: %v", err)
|
||||
}
|
||||
|
||||
repo := NewPasswordHistoryRepository(db)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 5 records with different timestamps
|
||||
now := time.Now()
|
||||
var createdIDs []int64
|
||||
for i := 0; i < 5; i++ {
|
||||
h := &domain.PasswordHistory{
|
||||
UserID: 1,
|
||||
PasswordHash: "hash",
|
||||
CreatedAt: now.Add(time.Duration(i) * time.Hour),
|
||||
}
|
||||
if err := repo.Create(ctx, h); err != nil {
|
||||
t.Fatalf("create failed: %v", err)
|
||||
}
|
||||
createdIDs = append(createdIDs, h.ID)
|
||||
}
|
||||
|
||||
// Delete old records, keep only 2
|
||||
if err := repo.DeleteOldRecords(ctx, 1, 2); err != nil {
|
||||
t.Fatalf("delete old records failed: %v", err)
|
||||
}
|
||||
|
||||
histories, err := repo.GetByUserID(ctx, 1, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("get failed: %v", err)
|
||||
}
|
||||
if len(histories) != 2 {
|
||||
t.Fatalf("expected 2 histories after cleanup, got %d", len(histories))
|
||||
}
|
||||
|
||||
// The remaining records should be the newest (last 2 created)
|
||||
expectedIDs := map[int64]bool{createdIDs[3]: true, createdIDs[4]: true}
|
||||
for _, h := range histories {
|
||||
if !expectedIDs[h.ID] {
|
||||
t.Errorf("expected remaining IDs to be %v, got %d", expectedIDs, h.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
117
internal/repository/sql_scan_test.go
Normal file
117
internal/repository/sql_scan_test.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// mockQueryer implements sqlQueryer for testing
|
||||
type mockQueryer struct {
|
||||
rows *sql.Rows
|
||||
err error
|
||||
}
|
||||
|
||||
func (m *mockQueryer) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) {
|
||||
return m.rows, m.err
|
||||
}
|
||||
|
||||
func TestScanSingleRow_QueryError(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
mockErr := errors.New("query failed")
|
||||
q := &mockQueryer{err: mockErr}
|
||||
|
||||
var dest int
|
||||
err := scanSingleRow(ctx, q, "SELECT 1", nil, &dest)
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
if !errors.Is(err, mockErr) {
|
||||
t.Errorf("expected query error, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanSingleRow_NoRows(t *testing.T) {
|
||||
// This test requires a real database connection to create sql.Rows.
|
||||
// scanSingleRow is designed to work with any sqlQueryer, but creating
|
||||
// a mock sql.Rows without a real driver is complex.
|
||||
// We test the behavior through integration with the test database.
|
||||
db := openTestDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Use the raw sql.DB from gorm
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("get sql.DB failed: %v", err)
|
||||
}
|
||||
|
||||
var dest int
|
||||
err = scanSingleRow(ctx, sqlDB, "SELECT 1 WHERE 1=0", nil, &dest)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for no rows, got nil")
|
||||
}
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
t.Errorf("expected sql.ErrNoRows, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanSingleRow_Success(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("get sql.DB failed: %v", err)
|
||||
}
|
||||
|
||||
var dest int
|
||||
err = scanSingleRow(ctx, sqlDB, "SELECT 42", nil, &dest)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if dest != 42 {
|
||||
t.Errorf("expected 42, got %d", dest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanSingleRow_MultipleColumns(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("get sql.DB failed: %v", err)
|
||||
}
|
||||
|
||||
var a, b int
|
||||
err = scanSingleRow(ctx, sqlDB, "SELECT 1, 2", nil, &a, &b)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if a != 1 {
|
||||
t.Errorf("expected a=1, got %d", a)
|
||||
}
|
||||
if b != 2 {
|
||||
t.Errorf("expected b=2, got %d", b)
|
||||
}
|
||||
}
|
||||
|
||||
func TestScanSingleRow_StringResult(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
ctx := context.Background()
|
||||
|
||||
sqlDB, err := db.DB()
|
||||
if err != nil {
|
||||
t.Fatalf("get sql.DB failed: %v", err)
|
||||
}
|
||||
|
||||
var dest string
|
||||
err = scanSingleRow(ctx, sqlDB, "SELECT 'hello'", nil, &dest)
|
||||
if err != nil {
|
||||
t.Fatalf("expected no error, got %v", err)
|
||||
}
|
||||
if dest != "hello" {
|
||||
t.Errorf("expected 'hello', got %q", dest)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user