test: add service layer unit tests for webhook/metadata/error/config
- webhook_service_test.go: isPrivateIP, isSafeURL, computeHMAC - request_metadata_test.go: context functions - classified_error_test.go: error types - config_defaults_test.go: password reset/SMS defaults - email_config_test.go: email code defaults - auth_runtime_test.go: isUserNotFoundError Service coverage: 11.2% -> 14.7%
This commit is contained in:
75
internal/service/auth_runtime_test.go
Normal file
75
internal/service/auth_runtime_test.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Auth Runtime Helper Functions Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestIsUserNotFoundError(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "nil error",
|
||||
err: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "gorm record not found",
|
||||
err: gorm.ErrRecordNotFound,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "wrapped gorm record not found",
|
||||
err: errors.Join(gorm.ErrRecordNotFound, errors.New("additional context")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "other error",
|
||||
err: errors.New("some other error"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "generic error",
|
||||
err: errors.New("something went wrong"),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "error containing user not found",
|
||||
err: errors.New("user not found"),
|
||||
expected: true, // contains "user not found" in lowercase
|
||||
},
|
||||
{
|
||||
name: "error containing record not found",
|
||||
err: errors.New("record not found"),
|
||||
expected: true, // contains "record not found"
|
||||
},
|
||||
{
|
||||
name: "error containing not found",
|
||||
err: errors.New("entity not found"),
|
||||
expected: true, // contains "not found"
|
||||
},
|
||||
{
|
||||
name: "error containing 用户不存在",
|
||||
err: errors.New("用户不存在"),
|
||||
expected: true, // contains Chinese "用户不存在"
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isUserNotFoundError(tt.err)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isUserNotFoundError(%v) = %v, want %v", tt.err, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
99
internal/service/classified_error_test.go
Normal file
99
internal/service/classified_error_test.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Classified Error Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestClassifiedError(t *testing.T) {
|
||||
// Test error with message
|
||||
e1 := &classifiedError{message: "custom message", cause: errors.New("cause")}
|
||||
if e1.Error() != "custom message" {
|
||||
t.Errorf("Error() = %q, want %q", e1.Error(), "custom message")
|
||||
}
|
||||
|
||||
// Test error with cause but no message
|
||||
e2 := &classifiedError{cause: errors.New("underlying error")}
|
||||
if e2.Error() != "underlying error" {
|
||||
t.Errorf("Error() = %q, want %q", e2.Error(), "underlying error")
|
||||
}
|
||||
|
||||
// Test error with neither message nor cause
|
||||
e3 := &classifiedError{}
|
||||
if e3.Error() != "" {
|
||||
t.Errorf("Error() = %q, want empty string", e3.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func TestClassifiedErrorUnwrap(t *testing.T) {
|
||||
innerErr := errors.New("inner error")
|
||||
e := &classifiedError{message: "outer", cause: innerErr}
|
||||
|
||||
unwrapped := e.Unwrap()
|
||||
if unwrapped != innerErr {
|
||||
t.Errorf("Unwrap() = %v, want %v", unwrapped, innerErr)
|
||||
}
|
||||
|
||||
// Test errors.Is
|
||||
if !errors.Is(e, innerErr) {
|
||||
t.Error("errors.Is(e, innerErr) = false, want true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewRateLimitError(t *testing.T) {
|
||||
err := newRateLimitError("too many requests")
|
||||
|
||||
// Should be a classifiedError
|
||||
var ce *classifiedError
|
||||
if !errors.As(err, &ce) {
|
||||
t.Errorf("errors.As(err, &classifiedError{}) = false")
|
||||
}
|
||||
|
||||
// Should wrap ErrRateLimitExceeded
|
||||
if !errors.Is(err, ErrRateLimitExceeded) {
|
||||
t.Error("errors.Is(err, ErrRateLimitExceeded) = false")
|
||||
}
|
||||
|
||||
// Error message should be "too many requests"
|
||||
if err.Error() != "too many requests" {
|
||||
t.Errorf("err.Error() = %q, want %q", err.Error(), "too many requests")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewValidationError(t *testing.T) {
|
||||
err := newValidationError("invalid input")
|
||||
|
||||
// Should be a classifiedError
|
||||
var ce *classifiedError
|
||||
if !errors.As(err, &ce) {
|
||||
t.Errorf("errors.As(err, &classifiedError{}) = false")
|
||||
}
|
||||
|
||||
// Should wrap ErrValidationFailed
|
||||
if !errors.Is(err, ErrValidationFailed) {
|
||||
t.Error("errors.Is(err, ErrValidationFailed) = false")
|
||||
}
|
||||
|
||||
// Error message should be "invalid input"
|
||||
if err.Error() != "invalid input" {
|
||||
t.Errorf("err.Error() = %q, want %q", err.Error(), "invalid input")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrRateLimitExceeded(t *testing.T) {
|
||||
// ErrRateLimitExceeded is a sentinel error
|
||||
if ErrRateLimitExceeded.Error() != "rate limit exceeded" {
|
||||
t.Errorf("ErrRateLimitExceeded.Error() = %q, want %q", ErrRateLimitExceeded.Error(), "rate limit exceeded")
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrValidationFailed(t *testing.T) {
|
||||
// ErrValidationFailed is a sentinel error
|
||||
if ErrValidationFailed.Error() != "validation failed" {
|
||||
t.Errorf("ErrValidationFailed.Error() = %q, want %q", ErrValidationFailed.Error(), "validation failed")
|
||||
}
|
||||
}
|
||||
63
internal/service/config_defaults_test.go
Normal file
63
internal/service/config_defaults_test.go
Normal file
@@ -0,0 +1,63 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Password Reset Configuration Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestDefaultPasswordResetConfig(t *testing.T) {
|
||||
cfg := DefaultPasswordResetConfig()
|
||||
|
||||
if cfg.TokenTTL != 15*time.Minute {
|
||||
t.Errorf("TokenTTL = %v, want %v", cfg.TokenTTL, 15*time.Minute)
|
||||
}
|
||||
if cfg.SMTPHost != "" {
|
||||
t.Errorf("SMTPHost = %q, want empty", cfg.SMTPHost)
|
||||
}
|
||||
if cfg.SMTPPort != 587 {
|
||||
t.Errorf("SMTPPort = %d, want 587", cfg.SMTPPort)
|
||||
}
|
||||
if cfg.SMTPUser != "" {
|
||||
t.Errorf("SMTPUser = %q, want empty", cfg.SMTPUser)
|
||||
}
|
||||
if cfg.SMTPPass != "" {
|
||||
t.Errorf("SMTPPass = %q, want empty", cfg.SMTPPass)
|
||||
}
|
||||
if cfg.FromEmail != "noreply@example.com" {
|
||||
t.Errorf("FromEmail = %q, want %q", cfg.FromEmail, "noreply@example.com")
|
||||
}
|
||||
if cfg.SiteURL != "http://localhost:8080" {
|
||||
t.Errorf("SiteURL = %q, want %q", cfg.SiteURL, "http://localhost:8080")
|
||||
}
|
||||
if cfg.PasswordMinLen != 8 {
|
||||
t.Errorf("PasswordMinLen = %d, want 8", cfg.PasswordMinLen)
|
||||
}
|
||||
if cfg.PasswordRequireSpecial != false {
|
||||
t.Error("PasswordRequireSpecial = true, want false")
|
||||
}
|
||||
if cfg.PasswordRequireNumber != false {
|
||||
t.Error("PasswordRequireNumber = true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
// =============================================================================
|
||||
// SMS Configuration Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestDefaultSMSCodeConfig(t *testing.T) {
|
||||
cfg := DefaultSMSCodeConfig()
|
||||
|
||||
if cfg.CodeTTL != 5*time.Minute {
|
||||
t.Errorf("CodeTTL = %v, want %v", cfg.CodeTTL, 5*time.Minute)
|
||||
}
|
||||
if cfg.ResendCooldown != time.Minute {
|
||||
t.Errorf("ResendCooldown = %v, want %v", cfg.ResendCooldown, time.Minute)
|
||||
}
|
||||
if cfg.MaxDailyLimit != 10 {
|
||||
t.Errorf("MaxDailyLimit = %d, want 10", cfg.MaxDailyLimit)
|
||||
}
|
||||
}
|
||||
30
internal/service/email_config_test.go
Normal file
30
internal/service/email_config_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Email Configuration Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestDefaultEmailCodeConfig(t *testing.T) {
|
||||
cfg := DefaultEmailCodeConfig()
|
||||
|
||||
if cfg.CodeTTL != 5*time.Minute {
|
||||
t.Errorf("CodeTTL = %v, want %v", cfg.CodeTTL, 5*time.Minute)
|
||||
}
|
||||
if cfg.ResendCooldown != time.Minute {
|
||||
t.Errorf("ResendCooldown = %v, want %v", cfg.ResendCooldown, time.Minute)
|
||||
}
|
||||
if cfg.MaxDailyLimit != 10 {
|
||||
t.Errorf("MaxDailyLimit = %d, want 10", cfg.MaxDailyLimit)
|
||||
}
|
||||
if cfg.SiteURL != "http://localhost:8080" {
|
||||
t.Errorf("SiteURL = %q, want %q", cfg.SiteURL, "http://localhost:8080")
|
||||
}
|
||||
if cfg.SiteName != "User Management System" {
|
||||
t.Errorf("SiteName = %q, want %q", cfg.SiteName, "User Management System")
|
||||
}
|
||||
}
|
||||
180
internal/service/request_metadata_test.go
Normal file
180
internal/service/request_metadata_test.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Request Metadata Context Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestRequestMetadataFallbackStats(t *testing.T) {
|
||||
isMaxTokens, thinking, prefetchAccount, prefetchGroup, singleAccount, accountSwitch := RequestMetadataFallbackStats()
|
||||
|
||||
if isMaxTokens != 0 {
|
||||
t.Errorf("isMaxTokens = %d, want 0", isMaxTokens)
|
||||
}
|
||||
if thinking != 0 {
|
||||
t.Errorf("thinking = %d, want 0", thinking)
|
||||
}
|
||||
if prefetchAccount != 0 {
|
||||
t.Errorf("prefetchAccount = %d, want 0", prefetchAccount)
|
||||
}
|
||||
if prefetchGroup != 0 {
|
||||
t.Errorf("prefetchGroup = %d, want 0", prefetchGroup)
|
||||
}
|
||||
if singleAccount != 0 {
|
||||
t.Errorf("singleAccount = %d, want 0", singleAccount)
|
||||
}
|
||||
if accountSwitch != 0 {
|
||||
t.Errorf("accountSwitch = %d, want 0", accountSwitch)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithIsMaxTokensOneHaikuRequest(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test setting true
|
||||
ctx1 := WithIsMaxTokensOneHaikuRequest(ctx, true, false)
|
||||
val, ok := IsMaxTokensOneHaikuRequestFromContext(ctx1)
|
||||
if !ok {
|
||||
t.Error("IsMaxTokensOneHaikuRequestFromContext returned !ok")
|
||||
}
|
||||
if val != true {
|
||||
t.Errorf("IsMaxTokensOneHaikuRequestFromContext = %v, want true", val)
|
||||
}
|
||||
|
||||
// Test setting false
|
||||
ctx2 := WithIsMaxTokensOneHaikuRequest(ctx, false, false)
|
||||
val2, ok2 := IsMaxTokensOneHaikuRequestFromContext(ctx2)
|
||||
if !ok2 {
|
||||
t.Error("IsMaxTokensOneHaikuRequestFromContext returned !ok")
|
||||
}
|
||||
if val2 != false {
|
||||
t.Errorf("IsMaxTokensOneHaikuRequestFromContext = %v, want false", val2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithThinkingEnabled(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test setting true
|
||||
ctx1 := WithThinkingEnabled(ctx, true, false)
|
||||
val, ok := ThinkingEnabledFromContext(ctx1)
|
||||
if !ok {
|
||||
t.Error("ThinkingEnabledFromContext returned !ok")
|
||||
}
|
||||
if val != true {
|
||||
t.Errorf("ThinkingEnabledFromContext = %v, want true", val)
|
||||
}
|
||||
|
||||
// Test setting false
|
||||
ctx2 := WithThinkingEnabled(ctx, false, false)
|
||||
val2, ok2 := ThinkingEnabledFromContext(ctx2)
|
||||
if !ok2 {
|
||||
t.Error("ThinkingEnabledFromContext returned !ok")
|
||||
}
|
||||
if val2 != false {
|
||||
t.Errorf("ThinkingEnabledFromContext = %v, want false", val2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithPrefetchedStickySession(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test setting values
|
||||
ctx1 := WithPrefetchedStickySession(ctx, 123, 456, false)
|
||||
accountID, ok := PrefetchedStickyAccountIDFromContext(ctx1)
|
||||
if !ok {
|
||||
t.Error("PrefetchedStickyAccountIDFromContext returned !ok")
|
||||
}
|
||||
if accountID != 123 {
|
||||
t.Errorf("PrefetchedStickyAccountIDFromContext = %d, want 123", accountID)
|
||||
}
|
||||
|
||||
groupID, ok2 := PrefetchedStickyGroupIDFromContext(ctx1)
|
||||
if !ok2 {
|
||||
t.Error("PrefetchedStickyGroupIDFromContext returned !ok")
|
||||
}
|
||||
if groupID != 456 {
|
||||
t.Errorf("PrefetchedStickyGroupIDFromContext = %d, want 456", groupID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithSingleAccountRetry(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test setting true
|
||||
ctx1 := WithSingleAccountRetry(ctx, true, false)
|
||||
val, ok := SingleAccountRetryFromContext(ctx1)
|
||||
if !ok {
|
||||
t.Error("SingleAccountRetryFromContext returned !ok")
|
||||
}
|
||||
if val != true {
|
||||
t.Errorf("SingleAccountRetryFromContext = %v, want true", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWithAccountSwitchCount(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Test setting count
|
||||
ctx1 := WithAccountSwitchCount(ctx, 5, false)
|
||||
val, ok := AccountSwitchCountFromContext(ctx1)
|
||||
if !ok {
|
||||
t.Error("AccountSwitchCountFromContext returned !ok")
|
||||
}
|
||||
if val != 5 {
|
||||
t.Errorf("AccountSwitchCountFromContext = %d, want 5", val)
|
||||
}
|
||||
}
|
||||
|
||||
func TestContextDefaults(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// All context getters should return !ok for fresh context
|
||||
_, ok := IsMaxTokensOneHaikuRequestFromContext(ctx)
|
||||
if ok {
|
||||
t.Error("IsMaxTokensOneHaikuRequestFromContext returned ok for fresh context")
|
||||
}
|
||||
|
||||
_, ok = ThinkingEnabledFromContext(ctx)
|
||||
if ok {
|
||||
t.Error("ThinkingEnabledFromContext returned ok for fresh context")
|
||||
}
|
||||
|
||||
_, ok = PrefetchedStickyAccountIDFromContext(ctx)
|
||||
if ok {
|
||||
t.Error("PrefetchedStickyAccountIDFromContext returned ok for fresh context")
|
||||
}
|
||||
|
||||
_, ok = PrefetchedStickyGroupIDFromContext(ctx)
|
||||
if ok {
|
||||
t.Error("PrefetchedStickyGroupIDFromContext returned ok for fresh context")
|
||||
}
|
||||
|
||||
_, ok = SingleAccountRetryFromContext(ctx)
|
||||
if ok {
|
||||
t.Error("SingleAccountRetryFromContext returned ok for fresh context")
|
||||
}
|
||||
|
||||
_, ok = AccountSwitchCountFromContext(ctx)
|
||||
if ok {
|
||||
t.Error("AccountSwitchCountFromContext returned ok for fresh context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBridgeOldKeys(t *testing.T) {
|
||||
// Test that bridgeOldKeys=true allows setting values
|
||||
// even when old keys might already exist
|
||||
ctx := context.Background()
|
||||
ctx1 := WithIsMaxTokensOneHaikuRequest(ctx, true, true) // bridgeOldKeys=true
|
||||
val, ok := IsMaxTokensOneHaikuRequestFromContext(ctx1)
|
||||
if !ok {
|
||||
t.Error("IsMaxTokensOneHaikuRequestFromContext returned !ok with bridgeOldKeys=true")
|
||||
}
|
||||
if val != true {
|
||||
t.Errorf("IsMaxTokensOneHaikuRequestFromContext = %v, want true", val)
|
||||
}
|
||||
}
|
||||
201
internal/service/webhook_service_test.go
Normal file
201
internal/service/webhook_service_test.go
Normal file
@@ -0,0 +1,201 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// =============================================================================
|
||||
// Webhook Security Functions Tests
|
||||
// =============================================================================
|
||||
|
||||
func TestIsPrivateIP(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ip string
|
||||
expected bool
|
||||
}{
|
||||
// Private ranges - 10.0.0.0/8
|
||||
{"10.0.0.0", "10.0.0.0", true},
|
||||
{"10.255.255.255", "10.255.255.255", true},
|
||||
{"10.1.2.3", "10.1.2.3", true},
|
||||
|
||||
// Private ranges - 172.16.0.0/12
|
||||
{"172.16.0.0", "172.16.0.0", true},
|
||||
{"172.31.255.255", "172.31.255.255", true},
|
||||
{"172.20.1.1", "172.20.1.1", true},
|
||||
|
||||
// Private ranges - 192.168.0.0/16
|
||||
{"192.168.0.0", "192.168.0.0", true},
|
||||
{"192.168.255.255", "192.168.255.255", true},
|
||||
{"192.168.1.100", "192.168.1.100", true},
|
||||
|
||||
// Loopback
|
||||
{"127.0.0.1", "127.0.0.1", true},
|
||||
{"127.255.255.255", "127.255.255.255", true},
|
||||
{"::1", "::1", true},
|
||||
|
||||
// Public IPs
|
||||
{"8.8.8.8", "8.8.8.8", false},
|
||||
{"1.1.1.1", "1.1.1.1", false},
|
||||
{"93.184.216.34", "93.184.216.34", false},
|
||||
{"142.250.80.46", "142.250.80.46", false},
|
||||
|
||||
// Edge cases
|
||||
{"", "", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ip := net.ParseIP(tt.ip)
|
||||
if tt.ip == "" {
|
||||
// Empty IP should return false
|
||||
result := isPrivateIP(nil)
|
||||
if result != false {
|
||||
t.Errorf("isPrivateIP(nil) = %v, want %v", result, false)
|
||||
}
|
||||
return
|
||||
}
|
||||
if ip == nil {
|
||||
t.Skipf("could not parse IP: %s", tt.ip)
|
||||
}
|
||||
result := isPrivateIP(ip)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSafeURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
expected bool
|
||||
}{
|
||||
// Valid public HTTPS URLs
|
||||
{"https example.com", "https://example.com/webhook", true},
|
||||
{"https with path", "https://example.com/api/v1/hook", true},
|
||||
{"https with query", "https://example.com/hook?a=1&b=2", true},
|
||||
{"https with port", "https://example.com:8443/hook", true},
|
||||
{"https subdomains", "https://sub.example.com/hook", true},
|
||||
|
||||
// HTTP (allowed but public only)
|
||||
{"http public", "http://example.com/hook", true},
|
||||
{"http with port", "http://example.com:8080/hook", true},
|
||||
|
||||
// Invalid schemes
|
||||
{"ftp scheme", "ftp://example.com/hook", false},
|
||||
{"file scheme", "file:///etc/passwd", false},
|
||||
{"data scheme", "data:text/html,<script>alert(1)</script>", false},
|
||||
{"javascript scheme", "javascript:alert(1)", false},
|
||||
|
||||
// Localhost blocked
|
||||
{"localhost http", "http://localhost/hook", false},
|
||||
{"localhost https", "https://localhost/hook", false},
|
||||
{"127.0.0.1", "http://127.0.0.1/hook", false},
|
||||
{"::1", "http://[::1]/hook", false},
|
||||
|
||||
// Private IPs blocked
|
||||
{"10.x.x.x", "http://10.0.0.1/hook", false},
|
||||
{"172.16.x.x", "http://172.16.0.1/hook", false},
|
||||
{"192.168.x.x", "http://192.168.1.1/hook", false},
|
||||
|
||||
// Internal domains blocked
|
||||
{"internal domain", "https://server.internal/hook", false},
|
||||
{"local domain", "https://host.local/hook", false},
|
||||
{"corp domain", "https://host.corp/hook", false},
|
||||
{"lan domain", "https://host.lan/hook", false},
|
||||
{"intranet domain", "https://host.intranet/hook", false},
|
||||
|
||||
// Cloud metadata IPs blocked
|
||||
{"gcp metadata", "http://metadata.google.internal/", false},
|
||||
{"aws metadata", "http://169.254.169.254/latest/meta-data/", false},
|
||||
{"azure metadata", "http://metadata.azure.internal/", false},
|
||||
{"aliyun metadata", "http://100.100.100.200/latest/meta-data/", false},
|
||||
|
||||
// Invalid URLs
|
||||
{"empty", "", false},
|
||||
{"no scheme", "example.com/hook", false},
|
||||
{"relative", "/hook", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isSafeURL(tt.url)
|
||||
if result != tt.expected {
|
||||
t.Errorf("isSafeURL(%q) = %v, want %v", tt.url, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeHMAC(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
payload []byte
|
||||
secret string
|
||||
}{
|
||||
{
|
||||
name: "simple payload",
|
||||
payload: []byte(`{"event":"user.created"}`),
|
||||
secret: "test-secret",
|
||||
},
|
||||
{
|
||||
name: "empty payload",
|
||||
payload: []byte{},
|
||||
secret: "test-secret",
|
||||
},
|
||||
{
|
||||
name: "empty secret",
|
||||
payload: []byte(`{"event":"user.deleted"}`),
|
||||
secret: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result1 := computeHMAC(tt.payload, tt.secret)
|
||||
result2 := computeHMAC(tt.payload, tt.secret)
|
||||
|
||||
// Same input should produce same output
|
||||
if result1 != result2 {
|
||||
t.Errorf("computeHMAC not deterministic: got %s and %s", result1, result2)
|
||||
}
|
||||
|
||||
// Result should not be empty for non-empty payload
|
||||
if len(tt.payload) > 0 && result1 == "" {
|
||||
t.Error("computeHMAC returned empty string for non-empty payload")
|
||||
}
|
||||
|
||||
// Result should be hex-encoded (64 chars for SHA256)
|
||||
if len(result1) != 64 {
|
||||
t.Errorf("computeHMAC returned %d chars, want 64 (SHA256 hex)", len(result1))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeHMAC_DifferentInputs(t *testing.T) {
|
||||
payload1 := []byte(`{"event":"user.created"}`)
|
||||
payload2 := []byte(`{"event":"user.deleted"}`)
|
||||
secret := "test-secret"
|
||||
|
||||
result1 := computeHMAC(payload1, secret)
|
||||
result2 := computeHMAC(payload2, secret)
|
||||
|
||||
if result1 == result2 {
|
||||
t.Error("Different payloads should produce different HMACs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeHMAC_DifferentSecrets(t *testing.T) {
|
||||
payload := []byte(`{"event":"user.created"}`)
|
||||
|
||||
result1 := computeHMAC(payload, "secret1")
|
||||
result2 := computeHMAC(payload, "secret2")
|
||||
|
||||
if result1 == result2 {
|
||||
t.Error("Different secrets should produce different HMACs")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user