Files
user-system/internal/repository/social_account_repo.go
Your Name 8a45548ed8 refactor: migrate SocialAccountRepository to GORM for consistency
- Replace raw SQL with GORM chain calls in Create/Update/Delete/List
- Maintain backward compatibility for *sql.DB construction (wrapped via GORM)
- Update only permitted fields in Update to prevent accidental overwrite of binding keys
- Add repository-level tests for new implementation

Refs: UNFIXED_ISSUES_20260329 social_account_repo GORM refactor
2026-05-29 12:31:48 +08:00

149 lines
4.9 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package repository
import (
"context"
"database/sql"
"fmt"
gormsqlite "gorm.io/driver/sqlite"
"gorm.io/gorm"
"github.com/user-management-system/internal/domain"
)
// SocialAccountRepository 社交账号仓库接口
type SocialAccountRepository interface {
Create(ctx context.Context, account *domain.SocialAccount) error
Update(ctx context.Context, account *domain.SocialAccount) error
Delete(ctx context.Context, id int64) error
DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error
GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error)
GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error)
GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error)
List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error)
}
// SocialAccountRepositoryImpl 社交账号仓库实现
type SocialAccountRepositoryImpl struct {
db *gorm.DB
}
// NewSocialAccountRepository 创建社交账号仓库。
// 仓库主实现统一基于 GORM保留 *sql.DB 构造兼容仅用于当前仓库的 SQLite 测试场景。
func NewSocialAccountRepository(db interface{}) (SocialAccountRepository, error) {
switch d := db.(type) {
case *gorm.DB:
if d == nil {
return nil, fmt.Errorf("gorm db is nil")
}
return &SocialAccountRepositoryImpl{db: d}, nil
case *sql.DB:
if d == nil {
return nil, fmt.Errorf("sql db is nil")
}
gormDB, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
Conn: d,
DriverName: "sqlite",
}), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("wrap sql db with gorm failed: %w", err)
}
return &SocialAccountRepositoryImpl{db: gormDB}, nil
default:
return nil, fmt.Errorf("unsupported db type: %T", db)
}
}
// Create 创建社交账号
func (r *SocialAccountRepositoryImpl) Create(ctx context.Context, account *domain.SocialAccount) error {
return r.db.WithContext(ctx).Create(account).Error
}
// Update 更新社交账号
func (r *SocialAccountRepositoryImpl) Update(ctx context.Context, account *domain.SocialAccount) error {
updates := map[string]interface{}{
"union_id": account.UnionID,
"nickname": account.Nickname,
"avatar": account.Avatar,
"gender": account.Gender,
"email": account.Email,
"phone": account.Phone,
"extra": account.Extra,
"status": account.Status,
}
return r.db.WithContext(ctx).
Model(&domain.SocialAccount{}).
Where("id = ?", account.ID).
Updates(updates).Error
}
// Delete 删除社交账号
func (r *SocialAccountRepositoryImpl) Delete(ctx context.Context, id int64) error {
return r.db.WithContext(ctx).Delete(&domain.SocialAccount{}, id).Error
}
// DeleteByProviderAndUserID 删除指定用户和提供商的社交账号
func (r *SocialAccountRepositoryImpl) DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error {
return r.db.WithContext(ctx).
Where("provider = ? AND user_id = ?", provider, userID).
Delete(&domain.SocialAccount{}).Error
}
// GetByID 根据ID获取社交账号
func (r *SocialAccountRepositoryImpl) GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error) {
var account domain.SocialAccount
if err := r.db.WithContext(ctx).First(&account, id).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get social account: %w", err)
}
return &account, nil
}
// GetByUserID 根据用户ID获取社交账号列表
func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error) {
var accounts []*domain.SocialAccount
if err := r.db.WithContext(ctx).
Where("user_id = ?", userID).
Order("created_at DESC").
Find(&accounts).Error; err != nil {
return nil, fmt.Errorf("failed to query social accounts: %w", err)
}
return accounts, nil
}
// GetByProviderAndOpenID 根据提供商和OpenID获取社交账号
func (r *SocialAccountRepositoryImpl) GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error) {
var account domain.SocialAccount
if err := r.db.WithContext(ctx).
Where("provider = ? AND open_id = ?", provider, openID).
First(&account).Error; err != nil {
if err == gorm.ErrRecordNotFound {
return nil, nil
}
return nil, fmt.Errorf("failed to get social account: %w", err)
}
return &account, nil
}
// List 分页获取社交账号列表
func (r *SocialAccountRepositoryImpl) List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error) {
var accounts []*domain.SocialAccount
var total int64
query := r.db.WithContext(ctx).Model(&domain.SocialAccount{})
if err := query.Count(&total).Error; err != nil {
return nil, 0, fmt.Errorf("failed to count social accounts: %w", err)
}
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&accounts).Error; err != nil {
return nil, 0, fmt.Errorf("failed to query social accounts: %w", err)
}
return accounts, total, nil
}