package auth import ( "crypto/rand" "encoding/base64" "encoding/json" "fmt" "net/http" "net/url" "strings" "sync" "time" "golang.org/x/oauth2" ) // StateStore OAuth状态存储 type StateStore struct { states map[string]time.Time mu sync.RWMutex } var stateStore = &StateStore{ states: make(map[string]time.Time), } // GenerateState 生成OAuth状态参数 func GenerateState() (string, error) { b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", fmt.Errorf("generate state failed: %w", err) } state := base64.URLEncoding.EncodeToString(b) // 存储状态,10分钟过期 stateStore.mu.Lock() stateStore.states[state] = time.Now().Add(10 * time.Minute) stateStore.mu.Unlock() return state, nil } // ValidateState 验证OAuth状态参数 func ValidateState(state string) bool { stateStore.mu.Lock() defer stateStore.mu.Unlock() expireTime, ok := stateStore.states[state] if !ok { return false } // 检查是否过期 if time.Now().After(expireTime) { delete(stateStore.states, state) return false } // 使用后删除 delete(stateStore.states, state) return true } // CleanupStates 清理过期的状态 func CleanupStates() { stateStore.mu.Lock() defer stateStore.mu.Unlock() now := time.Now() for state, expireTime := range stateStore.states { if now.After(expireTime) { delete(stateStore.states, state) } } } // HTTPClient OAuth HTTP客户端 var HTTPClient = &http.Client{ Timeout: 30 * time.Second, } // Get 发送GET请求 func Get(url string) (*http.Response, error) { return HTTPClient.Get(url) } // PostForm 发送POST表单请求 func PostForm(url string, data url.Values) (*http.Response, error) { return HTTPClient.PostForm(url, data) } // GetJSON 发送GET请求并解析JSON响应 func GetJSON(url string, result interface{}) error { resp, err := Get(url) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("HTTP request failed with status %d", resp.StatusCode) } return json.NewDecoder(resp.Body).Decode(result) } // PostFormJSON 发送POST表单请求并解析JSON响应 func PostFormJSON(url string, data url.Values, result interface{}) error { resp, err := PostForm(url, data) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("HTTP request failed with status %d", resp.StatusCode) } return json.NewDecoder(resp.Body).Decode(result) } // BuildAuthURL 构建标准OAuth授权URL func BuildAuthURL(baseURL, clientID, redirectURI, scope, state string) string { u, _ := url.Parse(baseURL) q := u.Query() q.Set("client_id", clientID) q.Set("redirect_uri", redirectURI) q.Set("scope", scope) q.Set("state", state) q.Set("response_type", "code") u.RawQuery = q.Encode() return u.String() } // ParseAccessTokenResponse 解析访问令牌响应 func ParseAccessTokenResponse(resp []byte) (*OAuthToken, error) { var result struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int64 `json:"expires_in"` TokenType string `json:"token_type"` } if err := json.Unmarshal(resp, &result); err != nil { return nil, err } return &OAuthToken{ AccessToken: result.AccessToken, RefreshToken: result.RefreshToken, ExpiresIn: result.ExpiresIn, TokenType: result.TokenType, }, nil } // ParseQueryAccessToken 解析查询字符串形式的访问令牌(用于某些返回text/plain的API) func ParseQueryAccessToken(body string) (accessToken string, err error) { values, err := url.ParseQuery(body) if err != nil { return "", err } return values.Get("access_token"), nil } // ParseJSONPResponse 解析JSONP响应(用于QQ等平台) func ParseJSONPResponse(jsonp string) (map[string]interface{}, error) { // 移除callback包装 start := strings.Index(jsonp, "(") end := strings.LastIndex(jsonp, ")") if start == -1 || end == -1 { return nil, fmt.Errorf("invalid JSONP format") } jsonStr := jsonp[start+1 : end] var result map[string]interface{} if err := json.Unmarshal([]byte(jsonStr), &result); err != nil { return nil, err } return result, nil } // ToOAuth2Config 转换为oauth2.Config func ToOAuth2Config(config *OAuthConfig) *oauth2.Config { return &oauth2.Config{ ClientID: config.ClientID, ClientSecret: config.ClientSecret, RedirectURL: config.RedirectURI, Scopes: strings.Split(config.Scope, ","), Endpoint: oauth2.Endpoint{ AuthURL: config.AuthURL, TokenURL: config.TokenURL, }, } }