refactor: migrate SocialAccountRepository to GORM for consistency
- Replace raw SQL with GORM chain calls in Create/Update/Delete/List - Maintain backward compatibility for *sql.DB construction (wrapped via GORM) - Update only permitted fields in Update to prevent accidental overwrite of binding keys - Add repository-level tests for new implementation Refs: UNFIXED_ISSUES_20260329 social_account_repo GORM refactor
This commit is contained in:
@@ -5,8 +5,10 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/user-management-system/internal/domain"
|
gormsqlite "gorm.io/driver/sqlite"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
|
|
||||||
|
"github.com/user-management-system/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SocialAccountRepository 社交账号仓库接口
|
// SocialAccountRepository 社交账号仓库接口
|
||||||
@@ -23,142 +25,78 @@ type SocialAccountRepository interface {
|
|||||||
|
|
||||||
// SocialAccountRepositoryImpl 社交账号仓库实现
|
// SocialAccountRepositoryImpl 社交账号仓库实现
|
||||||
type SocialAccountRepositoryImpl struct {
|
type SocialAccountRepositoryImpl struct {
|
||||||
db *sql.DB
|
db *gorm.DB
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewSocialAccountRepository 创建社交账号仓库(支持 gorm.DB 或 *sql.DB)
|
// NewSocialAccountRepository 创建社交账号仓库。
|
||||||
|
// 仓库主实现统一基于 GORM;保留 *sql.DB 构造兼容仅用于当前仓库的 SQLite 测试场景。
|
||||||
func NewSocialAccountRepository(db interface{}) (SocialAccountRepository, error) {
|
func NewSocialAccountRepository(db interface{}) (SocialAccountRepository, error) {
|
||||||
var sqlDB *sql.DB
|
|
||||||
switch d := db.(type) {
|
switch d := db.(type) {
|
||||||
case *gorm.DB:
|
case *gorm.DB:
|
||||||
var err error
|
if d == nil {
|
||||||
sqlDB, err = d.DB()
|
return nil, fmt.Errorf("gorm db is nil")
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("resolve sql db from gorm db failed: %w", err)
|
|
||||||
}
|
}
|
||||||
|
return &SocialAccountRepositoryImpl{db: d}, nil
|
||||||
case *sql.DB:
|
case *sql.DB:
|
||||||
sqlDB = d
|
if d == nil {
|
||||||
|
return nil, fmt.Errorf("sql db is nil")
|
||||||
|
}
|
||||||
|
gormDB, err := gorm.Open(gormsqlite.New(gormsqlite.Config{
|
||||||
|
Conn: d,
|
||||||
|
DriverName: "sqlite",
|
||||||
|
}), &gorm.Config{})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("wrap sql db with gorm failed: %w", err)
|
||||||
|
}
|
||||||
|
return &SocialAccountRepositoryImpl{db: gormDB}, nil
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported db type: %T", db)
|
return nil, fmt.Errorf("unsupported db type: %T", db)
|
||||||
}
|
}
|
||||||
if sqlDB == nil {
|
|
||||||
return nil, fmt.Errorf("sql db is nil")
|
|
||||||
}
|
|
||||||
return &SocialAccountRepositoryImpl{db: sqlDB}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create 创建社交账号
|
// Create 创建社交账号
|
||||||
func (r *SocialAccountRepositoryImpl) Create(ctx context.Context, account *domain.SocialAccount) error {
|
func (r *SocialAccountRepositoryImpl) Create(ctx context.Context, account *domain.SocialAccount) error {
|
||||||
query := `
|
return r.db.WithContext(ctx).Create(account).Error
|
||||||
INSERT INTO user_social_accounts (user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status)
|
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
||||||
`
|
|
||||||
|
|
||||||
result, err := r.db.ExecContext(ctx, query,
|
|
||||||
account.UserID,
|
|
||||||
account.Provider,
|
|
||||||
account.OpenID,
|
|
||||||
account.UnionID,
|
|
||||||
account.Nickname,
|
|
||||||
account.Avatar,
|
|
||||||
account.Gender,
|
|
||||||
account.Email,
|
|
||||||
account.Phone,
|
|
||||||
account.Extra,
|
|
||||||
account.Status,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to create social account: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := result.LastInsertId()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
account.ID = id
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update 更新社交账号
|
// Update 更新社交账号
|
||||||
func (r *SocialAccountRepositoryImpl) Update(ctx context.Context, account *domain.SocialAccount) error {
|
func (r *SocialAccountRepositoryImpl) Update(ctx context.Context, account *domain.SocialAccount) error {
|
||||||
query := `
|
updates := map[string]interface{}{
|
||||||
UPDATE user_social_accounts
|
"union_id": account.UnionID,
|
||||||
SET union_id = ?, nickname = ?, avatar = ?, gender = ?, email = ?, phone = ?, extra = ?, status = ?, updated_at = CURRENT_TIMESTAMP
|
"nickname": account.Nickname,
|
||||||
WHERE id = ?
|
"avatar": account.Avatar,
|
||||||
`
|
"gender": account.Gender,
|
||||||
|
"email": account.Email,
|
||||||
_, err := r.db.ExecContext(ctx, query,
|
"phone": account.Phone,
|
||||||
account.UnionID,
|
"extra": account.Extra,
|
||||||
account.Nickname,
|
"status": account.Status,
|
||||||
account.Avatar,
|
|
||||||
account.Gender,
|
|
||||||
account.Email,
|
|
||||||
account.Phone,
|
|
||||||
account.Extra,
|
|
||||||
account.Status,
|
|
||||||
account.ID,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to update social account: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return r.db.WithContext(ctx).
|
||||||
|
Model(&domain.SocialAccount{}).
|
||||||
|
Where("id = ?", account.ID).
|
||||||
|
Updates(updates).Error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete 删除社交账号
|
// Delete 删除社交账号
|
||||||
func (r *SocialAccountRepositoryImpl) Delete(ctx context.Context, id int64) error {
|
func (r *SocialAccountRepositoryImpl) Delete(ctx context.Context, id int64) error {
|
||||||
query := `DELETE FROM user_social_accounts WHERE id = ?`
|
return r.db.WithContext(ctx).Delete(&domain.SocialAccount{}, id).Error
|
||||||
|
|
||||||
_, err := r.db.ExecContext(ctx, query, id)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to delete social account: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteByProviderAndUserID 删除指定用户和提供商的社交账号
|
// DeleteByProviderAndUserID 删除指定用户和提供商的社交账号
|
||||||
func (r *SocialAccountRepositoryImpl) DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error {
|
func (r *SocialAccountRepositoryImpl) DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error {
|
||||||
query := `DELETE FROM user_social_accounts WHERE provider = ? AND user_id = ?`
|
return r.db.WithContext(ctx).
|
||||||
|
Where("provider = ? AND user_id = ?", provider, userID).
|
||||||
_, err := r.db.ExecContext(ctx, query, provider, userID)
|
Delete(&domain.SocialAccount{}).Error
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to delete social account: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetByID 根据ID获取社交账号
|
// GetByID 根据ID获取社交账号
|
||||||
func (r *SocialAccountRepositoryImpl) GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error) {
|
func (r *SocialAccountRepositoryImpl) GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error) {
|
||||||
query := `
|
|
||||||
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
|
|
||||||
FROM user_social_accounts
|
|
||||||
WHERE id = ?
|
|
||||||
`
|
|
||||||
|
|
||||||
var account domain.SocialAccount
|
var account domain.SocialAccount
|
||||||
err := r.db.QueryRowContext(ctx, query, id).Scan(
|
if err := r.db.WithContext(ctx).First(&account, id).Error; err != nil {
|
||||||
&account.ID,
|
if err == gorm.ErrRecordNotFound {
|
||||||
&account.UserID,
|
return nil, nil
|
||||||
&account.Provider,
|
}
|
||||||
&account.OpenID,
|
|
||||||
&account.UnionID,
|
|
||||||
&account.Nickname,
|
|
||||||
&account.Avatar,
|
|
||||||
&account.Gender,
|
|
||||||
&account.Email,
|
|
||||||
&account.Phone,
|
|
||||||
&account.Extra,
|
|
||||||
&account.Status,
|
|
||||||
&account.CreatedAt,
|
|
||||||
&account.UpdatedAt,
|
|
||||||
)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get social account: %w", err)
|
return nil, fmt.Errorf("failed to get social account: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -167,45 +105,12 @@ func (r *SocialAccountRepositoryImpl) GetByID(ctx context.Context, id int64) (*d
|
|||||||
|
|
||||||
// GetByUserID 根据用户ID获取社交账号列表
|
// GetByUserID 根据用户ID获取社交账号列表
|
||||||
func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error) {
|
func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error) {
|
||||||
query := `
|
|
||||||
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
|
|
||||||
FROM user_social_accounts
|
|
||||||
WHERE user_id = ?
|
|
||||||
ORDER BY created_at DESC
|
|
||||||
`
|
|
||||||
|
|
||||||
rows, err := r.db.QueryContext(ctx, query, userID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to query social accounts: %w", err)
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var accounts []*domain.SocialAccount
|
var accounts []*domain.SocialAccount
|
||||||
for rows.Next() {
|
if err := r.db.WithContext(ctx).
|
||||||
var account domain.SocialAccount
|
Where("user_id = ?", userID).
|
||||||
err := rows.Scan(
|
Order("created_at DESC").
|
||||||
&account.ID,
|
Find(&accounts).Error; err != nil {
|
||||||
&account.UserID,
|
return nil, fmt.Errorf("failed to query social accounts: %w", err)
|
||||||
&account.Provider,
|
|
||||||
&account.OpenID,
|
|
||||||
&account.UnionID,
|
|
||||||
&account.Nickname,
|
|
||||||
&account.Avatar,
|
|
||||||
&account.Gender,
|
|
||||||
&account.Email,
|
|
||||||
&account.Phone,
|
|
||||||
&account.Extra,
|
|
||||||
&account.Status,
|
|
||||||
&account.CreatedAt,
|
|
||||||
&account.UpdatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
accounts = append(accounts, &account)
|
|
||||||
}
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return accounts, nil
|
return accounts, nil
|
||||||
@@ -213,33 +118,13 @@ func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID in
|
|||||||
|
|
||||||
// GetByProviderAndOpenID 根据提供商和OpenID获取社交账号
|
// GetByProviderAndOpenID 根据提供商和OpenID获取社交账号
|
||||||
func (r *SocialAccountRepositoryImpl) GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error) {
|
func (r *SocialAccountRepositoryImpl) GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error) {
|
||||||
query := `
|
|
||||||
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
|
|
||||||
FROM user_social_accounts
|
|
||||||
WHERE provider = ? AND open_id = ?
|
|
||||||
`
|
|
||||||
|
|
||||||
var account domain.SocialAccount
|
var account domain.SocialAccount
|
||||||
err := r.db.QueryRowContext(ctx, query, provider, openID).Scan(
|
if err := r.db.WithContext(ctx).
|
||||||
&account.ID,
|
Where("provider = ? AND open_id = ?", provider, openID).
|
||||||
&account.UserID,
|
First(&account).Error; err != nil {
|
||||||
&account.Provider,
|
if err == gorm.ErrRecordNotFound {
|
||||||
&account.OpenID,
|
return nil, nil
|
||||||
&account.UnionID,
|
}
|
||||||
&account.Nickname,
|
|
||||||
&account.Avatar,
|
|
||||||
&account.Gender,
|
|
||||||
&account.Email,
|
|
||||||
&account.Phone,
|
|
||||||
&account.Extra,
|
|
||||||
&account.Status,
|
|
||||||
&account.CreatedAt,
|
|
||||||
&account.UpdatedAt,
|
|
||||||
)
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to get social account: %w", err)
|
return nil, fmt.Errorf("failed to get social account: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -248,54 +133,16 @@ func (r *SocialAccountRepositoryImpl) GetByProviderAndOpenID(ctx context.Context
|
|||||||
|
|
||||||
// List 分页获取社交账号列表
|
// List 分页获取社交账号列表
|
||||||
func (r *SocialAccountRepositoryImpl) List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error) {
|
func (r *SocialAccountRepositoryImpl) List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error) {
|
||||||
// 获取总数
|
var accounts []*domain.SocialAccount
|
||||||
var total int64
|
var total int64
|
||||||
countQuery := `SELECT COUNT(*) FROM user_social_accounts`
|
query := r.db.WithContext(ctx).Model(&domain.SocialAccount{})
|
||||||
if err := r.db.QueryRowContext(ctx, countQuery).Scan(&total); err != nil {
|
|
||||||
|
if err := query.Count(&total).Error; err != nil {
|
||||||
return nil, 0, fmt.Errorf("failed to count social accounts: %w", err)
|
return nil, 0, fmt.Errorf("failed to count social accounts: %w", err)
|
||||||
}
|
}
|
||||||
|
if err := query.Order("created_at DESC").Offset(offset).Limit(limit).Find(&accounts).Error; err != nil {
|
||||||
// 获取列表
|
|
||||||
query := `
|
|
||||||
SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at
|
|
||||||
FROM user_social_accounts
|
|
||||||
ORDER BY created_at DESC
|
|
||||||
LIMIT ? OFFSET ?
|
|
||||||
`
|
|
||||||
|
|
||||||
rows, err := r.db.QueryContext(ctx, query, limit, offset)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, fmt.Errorf("failed to query social accounts: %w", err)
|
return nil, 0, fmt.Errorf("failed to query social accounts: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
var accounts []*domain.SocialAccount
|
|
||||||
for rows.Next() {
|
|
||||||
var account domain.SocialAccount
|
|
||||||
err := rows.Scan(
|
|
||||||
&account.ID,
|
|
||||||
&account.UserID,
|
|
||||||
&account.Provider,
|
|
||||||
&account.OpenID,
|
|
||||||
&account.UnionID,
|
|
||||||
&account.Nickname,
|
|
||||||
&account.Avatar,
|
|
||||||
&account.Gender,
|
|
||||||
&account.Email,
|
|
||||||
&account.Phone,
|
|
||||||
&account.Extra,
|
|
||||||
&account.Status,
|
|
||||||
&account.CreatedAt,
|
|
||||||
&account.UpdatedAt,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
accounts = append(accounts, &account)
|
|
||||||
}
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
return nil, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return accounts, total, nil
|
return accounts, total, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -182,6 +182,54 @@ func TestSocialAccountRepository_Update(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSocialAccountRepository_Update_DoesNotRewriteBindingIdentityFields(t *testing.T) {
|
||||||
|
db := setupSocialAccountTestDB(t)
|
||||||
|
repo, err := NewSocialAccountRepository(db)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewSocialAccountRepository() error = %v", err)
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
account := &domain.SocialAccount{
|
||||||
|
UserID: 1,
|
||||||
|
Provider: "github",
|
||||||
|
OpenID: "openid-identity",
|
||||||
|
Nickname: "before-update",
|
||||||
|
Status: domain.SocialAccountStatusActive,
|
||||||
|
}
|
||||||
|
if err := repo.Create(ctx, account); err != nil {
|
||||||
|
t.Fatalf("Create() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
account.UserID = 999
|
||||||
|
account.Provider = "wechat"
|
||||||
|
account.OpenID = "rewritten-openid"
|
||||||
|
account.Nickname = "after-update"
|
||||||
|
if err := repo.Update(ctx, account); err != nil {
|
||||||
|
t.Fatalf("Update() error = %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
found, err := repo.GetByID(ctx, account.ID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetByID() error = %v", err)
|
||||||
|
}
|
||||||
|
if found == nil {
|
||||||
|
t.Fatal("expected social account after update")
|
||||||
|
}
|
||||||
|
if found.UserID != 1 {
|
||||||
|
t.Fatalf("UserID = %d, want 1", found.UserID)
|
||||||
|
}
|
||||||
|
if found.Provider != "github" {
|
||||||
|
t.Fatalf("Provider = %q, want github", found.Provider)
|
||||||
|
}
|
||||||
|
if found.OpenID != "openid-identity" {
|
||||||
|
t.Fatalf("OpenID = %q, want openid-identity", found.OpenID)
|
||||||
|
}
|
||||||
|
if found.Nickname != "after-update" {
|
||||||
|
t.Fatalf("Nickname = %q, want after-update", found.Nickname)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSocialAccountRepository_Delete(t *testing.T) {
|
func TestSocialAccountRepository_Delete(t *testing.T) {
|
||||||
db := setupSocialAccountTestDB(t)
|
db := setupSocialAccountTestDB(t)
|
||||||
repo, err := NewSocialAccountRepository(db)
|
repo, err := NewSocialAccountRepository(db)
|
||||||
|
|||||||
Reference in New Issue
Block a user