package repository import ( "context" "time" "gorm.io/gorm" "github.com/user-management-system/internal/domain" ) // DeviceRepository 设备数据访问层 type DeviceRepository struct { db *gorm.DB } // NewDeviceRepository 创建设备数据访问层 func NewDeviceRepository(db *gorm.DB) *DeviceRepository { return &DeviceRepository{db: db} } // Create 创建设备 func (r *DeviceRepository) Create(ctx context.Context, device *domain.Device) error { // GORM omits zero values on insert for fields with DB defaults. Explicitly // backfill inactive status so callers can persist status=0 devices. requestedStatus := device.Status return r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { if err := tx.Create(device).Error; err != nil { return err } if requestedStatus == domain.DeviceStatusInactive { if err := tx.Model(&domain.Device{}).Where("id = ?", device.ID).Update("status", requestedStatus).Error; err != nil { return err } device.Status = requestedStatus } return nil }) } // Update 更新设备 func (r *DeviceRepository) Update(ctx context.Context, device *domain.Device) error { return r.db.WithContext(ctx).Save(device).Error } // Delete 删除设备 func (r *DeviceRepository) Delete(ctx context.Context, id int64) error { return r.db.WithContext(ctx).Delete(&domain.Device{}, id).Error } // GetByID 根据ID获取设备 func (r *DeviceRepository) GetByID(ctx context.Context, id int64) (*domain.Device, error) { var device domain.Device err := r.db.WithContext(ctx).First(&device, id).Error if err != nil { return nil, err } return &device, nil } // GetByDeviceID 根据设备ID和用户ID获取设备 func (r *DeviceRepository) GetByDeviceID(ctx context.Context, userID int64, deviceID string) (*domain.Device, error) { var device domain.Device err := r.db.WithContext(ctx).Where("user_id = ? AND device_id = ?", userID, deviceID).First(&device).Error if err != nil { return nil, err } return &device, nil } // List 获取设备列表 func (r *DeviceRepository) List(ctx context.Context, offset, limit int) ([]*domain.Device, int64, error) { var devices []*domain.Device var total int64 query := r.db.WithContext(ctx).Model(&domain.Device{}) // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 获取列表 if err := query.Offset(offset).Limit(limit).Find(&devices).Error; err != nil { return nil, 0, err } return devices, total, nil } // ListByUserID 根据用户ID获取设备列表 func (r *DeviceRepository) ListByUserID(ctx context.Context, userID int64, offset, limit int) ([]*domain.Device, int64, error) { var devices []*domain.Device var total int64 query := r.db.WithContext(ctx).Model(&domain.Device{}).Where("user_id = ?", userID) // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 获取列表 if err := query.Offset(offset).Limit(limit).Order("last_active_time DESC").Find(&devices).Error; err != nil { return nil, 0, err } return devices, total, nil } // ListByStatus 根据状态获取设备列表 func (r *DeviceRepository) ListByStatus(ctx context.Context, status domain.DeviceStatus, offset, limit int) ([]*domain.Device, int64, error) { var devices []*domain.Device var total int64 query := r.db.WithContext(ctx).Model(&domain.Device{}).Where("status = ?", status) // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 获取列表 if err := query.Offset(offset).Limit(limit).Find(&devices).Error; err != nil { return nil, 0, err } return devices, total, nil } // UpdateStatus 更新设备状态 func (r *DeviceRepository) UpdateStatus(ctx context.Context, id int64, status domain.DeviceStatus) error { return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", id).Update("status", status).Error } // UpdateLastActiveTime 更新最后活跃时间 func (r *DeviceRepository) UpdateLastActiveTime(ctx context.Context, id int64) error { now := time.Now() return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", id).Update("last_active_time", now).Error } // Exists 检查设备是否存在 func (r *DeviceRepository) Exists(ctx context.Context, userID int64, deviceID string) (bool, error) { var count int64 err := r.db.WithContext(ctx).Model(&domain.Device{}). Where("user_id = ? AND device_id = ?", userID, deviceID). Count(&count).Error return count > 0, err } // DeleteByUserID 删除用户的所有设备 func (r *DeviceRepository) DeleteByUserID(ctx context.Context, userID int64) error { return r.db.WithContext(ctx).Where("user_id = ?", userID).Delete(&domain.Device{}).Error } // GetActiveDevices 获取活跃设备 func (r *DeviceRepository) GetActiveDevices(ctx context.Context, userID int64) ([]*domain.Device, error) { var devices []*domain.Device thirtyDaysAgo := time.Now().Add(-30 * 24 * time.Hour) err := r.db.WithContext(ctx). Where("user_id = ? AND last_active_time > ?", userID, thirtyDaysAgo). Order("last_active_time DESC"). Find(&devices).Error if err != nil { return nil, err } return devices, nil } // TrustDevice 设置设备为信任状态 func (r *DeviceRepository) TrustDevice(ctx context.Context, deviceID int64, expiresAt *time.Time) error { updates := map[string]interface{}{ "is_trusted": true, "trust_expires_at": expiresAt, } return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", deviceID).Updates(updates).Error } // UntrustDevice 取消设备信任状态 func (r *DeviceRepository) UntrustDevice(ctx context.Context, deviceID int64) error { updates := map[string]interface{}{ "is_trusted": false, "trust_expires_at": nil, } return r.db.WithContext(ctx).Model(&domain.Device{}).Where("id = ?", deviceID).Updates(updates).Error } // DeleteAllByUserIDExcept 删除用户的所有设备(除指定设备外) func (r *DeviceRepository) DeleteAllByUserIDExcept(ctx context.Context, userID int64, exceptDeviceID int64) error { return r.db.WithContext(ctx). Where("user_id = ? AND id != ?", userID, exceptDeviceID). Delete(&domain.Device{}).Error } // GetTrustedDevices 获取用户的信任设备列表 func (r *DeviceRepository) GetTrustedDevices(ctx context.Context, userID int64) ([]*domain.Device, error) { var devices []*domain.Device now := time.Now() err := r.db.WithContext(ctx). Where("user_id = ? AND is_trusted = ? AND (trust_expires_at IS NULL OR trust_expires_at > ?)", userID, true, now). Order("last_active_time DESC"). Find(&devices).Error if err != nil { return nil, err } return devices, nil } // ListDevicesParams 设备列表查询参数 type ListDevicesParams struct { UserID int64 Status domain.DeviceStatus IsTrusted *bool Keyword string Offset int Limit int } // ListAll 获取所有设备列表(支持筛选) func (r *DeviceRepository) ListAll(ctx context.Context, params *ListDevicesParams) ([]*domain.Device, int64, error) { var devices []*domain.Device var total int64 query := r.db.WithContext(ctx).Model(&domain.Device{}) // 按用户ID筛选 if params.UserID > 0 { query = query.Where("user_id = ?", params.UserID) } // 按状态筛选 if params.Status >= 0 { query = query.Where("status = ?", params.Status) } // 按信任状态筛选 if params.IsTrusted != nil { query = query.Where("is_trusted = ?", *params.IsTrusted) } // 按关键词筛选(设备名/IP/位置) if params.Keyword != "" { search := "%" + params.Keyword + "%" query = query.Where("device_name LIKE ? OR ip LIKE ? OR location LIKE ?", search, search, search) } // 获取总数 if err := query.Count(&total).Error; err != nil { return nil, 0, err } // 获取列表 if err := query.Offset(params.Offset).Limit(params.Limit). Order("last_active_time DESC").Find(&devices).Error; err != nil { return nil, 0, err } return devices, total, nil }