feat(supply-api): 完成IAM和Audit数据库-backed Repository实现
- 新增 iam_schema_v1.sql DDL脚本 (iam_roles, iam_scopes, iam_role_scopes, iam_user_roles, iam_role_hierarchy) - 新增 PostgresIAMRepository 实现数据库-backed IAM仓储 - 新增 DatabaseIAMService 使用数据库-backed Repository - 新增 PostgresAuditRepository 实现数据库-backed Audit仓储 - 新增 DatabaseAuditService 使用数据库-backed Repository - 更新实施状态文档 v1.3 R-07~R-09 完成。
This commit is contained in:
419
supply-api/internal/audit/repository/audit_repository.go
Normal file
419
supply-api/internal/audit/repository/audit_repository.go
Normal file
@@ -0,0 +1,419 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
)
|
||||
|
||||
// EventFilter 事件查询过滤器(仓储层定义,避免循环依赖)
|
||||
type EventFilter struct {
|
||||
TenantID int64
|
||||
OperatorID int64
|
||||
Category string
|
||||
EventName string
|
||||
StartTime *time.Time
|
||||
EndTime *time.Time
|
||||
Limit int
|
||||
Offset int
|
||||
}
|
||||
|
||||
// AuditRepository 审计事件仓储接口
|
||||
type AuditRepository interface {
|
||||
// Emit 发送审计事件
|
||||
Emit(ctx context.Context, event *model.AuditEvent) error
|
||||
// Query 查询审计事件
|
||||
Query(ctx context.Context, filter *EventFilter) ([]*model.AuditEvent, int64, error)
|
||||
// GetByIdempotencyKey 根据幂等键获取事件
|
||||
GetByIdempotencyKey(ctx context.Context, key string) (*model.AuditEvent, error)
|
||||
}
|
||||
|
||||
// PostgresAuditRepository PostgreSQL实现的审计仓储
|
||||
type PostgresAuditRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewPostgresAuditRepository 创建PostgreSQL审计仓储
|
||||
func NewPostgresAuditRepository(pool *pgxpool.Pool) *PostgresAuditRepository {
|
||||
return &PostgresAuditRepository{pool: pool}
|
||||
}
|
||||
|
||||
// Ensure interface
|
||||
var _ AuditRepository = (*PostgresAuditRepository)(nil)
|
||||
|
||||
// Emit 发送审计事件
|
||||
func (r *PostgresAuditRepository) Emit(ctx context.Context, event *model.AuditEvent) error {
|
||||
// 生成事件ID
|
||||
if event.EventID == "" {
|
||||
event.EventID = uuid.New().String()
|
||||
}
|
||||
|
||||
// 设置时间戳
|
||||
if event.Timestamp.IsZero() {
|
||||
event.Timestamp = time.Now()
|
||||
}
|
||||
event.TimestampMs = event.Timestamp.UnixMilli()
|
||||
|
||||
// 序列化扩展字段
|
||||
var extensionsJSON []byte
|
||||
if event.Extensions != nil {
|
||||
var err error
|
||||
extensionsJSON, err = json.Marshal(event.Extensions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal extensions: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 序列化安全标记
|
||||
securityFlagsJSON, err := json.Marshal(event.SecurityFlags)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal security flags: %w", err)
|
||||
}
|
||||
|
||||
// 序列化状态变更
|
||||
var beforeStateJSON, afterStateJSON []byte
|
||||
if event.BeforeState != nil {
|
||||
beforeStateJSON, err = json.Marshal(event.BeforeState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal before state: %w", err)
|
||||
}
|
||||
}
|
||||
if event.AfterState != nil {
|
||||
afterStateJSON, err = json.Marshal(event.AfterState)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal after state: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
query := `
|
||||
INSERT INTO audit_events (
|
||||
event_id, event_name, event_category, event_sub_category,
|
||||
timestamp, timestamp_ms,
|
||||
request_id, trace_id, span_id,
|
||||
idempotency_key,
|
||||
operator_id, operator_type, operator_role,
|
||||
tenant_id, tenant_type,
|
||||
object_type, object_id,
|
||||
action, action_detail,
|
||||
credential_type, credential_id, credential_fingerprint,
|
||||
source_type, source_ip, source_region, user_agent,
|
||||
target_type, target_endpoint, target_direct,
|
||||
result_code, result_message, success,
|
||||
before_data, after_data,
|
||||
security_flags, risk_score,
|
||||
compliance_tags, invariant_rule,
|
||||
extensions,
|
||||
version, created_at
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
|
||||
$11, $12, $13, $14, $15, $16, $17, $18, $19, $20,
|
||||
$21, $22, $23, $24, $25, $26, $27, $28, $29, $30,
|
||||
$31, $32, $33, $34, $35, $36, $37, $38, $39, $40, $41
|
||||
)
|
||||
`
|
||||
|
||||
_, err = r.pool.Exec(ctx, query,
|
||||
event.EventID, event.EventName, event.EventCategory, event.EventSubCategory,
|
||||
event.Timestamp, event.TimestampMs,
|
||||
event.RequestID, event.TraceID, event.SpanID,
|
||||
event.IdempotencyKey,
|
||||
event.OperatorID, event.OperatorType, event.OperatorRole,
|
||||
event.TenantID, event.TenantType,
|
||||
event.ObjectType, event.ObjectID,
|
||||
event.Action, event.ActionDetail,
|
||||
event.CredentialType, event.CredentialID, event.CredentialFingerprint,
|
||||
event.SourceType, event.SourceIP, event.SourceRegion, event.UserAgent,
|
||||
event.TargetType, event.TargetEndpoint, event.TargetDirect,
|
||||
event.ResultCode, event.ResultMessage, event.Success,
|
||||
beforeStateJSON, afterStateJSON,
|
||||
securityFlagsJSON, event.RiskScore,
|
||||
event.ComplianceTags, event.InvariantRule,
|
||||
extensionsJSON,
|
||||
1, time.Now(),
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
// 检查幂等键重复
|
||||
if strings.Contains(err.Error(), "idempotency_key") && strings.Contains(err.Error(), "unique") {
|
||||
return ErrDuplicateIdempotencyKey
|
||||
}
|
||||
return fmt.Errorf("failed to emit audit event: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query 查询审计事件
|
||||
func (r *PostgresAuditRepository) Query(ctx context.Context, filter *EventFilter) ([]*model.AuditEvent, int64, error) {
|
||||
// 构建查询条件
|
||||
conditions := []string{}
|
||||
args := []interface{}{}
|
||||
argIndex := 1
|
||||
|
||||
if filter.TenantID != 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("tenant_id = $%d", argIndex))
|
||||
args = append(args, filter.TenantID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.Category != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("event_category = $%d", argIndex))
|
||||
args = append(args, filter.Category)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.EventName != "" {
|
||||
conditions = append(conditions, fmt.Sprintf("event_name = $%d", argIndex))
|
||||
args = append(args, filter.EventName)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.OperatorID != 0 {
|
||||
conditions = append(conditions, fmt.Sprintf("operator_id = $%d", argIndex))
|
||||
args = append(args, filter.OperatorID)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.StartTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("timestamp >= $%d", argIndex))
|
||||
args = append(args, *filter.StartTime)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
if filter.EndTime != nil {
|
||||
conditions = append(conditions, fmt.Sprintf("timestamp <= $%d", argIndex))
|
||||
args = append(args, *filter.EndTime)
|
||||
argIndex++
|
||||
}
|
||||
|
||||
whereClause := ""
|
||||
if len(conditions) > 0 {
|
||||
whereClause = "WHERE " + strings.Join(conditions, " AND ")
|
||||
}
|
||||
|
||||
// 查询总数
|
||||
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM audit_events %s", whereClause)
|
||||
var total int64
|
||||
err := r.pool.QueryRow(ctx, countQuery, args...).Scan(&total)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to count audit events: %w", err)
|
||||
}
|
||||
|
||||
// 查询事件列表
|
||||
limit := filter.Limit
|
||||
if limit <= 0 {
|
||||
limit = 100
|
||||
}
|
||||
if limit > 1000 {
|
||||
limit = 1000
|
||||
}
|
||||
|
||||
offset := filter.Offset
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
event_id, event_name, event_category, event_sub_category,
|
||||
timestamp, timestamp_ms,
|
||||
request_id, trace_id, span_id,
|
||||
idempotency_key,
|
||||
operator_id, operator_type, operator_role,
|
||||
tenant_id, tenant_type,
|
||||
object_type, object_id,
|
||||
action, action_detail,
|
||||
credential_type, credential_id, credential_fingerprint,
|
||||
source_type, source_ip, source_region, user_agent,
|
||||
target_type, target_endpoint, target_direct,
|
||||
result_code, result_message, success,
|
||||
before_data, after_data,
|
||||
security_flags, risk_score,
|
||||
compliance_tags, invariant_rule,
|
||||
extensions,
|
||||
version, created_at
|
||||
FROM audit_events
|
||||
%s
|
||||
ORDER BY timestamp DESC
|
||||
LIMIT $%d OFFSET $%d
|
||||
`, whereClause, argIndex, argIndex+1)
|
||||
|
||||
args = append(args, limit, offset)
|
||||
|
||||
rows, err := r.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to query audit events: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var events []*model.AuditEvent
|
||||
for rows.Next() {
|
||||
event, err := r.scanAuditEvent(rows)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("failed to scan audit event: %w", err)
|
||||
}
|
||||
events = append(events, event)
|
||||
}
|
||||
|
||||
return events, total, nil
|
||||
}
|
||||
|
||||
// GetByIdempotencyKey 根据幂等键获取事件
|
||||
func (r *PostgresAuditRepository) GetByIdempotencyKey(ctx context.Context, key string) (*model.AuditEvent, error) {
|
||||
query := `
|
||||
SELECT
|
||||
event_id, event_name, event_category, event_sub_category,
|
||||
timestamp, timestamp_ms,
|
||||
request_id, trace_id, span_id,
|
||||
idempotency_key,
|
||||
operator_id, operator_type, operator_role,
|
||||
tenant_id, tenant_type,
|
||||
object_type, object_id,
|
||||
action, action_detail,
|
||||
credential_type, credential_id, credential_fingerprint,
|
||||
source_type, source_ip, source_region, user_agent,
|
||||
target_type, target_endpoint, target_direct,
|
||||
result_code, result_message, success,
|
||||
before_data, after_data,
|
||||
security_flags, risk_score,
|
||||
compliance_tags, invariant_rule,
|
||||
extensions,
|
||||
version, created_at
|
||||
FROM audit_events
|
||||
WHERE idempotency_key = $1
|
||||
`
|
||||
|
||||
row := r.pool.QueryRow(ctx, query, key)
|
||||
event, err := r.scanAuditEventRow(row)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get event by idempotency key: %w", err)
|
||||
}
|
||||
|
||||
return event, nil
|
||||
}
|
||||
|
||||
// scanAuditEvent 扫描审计事件行
|
||||
func (r *PostgresAuditRepository) scanAuditEvent(rows pgx.Rows) (*model.AuditEvent, error) {
|
||||
var event model.AuditEvent
|
||||
var eventSubCategory, traceID, spanID, idempotencyKey, operatorRole string
|
||||
var beforeData, afterData, extensions []byte
|
||||
var securityFlagsJSON []byte
|
||||
var complianceTags []string
|
||||
|
||||
err := rows.Scan(
|
||||
&event.EventID, &event.EventName, &event.EventCategory, &eventSubCategory,
|
||||
&event.Timestamp, &event.TimestampMs,
|
||||
&event.RequestID, &traceID, &spanID,
|
||||
&idempotencyKey,
|
||||
&event.OperatorID, &event.OperatorType, &operatorRole,
|
||||
&event.TenantID, &event.TenantType,
|
||||
&event.ObjectType, &event.ObjectID,
|
||||
&event.Action, &event.ActionDetail,
|
||||
&event.CredentialType, &event.CredentialID, &event.CredentialFingerprint,
|
||||
&event.SourceType, &event.SourceIP, &event.SourceRegion, &event.UserAgent,
|
||||
&event.TargetType, &event.TargetEndpoint, &event.TargetDirect,
|
||||
&event.ResultCode, &event.ResultMessage, &event.Success,
|
||||
&beforeData, &afterData,
|
||||
&securityFlagsJSON, &event.RiskScore,
|
||||
&complianceTags, &event.InvariantRule,
|
||||
&extensions,
|
||||
&event.Version, &event.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
event.EventSubCategory = eventSubCategory
|
||||
event.TraceID = traceID
|
||||
event.SpanID = spanID
|
||||
event.IdempotencyKey = idempotencyKey
|
||||
event.OperatorRole = operatorRole
|
||||
event.ComplianceTags = complianceTags
|
||||
|
||||
// 反序列化JSON字段
|
||||
if beforeData != nil {
|
||||
json.Unmarshal(beforeData, &event.BeforeState)
|
||||
}
|
||||
if afterData != nil {
|
||||
json.Unmarshal(afterData, &event.AfterState)
|
||||
}
|
||||
if securityFlagsJSON != nil {
|
||||
json.Unmarshal(securityFlagsJSON, &event.SecurityFlags)
|
||||
}
|
||||
if extensions != nil {
|
||||
json.Unmarshal(extensions, &event.Extensions)
|
||||
}
|
||||
|
||||
return &event, nil
|
||||
}
|
||||
|
||||
// scanAuditEventRow 扫描单行审计事件
|
||||
func (r *PostgresAuditRepository) scanAuditEventRow(row pgx.Row) (*model.AuditEvent, error) {
|
||||
var event model.AuditEvent
|
||||
var eventSubCategory, traceID, spanID, idempotencyKey, operatorRole string
|
||||
var beforeData, afterData, extensions []byte
|
||||
var securityFlagsJSON []byte
|
||||
var complianceTags []string
|
||||
|
||||
err := row.Scan(
|
||||
&event.EventID, &event.EventName, &event.EventCategory, &eventSubCategory,
|
||||
&event.Timestamp, &event.TimestampMs,
|
||||
&event.RequestID, &traceID, &spanID,
|
||||
&idempotencyKey,
|
||||
&event.OperatorID, &event.OperatorType, &operatorRole,
|
||||
&event.TenantID, &event.TenantType,
|
||||
&event.ObjectType, &event.ObjectID,
|
||||
&event.Action, &event.ActionDetail,
|
||||
&event.CredentialType, &event.CredentialID, &event.CredentialFingerprint,
|
||||
&event.SourceType, &event.SourceIP, &event.SourceRegion, &event.UserAgent,
|
||||
&event.TargetType, &event.TargetEndpoint, &event.TargetDirect,
|
||||
&event.ResultCode, &event.ResultMessage, &event.Success,
|
||||
&beforeData, &afterData,
|
||||
&securityFlagsJSON, &event.RiskScore,
|
||||
&complianceTags, &event.InvariantRule,
|
||||
&extensions,
|
||||
&event.Version, &event.CreatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
event.EventSubCategory = eventSubCategory
|
||||
event.TraceID = traceID
|
||||
event.SpanID = spanID
|
||||
event.IdempotencyKey = idempotencyKey
|
||||
event.OperatorRole = operatorRole
|
||||
event.ComplianceTags = complianceTags
|
||||
|
||||
// 反序列化JSON字段
|
||||
if beforeData != nil {
|
||||
json.Unmarshal(beforeData, &event.BeforeState)
|
||||
}
|
||||
if afterData != nil {
|
||||
json.Unmarshal(afterData, &event.AfterState)
|
||||
}
|
||||
if securityFlagsJSON != nil {
|
||||
json.Unmarshal(securityFlagsJSON, &event.SecurityFlags)
|
||||
}
|
||||
if extensions != nil {
|
||||
json.Unmarshal(extensions, &event.Extensions)
|
||||
}
|
||||
|
||||
return &event, nil
|
||||
}
|
||||
|
||||
// errors
|
||||
var (
|
||||
ErrDuplicateIdempotencyKey = errors.New("duplicate idempotency key")
|
||||
)
|
||||
96
supply-api/internal/audit/service/audit_service_db.go
Normal file
96
supply-api/internal/audit/service/audit_service_db.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/audit/model"
|
||||
"lijiaoqiao/supply-api/internal/audit/repository"
|
||||
)
|
||||
|
||||
// DatabaseAuditService 数据库-backed审计服务
|
||||
type DatabaseAuditService struct {
|
||||
repo repository.AuditRepository
|
||||
}
|
||||
|
||||
// NewDatabaseAuditService 创建数据库-backed审计服务
|
||||
func NewDatabaseAuditService(repo repository.AuditRepository) *DatabaseAuditService {
|
||||
return &DatabaseAuditService{repo: repo}
|
||||
}
|
||||
|
||||
// Ensure interface
|
||||
var _ AuditStoreInterface = (*DatabaseAuditService)(nil)
|
||||
|
||||
// Emit 发送审计事件
|
||||
func (s *DatabaseAuditService) Emit(ctx context.Context, event *model.AuditEvent) error {
|
||||
// 验证事件
|
||||
if event == nil {
|
||||
return ErrInvalidInput
|
||||
}
|
||||
if event.EventName == "" {
|
||||
return ErrMissingEventName
|
||||
}
|
||||
|
||||
// 检查幂等键
|
||||
if event.IdempotencyKey != "" {
|
||||
existing, err := s.repo.GetByIdempotencyKey(ctx, event.IdempotencyKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if existing != nil {
|
||||
// 幂等键已存在,检查payload是否一致
|
||||
if isSamePayload(existing, event) {
|
||||
return repository.ErrDuplicateIdempotencyKey
|
||||
}
|
||||
return ErrIdempotencyConflict
|
||||
}
|
||||
}
|
||||
|
||||
// 发送事件
|
||||
if err := s.repo.Emit(ctx, event); err != nil {
|
||||
if errors.Is(err, repository.ErrDuplicateIdempotencyKey) {
|
||||
return repository.ErrDuplicateIdempotencyKey
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Query 查询审计事件
|
||||
func (s *DatabaseAuditService) Query(ctx context.Context, filter *EventFilter) ([]*model.AuditEvent, int64, error) {
|
||||
if filter == nil {
|
||||
filter = &EventFilter{}
|
||||
}
|
||||
// 转换 filter 类型
|
||||
repoFilter := &repository.EventFilter{
|
||||
TenantID: filter.TenantID,
|
||||
Category: filter.Category,
|
||||
EventName: filter.EventName,
|
||||
Limit: filter.Limit,
|
||||
Offset: filter.Offset,
|
||||
}
|
||||
if !filter.StartTime.IsZero() {
|
||||
repoFilter.StartTime = &filter.StartTime
|
||||
}
|
||||
if !filter.EndTime.IsZero() {
|
||||
repoFilter.EndTime = &filter.EndTime
|
||||
}
|
||||
return s.repo.Query(ctx, repoFilter)
|
||||
}
|
||||
|
||||
// GetByIdempotencyKey 根据幂等键获取事件
|
||||
func (s *DatabaseAuditService) GetByIdempotencyKey(ctx context.Context, key string) (*model.AuditEvent, error) {
|
||||
return s.repo.GetByIdempotencyKey(ctx, key)
|
||||
}
|
||||
|
||||
// NewDatabaseAuditServiceWithPool 从数据库连接池创建审计服务
|
||||
func NewDatabaseAuditServiceWithPool(pool interface {
|
||||
Query(ctx context.Context, sql string, args ...interface{}) (interface{}, error)
|
||||
Exec(ctx context.Context, sql string, args ...interface{}) (interface{}, error)
|
||||
}) *DatabaseAuditService {
|
||||
// 注意:这里需要一个适配器来将通用的pool接口转换为pgxpool.Pool
|
||||
// 在实际使用中,应该直接使用 NewDatabaseAuditService(repo)
|
||||
// 这个函数仅用于类型兼容性
|
||||
return nil
|
||||
}
|
||||
599
supply-api/internal/iam/repository/iam_repository.go
Normal file
599
supply-api/internal/iam/repository/iam_repository.go
Normal file
@@ -0,0 +1,599 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"lijiaoqiao/supply-api/internal/iam/model"
|
||||
)
|
||||
|
||||
// errors
|
||||
var (
|
||||
ErrRoleNotFound = errors.New("role not found")
|
||||
ErrDuplicateRoleCode = errors.New("role code already exists")
|
||||
ErrDuplicateAssignment = errors.New("user already has this role")
|
||||
ErrScopeNotFound = errors.New("scope not found")
|
||||
ErrUserRoleNotFound = errors.New("user role not found")
|
||||
)
|
||||
|
||||
// IAMRepository IAM数据仓储接口
|
||||
type IAMRepository interface {
|
||||
// Role operations
|
||||
CreateRole(ctx context.Context, role *model.Role) error
|
||||
GetRoleByCode(ctx context.Context, code string) (*model.Role, error)
|
||||
UpdateRole(ctx context.Context, role *model.Role) error
|
||||
DeleteRole(ctx context.Context, code string) error
|
||||
ListRoles(ctx context.Context, roleType string) ([]*model.Role, error)
|
||||
|
||||
// Scope operations
|
||||
CreateScope(ctx context.Context, scope *model.Scope) error
|
||||
GetScopeByCode(ctx context.Context, code string) (*model.Scope, error)
|
||||
ListScopes(ctx context.Context) ([]*model.Scope, error)
|
||||
|
||||
// Role-Scope operations
|
||||
AddScopeToRole(ctx context.Context, roleCode, scopeCode string) error
|
||||
RemoveScopeFromRole(ctx context.Context, roleCode, scopeCode string) error
|
||||
GetScopesByRoleCode(ctx context.Context, roleCode string) ([]string, error)
|
||||
|
||||
// User-Role operations
|
||||
AssignRole(ctx context.Context, userRole *model.UserRoleMapping) error
|
||||
RevokeRole(ctx context.Context, userID int64, roleCode string, tenantID int64) error
|
||||
GetUserRoles(ctx context.Context, userID int64) ([]*model.UserRoleMapping, error)
|
||||
GetUserRolesWithCode(ctx context.Context, userID int64) ([]*UserRoleWithCode, error)
|
||||
GetUserScopes(ctx context.Context, userID int64) ([]string, error)
|
||||
}
|
||||
|
||||
// PostgresIAMRepository PostgreSQL实现的IAM仓储
|
||||
type PostgresIAMRepository struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
// NewPostgresIAMRepository 创建PostgreSQL IAM仓储
|
||||
func NewPostgresIAMRepository(pool *pgxpool.Pool) *PostgresIAMRepository {
|
||||
return &PostgresIAMRepository{pool: pool}
|
||||
}
|
||||
|
||||
// Ensure interfaces
|
||||
var _ IAMRepository = (*PostgresIAMRepository)(nil)
|
||||
|
||||
// ============ Role Operations ============
|
||||
|
||||
// CreateRole 创建角色
|
||||
func (r *PostgresIAMRepository) CreateRole(ctx context.Context, role *model.Role) error {
|
||||
query := `
|
||||
INSERT INTO iam_roles (code, name, type, parent_role_id, level, description, is_active,
|
||||
request_id, created_ip, updated_ip, version, created_at, updated_at)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
|
||||
`
|
||||
|
||||
var parentID *int64
|
||||
if role.ParentRoleID != nil {
|
||||
parentID = role.ParentRoleID
|
||||
}
|
||||
|
||||
var createdIP, updatedIP interface{}
|
||||
if role.CreatedIP != "" {
|
||||
createdIP = role.CreatedIP
|
||||
}
|
||||
if role.UpdatedIP != "" {
|
||||
updatedIP = role.UpdatedIP
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
if role.CreatedAt == nil {
|
||||
role.CreatedAt = &now
|
||||
}
|
||||
if role.UpdatedAt == nil {
|
||||
role.UpdatedAt = &now
|
||||
}
|
||||
|
||||
_, err := r.pool.Exec(ctx, query,
|
||||
role.Code, role.Name, role.Type, parentID, role.Level, role.Description, role.IsActive,
|
||||
role.RequestID, createdIP, updatedIP, role.Version, role.CreatedAt, role.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "duplicate key") || strings.Contains(err.Error(), "unique constraint") {
|
||||
return ErrDuplicateRoleCode
|
||||
}
|
||||
return fmt.Errorf("failed to create role: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetRoleByCode 根据角色代码获取角色
|
||||
func (r *PostgresIAMRepository) GetRoleByCode(ctx context.Context, code string) (*model.Role, error) {
|
||||
query := `
|
||||
SELECT id, code, name, type, parent_role_id, level, description, is_active,
|
||||
request_id, created_ip, updated_ip, version, created_at, updated_at
|
||||
FROM iam_roles WHERE code = $1 AND is_active = true
|
||||
`
|
||||
|
||||
var role model.Role
|
||||
var parentID *int64
|
||||
var createdIP, updatedIP *string
|
||||
|
||||
err := r.pool.QueryRow(ctx, query, code).Scan(
|
||||
&role.ID, &role.Code, &role.Name, &role.Type, &parentID, &role.Level,
|
||||
&role.Description, &role.IsActive, &role.RequestID, &createdIP, &updatedIP,
|
||||
&role.Version, &role.CreatedAt, &role.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrRoleNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
|
||||
role.ParentRoleID = parentID
|
||||
if createdIP != nil {
|
||||
role.CreatedIP = *createdIP
|
||||
}
|
||||
if updatedIP != nil {
|
||||
role.UpdatedIP = *updatedIP
|
||||
}
|
||||
|
||||
return &role, nil
|
||||
}
|
||||
|
||||
// UpdateRole 更新角色
|
||||
func (r *PostgresIAMRepository) UpdateRole(ctx context.Context, role *model.Role) error {
|
||||
query := `
|
||||
UPDATE iam_roles
|
||||
SET name = $2, description = $3, is_active = $4, updated_ip = $5, version = version + 1, updated_at = NOW()
|
||||
WHERE code = $1 AND is_active = true
|
||||
`
|
||||
|
||||
result, err := r.pool.Exec(ctx, query, role.Code, role.Name, role.Description, role.IsActive, role.UpdatedIP)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update role: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return ErrRoleNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DeleteRole 删除角色(软删除)
|
||||
func (r *PostgresIAMRepository) DeleteRole(ctx context.Context, code string) error {
|
||||
query := `UPDATE iam_roles SET is_active = false, updated_at = NOW() WHERE code = $1`
|
||||
|
||||
result, err := r.pool.Exec(ctx, query, code)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete role: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return ErrRoleNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRoles 列出角色
|
||||
func (r *PostgresIAMRepository) ListRoles(ctx context.Context, roleType string) ([]*model.Role, error) {
|
||||
var query string
|
||||
var args []interface{}
|
||||
|
||||
if roleType != "" {
|
||||
query = `
|
||||
SELECT id, code, name, type, parent_role_id, level, description, is_active,
|
||||
request_id, created_ip, updated_ip, version, created_at, updated_at
|
||||
FROM iam_roles WHERE type = $1 AND is_active = true
|
||||
`
|
||||
args = []interface{}{roleType}
|
||||
} else {
|
||||
query = `
|
||||
SELECT id, code, name, type, parent_role_id, level, description, is_active,
|
||||
request_id, created_ip, updated_ip, version, created_at, updated_at
|
||||
FROM iam_roles WHERE is_active = true
|
||||
`
|
||||
}
|
||||
|
||||
rows, err := r.pool.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list roles: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var roles []*model.Role
|
||||
for rows.Next() {
|
||||
var role model.Role
|
||||
var parentID *int64
|
||||
var createdIP, updatedIP *string
|
||||
|
||||
err := rows.Scan(
|
||||
&role.ID, &role.Code, &role.Name, &role.Type, &parentID, &role.Level,
|
||||
&role.Description, &role.IsActive, &role.RequestID, &createdIP, &updatedIP,
|
||||
&role.Version, &role.CreatedAt, &role.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan role: %w", err)
|
||||
}
|
||||
|
||||
role.ParentRoleID = parentID
|
||||
if createdIP != nil {
|
||||
role.CreatedIP = *createdIP
|
||||
}
|
||||
if updatedIP != nil {
|
||||
role.UpdatedIP = *updatedIP
|
||||
}
|
||||
|
||||
roles = append(roles, &role)
|
||||
}
|
||||
|
||||
return roles, nil
|
||||
}
|
||||
|
||||
// ============ Scope Operations ============
|
||||
|
||||
// CreateScope 创建权限范围
|
||||
func (r *PostgresIAMRepository) CreateScope(ctx context.Context, scope *model.Scope) error {
|
||||
query := `
|
||||
INSERT INTO iam_scopes (code, name, description, category, is_active, request_id, version)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
`
|
||||
|
||||
_, err := r.pool.Exec(ctx, query, scope.Code, scope.Name, scope.Description, scope.Type, scope.IsActive, scope.RequestID, scope.Version)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create scope: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetScopeByCode 根据代码获取权限范围
|
||||
func (r *PostgresIAMRepository) GetScopeByCode(ctx context.Context, code string) (*model.Scope, error) {
|
||||
query := `
|
||||
SELECT id, code, name, description, category, is_active, request_id, version, created_at, updated_at
|
||||
FROM iam_scopes WHERE code = $1 AND is_active = true
|
||||
`
|
||||
|
||||
var scope model.Scope
|
||||
err := r.pool.QueryRow(ctx, query, code).Scan(
|
||||
&scope.ID, &scope.Code, &scope.Name, &scope.Description, &scope.Type,
|
||||
&scope.IsActive, &scope.RequestID, &scope.Version, &scope.CreatedAt, &scope.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return nil, ErrScopeNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get scope: %w", err)
|
||||
}
|
||||
|
||||
return &scope, nil
|
||||
}
|
||||
|
||||
// ListScopes 列出所有权限范围
|
||||
func (r *PostgresIAMRepository) ListScopes(ctx context.Context) ([]*model.Scope, error) {
|
||||
query := `
|
||||
SELECT id, code, name, description, category, is_active, request_id, version, created_at, updated_at
|
||||
FROM iam_scopes WHERE is_active = true
|
||||
`
|
||||
|
||||
rows, err := r.pool.Query(ctx, query)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list scopes: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var scopes []*model.Scope
|
||||
for rows.Next() {
|
||||
var scope model.Scope
|
||||
err := rows.Scan(
|
||||
&scope.ID, &scope.Code, &scope.Name, &scope.Description, &scope.Type,
|
||||
&scope.IsActive, &scope.RequestID, &scope.Version, &scope.CreatedAt, &scope.UpdatedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan scope: %w", err)
|
||||
}
|
||||
scopes = append(scopes, &scope)
|
||||
}
|
||||
|
||||
return scopes, nil
|
||||
}
|
||||
|
||||
// ============ Role-Scope Operations ============
|
||||
|
||||
// AddScopeToRole 为角色添加权限
|
||||
func (r *PostgresIAMRepository) AddScopeToRole(ctx context.Context, roleCode, scopeCode string) error {
|
||||
// 获取role_id和scope_id
|
||||
var roleID, scopeID int64
|
||||
|
||||
err := r.pool.QueryRow(ctx, "SELECT id FROM iam_roles WHERE code = $1 AND is_active = true", roleCode).Scan(&roleID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrRoleNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
|
||||
err = r.pool.QueryRow(ctx, "SELECT id FROM iam_scopes WHERE code = $1 AND is_active = true", scopeCode).Scan(&scopeID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrScopeNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to get scope: %w", err)
|
||||
}
|
||||
|
||||
_, err = r.pool.Exec(ctx, "INSERT INTO iam_role_scopes (role_id, scope_id) VALUES ($1, $2) ON CONFLICT DO NOTHING", roleID, scopeID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add scope to role: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveScopeFromRole 移除角色的权限
|
||||
func (r *PostgresIAMRepository) RemoveScopeFromRole(ctx context.Context, roleCode, scopeCode string) error {
|
||||
var roleID, scopeID int64
|
||||
|
||||
err := r.pool.QueryRow(ctx, "SELECT id FROM iam_roles WHERE code = $1 AND is_active = true", roleCode).Scan(&roleID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrRoleNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
|
||||
err = r.pool.QueryRow(ctx, "SELECT id FROM iam_scopes WHERE code = $1 AND is_active = true", scopeCode).Scan(&scopeID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrScopeNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to get scope: %w", err)
|
||||
}
|
||||
|
||||
_, err = r.pool.Exec(ctx, "DELETE FROM iam_role_scopes WHERE role_id = $1 AND scope_id = $2", roleID, scopeID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove scope from role: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetScopesByRoleCode 获取角色的所有权限
|
||||
func (r *PostgresIAMRepository) GetScopesByRoleCode(ctx context.Context, roleCode string) ([]string, error) {
|
||||
query := `
|
||||
SELECT s.code FROM iam_scopes s
|
||||
JOIN iam_role_scopes rs ON s.id = rs.scope_id
|
||||
JOIN iam_roles r ON r.id = rs.role_id
|
||||
WHERE r.code = $1 AND r.is_active = true AND s.is_active = true
|
||||
`
|
||||
|
||||
rows, err := r.pool.Query(ctx, query, roleCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get scopes by role: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var scopes []string
|
||||
for rows.Next() {
|
||||
var code string
|
||||
if err := rows.Scan(&code); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan scope code: %w", err)
|
||||
}
|
||||
scopes = append(scopes, code)
|
||||
}
|
||||
|
||||
return scopes, nil
|
||||
}
|
||||
|
||||
// ============ User-Role Operations ============
|
||||
|
||||
// AssignRole 分配角色给用户
|
||||
func (r *PostgresIAMRepository) AssignRole(ctx context.Context, userRole *model.UserRoleMapping) error {
|
||||
// 检查是否已分配
|
||||
var existingID int64
|
||||
err := r.pool.QueryRow(ctx,
|
||||
"SELECT id FROM iam_user_roles WHERE user_id = $1 AND role_id = $2 AND tenant_id = $3 AND is_active = true",
|
||||
userRole.UserID, userRole.RoleID, userRole.TenantID,
|
||||
).Scan(&existingID)
|
||||
|
||||
if err == nil {
|
||||
return ErrDuplicateAssignment // 已存在
|
||||
}
|
||||
if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return fmt.Errorf("failed to check existing assignment: %w", err)
|
||||
}
|
||||
|
||||
_, err = r.pool.Exec(ctx, `
|
||||
INSERT INTO iam_user_roles (user_id, role_id, tenant_id, is_active, granted_by, expires_at, request_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
`, userRole.UserID, userRole.RoleID, userRole.TenantID, true, userRole.GrantedBy, userRole.ExpiresAt, userRole.RequestID)
|
||||
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "duplicate key") || strings.Contains(err.Error(), "unique constraint") {
|
||||
return ErrDuplicateAssignment
|
||||
}
|
||||
return fmt.Errorf("failed to assign role: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RevokeRole 撤销用户的角色
|
||||
func (r *PostgresIAMRepository) RevokeRole(ctx context.Context, userID int64, roleCode string, tenantID int64) error {
|
||||
var roleID int64
|
||||
err := r.pool.QueryRow(ctx, "SELECT id FROM iam_roles WHERE code = $1 AND is_active = true", roleCode).Scan(&roleID)
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return ErrRoleNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
|
||||
result, err := r.pool.Exec(ctx,
|
||||
"UPDATE iam_user_roles SET is_active = false WHERE user_id = $1 AND role_id = $2 AND tenant_id = $3 AND is_active = true",
|
||||
userID, roleID, tenantID,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to revoke role: %w", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return ErrUserRoleNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// UserRoleWithCode 用户角色(含角色代码)
|
||||
type UserRoleWithCode struct {
|
||||
*model.UserRoleMapping
|
||||
RoleCode string
|
||||
}
|
||||
|
||||
// GetUserRoles 获取用户的角色
|
||||
func (r *PostgresIAMRepository) GetUserRoles(ctx context.Context, userID int64) ([]*model.UserRoleMapping, error) {
|
||||
query := `
|
||||
SELECT ur.id, ur.user_id, r.code, ur.tenant_id, ur.is_active, ur.granted_by, ur.expires_at, ur.request_id, ur.created_at, ur.updated_at
|
||||
FROM iam_user_roles ur
|
||||
JOIN iam_roles r ON r.id = ur.role_id
|
||||
WHERE ur.user_id = $1 AND ur.is_active = true AND r.is_active = true
|
||||
AND (ur.expires_at IS NULL OR ur.expires_at > NOW())
|
||||
`
|
||||
|
||||
rows, err := r.pool.Query(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user roles: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var userRoles []*model.UserRoleMapping
|
||||
for rows.Next() {
|
||||
var ur model.UserRoleMapping
|
||||
var roleCode string
|
||||
err := rows.Scan(&ur.ID, &ur.UserID, &roleCode, &ur.TenantID, &ur.IsActive, &ur.GrantedBy, &ur.ExpiresAt, &ur.RequestID, &ur.CreatedAt, &ur.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan user role: %w", err)
|
||||
}
|
||||
userRoles = append(userRoles, &ur)
|
||||
}
|
||||
|
||||
return userRoles, nil
|
||||
}
|
||||
|
||||
// GetUserRolesWithCode 获取用户的角色(含角色代码)
|
||||
func (r *PostgresIAMRepository) GetUserRolesWithCode(ctx context.Context, userID int64) ([]*UserRoleWithCode, error) {
|
||||
query := `
|
||||
SELECT ur.id, ur.user_id, r.code, ur.tenant_id, ur.is_active, ur.granted_by, ur.expires_at, ur.request_id, ur.created_at, ur.updated_at
|
||||
FROM iam_user_roles ur
|
||||
JOIN iam_roles r ON r.id = ur.role_id
|
||||
WHERE ur.user_id = $1 AND ur.is_active = true AND r.is_active = true
|
||||
AND (ur.expires_at IS NULL OR ur.expires_at > NOW())
|
||||
`
|
||||
|
||||
rows, err := r.pool.Query(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user roles: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var userRoles []*UserRoleWithCode
|
||||
for rows.Next() {
|
||||
var ur model.UserRoleMapping
|
||||
var roleCode string
|
||||
err := rows.Scan(&ur.ID, &ur.UserID, &roleCode, &ur.TenantID, &ur.IsActive, &ur.GrantedBy, &ur.ExpiresAt, &ur.RequestID, &ur.CreatedAt, &ur.UpdatedAt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to scan user role: %w", err)
|
||||
}
|
||||
userRoles = append(userRoles, &UserRoleWithCode{UserRoleMapping: &ur, RoleCode: roleCode})
|
||||
}
|
||||
|
||||
return userRoles, nil
|
||||
}
|
||||
|
||||
// GetUserScopes 获取用户的所有权限
|
||||
func (r *PostgresIAMRepository) GetUserScopes(ctx context.Context, userID int64) ([]string, error) {
|
||||
query := `
|
||||
SELECT DISTINCT s.code
|
||||
FROM iam_user_roles ur
|
||||
JOIN iam_roles r ON r.id = ur.role_id
|
||||
JOIN iam_role_scopes rs ON rs.role_id = r.id
|
||||
JOIN iam_scopes s ON s.id = rs.scope_id
|
||||
WHERE ur.user_id = $1
|
||||
AND ur.is_active = true
|
||||
AND r.is_active = true
|
||||
AND s.is_active = true
|
||||
AND (ur.expires_at IS NULL OR ur.expires_at > NOW())
|
||||
`
|
||||
|
||||
rows, err := r.pool.Query(ctx, query, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user scopes: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var scopes []string
|
||||
for rows.Next() {
|
||||
var code string
|
||||
if err := rows.Scan(&code); err != nil {
|
||||
return nil, fmt.Errorf("failed to scan scope code: %w", err)
|
||||
}
|
||||
scopes = append(scopes, code)
|
||||
}
|
||||
|
||||
return scopes, nil
|
||||
}
|
||||
|
||||
// ServiceRole is a copy of service.Role for conversion (avoids import cycle)
|
||||
// Service层角色结构,用于仓储层到服务层的转换
|
||||
type ServiceRole struct {
|
||||
Code string
|
||||
Name string
|
||||
Type string
|
||||
Level int
|
||||
Description string
|
||||
IsActive bool
|
||||
Version int
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// ServiceUserRole is a copy of service.UserRole for conversion
|
||||
type ServiceUserRole struct {
|
||||
UserID int64
|
||||
RoleCode string
|
||||
TenantID int64
|
||||
IsActive bool
|
||||
ExpiresAt *time.Time
|
||||
}
|
||||
|
||||
// ModelRoleToServiceRole 将模型角色转换为服务层角色
|
||||
func ModelRoleToServiceRole(mr *model.Role) *ServiceRole {
|
||||
if mr == nil {
|
||||
return nil
|
||||
}
|
||||
return &ServiceRole{
|
||||
Code: mr.Code,
|
||||
Name: mr.Name,
|
||||
Type: mr.Type,
|
||||
Level: mr.Level,
|
||||
Description: mr.Description,
|
||||
IsActive: mr.IsActive,
|
||||
Version: mr.Version,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// ModelUserRoleToServiceUserRole 将模型用户角色转换为服务层用户角色
|
||||
// 注意:UserRoleMapping 不包含 RoleCode,需要通过 GetUserRolesWithCode 获取
|
||||
func ModelUserRoleToServiceUserRole(mur *model.UserRoleMapping, roleCode string) *ServiceUserRole {
|
||||
if mur == nil {
|
||||
return nil
|
||||
}
|
||||
return &ServiceUserRole{
|
||||
UserID: mur.UserID,
|
||||
RoleCode: roleCode,
|
||||
TenantID: mur.TenantID,
|
||||
IsActive: mur.IsActive,
|
||||
ExpiresAt: mur.ExpiresAt,
|
||||
}
|
||||
}
|
||||
290
supply-api/internal/iam/service/iam_service_db.go
Normal file
290
supply-api/internal/iam/service/iam_service_db.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"lijiaoqiao/supply-api/internal/iam/model"
|
||||
"lijiaoqiao/supply-api/internal/iam/repository"
|
||||
)
|
||||
|
||||
// DatabaseIAMService 数据库-backed IAM服务
|
||||
type DatabaseIAMService struct {
|
||||
repo repository.IAMRepository
|
||||
}
|
||||
|
||||
// NewDatabaseIAMService 创建数据库-backed IAM服务
|
||||
func NewDatabaseIAMService(repo repository.IAMRepository) *DatabaseIAMService {
|
||||
return &DatabaseIAMService{repo: repo}
|
||||
}
|
||||
|
||||
// Ensure interface
|
||||
var _ IAMServiceInterface = (*DatabaseIAMService)(nil)
|
||||
|
||||
// ============ Role Operations ============
|
||||
|
||||
// CreateRole 创建角色
|
||||
func (s *DatabaseIAMService) CreateRole(ctx context.Context, req *CreateRoleRequest) (*Role, error) {
|
||||
// 验证角色类型
|
||||
if req.Type != model.RoleTypePlatform && req.Type != model.RoleTypeSupply && req.Type != model.RoleTypeConsumer {
|
||||
return nil, ErrInvalidRequest
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
role := &model.Role{
|
||||
Code: req.Code,
|
||||
Name: req.Name,
|
||||
Type: req.Type,
|
||||
Level: req.Level,
|
||||
Description: req.Description,
|
||||
IsActive: true,
|
||||
Version: 1,
|
||||
CreatedAt: &now,
|
||||
UpdatedAt: &now,
|
||||
}
|
||||
|
||||
// 处理父角色
|
||||
if req.ParentCode != "" {
|
||||
parent, err := s.repo.GetRoleByCode(ctx, req.ParentCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parent role not found: %w", err)
|
||||
}
|
||||
role.ParentRoleID = &parent.ID
|
||||
}
|
||||
|
||||
// 创建角色
|
||||
if err := s.repo.CreateRole(ctx, role); err != nil {
|
||||
if errors.Is(err, repository.ErrDuplicateRoleCode) {
|
||||
return nil, ErrDuplicateRoleCode
|
||||
}
|
||||
return nil, fmt.Errorf("failed to create role: %w", err)
|
||||
}
|
||||
|
||||
// 添加权限关联
|
||||
for _, scopeCode := range req.Scopes {
|
||||
if err := s.repo.AddScopeToRole(ctx, req.Code, scopeCode); err != nil {
|
||||
if !errors.Is(err, repository.ErrScopeNotFound) {
|
||||
return nil, fmt.Errorf("failed to add scope %s: %w", scopeCode, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 重新获取完整角色信息
|
||||
createdRole, err := s.repo.GetRoleByCode(ctx, req.Code)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get created role: %w", err)
|
||||
}
|
||||
|
||||
return modelRoleToServiceRole(createdRole), nil
|
||||
}
|
||||
|
||||
// GetRole 获取角色
|
||||
func (s *DatabaseIAMService) GetRole(ctx context.Context, roleCode string) (*Role, error) {
|
||||
role, err := s.repo.GetRoleByCode(ctx, roleCode)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrRoleNotFound) {
|
||||
return nil, ErrRoleNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
|
||||
// 获取角色关联的权限
|
||||
scopes, err := s.repo.GetScopesByRoleCode(ctx, roleCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get role scopes: %w", err)
|
||||
}
|
||||
role.Scopes = scopes
|
||||
|
||||
return modelRoleToServiceRole(role), nil
|
||||
}
|
||||
|
||||
// UpdateRole 更新角色
|
||||
func (s *DatabaseIAMService) UpdateRole(ctx context.Context, req *UpdateRoleRequest) (*Role, error) {
|
||||
// 获取现有角色
|
||||
existing, err := s.repo.GetRoleByCode(ctx, req.Code)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrRoleNotFound) {
|
||||
return nil, ErrRoleNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
|
||||
// 更新字段
|
||||
if req.Name != "" {
|
||||
existing.Name = req.Name
|
||||
}
|
||||
if req.Description != "" {
|
||||
existing.Description = req.Description
|
||||
}
|
||||
if req.IsActive != nil {
|
||||
existing.IsActive = *req.IsActive
|
||||
}
|
||||
|
||||
// 更新权限关联
|
||||
if req.Scopes != nil {
|
||||
// 移除所有现有权限
|
||||
currentScopes, _ := s.repo.GetScopesByRoleCode(ctx, req.Code)
|
||||
for _, scope := range currentScopes {
|
||||
s.repo.RemoveScopeFromRole(ctx, req.Code, scope)
|
||||
}
|
||||
// 添加新权限
|
||||
for _, scope := range req.Scopes {
|
||||
s.repo.AddScopeToRole(ctx, req.Code, scope)
|
||||
}
|
||||
}
|
||||
|
||||
// 保存更新
|
||||
if err := s.repo.UpdateRole(ctx, existing); err != nil {
|
||||
return nil, fmt.Errorf("failed to update role: %w", err)
|
||||
}
|
||||
|
||||
return s.GetRole(ctx, req.Code)
|
||||
}
|
||||
|
||||
// DeleteRole 删除角色(软删除)
|
||||
func (s *DatabaseIAMService) DeleteRole(ctx context.Context, roleCode string) error {
|
||||
if err := s.repo.DeleteRole(ctx, roleCode); err != nil {
|
||||
if errors.Is(err, repository.ErrRoleNotFound) {
|
||||
return ErrRoleNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to delete role: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListRoles 列出角色
|
||||
func (s *DatabaseIAMService) ListRoles(ctx context.Context, roleType string) ([]*Role, error) {
|
||||
roles, err := s.repo.ListRoles(ctx, roleType)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to list roles: %w", err)
|
||||
}
|
||||
|
||||
var result []*Role
|
||||
for _, role := range roles {
|
||||
// 获取每个角色的权限
|
||||
scopes, _ := s.repo.GetScopesByRoleCode(ctx, role.Code)
|
||||
role.Scopes = scopes
|
||||
result = append(result, modelRoleToServiceRole(role))
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ============ User-Role Operations ============
|
||||
|
||||
// AssignRole 分配角色给用户
|
||||
func (s *DatabaseIAMService) AssignRole(ctx context.Context, req *AssignRoleRequest) (*UserRole, error) {
|
||||
// 获取角色ID
|
||||
role, err := s.repo.GetRoleByCode(ctx, req.RoleCode)
|
||||
if err != nil {
|
||||
if errors.Is(err, repository.ErrRoleNotFound) {
|
||||
return nil, ErrRoleNotFound
|
||||
}
|
||||
return nil, fmt.Errorf("failed to get role: %w", err)
|
||||
}
|
||||
|
||||
userRole := &model.UserRoleMapping{
|
||||
UserID: req.UserID,
|
||||
RoleID: role.ID,
|
||||
TenantID: req.TenantID,
|
||||
IsActive: true,
|
||||
GrantedBy: req.GrantedBy,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
}
|
||||
|
||||
if err := s.repo.AssignRole(ctx, userRole); err != nil {
|
||||
if errors.Is(err, repository.ErrDuplicateAssignment) {
|
||||
return nil, ErrDuplicateAssignment
|
||||
}
|
||||
return nil, fmt.Errorf("failed to assign role: %w", err)
|
||||
}
|
||||
|
||||
return &UserRole{
|
||||
UserID: req.UserID,
|
||||
RoleCode: req.RoleCode,
|
||||
TenantID: req.TenantID,
|
||||
IsActive: true,
|
||||
ExpiresAt: req.ExpiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RevokeRole 撤销用户的角色
|
||||
func (s *DatabaseIAMService) RevokeRole(ctx context.Context, userID int64, roleCode string, tenantID int64) error {
|
||||
if err := s.repo.RevokeRole(ctx, userID, roleCode, tenantID); err != nil {
|
||||
if errors.Is(err, repository.ErrRoleNotFound) {
|
||||
return ErrRoleNotFound
|
||||
}
|
||||
if errors.Is(err, repository.ErrUserRoleNotFound) {
|
||||
return ErrRoleNotFound
|
||||
}
|
||||
return fmt.Errorf("failed to revoke role: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserRoles 获取用户角色
|
||||
func (s *DatabaseIAMService) GetUserRoles(ctx context.Context, userID int64) ([]*UserRole, error) {
|
||||
userRoles, err := s.repo.GetUserRolesWithCode(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user roles: %w", err)
|
||||
}
|
||||
|
||||
var result []*UserRole
|
||||
for _, ur := range userRoles {
|
||||
result = append(result, &UserRole{
|
||||
UserID: ur.UserID,
|
||||
RoleCode: ur.RoleCode,
|
||||
TenantID: ur.TenantID,
|
||||
IsActive: ur.IsActive,
|
||||
ExpiresAt: ur.ExpiresAt,
|
||||
})
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// ============ Scope Operations ============
|
||||
|
||||
// CheckScope 检查用户是否有指定权限
|
||||
func (s *DatabaseIAMService) CheckScope(ctx context.Context, userID int64, requiredScope string) (bool, error) {
|
||||
scopes, err := s.repo.GetUserScopes(ctx, userID)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get user scopes: %w", err)
|
||||
}
|
||||
|
||||
for _, scope := range scopes {
|
||||
if scope == requiredScope || scope == "*" {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// GetUserScopes 获取用户所有权限
|
||||
func (s *DatabaseIAMService) GetUserScopes(ctx context.Context, userID int64) ([]string, error) {
|
||||
scopes, err := s.repo.GetUserScopes(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get user scopes: %w", err)
|
||||
}
|
||||
return scopes, nil
|
||||
}
|
||||
|
||||
// ============ Helper Functions ============
|
||||
|
||||
// modelRoleToServiceRole 将模型角色转换为服务层角色
|
||||
func modelRoleToServiceRole(mr *model.Role) *Role {
|
||||
return &Role{
|
||||
Code: mr.Code,
|
||||
Name: mr.Name,
|
||||
Type: mr.Type,
|
||||
Level: mr.Level,
|
||||
Description: mr.Description,
|
||||
IsActive: mr.IsActive,
|
||||
Version: mr.Version,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user