Files
user-system/internal/database/db.go
long-agent 8095307d82 fix: P0/P1 security and quality fixes
P0-01: Add ESCAPE clause to LIKE queries in operation_log.go and device.go
P0-02: Add atomic Increment to L1Cache and L2Cache interfaces
P0-07: Add TOTP verification step after password login
P1-01: Sanitize error messages in error.go middleware
P1-03: Remove err.Error() from export error messages
P1-04: Add error return to CountByResultSince in login_log.go
P1-05: Add transactional DeleteCascade to RoleRepository
P1-06: Add PasswordChangedAt tracking for JWT token invalidation
P1-07: Wrap theme SetDefault in database transaction
P1-08: Use config values for database pool parameters
P1-09: Add rows.Err() checks in social_account_repo.go
P1-10: Validate sortOrder with map in user.go ORDER BY
P1-11: Add GORM tags to Announcement struct
P1-15: Add pageSize upper limit (100) to device and log handlers
2026-04-18 15:33:12 +08:00

268 lines
7.8 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package database
import (
"fmt"
"log"
"time"
"github.com/glebarez/sqlite"
"gorm.io/gorm"
"github.com/user-management-system/internal/auth"
"github.com/user-management-system/internal/config"
"github.com/user-management-system/internal/domain"
)
type DB struct {
*gorm.DB
}
func NewDB(cfg *config.Config) (*DB, error) {
// 当前仅支持 SQLite
// 如果配置中指定了数据库路径则使用它,否则使用默认路径
dbPath := "./data/user_management.db"
if cfg != nil && cfg.Database.DBName != "" {
dbPath = cfg.Database.DBName
}
dialector := sqlite.Open(dbPath)
db, err := gorm.Open(dialector, &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("connect database failed: %w", err)
}
// WARN-02 修复:开启 WAL 模式提升并发读写性能
// WALWrite-Ahead Logging允许读写并发显著减少写操作对读操作的阻塞
sqlDB, err := db.DB()
if err != nil {
return nil, fmt.Errorf("get underlying sql.DB failed: %w", err)
}
// 开启 WAL 模式
if _, err := sqlDB.Exec("PRAGMA journal_mode=WAL"); err != nil {
log.Printf("warn: enable WAL mode failed: %v", err)
}
// 开启同步模式 NORMALWAL 下 NORMAL 已足够安全,比 FULL 快很多)
if _, err := sqlDB.Exec("PRAGMA synchronous=NORMAL"); err != nil {
log.Printf("warn: set synchronous=NORMAL failed: %v", err)
}
// 缓存大小8MB单位负数表示 KB
if _, err := sqlDB.Exec("PRAGMA cache_size=-8192"); err != nil {
log.Printf("warn: set cache_size failed: %v", err)
}
// 开启外键约束SQLite 默认关闭)
if _, err := sqlDB.Exec("PRAGMA foreign_keys=ON"); err != nil {
log.Printf("warn: enable foreign_keys failed: %v", err)
}
// Busy Timeout5 秒(减少写冲突时的 SQLITE_BUSY 错误)
if _, err := sqlDB.Exec("PRAGMA busy_timeout=5000"); err != nil {
log.Printf("warn: set busy_timeout failed: %v", err)
}
// 连接池配置:使用配置文件中的参数
maxOpenConns := 10
maxIdleConns := 5
connMaxLifetime := 30 * time.Minute
connMaxIdleTime := 10 * time.Minute
if cfg != nil {
if cfg.Database.MaxOpenConns > 0 {
maxOpenConns = cfg.Database.MaxOpenConns
}
if cfg.Database.MaxIdleConns > 0 {
maxIdleConns = cfg.Database.MaxIdleConns
}
if cfg.Database.ConnMaxLifetimeMinutes > 0 {
connMaxLifetime = time.Duration(cfg.Database.ConnMaxLifetimeMinutes) * time.Minute
}
if cfg.Database.ConnMaxIdleTimeMinutes > 0 {
connMaxIdleTime = time.Duration(cfg.Database.ConnMaxIdleTimeMinutes) * time.Minute
}
}
sqlDB.SetMaxOpenConns(maxOpenConns)
sqlDB.SetMaxIdleConns(maxIdleConns)
sqlDB.SetConnMaxLifetime(connMaxLifetime)
sqlDB.SetConnMaxIdleTime(connMaxIdleTime)
log.Println("database: SQLite WAL mode enabled, connection pool configured")
return &DB{DB: db}, nil
}
func (db *DB) AutoMigrate(cfg *config.Config) error {
log.Println("starting database migration")
if err := db.DB.AutoMigrate(
&domain.User{},
&domain.Role{},
&domain.Permission{},
&domain.UserRole{},
&domain.RolePermission{},
&domain.Device{},
&domain.LoginLog{},
&domain.OperationLog{},
&domain.SocialAccount{},
&domain.Webhook{},
&domain.WebhookDelivery{},
&domain.PasswordHistory{},
); err != nil {
return fmt.Errorf("database migration failed: %w", err)
}
if err := db.initDefaultData(cfg); err != nil {
return fmt.Errorf("initialize default data failed: %w", err)
}
return nil
}
func (db *DB) initDefaultData(cfg *config.Config) error {
var count int64
if err := db.DB.Model(&domain.Role{}).Count(&count).Error; err != nil {
return err
}
if count > 0 {
// 角色已存在,仍需补充权限数据(升级场景)
if err := db.ensurePermissions(); err != nil {
log.Printf("warn: ensure permissions failed: %v", err)
}
log.Println("default data already exists, skipping bootstrap")
return nil
}
log.Println("bootstrapping default roles and permissions")
// 1. 创建角色
var adminRoleID int64
var userRoleID int64
for _, predefined := range domain.PredefinedRoles {
role := predefined
if err := db.DB.Create(&role).Error; err != nil {
return fmt.Errorf("create role failed: %w", err)
}
if role.Code == "admin" {
adminRoleID = role.ID
}
if role.Code == "user" {
userRoleID = role.ID
}
}
// 2. 创建权限
permIDs, err := db.createDefaultPermissions()
if err != nil {
return fmt.Errorf("create permissions failed: %w", err)
}
// 3. 给 admin 角色绑定所有权限
if adminRoleID > 0 {
for _, permID := range permIDs {
db.DB.Create(&domain.RolePermission{RoleID: adminRoleID, PermissionID: permID})
}
log.Printf("assigned %d permissions to admin role", len(permIDs))
}
// 4. 给普通用户角色绑定基础权限
if userRoleID > 0 {
userPermCodes := []string{"profile:view", "profile:edit", "log:view_own"}
for _, code := range userPermCodes {
var perm domain.Permission
if err := db.DB.Where("code = ?", code).First(&perm).Error; err == nil {
db.DB.Create(&domain.RolePermission{RoleID: userRoleID, PermissionID: perm.ID})
}
}
}
// 5. 创建 admin 用户
adminUsername := cfg.Default.AdminEmail
adminPassword := cfg.Default.AdminPassword
if adminUsername == "" || adminPassword == "" {
log.Println("admin bootstrap skipped: default.admin_email/admin_password not configured")
return nil
}
passwordHash, err := auth.HashPassword(adminPassword)
if err != nil {
return fmt.Errorf("hash admin password failed: %w", err)
}
adminUser := &domain.User{
Username: adminUsername,
Email: domain.StrPtr(adminUsername),
Password: passwordHash,
Nickname: "系统管理员",
Status: domain.UserStatusActive,
}
if err := db.DB.Create(adminUser).Error; err != nil {
return fmt.Errorf("create admin user failed: %w", err)
}
if adminRoleID == 0 {
return fmt.Errorf("admin role missing during bootstrap")
}
if err := db.DB.Create(&domain.UserRole{
UserID: adminUser.ID,
RoleID: adminRoleID,
}).Error; err != nil {
return fmt.Errorf("assign admin role failed: %w", err)
}
log.Printf("bootstrap completed: admin user=%s, roles=%d, permissions=%d",
adminUser.Username, 2, len(permIDs))
return nil
}
// ensurePermissions 在升级场景中补充缺失的权限数据
func (db *DB) ensurePermissions() error {
var permCount int64
db.DB.Model(&domain.Permission{}).Count(&permCount)
if permCount > 0 {
return nil // 已有权限数据
}
log.Println("permissions table is empty, seeding default permissions")
permIDs, err := db.createDefaultPermissions()
if err != nil {
return err
}
// 找到 admin 角色并绑定所有权限
var adminRole domain.Role
if err := db.DB.Where("code = ?", "admin").First(&adminRole).Error; err == nil {
for _, permID := range permIDs {
db.DB.Create(&domain.RolePermission{RoleID: adminRole.ID, PermissionID: permID})
}
log.Printf("assigned %d permissions to admin role (upgrade)", len(permIDs))
}
// 找到普通用户角色并绑定基础权限
var userRole domain.Role
if err := db.DB.Where("code = ?", "user").First(&userRole).Error; err == nil {
userPermCodes := []string{"profile:view", "profile:edit", "log:view_own"}
for _, code := range userPermCodes {
var perm domain.Permission
if err := db.DB.Where("code = ?", code).First(&perm).Error; err == nil {
db.DB.Create(&domain.RolePermission{RoleID: userRole.ID, PermissionID: perm.ID})
}
}
}
return nil
}
// createDefaultPermissions 创建默认权限列表,返回所有权限 ID
func (db *DB) createDefaultPermissions() ([]int64, error) {
permissions := domain.DefaultPermissions()
var ids []int64
for i := range permissions {
p := permissions[i]
// 使用 FirstOrCreate 防止重复插入(幂等)
result := db.DB.Where("code = ?", p.Code).FirstOrCreate(&p)
if result.Error != nil {
log.Printf("warn: create permission %s failed: %v", p.Code, result.Error)
continue
}
ids = append(ids, p.ID)
}
return ids, nil
}