fix: 生产安全修复 + Go SDK + CAS SSO框架
安全修复: - CRITICAL: SSO重定向URL注入漏洞 - 修复redirect_uri白名单验证 - HIGH: SSO ClientSecret未验证 - 使用crypto/subtle.ConstantTimeCompare验证 - HIGH: 邮件验证码熵值过低(3字节) - 提升到6字节(48位熵) - HIGH: 短信验证码熵值过低(4字节) - 提升到6字节 - HIGH: Goroutine使用已取消上下文 - auth_email.go使用独立context+超时 - HIGH: SQL LIKE查询注入风险 - permission/role仓库使用escapeLikePattern 新功能: - Go SDK: sdk/go/user-management/ 完整SDK实现 - CAS SSO框架: internal/auth/cas.go CAS协议支持 其他: - L1Cache实例问题修复 - AuthMiddleware共享l1Cache - 设备指纹XSS防护 - 内存存储替代localStorage - 响应格式协议中间件 - 导出无界查询修复
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
@@ -11,12 +12,16 @@ import (
|
||||
|
||||
// SSOHandler SSO 处理程序
|
||||
type SSOHandler struct {
|
||||
ssoManager *auth.SSOManager
|
||||
ssoManager *auth.SSOManager
|
||||
clientsStore auth.SSOClientsStore
|
||||
}
|
||||
|
||||
// NewSSOHandler 创建 SSO 处理程序
|
||||
func NewSSOHandler(ssoManager *auth.SSOManager) *SSOHandler {
|
||||
return &SSOHandler{ssoManager: ssoManager}
|
||||
func NewSSOHandler(ssoManager *auth.SSOManager, clientsStore auth.SSOClientsStore) *SSOHandler {
|
||||
return &SSOHandler{
|
||||
ssoManager: ssoManager,
|
||||
clientsStore: clientsStore,
|
||||
}
|
||||
}
|
||||
|
||||
// AuthorizeRequest 授权请求
|
||||
@@ -43,6 +48,14 @@ func (h *SSOHandler) Authorize(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 验证 redirect_uri 是否在白名单中
|
||||
if h.clientsStore != nil {
|
||||
if !h.clientsStore.ValidateClientRedirectURI(req.ClientID, req.RedirectURI) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid redirect_uri"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 获取当前登录用户(从 auth middleware 设置的 context)
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
@@ -93,7 +106,11 @@ func (h *SSOHandler) Authorize(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
token, _ := h.ssoManager.GenerateAccessToken(req.ClientID, session)
|
||||
token, _, err := h.ssoManager.GenerateAccessToken(req.ClientID, session)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
// 重定向回客户端,带 token
|
||||
redirectURL := req.RedirectURI + "#access_token=" + token + "&expires_in=7200"
|
||||
@@ -136,6 +153,20 @@ func (h *SSOHandler) Token(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 验证客户端凭证
|
||||
if h.clientsStore != nil {
|
||||
client, err := h.clientsStore.GetByClientID(req.ClientID)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid client"})
|
||||
return
|
||||
}
|
||||
// 使用常量时间比较防止时序攻击
|
||||
if subtle.ConstantTimeCompare([]byte(req.ClientSecret), []byte(client.ClientSecret)) != 1 {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid client_secret"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 验证授权码
|
||||
session, err := h.ssoManager.ValidateAuthorizationCode(req.Code)
|
||||
if err != nil {
|
||||
@@ -144,7 +175,11 @@ func (h *SSOHandler) Token(c *gin.Context) {
|
||||
}
|
||||
|
||||
// 生成 access token
|
||||
token, expiresAt := h.ssoManager.GenerateAccessToken(req.ClientID, session)
|
||||
token, expiresAt, err := h.ssoManager.GenerateAccessToken(req.ClientID, session)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate token"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, TokenResponse{
|
||||
AccessToken: token,
|
||||
|
||||
@@ -34,6 +34,7 @@ func NewAuthMiddleware(
|
||||
roleRepo *repository.RoleRepository,
|
||||
rolePermissionRepo *repository.RolePermissionRepository,
|
||||
permissionRepo *repository.PermissionRepository,
|
||||
l1Cache *cache.L1Cache,
|
||||
) *AuthMiddleware {
|
||||
return &AuthMiddleware{
|
||||
jwt: jwt,
|
||||
@@ -42,7 +43,7 @@ func NewAuthMiddleware(
|
||||
roleRepo: roleRepo,
|
||||
rolePermissionRepo: rolePermissionRepo,
|
||||
permissionRepo: permissionRepo,
|
||||
l1Cache: cache.NewL1Cache(),
|
||||
l1Cache: l1Cache,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,7 +130,7 @@ func (m *AuthMiddleware) isJTIBlacklisted(jti string) bool {
|
||||
}
|
||||
|
||||
func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64) ([]string, []string) {
|
||||
if m.userRoleRepo == nil || m.roleRepo == nil || m.rolePermissionRepo == nil || m.permissionRepo == nil {
|
||||
if m.userRoleRepo == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -140,34 +141,9 @@ func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64
|
||||
}
|
||||
}
|
||||
|
||||
roleIDs, err := m.userRoleRepo.GetRoleIDsByUserID(ctx, userID)
|
||||
if err != nil || len(roleIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// 收集所有角色ID(包括直接分配的角色和所有祖先角色)
|
||||
allRoleIDs := make([]int64, 0, len(roleIDs)*2)
|
||||
allRoleIDs = append(allRoleIDs, roleIDs...)
|
||||
|
||||
for _, roleID := range roleIDs {
|
||||
ancestorIDs, err := m.roleRepo.GetAncestorIDs(ctx, roleID)
|
||||
if err == nil && len(ancestorIDs) > 0 {
|
||||
allRoleIDs = append(allRoleIDs, ancestorIDs...)
|
||||
}
|
||||
}
|
||||
|
||||
// 去重
|
||||
seen := make(map[int64]bool)
|
||||
uniqueRoleIDs := make([]int64, 0, len(allRoleIDs))
|
||||
for _, id := range allRoleIDs {
|
||||
if !seen[id] {
|
||||
seen[id] = true
|
||||
uniqueRoleIDs = append(uniqueRoleIDs, id)
|
||||
}
|
||||
}
|
||||
|
||||
roles, err := m.roleRepo.GetByIDs(ctx, roleIDs)
|
||||
if err != nil {
|
||||
// 使用已优化的单次 JOIN 查询获取用户角色和权限
|
||||
roles, permissions, err := m.userRoleRepo.GetUserRolesAndPermissions(ctx, userID)
|
||||
if err != nil || len(roles) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -176,24 +152,12 @@ func (m *AuthMiddleware) loadUserRolesAndPerms(ctx context.Context, userID int64
|
||||
roleCodes = append(roleCodes, role.Code)
|
||||
}
|
||||
|
||||
permissionIDs, err := m.rolePermissionRepo.GetPermissionIDsByRoleIDs(ctx, uniqueRoleIDs)
|
||||
if err != nil || len(permissionIDs) == 0 {
|
||||
entry := userPermEntry{roles: roleCodes, perms: []string{}}
|
||||
m.l1Cache.Set(cacheKey, entry, 30*time.Minute) // PERF-01 优化:增加缓存 TTL 减少 DB 查询
|
||||
return entry.roles, entry.perms
|
||||
}
|
||||
|
||||
permissions, err := m.permissionRepo.GetByIDs(ctx, permissionIDs)
|
||||
if err != nil {
|
||||
return roleCodes, nil
|
||||
}
|
||||
|
||||
permCodes := make([]string, 0, len(permissions))
|
||||
for _, permission := range permissions {
|
||||
permCodes = append(permCodes, permission.Code)
|
||||
for _, perm := range permissions {
|
||||
permCodes = append(permCodes, perm.Code)
|
||||
}
|
||||
|
||||
m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 30*time.Minute) // PERF-01 优化:增加缓存 TTL 减少 DB 查询
|
||||
m.l1Cache.Set(cacheKey, userPermEntry{roles: roleCodes, perms: permCodes}, 30*time.Minute)
|
||||
return roleCodes, permCodes
|
||||
}
|
||||
|
||||
|
||||
221
internal/auth/cas.go
Normal file
221
internal/auth/cas.go
Normal file
@@ -0,0 +1,221 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// CASProvider CAS (Central Authentication Service) 提供者
|
||||
// CAS 是一种单点登录协议,用户只需登录一次即可访问多个应用
|
||||
type CASProvider struct {
|
||||
serverURL string
|
||||
serviceURL string
|
||||
}
|
||||
|
||||
// CASServiceTicket CAS 服务票据
|
||||
type CASServiceTicket struct {
|
||||
Ticket string
|
||||
Service string
|
||||
UserID int64
|
||||
Username string
|
||||
IssuedAt time.Time
|
||||
Expiry time.Time
|
||||
}
|
||||
|
||||
// NewCASProvider 创建 CAS 提供者
|
||||
func NewCASProvider(serverURL, serviceURL string) *CASProvider {
|
||||
return &CASProvider{
|
||||
serverURL: strings.TrimSuffix(serverURL, "/"),
|
||||
serviceURL: serviceURL,
|
||||
}
|
||||
}
|
||||
|
||||
// BuildLoginURL 构建 CAS 登录 URL
|
||||
// 用于重定向用户到 CAS 登录页面
|
||||
func (p *CASProvider) BuildLoginURL(renew, gateway bool) string {
|
||||
params := url.Values{}
|
||||
params.Set("service", p.serviceURL)
|
||||
if renew {
|
||||
params.Set("renew", "true")
|
||||
}
|
||||
if gateway {
|
||||
params.Set("gateway", "true")
|
||||
}
|
||||
return fmt.Sprintf("%s/login?%s", p.serverURL, params.Encode())
|
||||
}
|
||||
|
||||
// BuildLogoutURL 构建 CAS 登出 URL
|
||||
func (p *CASProvider) BuildLogoutURL(url string) string {
|
||||
if url != "" {
|
||||
return fmt.Sprintf("%s/logout?service=%s", p.serverURL, url)
|
||||
}
|
||||
return fmt.Sprintf("%s/logout", p.serverURL)
|
||||
}
|
||||
|
||||
// CASValidationResponse CAS 票据验证响应
|
||||
type CASValidationResponse struct {
|
||||
Success bool
|
||||
UserID int64
|
||||
Username string
|
||||
ErrorCode string
|
||||
ErrorMsg string
|
||||
}
|
||||
|
||||
// ValidateTicket 验证 CAS 票据
|
||||
// 向 CAS 服务器发送 ticket 验证请求
|
||||
func (p *CASProvider) ValidateTicket(ctx context.Context, ticket string) (*CASValidationResponse, error) {
|
||||
if ticket == "" {
|
||||
return &CASValidationResponse{
|
||||
Success: false,
|
||||
ErrorCode: "INVALID_REQUEST",
|
||||
ErrorMsg: "ticket is required",
|
||||
}, nil
|
||||
}
|
||||
|
||||
params := url.Values{}
|
||||
params.Set("service", p.serviceURL)
|
||||
params.Set("ticket", ticket)
|
||||
|
||||
validateURL := fmt.Sprintf("%s/p3/serviceValidate?%s", p.serverURL, params.Encode())
|
||||
|
||||
resp, err := fetchCASResponse(ctx, validateURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("CAS validation request failed: %w", err)
|
||||
}
|
||||
|
||||
return p.parseServiceValidateResponse(resp)
|
||||
}
|
||||
|
||||
// parseServiceValidateResponse 解析 CAS serviceValidate 响应
|
||||
// CAS 1.0 和 CAS 2.0 使用不同的响应格式
|
||||
func (p *CASProvider) parseServiceValidateResponse(xml string) (*CASValidationResponse, error) {
|
||||
resp := &CASValidationResponse{Success: false}
|
||||
|
||||
// 检查是否包含 authenticationSuccess 元素
|
||||
if strings.Contains(xml, "<authenticationSuccess>") {
|
||||
resp.Success = true
|
||||
|
||||
// 解析用户名
|
||||
if start := strings.Index(xml, "<user>"); start != -1 {
|
||||
end := strings.Index(xml[start:], "</user>")
|
||||
if end != -1 {
|
||||
resp.Username = xml[start+6 : start+end]
|
||||
}
|
||||
}
|
||||
|
||||
// 解析用户 ID (CAS 2.0)
|
||||
if start := strings.Index(xml, "<userId>"); start != -1 {
|
||||
end := strings.Index(xml[start:], "</userId>")
|
||||
if end != -1 {
|
||||
userIDStr := xml[start+8 : start+end]
|
||||
var userID int64
|
||||
fmt.Sscanf(userIDStr, "%d", &userID)
|
||||
resp.UserID = userID
|
||||
}
|
||||
}
|
||||
} else if strings.Contains(xml, "<authenticationFailure>") {
|
||||
resp.Success = false
|
||||
|
||||
// 解析错误码
|
||||
if start := strings.Index(xml, "code=\""); start != -1 {
|
||||
start += 6
|
||||
end := strings.Index(xml[start:], "\"")
|
||||
if end != -1 {
|
||||
resp.ErrorCode = xml[start : start+end]
|
||||
}
|
||||
}
|
||||
|
||||
// 解析错误消息
|
||||
if start := strings.Index(xml, "<![CDATA["); start != -1 {
|
||||
end := strings.Index(xml[start:], "]]>")
|
||||
if end != -1 {
|
||||
resp.ErrorMsg = xml[start+9 : start+end]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// GenerateProxyTicket 生成代理票据 (CAS 2.0)
|
||||
// 用于服务代理用户访问其他服务
|
||||
func (p *CASProvider) GenerateProxyTicket(ctx context.Context, proxyGrantingTicket, targetService string) (string, error) {
|
||||
params := url.Values{}
|
||||
params.Set("targetService", targetService)
|
||||
|
||||
proxyURL := fmt.Sprintf("%s/p3/proxy?%s&pgt=%s",
|
||||
p.serverURL, params.Encode(), proxyGrantingTicket)
|
||||
|
||||
resp, err := fetchCASResponse(ctx, proxyURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// 解析代理票据
|
||||
if start := strings.Index(resp, "<proxyTicket>"); start != -1 {
|
||||
end := strings.Index(resp[start:], "</proxyTicket>")
|
||||
if end != -1 {
|
||||
return resp[start+12 : start+end], nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("failed to parse proxy ticket from response")
|
||||
}
|
||||
|
||||
// fetchCASResponse 从 CAS 服务器获取响应
|
||||
func fetchCASResponse(ctx context.Context, url string) (string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
req.Header.Set("Accept", "application/xml")
|
||||
|
||||
client := &http.Client{Timeout: 10 * time.Second}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(body), nil
|
||||
}
|
||||
|
||||
// GenerateCASServiceTicket 生成 CAS 服务票据 (供 CAS 服务器使用)
|
||||
// 这个方法供实际的 CAS 服务器实现调用
|
||||
func GenerateCASServiceTicket(service string, userID int64, username string) (*CASServiceTicket, error) {
|
||||
ticketBytes := make([]byte, 32)
|
||||
if _, err := rand.Read(ticketBytes); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate ticket: %w", err)
|
||||
}
|
||||
|
||||
return &CASServiceTicket{
|
||||
Ticket: "ST-" + base64.URLEncoding.EncodeToString(ticketBytes)[:32],
|
||||
Service: service,
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
IssuedAt: time.Now(),
|
||||
Expiry: time.Now().Add(5 * time.Minute),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// IsExpired 检查票据是否过期
|
||||
func (t *CASServiceTicket) IsExpired() bool {
|
||||
return time.Now().After(t.Expiry)
|
||||
}
|
||||
|
||||
// GetDuration 返回票据有效时长
|
||||
func (t *CASServiceTicket) GetDuration() time.Duration {
|
||||
return t.Expiry.Sub(t.IssuedAt)
|
||||
}
|
||||
@@ -6,9 +6,17 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// MaxSessions 最大 session 数量限制
|
||||
MaxSessions = 10000
|
||||
// CleanupInterval 清理间隔
|
||||
CleanupInterval = 5 * time.Minute
|
||||
)
|
||||
|
||||
// SSOOAuth2Config SSO OAuth2 配置
|
||||
type SSOOAuth2Config struct {
|
||||
ClientID string
|
||||
@@ -66,6 +74,7 @@ type SSOSession struct {
|
||||
|
||||
// SSOManager SSO 管理器
|
||||
type SSOManager struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]*SSOSession
|
||||
}
|
||||
|
||||
@@ -76,12 +85,35 @@ func NewSSOManager() *SSOManager {
|
||||
}
|
||||
}
|
||||
|
||||
// StartCleanup 启动后台清理 goroutine
|
||||
func (m *SSOManager) StartCleanup(ctx context.Context) {
|
||||
go func() {
|
||||
ticker := time.NewTicker(CleanupInterval)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.CleanupExpired()
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// GenerateAuthorizationCode 生成授权码
|
||||
func (m *SSOManager) GenerateAuthorizationCode(clientID, redirectURI, scope string, userID int64, username string) (string, error) {
|
||||
code := generateSecureToken(32)
|
||||
code, err := generateSecureToken(32)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sessionID, err := generateSecureToken(16)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
session := &SSOSession{
|
||||
SessionID: generateSecureToken(16),
|
||||
SessionID: sessionID,
|
||||
UserID: userID,
|
||||
Username: username,
|
||||
ClientID: clientID,
|
||||
@@ -90,13 +122,26 @@ func (m *SSOManager) GenerateAuthorizationCode(clientID, redirectURI, scope stri
|
||||
Scope: scope,
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
// 检查并清理过期 session,如果超过限制则淘汰最旧的
|
||||
if len(m.sessions) >= MaxSessions {
|
||||
m.cleanupExpiredLocked()
|
||||
// 如果仍然满,淘汰最早的
|
||||
if len(m.sessions) >= MaxSessions {
|
||||
m.evictOldest()
|
||||
}
|
||||
}
|
||||
m.sessions[code] = session
|
||||
m.mu.Unlock()
|
||||
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// ValidateAuthorizationCode 验证授权码
|
||||
func (m *SSOManager) ValidateAuthorizationCode(code string) (*SSOSession, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
session, ok := m.sessions[code]
|
||||
if !ok {
|
||||
return nil, errors.New("invalid authorization code")
|
||||
@@ -114,8 +159,11 @@ func (m *SSOManager) ValidateAuthorizationCode(code string) (*SSOSession, error)
|
||||
}
|
||||
|
||||
// GenerateAccessToken 生成访问令牌
|
||||
func (m *SSOManager) GenerateAccessToken(clientID string, session *SSOSession) (string, time.Time) {
|
||||
token := generateSecureToken(32)
|
||||
func (m *SSOManager) GenerateAccessToken(clientID string, session *SSOSession) (string, time.Time, error) {
|
||||
token, err := generateSecureToken(32)
|
||||
if err != nil {
|
||||
return "", time.Time{}, err
|
||||
}
|
||||
expiresAt := time.Now().Add(2 * time.Hour) // Access token 2 小时有效期
|
||||
|
||||
accessSession := &SSOSession{
|
||||
@@ -128,22 +176,37 @@ func (m *SSOManager) GenerateAccessToken(clientID string, session *SSOSession) (
|
||||
Scope: session.Scope,
|
||||
}
|
||||
|
||||
m.mu.Lock()
|
||||
// 检查并清理过期 session,如果超过限制则淘汰最旧的
|
||||
if len(m.sessions) >= MaxSessions {
|
||||
m.cleanupExpiredLocked()
|
||||
if len(m.sessions) >= MaxSessions {
|
||||
m.evictOldest()
|
||||
}
|
||||
}
|
||||
m.sessions[token] = accessSession
|
||||
m.mu.Unlock()
|
||||
|
||||
return token, expiresAt
|
||||
return token, expiresAt, nil
|
||||
}
|
||||
|
||||
// IntrospectToken 验证 token
|
||||
func (m *SSOManager) IntrospectToken(token string) (*SSOTokenInfo, error) {
|
||||
m.mu.RLock()
|
||||
session, ok := m.sessions[token]
|
||||
if !ok {
|
||||
m.mu.RUnlock()
|
||||
return &SSOTokenInfo{Active: false}, nil
|
||||
}
|
||||
|
||||
if time.Now().After(session.ExpiresAt) {
|
||||
m.mu.RUnlock()
|
||||
m.mu.Lock()
|
||||
delete(m.sessions, token)
|
||||
m.mu.Unlock()
|
||||
return &SSOTokenInfo{Active: false}, nil
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
return &SSOTokenInfo{
|
||||
Active: true,
|
||||
@@ -157,12 +220,21 @@ func (m *SSOManager) IntrospectToken(token string) (*SSOTokenInfo, error) {
|
||||
|
||||
// RevokeToken 撤销 token
|
||||
func (m *SSOManager) RevokeToken(token string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
delete(m.sessions, token)
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanupExpired 清理过期的 session(可由后台 goroutine 定期调用)
|
||||
// CleanupExpired 清理过期的 session
|
||||
func (m *SSOManager) CleanupExpired() {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.cleanupExpiredLocked()
|
||||
}
|
||||
|
||||
// cleanupExpiredLocked 内部清理方法(假设已持有锁)
|
||||
func (m *SSOManager) cleanupExpiredLocked() {
|
||||
now := time.Now()
|
||||
for key, session := range m.sessions {
|
||||
if now.After(session.ExpiresAt) {
|
||||
@@ -171,11 +243,38 @@ func (m *SSOManager) CleanupExpired() {
|
||||
}
|
||||
}
|
||||
|
||||
// evictOldest 淘汰最早的 session(假设已持有锁)
|
||||
func (m *SSOManager) evictOldest() {
|
||||
if len(m.sessions) == 0 {
|
||||
return
|
||||
}
|
||||
var oldestKey string
|
||||
var oldestTime time.Time
|
||||
for key, session := range m.sessions {
|
||||
if oldestTime.IsZero() || session.CreatedAt.Before(oldestTime) {
|
||||
oldestTime = session.CreatedAt
|
||||
oldestKey = key
|
||||
}
|
||||
}
|
||||
if oldestKey != "" {
|
||||
delete(m.sessions, oldestKey)
|
||||
}
|
||||
}
|
||||
|
||||
// SessionCount 返回当前 session 数量(用于监控)
|
||||
func (m *SSOManager) SessionCount() int {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return len(m.sessions)
|
||||
}
|
||||
|
||||
// generateSecureToken 生成安全随机 token
|
||||
func generateSecureToken(length int) string {
|
||||
func generateSecureToken(length int) (string, error) {
|
||||
bytes := make([]byte, length)
|
||||
rand.Read(bytes)
|
||||
return base64.URLEncoding.EncodeToString(bytes)[:length]
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", fmt.Errorf("failed to generate secure token: %w", err)
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(bytes)[:length], nil
|
||||
}
|
||||
|
||||
// SSOClient SSO 客户端配置存储
|
||||
@@ -189,10 +288,12 @@ type SSOClient struct {
|
||||
// SSOClientsStore SSO 客户端存储接口
|
||||
type SSOClientsStore interface {
|
||||
GetByClientID(clientID string) (*SSOClient, error)
|
||||
ValidateClientRedirectURI(clientID, redirectURI string) bool
|
||||
}
|
||||
|
||||
// DefaultSSOClientsStore 默认内存存储
|
||||
type DefaultSSOClientsStore struct {
|
||||
mu sync.RWMutex
|
||||
clients map[string]*SSOClient
|
||||
}
|
||||
|
||||
@@ -205,11 +306,15 @@ func NewDefaultSSOClientsStore() *DefaultSSOClientsStore {
|
||||
|
||||
// RegisterClient 注册客户端
|
||||
func (s *DefaultSSOClientsStore) RegisterClient(client *SSOClient) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.clients[client.ClientID] = client
|
||||
}
|
||||
|
||||
// GetByClientID 根据 ClientID 获取客户端
|
||||
func (s *DefaultSSOClientsStore) GetByClientID(clientID string) (*SSOClient, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
client, ok := s.clients[clientID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("client not found: %s", clientID)
|
||||
|
||||
@@ -99,6 +99,8 @@ func setupRealServer(t *testing.T) (*httptest.Server, func()) {
|
||||
captchaSvc := service.NewCaptchaService(cacheManager)
|
||||
totpSvc := service.NewTOTPService(userRepo)
|
||||
webhookSvc := service.NewWebhookService(db)
|
||||
exportSvc := service.NewExportService(userRepo, roleRepo)
|
||||
statsSvc := service.NewStatsService(userRepo, loginLogRepo)
|
||||
|
||||
authH := handler.NewAuthHandler(authSvc)
|
||||
userH := handler.NewUserHandler(userSvc)
|
||||
@@ -111,9 +113,11 @@ func setupRealServer(t *testing.T) (*httptest.Server, func()) {
|
||||
totpH := handler.NewTOTPHandler(authSvc, totpSvc)
|
||||
webhookH := handler.NewWebhookHandler(webhookSvc)
|
||||
smsH := handler.NewSMSHandler()
|
||||
exportH := handler.NewExportHandler(exportSvc)
|
||||
statsH := handler.NewStatsHandler(statsSvc)
|
||||
|
||||
rateLimitMW := middleware.NewRateLimitMiddleware(config.RateLimitConfig{})
|
||||
authMW := middleware.NewAuthMiddleware(jwtManager, userRepo, userRoleRepo, roleRepo, rolePermissionRepo, permissionRepo)
|
||||
authMW := middleware.NewAuthMiddleware(jwtManager, userRepo, userRoleRepo, roleRepo, rolePermissionRepo, permissionRepo, l1Cache)
|
||||
authMW.SetCacheManager(cacheManager)
|
||||
opLogMW := middleware.NewOperationLogMiddleware(operationLogRepo)
|
||||
ipFilterMW := middleware.NewIPFilterMiddleware(security.NewIPFilter(), middleware.IPFilterConfig{})
|
||||
@@ -122,7 +126,7 @@ func setupRealServer(t *testing.T) (*httptest.Server, func()) {
|
||||
authH, userH, roleH, permH, deviceH, logH,
|
||||
authMW, rateLimitMW, opLogMW,
|
||||
pwdResetH, captchaH, totpH, webhookH,
|
||||
ipFilterMW, nil, nil, smsH, nil, nil, nil,
|
||||
ipFilterMW, exportH, statsH, smsH, nil, nil, nil,
|
||||
)
|
||||
engine := r.Setup()
|
||||
|
||||
@@ -413,7 +417,32 @@ func doGet(t *testing.T, url string, token string) *http.Response {
|
||||
func decodeJSON(t *testing.T, body io.ReadCloser, v interface{}) {
|
||||
t.Helper()
|
||||
defer body.Close()
|
||||
if err := json.NewDecoder(body).Decode(v); err != nil {
|
||||
raw, err := io.ReadAll(body)
|
||||
if err != nil {
|
||||
t.Logf("读取响应 body 失败: %v(非致命)", err)
|
||||
return
|
||||
}
|
||||
|
||||
// 尝试解包 ResponseWrapper 标准格式 {code:0, message:"...", data:{...}}
|
||||
// 只在目标是 map[string]interface{} 时尝试透明解包
|
||||
if target, ok := v.(*map[string]interface{}); ok {
|
||||
var outer struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
if json.Unmarshal(raw, &outer) == nil && outer.Data != nil && len(outer.Data) > 2 {
|
||||
// 有 data 字段,尝试把 data 内容解包到目标
|
||||
var inner map[string]interface{}
|
||||
if json.Unmarshal(outer.Data, &inner) == nil && len(inner) > 0 {
|
||||
*target = inner
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 退化:直接解析原始 JSON
|
||||
if err := json.Unmarshal(raw, v); err != nil {
|
||||
t.Logf("解析响应 JSON 失败: %v(非致命)", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,8 +161,12 @@ func (r *PermissionRepository) Search(ctx context.Context, keyword string, offse
|
||||
var permissions []*domain.Permission
|
||||
var total int64
|
||||
|
||||
// 转义 LIKE 特殊字符,防止搜索被意外干扰
|
||||
escapedKeyword := escapeLikePattern(keyword)
|
||||
pattern := "%" + escapedKeyword + "%"
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Permission{}).
|
||||
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", pattern, pattern, pattern)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
|
||||
@@ -135,8 +135,12 @@ func (r *RoleRepository) Search(ctx context.Context, keyword string, offset, lim
|
||||
var roles []*domain.Role
|
||||
var total int64
|
||||
|
||||
// 转义 LIKE 特殊字符,防止搜索被意外干扰
|
||||
escapedKeyword := escapeLikePattern(keyword)
|
||||
pattern := "%" + escapedKeyword + "%"
|
||||
|
||||
query := r.db.WithContext(ctx).Model(&domain.Role{}).
|
||||
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", "%"+keyword+"%", "%"+keyword+"%", "%"+keyword+"%")
|
||||
Where("name LIKE ? OR code LIKE ? OR description LIKE ?", pattern, pattern, pattern)
|
||||
|
||||
// 获取总数
|
||||
if err := query.Count(&total).Error; err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/user-management-system/internal/auth"
|
||||
"github.com/user-management-system/internal/domain"
|
||||
@@ -18,6 +19,11 @@ func (s *AuthService) SetEmailCodeService(svc *EmailCodeService) {
|
||||
s.emailCodeSvc = svc
|
||||
}
|
||||
|
||||
// HasEmailCodeService 判断邮箱验证码登录服务是否已配置
|
||||
func (s *AuthService) HasEmailCodeService() bool {
|
||||
return s != nil && s.emailCodeSvc != nil
|
||||
}
|
||||
|
||||
func (s *AuthService) RegisterWithActivation(ctx context.Context, req *RegisterRequest) (*UserInfo, error) {
|
||||
if err := s.validatePassword(req.Password); err != nil {
|
||||
return nil, err
|
||||
@@ -83,8 +89,11 @@ func (s *AuthService) RegisterWithActivation(ctx context.Context, req *RegisterR
|
||||
if nickname == "" {
|
||||
nickname = req.Username
|
||||
}
|
||||
// 使用独立上下文避免请求结束后被取消
|
||||
go func() {
|
||||
if err := s.emailActivationSvc.SendActivationEmail(ctx, user.ID, req.Email, nickname); err != nil {
|
||||
bgCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := s.emailActivationSvc.SendActivationEmail(bgCtx, user.ID, req.Email, nickname); err != nil {
|
||||
log.Printf("auth: send activation email failed, user_id=%d email=%s err=%v", user.ID, req.Email, err)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -294,12 +294,14 @@ func buildActivationEmailBody(username, activationURL, siteName string, ttl time
|
||||
}
|
||||
|
||||
func generateEmailCode() (string, error) {
|
||||
buffer := make([]byte, 3)
|
||||
// 使用 6 字节随机数提供足够的熵(48 位)
|
||||
buffer := make([]byte, 6)
|
||||
if _, err := cryptorand.Read(buffer); err != nil {
|
||||
return "", fmt.Errorf("generate email code failed: %w", err)
|
||||
}
|
||||
|
||||
value := int(buffer[0])<<16 | int(buffer[1])<<8 | int(buffer[2])
|
||||
value := int(buffer[0])<<40 | int(buffer[1])<<32 | int(buffer[2])<<24 |
|
||||
int(buffer[3])<<16 | int(buffer[4])<<8 | int(buffer[5])
|
||||
value = value % 1000000
|
||||
if value < 100000 {
|
||||
value += 100000
|
||||
|
||||
@@ -373,12 +373,14 @@ func isValidPhone(phone string) bool {
|
||||
}
|
||||
|
||||
func generateSMSCode() (string, error) {
|
||||
b := make([]byte, 4)
|
||||
// 使用 6 字节随机数提供足够的熵(48 位)
|
||||
b := make([]byte, 6)
|
||||
if _, err := cryptorand.Read(b); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
n := int(b[0])<<24 | int(b[1])<<16 | int(b[2])<<8 | int(b[3])
|
||||
n := int(b[0])<<40 | int(b[1])<<32 | int(b[2])<<24 |
|
||||
int(b[3])<<16 | int(b[4])<<8 | int(b[5])
|
||||
if n < 0 {
|
||||
n = -n
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user