feat: backend core - auth, user, role, permission, device, webhook, monitoring, cache, repository, service, middleware, API handlers

This commit is contained in:
2026-04-02 11:19:50 +08:00
parent e59a77bc49
commit dcc1f186f8
298 changed files with 62603 additions and 0 deletions

View File

@@ -0,0 +1,256 @@
package providers
import (
"context"
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"net/http"
"net/url"
"sort"
"strings"
"time"
)
// AlipayProvider 支付宝 OAuth提供者
// 支付宝使用 RSA2 签名SHA256withRSA
type AlipayProvider struct {
AppID string
PrivateKey string // RSA2 私钥PKCS#8 PEM格式
RedirectURI string
IsSandbox bool
}
// AlipayTokenResponse 支付宝 Token响应
type AlipayTokenResponse struct {
UserID string `json:"user_id"`
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
}
// AlipayUserInfo 支付宝用户信息
type AlipayUserInfo struct {
UserID string `json:"user_id"`
Nickname string `json:"nick_name"`
Avatar string `json:"avatar"`
Gender string `json:"gender"`
}
// NewAlipayProvider 创建支付宝 OAuth提供者
func NewAlipayProvider(appID, privateKey, redirectURI string, isSandbox bool) *AlipayProvider {
return &AlipayProvider{
AppID: appID,
PrivateKey: privateKey,
RedirectURI: redirectURI,
IsSandbox: isSandbox,
}
}
func (a *AlipayProvider) getGateway() string {
if a.IsSandbox {
return "https://openapi-sandbox.dl.alipaydev.com/gateway.do"
}
return "https://openapi.alipay.com/gateway.do"
}
// GetAuthURL 获取支付宝授权URL
func (a *AlipayProvider) GetAuthURL(state string) (string, error) {
authURL := fmt.Sprintf(
"https://openauth.alipay.com/oauth2/publicAppAuthorize.htm?app_id=%s&scope=auth_user&redirect_uri=%s&state=%s",
a.AppID,
url.QueryEscape(a.RedirectURI),
url.QueryEscape(state),
)
return authURL, nil
}
// ExchangeCode 用授权码换取 access_token
func (a *AlipayProvider) ExchangeCode(ctx context.Context, code string) (*AlipayTokenResponse, error) {
params := map[string]string{
"app_id": a.AppID,
"method": "alipay.system.oauth.token",
"charset": "UTF-8",
"sign_type": "RSA2",
"timestamp": time.Now().Format("2006-01-02 15:04:05"),
"version": "1.0",
"grant_type": "authorization_code",
"code": code,
}
if a.PrivateKey != "" {
sign, err := a.signParams(params)
if err != nil {
return nil, fmt.Errorf("sign failed: %w", err)
}
params["sign"] = sign
}
form := url.Values{}
for k, v := range params {
form.Set(k, v)
}
req, err := http.NewRequestWithContext(ctx, "POST", a.getGateway(),
strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var rawResp map[string]json.RawMessage
if err := json.Unmarshal(body, &rawResp); err != nil {
return nil, fmt.Errorf("parse response failed: %w", err)
}
tokenData, ok := rawResp["alipay_system_oauth_token_response"]
if !ok {
return nil, fmt.Errorf("invalid alipay response structure")
}
var tokenResp AlipayTokenResponse
if err := json.Unmarshal(tokenData, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
return &tokenResp, nil
}
// GetUserInfo 获取支付宝用户信息
func (a *AlipayProvider) GetUserInfo(ctx context.Context, accessToken string) (*AlipayUserInfo, error) {
params := map[string]string{
"app_id": a.AppID,
"method": "alipay.user.info.share",
"charset": "UTF-8",
"sign_type": "RSA2",
"timestamp": time.Now().Format("2006-01-02 15:04:05"),
"version": "1.0",
"auth_token": accessToken,
}
if a.PrivateKey != "" {
sign, err := a.signParams(params)
if err != nil {
return nil, fmt.Errorf("sign failed: %w", err)
}
params["sign"] = sign
}
form := url.Values{}
for k, v := range params {
form.Set(k, v)
}
req, err := http.NewRequestWithContext(ctx, "POST", a.getGateway(),
strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var rawResp map[string]json.RawMessage
if err := json.Unmarshal(body, &rawResp); err != nil {
return nil, fmt.Errorf("parse response failed: %w", err)
}
userData, ok := rawResp["alipay_user_info_share_response"]
if !ok {
return nil, fmt.Errorf("invalid alipay user info response")
}
var userInfo AlipayUserInfo
if err := json.Unmarshal(userData, &userInfo); err != nil {
return nil, fmt.Errorf("parse user info failed: %w", err)
}
return &userInfo, nil
}
// signParams 使用 RSA2SHA256withRSA对参数签名
func (a *AlipayProvider) signParams(params map[string]string) (string, error) {
// 按字典序排列参数
keys := make([]string, 0, len(params))
for k := range params {
if k != "sign" {
keys = append(keys, k)
}
}
sort.Strings(keys)
var parts []string
for _, k := range keys {
parts = append(parts, k+"="+params[k])
}
signContent := strings.Join(parts, "&")
// 解析私钥
privKey, err := parseAlipayPrivateKey(a.PrivateKey)
if err != nil {
return "", fmt.Errorf("parse private key: %w", err)
}
// SHA256withRSA 签名
hash := sha256.Sum256([]byte(signContent))
signature, err := rsa.SignPKCS1v15(rand.Reader, privKey, crypto.SHA256, hash[:])
if err != nil {
return "", fmt.Errorf("rsa sign: %w", err)
}
return base64.StdEncoding.EncodeToString(signature), nil
}
// parseAlipayPrivateKey 解析支付宝私钥(支持 PKCS#8 和 PKCS#1
func parseAlipayPrivateKey(pemStr string) (*rsa.PrivateKey, error) {
// 如果没有 PEM 头,添加 PKCS#8 头
if !strings.Contains(pemStr, "-----BEGIN") {
pemStr = "-----BEGIN PRIVATE KEY-----\n" + pemStr + "\n-----END PRIVATE KEY-----"
}
block, _ := pem.Decode([]byte(pemStr))
if block == nil {
return nil, fmt.Errorf("failed to decode PEM block")
}
// 尝试 PKCS#8
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err == nil {
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("not an RSA private key")
}
return rsaKey, nil
}
// 尝试 PKCS#1
return x509.ParsePKCS1PrivateKey(block.Bytes)
}

View File

@@ -0,0 +1,138 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
// DouyinProvider 抖音 OAuth提供者
// 抖音 OAuth 文档https://developer.open-douyin.com/docs/resource/zh-CN/dop/develop/openapi/account-permission/get-access-token
type DouyinProvider struct {
ClientKey string // 抖音开放平台 client_key
ClientSecret string // 抖音开放平台 client_secret
RedirectURI string
}
// DouyinTokenResponse 抖音 Token响应
type DouyinTokenResponse struct {
Data struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
RefreshExpiresIn int `json:"refresh_expires_in"`
OpenID string `json:"open_id"`
Scope string `json:"scope"`
} `json:"data"`
Message string `json:"message"`
}
// DouyinUserInfo 抖音用户信息
type DouyinUserInfo struct {
Data struct {
OpenID string `json:"open_id"`
UnionID string `json:"union_id"`
Nickname string `json:"nickname"`
Avatar string `json:"avatar"`
Gender int `json:"gender"` // 0:未知 1:男 2:女
Country string `json:"country"`
Province string `json:"province"`
City string `json:"city"`
} `json:"data"`
Message string `json:"message"`
}
// NewDouyinProvider 创建抖音 OAuth提供者
func NewDouyinProvider(clientKey, clientSecret, redirectURI string) *DouyinProvider {
return &DouyinProvider{
ClientKey: clientKey,
ClientSecret: clientSecret,
RedirectURI: redirectURI,
}
}
// GetAuthURL 获取抖音授权URL
func (d *DouyinProvider) GetAuthURL(state string) (string, error) {
authURL := fmt.Sprintf(
"https://open.douyin.com/platform/oauth/connect?client_key=%s&redirect_uri=%s&response_type=code&scope=user_info&state=%s",
d.ClientKey,
url.QueryEscape(d.RedirectURI),
url.QueryEscape(state),
)
return authURL, nil
}
// ExchangeCode 用授权码换取 access_token
func (d *DouyinProvider) ExchangeCode(ctx context.Context, code string) (*DouyinTokenResponse, error) {
tokenURL := "https://open.douyin.com/oauth/access_token/"
data := url.Values{}
data.Set("client_key", d.ClientKey)
data.Set("client_secret", d.ClientSecret)
data.Set("code", code)
data.Set("grant_type", "authorization_code")
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL,
strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var tokenResp DouyinTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
if tokenResp.Data.AccessToken == "" {
return nil, fmt.Errorf("抖音 OAuth: %s", tokenResp.Message)
}
return &tokenResp, nil
}
// GetUserInfo 获取抖音用户信息
func (d *DouyinProvider) GetUserInfo(ctx context.Context, accessToken, openID string) (*DouyinUserInfo, error) {
userInfoURL := fmt.Sprintf("https://open.douyin.com/oauth/userinfo/?open_id=%s&access_token=%s",
url.QueryEscape(openID), url.QueryEscape(accessToken))
req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var userInfo DouyinUserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("parse user info failed: %w", err)
}
return &userInfo, nil
}

View File

