package repository import ( "context" "database/sql" "fmt" "github.com/user-management-system/internal/domain" "gorm.io/gorm" ) // SocialAccountRepository 社交账号仓库接口 type SocialAccountRepository interface { Create(ctx context.Context, account *domain.SocialAccount) error Update(ctx context.Context, account *domain.SocialAccount) error Delete(ctx context.Context, id int64) error DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error) GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error) GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error) List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error) } // SocialAccountRepositoryImpl 社交账号仓库实现 type SocialAccountRepositoryImpl struct { db *sql.DB } // NewSocialAccountRepository 创建社交账号仓库(支持 gorm.DB 或 *sql.DB) func NewSocialAccountRepository(db interface{}) (SocialAccountRepository, error) { var sqlDB *sql.DB switch d := db.(type) { case *gorm.DB: var err error sqlDB, err = d.DB() if err != nil { return nil, fmt.Errorf("resolve sql db from gorm db failed: %w", err) } case *sql.DB: sqlDB = d default: return nil, fmt.Errorf("unsupported db type: %T", db) } if sqlDB == nil { return nil, fmt.Errorf("sql db is nil") } return &SocialAccountRepositoryImpl{db: sqlDB}, nil } // Create 创建社交账号 func (r *SocialAccountRepositoryImpl) Create(ctx context.Context, account *domain.SocialAccount) error { query := ` INSERT INTO user_social_accounts (user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ` result, err := r.db.ExecContext(ctx, query, account.UserID, account.Provider, account.OpenID, account.UnionID, account.Nickname, account.Avatar, account.Gender, account.Email, account.Phone, account.Extra, account.Status, ) if err != nil { return fmt.Errorf("failed to create social account: %w", err) } id, err := result.LastInsertId() if err != nil { return err } account.ID = id return nil } // Update 更新社交账号 func (r *SocialAccountRepositoryImpl) Update(ctx context.Context, account *domain.SocialAccount) error { query := ` UPDATE user_social_accounts SET union_id = ?, nickname = ?, avatar = ?, gender = ?, email = ?, phone = ?, extra = ?, status = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ? ` _, err := r.db.ExecContext(ctx, query, account.UnionID, account.Nickname, account.Avatar, account.Gender, account.Email, account.Phone, account.Extra, account.Status, account.ID, ) if err != nil { return fmt.Errorf("failed to update social account: %w", err) } return nil } // Delete 删除社交账号 func (r *SocialAccountRepositoryImpl) Delete(ctx context.Context, id int64) error { query := `DELETE FROM user_social_accounts WHERE id = ?` _, err := r.db.ExecContext(ctx, query, id) if err != nil { return fmt.Errorf("failed to delete social account: %w", err) } return nil } // DeleteByProviderAndUserID 删除指定用户和提供商的社交账号 func (r *SocialAccountRepositoryImpl) DeleteByProviderAndUserID(ctx context.Context, provider string, userID int64) error { query := `DELETE FROM user_social_accounts WHERE provider = ? AND user_id = ?` _, err := r.db.ExecContext(ctx, query, provider, userID) if err != nil { return fmt.Errorf("failed to delete social account: %w", err) } return nil } // GetByID 根据ID获取社交账号 func (r *SocialAccountRepositoryImpl) GetByID(ctx context.Context, id int64) (*domain.SocialAccount, error) { query := ` SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at FROM user_social_accounts WHERE id = ? ` var account domain.SocialAccount err := r.db.QueryRowContext(ctx, query, id).Scan( &account.ID, &account.UserID, &account.Provider, &account.OpenID, &account.UnionID, &account.Nickname, &account.Avatar, &account.Gender, &account.Email, &account.Phone, &account.Extra, &account.Status, &account.CreatedAt, &account.UpdatedAt, ) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("failed to get social account: %w", err) } return &account, nil } // GetByUserID 根据用户ID获取社交账号列表 func (r *SocialAccountRepositoryImpl) GetByUserID(ctx context.Context, userID int64) ([]*domain.SocialAccount, error) { query := ` SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at FROM user_social_accounts WHERE user_id = ? ORDER BY created_at DESC ` rows, err := r.db.QueryContext(ctx, query, userID) if err != nil { return nil, fmt.Errorf("failed to query social accounts: %w", err) } defer rows.Close() var accounts []*domain.SocialAccount for rows.Next() { var account domain.SocialAccount err := rows.Scan( &account.ID, &account.UserID, &account.Provider, &account.OpenID, &account.UnionID, &account.Nickname, &account.Avatar, &account.Gender, &account.Email, &account.Phone, &account.Extra, &account.Status, &account.CreatedAt, &account.UpdatedAt, ) if err != nil { return nil, err } accounts = append(accounts, &account) } return accounts, nil } // GetByProviderAndOpenID 根据提供商和OpenID获取社交账号 func (r *SocialAccountRepositoryImpl) GetByProviderAndOpenID(ctx context.Context, provider, openID string) (*domain.SocialAccount, error) { query := ` SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at FROM user_social_accounts WHERE provider = ? AND open_id = ? ` var account domain.SocialAccount err := r.db.QueryRowContext(ctx, query, provider, openID).Scan( &account.ID, &account.UserID, &account.Provider, &account.OpenID, &account.UnionID, &account.Nickname, &account.Avatar, &account.Gender, &account.Email, &account.Phone, &account.Extra, &account.Status, &account.CreatedAt, &account.UpdatedAt, ) if err == sql.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("failed to get social account: %w", err) } return &account, nil } // List 分页获取社交账号列表 func (r *SocialAccountRepositoryImpl) List(ctx context.Context, offset, limit int) ([]*domain.SocialAccount, int64, error) { // 获取总数 var total int64 countQuery := `SELECT COUNT(*) FROM user_social_accounts` if err := r.db.QueryRowContext(ctx, countQuery).Scan(&total); err != nil { return nil, 0, fmt.Errorf("failed to count social accounts: %w", err) } // 获取列表 query := ` SELECT id, user_id, provider, open_id, union_id, nickname, avatar, gender, email, phone, extra, status, created_at, updated_at FROM user_social_accounts ORDER BY created_at DESC LIMIT ? OFFSET ? ` rows, err := r.db.QueryContext(ctx, query, limit, offset) if err != nil { return nil, 0, fmt.Errorf("failed to query social accounts: %w", err) } defer rows.Close() var accounts []*domain.SocialAccount for rows.Next() { var account domain.SocialAccount err := rows.Scan( &account.ID, &account.UserID, &account.Provider, &account.OpenID, &account.UnionID, &account.Nickname, &account.Avatar, &account.Gender, &account.Email, &account.Phone, &account.Extra, &account.Status, &account.CreatedAt, &account.UpdatedAt, ) if err != nil { return nil, 0, err } accounts = append(accounts, &account) } return accounts, total, nil }