257 lines
6.5 KiB
Go
257 lines
6.5 KiB
Go
package providers
|
||
|
||
import (
|
||
"context"
|
||
"crypto"
|
||
"crypto/rand"
|
||
"crypto/rsa"
|
||
"crypto/sha256"
|
||
"crypto/x509"
|
||
"encoding/base64"
|
||
"encoding/json"
|
||
"encoding/pem"
|
||
"fmt"
|
||
"net/http"
|
||
"net/url"
|
||
"sort"
|
||
"strings"
|
||
"time"
|
||
)
|
||
|
||
// AlipayProvider 支付宝 OAuth提供者
|
||
// 支付宝使用 RSA2 签名(SHA256withRSA)
|
||
type AlipayProvider struct {
|
||
AppID string
|
||
PrivateKey string // RSA2 私钥(PKCS#8 PEM格式)
|
||
RedirectURI string
|
||
IsSandbox bool
|
||
}
|
||
|
||
// AlipayTokenResponse 支付宝 Token响应
|
||
type AlipayTokenResponse struct {
|
||
UserID string `json:"user_id"`
|
||
AccessToken string `json:"access_token"`
|
||
ExpiresIn int `json:"expires_in"`
|
||
RefreshToken string `json:"refresh_token"`
|
||
}
|
||
|
||
// AlipayUserInfo 支付宝用户信息
|
||
type AlipayUserInfo struct {
|
||
UserID string `json:"user_id"`
|
||
Nickname string `json:"nick_name"`
|
||
Avatar string `json:"avatar"`
|
||
Gender string `json:"gender"`
|
||
}
|
||
|
||
// NewAlipayProvider 创建支付宝 OAuth提供者
|
||
func NewAlipayProvider(appID, privateKey, redirectURI string, isSandbox bool) *AlipayProvider {
|
||
return &AlipayProvider{
|
||
AppID: appID,
|
||
PrivateKey: privateKey,
|
||
RedirectURI: redirectURI,
|
||
IsSandbox: isSandbox,
|
||
}
|
||
}
|
||
|
||
func (a *AlipayProvider) getGateway() string {
|
||
if a.IsSandbox {
|
||
return "https://openapi-sandbox.dl.alipaydev.com/gateway.do"
|
||
}
|
||
return "https://openapi.alipay.com/gateway.do"
|
||
}
|
||
|
||
// GetAuthURL 获取支付宝授权URL
|
||
func (a *AlipayProvider) GetAuthURL(state string) (string, error) {
|
||
authURL := fmt.Sprintf(
|
||
"https://openauth.alipay.com/oauth2/publicAppAuthorize.htm?app_id=%s&scope=auth_user&redirect_uri=%s&state=%s",
|
||
a.AppID,
|
||
url.QueryEscape(a.RedirectURI),
|
||
url.QueryEscape(state),
|
||
)
|
||
return authURL, nil
|
||
}
|
||
|
||
// ExchangeCode 用授权码换取 access_token
|
||
func (a *AlipayProvider) ExchangeCode(ctx context.Context, code string) (*AlipayTokenResponse, error) {
|
||
params := map[string]string{
|
||
"app_id": a.AppID,
|
||
"method": "alipay.system.oauth.token",
|
||
"charset": "UTF-8",
|
||
"sign_type": "RSA2",
|
||
"timestamp": time.Now().Format("2006-01-02 15:04:05"),
|
||
"version": "1.0",
|
||
"grant_type": "authorization_code",
|
||
"code": code,
|
||
}
|
||
|
||
if a.PrivateKey != "" {
|
||
sign, err := a.signParams(params)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("sign failed: %w", err)
|
||
}
|
||
params["sign"] = sign
|
||
}
|
||
|
||
form := url.Values{}
|
||
for k, v := range params {
|
||
form.Set(k, v)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", a.getGateway(),
|
||
strings.NewReader(form.Encode()))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("create request failed: %w", err)
|
||
}
|
||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||
|
||
client := &http.Client{Timeout: 10 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("request failed: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := readOAuthResponseBody(resp)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("read response failed: %w", err)
|
||
}
|
||
|
||
var rawResp map[string]json.RawMessage
|
||
if err := json.Unmarshal(body, &rawResp); err != nil {
|
||
return nil, fmt.Errorf("parse response failed: %w", err)
|
||
}
|
||
|
||
tokenData, ok := rawResp["alipay_system_oauth_token_response"]
|
||
if !ok {
|
||
return nil, fmt.Errorf("invalid alipay response structure")
|
||
}
|
||
|
||
var tokenResp AlipayTokenResponse
|
||
if err := json.Unmarshal(tokenData, &tokenResp); err != nil {
|
||
return nil, fmt.Errorf("parse token response failed: %w", err)
|
||
}
|
||
|
||
return &tokenResp, nil
|
||
}
|
||
|
||
// GetUserInfo 获取支付宝用户信息
|
||
func (a *AlipayProvider) GetUserInfo(ctx context.Context, accessToken string) (*AlipayUserInfo, error) {
|
||
params := map[string]string{
|
||
"app_id": a.AppID,
|
||
"method": "alipay.user.info.share",
|
||
"charset": "UTF-8",
|
||
"sign_type": "RSA2",
|
||
"timestamp": time.Now().Format("2006-01-02 15:04:05"),
|
||
"version": "1.0",
|
||
"auth_token": accessToken,
|
||
}
|
||
|
||
if a.PrivateKey != "" {
|
||
sign, err := a.signParams(params)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("sign failed: %w", err)
|
||
}
|
||
params["sign"] = sign
|
||
}
|
||
|
||
form := url.Values{}
|
||
for k, v := range params {
|
||
form.Set(k, v)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "POST", a.getGateway(),
|
||
strings.NewReader(form.Encode()))
|
||
if err != nil {
|
||
return nil, fmt.Errorf("create request failed: %w", err)
|
||
}
|
||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||
|
||
client := &http.Client{Timeout: 10 * time.Second}
|
||
resp, err := client.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("request failed: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
body, err := readOAuthResponseBody(resp)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("read response failed: %w", err)
|
||
}
|
||
|
||
var rawResp map[string]json.RawMessage
|
||
if err := json.Unmarshal(body, &rawResp); err != nil {
|
||
return nil, fmt.Errorf("parse response failed: %w", err)
|
||
}
|
||
|
||
userData, ok := rawResp["alipay_user_info_share_response"]
|
||
if !ok {
|
||
return nil, fmt.Errorf("invalid alipay user info response")
|
||
}
|
||
|
||
var userInfo AlipayUserInfo
|
||
if err := json.Unmarshal(userData, &userInfo); err != nil {
|
||
return nil, fmt.Errorf("parse user info failed: %w", err)
|
||
}
|
||
|
||
return &userInfo, nil
|
||
}
|
||
|
||
// signParams 使用 RSA2(SHA256withRSA)对参数签名
|
||
func (a *AlipayProvider) signParams(params map[string]string) (string, error) {
|
||
// 按字典序排列参数
|
||
keys := make([]string, 0, len(params))
|
||
for k := range params {
|
||
if k != "sign" {
|
||
keys = append(keys, k)
|
||
}
|
||
}
|
||
sort.Strings(keys)
|
||
|
||
var parts []string
|
||
for _, k := range keys {
|
||
parts = append(parts, k+"="+params[k])
|
||
}
|
||
signContent := strings.Join(parts, "&")
|
||
|
||
// 解析私钥
|
||
privKey, err := parseAlipayPrivateKey(a.PrivateKey)
|
||
if err != nil {
|
||
return "", fmt.Errorf("parse private key: %w", err)
|
||
}
|
||
|
||
// SHA256withRSA 签名
|
||
hash := sha256.Sum256([]byte(signContent))
|
||
signature, err := rsa.SignPKCS1v15(rand.Reader, privKey, crypto.SHA256, hash[:])
|
||
if err != nil {
|
||
return "", fmt.Errorf("rsa sign: %w", err)
|
||
}
|
||
|
||
return base64.StdEncoding.EncodeToString(signature), nil
|
||
}
|
||
|
||
// parseAlipayPrivateKey 解析支付宝私钥(支持 PKCS#8 和 PKCS#1)
|
||
func parseAlipayPrivateKey(pemStr string) (*rsa.PrivateKey, error) {
|
||
// 如果没有 PEM 头,添加 PKCS#8 头
|
||
if !strings.Contains(pemStr, "-----BEGIN") {
|
||
pemStr = "-----BEGIN PRIVATE KEY-----\n" + pemStr + "\n-----END PRIVATE KEY-----"
|
||
}
|
||
|
||
block, _ := pem.Decode([]byte(pemStr))
|
||
if block == nil {
|
||
return nil, fmt.Errorf("failed to decode PEM block")
|
||
}
|
||
|
||
// 尝试 PKCS#8
|
||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||
if err == nil {
|
||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||
if !ok {
|
||
return nil, fmt.Errorf("not an RSA private key")
|
||
}
|
||
return rsaKey, nil
|
||
}
|
||
|
||
// 尝试 PKCS#1
|
||
return x509.ParsePKCS1PrivateKey(block.Bytes)
|
||
}
|