@@ -0,0 +1,207 @@
package providers
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"
)
// FacebookProvider Facebook OAuth提供者
type FacebookProvider struct {
AppID string
AppSecret string
RedirectURI string
}
// FacebookAuthURLResponse Facebook授权URL响应
type FacebookAuthURLResponse struct {
URL string `json:"url"`
State string `json:"state"`
Redirect string `json:"redirect,omitempty"`
}
// FacebookTokenResponse Facebook Token响应
type FacebookTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
// FacebookUserInfo Facebook用户信息
type FacebookUserInfo struct {
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
Picture struct {
Data struct {
URL string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
IsSilhouette bool `json:"is_silhouette"`
} `json:"data"`
} `json:"picture"`
}
// NewFacebookProvider 创建Facebook OAuth提供者
func NewFacebookProvider(appID, appSecret, redirectURI string) *FacebookProvider {
return &FacebookProvider{
AppID: appID,
AppSecret: appSecret,
RedirectURI: redirectURI,
}
}
// GenerateState 生成随机状态码
func (f *FacebookProvider) GenerateState() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// GetAuthURL 获取Facebook授权URL
func (f *FacebookProvider) GetAuthURL(state string) (*FacebookAuthURLResponse, error) {
authURL := fmt.Sprintf(
"https://www.facebook.com/v18.0/dialog/oauth?client_id=%s&redirect_uri=%s&scope=email,public_profile&response_type=code&state=%s",
f.AppID,
url.QueryEscape(f.RedirectURI),
state,
)
return &FacebookAuthURLResponse{
URL: authURL,
State: state,
Redirect: f.RedirectURI,
}, nil
}
// ExchangeCode 用授权码换取访问令牌
func (f *FacebookProvider) ExchangeCode(ctx context.Context, code string) (*FacebookTokenResponse, error) {
tokenURL := fmt.Sprintf(
"https://graph.facebook.com/v18.0/oauth/access_token?client_id=%s&client_secret=%s&redirect_uri=%s&code=%s",
f.AppID,
f.AppSecret,
url.QueryEscape(f.RedirectURI),
code,
)
req, err := http.NewRequestWithContext(ctx, "GET", tokenURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var tokenResp FacebookTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
return &tokenResp, nil
}
// GetUserInfo 获取Facebook用户信息
func (f *FacebookProvider) GetUserInfo(ctx context.Context, accessToken string) (*FacebookUserInfo, error) {
// 请求用户信息(包括头像)
userInfoURL := fmt.Sprintf(
"https://graph.facebook.com/v18.0/me?fields=id,name,email,picture&access_token=%s",
accessToken,
)
req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
// Facebook错误响应
var errResp struct {
Error struct {
Message string `json:"message"`
Type string `json:"type"`
Code int `json:"code"`
ErrorSubcode int `json:"error_subcode,omitempty"`
} `json:"error"`
}
if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error.Message != "" {
return nil, fmt.Errorf("facebook api error: %s", errResp.Error.Message)
}
var userInfo FacebookUserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("parse user info failed: %w", err)
}
return &userInfo, nil
}
// ValidateToken 验证访问令牌是否有效
func (f *FacebookProvider) ValidateToken(ctx context.Context, accessToken string) (bool, error) {
userInfo, err := f.GetUserInfo(ctx, accessToken)
if err != nil {
return false, err
}
return userInfo != nil && userInfo.ID != "", nil
}
// GetLongLivedToken 获取长期有效的访问令牌60天
func (f *FacebookProvider) GetLongLivedToken(ctx context.Context, shortLivedToken string) (*FacebookTokenResponse, error) {
tokenURL := fmt.Sprintf(
"https://graph.facebook.com/v18.0/oauth/access_token?grant_type=fb_exchange_token&client_id=%s&client_secret=%s&fb_exchange_token=%s",
f.AppID,
f.AppSecret,
shortLivedToken,
)
req, err := http.NewRequestWithContext(ctx, "GET", tokenURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var tokenResp FacebookTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
return &tokenResp, nil
}

View File

@@ -0,0 +1,172 @@
package providers
import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"time"
)
// GitHubProvider GitHub OAuth提供者
type GitHubProvider struct {
ClientID string
ClientSecret string
RedirectURI string
}
// GitHubTokenResponse GitHub Token响应
type GitHubTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
}
// GitHubUserInfo GitHub用户信息
type GitHubUserInfo struct {
ID int64 `json:"id"`
Login string `json:"login"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatar_url"`
Bio string `json:"bio"`
Location string `json:"location"`
}
// NewGitHubProvider 创建GitHub OAuth提供者
func NewGitHubProvider(clientID, clientSecret, redirectURI string) *GitHubProvider {
return &GitHubProvider{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURI: redirectURI,
}
}
// GetAuthURL 获取GitHub授权URL
func (g *GitHubProvider) GetAuthURL(state string) (string, error) {
authURL := fmt.Sprintf(
"https://github.com/login/oauth/authorize?client_id=%s&redirect_uri=%s&scope=read:user,user:email&state=%s",
g.ClientID,
url.QueryEscape(g.RedirectURI),
url.QueryEscape(state),
)
return authURL, nil
}
// ExchangeCode 用授权码换取访问令牌
func (g *GitHubProvider) ExchangeCode(ctx context.Context, code string) (*GitHubTokenResponse, error) {
tokenURL := "https://github.com/login/oauth/access_token"
data := url.Values{}
data.Set("client_id", g.ClientID)
data.Set("client_secret", g.ClientSecret)
data.Set("code", code)
data.Set("redirect_uri", g.RedirectURI)
req, err := http.NewRequestWithContext(ctx, "POST", tokenURL,
strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var tokenResp GitHubTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
if tokenResp.AccessToken == "" {
return nil, fmt.Errorf("GitHub OAuth: empty access token in response")
}
return &tokenResp, nil
}
// GetUserInfo 获取GitHub用户信息
func (g *GitHubProvider) GetUserInfo(ctx context.Context, accessToken string) (*GitHubUserInfo, error) {
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("X-GitHub-Api-Version", "2022-11-28")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var userInfo GitHubUserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("parse user info failed: %w", err)
}
// 如果用户信息中的邮箱为空,尝试通过邮箱 API 获取主要邮箱
if userInfo.Email == "" {
email, _ := g.getPrimaryEmail(ctx, accessToken)
userInfo.Email = email
}
return &userInfo, nil
}
// getPrimaryEmail 获取用户的主要邮箱
func (g *GitHubProvider) getPrimaryEmail(ctx context.Context, accessToken string) (string, error) {
req, err := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user/emails", nil)
if err != nil {
return "", err
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Accept", "application/vnd.github+json")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return "", err
}
var emails []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
Verified bool `json:"verified"`
}
if err := json.Unmarshal(body, &emails); err != nil {
return "", err
}
for _, e := range emails {
if e.Primary && e.Verified {
return e.Email, nil
}
}
return "", nil
}

View File

