fix: close auth, permission, contract and e2e review blockers
This commit is contained in:
@@ -216,6 +216,7 @@ $env:CORS_ALLOWED_ORIGINS = "$frontendBaseUrl,http://localhost:$selectedFrontend
|
||||
|
||||
$env:VITE_API_PROXY_TARGET = $backendBaseUrl
|
||||
$env:VITE_API_BASE_URL = '/api/v1'
|
||||
$env:NODE_ENV = 'development'
|
||||
$frontendHandle = Start-ManagedProcess `
|
||||
-Name 'ums-frontend-playwright' `
|
||||
-FilePath 'npm.cmd' `
|
||||
@@ -288,10 +289,11 @@ $env:CORS_ALLOWED_ORIGINS = "$frontendBaseUrl,http://localhost:$selectedFrontend
|
||||
Remove-Item Env:EMAIL_PORT -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:EMAIL_FROM_EMAIL -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:EMAIL_FROM_NAME -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:VITE_API_PROXY_TARGET -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:VITE_API_BASE_URL -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:JWT_SECRET -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:DEFAULT_ADMIN_EMAIL -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:VITE_API_PROXY_TARGET -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:VITE_API_BASE_URL -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:NODE_ENV -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:JWT_SECRET -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:DEFAULT_ADMIN_EMAIL -ErrorAction SilentlyContinue
|
||||
Remove-Item Env:DEFAULT_ADMIN_PASSWORD -ErrorAction SilentlyContinue
|
||||
Remove-Item $serverExePath -Force -ErrorAction SilentlyContinue
|
||||
Remove-Item $e2eRunRoot -Recurse -Force -ErrorAction SilentlyContinue
|
||||
|
||||
142
frontend/admin/scripts/run-playwright-auth-e2e.sh
Normal file
142
frontend/admin/scripts/run-playwright-auth-e2e.sh
Normal file
@@ -0,0 +1,142 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
ADMIN_USERNAME="${E2E_LOGIN_USERNAME:-e2e_admin}"
|
||||
ADMIN_PASSWORD="${E2E_LOGIN_PASSWORD:-E2EAdmin@123456}"
|
||||
ADMIN_EMAIL="${E2E_LOGIN_EMAIL:-e2e_admin@example.com}"
|
||||
BOOTSTRAP_SECRET_VALUE="${E2E_BOOTSTRAP_SECRET:-${BOOTSTRAP_SECRET:-e2e-bootstrap-secret-0123456789abcdefghijklmnopqrstuvwxyz}}"
|
||||
BROWSER_PORT="${E2E_CDP_PORT:-0}"
|
||||
BACKEND_PORT="${E2E_BACKEND_PORT:-0}"
|
||||
FRONTEND_PORT="${E2E_FRONTEND_PORT:-0}"
|
||||
|
||||
SCRIPT_DIR="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)"
|
||||
FRONTEND_ROOT="$(cd -- "$SCRIPT_DIR/.." && pwd)"
|
||||
PROJECT_ROOT="$(cd -- "$SCRIPT_DIR/../../.." && pwd)"
|
||||
TMP_ROOT="$(mktemp -d -t ums-playwright-e2e-XXXXXX)"
|
||||
DATA_ROOT="$TMP_ROOT/data"
|
||||
SMTP_CAPTURE_FILE="$TMP_ROOT/smtp-capture.jsonl"
|
||||
SERVER_BIN="$TMP_ROOT/ums-server"
|
||||
mkdir -p "$DATA_ROOT"
|
||||
|
||||
backend_pid=''
|
||||
frontend_pid=''
|
||||
smtp_pid=''
|
||||
|
||||
cleanup() {
|
||||
local exit_code=$?
|
||||
for pid in "$frontend_pid" "$backend_pid" "$smtp_pid"; do
|
||||
if [[ -n "$pid" ]] && kill -0 "$pid" 2>/dev/null; then
|
||||
kill "$pid" 2>/dev/null || true
|
||||
wait "$pid" 2>/dev/null || true
|
||||
fi
|
||||
done
|
||||
rm -rf "$TMP_ROOT"
|
||||
exit "$exit_code"
|
||||
}
|
||||
trap cleanup EXIT INT TERM
|
||||
|
||||
get_free_port() {
|
||||
python3 - <<'PY'
|
||||
import socket
|
||||
s = socket.socket()
|
||||
s.bind(('127.0.0.1', 0))
|
||||
print(s.getsockname()[1])
|
||||
s.close()
|
||||
PY
|
||||
}
|
||||
|
||||
wait_url_ready() {
|
||||
local url="$1"
|
||||
local label="$2"
|
||||
local attempts="${3:-120}"
|
||||
local delay="${4:-0.5}"
|
||||
for ((i=0; i<attempts; i++)); do
|
||||
if curl -fsS "$url" >/dev/null 2>&1; then
|
||||
return 0
|
||||
fi
|
||||
sleep "$delay"
|
||||
done
|
||||
echo "$label did not become ready: $url" >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
SELECTED_BACKEND_PORT="$BACKEND_PORT"
|
||||
if [[ "$SELECTED_BACKEND_PORT" == "0" ]]; then
|
||||
SELECTED_BACKEND_PORT="$(get_free_port)"
|
||||
fi
|
||||
SELECTED_FRONTEND_PORT="$FRONTEND_PORT"
|
||||
if [[ "$SELECTED_FRONTEND_PORT" == "0" ]]; then
|
||||
SELECTED_FRONTEND_PORT="$(get_free_port)"
|
||||
fi
|
||||
SELECTED_SMTP_PORT="$(get_free_port)"
|
||||
|
||||
BACKEND_BASE_URL="http://127.0.0.1:${SELECTED_BACKEND_PORT}"
|
||||
FRONTEND_BASE_URL="http://127.0.0.1:${SELECTED_FRONTEND_PORT}"
|
||||
SQLITE_PATH="$DATA_ROOT/user_management.e2e.db"
|
||||
|
||||
cd "$PROJECT_ROOT"
|
||||
go build -o "$SERVER_BIN" ./cmd/server
|
||||
|
||||
echo "playwright e2e backend: $BACKEND_BASE_URL"
|
||||
echo "playwright e2e frontend: $FRONTEND_BASE_URL"
|
||||
echo "playwright e2e smtp: 127.0.0.1:$SELECTED_SMTP_PORT"
|
||||
echo "playwright e2e sqlite: $SQLITE_PATH"
|
||||
|
||||
node "$SCRIPT_DIR/mock-smtp-capture.mjs" --port "$SELECTED_SMTP_PORT" --output "$SMTP_CAPTURE_FILE" >"$TMP_ROOT/smtp.log" 2>&1 &
|
||||
smtp_pid=$!
|
||||
sleep 0.5
|
||||
if ! kill -0 "$smtp_pid" 2>/dev/null; then
|
||||
cat "$TMP_ROOT/smtp.log" >&2 || true
|
||||
echo "smtp capture server failed to start" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
(
|
||||
export SERVER_PORT="$SELECTED_BACKEND_PORT"
|
||||
export DATABASE_DBNAME="$SQLITE_PATH"
|
||||
export SERVER_MODE='debug'
|
||||
export SERVER_FRONTEND_URL="$FRONTEND_BASE_URL"
|
||||
export CORS_ALLOWED_ORIGINS="$FRONTEND_BASE_URL,http://localhost:${SELECTED_FRONTEND_PORT}"
|
||||
export LOGGING_OUTPUT='stdout'
|
||||
export DISABLE_RATE_LIMIT='1'
|
||||
export EMAIL_HOST='127.0.0.1'
|
||||
export EMAIL_PORT="$SELECTED_SMTP_PORT"
|
||||
export EMAIL_FROM_EMAIL='noreply@test.local'
|
||||
export EMAIL_FROM_NAME='UMS E2E'
|
||||
export JWT_SECRET='e2e-test-jwt-secret-at-least-32-bytes-long-for-security'
|
||||
export BOOTSTRAP_SECRET="$BOOTSTRAP_SECRET_VALUE"
|
||||
exec "$SERVER_BIN"
|
||||
) >"$TMP_ROOT/backend.log" 2>&1 &
|
||||
backend_pid=$!
|
||||
|
||||
if ! wait_url_ready "$BACKEND_BASE_URL/health" 'backend'; then
|
||||
cat "$TMP_ROOT/backend.log" >&2 || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
(
|
||||
cd "$FRONTEND_ROOT"
|
||||
export VITE_API_PROXY_TARGET="$BACKEND_BASE_URL"
|
||||
export VITE_API_BASE_URL='/api/v1'
|
||||
exec env -u NODE_ENV npm run dev -- --host 127.0.0.1 --port "$SELECTED_FRONTEND_PORT"
|
||||
) >"$TMP_ROOT/frontend.log" 2>&1 &
|
||||
frontend_pid=$!
|
||||
|
||||
if ! wait_url_ready "$FRONTEND_BASE_URL" 'frontend'; then
|
||||
cat "$TMP_ROOT/frontend.log" >&2 || true
|
||||
exit 1
|
||||
fi
|
||||
|
||||
cd "$FRONTEND_ROOT"
|
||||
export E2E_LOGIN_USERNAME="$ADMIN_USERNAME"
|
||||
export E2E_LOGIN_PASSWORD="$ADMIN_PASSWORD"
|
||||
export E2E_LOGIN_EMAIL="$ADMIN_EMAIL"
|
||||
export E2E_BOOTSTRAP_SECRET="$BOOTSTRAP_SECRET_VALUE"
|
||||
export BOOTSTRAP_SECRET="$BOOTSTRAP_SECRET_VALUE"
|
||||
export E2E_EXPECT_ADMIN_BOOTSTRAP='1'
|
||||
export E2E_EXTERNAL_WEB_SERVER='1'
|
||||
export E2E_MANAGED_BROWSER='1'
|
||||
export E2E_BASE_URL="$FRONTEND_BASE_URL"
|
||||
export E2E_SMTP_CAPTURE_FILE="$SMTP_CAPTURE_FILE"
|
||||
|
||||
env -u NODE_ENV node ./scripts/run-playwright-cdp-e2e.mjs
|
||||
@@ -18,16 +18,18 @@ const TEXT = {
|
||||
assignPermissions: '\u5206\u914d\u6743\u9650',
|
||||
assignRoles: '\u5206\u914d\u89d2\u8272',
|
||||
assignRolesAction: '\u89d2\u8272',
|
||||
auditLogs: '\u5ba1\u8ba1\u65e5\u5fd7',
|
||||
backToLogin: '\u8fd4\u56de\u767b\u5f55',
|
||||
bootstrapAdminConfirmPasswordPlaceholder: '\u786e\u8ba4\u7ba1\u7406\u5458\u5bc6\u7801',
|
||||
bootstrapAdminEmailPlaceholder: '\u7ba1\u7406\u5458\u90ae\u7bb1\uff08\u9009\u586b\uff09',
|
||||
bootstrapAdminEmailPlaceholder: '\u7ba1\u7406\u5458\u90ae\u7bb1',
|
||||
bootstrapAdminPasswordPlaceholder: '\u7ba1\u7406\u5458\u5bc6\u7801',
|
||||
bootstrapAdminSecretPlaceholder: 'Bootstrap Secret',
|
||||
bootstrapAdminSubmit: '\u5b8c\u6210\u521d\u59cb\u5316\u5e76\u8fdb\u5165\u7cfb\u7edf',
|
||||
bootstrapAdminUsernamePlaceholder: '\u7ba1\u7406\u5458\u7528\u6237\u540d',
|
||||
changePassword: '\u4fee\u6539\u5bc6\u7801',
|
||||
confirmPasswordPlaceholder: '\u786e\u8ba4\u5bc6\u7801',
|
||||
createAccount: '\u521b\u5efa\u8d26\u53f7',
|
||||
createUser: '\u521b\u5efa\u7528\u5458',
|
||||
createUser: '\u521b\u5efa\u7528\u6237',
|
||||
createUserEmailPlaceholder: '\u90ae\u7bb1\u5730\u5740',
|
||||
createUserPasswordPlaceholder: '\u8bf7\u8f93\u5165\u521d\u59cb\u5bc6\u7801',
|
||||
createUserUsernamePlaceholder: '\u8bf7\u8f93\u5165\u7528\u6237\u540d',
|
||||
@@ -45,6 +47,7 @@ const TEXT = {
|
||||
emailActivationSuccess: '\u90ae\u7bb1\u9a8c\u8bc1\u6210\u529f',
|
||||
export: '\u5bfc\u51fa',
|
||||
forgotPassword: '\u5fd8\u8bb0\u5bc6\u7801\uff1f',
|
||||
integration: '\u96c6\u6210\u80fd\u529b',
|
||||
loginAction: '\u767b\u5f55',
|
||||
loginLogs: '\u767b\u5f55\u65e5\u5fd7',
|
||||
loginNow: '\u7acb\u5373\u767b\u5f55',
|
||||
@@ -104,6 +107,7 @@ const SMTP_CAPTURE_FILE = (process.env.E2E_SMTP_CAPTURE_FILE ?? '').trim()
|
||||
const SESSION_PRESENCE_COOKIE_NAME = 'ums_session_present'
|
||||
|
||||
let managedCdpUrl = null
|
||||
const IS_WINDOWS = process.platform === 'win32'
|
||||
|
||||
function appUrl(pathname) {
|
||||
return new URL(pathname, `${BASE_URL}/`).toString()
|
||||
@@ -193,6 +197,16 @@ async function waitForActivationLink(email, timeoutMs = 20_000) {
|
||||
throw new Error(`Timed out waiting for activation email for ${email}.`)
|
||||
}
|
||||
|
||||
async function fetchAuthCapabilitiesSnapshot() {
|
||||
const response = await fetch(appUrl('/api/v1/auth/capabilities'))
|
||||
if (!response.ok) {
|
||||
throw new Error(`Failed to fetch auth capabilities: ${response.status} ${response.statusText}`)
|
||||
}
|
||||
|
||||
const payload = await response.json()
|
||||
return payload?.data ?? {}
|
||||
}
|
||||
|
||||
function resolveCdpUrl() {
|
||||
if (managedCdpUrl) {
|
||||
return managedCdpUrl
|
||||
@@ -272,12 +286,24 @@ async function resolveManagedBrowserPath() {
|
||||
return candidate
|
||||
}
|
||||
|
||||
for (const candidate of [
|
||||
'C:\\Program Files\\Google\\Chrome\\Application\\chrome.exe',
|
||||
'C:\\Program Files (x86)\\Google\\Chrome\\Application\\chrome.exe',
|
||||
'C:\\Program Files\\Microsoft\\Edge\\Application\\msedge.exe',
|
||||
'C:\\Program Files (x86)\\Microsoft\\Edge\\Application\\msedge.exe',
|
||||
]) {
|
||||
const platformCandidates = IS_WINDOWS
|
||||
? [
|
||||
'C:\\Program Files\\Google\\Chrome\\Application\\chrome.exe',
|
||||
'C:\\Program Files (x86)\\Google\\Chrome\\Application\\chrome.exe',
|
||||
'C:\\Program Files\\Microsoft\\Edge\\Application\\msedge.exe',
|
||||
'C:\\Program Files (x86)\\Microsoft\\Edge\\Application\\msedge.exe',
|
||||
]
|
||||
: [
|
||||
'/snap/bin/chromium',
|
||||
'/usr/bin/chromium',
|
||||
'/usr/bin/chromium-browser',
|
||||
'/usr/bin/google-chrome',
|
||||
'/usr/bin/google-chrome-stable',
|
||||
'/usr/bin/microsoft-edge',
|
||||
'/usr/bin/msedge',
|
||||
]
|
||||
|
||||
for (const candidate of platformCandidates) {
|
||||
try {
|
||||
await assertFileExists(candidate)
|
||||
return candidate
|
||||
@@ -286,7 +312,9 @@ async function resolveManagedBrowserPath() {
|
||||
}
|
||||
}
|
||||
|
||||
const baseDir = path.join(process.env.LOCALAPPDATA ?? '', 'ms-playwright')
|
||||
const baseDir = IS_WINDOWS
|
||||
? path.join(process.env.LOCALAPPDATA ?? '', 'ms-playwright')
|
||||
: path.join(process.env.HOME ?? '', '.cache', 'ms-playwright')
|
||||
const candidates = []
|
||||
|
||||
try {
|
||||
@@ -297,11 +325,16 @@ async function resolveManagedBrowserPath() {
|
||||
}
|
||||
|
||||
candidates.push(
|
||||
path.join(baseDir, entry.name, 'chrome-headless-shell-win64', 'chrome-headless-shell.exe'),
|
||||
path.join(
|
||||
baseDir,
|
||||
entry.name,
|
||||
IS_WINDOWS ? 'chrome-headless-shell-win64' : 'chrome-headless-shell-linux64',
|
||||
IS_WINDOWS ? 'chrome-headless-shell.exe' : 'chrome-headless-shell',
|
||||
),
|
||||
)
|
||||
}
|
||||
} catch {
|
||||
throw new Error('failed to scan Playwright browser cache under LOCALAPPDATA')
|
||||
throw new Error(`failed to scan Playwright browser cache under ${baseDir}`)
|
||||
}
|
||||
|
||||
candidates.sort().reverse()
|
||||
@@ -376,6 +409,15 @@ async function killManagedBrowser(browserProcess) {
|
||||
return
|
||||
}
|
||||
|
||||
if (!IS_WINDOWS) {
|
||||
try {
|
||||
browserProcess.kill('SIGKILL')
|
||||
} catch {
|
||||
// ignore
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
await new Promise((resolve) => {
|
||||
const killer = spawn('taskkill', ['/PID', String(browserProcess.pid), '/T', '/F'], {
|
||||
stdio: 'ignore',
|
||||
@@ -547,8 +589,28 @@ function attachSignalCollectors(page, signals) {
|
||||
}
|
||||
}
|
||||
|
||||
async function assertBaseUrlServesAdminApp(page) {
|
||||
await page.goto(appUrl('/login'), { waitUntil: 'domcontentloaded' })
|
||||
await page.waitForLoadState('networkidle').catch(() => {})
|
||||
|
||||
const title = await page.title().catch(() => '')
|
||||
const bodyText = (await page.locator('body').textContent())?.trim() ?? ''
|
||||
const matchesAppTitle = title.includes(TEXT.appTitle)
|
||||
const matchesAppBody = bodyText.includes(TEXT.welcomeLogin) || bodyText.includes(TEXT.adminBootstrapTitle)
|
||||
if (matchesAppTitle || matchesAppBody) {
|
||||
return
|
||||
}
|
||||
|
||||
throw new Error(
|
||||
`E2E_BASE_URL resolved to ${appUrl('/login')}, but the page does not look like the admin app. ` +
|
||||
`title=${JSON.stringify(title)} body_excerpt=${JSON.stringify(bodyText.slice(0, 160))}. ` +
|
||||
`Set E2E_BASE_URL to the running frontend app (default expects the Vite dev server on :3000).`,
|
||||
)
|
||||
}
|
||||
|
||||
async function resetBrowserState(context, page) {
|
||||
logDebug('resetting browser state')
|
||||
await page.setViewportSize({ width: VIEWPORTS[0].width, height: VIEWPORTS[0].height })
|
||||
await context.clearCookies()
|
||||
await page.goto(appUrl('/login'), { waitUntil: 'domcontentloaded' })
|
||||
await page.evaluate(() => {
|
||||
@@ -709,7 +771,12 @@ async function forceClick(locator) {
|
||||
})
|
||||
}
|
||||
|
||||
async function readRefreshToken(page) {
|
||||
async function hasHttpOnlyRefreshCookie(page) {
|
||||
const cookies = await page.context().cookies()
|
||||
return cookies.some((cookie) => cookie.name === 'ums_refresh_token' && Boolean(cookie.value))
|
||||
}
|
||||
|
||||
async function readSessionPresenceCookie(page) {
|
||||
return await page.evaluate((cookieName) => {
|
||||
const target = `${cookieName}=`
|
||||
const matched = document.cookie
|
||||
@@ -731,19 +798,31 @@ async function assertApiSuccessResponse(response, label) {
|
||||
try {
|
||||
payload = JSON.parse(responseBody)
|
||||
} catch (error) {
|
||||
if (error instanceof SyntaxError) {
|
||||
throw new Error(`${label} response is not valid JSON: ${responseBody}`)
|
||||
}
|
||||
throw error
|
||||
throw new Error(`${label} response is not valid JSON: ${responseBody}`)
|
||||
}
|
||||
|
||||
if (payload?.code !== 0) {
|
||||
throw new Error(`${label} business response failed: ${responseBody}`)
|
||||
throw new Error(`${label} response code ${payload?.code}: ${payload?.message ?? responseBody}`)
|
||||
}
|
||||
|
||||
return payload
|
||||
}
|
||||
|
||||
async function waitForSessionCookies(context, timeoutMs = 10_000) {
|
||||
const startedAt = Date.now()
|
||||
while (Date.now() - startedAt < timeoutMs) {
|
||||
const cookies = await context.cookies()
|
||||
const hasRefresh = cookies.some((cookie) => cookie.name === 'ums_refresh_token' && cookie.value)
|
||||
const hasPresence = cookies.some((cookie) => cookie.name === 'ums_session_present' && cookie.value === '1')
|
||||
if (hasRefresh && hasPresence) {
|
||||
return
|
||||
}
|
||||
await delay(100)
|
||||
}
|
||||
|
||||
throw new Error('session cookies were not persisted after login within timeout')
|
||||
}
|
||||
|
||||
async function loginWithPassword(page, username, password, expectedUrlPattern) {
|
||||
const usernameInput = page
|
||||
.locator(`input[autocomplete="username"], input[placeholder="${TEXT.usernamePlaceholder}"]`)
|
||||
@@ -761,12 +840,25 @@ async function loginWithPassword(page, username, password, expectedUrlPattern) {
|
||||
if (loginResponse) {
|
||||
await assertApiSuccessResponse(loginResponse, 'password login')
|
||||
}
|
||||
await waitForSessionCookies(page.context())
|
||||
|
||||
if (expectedUrlPattern) {
|
||||
await expect(page).toHaveURL(expectedUrlPattern, { timeout: 30 * 1000 })
|
||||
}
|
||||
}
|
||||
|
||||
async function expectLoggedInLanding(page, timeoutMs = 30 * 1000) {
|
||||
await expect(page).toHaveURL(/\/(dashboard|profile)$/, { timeout: timeoutMs })
|
||||
|
||||
const currentUrl = page.url()
|
||||
if (currentUrl.endsWith('/dashboard')) {
|
||||
await expect(page.getByText(TEXT.todaySuccessLogins)).toBeVisible()
|
||||
return
|
||||
}
|
||||
|
||||
await expect(page.locator('body')).toContainText(TEXT.profile)
|
||||
}
|
||||
|
||||
async function loginFromLoginPage(page) {
|
||||
const username = requireEnv('E2E_LOGIN_USERNAME')
|
||||
const password = requireEnv('E2E_LOGIN_PASSWORD')
|
||||
@@ -775,7 +867,8 @@ async function loginFromLoginPage(page) {
|
||||
await expect(page).toHaveURL(/\/login$/)
|
||||
await expect(page.getByRole('heading', { name: TEXT.welcomeLogin })).toBeVisible()
|
||||
|
||||
await loginWithPassword(page, username, password, /\/dashboard$/)
|
||||
await loginWithPassword(page, username, password)
|
||||
await expectLoggedInLanding(page)
|
||||
|
||||
return { username, password }
|
||||
}
|
||||
@@ -784,6 +877,10 @@ async function verifyAdminBootstrapWorkflow(page) {
|
||||
const username = requireEnv('E2E_LOGIN_USERNAME')
|
||||
const password = requireEnv('E2E_LOGIN_PASSWORD')
|
||||
const email = (process.env.E2E_LOGIN_EMAIL ?? `${username}@example.com`).trim()
|
||||
const bootstrapSecret = (process.env.E2E_BOOTSTRAP_SECRET ?? process.env.BOOTSTRAP_SECRET ?? '').trim()
|
||||
if (!bootstrapSecret) {
|
||||
throw new Error('E2E_BOOTSTRAP_SECRET or BOOTSTRAP_SECRET is required when E2E_EXPECT_ADMIN_BOOTSTRAP=1.')
|
||||
}
|
||||
|
||||
const capabilitiesResponse = page.waitForResponse((response) => {
|
||||
return response.url().includes('/api/v1/auth/capabilities') && response.request().method() === 'GET'
|
||||
@@ -800,6 +897,7 @@ async function verifyAdminBootstrapWorkflow(page) {
|
||||
|
||||
await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminUsernamePlaceholder}"]`).first(), username)
|
||||
await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminEmailPlaceholder}"]`).first(), email)
|
||||
await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminSecretPlaceholder}"]`).first(), bootstrapSecret)
|
||||
await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminPasswordPlaceholder}"]`).first(), password)
|
||||
await forceFillInput(page.locator(`input[placeholder="${TEXT.bootstrapAdminConfirmPasswordPlaceholder}"]`).first(), password)
|
||||
|
||||
@@ -811,8 +909,7 @@ async function verifyAdminBootstrapWorkflow(page) {
|
||||
])
|
||||
await assertApiSuccessResponse(bootstrapResponse, 'bootstrap admin')
|
||||
|
||||
await expect(page).toHaveURL(/\/dashboard$/, { timeout: 30 * 1000 })
|
||||
await expect(page.getByText(TEXT.todaySuccessLogins)).toBeVisible()
|
||||
await expectLoggedInLanding(page)
|
||||
|
||||
await forceClick(page.locator('[class*="userTrigger"]'))
|
||||
await forceClick(page.getByText(TEXT.logout, { exact: true }))
|
||||
@@ -1012,7 +1109,8 @@ async function verifyAuthWorkflow(page) {
|
||||
await page.goto(appUrl('/users'))
|
||||
await expect(page).toHaveURL(/\/users$/)
|
||||
|
||||
expect(await readRefreshToken(page)).toBeTruthy()
|
||||
expect(await hasHttpOnlyRefreshCookie(page)).toBe(true)
|
||||
expect(await readSessionPresenceCookie(page)).toBe('1')
|
||||
|
||||
const userRow = page.locator('tbody tr').filter({ hasText: credentials.username }).first()
|
||||
await expect(userRow).toBeVisible({ timeout: 20 * 1000 })
|
||||
@@ -1084,7 +1182,8 @@ async function verifyAuthWorkflow(page) {
|
||||
await forceClick(page.locator('[class*="userTrigger"]'))
|
||||
await forceClick(page.getByText(TEXT.logout, { exact: true }))
|
||||
await expect(page).toHaveURL(/\/login$/)
|
||||
await expect(await readRefreshToken(page)).toBeNull()
|
||||
await expect(await hasHttpOnlyRefreshCookie(page)).toBe(false)
|
||||
await expect(await readSessionPresenceCookie(page)).toBeNull()
|
||||
|
||||
await page.goto(appUrl('/dashboard'))
|
||||
const postLogoutRedirect = await getProtectedRouteRedirect(page)
|
||||
@@ -1191,7 +1290,7 @@ async function verifyUserManagementCRUD(page) {
|
||||
|
||||
const userRow = page.locator('tbody tr').filter({ hasText: testUsername }).first()
|
||||
await forceClick(userRow.getByRole('button', { name: TEXT.edit }))
|
||||
const editDrawer = page.locator('.ant-drawer')
|
||||
const editDrawer = page.locator('.ant-drawer.ant-drawer-open')
|
||||
await expect(editDrawer).toBeVisible({ timeout: 10 * 1000 })
|
||||
|
||||
const editResponsePromise = page.waitForResponse((response) => {
|
||||
@@ -1202,7 +1301,7 @@ async function verifyUserManagementCRUD(page) {
|
||||
await assertApiSuccessResponse(editResponse, 'edit user CRUD')
|
||||
|
||||
await forceClick(userRow.getByRole('button', { name: TEXT.userDetailAction }))
|
||||
const detailDrawer = page.locator('.ant-drawer')
|
||||
const detailDrawer = page.locator('.ant-drawer.ant-drawer-open')
|
||||
await expect(detailDrawer).toBeVisible({ timeout: 10 * 1000 })
|
||||
await expect(detailDrawer).toContainText(testUsername)
|
||||
|
||||
@@ -1211,13 +1310,14 @@ async function verifyUserManagementCRUD(page) {
|
||||
await expect(page.locator('tbody tr').filter({ hasText: testUsername }).first()).toBeVisible({ timeout: 10 * 1000 })
|
||||
|
||||
await forceClick(userRow.getByRole('button', { name: TEXT.delete }))
|
||||
const deleteConfirmModal = page.locator('.ant-modal-confirm')
|
||||
const deleteConfirmModal = page.locator('.ant-popover').filter({ hasText: '确定要删除用户' }).last()
|
||||
await expect(deleteConfirmModal).toBeVisible({ timeout: 10 * 1000 })
|
||||
const deleteResponsePromise = page.waitForResponse((response) => {
|
||||
return response.url().includes(`/api/v1/users/`) && response.request().method() === 'DELETE'
|
||||
})
|
||||
await forceClick(deleteConfirmModal.locator('.ant-btn-primary').last())
|
||||
const deleteResponse = await deleteResponsePromise
|
||||
const [deleteResponse] = await Promise.all([
|
||||
page.waitForResponse((response) => {
|
||||
return response.url().includes(`/api/v1/users/`) && response.request().method() === 'DELETE'
|
||||
}),
|
||||
forceClick(deleteConfirmModal.locator('.ant-popconfirm-buttons .ant-btn-primary').last()),
|
||||
])
|
||||
await assertApiSuccessResponse(deleteResponse, 'delete user CRUD')
|
||||
|
||||
await expect(page.locator('tbody tr').filter({ hasText: testUsername }).first()).toHaveCount(0, { timeout: 10 * 1000 })
|
||||
@@ -1255,8 +1355,7 @@ async function verifyDeviceManagement(page) {
|
||||
logDebug('verifyDeviceManagement: login /login')
|
||||
await loginFromLoginPage(page)
|
||||
|
||||
await expandSidebarGroup(page, TEXT.systemManagement)
|
||||
await clickSidebarMenu(page, TEXT.devices)
|
||||
await page.goto(appUrl('/devices'))
|
||||
await expect(page).toHaveURL(/\/devices$/)
|
||||
|
||||
await expect(page.getByText(TEXT.deviceManagement)).toBeVisible({ timeout: 10 * 1000 })
|
||||
@@ -1270,11 +1369,11 @@ async function verifyLoginLogs(page) {
|
||||
logDebug('verifyLoginLogs: login /login')
|
||||
await loginFromLoginPage(page)
|
||||
|
||||
await expandSidebarGroup(page, TEXT.systemManagement)
|
||||
await expandSidebarGroup(page, TEXT.auditLogs)
|
||||
await clickSidebarMenu(page, TEXT.loginLogs)
|
||||
await expect(page).toHaveURL(/\/login-logs$/)
|
||||
await expect(page).toHaveURL(/\/logs\/login$/)
|
||||
|
||||
await expect(page.getByText(TEXT.loginLogs)).toBeVisible({ timeout: 10 * 1000 })
|
||||
await expect(page.getByRole('heading', { name: TEXT.loginLogs })).toBeVisible({ timeout: 10 * 1000 })
|
||||
|
||||
await forceClick(page.locator('[class*="userTrigger"]'))
|
||||
await forceClick(page.getByText(TEXT.logout, { exact: true }))
|
||||
@@ -1285,11 +1384,11 @@ async function verifyOperationLogs(page) {
|
||||
logDebug('verifyOperationLogs: login /login')
|
||||
await loginFromLoginPage(page)
|
||||
|
||||
await expandSidebarGroup(page, TEXT.systemManagement)
|
||||
await expandSidebarGroup(page, TEXT.auditLogs)
|
||||
await clickSidebarMenu(page, TEXT.operationLogs)
|
||||
await expect(page).toHaveURL(/\/operation-logs$/)
|
||||
await expect(page).toHaveURL(/\/logs\/operation$/)
|
||||
|
||||
await expect(page.getByText(TEXT.operationLogs)).toBeVisible({ timeout: 10 * 1000 })
|
||||
await expect(page.getByRole('heading', { name: TEXT.operationLogs })).toBeVisible({ timeout: 10 * 1000 })
|
||||
|
||||
await forceClick(page.locator('[class*="userTrigger"]'))
|
||||
await forceClick(page.getByText(TEXT.logout, { exact: true }))
|
||||
@@ -1300,11 +1399,11 @@ async function verifyWebhookManagement(page) {
|
||||
logDebug('verifyWebhookManagement: login /login')
|
||||
await loginFromLoginPage(page)
|
||||
|
||||
await expandSidebarGroup(page, TEXT.systemManagement)
|
||||
await expandSidebarGroup(page, TEXT.integration)
|
||||
await clickSidebarMenu(page, TEXT.webhooks)
|
||||
await expect(page).toHaveURL(/\/webhooks$/)
|
||||
|
||||
await expect(page.getByText(TEXT.webhooks)).toBeVisible({ timeout: 10 * 1000 })
|
||||
await expect(page.locator('body')).toContainText('Webhook 管理', { timeout: 10 * 1000 })
|
||||
|
||||
await forceClick(page.locator('[class*="userTrigger"]'))
|
||||
await forceClick(page.getByText(TEXT.logout, { exact: true }))
|
||||
@@ -1322,10 +1421,10 @@ async function verifyProfileAndSecurity(page) {
|
||||
await expect(page.locator('body')).toContainText(credentials.username, { timeout: 10 * 1000 })
|
||||
|
||||
await forceClick(page.locator('[class*="userTrigger"]'))
|
||||
await forceClick(page.getByText(TEXT.security))
|
||||
await forceClick(page.locator('.ant-dropdown').getByText(TEXT.security, { exact: true }).last())
|
||||
await expect(page).toHaveURL(/\/profile\/security$/)
|
||||
|
||||
await expect(page.getByText(TEXT.changePassword)).toBeVisible({ timeout: 10 * 1000 })
|
||||
await expect(page.getByRole('button', { name: TEXT.changePassword })).toBeVisible({ timeout: 10 * 1000 })
|
||||
|
||||
await forceClick(page.locator('[class*="userTrigger"]'))
|
||||
await forceClick(page.getByText(TEXT.logout, { exact: true }))
|
||||
@@ -1370,11 +1469,22 @@ async function main() {
|
||||
throw new Error('No persistent Chromium context is available through CDP.')
|
||||
}
|
||||
|
||||
const preflightPage = await ensurePersistentPage(browser, context)
|
||||
if (!preflightPage) {
|
||||
throw new Error('No persistent page is available in the Chromium CDP context.')
|
||||
}
|
||||
await assertBaseUrlServesAdminApp(preflightPage)
|
||||
const authCapabilities = await fetchAuthCapabilitiesSnapshot()
|
||||
|
||||
if (process.env.E2E_EXPECT_ADMIN_BOOTSTRAP === '1') {
|
||||
await runScenario(browser, context, 'admin-bootstrap', verifyAdminBootstrapWorkflow)
|
||||
}
|
||||
await runScenario(browser, context, 'public-registration', verifyPublicRegistration)
|
||||
await runScenario(browser, context, 'email-activation', verifyEmailActivationWorkflow)
|
||||
if (authCapabilities.email_activation) {
|
||||
await runScenario(browser, context, 'email-activation', verifyEmailActivationWorkflow)
|
||||
} else {
|
||||
console.log('SKIP email-activation (auth capability disabled)')
|
||||
}
|
||||
await runScenario(browser, context, 'login-surface', verifyLoginSurface)
|
||||
await runScenario(browser, context, 'auth-workflow', verifyAuthWorkflow)
|
||||
await runScenario(browser, context, 'responsive-login', verifyResponsiveLogin)
|
||||
|
||||
@@ -18,6 +18,7 @@ import { CSRF_PROTECTED_METHODS, getCSRFHeaders } from './csrf'
|
||||
import type { TokenBundle } from '@/types'
|
||||
|
||||
const DEFAULT_TIMEOUT = 30_000
|
||||
let inFlightRefreshBundle: Promise<TokenBundle> | null = null
|
||||
|
||||
function isFormDataBody(body: unknown): body is FormData {
|
||||
return typeof FormData !== 'undefined' && body instanceof FormData
|
||||
@@ -145,6 +146,40 @@ async function refreshAccessToken(): Promise<TokenBundle> {
|
||||
return result.data
|
||||
}
|
||||
|
||||
async function performTokenRefresh(): Promise<TokenBundle> {
|
||||
if (inFlightRefreshBundle) {
|
||||
return inFlightRefreshBundle
|
||||
}
|
||||
|
||||
startRefreshing()
|
||||
const promise = (async () => {
|
||||
try {
|
||||
const tokenBundle = await refreshAccessToken()
|
||||
setAccessToken(tokenBundle.access_token, tokenBundle.expires_in)
|
||||
setRefreshToken(tokenBundle.refresh_token)
|
||||
return tokenBundle
|
||||
} finally {
|
||||
endRefreshing()
|
||||
clearRefreshPromise()
|
||||
inFlightRefreshBundle = null
|
||||
}
|
||||
})()
|
||||
|
||||
inFlightRefreshBundle = promise
|
||||
setRefreshPromise(
|
||||
promise.then(
|
||||
() => undefined,
|
||||
() => undefined,
|
||||
),
|
||||
)
|
||||
|
||||
return promise
|
||||
}
|
||||
|
||||
export async function refreshSessionBundle(): Promise<TokenBundle> {
|
||||
return await performTokenRefresh()
|
||||
}
|
||||
|
||||
async function performRefresh(): Promise<string> {
|
||||
if (isRefreshing()) {
|
||||
const promise = getRefreshPromise()
|
||||
@@ -160,26 +195,8 @@ async function performRefresh(): Promise<string> {
|
||||
return token
|
||||
}
|
||||
|
||||
startRefreshing()
|
||||
const promise = (async () => {
|
||||
try {
|
||||
const tokenBundle = await refreshAccessToken()
|
||||
setAccessToken(tokenBundle.access_token, tokenBundle.expires_in)
|
||||
setRefreshToken(tokenBundle.refresh_token)
|
||||
return tokenBundle.access_token
|
||||
} finally {
|
||||
endRefreshing()
|
||||
clearRefreshPromise()
|
||||
}
|
||||
})()
|
||||
|
||||
setRefreshPromise(
|
||||
promise.then(
|
||||
() => undefined,
|
||||
() => undefined,
|
||||
),
|
||||
)
|
||||
return promise
|
||||
const tokenBundle = await performTokenRefresh()
|
||||
return tokenBundle.access_token
|
||||
}
|
||||
|
||||
async function resolveAuthorizationHeader(auth: boolean): Promise<string | null> {
|
||||
|
||||
@@ -345,14 +345,12 @@ export function ContactBindingsSection({
|
||||
label="验证码"
|
||||
rules={[{ required: true, message: '请输入验证码' }]}
|
||||
>
|
||||
<Input
|
||||
placeholder="请输入验证码"
|
||||
addonAfter={
|
||||
<Button type="link" size="small" loading={sendCodeLoading} onClick={handleSendCode}>
|
||||
发送验证码
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
<Space.Compact style={{ width: '100%' }}>
|
||||
<Input placeholder="请输入验证码" />
|
||||
<Button type="link" loading={sendCodeLoading} onClick={handleSendCode}>
|
||||
发送验证码
|
||||
</Button>
|
||||
</Space.Compact>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item name="current_password" label="当前密码">
|
||||
|
||||
@@ -29,7 +29,7 @@ const authContextValue: AuthContextValue = {
|
||||
|
||||
function renderBootstrapAdminPage() {
|
||||
return render(
|
||||
<MemoryRouter initialEntries={['/bootstrap-admin']}>
|
||||
<MemoryRouter future={{ v7_startTransition: true, v7_relativeSplatPath: true }} initialEntries={['/bootstrap-admin']}>
|
||||
<AuthContext.Provider value={authContextValue}>
|
||||
<BootstrapAdminPage />
|
||||
</AuthContext.Provider>
|
||||
@@ -88,7 +88,8 @@ describe('BootstrapAdminPage', () => {
|
||||
|
||||
await user.type(screen.getByPlaceholderText('管理员用户名'), 'bootstrap_admin')
|
||||
await user.type(screen.getByPlaceholderText('管理员昵称(选填)'), 'Bootstrap Admin')
|
||||
await user.type(screen.getByPlaceholderText('管理员邮箱(选填)'), 'bootstrap_admin@example.com')
|
||||
await user.type(screen.getByPlaceholderText('管理员邮箱'), 'bootstrap_admin@example.com')
|
||||
await user.type(screen.getByPlaceholderText('Bootstrap Secret'), 'bootstrap-secret-demo')
|
||||
await user.type(screen.getByPlaceholderText('管理员密码'), 'Bootstrap123!@#')
|
||||
await user.type(screen.getByPlaceholderText('确认管理员密码'), 'Bootstrap123!@#')
|
||||
await user.click(screen.getByRole('button', { name: '完成初始化并进入系统' }))
|
||||
@@ -99,6 +100,7 @@ describe('BootstrapAdminPage', () => {
|
||||
nickname: 'Bootstrap Admin',
|
||||
email: 'bootstrap_admin@example.com',
|
||||
password: 'Bootstrap123!@#',
|
||||
bootstrap_secret: 'bootstrap-secret-demo',
|
||||
}),
|
||||
)
|
||||
|
||||
|
||||
@@ -24,7 +24,8 @@ const DEFAULT_CAPABILITIES: AuthCapabilities = {
|
||||
type BootstrapAdminFormValues = {
|
||||
username: string
|
||||
nickname?: string
|
||||
email?: string
|
||||
email: string
|
||||
bootstrapSecret: string
|
||||
password: string
|
||||
confirmPassword: string
|
||||
}
|
||||
@@ -71,7 +72,8 @@ export function BootstrapAdminPage() {
|
||||
const tokenBundle = await bootstrapAdmin({
|
||||
username: values.username.trim(),
|
||||
nickname: values.nickname?.trim() || undefined,
|
||||
email: values.email?.trim() || undefined,
|
||||
email: values.email!.trim(),
|
||||
bootstrap_secret: values.bootstrapSecret!.trim(),
|
||||
password: values.password,
|
||||
})
|
||||
await onLoginSuccess(tokenBundle)
|
||||
@@ -110,7 +112,7 @@ export function BootstrapAdminPage() {
|
||||
初始化首个管理员账号
|
||||
</Title>
|
||||
<Paragraph type="secondary" style={{ marginBottom: 24 }}>
|
||||
当前版本不内置默认账号。首次部署时,请先创建首个管理员账号,初始化完成后系统会自动关闭该入口。
|
||||
当前版本不内置默认账号。首次部署时,请提供 Bootstrap Secret 并创建首个管理员账号,初始化完成后系统会自动关闭该入口。
|
||||
</Paragraph>
|
||||
|
||||
<Alert
|
||||
@@ -143,15 +145,29 @@ export function BootstrapAdminPage() {
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name="email"
|
||||
rules={[{ type: 'email', message: '请输入有效的邮箱地址' }]}
|
||||
rules={[
|
||||
{ required: true, message: '请输入管理员邮箱' },
|
||||
{ type: 'email', message: '请输入有效的邮箱地址' },
|
||||
]}
|
||||
>
|
||||
<Input
|
||||
prefix={<MailOutlined />}
|
||||
placeholder="管理员邮箱(选填)"
|
||||
placeholder="管理员邮箱"
|
||||
size="large"
|
||||
autoComplete="email"
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name="bootstrapSecret"
|
||||
rules={[{ required: true, message: '请输入 Bootstrap Secret' }]}
|
||||
>
|
||||
<Input.Password
|
||||
prefix={<LockOutlined />}
|
||||
placeholder="Bootstrap Secret"
|
||||
size="large"
|
||||
autoComplete="one-time-code"
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
name="password"
|
||||
rules={[{ required: true, message: '请输入管理员密码' }]}
|
||||
|
||||
@@ -41,16 +41,13 @@ const defaultCapabilities: AuthCapabilities = {
|
||||
}
|
||||
|
||||
const activeRegisterResponse: RegisterResponse = {
|
||||
user: {
|
||||
id: 2,
|
||||
username: 'new-user',
|
||||
email: 'new-user@example.com',
|
||||
phone: '',
|
||||
nickname: 'New User',
|
||||
avatar: '',
|
||||
status: 1,
|
||||
},
|
||||
message: 'registered successfully',
|
||||
id: 2,
|
||||
username: 'new-user',
|
||||
email: 'new-user@example.com',
|
||||
phone: '',
|
||||
nickname: 'New User',
|
||||
avatar: '',
|
||||
status: 1,
|
||||
}
|
||||
|
||||
vi.mock('@/services/auth', () => ({
|
||||
@@ -61,7 +58,7 @@ vi.mock('@/services/auth', () => ({
|
||||
|
||||
function renderRegisterPage() {
|
||||
return render(
|
||||
<MemoryRouter initialEntries={['/register']}>
|
||||
<MemoryRouter future={{ v7_startTransition: true, v7_relativeSplatPath: true }} initialEntries={['/register']}>
|
||||
<RegisterPage />
|
||||
</MemoryRouter>,
|
||||
)
|
||||
@@ -321,16 +318,13 @@ describe('RegisterPage', () => {
|
||||
email_activation: true,
|
||||
})
|
||||
registerMock.mockResolvedValue({
|
||||
user: {
|
||||
id: 3,
|
||||
username: 'inactive-user',
|
||||
email: 'inactive-user@example.com',
|
||||
phone: '',
|
||||
nickname: 'Inactive User',
|
||||
avatar: '',
|
||||
status: 0,
|
||||
},
|
||||
message: 'registered successfully, please check your email to activate the account',
|
||||
id: 3,
|
||||
username: 'inactive-user',
|
||||
email: 'inactive-user@example.com',
|
||||
phone: '',
|
||||
nickname: 'Inactive User',
|
||||
avatar: '',
|
||||
status: 0,
|
||||
})
|
||||
|
||||
renderRegisterPage()
|
||||
@@ -350,16 +344,13 @@ describe('RegisterPage', () => {
|
||||
|
||||
it('shows the generic activation summary when the new inactive account has no email address', async () => {
|
||||
registerMock.mockResolvedValue({
|
||||
user: {
|
||||
id: 4,
|
||||
username: 'inactive-without-email',
|
||||
email: '',
|
||||
phone: '',
|
||||
nickname: '',
|
||||
avatar: '',
|
||||
status: 0,
|
||||
},
|
||||
message: 'registered successfully, activation required',
|
||||
id: 4,
|
||||
username: 'inactive-without-email',
|
||||
email: '',
|
||||
phone: '',
|
||||
nickname: '',
|
||||
avatar: '',
|
||||
status: 0,
|
||||
})
|
||||
|
||||
renderRegisterPage()
|
||||
|
||||
@@ -38,10 +38,10 @@ type RegisterFormValues = {
|
||||
confirmPassword: string
|
||||
}
|
||||
|
||||
function buildRegisterSummary(result: RegisterResponse) {
|
||||
if (result.user.status === 0) {
|
||||
if (result.user.email) {
|
||||
return `账号已创建,激活邮件会发送到 ${result.user.email}。请完成激活后再登录。`
|
||||
function buildRegisterSummary(user: RegisterResponse) {
|
||||
if (user.status === 0) {
|
||||
if (user.email) {
|
||||
return `账号已创建,激活邮件会发送到 ${user.email}。请完成激活后再登录。`
|
||||
}
|
||||
return '账号已创建,请按页面提示完成激活后再登录。'
|
||||
}
|
||||
@@ -128,7 +128,7 @@ export function RegisterPage() {
|
||||
form.resetFields()
|
||||
setSmsCountdown(0)
|
||||
setSubmitted(result)
|
||||
message.success(result.user.status === 0 ? '注册成功,请完成邮箱激活' : '注册成功')
|
||||
message.success(result.status === 0 ? '注册成功,请完成邮箱激活' : '注册成功')
|
||||
} catch (error) {
|
||||
message.error(getErrorMessage(error, '注册失败,请检查输入信息后重试'))
|
||||
} finally {
|
||||
@@ -137,7 +137,7 @@ export function RegisterPage() {
|
||||
}, [capabilities.sms_code, form])
|
||||
|
||||
if (submitted) {
|
||||
const activationEmail = submitted.user.email?.trim()
|
||||
const activationEmail = submitted.email?.trim()
|
||||
|
||||
return (
|
||||
<AuthLayout>
|
||||
@@ -146,7 +146,7 @@ export function RegisterPage() {
|
||||
title="注册成功"
|
||||
subTitle={(
|
||||
<Paragraph>
|
||||
<Text strong>{submitted.user.username}</Text>
|
||||
<Text strong>{submitted.username}</Text>
|
||||
{' '}
|
||||
{buildRegisterSummary(submitted)}
|
||||
</Paragraph>
|
||||
@@ -155,7 +155,7 @@ export function RegisterPage() {
|
||||
<Link key="login" to="/login">
|
||||
<Button type="primary">返回登录</Button>
|
||||
</Link>,
|
||||
submitted.user.status === 0 && activationEmail && capabilities.email_activation ? (
|
||||
submitted.status === 0 && activationEmail && capabilities.email_activation ? (
|
||||
<Link key="activation" to={`/activate-account?email=${encodeURIComponent(activationEmail)}`}>
|
||||
<Button>重新发送激活邮件</Button>
|
||||
</Link>
|
||||
|
||||
@@ -2,17 +2,21 @@ import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const getMock = vi.fn()
|
||||
const postMock = vi.fn()
|
||||
const refreshSessionBundleMock = vi.fn()
|
||||
|
||||
vi.mock('@/lib/http/client', () => ({
|
||||
get: getMock,
|
||||
post: postMock,
|
||||
refreshSessionBundle: refreshSessionBundleMock,
|
||||
}))
|
||||
|
||||
describe('auth service', () => {
|
||||
beforeEach(() => {
|
||||
getMock.mockReset()
|
||||
postMock.mockReset()
|
||||
refreshSessionBundleMock.mockReset()
|
||||
postMock.mockResolvedValue(undefined)
|
||||
refreshSessionBundleMock.mockResolvedValue(undefined)
|
||||
})
|
||||
|
||||
it('loads public auth capabilities without auth headers', async () => {
|
||||
@@ -84,6 +88,28 @@ describe('auth service', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('verifies password-login totp with the temporary challenge token', async () => {
|
||||
const { verifyTOTPAfterPasswordLogin } = await import('./auth')
|
||||
|
||||
await verifyTOTPAfterPasswordLogin({
|
||||
user_id: 42,
|
||||
code: '123456',
|
||||
device_id: 'device-1',
|
||||
temp_token: 'temp-token-demo',
|
||||
})
|
||||
|
||||
expect(postMock).toHaveBeenCalledWith(
|
||||
'/auth/login/totp-verify',
|
||||
{
|
||||
user_id: 42,
|
||||
code: '123456',
|
||||
device_id: 'device-1',
|
||||
temp_token: 'temp-token-demo',
|
||||
},
|
||||
{ auth: false, credentials: 'include' },
|
||||
)
|
||||
})
|
||||
|
||||
it('submits public registration without auth headers', async () => {
|
||||
const { register } = await import('./auth')
|
||||
|
||||
@@ -106,7 +132,7 @@ describe('auth service', () => {
|
||||
)
|
||||
})
|
||||
|
||||
it('submits first-admin bootstrap without auth headers', async () => {
|
||||
it('submits first-admin bootstrap with bootstrap secret header', async () => {
|
||||
const { bootstrapAdmin } = await import('./auth')
|
||||
|
||||
await bootstrapAdmin({
|
||||
@@ -114,6 +140,7 @@ describe('auth service', () => {
|
||||
password: 'Bootstrap123!@#',
|
||||
email: 'bootstrap_admin@example.com',
|
||||
nickname: 'Bootstrap Admin',
|
||||
bootstrap_secret: 'bootstrap-secret-demo',
|
||||
})
|
||||
|
||||
expect(postMock).toHaveBeenCalledWith(
|
||||
@@ -124,7 +151,13 @@ describe('auth service', () => {
|
||||
email: 'bootstrap_admin@example.com',
|
||||
nickname: 'Bootstrap Admin',
|
||||
},
|
||||
{ auth: false, credentials: 'include' },
|
||||
{
|
||||
auth: false,
|
||||
credentials: 'include',
|
||||
headers: {
|
||||
'X-Bootstrap-Secret': 'bootstrap-secret-demo',
|
||||
},
|
||||
},
|
||||
)
|
||||
})
|
||||
|
||||
@@ -192,12 +225,13 @@ describe('auth service', () => {
|
||||
expect(postMock).toHaveBeenCalledWith('/auth/logout', undefined, { credentials: 'include' })
|
||||
})
|
||||
|
||||
it('refreshes the session with credentials even when no body token is supplied', async () => {
|
||||
it('refreshes the session through the shared refresh single-flight when no body token is supplied', async () => {
|
||||
const { refreshSession } = await import('./auth')
|
||||
|
||||
await refreshSession()
|
||||
|
||||
expect(postMock).toHaveBeenCalledWith(
|
||||
expect(refreshSessionBundleMock).toHaveBeenCalledTimes(1)
|
||||
expect(postMock).not.toHaveBeenCalledWith(
|
||||
'/auth/refresh',
|
||||
undefined,
|
||||
{ auth: false, credentials: 'include' },
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { get, post } from '@/lib/http/client'
|
||||
import { refreshSessionBundle } from '@/lib/http/client'
|
||||
import type {
|
||||
ActionMessageResponse,
|
||||
AuthCapabilities,
|
||||
@@ -59,7 +60,14 @@ export function register(data: RegisterRequest): Promise<RegisterResponse> {
|
||||
}
|
||||
|
||||
export function bootstrapAdmin(data: BootstrapAdminRequest): Promise<TokenBundle> {
|
||||
return post<TokenBundle>('/auth/bootstrap-admin', data, { auth: false, credentials: 'include' })
|
||||
const { bootstrap_secret, ...payload } = data
|
||||
return post<TokenBundle>('/auth/bootstrap-admin', payload, {
|
||||
auth: false,
|
||||
credentials: 'include',
|
||||
headers: {
|
||||
'X-Bootstrap-Secret': bootstrap_secret,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
export function activateEmail(token: string): Promise<ActionMessageResponse> {
|
||||
@@ -81,8 +89,11 @@ export function sendSmsCode(data: SendSmsCodeRequest): Promise<void> {
|
||||
}
|
||||
|
||||
export function refreshSession(refreshToken?: string | null): Promise<TokenBundle> {
|
||||
const body = refreshToken ? { refresh_token: refreshToken } : undefined
|
||||
return post<TokenBundle>('/auth/refresh', body, { auth: false, credentials: 'include' })
|
||||
if (!refreshToken) {
|
||||
return refreshSessionBundle()
|
||||
}
|
||||
|
||||
return post<TokenBundle>('/auth/refresh', { refresh_token: refreshToken }, { auth: false, credentials: 'include' })
|
||||
}
|
||||
|
||||
export function getOAuthAuthorizationUrl(
|
||||
|
||||
@@ -28,6 +28,29 @@ describe('social account service', () => {
|
||||
expect(getMock).toHaveBeenCalledWith('/users/me/social-accounts')
|
||||
})
|
||||
|
||||
it('normalizes object-wrapped social account payloads', async () => {
|
||||
getMock.mockResolvedValue({
|
||||
social_accounts: [
|
||||
{
|
||||
provider: 'github',
|
||||
provider_user_id: '123',
|
||||
provider_username: 'octocat',
|
||||
bound_at: '2026-03-27 20:00:00',
|
||||
},
|
||||
],
|
||||
})
|
||||
|
||||
const { listSocialAccounts } = await import('./social-accounts')
|
||||
const result = await listSocialAccounts()
|
||||
|
||||
expect(result).toEqual([
|
||||
expect.objectContaining({
|
||||
provider: 'github',
|
||||
provider_username: 'octocat',
|
||||
}),
|
||||
])
|
||||
})
|
||||
|
||||
it('starts social binding with the current verification payload', async () => {
|
||||
const { startSocialBinding } = await import('./social-accounts')
|
||||
|
||||
|
||||
@@ -6,8 +6,35 @@ import type {
|
||||
SocialBindingStartResponse,
|
||||
} from '@/types'
|
||||
|
||||
export function listSocialAccounts(): Promise<SocialAccountInfo[]> {
|
||||
return get<SocialAccountInfo[]>('/users/me/social-accounts')
|
||||
interface SocialAccountsResponse {
|
||||
items?: SocialAccountInfo[]
|
||||
accounts?: SocialAccountInfo[]
|
||||
social_accounts?: SocialAccountInfo[]
|
||||
}
|
||||
|
||||
function normalizeSocialAccounts(payload: SocialAccountInfo[] | SocialAccountsResponse): SocialAccountInfo[] {
|
||||
if (Array.isArray(payload)) {
|
||||
return payload
|
||||
}
|
||||
|
||||
if (Array.isArray(payload.items)) {
|
||||
return payload.items
|
||||
}
|
||||
|
||||
if (Array.isArray(payload.accounts)) {
|
||||
return payload.accounts
|
||||
}
|
||||
|
||||
if (Array.isArray(payload.social_accounts)) {
|
||||
return payload.social_accounts
|
||||
}
|
||||
|
||||
return []
|
||||
}
|
||||
|
||||
export async function listSocialAccounts(): Promise<SocialAccountInfo[]> {
|
||||
const payload = await get<SocialAccountInfo[] | SocialAccountsResponse>('/users/me/social-accounts')
|
||||
return normalizeSocialAccounts(payload)
|
||||
}
|
||||
|
||||
export function startSocialBinding(
|
||||
|
||||
@@ -20,6 +20,52 @@ describe('users service', () => {
|
||||
delMock.mockReset()
|
||||
})
|
||||
|
||||
it('normalizes backend user list payloads that use users/limit/offset fields', async () => {
|
||||
getMock.mockResolvedValue({
|
||||
users: [
|
||||
{
|
||||
id: 7,
|
||||
username: 'e2e_admin',
|
||||
email: 'admin@example.com',
|
||||
nickname: '管理员',
|
||||
status: '1',
|
||||
},
|
||||
],
|
||||
total: 1,
|
||||
limit: 20,
|
||||
offset: 0,
|
||||
})
|
||||
|
||||
const { listUsers } = await import('./users')
|
||||
const result = await listUsers({ page: 1, page_size: 20 })
|
||||
|
||||
expect(getMock).toHaveBeenCalledWith('/users', { page: 1, page_size: 20 })
|
||||
expect(result).toEqual({
|
||||
items: [
|
||||
{
|
||||
id: 7,
|
||||
username: 'e2e_admin',
|
||||
email: 'admin@example.com',
|
||||
phone: '',
|
||||
nickname: '管理员',
|
||||
avatar: '',
|
||||
gender: 0,
|
||||
birthday: '',
|
||||
region: '',
|
||||
bio: '',
|
||||
status: 1,
|
||||
last_login_at: '',
|
||||
last_login_ip: '',
|
||||
created_at: '',
|
||||
updated_at: '',
|
||||
},
|
||||
],
|
||||
total: 1,
|
||||
page: 1,
|
||||
page_size: 20,
|
||||
})
|
||||
})
|
||||
|
||||
it('creates a user through the protected users endpoint', async () => {
|
||||
const payload = {
|
||||
username: 'new-user',
|
||||
|
||||
@@ -17,12 +17,59 @@ import type {
|
||||
AssignUserRolesRequest,
|
||||
} from '@/types/user'
|
||||
|
||||
interface RawUserListResponse {
|
||||
items?: Partial<User>[]
|
||||
users?: Partial<User>[]
|
||||
total?: number
|
||||
page?: number
|
||||
page_size?: number
|
||||
limit?: number
|
||||
offset?: number
|
||||
}
|
||||
|
||||
function normalizeUser(user: Partial<User>): User {
|
||||
const numericStatus = typeof user.status === 'string' ? Number(user.status) : user.status
|
||||
return {
|
||||
id: user.id ?? 0,
|
||||
username: user.username ?? '',
|
||||
email: user.email ?? '',
|
||||
phone: user.phone ?? '',
|
||||
nickname: user.nickname ?? '',
|
||||
avatar: user.avatar ?? '',
|
||||
gender: user.gender ?? 0,
|
||||
birthday: user.birthday ?? '',
|
||||
region: user.region ?? '',
|
||||
bio: user.bio ?? '',
|
||||
status: (typeof numericStatus === 'number' && !Number.isNaN(numericStatus) ? numericStatus : 0) as UserStatus,
|
||||
last_login_at: user.last_login_at ?? '',
|
||||
last_login_ip: user.last_login_ip ?? '',
|
||||
created_at: user.created_at ?? '',
|
||||
updated_at: user.updated_at ?? '',
|
||||
}
|
||||
}
|
||||
|
||||
function normalizeUserListResponse(result?: RawUserListResponse | null): PaginatedData<User> {
|
||||
const payload = result ?? {}
|
||||
const items = Array.isArray(payload.items) ? payload.items : Array.isArray(payload.users) ? payload.users : []
|
||||
const pageSize = payload.page_size ?? payload.limit ?? items.length
|
||||
const offset = payload.offset ?? 0
|
||||
const page = payload.page ?? (pageSize > 0 ? Math.floor(offset / pageSize) + 1 : 1)
|
||||
|
||||
return {
|
||||
items: items.map(normalizeUser),
|
||||
total: payload.total ?? items.length,
|
||||
page,
|
||||
page_size: pageSize,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取用户列表
|
||||
* GET /api/v1/users
|
||||
*/
|
||||
export function listUsers(params: UserListParams): Promise<PaginatedData<User>> {
|
||||
return get<PaginatedData<User>>('/users', params as Record<string, string | number | boolean | undefined>)
|
||||
export async function listUsers(params: UserListParams): Promise<PaginatedData<User>> {
|
||||
const result = await get<RawUserListResponse>('/users', params as Record<string, string | number | boolean | undefined>)
|
||||
return normalizeUserListResponse(result)
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -74,6 +74,44 @@ describe('webhooks service', () => {
|
||||
expect(result.data[2].events).toEqual([])
|
||||
})
|
||||
|
||||
it('normalizes backend webhook list payloads that use items/limit/offset fields', async () => {
|
||||
getMock.mockResolvedValue({
|
||||
items: [
|
||||
{
|
||||
id: 11,
|
||||
name: 'Compat Hook',
|
||||
url: 'https://example.com/compat',
|
||||
events: '["user.updated"]',
|
||||
status: 1,
|
||||
max_retries: 3,
|
||||
timeout_sec: 10,
|
||||
created_by: 1,
|
||||
created_at: '2026-03-27 20:20:00',
|
||||
updated_at: '2026-03-27 20:20:00',
|
||||
},
|
||||
],
|
||||
total: 1,
|
||||
limit: 20,
|
||||
offset: 0,
|
||||
})
|
||||
|
||||
const { listWebhooks } = await import('./webhooks')
|
||||
const result = await listWebhooks({ page: 1, page_size: 20 })
|
||||
|
||||
expect(result).toEqual({
|
||||
data: [
|
||||
expect.objectContaining({
|
||||
id: 11,
|
||||
name: 'Compat Hook',
|
||||
events: ['user.updated'],
|
||||
}),
|
||||
],
|
||||
total: 1,
|
||||
page: 1,
|
||||
page_size: 20,
|
||||
})
|
||||
})
|
||||
|
||||
it('sends create, update, delete, and delivery requests through the HTTP client', async () => {
|
||||
postMock.mockResolvedValue({
|
||||
id: 1,
|
||||
|
||||
@@ -33,18 +33,42 @@ function normalizeWebhook(webhook: RawWebhook): Webhook {
|
||||
}
|
||||
|
||||
interface PaginatedResponse<T> {
|
||||
data: T[]
|
||||
total: number
|
||||
page: number
|
||||
page_size: number
|
||||
data?: T[]
|
||||
items?: T[]
|
||||
webhooks?: T[]
|
||||
total?: number
|
||||
page?: number
|
||||
page_size?: number
|
||||
limit?: number
|
||||
offset?: number
|
||||
}
|
||||
|
||||
function normalizeWebhookList(result: PaginatedResponse<RawWebhook>): { data: Webhook[]; total: number; page: number; page_size: number } {
|
||||
const rawItems = Array.isArray(result.data)
|
||||
? result.data
|
||||
: Array.isArray(result.items)
|
||||
? result.items
|
||||
: Array.isArray(result.webhooks)
|
||||
? result.webhooks
|
||||
: []
|
||||
const data = rawItems.map(normalizeWebhook)
|
||||
const pageSize = result.page_size ?? result.limit ?? data.length
|
||||
const offset = result.offset ?? 0
|
||||
const page = result.page ?? (pageSize > 0 ? Math.floor(offset / pageSize) + 1 : 1)
|
||||
|
||||
return {
|
||||
data,
|
||||
total: result.total ?? data.length,
|
||||
page,
|
||||
page_size: pageSize,
|
||||
}
|
||||
}
|
||||
|
||||
export async function listWebhooks(
|
||||
params?: WebhookListParams,
|
||||
): Promise<{ data: Webhook[]; total: number; page: number; page_size: number }> {
|
||||
const result = await get<PaginatedResponse<RawWebhook>>('/webhooks', params as Record<string, string | number | boolean | undefined>)
|
||||
const webhooks = result.data.map(normalizeWebhook)
|
||||
return { data: webhooks, total: result.total, page: result.page, page_size: result.page_size }
|
||||
return normalizeWebhookList(result)
|
||||
}
|
||||
|
||||
export function createWebhook(data: CreateWebhookRequest): Promise<Webhook> {
|
||||
|
||||
@@ -30,11 +30,74 @@ type AuthHandler struct {
|
||||
authService *service.AuthService
|
||||
}
|
||||
|
||||
const (
|
||||
refreshTokenCookieName = "ums_refresh_token"
|
||||
sessionPresenceCookieName = "ums_session_present"
|
||||
)
|
||||
|
||||
// NewAuthHandler creates a new AuthHandler
|
||||
func NewAuthHandler(authService *service.AuthService) *AuthHandler {
|
||||
return &AuthHandler{authService: authService}
|
||||
}
|
||||
|
||||
func isSecureRequest(c *gin.Context) bool {
|
||||
if c == nil || c.Request == nil {
|
||||
return false
|
||||
}
|
||||
if c.Request.TLS != nil {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(c.GetHeader("X-Forwarded-Proto"), "https")
|
||||
}
|
||||
|
||||
func (h *AuthHandler) setSessionCookies(c *gin.Context, resp *service.LoginResponse) {
|
||||
if c == nil || resp == nil || strings.TrimSpace(resp.RefreshToken) == "" || h == nil || h.authService == nil {
|
||||
return
|
||||
}
|
||||
|
||||
maxAge := int(h.authService.RefreshTokenTTLSeconds())
|
||||
secure := isSecureRequest(c)
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: refreshTokenCookieName,
|
||||
Value: resp.RefreshToken,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: maxAge,
|
||||
})
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: sessionPresenceCookieName,
|
||||
Value: "1",
|
||||
Path: "/",
|
||||
HttpOnly: false,
|
||||
Secure: secure,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: maxAge,
|
||||
})
|
||||
}
|
||||
|
||||
func clearCookie(c *gin.Context, name string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
http.SetCookie(c.Writer, &http.Cookie{
|
||||
Name: name,
|
||||
Value: "",
|
||||
Path: "/",
|
||||
HttpOnly: name == refreshTokenCookieName,
|
||||
Secure: isSecureRequest(c),
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
MaxAge: -1,
|
||||
Expires: time.Unix(0, 0),
|
||||
})
|
||||
}
|
||||
|
||||
func clearSessionCookies(c *gin.Context) {
|
||||
clearCookie(c, refreshTokenCookieName)
|
||||
clearCookie(c, sessionPresenceCookieName)
|
||||
}
|
||||
|
||||
// Register 用户注册
|
||||
// @Summary 用户注册
|
||||
// @Description 用户注册新账号,支持用户名+密码或手机号注册
|
||||
@@ -130,6 +193,7 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.setSessionCookies(c, resp)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
@@ -150,21 +214,23 @@ func (h *AuthHandler) Login(c *gin.Context) {
|
||||
// @Router /api/v1/auth/login/totp-verify [post]
|
||||
func (h *AuthHandler) VerifyTOTPAfterPasswordLogin(c *gin.Context) {
|
||||
var req struct {
|
||||
UserID int64 `json:"user_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
DeviceID string `json:"device_id"`
|
||||
UserID int64 `json:"user_id" binding:"required"`
|
||||
Code string `json:"code" binding:"required"`
|
||||
DeviceID string `json:"device_id"`
|
||||
TempToken string `json:"temp_token" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.authService.VerifyTOTPAfterPasswordLogin(c.Request.Context(), req.UserID, req.Code, req.DeviceID)
|
||||
resp, err := h.authService.VerifyTOTPAfterPasswordLogin(c.Request.Context(), req.UserID, req.Code, req.DeviceID, req.TempToken)
|
||||
if err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.setSessionCookies(c, resp)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
@@ -197,6 +263,12 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if req.RefreshToken == "" {
|
||||
if cookie, err := c.Request.Cookie(refreshTokenCookieName); err == nil {
|
||||
req.RefreshToken = cookie.Value
|
||||
}
|
||||
}
|
||||
|
||||
username, _ := c.Get("username")
|
||||
usernameStr, _ := username.(string)
|
||||
|
||||
@@ -204,7 +276,11 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
AccessToken: req.AccessToken,
|
||||
RefreshToken: req.RefreshToken,
|
||||
}
|
||||
_ = h.authService.Logout(c.Request.Context(), usernameStr, logoutReq)
|
||||
if err := h.authService.Logout(c.Request.Context(), usernameStr, logoutReq); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
clearSessionCookies(c)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"message": "logged out"})
|
||||
}
|
||||
@@ -222,20 +298,28 @@ func (h *AuthHandler) Logout(c *gin.Context) {
|
||||
// @Router /api/v1/auth/refresh-token [post]
|
||||
func (h *AuthHandler) RefreshToken(c *gin.Context) {
|
||||
var req struct {
|
||||
RefreshToken string `json:"refresh_token" binding:"required"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
}
|
||||
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
_ = c.ShouldBindJSON(&req)
|
||||
if strings.TrimSpace(req.RefreshToken) == "" {
|
||||
if cookie, err := c.Request.Cookie(refreshTokenCookieName); err == nil {
|
||||
req.RefreshToken = cookie.Value
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(req.RefreshToken) == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "refresh_token is required"})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := h.authService.RefreshToken(c.Request.Context(), req.RefreshToken)
|
||||
if err != nil {
|
||||
clearSessionCookies(c)
|
||||
handleError(c, err)
|
||||
return
|
||||
}
|
||||
|
||||
h.setSessionCookies(c, resp)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
@@ -315,7 +399,7 @@ func (h *AuthHandler) GetAuthCapabilities(c *gin.Context) {
|
||||
// @Router /api/v1/auth/oauth/{provider} [get]
|
||||
func (h *AuthHandler) OAuthLogin(c *gin.Context) {
|
||||
provider := c.Param("provider")
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "OAuth not configured", "data": gin.H{"provider": provider}})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "OAuth login is not configured", "data": gin.H{"provider": provider}})
|
||||
}
|
||||
|
||||
// OAuthCallback OAuth回调
|
||||
@@ -327,7 +411,7 @@ func (h *AuthHandler) OAuthLogin(c *gin.Context) {
|
||||
// @Success 200 {object} Response "OAuth未配置"
|
||||
// @Router /api/v1/auth/oauth/{provider}/callback [get]
|
||||
func (h *AuthHandler) OAuthCallback(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "OAuth not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "OAuth callback is not configured"})
|
||||
}
|
||||
|
||||
// OAuthExchange OAuth令牌交换
|
||||
@@ -340,7 +424,7 @@ func (h *AuthHandler) OAuthCallback(c *gin.Context) {
|
||||
// @Success 200 {object} Response "OAuth未配置"
|
||||
// @Router /api/v1/auth/oauth/{provider}/exchange [post]
|
||||
func (h *AuthHandler) OAuthExchange(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "OAuth not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "OAuth exchange is not configured"})
|
||||
}
|
||||
|
||||
// GetEnabledOAuthProviders 获取已启用的OAuth提供商
|
||||
@@ -481,6 +565,7 @@ func (h *AuthHandler) LoginByEmailCode(c *gin.Context) {
|
||||
}()
|
||||
}
|
||||
|
||||
h.setSessionCookies(c, resp)
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
@@ -545,6 +630,7 @@ func (h *AuthHandler) BootstrapAdmin(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
h.setSessionCookies(c, resp)
|
||||
c.JSON(http.StatusCreated, gin.H{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
@@ -561,7 +647,7 @@ func (h *AuthHandler) BootstrapAdmin(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/email/bind/send [post]
|
||||
func (h *AuthHandler) SendEmailBindCode(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "email bind not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "email binding is not configured"})
|
||||
}
|
||||
|
||||
// BindEmail 绑定邮箱
|
||||
@@ -573,7 +659,7 @@ func (h *AuthHandler) SendEmailBindCode(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/email/bind [post]
|
||||
func (h *AuthHandler) BindEmail(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "email bind not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "email binding is not configured"})
|
||||
}
|
||||
|
||||
// UnbindEmail 解绑邮箱
|
||||
@@ -585,7 +671,7 @@ func (h *AuthHandler) BindEmail(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/email/unbind [post]
|
||||
func (h *AuthHandler) UnbindEmail(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "email unbind not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "email binding is not configured"})
|
||||
}
|
||||
|
||||
// SendPhoneBindCode 发送手机绑定验证码
|
||||
@@ -597,7 +683,7 @@ func (h *AuthHandler) UnbindEmail(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/phone/bind/send [post]
|
||||
func (h *AuthHandler) SendPhoneBindCode(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "phone bind not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "phone binding is not configured"})
|
||||
}
|
||||
|
||||
// BindPhone 绑定手机号
|
||||
@@ -609,7 +695,7 @@ func (h *AuthHandler) SendPhoneBindCode(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/phone/bind [post]
|
||||
func (h *AuthHandler) BindPhone(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "phone bind not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "phone binding is not configured"})
|
||||
}
|
||||
|
||||
// UnbindPhone 解绑手机号
|
||||
@@ -621,7 +707,7 @@ func (h *AuthHandler) BindPhone(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/phone/unbind [post]
|
||||
func (h *AuthHandler) UnbindPhone(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "phone unbind not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "phone binding is not configured"})
|
||||
}
|
||||
|
||||
// GetSocialAccounts 获取社交账号列表
|
||||
@@ -645,7 +731,7 @@ func (h *AuthHandler) GetSocialAccounts(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/social/bind [post]
|
||||
func (h *AuthHandler) BindSocialAccount(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "social binding not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "social binding is not configured"})
|
||||
}
|
||||
|
||||
// UnbindSocialAccount 解绑社交账号
|
||||
@@ -657,7 +743,7 @@ func (h *AuthHandler) BindSocialAccount(c *gin.Context) {
|
||||
// @Success 200 {object} Response "功能未配置"
|
||||
// @Router /api/v1/auth/social/unbind [post]
|
||||
func (h *AuthHandler) UnbindSocialAccount(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "social unbinding not configured"})
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"code": http.StatusServiceUnavailable, "message": "social binding is not configured"})
|
||||
}
|
||||
|
||||
func (h *AuthHandler) SupportsEmailCodeLogin() bool {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
apimiddleware "github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
)
|
||||
|
||||
@@ -33,10 +34,12 @@ func NewAvatarHandler(userRepo avatarUserRepository) *AvatarHandler {
|
||||
}
|
||||
|
||||
// generateSecureToken generates a secure random token
|
||||
func generateSecureToken(length int) string {
|
||||
func generateSecureToken(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
rand.Read(bytes)
|
||||
return hex.EncodeToString(bytes)[:length]
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return hex.EncodeToString(bytes)[:length], nil
|
||||
}
|
||||
|
||||
// UploadAvatar 上传用户头像
|
||||
@@ -70,17 +73,7 @@ func (h *AvatarHandler) UploadAvatar(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Check permission: user can only update their own avatar, or admin can update any
|
||||
isAdmin := false
|
||||
if roles, ok := c.Get("user_roles"); ok {
|
||||
for _, role := range roles.([]*domain.Role) {
|
||||
if role.Code == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if currentUserID != userID && !isAdmin {
|
||||
if currentUserID != userID && !apimiddleware.IsAdmin(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return
|
||||
}
|
||||
@@ -140,7 +133,12 @@ func (h *AvatarHandler) UploadAvatar(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Generate unique filename
|
||||
avatarFilename := fmt.Sprintf("avatar_%d_%s%s", userID, generateSecureToken(8), ext)
|
||||
token, err := generateSecureToken(8)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to generate avatar token"})
|
||||
return
|
||||
}
|
||||
avatarFilename := fmt.Sprintf("avatar_%d_%s%s", userID, token, ext)
|
||||
uploadDir := "./uploads/avatars"
|
||||
|
||||
// Create upload directory if not exists
|
||||
|
||||
@@ -7,7 +7,9 @@ import (
|
||||
"io"
|
||||
"mime/multipart"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -35,6 +37,11 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
|
||||
t.Helper()
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
previousBootstrapSecret, hadBootstrapSecret := os.LookupEnv("BOOTSTRAP_SECRET")
|
||||
if err := os.Setenv("BOOTSTRAP_SECRET", "test-bootstrap-secret"); err != nil {
|
||||
t.Fatalf("set bootstrap secret failed: %v", err)
|
||||
}
|
||||
|
||||
id := atomic.AddInt64(&handlerDbCounter, 1)
|
||||
dsn := fmt.Sprintf("file:handlerdb_%d_%s?mode=memory&cache=shared", id, t.Name())
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||
@@ -64,6 +71,20 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
|
||||
t.Fatalf("db migration failed: %v", err)
|
||||
}
|
||||
|
||||
adminRole := &domain.Role{Code: "admin", Name: "管理员", Status: domain.RoleStatusEnabled}
|
||||
if err := db.Create(adminRole).Error; err != nil {
|
||||
t.Fatalf("seed admin role failed: %v", err)
|
||||
}
|
||||
for _, permission := range domain.DefaultPermissions() {
|
||||
perm := permission
|
||||
if err := db.Create(&perm).Error; err != nil {
|
||||
t.Fatalf("seed permission %s failed: %v", perm.Code, err)
|
||||
}
|
||||
if err := db.Create(&domain.RolePermission{RoleID: adminRole.ID, PermissionID: perm.ID}).Error; err != nil {
|
||||
t.Fatalf("seed role permission %s failed: %v", perm.Code, err)
|
||||
}
|
||||
}
|
||||
|
||||
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||
HS256Secret: "test-handler-secret-key",
|
||||
AccessTokenExpire: 15 * time.Minute,
|
||||
@@ -136,6 +157,11 @@ func setupHandlerTestServer(t *testing.T) (*httptest.Server, func()) {
|
||||
server := httptest.NewServer(engine)
|
||||
return server, func() {
|
||||
server.Close()
|
||||
if hadBootstrapSecret {
|
||||
_ = os.Setenv("BOOTSTRAP_SECRET", previousBootstrapSecret)
|
||||
} else {
|
||||
_ = os.Unsetenv("BOOTSTRAP_SECRET")
|
||||
}
|
||||
if sqlDB, _ := db.DB(); sqlDB != nil {
|
||||
sqlDB.Close()
|
||||
}
|
||||
@@ -207,6 +233,35 @@ func registerUser(baseURL, username, email, password string) bool {
|
||||
return resp.StatusCode == http.StatusCreated
|
||||
}
|
||||
|
||||
func bootstrapAdminToken(baseURL, username, email, password string) string {
|
||||
payload, _ := json.Marshal(map[string]interface{}{
|
||||
"username": username,
|
||||
"email": email,
|
||||
"password": password,
|
||||
})
|
||||
req, _ := http.NewRequest("POST", baseURL+"/api/v1/auth/bootstrap-admin", bytes.NewReader(payload))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("X-Bootstrap-Secret", "test-bootstrap-secret")
|
||||
resp, err := (&http.Client{}).Do(req)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusCreated {
|
||||
return ""
|
||||
}
|
||||
var result map[string]interface{}
|
||||
if err := json.Unmarshal(bodyBytes, &result); err != nil {
|
||||
return ""
|
||||
}
|
||||
data, ok := result["data"].(map[string]interface{})
|
||||
if !ok || data["access_token"] == nil {
|
||||
return ""
|
||||
}
|
||||
return data["access_token"].(string)
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// Auth Handler Tests
|
||||
// =============================================================================
|
||||
@@ -292,6 +347,89 @@ func TestAuthHandler_Login_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_Login_SetsSessionCookies(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "cookieuser", "cookie@example.com", "Password123!")
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/login", "", map[string]interface{}{
|
||||
"account": "cookieuser",
|
||||
"password": "Password123!",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
cookies := resp.Cookies()
|
||||
var hasRefreshCookie bool
|
||||
var hasPresenceCookie bool
|
||||
for _, cookie := range cookies {
|
||||
switch cookie.Name {
|
||||
case "ums_refresh_token":
|
||||
hasRefreshCookie = cookie.HttpOnly && cookie.Value != ""
|
||||
case "ums_session_present":
|
||||
hasPresenceCookie = !cookie.HttpOnly && cookie.Value == "1"
|
||||
}
|
||||
}
|
||||
if !hasRefreshCookie {
|
||||
t.Fatalf("expected login response to set ums_refresh_token cookie, got %#v", cookies)
|
||||
}
|
||||
if !hasPresenceCookie {
|
||||
t.Fatalf("expected login response to set ums_session_present cookie, got %#v", cookies)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_RefreshToken_UsesCookieFallback(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "refreshcookieuser", "refreshcookie@example.com", "Password123!")
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("cookiejar.New() error: %v", err)
|
||||
}
|
||||
client := &http.Client{Jar: jar}
|
||||
|
||||
loginBody, _ := json.Marshal(map[string]interface{}{
|
||||
"account": "refreshcookieuser",
|
||||
"password": "Password123!",
|
||||
})
|
||||
loginReq, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/login", bytes.NewReader(loginBody))
|
||||
loginReq.Header.Set("Content-Type", "application/json")
|
||||
loginResp, err := client.Do(loginReq)
|
||||
if err != nil {
|
||||
t.Fatalf("login request failed: %v", err)
|
||||
}
|
||||
defer loginResp.Body.Close()
|
||||
if loginResp.StatusCode != http.StatusOK {
|
||||
payload, _ := io.ReadAll(loginResp.Body)
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, loginResp.StatusCode, string(payload))
|
||||
}
|
||||
|
||||
refreshReq, _ := http.NewRequest("POST", server.URL+"/api/v1/auth/refresh", nil)
|
||||
refreshReq.Header.Set("Content-Type", "application/json")
|
||||
refreshResp, err := client.Do(refreshReq)
|
||||
if err != nil {
|
||||
t.Fatalf("refresh request failed: %v", err)
|
||||
}
|
||||
defer refreshResp.Body.Close()
|
||||
refreshPayload, _ := io.ReadAll(refreshResp.Body)
|
||||
if refreshResp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusOK, refreshResp.StatusCode, string(refreshPayload))
|
||||
}
|
||||
|
||||
var parsed map[string]interface{}
|
||||
if err := json.Unmarshal(refreshPayload, &parsed); err != nil {
|
||||
t.Fatalf("refresh response json unmarshal failed: %v", err)
|
||||
}
|
||||
data, _ := parsed["data"].(map[string]interface{})
|
||||
if data == nil || data["access_token"] == nil || data["refresh_token"] == nil {
|
||||
t.Fatalf("expected refresh response to include token pair, got %v", parsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_Login_WrongPassword(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
@@ -336,33 +474,61 @@ func TestAuthHandler_BootstrapAdmin_MissingSecret(t *testing.T) {
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
// Without BOOTSTRAP_SECRET env var set, should get forbidden
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("expected status %d for missing bootstrap secret, got %d", http.StatusForbidden, resp.StatusCode)
|
||||
// P0 修复后:已配置 BOOTSTRAP_SECRET 但未提供 header,应返回 401
|
||||
if resp.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected status %d for missing bootstrap secret header, got %d", http.StatusUnauthorized, resp.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthHandler_GetAuthCapabilities(t *testing.T) {
|
||||
func TestAuthHandler_VerifyTOTPAfterPasswordLogin_RequiresTempToken(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/auth/capabilities", "")
|
||||
resp, body := doPost(server.URL+"/api/v1/auth/login/totp-verify", "", map[string]interface{}{
|
||||
"user_id": 1,
|
||||
"code": "123456",
|
||||
"device_id": "device-1",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
if resp.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusBadRequest, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// User Handler Tests
|
||||
// =============================================================================
|
||||
func TestAuthHandler_UnconfiguredOAuthAndBindingsFailClosed(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "failclosed", "failclosed@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "failclosed", "AdminPass123!")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
body map[string]interface{}
|
||||
}{
|
||||
{name: "oauth login", url: server.URL + "/api/v1/auth/oauth/github"},
|
||||
{name: "email bind code", url: server.URL + "/api/v1/users/me/bind-email/code", body: map[string]interface{}{"email": "bind@example.com"}},
|
||||
{name: "social bind", url: server.URL + "/api/v1/users/me/bind-social", body: map[string]interface{}{"provider": "github"}},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var resp *http.Response
|
||||
var body string
|
||||
if tc.body == nil {
|
||||
resp, body = doGet(tc.url, token)
|
||||
} else {
|
||||
resp, body = doPost(tc.url, token, tc.body)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusServiceUnavailable {
|
||||
t.Fatalf("expected status %d, got %d, body: %s", http.StatusServiceUnavailable, resp.StatusCode, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_CreateUser_RequiresAdmin(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
@@ -400,39 +566,33 @@ func TestUserHandler_CreateUser_Unauthorized(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_ListUsers_Success(t *testing.T) {
|
||||
func TestUserHandler_ListUsers_ForbiddenForRegularUser(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "listadmin", "listadmin@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "listadmin", "AdminPass123!")
|
||||
registerUser(server.URL, "listuser", "listuser@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "listuser", "AdminPass123!")
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/users?page=1&page_size=10", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
|
||||
var result map[string]interface{}
|
||||
json.Unmarshal([]byte(body), &result)
|
||||
if result["code"] != float64(0) {
|
||||
t.Errorf("expected code 0, got %v", result["code"])
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_GetUser_Success(t *testing.T) {
|
||||
func TestUserHandler_GetUser_ForbiddenForRegularUser(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "getadmin", "getadmin@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "getadmin", "AdminPass123!")
|
||||
registerUser(server.URL, "getuser", "getuser@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "getuser", "AdminPass123!")
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/users/1", token)
|
||||
resp, body := doGet(server.URL+"/api/v1/users/1", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d", http.StatusOK, resp.StatusCode)
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -440,8 +600,8 @@ func TestUserHandler_UpdateUser_Success(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "updateadmin", "updateadmin@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "updateadmin", "AdminPass123!")
|
||||
registerUser(server.URL, "updateuser", "update@example.com", "UserPass123!")
|
||||
token := getToken(server.URL, "updateuser", "UserPass123!")
|
||||
|
||||
resp, body := doPut(server.URL+"/api/v1/users/1", token, map[string]string{"nickname": "Updated Nickname"})
|
||||
defer resp.Body.Close()
|
||||
@@ -451,6 +611,43 @@ func TestUserHandler_UpdateUser_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_UpdateUser_AdminCanUpdateOther(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
token := bootstrapAdminToken(server.URL, "updateadmin", "updateadmin@test.com", "AdminPass123!")
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin token should succeed")
|
||||
}
|
||||
registerUser(server.URL, "manageduser", "manageduser@test.com", "UserPass123!")
|
||||
|
||||
resp, body := doPut(server.URL+"/api/v1/users/2", token, map[string]string{"nickname": "Admin Updated"})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_UpdatePassword_NonAdminCannotUpdateOther(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "pwd-user-1", "pwd-user-1@test.com", "UserPass123!")
|
||||
token := getToken(server.URL, "pwd-user-1", "UserPass123!")
|
||||
registerUser(server.URL, "pwd-user-2", "pwd-user-2@test.com", "TargetPass123!")
|
||||
|
||||
resp, body := doPut(server.URL+"/api/v1/users/2/password", token, map[string]string{
|
||||
"old_password": "TargetPass123!",
|
||||
"new_password": "TargetNew456!",
|
||||
})
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusForbidden, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_DeleteUser_NonAdmin_Forbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
@@ -471,8 +668,10 @@ func TestUserHandler_SearchUsers_Success(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "searchadmin", "searchadmin@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "searchadmin", "AdminPass123!")
|
||||
token := bootstrapAdminToken(server.URL, "searchadmin", "searchadmin@test.com", "AdminPass123!")
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin token should succeed")
|
||||
}
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/users/1", token)
|
||||
defer resp.Body.Close()
|
||||
@@ -515,6 +714,24 @@ func TestUserHandler_GetUserRoles_Success(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_GetUserRoles_AdminCanViewOther(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
token := bootstrapAdminToken(server.URL, "rolesbootstrap", "rolesbootstrap@test.com", "AdminPass123!")
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin token should succeed")
|
||||
}
|
||||
registerUser(server.URL, "role-target", "role-target@test.com", "UserPass123!")
|
||||
|
||||
resp, body := doGet(server.URL+"/api/v1/users/2/roles", token)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, resp.StatusCode, body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserHandler_AssignRoles_RequiresAdmin(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
@@ -974,8 +1191,10 @@ func TestInvalidUserID_ReturnsBadRequest(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "invalidid", "invalidid@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "invalidid", "AdminPass123!")
|
||||
token := bootstrapAdminToken(server.URL, "invalidid", "invalidid@test.com", "AdminPass123!")
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin token should succeed")
|
||||
}
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/users/invalid", token)
|
||||
defer resp.Body.Close()
|
||||
@@ -989,8 +1208,10 @@ func TestNonExistentUserID_ReturnsNotFound(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
registerUser(server.URL, "notfound", "notfound@test.com", "AdminPass123!")
|
||||
token := getToken(server.URL, "notfound", "AdminPass123!")
|
||||
token := bootstrapAdminToken(server.URL, "notfound", "notfound@test.com", "AdminPass123!")
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin token should succeed")
|
||||
}
|
||||
|
||||
resp, _ := doGet(server.URL+"/api/v1/users/99999", token)
|
||||
defer resp.Body.Close()
|
||||
@@ -1350,6 +1571,29 @@ func TestAvatarHandler_UploadAvatar_NonAdminCannotUpdateOther(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAvatarHandler_UploadAvatar_AdminCanUpdateOther(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
token := bootstrapAdminToken(server.URL, "avataradmin", "avataradmin@test.com", "AdminPass123!")
|
||||
if token == "" {
|
||||
t.Fatal("bootstrap admin token should succeed")
|
||||
}
|
||||
registerUser(server.URL, "avatar-target", "avatar-target@test.com", "UserPass123!")
|
||||
|
||||
fileContent := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A}
|
||||
resp, err := doUploadFile(server.URL+"/api/v1/users/2/avatar", token, "avatar", "test.png", fileContent)
|
||||
if err != nil {
|
||||
t.Fatalf("upload request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
t.Fatalf("expected status %d for admin updating other's avatar, got %d, body: %s", http.StatusOK, resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAvatarHandler_UploadAvatar_UserNotFoundOrForbidden(t *testing.T) {
|
||||
server, cleanup := setupHandlerTestServer(t)
|
||||
defer cleanup()
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
apimiddleware "github.com/user-management-system/internal/api/middleware"
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/service"
|
||||
@@ -187,16 +188,7 @@ func (h *UserHandler) UpdateUser(c *gin.Context) {
|
||||
|
||||
// Authorization: only self or admin can update user profile
|
||||
currentUserID := c.GetInt64("user_id")
|
||||
isAdmin := false
|
||||
if roles, ok := c.Get("user_roles"); ok {
|
||||
for _, role := range roles.([]*domain.Role) {
|
||||
if role.Code == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if currentUserID != id && !isAdmin {
|
||||
if currentUserID != id && !apimiddleware.IsAdmin(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return
|
||||
}
|
||||
@@ -289,6 +281,12 @@ func (h *UserHandler) UpdatePassword(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
currentUserID := c.GetInt64("user_id")
|
||||
if currentUserID != id && !apimiddleware.IsAdmin(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return
|
||||
}
|
||||
|
||||
if err := h.userService.ChangePassword(c.Request.Context(), id, req.OldPassword, req.NewPassword); err != nil {
|
||||
handleError(c, err)
|
||||
return
|
||||
@@ -370,16 +368,7 @@ func (h *UserHandler) GetUserRoles(c *gin.Context) {
|
||||
|
||||
// Authorization: only self or admin can view user roles
|
||||
currentUserID := c.GetInt64("user_id")
|
||||
isAdmin := false
|
||||
if roles, ok := c.Get("user_roles"); ok {
|
||||
for _, role := range roles.([]*domain.Role) {
|
||||
if role.Code == "admin" {
|
||||
isAdmin = true
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if currentUserID != id && !isAdmin {
|
||||
if currentUserID != id && !apimiddleware.IsAdmin(c) {
|
||||
c.JSON(http.StatusForbidden, gin.H{"code": 403, "message": "permission denied"})
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -89,6 +90,12 @@ func (m *RateLimitMiddleware) Refresh() gin.HandlerFunc {
|
||||
}
|
||||
|
||||
func (m *RateLimitMiddleware) limitForKey(key string, windowSeconds int, capacity int64) gin.HandlerFunc {
|
||||
if os.Getenv("DISABLE_RATE_LIMIT") == "1" {
|
||||
return func(c *gin.Context) {
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
limiter := m.getOrCreateLimiter(key, time.Duration(windowSeconds)*time.Second, capacity)
|
||||
|
||||
return func(c *gin.Context) {
|
||||
|
||||
@@ -142,6 +142,7 @@ func (r *Router) Setup() *gin.Engine {
|
||||
authGroup.POST("/login/totp-verify", r.rateLimitMiddleware.Login(), r.authHandler.VerifyTOTPAfterPasswordLogin)
|
||||
authGroup.POST("/refresh", r.rateLimitMiddleware.Refresh(), r.authHandler.RefreshToken)
|
||||
authGroup.GET("/capabilities", r.authHandler.GetAuthCapabilities)
|
||||
authGroup.GET("/csrf-token", r.authHandler.GetCSRFToken)
|
||||
|
||||
authGroup.POST("/activate-email", r.authHandler.ActivateEmail)
|
||||
authGroup.POST("/resend-activation", r.authHandler.ResendActivationEmail)
|
||||
@@ -189,7 +190,6 @@ func (r *Router) Setup() *gin.Engine {
|
||||
protected.Use(r.authMiddleware.Required())
|
||||
protected.Use(r.rateLimitMiddleware.API())
|
||||
{
|
||||
protected.GET("/auth/csrf-token", r.authHandler.GetCSRFToken)
|
||||
protected.POST("/auth/logout", r.authHandler.Logout)
|
||||
protected.GET("/auth/userinfo", r.authHandler.GetUserInfo)
|
||||
|
||||
@@ -206,8 +206,8 @@ func (r *Router) Setup() *gin.Engine {
|
||||
users := protected.Group("/users")
|
||||
{
|
||||
users.POST("", middleware.RequirePermission("user:manage"), r.userHandler.CreateUser)
|
||||
users.GET("", r.userHandler.ListUsers)
|
||||
users.GET("/:id", r.userHandler.GetUser)
|
||||
users.GET("", middleware.RequirePermission("user:manage"), r.userHandler.ListUsers)
|
||||
users.GET("/:id", middleware.RequirePermission("user:manage"), r.userHandler.GetUser)
|
||||
users.PUT("/:id", r.userHandler.UpdateUser)
|
||||
users.DELETE("/:id", middleware.RequirePermission("user:delete"), r.userHandler.DeleteUser)
|
||||
users.PUT("/:id/password", r.userHandler.UpdatePassword)
|
||||
|
||||
@@ -54,6 +54,7 @@ type Claims struct {
|
||||
Remember bool `json:"remember,omitempty"` // 记住登录标记
|
||||
JTI string `json:"jti"` // JWT ID,用于黑名单
|
||||
PCE int64 `json:"pce,omitempty"` // Password Changed Epoch,密码变更时间戳,用于 token 失效机制
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
|
||||
@@ -494,6 +495,47 @@ func (j *JWT) ValidateRefreshToken(tokenString string) (*Claims, error) {
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (j *JWT) GenerateTOTPChallengeToken(userID int64, username, deviceID string, pce int64) (string, error) {
|
||||
if err := j.ensureReady(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
jti, err := generateJTI()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
claims := Claims{
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
Type: "totp_challenge",
|
||||
JTI: jti,
|
||||
PCE: pce,
|
||||
DeviceID: strings.TrimSpace(deviceID),
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(5 * time.Minute)),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
NotBefore: jwt.NewNumericDate(now),
|
||||
},
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(j.signingMethod(), claims)
|
||||
return token.SignedString(j.signingKey())
|
||||
}
|
||||
|
||||
func (j *JWT) ValidateTOTPChallengeToken(tokenString string) (*Claims, error) {
|
||||
claims, err := j.ParseToken(tokenString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if claims.Type != "totp_challenge" {
|
||||
return nil, errors.New("invalid token type")
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// RefreshAccessToken 刷新访问令牌
|
||||
func (j *JWT) RefreshAccessToken(refreshTokenString string) (string, error) {
|
||||
claims, err := j.ValidateRefreshToken(refreshTokenString)
|
||||
|
||||
@@ -122,7 +122,7 @@ type LoginResponse struct {
|
||||
ExpiresIn int64 `json:"expires_in,omitempty"`
|
||||
User *UserInfo `json:"user,omitempty"`
|
||||
// RequiresTOTP 指示登录需要额外的TOTP验证(当设备未信任时)
|
||||
RequiresTOTP bool `json:"requires_totp,omitempty"`
|
||||
RequiresTOTP bool `json:"requires_totp,omitempty"`
|
||||
// TempToken 临时令牌,用于TOTP验证阶段(短生命周期,不可用于常规API)
|
||||
TempToken string `json:"temp_token,omitempty"`
|
||||
// UserID 当RequiresTOTP为true时返回,用于后续TOTP验证
|
||||
@@ -759,11 +759,16 @@ func (s *AuthService) Login(ctx context.Context, req *LoginRequest, ip string) (
|
||||
|
||||
// P0-07 安全修复:检查是否需要TOTP验证(用户启用了TOTP且设备未信任)
|
||||
if s.isTOTPRequiredForLogin(ctx, user, req.DeviceID) {
|
||||
tempToken, err := s.jwtManager.GenerateTOTPChallengeToken(user.ID, user.Username, req.DeviceID, user.PasswordChangedAt.Unix())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 返回RequiresTOTP指示前端需要完成TOTP验证
|
||||
// 前端应调用 /auth/login/totp-verify 接口完成验证
|
||||
return &LoginResponse{
|
||||
RequiresTOTP: true,
|
||||
UserID: user.ID,
|
||||
TempToken: tempToken,
|
||||
UserID: user.ID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -808,10 +813,27 @@ func (s *AuthService) isTOTPRequiredForLogin(ctx context.Context, user *domain.U
|
||||
// VerifyTOTPAfterPasswordLogin 完成密码登录后的TOTP验证
|
||||
// 当用户启用了TOTP但设备未信任时,密码登录会返回RequiresTOTP=true
|
||||
// 前端需要调用此接口完成TOTP验证以获取令牌
|
||||
func (s *AuthService) VerifyTOTPAfterPasswordLogin(ctx context.Context, userID int64, totpCode, deviceID string) (*LoginResponse, error) {
|
||||
func (s *AuthService) VerifyTOTPAfterPasswordLogin(ctx context.Context, userID int64, totpCode, deviceID, tempToken string) (*LoginResponse, error) {
|
||||
if s == nil {
|
||||
return nil, errors.New("auth service is not initialized")
|
||||
}
|
||||
if s.jwtManager == nil {
|
||||
return nil, errors.New("jwt manager is not configured")
|
||||
}
|
||||
|
||||
claims, err := s.jwtManager.ValidateTOTPChallengeToken(strings.TrimSpace(tempToken))
|
||||
if err != nil {
|
||||
return nil, errors.New("TOTP challenge is invalid or expired")
|
||||
}
|
||||
if claims == nil || claims.UserID != userID {
|
||||
return nil, errors.New("TOTP challenge does not match user")
|
||||
}
|
||||
if strings.TrimSpace(claims.DeviceID) != strings.TrimSpace(deviceID) {
|
||||
return nil, errors.New("TOTP challenge does not match device")
|
||||
}
|
||||
if s.IsTokenBlacklisted(ctx, claims.JTI) {
|
||||
return nil, errors.New("TOTP challenge has already been used")
|
||||
}
|
||||
|
||||
user, err := s.userRepo.GetByID(ctx, userID)
|
||||
if err != nil {
|
||||
@@ -826,6 +848,9 @@ func (s *AuthService) VerifyTOTPAfterPasswordLogin(ctx context.Context, userID i
|
||||
if err := s.VerifyTOTP(ctx, userID, totpCode, deviceID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.blacklistTokenClaims(ctx, tempToken, s.jwtManager.ValidateTOTPChallengeToken); err != nil {
|
||||
return nil, fmt.Errorf("totp challenge revocation failed: %w", err)
|
||||
}
|
||||
|
||||
// TOTP验证成功,返回完整登录响应
|
||||
return s.generateLoginResponseWithoutRemember(ctx, user)
|
||||
@@ -902,18 +927,22 @@ func (s *AuthService) Logout(ctx context.Context, username string, req *LogoutRe
|
||||
return nil
|
||||
}
|
||||
|
||||
_ = s.blacklistTokenClaims(ctx, req.AccessToken, func(token string) (*auth.Claims, error) {
|
||||
if err := s.blacklistTokenClaims(ctx, req.AccessToken, func(token string) (*auth.Claims, error) {
|
||||
if s.jwtManager == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.jwtManager.ValidateAccessToken(token)
|
||||
})
|
||||
_ = s.blacklistTokenClaims(ctx, req.RefreshToken, func(token string) (*auth.Claims, error) {
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := s.blacklistTokenClaims(ctx, req.RefreshToken, func(token string) (*auth.Claims, error) {
|
||||
if s.jwtManager == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return s.jwtManager.ValidateRefreshToken(token)
|
||||
})
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if strings.TrimSpace(username) != "" {
|
||||
s.publishEvent(ctx, domain.EventUserLogout, map[string]interface{}{
|
||||
|
||||
@@ -157,6 +157,41 @@ func TestAuthService_Login(t *testing.T) {
|
||||
t.Error("nil service should return error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("login with totp enabled returns temporary challenge token", func(t *testing.T) {
|
||||
req := &service.RegisterRequest{
|
||||
Username: "totploginuser",
|
||||
Password: "Test123!",
|
||||
Email: "totplogin@test.com",
|
||||
}
|
||||
user, err := env.authSvc.Register(ctx, req)
|
||||
if err != nil {
|
||||
t.Fatalf("Register failed: %v", err)
|
||||
}
|
||||
if err := env.db.Model(&domain.User{}).Where("id = ?", user.ID).Updates(map[string]interface{}{
|
||||
"totp_enabled": true,
|
||||
"totp_secret": "JBSWY3DPEHPK3PXP",
|
||||
}).Error; err != nil {
|
||||
t.Fatalf("enable totp failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := env.authSvc.Login(ctx, &service.LoginRequest{
|
||||
Username: "totploginuser",
|
||||
Password: "Test123!",
|
||||
}, "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("Login failed: %v", err)
|
||||
}
|
||||
if !resp.RequiresTOTP {
|
||||
t.Fatal("expected requires_totp response")
|
||||
}
|
||||
if resp.TempToken == "" {
|
||||
t.Fatal("expected temp_token for second-factor challenge")
|
||||
}
|
||||
if resp.AccessToken != "" || resp.RefreshToken != "" {
|
||||
t.Fatal("totp challenge should not mint full session tokens before second factor verification")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthService_Register(t *testing.T) {
|
||||
|
||||
87
internal/service/auth_logout_failclosed_test.go
Normal file
87
internal/service/auth_logout_failclosed_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/cache"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
gormsqlite "gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
)
|
||||
|
||||
type failingL2Cache struct {
|
||||
setErr error
|
||||
}
|
||||
|
||||
func (f *failingL2Cache) Set(ctx context.Context, key string, value interface{}, ttl time.Duration) error {
|
||||
return f.setErr
|
||||
}
|
||||
func (f *failingL2Cache) Get(ctx context.Context, key string) (interface{}, error) { return nil, nil }
|
||||
func (f *failingL2Cache) Delete(ctx context.Context, key string) error { return nil }
|
||||
func (f *failingL2Cache) Exists(ctx context.Context, key string) (bool, error) { return false, nil }
|
||||
func (f *failingL2Cache) Clear(ctx context.Context) error { return nil }
|
||||
func (f *failingL2Cache) Increment(ctx context.Context, key string, delta int64, ttl time.Duration) (int64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (f *failingL2Cache) Close() error { return nil }
|
||||
|
||||
func TestAuthService_Logout_FailsClosedWhenBlacklistWriteFails(t *testing.T) {
|
||||
dsn := fmt.Sprintf("file:logoutfailclosed_%d?mode=memory&cache=shared", time.Now().UnixNano())
|
||||
db, err := gorm.Open(gormsqlite.New(gormsqlite.Config{DriverName: "sqlite", DSN: dsn}), &gorm.Config{
|
||||
Logger: logger.Default.LogMode(logger.Silent),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("open db failed: %v", err)
|
||||
}
|
||||
if err := db.AutoMigrate(&domain.User{}, &domain.Role{}, &domain.UserRole{}, &domain.LoginLog{}, &domain.PasswordHistory{}); err != nil {
|
||||
t.Fatalf("migrate failed: %v", err)
|
||||
}
|
||||
for _, role := range domain.PredefinedRoles {
|
||||
roleCopy := role
|
||||
if err := db.Create(&roleCopy).Error; err != nil {
|
||||
t.Fatalf("seed role %s failed: %v", role.Code, err)
|
||||
}
|
||||
}
|
||||
|
||||
jwtManager, err := auth.NewJWTWithOptions(auth.JWTOptions{
|
||||
HS256Secret: fmt.Sprintf("logout-failclosed-secret-%d", time.Now().UnixNano()),
|
||||
AccessTokenExpire: 15 * time.Minute,
|
||||
RefreshTokenExpire: 7 * 24 * time.Hour,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("create jwt manager failed: %v", err)
|
||||
}
|
||||
|
||||
userRepo := repository.NewUserRepository(db)
|
||||
userRoleRepo := repository.NewUserRoleRepository(db)
|
||||
roleRepo := repository.NewRoleRepository(db)
|
||||
cacheManager := cache.NewCacheManager(cache.NewL1Cache(), &failingL2Cache{setErr: errors.New("forced blacklist failure")})
|
||||
|
||||
authSvc := NewAuthService(userRepo, nil, jwtManager, cacheManager, 8, 5, 15*time.Minute)
|
||||
authSvc.SetRoleRepositories(userRoleRepo, roleRepo)
|
||||
|
||||
ctx := context.Background()
|
||||
if _, err := authSvc.Register(ctx, &RegisterRequest{Username: "logoutfail", Password: "Password123!"}); err != nil {
|
||||
t.Fatalf("register failed: %v", err)
|
||||
}
|
||||
loginResp, err := authSvc.Login(ctx, &LoginRequest{Username: "logoutfail", Password: "Password123!"}, "127.0.0.1")
|
||||
if err != nil {
|
||||
t.Fatalf("login failed: %v", err)
|
||||
}
|
||||
|
||||
err = authSvc.Logout(ctx, "logoutfail", &LogoutRequest{AccessToken: loginResp.AccessToken, RefreshToken: loginResp.RefreshToken})
|
||||
if err == nil {
|
||||
t.Fatal("expected logout to fail closed when blacklist write fails")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "forced blacklist failure") {
|
||||
t.Fatalf("expected propagated blacklist error, got: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -125,24 +125,49 @@ func (s *UserService) ChangePassword(ctx context.Context, userID int64, oldPassw
|
||||
return errors.New("密码哈希失败")
|
||||
}
|
||||
|
||||
// 保存新密码到历史记录(异步,不阻塞密码更新)
|
||||
if s.passwordHistoryRepo != nil {
|
||||
// #nosec G118 - 使用带超时的独立 context(不能使用请求 ctx,该 goroutine 在请求完成后仍可能运行)
|
||||
go func(hashedPw string) { // #nosec G118
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = s.passwordHistoryRepo.Create(bgCtx, &domain.PasswordHistory{
|
||||
UserID: userID,
|
||||
PasswordHash: hashedPw,
|
||||
})
|
||||
_ = s.passwordHistoryRepo.DeleteOldRecords(bgCtx, userID, passwordHistoryLimit)
|
||||
}(newHashedPassword)
|
||||
}
|
||||
|
||||
// 更新密码(使用同一哈希值)
|
||||
oldPasswordHash := user.Password
|
||||
oldPasswordChangedAt := user.PasswordChangedAt
|
||||
user.Password = newHashedPassword
|
||||
user.PasswordChangedAt = time.Now()
|
||||
return s.userRepo.Update(ctx, user)
|
||||
|
||||
if s.passwordHistoryRepo == nil {
|
||||
return s.userRepo.Update(ctx, user)
|
||||
}
|
||||
|
||||
return s.userRepo.DB().WithContext(ctx).Transaction(func(tx *gorm.DB) error {
|
||||
if err := tx.Model(&domain.User{}).
|
||||
Where("id = ?", user.ID).
|
||||
Updates(map[string]interface{}{"password": user.Password, "password_changed_at": user.PasswordChangedAt}).Error; err != nil {
|
||||
user.Password = oldPasswordHash
|
||||
user.PasswordChangedAt = oldPasswordChangedAt
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Create(&domain.PasswordHistory{UserID: userID, PasswordHash: newHashedPassword}).Error; err != nil {
|
||||
user.Password = oldPasswordHash
|
||||
user.PasswordChangedAt = oldPasswordChangedAt
|
||||
return err
|
||||
}
|
||||
|
||||
var ids []int64
|
||||
if err := tx.Model(&domain.PasswordHistory{}).
|
||||
Where("user_id = ?", userID).
|
||||
Order("created_at DESC").
|
||||
Limit(passwordHistoryLimit).
|
||||
Pluck("id", &ids).Error; err != nil {
|
||||
user.Password = oldPasswordHash
|
||||
user.PasswordChangedAt = oldPasswordChangedAt
|
||||
return err
|
||||
}
|
||||
if len(ids) > 0 {
|
||||
if err := tx.Where("user_id = ? AND id NOT IN ?", userID, ids).Delete(&domain.PasswordHistory{}).Error; err != nil {
|
||||
user.Password = oldPasswordHash
|
||||
user.PasswordChangedAt = oldPasswordChangedAt
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetByID 根据ID获取用户
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
"github.com/user-management-system/internal/repository"
|
||||
"github.com/user-management-system/internal/service"
|
||||
)
|
||||
|
||||
@@ -339,6 +340,32 @@ func TestUserService_ChangePassword(t *testing.T) {
|
||||
t.Error("Expected error for weak new password")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Change password persists history synchronously", func(t *testing.T) {
|
||||
hashedPassword, _ := auth.HashPassword("HistoryOld123!")
|
||||
user := &domain.User{
|
||||
Username: "historysync",
|
||||
Password: hashedPassword,
|
||||
Status: domain.UserStatusActive,
|
||||
}
|
||||
env.userSvc.Create(ctx, user)
|
||||
|
||||
if err := env.userSvc.ChangePassword(ctx, user.ID, "HistoryOld123!", "HistoryNew456!"); err != nil {
|
||||
t.Fatalf("ChangePassword failed: %v", err)
|
||||
}
|
||||
|
||||
historyRepo := repository.NewPasswordHistoryRepository(env.db)
|
||||
history, err := historyRepo.GetByUserID(ctx, user.ID, 10)
|
||||
if err != nil {
|
||||
t.Fatalf("GetByUserID failed: %v", err)
|
||||
}
|
||||
if len(history) == 0 {
|
||||
t.Fatal("expected password history to be written synchronously")
|
||||
}
|
||||
if !auth.VerifyPassword(history[0].PasswordHash, "HistoryNew456!") {
|
||||
t.Fatal("latest password history hash does not match new password")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserService_BatchUpdateStatus(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user