Files
user-system/internal/api/handler/sso_handler.go
2026-05-30 21:29:24 +08:00

346 lines
11 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"crypto/subtle"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/auth"
)
// SSOHandler SSO 处理程序
type SSOHandler struct {
ssoManager *auth.SSOManager
clientsStore auth.SSOClientsStore
}
// NewSSOHandler 创建 SSO 处理程序
func NewSSOHandler(ssoManager *auth.SSOManager, clientsStore auth.SSOClientsStore) *SSOHandler {
return &SSOHandler{
ssoManager: ssoManager,
clientsStore: clientsStore,
}
}
// AuthorizeRequest 授权请求
type AuthorizeRequest struct {
ClientID string `form:"client_id" binding:"required"`
RedirectURI string `form:"redirect_uri" binding:"required"`
ResponseType string `form:"response_type" binding:"required"`
Scope string `form:"scope"`
State string `form:"state"`
}
// Authorize 处理 SSO 授权请求
// @Summary SSO 授权
// @Description 处理 SSO 授权请求,返回授权码
// @Tags SSO
// @Accept json
// @Produce json
// @Security BearerAuth
// @Param client_id query string true "客户端ID"
// @Param redirect_uri query string true "回调地址"
// @Param response_type query string true "响应类型" Enums(code)
// @Param scope query string false "授权范围"
// @Param state query string false "状态参数"
// @Success 302 {string} string "重定向到回调地址"
// @Failure 400 {object} Response "请求参数错误"
// @Failure 401 {object} Response "未认证"
// @Failure 500 {object} Response "服务器错误"
// @Router /api/v1/sso/authorize [get]
func (h *SSOHandler) Authorize(c *gin.Context) {
var req AuthorizeRequest
if err := c.ShouldBindQuery(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
return
}
if req.ResponseType != "code" {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "unsupported response_type"})
return
}
if h.clientsStore == nil || !h.clientsStore.ValidateClientRedirectURI(req.ClientID, req.RedirectURI) {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "invalid redirect_uri"})
return
}
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
return
}
username, ok := getUsernameFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
return
}
code, err := h.ssoManager.GenerateAuthorizationCode(
req.ClientID,
req.RedirectURI,
req.Scope,
userID,
username,
)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to generate code"})
return
}
redirectURL := req.RedirectURI + "?code=" + code
if req.State != "" {
redirectURL += "&state=" + req.State
}
c.Redirect(http.StatusFound, redirectURL)
}
// TokenRequest Token 请求
type TokenRequest struct {
GrantType string `form:"grant_type" binding:"required"`
Code string `form:"code"`
RedirectURI string `form:"redirect_uri"`
ClientID string `form:"client_id" binding:"required"`
ClientSecret string `form:"client_secret" binding:"required"`
}
// TokenResponse Token 响应
type TokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int64 `json:"expires_in"`
Scope string `json:"scope"`
}
// Token 处理 Token 请求
// @Summary 获取 Access Token
// @Description 使用授权码获取 Access Token授权码模式第二步
// @Tags SSO
// @Accept x-www-form-urlencoded
// @Produce json
// @Param grant_type formData string true "授权类型" Enums(authorization_code)
// @Param code formData string true "授权码"
// @Param redirect_uri formData string true "回调地址"
// @Param client_id formData string true "客户端ID"
// @Param client_secret formData string true "客户端密钥"
// @Success 200 {object} Response{data=TokenResponse} "访问令牌响应"
// @Failure 400 {object} Response "请求参数错误"
// @Failure 401 {object} Response "客户端认证失败"
// @Failure 500 {object} Response "服务器错误"
// @Router /api/v1/sso/token [post]
func (h *SSOHandler) Token(c *gin.Context) {
var req TokenRequest
if err := c.ShouldBind(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
return
}
if req.GrantType != "authorization_code" {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "unsupported grant_type"})
return
}
if req.Code == "" || req.RedirectURI == "" {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "code and redirect_uri are required"})
return
}
client, ok := h.authenticateClient(req.ClientID, req.ClientSecret)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "invalid client credentials"})
return
}
if !h.clientsStore.ValidateClientRedirectURI(client.ClientID, req.RedirectURI) {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "invalid redirect_uri"})
return
}
session, err := h.ssoManager.ValidateAuthorizationCode(req.Code)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "invalid code"})
return
}
if session.ClientID != req.ClientID || session.RedirectURI != req.RedirectURI {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": "authorization code does not match client or redirect_uri"})
return
}
token, expiresAt, err := h.ssoManager.GenerateAccessToken(req.ClientID, session)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"code": 500, "message": "failed to generate token"})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "success",
"data": TokenResponse{
AccessToken: token,
TokenType: "Bearer",
ExpiresIn: int64(time.Until(expiresAt).Seconds()),
Scope: session.Scope,
},
})
}
// IntrospectRequest Introspect 请求
type IntrospectRequest struct {
Token string `form:"token" binding:"required"`
ClientID string `form:"client_id"`
}
// IntrospectResponse Introspect 响应
type IntrospectResponse struct {
Active bool `json:"active"`
UserID int64 `json:"user_id,omitempty"`
Username string `json:"username,omitempty"`
ExpiresAt int64 `json:"exp,omitempty"`
Scope string `json:"scope,omitempty"`
}
// Introspect 验证 access token
// @Summary 验证 Access Token
// @Description 验证 Access Token 的有效性并返回相关信息
// @Tags SSO
// @Accept x-www-form-urlencoded
// @Produce json
// @Param token formData string true "Access Token"
// @Param client_id formData string true "客户端ID"
// @Param client_secret formData string true "客户端密钥"
// @Success 200 {object} Response{data=IntrospectResponse} "Token信息"
// @Failure 400 {object} Response "请求参数错误"
// @Failure 401 {object} Response "客户端认证失败"
// @Router /api/v1/sso/introspect [post]
func (h *SSOHandler) Introspect(c *gin.Context) {
var req struct {
Token string `form:"token" binding:"required"`
ClientID string `form:"client_id" binding:"required"`
ClientSecret string `form:"client_secret" binding:"required"`
}
if err := c.ShouldBind(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
return
}
if _, ok := h.authenticateClient(req.ClientID, req.ClientSecret); !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "invalid client credentials"})
return
}
info, err := h.ssoManager.IntrospectToken(req.Token)
if err != nil {
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "success", "data": IntrospectResponse{Active: false}})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "success",
"data": IntrospectResponse{
Active: info.Active,
UserID: info.UserID,
Username: info.Username,
ExpiresAt: info.ExpiresAt.Unix(),
Scope: info.Scope,
},
})
}
// RevokeRequest 撤销请求
type RevokeRequest struct {
Token string `form:"token" binding:"required"`
}
// Revoke 撤销 access token
// @Summary 撤销 Access Token
// @Description 撤销指定的 Access Token
// @Tags SSO
// @Accept x-www-form-urlencoded
// @Produce json
// @Param token formData string true "Access Token"
// @Param client_id formData string true "客户端ID"
// @Param client_secret formData string true "客户端密钥"
// @Success 200 {object} Response "撤销成功"
// @Failure 400 {object} Response "请求参数错误"
// @Failure 401 {object} Response "客户端认证失败"
// @Router /api/v1/sso/revoke [post]
func (h *SSOHandler) Revoke(c *gin.Context) {
var req struct {
Token string `form:"token" binding:"required"`
ClientID string `form:"client_id" binding:"required"`
ClientSecret string `form:"client_secret" binding:"required"`
}
if err := c.ShouldBind(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"code": 400, "message": err.Error()})
return
}
if _, ok := h.authenticateClient(req.ClientID, req.ClientSecret); !ok {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "invalid client credentials"})
return
}
_ = h.ssoManager.RevokeToken(req.Token)
c.JSON(http.StatusOK, gin.H{"code": 0, "message": "token revoked"})
}
// UserInfoResponse 用户信息响应
type UserInfoResponse struct {
UserID int64 `json:"user_id"`
Username string `json:"username"`
}
// UserInfo 获取当前用户信息
// @Summary 获取 SSO 用户信息
// @Description 获取当前通过 SSO Access Token 授权的用户信息
// @Tags SSO
// @Produce json
// @Security BearerAuth
// @Success 200 {object} Response{data=UserInfoResponse} "用户信息"
// @Failure 401 {object} Response "未认证"
// @Router /api/v1/sso/userinfo [get]
func (h *SSOHandler) UserInfo(c *gin.Context) {
token := extractBearerToken(c)
if token == "" {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "unauthorized"})
return
}
session, err := h.ssoManager.ValidateAccessToken(token)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"code": 401, "message": "invalid access token"})
return
}
c.JSON(http.StatusOK, gin.H{
"code": 0,
"message": "success",
"data": UserInfoResponse{
UserID: session.UserID,
Username: session.Username,
},
})
}
func (h *SSOHandler) authenticateClient(clientID, clientSecret string) (*auth.SSOClient, bool) {
if h.clientsStore == nil {
return nil, false
}
client, err := h.clientsStore.GetByClientID(clientID)
if err != nil {
return nil, false
}
if subtle.ConstantTimeCompare([]byte(clientSecret), []byte(client.ClientSecret)) != 1 {
return nil, false
}
return client, true
}
func extractBearerToken(c *gin.Context) string {
authorization := c.GetHeader("Authorization")
if !strings.HasPrefix(authorization, "Bearer ") {
return ""
}
return strings.TrimSpace(strings.TrimPrefix(authorization, "Bearer "))
}