fix: v6 code review P0 auth/IDOR fixes + frontend regression patches
Backend fixes: - auth_handler: P0 认证逻辑修复 - ratelimit: 限速中间件增强 + 新增单元测试 - auth_service: 认证服务逻辑完善 + 新增测试 - server: server 配置增强 + 新增测试 - handler_test: 新增 handler 层集成测试 - auth_bootstrap_test: bootstrap 路径测试 Frontend patches: - LoginPage/RegisterPage: CSRF + 表单交互修复 - BootstrapAdminPage: 引导流程修复 - DevicesPage: 设备管理页修复 - auth/social-accounts/users/webhooks services: 类型修正 - csrf.ts: CSRF token 处理修正 - E2E 脚本: CDP smoke + auth e2e 增强 Docs: - FULL_CODE_REVIEW_REPORT_2026-04-20 - report-v6 执行计划 - REAL_PROJECT_STATUS 更新 - .gitignore: 新增 .gocache-*/config.yaml 排除 验证: go build/vet 0错误, go test 42/42 PASS, 0 FAIL
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"crypto/subtle"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
@@ -15,6 +16,11 @@ import (
|
||||
"github.com/user-management-system/internal/service"
|
||||
)
|
||||
|
||||
const (
|
||||
refreshTokenCookieName = "ums_refresh_token"
|
||||
sessionPresenceCookieName = "ums_session_present"
|
||||
)
|
||||
|
||||
// newBackgroundCtx 创建用于后台 goroutine 的带超时独立 context(与请求 context 无关)
|
||||
func newBackgroundCtx(timeoutSec int) (context.Context, context.CancelFunc) {
|
||||
return context.WithTimeout(context.Background(), time.Duration(timeoutSec)*time.Second)
|
||||
@@ -129,6 +135,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
setSessionCookies(c, h.authService, resp.RefreshToken)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
@@ -150,20 +157,28 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
// @Router /api/v1/auth/login/totp-verify [post]
|
||||
func (h *AuthHandler) VerifyTOTPAfterPasswordLogin(c *gin.Context) {
|
||||
var req struct {
|
||||
UserID int64 `json:"user_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
DeviceID string `json:"device_id"`
|
||||
UserID int64 `json:"user_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
DeviceID string `json:"device_id"`
|
||||
TempToken string `json:"temp_token"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.authService.VerifyTOTPAfterPasswordLogin(c.Request.Context(), req.UserID, req.Code, req.DeviceID)
|
||||
resp, err := h.authService.VerifyTOTPAfterPasswordLogin(
|
||||
c.Request.Context(),
|
||||
req.UserID,
|
||||
req.Code,
|
||||
req.DeviceID,
|
||||
req.TempToken,
|
||||
)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
setSessionCookies(c, h.authService, resp.RefreshToken)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
@@ -197,6 +212,10 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if req.RefreshToken == "" {
|
||||
req.RefreshToken, _ = c.Cookie(refreshTokenCookieName)
|
||||
}
|
||||
|
||||
username, _ := c.Get("username")
|
||||
usernameStr, _ := username.(string)
|
||||
|
||||
@@ -206,6 +225,8 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
}
|
||||
_ = h.authService.Logout(c.Request.Context(), usernameStr, logoutReq)
|
||||
|
||||
clearSessionCookies(c)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "logged out"})
|
||||
}
|
||||
|
||||
@@ -222,19 +243,27 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
// @Router /api/v1/auth/refresh-token [post]
|
||||
func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
if err := c.ShouldBindJSON(&req); err != nil && !errors.Is(err, io.EOF) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
if req.RefreshToken == "" {
|
||||
req.RefreshToken, _ = c.Cookie(refreshTokenCookieName)
|
||||
}
|
||||
if req.RefreshToken == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "refresh_token is required"})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.authService.RefreshToken(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
setSessionCookies(c, h.authService, resp.RefreshToken)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
@@ -480,6 +509,7 @@ func (h *AuthHandler) LoginByEmailCode(c *gin.Context) {
|
||||
h.authService.BestEffortRegisterDevicePublic(devCtx, userID, loginReq)
|
||||
}()
|
||||
}
|
||||
setSessionCookies(c, h.authService, resp.RefreshToken)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
@@ -544,6 +574,7 @@ func (h *AuthHandler) BootstrapAdmin(c *gin.Context) {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
setSessionCookies(c, h.authService, resp.RefreshToken)
|
||||
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"code": 0,
|
||||
@@ -673,6 +704,46 @@ func getUserIDFromContext(c *gin.Context) (int64, bool) {
|
||||
return id, ok
|
||||
}
|
||||
|
||||
func setSessionCookies(c *gin.Context, authService *service.AuthService, refreshToken string) {
|
||||
if c == nil || strings.TrimSpace(refreshToken) == "" {
|
||||
return
|
||||
}
|
||||
|
||||
maxAge := 0
|
||||
if authService != nil {
|
||||
if ttl := authService.RefreshTokenTTLSeconds(); ttl > 0 {
|
||||
maxAge = int(ttl)
|
||||
}
|
||||
}
|
||||
secure := requestUsesHTTPS(c)
|
||||
|
||||
c.SetSameSite(http.SameSiteLaxMode)
|
||||
c.SetCookie(refreshTokenCookieName, refreshToken, maxAge, "/", "", secure, true)
|
||||
c.SetCookie(sessionPresenceCookieName, "1", maxAge, "/", "", secure, false)
|
||||
}
|
||||
|
||||
func clearSessionCookies(c *gin.Context) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
secure := requestUsesHTTPS(c)
|
||||
|
||||
c.SetSameSite(http.SameSiteLaxMode)
|
||||
c.SetCookie(refreshTokenCookieName, "", -1, "/", "", secure, true)
|
||||
c.SetCookie(sessionPresenceCookieName, "", -1, "/", "", secure, false)
|
||||
}
|
||||
|
||||
func requestUsesHTTPS(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil {
|
||||
return false
|
||||
}
|
||||
if c.Request.TLS != nil {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(c.GetHeader("X-Forwarded-Proto")), "https")
|
||||
}
|
||||
|
||||
// handleError 将 error 转换为对应的 HTTP 响应。
|
||||
// 优先识别 ApplicationError,其次通过关键词推断业务错误类型,兜底返回 500。
|
||||
func handleError(c *gin.Context, err error) {
|
||||
|
||||
@@ -31,6 +31,46 @@ import (
|
||||
|
||||
var handlerDbCounter int64
|
||||
|
||||
func seedHandlerAuthzData(t *testing.T, db *gorm.DB) {
|
||||
t.Helper()
|
||||
|
||||
roleIDs := make(map[string]int64)
|
||||
for _, predefined := range domain.PredefinedRoles {
|
||||
role := predefined
|
||||
if err := db.Create(&role).Error; err != nil {
|
||||
t.Fatalf("seed role %s failed: %v", role.Code, err)
|
||||
}
|
||||
roleIDs[role.Code] = role.ID
|
||||
}
|
||||
|
||||
permissionIDs := make(map[string]int64)
|
||||
for _, predefined := range domain.DefaultPermissions() {
|
||||
permission := predefined
|
||||
if err := db.Create(&permission).Error; err != nil {
|
||||
t.Fatalf("seed permission %s failed: %v", permission.Code, err)
|
||||
}
|
||||
permissionIDs[permission.Code] = permission.ID
|
||||
}
|
||||
|
||||
adminRoleID := roleIDs["admin"]
|
||||
for _, permissionID := range permissionIDs {
|
||||
if err := db.Create(&domain.RolePermission{RoleID: adminRoleID, PermissionID: permissionID}).Error; err != nil {
|
||||
t.Fatalf("assign admin permission %d failed: %v", permissionID, err)
|
||||
}
|
||||
}
|
||||
|
||||
userRoleID := roleIDs["user"]
|
||||
for _, code := range []string{"profile:view", "profile:edit", "log:view_own"} {
|
||||
permissionID, ok := permissionIDs[code]
|
||||
if !ok {
|
||||
t.Fatalf("seeded permissions missing %s", code)
|
||||
}
|
||||
if err := db.Create(&domain.RolePermission{RoleID: userRoleID, PermissionID: permissionID}).Error; err != nil {
|
||||
t.Fatalf("assign user permission %s failed: %v", code, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
@@ -64,6 +104,8 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
|
||||
t.Fatalf("db migration failed: %v", err)
|
||||
}
|
||||
|
||||
seedHandlerAuthzData(t, db)
|
||||
|
||||
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||
HS256Secret: "test-handler-secret-key",
|
||||
AccessTokenExpire: 15 * time.Minute,
|
||||
@@ -176,6 +218,18 @@ func doDelete(url, token string) (*http.Response, string) {
|
||||
return doRequest("DELETE", url, token, nil)
|
||||
}
|
||||
|
||||
func getCookie(resp *http.Response, name string) *http.Cookie {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
for _, cookie := range resp.Cookies() {
|
||||
if cookie.Name == name {
|
||||
return cookie
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getToken(baseURL, username, password string) string {
|
||||
resp, body := doPost(baseURL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": username,
|
||||
@@ -207,6 +261,111 @@ func registerUser(baseURL, username, email, password string) bool {
|
||||
return resp.StatusCode == http.StatusCreated
|
||||
}
|
||||
|
||||
func bootstrapAdmin(baseURL, secret, username, email, password string) string {
|
||||
payload, _ := json.Marshal(map[string]interface{}{
|
||||
"username": username,
|
||||
"email": email,
|
||||
"password": password,
|
||||
})
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, baseURL+"/api/v1/auth/bootstrap-admin", bytes.NewReader(payload))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Bootstrap-Secret", secret)
|
||||
|
||||
resp, err := (&http.Client{}).Do(req)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK {
|
||||
return ""
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(body, &result); err != nil {
|
||||
return ""
|
||||
}
|
||||
data, ok := result["data"].(map[string]interface{})
|
||||
if !ok || data["access_token"] == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
token, _ := data["access_token"].(string)
|
||||
return token
|
||||
}
|
||||
|
||||
func setupEnabledTOTPUser(t *testing.T, baseURL, username, email, password string) (int64, string) {
|
||||
t.Helper()
|
||||
|
||||
if ok := registerUser(baseURL, username, email, password); !ok {
|
||||
t.Fatalf("registration failed for %s", username)
|
||||
}
|
||||
|
||||
token := getToken(baseURL, username, password)
|
||||
if token == "" {
|
||||
t.Fatalf("failed to get token for %s", username)
|
||||
}
|
||||
|
||||
userInfoResp, userInfoBody := doGet(baseURL+"/api/v1/auth/userinfo", token)
|
||||
defer userInfoResp.Body.Close()
|
||||
if userInfoResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("userinfo failed: status=%d body=%s", userInfoResp.StatusCode, userInfoBody)
|
||||
}
|
||||
|
||||
var userInfoResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(userInfoBody), &userInfoResult); err != nil {
|
||||
t.Fatalf("failed to parse userinfo response: %v", err)
|
||||
}
|
||||
userData, ok := userInfoResult["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("userinfo response missing data: %s", userInfoBody)
|
||||
}
|
||||
userID, ok := userData["id"].(float64)
|
||||
if !ok {
|
||||
t.Fatalf("userinfo response missing id: %s", userInfoBody)
|
||||
}
|
||||
|
||||
setupResp, setupBody := doGet(baseURL+"/api/v1/auth/2fa/setup", token)
|
||||
defer setupResp.Body.Close()
|
||||
if setupResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("2fa setup failed: status=%d body=%s", setupResp.StatusCode, setupBody)
|
||||
}
|
||||
|
||||
var setupResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(setupBody), &setupResult); err != nil {
|
||||
t.Fatalf("failed to parse 2fa setup response: %v", err)
|
||||
}
|
||||
setupData, ok := setupResult["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("2fa setup response missing data: %s", setupBody)
|
||||
}
|
||||
secret, ok := setupData["secret"].(string)
|
||||
if !ok || secret == "" {
|
||||
t.Fatalf("2fa setup response missing secret: %s", setupBody)
|
||||
}
|
||||
|
||||
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate TOTP code: %v", err)
|
||||
}
|
||||
|
||||
enableResp, enableBody := doPost(baseURL+"/api/v1/auth/2fa/enable", token, map[string]interface{}{
|
||||
"code": code,
|
||||
})
|
||||
defer enableResp.Body.Close()
|
||||
if enableResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("2fa enable failed: status=%d body=%s", enableResp.StatusCode, enableBody)
|
||||
}
|
||||
|
||||
return int64(userID), secret
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Auth Handler Tests
|
||||
// =============================================================================
|
||||
@@ -292,6 +451,38 @@ func TestAuthHandler_Login_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_Login_SetsSessionCookies(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "logincookieuser", "logincookie@example.com", "Password123!")
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "logincookieuser",
|
||||
"password": "Password123!",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
refreshCookie := getCookie(resp, "ums_refresh_token")
|
||||
if refreshCookie == nil || refreshCookie.Value == "" {
|
||||
t.Fatalf("login response missing refresh cookie, cookies=%v", resp.Cookies())
|
||||
}
|
||||
if !refreshCookie.HttpOnly {
|
||||
t.Fatalf("refresh cookie should be HttpOnly, got %+v", refreshCookie)
|
||||
}
|
||||
|
||||
presenceCookie := getCookie(resp, "ums_session_present")
|
||||
if presenceCookie == nil || presenceCookie.Value != "1" {
|
||||
t.Fatalf("login response missing presence cookie, cookies=%v", resp.Cookies())
|
||||
}
|
||||
if presenceCookie.HttpOnly {
|
||||
t.Fatalf("presence cookie should be readable by the frontend, got %+v", presenceCookie)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_Login_WrongPassword(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
@@ -360,6 +551,66 @@ func TestAuthHandler_GetAuthCapabilities(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_Login_WithTOTPEnabled_ReturnsChallengeToken(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
_, _ = setupEnabledTOTPUser(t, server.URL, "totplogin", "totplogin@example.com", "Password123!")
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "totplogin",
|
||||
"password": "Password123!",
|
||||
"device_id": "device-login-1",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(body), &result); err != nil {
|
||||
t.Fatalf("failed to parse login response: %v", err)
|
||||
}
|
||||
|
||||
data, ok := result["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("expected login response data, got %s", body)
|
||||
}
|
||||
|
||||
if data["requires_totp"] != true {
|
||||
t.Fatalf("expected requires_totp=true, got %+v", data)
|
||||
}
|
||||
|
||||
tempToken, ok := data["temp_token"].(string)
|
||||
if !ok || tempToken == "" {
|
||||
t.Fatalf("expected temp_token in TOTP challenge response, got %+v", data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_VerifyTOTPAfterPasswordLogin_RequiresTempToken(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
userID, secret := setupEnabledTOTPUser(t, server.URL, "totpreverify", "totpreverify@example.com", "Password123!")
|
||||
|
||||
code, err := auth.NewTOTPManager().GenerateCurrentCode(secret)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate TOTP code: %v", err)
|
||||
}
|
||||
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||
"user_id": userID,
|
||||
"code": code,
|
||||
"device_id": "device-login-1",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Fatalf("expected status %d when temp_token is missing, got %d, body: %s", http.StatusUnauthorized, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// User Handler Tests
|
||||
// =============================================================================
|
||||
@@ -451,6 +702,26 @@ func TestUserHandler_UpdateUser_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_UpdateUser_AdminCanUpdateAnotherUser(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
|
||||
token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "updateadmin", "updateadmin@test.com", "AdminPass123!")
|
||||
registerUser(server.URL, "targetuser", "targetuser@test.com", "UserPass123!")
|
||||
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin should return access token")
|
||||
}
|
||||
|
||||
resp, body := doPut(server.URL+"/api/v1/users/2", token, map[string]string{"nickname": "Updated By Admin"})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_DeleteUser_NonAdmin_Forbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
@@ -515,6 +786,26 @@ func TestUserHandler_GetUserRoles_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_GetUserRoles_AdminCanViewAnotherUser(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
t.Setenv("BOOTSTRAP_SECRET", "handler-bootstrap-secret")
|
||||
token := bootstrapAdmin(server.URL, "handler-bootstrap-secret", "rolesadmin2", "rolesadmin2@test.com", "AdminPass123!")
|
||||
registerUser(server.URL, "roles-target", "roles-target@test.com", "UserPass123!")
|
||||
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin should return access token")
|
||||
}
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/users/2/roles", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_AssignRoles_RequiresAdmin(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
@@ -1253,6 +1544,187 @@ func TestAuthHandler_RefreshToken_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_RefreshToken_AcceptsRefreshCookie(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "refreshcookieuser", "refreshcookie@example.com", "Password123!")
|
||||
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "refreshcookieuser",
|
||||
"password": "Password123!",
|
||||
})
|
||||
defer loginResp.Body.Close()
|
||||
|
||||
if loginResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, loginBody)
|
||||
}
|
||||
|
||||
refreshCookie := getCookie(loginResp, "ums_refresh_token")
|
||||
if refreshCookie == nil || refreshCookie.Value == "" {
|
||||
t.Fatalf("login response missing refresh cookie, cookies=%v", loginResp.Cookies())
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", server.URL+"/api/v1/auth/refresh", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create refresh request failed: %v", err)
|
||||
}
|
||||
req.AddCookie(refreshCookie)
|
||||
req.AddCookie(&http.Cookie{Name: "ums_session_present", Value: "1"})
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("refresh request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read refresh response failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
rotatedCookie := getCookie(resp, "ums_refresh_token")
|
||||
if rotatedCookie == nil || rotatedCookie.Value == "" {
|
||||
t.Fatalf("refresh response missing rotated refresh cookie, cookies=%v", resp.Cookies())
|
||||
}
|
||||
if rotatedCookie.Value == refreshCookie.Value {
|
||||
t.Fatalf("refresh should rotate cookie value, old=%q new=%q", refreshCookie.Value, rotatedCookie.Value)
|
||||
}
|
||||
|
||||
presenceCookie := getCookie(resp, "ums_session_present")
|
||||
if presenceCookie == nil || presenceCookie.Value != "1" {
|
||||
t.Fatalf("refresh response missing presence cookie, cookies=%v", resp.Cookies())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_RefreshToken_AllowsImmediateRetryWithPreviousCookie(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "refreshretryuser", "refreshretry@example.com", "Password123!")
|
||||
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "refreshretryuser",
|
||||
"password": "Password123!",
|
||||
})
|
||||
defer loginResp.Body.Close()
|
||||
|
||||
if loginResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, loginBody)
|
||||
}
|
||||
|
||||
refreshCookie := getCookie(loginResp, "ums_refresh_token")
|
||||
if refreshCookie == nil || refreshCookie.Value == "" {
|
||||
t.Fatalf("login response missing refresh cookie, cookies=%v", loginResp.Cookies())
|
||||
}
|
||||
|
||||
newRefreshRequest := func(cookie *http.Cookie) *http.Response {
|
||||
req, err := http.NewRequest(http.MethodPost, server.URL+"/api/v1/auth/refresh", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create refresh request failed: %v", err)
|
||||
}
|
||||
req.AddCookie(cookie)
|
||||
req.AddCookie(&http.Cookie{Name: "ums_session_present", Value: "1"})
|
||||
|
||||
resp, err := (&http.Client{}).Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("refresh request failed: %v", err)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
firstResp := newRefreshRequest(refreshCookie)
|
||||
defer firstResp.Body.Close()
|
||||
firstBody, err := io.ReadAll(firstResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read first refresh response failed: %v", err)
|
||||
}
|
||||
if firstResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected first refresh status %d, got %d, body: %s", http.StatusOK, firstResp.StatusCode, string(firstBody))
|
||||
}
|
||||
|
||||
retryResp := newRefreshRequest(refreshCookie)
|
||||
defer retryResp.Body.Close()
|
||||
retryBody, err := io.ReadAll(retryResp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read retry refresh response failed: %v", err)
|
||||
}
|
||||
if retryResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected retry refresh status %d, got %d, body: %s", http.StatusOK, retryResp.StatusCode, string(retryBody))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_Logout_ClearsSessionCookies(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "logoutcookieuser", "logoutcookie@example.com", "Password123!")
|
||||
loginResp, loginBody := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "logoutcookieuser",
|
||||
"password": "Password123!",
|
||||
})
|
||||
defer loginResp.Body.Close()
|
||||
|
||||
if loginResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, loginBody)
|
||||
}
|
||||
|
||||
var loginResult map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(loginBody), &loginResult); err != nil {
|
||||
t.Fatalf("parse login response failed: %v", err)
|
||||
}
|
||||
loginData, ok := loginResult["data"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("login response missing data: %s", loginBody)
|
||||
}
|
||||
accessToken, ok := loginData["access_token"].(string)
|
||||
if !ok || accessToken == "" {
|
||||
t.Fatalf("login response missing access token: %s", loginBody)
|
||||
}
|
||||
|
||||
refreshCookie := getCookie(loginResp, "ums_refresh_token")
|
||||
if refreshCookie == nil || refreshCookie.Value == "" {
|
||||
t.Fatalf("login response missing refresh cookie, cookies=%v", loginResp.Cookies())
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("POST", server.URL+"/api/v1/auth/logout", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("create logout request failed: %v", err)
|
||||
}
|
||||
req.Header.Set("Authorization", "Bearer "+accessToken)
|
||||
req.AddCookie(refreshCookie)
|
||||
req.AddCookie(&http.Cookie{Name: "ums_session_present", Value: "1"})
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("logout request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("read logout response failed: %v", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
clearedRefreshCookie := getCookie(resp, "ums_refresh_token")
|
||||
if clearedRefreshCookie == nil || clearedRefreshCookie.Value != "" {
|
||||
t.Fatalf("logout response should clear refresh cookie, cookies=%v", resp.Cookies())
|
||||
}
|
||||
|
||||
clearedPresenceCookie := getCookie(resp, "ums_session_present")
|
||||
if clearedPresenceCookie == nil || clearedPresenceCookie.Value != "" {
|
||||
t.Fatalf("logout response should clear presence cookie, cookies=%v", resp.Cookies())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_RefreshToken_InvalidToken(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
@@ -116,6 +116,7 @@ func (h *SMSHandler) LoginByCode(c *gin.Context) {
|
||||
h.authService.BestEffortRegisterDevicePublic(devCtx, userID, loginReq)
|
||||
}()
|
||||
}
|
||||
setSessionCookies(c, h.authService, resp.RefreshToken)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/service"
|
||||
@@ -187,15 +188,7 @@ func (h *UserHandler) UpdateUser(c *gin.Context) {
|
||||
|
||||
// Authorization: only self or admin can update user profile
|
||||
currentUserID := c.GetInt64("user_id")
|
||||
isAdmin := false
|
||||
if roles, ok := c.Get("user_roles"); ok {
|
||||
for _, role := range roles.([]*domain.Role) {
|
||||
if role.Code == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
isAdmin := middleware.IsAdmin(c)
|
||||
if currentUserID != id && !isAdmin {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return
|
||||
@@ -370,15 +363,7 @@ func (h *UserHandler) GetUserRoles(c *gin.Context) {
|
||||
|
||||
// Authorization: only self or admin can view user roles
|
||||
currentUserID := c.GetInt64("user_id")
|
||||
isAdmin := false
|
||||
if roles, ok := c.Get("user_roles"); ok {
|
||||
for _, role := range roles.([]*domain.Role) {
|
||||
if role.Code == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
isAdmin := middleware.IsAdmin(c)
|
||||
if currentUserID != id && !isAdmin {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return
|
||||
|
||||
103
internal/api/middleware/auth_bootstrap_test.go
Normal file
103
internal/api/middleware/auth_bootstrap_test.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"github.com/user-management-system/internal/service"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func TestAuthMiddleware_AcceptsBootstrapAdminTokenImmediately(t *testing.T) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: "file:middleware_bootstrap_test?mode=memory&cache=shared",
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("open sqlite failed: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&domain.User{}, &domain.Role{}, &domain.UserRole{}); err != nil {
|
||||
t.Fatalf("migrate failed: %v", err)
|
||||
}
|
||||
|
||||
if err := db.Create(&domain.Role{
|
||||
Name: "管理员",
|
||||
Code: "admin",
|
||||
IsSystem: true,
|
||||
Status: domain.RoleStatusEnabled,
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("seed admin role failed: %v", err)
|
||||
}
|
||||
|
||||
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||
HS256Secret: "test-bootstrap-token-secret-at-least-32-chars",
|
||||
AccessTokenExpire: 15 * time.Minute,
|
||||
RefreshTokenExpire: 7 * 24 * time.Hour,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
l1Cache := cache.NewL1Cache()
|
||||
l2Cache := cache.NewRedisCache(false)
|
||||
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
||||
|
||||
userRepo := repository.NewUserRepository(db)
|
||||
roleRepo := repository.NewRoleRepository(db)
|
||||
userRoleRepo := repository.NewUserRoleRepository(db)
|
||||
|
||||
authService := service.NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
|
||||
authService.SetRoleRepositories(userRoleRepo, roleRepo)
|
||||
|
||||
loginResponse, err := authService.BootstrapAdmin(context.Background(), &service.BootstrapAdminRequest{
|
||||
Username: "bootstrap_admin",
|
||||
Email: "bootstrap_admin@example.com",
|
||||
Password: "AdminPass123!",
|
||||
}, "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("bootstrap admin failed: %v", err)
|
||||
}
|
||||
if loginResponse == nil || loginResponse.AccessToken == "" {
|
||||
t.Fatalf("expected bootstrap access token, got %+v", loginResponse)
|
||||
}
|
||||
|
||||
if _, err := jwtManager.ValidateAccessToken(loginResponse.AccessToken); err != nil {
|
||||
t.Fatalf("bootstrap access token should validate immediately: %v", err)
|
||||
}
|
||||
|
||||
authMiddleware := NewAuthMiddleware(jwtManager, userRepo, userRoleRepo, l1Cache)
|
||||
authMiddleware.SetCacheManager(cacheManager)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
ctx, engine := gin.CreateTestContext(recorder)
|
||||
ctx.Request = httptest.NewRequest(http.MethodGet, "/protected", nil)
|
||||
ctx.Request.Header.Set("Authorization", "Bearer "+loginResponse.AccessToken)
|
||||
|
||||
engine.Use(authMiddleware.Required())
|
||||
engine.GET("/protected", func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0})
|
||||
})
|
||||
|
||||
engine.ServeHTTP(recorder, ctx.Request)
|
||||
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("expected bootstrap token to pass auth middleware immediately, got %d body: %s", recorder.Code, recorder.Body.String())
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,21 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/user-management-system/internal/config"
|
||||
)
|
||||
|
||||
// RateLimitMiddleware 限流中间件
|
||||
// RateLimitMiddleware provides simple in-memory sliding-window rate limiting.
|
||||
type RateLimitMiddleware struct {
|
||||
cfg config.RateLimitConfig
|
||||
limiters map[string]*SlidingWindowLimiter
|
||||
@@ -16,7 +23,7 @@ type RateLimitMiddleware struct {
|
||||
cleanupInt time.Duration
|
||||
}
|
||||
|
||||
// SlidingWindowLimiter 滑动窗口限流器
|
||||
// SlidingWindowLimiter enforces a fixed-capacity sliding window.
|
||||
type SlidingWindowLimiter struct {
|
||||
mu sync.Mutex
|
||||
window time.Duration
|
||||
@@ -24,7 +31,6 @@ type SlidingWindowLimiter struct {
|
||||
requests []int64
|
||||
}
|
||||
|
||||
// NewSlidingWindowLimiter 创建滑动窗口限流器
|
||||
func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindowLimiter {
|
||||
return &SlidingWindowLimiter{
|
||||
window: window,
|
||||
@@ -33,7 +39,6 @@ func NewSlidingWindowLimiter(window time.Duration, capacity int64) *SlidingWindo
|
||||
}
|
||||
}
|
||||
|
||||
// Allow 检查是否允许请求
|
||||
func (l *SlidingWindowLimiter) Allow() bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
@@ -41,16 +46,14 @@ func (l *SlidingWindowLimiter) Allow() bool {
|
||||
now := time.Now().UnixMilli()
|
||||
cutoff := now - l.window.Milliseconds()
|
||||
|
||||
// 清理过期请求
|
||||
var validRequests []int64
|
||||
for _, t := range l.requests {
|
||||
if t > cutoff {
|
||||
validRequests = append(validRequests, t)
|
||||
validRequests := make([]int64, 0, len(l.requests))
|
||||
for _, ts := range l.requests {
|
||||
if ts > cutoff {
|
||||
validRequests = append(validRequests, ts)
|
||||
}
|
||||
}
|
||||
l.requests = validRequests
|
||||
|
||||
// 检查容量
|
||||
if int64(len(l.requests)) >= l.capacity {
|
||||
return false
|
||||
}
|
||||
@@ -59,7 +62,6 @@ func (l *SlidingWindowLimiter) Allow() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// NewRateLimitMiddleware 创建限流中间件
|
||||
func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
|
||||
return &RateLimitMiddleware{
|
||||
cfg: cfg,
|
||||
@@ -68,30 +70,28 @@ func NewRateLimitMiddleware(cfg config.RateLimitConfig) *RateLimitMiddleware {
|
||||
}
|
||||
}
|
||||
|
||||
// Register 返回注册接口的限流中间件
|
||||
func (m *RateLimitMiddleware) Register() gin.HandlerFunc {
|
||||
return m.limitForKey("register", 60, 10)
|
||||
}
|
||||
|
||||
// Login 返回登录接口的限流中间件
|
||||
func (m *RateLimitMiddleware) Login() gin.HandlerFunc {
|
||||
return m.limitForKey("login", 60, 5)
|
||||
}
|
||||
|
||||
// API 返回 API 接口的限流中间件
|
||||
func (m *RateLimitMiddleware) API() gin.HandlerFunc {
|
||||
return m.limitForKey("api", 60, 100)
|
||||
}
|
||||
|
||||
// Refresh 返回刷新令牌的限流中间件
|
||||
func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
|
||||
return m.limitForKey("refresh", 60, 10)
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc {
|
||||
limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity)
|
||||
func (m *RateLimitMiddleware) limitForKey(bucket string, windowSeconds int, capacity int64) gin.HandlerFunc {
|
||||
window := time.Duration(windowSeconds) * time.Second
|
||||
|
||||
return func(c *gin.Context) {
|
||||
limiterKey := m.resolveLimiterKey(c, bucket)
|
||||
limiter := m.getOrCreateLimiter(limiterKey, window, capacity)
|
||||
if !limiter.Allow() {
|
||||
c.JSON(429, gin.H{
|
||||
"code": 429,
|
||||
@@ -104,6 +104,81 @@ func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacit
|
||||
}
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) resolveLimiterKey(c *gin.Context, bucket string) string {
|
||||
if bucket == "refresh" {
|
||||
if refreshToken := extractRefreshToken(c); refreshToken != "" {
|
||||
return fmt.Sprintf("%s:token:%s", bucket, fingerprintValue(refreshToken))
|
||||
}
|
||||
}
|
||||
|
||||
identity := "anonymous"
|
||||
if c != nil {
|
||||
if userID, ok := c.Get("user_id"); ok {
|
||||
identity = fmt.Sprintf("user:%v", userID)
|
||||
} else if ip := c.ClientIP(); ip != "" {
|
||||
identity = "ip:" + ip
|
||||
}
|
||||
}
|
||||
|
||||
if bucket == "api" {
|
||||
method := ""
|
||||
route := ""
|
||||
if c != nil {
|
||||
if c.Request != nil {
|
||||
method = c.Request.Method
|
||||
if c.Request.URL != nil {
|
||||
route = c.Request.URL.Path
|
||||
}
|
||||
}
|
||||
if fullPath := c.FullPath(); fullPath != "" {
|
||||
route = fullPath
|
||||
}
|
||||
}
|
||||
return fmt.Sprintf("%s:%s:%s:%s", bucket, method, route, identity)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s:%s", bucket, identity)
|
||||
}
|
||||
|
||||
func extractRefreshToken(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
if refreshToken, err := c.Cookie("ums_refresh_token"); err == nil && refreshToken != "" {
|
||||
return refreshToken
|
||||
}
|
||||
|
||||
if c.Request == nil || c.Request.Body == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(body))
|
||||
|
||||
if len(bytes.TrimSpace(body)) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var payload struct {
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return payload.RefreshToken
|
||||
}
|
||||
|
||||
func fingerprintValue(value string) string {
|
||||
sum := sha256.Sum256([]byte(value))
|
||||
return hex.EncodeToString(sum[:12])
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duration, capacity int64) *SlidingWindowLimiter {
|
||||
m.mu.RLock()
|
||||
limiter, exists := m.limiters[key]
|
||||
@@ -116,7 +191,6 @@ func (m *RateLimitMiddleware) getOrCreateLimiter(key string, window time.Duratio
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// 双重检查
|
||||
if limiter, exists = m.limiters[key]; exists {
|
||||
return limiter
|
||||
}
|
||||
|
||||
140
internal/api/middleware/ratelimit_test.go
Normal file
140
internal/api/middleware/ratelimit_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/user-management-system/internal/config"
|
||||
)
|
||||
|
||||
func performRateLimitedRequest(router *gin.Engine, path string, userID int64) *httptest.ResponseRecorder {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
req.Header.Set("X-Test-User-ID", strconv.FormatInt(userID, 10))
|
||||
router.ServeHTTP(recorder, req)
|
||||
return recorder
|
||||
}
|
||||
|
||||
func performRefreshRateLimitedRequestWithCookie(router *gin.Engine, refreshToken string) *httptest.ResponseRecorder {
|
||||
recorder := httptest.NewRecorder()
|
||||
req := httptest.NewRequest(http.MethodPost, "/auth/refresh", nil)
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
if refreshToken != "" {
|
||||
req.AddCookie(&http.Cookie{Name: "ums_refresh_token", Value: refreshToken})
|
||||
}
|
||||
router.ServeHTTP(recorder, req)
|
||||
return recorder
|
||||
}
|
||||
|
||||
func performRefreshRateLimitedRequestWithBody(router *gin.Engine, refreshToken string) *httptest.ResponseRecorder {
|
||||
recorder := httptest.NewRecorder()
|
||||
body := bytes.NewBufferString(`{"refresh_token":"` + refreshToken + `"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/auth/refresh", body)
|
||||
req.RemoteAddr = "127.0.0.1:12345"
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
router.ServeHTTP(recorder, req)
|
||||
return recorder
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_API_ScopesBudgetByRouteForAuthenticatedUser(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
||||
router := gin.New()
|
||||
router.Use(func(c *gin.Context) {
|
||||
rawUserID := c.GetHeader("X-Test-User-ID")
|
||||
if rawUserID != "" {
|
||||
userID, err := strconv.ParseInt(rawUserID, 10, 64)
|
||||
if err == nil {
|
||||
c.Set("user_id", userID)
|
||||
}
|
||||
}
|
||||
c.Next()
|
||||
})
|
||||
|
||||
protected := router.Group("")
|
||||
protected.Use(rateLimitMiddleware.API())
|
||||
protected.GET("/users", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
protected.GET("/roles", func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
recorder := performRateLimitedRequest(router, "/users", 1)
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("request %d to /users returned %d, want %d", i+1, recorder.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
sameRouteOverflow := performRateLimitedRequest(router, "/users", 1)
|
||||
if sameRouteOverflow.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("overflow request to /users returned %d, want %d", sameRouteOverflow.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
differentRoute := performRateLimitedRequest(router, "/roles", 1)
|
||||
if differentRoute.Code != http.StatusOK {
|
||||
t.Fatalf("request to /roles after exhausting /users budget returned %d, want %d", differentRoute.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshCookie(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
||||
router := gin.New()
|
||||
router.POST("/auth/refresh", rateLimitMiddleware.Refresh(), func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
recorder := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-a")
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("request %d for refresh-token-a returned %d, want %d", i+1, recorder.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
sameTokenOverflow := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-a")
|
||||
if sameTokenOverflow.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("overflow request for refresh-token-a returned %d, want %d", sameTokenOverflow.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
differentToken := performRefreshRateLimitedRequestWithCookie(router, "refresh-token-b")
|
||||
if differentToken.Code != http.StatusOK {
|
||||
t.Fatalf("request for refresh-token-b after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitMiddleware_Refresh_ScopesBudgetByRefreshTokenBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rateLimitMiddleware := NewRateLimitMiddleware(config.RateLimitConfig{})
|
||||
router := gin.New()
|
||||
router.POST("/auth/refresh", rateLimitMiddleware.Refresh(), func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
})
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
recorder := performRefreshRateLimitedRequestWithBody(router, "refresh-token-a")
|
||||
if recorder.Code != http.StatusOK {
|
||||
t.Fatalf("request %d for refresh-token-a body returned %d, want %d", i+1, recorder.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
|
||||
sameTokenOverflow := performRefreshRateLimitedRequestWithBody(router, "refresh-token-a")
|
||||
if sameTokenOverflow.Code != http.StatusTooManyRequests {
|
||||
t.Fatalf("overflow request for refresh-token-a body returned %d, want %d", sameTokenOverflow.Code, http.StatusTooManyRequests)
|
||||
}
|
||||
|
||||
differentToken := performRefreshRateLimitedRequestWithBody(router, "refresh-token-b")
|
||||
if differentToken.Code != http.StatusOK {
|
||||
t.Fatalf("request for refresh-token-b body after exhausting refresh-token-a budget returned %d, want %d", differentToken.Code, http.StatusOK)
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -43,10 +45,12 @@ func Serve(cfg *config.Config) error {
|
||||
// P1-3:Argon2id 启动时自适应校准
|
||||
auth.CalibrateArgon2id(500 * time.Millisecond)
|
||||
|
||||
accessTokenExpire := resolveJWTAccessTokenExpire(cfg)
|
||||
|
||||
// 初始化 JWT 管理器
|
||||
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||
HS256Secret: cfg.JWT.Secret,
|
||||
AccessTokenExpire: time.Duration(cfg.JWT.AccessTokenExpireMinutes) * time.Minute,
|
||||
AccessTokenExpire: accessTokenExpire,
|
||||
RefreshTokenExpire: time.Duration(cfg.JWT.RefreshTokenExpireDays) * 24 * time.Hour,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -125,6 +129,9 @@ func Serve(cfg *config.Config) error {
|
||||
totpService := service.NewTOTPService(userRepo)
|
||||
|
||||
passwordResetConfig := service.DefaultPasswordResetConfig()
|
||||
if err := configureAuthEmailServices(cfg, cacheManager, authService, passwordResetConfig); err != nil {
|
||||
return fmt.Errorf("configure auth email services failed: %w", err)
|
||||
}
|
||||
passwordResetService := service.NewPasswordResetService(userRepo, cacheManager, passwordResetConfig).
|
||||
WithPasswordHistoryRepo(passwordHistoryRepo)
|
||||
|
||||
@@ -259,3 +266,100 @@ func resolveGinMode(mode string) string {
|
||||
return gin.ReleaseMode
|
||||
}
|
||||
}
|
||||
|
||||
func configureAuthEmailServices(
|
||||
cfg *config.Config,
|
||||
cacheManager *cache.CacheManager,
|
||||
authService *service.AuthService,
|
||||
passwordResetConfig *service.PasswordResetConfig,
|
||||
) error {
|
||||
smtpConfig, enabled, err := resolveSMTPEmailConfigFromEnv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !enabled || cacheManager == nil || authService == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
siteURL := resolveAuthEmailSiteURL(cfg)
|
||||
siteName := resolveAuthEmailSiteName(cfg)
|
||||
provider := service.NewSMTPEmailProvider(smtpConfig)
|
||||
|
||||
authService.SetEmailActivationService(
|
||||
service.NewEmailActivationService(provider, cacheManager, siteURL, siteName),
|
||||
)
|
||||
|
||||
emailCodeConfig := service.DefaultEmailCodeConfig()
|
||||
emailCodeConfig.SiteURL = siteURL
|
||||
emailCodeConfig.SiteName = siteName
|
||||
authService.SetEmailCodeService(service.NewEmailCodeService(provider, cacheManager, emailCodeConfig))
|
||||
|
||||
if passwordResetConfig != nil {
|
||||
passwordResetConfig.SMTPHost = smtpConfig.Host
|
||||
passwordResetConfig.SMTPPort = smtpConfig.Port
|
||||
passwordResetConfig.SMTPUser = smtpConfig.Username
|
||||
passwordResetConfig.SMTPPass = smtpConfig.Password
|
||||
passwordResetConfig.FromEmail = smtpConfig.FromEmail
|
||||
passwordResetConfig.SiteURL = siteURL
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveSMTPEmailConfigFromEnv() (service.SMTPEmailConfig, bool, error) {
|
||||
host := strings.TrimSpace(os.Getenv("EMAIL_HOST"))
|
||||
if host == "" {
|
||||
return service.SMTPEmailConfig{}, false, nil
|
||||
}
|
||||
|
||||
port := 587
|
||||
if rawPort := strings.TrimSpace(os.Getenv("EMAIL_PORT")); rawPort != "" {
|
||||
parsedPort, err := strconv.Atoi(rawPort)
|
||||
if err != nil || parsedPort <= 0 {
|
||||
return service.SMTPEmailConfig{}, false, fmt.Errorf("invalid EMAIL_PORT %q", rawPort)
|
||||
}
|
||||
port = parsedPort
|
||||
}
|
||||
|
||||
fromEmail := strings.TrimSpace(os.Getenv("EMAIL_FROM_EMAIL"))
|
||||
if fromEmail == "" {
|
||||
fromEmail = service.DefaultPasswordResetConfig().FromEmail
|
||||
}
|
||||
|
||||
return service.SMTPEmailConfig{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Username: strings.TrimSpace(os.Getenv("EMAIL_USER")),
|
||||
Password: os.Getenv("EMAIL_PASS"),
|
||||
FromEmail: fromEmail,
|
||||
FromName: strings.TrimSpace(os.Getenv("EMAIL_FROM_NAME")),
|
||||
}, true, nil
|
||||
}
|
||||
|
||||
func resolveAuthEmailSiteURL(cfg *config.Config) string {
|
||||
if cfg != nil {
|
||||
if siteURL := strings.TrimSpace(cfg.Server.FrontendURL); siteURL != "" {
|
||||
return siteURL
|
||||
}
|
||||
}
|
||||
return service.DefaultEmailCodeConfig().SiteURL
|
||||
}
|
||||
|
||||
func resolveAuthEmailSiteName(cfg *config.Config) string {
|
||||
if cfg != nil {
|
||||
if siteName := strings.TrimSpace(cfg.Log.ServiceName); siteName != "" {
|
||||
return siteName
|
||||
}
|
||||
}
|
||||
return service.DefaultEmailCodeConfig().SiteName
|
||||
}
|
||||
|
||||
func resolveJWTAccessTokenExpire(cfg *config.Config) time.Duration {
|
||||
if cfg == nil {
|
||||
return 0
|
||||
}
|
||||
if cfg.JWT.AccessTokenExpireMinutes > 0 {
|
||||
return time.Duration(cfg.JWT.AccessTokenExpireMinutes) * time.Minute
|
||||
}
|
||||
return time.Duration(cfg.JWT.ExpireHour) * time.Hour
|
||||
}
|
||||
|
||||
73
internal/server/server_test.go
Normal file
73
internal/server/server_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/config"
|
||||
"github.com/user-management-system/internal/service"
|
||||
)
|
||||
|
||||
func TestResolveJWTAccessTokenExpire_UsesExpireHourFallback(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.JWT.ExpireHour = 24
|
||||
cfg.JWT.AccessTokenExpireMinutes = 0
|
||||
|
||||
expire := resolveJWTAccessTokenExpire(cfg)
|
||||
|
||||
if expire != 24*time.Hour {
|
||||
t.Fatalf("resolveJWTAccessTokenExpire() = %v, want %v", expire, 24*time.Hour)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveJWTAccessTokenExpire_PrefersMinuteOverride(t *testing.T) {
|
||||
cfg := &config.Config{}
|
||||
cfg.JWT.ExpireHour = 24
|
||||
cfg.JWT.AccessTokenExpireMinutes = 90
|
||||
|
||||
expire := resolveJWTAccessTokenExpire(cfg)
|
||||
|
||||
if expire != 90*time.Minute {
|
||||
t.Fatalf("resolveJWTAccessTokenExpire() = %v, want %v", expire, 90*time.Minute)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigureAuthEmailServices_UsesSMTPEnvironment(t *testing.T) {
|
||||
t.Setenv("EMAIL_HOST", "127.0.0.1")
|
||||
t.Setenv("EMAIL_PORT", "2525")
|
||||
t.Setenv("EMAIL_FROM_EMAIL", "noreply@test.local")
|
||||
t.Setenv("EMAIL_FROM_NAME", "UMS E2E")
|
||||
t.Setenv("EMAIL_USER", "smtp-user")
|
||||
t.Setenv("EMAIL_PASS", "smtp-pass")
|
||||
|
||||
cfg := &config.Config{}
|
||||
cfg.Server.FrontendURL = "http://127.0.0.1:3000"
|
||||
cfg.Log.ServiceName = "UMS E2E"
|
||||
|
||||
cacheManager := cache.NewCacheManager(cache.NewL1Cache(), cache.NewRedisCache(false))
|
||||
authService := service.NewAuthService(nil, nil, nil, cacheManager, 8, 5, time.Minute)
|
||||
passwordResetConfig := service.DefaultPasswordResetConfig()
|
||||
|
||||
if err := configureAuthEmailServices(cfg, cacheManager, authService, passwordResetConfig); err != nil {
|
||||
t.Fatalf("configureAuthEmailServices() error = %v", err)
|
||||
}
|
||||
if !authService.SupportsEmailActivation() {
|
||||
t.Fatal("SupportsEmailActivation() = false, want true")
|
||||
}
|
||||
if !authService.HasEmailCodeService() {
|
||||
t.Fatal("HasEmailCodeService() = false, want true")
|
||||
}
|
||||
if passwordResetConfig.SMTPHost != "127.0.0.1" {
|
||||
t.Fatalf("password reset SMTP host = %q, want %q", passwordResetConfig.SMTPHost, "127.0.0.1")
|
||||
}
|
||||
if passwordResetConfig.SMTPPort != 2525 {
|
||||
t.Fatalf("password reset SMTP port = %d, want %d", passwordResetConfig.SMTPPort, 2525)
|
||||
}
|
||||
if passwordResetConfig.FromEmail != "noreply@test.local" {
|
||||
t.Fatalf("password reset FromEmail = %q, want %q", passwordResetConfig.FromEmail, "noreply@test.local")
|
||||
}
|
||||
if passwordResetConfig.SiteURL != "http://127.0.0.1:3000" {
|
||||
t.Fatalf("password reset SiteURL = %q, want %q", passwordResetConfig.SiteURL, "http://127.0.0.1:3000")
|
||||
}
|
||||
}
|
||||
@@ -2,10 +2,13 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
cryptorand "crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
@@ -19,11 +22,14 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
userInfoCachePrefix = "auth_user_info:"
|
||||
tokenBlacklistPrefix = "auth_token_blacklist:"
|
||||
defaultUserCacheTTL = 15 * time.Minute
|
||||
defaultBlacklistTTL = time.Hour
|
||||
defaultPasswordMinLen = 8
|
||||
userInfoCachePrefix = "auth_user_info:"
|
||||
tokenBlacklistPrefix = "auth_token_blacklist:"
|
||||
totpChallengePrefix = "auth_totp_challenge:"
|
||||
defaultUserCacheTTL = 15 * time.Minute
|
||||
defaultBlacklistTTL = time.Hour
|
||||
defaultTOTPChallengeTTL = 5 * time.Minute
|
||||
defaultPasswordMinLen = 8
|
||||
refreshTokenRetryGrace = 10 * time.Second
|
||||
)
|
||||
|
||||
type userRepositoryInterface interface {
|
||||
@@ -122,13 +128,18 @@ type LoginResponse struct {
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
User *UserInfo `json:"user,omitempty"`
|
||||
// RequiresTOTP 指示登录需要额外的TOTP验证(当设备未信任时)
|
||||
RequiresTOTP bool `json:"requires_totp,omitempty"`
|
||||
RequiresTOTP bool `json:"requires_totp,omitempty"`
|
||||
// TempToken 临时令牌,用于TOTP验证阶段(短生命周期,不可用于常规API)
|
||||
TempToken string `json:"temp_token,omitempty"`
|
||||
// UserID 当RequiresTOTP为true时返回,用于后续TOTP验证
|
||||
UserID int64 `json:"user_id,omitempty"`
|
||||
}
|
||||
|
||||
type totpLoginChallenge struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
}
|
||||
|
||||
type LogoutRequest struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
@@ -432,6 +443,38 @@ func (s *AuthService) blacklistTokenClaims(ctx context.Context, token string, va
|
||||
return s.cache.Set(ctx, tokenBlacklistPrefix+claims.JTI, true, ttl, ttl)
|
||||
}
|
||||
|
||||
func (s *AuthService) getTokenBlacklistValue(ctx context.Context, jti string) (interface{}, bool) {
|
||||
if s == nil || s.cache == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
jti = strings.TrimSpace(jti)
|
||||
if jti == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return s.cache.Get(ctx, tokenBlacklistPrefix+jti)
|
||||
}
|
||||
|
||||
func tokenBlacklistRevokedAt(value interface{}) (time.Time, bool) {
|
||||
switch v := value.(type) {
|
||||
case int64:
|
||||
return time.Unix(0, v), true
|
||||
case int:
|
||||
return time.Unix(0, int64(v)), true
|
||||
case float64:
|
||||
return time.Unix(0, int64(v)), true
|
||||
case string:
|
||||
timestamp, err := strconv.ParseInt(strings.TrimSpace(v), 10, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return time.Unix(0, timestamp), true
|
||||
default:
|
||||
return time.Time{}, false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) recordLoginAnomaly(ctx context.Context, userID *int64, ip, location, deviceFingerprint string, success bool) {
|
||||
if s == nil || s.anomalyDetector == nil || userID == nil {
|
||||
return
|
||||
@@ -601,6 +644,93 @@ func userInfoFromCacheValue(value interface{}) (*UserInfo, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func generateTemporaryLoginToken() (string, error) {
|
||||
payload := make([]byte, 32)
|
||||
if _, err := cryptorand.Read(payload); err != nil {
|
||||
return "", fmt.Errorf("generate temporary login token failed: %w", err)
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(payload), nil
|
||||
}
|
||||
|
||||
func totpLoginChallengeFromCacheValue(value interface{}) (*totpLoginChallenge, bool) {
|
||||
switch typed := value.(type) {
|
||||
case *totpLoginChallenge:
|
||||
return typed, true
|
||||
case totpLoginChallenge:
|
||||
challenge := typed
|
||||
return &challenge, true
|
||||
case map[string]interface{}:
|
||||
payload, err := json.Marshal(typed)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
var challenge totpLoginChallenge
|
||||
if err := json.Unmarshal(payload, &challenge); err != nil {
|
||||
return nil, false
|
||||
}
|
||||
return &challenge, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AuthService) issueTOTPLoginChallenge(ctx context.Context, user *domain.User, deviceID string) (string, error) {
|
||||
if s == nil || s.cache == nil {
|
||||
return "", errors.New("temporary login token storage is unavailable")
|
||||
}
|
||||
if user == nil {
|
||||
return "", errors.New("temporary login token requires a user")
|
||||
}
|
||||
|
||||
tempToken, err := generateTemporaryLoginToken()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
challenge := &totpLoginChallenge{
|
||||
UserID: user.ID,
|
||||
DeviceID: strings.TrimSpace(deviceID),
|
||||
}
|
||||
if err := s.cache.Set(
|
||||
ctx,
|
||||
totpChallengePrefix+tempToken,
|
||||
challenge,
|
||||
defaultTOTPChallengeTTL,
|
||||
defaultTOTPChallengeTTL,
|
||||
); err != nil {
|
||||
return "", fmt.Errorf("temporary login token storage failed: %w", err)
|
||||
}
|
||||
|
||||
return tempToken, nil
|
||||
}
|
||||
|
||||
func (s *AuthService) validateTOTPLoginChallenge(ctx context.Context, userID int64, deviceID, tempToken string) error {
|
||||
if s == nil || s.cache == nil {
|
||||
return errors.New("temporary login token storage is unavailable")
|
||||
}
|
||||
|
||||
normalizedToken := strings.TrimSpace(tempToken)
|
||||
if normalizedToken == "" {
|
||||
return errors.New("temporary login token is required")
|
||||
}
|
||||
|
||||
value, ok := s.cache.Get(ctx, totpChallengePrefix+normalizedToken)
|
||||
if !ok {
|
||||
return errors.New("temporary login token is invalid or expired")
|
||||
}
|
||||
|
||||
challenge, ok := totpLoginChallengeFromCacheValue(value)
|
||||
if !ok || challenge == nil {
|
||||
return errors.New("temporary login token is invalid or expired")
|
||||
}
|
||||
|
||||
if challenge.UserID != userID || strings.TrimSpace(challenge.DeviceID) != strings.TrimSpace(deviceID) {
|
||||
return errors.New("temporary login token does not match the requested login flow")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AuthService) Register(ctx context.Context, req *RegisterRequest) (*UserInfo, error) {
|
||||
if req == nil {
|
||||
return nil, errors.New("注册请求不能为空")
|
||||
@@ -628,6 +758,9 @@ func (s *AuthService) Register(ctx context.Context, req *RegisterRequest) (*User
|
||||
if err := s.verifyPhoneRegistration(ctx, req); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.emailActivationSvc != nil && req.Email != "" {
|
||||
return s.RegisterWithActivation(ctx, req)
|
||||
}
|
||||
|
||||
exists, err := s.userRepo.ExistsByUsername(ctx, req.Username)
|
||||
if err != nil {
|
||||
@@ -759,11 +892,17 @@ func (s *AuthService) Login(ctx context.Context, req *LoginRequest, ip string) (
|
||||
|
||||
// P0-07 安全修复:检查是否需要TOTP验证(用户启用了TOTP且设备未信任)
|
||||
if s.isTOTPRequiredForLogin(ctx, user, req.DeviceID) {
|
||||
tempToken, err := s.issueTOTPLoginChallenge(ctx, user, req.DeviceID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 返回RequiresTOTP指示前端需要完成TOTP验证
|
||||
// 前端应调用 /auth/login/totp-verify 接口完成验证
|
||||
return &LoginResponse{
|
||||
RequiresTOTP: true,
|
||||
UserID: user.ID,
|
||||
TempToken: tempToken,
|
||||
UserID: user.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -808,10 +947,13 @@ func (s *AuthService) isTOTPRequiredForLogin(ctx context.Context, user *domain.U
|
||||
// VerifyTOTPAfterPasswordLogin 完成密码登录后的TOTP验证
|
||||
// 当用户启用了TOTP但设备未信任时,密码登录会返回RequiresTOTP=true
|
||||
// 前端需要调用此接口完成TOTP验证以获取令牌
|
||||
func (s *AuthService) VerifyTOTPAfterPasswordLogin(ctx context.Context, userID int64, totpCode, deviceID string) (*LoginResponse, error) {
|
||||
func (s *AuthService) VerifyTOTPAfterPasswordLogin(ctx context.Context, userID int64, totpCode, deviceID, tempToken string) (*LoginResponse, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("auth service is not initialized")
|
||||
}
|
||||
if err := s.validateTOTPLoginChallenge(ctx, userID, deviceID, tempToken); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -827,6 +969,10 @@ func (s *AuthService) VerifyTOTPAfterPasswordLogin(ctx context.Context, userID i
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := s.cache.Delete(ctx, totpChallengePrefix+strings.TrimSpace(tempToken)); err != nil {
|
||||
return nil, fmt.Errorf("temporary login token cleanup failed: %w", err)
|
||||
}
|
||||
|
||||
// TOTP验证成功,返回完整登录响应
|
||||
return s.generateLoginResponseWithoutRemember(ctx, user)
|
||||
}
|
||||
@@ -841,8 +987,11 @@ func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*L
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if s.IsTokenBlacklisted(ctx, claims.JTI) {
|
||||
return nil, errors.New("refresh token has been revoked")
|
||||
if blacklistValue, blacklisted := s.getTokenBlacklistValue(ctx, claims.JTI); blacklisted {
|
||||
revokedAt, hasRevocationTimestamp := tokenBlacklistRevokedAt(blacklistValue)
|
||||
if !hasRevocationTimestamp || time.Since(revokedAt) > refreshTokenRetryGrace {
|
||||
return nil, errors.New("refresh token has been revoked")
|
||||
}
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, claims.UserID)
|
||||
@@ -861,7 +1010,7 @@ func (s *AuthService) RefreshToken(ctx context.Context, refreshToken string) (*L
|
||||
if claims.ExpiresAt != nil {
|
||||
remaining := time.Until(claims.ExpiresAt.Time)
|
||||
if remaining > 0 {
|
||||
if err := s.cache.Set(ctx, blacklistKey, "1", 5*time.Minute, remaining); err != nil {
|
||||
if err := s.cache.Set(ctx, blacklistKey, time.Now().UnixNano(), 5*time.Minute, remaining); err != nil {
|
||||
return nil, fmt.Errorf("token revocation failed: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,13 +69,17 @@ func (s *AuthService) RegisterWithActivation(ctx context.Context, req *RegisterR
|
||||
if s.emailActivationSvc != nil && req.Email != "" {
|
||||
initialStatus = domain.UserStatusInactive
|
||||
}
|
||||
nickname := req.Nickname
|
||||
if nickname == "" {
|
||||
nickname = req.Username
|
||||
}
|
||||
|
||||
user := &domain.User{
|
||||
Username: req.Username,
|
||||
Email: domain.StrPtr(req.Email),
|
||||
Phone: domain.StrPtr(req.Phone),
|
||||
Password: hashedPassword,
|
||||
Nickname: req.Nickname,
|
||||
Nickname: nickname,
|
||||
Status: initialStatus,
|
||||
}
|
||||
if err := s.userRepo.Create(ctx, user); err != nil {
|
||||
@@ -85,10 +89,6 @@ func (s *AuthService) RegisterWithActivation(ctx context.Context, req *RegisterR
|
||||
s.bestEffortAssignDefaultRoles(ctx, user.ID, "register_with_activation")
|
||||
|
||||
if s.emailActivationSvc != nil && req.Email != "" {
|
||||
nickname := req.Nickname
|
||||
if nickname == "" {
|
||||
nickname = req.Username
|
||||
}
|
||||
// #nosec G118 - 使用独立上下文避免请求结束后被取消
|
||||
go func() { // #nosec G118
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
|
||||
@@ -375,6 +375,51 @@ func TestAuthService_RegisterWithActivation(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthService_Register_UsesEmailActivationFlowWhenConfigured(t *testing.T) {
|
||||
svc, db := setupAuthEmailTestEnv(t)
|
||||
ctx := context.Background()
|
||||
|
||||
l1Cache := cache.NewL1Cache()
|
||||
l2Cache := cache.NewRedisCache(false)
|
||||
cacheManager := cache.NewCacheManager(l1Cache, l2Cache)
|
||||
emailActivationSvc := service.NewEmailActivationService(
|
||||
&service.MockEmailProvider{},
|
||||
cacheManager,
|
||||
"http://localhost:8080",
|
||||
"TestSite",
|
||||
)
|
||||
svc.SetEmailActivationService(emailActivationSvc)
|
||||
|
||||
userInfo, err := svc.Register(ctx, &service.RegisterRequest{
|
||||
Username: "register_activation_enabled",
|
||||
Password: "Password123!",
|
||||
Email: "register-activation-enabled@test.com",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Register failed: %v", err)
|
||||
}
|
||||
if userInfo == nil {
|
||||
t.Fatal("Register returned nil user info")
|
||||
}
|
||||
if userInfo.Status != domain.UserStatusInactive {
|
||||
t.Fatalf("Register status = %d, want %d", userInfo.Status, domain.UserStatusInactive)
|
||||
}
|
||||
if userInfo.Nickname != "register_activation_enabled" {
|
||||
t.Fatalf("Register nickname = %q, want %q", userInfo.Nickname, "register_activation_enabled")
|
||||
}
|
||||
|
||||
var storedUser domain.User
|
||||
if err := db.WithContext(ctx).Where("username = ?", "register_activation_enabled").First(&storedUser).Error; err != nil {
|
||||
t.Fatalf("load stored user: %v", err)
|
||||
}
|
||||
if storedUser.Status != domain.UserStatusInactive {
|
||||
t.Fatalf("stored user status = %d, want %d", storedUser.Status, domain.UserStatusInactive)
|
||||
}
|
||||
if storedUser.Nickname != "register_activation_enabled" {
|
||||
t.Fatalf("stored user nickname = %q, want %q", storedUser.Nickname, "register_activation_enabled")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Login By Email Code Extended Tests
|
||||
// =============================================================================
|
||||
|
||||
@@ -3,10 +3,12 @@ package service
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"github.com/user-management-system/internal/security"
|
||||
@@ -359,6 +361,73 @@ func TestBuildDeviceFingerprint(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogin_IssuesTOTPChallengeTokenWhenSecondFactorIsRequired(t *testing.T) {
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
DriverName: "sqlite",
|
||||
DSN: fmt.Sprintf("file:login_totp_challenge_%d?mode=memory&cache=shared", time.Now().UnixNano()),
|
||||
}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect database: %v", err)
|
||||
}
|
||||
|
||||
if err := db.AutoMigrate(&domain.User{}); err != nil {
|
||||
t.Fatalf("failed to migrate: %v", err)
|
||||
}
|
||||
|
||||
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||
HS256Secret: "totp-challenge-secret",
|
||||
AccessTokenExpire: 15 * time.Minute,
|
||||
RefreshTokenExpire: 7 * 24 * time.Hour,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create jwt manager: %v", err)
|
||||
}
|
||||
|
||||
cacheManager := cache.NewCacheManager(cache.NewL1Cache(), cache.NewRedisCache(false))
|
||||
userRepo := repository.NewUserRepository(db)
|
||||
svc := NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
|
||||
|
||||
hashedPassword, err := auth.HashPassword("Password123!")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to hash password: %v", err)
|
||||
}
|
||||
|
||||
user := &domain.User{
|
||||
Username: "totpchallenge",
|
||||
Password: hashedPassword,
|
||||
Status: domain.UserStatusActive,
|
||||
TOTPEnabled: true,
|
||||
TOTPSecret: "JBSWY3DPEHPK3PXP",
|
||||
}
|
||||
if err := db.Create(user).Error; err != nil {
|
||||
t.Fatalf("failed to create user: %v", err)
|
||||
}
|
||||
|
||||
resp, err := svc.Login(context.Background(), &LoginRequest{
|
||||
Account: "totpchallenge",
|
||||
Password: "Password123!",
|
||||
DeviceID: "device-1",
|
||||
}, "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("login failed: %v", err)
|
||||
}
|
||||
|
||||
if !resp.RequiresTOTP {
|
||||
t.Fatalf("expected requires_totp response, got %+v", resp)
|
||||
}
|
||||
if resp.UserID != user.ID {
|
||||
t.Fatalf("expected user id %d, got %d", user.ID, resp.UserID)
|
||||
}
|
||||
if strings.TrimSpace(resp.TempToken) == "" {
|
||||
t.Fatalf("expected temp token when TOTP is required, got %+v", resp)
|
||||
}
|
||||
if resp.AccessToken != "" || resp.RefreshToken != "" {
|
||||
t.Fatalf("expected no full session tokens before TOTP verification, got %+v", resp)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthServiceDefaultConfig(t *testing.T) {
|
||||
// Test that default configuration is applied correctly
|
||||
svc := NewAuthService(nil, nil, nil, nil, 0, 0, 0)
|
||||
|
||||
Reference in New Issue
Block a user