chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,60 @@
|
||||
"""
|
||||
LiteLLM Policy Engine
|
||||
|
||||
The Policy Engine allows administrators to define policies that combine guardrails
|
||||
with scoping rules. Policies can target specific teams, API keys, and models using
|
||||
wildcard patterns, and support inheritance from base policies.
|
||||
|
||||
Configuration structure:
|
||||
- `policies`: Define WHAT guardrails to apply (with inheritance and conditions)
|
||||
- `policy_attachments`: Define WHERE policies apply (teams, keys, models)
|
||||
|
||||
Example:
|
||||
```yaml
|
||||
policies:
|
||||
global-baseline:
|
||||
description: "Base guardrails for all requests"
|
||||
guardrails:
|
||||
add: [pii_blocker]
|
||||
|
||||
gpt4-safety:
|
||||
inherit: global-baseline
|
||||
description: "Extra safety for GPT-4"
|
||||
guardrails:
|
||||
add: [toxicity_filter]
|
||||
condition:
|
||||
model: "gpt-4.*" # regex pattern
|
||||
|
||||
policy_attachments:
|
||||
- policy: global-baseline
|
||||
scope: "*"
|
||||
- policy: gpt4-safety
|
||||
scope: "*"
|
||||
```
|
||||
"""
|
||||
|
||||
from litellm.proxy.policy_engine.attachment_registry import (
|
||||
AttachmentRegistry,
|
||||
get_attachment_registry,
|
||||
)
|
||||
from litellm.proxy.policy_engine.condition_evaluator import ConditionEvaluator
|
||||
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
|
||||
from litellm.proxy.policy_engine.policy_registry import (
|
||||
PolicyRegistry,
|
||||
get_policy_registry,
|
||||
)
|
||||
from litellm.proxy.policy_engine.policy_resolver import PolicyResolver
|
||||
from litellm.proxy.policy_engine.policy_validator import PolicyValidator
|
||||
|
||||
__all__ = [
|
||||
# Registries
|
||||
"PolicyRegistry",
|
||||
"get_policy_registry",
|
||||
"AttachmentRegistry",
|
||||
"get_attachment_registry",
|
||||
# Core components
|
||||
"PolicyMatcher",
|
||||
"PolicyResolver",
|
||||
"PolicyValidator",
|
||||
"ConditionEvaluator",
|
||||
]
|
||||
@@ -0,0 +1,54 @@
|
||||
# Policy Engine Architecture
|
||||
|
||||
## Overview
|
||||
|
||||
The Policy Engine allows administrators to define policies that combine guardrails with scoping rules. Policies can target specific teams, API keys, and models using wildcard patterns, and support inheritance from base policies.
|
||||
|
||||
## Architecture Diagram
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
subgraph Config["config.yaml"]
|
||||
PC[policies config]
|
||||
end
|
||||
|
||||
subgraph PolicyEngine["Policy Engine"]
|
||||
PR[PolicyRegistry]
|
||||
PV[PolicyValidator]
|
||||
PM[PolicyMatcher]
|
||||
PRe[PolicyResolver]
|
||||
end
|
||||
|
||||
subgraph Request["Incoming Request"]
|
||||
CTX[Context: team_alias, key_alias, model]
|
||||
end
|
||||
|
||||
subgraph Output["Output"]
|
||||
GR[Guardrails to Apply]
|
||||
end
|
||||
|
||||
PC -->|load| PR
|
||||
PC -->|validate| PV
|
||||
PV -->|errors/warnings| PR
|
||||
|
||||
CTX -->|match| PM
|
||||
PM -->|matching policies| PRe
|
||||
PR -->|policies| PM
|
||||
PR -->|policies| PRe
|
||||
PRe -->|resolve inheritance + add/remove| GR
|
||||
```
|
||||
|
||||
## Components
|
||||
|
||||
| Component | File | Description |
|
||||
|-----------|------|-------------|
|
||||
| **PolicyRegistry** | `policy_registry.py` | In-memory singleton store for parsed policies |
|
||||
| **PolicyValidator** | `policy_validator.py` | Validates configs (guardrails, inheritance, teams/keys/models) |
|
||||
| **PolicyMatcher** | `policy_matcher.py` | Matches request context against policy scopes |
|
||||
| **PolicyResolver** | `policy_resolver.py` | Resolves final guardrails via inheritance chain |
|
||||
|
||||
## Flow
|
||||
|
||||
1. **Startup**: `init_policies()` loads policies from config, validates, and populates `PolicyRegistry`
|
||||
2. **Request**: `PolicyMatcher` finds policies matching the request's team/key/model
|
||||
3. **Resolution**: `PolicyResolver` traverses inheritance and applies add/remove to get final guardrails
|
||||
@@ -0,0 +1,501 @@
|
||||
"""
|
||||
Attachment Registry - Manages policy attachments from YAML config.
|
||||
|
||||
Attachments define WHERE policies apply, separate from the policy definitions.
|
||||
This allows the same policy to be attached to multiple scopes.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
PolicyAttachment,
|
||||
PolicyAttachmentCreateRequest,
|
||||
PolicyAttachmentDBResponse,
|
||||
PolicyMatchContext,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
|
||||
class AttachmentRegistry:
|
||||
"""
|
||||
In-memory registry for storing and managing policy attachments.
|
||||
|
||||
Attachments define the relationship between policies and their scopes.
|
||||
A single policy can have multiple attachments (applied to different scopes).
|
||||
|
||||
Example YAML:
|
||||
```yaml
|
||||
attachments:
|
||||
- policy: global-baseline
|
||||
scope: "*"
|
||||
- policy: healthcare-compliance
|
||||
teams: [healthcare-team]
|
||||
- policy: dev-safety
|
||||
keys: ["dev-key-*"]
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._attachments: List[PolicyAttachment] = []
|
||||
self._initialized: bool = False
|
||||
|
||||
def load_attachments(self, attachments_config: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
Load attachments from a configuration list.
|
||||
|
||||
Args:
|
||||
attachments_config: List of attachment dictionaries from YAML.
|
||||
"""
|
||||
self._attachments = []
|
||||
|
||||
for attachment_data in attachments_config:
|
||||
try:
|
||||
attachment = self._parse_attachment(attachment_data)
|
||||
self._attachments.append(attachment)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Loaded attachment for policy: {attachment.policy}"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error loading attachment: {str(e)}")
|
||||
raise ValueError(f"Invalid attachment: {str(e)}") from e
|
||||
|
||||
self._initialized = True
|
||||
verbose_proxy_logger.info(f"Loaded {len(self._attachments)} policy attachments")
|
||||
|
||||
def _parse_attachment(self, attachment_data: Dict[str, Any]) -> PolicyAttachment:
|
||||
"""
|
||||
Parse an attachment from raw configuration data.
|
||||
|
||||
Args:
|
||||
attachment_data: Raw attachment configuration
|
||||
|
||||
Returns:
|
||||
Parsed PolicyAttachment object
|
||||
"""
|
||||
return PolicyAttachment(
|
||||
policy=attachment_data.get("policy", ""),
|
||||
scope=attachment_data.get("scope"),
|
||||
teams=attachment_data.get("teams"),
|
||||
keys=attachment_data.get("keys"),
|
||||
models=attachment_data.get("models"),
|
||||
tags=attachment_data.get("tags"),
|
||||
)
|
||||
|
||||
def get_attached_policies(self, context: PolicyMatchContext) -> List[str]:
|
||||
"""
|
||||
Get list of policy names attached to the given context.
|
||||
|
||||
Args:
|
||||
context: The request context to match against
|
||||
|
||||
Returns:
|
||||
List of policy names that are attached to matching scopes
|
||||
"""
|
||||
return [
|
||||
r["policy_name"] for r in self.get_attached_policies_with_reasons(context)
|
||||
]
|
||||
|
||||
def get_attached_policies_with_reasons(
|
||||
self, context: PolicyMatchContext
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Get list of policy names and match reasons for the given context.
|
||||
|
||||
Returns a list of dicts with 'policy_name' and 'matched_via' keys.
|
||||
The 'matched_via' describes which dimension caused the match.
|
||||
"""
|
||||
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
|
||||
|
||||
results: List[Dict[str, Any]] = []
|
||||
seen_policies: set = set()
|
||||
|
||||
for attachment in self._attachments:
|
||||
scope = attachment.to_policy_scope()
|
||||
if PolicyMatcher.scope_matches(scope=scope, context=context):
|
||||
if attachment.policy not in seen_policies:
|
||||
seen_policies.add(attachment.policy)
|
||||
matched_via = self._describe_match_reason(attachment, context)
|
||||
results.append(
|
||||
{
|
||||
"policy_name": attachment.policy,
|
||||
"matched_via": matched_via,
|
||||
}
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Attachment matched: policy={attachment.policy}, "
|
||||
f"matched_via={matched_via}, "
|
||||
f"context=(team={context.team_alias}, key={context.key_alias}, model={context.model})"
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _describe_match_reason(
|
||||
attachment: PolicyAttachment, context: PolicyMatchContext
|
||||
) -> str:
|
||||
"""Describe why an attachment matched the context."""
|
||||
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
|
||||
|
||||
if attachment.is_global():
|
||||
return "scope:*"
|
||||
|
||||
reasons = []
|
||||
if attachment.tags and context.tags:
|
||||
matching_tags = [
|
||||
t
|
||||
for t in context.tags
|
||||
if PolicyMatcher.matches_pattern(t, attachment.tags)
|
||||
]
|
||||
if matching_tags:
|
||||
reasons.append(f"tag:{matching_tags[0]}")
|
||||
if attachment.teams and context.team_alias:
|
||||
reasons.append(f"team:{context.team_alias}")
|
||||
if attachment.keys and context.key_alias:
|
||||
reasons.append(f"key:{context.key_alias}")
|
||||
if attachment.models and context.model:
|
||||
reasons.append(f"model:{context.model}")
|
||||
|
||||
return "+".join(reasons) if reasons else "scope:default"
|
||||
|
||||
def is_policy_attached(self, policy_name: str, context: PolicyMatchContext) -> bool:
|
||||
"""
|
||||
Check if a specific policy is attached to the given context.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy to check
|
||||
context: The request context to match against
|
||||
|
||||
Returns:
|
||||
True if the policy is attached to a matching scope
|
||||
"""
|
||||
attached = self.get_attached_policies(context)
|
||||
return policy_name in attached
|
||||
|
||||
def get_all_attachments(self) -> List[PolicyAttachment]:
|
||||
"""
|
||||
Get all loaded attachments.
|
||||
|
||||
Returns:
|
||||
List of all PolicyAttachment objects
|
||||
"""
|
||||
return self._attachments.copy()
|
||||
|
||||
def get_attachments_for_policy(self, policy_name: str) -> List[PolicyAttachment]:
|
||||
"""
|
||||
Get all attachments for a specific policy.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
|
||||
Returns:
|
||||
List of attachments for the policy
|
||||
"""
|
||||
return [a for a in self._attachments if a.policy == policy_name]
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
"""
|
||||
Check if the registry has been initialized with attachments.
|
||||
|
||||
Returns:
|
||||
True if attachments have been loaded, False otherwise
|
||||
"""
|
||||
return self._initialized
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all attachments from the registry.
|
||||
"""
|
||||
self._attachments = []
|
||||
self._initialized = False
|
||||
|
||||
def add_attachment(self, attachment: PolicyAttachment) -> None:
|
||||
"""
|
||||
Add a single attachment.
|
||||
|
||||
Args:
|
||||
attachment: PolicyAttachment object to add
|
||||
"""
|
||||
self._attachments.append(attachment)
|
||||
verbose_proxy_logger.debug(f"Added attachment for policy: {attachment.policy}")
|
||||
|
||||
def remove_attachments_for_policy(self, policy_name: str) -> int:
|
||||
"""
|
||||
Remove all attachments for a specific policy.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
|
||||
Returns:
|
||||
Number of attachments removed
|
||||
"""
|
||||
original_count = len(self._attachments)
|
||||
self._attachments = [a for a in self._attachments if a.policy != policy_name]
|
||||
removed_count = original_count - len(self._attachments)
|
||||
if removed_count > 0:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Removed {removed_count} attachment(s) for policy: {policy_name}"
|
||||
)
|
||||
return removed_count
|
||||
|
||||
def remove_attachment_by_id(self, attachment_id: str) -> bool:
|
||||
"""
|
||||
Remove an attachment by its ID (for DB-synced attachments).
|
||||
|
||||
Args:
|
||||
attachment_id: The ID of the attachment to remove
|
||||
|
||||
Returns:
|
||||
True if removed, False if not found
|
||||
"""
|
||||
# Note: In-memory attachments don't have IDs, so this is primarily
|
||||
# for consistency after DB operations
|
||||
return False
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Database CRUD Methods
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def add_attachment_to_db(
|
||||
self,
|
||||
attachment_request: PolicyAttachmentCreateRequest,
|
||||
prisma_client: "PrismaClient",
|
||||
created_by: Optional[str] = None,
|
||||
) -> PolicyAttachmentDBResponse:
|
||||
"""
|
||||
Add a policy attachment to the database.
|
||||
|
||||
Args:
|
||||
attachment_request: The attachment creation request
|
||||
prisma_client: The Prisma client instance
|
||||
created_by: User who created the attachment
|
||||
|
||||
Returns:
|
||||
PolicyAttachmentDBResponse with the created attachment
|
||||
"""
|
||||
try:
|
||||
created_attachment = (
|
||||
await prisma_client.db.litellm_policyattachmenttable.create(
|
||||
data={
|
||||
"policy_name": attachment_request.policy_name,
|
||||
"scope": attachment_request.scope,
|
||||
"teams": attachment_request.teams or [],
|
||||
"keys": attachment_request.keys or [],
|
||||
"models": attachment_request.models or [],
|
||||
"tags": attachment_request.tags or [],
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
"created_by": created_by,
|
||||
"updated_by": created_by,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Also add to in-memory registry
|
||||
attachment = PolicyAttachment(
|
||||
policy=attachment_request.policy_name,
|
||||
scope=attachment_request.scope,
|
||||
teams=attachment_request.teams,
|
||||
keys=attachment_request.keys,
|
||||
models=attachment_request.models,
|
||||
tags=attachment_request.tags,
|
||||
)
|
||||
self.add_attachment(attachment)
|
||||
|
||||
return PolicyAttachmentDBResponse(
|
||||
attachment_id=created_attachment.attachment_id,
|
||||
policy_name=created_attachment.policy_name,
|
||||
scope=created_attachment.scope,
|
||||
teams=created_attachment.teams or [],
|
||||
keys=created_attachment.keys or [],
|
||||
models=created_attachment.models or [],
|
||||
tags=created_attachment.tags or [],
|
||||
created_at=created_attachment.created_at,
|
||||
updated_at=created_attachment.updated_at,
|
||||
created_by=created_attachment.created_by,
|
||||
updated_by=created_attachment.updated_by,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error adding attachment to DB: {e}")
|
||||
raise Exception(f"Error adding attachment to DB: {str(e)}")
|
||||
|
||||
async def delete_attachment_from_db(
|
||||
self,
|
||||
attachment_id: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Delete a policy attachment from the database.
|
||||
|
||||
Args:
|
||||
attachment_id: The ID of the attachment to delete
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
Dict with success message
|
||||
"""
|
||||
try:
|
||||
# Get attachment before deleting
|
||||
attachment = (
|
||||
await prisma_client.db.litellm_policyattachmenttable.find_unique(
|
||||
where={"attachment_id": attachment_id}
|
||||
)
|
||||
)
|
||||
|
||||
if attachment is None:
|
||||
raise Exception(f"Attachment with ID {attachment_id} not found")
|
||||
|
||||
# Delete from DB
|
||||
await prisma_client.db.litellm_policyattachmenttable.delete(
|
||||
where={"attachment_id": attachment_id}
|
||||
)
|
||||
|
||||
# Note: In-memory attachments don't have IDs, so we need to sync from DB
|
||||
# to properly update in-memory state
|
||||
await self.sync_attachments_from_db(prisma_client)
|
||||
|
||||
return {"message": f"Attachment {attachment_id} deleted successfully"}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error deleting attachment from DB: {e}")
|
||||
raise Exception(f"Error deleting attachment from DB: {str(e)}")
|
||||
|
||||
async def get_attachment_by_id_from_db(
|
||||
self,
|
||||
attachment_id: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> Optional[PolicyAttachmentDBResponse]:
|
||||
"""
|
||||
Get a policy attachment by ID from the database.
|
||||
|
||||
Args:
|
||||
attachment_id: The ID of the attachment to retrieve
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
PolicyAttachmentDBResponse if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
attachment = (
|
||||
await prisma_client.db.litellm_policyattachmenttable.find_unique(
|
||||
where={"attachment_id": attachment_id}
|
||||
)
|
||||
)
|
||||
|
||||
if attachment is None:
|
||||
return None
|
||||
|
||||
return PolicyAttachmentDBResponse(
|
||||
attachment_id=attachment.attachment_id,
|
||||
policy_name=attachment.policy_name,
|
||||
scope=attachment.scope,
|
||||
teams=attachment.teams or [],
|
||||
keys=attachment.keys or [],
|
||||
models=attachment.models or [],
|
||||
tags=attachment.tags or [],
|
||||
created_at=attachment.created_at,
|
||||
updated_at=attachment.updated_at,
|
||||
created_by=attachment.created_by,
|
||||
updated_by=attachment.updated_by,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting attachment from DB: {e}")
|
||||
raise Exception(f"Error getting attachment from DB: {str(e)}")
|
||||
|
||||
async def get_all_attachments_from_db(
|
||||
self,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> List[PolicyAttachmentDBResponse]:
|
||||
"""
|
||||
Get all policy attachments from the database.
|
||||
|
||||
Args:
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
List of PolicyAttachmentDBResponse objects
|
||||
"""
|
||||
try:
|
||||
attachments = (
|
||||
await prisma_client.db.litellm_policyattachmenttable.find_many(
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
PolicyAttachmentDBResponse(
|
||||
attachment_id=a.attachment_id,
|
||||
policy_name=a.policy_name,
|
||||
scope=a.scope,
|
||||
teams=a.teams or [],
|
||||
keys=a.keys or [],
|
||||
models=a.models or [],
|
||||
tags=a.tags or [],
|
||||
created_at=a.created_at,
|
||||
updated_at=a.updated_at,
|
||||
created_by=a.created_by,
|
||||
updated_by=a.updated_by,
|
||||
)
|
||||
for a in attachments
|
||||
]
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting attachments from DB: {e}")
|
||||
raise Exception(f"Error getting attachments from DB: {str(e)}")
|
||||
|
||||
async def sync_attachments_from_db(
|
||||
self,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> None:
|
||||
"""
|
||||
Sync policy attachments from the database to in-memory registry.
|
||||
|
||||
Args:
|
||||
prisma_client: The Prisma client instance
|
||||
"""
|
||||
try:
|
||||
attachments = await self.get_all_attachments_from_db(prisma_client)
|
||||
|
||||
# Clear existing attachments and reload from DB
|
||||
self._attachments = []
|
||||
|
||||
for attachment_response in attachments:
|
||||
attachment = PolicyAttachment(
|
||||
policy=attachment_response.policy_name,
|
||||
scope=attachment_response.scope,
|
||||
teams=attachment_response.teams
|
||||
if attachment_response.teams
|
||||
else None,
|
||||
keys=attachment_response.keys if attachment_response.keys else None,
|
||||
models=attachment_response.models
|
||||
if attachment_response.models
|
||||
else None,
|
||||
tags=attachment_response.tags if attachment_response.tags else None,
|
||||
)
|
||||
self._attachments.append(attachment)
|
||||
|
||||
self._initialized = True
|
||||
verbose_proxy_logger.info(
|
||||
f"Synced {len(attachments)} attachments from DB to in-memory registry"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error syncing attachments from DB: {e}")
|
||||
raise Exception(f"Error syncing attachments from DB: {str(e)}")
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_attachment_registry: Optional[AttachmentRegistry] = None
|
||||
|
||||
|
||||
def get_attachment_registry() -> AttachmentRegistry:
|
||||
"""
|
||||
Get the global AttachmentRegistry singleton.
|
||||
|
||||
Returns:
|
||||
The global AttachmentRegistry instance
|
||||
"""
|
||||
global _attachment_registry
|
||||
if _attachment_registry is None:
|
||||
_attachment_registry = AttachmentRegistry()
|
||||
return _attachment_registry
|
||||
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Condition Evaluator - Evaluates policy conditions.
|
||||
|
||||
Supports model-based conditions with exact match or regex patterns.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
PolicyCondition,
|
||||
PolicyMatchContext,
|
||||
)
|
||||
|
||||
|
||||
class ConditionEvaluator:
|
||||
"""
|
||||
Evaluates policy conditions against request context.
|
||||
|
||||
Supports model conditions with:
|
||||
- Exact string match: "gpt-4"
|
||||
- Regex pattern: "gpt-4.*"
|
||||
- List of values: ["gpt-4", "gpt-4-turbo"]
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def evaluate(
|
||||
condition: Optional[PolicyCondition],
|
||||
context: PolicyMatchContext,
|
||||
) -> bool:
|
||||
"""
|
||||
Evaluate a policy condition against a request context.
|
||||
|
||||
Args:
|
||||
condition: The condition to evaluate (None = always matches)
|
||||
context: The request context with team, key, model
|
||||
|
||||
Returns:
|
||||
True if condition matches, False otherwise
|
||||
"""
|
||||
# No condition means always matches
|
||||
if condition is None:
|
||||
return True
|
||||
|
||||
# Check model condition
|
||||
if condition.model is not None:
|
||||
if not ConditionEvaluator._evaluate_model_condition(
|
||||
condition=condition.model,
|
||||
model=context.model,
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
f"Condition failed: model={context.model} did not match {condition.model}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _evaluate_model_condition(
|
||||
condition: Union[str, List[str]],
|
||||
model: Optional[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Evaluate a model condition.
|
||||
|
||||
Args:
|
||||
condition: String (exact or regex) or list of strings
|
||||
model: The model name to check
|
||||
|
||||
Returns:
|
||||
True if model matches condition, False otherwise
|
||||
"""
|
||||
if model is None:
|
||||
return False
|
||||
|
||||
# Handle list of values
|
||||
if isinstance(condition, list):
|
||||
return any(
|
||||
ConditionEvaluator._matches_pattern(pattern, model)
|
||||
for pattern in condition
|
||||
)
|
||||
|
||||
# Single value - check as pattern
|
||||
return ConditionEvaluator._matches_pattern(condition, model)
|
||||
|
||||
@staticmethod
|
||||
def _matches_pattern(pattern: str, value: str) -> bool:
|
||||
"""
|
||||
Check if value matches pattern (exact match or regex).
|
||||
|
||||
Args:
|
||||
pattern: Pattern to match (exact string or regex)
|
||||
value: Value to check
|
||||
|
||||
Returns:
|
||||
True if matches, False otherwise
|
||||
"""
|
||||
# First try exact match
|
||||
if pattern == value:
|
||||
return True
|
||||
|
||||
# Try as regex pattern
|
||||
try:
|
||||
if re.fullmatch(pattern, value):
|
||||
return True
|
||||
except re.error:
|
||||
# Invalid regex, treat as literal string (already checked above)
|
||||
pass
|
||||
|
||||
return False
|
||||
@@ -0,0 +1,286 @@
|
||||
"""
|
||||
Policy Initialization - Loads policies from config and validates on startup.
|
||||
|
||||
Configuration structure:
|
||||
- policies: Define WHAT guardrails to apply (with inheritance and conditions)
|
||||
- policy_attachments: Define WHERE policies apply (teams, keys, models)
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
from litellm.proxy.policy_engine.policy_validator import PolicyValidator
|
||||
from litellm.types.proxy.policy_engine import PolicyValidationResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
# ANSI color codes for terminal output
|
||||
_green_color_code = "\033[92m"
|
||||
_blue_color_code = "\033[94m"
|
||||
_yellow_color_code = "\033[93m"
|
||||
_reset_color_code = "\033[0m"
|
||||
|
||||
|
||||
def _print_policies_on_startup(
|
||||
policies_config: Dict[str, Any],
|
||||
policy_attachments_config: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Print loaded policies to console on startup (similar to model list).
|
||||
"""
|
||||
import sys
|
||||
|
||||
print( # noqa: T201
|
||||
f"{_green_color_code}\nLiteLLM Policy Engine: Loaded {len(policies_config)} policies{_reset_color_code}\n"
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
for policy_name, policy_data in policies_config.items():
|
||||
guardrails = policy_data.get("guardrails", {})
|
||||
inherit = policy_data.get("inherit")
|
||||
condition = policy_data.get("condition")
|
||||
description = policy_data.get("description")
|
||||
|
||||
guardrails_add = (
|
||||
guardrails.get("add", []) if isinstance(guardrails, dict) else []
|
||||
)
|
||||
guardrails_remove = (
|
||||
guardrails.get("remove", []) if isinstance(guardrails, dict) else []
|
||||
)
|
||||
inherit_str = f" (inherits: {inherit})" if inherit else ""
|
||||
|
||||
print( # noqa: T201
|
||||
f"{_blue_color_code} - {policy_name}{inherit_str}{_reset_color_code}"
|
||||
)
|
||||
if description:
|
||||
print(f" description: {description}") # noqa: T201
|
||||
if guardrails_add:
|
||||
print(f" guardrails.add: {guardrails_add}") # noqa: T201
|
||||
if guardrails_remove:
|
||||
print(f" guardrails.remove: {guardrails_remove}") # noqa: T201
|
||||
if condition:
|
||||
model_condition = (
|
||||
condition.get("model") if isinstance(condition, dict) else None
|
||||
)
|
||||
if model_condition:
|
||||
print(f" condition.model: {model_condition}") # noqa: T201
|
||||
|
||||
# Print attachments
|
||||
if policy_attachments_config:
|
||||
print( # noqa: T201
|
||||
f"\n{_yellow_color_code}Policy Attachments: {len(policy_attachments_config)} attachment(s){_reset_color_code}"
|
||||
)
|
||||
for attachment in policy_attachments_config:
|
||||
policy = attachment.get("policy", "unknown")
|
||||
scope = attachment.get("scope")
|
||||
teams = attachment.get("teams")
|
||||
keys = attachment.get("keys")
|
||||
models = attachment.get("models")
|
||||
|
||||
scope_parts = []
|
||||
if scope == "*":
|
||||
scope_parts.append("scope=* (global)")
|
||||
if teams:
|
||||
scope_parts.append(f"teams={teams}")
|
||||
if keys:
|
||||
scope_parts.append(f"keys={keys}")
|
||||
if models:
|
||||
scope_parts.append(f"models={models}")
|
||||
scope_str = ", ".join(scope_parts) if scope_parts else "all"
|
||||
|
||||
print(f" - {policy} -> {scope_str}") # noqa: T201
|
||||
else:
|
||||
print( # noqa: T201
|
||||
f"\n{_yellow_color_code}Warning: No policy_attachments configured. Policies will not be applied to any requests.{_reset_color_code}"
|
||||
)
|
||||
|
||||
print() # noqa: T201
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
async def init_policies(
|
||||
policies_config: Dict[str, Any],
|
||||
policy_attachments_config: Optional[List[Dict[str, Any]]] = None,
|
||||
prisma_client: Optional["PrismaClient"] = None,
|
||||
validate_db: bool = True,
|
||||
fail_on_error: bool = True,
|
||||
) -> PolicyValidationResponse:
|
||||
"""
|
||||
Initialize policies from configuration.
|
||||
|
||||
This function:
|
||||
1. Parses the policy configuration
|
||||
2. Validates policies (guardrails exist, teams/keys exist in DB)
|
||||
3. Loads policies into the global registry
|
||||
4. Loads attachments into the attachment registry (if provided)
|
||||
|
||||
Args:
|
||||
policies_config: Dictionary mapping policy names to policy definitions
|
||||
policy_attachments_config: Optional list of policy attachment configurations
|
||||
prisma_client: Optional Prisma client for database validation
|
||||
validate_db: Whether to validate team/key aliases against database
|
||||
fail_on_error: If True, raise exception on validation errors
|
||||
|
||||
Returns:
|
||||
PolicyValidationResponse with validation results
|
||||
|
||||
Raises:
|
||||
ValueError: If fail_on_error is True and validation errors are found
|
||||
"""
|
||||
verbose_proxy_logger.info(f"Initializing {len(policies_config)} policies...")
|
||||
|
||||
# Print policies to console on startup
|
||||
_print_policies_on_startup(policies_config, policy_attachments_config)
|
||||
|
||||
# Get the global registries
|
||||
policy_registry = get_policy_registry()
|
||||
attachment_registry = get_attachment_registry()
|
||||
|
||||
# Create validator
|
||||
validator = PolicyValidator(prisma_client=prisma_client)
|
||||
|
||||
# Validate the configuration
|
||||
validation_result = await validator.validate_policy_config(
|
||||
policies_config,
|
||||
validate_db=validate_db,
|
||||
)
|
||||
|
||||
# Log validation results
|
||||
if validation_result.errors:
|
||||
for error in validation_result.errors:
|
||||
verbose_proxy_logger.error(
|
||||
f"Policy validation error in '{error.policy_name}': "
|
||||
f"[{error.error_type}] {error.message}"
|
||||
)
|
||||
|
||||
if validation_result.warnings:
|
||||
for warning in validation_result.warnings:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Policy validation warning in '{warning.policy_name}': "
|
||||
f"[{warning.error_type}] {warning.message}"
|
||||
)
|
||||
|
||||
# Fail if there are errors and fail_on_error is True
|
||||
if not validation_result.valid and fail_on_error:
|
||||
error_messages = [
|
||||
f"[{e.policy_name}] {e.message}" for e in validation_result.errors
|
||||
]
|
||||
raise ValueError(
|
||||
f"Policy validation failed with {len(validation_result.errors)} error(s):\n"
|
||||
+ "\n".join(error_messages)
|
||||
)
|
||||
|
||||
# Load policies into registry (even with warnings)
|
||||
try:
|
||||
policy_registry.load_policies(policies_config)
|
||||
verbose_proxy_logger.info(
|
||||
f"Successfully loaded {len(policies_config)} policies"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Failed to load policies: {str(e)}")
|
||||
raise
|
||||
|
||||
# Load attachments if provided
|
||||
if policy_attachments_config:
|
||||
try:
|
||||
attachment_registry.load_attachments(policy_attachments_config)
|
||||
verbose_proxy_logger.info(
|
||||
f"Successfully loaded {len(policy_attachments_config)} policy attachments"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Failed to load policy attachments: {str(e)}")
|
||||
raise
|
||||
|
||||
return validation_result
|
||||
|
||||
|
||||
def init_policies_sync(
|
||||
policies_config: Dict[str, Any],
|
||||
policy_attachments_config: Optional[List[Dict[str, Any]]] = None,
|
||||
fail_on_error: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
Synchronous version of init_policies (without DB validation).
|
||||
|
||||
Use this when async is not available or DB validation is not needed.
|
||||
|
||||
Args:
|
||||
policies_config: Dictionary mapping policy names to policy definitions
|
||||
policy_attachments_config: Optional list of policy attachment configurations
|
||||
fail_on_error: If True, raise exception on validation errors
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# Run the async function without DB validation
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
loop.run_until_complete(
|
||||
init_policies(
|
||||
policies_config=policies_config,
|
||||
policy_attachments_config=policy_attachments_config,
|
||||
prisma_client=None,
|
||||
validate_db=False,
|
||||
fail_on_error=fail_on_error,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_policies_summary() -> Dict[str, Any]:
|
||||
"""
|
||||
Get a summary of loaded policies for debugging/display.
|
||||
|
||||
Returns:
|
||||
Dictionary with policy information
|
||||
"""
|
||||
from litellm.proxy.policy_engine.policy_resolver import PolicyResolver
|
||||
|
||||
policy_registry = get_policy_registry()
|
||||
attachment_registry = get_attachment_registry()
|
||||
|
||||
if not policy_registry.is_initialized():
|
||||
return {"initialized": False, "policies": {}, "attachments": []}
|
||||
|
||||
resolved = PolicyResolver.get_all_resolved_policies()
|
||||
|
||||
summary: Dict[str, Any] = {
|
||||
"initialized": True,
|
||||
"policy_count": len(resolved),
|
||||
"attachment_count": len(attachment_registry.get_all_attachments()),
|
||||
"policies": {},
|
||||
"attachments": [],
|
||||
}
|
||||
|
||||
for policy_name, resolved_policy in resolved.items():
|
||||
policy = policy_registry.get_policy(policy_name)
|
||||
summary["policies"][policy_name] = {
|
||||
"inherit": policy.inherit if policy else None,
|
||||
"description": policy.description if policy else None,
|
||||
"guardrails_add": policy.guardrails.get_add() if policy else [],
|
||||
"guardrails_remove": policy.guardrails.get_remove() if policy else [],
|
||||
"condition": policy.condition.model_dump()
|
||||
if policy and policy.condition
|
||||
else None,
|
||||
"resolved_guardrails": resolved_policy.guardrails,
|
||||
"inheritance_chain": resolved_policy.inheritance_chain,
|
||||
}
|
||||
|
||||
# Add attachment info
|
||||
for attachment in attachment_registry.get_all_attachments():
|
||||
summary["attachments"].append(
|
||||
{
|
||||
"policy": attachment.policy,
|
||||
"scope": attachment.scope,
|
||||
"teams": attachment.teams,
|
||||
"keys": attachment.keys,
|
||||
"models": attachment.models,
|
||||
}
|
||||
)
|
||||
|
||||
return summary
|
||||
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Pipeline Executor - Executes guardrail pipelines with conditional step logic.
|
||||
|
||||
Runs guardrails sequentially per pipeline step definitions, handling
|
||||
pass/fail actions (allow, block, next, modify_response) and data forwarding.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
ModifyResponseException,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import (
|
||||
UnifiedLLMGuardrails,
|
||||
)
|
||||
from litellm.types.proxy.policy_engine.pipeline_types import (
|
||||
PipelineExecutionResult,
|
||||
PipelineStep,
|
||||
PipelineStepResult,
|
||||
)
|
||||
|
||||
try:
|
||||
from fastapi.exceptions import HTTPException
|
||||
except ImportError:
|
||||
HTTPException = None # type: ignore
|
||||
|
||||
|
||||
class PipelineExecutor:
|
||||
"""Executes guardrail pipelines with ordered, conditional step logic."""
|
||||
|
||||
@staticmethod
|
||||
async def execute_steps(
|
||||
steps: List[PipelineStep],
|
||||
mode: str,
|
||||
data: dict,
|
||||
user_api_key_dict: Any,
|
||||
call_type: str,
|
||||
policy_name: str,
|
||||
) -> PipelineExecutionResult:
|
||||
"""
|
||||
Execute pipeline steps sequentially with conditional actions.
|
||||
|
||||
Args:
|
||||
steps: Ordered list of pipeline steps
|
||||
mode: Event hook mode (pre_call, post_call)
|
||||
data: Request data dict
|
||||
user_api_key_dict: User API key auth
|
||||
call_type: Type of call (completion, etc.)
|
||||
policy_name: Name of the owning policy (for logging)
|
||||
|
||||
Returns:
|
||||
PipelineExecutionResult with terminal action and step results
|
||||
"""
|
||||
step_results: List[PipelineStepResult] = []
|
||||
working_data = data.copy()
|
||||
if "metadata" in working_data:
|
||||
working_data["metadata"] = working_data["metadata"].copy()
|
||||
|
||||
for i, step in enumerate(steps):
|
||||
start_time = time.perf_counter()
|
||||
|
||||
outcome, modified_data, error_detail = await PipelineExecutor._run_step(
|
||||
step=step,
|
||||
mode=mode,
|
||||
data=working_data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_type=call_type,
|
||||
)
|
||||
|
||||
duration = time.perf_counter() - start_time
|
||||
|
||||
action = step.on_pass if outcome == "pass" else step.on_fail
|
||||
|
||||
step_result = PipelineStepResult(
|
||||
guardrail_name=step.guardrail,
|
||||
outcome=outcome,
|
||||
action_taken=action,
|
||||
modified_data=modified_data,
|
||||
error_detail=error_detail,
|
||||
duration_seconds=round(duration, 4),
|
||||
)
|
||||
step_results.append(step_result)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Pipeline '{policy_name}' step {i}: guardrail={step.guardrail}, "
|
||||
f"outcome={outcome}, action={action}"
|
||||
)
|
||||
|
||||
# Forward modified data to next step if pass_data is True
|
||||
if step.pass_data and modified_data is not None:
|
||||
working_data = {**working_data, **modified_data}
|
||||
|
||||
# Handle terminal actions
|
||||
if action == "allow":
|
||||
return PipelineExecutionResult(
|
||||
terminal_action="allow",
|
||||
step_results=step_results,
|
||||
modified_data=working_data if working_data != data else None,
|
||||
)
|
||||
|
||||
if action == "block":
|
||||
return PipelineExecutionResult(
|
||||
terminal_action="block",
|
||||
step_results=step_results,
|
||||
error_message=error_detail,
|
||||
)
|
||||
|
||||
if action == "modify_response":
|
||||
return PipelineExecutionResult(
|
||||
terminal_action="modify_response",
|
||||
step_results=step_results,
|
||||
modify_response_message=step.modify_response_message
|
||||
or error_detail,
|
||||
)
|
||||
|
||||
# action == "next" → continue to next step
|
||||
|
||||
# Ran out of steps without a terminal action → default allow
|
||||
return PipelineExecutionResult(
|
||||
terminal_action="allow",
|
||||
step_results=step_results,
|
||||
modified_data=working_data if working_data != data else None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _run_step(
|
||||
step: PipelineStep,
|
||||
mode: str,
|
||||
data: dict,
|
||||
user_api_key_dict: Any,
|
||||
call_type: str,
|
||||
) -> tuple:
|
||||
"""
|
||||
Run a single pipeline step's guardrail.
|
||||
|
||||
Returns:
|
||||
Tuple of (outcome, modified_data, error_detail) where:
|
||||
- outcome: "pass", "fail", or "error"
|
||||
- modified_data: dict if guardrail returned modified data, else None
|
||||
- error_detail: error message string if fail/error, else None
|
||||
"""
|
||||
callback = PipelineExecutor._find_guardrail_callback(step.guardrail)
|
||||
if callback is None:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Pipeline: guardrail '{step.guardrail}' not found in callbacks"
|
||||
)
|
||||
return ("error", None, f"Guardrail '{step.guardrail}' not found")
|
||||
|
||||
try:
|
||||
# Inject guardrail name into metadata so should_run_guardrail() allows it
|
||||
if "metadata" not in data:
|
||||
data["metadata"] = {}
|
||||
data["metadata"]["guardrails"] = [step.guardrail]
|
||||
|
||||
# Use unified_guardrail path if callback implements apply_guardrail
|
||||
target: CustomLogger = callback
|
||||
use_unified = "apply_guardrail" in type(callback).__dict__
|
||||
if use_unified:
|
||||
data["guardrail_to_apply"] = callback
|
||||
target = UnifiedLLMGuardrails()
|
||||
|
||||
if mode == "pre_call":
|
||||
response = await target.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=None, # type: ignore
|
||||
data=data,
|
||||
call_type=call_type, # type: ignore
|
||||
)
|
||||
elif mode == "post_call":
|
||||
response = await target.async_post_call_success_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
data=data,
|
||||
response=data.get("response"), # type: ignore
|
||||
)
|
||||
else:
|
||||
return ("error", None, f"Unsupported pipeline mode: {mode}")
|
||||
|
||||
# Normal return means pass
|
||||
modified_data = None
|
||||
if response is not None and isinstance(response, dict):
|
||||
modified_data = response
|
||||
return ("pass", modified_data, None)
|
||||
|
||||
except Exception as e:
|
||||
if CustomGuardrail._is_guardrail_intervention(e):
|
||||
error_msg = _extract_error_message(e)
|
||||
return ("fail", None, error_msg)
|
||||
else:
|
||||
verbose_proxy_logger.error(
|
||||
f"Pipeline: unexpected error from guardrail '{step.guardrail}': {e}"
|
||||
)
|
||||
return ("error", None, str(e))
|
||||
|
||||
@staticmethod
|
||||
def _find_guardrail_callback(guardrail_name: str) -> Optional[CustomGuardrail]:
|
||||
"""Look up an initialized guardrail callback by name from litellm.callbacks."""
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, CustomGuardrail):
|
||||
if callback.guardrail_name == guardrail_name:
|
||||
return callback
|
||||
return None
|
||||
|
||||
|
||||
def _extract_error_message(e: Exception) -> str:
|
||||
"""Extract a human-readable error message from a guardrail exception."""
|
||||
if isinstance(e, ModifyResponseException):
|
||||
return str(e)
|
||||
if HTTPException is not None and isinstance(e, HTTPException):
|
||||
detail = getattr(e, "detail", None)
|
||||
if detail:
|
||||
return str(detail)
|
||||
return str(e)
|
||||
@@ -0,0 +1,834 @@
|
||||
"""
|
||||
CRUD ENDPOINTS FOR POLICIES
|
||||
|
||||
Provides REST API endpoints for managing policies and policy attachments.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
|
||||
from litellm.proxy.policy_engine.pipeline_executor import PipelineExecutor
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
GuardrailPipeline,
|
||||
PipelineTestRequest,
|
||||
PolicyAttachmentCreateRequest,
|
||||
PolicyAttachmentDBResponse,
|
||||
PolicyAttachmentListResponse,
|
||||
PolicyCreateRequest,
|
||||
PolicyDBResponse,
|
||||
PolicyListDBResponse,
|
||||
PolicyUpdateRequest,
|
||||
PolicyVersionCompareResponse,
|
||||
PolicyVersionCreateRequest,
|
||||
PolicyVersionListResponse,
|
||||
PolicyVersionStatusUpdateRequest,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Policy CRUD Endpoints
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/policies/list",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PolicyListDBResponse,
|
||||
)
|
||||
async def list_policies(version_status: Optional[str] = None):
|
||||
"""
|
||||
List all policies from the database. Optionally filter by version_status.
|
||||
|
||||
Query params:
|
||||
- version_status: Optional. One of "draft", "published", "production".
|
||||
If omitted, all versions are returned.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X GET "http://localhost:4000/policies/list" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
curl -X GET "http://localhost:4000/policies/list?version_status=production" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"policies": [
|
||||
{
|
||||
"policy_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"policy_name": "global-baseline",
|
||||
"version_number": 1,
|
||||
"version_status": "production",
|
||||
"inherit": null,
|
||||
"description": "Base guardrails for all requests",
|
||||
"guardrails_add": ["pii_masking"],
|
||||
"guardrails_remove": [],
|
||||
"condition": null,
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
],
|
||||
"total_count": 1
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
policies = await get_policy_registry().get_all_policies_from_db(
|
||||
prisma_client, version_status=version_status
|
||||
)
|
||||
return PolicyListDBResponse(policies=policies, total_count=len(policies))
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error listing policies: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/policies",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PolicyDBResponse,
|
||||
)
|
||||
async def create_policy(
|
||||
request: PolicyCreateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new policy.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/policies" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"policy_name": "global-baseline",
|
||||
"description": "Base guardrails for all requests",
|
||||
"guardrails_add": ["pii_masking", "prompt_injection"],
|
||||
"guardrails_remove": []
|
||||
}'
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"policy_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"policy_name": "global-baseline",
|
||||
"inherit": null,
|
||||
"description": "Base guardrails for all requests",
|
||||
"guardrails_add": ["pii_masking", "prompt_injection"],
|
||||
"guardrails_remove": [],
|
||||
"condition": null,
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
created_by = user_api_key_dict.user_id
|
||||
result = await get_policy_registry().add_policy_to_db(
|
||||
policy_request=request,
|
||||
prisma_client=prisma_client,
|
||||
created_by=created_by,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating policy: {e}")
|
||||
if "unique constraint" in str(e).lower():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Policy with name '{request.policy_name}' already exists",
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Policy Versioning Endpoints (must be before /policies/{policy_id} to avoid path conflicts)
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/policies/name/{policy_name}/versions",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PolicyVersionListResponse,
|
||||
)
|
||||
async def list_policy_versions(policy_name: str):
|
||||
"""
|
||||
List all versions of a policy by name, ordered by version_number descending.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
return await get_policy_registry().get_versions_by_policy_name(
|
||||
policy_name=policy_name,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error listing policy versions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/policies/name/{policy_name}/versions",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PolicyDBResponse,
|
||||
)
|
||||
async def create_policy_version(
|
||||
policy_name: str,
|
||||
request: PolicyVersionCreateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new draft version of a policy. Copies all fields from the source.
|
||||
Source is current production if source_policy_id is not provided.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
created_by = user_api_key_dict.user_id
|
||||
return await get_policy_registry().create_new_version(
|
||||
policy_name=policy_name,
|
||||
prisma_client=prisma_client,
|
||||
source_policy_id=request.source_policy_id,
|
||||
created_by=created_by,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating policy version: {e}")
|
||||
if "not found" in str(e).lower() or "no production" in str(e).lower():
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put(
|
||||
"/policies/{policy_id}/status",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PolicyDBResponse,
|
||||
)
|
||||
async def update_policy_version_status(
|
||||
policy_id: str,
|
||||
request: PolicyVersionStatusUpdateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update a policy version's status. Valid transitions:
|
||||
- draft -> published
|
||||
- published -> production (demotes current production to published)
|
||||
- production -> published (demotes, policy becomes inactive)
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
updated_by = user_api_key_dict.user_id
|
||||
return await get_policy_registry().update_version_status(
|
||||
policy_id=policy_id,
|
||||
new_status=request.version_status,
|
||||
prisma_client=prisma_client,
|
||||
updated_by=updated_by,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating version status: {e}")
|
||||
if (
|
||||
"invalid status" in str(e).lower()
|
||||
or "only draft" in str(e).lower()
|
||||
or "cannot promote" in str(e).lower()
|
||||
):
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
if "not found" in str(e).lower():
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/policies/compare",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PolicyVersionCompareResponse,
|
||||
)
|
||||
async def compare_policy_versions(
|
||||
version_a: str,
|
||||
version_b: str,
|
||||
):
|
||||
"""
|
||||
Compare two policy versions. Query params: version_a, version_b (policy version IDs).
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
return await get_policy_registry().compare_versions(
|
||||
policy_id_a=version_a,
|
||||
policy_id_b=version_b,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error comparing versions: {e}")
|
||||
if "not found" in str(e).lower():
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/policies/name/{policy_name}/all-versions",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_all_policy_versions(policy_name: str):
|
||||
"""
|
||||
Delete all versions of a policy. Also removes from in-memory registry.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
return await get_policy_registry().delete_all_versions(
|
||||
policy_name=policy_name,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error deleting all versions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Policy CRUD by ID
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/policies/{policy_id}",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PolicyDBResponse,
|
||||
)
|
||||
async def get_policy(policy_id: str):
|
||||
"""
|
||||
Get a policy by ID.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X GET "http://localhost:4000/policies/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
result = await get_policy_registry().get_policy_by_id_from_db(
|
||||
policy_id=policy_id,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Policy with ID {policy_id} not found"
|
||||
)
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting policy: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.put(
|
||||
"/policies/{policy_id}",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PolicyDBResponse,
|
||||
)
|
||||
async def update_policy(
|
||||
policy_id: str,
|
||||
request: PolicyUpdateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Update an existing policy.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X PUT "http://localhost:4000/policies/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"description": "Updated description",
|
||||
"guardrails_add": ["pii_masking", "toxicity_filter"]
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Check if policy exists and is draft (only drafts can be updated)
|
||||
existing = await get_policy_registry().get_policy_by_id_from_db(
|
||||
policy_id=policy_id,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
if existing is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Policy with ID {policy_id} not found"
|
||||
)
|
||||
if getattr(existing, "version_status", "production") != "draft":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Only draft versions can be updated. Publish or create a new version to change published/production.",
|
||||
)
|
||||
|
||||
updated_by = user_api_key_dict.user_id
|
||||
result = await get_policy_registry().update_policy_in_db(
|
||||
policy_id=policy_id,
|
||||
policy_request=request,
|
||||
prisma_client=prisma_client,
|
||||
updated_by=updated_by,
|
||||
)
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating policy: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/policies/{policy_id}",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_policy(policy_id: str):
|
||||
"""
|
||||
Delete a policy.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X DELETE "http://localhost:4000/policies/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"message": "Policy 123e4567-e89b-12d3-a456-426614174000 deleted successfully"
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Check if policy exists
|
||||
existing = await get_policy_registry().get_policy_by_id_from_db(
|
||||
policy_id=policy_id,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
if existing is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Policy with ID {policy_id} not found"
|
||||
)
|
||||
|
||||
result = await get_policy_registry().delete_policy_from_db(
|
||||
policy_id=policy_id,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
# Result may include "warning" if production was deleted
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error deleting policy: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/policies/{policy_id}/resolved-guardrails",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def get_resolved_guardrails(policy_id: str):
|
||||
"""
|
||||
Get the resolved guardrails for a policy (including inherited guardrails).
|
||||
|
||||
This endpoint resolves the full inheritance chain and returns the final
|
||||
set of guardrails that would be applied for this policy.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X GET "http://localhost:4000/policies/123e4567-e89b-12d3-a456-426614174000/resolved-guardrails" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"policy_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"policy_name": "healthcare-compliance",
|
||||
"resolved_guardrails": ["pii_masking", "prompt_injection", "toxicity_filter"]
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Get the policy
|
||||
policy = await get_policy_registry().get_policy_by_id_from_db(
|
||||
policy_id=policy_id,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
if policy is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Policy with ID {policy_id} not found"
|
||||
)
|
||||
|
||||
# Resolve guardrails
|
||||
resolved = await get_policy_registry().resolve_guardrails_from_db(
|
||||
policy_name=policy.policy_name,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
return {
|
||||
"policy_id": policy.policy_id,
|
||||
"policy_name": policy.policy_name,
|
||||
"resolved_guardrails": resolved,
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error resolving guardrails: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Pipeline Test Endpoint
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/policies/test-pipeline",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def test_pipeline(
|
||||
request: PipelineTestRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Test a guardrail pipeline with sample messages.
|
||||
|
||||
Executes the pipeline steps against the provided test messages and returns
|
||||
step-by-step results showing which guardrails passed/failed, actions taken,
|
||||
and timing information.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/policies/test-pipeline" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"pipeline": {
|
||||
"mode": "pre_call",
|
||||
"steps": [
|
||||
{"guardrail": "pii-guard", "on_pass": "next", "on_fail": "block"}
|
||||
]
|
||||
},
|
||||
"test_messages": [{"role": "user", "content": "My SSN is 123-45-6789"}]
|
||||
}'
|
||||
```
|
||||
"""
|
||||
try:
|
||||
validated_pipeline = GuardrailPipeline(**request.pipeline)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid pipeline: {e}")
|
||||
|
||||
data = {
|
||||
"messages": request.test_messages,
|
||||
"model": "test",
|
||||
"metadata": {},
|
||||
}
|
||||
|
||||
try:
|
||||
result = await PipelineExecutor.execute_steps(
|
||||
steps=validated_pipeline.steps,
|
||||
mode=validated_pipeline.mode,
|
||||
data=data,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
call_type="completion",
|
||||
policy_name="test-pipeline",
|
||||
)
|
||||
return result.model_dump()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error testing pipeline: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Policy Attachment CRUD Endpoints
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/policies/attachments/list",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PolicyAttachmentListResponse,
|
||||
)
|
||||
async def list_policy_attachments():
|
||||
"""
|
||||
List all policy attachments from the database.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X GET "http://localhost:4000/policies/attachments/list" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"attachments": [
|
||||
{
|
||||
"attachment_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"policy_name": "global-baseline",
|
||||
"scope": "*",
|
||||
"teams": [],
|
||||
"keys": [],
|
||||
"models": [],
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
],
|
||||
"total_count": 1
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
attachments = await get_attachment_registry().get_all_attachments_from_db(
|
||||
prisma_client
|
||||
)
|
||||
return PolicyAttachmentListResponse(
|
||||
attachments=attachments, total_count=len(attachments)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error listing policy attachments: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/policies/attachments",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PolicyAttachmentDBResponse,
|
||||
)
|
||||
async def create_policy_attachment(
|
||||
request: PolicyAttachmentCreateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Create a new policy attachment.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/policies/attachments" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"policy_name": "global-baseline",
|
||||
"scope": "*"
|
||||
}'
|
||||
```
|
||||
|
||||
Example with team-specific attachment:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/policies/attachments" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"policy_name": "healthcare-compliance",
|
||||
"teams": ["healthcare-team", "medical-research"]
|
||||
}'
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"attachment_id": "123e4567-e89b-12d3-a456-426614174000",
|
||||
"policy_name": "global-baseline",
|
||||
"scope": "*",
|
||||
"teams": [],
|
||||
"keys": [],
|
||||
"models": [],
|
||||
"created_at": "2024-01-01T00:00:00Z",
|
||||
"updated_at": "2024-01-01T00:00:00Z"
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Verify the policy has a production version (attachments resolve against production)
|
||||
policies = await get_policy_registry().get_all_policies_from_db(
|
||||
prisma_client, version_status="production"
|
||||
)
|
||||
policy_names = {p.policy_name for p in policies}
|
||||
if request.policy_name not in policy_names:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Policy '{request.policy_name}' not found. Create the policy first.",
|
||||
)
|
||||
|
||||
created_by = user_api_key_dict.user_id
|
||||
result = await get_attachment_registry().add_attachment_to_db(
|
||||
attachment_request=request,
|
||||
prisma_client=prisma_client,
|
||||
created_by=created_by,
|
||||
)
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating policy attachment: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/policies/attachments/{attachment_id}",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PolicyAttachmentDBResponse,
|
||||
)
|
||||
async def get_policy_attachment(attachment_id: str):
|
||||
"""
|
||||
Get a policy attachment by ID.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X GET "http://localhost:4000/policies/attachments/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
result = await get_attachment_registry().get_attachment_by_id_from_db(
|
||||
attachment_id=attachment_id,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Attachment with ID {attachment_id} not found",
|
||||
)
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting policy attachment: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/policies/attachments/{attachment_id}",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_policy_attachment(attachment_id: str):
|
||||
"""
|
||||
Delete a policy attachment.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X DELETE "http://localhost:4000/policies/attachments/123e4567-e89b-12d3-a456-426614174000" \\
|
||||
-H "Authorization: Bearer <your_api_key>"
|
||||
```
|
||||
|
||||
Example Response:
|
||||
```json
|
||||
{
|
||||
"message": "Attachment 123e4567-e89b-12d3-a456-426614174000 deleted successfully"
|
||||
}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Check if attachment exists
|
||||
existing = await get_attachment_registry().get_attachment_by_id_from_db(
|
||||
attachment_id=attachment_id,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
if existing is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Attachment with ID {attachment_id} not found",
|
||||
)
|
||||
|
||||
result = await get_attachment_registry().delete_attachment_from_db(
|
||||
attachment_id=attachment_id,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
return result
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error deleting policy attachment: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Policy Matcher - Matches requests against policy attachments.
|
||||
|
||||
Uses existing wildcard pattern matching helpers to determine which policies
|
||||
apply to a given request based on team alias, key alias, and model.
|
||||
|
||||
Policies are matched via policy_attachments which define WHERE each policy applies.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.types.proxy.policy_engine import Policy, PolicyMatchContext, PolicyScope
|
||||
|
||||
|
||||
class PolicyMatcher:
|
||||
"""
|
||||
Matches incoming requests against policy attachments.
|
||||
|
||||
Supports wildcard patterns:
|
||||
- "*" matches everything
|
||||
- "prefix-*" matches anything starting with "prefix-"
|
||||
|
||||
Uses policy_attachments to determine which policies apply to a request.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def matches_pattern(value: Optional[str], patterns: List[str]) -> bool:
|
||||
"""
|
||||
Check if a value matches any of the given patterns.
|
||||
|
||||
Uses the existing RouteChecks._route_matches_wildcard_pattern helper.
|
||||
|
||||
Args:
|
||||
value: The value to check (e.g., team alias, key alias, model)
|
||||
patterns: List of patterns to match against
|
||||
|
||||
Returns:
|
||||
True if value matches any pattern, False otherwise
|
||||
"""
|
||||
# If no value provided, only match if patterns include "*"
|
||||
if value is None:
|
||||
return "*" in patterns
|
||||
|
||||
for pattern in patterns:
|
||||
# Use existing wildcard pattern matching helper
|
||||
if RouteChecks._route_matches_wildcard_pattern(
|
||||
route=value, pattern=pattern
|
||||
):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def scope_matches(scope: PolicyScope, context: PolicyMatchContext) -> bool:
|
||||
"""
|
||||
Check if a policy scope matches the given context.
|
||||
|
||||
A scope matches if ALL of its fields match:
|
||||
- teams matches context.team_alias
|
||||
- keys matches context.key_alias
|
||||
- models matches context.model
|
||||
|
||||
Args:
|
||||
scope: The policy scope to check
|
||||
context: The request context
|
||||
|
||||
Returns:
|
||||
True if scope matches context, False otherwise
|
||||
"""
|
||||
# Check teams
|
||||
if not PolicyMatcher.matches_pattern(context.team_alias, scope.get_teams()):
|
||||
return False
|
||||
|
||||
# Check keys
|
||||
if not PolicyMatcher.matches_pattern(context.key_alias, scope.get_keys()):
|
||||
return False
|
||||
|
||||
# Check models
|
||||
if not PolicyMatcher.matches_pattern(context.model, scope.get_models()):
|
||||
return False
|
||||
|
||||
# Check tags (only if scope specifies tags)
|
||||
# Unlike teams/keys/models, empty tags means "do not check" rather than "match all"
|
||||
scope_tags = scope.get_tags()
|
||||
if scope_tags:
|
||||
if not context.tags:
|
||||
return False
|
||||
# Match if ANY context tag matches ANY scope tag pattern
|
||||
if not any(
|
||||
PolicyMatcher.matches_pattern(tag, scope_tags) for tag in context.tags
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_matching_policies(
|
||||
context: PolicyMatchContext,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get list of policy names that match the given context via attachments.
|
||||
|
||||
Args:
|
||||
context: The request context to match against
|
||||
|
||||
Returns:
|
||||
List of policy names that match the context
|
||||
"""
|
||||
from litellm.proxy.policy_engine.attachment_registry import (
|
||||
get_attachment_registry,
|
||||
)
|
||||
|
||||
registry = get_attachment_registry()
|
||||
if not registry.is_initialized():
|
||||
verbose_proxy_logger.debug(
|
||||
"AttachmentRegistry not initialized, returning empty list"
|
||||
)
|
||||
return []
|
||||
|
||||
return registry.get_attached_policies(context)
|
||||
|
||||
@staticmethod
|
||||
def get_matching_policies_from_registry(
|
||||
context: PolicyMatchContext,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get list of policy names that match the given context from the global registry.
|
||||
|
||||
Args:
|
||||
context: The request context to match against
|
||||
|
||||
Returns:
|
||||
List of policy names that match the context
|
||||
"""
|
||||
return PolicyMatcher.get_matching_policies(context=context)
|
||||
|
||||
@staticmethod
|
||||
def get_policies_with_matching_conditions(
|
||||
policy_names: List[str],
|
||||
context: PolicyMatchContext,
|
||||
policies: Optional[Dict[str, Policy]] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Filter policies to only those whose conditions match the context.
|
||||
|
||||
A policy's condition matches if:
|
||||
- The policy has no condition (condition is None), OR
|
||||
- The policy's condition evaluates to True for the given context
|
||||
|
||||
Args:
|
||||
policy_names: List of policy names to filter
|
||||
context: The request context to evaluate conditions against
|
||||
policies: Dictionary of all policies (if None, uses global registry)
|
||||
|
||||
Returns:
|
||||
List of policy names whose conditions match the context
|
||||
"""
|
||||
from litellm.proxy.policy_engine.condition_evaluator import ConditionEvaluator
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
|
||||
if policies is None:
|
||||
registry = get_policy_registry()
|
||||
if not registry.is_initialized():
|
||||
return []
|
||||
policies = registry.get_all_policies()
|
||||
|
||||
matching_policies = []
|
||||
for policy_name in policy_names:
|
||||
policy = policies.get(policy_name)
|
||||
if policy is None:
|
||||
continue
|
||||
# Policy matches if it has no condition OR condition evaluates to True
|
||||
if policy.condition is None or ConditionEvaluator.evaluate(
|
||||
policy.condition, context
|
||||
):
|
||||
matching_policies.append(policy_name)
|
||||
|
||||
return matching_policies
|
||||
@@ -0,0 +1,978 @@
|
||||
"""
|
||||
Policy Registry - In-memory storage for policies.
|
||||
|
||||
Handles storing, retrieving, and managing policies.
|
||||
|
||||
Policies define WHAT guardrails to apply. WHERE they apply is defined
|
||||
by policy_attachments (see AttachmentRegistry).
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
GuardrailPipeline,
|
||||
PipelineStep,
|
||||
Policy,
|
||||
PolicyCondition,
|
||||
PolicyCreateRequest,
|
||||
PolicyDBResponse,
|
||||
PolicyGuardrails,
|
||||
PolicyUpdateRequest,
|
||||
PolicyVersionCompareResponse,
|
||||
PolicyVersionListResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
# Prefix for policy version IDs in request body. Use policy_<uuid> to execute a specific version.
|
||||
POLICY_VERSION_ID_PREFIX = "policy_"
|
||||
|
||||
|
||||
def _row_to_policy_db_response(row: Any) -> PolicyDBResponse:
|
||||
"""Build PolicyDBResponse from a Prisma LiteLLM_PolicyTable row."""
|
||||
return PolicyDBResponse(
|
||||
policy_id=row.policy_id,
|
||||
policy_name=row.policy_name,
|
||||
version_number=getattr(row, "version_number", 1),
|
||||
version_status=getattr(row, "version_status", "production"),
|
||||
parent_version_id=getattr(row, "parent_version_id", None),
|
||||
is_latest=getattr(row, "is_latest", True),
|
||||
published_at=getattr(row, "published_at", None),
|
||||
production_at=getattr(row, "production_at", None),
|
||||
inherit=row.inherit,
|
||||
description=row.description,
|
||||
guardrails_add=row.guardrails_add or [],
|
||||
guardrails_remove=row.guardrails_remove or [],
|
||||
condition=row.condition,
|
||||
pipeline=row.pipeline,
|
||||
created_at=row.created_at,
|
||||
updated_at=row.updated_at,
|
||||
created_by=row.created_by,
|
||||
updated_by=row.updated_by,
|
||||
)
|
||||
|
||||
|
||||
class PolicyRegistry:
|
||||
"""
|
||||
In-memory registry for storing and managing policies.
|
||||
|
||||
This is a singleton that holds all loaded policies and provides
|
||||
methods to access them.
|
||||
|
||||
Policies define WHAT guardrails to apply:
|
||||
- Base guardrails via guardrails.add/remove
|
||||
- Inheritance via inherit field
|
||||
- Conditional guardrails via condition.model
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._policies: Dict[str, Policy] = {}
|
||||
self._policies_by_id: Dict[str, Tuple[str, Policy]] = {}
|
||||
self._initialized: bool = False
|
||||
|
||||
def load_policies(self, policies_config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Load policies from a configuration dictionary.
|
||||
|
||||
Args:
|
||||
policies_config: Dictionary mapping policy names to policy definitions.
|
||||
This is the raw config from the YAML file.
|
||||
"""
|
||||
self._policies = {}
|
||||
self._policies_by_id = {}
|
||||
|
||||
for policy_name, policy_data in policies_config.items():
|
||||
try:
|
||||
policy = self._parse_policy(policy_name, policy_data)
|
||||
self._policies[policy_name] = policy
|
||||
verbose_proxy_logger.debug(f"Loaded policy: {policy_name}")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error loading policy '{policy_name}': {str(e)}"
|
||||
)
|
||||
raise ValueError(f"Invalid policy '{policy_name}': {str(e)}") from e
|
||||
|
||||
self._initialized = True
|
||||
verbose_proxy_logger.info(f"Loaded {len(self._policies)} policies")
|
||||
|
||||
def _parse_policy(self, policy_name: str, policy_data: Dict[str, Any]) -> Policy:
|
||||
"""
|
||||
Parse a policy from raw configuration data.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
policy_data: Raw policy configuration
|
||||
|
||||
Returns:
|
||||
Parsed Policy object
|
||||
"""
|
||||
# Parse guardrails
|
||||
guardrails_data = policy_data.get("guardrails", {})
|
||||
if isinstance(guardrails_data, dict):
|
||||
guardrails = PolicyGuardrails(
|
||||
add=guardrails_data.get("add"),
|
||||
remove=guardrails_data.get("remove"),
|
||||
)
|
||||
else:
|
||||
# Handle legacy format where guardrails might be a list
|
||||
guardrails = PolicyGuardrails(
|
||||
add=guardrails_data if guardrails_data else None
|
||||
)
|
||||
|
||||
# Parse condition (simple model-based condition)
|
||||
condition = None
|
||||
condition_data = policy_data.get("condition")
|
||||
if condition_data:
|
||||
condition = PolicyCondition(model=condition_data.get("model"))
|
||||
|
||||
# Parse pipeline (optional ordered guardrail execution)
|
||||
pipeline = PolicyRegistry._parse_pipeline(policy_data.get("pipeline"))
|
||||
|
||||
return Policy(
|
||||
inherit=policy_data.get("inherit"),
|
||||
description=policy_data.get("description"),
|
||||
guardrails=guardrails,
|
||||
condition=condition,
|
||||
pipeline=pipeline,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_pipeline(
|
||||
pipeline_data: Optional[Dict[str, Any]],
|
||||
) -> Optional[GuardrailPipeline]:
|
||||
"""Parse a pipeline configuration from raw data."""
|
||||
if pipeline_data is None:
|
||||
return None
|
||||
|
||||
steps_data = pipeline_data.get("steps", [])
|
||||
steps = [
|
||||
PipelineStep(**step_data) if isinstance(step_data, dict) else step_data
|
||||
for step_data in steps_data
|
||||
]
|
||||
|
||||
return GuardrailPipeline(
|
||||
mode=pipeline_data.get("mode", "pre_call"),
|
||||
steps=steps,
|
||||
)
|
||||
|
||||
def get_policy(self, policy_name: str) -> Optional[Policy]:
|
||||
"""
|
||||
Get a policy by name.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy to retrieve
|
||||
|
||||
Returns:
|
||||
Policy object if found, None otherwise
|
||||
"""
|
||||
return self._policies.get(policy_name)
|
||||
|
||||
def get_all_policies(self) -> Dict[str, Policy]:
|
||||
"""
|
||||
Get all loaded policies.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping policy names to Policy objects
|
||||
"""
|
||||
return self._policies.copy()
|
||||
|
||||
def get_policy_names(self) -> List[str]:
|
||||
"""
|
||||
Get list of all policy names.
|
||||
|
||||
Returns:
|
||||
List of policy names
|
||||
"""
|
||||
return list(self._policies.keys())
|
||||
|
||||
def has_policy(self, policy_name: str) -> bool:
|
||||
"""
|
||||
Check if a policy exists.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy to check
|
||||
|
||||
Returns:
|
||||
True if policy exists, False otherwise
|
||||
"""
|
||||
return policy_name in self._policies
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
"""
|
||||
Check if the registry has been initialized with policies.
|
||||
|
||||
Returns:
|
||||
True if policies have been loaded, False otherwise
|
||||
"""
|
||||
return self._initialized
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all policies from the registry.
|
||||
"""
|
||||
self._policies = {}
|
||||
self._initialized = False
|
||||
|
||||
def add_policy(self, policy_name: str, policy: Policy) -> None:
|
||||
"""
|
||||
Add or update a single policy.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
policy: Policy object to add
|
||||
"""
|
||||
self._policies[policy_name] = policy
|
||||
verbose_proxy_logger.debug(f"Added/updated policy: {policy_name}")
|
||||
|
||||
def remove_policy(self, policy_name: str) -> bool:
|
||||
"""
|
||||
Remove a policy by name.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy to remove
|
||||
|
||||
Returns:
|
||||
True if policy was removed, False if it didn't exist
|
||||
"""
|
||||
if policy_name in self._policies:
|
||||
del self._policies[policy_name]
|
||||
verbose_proxy_logger.debug(f"Removed policy: {policy_name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
# Database CRUD Methods
|
||||
# ─────────────────────────────────────────────────────────────────────────
|
||||
|
||||
async def add_policy_to_db(
|
||||
self,
|
||||
policy_request: PolicyCreateRequest,
|
||||
prisma_client: "PrismaClient",
|
||||
created_by: Optional[str] = None,
|
||||
) -> PolicyDBResponse:
|
||||
"""
|
||||
Add a policy to the database.
|
||||
|
||||
Args:
|
||||
policy_request: The policy creation request
|
||||
prisma_client: The Prisma client instance
|
||||
created_by: User who created the policy
|
||||
|
||||
Returns:
|
||||
PolicyDBResponse with the created policy
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
# Build data dict; new policy is v1 production
|
||||
data: Dict[str, Any] = {
|
||||
"policy_name": policy_request.policy_name,
|
||||
"version_number": 1,
|
||||
"version_status": "production",
|
||||
"is_latest": True,
|
||||
"production_at": now,
|
||||
"guardrails_add": policy_request.guardrails_add or [],
|
||||
"guardrails_remove": policy_request.guardrails_remove or [],
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
# Only add optional fields if they have values
|
||||
if policy_request.inherit is not None:
|
||||
data["inherit"] = policy_request.inherit
|
||||
if policy_request.description is not None:
|
||||
data["description"] = policy_request.description
|
||||
if created_by is not None:
|
||||
data["created_by"] = created_by
|
||||
data["updated_by"] = created_by
|
||||
if policy_request.condition is not None:
|
||||
data["condition"] = json.dumps(policy_request.condition.model_dump())
|
||||
if policy_request.pipeline is not None:
|
||||
validated_pipeline = GuardrailPipeline(**policy_request.pipeline)
|
||||
data["pipeline"] = json.dumps(validated_pipeline.model_dump())
|
||||
|
||||
created_policy = await prisma_client.db.litellm_policytable.create(
|
||||
data=data
|
||||
)
|
||||
|
||||
# Also add to in-memory registry
|
||||
policy = self._parse_policy(
|
||||
policy_request.policy_name,
|
||||
{
|
||||
"inherit": policy_request.inherit,
|
||||
"description": policy_request.description,
|
||||
"guardrails": {
|
||||
"add": policy_request.guardrails_add,
|
||||
"remove": policy_request.guardrails_remove,
|
||||
},
|
||||
"condition": (
|
||||
policy_request.condition.model_dump()
|
||||
if policy_request.condition
|
||||
else None
|
||||
),
|
||||
"pipeline": policy_request.pipeline,
|
||||
},
|
||||
)
|
||||
self.add_policy(policy_request.policy_name, policy)
|
||||
|
||||
return _row_to_policy_db_response(created_policy)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error adding policy to DB: {e}")
|
||||
raise Exception(f"Error adding policy to DB: {str(e)}")
|
||||
|
||||
async def update_policy_in_db(
|
||||
self,
|
||||
policy_id: str,
|
||||
policy_request: PolicyUpdateRequest,
|
||||
prisma_client: "PrismaClient",
|
||||
updated_by: Optional[str] = None,
|
||||
) -> PolicyDBResponse:
|
||||
"""
|
||||
Update a policy in the database. Only draft versions can be updated.
|
||||
|
||||
Args:
|
||||
policy_id: The ID of the policy to update
|
||||
policy_request: The policy update request
|
||||
prisma_client: The Prisma client instance
|
||||
updated_by: User who updated the policy
|
||||
|
||||
Returns:
|
||||
PolicyDBResponse with the updated policy
|
||||
|
||||
Raises:
|
||||
Exception: If policy is not in draft status (only drafts are editable).
|
||||
"""
|
||||
try:
|
||||
existing = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
if existing is None:
|
||||
raise Exception(f"Policy with ID {policy_id} not found")
|
||||
version_status = getattr(existing, "version_status", "production")
|
||||
if version_status != "draft":
|
||||
raise Exception(
|
||||
f"Only draft versions can be updated. This policy has status '{version_status}'."
|
||||
)
|
||||
|
||||
# Build update data - only include fields that are set
|
||||
update_data: Dict[str, Any] = {
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
"updated_by": updated_by,
|
||||
}
|
||||
|
||||
if policy_request.policy_name is not None:
|
||||
update_data["policy_name"] = policy_request.policy_name
|
||||
if policy_request.inherit is not None:
|
||||
update_data["inherit"] = policy_request.inherit
|
||||
if policy_request.description is not None:
|
||||
update_data["description"] = policy_request.description
|
||||
if policy_request.guardrails_add is not None:
|
||||
update_data["guardrails_add"] = policy_request.guardrails_add
|
||||
if policy_request.guardrails_remove is not None:
|
||||
update_data["guardrails_remove"] = policy_request.guardrails_remove
|
||||
if policy_request.condition is not None:
|
||||
update_data["condition"] = json.dumps(
|
||||
policy_request.condition.model_dump()
|
||||
)
|
||||
if policy_request.pipeline is not None:
|
||||
validated_pipeline = GuardrailPipeline(**policy_request.pipeline)
|
||||
update_data["pipeline"] = json.dumps(validated_pipeline.model_dump())
|
||||
|
||||
updated_policy = await prisma_client.db.litellm_policytable.update(
|
||||
where={"policy_id": policy_id},
|
||||
data=update_data,
|
||||
)
|
||||
|
||||
# Do NOT update in-memory registry: drafts are not loaded into memory.
|
||||
|
||||
return _row_to_policy_db_response(updated_policy)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating policy in DB: {e}")
|
||||
raise Exception(f"Error updating policy in DB: {str(e)}")
|
||||
|
||||
async def delete_policy_from_db(
|
||||
self,
|
||||
policy_id: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Delete a policy version from the database.
|
||||
|
||||
If the deleted version was production, it is removed from the in-memory
|
||||
registry. No other version is auto-promoted; admin must explicitly promote.
|
||||
|
||||
Args:
|
||||
policy_id: The ID of the policy version to delete
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
Dict with "message" and optional "warning" if production was deleted.
|
||||
"""
|
||||
try:
|
||||
policy = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
|
||||
if policy is None:
|
||||
raise Exception(f"Policy with ID {policy_id} not found")
|
||||
|
||||
version_status = getattr(policy, "version_status", "production")
|
||||
policy_name = policy.policy_name
|
||||
|
||||
# Delete from DB
|
||||
await prisma_client.db.litellm_policytable.delete(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
|
||||
result: Dict[str, Any] = {
|
||||
"message": f"Policy {policy_id} deleted successfully"
|
||||
}
|
||||
|
||||
# Remove from in-memory registry only if this was the production version
|
||||
if version_status == "production":
|
||||
self.remove_policy(policy_name)
|
||||
result["warning"] = (
|
||||
"Production version was deleted. No other version was promoted. "
|
||||
"Promote another version to production if this policy should remain active."
|
||||
)
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error deleting policy from DB: {e}")
|
||||
raise Exception(f"Error deleting policy from DB: {str(e)}")
|
||||
|
||||
async def get_policy_by_id_from_db(
|
||||
self,
|
||||
policy_id: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> Optional[PolicyDBResponse]:
|
||||
"""
|
||||
Get a policy by ID from the database.
|
||||
|
||||
Args:
|
||||
policy_id: The ID of the policy to retrieve
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
PolicyDBResponse if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
policy = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
|
||||
if policy is None:
|
||||
return None
|
||||
|
||||
return _row_to_policy_db_response(policy)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting policy from DB: {e}")
|
||||
raise Exception(f"Error getting policy from DB: {str(e)}")
|
||||
|
||||
def get_policy_by_id_for_request(
|
||||
self, policy_id: str
|
||||
) -> Optional[Tuple[str, Policy]]:
|
||||
"""
|
||||
Return a policy version by ID from in-memory cache (no DB access).
|
||||
|
||||
Used when the request body specifies policy_<uuid> to execute a specific version
|
||||
(e.g. published or draft). The cache is populated by sync_policies_from_db,
|
||||
which loads draft and published versions keyed by policy_id.
|
||||
|
||||
Args:
|
||||
policy_id: The policy version ID (raw UUID, no prefix)
|
||||
|
||||
Returns:
|
||||
(policy_name, Policy) if found, None otherwise
|
||||
"""
|
||||
return self._policies_by_id.get(policy_id)
|
||||
|
||||
async def get_all_policies_from_db(
|
||||
self,
|
||||
prisma_client: "PrismaClient",
|
||||
version_status: Optional[str] = None,
|
||||
) -> List[PolicyDBResponse]:
|
||||
"""
|
||||
Get all policies from the database, optionally filtered by version_status.
|
||||
|
||||
Args:
|
||||
prisma_client: The Prisma client instance
|
||||
version_status: If set, only return policies with this status
|
||||
("draft", "published", "production").
|
||||
|
||||
Returns:
|
||||
List of PolicyDBResponse objects
|
||||
"""
|
||||
try:
|
||||
where: Dict[str, Any] = {}
|
||||
if version_status is not None:
|
||||
where["version_status"] = version_status
|
||||
|
||||
policies = await prisma_client.db.litellm_policytable.find_many(
|
||||
where=where if where else None,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
|
||||
return [_row_to_policy_db_response(p) for p in policies]
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting policies from DB: {e}")
|
||||
raise Exception(f"Error getting policies from DB: {str(e)}")
|
||||
|
||||
async def sync_policies_from_db(
|
||||
self,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> None:
|
||||
"""
|
||||
Sync policies from the database to in-memory registry.
|
||||
- Production versions are loaded into _policies (by policy name) for resolution.
|
||||
- Draft and published versions are loaded into _policies_by_id so request-body
|
||||
policy_<uuid> overrides can be resolved without DB access in the hot path.
|
||||
"""
|
||||
try:
|
||||
self._policies = {}
|
||||
production = await self.get_all_policies_from_db(
|
||||
prisma_client, version_status="production"
|
||||
)
|
||||
for policy_response in production:
|
||||
policy = self._parse_policy(
|
||||
policy_response.policy_name,
|
||||
{
|
||||
"inherit": policy_response.inherit,
|
||||
"description": policy_response.description,
|
||||
"guardrails": {
|
||||
"add": policy_response.guardrails_add,
|
||||
"remove": policy_response.guardrails_remove,
|
||||
},
|
||||
"condition": policy_response.condition,
|
||||
"pipeline": policy_response.pipeline,
|
||||
},
|
||||
)
|
||||
self.add_policy(policy_response.policy_name, policy)
|
||||
|
||||
self._policies_by_id = {}
|
||||
non_production = await prisma_client.db.litellm_policytable.find_many(
|
||||
where={"version_status": {"in": ["draft", "published"]}},
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
for row in non_production:
|
||||
policy = self._parse_policy(
|
||||
row.policy_name,
|
||||
{
|
||||
"inherit": row.inherit,
|
||||
"description": row.description,
|
||||
"guardrails": {
|
||||
"add": row.guardrails_add or [],
|
||||
"remove": row.guardrails_remove or [],
|
||||
},
|
||||
"condition": row.condition,
|
||||
"pipeline": row.pipeline,
|
||||
},
|
||||
)
|
||||
self._policies_by_id[row.policy_id] = (row.policy_name, policy)
|
||||
|
||||
self._initialized = True
|
||||
verbose_proxy_logger.info(
|
||||
f"Synced {len(production)} production policies and {len(non_production)} "
|
||||
"draft/published (by ID) from DB to in-memory registry"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error syncing policies from DB: {e}")
|
||||
raise Exception(f"Error syncing policies from DB: {str(e)}")
|
||||
|
||||
async def resolve_guardrails_from_db(
|
||||
self,
|
||||
policy_name: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> List[str]:
|
||||
"""
|
||||
Resolve all guardrails for a policy from the database.
|
||||
|
||||
Uses the existing PolicyResolver to handle inheritance chain resolution.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy to resolve
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
List of resolved guardrail names
|
||||
"""
|
||||
from litellm.proxy.policy_engine.policy_resolver import PolicyResolver
|
||||
|
||||
try:
|
||||
# Load only production versions so inheritance resolves against production
|
||||
policies = await self.get_all_policies_from_db(
|
||||
prisma_client, version_status="production"
|
||||
)
|
||||
|
||||
# Build a temporary in-memory map for resolution
|
||||
temp_policies = {}
|
||||
for policy_response in policies:
|
||||
policy = self._parse_policy(
|
||||
policy_response.policy_name,
|
||||
{
|
||||
"inherit": policy_response.inherit,
|
||||
"description": policy_response.description,
|
||||
"guardrails": {
|
||||
"add": policy_response.guardrails_add,
|
||||
"remove": policy_response.guardrails_remove,
|
||||
},
|
||||
"condition": policy_response.condition,
|
||||
"pipeline": policy_response.pipeline,
|
||||
},
|
||||
)
|
||||
temp_policies[policy_response.policy_name] = policy
|
||||
|
||||
# Use the existing PolicyResolver to resolve guardrails
|
||||
resolved_policy = PolicyResolver.resolve_policy_guardrails(
|
||||
policy_name=policy_name,
|
||||
policies=temp_policies,
|
||||
context=None, # No context needed for simple resolution
|
||||
)
|
||||
|
||||
return sorted(resolved_policy.guardrails)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error resolving guardrails from DB: {e}")
|
||||
raise Exception(f"Error resolving guardrails from DB: {str(e)}")
|
||||
|
||||
async def get_versions_by_policy_name(
|
||||
self,
|
||||
policy_name: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> PolicyVersionListResponse:
|
||||
"""
|
||||
Get all versions of a policy by name, ordered by version_number descending.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
PolicyVersionListResponse with policy_name and list of versions
|
||||
"""
|
||||
try:
|
||||
rows = await prisma_client.db.litellm_policytable.find_many(
|
||||
where={"policy_name": policy_name},
|
||||
order={"version_number": "desc"},
|
||||
)
|
||||
versions = [_row_to_policy_db_response(r) for r in rows]
|
||||
return PolicyVersionListResponse(
|
||||
policy_name=policy_name,
|
||||
versions=versions,
|
||||
total_count=len(versions),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error getting versions: {e}")
|
||||
raise Exception(f"Error getting versions: {str(e)}")
|
||||
|
||||
async def create_new_version(
|
||||
self,
|
||||
policy_name: str,
|
||||
prisma_client: "PrismaClient",
|
||||
source_policy_id: Optional[str] = None,
|
||||
created_by: Optional[str] = None,
|
||||
) -> PolicyDBResponse:
|
||||
"""
|
||||
Create a new draft version of a policy. Copies all fields from the source.
|
||||
Source is current production if source_policy_id is None.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
prisma_client: The Prisma client instance
|
||||
source_policy_id: Policy ID to clone from; if None, use current production
|
||||
created_by: User who created the version
|
||||
|
||||
Returns:
|
||||
PolicyDBResponse for the new draft version
|
||||
"""
|
||||
try:
|
||||
if source_policy_id is not None:
|
||||
source = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": source_policy_id}
|
||||
)
|
||||
if source is None:
|
||||
raise Exception(f"Source policy {source_policy_id} not found")
|
||||
if source.policy_name != policy_name:
|
||||
raise Exception(
|
||||
f"Source policy name '{source.policy_name}' does not match '{policy_name}'"
|
||||
)
|
||||
else:
|
||||
# Find current production version for this policy_name
|
||||
prod = await prisma_client.db.litellm_policytable.find_first(
|
||||
where={
|
||||
"policy_name": policy_name,
|
||||
"version_status": "production",
|
||||
}
|
||||
)
|
||||
if prod is None:
|
||||
raise Exception(
|
||||
f"No production version found for policy '{policy_name}'"
|
||||
)
|
||||
source = prod
|
||||
|
||||
# Next version number
|
||||
latest = await prisma_client.db.litellm_policytable.find_first(
|
||||
where={"policy_name": policy_name},
|
||||
order={"version_number": "desc"},
|
||||
)
|
||||
next_num = (latest.version_number + 1) if latest else 1
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
# Set is_latest=False on all existing versions for this policy_name
|
||||
await prisma_client.db.litellm_policytable.update_many(
|
||||
where={"policy_name": policy_name},
|
||||
data={"is_latest": False},
|
||||
)
|
||||
|
||||
data: Dict[str, Any] = {
|
||||
"policy_name": policy_name,
|
||||
"version_number": next_num,
|
||||
"version_status": "draft",
|
||||
"parent_version_id": source.policy_id,
|
||||
"is_latest": True,
|
||||
"published_at": None,
|
||||
"production_at": None,
|
||||
"inherit": source.inherit,
|
||||
"description": source.description,
|
||||
"guardrails_add": source.guardrails_add or [],
|
||||
"guardrails_remove": source.guardrails_remove or [],
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"created_by": created_by,
|
||||
"updated_by": created_by,
|
||||
}
|
||||
# Prisma expects Json fields as JSON strings on create (same as add_policy_to_db)
|
||||
if source.condition is not None:
|
||||
data["condition"] = (
|
||||
json.dumps(source.condition)
|
||||
if isinstance(source.condition, dict)
|
||||
else source.condition
|
||||
)
|
||||
if source.pipeline is not None:
|
||||
data["pipeline"] = (
|
||||
json.dumps(source.pipeline)
|
||||
if isinstance(source.pipeline, dict)
|
||||
else source.pipeline
|
||||
)
|
||||
|
||||
created = await prisma_client.db.litellm_policytable.create(data=data)
|
||||
return _row_to_policy_db_response(created)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating new version: {e}")
|
||||
raise Exception(f"Error creating new version: {str(e)}")
|
||||
|
||||
async def update_version_status(
|
||||
self,
|
||||
policy_id: str,
|
||||
new_status: str,
|
||||
prisma_client: "PrismaClient",
|
||||
updated_by: Optional[str] = None,
|
||||
) -> PolicyDBResponse:
|
||||
"""
|
||||
Update a policy version's status. Valid transitions:
|
||||
- draft -> published (sets published_at)
|
||||
- published -> production (sets production_at, demotes current production to published, updates in-memory)
|
||||
- production -> published (demotes, removes from in-memory)
|
||||
- draft -> production: NOT allowed (must publish first)
|
||||
- published -> draft: NOT allowed
|
||||
|
||||
Args:
|
||||
policy_id: The policy version ID
|
||||
new_status: "published" or "production"
|
||||
prisma_client: The Prisma client instance
|
||||
updated_by: User who updated
|
||||
|
||||
Returns:
|
||||
PolicyDBResponse for the updated version
|
||||
"""
|
||||
try:
|
||||
if new_status not in ("published", "production"):
|
||||
raise Exception(
|
||||
f"Invalid status '{new_status}'. Use 'published' or 'production'."
|
||||
)
|
||||
|
||||
row = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": policy_id}
|
||||
)
|
||||
if row is None:
|
||||
raise Exception(f"Policy with ID {policy_id} not found")
|
||||
|
||||
current = getattr(row, "version_status", "production")
|
||||
policy_name = row.policy_name
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
if new_status == "published":
|
||||
if current != "draft":
|
||||
raise Exception(
|
||||
f"Only draft versions can be published. Current status: '{current}'."
|
||||
)
|
||||
updated = await prisma_client.db.litellm_policytable.update(
|
||||
where={"policy_id": policy_id},
|
||||
data={
|
||||
"version_status": "published",
|
||||
"published_at": now,
|
||||
"updated_at": now,
|
||||
"updated_by": updated_by,
|
||||
},
|
||||
)
|
||||
return _row_to_policy_db_response(updated)
|
||||
|
||||
# new_status == "production"
|
||||
if current not in ("draft", "published"):
|
||||
raise Exception(
|
||||
f"Only draft or published versions can be promoted to production. Current: '{current}'."
|
||||
)
|
||||
# Plan: "draft -> production" NOT allowed
|
||||
if current == "draft":
|
||||
raise Exception(
|
||||
"Cannot promote draft directly to production. Publish the version first."
|
||||
)
|
||||
|
||||
# Demote current production to published
|
||||
await prisma_client.db.litellm_policytable.update_many(
|
||||
where={
|
||||
"policy_name": policy_name,
|
||||
"version_status": "production",
|
||||
},
|
||||
data={
|
||||
"version_status": "published",
|
||||
"updated_at": now,
|
||||
"updated_by": updated_by,
|
||||
},
|
||||
)
|
||||
|
||||
# Promote this version to production
|
||||
updated = await prisma_client.db.litellm_policytable.update(
|
||||
where={"policy_id": policy_id},
|
||||
data={
|
||||
"version_status": "production",
|
||||
"production_at": now,
|
||||
"updated_at": now,
|
||||
"updated_by": updated_by,
|
||||
},
|
||||
)
|
||||
|
||||
# Update in-memory registry: remove old production (by name), add this one
|
||||
self.remove_policy(policy_name)
|
||||
policy = self._parse_policy(
|
||||
policy_name,
|
||||
{
|
||||
"inherit": updated.inherit,
|
||||
"description": updated.description,
|
||||
"guardrails": {
|
||||
"add": updated.guardrails_add or [],
|
||||
"remove": updated.guardrails_remove or [],
|
||||
},
|
||||
"condition": updated.condition,
|
||||
"pipeline": updated.pipeline,
|
||||
},
|
||||
)
|
||||
self.add_policy(policy_name, policy)
|
||||
|
||||
return _row_to_policy_db_response(updated)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error updating version status: {e}")
|
||||
raise Exception(f"Error updating version status: {str(e)}")
|
||||
|
||||
async def compare_versions(
|
||||
self,
|
||||
policy_id_a: str,
|
||||
policy_id_b: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> PolicyVersionCompareResponse:
|
||||
"""
|
||||
Compare two policy versions and return field-by-field diffs.
|
||||
|
||||
Args:
|
||||
policy_id_a: First policy version ID
|
||||
policy_id_b: Second policy version ID
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
PolicyVersionCompareResponse with both versions and field_diffs
|
||||
"""
|
||||
try:
|
||||
a = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": policy_id_a}
|
||||
)
|
||||
b = await prisma_client.db.litellm_policytable.find_unique(
|
||||
where={"policy_id": policy_id_b}
|
||||
)
|
||||
if a is None:
|
||||
raise Exception(f"Policy {policy_id_a} not found")
|
||||
if b is None:
|
||||
raise Exception(f"Policy {policy_id_b} not found")
|
||||
|
||||
resp_a = _row_to_policy_db_response(a)
|
||||
resp_b = _row_to_policy_db_response(b)
|
||||
|
||||
# Compare fields that are part of policy content (not metadata)
|
||||
compare_fields = [
|
||||
"inherit",
|
||||
"description",
|
||||
"guardrails_add",
|
||||
"guardrails_remove",
|
||||
"condition",
|
||||
"pipeline",
|
||||
]
|
||||
field_diffs: Dict[str, Dict[str, Any]] = {}
|
||||
for field in compare_fields:
|
||||
val_a = getattr(resp_a, field)
|
||||
val_b = getattr(resp_b, field)
|
||||
if val_a != val_b:
|
||||
field_diffs[field] = {"version_a": val_a, "version_b": val_b}
|
||||
|
||||
return PolicyVersionCompareResponse(
|
||||
version_a=resp_a,
|
||||
version_b=resp_b,
|
||||
field_diffs=field_diffs,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error comparing versions: {e}")
|
||||
raise Exception(f"Error comparing versions: {str(e)}")
|
||||
|
||||
async def delete_all_versions(
|
||||
self,
|
||||
policy_name: str,
|
||||
prisma_client: "PrismaClient",
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Delete all versions of a policy. Also removes from in-memory registry.
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
prisma_client: The Prisma client instance
|
||||
|
||||
Returns:
|
||||
Dict with success message
|
||||
"""
|
||||
try:
|
||||
await prisma_client.db.litellm_policytable.delete_many(
|
||||
where={"policy_name": policy_name}
|
||||
)
|
||||
self.remove_policy(policy_name)
|
||||
return {
|
||||
"message": f"All versions of policy '{policy_name}' deleted successfully"
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error deleting all versions: {e}")
|
||||
raise Exception(f"Error deleting all versions: {str(e)}")
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_policy_registry: Optional[PolicyRegistry] = None
|
||||
|
||||
|
||||
def get_policy_registry() -> PolicyRegistry:
|
||||
"""
|
||||
Get the global PolicyRegistry singleton.
|
||||
|
||||
Returns:
|
||||
The global PolicyRegistry instance
|
||||
"""
|
||||
global _policy_registry
|
||||
if _policy_registry is None:
|
||||
_policy_registry = PolicyRegistry()
|
||||
return _policy_registry
|
||||
@@ -0,0 +1,416 @@
|
||||
"""
|
||||
Policy resolve and attachment impact estimation endpoints.
|
||||
|
||||
- /policies/resolve — debug which guardrails apply for a given context
|
||||
- /policies/attachments/estimate-impact — preview blast radius before creating an attachment
|
||||
"""
|
||||
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.constants import MAX_POLICY_ESTIMATE_IMPACT_ROWS
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
AttachmentImpactResponse,
|
||||
PolicyAttachmentCreateRequest,
|
||||
PolicyMatchContext,
|
||||
PolicyMatchDetail,
|
||||
PolicyResolveRequest,
|
||||
PolicyResolveResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _build_alias_where(field: str, patterns: list) -> dict:
|
||||
"""Build a Prisma ``where`` clause for alias patterns.
|
||||
|
||||
Supports exact matches and suffix wildcards (``prefix*``).
|
||||
Returns something like:
|
||||
{"OR": [{"field": {"in": ["a","b"]}}, {"field": {"startsWith": "dev-"}}]}
|
||||
"""
|
||||
exact: list = []
|
||||
prefix_conditions: list = []
|
||||
for pat in patterns:
|
||||
if pat.endswith("*"):
|
||||
prefix_conditions.append({field: {"startsWith": pat[:-1]}})
|
||||
else:
|
||||
exact.append(pat)
|
||||
|
||||
conditions: list = []
|
||||
if exact:
|
||||
conditions.append({field: {"in": exact}})
|
||||
conditions.extend(prefix_conditions)
|
||||
|
||||
if not conditions:
|
||||
return {field: {"not": None}}
|
||||
if len(conditions) == 1:
|
||||
return conditions[0]
|
||||
return {"OR": conditions}
|
||||
|
||||
|
||||
def _parse_metadata(raw_metadata: object) -> dict:
|
||||
"""Parse metadata that may be a dict, JSON string, or None."""
|
||||
if raw_metadata is None:
|
||||
return {}
|
||||
if isinstance(raw_metadata, str):
|
||||
try:
|
||||
return json.loads(raw_metadata)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return {}
|
||||
return raw_metadata if isinstance(raw_metadata, dict) else {}
|
||||
|
||||
|
||||
def _get_tags_from_metadata(metadata: object, json_metadata: object = None) -> list:
|
||||
"""Extract tags list from a metadata field (or metadata_json fallback)."""
|
||||
raw = json_metadata if json_metadata is not None else metadata
|
||||
parsed = _parse_metadata(raw)
|
||||
return parsed.get("tags", []) or []
|
||||
|
||||
|
||||
async def _fetch_all_teams(prisma_client: object) -> list:
|
||||
"""Fetch teams from DB once. Reuse the result across tag and alias lookups."""
|
||||
return await prisma_client.db.litellm_teamtable.find_many( # type: ignore
|
||||
where={},
|
||||
order={"created_at": "desc"},
|
||||
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
|
||||
)
|
||||
|
||||
|
||||
def _filter_keys_by_tags(keys: list, tag_patterns: list) -> tuple:
|
||||
"""Filter key rows whose metadata.tags match any of the given patterns.
|
||||
|
||||
Returns (named_aliases, unnamed_count).
|
||||
"""
|
||||
|
||||
affected: list = []
|
||||
unnamed_count = 0
|
||||
for key in keys:
|
||||
key_alias = key.key_alias or ""
|
||||
key_tags = _get_tags_from_metadata(
|
||||
key.metadata, getattr(key, "metadata_json", None)
|
||||
)
|
||||
if key_tags and any(
|
||||
RouteChecks._route_matches_wildcard_pattern(route=tag, pattern=pat)
|
||||
for tag in key_tags
|
||||
for pat in tag_patterns
|
||||
):
|
||||
if key_alias:
|
||||
affected.append(key_alias)
|
||||
else:
|
||||
unnamed_count += 1
|
||||
return affected, unnamed_count
|
||||
|
||||
|
||||
def _filter_teams_by_tags(teams: list, tag_patterns: list) -> tuple:
|
||||
"""Filter pre-fetched team rows whose metadata.tags match any patterns.
|
||||
|
||||
Returns (named_aliases, unnamed_count).
|
||||
"""
|
||||
|
||||
affected: list = []
|
||||
unnamed_count = 0
|
||||
for team in teams:
|
||||
team_alias = team.team_alias or ""
|
||||
team_tags = _get_tags_from_metadata(team.metadata)
|
||||
if team_tags and any(
|
||||
RouteChecks._route_matches_wildcard_pattern(route=tag, pattern=pat)
|
||||
for tag in team_tags
|
||||
for pat in tag_patterns
|
||||
):
|
||||
if team_alias:
|
||||
affected.append(team_alias)
|
||||
else:
|
||||
unnamed_count += 1
|
||||
return affected, unnamed_count
|
||||
|
||||
|
||||
async def _find_affected_by_team_patterns(
|
||||
prisma_client: object,
|
||||
all_teams: list,
|
||||
team_patterns: list,
|
||||
existing_teams: list,
|
||||
existing_keys: list,
|
||||
) -> tuple:
|
||||
"""Filter pre-fetched teams by alias patterns, then fetch their keys.
|
||||
|
||||
Returns (new_teams, new_keys, unnamed_keys_count).
|
||||
"""
|
||||
|
||||
new_teams: list = []
|
||||
matched_team_ids: list = []
|
||||
|
||||
for team in all_teams:
|
||||
team_alias = team.team_alias or ""
|
||||
if team_alias and any(
|
||||
RouteChecks._route_matches_wildcard_pattern(route=team_alias, pattern=pat)
|
||||
for pat in team_patterns
|
||||
):
|
||||
if team_alias not in existing_teams:
|
||||
new_teams.append(team_alias)
|
||||
matched_team_ids.append(str(team.team_id))
|
||||
|
||||
new_keys: list = []
|
||||
unnamed_keys_count = 0
|
||||
if matched_team_ids:
|
||||
keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore
|
||||
where={"team_id": {"in": matched_team_ids}},
|
||||
order={"created_at": "desc"},
|
||||
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
|
||||
)
|
||||
for key in keys:
|
||||
key_alias = key.key_alias or ""
|
||||
if key_alias:
|
||||
if key_alias not in existing_keys:
|
||||
new_keys.append(key_alias)
|
||||
else:
|
||||
unnamed_keys_count += 1
|
||||
|
||||
return new_teams, new_keys, unnamed_keys_count
|
||||
|
||||
|
||||
async def _find_affected_keys_by_alias(
|
||||
prisma_client: object, key_patterns: list, existing_keys: list
|
||||
) -> list:
|
||||
"""Find keys whose alias matches the given patterns."""
|
||||
|
||||
affected: list = []
|
||||
|
||||
keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore
|
||||
where=_build_alias_where("key_alias", key_patterns),
|
||||
order={"created_at": "desc"},
|
||||
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
|
||||
)
|
||||
for key in keys:
|
||||
key_alias = key.key_alias or ""
|
||||
if key_alias and any(
|
||||
RouteChecks._route_matches_wildcard_pattern(route=key_alias, pattern=pat)
|
||||
for pat in key_patterns
|
||||
):
|
||||
if key_alias not in existing_keys:
|
||||
affected.append(key_alias)
|
||||
return affected
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Policy Resolve Endpoint
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/policies/resolve",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=PolicyResolveResponse,
|
||||
)
|
||||
async def resolve_policies_for_context(
|
||||
request: PolicyResolveRequest,
|
||||
force_sync: bool = Query(
|
||||
default=False,
|
||||
description="Force a DB sync before resolving. Default uses in-memory cache.",
|
||||
),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Resolve which policies and guardrails apply for a given context.
|
||||
|
||||
Use this endpoint to debug "what guardrails would apply to a request
|
||||
with this team/key/model/tags combination?"
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/policies/resolve" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"tags": ["healthcare"],
|
||||
"model": "gpt-4"
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
|
||||
from litellm.proxy.policy_engine.policy_resolver import PolicyResolver
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# Only sync from DB when explicitly requested; otherwise use in-memory cache
|
||||
if force_sync:
|
||||
await get_policy_registry().sync_policies_from_db(prisma_client)
|
||||
await get_attachment_registry().sync_attachments_from_db(prisma_client)
|
||||
|
||||
# Build context from request
|
||||
context = PolicyMatchContext(
|
||||
team_alias=request.team_alias,
|
||||
key_alias=request.key_alias,
|
||||
model=request.model,
|
||||
tags=request.tags,
|
||||
)
|
||||
|
||||
# Get matching policies with reasons
|
||||
match_results = get_attachment_registry().get_attached_policies_with_reasons(
|
||||
context=context
|
||||
)
|
||||
|
||||
if not match_results:
|
||||
return PolicyResolveResponse(
|
||||
effective_guardrails=[],
|
||||
matched_policies=[],
|
||||
)
|
||||
|
||||
# Filter by conditions
|
||||
policy_names = [r["policy_name"] for r in match_results]
|
||||
applied_policy_names = PolicyMatcher.get_policies_with_matching_conditions(
|
||||
policy_names=policy_names,
|
||||
context=context,
|
||||
)
|
||||
|
||||
# Resolve guardrails for each applied policy
|
||||
matched_policies = []
|
||||
all_guardrails: set = set()
|
||||
for result in match_results:
|
||||
pname = result["policy_name"]
|
||||
if pname not in applied_policy_names:
|
||||
continue
|
||||
resolved = PolicyResolver.resolve_policy_guardrails(
|
||||
policy_name=pname,
|
||||
policies=get_policy_registry().get_all_policies(),
|
||||
context=context,
|
||||
)
|
||||
guardrails = resolved.guardrails if resolved else []
|
||||
all_guardrails.update(guardrails)
|
||||
matched_policies.append(
|
||||
PolicyMatchDetail(
|
||||
policy_name=pname,
|
||||
matched_via=result["matched_via"],
|
||||
guardrails_added=guardrails,
|
||||
)
|
||||
)
|
||||
|
||||
return PolicyResolveResponse(
|
||||
effective_guardrails=sorted(all_guardrails),
|
||||
matched_policies=matched_policies,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error resolving policies: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Attachment Impact Estimation Endpoint
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/policies/attachments/estimate-impact",
|
||||
tags=["Policies"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=AttachmentImpactResponse,
|
||||
)
|
||||
async def estimate_attachment_impact(
|
||||
request: PolicyAttachmentCreateRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Estimate how many keys and teams would be affected by a policy attachment.
|
||||
|
||||
Use this before creating an attachment to preview the blast radius.
|
||||
|
||||
Example Request:
|
||||
```bash
|
||||
curl -X POST "http://localhost:4000/policies/attachments/estimate-impact" \\
|
||||
-H "Authorization: Bearer <your_api_key>" \\
|
||||
-H "Content-Type: application/json" \\
|
||||
-d '{
|
||||
"policy_name": "hipaa-compliance",
|
||||
"tags": ["healthcare", "health-*"]
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail="Database not connected")
|
||||
|
||||
try:
|
||||
# If global scope, everything is affected — not useful to enumerate
|
||||
if request.scope == "*":
|
||||
return AttachmentImpactResponse(
|
||||
affected_keys_count=-1,
|
||||
affected_teams_count=-1,
|
||||
sample_keys=["(global scope — affects all keys)"],
|
||||
sample_teams=["(global scope — affects all teams)"],
|
||||
)
|
||||
|
||||
affected_keys: list = []
|
||||
affected_teams: list = []
|
||||
unnamed_keys = 0
|
||||
unnamed_teams = 0
|
||||
|
||||
tag_patterns = request.tags or []
|
||||
team_patterns = request.teams or []
|
||||
|
||||
# Fetch teams once — reused by both tag-based and alias-based lookups
|
||||
all_teams: list = []
|
||||
if tag_patterns or team_patterns:
|
||||
all_teams = await _fetch_all_teams(prisma_client)
|
||||
|
||||
# Tag-based impact
|
||||
if tag_patterns:
|
||||
keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore
|
||||
where={},
|
||||
order={"created_at": "desc"},
|
||||
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
|
||||
)
|
||||
affected_keys, unnamed_keys = _filter_keys_by_tags(keys, tag_patterns)
|
||||
affected_teams, unnamed_teams = _filter_teams_by_tags(
|
||||
all_teams,
|
||||
tag_patterns,
|
||||
)
|
||||
|
||||
# Team-based impact (alias matching + keys belonging to those teams)
|
||||
if team_patterns:
|
||||
new_teams, new_keys, new_unnamed = await _find_affected_by_team_patterns(
|
||||
prisma_client,
|
||||
all_teams,
|
||||
team_patterns,
|
||||
affected_teams,
|
||||
affected_keys,
|
||||
)
|
||||
affected_teams.extend(new_teams)
|
||||
affected_keys.extend(new_keys)
|
||||
unnamed_keys += new_unnamed
|
||||
|
||||
# Key-based impact (direct alias matching)
|
||||
key_patterns = request.keys or []
|
||||
if key_patterns:
|
||||
new_keys = await _find_affected_keys_by_alias(
|
||||
prisma_client,
|
||||
key_patterns,
|
||||
affected_keys,
|
||||
)
|
||||
affected_keys.extend(new_keys)
|
||||
|
||||
return AttachmentImpactResponse(
|
||||
affected_keys_count=len(affected_keys) + unnamed_keys,
|
||||
affected_teams_count=len(affected_teams) + unnamed_teams,
|
||||
unnamed_keys_count=unnamed_keys,
|
||||
unnamed_teams_count=unnamed_teams,
|
||||
sample_keys=affected_keys[:10],
|
||||
sample_teams=affected_teams[:10],
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error estimating attachment impact: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -0,0 +1,299 @@
|
||||
"""
|
||||
Policy Resolver - Resolves final guardrail list from policies.
|
||||
|
||||
Handles:
|
||||
- Inheritance chain resolution (inherit with add/remove)
|
||||
- Applying add/remove guardrails
|
||||
- Evaluating model conditions
|
||||
- Combining guardrails from multiple matching policies
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
GuardrailPipeline,
|
||||
Policy,
|
||||
PolicyMatchContext,
|
||||
ResolvedPolicy,
|
||||
)
|
||||
|
||||
|
||||
class PolicyResolver:
|
||||
"""
|
||||
Resolves the final list of guardrails from policies.
|
||||
|
||||
Handles:
|
||||
- Inheritance chains with add/remove operations
|
||||
- Model-based conditions
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def resolve_inheritance_chain(
|
||||
policy_name: str,
|
||||
policies: Dict[str, Policy],
|
||||
visited: Optional[Set[str]] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get the inheritance chain for a policy (from root to policy).
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy
|
||||
policies: Dictionary of all policies
|
||||
visited: Set of visited policies (for cycle detection)
|
||||
|
||||
Returns:
|
||||
List of policy names from root ancestor to the given policy
|
||||
"""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
if policy_name in visited:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Circular inheritance detected for policy '{policy_name}'"
|
||||
)
|
||||
return []
|
||||
|
||||
policy = policies.get(policy_name)
|
||||
if policy is None:
|
||||
return []
|
||||
|
||||
visited.add(policy_name)
|
||||
|
||||
if policy.inherit:
|
||||
parent_chain = PolicyResolver.resolve_inheritance_chain(
|
||||
policy_name=policy.inherit, policies=policies, visited=visited
|
||||
)
|
||||
return parent_chain + [policy_name]
|
||||
|
||||
return [policy_name]
|
||||
|
||||
@staticmethod
|
||||
def resolve_policy_guardrails(
|
||||
policy_name: str,
|
||||
policies: Dict[str, Policy],
|
||||
context: Optional[PolicyMatchContext] = None,
|
||||
) -> ResolvedPolicy:
|
||||
"""
|
||||
Resolve the final guardrails for a single policy, including inheritance.
|
||||
|
||||
This method:
|
||||
1. Resolves the inheritance chain
|
||||
2. Applies add/remove from each policy in the chain
|
||||
3. Evaluates model conditions (if context provided)
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy to resolve
|
||||
policies: Dictionary of all policies
|
||||
context: Optional request context for evaluating conditions
|
||||
|
||||
Returns:
|
||||
ResolvedPolicy with final guardrails list
|
||||
"""
|
||||
from litellm.proxy.policy_engine.condition_evaluator import ConditionEvaluator
|
||||
|
||||
inheritance_chain = PolicyResolver.resolve_inheritance_chain(
|
||||
policy_name=policy_name, policies=policies
|
||||
)
|
||||
|
||||
# Start with empty set of guardrails
|
||||
guardrails: Set[str] = set()
|
||||
|
||||
# Apply each policy in the chain (from root to leaf)
|
||||
for chain_policy_name in inheritance_chain:
|
||||
policy = policies.get(chain_policy_name)
|
||||
if policy is None:
|
||||
continue
|
||||
|
||||
# Check if policy condition matches (if context provided)
|
||||
if context is not None and policy.condition is not None:
|
||||
if not ConditionEvaluator.evaluate(
|
||||
condition=policy.condition,
|
||||
context=context,
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
f"Policy '{chain_policy_name}' condition did not match, skipping guardrails"
|
||||
)
|
||||
continue
|
||||
|
||||
# Add guardrails from guardrails.add
|
||||
for guardrail in policy.guardrails.get_add():
|
||||
guardrails.add(guardrail)
|
||||
|
||||
# Remove guardrails from guardrails.remove
|
||||
for guardrail in policy.guardrails.get_remove():
|
||||
guardrails.discard(guardrail)
|
||||
|
||||
return ResolvedPolicy(
|
||||
policy_name=policy_name,
|
||||
guardrails=list(guardrails),
|
||||
inheritance_chain=inheritance_chain,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def resolve_guardrails_for_context(
|
||||
context: PolicyMatchContext,
|
||||
policies: Optional[Dict[str, Policy]] = None,
|
||||
policy_names: Optional[List[str]] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Resolve the final list of guardrails for a request context.
|
||||
|
||||
This:
|
||||
1. Finds all policies that match the context via policy_attachments (or policy_names if provided)
|
||||
2. Resolves each policy's guardrails (including inheritance)
|
||||
3. Evaluates model conditions
|
||||
4. Combines all guardrails (union)
|
||||
|
||||
Args:
|
||||
context: The request context
|
||||
policies: Dictionary of all policies (if None, uses global registry)
|
||||
policy_names: If provided, use this list instead of attachment matching
|
||||
|
||||
Returns:
|
||||
List of guardrail names to apply
|
||||
"""
|
||||
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
|
||||
if policies is None:
|
||||
registry = get_policy_registry()
|
||||
if not registry.is_initialized():
|
||||
return []
|
||||
policies = registry.get_all_policies()
|
||||
|
||||
# Use provided policy names or get matching policies via attachments
|
||||
matching_policy_names = (
|
||||
policy_names
|
||||
if policy_names is not None
|
||||
else PolicyMatcher.get_matching_policies(context=context)
|
||||
)
|
||||
|
||||
if not matching_policy_names:
|
||||
verbose_proxy_logger.debug(
|
||||
f"No policies match context: team_alias={context.team_alias}, "
|
||||
f"key_alias={context.key_alias}, model={context.model}"
|
||||
)
|
||||
return []
|
||||
|
||||
# Resolve each matching policy and combine guardrails
|
||||
all_guardrails: Set[str] = set()
|
||||
|
||||
for policy_name in matching_policy_names:
|
||||
resolved = PolicyResolver.resolve_policy_guardrails(
|
||||
policy_name=policy_name,
|
||||
policies=policies,
|
||||
context=context,
|
||||
)
|
||||
all_guardrails.update(resolved.guardrails)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Policy '{policy_name}' contributes guardrails: {resolved.guardrails}"
|
||||
)
|
||||
|
||||
result = list(all_guardrails)
|
||||
verbose_proxy_logger.debug(f"Final guardrails for context: {result}")
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def resolve_pipelines_for_context(
|
||||
context: PolicyMatchContext,
|
||||
policies: Optional[Dict[str, Policy]] = None,
|
||||
policy_names: Optional[List[str]] = None,
|
||||
) -> List[Tuple[str, GuardrailPipeline]]:
|
||||
"""
|
||||
Resolve pipelines from matching policies for a request context.
|
||||
|
||||
Returns (policy_name, pipeline) tuples for policies that have pipelines.
|
||||
Guardrails managed by pipelines should be excluded from the flat
|
||||
guardrails list to avoid double execution.
|
||||
|
||||
Args:
|
||||
context: The request context
|
||||
policies: Dictionary of all policies (if None, uses global registry)
|
||||
policy_names: If provided, use this list instead of attachment matching
|
||||
|
||||
Returns:
|
||||
List of (policy_name, GuardrailPipeline) tuples
|
||||
"""
|
||||
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
|
||||
if policies is None:
|
||||
registry = get_policy_registry()
|
||||
if not registry.is_initialized():
|
||||
return []
|
||||
policies = registry.get_all_policies()
|
||||
|
||||
matching_policy_names = (
|
||||
policy_names
|
||||
if policy_names is not None
|
||||
else PolicyMatcher.get_matching_policies(context=context)
|
||||
)
|
||||
if not matching_policy_names:
|
||||
return []
|
||||
|
||||
pipelines: List[Tuple[str, GuardrailPipeline]] = []
|
||||
for policy_name in matching_policy_names:
|
||||
policy = policies.get(policy_name)
|
||||
if policy is None:
|
||||
continue
|
||||
if policy.pipeline is not None:
|
||||
pipelines.append((policy_name, policy.pipeline))
|
||||
verbose_proxy_logger.debug(
|
||||
f"Policy '{policy_name}' has pipeline with "
|
||||
f"{len(policy.pipeline.steps)} steps"
|
||||
)
|
||||
|
||||
return pipelines
|
||||
|
||||
@staticmethod
|
||||
def get_pipeline_managed_guardrails(
|
||||
pipelines: List[Tuple[str, GuardrailPipeline]],
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Get the set of guardrail names managed by pipelines.
|
||||
|
||||
These guardrails should be excluded from normal independent execution.
|
||||
"""
|
||||
managed: Set[str] = set()
|
||||
for _policy_name, pipeline in pipelines:
|
||||
for step in pipeline.steps:
|
||||
managed.add(step.guardrail)
|
||||
return managed
|
||||
|
||||
@staticmethod
|
||||
def get_all_resolved_policies(
|
||||
policies: Optional[Dict[str, Policy]] = None,
|
||||
context: Optional[PolicyMatchContext] = None,
|
||||
) -> Dict[str, ResolvedPolicy]:
|
||||
"""
|
||||
Resolve all policies and return their final guardrails.
|
||||
|
||||
Useful for debugging and displaying policy configurations.
|
||||
|
||||
Args:
|
||||
policies: Dictionary of all policies (if None, uses global registry)
|
||||
context: Optional context for evaluating conditions
|
||||
|
||||
Returns:
|
||||
Dictionary mapping policy names to ResolvedPolicy objects
|
||||
"""
|
||||
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
|
||||
|
||||
if policies is None:
|
||||
registry = get_policy_registry()
|
||||
if not registry.is_initialized():
|
||||
return {}
|
||||
policies = registry.get_all_policies()
|
||||
|
||||
resolved: Dict[str, ResolvedPolicy] = {}
|
||||
|
||||
for policy_name in policies:
|
||||
resolved[policy_name] = PolicyResolver.resolve_policy_guardrails(
|
||||
policy_name=policy_name,
|
||||
policies=policies,
|
||||
context=context,
|
||||
)
|
||||
|
||||
return resolved
|
||||
@@ -0,0 +1,405 @@
|
||||
"""
|
||||
Policy Validator - Validates policy configurations.
|
||||
|
||||
Validates:
|
||||
- Guardrail names exist in the guardrail registry
|
||||
- Non-wildcard team aliases exist in the database
|
||||
- Non-wildcard key aliases exist in the database
|
||||
- Non-wildcard model names exist in the router or match a wildcard route
|
||||
- Inheritance chains are valid (no cycles, parents exist)
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.types.proxy.policy_engine import (
|
||||
Policy,
|
||||
PolicyValidationError,
|
||||
PolicyValidationErrorType,
|
||||
PolicyValidationResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.router import Router
|
||||
|
||||
|
||||
class PolicyValidator:
|
||||
"""
|
||||
Validates policy configurations against actual data.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prisma_client: Optional["PrismaClient"] = None,
|
||||
llm_router: Optional["Router"] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the validator.
|
||||
|
||||
Args:
|
||||
prisma_client: Optional Prisma client for database validation
|
||||
llm_router: Optional LLM router for model validation
|
||||
"""
|
||||
self.prisma_client = prisma_client
|
||||
self.llm_router = llm_router
|
||||
|
||||
@staticmethod
|
||||
def is_wildcard_pattern(pattern: str) -> bool:
|
||||
"""
|
||||
Check if a pattern contains wildcards.
|
||||
|
||||
Args:
|
||||
pattern: The pattern to check
|
||||
|
||||
Returns:
|
||||
True if the pattern contains wildcard characters
|
||||
"""
|
||||
return "*" in pattern or "?" in pattern
|
||||
|
||||
def get_available_guardrails(self) -> Set[str]:
|
||||
"""
|
||||
Get set of available guardrail names from the guardrail registry.
|
||||
|
||||
Returns:
|
||||
Set of guardrail names
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.guardrails.guardrail_registry import (
|
||||
IN_MEMORY_GUARDRAIL_HANDLER,
|
||||
)
|
||||
|
||||
guardrails = IN_MEMORY_GUARDRAIL_HANDLER.list_in_memory_guardrails()
|
||||
return {
|
||||
g.get("guardrail_name", "")
|
||||
for g in guardrails
|
||||
if g.get("guardrail_name")
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Could not get guardrails from registry: {str(e)}"
|
||||
)
|
||||
return set()
|
||||
|
||||
async def check_team_alias_exists(self, team_alias: str) -> bool:
|
||||
"""
|
||||
Check if a specific team alias exists in the database.
|
||||
|
||||
Args:
|
||||
team_alias: The team alias to check
|
||||
|
||||
Returns:
|
||||
True if the team alias exists
|
||||
"""
|
||||
if self.prisma_client is None:
|
||||
return True # Can't validate without DB, assume valid
|
||||
|
||||
try:
|
||||
team = await self.prisma_client.db.litellm_teamtable.find_first(
|
||||
where={"team_alias": team_alias},
|
||||
)
|
||||
return team is not None
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Could not check team alias '{team_alias}': {str(e)}"
|
||||
)
|
||||
return True # Assume valid on error
|
||||
|
||||
async def check_key_alias_exists(self, key_alias: str) -> bool:
|
||||
"""
|
||||
Check if a specific key alias exists in the database.
|
||||
|
||||
Args:
|
||||
key_alias: The key alias to check
|
||||
|
||||
Returns:
|
||||
True if the key alias exists
|
||||
"""
|
||||
if self.prisma_client is None:
|
||||
return True # Can't validate without DB, assume valid
|
||||
|
||||
try:
|
||||
key = await self.prisma_client.db.litellm_verificationtoken.find_first(
|
||||
where={"key_alias": key_alias},
|
||||
)
|
||||
return key is not None
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Could not check key alias '{key_alias}': {str(e)}"
|
||||
)
|
||||
return True # Assume valid on error
|
||||
|
||||
def check_model_exists(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a model exists in the router or matches a wildcard pattern.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
|
||||
Returns:
|
||||
True if the model exists or matches a pattern in the router
|
||||
"""
|
||||
if self.llm_router is None:
|
||||
return True # Can't validate without router, assume valid
|
||||
|
||||
try:
|
||||
# Check if model is in router's model names
|
||||
if model in self.llm_router.model_names:
|
||||
return True
|
||||
|
||||
# Check if model matches any pattern via pattern router
|
||||
if hasattr(self.llm_router, "pattern_router"):
|
||||
pattern_deployments = (
|
||||
self.llm_router.pattern_router.get_deployments_by_pattern(
|
||||
model=model
|
||||
)
|
||||
)
|
||||
if pattern_deployments:
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(f"Could not check model '{model}': {str(e)}")
|
||||
return True # Assume valid on error
|
||||
|
||||
def _validate_inheritance_chain(
|
||||
self,
|
||||
policy_name: str,
|
||||
policies: Dict[str, Policy],
|
||||
visited: Optional[Set[str]] = None,
|
||||
max_depth: int = 100,
|
||||
) -> List[PolicyValidationError]:
|
||||
"""
|
||||
Validate the inheritance chain for a policy.
|
||||
|
||||
Checks for:
|
||||
- Parent policy exists
|
||||
- No circular inheritance
|
||||
- Max depth not exceeded
|
||||
|
||||
Args:
|
||||
policy_name: Name of the policy to validate
|
||||
policies: All policies
|
||||
visited: Set of already visited policy names (for cycle detection)
|
||||
max_depth: Maximum recursion depth to prevent infinite loops
|
||||
|
||||
Returns:
|
||||
List of validation errors
|
||||
"""
|
||||
errors: List[PolicyValidationError] = []
|
||||
|
||||
# Prevent infinite recursion
|
||||
if max_depth <= 0:
|
||||
errors.append(
|
||||
PolicyValidationError(
|
||||
policy_name=policy_name,
|
||||
error_type=PolicyValidationErrorType.CIRCULAR_INHERITANCE,
|
||||
message="Inheritance chain too deep (exceeded max depth of 100)",
|
||||
field="inherit",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
if visited is None:
|
||||
visited = set()
|
||||
|
||||
if policy_name in visited:
|
||||
errors.append(
|
||||
PolicyValidationError(
|
||||
policy_name=policy_name,
|
||||
error_type=PolicyValidationErrorType.CIRCULAR_INHERITANCE,
|
||||
message=f"Circular inheritance detected: {' -> '.join(visited)} -> {policy_name}",
|
||||
field="inherit",
|
||||
)
|
||||
)
|
||||
return errors
|
||||
|
||||
policy = policies.get(policy_name)
|
||||
if policy is None:
|
||||
return errors
|
||||
|
||||
if policy.inherit:
|
||||
if policy.inherit not in policies:
|
||||
errors.append(
|
||||
PolicyValidationError(
|
||||
policy_name=policy_name,
|
||||
error_type=PolicyValidationErrorType.INVALID_INHERITANCE,
|
||||
message=f"Parent policy '{policy.inherit}' not found",
|
||||
field="inherit",
|
||||
value=policy.inherit,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Recursively check parent with decremented depth
|
||||
visited.add(policy_name)
|
||||
errors.extend(
|
||||
self._validate_inheritance_chain(
|
||||
policy.inherit, policies, visited, max_depth - 1
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
async def validate_policies(
|
||||
self,
|
||||
policies: Dict[str, Policy],
|
||||
validate_db: bool = True,
|
||||
) -> PolicyValidationResponse:
|
||||
"""
|
||||
Validate a set of policies.
|
||||
|
||||
Args:
|
||||
policies: Dictionary mapping policy names to Policy objects
|
||||
validate_db: Whether to validate against database (teams, keys)
|
||||
|
||||
Returns:
|
||||
PolicyValidationResponse with errors and warnings
|
||||
"""
|
||||
errors: List[PolicyValidationError] = []
|
||||
warnings: List[PolicyValidationError] = []
|
||||
|
||||
# Get available guardrails
|
||||
available_guardrails = self.get_available_guardrails()
|
||||
|
||||
for policy_name, policy in policies.items():
|
||||
# Validate guardrails
|
||||
for guardrail in policy.guardrails.get_add():
|
||||
if available_guardrails and guardrail not in available_guardrails:
|
||||
errors.append(
|
||||
PolicyValidationError(
|
||||
policy_name=policy_name,
|
||||
error_type=PolicyValidationErrorType.INVALID_GUARDRAIL,
|
||||
message=f"Guardrail '{guardrail}' not found in guardrail registry",
|
||||
field="guardrails.add",
|
||||
value=guardrail,
|
||||
)
|
||||
)
|
||||
|
||||
for guardrail in policy.guardrails.get_remove():
|
||||
if available_guardrails and guardrail not in available_guardrails:
|
||||
warnings.append(
|
||||
PolicyValidationError(
|
||||
policy_name=policy_name,
|
||||
error_type=PolicyValidationErrorType.INVALID_GUARDRAIL,
|
||||
message=f"Guardrail '{guardrail}' in remove list not found in guardrail registry",
|
||||
field="guardrails.remove",
|
||||
value=guardrail,
|
||||
)
|
||||
)
|
||||
|
||||
# Validate pipeline if present
|
||||
if policy.pipeline is not None:
|
||||
pipeline_errors = PolicyValidator._validate_pipeline(
|
||||
policy_name=policy_name,
|
||||
policy=policy,
|
||||
available_guardrails=available_guardrails,
|
||||
)
|
||||
errors.extend(pipeline_errors)
|
||||
|
||||
# Validate inheritance
|
||||
inheritance_errors = self._validate_inheritance_chain(
|
||||
policy_name=policy_name, policies=policies
|
||||
)
|
||||
errors.extend(inheritance_errors)
|
||||
|
||||
return PolicyValidationResponse(
|
||||
valid=len(errors) == 0,
|
||||
errors=errors,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_pipeline(
|
||||
policy_name: str,
|
||||
policy: Policy,
|
||||
available_guardrails: Set[str],
|
||||
) -> List[PolicyValidationError]:
|
||||
"""Validate a policy's pipeline configuration."""
|
||||
errors: List[PolicyValidationError] = []
|
||||
pipeline = policy.pipeline
|
||||
if pipeline is None:
|
||||
return errors
|
||||
|
||||
guardrails_add = set(policy.guardrails.get_add())
|
||||
|
||||
for i, step in enumerate(pipeline.steps):
|
||||
# Check guardrail is in policy's guardrails.add
|
||||
if step.guardrail not in guardrails_add:
|
||||
errors.append(
|
||||
PolicyValidationError(
|
||||
policy_name=policy_name,
|
||||
error_type=PolicyValidationErrorType.INVALID_GUARDRAIL,
|
||||
message=(
|
||||
f"Pipeline step {i} guardrail '{step.guardrail}' "
|
||||
f"is not in the policy's guardrails.add list"
|
||||
),
|
||||
field="pipeline.steps",
|
||||
value=step.guardrail,
|
||||
)
|
||||
)
|
||||
|
||||
# Check guardrail exists in registry
|
||||
if available_guardrails and step.guardrail not in available_guardrails:
|
||||
errors.append(
|
||||
PolicyValidationError(
|
||||
policy_name=policy_name,
|
||||
error_type=PolicyValidationErrorType.INVALID_GUARDRAIL,
|
||||
message=(
|
||||
f"Pipeline step {i} guardrail '{step.guardrail}' "
|
||||
f"not found in guardrail registry"
|
||||
),
|
||||
field="pipeline.steps",
|
||||
value=step.guardrail,
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
async def validate_policy_config(
|
||||
self,
|
||||
policy_config: Dict[str, Any],
|
||||
validate_db: bool = True,
|
||||
) -> PolicyValidationResponse:
|
||||
"""
|
||||
Validate a raw policy configuration dictionary.
|
||||
|
||||
This parses the config and then validates it.
|
||||
|
||||
Args:
|
||||
policy_config: Raw policy configuration from YAML
|
||||
validate_db: Whether to validate against database
|
||||
|
||||
Returns:
|
||||
PolicyValidationResponse with errors and warnings
|
||||
"""
|
||||
from litellm.proxy.policy_engine.policy_registry import PolicyRegistry
|
||||
|
||||
# First, try to parse the policies
|
||||
errors: List[PolicyValidationError] = []
|
||||
policies: Dict[str, Policy] = {}
|
||||
|
||||
temp_registry = PolicyRegistry()
|
||||
|
||||
for policy_name, policy_data in policy_config.items():
|
||||
try:
|
||||
policy = temp_registry._parse_policy(policy_name, policy_data)
|
||||
policies[policy_name] = policy
|
||||
except Exception as e:
|
||||
errors.append(
|
||||
PolicyValidationError(
|
||||
policy_name=policy_name,
|
||||
error_type=PolicyValidationErrorType.INVALID_SYNTAX,
|
||||
message=f"Failed to parse policy: {str(e)}",
|
||||
)
|
||||
)
|
||||
|
||||
# If there were parsing errors, return early
|
||||
if errors:
|
||||
return PolicyValidationResponse(
|
||||
valid=False,
|
||||
errors=errors,
|
||||
warnings=[],
|
||||
)
|
||||
|
||||
# Validate the parsed policies
|
||||
return await self.validate_policies(policies, validate_db=validate_db)
|
||||
Reference in New Issue
Block a user