507 lines
14 KiB
Go
507 lines
14 KiB
Go
package auth
|
||
|
||
import (
|
||
"context"
|
||
"errors"
|
||
"fmt"
|
||
"net/url"
|
||
|
||
"github.com/user-management-system/internal/auth/providers"
|
||
)
|
||
|
||
// OAuthProvider OAuth提供商类型
|
||
type OAuthProvider string
|
||
|
||
const (
|
||
OAuthProviderWeChat OAuthProvider = "wechat"
|
||
OAuthProviderQQ OAuthProvider = "qq"
|
||
OAuthProviderWeibo OAuthProvider = "weibo"
|
||
OAuthProviderGoogle OAuthProvider = "google"
|
||
OAuthProviderFacebook OAuthProvider = "facebook"
|
||
OAuthProviderTwitter OAuthProvider = "twitter"
|
||
OAuthProviderGitHub OAuthProvider = "github"
|
||
OAuthProviderAlipay OAuthProvider = "alipay"
|
||
OAuthProviderDouyin OAuthProvider = "douyin"
|
||
)
|
||
|
||
// OAuthUser OAuth用户信息
|
||
type OAuthUser struct {
|
||
Provider OAuthProvider `json:"provider"`
|
||
OpenID string `json:"open_id"`
|
||
UnionID string `json:"union_id,omitempty"`
|
||
Nickname string `json:"nickname"`
|
||
Avatar string `json:"avatar"`
|
||
Gender string `json:"gender,omitempty"`
|
||
Email string `json:"email,omitempty"`
|
||
Phone string `json:"phone,omitempty"`
|
||
Extra map[string]interface{} `json:"extra,omitempty"`
|
||
}
|
||
|
||
// OAuthToken OAuth令牌
|
||
type OAuthToken struct {
|
||
AccessToken string `json:"access_token"`
|
||
RefreshToken string `json:"refresh_token,omitempty"`
|
||
ExpiresIn int64 `json:"expires_in"`
|
||
TokenType string `json:"token_type"`
|
||
OpenID string `json:"open_id,omitempty"` // 微信等需要 openid
|
||
}
|
||
|
||
// OAuthConfig OAuth配置
|
||
type OAuthConfig struct {
|
||
ClientID string `json:"client_id"`
|
||
ClientSecret string `json:"client_secret"`
|
||
RedirectURI string `json:"redirect_uri"`
|
||
Scope string `json:"scope"`
|
||
AuthURL string `json:"auth_url"`
|
||
TokenURL string `json:"token_url"`
|
||
UserInfoURL string `json:"user_info_url"`
|
||
}
|
||
|
||
// OAuthManager OAuth管理器接口
|
||
type OAuthManager interface {
|
||
// GetAuthURL 获取授权URL
|
||
GetAuthURL(provider OAuthProvider, state string) (string, error)
|
||
|
||
// ExchangeCode 换取访问令牌
|
||
ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error)
|
||
|
||
// GetUserInfo 获取用户信息
|
||
GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error)
|
||
|
||
// ValidateToken 验证令牌
|
||
ValidateToken(token string) (bool, error)
|
||
|
||
// GetConfig 获取OAuth配置
|
||
GetConfig(provider OAuthProvider) (*OAuthConfig, bool)
|
||
|
||
// GetEnabledProviders 获取已启用的OAuth提供商
|
||
GetEnabledProviders() []OAuthProviderInfo
|
||
}
|
||
|
||
// OAuthProviderInfo OAuth提供商信息
|
||
type OAuthProviderInfo struct {
|
||
Provider OAuthProvider `json:"provider"`
|
||
Enabled bool `json:"enabled"`
|
||
Name string `json:"name"`
|
||
}
|
||
|
||
// providerEntry 内部 provider 条目
|
||
type providerEntry struct {
|
||
config *OAuthConfig
|
||
google *providers.GoogleProvider
|
||
wechat *providers.WeChatProvider
|
||
wechatRedir string
|
||
qq *providers.QQProvider
|
||
github *providers.GitHubProvider
|
||
alipay *providers.AlipayProvider
|
||
douyin *providers.DouyinProvider
|
||
}
|
||
|
||
// DefaultOAuthManager 默认OAuth管理器(集成真实 provider HTTP 调用)
|
||
type DefaultOAuthManager struct {
|
||
entries map[OAuthProvider]*providerEntry
|
||
}
|
||
|
||
// NewOAuthManager 创建OAuth管理器
|
||
func NewOAuthManager() *DefaultOAuthManager {
|
||
return &DefaultOAuthManager{
|
||
entries: make(map[OAuthProvider]*providerEntry),
|
||
}
|
||
}
|
||
|
||
// RegisterProvider 注册OAuth提供商(保留旧接口,仅存储配置)
|
||
func (m *DefaultOAuthManager) RegisterProvider(provider OAuthProvider, config *OAuthConfig) {
|
||
entry := &providerEntry{config: config}
|
||
|
||
switch provider {
|
||
case OAuthProviderGoogle:
|
||
entry.google = providers.NewGoogleProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
|
||
case OAuthProviderWeChat:
|
||
entry.wechat = providers.NewWeChatProvider(config.ClientID, config.ClientSecret, "web")
|
||
entry.wechatRedir = config.RedirectURI
|
||
case OAuthProviderQQ:
|
||
entry.qq = providers.NewQQProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
|
||
case OAuthProviderGitHub:
|
||
entry.github = providers.NewGitHubProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
|
||
case OAuthProviderAlipay:
|
||
// 支付宝使用 ClientID 存储 AppID,ClientSecret 存储 RSA 私钥
|
||
entry.alipay = providers.NewAlipayProvider(config.ClientID, config.ClientSecret, config.RedirectURI, false)
|
||
case OAuthProviderDouyin:
|
||
entry.douyin = providers.NewDouyinProvider(config.ClientID, config.ClientSecret, config.RedirectURI)
|
||
}
|
||
|
||
m.entries[provider] = entry
|
||
}
|
||
|
||
// GetConfig 获取OAuth配置
|
||
func (m *DefaultOAuthManager) GetConfig(provider OAuthProvider) (*OAuthConfig, bool) {
|
||
entry, ok := m.entries[provider]
|
||
if !ok {
|
||
return nil, false
|
||
}
|
||
return entry.config, true
|
||
}
|
||
|
||
// GetAuthURL 获取授权URL(使用真实 provider 实现)
|
||
func (m *DefaultOAuthManager) GetAuthURL(provider OAuthProvider, state string) (string, error) {
|
||
entry, ok := m.entries[provider]
|
||
if !ok {
|
||
return "", ErrOAuthProviderNotSupported
|
||
}
|
||
|
||
switch provider {
|
||
case OAuthProviderGoogle:
|
||
if entry.google != nil {
|
||
resp, err := entry.google.GetAuthURL(state)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return resp.URL, nil
|
||
}
|
||
case OAuthProviderWeChat:
|
||
if entry.wechat != nil {
|
||
resp, err := entry.wechat.GetAuthURL(entry.wechatRedir, state)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return resp.URL, nil
|
||
}
|
||
case OAuthProviderQQ:
|
||
if entry.qq != nil {
|
||
resp, err := entry.qq.GetAuthURL(state)
|
||
if err != nil {
|
||
return "", err
|
||
}
|
||
return resp.URL, nil
|
||
}
|
||
case OAuthProviderGitHub:
|
||
if entry.github != nil {
|
||
return entry.github.GetAuthURL(state)
|
||
}
|
||
case OAuthProviderAlipay:
|
||
if entry.alipay != nil {
|
||
return entry.alipay.GetAuthURL(state)
|
||
}
|
||
case OAuthProviderDouyin:
|
||
if entry.douyin != nil {
|
||
return entry.douyin.GetAuthURL(state)
|
||
}
|
||
}
|
||
|
||
// 通用 fallback:按标准 OAuth2 拼接 URL(对 QQ/微博/Twitter/Facebook)
|
||
config := entry.config
|
||
if config == nil {
|
||
return "", ErrOAuthProviderNotSupported
|
||
}
|
||
return fmt.Sprintf("%s?client_id=%s&redirect_uri=%s&response_type=code&scope=%s&state=%s",
|
||
config.AuthURL,
|
||
url.QueryEscape(config.ClientID),
|
||
url.QueryEscape(config.RedirectURI),
|
||
url.QueryEscape(config.Scope),
|
||
url.QueryEscape(state),
|
||
), nil
|
||
}
|
||
|
||
// ExchangeCode 换取访问令牌(使用真实 provider 实现)
|
||
func (m *DefaultOAuthManager) ExchangeCode(provider OAuthProvider, code string) (*OAuthToken, error) {
|
||
entry, ok := m.entries[provider]
|
||
if !ok {
|
||
return nil, ErrOAuthProviderNotSupported
|
||
}
|
||
|
||
ctx := context.Background()
|
||
|
||
switch provider {
|
||
case OAuthProviderGoogle:
|
||
if entry.google != nil {
|
||
resp, err := entry.google.ExchangeCode(ctx, code)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &OAuthToken{
|
||
AccessToken: resp.AccessToken,
|
||
RefreshToken: resp.RefreshToken,
|
||
ExpiresIn: int64(resp.ExpiresIn),
|
||
TokenType: resp.TokenType,
|
||
}, nil
|
||
}
|
||
case OAuthProviderWeChat:
|
||
if entry.wechat != nil {
|
||
resp, err := entry.wechat.ExchangeCode(ctx, code)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &OAuthToken{
|
||
AccessToken: resp.AccessToken,
|
||
RefreshToken: resp.RefreshToken,
|
||
ExpiresIn: int64(resp.ExpiresIn),
|
||
TokenType: "Bearer",
|
||
OpenID: resp.OpenID,
|
||
}, nil
|
||
}
|
||
case OAuthProviderQQ:
|
||
if entry.qq != nil {
|
||
resp, err := entry.qq.ExchangeCode(ctx, code)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
openIDResp, err := entry.qq.GetOpenID(ctx, resp.AccessToken)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &OAuthToken{
|
||
AccessToken: resp.AccessToken,
|
||
RefreshToken: resp.RefreshToken,
|
||
ExpiresIn: int64(resp.ExpiresIn),
|
||
TokenType: "Bearer",
|
||
OpenID: openIDResp.OpenID,
|
||
}, nil
|
||
}
|
||
case OAuthProviderGitHub:
|
||
if entry.github != nil {
|
||
resp, err := entry.github.ExchangeCode(ctx, code)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &OAuthToken{
|
||
AccessToken: resp.AccessToken,
|
||
TokenType: resp.TokenType,
|
||
}, nil
|
||
}
|
||
case OAuthProviderAlipay:
|
||
if entry.alipay != nil {
|
||
resp, err := entry.alipay.ExchangeCode(ctx, code)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &OAuthToken{
|
||
AccessToken: resp.AccessToken,
|
||
RefreshToken: resp.RefreshToken,
|
||
ExpiresIn: int64(resp.ExpiresIn),
|
||
TokenType: "Bearer",
|
||
OpenID: resp.UserID,
|
||
}, nil
|
||
}
|
||
case OAuthProviderDouyin:
|
||
if entry.douyin != nil {
|
||
resp, err := entry.douyin.ExchangeCode(ctx, code)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &OAuthToken{
|
||
AccessToken: resp.Data.AccessToken,
|
||
RefreshToken: resp.Data.RefreshToken,
|
||
ExpiresIn: int64(resp.Data.ExpiresIn),
|
||
TokenType: "Bearer",
|
||
OpenID: resp.Data.OpenID,
|
||
}, nil
|
||
}
|
||
}
|
||
|
||
return nil, fmt.Errorf("provider %s: real HTTP exchange not implemented yet", provider)
|
||
}
|
||
|
||
// GetUserInfo 获取用户信息(使用真实 provider 实现)
|
||
func (m *DefaultOAuthManager) GetUserInfo(provider OAuthProvider, token *OAuthToken) (*OAuthUser, error) {
|
||
entry, ok := m.entries[provider]
|
||
if !ok {
|
||
return nil, ErrOAuthProviderNotSupported
|
||
}
|
||
|
||
ctx := context.Background()
|
||
|
||
switch provider {
|
||
case OAuthProviderGoogle:
|
||
if entry.google != nil {
|
||
info, err := entry.google.GetUserInfo(ctx, token.AccessToken)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &OAuthUser{
|
||
Provider: provider,
|
||
OpenID: info.ID,
|
||
Nickname: info.Name,
|
||
Avatar: info.Picture,
|
||
Email: info.Email,
|
||
}, nil
|
||
}
|
||
case OAuthProviderWeChat:
|
||
if entry.wechat != nil {
|
||
openID := token.OpenID
|
||
info, err := entry.wechat.GetUserInfo(ctx, token.AccessToken, openID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
gender := ""
|
||
switch info.Sex {
|
||
case 1:
|
||
gender = "male"
|
||
case 2:
|
||
gender = "female"
|
||
}
|
||
return &OAuthUser{
|
||
Provider: provider,
|
||
OpenID: info.OpenID,
|
||
UnionID: info.UnionID,
|
||
Nickname: info.Nickname,
|
||
Avatar: info.HeadImgURL,
|
||
Gender: gender,
|
||
}, nil
|
||
}
|
||
case OAuthProviderQQ:
|
||
if entry.qq != nil {
|
||
info, err := entry.qq.GetUserInfo(ctx, token.AccessToken, token.OpenID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
avatar := info.FigureURL2
|
||
if avatar == "" {
|
||
avatar = info.FigureURL1
|
||
}
|
||
if avatar == "" {
|
||
avatar = info.FigureURL
|
||
}
|
||
return &OAuthUser{
|
||
Provider: provider,
|
||
OpenID: token.OpenID,
|
||
Nickname: info.Nickname,
|
||
Avatar: avatar,
|
||
Gender: info.Gender,
|
||
Extra: map[string]interface{}{
|
||
"province": info.Province,
|
||
"city": info.City,
|
||
"year": info.Year,
|
||
},
|
||
}, nil
|
||
}
|
||
case OAuthProviderGitHub:
|
||
if entry.github != nil {
|
||
info, err := entry.github.GetUserInfo(ctx, token.AccessToken)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
nickname := info.Name
|
||
if nickname == "" {
|
||
nickname = info.Login
|
||
}
|
||
return &OAuthUser{
|
||
Provider: provider,
|
||
OpenID: fmt.Sprintf("%d", info.ID),
|
||
Nickname: nickname,
|
||
Email: info.Email,
|
||
}, nil
|
||
}
|
||
case OAuthProviderAlipay:
|
||
if entry.alipay != nil {
|
||
info, err := entry.alipay.GetUserInfo(ctx, token.AccessToken)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
return &OAuthUser{
|
||
Provider: provider,
|
||
OpenID: info.UserID,
|
||
Nickname: info.Nickname,
|
||
Avatar: info.Avatar,
|
||
}, nil
|
||
}
|
||
case OAuthProviderDouyin:
|
||
if entry.douyin != nil {
|
||
info, err := entry.douyin.GetUserInfo(ctx, token.AccessToken, token.OpenID)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
gender := ""
|
||
switch info.Data.Gender {
|
||
case 1:
|
||
gender = "male"
|
||
case 2:
|
||
gender = "female"
|
||
}
|
||
return &OAuthUser{
|
||
Provider: provider,
|
||
OpenID: info.Data.OpenID,
|
||
UnionID: info.Data.UnionID,
|
||
Nickname: info.Data.Nickname,
|
||
Avatar: info.Data.Avatar,
|
||
Gender: gender,
|
||
}, nil
|
||
}
|
||
}
|
||
|
||
return nil, fmt.Errorf("provider %s: real HTTP user info not implemented yet", provider)
|
||
}
|
||
|
||
// ValidateToken 验证令牌
|
||
// 注意:由于 ValidateToken 不持有 provider 上下文,无法进行真正的 token 验证
|
||
// 对于需要验证 token 的场景,应使用 GetUserInfo 通过 provider 的 userinfo 端点验证
|
||
// 如果没有可用的 provider,返回错误
|
||
func (m *DefaultOAuthManager) ValidateToken(token string) (bool, error) {
|
||
if len(token) == 0 {
|
||
return false, nil
|
||
}
|
||
// 由于缺乏 provider 上下文,无法进行有意义的验证
|
||
// 遍历所有已启用的 provider,尝试通过 GetUserInfo 验证
|
||
// 如果没有任何 provider 可用,返回错误而不是默认通过
|
||
providers := m.GetEnabledProviders()
|
||
if len(providers) == 0 {
|
||
return false, errors.New("no OAuth providers configured")
|
||
}
|
||
// 尝试任一 provider 的 userinfo 端点验证
|
||
tokenObj := &OAuthToken{AccessToken: token}
|
||
for _, p := range providers {
|
||
if _, err := m.GetUserInfo(p.Provider, tokenObj); err == nil {
|
||
return true, nil
|
||
}
|
||
}
|
||
return false, nil
|
||
}
|
||
|
||
// ValidateTokenWithProvider 通过指定 provider 验证令牌
|
||
func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider, token string) (bool, error) {
|
||
if token == "" {
|
||
return false, nil
|
||
}
|
||
|
||
cfg, ok := m.GetConfig(provider)
|
||
if !ok || cfg.ClientID == "" {
|
||
return false, fmt.Errorf("provider %s not configured", provider)
|
||
}
|
||
|
||
// 通过 provider 的 userinfo 端点验证 token
|
||
tokenObj := &OAuthToken{AccessToken: token}
|
||
_, err := m.GetUserInfo(provider, tokenObj)
|
||
if err != nil {
|
||
return false, err
|
||
}
|
||
return true, nil
|
||
}
|
||
|
||
// GetEnabledProviders 获取已启用的OAuth提供商
|
||
func (m *DefaultOAuthManager) GetEnabledProviders() []OAuthProviderInfo {
|
||
providerNames := map[OAuthProvider]string{
|
||
OAuthProviderGoogle: "Google",
|
||
OAuthProviderWeChat: "微信",
|
||
OAuthProviderQQ: "QQ",
|
||
OAuthProviderWeibo: "微博",
|
||
OAuthProviderFacebook: "Facebook",
|
||
OAuthProviderTwitter: "Twitter",
|
||
OAuthProviderGitHub: "GitHub",
|
||
OAuthProviderAlipay: "支付宝",
|
||
OAuthProviderDouyin: "抖音",
|
||
}
|
||
|
||
var result []OAuthProviderInfo
|
||
for provider, entry := range m.entries {
|
||
name := providerNames[provider]
|
||
if name == "" {
|
||
name = string(provider)
|
||
}
|
||
result = append(result, OAuthProviderInfo{
|
||
Provider: provider,
|
||
Enabled: entry.config != nil,
|
||
Name: name,
|
||
})
|
||
}
|
||
return result
|
||
}
|