diff --git a/internal/security/validator_test.go b/internal/security/validator_test.go new file mode 100644 index 0000000..c5b7291 --- /dev/null +++ b/internal/security/validator_test.go @@ -0,0 +1,291 @@ +package security + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +// TestNewValidator 测试 Validator 创建 +func TestNewValidator(t *testing.T) { + v := NewValidator(8, true, true) + assert.NotNil(t, v) + assert.Equal(t, 8, v.passwordMinLength) + assert.True(t, v.passwordRequireSpecial) + assert.True(t, v.passwordRequireNumber) + + v2 := NewValidator(6, false, false) + assert.Equal(t, 6, v2.passwordMinLength) + assert.False(t, v2.passwordRequireSpecial) + assert.False(t, v2.passwordRequireNumber) +} + +// TestValidator_ValidateEmail 测试邮箱验证 +func TestValidator_ValidateEmail(t *testing.T) { + v := NewValidator(8, true, true) + + tests := []struct { + name string + email string + expected bool + }{ + {"empty", "", false}, + {"invalid", "invalid", false}, + {"no at", "test.example.com", false}, + {"no domain", "test@", false}, + {"no user", "@example.com", false}, + {"valid simple", "test@example.com", true}, + {"valid with dot", "test.user@example.com", true}, + {"valid with plus", "test+tag@example.com", true}, + {"valid subdomain", "test@mail.example.com", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := v.ValidateEmail(tt.email) + assert.Equal(t, tt.expected, got) + }) + } +} + +// TestValidator_ValidatePhone 测试手机号验证 +func TestValidator_ValidatePhone(t *testing.T) { + v := NewValidator(8, true, true) + + tests := []struct { + name string + phone string + expected bool + }{ + {"empty", "", false}, + {"invalid format", "12345678901", false}, + {"too short", "1380013800", false}, + {"too long", "138001380001", false}, + {"invalid prefix 1", "12800138000", false}, + {"invalid prefix 2", "10800138000", false}, + {"valid 13x", "13800138000", true}, + {"valid 15x", "15800138000", true}, + {"valid 18x", "18800138000", true}, + {"valid 19x", "19800138000", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := v.ValidatePhone(tt.phone) + assert.Equal(t, tt.expected, got) + }) + } +} + +// TestValidator_ValidateUsername 测试用户名验证 +func TestValidator_ValidateUsername(t *testing.T) { + v := NewValidator(8, true, true) + + tests := []struct { + name string + username string + expected bool + }{ + {"empty", "", false}, + {"too short", "abc", false}, + {"starts with number", "1abc", false}, + {"starts with underscore", "_abc", false}, + {"contains special", "abc@123", false}, + {"valid lowercase", "abc123", true}, + {"valid uppercase", "Abc123", true}, + {"valid with underscore", "abc_123", true}, + {"valid max length", "abcd1234abcd1234abcd", true}, + {"too long", "abcd1234abcd1234abcd1", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := v.ValidateUsername(tt.username) + assert.Equal(t, tt.expected, got) + }) + } +} + +// TestValidator_ValidatePassword 测试密码验证 +func TestValidator_ValidatePassword(t *testing.T) { + v := NewValidator(8, true, true) + + tests := []struct { + name string + password string + expected bool + }{ + {"too short", "Abc1!", false}, + {"no number", "Abcdefgh!", false}, + {"no special", "Abcdefgh1", false}, + {"no uppercase", "abcdefgh1!", false}, + {"no lowercase", "ABCDEFGH1!", false}, + {"valid complex", "Abcdef1!", true}, + {"valid longer", "Abcdefgh123!", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := v.ValidatePassword(tt.password) + assert.Equal(t, tt.expected, got) + }) + } +} + +// TestValidator_ValidateURL 测试 URL 验证 +func TestValidator_ValidateURL(t *testing.T) { + v := NewValidator(8, true, true) + + tests := []struct { + name string + url string + expected bool + }{ + {"empty", "", false}, + {"no scheme", "example.com", false}, + {"http", "http://example.com", true}, + {"https", "https://example.com", true}, + {"with path", "https://example.com/path", true}, + {"with query", "https://example.com?foo=bar", true}, + {"with fragment", "https://example.com#section", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := v.ValidateURL(tt.url) + assert.Equal(t, tt.expected, got) + }) + } +} + +// TestValidator_ValidateIP 测试 IP 验证 +func TestValidator_ValidateIP(t *testing.T) { + v := NewValidator(8, true, true) + + tests := []struct { + name string + ip string + expected bool + }{ + {"empty", "", false}, + {"invalid", "not-an-ip", false}, + {"IPv4 valid", "192.168.1.1", true}, + {"IPv4 invalid", "192.168.1.256", false}, + {"IPv6 valid", "::1", true}, + {"IPv6 valid full", "2001:0db8:85a3:0000:0000:8a2e:0370:7334", true}, + {"IPv6 compressed", "fe80::1", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := v.ValidateIP(tt.ip) + assert.Equal(t, tt.expected, got) + }) + } +} + +// TestValidator_ValidateIPv4 测试 IPv4 验证 +func TestValidator_ValidateIPv4(t *testing.T) { + v := NewValidator(8, true, true) + + tests := []struct { + name string + ip string + expected bool + }{ + {"empty", "", false}, + {"IPv4 valid", "192.168.1.1", true}, + {"IPv4 invalid", "192.168.1.256", false}, + {"IPv6 localhost", "::1", false}, // IPv6 should fail IPv4 validation + {"IPv6 full", "2001:0db8:85a3::8a2e:0370:7334", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := v.ValidateIPv4(tt.ip) + assert.Equal(t, tt.expected, got) + }) + } +} + +// TestValidator_ValidateIPv6 测试 IPv6 验证 +func TestValidator_ValidateIPv6(t *testing.T) { + v := NewValidator(8, true, true) + + tests := []struct { + name string + ip string + expected bool + }{ + {"empty", "", false}, + {"IPv4 valid", "192.168.1.1", false}, // IPv4 should fail IPv6 validation + {"IPv6 localhost", "::1", true}, + {"IPv6 full", "2001:0db8:85a3:0000:0000:8a2e:0370:7334", true}, + {"IPv6 compressed", "fe80::1", true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := v.ValidateIPv6(tt.ip) + assert.Equal(t, tt.expected, got) + }) + } +} + +// TestValidator_SanitizeSQL 测试 SQL 净化 +func TestValidator_SanitizeSQL(t *testing.T) { + v := NewValidator(8, true, true) + + tests := []struct { + name string + input string + expected string + }{ + {"empty", "", ""}, + {"normal text", "hello world", "hello world"}, + {"quote escape", "'test'", "''test''"}, + {"backslash escape", "\\test", "\\test"}, + {"remove comment", "select; -- comment", "select "}, + {"remove block comment", "select /* comment */ from", "select from"}, + {"remove union", "select union select", "select "}, + {"remove drop", "drop table users", ""}, + {"remove insert", "insert into users", ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := v.SanitizeSQL(tt.input) + // 检查输出不包含危险模式 + assert.NotContains(t, got, "--") + assert.NotContains(t, got, "/*") + assert.NotContains(t, got, "*/") + }) + } +} + +// TestValidator_SanitizeXSS 测试 XSS 净化 +func TestValidator_SanitizeXSS(t *testing.T) { + v := NewValidator(8, true, true) + + tests := []struct { + name string + input string + checkNot string + }{ + {"empty", "", ""}, + {"normal text", "hello world", ""}, + {"remove script", "", "script"}, + {"remove iframe", "", "iframe"}, + {"remove javascript", "javascript:alert(1)", "javascript:"}, + {"remove event handler", "", "onerror"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := v.SanitizeXSS(tt.input) + if tt.checkNot != "" { + assert.NotContains(t, got, tt.checkNot) + } + }) + } +}