Files
user-system/internal/auth/oauth_utils.go

197 lines
4.5 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package auth
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"time"
"golang.org/x/oauth2"
)
// StateStore OAuth状态存储
type StateStore struct {
states map[string]time.Time
mu sync.RWMutex
}
var stateStore = &StateStore{
states: make(map[string]time.Time),
}
// GenerateState 生成OAuth状态参数
func GenerateState() (string, error) {
b := make([]byte, 32)
if _, err := rand.Read(b); err != nil {
return "", fmt.Errorf("generate state failed: %w", err)
}
state := base64.URLEncoding.EncodeToString(b)
// 存储状态10分钟过期
stateStore.mu.Lock()
stateStore.states[state] = time.Now().Add(10 * time.Minute)
stateStore.mu.Unlock()
return state, nil
}
// ValidateState 验证OAuth状态参数
func ValidateState(state string) bool {
stateStore.mu.Lock()
defer stateStore.mu.Unlock()
expireTime, ok := stateStore.states[state]
if !ok {
return false
}
// 检查是否过期
if time.Now().After(expireTime) {
delete(stateStore.states, state)
return false
}
// 使用后删除
delete(stateStore.states, state)
return true
}
// CleanupStates 清理过期的状态
func CleanupStates() {
stateStore.mu.Lock()
defer stateStore.mu.Unlock()
now := time.Now()
for state, expireTime := range stateStore.states {
if now.After(expireTime) {
delete(stateStore.states, state)
}
}
}
// HTTPClient OAuth HTTP客户端
var HTTPClient = &http.Client{
Timeout: 30 * time.Second,
}
// Get 发送GET请求
func Get(url string) (*http.Response, error) {
return HTTPClient.Get(url)
}
// PostForm 发送POST表单请求
func PostForm(url string, data url.Values) (*http.Response, error) {
return HTTPClient.PostForm(url, data)
}
// GetJSON 发送GET请求并解析JSON响应
func GetJSON(url string, result interface{}) error {
resp, err := Get(url)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("HTTP request failed with status %d", resp.StatusCode)
}
return json.NewDecoder(resp.Body).Decode(result)
}
// PostFormJSON 发送POST表单请求并解析JSON响应
func PostFormJSON(url string, data url.Values, result interface{}) error {
resp, err := PostForm(url, data)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("HTTP request failed with status %d", resp.StatusCode)
}
return json.NewDecoder(resp.Body).Decode(result)
}
// BuildAuthURL 构建标准OAuth授权URL
func BuildAuthURL(baseURL, clientID, redirectURI, scope, state string) string {
u, _ := url.Parse(baseURL)
q := u.Query()
q.Set("client_id", clientID)
q.Set("redirect_uri", redirectURI)
q.Set("scope", scope)
q.Set("state", state)
q.Set("response_type", "code")
u.RawQuery = q.Encode()
return u.String()
}
// ParseAccessTokenResponse 解析访问令牌响应
func ParseAccessTokenResponse(resp []byte) (*OAuthToken, error) {
var result struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
if err := json.Unmarshal(resp, &result); err != nil {
return nil, err
}
return &OAuthToken{
AccessToken: result.AccessToken,
RefreshToken: result.RefreshToken,
ExpiresIn: result.ExpiresIn,
TokenType: result.TokenType,
}, nil
}
// ParseQueryAccessToken 解析查询字符串形式的访问令牌用于某些返回text/plain的API
func ParseQueryAccessToken(body string) (accessToken string, err error) {
values, err := url.ParseQuery(body)
if err != nil {
return "", err
}
return values.Get("access_token"), nil
}
// ParseJSONPResponse 解析JSONP响应用于QQ等平台
func ParseJSONPResponse(jsonp string) (map[string]interface{}, error) {
// 移除callback包装
start := strings.Index(jsonp, "(")
end := strings.LastIndex(jsonp, ")")
if start == -1 || end == -1 {
return nil, fmt.Errorf("invalid JSONP format")
}
jsonStr := jsonp[start+1 : end]
var result map[string]interface{}
if err := json.Unmarshal([]byte(jsonStr), &result); err != nil {
return nil, err
}
return result, nil
}
// ToOAuth2Config 转换为oauth2.Config
func ToOAuth2Config(config *OAuthConfig) *oauth2.Config {
return &oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURI,
Scopes: strings.Split(config.Scope, ","),
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
},
}
}