package database import ( "fmt" "log" "github.com/glebarez/sqlite" "gorm.io/gorm" "github.com/user-management-system/internal/auth" "github.com/user-management-system/internal/config" "github.com/user-management-system/internal/domain" ) type DB struct { *gorm.DB } func NewDB(cfg *config.Config) (*DB, error) { // 当前仅支持 SQLite // 如果配置中指定了数据库路径则使用它,否则使用默认路径 dbPath := "./data/user_management.db" if cfg != nil && cfg.Database.DBName != "" { dbPath = cfg.Database.DBName } dialector := sqlite.Open(dbPath) db, err := gorm.Open(dialector, &gorm.Config{}) if err != nil { return nil, fmt.Errorf("connect database failed: %w", err) } return &DB{DB: db}, nil } func (db *DB) AutoMigrate(cfg *config.Config) error { log.Println("starting database migration") if err := db.DB.AutoMigrate( &domain.User{}, &domain.Role{}, &domain.Permission{}, &domain.UserRole{}, &domain.RolePermission{}, &domain.Device{}, &domain.LoginLog{}, &domain.OperationLog{}, &domain.SocialAccount{}, &domain.Webhook{}, &domain.WebhookDelivery{}, &domain.PasswordHistory{}, ); err != nil { return fmt.Errorf("database migration failed: %w", err) } if err := db.initDefaultData(cfg); err != nil { return fmt.Errorf("initialize default data failed: %w", err) } return nil } func (db *DB) initDefaultData(cfg *config.Config) error { var count int64 if err := db.DB.Model(&domain.Role{}).Count(&count).Error; err != nil { return err } if count > 0 { // 角色已存在,仍需补充权限数据(升级场景) if err := db.ensurePermissions(); err != nil { log.Printf("warn: ensure permissions failed: %v", err) } log.Println("default data already exists, skipping bootstrap") return nil } log.Println("bootstrapping default roles and permissions") // 1. 创建角色 var adminRoleID int64 var userRoleID int64 for _, predefined := range domain.PredefinedRoles { role := predefined if err := db.DB.Create(&role).Error; err != nil { return fmt.Errorf("create role failed: %w", err) } if role.Code == "admin" { adminRoleID = role.ID } if role.Code == "user" { userRoleID = role.ID } } // 2. 创建权限 permIDs, err := db.createDefaultPermissions() if err != nil { return fmt.Errorf("create permissions failed: %w", err) } // 3. 给 admin 角色绑定所有权限 if adminRoleID > 0 { for _, permID := range permIDs { db.DB.Create(&domain.RolePermission{RoleID: adminRoleID, PermissionID: permID}) } log.Printf("assigned %d permissions to admin role", len(permIDs)) } // 4. 给普通用户角色绑定基础权限 if userRoleID > 0 { userPermCodes := []string{"profile:view", "profile:edit", "log:view_own"} for _, code := range userPermCodes { var perm domain.Permission if err := db.DB.Where("code = ?", code).First(&perm).Error; err == nil { db.DB.Create(&domain.RolePermission{RoleID: userRoleID, PermissionID: perm.ID}) } } } // 5. 创建 admin 用户 adminUsername := cfg.Default.AdminEmail adminPassword := cfg.Default.AdminPassword if adminUsername == "" || adminPassword == "" { log.Println("admin bootstrap skipped: default.admin_email/admin_password not configured") return nil } passwordHash, err := auth.HashPassword(adminPassword) if err != nil { return fmt.Errorf("hash admin password failed: %w", err) } adminUser := &domain.User{ Username: adminUsername, Email: domain.StrPtr(adminUsername), Password: passwordHash, Nickname: "系统管理员", Status: domain.UserStatusActive, } if err := db.DB.Create(adminUser).Error; err != nil { return fmt.Errorf("create admin user failed: %w", err) } if adminRoleID == 0 { return fmt.Errorf("admin role missing during bootstrap") } if err := db.DB.Create(&domain.UserRole{ UserID: adminUser.ID, RoleID: adminRoleID, }).Error; err != nil { return fmt.Errorf("assign admin role failed: %w", err) } log.Printf("bootstrap completed: admin user=%s, roles=%d, permissions=%d", adminUser.Username, 2, len(permIDs)) return nil } // ensurePermissions 在升级场景中补充缺失的权限数据 func (db *DB) ensurePermissions() error { var permCount int64 db.DB.Model(&domain.Permission{}).Count(&permCount) if permCount > 0 { return nil // 已有权限数据 } log.Println("permissions table is empty, seeding default permissions") permIDs, err := db.createDefaultPermissions() if err != nil { return err } // 找到 admin 角色并绑定所有权限 var adminRole domain.Role if err := db.DB.Where("code = ?", "admin").First(&adminRole).Error; err == nil { for _, permID := range permIDs { db.DB.Create(&domain.RolePermission{RoleID: adminRole.ID, PermissionID: permID}) } log.Printf("assigned %d permissions to admin role (upgrade)", len(permIDs)) } // 找到普通用户角色并绑定基础权限 var userRole domain.Role if err := db.DB.Where("code = ?", "user").First(&userRole).Error; err == nil { userPermCodes := []string{"profile:view", "profile:edit", "log:view_own"} for _, code := range userPermCodes { var perm domain.Permission if err := db.DB.Where("code = ?", code).First(&perm).Error; err == nil { db.DB.Create(&domain.RolePermission{RoleID: userRole.ID, PermissionID: perm.ID}) } } } return nil } // createDefaultPermissions 创建默认权限列表,返回所有权限 ID func (db *DB) createDefaultPermissions() ([]int64, error) { permissions := domain.DefaultPermissions() var ids []int64 for i := range permissions { p := permissions[i] // 使用 FirstOrCreate 防止重复插入(幂等) result := db.DB.Where("code = ?", p.Code).FirstOrCreate(&p) if result.Error != nil { log.Printf("warn: create permission %s failed: %v", p.Code, result.Error) continue } ids = append(ids, p.ID) } return ids, nil }