Files
user-system/internal/api/handler/device_handler.go

344 lines
8.4 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package handler
import (
"fmt"
"net/http"
"strconv"
"time"
"github.com/gin-gonic/gin"
"github.com/user-management-system/internal/domain"
"github.com/user-management-system/internal/service"
)
// DeviceHandler handles device management requests
type DeviceHandler struct {
deviceService *service.DeviceService
}
// NewDeviceHandler creates a new DeviceHandler
func NewDeviceHandler(deviceService *service.DeviceService) *DeviceHandler {
return &DeviceHandler{deviceService: deviceService}
}
func (h *DeviceHandler) CreateDevice(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
var req service.CreateDeviceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
device, err := h.deviceService.CreateDevice(c.Request.Context(), userID, &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusCreated, device)
}
func (h *DeviceHandler) GetMyDevices(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
devices, total, err := h.deviceService.GetUserDevices(c.Request.Context(), userID, page, pageSize)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"devices": devices,
"total": total,
"page": page,
"page_size": pageSize,
})
}
func (h *DeviceHandler) GetDevice(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
device, err := h.deviceService.GetDevice(c.Request.Context(), id)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, device)
}
func (h *DeviceHandler) UpdateDevice(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
var req service.UpdateDeviceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
device, err := h.deviceService.UpdateDevice(c.Request.Context(), id, &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, device)
}
func (h *DeviceHandler) DeleteDevice(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
if err := h.deviceService.DeleteDevice(c.Request.Context(), id); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "device deleted"})
}
func (h *DeviceHandler) UpdateDeviceStatus(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
var req struct {
Status string `json:"status" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
var status domain.DeviceStatus
switch req.Status {
case "active", "1":
status = domain.DeviceStatusActive
case "inactive", "0":
status = domain.DeviceStatusInactive
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid status"})
return
}
if err := h.deviceService.UpdateDeviceStatus(c.Request.Context(), id, status); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "status updated"})
}
func (h *DeviceHandler) GetUserDevices(c *gin.Context) {
userIDParam := c.Param("id")
userID, err := strconv.ParseInt(userIDParam, 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid user id"})
return
}
page, _ := strconv.Atoi(c.DefaultQuery("page", "1"))
pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "20"))
devices, total, err := h.deviceService.GetUserDevices(c.Request.Context(), userID, page, pageSize)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"devices": devices,
"total": total,
"page": page,
"page_size": pageSize,
})
}
// GetAllDevices 获取所有设备列表(管理员)
func (h *DeviceHandler) GetAllDevices(c *gin.Context) {
var req service.GetAllDevicesRequest
if err := c.ShouldBindQuery(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
devices, total, err := h.deviceService.GetAllDevices(c.Request.Context(), &req)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{
"devices": devices,
"total": total,
"page": req.Page,
"page_size": req.PageSize,
})
}
// TrustDeviceRequest 信任设备请求
type TrustDeviceRequest struct {
TrustDuration string `json:"trust_duration"` // 信任持续时间,如 "30d" 表示30天
}
// TrustDevice 设置设备为信任设备
func (h *DeviceHandler) TrustDevice(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
var req TrustDeviceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 解析信任持续时间
trustDuration := parseDuration(req.TrustDuration)
if err := h.deviceService.TrustDevice(c.Request.Context(), id, trustDuration); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "device trusted"})
}
// TrustDeviceByDeviceID 根据设备标识字符串设置设备为信任状态
func (h *DeviceHandler) TrustDeviceByDeviceID(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
deviceID := c.Param("deviceId")
if deviceID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
var req TrustDeviceRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// 解析信任持续时间
trustDuration := parseDuration(req.TrustDuration)
if err := h.deviceService.TrustDeviceByDeviceID(c.Request.Context(), userID, deviceID, trustDuration); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "device trusted"})
}
// UntrustDevice 取消设备信任状态
func (h *DeviceHandler) UntrustDevice(c *gin.Context) {
id, err := strconv.ParseInt(c.Param("id"), 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid device id"})
return
}
if err := h.deviceService.UntrustDevice(c.Request.Context(), id); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "device untrusted"})
}
// GetMyTrustedDevices 获取我的信任设备列表
func (h *DeviceHandler) GetMyTrustedDevices(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
devices, err := h.deviceService.GetTrustedDevices(c.Request.Context(), userID)
if err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"devices": devices})
}
// LogoutAllOtherDevices 登出所有其他设备
func (h *DeviceHandler) LogoutAllOtherDevices(c *gin.Context) {
userID, ok := getUserIDFromContext(c)
if !ok {
c.JSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return
}
// 从请求中获取当前设备ID
currentDeviceIDStr := c.GetHeader("X-Device-ID")
currentDeviceID, err := strconv.ParseInt(currentDeviceIDStr, 10, 64)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid current device id"})
return
}
if err := h.deviceService.LogoutAllOtherDevices(c.Request.Context(), userID, currentDeviceID); err != nil {
handleError(c, err)
return
}
c.JSON(http.StatusOK, gin.H{"message": "all other devices logged out"})
}
// parseDuration 解析duration字符串如 "30d" -> 30天的time.Duration
func parseDuration(s string) time.Duration {
if s == "" {
return 0
}
// 简单实现,支持 d(天)和h(小时)
var d int
var h int
_, _ = d, h
switch s[len(s)-1] {
case 'd':
d = 1
_, _ = fmt.Sscanf(s[:len(s)-1], "%d", &d)
return time.Duration(d) * 24 * time.Hour
case 'h':
_, _ = fmt.Sscanf(s[:len(s)-1], "%d", &h)
return time.Duration(h) * time.Hour
}
return 0
}