chore: 删除未使用的孤立包
清理以下未导入的包: - internal/response (未使用的响应结构体) - pkg/response (未使用的响应封装) - internal/model (TLSFingerprintProfile, ErrorPassthroughRule) - internal/models (SocialAccount, domain已有) - internal/pkg/response (未使用的响应封装) - internal/security/ratelimit (已迁移到middleware) 验证: go build ./... && go test ./... 通过
This commit is contained in:
@@ -1,75 +0,0 @@
|
|||||||
// Package model 定义服务层使用的数据模型。
|
|
||||||
package model
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
// ErrorPassthroughRule 全局错误透传规则
|
|
||||||
// 用于控制上游错误如何返回给客户端
|
|
||||||
type ErrorPassthroughRule struct {
|
|
||||||
ID int64 `json:"id"`
|
|
||||||
Name string `json:"name"` // 规则名称
|
|
||||||
Enabled bool `json:"enabled"` // 是否启用
|
|
||||||
Priority int `json:"priority"` // 优先级(数字越小优先级越高)
|
|
||||||
ErrorCodes []int `json:"error_codes"` // 匹配的错误码列表(OR关系)
|
|
||||||
Keywords []string `json:"keywords"` // 匹配的关键词列表(OR关系)
|
|
||||||
MatchMode string `json:"match_mode"` // "any"(任一条件) 或 "all"(所有条件)
|
|
||||||
Platforms []string `json:"platforms"` // 适用平台列表
|
|
||||||
PassthroughCode bool `json:"passthrough_code"` // 是否透传原始状态码
|
|
||||||
ResponseCode *int `json:"response_code"` // 自定义状态码(passthrough_code=false 时使用)
|
|
||||||
PassthroughBody bool `json:"passthrough_body"` // 是否透传原始错误信息
|
|
||||||
CustomMessage *string `json:"custom_message"` // 自定义错误信息(passthrough_body=false 时使用)
|
|
||||||
SkipMonitoring bool `json:"skip_monitoring"` // 是否跳过运维监控记录
|
|
||||||
Description *string `json:"description"` // 规则描述
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// MatchModeAny 表示任一条件匹配即可
|
|
||||||
const MatchModeAny = "any"
|
|
||||||
|
|
||||||
// MatchModeAll 表示所有条件都必须匹配
|
|
||||||
const MatchModeAll = "all"
|
|
||||||
|
|
||||||
// 支持的平台常量
|
|
||||||
const (
|
|
||||||
PlatformAnthropic = "anthropic"
|
|
||||||
PlatformOpenAI = "openai"
|
|
||||||
PlatformGemini = "gemini"
|
|
||||||
PlatformAntigravity = "antigravity"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AllPlatforms 返回所有支持的平台列表
|
|
||||||
func AllPlatforms() []string {
|
|
||||||
return []string{PlatformAnthropic, PlatformOpenAI, PlatformGemini, PlatformAntigravity}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate 验证规则配置的有效性
|
|
||||||
func (r *ErrorPassthroughRule) Validate() error {
|
|
||||||
if r.Name == "" {
|
|
||||||
return &ValidationError{Field: "name", Message: "name is required"}
|
|
||||||
}
|
|
||||||
if r.MatchMode != MatchModeAny && r.MatchMode != MatchModeAll {
|
|
||||||
return &ValidationError{Field: "match_mode", Message: "match_mode must be 'any' or 'all'"}
|
|
||||||
}
|
|
||||||
// 至少需要配置一个匹配条件(错误码或关键词)
|
|
||||||
if len(r.ErrorCodes) == 0 && len(r.Keywords) == 0 {
|
|
||||||
return &ValidationError{Field: "conditions", Message: "at least one error_code or keyword is required"}
|
|
||||||
}
|
|
||||||
if !r.PassthroughCode && (r.ResponseCode == nil || *r.ResponseCode <= 0) {
|
|
||||||
return &ValidationError{Field: "response_code", Message: "response_code is required when passthrough_code is false"}
|
|
||||||
}
|
|
||||||
if !r.PassthroughBody && (r.CustomMessage == nil || *r.CustomMessage == "") {
|
|
||||||
return &ValidationError{Field: "custom_message", Message: "custom_message is required when passthrough_body is false"}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidationError 表示验证错误
|
|
||||||
type ValidationError struct {
|
|
||||||
Field string
|
|
||||||
Message string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *ValidationError) Error() string {
|
|
||||||
return e.Field + ": " + e.Message
|
|
||||||
}
|
|
||||||
@@ -1,54 +0,0 @@
|
|||||||
// Package model 定义服务层使用的数据模型。
|
|
||||||
package model
|
|
||||||
|
|
||||||
import (
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/user-management-system/internal/pkg/tlsfingerprint"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TLSFingerprintProfile TLS 指纹配置模板
|
|
||||||
// 包含完整的 ClientHello 参数,用于模拟特定客户端的 TLS 握手特征
|
|
||||||
type TLSFingerprintProfile struct {
|
|
||||||
ID int64 `json:"id"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Description *string `json:"description"`
|
|
||||||
EnableGREASE bool `json:"enable_grease"`
|
|
||||||
CipherSuites []uint16 `json:"cipher_suites"`
|
|
||||||
Curves []uint16 `json:"curves"`
|
|
||||||
PointFormats []uint16 `json:"point_formats"`
|
|
||||||
SignatureAlgorithms []uint16 `json:"signature_algorithms"`
|
|
||||||
ALPNProtocols []string `json:"alpn_protocols"`
|
|
||||||
SupportedVersions []uint16 `json:"supported_versions"`
|
|
||||||
KeyShareGroups []uint16 `json:"key_share_groups"`
|
|
||||||
PSKModes []uint16 `json:"psk_modes"`
|
|
||||||
Extensions []uint16 `json:"extensions"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
UpdatedAt time.Time `json:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate 验证模板配置的有效性
|
|
||||||
func (p *TLSFingerprintProfile) Validate() error {
|
|
||||||
if p.Name == "" {
|
|
||||||
return &ValidationError{Field: "name", Message: "name is required"}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToTLSProfile 将领域模型转换为运行时使用的 tlsfingerprint.Profile
|
|
||||||
// 空切片字段会在 dialer 中 fallback 到内置默认值
|
|
||||||
func (p *TLSFingerprintProfile) ToTLSProfile() *tlsfingerprint.Profile {
|
|
||||||
return &tlsfingerprint.Profile{
|
|
||||||
Name: p.Name,
|
|
||||||
EnableGREASE: p.EnableGREASE,
|
|
||||||
CipherSuites: p.CipherSuites,
|
|
||||||
Curves: p.Curves,
|
|
||||||
PointFormats: p.PointFormats,
|
|
||||||
SignatureAlgorithms: p.SignatureAlgorithms,
|
|
||||||
ALPNProtocols: p.ALPNProtocols,
|
|
||||||
SupportedVersions: p.SupportedVersions,
|
|
||||||
KeyShareGroups: p.KeyShareGroups,
|
|
||||||
PSKModes: p.PSKModes,
|
|
||||||
Extensions: p.Extensions,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
package models
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SocialAccount 社交账号绑定模型
|
|
||||||
type SocialAccount struct {
|
|
||||||
ID uint64 `json:"id" db:"id"`
|
|
||||||
UserID uint64 `json:"user_id" db:"user_id"`
|
|
||||||
Provider string `json:"provider" db:"provider"` // wechat, qq, weibo, google, facebook, twitter
|
|
||||||
ProviderUserID string `json:"provider_user_id" db:"provider_user_id"`
|
|
||||||
ProviderUsername string `json:"provider_username" db:"provider_username"`
|
|
||||||
AccessToken string `json:"-" db:"access_token"` // 不返回给前端
|
|
||||||
RefreshToken string `json:"-" db:"refresh_token"`
|
|
||||||
ExpiresAt *time.Time `json:"expires_at" db:"expires_at"`
|
|
||||||
RawData JSON `json:"-" db:"raw_data"`
|
|
||||||
IsPrimary bool `json:"is_primary" db:"is_primary"`
|
|
||||||
CreatedAt time.Time `json:"created_at" db:"created_at"`
|
|
||||||
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// SocialAccountInfo 返回给前端的社交账号信息(不含敏感信息)
|
|
||||||
type SocialAccountInfo struct {
|
|
||||||
ID uint64 `json:"id"`
|
|
||||||
Provider string `json:"provider"`
|
|
||||||
ProviderUserID string `json:"provider_user_id"`
|
|
||||||
ProviderUsername string `json:"provider_username"`
|
|
||||||
IsPrimary bool `json:"is_primary"`
|
|
||||||
CreatedAt time.Time `json:"created_at"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToInfo 转换为安全信息
|
|
||||||
func (sa *SocialAccount) ToInfo() *SocialAccountInfo {
|
|
||||||
return &SocialAccountInfo{
|
|
||||||
ID: sa.ID,
|
|
||||||
Provider: sa.Provider,
|
|
||||||
ProviderUserID: sa.ProviderUserID,
|
|
||||||
ProviderUsername: sa.ProviderUsername,
|
|
||||||
IsPrimary: sa.IsPrimary,
|
|
||||||
CreatedAt: sa.CreatedAt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// JSON 自定义JSON类型,用于存储RawData
|
|
||||||
type JSON struct {
|
|
||||||
Data interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan 实现 sql.Scanner 接口
|
|
||||||
func (j *JSON) Scan(value interface{}) error {
|
|
||||||
if value == nil {
|
|
||||||
j.Data = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
bytes, ok := value.([]byte)
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return json.Unmarshal(bytes, &j.Data)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value 实现 driver.Valuer 接口
|
|
||||||
func (j JSON) Value() (interface{}, error) {
|
|
||||||
if j.Data == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return json.Marshal(j.Data)
|
|
||||||
}
|
|
||||||
@@ -1,203 +0,0 @@
|
|||||||
// Package response provides standardized HTTP response helpers.
|
|
||||||
package response
|
|
||||||
|
|
||||||
import (
|
|
||||||
"log"
|
|
||||||
"math"
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
infraerrors "github.com/user-management-system/internal/pkg/errors"
|
|
||||||
"github.com/user-management-system/internal/util/logredact"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Response 标准API响应格式
|
|
||||||
type Response struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Reason string `json:"reason,omitempty"`
|
|
||||||
Metadata map[string]string `json:"metadata,omitempty"`
|
|
||||||
Data any `json:"data,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// PaginatedData 分页数据格式(匹配前端期望)
|
|
||||||
type PaginatedData struct {
|
|
||||||
Items any `json:"items"`
|
|
||||||
Total int64 `json:"total"`
|
|
||||||
Page int `json:"page"`
|
|
||||||
PageSize int `json:"page_size"`
|
|
||||||
Pages int `json:"pages"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Success 返回成功响应
|
|
||||||
func Success(c *gin.Context, data any) {
|
|
||||||
c.JSON(http.StatusOK, Response{
|
|
||||||
Code: 0,
|
|
||||||
Message: "success",
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Created 返回创建成功响应
|
|
||||||
func Created(c *gin.Context, data any) {
|
|
||||||
c.JSON(http.StatusCreated, Response{
|
|
||||||
Code: 0,
|
|
||||||
Message: "success",
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Accepted 返回异步接受响应 (HTTP 202)
|
|
||||||
func Accepted(c *gin.Context, data any) {
|
|
||||||
c.JSON(http.StatusAccepted, Response{
|
|
||||||
Code: 0,
|
|
||||||
Message: "accepted",
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error 返回错误响应
|
|
||||||
func Error(c *gin.Context, statusCode int, message string) {
|
|
||||||
c.JSON(statusCode, Response{
|
|
||||||
Code: statusCode,
|
|
||||||
Message: message,
|
|
||||||
Reason: "",
|
|
||||||
Metadata: nil,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrorWithDetails returns an error response compatible with the existing envelope while
|
|
||||||
// optionally providing structured error fields (reason/metadata).
|
|
||||||
func ErrorWithDetails(c *gin.Context, statusCode int, message, reason string, metadata map[string]string) {
|
|
||||||
c.JSON(statusCode, Response{
|
|
||||||
Code: statusCode,
|
|
||||||
Message: message,
|
|
||||||
Reason: reason,
|
|
||||||
Metadata: metadata,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrorFrom converts an ApplicationError (or any error) into the envelope-compatible error response.
|
|
||||||
// It returns true if an error was written.
|
|
||||||
func ErrorFrom(c *gin.Context, err error) bool {
|
|
||||||
if err == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
statusCode, status := infraerrors.ToHTTP(err)
|
|
||||||
|
|
||||||
// Log internal errors with full details for debugging
|
|
||||||
if statusCode >= 500 && c.Request != nil {
|
|
||||||
log.Printf("[ERROR] %s %s\n Error: %s", c.Request.Method, c.Request.URL.Path, logredact.RedactText(err.Error()))
|
|
||||||
}
|
|
||||||
|
|
||||||
ErrorWithDetails(c, statusCode, status.Message, status.Reason, status.Metadata)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// BadRequest 返回400错误
|
|
||||||
func BadRequest(c *gin.Context, message string) {
|
|
||||||
Error(c, http.StatusBadRequest, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Unauthorized 返回401错误
|
|
||||||
func Unauthorized(c *gin.Context, message string) {
|
|
||||||
Error(c, http.StatusUnauthorized, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Forbidden 返回403错误
|
|
||||||
func Forbidden(c *gin.Context, message string) {
|
|
||||||
Error(c, http.StatusForbidden, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NotFound 返回404错误
|
|
||||||
func NotFound(c *gin.Context, message string) {
|
|
||||||
Error(c, http.StatusNotFound, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// InternalError 返回500错误
|
|
||||||
func InternalError(c *gin.Context, message string) {
|
|
||||||
Error(c, http.StatusInternalServerError, message)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Paginated 返回分页数据
|
|
||||||
func Paginated(c *gin.Context, items any, total int64, page, pageSize int) {
|
|
||||||
pages := int(math.Ceil(float64(total) / float64(pageSize)))
|
|
||||||
if pages < 1 {
|
|
||||||
pages = 1
|
|
||||||
}
|
|
||||||
|
|
||||||
Success(c, PaginatedData{
|
|
||||||
Items: items,
|
|
||||||
Total: total,
|
|
||||||
Page: page,
|
|
||||||
PageSize: pageSize,
|
|
||||||
Pages: pages,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// PaginationResult 分页结果(与pagination.PaginationResult兼容)
|
|
||||||
type PaginationResult struct {
|
|
||||||
Total int64
|
|
||||||
Page int
|
|
||||||
PageSize int
|
|
||||||
Pages int
|
|
||||||
}
|
|
||||||
|
|
||||||
// PaginatedWithResult 使用PaginationResult返回分页数据
|
|
||||||
func PaginatedWithResult(c *gin.Context, items any, pagination *PaginationResult) {
|
|
||||||
if pagination == nil {
|
|
||||||
Success(c, PaginatedData{
|
|
||||||
Items: items,
|
|
||||||
Total: 0,
|
|
||||||
Page: 1,
|
|
||||||
PageSize: 20,
|
|
||||||
Pages: 1,
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
Success(c, PaginatedData{
|
|
||||||
Items: items,
|
|
||||||
Total: pagination.Total,
|
|
||||||
Page: pagination.Page,
|
|
||||||
PageSize: pagination.PageSize,
|
|
||||||
Pages: pagination.Pages,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ParsePagination 解析分页参数
|
|
||||||
func ParsePagination(c *gin.Context) (page, pageSize int) {
|
|
||||||
page = 1
|
|
||||||
pageSize = 20
|
|
||||||
|
|
||||||
if p := c.Query("page"); p != "" {
|
|
||||||
if val, err := parseInt(p); err == nil && val > 0 {
|
|
||||||
page = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 支持 page_size 和 limit 两种参数名
|
|
||||||
if ps := c.Query("page_size"); ps != "" {
|
|
||||||
if val, err := parseInt(ps); err == nil && val > 0 && val <= 1000 {
|
|
||||||
pageSize = val
|
|
||||||
}
|
|
||||||
} else if l := c.Query("limit"); l != "" {
|
|
||||||
if val, err := parseInt(l); err == nil && val > 0 && val <= 1000 {
|
|
||||||
pageSize = val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return page, pageSize
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseInt(s string) (int, error) {
|
|
||||||
var result int
|
|
||||||
for _, c := range s {
|
|
||||||
if c < '0' || c > '9' {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
result = result*10 + int(c-'0')
|
|
||||||
}
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
@@ -1,788 +0,0 @@
|
|||||||
//go:build unit
|
|
||||||
|
|
||||||
package response
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
errors2 "github.com/user-management-system/internal/pkg/errors"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ---------- 辅助函数 ----------
|
|
||||||
|
|
||||||
// parseResponseBody 从 httptest.ResponseRecorder 中解析 JSON 响应体
|
|
||||||
func parseResponseBody(t *testing.T, w *httptest.ResponseRecorder) Response {
|
|
||||||
t.Helper()
|
|
||||||
var got Response
|
|
||||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
|
||||||
return got
|
|
||||||
}
|
|
||||||
|
|
||||||
// parsePaginatedBody 从响应体中解析分页数据(Data 字段是 PaginatedData)
|
|
||||||
func parsePaginatedBody(t *testing.T, w *httptest.ResponseRecorder) (Response, PaginatedData) {
|
|
||||||
t.Helper()
|
|
||||||
// 先用 raw json 解析,因为 Data 是 any 类型
|
|
||||||
var raw struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Reason string `json:"reason,omitempty"`
|
|
||||||
Data json.RawMessage `json:"data,omitempty"`
|
|
||||||
}
|
|
||||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &raw))
|
|
||||||
|
|
||||||
var pd PaginatedData
|
|
||||||
require.NoError(t, json.Unmarshal(raw.Data, &pd))
|
|
||||||
|
|
||||||
return Response{Code: raw.Code, Message: raw.Message, Reason: raw.Reason}, pd
|
|
||||||
}
|
|
||||||
|
|
||||||
// newContextWithQuery 创建一个带有 URL query 参数的 gin.Context 用于测试 ParsePagination
|
|
||||||
func newContextWithQuery(query string) (*httptest.ResponseRecorder, *gin.Context) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
c.Request = httptest.NewRequest(http.MethodGet, "/?"+query, nil)
|
|
||||||
return w, c
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------- 现有测试 ----------
|
|
||||||
|
|
||||||
func TestErrorWithDetails(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
statusCode int
|
|
||||||
message string
|
|
||||||
reason string
|
|
||||||
metadata map[string]string
|
|
||||||
want Response
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "plain_error",
|
|
||||||
statusCode: http.StatusBadRequest,
|
|
||||||
message: "invalid request",
|
|
||||||
want: Response{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
Message: "invalid request",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "structured_error",
|
|
||||||
statusCode: http.StatusForbidden,
|
|
||||||
message: "no access",
|
|
||||||
reason: "FORBIDDEN",
|
|
||||||
metadata: map[string]string{"k": "v"},
|
|
||||||
want: Response{
|
|
||||||
Code: http.StatusForbidden,
|
|
||||||
Message: "no access",
|
|
||||||
Reason: "FORBIDDEN",
|
|
||||||
Metadata: map[string]string{"k": "v"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
ErrorWithDetails(c, tt.statusCode, tt.message, tt.reason, tt.metadata)
|
|
||||||
|
|
||||||
require.Equal(t, tt.statusCode, w.Code)
|
|
||||||
|
|
||||||
var got Response
|
|
||||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
|
||||||
require.Equal(t, tt.want, got)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestErrorFrom(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
err error
|
|
||||||
wantWritten bool
|
|
||||||
wantHTTPCode int
|
|
||||||
wantBody Response
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "nil_error",
|
|
||||||
err: nil,
|
|
||||||
wantWritten: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "application_error",
|
|
||||||
err: errors2.Forbidden("FORBIDDEN", "no access").WithMetadata(map[string]string{"scope": "admin"}),
|
|
||||||
wantWritten: true,
|
|
||||||
wantHTTPCode: http.StatusForbidden,
|
|
||||||
wantBody: Response{
|
|
||||||
Code: http.StatusForbidden,
|
|
||||||
Message: "no access",
|
|
||||||
Reason: "FORBIDDEN",
|
|
||||||
Metadata: map[string]string{"scope": "admin"},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "bad_request_error",
|
|
||||||
err: errors2.BadRequest("INVALID_REQUEST", "invalid request"),
|
|
||||||
wantWritten: true,
|
|
||||||
wantHTTPCode: http.StatusBadRequest,
|
|
||||||
wantBody: Response{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
Message: "invalid request",
|
|
||||||
Reason: "INVALID_REQUEST",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "unauthorized_error",
|
|
||||||
err: errors2.Unauthorized("UNAUTHORIZED", "unauthorized"),
|
|
||||||
wantWritten: true,
|
|
||||||
wantHTTPCode: http.StatusUnauthorized,
|
|
||||||
wantBody: Response{
|
|
||||||
Code: http.StatusUnauthorized,
|
|
||||||
Message: "unauthorized",
|
|
||||||
Reason: "UNAUTHORIZED",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "not_found_error",
|
|
||||||
err: errors2.NotFound("NOT_FOUND", "not found"),
|
|
||||||
wantWritten: true,
|
|
||||||
wantHTTPCode: http.StatusNotFound,
|
|
||||||
wantBody: Response{
|
|
||||||
Code: http.StatusNotFound,
|
|
||||||
Message: "not found",
|
|
||||||
Reason: "NOT_FOUND",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "conflict_error",
|
|
||||||
err: errors2.Conflict("CONFLICT", "conflict"),
|
|
||||||
wantWritten: true,
|
|
||||||
wantHTTPCode: http.StatusConflict,
|
|
||||||
wantBody: Response{
|
|
||||||
Code: http.StatusConflict,
|
|
||||||
Message: "conflict",
|
|
||||||
Reason: "CONFLICT",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "unknown_error_defaults_to_500",
|
|
||||||
err: errors.New("boom"),
|
|
||||||
wantWritten: true,
|
|
||||||
wantHTTPCode: http.StatusInternalServerError,
|
|
||||||
wantBody: Response{
|
|
||||||
Code: http.StatusInternalServerError,
|
|
||||||
Message: errors2.UnknownMessage,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
written := ErrorFrom(c, tt.err)
|
|
||||||
require.Equal(t, tt.wantWritten, written)
|
|
||||||
|
|
||||||
if !tt.wantWritten {
|
|
||||||
require.Equal(t, 200, w.Code)
|
|
||||||
require.Empty(t, w.Body.String())
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, tt.wantHTTPCode, w.Code)
|
|
||||||
var got Response
|
|
||||||
require.NoError(t, json.Unmarshal(w.Body.Bytes(), &got))
|
|
||||||
require.Equal(t, tt.wantBody, got)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ---------- 新增测试 ----------
|
|
||||||
|
|
||||||
func TestSuccess(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
data any
|
|
||||||
wantCode int
|
|
||||||
wantBody Response
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "返回字符串数据",
|
|
||||||
data: "hello",
|
|
||||||
wantCode: http.StatusOK,
|
|
||||||
wantBody: Response{Code: 0, Message: "success", Data: "hello"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "返回nil数据",
|
|
||||||
data: nil,
|
|
||||||
wantCode: http.StatusOK,
|
|
||||||
wantBody: Response{Code: 0, Message: "success"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "返回map数据",
|
|
||||||
data: map[string]string{"key": "value"},
|
|
||||||
wantCode: http.StatusOK,
|
|
||||||
wantBody: Response{Code: 0, Message: "success"},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
Success(c, tt.data)
|
|
||||||
|
|
||||||
require.Equal(t, tt.wantCode, w.Code)
|
|
||||||
|
|
||||||
// 只验证 code 和 message,data 字段类型在 JSON 反序列化时会变成 map/slice
|
|
||||||
got := parseResponseBody(t, w)
|
|
||||||
require.Equal(t, 0, got.Code)
|
|
||||||
require.Equal(t, "success", got.Message)
|
|
||||||
|
|
||||||
if tt.data == nil {
|
|
||||||
require.Nil(t, got.Data)
|
|
||||||
} else {
|
|
||||||
require.NotNil(t, got.Data)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCreated(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
data any
|
|
||||||
wantCode int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "创建成功_返回数据",
|
|
||||||
data: map[string]int{"id": 42},
|
|
||||||
wantCode: http.StatusCreated,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "创建成功_nil数据",
|
|
||||||
data: nil,
|
|
||||||
wantCode: http.StatusCreated,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
Created(c, tt.data)
|
|
||||||
|
|
||||||
require.Equal(t, tt.wantCode, w.Code)
|
|
||||||
|
|
||||||
got := parseResponseBody(t, w)
|
|
||||||
require.Equal(t, 0, got.Code)
|
|
||||||
require.Equal(t, "success", got.Message)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestError(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
statusCode int
|
|
||||||
message string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "400错误",
|
|
||||||
statusCode: http.StatusBadRequest,
|
|
||||||
message: "bad request",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "500错误",
|
|
||||||
statusCode: http.StatusInternalServerError,
|
|
||||||
message: "internal error",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "自定义状态码",
|
|
||||||
statusCode: 418,
|
|
||||||
message: "I'm a teapot",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
Error(c, tt.statusCode, tt.message)
|
|
||||||
|
|
||||||
require.Equal(t, tt.statusCode, w.Code)
|
|
||||||
|
|
||||||
got := parseResponseBody(t, w)
|
|
||||||
require.Equal(t, tt.statusCode, got.Code)
|
|
||||||
require.Equal(t, tt.message, got.Message)
|
|
||||||
require.Empty(t, got.Reason)
|
|
||||||
require.Nil(t, got.Metadata)
|
|
||||||
require.Nil(t, got.Data)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBadRequest(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
BadRequest(c, "参数无效")
|
|
||||||
|
|
||||||
require.Equal(t, http.StatusBadRequest, w.Code)
|
|
||||||
got := parseResponseBody(t, w)
|
|
||||||
require.Equal(t, http.StatusBadRequest, got.Code)
|
|
||||||
require.Equal(t, "参数无效", got.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnauthorized(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
Unauthorized(c, "未登录")
|
|
||||||
|
|
||||||
require.Equal(t, http.StatusUnauthorized, w.Code)
|
|
||||||
got := parseResponseBody(t, w)
|
|
||||||
require.Equal(t, http.StatusUnauthorized, got.Code)
|
|
||||||
require.Equal(t, "未登录", got.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestForbidden(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
Forbidden(c, "无权限")
|
|
||||||
|
|
||||||
require.Equal(t, http.StatusForbidden, w.Code)
|
|
||||||
got := parseResponseBody(t, w)
|
|
||||||
require.Equal(t, http.StatusForbidden, got.Code)
|
|
||||||
require.Equal(t, "无权限", got.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNotFound(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
NotFound(c, "资源不存在")
|
|
||||||
|
|
||||||
require.Equal(t, http.StatusNotFound, w.Code)
|
|
||||||
got := parseResponseBody(t, w)
|
|
||||||
require.Equal(t, http.StatusNotFound, got.Code)
|
|
||||||
require.Equal(t, "资源不存在", got.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInternalError(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
InternalError(c, "服务器内部错误")
|
|
||||||
|
|
||||||
require.Equal(t, http.StatusInternalServerError, w.Code)
|
|
||||||
got := parseResponseBody(t, w)
|
|
||||||
require.Equal(t, http.StatusInternalServerError, got.Code)
|
|
||||||
require.Equal(t, "服务器内部错误", got.Message)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPaginated(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
items any
|
|
||||||
total int64
|
|
||||||
page int
|
|
||||||
pageSize int
|
|
||||||
wantPages int
|
|
||||||
wantTotal int64
|
|
||||||
wantPage int
|
|
||||||
wantPageSize int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "标准分页_多页",
|
|
||||||
items: []string{"a", "b"},
|
|
||||||
total: 25,
|
|
||||||
page: 1,
|
|
||||||
pageSize: 10,
|
|
||||||
wantPages: 3,
|
|
||||||
wantTotal: 25,
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 10,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "总数刚好整除",
|
|
||||||
items: []string{"a"},
|
|
||||||
total: 20,
|
|
||||||
page: 2,
|
|
||||||
pageSize: 10,
|
|
||||||
wantPages: 2,
|
|
||||||
wantTotal: 20,
|
|
||||||
wantPage: 2,
|
|
||||||
wantPageSize: 10,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "总数为0_pages至少为1",
|
|
||||||
items: []string{},
|
|
||||||
total: 0,
|
|
||||||
page: 1,
|
|
||||||
pageSize: 10,
|
|
||||||
wantPages: 1,
|
|
||||||
wantTotal: 0,
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 10,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "单页数据",
|
|
||||||
items: []int{1, 2, 3},
|
|
||||||
total: 3,
|
|
||||||
page: 1,
|
|
||||||
pageSize: 20,
|
|
||||||
wantPages: 1,
|
|
||||||
wantTotal: 3,
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 20,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "总数为1",
|
|
||||||
items: []string{"only"},
|
|
||||||
total: 1,
|
|
||||||
page: 1,
|
|
||||||
pageSize: 10,
|
|
||||||
wantPages: 1,
|
|
||||||
wantTotal: 1,
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 10,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
Paginated(c, tt.items, tt.total, tt.page, tt.pageSize)
|
|
||||||
|
|
||||||
require.Equal(t, http.StatusOK, w.Code)
|
|
||||||
|
|
||||||
resp, pd := parsePaginatedBody(t, w)
|
|
||||||
require.Equal(t, 0, resp.Code)
|
|
||||||
require.Equal(t, "success", resp.Message)
|
|
||||||
require.Equal(t, tt.wantTotal, pd.Total)
|
|
||||||
require.Equal(t, tt.wantPage, pd.Page)
|
|
||||||
require.Equal(t, tt.wantPageSize, pd.PageSize)
|
|
||||||
require.Equal(t, tt.wantPages, pd.Pages)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPaginatedWithResult(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
items any
|
|
||||||
pagination *PaginationResult
|
|
||||||
wantTotal int64
|
|
||||||
wantPage int
|
|
||||||
wantPageSize int
|
|
||||||
wantPages int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "正常分页结果",
|
|
||||||
items: []string{"a", "b"},
|
|
||||||
pagination: &PaginationResult{
|
|
||||||
Total: 50,
|
|
||||||
Page: 3,
|
|
||||||
PageSize: 10,
|
|
||||||
Pages: 5,
|
|
||||||
},
|
|
||||||
wantTotal: 50,
|
|
||||||
wantPage: 3,
|
|
||||||
wantPageSize: 10,
|
|
||||||
wantPages: 5,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "pagination为nil_使用默认值",
|
|
||||||
items: []string{},
|
|
||||||
pagination: nil,
|
|
||||||
wantTotal: 0,
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 20,
|
|
||||||
wantPages: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "单页结果",
|
|
||||||
items: []int{1},
|
|
||||||
pagination: &PaginationResult{
|
|
||||||
Total: 1,
|
|
||||||
Page: 1,
|
|
||||||
PageSize: 20,
|
|
||||||
Pages: 1,
|
|
||||||
},
|
|
||||||
wantTotal: 1,
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 20,
|
|
||||||
wantPages: 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
c, _ := gin.CreateTestContext(w)
|
|
||||||
|
|
||||||
PaginatedWithResult(c, tt.items, tt.pagination)
|
|
||||||
|
|
||||||
require.Equal(t, http.StatusOK, w.Code)
|
|
||||||
|
|
||||||
resp, pd := parsePaginatedBody(t, w)
|
|
||||||
require.Equal(t, 0, resp.Code)
|
|
||||||
require.Equal(t, "success", resp.Message)
|
|
||||||
require.Equal(t, tt.wantTotal, pd.Total)
|
|
||||||
require.Equal(t, tt.wantPage, pd.Page)
|
|
||||||
require.Equal(t, tt.wantPageSize, pd.PageSize)
|
|
||||||
require.Equal(t, tt.wantPages, pd.Pages)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsePagination(t *testing.T) {
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
query string
|
|
||||||
wantPage int
|
|
||||||
wantPageSize int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "无参数_使用默认值",
|
|
||||||
query: "",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 20,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "仅指定page",
|
|
||||||
query: "page=3",
|
|
||||||
wantPage: 3,
|
|
||||||
wantPageSize: 20,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "仅指定page_size",
|
|
||||||
query: "page_size=50",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 50,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "同时指定page和page_size",
|
|
||||||
query: "page=2&page_size=30",
|
|
||||||
wantPage: 2,
|
|
||||||
wantPageSize: 30,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "使用limit代替page_size",
|
|
||||||
query: "limit=15",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 15,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "page_size优先于limit",
|
|
||||||
query: "page_size=25&limit=50",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 25,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "page为0_使用默认值",
|
|
||||||
query: "page=0",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 20,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "page_size超过1000_使用默认值",
|
|
||||||
query: "page_size=1001",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 20,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "page_size恰好1000_有效",
|
|
||||||
query: "page_size=1000",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 1000,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "page为非数字_使用默认值",
|
|
||||||
query: "page=abc",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 20,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "page_size为非数字_使用默认值",
|
|
||||||
query: "page_size=xyz",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 20,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "limit为非数字_使用默认值",
|
|
||||||
query: "limit=abc",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 20,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "page_size为0_使用默认值",
|
|
||||||
query: "page_size=0",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 20,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "limit为0_使用默认值",
|
|
||||||
query: "limit=0",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 20,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "大页码",
|
|
||||||
query: "page=999&page_size=100",
|
|
||||||
wantPage: 999,
|
|
||||||
wantPageSize: 100,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "page_size为1_最小有效值",
|
|
||||||
query: "page_size=1",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "混合数字和字母的page",
|
|
||||||
query: "page=12a",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 20,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "limit超过1000_使用默认值",
|
|
||||||
query: "limit=2000",
|
|
||||||
wantPage: 1,
|
|
||||||
wantPageSize: 20,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
_, c := newContextWithQuery(tt.query)
|
|
||||||
|
|
||||||
page, pageSize := ParsePagination(c)
|
|
||||||
|
|
||||||
require.Equal(t, tt.wantPage, page, "page 不符合预期")
|
|
||||||
require.Equal(t, tt.wantPageSize, pageSize, "pageSize 不符合预期")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_parseInt(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
input string
|
|
||||||
wantVal int
|
|
||||||
wantErr bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "正常数字",
|
|
||||||
input: "123",
|
|
||||||
wantVal: 123,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "零",
|
|
||||||
input: "0",
|
|
||||||
wantVal: 0,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "单个数字",
|
|
||||||
input: "5",
|
|
||||||
wantVal: 5,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "大数字",
|
|
||||||
input: "99999",
|
|
||||||
wantVal: 99999,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "包含字母_返回0",
|
|
||||||
input: "abc",
|
|
||||||
wantVal: 0,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "数字开头接字母_返回0",
|
|
||||||
input: "12a",
|
|
||||||
wantVal: 0,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "包含负号_返回0",
|
|
||||||
input: "-1",
|
|
||||||
wantVal: 0,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "包含小数点_返回0",
|
|
||||||
input: "1.5",
|
|
||||||
wantVal: 0,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "包含空格_返回0",
|
|
||||||
input: "1 2",
|
|
||||||
wantVal: 0,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "空字符串",
|
|
||||||
input: "",
|
|
||||||
wantVal: 0,
|
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
val, err := parseInt(tt.input)
|
|
||||||
if tt.wantErr {
|
|
||||||
require.Error(t, err)
|
|
||||||
} else {
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
require.Equal(t, tt.wantVal, val)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
package response
|
|
||||||
|
|
||||||
// Response 统一响应结构
|
|
||||||
type Response struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Data interface{} `json:"data,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Success 成功响应
|
|
||||||
func Success(data interface{}) *Response {
|
|
||||||
return &Response{
|
|
||||||
Code: 0,
|
|
||||||
Message: "success",
|
|
||||||
Data: data,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error 错误响应
|
|
||||||
func Error(message string) *Response {
|
|
||||||
return &Response{
|
|
||||||
Code: -1,
|
|
||||||
Message: message,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrorWithCode 带错误码的错误响应
|
|
||||||
func ErrorWithCode(code int, message string) *Response {
|
|
||||||
return &Response{
|
|
||||||
Code: code,
|
|
||||||
Message: message,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithData 带扩展数据的成功响应
|
|
||||||
func WithData(data interface{}, extra map[string]interface{}) *Response {
|
|
||||||
payload, ok := data.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
payload = map[string]interface{}{
|
|
||||||
"items": data,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range extra {
|
|
||||||
payload[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := Success(payload)
|
|
||||||
return resp
|
|
||||||
}
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
package response
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestWithDataWrapsSlicesAndMergesExtra(t *testing.T) {
|
|
||||||
resp := WithData([]string{"a", "b"}, map[string]interface{}{
|
|
||||||
"total": 2,
|
|
||||||
"page": 1,
|
|
||||||
})
|
|
||||||
|
|
||||||
data, ok := resp.Data.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
t.Fatalf("expected map payload, got %T", resp.Data)
|
|
||||||
}
|
|
||||||
if data["total"] != 2 {
|
|
||||||
t.Fatalf("expected total=2, got %v", data["total"])
|
|
||||||
}
|
|
||||||
items, ok := data["items"].([]string)
|
|
||||||
if !ok || len(items) != 2 {
|
|
||||||
t.Fatalf("expected items slice to be preserved, got %#v", data["items"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithDataPreservesMapPayload(t *testing.T) {
|
|
||||||
resp := WithData(map[string]interface{}{"user": "alice"}, map[string]interface{}{"page": 1})
|
|
||||||
|
|
||||||
data := resp.Data.(map[string]interface{})
|
|
||||||
if data["user"] != "alice" {
|
|
||||||
t.Fatalf("expected user=alice, got %v", data["user"])
|
|
||||||
}
|
|
||||||
if data["page"] != 1 {
|
|
||||||
t.Fatalf("expected page=1, got %v", data["page"])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
package security
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RateLimitAlgorithm 限流算法类型
|
|
||||||
type RateLimitAlgorithm string
|
|
||||||
|
|
||||||
const (
|
|
||||||
AlgorithmTokenBucket RateLimitAlgorithm = "token_bucket"
|
|
||||||
AlgorithmLeakyBucket RateLimitAlgorithm = "leaky_bucket"
|
|
||||||
AlgorithmSlidingWindow RateLimitAlgorithm = "sliding_window"
|
|
||||||
AlgorithmFixedWindow RateLimitAlgorithm = "fixed_window"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TokenBucket 令牌桶算法
|
|
||||||
type TokenBucket struct {
|
|
||||||
capacity int64
|
|
||||||
tokens int64
|
|
||||||
rate int64 // 每秒产生的令牌数
|
|
||||||
lastRefill time.Time
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewTokenBucket 创建令牌桶
|
|
||||||
func NewTokenBucket(capacity, rate int64) *TokenBucket {
|
|
||||||
return &TokenBucket{
|
|
||||||
capacity: capacity,
|
|
||||||
tokens: capacity,
|
|
||||||
rate: rate,
|
|
||||||
lastRefill: time.Now(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allow 检查是否允许访问
|
|
||||||
func (tb *TokenBucket) Allow() bool {
|
|
||||||
tb.mu.Lock()
|
|
||||||
defer tb.mu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
elapsed := now.Sub(tb.lastRefill).Seconds()
|
|
||||||
|
|
||||||
// 计算需要补充的令牌数
|
|
||||||
refillTokens := int64(elapsed * float64(tb.rate))
|
|
||||||
tb.tokens += refillTokens
|
|
||||||
if tb.tokens > tb.capacity {
|
|
||||||
tb.tokens = tb.capacity
|
|
||||||
}
|
|
||||||
tb.lastRefill = now
|
|
||||||
|
|
||||||
// 检查是否有足够的令牌
|
|
||||||
if tb.tokens > 0 {
|
|
||||||
tb.tokens--
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// LeakyBucket 漏桶算法
|
|
||||||
type LeakyBucket struct {
|
|
||||||
capacity int64
|
|
||||||
water int64
|
|
||||||
rate int64 // 每秒漏出的水量
|
|
||||||
lastLeak time.Time
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewLeakyBucket 创建漏桶
|
|
||||||
func NewLeakyBucket(capacity, rate int64) *LeakyBucket {
|
|
||||||
return &LeakyBucket{
|
|
||||||
capacity: capacity,
|
|
||||||
water: 0,
|
|
||||||
rate: rate,
|
|
||||||
lastLeak: time.Now(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allow 检查是否允许访问
|
|
||||||
func (lb *LeakyBucket) Allow() bool {
|
|
||||||
lb.mu.Lock()
|
|
||||||
defer lb.mu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
elapsed := now.Sub(lb.lastLeak).Seconds()
|
|
||||||
|
|
||||||
// 计算漏出的水量
|
|
||||||
leakWater := int64(elapsed * float64(lb.rate))
|
|
||||||
lb.water -= leakWater
|
|
||||||
if lb.water < 0 {
|
|
||||||
lb.water = 0
|
|
||||||
}
|
|
||||||
lb.lastLeak = now
|
|
||||||
|
|
||||||
// 检查桶是否已满
|
|
||||||
if lb.water < lb.capacity {
|
|
||||||
lb.water++
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// SlidingWindow 滑动窗口算法
|
|
||||||
type SlidingWindow struct {
|
|
||||||
window time.Duration
|
|
||||||
capacity int64
|
|
||||||
requests []time.Time
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewSlidingWindow 创建滑动窗口
|
|
||||||
func NewSlidingWindow(window time.Duration, capacity int64) *SlidingWindow {
|
|
||||||
return &SlidingWindow{
|
|
||||||
window: window,
|
|
||||||
capacity: capacity,
|
|
||||||
requests: make([]time.Time, 0),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allow 检查是否允许访问
|
|
||||||
func (sw *SlidingWindow) Allow() bool {
|
|
||||||
sw.mu.Lock()
|
|
||||||
defer sw.mu.Unlock()
|
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
// 移除窗口外的请求
|
|
||||||
validRequests := make([]time.Time, 0)
|
|
||||||
for _, req := range sw.requests {
|
|
||||||
if now.Sub(req) < sw.window {
|
|
||||||
validRequests = append(validRequests, req)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sw.requests = validRequests
|
|
||||||
|
|
||||||
// 检查是否超过容量
|
|
||||||
if int64(len(sw.requests)) < sw.capacity {
|
|
||||||
sw.requests = append(sw.requests, now)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// RateLimiter 限流器
|
|
||||||
type RateLimiter struct {
|
|
||||||
algorithm RateLimitAlgorithm
|
|
||||||
limiter interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRateLimiter 创建限流器
|
|
||||||
func NewRateLimiter(algorithm RateLimitAlgorithm, capacity, rate int64, window time.Duration) *RateLimiter {
|
|
||||||
limiter := &RateLimiter{algorithm: algorithm}
|
|
||||||
|
|
||||||
switch algorithm {
|
|
||||||
case AlgorithmTokenBucket:
|
|
||||||
limiter.limiter = NewTokenBucket(capacity, rate)
|
|
||||||
case AlgorithmLeakyBucket:
|
|
||||||
limiter.limiter = NewLeakyBucket(capacity, rate)
|
|
||||||
case AlgorithmSlidingWindow:
|
|
||||||
limiter.limiter = NewSlidingWindow(window, capacity)
|
|
||||||
default:
|
|
||||||
limiter.limiter = NewSlidingWindow(window, capacity)
|
|
||||||
}
|
|
||||||
|
|
||||||
return limiter
|
|
||||||
}
|
|
||||||
|
|
||||||
// Allow 检查是否允许访问
|
|
||||||
func (rl *RateLimiter) Allow() bool {
|
|
||||||
switch rl.algorithm {
|
|
||||||
case AlgorithmTokenBucket:
|
|
||||||
return rl.limiter.(*TokenBucket).Allow()
|
|
||||||
case AlgorithmLeakyBucket:
|
|
||||||
return rl.limiter.(*LeakyBucket).Allow()
|
|
||||||
case AlgorithmSlidingWindow:
|
|
||||||
return rl.limiter.(*SlidingWindow).Allow()
|
|
||||||
default:
|
|
||||||
return rl.limiter.(*SlidingWindow).Allow()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
package response
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Response 统一响应结构
|
|
||||||
type Response struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Data interface{} `json:"data,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Success 成功响应
|
|
||||||
func Success(c *gin.Context, data interface{}) {
|
|
||||||
c.JSON(http.StatusOK, Response{
|
|
||||||
Code: 0,
|
|
||||||
Message: "success",
|
|
||||||
Data: data,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error 错误响应
|
|
||||||
func Error(c *gin.Context, httpStatus int, message string, err error) {
|
|
||||||
if err != nil {
|
|
||||||
// 在开发环境下返回详细错误信息
|
|
||||||
if gin.Mode() == gin.DebugMode {
|
|
||||||
c.JSON(httpStatus, Response{
|
|
||||||
Code: httpStatus,
|
|
||||||
Message: message,
|
|
||||||
Data: err.Error(),
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
c.JSON(httpStatus, Response{
|
|
||||||
Code: httpStatus,
|
|
||||||
Message: message,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrorWithCode 错误响应(带自定义错误码)
|
|
||||||
func ErrorWithCode(c *gin.Context, code int, message string) {
|
|
||||||
c.JSON(http.StatusOK, Response{
|
|
||||||
Code: code,
|
|
||||||
Message: message,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user