@@ -0,0 +1,182 @@
package providers
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"
)
// GoogleProvider Google OAuth提供者
type GoogleProvider struct {
ClientID string
ClientSecret string
RedirectURI string
}
// GoogleAuthURLResponse Google授权URL响应
type GoogleAuthURLResponse struct {
URL string `json:"url"`
State string `json:"state"`
Redirect string `json:"redirect,omitempty"`
}
// GoogleTokenResponse Google Token响应
type GoogleTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
}
// GoogleUserInfo Google用户信息
type GoogleUserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
VerifiedEmail bool `json:"verified_email"`
Name string `json:"name"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Picture string `json:"picture"`
Locale string `json:"locale"`
}
// NewGoogleProvider 创建Google OAuth提供者
func NewGoogleProvider(clientID, clientSecret, redirectURI string) *GoogleProvider {
return &GoogleProvider{
ClientID: clientID,
ClientSecret: clientSecret,
RedirectURI: redirectURI,
}
}
// GenerateState 生成随机状态码
func (g *GoogleProvider) GenerateState() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// GetAuthURL 获取Google授权URL
func (g *GoogleProvider) GetAuthURL(state string) (*GoogleAuthURLResponse, error) {
authURL := fmt.Sprintf(
"https://accounts.google.com/o/oauth2/v2/auth?client_id=%s&redirect_uri=%s&response_type=code&scope=openid+email+profile&state=%s",
g.ClientID,
url.QueryEscape(g.RedirectURI),
state,
)
return &GoogleAuthURLResponse{
URL: authURL,
State: state,
Redirect: g.RedirectURI,
}, nil
}
// ExchangeCode 用授权码换取访问令牌
func (g *GoogleProvider) ExchangeCode(ctx context.Context, code string) (*GoogleTokenResponse, error) {
tokenURL := "https://oauth2.googleapis.com/token"
data := url.Values{}
data.Set("code", code)
data.Set("client_id", g.ClientID)
data.Set("client_secret", g.ClientSecret)
data.Set("redirect_uri", g.RedirectURI)
data.Set("grant_type", "authorization_code")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := postFormWithContext(ctx, client, tokenURL, data)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var tokenResp GoogleTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
return &tokenResp, nil
}
// GetUserInfo 获取Google用户信息
func (g *GoogleProvider) GetUserInfo(ctx context.Context, accessToken string) (*GoogleUserInfo, error) {
userInfoURL := fmt.Sprintf("https://www.googleapis.com/oauth2/v2/userinfo?access_token=%s", accessToken)
req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var userInfo GoogleUserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("parse user info failed: %w", err)
}
return &userInfo, nil
}
// RefreshToken 刷新访问令牌
func (g *GoogleProvider) RefreshToken(ctx context.Context, refreshToken string) (*GoogleTokenResponse, error) {
tokenURL := "https://oauth2.googleapis.com/token"
data := url.Values{}
data.Set("refresh_token", refreshToken)
data.Set("client_id", g.ClientID)
data.Set("client_secret", g.ClientSecret)
data.Set("grant_type", "refresh_token")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := postFormWithContext(ctx, client, tokenURL, data)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var tokenResp GoogleTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
return &tokenResp, nil
}
// ValidateToken 验证访问令牌是否有效
func (g *GoogleProvider) ValidateToken(ctx context.Context, accessToken string) (bool, error) {
userInfo, err := g.GetUserInfo(ctx, accessToken)
if err != nil {
return false, err
}
return userInfo != nil, nil
}

View File

@@ -0,0 +1,43 @@
package providers
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
"strings"
)
const maxOAuthResponseBodyBytes = 1 << 20
func postFormWithContext(ctx context.Context, client *http.Client, endpoint string, data url.Values) (*http.Response, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, err
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
return client.Do(req)
}
func readOAuthResponseBody(resp *http.Response) ([]byte, error) {
limited := io.LimitReader(resp.Body, maxOAuthResponseBodyBytes+1)
body, err := io.ReadAll(limited)
if err != nil {
return nil, err
}
if len(body) > maxOAuthResponseBodyBytes {
return nil, fmt.Errorf("oauth response body exceeded %d bytes", maxOAuthResponseBodyBytes)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
snippet := strings.TrimSpace(string(body))
if len(snippet) > 256 {
snippet = snippet[:256]
}
if snippet == "" {
return nil, fmt.Errorf("oauth request failed with status %d", resp.StatusCode)
}
return nil, fmt.Errorf("oauth request failed with status %d: %s", resp.StatusCode, snippet)
}
return body, nil
}

View File

@@ -0,0 +1,66 @@
package providers
import (
"bytes"
"io"
"net/http"
"strings"
"testing"
)
func TestReadOAuthResponseBodyRejectsOversizedResponse(t *testing.T) {
resp := &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(bytes.NewReader(
bytes.Repeat([]byte("a"), maxOAuthResponseBodyBytes+1),
)),
}
_, err := readOAuthResponseBody(resp)
if err == nil || !strings.Contains(err.Error(), "exceeded") {
t.Fatalf("expected oversized response error, got %v", err)
}
}
func TestReadOAuthResponseBodyRejectsNonSuccessStatus(t *testing.T) {
resp := &http.Response{
StatusCode: http.StatusBadGateway,
Body: io.NopCloser(strings.NewReader("provider unavailable")),
}
_, err := readOAuthResponseBody(resp)
if err == nil || !strings.Contains(err.Error(), "502") {
t.Fatalf("expected status error, got %v", err)
}
}
func TestReadOAuthResponseBodyHandlesEmptyErrorBody(t *testing.T) {
resp := &http.Response{
StatusCode: http.StatusServiceUnavailable,
Body: io.NopCloser(strings.NewReader(" ")),
}
_, err := readOAuthResponseBody(resp)
if err == nil || !strings.Contains(err.Error(), "503") {
t.Fatalf("expected empty-body status error, got %v", err)
}
}
func TestReadOAuthResponseBodyTruncatesLongErrorSnippet(t *testing.T) {
longBody := strings.Repeat("x", 400)
resp := &http.Response{
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(strings.NewReader(longBody)),
}
_, err := readOAuthResponseBody(resp)
if err == nil {
t.Fatal("expected long error body to produce status error")
}
if !strings.Contains(err.Error(), "400") {
t.Fatalf("expected status code in error, got %v", err)
}
if strings.Contains(err.Error(), strings.Repeat("x", 300)) {
t.Fatalf("expected error snippet to be truncated, got %v", err)
}
}

View File

@@ -0,0 +1,169 @@
package providers
import (
"crypto"
"crypto/rand"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"net/url"
"strings"
"testing"
)
func generateRSAKeyForTest(t *testing.T) *rsa.PrivateKey {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
t.Fatalf("generate rsa key failed: %v", err)
}
return key
}
func marshalPKCS8PEMForTest(t *testing.T, key *rsa.PrivateKey) string {
t.Helper()
der, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
t.Fatalf("marshal PKCS#8 failed: %v", err)
}
return string(pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: der,
}))
}
func TestParseAlipayPrivateKeySupportsRawPKCS8AndPKCS1(t *testing.T) {
key := generateRSAKeyForTest(t)
pkcs8DER, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
t.Fatalf("marshal PKCS#8 failed: %v", err)
}
rawPKCS8 := base64.StdEncoding.EncodeToString(pkcs8DER)
parsedPKCS8, err := parseAlipayPrivateKey(rawPKCS8)
if err != nil {
t.Fatalf("parse raw PKCS#8 key failed: %v", err)
}
if parsedPKCS8.N.Cmp(key.N) != 0 || parsedPKCS8.D.Cmp(key.D) != 0 {
t.Fatal("parsed raw PKCS#8 key does not match original key")
}
pkcs1PEM := string(pem.EncodeToMemory(&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
}))
parsedPKCS1, err := parseAlipayPrivateKey(pkcs1PEM)
if err != nil {
t.Fatalf("parse PKCS#1 key failed: %v", err)
}
if parsedPKCS1.N.Cmp(key.N) != 0 || parsedPKCS1.D.Cmp(key.D) != 0 {
t.Fatal("parsed PKCS#1 key does not match original key")
}
}
func TestParseAlipayPrivateKeyRejectsInvalidPEM(t *testing.T) {
if _, err := parseAlipayPrivateKey("not-a-valid-private-key"); err == nil {
t.Fatal("expected invalid private key parsing to fail")
}
}
func TestAlipayProviderSignParamsProducesVerifiableSignature(t *testing.T) {
key := generateRSAKeyForTest(t)
provider := NewAlipayProvider(
"app-id",
marshalPKCS8PEMForTest(t, key),
"https://admin.example.com/login/oauth/callback",
false,
)
params := map[string]string{
"method": "alipay.system.oauth.token",
"app_id": "app-id",
"code": "auth-code",
"sign": "should-be-ignored",
}
signature, err := provider.signParams(params)
if err != nil {
t.Fatalf("signParams failed: %v", err)
}
if signature == "" {
t.Fatal("expected non-empty signature")
}
signatureBytes, err := base64.StdEncoding.DecodeString(signature)
if err != nil {
t.Fatalf("decode signature failed: %v", err)
}
signContent := "app_id=app-id&code=auth-code&method=alipay.system.oauth.token"
hash := sha256.Sum256([]byte(signContent))
if err := rsa.VerifyPKCS1v15(&key.PublicKey, crypto.SHA256, hash[:], signatureBytes); err != nil {
t.Fatalf("signature verification failed: %v", err)
}
}
func TestTwitterProviderPKCEHelpersAndAuthURL(t *testing.T) {
provider := NewTwitterProvider("twitter-client", "https://admin.example.com/login/oauth/callback")
verifierA, err := provider.GenerateCodeVerifier()
if err != nil {
t.Fatalf("GenerateCodeVerifier(first) failed: %v", err)
}
verifierB, err := provider.GenerateCodeVerifier()
if err != nil {
t.Fatalf("GenerateCodeVerifier(second) failed: %v", err)
}
if verifierA == "" || verifierB == "" {
t.Fatal("expected non-empty code verifiers")
}
if verifierA == verifierB {
t.Fatal("expected code verifiers to differ across calls")
}
if strings.Contains(verifierA, "=") || strings.Contains(verifierB, "=") {
t.Fatal("expected code verifiers to be base64url values without padding")
}
if provider.GenerateCodeChallenge(verifierA) != verifierA {
t.Fatal("expected current code challenge implementation to mirror the verifier")
}
authURL, err := provider.GetAuthURL()
if err != nil {
t.Fatalf("GetAuthURL failed: %v", err)
}
if authURL.CodeVerifier == "" || authURL.State == "" {
t.Fatal("expected auth url response to include verifier and state")
}
if authURL.Redirect != provider.RedirectURI {
t.Fatalf("expected redirect %q, got %q", provider.RedirectURI, authURL.Redirect)
}
parsed, err := url.Parse(authURL.URL)
if err != nil {
t.Fatalf("parse auth url failed: %v", err)
}
query := parsed.Query()
if query.Get("client_id") != "twitter-client" {
t.Fatalf("expected twitter client_id, got %q", query.Get("client_id"))
}
if query.Get("redirect_uri") != provider.RedirectURI {
t.Fatalf("expected redirect_uri %q, got %q", provider.RedirectURI, query.Get("redirect_uri"))
}
if query.Get("code_challenge") != authURL.CodeVerifier {
t.Fatalf("expected code challenge to equal verifier, got %q", query.Get("code_challenge"))
}
if query.Get("code_challenge_method") != "plain" {
t.Fatalf("expected code_challenge_method plain, got %q", query.Get("code_challenge_method"))
}
if query.Get("state") != authURL.State {
t.Fatalf("expected state %q, got %q", authURL.State, query.Get("state"))
}
}

View File

@@ -0,0 +1,649 @@
package providers
import (
"context"
"io"
"net/http"
"net/url"
"strings"
"testing"
)
func parseRequestForm(t *testing.T, req *http.Request) url.Values {
t.Helper()
body, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("read request body failed: %v", err)
}
values, err := url.ParseQuery(string(body))
if err != nil {
t.Fatalf("parse request body failed: %v", err)
}
return values
}
func TestPostFormWithContextSendsEncodedBody(t *testing.T) {
client := &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.Method != http.MethodPost {
t.Fatalf("expected POST request, got %s", req.Method)
}
if req.URL.String() != "https://oauth.example.com/token" {
t.Fatalf("unexpected endpoint: %s", req.URL.String())
}
if req.Header.Get("Content-Type") != "application/x-www-form-urlencoded" {
t.Fatalf("unexpected content type: %s", req.Header.Get("Content-Type"))
}
form := parseRequestForm(t, req)
if form.Get("code") != "auth-code" || form.Get("grant_type") != "authorization_code" {
t.Fatalf("unexpected form payload: %#v", form)
}
return oauthResponse(`{"ok":true}`), nil
}),
}
resp, err := postFormWithContext(context.Background(), client, "https://oauth.example.com/token", url.Values{
"code": {"auth-code"},
"grant_type": {"authorization_code"},
})
if err != nil {
t.Fatalf("postFormWithContext failed: %v", err)
}
defer resp.Body.Close()
}
func TestAlipayProviderExchangeCodeAndGetUserInfo(t *testing.T) {
ctx := context.Background()
provider := NewAlipayProvider("alipay-app", "", "https://example.com/callback", false)
t.Run("exchange code success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "openapi.alipay.com" || req.URL.Path != "/gateway.do" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
form := parseRequestForm(t, req)
if form.Get("method") != "alipay.system.oauth.token" || form.Get("code") != "auth-code" {
t.Fatalf("unexpected exchange payload: %#v", form)
}
return oauthResponse(`{"alipay_system_oauth_token_response":{"user_id":"2088","access_token":"ali-token","expires_in":3600}}`), nil
}))
tokenResp, err := provider.ExchangeCode(ctx, "auth-code")
if err != nil {
t.Fatalf("expected exchange success, got error %v", err)
}
if tokenResp.AccessToken != "ali-token" || tokenResp.UserID != "2088" {
t.Fatalf("unexpected alipay token response: %#v", tokenResp)
}
})
t.Run("exchange code rejects invalid structure", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
return oauthResponse(`{"unexpected":{}}`), nil
}))
_, err := provider.ExchangeCode(ctx, "auth-code")
if err == nil || !strings.Contains(err.Error(), "invalid alipay response structure") {
t.Fatalf("expected invalid structure error, got %v", err)
}
})
t.Run("get user info success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "openapi.alipay.com" || req.URL.Path != "/gateway.do" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
form := parseRequestForm(t, req)
if form.Get("method") != "alipay.user.info.share" || form.Get("auth_token") != "ali-token" {
t.Fatalf("unexpected user-info payload: %#v", form)
}
return oauthResponse(`{"alipay_user_info_share_response":{"user_id":"2088","nick_name":"Ali User","avatar":"https://cdn.example.com/avatar.png"}}`), nil
}))
userInfo, err := provider.GetUserInfo(ctx, "ali-token")
if err != nil {
t.Fatalf("expected user info success, got error %v", err)
}
if userInfo.UserID != "2088" || userInfo.Nickname != "Ali User" {
t.Fatalf("unexpected alipay user info: %#v", userInfo)
}
})
t.Run("get user info rejects invalid structure", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
return oauthResponse(`{"unexpected":{}}`), nil
}))
_, err := provider.GetUserInfo(ctx, "ali-token")
if err == nil || !strings.Contains(err.Error(), "invalid alipay user info response") {
t.Fatalf("expected invalid user info response error, got %v", err)
}
})
}
func TestDouyinProviderExchangeCodeAndGetUserInfo(t *testing.T) {
ctx := context.Background()
provider := NewDouyinProvider("douyin-key", "douyin-secret", "https://example.com/callback")
t.Run("exchange code success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "open.douyin.com" || req.URL.Path != "/oauth/access_token/" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
form := parseRequestForm(t, req)
if form.Get("client_key") != "douyin-key" || form.Get("code") != "auth-code" {
t.Fatalf("unexpected exchange payload: %#v", form)
}
return oauthResponse(`{"data":{"access_token":"douyin-token","open_id":"open-1"},"message":"success"}`), nil
}))
tokenResp, err := provider.ExchangeCode(ctx, "auth-code")
if err != nil {
t.Fatalf("expected exchange success, got error %v", err)
}
if tokenResp.Data.AccessToken != "douyin-token" || tokenResp.Data.OpenID != "open-1" {
t.Fatalf("unexpected douyin token response: %#v", tokenResp)
}
})
t.Run("exchange code rejects empty access token", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
return oauthResponse(`{"data":{},"message":"invalid code"}`), nil
}))
_, err := provider.ExchangeCode(ctx, "auth-code")
if err == nil || !strings.Contains(err.Error(), "invalid code") {
t.Fatalf("expected douyin api error, got %v", err)
}
})
t.Run("get user info success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "open.douyin.com" || req.URL.Path != "/oauth/userinfo/" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
if req.URL.Query().Get("open_id") != "open-1" {
t.Fatalf("unexpected open_id: %s", req.URL.Query().Get("open_id"))
}
return oauthResponse(`{"data":{"open_id":"open-1","union_id":"union-1","nickname":"Douyin User"}}`), nil
}))
userInfo, err := provider.GetUserInfo(ctx, "douyin-token", "open-1")
if err != nil {
t.Fatalf("expected user info success, got error %v", err)
}
if userInfo.Data.OpenID != "open-1" || userInfo.Data.Nickname != "Douyin User" {
t.Fatalf("unexpected douyin user info: %#v", userInfo)
}
})
}
func TestGitHubProviderExchangeCodeAndGetUserInfo(t *testing.T) {
ctx := context.Background()
provider := NewGitHubProvider("github-client", "github-secret", "https://example.com/callback")
t.Run("exchange code success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "github.com" || req.URL.Path != "/login/oauth/access_token" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
form := parseRequestForm(t, req)
if form.Get("client_id") != "github-client" || form.Get("code") != "auth-code" {
t.Fatalf("unexpected exchange payload: %#v", form)
}
return oauthResponse(`{"access_token":"gh-token","token_type":"bearer","scope":"read:user"}`), nil
}))
tokenResp, err := provider.ExchangeCode(ctx, "auth-code")
if err != nil {
t.Fatalf("expected exchange success, got error %v", err)
}
if tokenResp.AccessToken != "gh-token" {
t.Fatalf("unexpected github token response: %#v", tokenResp)
}
})
t.Run("exchange code rejects empty token", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
return oauthResponse(`{"token_type":"bearer"}`), nil
}))
_, err := provider.ExchangeCode(ctx, "auth-code")
if err == nil || !strings.Contains(err.Error(), "empty access token") {
t.Fatalf("expected empty access token error, got %v", err)
}
})
t.Run("get user info falls back to primary email", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
switch req.URL.Host + req.URL.Path {
case "api.github.com/user":
if req.Header.Get("Authorization") != "Bearer gh-token" {
t.Fatalf("unexpected auth header: %s", req.Header.Get("Authorization"))
}
return oauthResponse(`{"id":101,"login":"octocat","name":"The Octocat","email":"","avatar_url":"https://cdn.example.com/octocat.png"}`), nil
case "api.github.com/user/emails":
return oauthResponse(`[{"email":"secondary@example.com","primary":false,"verified":true},{"email":"primary@example.com","primary":true,"verified":true}]`), nil
default:
t.Fatalf("unexpected request: %s", req.URL.String())
return nil, nil
}
}))
userInfo, err := provider.GetUserInfo(ctx, "gh-token")
if err != nil {
t.Fatalf("expected user info success, got error %v", err)
}
if userInfo.Login != "octocat" || userInfo.Email != "primary@example.com" {
t.Fatalf("unexpected github user info: %#v", userInfo)
}
})
}
func TestGoogleProviderExchangeCodeAndRefreshToken(t *testing.T) {
ctx := context.Background()
provider := NewGoogleProvider("google-client", "google-secret", "https://example.com/callback")
t.Run("exchange code success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "oauth2.googleapis.com" || req.URL.Path != "/token" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
form := parseRequestForm(t, req)
if form.Get("grant_type") != "authorization_code" || form.Get("code") != "auth-code" {
t.Fatalf("unexpected exchange payload: %#v", form)
}
return oauthResponse(`{"access_token":"google-token","expires_in":3600,"refresh_token":"refresh-1","token_type":"Bearer"}`), nil
}))
tokenResp, err := provider.ExchangeCode(ctx, "auth-code")
if err != nil {
t.Fatalf("expected exchange success, got error %v", err)
}
if tokenResp.AccessToken != "google-token" || tokenResp.RefreshToken != "refresh-1" {
t.Fatalf("unexpected google token response: %#v", tokenResp)
}
})
t.Run("refresh token success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "oauth2.googleapis.com" || req.URL.Path != "/token" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
form := parseRequestForm(t, req)
if form.Get("grant_type") != "refresh_token" || form.Get("refresh_token") != "refresh-1" {
t.Fatalf("unexpected refresh payload: %#v", form)
}
return oauthResponse(`{"access_token":"google-token-2","expires_in":3600,"token_type":"Bearer"}`), nil
}))
tokenResp, err := provider.RefreshToken(ctx, "refresh-1")
if err != nil {
t.Fatalf("expected refresh success, got error %v", err)
}
if tokenResp.AccessToken != "google-token-2" {
t.Fatalf("unexpected google refresh response: %#v", tokenResp)
}
})
}
func TestQQProviderExchangeCodeAndValidateToken(t *testing.T) {
ctx := context.Background()
provider := NewQQProvider("qq-app", "qq-secret", "https://example.com/callback")
t.Run("exchange code success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "graph.qq.com" || req.URL.Path != "/oauth2.0/token" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
if req.URL.Query().Get("code") != "auth-code" {
t.Fatalf("unexpected code: %s", req.URL.Query().Get("code"))
}
return oauthResponse(`{"access_token":"qq-token","expires_in":3600,"refresh_token":"qq-refresh"}`), nil
}))
tokenResp, err := provider.ExchangeCode(ctx, "auth-code")
if err != nil {
t.Fatalf("expected exchange success, got error %v", err)
}
if tokenResp.AccessToken != "qq-token" || tokenResp.RefreshToken != "qq-refresh" {
t.Fatalf("unexpected qq token response: %#v", tokenResp)
}
})
t.Run("validate token success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "graph.qq.com" || req.URL.Path != "/oauth2.0/me" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"client_id":"qq-app","openid":"openid-1"}`), nil
}))
valid, err := provider.ValidateToken(ctx, "qq-token")
if err != nil {
t.Fatalf("expected validate success, got error %v", err)
}
if !valid {
t.Fatal("expected qq token to be valid")
}
})
}
func TestTwitterProviderNetworkMethods(t *testing.T) {
ctx := context.Background()
provider := NewTwitterProvider("twitter-client", "https://example.com/callback")
t.Run("exchange code rejects twitter error response", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "api.twitter.com" || req.URL.Path != "/2/oauth2/token" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
form := parseRequestForm(t, req)
if form.Get("grant_type") != "authorization_code" || form.Get("code_verifier") != "verifier-1" {
t.Fatalf("unexpected exchange payload: %#v", form)
}
return oauthResponse(`{"title":"Unauthorized","detail":"invalid verifier","status":401}`), nil
}))
_, err := provider.ExchangeCode(ctx, "auth-code", "verifier-1")
if err == nil || !strings.Contains(err.Error(), "invalid verifier") {
t.Fatalf("expected twitter api error, got %v", err)
}
})
t.Run("exchange code success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
return oauthResponse(`{"access_token":"twitter-token","refresh_token":"twitter-refresh","token_type":"bearer"}`), nil
}))
tokenResp, err := provider.ExchangeCode(ctx, "auth-code", "verifier-1")
if err != nil {
t.Fatalf("expected exchange success, got error %v", err)
}
if tokenResp.AccessToken != "twitter-token" {
t.Fatalf("unexpected twitter token response: %#v", tokenResp)
}
})
t.Run("get user info rejects twitter error response", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "api.twitter.com" || req.URL.Path != "/2/users/me" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"title":"Unauthorized","detail":"token expired","status":401}`), nil
}))
_, err := provider.GetUserInfo(ctx, "twitter-token")
if err == nil || !strings.Contains(err.Error(), "token expired") {
t.Fatalf("expected twitter user info error, got %v", err)
}
})
t.Run("get user info success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
return oauthResponse(`{"data":{"id":"user-1","name":"Twitter User","username":"tw-user"}}`), nil
}))
userInfo, err := provider.GetUserInfo(ctx, "twitter-token")
if err != nil {
t.Fatalf("expected user info success, got error %v", err)
}
if userInfo.Data.ID != "user-1" || userInfo.Data.Username != "tw-user" {
t.Fatalf("unexpected twitter user info: %#v", userInfo)
}
})
t.Run("refresh token success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
form := parseRequestForm(t, req)
if form.Get("grant_type") != "refresh_token" || form.Get("refresh_token") != "twitter-refresh" {
t.Fatalf("unexpected refresh payload: %#v", form)
}
return oauthResponse(`{"access_token":"twitter-token-2","refresh_token":"twitter-refresh-2","token_type":"bearer"}`), nil
}))
tokenResp, err := provider.RefreshToken(ctx, "twitter-refresh")
if err != nil {
t.Fatalf("expected refresh success, got error %v", err)
}
if tokenResp.AccessToken != "twitter-token-2" {
t.Fatalf("unexpected twitter refresh response: %#v", tokenResp)
}
})
t.Run("validate token returns false when user id is empty", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
return oauthResponse(`{"data":{"id":"","username":"anonymous"}}`), nil
}))
valid, err := provider.ValidateToken(ctx, "twitter-token")
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if valid {
t.Fatal("expected twitter token to be reported invalid")
}
})
t.Run("revoke token success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "api.twitter.com" || req.URL.Path != "/2/oauth2/revoke" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
form := parseRequestForm(t, req)
if form.Get("token") != "twitter-token" || form.Get("token_type_hint") != "access_token" {
t.Fatalf("unexpected revoke payload: %#v", form)
}
return oauthResponse(`{}`), nil
}))
if err := provider.RevokeToken(ctx, "twitter-token"); err != nil {
t.Fatalf("expected revoke success, got error %v", err)
}
})
}
func TestWeChatProviderExchangeUserInfoAndRefreshToken(t *testing.T) {
ctx := context.Background()
provider := NewWeChatProvider("wx-app", "wx-secret", "web")
t.Run("exchange code rejects api error", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "api.weixin.qq.com" || req.URL.Path != "/sns/oauth2/access_token" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"errcode":40029,"errmsg":"invalid code"}`), nil
}))
_, err := provider.ExchangeCode(ctx, "auth-code")
if err == nil || !strings.Contains(err.Error(), "wechat api error: 40029 - invalid code") {
t.Fatalf("expected wechat api error, got %v", err)
}
})
t.Run("exchange code success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
return oauthResponse(`{"access_token":"wx-token","refresh_token":"wx-refresh","openid":"openid-1","scope":"snsapi_login"}`), nil
}))
tokenResp, err := provider.ExchangeCode(ctx, "auth-code")
if err != nil {
t.Fatalf("expected exchange success, got error %v", err)
}
if tokenResp.AccessToken != "wx-token" || tokenResp.OpenID != "openid-1" {
t.Fatalf("unexpected wechat token response: %#v", tokenResp)
}
})
t.Run("get user info rejects api error", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "api.weixin.qq.com" || req.URL.Path != "/sns/userinfo" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"errcode":40003,"errmsg":"invalid openid"}`), nil
}))
_, err := provider.GetUserInfo(ctx, "wx-token", "openid-1")
if err == nil || !strings.Contains(err.Error(), "wechat api error: 40003 - invalid openid") {
t.Fatalf("expected wechat user info error, got %v", err)
}
})
t.Run("get user info success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
return oauthResponse(`{"openid":"openid-1","nickname":"WeChat User","province":"Shanghai"}`), nil
}))
userInfo, err := provider.GetUserInfo(ctx, "wx-token", "openid-1")
if err != nil {
t.Fatalf("expected user info success, got error %v", err)
}
if userInfo.OpenID != "openid-1" || userInfo.Nickname != "WeChat User" {
t.Fatalf("unexpected wechat user info: %#v", userInfo)
}
})
t.Run("refresh token rejects api error", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "api.weixin.qq.com" || req.URL.Path != "/sns/oauth2/refresh_token" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"errcode":40030,"errmsg":"invalid refresh token"}`), nil
}))
_, err := provider.RefreshToken(ctx, "wx-refresh")
if err == nil || !strings.Contains(err.Error(), "wechat api error: 40030 - invalid refresh token") {
t.Fatalf("expected wechat refresh error, got %v", err)
}
})
t.Run("refresh token success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
return oauthResponse(`{"access_token":"wx-token-2","refresh_token":"wx-refresh-2","openid":"openid-1"}`), nil
}))
tokenResp, err := provider.RefreshToken(ctx, "wx-refresh")
if err != nil {
t.Fatalf("expected refresh success, got error %v", err)
}
if tokenResp.AccessToken != "wx-token-2" {
t.Fatalf("unexpected wechat refresh response: %#v", tokenResp)
}
})
}
func TestWeiboProviderExchangeCodeAndGetUserInfo(t *testing.T) {
ctx := context.Background()
provider := NewWeiboProvider("weibo-app", "weibo-secret", "https://example.com/callback")
t.Run("exchange code success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "api.weibo.com" || req.URL.Path != "/oauth2/access_token" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
form := parseRequestForm(t, req)
if form.Get("client_id") != "weibo-app" || form.Get("code") != "auth-code" {
t.Fatalf("unexpected exchange payload: %#v", form)
}
return oauthResponse(`{"access_token":"weibo-token","expires_in":3600,"uid":"1001"}`), nil
}))
tokenResp, err := provider.ExchangeCode(ctx, "auth-code")
if err != nil {
t.Fatalf("expected exchange success, got error %v", err)
}
if tokenResp.AccessToken != "weibo-token" || tokenResp.UID != "1001" {
t.Fatalf("unexpected weibo token response: %#v", tokenResp)
}
})
t.Run("get user info rejects api error", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "api.weibo.com" || req.URL.Path != "/2/users/show.json" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"error":1,"error_code":21315,"request":"/2/users/show.json"}`), nil
}))
_, err := provider.GetUserInfo(ctx, "weibo-token", "1001")
if err == nil || !strings.Contains(err.Error(), "weibo api error: code=21315") {
t.Fatalf("expected weibo api error, got %v", err)
}
})
t.Run("get user info success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
return oauthResponse(`{"id":1001,"idstr":"1001","screen_name":"weibo-user","name":"Weibo User"}`), nil
}))
userInfo, err := provider.GetUserInfo(ctx, "weibo-token", "1001")
if err != nil {
t.Fatalf("expected user info success, got error %v", err)
}
if userInfo.ID != 1001 || userInfo.ScreenName != "weibo-user" {
t.Fatalf("unexpected weibo user info: %#v", userInfo)
}
})
}
func TestFacebookProviderExchangeValidateAndLongLivedToken(t *testing.T) {
ctx := context.Background()
provider := NewFacebookProvider("facebook-app", "facebook-secret", "https://example.com/callback")
t.Run("exchange code success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "graph.facebook.com" || req.URL.Path != "/v18.0/oauth/access_token" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
if req.URL.Query().Get("code") != "auth-code" {
t.Fatalf("unexpected code: %s", req.URL.Query().Get("code"))
}
return oauthResponse(`{"access_token":"fb-token","token_type":"bearer","expires_in":3600}`), nil
}))
tokenResp, err := provider.ExchangeCode(ctx, "auth-code")
if err != nil {
t.Fatalf("expected exchange success, got error %v", err)
}
if tokenResp.AccessToken != "fb-token" {
t.Fatalf("unexpected facebook token response: %#v", tokenResp)
}
})
t.Run("validate token returns false for empty id", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Path != "/v18.0/me" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"id":"","name":"No ID User"}`), nil
}))
valid, err := provider.ValidateToken(ctx, "fb-token")
if err != nil {
t.Fatalf("expected validate success, got error %v", err)
}
if valid {
t.Fatal("expected facebook token to be reported invalid")
}
})
t.Run("get long lived token success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Path != "/v18.0/oauth/access_token" || req.URL.Query().Get("grant_type") != "fb_exchange_token" {
t.Fatalf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"access_token":"fb-long-lived","token_type":"bearer","expires_in":5184000}`), nil
}))
tokenResp, err := provider.GetLongLivedToken(ctx, "fb-token")
if err != nil {
t.Fatalf("expected long-lived token success, got error %v", err)
}
if tokenResp.AccessToken != "fb-long-lived" {
t.Fatalf("unexpected facebook long-lived token response: %#v", tokenResp)
}
})
}

View File

@@ -0,0 +1,284 @@
package providers
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (fn roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return fn(req)
}
func useDefaultTransport(t *testing.T, fn roundTripFunc) {
t.Helper()
originalTransport := http.DefaultTransport
http.DefaultTransport = fn
t.Cleanup(func() {
http.DefaultTransport = originalTransport
})
}
func oauthResponse(body string) *http.Response {
return &http.Response{
StatusCode: http.StatusOK,
Body: io.NopCloser(strings.NewReader(body)),
Header: make(http.Header),
}
}
func TestQQProviderGetOpenIDAndUserInfoWithDefaultTransport(t *testing.T) {
ctx := context.Background()
provider := NewQQProvider("qq-app", "qq-secret", "https://example.com/callback")
t.Run("get openid success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "graph.qq.com" || req.URL.Path != "/oauth2.0/me" {
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"client_id":"qq-app","openid":"openid-123"}`), nil
}))
resp, err := provider.GetOpenID(ctx, "access-token")
if err != nil {
t.Fatalf("expected openid success, got error %v", err)
}
if resp.OpenID != "openid-123" || resp.ClientID != "qq-app" {
t.Fatalf("unexpected openid response: %#v", resp)
}
})
t.Run("get openid parse error", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "graph.qq.com" || req.URL.Path != "/oauth2.0/me" {
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`not-json`), nil
}))
_, err := provider.GetOpenID(ctx, "access-token")
if err == nil || !strings.Contains(err.Error(), "parse openid response failed") {
t.Fatalf("expected openid parse error, got %v", err)
}
})
t.Run("get user info api error", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "graph.qq.com" || req.URL.Path != "/user/get_user_info" {
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"ret":1001,"msg":"invalid token"}`), nil
}))
_, err := provider.GetUserInfo(ctx, "access-token", "openid-123")
if err == nil || !strings.Contains(err.Error(), "qq api error: invalid token") {
t.Fatalf("expected qq api error, got %v", err)
}
})
t.Run("get user info success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "graph.qq.com" || req.URL.Path != "/user/get_user_info" {
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"ret":0,"msg":"","nickname":"tester","gender":"male","city":"Shanghai"}`), nil
}))
info, err := provider.GetUserInfo(ctx, "access-token", "openid-123")
if err != nil {
t.Fatalf("expected user info success, got error %v", err)
}
if info.Nickname != "tester" || info.City != "Shanghai" {
t.Fatalf("unexpected user info response: %#v", info)
}
})
}
func TestWeiboProviderValidateTokenWithDefaultTransport(t *testing.T) {
ctx := context.Background()
provider := NewWeiboProvider("weibo-app", "weibo-secret", "https://example.com/callback")
tests := []struct {
name string
body string
wantValid bool
wantErrContains string
}{
{
name: "rejects error response",
body: `{"error":"invalid_token"}`,
wantValid: false,
},
{
name: "accepts expire_in response",
body: `{"expire_in":3600}`,
wantValid: true,
},
{
name: "rejects ambiguous response",
body: `{"uid":"123"}`,
wantValid: false,
},
{
name: "returns parse error",
body: `not-json`,
wantErrContains: "parse response failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "api.weibo.com" || req.URL.Path != "/oauth2/get_token_info" {
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
return oauthResponse(tt.body), nil
}))
valid, err := provider.ValidateToken(ctx, "access-token")
if tt.wantErrContains != "" {
if err == nil || !strings.Contains(err.Error(), tt.wantErrContains) {
t.Fatalf("expected error containing %q, got %v", tt.wantErrContains, err)
}
return
}
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if valid != tt.wantValid {
t.Fatalf("expected valid=%v, got %v", tt.wantValid, valid)
}
})
}
}
func TestWeChatProviderValidateTokenWithDefaultTransport(t *testing.T) {
ctx := context.Background()
provider := NewWeChatProvider("wx-app", "wx-secret", "web")
tests := []struct {
name string
body string
wantValid bool
wantErrContains string
}{
{
name: "accepts errcode zero",
body: `{"errcode":0,"errmsg":"ok"}`,
wantValid: true,
},
{
name: "rejects non-zero errcode",
body: `{"errcode":40003,"errmsg":"invalid openid"}`,
wantValid: false,
},
{
name: "returns parse error",
body: `not-json`,
wantErrContains: "parse response failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "api.weixin.qq.com" || req.URL.Path != "/sns/auth" {
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
return oauthResponse(tt.body), nil
}))
valid, err := provider.ValidateToken(ctx, "access-token", "openid-123")
if tt.wantErrContains != "" {
if err == nil || !strings.Contains(err.Error(), tt.wantErrContains) {
t.Fatalf("expected error containing %q, got %v", tt.wantErrContains, err)
}
return
}
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
if valid != tt.wantValid {
t.Fatalf("expected valid=%v, got %v", tt.wantValid, valid)
}
})
}
}
func TestGoogleProviderValidateTokenWithDefaultTransport(t *testing.T) {
ctx := context.Background()
provider := NewGoogleProvider("google-client", "google-secret", "https://example.com/callback")
t.Run("validate token success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "www.googleapis.com" || req.URL.Path != "/oauth2/v2/userinfo" {
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"id":"user-1","email":"user@example.com","name":"Google User"}`), nil
}))
valid, err := provider.ValidateToken(ctx, "access-token")
if err != nil {
t.Fatalf("expected success, got error %v", err)
}
if !valid {
t.Fatal("expected token to be valid")
}
})
t.Run("validate token parse error", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "www.googleapis.com" || req.URL.Path != "/oauth2/v2/userinfo" {
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`not-json`), nil
}))
valid, err := provider.ValidateToken(ctx, "access-token")
if err == nil || !strings.Contains(err.Error(), "parse user info failed") {
t.Fatalf("expected user info parse error, got valid=%v err=%v", valid, err)
}
})
}
func TestFacebookProviderGetUserInfoWithDefaultTransport(t *testing.T) {
ctx := context.Background()
provider := NewFacebookProvider("facebook-app", "facebook-secret", "https://example.com/callback")
t.Run("facebook api error", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "graph.facebook.com" || req.URL.Path != "/v18.0/me" {
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"error":{"message":"token expired","type":"OAuthException","code":190}}`), nil
}))
_, err := provider.GetUserInfo(ctx, "access-token")
if err == nil || !strings.Contains(err.Error(), "facebook api error: token expired") {
t.Fatalf("expected facebook api error, got %v", err)
}
})
t.Run("facebook success", func(t *testing.T) {
useDefaultTransport(t, roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL.Host != "graph.facebook.com" || req.URL.Path != "/v18.0/me" {
return nil, fmt.Errorf("unexpected request: %s", req.URL.String())
}
return oauthResponse(`{"id":"user-1","name":"Facebook User","email":"fb@example.com","picture":{"data":{"url":"https://cdn.example.com/a.png"}}}`), nil
}))
info, err := provider.GetUserInfo(ctx, "access-token")
if err != nil {
t.Fatalf("expected user info success, got error %v", err)
}
if info.ID != "user-1" || info.Picture.Data.URL == "" {
t.Fatalf("unexpected facebook user info response: %#v", info)
}
})
}

