fix: P1/P2 优化 - OAuth验证 + API响应 + 缓存击穿 + Webhook关闭
P1 - OAuth auth_url origin 验证: - 添加 validateOAuthUrl() 函数验证 OAuth URL origin - 仅允许同源或可信 OAuth 提供商 - LoginPage 和 ProfileSecurityPage 调用前验证 P2 - API 响应运行时类型验证: - 添加 isApiResponse() 运行时验证函数 - parseJsonResponse 验证响应结构完整性 P2 - 缓存击穿防护 (singleflight): - AuthMiddleware.isJTIBlacklisted 使用 singleflight.Group - 防止 L1 miss 时并发请求同时打 L2 P2 - Webhook 服务优雅关闭: - WebhookService 添加 Shutdown() 方法 - 服务器关闭时等待 worker 完成 - main.go 集成 shutdown 调用
This commit is contained in:
@@ -215,6 +215,13 @@ func main() {
|
|||||||
|
|
||||||
log.Println("shutting down server...")
|
log.Println("shutting down server...")
|
||||||
|
|
||||||
|
// 关闭 Webhook 服务,等待投递任务完成
|
||||||
|
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer shutdownCancel()
|
||||||
|
if err := webhookService.Shutdown(shutdownCtx); err != nil {
|
||||||
|
log.Printf("webhook service shutdown: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
|||||||
@@ -1,9 +1,10 @@
|
|||||||
import { describe, expect, it } from 'vitest'
|
import { afterAll, describe, expect, it } from 'vitest'
|
||||||
|
|
||||||
import {
|
import {
|
||||||
buildOAuthCallbackReturnTo,
|
buildOAuthCallbackReturnTo,
|
||||||
parseOAuthCallbackHash,
|
parseOAuthCallbackHash,
|
||||||
sanitizeAuthRedirect,
|
sanitizeAuthRedirect,
|
||||||
|
validateOAuthUrl,
|
||||||
} from './oauth'
|
} from './oauth'
|
||||||
|
|
||||||
describe('oauth auth helpers', () => {
|
describe('oauth auth helpers', () => {
|
||||||
@@ -26,4 +27,40 @@ describe('oauth auth helpers', () => {
|
|||||||
message: '',
|
message: '',
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('validateOAuthUrl', () => {
|
||||||
|
const originalOrigin = window.location.origin
|
||||||
|
|
||||||
|
afterAll(() => {
|
||||||
|
// 恢复原始 origin
|
||||||
|
Object.defineProperty(window, 'location', {
|
||||||
|
value: { origin: originalOrigin },
|
||||||
|
writable: true,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('allows same-origin URLs', () => {
|
||||||
|
Object.defineProperty(window, 'location', {
|
||||||
|
value: { origin: 'http://localhost:3000' },
|
||||||
|
writable: true,
|
||||||
|
})
|
||||||
|
expect(validateOAuthUrl('http://localhost:3000/api/v1/auth/oauth/authorize')).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('allows trusted OAuth provider origins', () => {
|
||||||
|
expect(validateOAuthUrl('https://github.com/login/oauth/authorize')).toBe(true)
|
||||||
|
expect(validateOAuthUrl('https://google.com/oauth/authorize')).toBe(true)
|
||||||
|
expect(validateOAuthUrl('https://facebook.com/v1.0/oauth/authorize')).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('rejects untrusted origins', () => {
|
||||||
|
expect(validateOAuthUrl('https://evil.example.com/oauth/authorize')).toBe(false)
|
||||||
|
expect(validateOAuthUrl('https://attacker.com/callback')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('rejects invalid URLs', () => {
|
||||||
|
expect(validateOAuthUrl('not-a-url')).toBe(false)
|
||||||
|
expect(validateOAuthUrl('')).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -6,6 +6,46 @@ export function sanitizeAuthRedirect(target: string | null | undefined, fallback
|
|||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 可信的 OAuth 提供商 origin 白名单
|
||||||
|
const TRUSTED_OAUTH_ORIGINS = new Set([
|
||||||
|
// 社交登录提供商
|
||||||
|
'https://github.com',
|
||||||
|
'https://google.com',
|
||||||
|
'https://facebook.com',
|
||||||
|
'https://twitter.com',
|
||||||
|
'https://apple.com',
|
||||||
|
'https://weixin.qq.com',
|
||||||
|
'https://qq.com',
|
||||||
|
'https://alipay.com',
|
||||||
|
'https://douyin.com',
|
||||||
|
])
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 验证 OAuth 授权 URL 的 origin 是否可信
|
||||||
|
* 防止开放重定向攻击
|
||||||
|
*/
|
||||||
|
export function validateOAuthUrl(authUrl: string): boolean {
|
||||||
|
try {
|
||||||
|
const url = new URL(authUrl)
|
||||||
|
|
||||||
|
// 允许同源(当前应用自身作为 OAuth 提供者的情况)
|
||||||
|
if (url.origin === window.location.origin) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否在可信 origin 白名单中
|
||||||
|
if (TRUSTED_OAUTH_ORIGINS.has(url.origin)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 拒绝所有其他 origin
|
||||||
|
return false
|
||||||
|
} catch {
|
||||||
|
// 无效的 URL 格式
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
export function buildOAuthCallbackReturnTo(redirectPath: string): string {
|
export function buildOAuthCallbackReturnTo(redirectPath: string): string {
|
||||||
const callbackUrl = new URL('/login/oauth/callback', window.location.origin)
|
const callbackUrl = new URL('/login/oauth/callback', window.location.origin)
|
||||||
if (redirectPath && redirectPath !== '/dashboard') {
|
if (redirectPath && redirectPath !== '/dashboard') {
|
||||||
|
|||||||
@@ -85,7 +85,41 @@ function createTimeoutSignal(signal?: AbortSignal): { signal: AbortSignal; clean
|
|||||||
}
|
}
|
||||||
|
|
||||||
async function parseJsonResponse<T>(response: Response): Promise<ApiResponse<T>> {
|
async function parseJsonResponse<T>(response: Response): Promise<ApiResponse<T>> {
|
||||||
return response.json() as Promise<ApiResponse<T>>
|
const raw = await response.json()
|
||||||
|
|
||||||
|
// 运行时验证响应结构
|
||||||
|
if (!isApiResponse(raw)) {
|
||||||
|
throw new Error('Invalid API response structure: missing required fields')
|
||||||
|
}
|
||||||
|
|
||||||
|
return raw as ApiResponse<T>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 运行时验证 API 响应结构
|
||||||
|
* 防止后端返回异常格式时导致运行时错误
|
||||||
|
*/
|
||||||
|
function isApiResponse(obj: unknown): obj is ApiResponse<unknown> {
|
||||||
|
if (typeof obj !== 'object' || obj === null) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
const r = obj as Record<string, unknown>
|
||||||
|
|
||||||
|
// 必须有 code 字段且为数字
|
||||||
|
if (typeof r.code !== 'number') {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 必须有 message 字段且为字符串
|
||||||
|
if (typeof r.message !== 'string') {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果有 data 字段,应该存在
|
||||||
|
// (data 可以是 undefined/null/任何类型,但我们允许这些值)
|
||||||
|
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
async function refreshAccessToken(): Promise<TokenBundle> {
|
async function refreshAccessToken(): Promise<TokenBundle> {
|
||||||
|
|||||||
@@ -31,7 +31,8 @@ import type { RcFile } from 'antd/es/upload'
|
|||||||
import dayjs from 'dayjs'
|
import dayjs from 'dayjs'
|
||||||
import { useAuth } from '@/app/providers/auth-context'
|
import { useAuth } from '@/app/providers/auth-context'
|
||||||
import { getErrorMessage } from '@/lib/errors'
|
import { getErrorMessage } from '@/lib/errors'
|
||||||
import { parseOAuthCallbackHash } from '@/lib/auth/oauth'
|
import { parseOAuthCallbackHash, validateOAuthUrl } from '@/lib/auth/oauth'
|
||||||
|
import { getDeviceFingerprint } from '@/lib/device-fingerprint'
|
||||||
import { PageLayout, ContentCard } from '@/components/layout'
|
import { PageLayout, ContentCard } from '@/components/layout'
|
||||||
import { PageHeader } from '@/components/common'
|
import { PageHeader } from '@/components/common'
|
||||||
import { getAuthCapabilities } from '@/services/auth'
|
import { getAuthCapabilities } from '@/services/auth'
|
||||||
@@ -198,6 +199,11 @@ export function ProfileSecurityPage() {
|
|||||||
totp_code: values.totp_code?.trim() || undefined,
|
totp_code: values.totp_code?.trim() || undefined,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// 验证 OAuth URL origin 防止开放重定向攻击
|
||||||
|
if (!validateOAuthUrl(result.auth_url)) {
|
||||||
|
throw new Error('Invalid OAuth authorization URL')
|
||||||
|
}
|
||||||
|
|
||||||
setBindVisible(false)
|
setBindVisible(false)
|
||||||
setActiveProvider(null)
|
setActiveProvider(null)
|
||||||
bindSocialForm.resetFields()
|
bindSocialForm.resetFields()
|
||||||
@@ -306,11 +312,8 @@ export function ProfileSecurityPage() {
|
|||||||
// If "remember device" is checked, trust the current device
|
// If "remember device" is checked, trust the current device
|
||||||
if (totpRememberDevice) {
|
if (totpRememberDevice) {
|
||||||
try {
|
try {
|
||||||
const stored = localStorage.getItem('device_fingerprint')
|
const deviceInfo = getDeviceFingerprint()
|
||||||
if (stored) {
|
await trustDeviceByDeviceId(deviceInfo.device_id, '30d')
|
||||||
const deviceInfo = JSON.parse(stored)
|
|
||||||
await trustDeviceByDeviceId(deviceInfo.device_id, '30d')
|
|
||||||
}
|
|
||||||
} catch {
|
} catch {
|
||||||
// Non-critical: device trust failed, but TOTP was enabled
|
// Non-critical: device trust failed, but TOTP was enabled
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,8 +11,9 @@ import {
|
|||||||
|
|
||||||
import { useAuth } from '@/app/providers/auth-context'
|
import { useAuth } from '@/app/providers/auth-context'
|
||||||
import { AuthLayout } from '@/layouts'
|
import { AuthLayout } from '@/layouts'
|
||||||
import { buildOAuthCallbackReturnTo, sanitizeAuthRedirect } from '@/lib/auth/oauth'
|
import { buildOAuthCallbackReturnTo, sanitizeAuthRedirect, validateOAuthUrl } from '@/lib/auth/oauth'
|
||||||
import { getErrorMessage, isFormValidationError } from '@/lib/errors'
|
import { getErrorMessage, isFormValidationError } from '@/lib/errors'
|
||||||
|
import { getDeviceFingerprint } from '@/lib/device-fingerprint'
|
||||||
import {
|
import {
|
||||||
getAuthCapabilities,
|
getAuthCapabilities,
|
||||||
getOAuthAuthorizationUrl,
|
getOAuthAuthorizationUrl,
|
||||||
@@ -52,34 +53,6 @@ type SmsCodeFormValues = {
|
|||||||
code: string
|
code: string
|
||||||
}
|
}
|
||||||
|
|
||||||
// 构建设备指纹
|
|
||||||
function buildDeviceFingerprint(): { device_id: string; device_name: string; device_browser: string; device_os: string } {
|
|
||||||
const ua = navigator.userAgent
|
|
||||||
let browser = 'Unknown'
|
|
||||||
let os = 'Unknown'
|
|
||||||
|
|
||||||
if (ua.includes('Chrome')) browser = 'Chrome'
|
|
||||||
else if (ua.includes('Firefox')) browser = 'Firefox'
|
|
||||||
else if (ua.includes('Safari')) browser = 'Safari'
|
|
||||||
else if (ua.includes('Edge')) browser = 'Edge'
|
|
||||||
|
|
||||||
if (ua.includes('Windows')) os = 'Windows'
|
|
||||||
else if (ua.includes('Mac')) os = 'macOS'
|
|
||||||
else if (ua.includes('Linux')) os = 'Linux'
|
|
||||||
else if (ua.includes('Android')) os = 'Android'
|
|
||||||
else if (ua.includes('iOS')) os = 'iOS'
|
|
||||||
|
|
||||||
// 使用随机ID作为设备唯一标识
|
|
||||||
const deviceId = `${browser}-${os}-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`
|
|
||||||
|
|
||||||
return {
|
|
||||||
device_id: deviceId,
|
|
||||||
device_name: `${browser} on ${os}`,
|
|
||||||
device_browser: browser,
|
|
||||||
device_os: os,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
export function LoginPage() {
|
export function LoginPage() {
|
||||||
const [activeTab, setActiveTab] = useState('password')
|
const [activeTab, setActiveTab] = useState('password')
|
||||||
const [loading, setLoading] = useState(false)
|
const [loading, setLoading] = useState(false)
|
||||||
@@ -165,6 +138,10 @@ export function LoginPage() {
|
|||||||
provider,
|
provider,
|
||||||
buildOAuthCallbackReturnTo(redirect),
|
buildOAuthCallbackReturnTo(redirect),
|
||||||
)
|
)
|
||||||
|
// 验证 OAuth URL origin 防止开放重定向攻击
|
||||||
|
if (!validateOAuthUrl(result.auth_url)) {
|
||||||
|
throw new Error('Invalid OAuth authorization URL')
|
||||||
|
}
|
||||||
window.location.assign(result.auth_url)
|
window.location.assign(result.auth_url)
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
message.error(getErrorMessage(error, '启动第三方登录失败'))
|
message.error(getErrorMessage(error, '启动第三方登录失败'))
|
||||||
@@ -175,9 +152,7 @@ export function LoginPage() {
|
|||||||
const handlePasswordLogin = useCallback(async (values: LoginFormValues) => {
|
const handlePasswordLogin = useCallback(async (values: LoginFormValues) => {
|
||||||
setLoading(true)
|
setLoading(true)
|
||||||
try {
|
try {
|
||||||
const deviceInfo = buildDeviceFingerprint()
|
const deviceInfo = getDeviceFingerprint()
|
||||||
// Store device info for "remember device" feature on TOTP enable
|
|
||||||
localStorage.setItem('device_fingerprint', JSON.stringify(deviceInfo))
|
|
||||||
const tokenBundle = await loginByPassword({
|
const tokenBundle = await loginByPassword({
|
||||||
username: values.username,
|
username: values.username,
|
||||||
password: values.password,
|
password: values.password,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"golang.org/x/sync/singleflight"
|
||||||
|
|
||||||
"github.com/user-management-system/internal/auth"
|
"github.com/user-management-system/internal/auth"
|
||||||
"github.com/user-management-system/internal/cache"
|
"github.com/user-management-system/internal/cache"
|
||||||
@@ -25,6 +26,7 @@ type AuthMiddleware struct {
|
|||||||
permissionRepo *repository.PermissionRepository
|
permissionRepo *repository.PermissionRepository
|
||||||
l1Cache *cache.L1Cache
|
l1Cache *cache.L1Cache
|
||||||
cacheManager *cache.CacheManager
|
cacheManager *cache.CacheManager
|
||||||
|
sfGroup singleflight.Group
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAuthMiddleware(
|
func NewAuthMiddleware(
|
||||||
@@ -116,12 +118,22 @@ func (m *AuthMiddleware) isJTIBlacklisted(jti string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
key := "jwt_blacklist:" + jti
|
key := "jwt_blacklist:" + jti
|
||||||
|
|
||||||
|
// 先检查 L1 缓存
|
||||||
if _, ok := m.l1Cache.Get(key); ok {
|
if _, ok := m.l1Cache.Get(key); ok {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// L1 miss 时使用 singleflight 防止缓存击穿
|
||||||
|
// 多个并发请求只会触发一次 L2 查询
|
||||||
if m.cacheManager != nil {
|
if m.cacheManager != nil {
|
||||||
if _, ok := m.cacheManager.Get(context.Background(), key); ok {
|
val, err, _ := m.sfGroup.Do(key, func() (interface{}, error) {
|
||||||
|
found, _ := m.cacheManager.Get(context.Background(), key)
|
||||||
|
return found, nil
|
||||||
|
})
|
||||||
|
if err == nil && val != nil {
|
||||||
|
// 回写 L1 缓存
|
||||||
|
m.l1Cache.Set(key, true, 5*time.Minute)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -122,6 +122,29 @@ func (s *WebhookService) startWorkers() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Shutdown 优雅关闭 Webhook 服务
|
||||||
|
// 等待所有处理中的投递任务完成,最多等待 timeout
|
||||||
|
func (s *WebhookService) Shutdown(ctx context.Context) error {
|
||||||
|
// 1. 停止接收新任务:关闭队列
|
||||||
|
close(s.queue)
|
||||||
|
|
||||||
|
// 2. 等待所有 worker 完成
|
||||||
|
done := make(chan struct{})
|
||||||
|
go func() {
|
||||||
|
s.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-done:
|
||||||
|
// 正常完成
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Publish 发布事件:找到订阅该事件的所有 Webhook,异步投递
|
// Publish 发布事件:找到订阅该事件的所有 Webhook,异步投递
|
||||||
func (s *WebhookService) Publish(ctx context.Context, eventType domain.WebhookEventType, data interface{}) {
|
func (s *WebhookService) Publish(ctx context.Context, eventType domain.WebhookEventType, data interface{}) {
|
||||||
if !s.config.Enabled {
|
if !s.config.Enabled {
|
||||||
@@ -270,7 +293,10 @@ func (s *WebhookService) recordDelivery(task *deliveryTask, statusCode int, body
|
|||||||
if success {
|
if success {
|
||||||
delivery.DeliveredAt = &now
|
delivery.DeliveredAt = &now
|
||||||
}
|
}
|
||||||
_ = s.repo.CreateDelivery(context.Background(), delivery)
|
// 使用带超时的独立 context,防止 DB 写入无限等待
|
||||||
|
writeCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
_ = s.repo.CreateDelivery(writeCtx, delivery)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CreateWebhook 创建 Webhook
|
// CreateWebhook 创建 Webhook
|
||||||
|
|||||||
Reference in New Issue
Block a user