125 lines
3.6 KiB
Go
125 lines
3.6 KiB
Go
package providers
|
|
|
|
import (
|
|
"net/url"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestGitHubProviderGetAuthURLEscapesRedirectAndState(t *testing.T) {
|
|
provider := NewGitHubProvider("client-id", "client-secret", "https://admin.example.com/login/oauth/callback")
|
|
|
|
authURL, err := provider.GetAuthURL("state value")
|
|
if err != nil {
|
|
t.Fatalf("GetAuthURL failed: %v", err)
|
|
}
|
|
|
|
parsed, err := url.Parse(authURL)
|
|
if err != nil {
|
|
t.Fatalf("parse auth url failed: %v", err)
|
|
}
|
|
|
|
query := parsed.Query()
|
|
if query.Get("client_id") != "client-id" {
|
|
t.Fatalf("expected client_id to be propagated, got %q", query.Get("client_id"))
|
|
}
|
|
if query.Get("redirect_uri") != "https://admin.example.com/login/oauth/callback" {
|
|
t.Fatalf("expected redirect_uri to be propagated, got %q", query.Get("redirect_uri"))
|
|
}
|
|
if query.Get("state") != "state value" {
|
|
t.Fatalf("expected state to be propagated, got %q", query.Get("state"))
|
|
}
|
|
if !strings.Contains(query.Get("scope"), "read:user") {
|
|
t.Fatalf("expected GitHub scope to include read:user, got %q", query.Get("scope"))
|
|
}
|
|
}
|
|
|
|
func TestGoogleProviderGenerateStateAndBuildAuthURL(t *testing.T) {
|
|
provider := NewGoogleProvider("google-client", "google-secret", "https://admin.example.com/login/oauth/callback")
|
|
|
|
stateA, err := provider.GenerateState()
|
|
if err != nil {
|
|
t.Fatalf("GenerateState failed: %v", err)
|
|
}
|
|
stateB, err := provider.GenerateState()
|
|
if err != nil {
|
|
t.Fatalf("GenerateState failed: %v", err)
|
|
}
|
|
|
|
if stateA == "" || stateB == "" {
|
|
t.Fatal("expected non-empty generated states")
|
|
}
|
|
if stateA == stateB {
|
|
t.Fatal("expected generated states to be unique across calls")
|
|
}
|
|
|
|
authURL, err := provider.GetAuthURL("redirect-state")
|
|
if err != nil {
|
|
t.Fatalf("GetAuthURL failed: %v", err)
|
|
}
|
|
if authURL.State != "redirect-state" {
|
|
t.Fatalf("expected auth url state to be preserved, got %q", authURL.State)
|
|
}
|
|
if authURL.Redirect != provider.RedirectURI {
|
|
t.Fatalf("expected redirect uri to be preserved, got %q", authURL.Redirect)
|
|
}
|
|
if !strings.Contains(authURL.URL, "response_type=code") {
|
|
t.Fatalf("expected google auth url to request authorization code flow, got %q", authURL.URL)
|
|
}
|
|
}
|
|
|
|
func TestWeChatProviderGetAuthURLSupportsKnownTypes(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
oauthType string
|
|
expectedHost string
|
|
expectedPath string
|
|
}{
|
|
{
|
|
name: "web login",
|
|
oauthType: "web",
|
|
expectedHost: "open.weixin.qq.com",
|
|
expectedPath: "/connect/qrconnect",
|
|
},
|
|
{
|
|
name: "public account login",
|
|
oauthType: "mp",
|
|
expectedHost: "open.weixin.qq.com",
|
|
expectedPath: "/connect/oauth2/authorize",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
provider := NewWeChatProvider("wx-app-id", "wx-app-secret", tc.oauthType)
|
|
authURL, err := provider.GetAuthURL("https://admin.example.com/login/oauth/callback", "wechat-state")
|
|
if err != nil {
|
|
t.Fatalf("GetAuthURL failed: %v", err)
|
|
}
|
|
|
|
parsed, err := url.Parse(authURL.URL)
|
|
if err != nil {
|
|
t.Fatalf("parse auth url failed: %v", err)
|
|
}
|
|
|
|
if parsed.Host != tc.expectedHost {
|
|
t.Fatalf("expected host %q, got %q", tc.expectedHost, parsed.Host)
|
|
}
|
|
if parsed.Path != tc.expectedPath {
|
|
t.Fatalf("expected path %q, got %q", tc.expectedPath, parsed.Path)
|
|
}
|
|
if authURL.State != "wechat-state" {
|
|
t.Fatalf("expected state to be preserved, got %q", authURL.State)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWeChatProviderRejectsUnsupportedOAuthType(t *testing.T) {
|
|
provider := NewWeChatProvider("wx-app-id", "wx-app-secret", "mini")
|
|
|
|
if _, err := provider.GetAuthURL("https://admin.example.com/login/oauth/callback", "state"); err == nil {
|
|
t.Fatal("expected unsupported oauth type error")
|
|
}
|
|
}
|