View File

@@ -0,0 +1,191 @@
package providers
import (
"net/url"
"strings"
"testing"
)
func TestAdditionalProviderStateGeneratorsProduceDistinctTokens(t *testing.T) {
tests := []struct {
name string
generateState func() (string, error)
}{
{
name: "facebook",
generateState: func() (string, error) {
return NewFacebookProvider("app-id", "secret", "https://admin.example.com/callback").GenerateState()
},
},
{
name: "qq",
generateState: func() (string, error) {
return NewQQProvider("app-id", "secret", "https://admin.example.com/callback").GenerateState()
},
},
{
name: "weibo",
generateState: func() (string, error) {
return NewWeiboProvider("app-id", "secret", "https://admin.example.com/callback").GenerateState()
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
stateA, err := tc.generateState()
if err != nil {
t.Fatalf("GenerateState(first) failed: %v", err)
}
stateB, err := tc.generateState()
if err != nil {
t.Fatalf("GenerateState(second) failed: %v", err)
}
if stateA == "" || stateB == "" {
t.Fatal("expected non-empty generated states")
}
if stateA == stateB {
t.Fatal("expected generated states to differ between calls")
}
})
}
}
func TestAdditionalProviderAuthURLs(t *testing.T) {
tests := []struct {
name string
buildURL func(t *testing.T) (string, string)
expectedHost string
expectedPath string
expectedKey string
expectedValue string
expectedClause string
}{
{
name: "facebook",
buildURL: func(t *testing.T) (string, string) {
t.Helper()
redirectURI := "https://admin.example.com/login/oauth/callback?from=fb"
authURL, err := NewFacebookProvider("fb-app-id", "fb-secret", redirectURI).GetAuthURL("fb-state")
if err != nil {
t.Fatalf("GetAuthURL failed: %v", err)
}
return authURL.URL, redirectURI
},
expectedHost: "www.facebook.com",
expectedPath: "/v18.0/dialog/oauth",
expectedKey: "client_id",
expectedValue: "fb-app-id",
expectedClause: "scope=email,public_profile",
},
{
name: "qq",
buildURL: func(t *testing.T) (string, string) {
t.Helper()
redirectURI := "https://admin.example.com/login/oauth/callback?from=qq"
authURL, err := NewQQProvider("qq-app-id", "qq-secret", redirectURI).GetAuthURL("qq-state")
if err != nil {
t.Fatalf("GetAuthURL failed: %v", err)
}
return authURL.URL, redirectURI
},
expectedHost: "graph.qq.com",
expectedPath: "/oauth2.0/authorize",
expectedKey: "client_id",
expectedValue: "qq-app-id",
expectedClause: "scope=get_user_info",
},
{
name: "weibo",
buildURL: func(t *testing.T) (string, string) {
t.Helper()
redirectURI := "https://admin.example.com/login/oauth/callback?from=weibo"
authURL, err := NewWeiboProvider("wb-app-id", "wb-secret", redirectURI).GetAuthURL("wb-state")
if err != nil {
t.Fatalf("GetAuthURL failed: %v", err)
}
return authURL.URL, redirectURI
},
expectedHost: "api.weibo.com",
expectedPath: "/oauth2/authorize",
expectedKey: "client_id",
expectedValue: "wb-app-id",
expectedClause: "response_type=code",
},
{
name: "douyin",
buildURL: func(t *testing.T) (string, string) {
t.Helper()
redirectURI := "https://admin.example.com/login/oauth/callback?from=douyin"
authURL, err := NewDouyinProvider("dy-client", "dy-secret", redirectURI).GetAuthURL("dy-state")
if err != nil {
t.Fatalf("GetAuthURL failed: %v", err)
}
return authURL, redirectURI
},
expectedHost: "open.douyin.com",
expectedPath: "/platform/oauth/connect",
expectedKey: "client_key",
expectedValue: "dy-client",
expectedClause: "scope=user_info",
},
{
name: "alipay",
buildURL: func(t *testing.T) (string, string) {
t.Helper()
redirectURI := "https://admin.example.com/login/oauth/callback?from=alipay"
authURL, err := NewAlipayProvider("ali-app-id", "private-key", redirectURI, false).GetAuthURL("ali-state")
if err != nil {
t.Fatalf("GetAuthURL failed: %v", err)
}
return authURL, redirectURI
},
expectedHost: "openauth.alipay.com",
expectedPath: "/oauth2/publicAppAuthorize.htm",
expectedKey: "app_id",
expectedValue: "ali-app-id",
expectedClause: "scope=auth_user",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
authURL, redirectURI := tc.buildURL(t)
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("parse auth url failed: %v", err)
}
if parsed.Host != tc.expectedHost {
t.Fatalf("expected host %q, got %q", tc.expectedHost, parsed.Host)
}
if parsed.Path != tc.expectedPath {
t.Fatalf("expected path %q, got %q", tc.expectedPath, parsed.Path)
}
query := parsed.Query()
if query.Get(tc.expectedKey) != tc.expectedValue {
t.Fatalf("expected %s=%q, got %q", tc.expectedKey, tc.expectedValue, query.Get(tc.expectedKey))
}
if query.Get("redirect_uri") != redirectURI {
t.Fatalf("expected redirect_uri %q, got %q", redirectURI, query.Get("redirect_uri"))
}
if !strings.Contains(authURL, tc.expectedClause) {
t.Fatalf("expected auth url to contain %q, got %q", tc.expectedClause, authURL)
}
})
}
}
func TestAlipayProviderUsesExpectedGatewayForSandboxAndProduction(t *testing.T) {
productionProvider := NewAlipayProvider("prod-app-id", "private-key", "https://admin.example.com/callback", false)
if gateway := productionProvider.getGateway(); gateway != "https://openapi.alipay.com/gateway.do" {
t.Fatalf("expected production gateway, got %q", gateway)
}
sandboxProvider := NewAlipayProvider("sandbox-app-id", "private-key", "https://admin.example.com/callback", true)
if gateway := sandboxProvider.getGateway(); gateway != "https://openapi-sandbox.dl.alipaydev.com/gateway.do" {
t.Fatalf("expected sandbox gateway, got %q", gateway)
}
}

