- Replace raw SQL with GORM chain calls in Create/Update/Delete/List - Maintain backward compatibility for *sql.DB construction (wrapped via GORM) - Update only permitted fields in Update to prevent accidental overwrite of binding keys - Add repository-level tests for new implementation Refs: UNFIXED_ISSUES_20260329 social_account_repo GORM refactor
149 lines
4.9 KiB
Go
149 lines
4.9 KiB
Go
package repository
|
||
|
||
import (
|
||
"context"
|
||
"database/sql"
|
||
"fmt"
|
||
|
||
gormsqlite "gorm.io/driver/sqlite"
|
||
"gorm.io/gorm"
|
||
|
||
"github.com/user-management-system/internal/domain"
|
||
)
|
||
|
||
// SocialAccountRepository 社交账号仓库接口
|
||
type SocialAccountRepository interface {
|
||
Create(ctx context.Context, account *domain.SocialAccount) error
|
||
Update(ctx context.Context, account *domain.SocialAccount) error
|
||
Delete(ctx context.Context, id int64) error
|
||
DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error
|
||
GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error)
|
||
GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error)
|
||
GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error)
|
||
List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error)
|
||
}
|
||
|
||
// SocialAccountRepositoryImpl 社交账号仓库实现
|
||
type SocialAccountRepositoryImpl struct {
|
||
db *gorm.DB
|
||
}
|
||
|
||
// NewSocialAccountRepository 创建社交账号仓库。
|
||
// 仓库主实现统一基于 GORM;保留 *sql.DB 构造兼容仅用于当前仓库的 SQLite 测试场景。
|
||
func NewSocialAccountRepository(db interface{}) (SocialAccountRepository, error) {
|
||
switch d := db.(type) {
|
||
case *gorm.DB:
|
||
if d == nil {
|
||
return nil, fmt.Errorf("gorm db is nil")
|
||
}
|
||
return &SocialAccountRepositoryImpl{db: d}, nil
|
||
case *sql.DB:
|
||
if d == nil {
|
||
return nil, fmt.Errorf("sql db is nil")
|
||
}
|
||
gormDB, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||
Conn: d,
|
||
DriverName: "sqlite",
|
||
}), &gorm.Config{})
|
||
if err != nil {
|
||
return nil, fmt.Errorf("wrap sql db with gorm failed: %w", err)
|
||
}
|
||
return &SocialAccountRepositoryImpl{db: gormDB}, nil
|
||
default:
|
||
return nil, fmt.Errorf("unsupported db type: %T", db)
|
||
}
|
||
}
|
||
|
||
// Create 创建社交账号
|
||
func (r *SocialAccountRepositoryImpl) Create(ctx context.Context, account *domain.SocialAccount) error {
|
||
return r.db.WithContext(ctx).Create(account).Error
|
||
}
|
||
|
||
// Update 更新社交账号
|
||
func (r *SocialAccountRepositoryImpl) Update(ctx context.Context, account *domain.SocialAccount) error {
|
||
updates := map[string]interface{}{
|
||
"union_id": account.UnionID,
|
||
"nickname": account.Nickname,
|
||
"avatar": account.Avatar,
|
||
"gender": account.Gender,
|
||
"email": account.Email,
|
||
"phone": account.Phone,
|
||
"extra": account.Extra,
|
||
"status": account.Status,
|
||
}
|
||
|
||
return r.db.WithContext(ctx).
|
||
Model(&domain.SocialAccount{}).
|
||
Where("id = ?", account.ID).
|
||
Updates(updates).Error
|
||
}
|
||
|
||
// Delete 删除社交账号
|
||
func (r *SocialAccountRepositoryImpl) Delete(ctx context.Context, id int64) error {
|
||
return r.db.WithContext(ctx).Delete(&domain.SocialAccount{}, id).Error
|
||
}
|
||
|
||
// DeleteByProviderAndUserID 删除指定用户和提供商的社交账号
|
||
func (r *SocialAccountRepositoryImpl) DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error {
|
||
return r.db.WithContext(ctx).
|
||
Where("provider = ? AND user_id = ?", provider, userID).
|
||
Delete(&domain.SocialAccount{}).Error
|
||
}
|
||
|
||
// GetByID 根据ID获取社交账号
|
||
func (r *SocialAccountRepositoryImpl) GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error) {
|
||
var account domain.SocialAccount
|
||
if err := r.db.WithContext(ctx).First(&account, id).Error; err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
return nil, nil
|
||
}
|
||
return nil, fmt.Errorf("failed to get social account: %w", err)
|
||
}
|
||
|
||
return &account, nil
|
||
}
|
||
|
||
// GetByUserID 根据用户ID获取社交账号列表
|
||
func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error) {
|
||
var accounts []*domain.SocialAccount
|
||
if err := r.db.WithContext(ctx).
|
||
Where("user_id = ?", userID).
|
||
Order("created_at DESC").
|
||
Find(&accounts).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to query social accounts: %w", err)
|
||
}
|
||
|
||
return accounts, nil
|
||
}
|
||
|
||
// GetByProviderAndOpenID 根据提供商和OpenID获取社交账号
|
||
func (r *SocialAccountRepositoryImpl) GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error) {
|
||
var account domain.SocialAccount
|
||
if err := r.db.WithContext(ctx).
|
||
Where("provider = ? AND open_id = ?", provider, openID).
|
||
First(&account).Error; err != nil {
|
||
if err == gorm.ErrRecordNotFound {
|
||
return nil, nil
|
||
}
|
||
return nil, fmt.Errorf("failed to get social account: %w", err)
|
||
}
|
||
|
||
return &account, nil
|
||
}
|
||
|
||
// List 分页获取社交账号列表
|
||
func (r *SocialAccountRepositoryImpl) List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error) {
|
||
var accounts []*domain.SocialAccount
|
||
var total int64
|
||
query := r.db.WithContext(ctx).Model(&domain.SocialAccount{})
|
||
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, fmt.Errorf("failed to count social accounts: %w", err)
|
||
}
|
||
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&accounts).Error; err != nil {
|
||
return nil, 0, fmt.Errorf("failed to query social accounts: %w", err)
|
||
}
|
||
|
||
return accounts, total, nil
|
||
}
|