diff --git a/internal/api/router/router.go b/internal/api/router/router.go index a2c760c..b6d4a04 100644 --- a/internal/api/router/router.go +++ b/internal/api/router/router.go @@ -1,6 +1,11 @@ package router import ( + "net/http" + "os" + "path/filepath" + "strings" + "github.com/gin-gonic/gin" "github.com/prometheus/client_golang/prometheus/promhttp" swaggerFiles "github.com/swaggo/files" @@ -122,9 +127,9 @@ func (r *Router) Setup() *gin.Engine { ) } - // P0 安全修复:/uploads 目录不再公开暴露,改为需要认证后才能访问 + // P0 安全修复:/uploads 目录使用受控文件服务,防止路径遍历 uploadsGroup := r.engine.Group("/uploads", r.authMiddleware.Required()) - uploadsGroup.Static("", "./uploads") + uploadsGroup.GET("/*filepath", r.serveUploads) r.engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) @@ -408,3 +413,37 @@ func (r *Router) Setup() *gin.Engine { func (r *Router) GetEngine() *gin.Engine { return r.engine } + +// serveUploads 提供受控的上传文件访问,防止路径遍历攻击 +func (r *Router) serveUploads(c *gin.Context) { + filePath := c.Param("filepath") + + // 1. 清理路径,阻止路径遍历 + filePath = filepath.Clean("/" + filePath) + if strings.Contains(filePath, "..") { + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"code": 403, "message": "invalid path"}) + return + } + + // 2. 限制在上传目录内 + fullPath := filepath.Join("./uploads", filePath) + absUploads, _ := filepath.Abs("./uploads") + absPath, _ := filepath.Abs(fullPath) + if !strings.HasPrefix(absPath, absUploads) { + c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"code": 403, "message": "access denied"}) + return + } + + // 3. 检查文件存在 + if _, err := os.Stat(fullPath); os.IsNotExist(err) { + c.AbortWithStatus(http.StatusNotFound) + return + } + + // 4. 设置安全响应头(禁止浏览器执行) + c.Header("Content-Security-Policy", "default-src 'none'") + c.Header("X-Content-Type-Options", "nosniff") + + // 5. 提供文件 + c.File(fullPath) +} diff --git a/internal/auth/oauth.go b/internal/auth/oauth.go index cd053e3..6803551 100644 --- a/internal/auth/oauth.go +++ b/internal/auth/oauth.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/url" + "time" "github.com/user-management-system/internal/auth/providers" ) @@ -71,6 +72,9 @@ type OAuthManager interface { // ValidateToken 验证令牌 ValidateToken(token string) (bool, error) + // ValidateTokenWithProvider 通过指定 provider 验证令牌 + ValidateTokenWithProvider(ctx context.Context, provider OAuthProvider, token string) (bool, error) + // GetConfig 获取OAuth配置 GetConfig(provider OAuthProvider) (*OAuthConfig, bool) @@ -442,9 +446,11 @@ func (m *DefaultOAuthManager) ValidateToken(token string) (bool, error) { if len(providers) == 0 { return false, errors.New("no OAuth providers configured") } + // 添加 5 秒超时,防止 provider API 无响应导致阻塞 + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() // 尝试任一 provider 的 userinfo 端点验证 tokenObj := &OAuthToken{AccessToken: token} - ctx := context.Background() for _, p := range providers { if _, err := m.GetUserInfo(ctx, p.Provider, tokenObj); err == nil { return true, nil @@ -454,10 +460,13 @@ func (m *DefaultOAuthManager) ValidateToken(token string) (bool, error) { } // ValidateTokenWithProvider 通过指定 provider 验证令牌 -func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider, token string) (bool, error) { +func (m *DefaultOAuthManager) ValidateTokenWithProvider(ctx context.Context, provider OAuthProvider, token string) (bool, error) { if token == "" { return false, nil } + if ctx == nil { + ctx = context.Background() + } cfg, ok := m.GetConfig(provider) if !ok || cfg.ClientID == "" { @@ -466,7 +475,6 @@ func (m *DefaultOAuthManager) ValidateTokenWithProvider(provider OAuthProvider, // 通过 provider 的 userinfo 端点验证 token tokenObj := &OAuthToken{AccessToken: token} - ctx := context.Background() _, err := m.GetUserInfo(ctx, provider, tokenObj) if err != nil { return false, err diff --git a/internal/auth/oauth_test.go b/internal/auth/oauth_test.go index e3c51d4..a85f609 100644 --- a/internal/auth/oauth_test.go +++ b/internal/auth/oauth_test.go @@ -175,15 +175,16 @@ func TestDefaultOAuthManager_ValidateToken(t *testing.T) { func TestDefaultOAuthManager_ValidateTokenWithProvider(t *testing.T) { m := NewOAuthManager() + ctx := context.Background() // Test empty token - valid, err := m.ValidateTokenWithProvider(OAuthProviderGoogle, "") + valid, err := m.ValidateTokenWithProvider(ctx, OAuthProviderGoogle, "") if valid || err != nil { t.Errorf("ValidateTokenWithProvider('') = %v, %v, want false, nil", valid, err) } // Test non-existent provider - valid, err = m.ValidateTokenWithProvider(OAuthProviderGoogle, "some-token") + valid, err = m.ValidateTokenWithProvider(ctx, OAuthProviderGoogle, "some-token") if valid { t.Error("ValidateTokenWithProvider() should return false for unconfigured provider") } @@ -607,7 +608,7 @@ func TestOAuthManager_ValidateTokenWithProvider_WithConfig(t *testing.T) { }) // ValidateTokenWithProvider will try GetUserInfo which will fail - valid, err := m.ValidateTokenWithProvider(OAuthProviderGoogle, "some-token") + valid, err := m.ValidateTokenWithProvider(context.Background(), OAuthProviderGoogle, "some-token") // Should return false if valid { t.Error("ValidateTokenWithProvider() should return false for invalid token") diff --git a/internal/service/auth_oauth_internal_test.go b/internal/service/auth_oauth_internal_test.go index 3d288de..f694e52 100644 --- a/internal/service/auth_oauth_internal_test.go +++ b/internal/service/auth_oauth_internal_test.go @@ -59,6 +59,10 @@ func (m *mockOAuthManager) ValidateToken(token string) (bool, error) { return token != "", nil } +func (m *mockOAuthManager) ValidateTokenWithProvider(ctx context.Context, provider auth.OAuthProvider, token string) (bool, error) { + return token != "", nil +} + func (m *mockOAuthManager) GetConfig(provider auth.OAuthProvider) (*auth.OAuthConfig, bool) { if m.config != nil { return m.config, true