View File

@@ -0,0 +1,124 @@
package providers
import (
"net/url"
"strings"
"testing"
)
func TestGitHubProviderGetAuthURLEscapesRedirectAndState(t *testing.T) {
provider := NewGitHubProvider("client-id", "client-secret", "https://admin.example.com/login/oauth/callback")
authURL, err := provider.GetAuthURL("state value")
if err != nil {
t.Fatalf("GetAuthURL failed: %v", err)
}
parsed, err := url.Parse(authURL)
if err != nil {
t.Fatalf("parse auth url failed: %v", err)
}
query := parsed.Query()
if query.Get("client_id") != "client-id" {
t.Fatalf("expected client_id to be propagated, got %q", query.Get("client_id"))
}
if query.Get("redirect_uri") != "https://admin.example.com/login/oauth/callback" {
t.Fatalf("expected redirect_uri to be propagated, got %q", query.Get("redirect_uri"))
}
if query.Get("state") != "state value" {
t.Fatalf("expected state to be propagated, got %q", query.Get("state"))
}
if !strings.Contains(query.Get("scope"), "read:user") {
t.Fatalf("expected GitHub scope to include read:user, got %q", query.Get("scope"))
}
}
func TestGoogleProviderGenerateStateAndBuildAuthURL(t *testing.T) {
provider := NewGoogleProvider("google-client", "google-secret", "https://admin.example.com/login/oauth/callback")
stateA, err := provider.GenerateState()
if err != nil {
t.Fatalf("GenerateState failed: %v", err)
}
stateB, err := provider.GenerateState()
if err != nil {
t.Fatalf("GenerateState failed: %v", err)
}
if stateA == "" || stateB == "" {
t.Fatal("expected non-empty generated states")
}
if stateA == stateB {
t.Fatal("expected generated states to be unique across calls")
}
authURL, err := provider.GetAuthURL("redirect-state")
if err != nil {
t.Fatalf("GetAuthURL failed: %v", err)
}
if authURL.State != "redirect-state" {
t.Fatalf("expected auth url state to be preserved, got %q", authURL.State)
}
if authURL.Redirect != provider.RedirectURI {
t.Fatalf("expected redirect uri to be preserved, got %q", authURL.Redirect)
}
if !strings.Contains(authURL.URL, "response_type=code") {
t.Fatalf("expected google auth url to request authorization code flow, got %q", authURL.URL)
}
}
func TestWeChatProviderGetAuthURLSupportsKnownTypes(t *testing.T) {
tests := []struct {
name string
oauthType string
expectedHost string
expectedPath string
}{
{
name: "web login",
oauthType: "web",
expectedHost: "open.weixin.qq.com",
expectedPath: "/connect/qrconnect",
},
{
name: "public account login",
oauthType: "mp",
expectedHost: "open.weixin.qq.com",
expectedPath: "/connect/oauth2/authorize",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
provider := NewWeChatProvider("wx-app-id", "wx-app-secret", tc.oauthType)
authURL, err := provider.GetAuthURL("https://admin.example.com/login/oauth/callback", "wechat-state")
if err != nil {
t.Fatalf("GetAuthURL failed: %v", err)
}
parsed, err := url.Parse(authURL.URL)
if err != nil {
t.Fatalf("parse auth url failed: %v", err)
}
if parsed.Host != tc.expectedHost {
t.Fatalf("expected host %q, got %q", tc.expectedHost, parsed.Host)
}
if parsed.Path != tc.expectedPath {
t.Fatalf("expected path %q, got %q", tc.expectedPath, parsed.Path)
}
if authURL.State != "wechat-state" {
t.Fatalf("expected state to be preserved, got %q", authURL.State)
}
})
}
}
func TestWeChatProviderRejectsUnsupportedOAuthType(t *testing.T) {
provider := NewWeChatProvider("wx-app-id", "wx-app-secret", "mini")
if _, err := provider.GetAuthURL("https://admin.example.com/login/oauth/callback", "state"); err == nil {
t.Fatal("expected unsupported oauth type error")
}
}

