Files

463 lines
13 KiB
Go

package service
import (
"context"
cryptorand "crypto/rand"
"encoding/json"
"fmt"
"log"
"regexp"
"strings"
"time"
aliyunopenapiutil "github.com/alibabacloud-go/darabonba-openapi/v2/utils"
aliyunsms "github.com/alibabacloud-go/dysmsapi-20170525/v5/client"
"github.com/alibabacloud-go/tea/dara"
tccommon "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common"
tcprofile "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/common/profile"
tcsms "github.com/tencentcloud/tencentcloud-sdk-go/tencentcloud/sms/v20210111"
)
var (
validPhonePattern = regexp.MustCompile(`^((\+86|86)?1[3-9]\d{9}|\+[1-9]\d{6,14})$`)
mainlandPhonePattern = regexp.MustCompile(`^1[3-9]\d{9}$`)
mainlandPhone86Pattern = regexp.MustCompile(`^86(1[3-9]\d{9})$`)
mainlandPhone0086Pattern = regexp.MustCompile(`^0086(1[3-9]\d{9})$`)
verificationCodeCharset10 = 1000000
)
// SMSProvider sends one verification code to one phone number.
type SMSProvider interface {
SendVerificationCode(ctx context.Context, phone, code string) error
}
// MockSMSProvider is a test helper and is not wired into the server runtime.
type MockSMSProvider struct{}
func (m *MockSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
_ = ctx
// 安全:不在日志中记录完整验证码,仅显示部分信息用于调试
maskedCode := "****"
if len(code) >= 4 {
maskedCode = strings.Repeat("*", len(code)-4) + code[len(code)-4:]
}
log.Printf("[sms-mock] phone=%s code=%s ttl=5m", phone, maskedCode)
return nil
}
type aliyunSMSClient interface {
SendSms(request *aliyunsms.SendSmsRequest) (*aliyunsms.SendSmsResponse, error)
}
type tencentSMSClient interface {
SendSmsWithContext(ctx context.Context, request *tcsms.SendSmsRequest) (*tcsms.SendSmsResponse, error)
}
type AliyunSMSConfig struct {
AccessKeyID string
AccessKeySecret string
SignName string
TemplateCode string
Endpoint string
RegionID string
CodeParamName string
}
type AliyunSMSProvider struct {
cfg AliyunSMSConfig
client aliyunSMSClient
}
func NewAliyunSMSProvider(cfg AliyunSMSConfig) (SMSProvider, error) {
cfg = normalizeAliyunSMSConfig(cfg)
if cfg.AccessKeyID == "" || cfg.AccessKeySecret == "" || cfg.SignName == "" || cfg.TemplateCode == "" {
return nil, fmt.Errorf("aliyun SMS config is incomplete")
}
client, err := newAliyunSMSClient(cfg)
if err != nil {
return nil, fmt.Errorf("create aliyun SMS client failed: %w", err)
}
return &AliyunSMSProvider{
cfg: cfg,
client: client,
}, nil
}
func newAliyunSMSClient(cfg AliyunSMSConfig) (aliyunSMSClient, error) {
client, err := aliyunsms.NewClient(&aliyunopenapiutil.Config{
AccessKeyId: dara.String(cfg.AccessKeyID),
AccessKeySecret: dara.String(cfg.AccessKeySecret),
Endpoint: stringPointerOrNil(cfg.Endpoint),
RegionId: dara.String(cfg.RegionID),
})
if err != nil {
return nil, err
}
return client, nil
}
func (a *AliyunSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
_ = ctx
templateParam, err := json.Marshal(map[string]string{
a.cfg.CodeParamName: code,
})
if err != nil {
return fmt.Errorf("marshal aliyun SMS template param failed: %w", err)
}
resp, err := a.client.SendSms(
new(aliyunsms.SendSmsRequest).
SetPhoneNumbers(normalizePhoneForSMS(phone)).
SetSignName(a.cfg.SignName).
SetTemplateCode(a.cfg.TemplateCode).
SetTemplateParam(string(templateParam)),
)
if err != nil {
return fmt.Errorf("aliyun SMS request failed: %w", err)
}
if resp == nil || resp.Body == nil {
return fmt.Errorf("aliyun SMS returned empty response")
}
body := resp.Body
if !strings.EqualFold(dara.StringValue(body.Code), "OK") {
return fmt.Errorf(
"aliyun SMS rejected: code=%s message=%s request_id=%s",
valueOrDefault(dara.StringValue(body.Code), "unknown"),
valueOrDefault(dara.StringValue(body.Message), "unknown"),
valueOrDefault(dara.StringValue(body.RequestId), "unknown"),
)
}
return nil
}
type TencentSMSConfig struct {
SecretID string
SecretKey string
AppID string
SignName string
TemplateID string
Region string
Endpoint string
}
type TencentSMSProvider struct {
cfg TencentSMSConfig
client tencentSMSClient
}
func NewTencentSMSProvider(cfg TencentSMSConfig) (SMSProvider, error) {
cfg = normalizeTencentSMSConfig(cfg)
if cfg.SecretID == "" || cfg.SecretKey == "" || cfg.AppID == "" || cfg.SignName == "" || cfg.TemplateID == "" {
return nil, fmt.Errorf("tencent SMS config is incomplete")
}
client, err := newTencentSMSClient(cfg)
if err != nil {
return nil, fmt.Errorf("create tencent SMS client failed: %w", err)
}
return &TencentSMSProvider{
cfg: cfg,
client: client,
}, nil
}
func newTencentSMSClient(cfg TencentSMSConfig) (tencentSMSClient, error) {
clientProfile := tcprofile.NewClientProfile()
clientProfile.HttpProfile.ReqTimeout = 30
if cfg.Endpoint != "" {
clientProfile.HttpProfile.Endpoint = cfg.Endpoint
}
client, err := tcsms.NewClient(
tccommon.NewCredential(cfg.SecretID, cfg.SecretKey),
cfg.Region,
clientProfile,
)
if err != nil {
return nil, err
}
return client, nil
}
func (t *TencentSMSProvider) SendVerificationCode(ctx context.Context, phone, code string) error {
req := tcsms.NewSendSmsRequest()
req.PhoneNumberSet = []*string{tccommon.StringPtr(normalizePhoneForSMS(phone))}
req.SmsSdkAppId = tccommon.StringPtr(t.cfg.AppID)
req.SignName = tccommon.StringPtr(t.cfg.SignName)
req.TemplateId = tccommon.StringPtr(t.cfg.TemplateID)
req.TemplateParamSet = []*string{tccommon.StringPtr(code)}
resp, err := t.client.SendSmsWithContext(ctx, req)
if err != nil {
return fmt.Errorf("tencent SMS request failed: %w", err)
}
if resp == nil || resp.Response == nil {
return fmt.Errorf("tencent SMS returned empty response")
}
if len(resp.Response.SendStatusSet) == 0 {
return fmt.Errorf(
"tencent SMS returned empty status list: request_id=%s",
valueOrDefault(pointerString(resp.Response.RequestId), "unknown"),
)
}
status := resp.Response.SendStatusSet[0]
if !strings.EqualFold(pointerString(status.Code), "Ok") {
return fmt.Errorf(
"tencent SMS rejected: code=%s message=%s request_id=%s",
valueOrDefault(pointerString(status.Code), "unknown"),
valueOrDefault(pointerString(status.Message), "unknown"),
valueOrDefault(pointerString(resp.Response.RequestId), "unknown"),
)
}
return nil
}
type SMSCodeConfig struct {
CodeTTL time.Duration
ResendCooldown time.Duration
MaxDailyLimit int
}
func DefaultSMSCodeConfig() SMSCodeConfig {
return SMSCodeConfig{
CodeTTL: 5 * time.Minute,
ResendCooldown: time.Minute,
MaxDailyLimit: 10,
}
}
type SMSCodeService struct {
provider SMSProvider
cache cacheInterface
cfg SMSCodeConfig
}
type cacheInterface interface {
Get(ctx context.Context, key string) (interface{}, bool)
Set(ctx context.Context, key string, value interface{}, l1TTL, l2TTL time.Duration) error
Delete(ctx context.Context, key string) error
}
func NewSMSCodeService(provider SMSProvider, cacheManager cacheInterface, cfg SMSCodeConfig) *SMSCodeService {
if cfg.CodeTTL <= 0 {
cfg.CodeTTL = 5 * time.Minute
}
if cfg.ResendCooldown <= 0 {
cfg.ResendCooldown = time.Minute
}
if cfg.MaxDailyLimit <= 0 {
cfg.MaxDailyLimit = 10
}
return &SMSCodeService{
provider: provider,
cache: cacheManager,
cfg: cfg,
}
}
type SendCodeRequest struct {
Phone string `json:"phone" binding:"required"`
Purpose string `json:"purpose"`
Scene string `json:"scene"`
}
type SendCodeResponse struct {
ExpiresIn int `json:"expires_in"`
Cooldown int `json:"cooldown"`
}
func (s *SMSCodeService) SendCode(ctx context.Context, req *SendCodeRequest) (*SendCodeResponse, error) {
if s == nil || s.provider == nil || s.cache == nil {
return nil, fmt.Errorf("sms code service is not configured")
}
if req == nil {
return nil, newValidationError("\u8bf7\u6c42\u4e0d\u80fd\u4e3a\u7a7a")
}
phone := strings.TrimSpace(req.Phone)
if !isValidPhone(phone) {
return nil, newValidationError("\u624b\u673a\u53f7\u7801\u683c\u5f0f\u4e0d\u6b63\u786e")
}
purpose := strings.TrimSpace(req.Purpose)
if purpose == "" {
purpose = strings.TrimSpace(req.Scene)
}
cooldownKey := fmt.Sprintf("sms_cooldown:%s", phone)
if _, ok := s.cache.Get(ctx, cooldownKey); ok {
return nil, newRateLimitError(fmt.Sprintf("\u64cd\u4f5c\u8fc7\u4e8e\u9891\u7e41\uff0c\u8bf7 %d \u79d2\u540e\u518d\u8bd5", int(s.cfg.ResendCooldown.Seconds())))
}
dailyKey := fmt.Sprintf("sms_daily:%s:%s", phone, time.Now().Format("2006-01-02"))
var dailyCount int
if val, ok := s.cache.Get(ctx, dailyKey); ok {
if n, ok := intValue(val); ok {
dailyCount = n
}
}
if dailyCount >= s.cfg.MaxDailyLimit {
return nil, newRateLimitError(fmt.Sprintf("\u4eca\u65e5\u53d1\u9001\u6b21\u6570\u5df2\u8fbe\u4e0a\u9650\uff08%d\u6b21\uff09\uff0c\u8bf7\u660e\u65e5\u518d\u8bd5", s.cfg.MaxDailyLimit))
}
code, err := generateSMSCode()
if err != nil {
return nil, fmt.Errorf("generate sms code failed: %w", err)
}
codeKey := fmt.Sprintf("sms_code:%s:%s", purpose, phone)
if err := s.cache.Set(ctx, codeKey, code, s.cfg.CodeTTL, s.cfg.CodeTTL); err != nil {
return nil, fmt.Errorf("store sms code failed: %w", err)
}
if err := s.cache.Set(ctx, cooldownKey, true, s.cfg.ResendCooldown, s.cfg.ResendCooldown); err != nil {
_ = s.cache.Delete(ctx, codeKey)
return nil, fmt.Errorf("store sms cooldown failed: %w", err)
}
if err := s.cache.Set(ctx, dailyKey, dailyCount+1, 24*time.Hour, 24*time.Hour); err != nil {
_ = s.cache.Delete(ctx, codeKey)
_ = s.cache.Delete(ctx, cooldownKey)
return nil, fmt.Errorf("store sms daily counter failed: %w", err)
}
if err := s.provider.SendVerificationCode(ctx, phone, code); err != nil {
_ = s.cache.Delete(ctx, codeKey)
_ = s.cache.Delete(ctx, cooldownKey)
return nil, fmt.Errorf("\u77ed\u4fe1\u53d1\u9001\u5931\u8d25: %w", err)
}
return &SendCodeResponse{
ExpiresIn: int(s.cfg.CodeTTL.Seconds()),
Cooldown: int(s.cfg.ResendCooldown.Seconds()),
}, nil
}
func (s *SMSCodeService) VerifyCode(ctx context.Context, phone, purpose, code string) error {
if s == nil || s.cache == nil {
return fmt.Errorf("sms code service is not configured")
}
if strings.TrimSpace(code) == "" {
return fmt.Errorf("\u9a8c\u8bc1\u7801\u4e0d\u80fd\u4e3a\u7a7a")
}
phone = strings.TrimSpace(phone)
purpose = strings.TrimSpace(purpose)
codeKey := fmt.Sprintf("sms_code:%s:%s", purpose, phone)
val, ok := s.cache.Get(ctx, codeKey)
if !ok {
return fmt.Errorf("\u9a8c\u8bc1\u7801\u5df2\u8fc7\u671f\u6216\u4e0d\u5b58\u5728")
}
stored, ok := val.(string)
if !ok || stored != code {
return fmt.Errorf("\u9a8c\u8bc1\u7801\u4e0d\u6b63\u786e")
}
if err := s.cache.Delete(ctx, codeKey); err != nil {
return fmt.Errorf("consume sms code failed: %w", err)
}
return nil
}
func isValidPhone(phone string) bool {
return validPhonePattern.MatchString(strings.TrimSpace(phone))
}
func generateSMSCode() (string, error) {
b := make([]byte, 4)
if _, err := cryptorand.Read(b); err != nil {
return "", err
}
n := int(b[0])<<24 | int(b[1])<<16 | int(b[2])<<8 | int(b[3])
if n < 0 {
n = -n
}
n = n % verificationCodeCharset10
if n < 100000 {
n += 100000
}
return fmt.Sprintf("%06d", n), nil
}
func normalizeAliyunSMSConfig(cfg AliyunSMSConfig) AliyunSMSConfig {
cfg.AccessKeyID = strings.TrimSpace(cfg.AccessKeyID)
cfg.AccessKeySecret = strings.TrimSpace(cfg.AccessKeySecret)
cfg.SignName = strings.TrimSpace(cfg.SignName)
cfg.TemplateCode = strings.TrimSpace(cfg.TemplateCode)
cfg.Endpoint = strings.TrimSpace(cfg.Endpoint)
cfg.RegionID = strings.TrimSpace(cfg.RegionID)
cfg.CodeParamName = strings.TrimSpace(cfg.CodeParamName)
if cfg.RegionID == "" {
cfg.RegionID = "cn-hangzhou"
}
if cfg.CodeParamName == "" {
cfg.CodeParamName = "code"
}
return cfg
}
func normalizeTencentSMSConfig(cfg TencentSMSConfig) TencentSMSConfig {
cfg.SecretID = strings.TrimSpace(cfg.SecretID)
cfg.SecretKey = strings.TrimSpace(cfg.SecretKey)
cfg.AppID = strings.TrimSpace(cfg.AppID)
cfg.SignName = strings.TrimSpace(cfg.SignName)
cfg.TemplateID = strings.TrimSpace(cfg.TemplateID)
cfg.Region = strings.TrimSpace(cfg.Region)
cfg.Endpoint = strings.TrimSpace(cfg.Endpoint)
if cfg.Region == "" {
cfg.Region = "ap-guangzhou"
}
return cfg
}
func normalizePhoneForSMS(phone string) string {
phone = strings.TrimSpace(phone)
switch {
case mainlandPhonePattern.MatchString(phone):
return "+86" + phone
case mainlandPhone86Pattern.MatchString(phone):
return "+" + phone
case mainlandPhone0086Pattern.MatchString(phone):
return "+86" + mainlandPhone0086Pattern.ReplaceAllString(phone, "$1")
default:
return phone
}
}
func stringPointerOrNil(value string) *string {
if value == "" {
return nil
}
return dara.String(value)
}
func pointerString(value *string) string {
if value == nil {
return ""
}
return *value
}
func valueOrDefault(value, fallback string) string {
if strings.TrimSpace(value) == "" {
return fallback
}
return value
}