Files
tokens-reef/backend/internal/setup/setup_security_test.go

355 lines
11 KiB
Go

package setup
import (
"fmt"
"os"
"path/filepath"
"strings"
"testing"
"github.com/stretchr/testify/assert"
)
// =============================================================================
// Test: setup.go — quoteIdentifier SQL Injection Prevention
// 验证 PostgreSQL 标识符引用能正确防御 SQL 注入
// =============================================================================
func TestQuoteIdentifier_SQLInjectionDefense(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
expectedQuoted string
description string
}{
{
name: "normal identifier",
input: "mydatabase",
expectedQuoted: `"mydatabase"`,
description: "Normal database name should be quoted as-is",
},
{
name: "identifier with underscores",
input: "my_db_name",
expectedQuoted: `"my_db_name"`,
description: "Underscores are valid in identifiers",
},
{
name: "identifier with numbers",
input: "db123",
expectedQuoted: `"db123"`,
description: "Numbers after first char are valid",
},
{
name: "identifier starting with number",
input: "123db",
expectedQuoted: `"123db"`,
description: "Numbers at start need quoting but are valid",
},
{
name: "SQL injection via double quote escape",
input: `mydb"; DROP TABLE users; --`,
expectedQuoted: `"mydb""; DROP TABLE users; --"`,
description: "Double quotes must be escaped by doubling to prevent injection",
},
{
name: "SQL injection single double quote",
input: `foo"bar`,
expectedQuoted: `"foo""bar"`,
description: "Single internal double quote gets doubled",
},
{
name: "SQL injection multiple double quotes",
input: `a"b"c"d"e`,
expectedQuoted: `"a""b""c""d""e"`,
description: "All double quotes must be escaped",
},
{
name: "empty string produces empty quoted",
input: "",
expectedQuoted: `""`,
description: "Empty input becomes empty quoted identifier",
},
{
name: "SQL injection UNION attack",
input: `db" UNION SELECT * FROM secrets --`,
expectedQuoted: `"db"" UNION SELECT * FROM secrets --"`,
description: "UNION injection attempt neutralized by quote escaping",
},
{
name: "SQL injection with semicolon and comment",
input: `test; SELECT 1--`,
expectedQuoted: `"test; SELECT 1--"`,
description: "Semicolons and comments inside quotes are literal text, not SQL syntax",
},
{
name: "whitespace is preserved inside quotes",
input: `my db name`,
expectedQuoted: `"my db name"`,
description: "Spaces inside quoted identifiers are preserved",
},
{
name: "special characters preserved",
input: `my-db.name$v2.0`,
expectedQuoted: `"my-db.name$v2.0"`,
description: "Non-quote special characters pass through (PostgreSQL allows these)",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := quoteIdentifier(tc.input)
assert.Equal(t, tc.expectedQuoted, got,
"quoteIdentifier(%q): got %q, want %q — %s", tc.input, got, tc.expectedQuoted, tc.description)
})
}
}
func TestQuoteIdentifier_SafetyInvariant(t *testing.T) {
t.Parallel()
attackStrings := []string{
`mydb`,
`my_db_123`,
`; COPY users TO '/etc/passwd'; --`,
}
for _, attack := range attackStrings {
attack := attack
safeName := fmt.Sprintf("inv_%d", hashString(attack))
t.Run(safeName, func(t *testing.T) {
t.Parallel()
quoted := quoteIdentifier(attack)
// Invariant 1: Output always starts and ends with exactly one double quote
if !strings.HasPrefix(quoted, `"`) {
t.Errorf("must start with double quote")
}
if !strings.HasSuffix(quoted, `"`) {
t.Errorf("must end with double quote")
}
// Invariant 2: All internal double quotes are escaped (doubled)
inner := quoted[1 : len(quoted)-1]
for i := 0; i < len(inner)-1; i++ {
if inner[i] == '"' && inner[i+1] != '"' {
t.Errorf("unescaped double quote at position %d in inner content", i)
}
}
// Invariant 3: When used in SQL, the result is a single valid identifier
sql := fmt.Sprintf("CREATE DATABASE %s", quoted)
if !strings.Contains(sql, quoted) {
t.Error("SQL must contain the exact quoted identifier")
}
})
}
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func hashString(s string) int {
h := 0
for _, c := range s {
h = h*31 + int(c)
}
if h < 0 {
h = -h
}
return h % 10000
}
// =============================================================================
// Test: setup.go — readExistingJWTSecret / JWT Secret Mismatch Detection
// =============================================================================
func TestReadExistingJWTSecret(t *testing.T) {
t.Run("returns empty when no config file exists", func(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
secret := readExistingJWTSecret()
assert.Empty(t, secret)
})
t.Run("reads jwt.secret from config file", func(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
configPath := filepath.Join(dir, "config.yaml")
content := []byte(`jwt:
secret: my-test-secret-32-bytes-long-value!!
`)
assert.NoError(t, os.WriteFile(configPath, content, 0o644))
secret := readExistingJWTSecret()
assert.Equal(t, "my-test-secret-32-bytes-long-value!!", secret)
})
t.Run("returns empty for missing jwt.secret key", func(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
configPath := filepath.Join(dir, "config.yaml")
content := []byte(`server:
port: 8080
`)
assert.NoError(t, os.WriteFile(configPath, content, 0o644))
secret := readExistingJWTSecret()
assert.Empty(t, secret)
})
t.Run("trims whitespace from secret", func(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
configPath := filepath.Join(dir, "config.yaml")
content := []byte("jwt:\n secret: spaced-secret-32b \n")
assert.NoError(t, os.WriteFile(configPath, content, 0o644))
secret := readExistingJWTSecret()
assert.Equal(t, "spaced-secret-32b", secret)
})
t.Run("returns empty on malformed YAML", func(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
configPath := filepath.Join(dir, "config.yaml")
content := []byte(`{invalid yaml [[[`)
assert.NoError(t, os.WriteFile(configPath, content, 0o644))
secret := readExistingJWTSecret()
assert.Empty(t, secret, "malformed YAML should return empty secret without error")
})
}
// =============================================================================
// Test: setup.go — AutoSetupFromEnv helpers
// =============================================================================
func TestGetEnvOrDefault(t *testing.T) {
t.Run("returns env var value", func(t *testing.T) {
t.Setenv("TEST_GETENV_KEY", "hello_value")
assert.Equal(t, "hello_value", getEnvOrDefault("TEST_GETENV_KEY", "default"))
})
t.Run("returns default when not set", func(t *testing.T) {
os.Unsetenv("TEST_NONEXISTENT_KEY_XYZ")
assert.Equal(t, "fallback", getEnvOrDefault("TEST_NONEXISTENT_KEY_XYZ", "fallback"))
})
t.Run("returns default for empty string env", func(t *testing.T) {
t.Setenv("TEST_EMPTY_ENV_KEY", "")
assert.Equal(t, "fallback", getEnvOrDefault("TEST_EMPTY_ENV_KEY", "fallback"))
})
}
func TestGetEnvIntOrDefault(t *testing.T) {
t.Run("parses valid integer", func(t *testing.T) {
t.Setenv("TEST_INT_KEY", "5432")
assert.Equal(t, 5432, getEnvIntOrDefault("TEST_INT_KEY", 0))
})
t.Run("returns default for invalid int", func(t *testing.T) {
t.Setenv("TEST_BAD_INT", "not_a_number")
assert.Equal(t, 9999, getEnvIntOrDefault("TEST_BAD_INT", 9999))
})
t.Run("returns default for empty", func(t *testing.T) {
os.Unsetenv("TEST_EMPTY_INT_KEY")
assert.Equal(t, 42, getEnvIntOrDefault("TEST_EMPTY_INT_KEY", 42))
})
}
func TestAutoSetupEnabled(t *testing.T) {
cases := map[string]bool{
"true": true, "1": true, "yes": true,
"false": false, "0": false, "no": false,
"": false, "TRUE": false, "Yes": false, // case-sensitive
}
for val, expected := range cases {
val, expected := val, expected
t.Run(fmt.Sprintf("AUTO_SETUP=%q", val), func(t *testing.T) {
t.Setenv("AUTO_SETUP", val)
assert.Equal(t, expected, AutoSetupEnabled())
})
}
}
// =============================================================================
// Test: setup.go — GetDataDir / NeedsSetup
// =============================================================================
func TestGetDataDir_Priority(t *testing.T) {
t.Run("DATA_DIR env takes priority", func(t *testing.T) {
t.Setenv("DATA_DIR", "/custom/data/path")
assert.Equal(t, "/custom/data/path", GetDataDir())
})
t.Run("falls back to current directory when no DATA_DIR and no /app/data", func(t *testing.T) {
os.Unsetenv("DATA_DIR")
// /app/data likely doesn't exist on dev machine
dir := GetDataDir()
assert.NotEmpty(t, dir)
// Should be "." or similar fallback
})
}
func TestNeedsSetup_WithNoFiles(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
// No config.yaml or .installed → needs setup
assert.True(t, NeedsSetup(), "should need setup when no config/lock files exist")
}
func TestNeedsSetup_WithConfigFile(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
configPath := filepath.Join(dir, "config.yaml")
assert.NoError(t, os.WriteFile(configPath, []byte("test: data"), 0o644))
assert.False(t, NeedsSetup(), "should NOT need setup when config.yaml exists")
}
func TestNeedsSetup_WithLockFile(t *testing.T) {
dir := t.TempDir()
t.Setenv("DATA_DIR", dir)
lockPath := filepath.Join(dir, ".installed")
assert.NoError(t, os.WriteFile(lockPath, []byte("installed_at=2024"), 0o644))
assert.False(t, NeedsSetup(), "should NOT need setup when .installed lock exists")
}
// =============================================================================
// Test: setup.go — generateSecret
// =============================================================================
func TestGenerateSecret(t *testing.T) {
t.Parallel()
t.Run("generates hex-encoded string of correct length", func(t *testing.T) {
s, err := generateSecret(16)
assert.NoError(t, err)
assert.Len(t, s, 32) // 16 bytes = 32 hex chars
})
t.Run("generates different values each call", func(t *testing.T) {
s1, _ := generateSecret(16)
s2, _ := generateSecret(16)
assert.NotEqual(t, s1, s2)
})
t.Run("valid hex characters only", func(t *testing.T) {
s, err := generateSecret(32)
assert.NoError(t, err)
for _, c := range s {
assert.True(t, (c >= '0' && c <= '9') || (c >= 'a' && c <= 'f'),
"invalid hex char: %c", c)
}
})
}