463 lines
13 KiB
Go
463 lines
13 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
cryptorand "crypto/rand"
|
|
"encoding/json"
|
|
"fmt"
|
|
"log"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
|
|
aliyunopenapiutil "github.com/alibabacloud-go/darabonba-openapi/v2/utils"
|
|
aliyunsms "github.com/alibabacloud-go/dysmsapi-20170525/v5/client"
|
|
"github.com/alibabacloud-go/tea/dara"
|
|
tccommon "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
|
|
tcprofile "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
|
|
tcsms "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/sms/v20210111"
|
|
)
|
|
|
|
var (
|
|
validPhonePattern = regexp.MustCompile(`^((\+86|86)?1[3-9]\d{9}|\+[1-9]\d{6,14})$`)
|
|
mainlandPhonePattern = regexp.MustCompile(`^1[3-9]\d{9}$`)
|
|
mainlandPhone86Pattern = regexp.MustCompile(`^86(1[3-9]\d{9})$`)
|
|
mainlandPhone0086Pattern = regexp.MustCompile(`^0086(1[3-9]\d{9})$`)
|
|
verificationCodeCharset10 = 1000000
|
|
)
|
|
|
|
// SMSProvider sends one verification code to one phone number.
|
|
type SMSProvider interface {
|
|
SendVerificationCode(ctx context.Context, phone, code string) error
|
|
}
|
|
|
|
// MockSMSProvider is a test helper and is not wired into the server runtime.
|
|
type MockSMSProvider struct{}
|
|
|
|
func (m *MockSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
|
|
_ = ctx
|
|
// 安全:不在日志中记录完整验证码,仅显示部分信息用于调试
|
|
maskedCode := "****"
|
|
if len(code) >= 4 {
|
|
maskedCode = strings.Repeat("*", len(code)-4) + code[len(code)-4:]
|
|
}
|
|
log.Printf("[sms-mock] phone=%s code=%s ttl=5m", phone, maskedCode)
|
|
return nil
|
|
}
|
|
|
|
type aliyunSMSClient interface {
|
|
SendSms(request *aliyunsms.SendSmsRequest) (*aliyunsms.SendSmsResponse, error)
|
|
}
|
|
|
|
type tencentSMSClient interface {
|
|
SendSmsWithContext(ctx context.Context, request *tcsms.SendSmsRequest) (*tcsms.SendSmsResponse, error)
|
|
}
|
|
|
|
type AliyunSMSConfig struct {
|
|
AccessKeyID string
|
|
AccessKeySecret string
|
|
SignName string
|
|
TemplateCode string
|
|
Endpoint string
|
|
RegionID string
|
|
CodeParamName string
|
|
}
|
|
|
|
type AliyunSMSProvider struct {
|
|
cfg AliyunSMSConfig
|
|
client aliyunSMSClient
|
|
}
|
|
|
|
func NewAliyunSMSProvider(cfg AliyunSMSConfig) (SMSProvider, error) {
|
|
cfg = normalizeAliyunSMSConfig(cfg)
|
|
if cfg.AccessKeyID == "" || cfg.AccessKeySecret == "" || cfg.SignName == "" || cfg.TemplateCode == "" {
|
|
return nil, fmt.Errorf("aliyun SMS config is incomplete")
|
|
}
|
|
|
|
client, err := newAliyunSMSClient(cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create aliyun SMS client failed: %w", err)
|
|
}
|
|
|
|
return &AliyunSMSProvider{
|
|
cfg: cfg,
|
|
client: client,
|
|
}, nil
|
|
}
|
|
|
|
func newAliyunSMSClient(cfg AliyunSMSConfig) (aliyunSMSClient, error) {
|
|
client, err := aliyunsms.NewClient(&aliyunopenapiutil.Config{
|
|
AccessKeyId: dara.String(cfg.AccessKeyID),
|
|
AccessKeySecret: dara.String(cfg.AccessKeySecret),
|
|
Endpoint: stringPointerOrNil(cfg.Endpoint),
|
|
RegionId: dara.String(cfg.RegionID),
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return client, nil
|
|
}
|
|
|
|
func (a *AliyunSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
|
|
_ = ctx
|
|
|
|
templateParam, err := json.Marshal(map[string]string{
|
|
a.cfg.CodeParamName: code,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("marshal aliyun SMS template param failed: %w", err)
|
|
}
|
|
|
|
resp, err := a.client.SendSms(
|
|
new(aliyunsms.SendSmsRequest).
|
|
SetPhoneNumbers(normalizePhoneForSMS(phone)).
|
|
SetSignName(a.cfg.SignName).
|
|
SetTemplateCode(a.cfg.TemplateCode).
|
|
SetTemplateParam(string(templateParam)),
|
|
)
|
|
if err != nil {
|
|
return fmt.Errorf("aliyun SMS request failed: %w", err)
|
|
}
|
|
if resp == nil || resp.Body == nil {
|
|
return fmt.Errorf("aliyun SMS returned empty response")
|
|
}
|
|
|
|
body := resp.Body
|
|
if !strings.EqualFold(dara.StringValue(body.Code), "OK") {
|
|
return fmt.Errorf(
|
|
"aliyun SMS rejected: code=%s message=%s request_id=%s",
|
|
valueOrDefault(dara.StringValue(body.Code), "unknown"),
|
|
valueOrDefault(dara.StringValue(body.Message), "unknown"),
|
|
valueOrDefault(dara.StringValue(body.RequestId), "unknown"),
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type TencentSMSConfig struct {
|
|
SecretID string
|
|
SecretKey string
|
|
AppID string
|
|
SignName string
|
|
TemplateID string
|
|
Region string
|
|
Endpoint string
|
|
}
|
|
|
|
type TencentSMSProvider struct {
|
|
cfg TencentSMSConfig
|
|
client tencentSMSClient
|
|
}
|
|
|
|
func NewTencentSMSProvider(cfg TencentSMSConfig) (SMSProvider, error) {
|
|
cfg = normalizeTencentSMSConfig(cfg)
|
|
if cfg.SecretID == "" || cfg.SecretKey == "" || cfg.AppID == "" || cfg.SignName == "" || cfg.TemplateID == "" {
|
|
return nil, fmt.Errorf("tencent SMS config is incomplete")
|
|
}
|
|
|
|
client, err := newTencentSMSClient(cfg)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("create tencent SMS client failed: %w", err)
|
|
}
|
|
|
|
return &TencentSMSProvider{
|
|
cfg: cfg,
|
|
client: client,
|
|
}, nil
|
|
}
|
|
|
|
func newTencentSMSClient(cfg TencentSMSConfig) (tencentSMSClient, error) {
|
|
clientProfile := tcprofile.NewClientProfile()
|
|
clientProfile.HttpProfile.ReqTimeout = 30
|
|
if cfg.Endpoint != "" {
|
|
clientProfile.HttpProfile.Endpoint = cfg.Endpoint
|
|
}
|
|
|
|
client, err := tcsms.NewClient(
|
|
tccommon.NewCredential(cfg.SecretID, cfg.SecretKey),
|
|
cfg.Region,
|
|
clientProfile,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return client, nil
|
|
}
|
|
|
|
func (t *TencentSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
|
|
req := tcsms.NewSendSmsRequest()
|
|
req.PhoneNumberSet = []*string{tccommon.StringPtr(normalizePhoneForSMS(phone))}
|
|
req.SmsSdkAppId = tccommon.StringPtr(t.cfg.AppID)
|
|
req.SignName = tccommon.StringPtr(t.cfg.SignName)
|
|
req.TemplateId = tccommon.StringPtr(t.cfg.TemplateID)
|
|
req.TemplateParamSet = []*string{tccommon.StringPtr(code)}
|
|
|
|
resp, err := t.client.SendSmsWithContext(ctx, req)
|
|
if err != nil {
|
|
return fmt.Errorf("tencent SMS request failed: %w", err)
|
|
}
|
|
if resp == nil || resp.Response == nil {
|
|
return fmt.Errorf("tencent SMS returned empty response")
|
|
}
|
|
if len(resp.Response.SendStatusSet) == 0 {
|
|
return fmt.Errorf(
|
|
"tencent SMS returned empty status list: request_id=%s",
|
|
valueOrDefault(pointerString(resp.Response.RequestId), "unknown"),
|
|
)
|
|
}
|
|
|
|
status := resp.Response.SendStatusSet[0]
|
|
if !strings.EqualFold(pointerString(status.Code), "Ok") {
|
|
return fmt.Errorf(
|
|
"tencent SMS rejected: code=%s message=%s request_id=%s",
|
|
valueOrDefault(pointerString(status.Code), "unknown"),
|
|
valueOrDefault(pointerString(status.Message), "unknown"),
|
|
valueOrDefault(pointerString(resp.Response.RequestId), "unknown"),
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
type SMSCodeConfig struct {
|
|
CodeTTL time.Duration
|
|
ResendCooldown time.Duration
|
|
MaxDailyLimit int
|
|
}
|
|
|
|
func DefaultSMSCodeConfig() SMSCodeConfig {
|
|
return SMSCodeConfig{
|
|
CodeTTL: 5 * time.Minute,
|
|
ResendCooldown: time.Minute,
|
|
MaxDailyLimit: 10,
|
|
}
|
|
}
|
|
|
|
type SMSCodeService struct {
|
|
provider SMSProvider
|
|
cache cacheInterface
|
|
cfg SMSCodeConfig
|
|
}
|
|
|
|
type cacheInterface interface {
|
|
Get(ctx context.Context, key string) (interface{}, bool)
|
|
Set(ctx context.Context, key string, value interface{}, l1TTL, l2TTL time.Duration) error
|
|
Delete(ctx context.Context, key string) error
|
|
}
|
|
|
|
func NewSMSCodeService(provider SMSProvider, cacheManager cacheInterface, cfg SMSCodeConfig) *SMSCodeService {
|
|
if cfg.CodeTTL <= 0 {
|
|
cfg.CodeTTL = 5 * time.Minute
|
|
}
|
|
if cfg.ResendCooldown <= 0 {
|
|
cfg.ResendCooldown = time.Minute
|
|
}
|
|
if cfg.MaxDailyLimit <= 0 {
|
|
cfg.MaxDailyLimit = 10
|
|
}
|
|
|
|
return &SMSCodeService{
|
|
provider: provider,
|
|
cache: cacheManager,
|
|
cfg: cfg,
|
|
}
|
|
}
|
|
|
|
type SendCodeRequest struct {
|
|
Phone string `json:"phone" binding:"required"`
|
|
Purpose string `json:"purpose"`
|
|
Scene string `json:"scene"`
|
|
}
|
|
|
|
type SendCodeResponse struct {
|
|
ExpiresIn int `json:"expires_in"`
|
|
Cooldown int `json:"cooldown"`
|
|
}
|
|
|
|
func (s *SMSCodeService) SendCode(ctx context.Context, req *SendCodeRequest) (*SendCodeResponse, error) {
|
|
if s == nil || s.provider == nil || s.cache == nil {
|
|
return nil, fmt.Errorf("sms code service is not configured")
|
|
}
|
|
if req == nil {
|
|
return nil, newValidationError("\u8bf7\u6c42\u4e0d\u80fd\u4e3a\u7a7a")
|
|
}
|
|
|
|
phone := strings.TrimSpace(req.Phone)
|
|
if !isValidPhone(phone) {
|
|
return nil, newValidationError("\u624b\u673a\u53f7\u7801\u683c\u5f0f\u4e0d\u6b63\u786e")
|
|
}
|
|
purpose := strings.TrimSpace(req.Purpose)
|
|
if purpose == "" {
|
|
purpose = strings.TrimSpace(req.Scene)
|
|
}
|
|
|
|
cooldownKey := fmt.Sprintf("sms_cooldown:%s", phone)
|
|
if _, ok := s.cache.Get(ctx, cooldownKey); ok {
|
|
return nil, newRateLimitError(fmt.Sprintf("\u64cd\u4f5c\u8fc7\u4e8e\u9891\u7e41\uff0c\u8bf7 %d \u79d2\u540e\u518d\u8bd5", int(s.cfg.ResendCooldown.Seconds())))
|
|
}
|
|
|
|
dailyKey := fmt.Sprintf("sms_daily:%s:%s", phone, time.Now().Format("2006-01-02"))
|
|
var dailyCount int
|
|
if val, ok := s.cache.Get(ctx, dailyKey); ok {
|
|
if n, ok := intValue(val); ok {
|
|
dailyCount = n
|
|
}
|
|
}
|
|
if dailyCount >= s.cfg.MaxDailyLimit {
|
|
return nil, newRateLimitError(fmt.Sprintf("\u4eca\u65e5\u53d1\u9001\u6b21\u6570\u5df2\u8fbe\u4e0a\u9650\uff08%d\u6b21\uff09\uff0c\u8bf7\u660e\u65e5\u518d\u8bd5", s.cfg.MaxDailyLimit))
|
|
}
|
|
|
|
code, err := generateSMSCode()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generate sms code failed: %w", err)
|
|
}
|
|
|
|
codeKey := fmt.Sprintf("sms_code:%s:%s", purpose, phone)
|
|
if err := s.cache.Set(ctx, codeKey, code, s.cfg.CodeTTL, s.cfg.CodeTTL); err != nil {
|
|
return nil, fmt.Errorf("store sms code failed: %w", err)
|
|
}
|
|
if err := s.cache.Set(ctx, cooldownKey, true, s.cfg.ResendCooldown, s.cfg.ResendCooldown); err != nil {
|
|
_ = s.cache.Delete(ctx, codeKey)
|
|
return nil, fmt.Errorf("store sms cooldown failed: %w", err)
|
|
}
|
|
if err := s.cache.Set(ctx, dailyKey, dailyCount+1, 24*time.Hour, 24*time.Hour); err != nil {
|
|
_ = s.cache.Delete(ctx, codeKey)
|
|
_ = s.cache.Delete(ctx, cooldownKey)
|
|
return nil, fmt.Errorf("store sms daily counter failed: %w", err)
|
|
}
|
|
|
|
if err := s.provider.SendVerificationCode(ctx, phone, code); err != nil {
|
|
_ = s.cache.Delete(ctx, codeKey)
|
|
_ = s.cache.Delete(ctx, cooldownKey)
|
|
return nil, fmt.Errorf("\u77ed\u4fe1\u53d1\u9001\u5931\u8d25: %w", err)
|
|
}
|
|
|
|
return &SendCodeResponse{
|
|
ExpiresIn: int(s.cfg.CodeTTL.Seconds()),
|
|
Cooldown: int(s.cfg.ResendCooldown.Seconds()),
|
|
}, nil
|
|
}
|
|
|
|
func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, purpose, code string) error {
|
|
if s == nil || s.cache == nil {
|
|
return fmt.Errorf("sms code service is not configured")
|
|
}
|
|
if strings.TrimSpace(code) == "" {
|
|
return fmt.Errorf("\u9a8c\u8bc1\u7801\u4e0d\u80fd\u4e3a\u7a7a")
|
|
}
|
|
|
|
phone = strings.TrimSpace(phone)
|
|
purpose = strings.TrimSpace(purpose)
|
|
codeKey := fmt.Sprintf("sms_code:%s:%s", purpose, phone)
|
|
val, ok := s.cache.Get(ctx, codeKey)
|
|
if !ok {
|
|
return fmt.Errorf("\u9a8c\u8bc1\u7801\u5df2\u8fc7\u671f\u6216\u4e0d\u5b58\u5728")
|
|
}
|
|
|
|
stored, ok := val.(string)
|
|
if !ok || stored != code {
|
|
return fmt.Errorf("\u9a8c\u8bc1\u7801\u4e0d\u6b63\u786e")
|
|
}
|
|
|
|
if err := s.cache.Delete(ctx, codeKey); err != nil {
|
|
return fmt.Errorf("consume sms code failed: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func isValidPhone(phone string) bool {
|
|
return validPhonePattern.MatchString(strings.TrimSpace(phone))
|
|
}
|
|
|
|
func generateSMSCode() (string, error) {
|
|
b := make([]byte, 4)
|
|
if _, err := cryptorand.Read(b); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
n := int(b[0])<<24 | int(b[1])<<16 | int(b[2])<<8 | int(b[3])
|
|
if n < 0 {
|
|
n = -n
|
|
}
|
|
n = n % verificationCodeCharset10
|
|
if n < 100000 {
|
|
n += 100000
|
|
}
|
|
|
|
return fmt.Sprintf("%06d", n), nil
|
|
}
|
|
|
|
func normalizeAliyunSMSConfig(cfg AliyunSMSConfig) AliyunSMSConfig {
|
|
cfg.AccessKeyID = strings.TrimSpace(cfg.AccessKeyID)
|
|
cfg.AccessKeySecret = strings.TrimSpace(cfg.AccessKeySecret)
|
|
cfg.SignName = strings.TrimSpace(cfg.SignName)
|
|
cfg.TemplateCode = strings.TrimSpace(cfg.TemplateCode)
|
|
cfg.Endpoint = strings.TrimSpace(cfg.Endpoint)
|
|
cfg.RegionID = strings.TrimSpace(cfg.RegionID)
|
|
cfg.CodeParamName = strings.TrimSpace(cfg.CodeParamName)
|
|
|
|
if cfg.RegionID == "" {
|
|
cfg.RegionID = "cn-hangzhou"
|
|
}
|
|
if cfg.CodeParamName == "" {
|
|
cfg.CodeParamName = "code"
|
|
}
|
|
|
|
return cfg
|
|
}
|
|
|
|
func normalizeTencentSMSConfig(cfg TencentSMSConfig) TencentSMSConfig {
|
|
cfg.SecretID = strings.TrimSpace(cfg.SecretID)
|
|
cfg.SecretKey = strings.TrimSpace(cfg.SecretKey)
|
|
cfg.AppID = strings.TrimSpace(cfg.AppID)
|
|
cfg.SignName = strings.TrimSpace(cfg.SignName)
|
|
cfg.TemplateID = strings.TrimSpace(cfg.TemplateID)
|
|
cfg.Region = strings.TrimSpace(cfg.Region)
|
|
cfg.Endpoint = strings.TrimSpace(cfg.Endpoint)
|
|
|
|
if cfg.Region == "" {
|
|
cfg.Region = "ap-guangzhou"
|
|
}
|
|
|
|
return cfg
|
|
}
|
|
|
|
func normalizePhoneForSMS(phone string) string {
|
|
phone = strings.TrimSpace(phone)
|
|
|
|
switch {
|
|
case mainlandPhonePattern.MatchString(phone):
|
|
return "+86" + phone
|
|
case mainlandPhone86Pattern.MatchString(phone):
|
|
return "+" + phone
|
|
case mainlandPhone0086Pattern.MatchString(phone):
|
|
return "+86" + mainlandPhone0086Pattern.ReplaceAllString(phone, "$1")
|
|
default:
|
|
return phone
|
|
}
|
|
}
|
|
|
|
func stringPointerOrNil(value string) *string {
|
|
if value == "" {
|
|
return nil
|
|
}
|
|
return dara.String(value)
|
|
}
|
|
|
|
func pointerString(value *string) string {
|
|
if value == nil {
|
|
return ""
|
|
}
|
|
return *value
|
|
}
|
|
|
|
func valueOrDefault(value, fallback string) string {
|
|
if strings.TrimSpace(value) == "" {
|
|
return fallback
|
|
}
|
|
return value
|
|
}
|