View File

@@ -0,0 +1,202 @@
package providers
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"
)
// QQProvider QQ OAuth提供者
type QQProvider struct {
AppID string
AppKey string
RedirectURI string
}
// QQAuthURLResponse QQ授权URL响应
type QQAuthURLResponse struct {
URL string `json:"url"`
State string `json:"state"`
Redirect string `json:"redirect,omitempty"`
}
// QQTokenResponse QQ Token响应
type QQTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
}
// QQOpenIDResponse QQ OpenID响应
type QQOpenIDResponse struct {
ClientID string `json:"client_id"`
OpenID string `json:"openid"`
}
// QQUserInfo QQ用户信息
type QQUserInfo struct {
Ret int `json:"ret"`
Msg string `json:"msg"`
Nickname string `json:"nickname"`
Gender string `json:"gender"` // 男, 女
Province string `json:"province"`
City string `json:"city"`
Year string `json:"year"`
FigureURL string `json:"figureurl"`
FigureURL1 string `json:"figureurl_1"`
FigureURL2 string `json:"figureurl_2"`
}
// NewQQProvider 创建QQ OAuth提供者
func NewQQProvider(appID, appKey, redirectURI string) *QQProvider {
return &QQProvider{
AppID: appID,
AppKey: appKey,
RedirectURI: redirectURI,
}
}
// GenerateState 生成随机状态码
func (q *QQProvider) GenerateState() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// GetAuthURL 获取QQ授权URL
func (q *QQProvider) GetAuthURL(state string) (*QQAuthURLResponse, error) {
authURL := fmt.Sprintf(
"https://graph.qq.com/oauth2.0/authorize?response_type=code&client_id=%s&redirect_uri=%s&scope=get_user_info&state=%s",
q.AppID,
url.QueryEscape(q.RedirectURI),
state,
)
return &QQAuthURLResponse{
URL: authURL,
State: state,
Redirect: q.RedirectURI,
}, nil
}
// ExchangeCode 用授权码换取访问令牌
func (q *QQProvider) ExchangeCode(ctx context.Context, code string) (*QQTokenResponse, error) {
tokenURL := fmt.Sprintf(
"https://graph.qq.com/oauth2.0/token?grant_type=authorization_code&client_id=%s&client_secret=%s&code=%s&redirect_uri=%s&fmt=json",
q.AppID,
q.AppKey,
code,
url.QueryEscape(q.RedirectURI),
)
req, err := http.NewRequestWithContext(ctx, "GET", tokenURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var tokenResp QQTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
return &tokenResp, nil
}
// GetOpenID 用访问令牌获取OpenID
func (q *QQProvider) GetOpenID(ctx context.Context, accessToken string) (*QQOpenIDResponse, error) {
openIDURL := fmt.Sprintf(
"https://graph.qq.com/oauth2.0/me?access_token=%s&fmt=json",
accessToken,
)
req, err := http.NewRequestWithContext(ctx, "GET", openIDURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var openIDResp QQOpenIDResponse
if err := json.Unmarshal(body, &openIDResp); err != nil {
return nil, fmt.Errorf("parse openid response failed: %w", err)
}
return &openIDResp, nil
}
// GetUserInfo 获取QQ用户信息
func (q *QQProvider) GetUserInfo(ctx context.Context, accessToken, openID string) (*QQUserInfo, error) {
userInfoURL := fmt.Sprintf(
"https://graph.qq.com/user/get_user_info?access_token=%s&oauth_consumer_key=%s&openid=%s&format=json",
accessToken,
q.AppID,
openID,
)
req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var userInfo QQUserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("parse user info failed: %w", err)
}
if userInfo.Ret != 0 {
return nil, fmt.Errorf("qq api error: %s", userInfo.Msg)
}
return &userInfo, nil
}
// ValidateToken 验证访问令牌是否有效
func (q *QQProvider) ValidateToken(ctx context.Context, accessToken string) (bool, error) {
_, err := q.GetOpenID(ctx, accessToken)
if err != nil {
return false, err
}
return true, nil
}

View File

@@ -0,0 +1,264 @@
package providers
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"
)
// TwitterProvider Twitter OAuth提供者 (OAuth 2.0 with PKCE)
type TwitterProvider struct {
ClientID string
RedirectURI string
}
// TwitterAuthURLResponse Twitter授权URL响应
type TwitterAuthURLResponse struct {
URL string `json:"url"`
CodeVerifier string `json:"code_verifier"`
State string `json:"state"`
Redirect string `json:"redirect,omitempty"`
}
// TwitterTokenResponse Twitter Token响应
type TwitterTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
Scope string `json:"scope"`
}
// TwitterUserInfo Twitter用户信息
type TwitterUserInfo struct {
Data struct {
ID string `json:"id"`
Name string `json:"name"`
Username string `json:"username"`
CreatedAt string `json:"created_at"`
Description string `json:"description"`
PublicMetrics struct {
FollowersCount int `json:"followers_count"`
FollowingCount int `json:"following_count"`
TweetCount int `json:"tweet_count"`
ListedCount int `json:"listed_count"`
} `json:"public_metrics"`
ProfileImageURL string `json:"profile_image_url"`
} `json:"data"`
}
// TwitterErrorResponse Twitter错误响应
type TwitterErrorResponse struct {
Title string `json:"title"`
Detail string `json:"detail"`
Type string `json:"type"`
Status int `json:"status"`
}
// NewTwitterProvider 创建Twitter OAuth提供者
func NewTwitterProvider(clientID, redirectURI string) *TwitterProvider {
return &TwitterProvider{
ClientID: clientID,
RedirectURI: redirectURI,
}
}
// GenerateCodeVerifier 生成PKCE Code Verifier
func (t *TwitterProvider) GenerateCodeVerifier() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(b), nil
}
// GenerateCodeChallenge 从Code Verifier生成Code Challenge
func (t *TwitterProvider) GenerateCodeChallenge(verifier string) string {
// 简化的base64编码实际应用中应该使用SHA256哈希
return verifier
}
// GenerateState 生成随机状态码
func (t *TwitterProvider) GenerateState() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// GetAuthURL 获取Twitter授权URL (OAuth 2.0 with PKCE)
func (t *TwitterProvider) GetAuthURL() (*TwitterAuthURLResponse, error) {
verifier, err := t.GenerateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("generate code verifier failed: %w", err)
}
challenge := t.GenerateCodeChallenge(verifier)
state, err := t.GenerateState()
if err != nil {
return nil, fmt.Errorf("generate state failed: %w", err)
}
authURL := fmt.Sprintf(
"https://twitter.com/i/oauth2/authorize?response_type=code&client_id=%s&redirect_uri=%s&scope=tweet.read%%20users.read%%20offline.access&state=%s&code_challenge=%s&code_challenge_method=plain",
t.ClientID,
url.QueryEscape(t.RedirectURI),
state,
challenge,
)
return &TwitterAuthURLResponse{
URL: authURL,
CodeVerifier: verifier,
State: state,
Redirect: t.RedirectURI,
}, nil
}
// ExchangeCode 用授权码换取访问令牌
func (t *TwitterProvider) ExchangeCode(ctx context.Context, code, codeVerifier string) (*TwitterTokenResponse, error) {
tokenURL := "https://api.twitter.com/2/oauth2/token"
data := url.Values{}
data.Set("code", code)
data.Set("grant_type", "authorization_code")
data.Set("client_id", t.ClientID)
data.Set("redirect_uri", t.RedirectURI)
data.Set("code_verifier", codeVerifier)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := postFormWithContext(ctx, client, tokenURL, data)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
// 检查错误响应
var errResp TwitterErrorResponse
if err := json.Unmarshal(body, &errResp); err == nil && errResp.Detail != "" {
return nil, fmt.Errorf("twitter api error: %s - %s", errResp.Title, errResp.Detail)
}
var tokenResp TwitterTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
return &tokenResp, nil
}
// GetUserInfo 获取Twitter用户信息
func (t *TwitterProvider) GetUserInfo(ctx context.Context, accessToken string) (*TwitterUserInfo, error) {
userInfoURL := "https://api.twitter.com/2/users/me?user.fields=created_at,description,public_metrics,profile_image_url"
req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
// 检查错误响应
var errResp TwitterErrorResponse
if err := json.Unmarshal(body, &errResp); err == nil && errResp.Detail != "" {
return nil, fmt.Errorf("twitter api error: %s - %s", errResp.Title, errResp.Detail)
}
var userInfo TwitterUserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("parse user info failed: %w", err)
}
return &userInfo, nil
}
// RefreshToken 刷新访问令牌
func (t *TwitterProvider) RefreshToken(ctx context.Context, refreshToken string) (*TwitterTokenResponse, error) {
tokenURL := "https://api.twitter.com/2/oauth2/token"
data := url.Values{}
data.Set("refresh_token", refreshToken)
data.Set("grant_type", "refresh_token")
data.Set("client_id", t.ClientID)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := postFormWithContext(ctx, client, tokenURL, data)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var errResp TwitterErrorResponse
if err := json.Unmarshal(body, &errResp); err == nil && errResp.Detail != "" {
return nil, fmt.Errorf("twitter api error: %s - %s", errResp.Title, errResp.Detail)
}
var tokenResp TwitterTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
return &tokenResp, nil
}
// ValidateToken 验证访问令牌是否有效
func (t *TwitterProvider) ValidateToken(ctx context.Context, accessToken string) (bool, error) {
userInfo, err := t.GetUserInfo(ctx, accessToken)
if err != nil {
return false, err
}
return userInfo != nil && userInfo.Data.ID != "", nil
}
// RevokeToken 撤销访问令牌
func (t *TwitterProvider) RevokeToken(ctx context.Context, accessToken string) error {
revokeURL := "https://api.twitter.com/2/oauth2/revoke"
data := url.Values{}
data.Set("token", accessToken)
data.Set("client_id", t.ClientID)
data.Set("token_type_hint", "access_token")
client := &http.Client{Timeout: 10 * time.Second}
resp, err := postFormWithContext(ctx, client, revokeURL, data)
if err != nil {
return fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
if _, err := readOAuthResponseBody(resp); err != nil {
return fmt.Errorf("revoke token failed: %w", err)
}
return nil
}

View File

@@ -0,0 +1,258 @@
package providers
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"
)
// WeChatProvider 微信OAuth提供者
type WeChatProvider struct {
AppID string
AppSecret string
Type string // "web" for 扫码登录, "mp" for 公众号, "mini" for 小程序
}
// WeChatAuthURLResponse 获取授权URL响应
type WeChatAuthURLResponse struct {
URL string `json:"url"`
State string `json:"state"`
Redirect string `json:"redirect,omitempty"`
}
// WeChatTokenResponse 微信Token响应
type WeChatTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token"`
OpenID string `json:"openid"`
Scope string `json:"scope"`
UnionID string `json:"unionid,omitempty"`
}
// WeChatUserInfo 微信用户信息
type WeChatUserInfo struct {
OpenID string `json:"openid"`
Nickname string `json:"nickname"`
Sex int `json:"sex"` // 1男性, 2女性, 0未知
Province string `json:"province"`
City string `json:"city"`
Country string `json:"country"`
HeadImgURL string `json:"headimgurl"`
UnionID string `json:"unionid,omitempty"`
}
// WeChatErrorCode 微信错误码
type WeChatErrorCode struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
// NewWeChatProvider 创建微信OAuth提供者
func NewWeChatProvider(appID, appSecret, oAuthType string) *WeChatProvider {
return &WeChatProvider{
AppID: appID,
AppSecret: appSecret,
Type: oAuthType,
}
}
// GenerateState 生成随机状态码
func (w *WeChatProvider) GenerateState() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// GetAuthURL 获取微信授权URL
func (w *WeChatProvider) GetAuthURL(redirectURI, state string) (*WeChatAuthURLResponse, error) {
var authURL string
switch w.Type {
case "web":
// 微信扫码登录 (开放平台)
authURL = fmt.Sprintf(
"https://open.weixin.qq.com/connect/qrconnect?appid=%s&redirect_uri=%s&response_type=code&scope=snsapi_login&state=%s#wechat_redirect",
w.AppID,
url.QueryEscape(redirectURI),
state,
)
case "mp":
// 微信公众号登录
authURL = fmt.Sprintf(
"https://open.weixin.qq.com/connect/oauth2/authorize?appid=%s&redirect_uri=%s&response_type=code&scope=snsapi_userinfo&state=%s#wechat_redirect",
w.AppID,
url.QueryEscape(redirectURI),
state,
)
default:
return nil, fmt.Errorf("unsupported wechat oauth type: %s", w.Type)
}
return &WeChatAuthURLResponse{
URL: authURL,
State: state,
Redirect: redirectURI,
}, nil
}
// ExchangeCode 用授权码换取访问令牌
func (w *WeChatProvider) ExchangeCode(ctx context.Context, code string) (*WeChatTokenResponse, error) {
tokenURL := fmt.Sprintf(
"https://api.weixin.qq.com/sns/oauth2/access_token?appid=%s&secret=%s&code=%s&grant_type=authorization_code",
w.AppID,
w.AppSecret,
code,
)
req, err := http.NewRequestWithContext(ctx, "GET", tokenURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
// 检查是否返回错误
var errResp WeChatErrorCode
if err := json.Unmarshal(body, &errResp); err == nil && errResp.ErrCode != 0 {
return nil, fmt.Errorf("wechat api error: %d - %s", errResp.ErrCode, errResp.ErrMsg)
}
var tokenResp WeChatTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
return &tokenResp, nil
}
// GetUserInfo 获取微信用户信息
func (w *WeChatProvider) GetUserInfo(ctx context.Context, accessToken, openID string) (*WeChatUserInfo, error) {
userInfoURL := fmt.Sprintf(
"https://api.weixin.qq.com/sns/userinfo?access_token=%s&openid=%s&lang=zh_CN",
accessToken,
openID,
)
req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
// 检查是否返回错误
var errResp WeChatErrorCode
if err := json.Unmarshal(body, &errResp); err == nil && errResp.ErrCode != 0 {
return nil, fmt.Errorf("wechat api error: %d - %s", errResp.ErrCode, errResp.ErrMsg)
}
var userInfo WeChatUserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("parse user info failed: %w", err)
}
return &userInfo, nil
}
// RefreshToken 刷新访问令牌
func (w *WeChatProvider) RefreshToken(ctx context.Context, refreshToken string) (*WeChatTokenResponse, error) {
refreshURL := fmt.Sprintf(
"https://api.weixin.qq.com/sns/oauth2/refresh_token?appid=%s&grant_type=refresh_token&refresh_token=%s",
w.AppID,
refreshToken,
)
req, err := http.NewRequestWithContext(ctx, "GET", refreshURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var errResp WeChatErrorCode
if err := json.Unmarshal(body, &errResp); err == nil && errResp.ErrCode != 0 {
return nil, fmt.Errorf("wechat api error: %d - %s", errResp.ErrCode, errResp.ErrMsg)
}
var tokenResp WeChatTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
return &tokenResp, nil
}
// ValidateToken 验证访问令牌是否有效
func (w *WeChatProvider) ValidateToken(ctx context.Context, accessToken, openID string) (bool, error) {
validateURL := fmt.Sprintf(
"https://api.weixin.qq.com/sns/auth?access_token=%s&openid=%s",
accessToken,
openID,
)
req, err := http.NewRequestWithContext(ctx, "GET", validateURL, nil)
if err != nil {
return false, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return false, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return false, fmt.Errorf("read response failed: %w", err)
}
var result struct {
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg"`
}
if err := json.Unmarshal(body, &result); err != nil {
return false, fmt.Errorf("parse response failed: %w", err)
}
return result.ErrCode == 0, nil
}

View File

@@ -0,0 +1,201 @@
package providers
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"time"
)
// WeiboProvider 微博OAuth提供者
type WeiboProvider struct {
AppKey string
AppSecret string
RedirectURI string
}
// WeiboAuthURLResponse 微博授权URL响应
type WeiboAuthURLResponse struct {
URL string `json:"url"`
State string `json:"state"`
Redirect string `json:"redirect,omitempty"`
}
// WeiboTokenResponse 微博Token响应
type WeiboTokenResponse struct {
AccessToken string `json:"access_token"`
ExpiresIn int `json:"expires_in"`
RemindIn string `json:"remind_in"`
UID string `json:"uid"`
}
// WeiboUserInfo 微博用户信息
type WeiboUserInfo struct {
ID int64 `json:"id"`
IDStr string `json:"idstr"`
ScreenName string `json:"screen_name"`
Name string `json:"name"`
Province string `json:"province"`
City string `json:"city"`
Location string `json:"location"`
Description string `json:"description"`
URL string `json:"url"`
ProfileImageURL string `json:"profile_image_url"`
Gender string `json:"gender"` // m:男, f:女, n:未知
FollowersCount int `json:"followers_count"`
FriendsCount int `json:"friends_count"`
StatusesCount int `json:"statuses_count"`
}
// NewWeiboProvider 创建微博OAuth提供者
func NewWeiboProvider(appKey, appSecret, redirectURI string) *WeiboProvider {
return &WeiboProvider{
AppKey: appKey,
AppSecret: appSecret,
RedirectURI: redirectURI,
}
}
// GenerateState 生成随机状态码
func (w *WeiboProvider) GenerateState() (string, error) {
b := make([]byte, 32)
_, err := rand.Read(b)
if err != nil {
return "", err
}
return base64.URLEncoding.EncodeToString(b), nil
}
// GetAuthURL 获取微博授权URL
func (w *WeiboProvider) GetAuthURL(state string) (*WeiboAuthURLResponse, error) {
authURL := fmt.Sprintf(
"https://api.weibo.com/oauth2/authorize?client_id=%s&redirect_uri=%s&response_type=code&state=%s",
w.AppKey,
url.QueryEscape(w.RedirectURI),
state,
)
return &WeiboAuthURLResponse{
URL: authURL,
State: state,
Redirect: w.RedirectURI,
}, nil
}
// ExchangeCode 用授权码换取访问令牌
func (w *WeiboProvider) ExchangeCode(ctx context.Context, code string) (*WeiboTokenResponse, error) {
tokenURL := "https://api.weibo.com/oauth2/access_token"
data := url.Values{}
data.Set("client_id", w.AppKey)
data.Set("client_secret", w.AppSecret)
data.Set("grant_type", "authorization_code")
data.Set("code", code)
data.Set("redirect_uri", w.RedirectURI)
client := &http.Client{Timeout: 10 * time.Second}
resp, err := postFormWithContext(ctx, client, tokenURL, data)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
var tokenResp WeiboTokenResponse
if err := json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("parse token response failed: %w", err)
}
return &tokenResp, nil
}
// GetUserInfo 获取微博用户信息
func (w *WeiboProvider) GetUserInfo(ctx context.Context, accessToken, uid string) (*WeiboUserInfo, error) {
userInfoURL := fmt.Sprintf(
"https://api.weibo.com/2/users/show.json?access_token=%s&uid=%s",
accessToken,
uid,
)
req, err := http.NewRequestWithContext(ctx, "GET", userInfoURL, nil)
if err != nil {
return nil, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return nil, fmt.Errorf("read response failed: %w", err)
}
// 微博错误响应
var errResp struct {
Error int `json:"error"`
ErrorCode int `json:"error_code"`
Request string `json:"request"`
}
if err := json.Unmarshal(body, &errResp); err == nil && errResp.Error != 0 {
return nil, fmt.Errorf("weibo api error: code=%d", errResp.ErrorCode)
}
var userInfo WeiboUserInfo
if err := json.Unmarshal(body, &userInfo); err != nil {
return nil, fmt.Errorf("parse user info failed: %w", err)
}
return &userInfo, nil
}
// ValidateToken 验证访问令牌是否有效
func (w *WeiboProvider) ValidateToken(ctx context.Context, accessToken string) (bool, error) {
// 微博没有专门的token验证接口通过获取API token信息来验证
tokenInfoURL := fmt.Sprintf("https://api.weibo.com/oauth2/get_token_info?access_token=%s", accessToken)
req, err := http.NewRequestWithContext(ctx, "GET", tokenInfoURL, nil)
if err != nil {
return false, fmt.Errorf("create request failed: %w", err)
}
client := &http.Client{Timeout: 10 * time.Second}
resp, err := client.Do(req)
if err != nil {
return false, fmt.Errorf("request failed: %w", err)
}
defer resp.Body.Close()
body, err := readOAuthResponseBody(resp)
if err != nil {
return false, fmt.Errorf("read response failed: %w", err)
}
var result map[string]interface{}
if err := json.Unmarshal(body, &result); err != nil {
return false, fmt.Errorf("parse response failed: %w", err)
}
// 如果返回了错误说明token无效
if _, ok := result["error"]; ok {
return false, nil
}
// 如果有expire_in字段说明token有效
if _, ok := result["expire_in"]; ok {
return true, nil
}
return false, nil
}