chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,60 @@
"""
LiteLLM Policy Engine
The Policy Engine allows administrators to define policies that combine guardrails
with scoping rules. Policies can target specific teams, API keys, and models using
wildcard patterns, and support inheritance from base policies.
Configuration structure:
- `policies`: Define WHAT guardrails to apply (with inheritance and conditions)
- `policy_attachments`: Define WHERE policies apply (teams, keys, models)
Example:
```yaml
policies:
global-baseline:
description: "Base guardrails for all requests"
guardrails:
add: [pii_blocker]
gpt4-safety:
inherit: global-baseline
description: "Extra safety for GPT-4"
guardrails:
add: [toxicity_filter]
condition:
model: "gpt-4.*" # regex pattern
policy_attachments:
- policy: global-baseline
scope: "*"
- policy: gpt4-safety
scope: "*"
```
"""
from litellm.proxy.policy_engine.attachment_registry import (
AttachmentRegistry,
get_attachment_registry,
)
from litellm.proxy.policy_engine.condition_evaluator import ConditionEvaluator
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
from litellm.proxy.policy_engine.policy_registry import (
PolicyRegistry,
get_policy_registry,
)
from litellm.proxy.policy_engine.policy_resolver import PolicyResolver
from litellm.proxy.policy_engine.policy_validator import PolicyValidator
__all__ = [
# Registries
"PolicyRegistry",
"get_policy_registry",
"AttachmentRegistry",
"get_attachment_registry",
# Core components
"PolicyMatcher",
"PolicyResolver",
"PolicyValidator",
"ConditionEvaluator",
]

View File

@@ -0,0 +1,54 @@
# Policy Engine Architecture
## Overview
The Policy Engine allows administrators to define policies that combine guardrails with scoping rules. Policies can target specific teams, API keys, and models using wildcard patterns, and support inheritance from base policies.
## Architecture Diagram
```mermaid
flowchart TD
subgraph Config["config.yaml"]
PC[policies config]
end
subgraph PolicyEngine["Policy Engine"]
PR[PolicyRegistry]
PV[PolicyValidator]
PM[PolicyMatcher]
PRe[PolicyResolver]
end
subgraph Request["Incoming Request"]
CTX[Context: team_alias, key_alias, model]
end
subgraph Output["Output"]
GR[Guardrails to Apply]
end
PC -->|load| PR
PC -->|validate| PV
PV -->|errors/warnings| PR
CTX -->|match| PM
PM -->|matching policies| PRe
PR -->|policies| PM
PR -->|policies| PRe
PRe -->|resolve inheritance + add/remove| GR
```
## Components
| Component | File | Description |
|-----------|------|-------------|
| **PolicyRegistry** | `policy_registry.py` | In-memory singleton store for parsed policies |
| **PolicyValidator** | `policy_validator.py` | Validates configs (guardrails, inheritance, teams/keys/models) |
| **PolicyMatcher** | `policy_matcher.py` | Matches request context against policy scopes |
| **PolicyResolver** | `policy_resolver.py` | Resolves final guardrails via inheritance chain |
## Flow
1. **Startup**: `init_policies()` loads policies from config, validates, and populates `PolicyRegistry`
2. **Request**: `PolicyMatcher` finds policies matching the request's team/key/model
3. **Resolution**: `PolicyResolver` traverses inheritance and applies add/remove to get final guardrails

View File

@@ -0,0 +1,501 @@
"""
Attachment Registry - Manages policy attachments from YAML config.
Attachments define WHERE policies apply, separate from the policy definitions.
This allows the same policy to be attached to multiple scopes.
"""
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.types.proxy.policy_engine import (
PolicyAttachment,
PolicyAttachmentCreateRequest,
PolicyAttachmentDBResponse,
PolicyMatchContext,
)
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
class AttachmentRegistry:
"""
In-memory registry for storing and managing policy attachments.
Attachments define the relationship between policies and their scopes.
A single policy can have multiple attachments (applied to different scopes).
Example YAML:
```yaml
attachments:
- policy: global-baseline
scope: "*"
- policy: healthcare-compliance
teams: [healthcare-team]
- policy: dev-safety
keys: ["dev-key-*"]
```
"""
def __init__(self):
self._attachments: List[PolicyAttachment] = []
self._initialized: bool = False
def load_attachments(self, attachments_config: List[Dict[str, Any]]) -> None:
"""
Load attachments from a configuration list.
Args:
attachments_config: List of attachment dictionaries from YAML.
"""
self._attachments = []
for attachment_data in attachments_config:
try:
attachment = self._parse_attachment(attachment_data)
self._attachments.append(attachment)
verbose_proxy_logger.debug(
f"Loaded attachment for policy: {attachment.policy}"
)
except Exception as e:
verbose_proxy_logger.error(f"Error loading attachment: {str(e)}")
raise ValueError(f"Invalid attachment: {str(e)}") from e
self._initialized = True
verbose_proxy_logger.info(f"Loaded {len(self._attachments)} policy attachments")
def _parse_attachment(self, attachment_data: Dict[str, Any]) -> PolicyAttachment:
"""
Parse an attachment from raw configuration data.
Args:
attachment_data: Raw attachment configuration
Returns:
Parsed PolicyAttachment object
"""
return PolicyAttachment(
policy=attachment_data.get("policy", ""),
scope=attachment_data.get("scope"),
teams=attachment_data.get("teams"),
keys=attachment_data.get("keys"),
models=attachment_data.get("models"),
tags=attachment_data.get("tags"),
)
def get_attached_policies(self, context: PolicyMatchContext) -> List[str]:
"""
Get list of policy names attached to the given context.
Args:
context: The request context to match against
Returns:
List of policy names that are attached to matching scopes
"""
return [
r["policy_name"] for r in self.get_attached_policies_with_reasons(context)
]
def get_attached_policies_with_reasons(
self, context: PolicyMatchContext
) -> List[Dict[str, Any]]:
"""
Get list of policy names and match reasons for the given context.
Returns a list of dicts with 'policy_name' and 'matched_via' keys.
The 'matched_via' describes which dimension caused the match.
"""
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
results: List[Dict[str, Any]] = []
seen_policies: set = set()
for attachment in self._attachments:
scope = attachment.to_policy_scope()
if PolicyMatcher.scope_matches(scope=scope, context=context):
if attachment.policy not in seen_policies:
seen_policies.add(attachment.policy)
matched_via = self._describe_match_reason(attachment, context)
results.append(
{
"policy_name": attachment.policy,
"matched_via": matched_via,
}
)
verbose_proxy_logger.debug(
f"Attachment matched: policy={attachment.policy}, "
f"matched_via={matched_via}, "
f"context=(team={context.team_alias}, key={context.key_alias}, model={context.model})"
)
return results
@staticmethod
def _describe_match_reason(
attachment: PolicyAttachment, context: PolicyMatchContext
) -> str:
"""Describe why an attachment matched the context."""
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
if attachment.is_global():
return "scope:*"
reasons = []
if attachment.tags and context.tags:
matching_tags = [
t
for t in context.tags
if PolicyMatcher.matches_pattern(t, attachment.tags)
]
if matching_tags:
reasons.append(f"tag:{matching_tags[0]}")
if attachment.teams and context.team_alias:
reasons.append(f"team:{context.team_alias}")
if attachment.keys and context.key_alias:
reasons.append(f"key:{context.key_alias}")
if attachment.models and context.model:
reasons.append(f"model:{context.model}")
return "+".join(reasons) if reasons else "scope:default"
def is_policy_attached(self, policy_name: str, context: PolicyMatchContext) -> bool:
"""
Check if a specific policy is attached to the given context.
Args:
policy_name: Name of the policy to check
context: The request context to match against
Returns:
True if the policy is attached to a matching scope
"""
attached = self.get_attached_policies(context)
return policy_name in attached
def get_all_attachments(self) -> List[PolicyAttachment]:
"""
Get all loaded attachments.
Returns:
List of all PolicyAttachment objects
"""
return self._attachments.copy()
def get_attachments_for_policy(self, policy_name: str) -> List[PolicyAttachment]:
"""
Get all attachments for a specific policy.
Args:
policy_name: Name of the policy
Returns:
List of attachments for the policy
"""
return [a for a in self._attachments if a.policy == policy_name]
def is_initialized(self) -> bool:
"""
Check if the registry has been initialized with attachments.
Returns:
True if attachments have been loaded, False otherwise
"""
return self._initialized
def clear(self) -> None:
"""
Clear all attachments from the registry.
"""
self._attachments = []
self._initialized = False
def add_attachment(self, attachment: PolicyAttachment) -> None:
"""
Add a single attachment.
Args:
attachment: PolicyAttachment object to add
"""
self._attachments.append(attachment)
verbose_proxy_logger.debug(f"Added attachment for policy: {attachment.policy}")
def remove_attachments_for_policy(self, policy_name: str) -> int:
"""
Remove all attachments for a specific policy.
Args:
policy_name: Name of the policy
Returns:
Number of attachments removed
"""
original_count = len(self._attachments)
self._attachments = [a for a in self._attachments if a.policy != policy_name]
removed_count = original_count - len(self._attachments)
if removed_count > 0:
verbose_proxy_logger.debug(
f"Removed {removed_count} attachment(s) for policy: {policy_name}"
)
return removed_count
def remove_attachment_by_id(self, attachment_id: str) -> bool:
"""
Remove an attachment by its ID (for DB-synced attachments).
Args:
attachment_id: The ID of the attachment to remove
Returns:
True if removed, False if not found
"""
# Note: In-memory attachments don't have IDs, so this is primarily
# for consistency after DB operations
return False
# ─────────────────────────────────────────────────────────────────────────
# Database CRUD Methods
# ─────────────────────────────────────────────────────────────────────────
async def add_attachment_to_db(
self,
attachment_request: PolicyAttachmentCreateRequest,
prisma_client: "PrismaClient",
created_by: Optional[str] = None,
) -> PolicyAttachmentDBResponse:
"""
Add a policy attachment to the database.
Args:
attachment_request: The attachment creation request
prisma_client: The Prisma client instance
created_by: User who created the attachment
Returns:
PolicyAttachmentDBResponse with the created attachment
"""
try:
created_attachment = (
await prisma_client.db.litellm_policyattachmenttable.create(
data={
"policy_name": attachment_request.policy_name,
"scope": attachment_request.scope,
"teams": attachment_request.teams or [],
"keys": attachment_request.keys or [],
"models": attachment_request.models or [],
"tags": attachment_request.tags or [],
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
"created_by": created_by,
"updated_by": created_by,
}
)
)
# Also add to in-memory registry
attachment = PolicyAttachment(
policy=attachment_request.policy_name,
scope=attachment_request.scope,
teams=attachment_request.teams,
keys=attachment_request.keys,
models=attachment_request.models,
tags=attachment_request.tags,
)
self.add_attachment(attachment)
return PolicyAttachmentDBResponse(
attachment_id=created_attachment.attachment_id,
policy_name=created_attachment.policy_name,
scope=created_attachment.scope,
teams=created_attachment.teams or [],
keys=created_attachment.keys or [],
models=created_attachment.models or [],
tags=created_attachment.tags or [],
created_at=created_attachment.created_at,
updated_at=created_attachment.updated_at,
created_by=created_attachment.created_by,
updated_by=created_attachment.updated_by,
)
except Exception as e:
verbose_proxy_logger.exception(f"Error adding attachment to DB: {e}")
raise Exception(f"Error adding attachment to DB: {str(e)}")
async def delete_attachment_from_db(
self,
attachment_id: str,
prisma_client: "PrismaClient",
) -> Dict[str, str]:
"""
Delete a policy attachment from the database.
Args:
attachment_id: The ID of the attachment to delete
prisma_client: The Prisma client instance
Returns:
Dict with success message
"""
try:
# Get attachment before deleting
attachment = (
await prisma_client.db.litellm_policyattachmenttable.find_unique(
where={"attachment_id": attachment_id}
)
)
if attachment is None:
raise Exception(f"Attachment with ID {attachment_id} not found")
# Delete from DB
await prisma_client.db.litellm_policyattachmenttable.delete(
where={"attachment_id": attachment_id}
)
# Note: In-memory attachments don't have IDs, so we need to sync from DB
# to properly update in-memory state
await self.sync_attachments_from_db(prisma_client)
return {"message": f"Attachment {attachment_id} deleted successfully"}
except Exception as e:
verbose_proxy_logger.exception(f"Error deleting attachment from DB: {e}")
raise Exception(f"Error deleting attachment from DB: {str(e)}")
async def get_attachment_by_id_from_db(
self,
attachment_id: str,
prisma_client: "PrismaClient",
) -> Optional[PolicyAttachmentDBResponse]:
"""
Get a policy attachment by ID from the database.
Args:
attachment_id: The ID of the attachment to retrieve
prisma_client: The Prisma client instance
Returns:
PolicyAttachmentDBResponse if found, None otherwise
"""
try:
attachment = (
await prisma_client.db.litellm_policyattachmenttable.find_unique(
where={"attachment_id": attachment_id}
)
)
if attachment is None:
return None
return PolicyAttachmentDBResponse(
attachment_id=attachment.attachment_id,
policy_name=attachment.policy_name,
scope=attachment.scope,
teams=attachment.teams or [],
keys=attachment.keys or [],
models=attachment.models or [],
tags=attachment.tags or [],
created_at=attachment.created_at,
updated_at=attachment.updated_at,
created_by=attachment.created_by,
updated_by=attachment.updated_by,
)
except Exception as e:
verbose_proxy_logger.exception(f"Error getting attachment from DB: {e}")
raise Exception(f"Error getting attachment from DB: {str(e)}")
async def get_all_attachments_from_db(
self,
prisma_client: "PrismaClient",
) -> List[PolicyAttachmentDBResponse]:
"""
Get all policy attachments from the database.
Args:
prisma_client: The Prisma client instance
Returns:
List of PolicyAttachmentDBResponse objects
"""
try:
attachments = (
await prisma_client.db.litellm_policyattachmenttable.find_many(
order={"created_at": "desc"},
)
)
return [
PolicyAttachmentDBResponse(
attachment_id=a.attachment_id,
policy_name=a.policy_name,
scope=a.scope,
teams=a.teams or [],
keys=a.keys or [],
models=a.models or [],
tags=a.tags or [],
created_at=a.created_at,
updated_at=a.updated_at,
created_by=a.created_by,
updated_by=a.updated_by,
)
for a in attachments
]
except Exception as e:
verbose_proxy_logger.exception(f"Error getting attachments from DB: {e}")
raise Exception(f"Error getting attachments from DB: {str(e)}")
async def sync_attachments_from_db(
self,
prisma_client: "PrismaClient",
) -> None:
"""
Sync policy attachments from the database to in-memory registry.
Args:
prisma_client: The Prisma client instance
"""
try:
attachments = await self.get_all_attachments_from_db(prisma_client)
# Clear existing attachments and reload from DB
self._attachments = []
for attachment_response in attachments:
attachment = PolicyAttachment(
policy=attachment_response.policy_name,
scope=attachment_response.scope,
teams=attachment_response.teams
if attachment_response.teams
else None,
keys=attachment_response.keys if attachment_response.keys else None,
models=attachment_response.models
if attachment_response.models
else None,
tags=attachment_response.tags if attachment_response.tags else None,
)
self._attachments.append(attachment)
self._initialized = True
verbose_proxy_logger.info(
f"Synced {len(attachments)} attachments from DB to in-memory registry"
)
except Exception as e:
verbose_proxy_logger.exception(f"Error syncing attachments from DB: {e}")
raise Exception(f"Error syncing attachments from DB: {str(e)}")
# Global singleton instance
_attachment_registry: Optional[AttachmentRegistry] = None
def get_attachment_registry() -> AttachmentRegistry:
"""
Get the global AttachmentRegistry singleton.
Returns:
The global AttachmentRegistry instance
"""
global _attachment_registry
if _attachment_registry is None:
_attachment_registry = AttachmentRegistry()
return _attachment_registry

View File

@@ -0,0 +1,111 @@
"""
Condition Evaluator - Evaluates policy conditions.
Supports model-based conditions with exact match or regex patterns.
"""
import re
from typing import List, Optional, Union
from litellm._logging import verbose_proxy_logger
from litellm.types.proxy.policy_engine import (
PolicyCondition,
PolicyMatchContext,
)
class ConditionEvaluator:
"""
Evaluates policy conditions against request context.
Supports model conditions with:
- Exact string match: "gpt-4"
- Regex pattern: "gpt-4.*"
- List of values: ["gpt-4", "gpt-4-turbo"]
"""
@staticmethod
def evaluate(
condition: Optional[PolicyCondition],
context: PolicyMatchContext,
) -> bool:
"""
Evaluate a policy condition against a request context.
Args:
condition: The condition to evaluate (None = always matches)
context: The request context with team, key, model
Returns:
True if condition matches, False otherwise
"""
# No condition means always matches
if condition is None:
return True
# Check model condition
if condition.model is not None:
if not ConditionEvaluator._evaluate_model_condition(
condition=condition.model,
model=context.model,
):
verbose_proxy_logger.debug(
f"Condition failed: model={context.model} did not match {condition.model}"
)
return False
return True
@staticmethod
def _evaluate_model_condition(
condition: Union[str, List[str]],
model: Optional[str],
) -> bool:
"""
Evaluate a model condition.
Args:
condition: String (exact or regex) or list of strings
model: The model name to check
Returns:
True if model matches condition, False otherwise
"""
if model is None:
return False
# Handle list of values
if isinstance(condition, list):
return any(
ConditionEvaluator._matches_pattern(pattern, model)
for pattern in condition
)
# Single value - check as pattern
return ConditionEvaluator._matches_pattern(condition, model)
@staticmethod
def _matches_pattern(pattern: str, value: str) -> bool:
"""
Check if value matches pattern (exact match or regex).
Args:
pattern: Pattern to match (exact string or regex)
value: Value to check
Returns:
True if matches, False otherwise
"""
# First try exact match
if pattern == value:
return True
# Try as regex pattern
try:
if re.fullmatch(pattern, value):
return True
except re.error:
# Invalid regex, treat as literal string (already checked above)
pass
return False

View File

@@ -0,0 +1,286 @@
"""
Policy Initialization - Loads policies from config and validates on startup.
Configuration structure:
- policies: Define WHAT guardrails to apply (with inheritance and conditions)
- policy_attachments: Define WHERE policies apply (teams, keys, models)
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
from litellm.proxy.policy_engine.policy_validator import PolicyValidator
from litellm.types.proxy.policy_engine import PolicyValidationResponse
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
# ANSI color codes for terminal output
_green_color_code = "\033[92m"
_blue_color_code = "\033[94m"
_yellow_color_code = "\033[93m"
_reset_color_code = "\033[0m"
def _print_policies_on_startup(
policies_config: Dict[str, Any],
policy_attachments_config: Optional[List[Dict[str, Any]]] = None,
) -> None:
"""
Print loaded policies to console on startup (similar to model list).
"""
import sys
print( # noqa: T201
f"{_green_color_code}\nLiteLLM Policy Engine: Loaded {len(policies_config)} policies{_reset_color_code}\n"
)
sys.stdout.flush()
for policy_name, policy_data in policies_config.items():
guardrails = policy_data.get("guardrails", {})
inherit = policy_data.get("inherit")
condition = policy_data.get("condition")
description = policy_data.get("description")
guardrails_add = (
guardrails.get("add", []) if isinstance(guardrails, dict) else []
)
guardrails_remove = (
guardrails.get("remove", []) if isinstance(guardrails, dict) else []
)
inherit_str = f" (inherits: {inherit})" if inherit else ""
print( # noqa: T201
f"{_blue_color_code} - {policy_name}{inherit_str}{_reset_color_code}"
)
if description:
print(f" description: {description}") # noqa: T201
if guardrails_add:
print(f" guardrails.add: {guardrails_add}") # noqa: T201
if guardrails_remove:
print(f" guardrails.remove: {guardrails_remove}") # noqa: T201
if condition:
model_condition = (
condition.get("model") if isinstance(condition, dict) else None
)
if model_condition:
print(f" condition.model: {model_condition}") # noqa: T201
# Print attachments
if policy_attachments_config:
print( # noqa: T201
f"\n{_yellow_color_code}Policy Attachments: {len(policy_attachments_config)} attachment(s){_reset_color_code}"
)
for attachment in policy_attachments_config:
policy = attachment.get("policy", "unknown")
scope = attachment.get("scope")
teams = attachment.get("teams")
keys = attachment.get("keys")
models = attachment.get("models")
scope_parts = []
if scope == "*":
scope_parts.append("scope=* (global)")
if teams:
scope_parts.append(f"teams={teams}")
if keys:
scope_parts.append(f"keys={keys}")
if models:
scope_parts.append(f"models={models}")
scope_str = ", ".join(scope_parts) if scope_parts else "all"
print(f" - {policy} -> {scope_str}") # noqa: T201
else:
print( # noqa: T201
f"\n{_yellow_color_code}Warning: No policy_attachments configured. Policies will not be applied to any requests.{_reset_color_code}"
)
print() # noqa: T201
sys.stdout.flush()
async def init_policies(
policies_config: Dict[str, Any],
policy_attachments_config: Optional[List[Dict[str, Any]]] = None,
prisma_client: Optional["PrismaClient"] = None,
validate_db: bool = True,
fail_on_error: bool = True,
) -> PolicyValidationResponse:
"""
Initialize policies from configuration.
This function:
1. Parses the policy configuration
2. Validates policies (guardrails exist, teams/keys exist in DB)
3. Loads policies into the global registry
4. Loads attachments into the attachment registry (if provided)
Args:
policies_config: Dictionary mapping policy names to policy definitions
policy_attachments_config: Optional list of policy attachment configurations
prisma_client: Optional Prisma client for database validation
validate_db: Whether to validate team/key aliases against database
fail_on_error: If True, raise exception on validation errors
Returns:
PolicyValidationResponse with validation results
Raises:
ValueError: If fail_on_error is True and validation errors are found
"""
verbose_proxy_logger.info(f"Initializing {len(policies_config)} policies...")
# Print policies to console on startup
_print_policies_on_startup(policies_config, policy_attachments_config)
# Get the global registries
policy_registry = get_policy_registry()
attachment_registry = get_attachment_registry()
# Create validator
validator = PolicyValidator(prisma_client=prisma_client)
# Validate the configuration
validation_result = await validator.validate_policy_config(
policies_config,
validate_db=validate_db,
)
# Log validation results
if validation_result.errors:
for error in validation_result.errors:
verbose_proxy_logger.error(
f"Policy validation error in '{error.policy_name}': "
f"[{error.error_type}] {error.message}"
)
if validation_result.warnings:
for warning in validation_result.warnings:
verbose_proxy_logger.warning(
f"Policy validation warning in '{warning.policy_name}': "
f"[{warning.error_type}] {warning.message}"
)
# Fail if there are errors and fail_on_error is True
if not validation_result.valid and fail_on_error:
error_messages = [
f"[{e.policy_name}] {e.message}" for e in validation_result.errors
]
raise ValueError(
f"Policy validation failed with {len(validation_result.errors)} error(s):\n"
+ "\n".join(error_messages)
)
# Load policies into registry (even with warnings)
try:
policy_registry.load_policies(policies_config)
verbose_proxy_logger.info(
f"Successfully loaded {len(policies_config)} policies"
)
except Exception as e:
verbose_proxy_logger.error(f"Failed to load policies: {str(e)}")
raise
# Load attachments if provided
if policy_attachments_config:
try:
attachment_registry.load_attachments(policy_attachments_config)
verbose_proxy_logger.info(
f"Successfully loaded {len(policy_attachments_config)} policy attachments"
)
except Exception as e:
verbose_proxy_logger.error(f"Failed to load policy attachments: {str(e)}")
raise
return validation_result
def init_policies_sync(
policies_config: Dict[str, Any],
policy_attachments_config: Optional[List[Dict[str, Any]]] = None,
fail_on_error: bool = True,
) -> None:
"""
Synchronous version of init_policies (without DB validation).
Use this when async is not available or DB validation is not needed.
Args:
policies_config: Dictionary mapping policy names to policy definitions
policy_attachments_config: Optional list of policy attachment configurations
fail_on_error: If True, raise exception on validation errors
"""
import asyncio
# Run the async function without DB validation
try:
loop = asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(
init_policies(
policies_config=policies_config,
policy_attachments_config=policy_attachments_config,
prisma_client=None,
validate_db=False,
fail_on_error=fail_on_error,
)
)
def get_policies_summary() -> Dict[str, Any]:
"""
Get a summary of loaded policies for debugging/display.
Returns:
Dictionary with policy information
"""
from litellm.proxy.policy_engine.policy_resolver import PolicyResolver
policy_registry = get_policy_registry()
attachment_registry = get_attachment_registry()
if not policy_registry.is_initialized():
return {"initialized": False, "policies": {}, "attachments": []}
resolved = PolicyResolver.get_all_resolved_policies()
summary: Dict[str, Any] = {
"initialized": True,
"policy_count": len(resolved),
"attachment_count": len(attachment_registry.get_all_attachments()),
"policies": {},
"attachments": [],
}
for policy_name, resolved_policy in resolved.items():
policy = policy_registry.get_policy(policy_name)
summary["policies"][policy_name] = {
"inherit": policy.inherit if policy else None,
"description": policy.description if policy else None,
"guardrails_add": policy.guardrails.get_add() if policy else [],
"guardrails_remove": policy.guardrails.get_remove() if policy else [],
"condition": policy.condition.model_dump()
if policy and policy.condition
else None,
"resolved_guardrails": resolved_policy.guardrails,
"inheritance_chain": resolved_policy.inheritance_chain,
}
# Add attachment info
for attachment in attachment_registry.get_all_attachments():
summary["attachments"].append(
{
"policy": attachment.policy,
"scope": attachment.scope,
"teams": attachment.teams,
"keys": attachment.keys,
"models": attachment.models,
}
)
return summary

View File

@@ -0,0 +1,217 @@
"""
Pipeline Executor - Executes guardrail pipelines with conditional step logic.
Runs guardrails sequentially per pipeline step definitions, handling
pass/fail actions (allow, block, next, modify_response) and data forwarding.
"""
import time
from typing import Any, List, Optional
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
ModifyResponseException,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import (
UnifiedLLMGuardrails,
)
from litellm.types.proxy.policy_engine.pipeline_types import (
PipelineExecutionResult,
PipelineStep,
PipelineStepResult,
)
try:
from fastapi.exceptions import HTTPException
except ImportError:
HTTPException = None # type: ignore
class PipelineExecutor:
"""Executes guardrail pipelines with ordered, conditional step logic."""
@staticmethod
async def execute_steps(
steps: List[PipelineStep],
mode: str,
data: dict,
user_api_key_dict: Any,
call_type: str,
policy_name: str,
) -> PipelineExecutionResult:
"""
Execute pipeline steps sequentially with conditional actions.
Args:
steps: Ordered list of pipeline steps
mode: Event hook mode (pre_call, post_call)
data: Request data dict
user_api_key_dict: User API key auth
call_type: Type of call (completion, etc.)
policy_name: Name of the owning policy (for logging)
Returns:
PipelineExecutionResult with terminal action and step results
"""
step_results: List[PipelineStepResult] = []
working_data = data.copy()
if "metadata" in working_data:
working_data["metadata"] = working_data["metadata"].copy()
for i, step in enumerate(steps):
start_time = time.perf_counter()
outcome, modified_data, error_detail = await PipelineExecutor._run_step(
step=step,
mode=mode,
data=working_data,
user_api_key_dict=user_api_key_dict,
call_type=call_type,
)
duration = time.perf_counter() - start_time
action = step.on_pass if outcome == "pass" else step.on_fail
step_result = PipelineStepResult(
guardrail_name=step.guardrail,
outcome=outcome,
action_taken=action,
modified_data=modified_data,
error_detail=error_detail,
duration_seconds=round(duration, 4),
)
step_results.append(step_result)
verbose_proxy_logger.debug(
f"Pipeline '{policy_name}' step {i}: guardrail={step.guardrail}, "
f"outcome={outcome}, action={action}"
)
# Forward modified data to next step if pass_data is True
if step.pass_data and modified_data is not None:
working_data = {**working_data, **modified_data}
# Handle terminal actions
if action == "allow":
return PipelineExecutionResult(
terminal_action="allow",
step_results=step_results,
modified_data=working_data if working_data != data else None,
)
if action == "block":
return PipelineExecutionResult(
terminal_action="block",
step_results=step_results,
error_message=error_detail,
)
if action == "modify_response":
return PipelineExecutionResult(
terminal_action="modify_response",
step_results=step_results,
modify_response_message=step.modify_response_message
or error_detail,
)
# action == "next" → continue to next step
# Ran out of steps without a terminal action → default allow
return PipelineExecutionResult(
terminal_action="allow",
step_results=step_results,
modified_data=working_data if working_data != data else None,
)
@staticmethod
async def _run_step(
step: PipelineStep,
mode: str,
data: dict,
user_api_key_dict: Any,
call_type: str,
) -> tuple:
"""
Run a single pipeline step's guardrail.
Returns:
Tuple of (outcome, modified_data, error_detail) where:
- outcome: "pass", "fail", or "error"
- modified_data: dict if guardrail returned modified data, else None
- error_detail: error message string if fail/error, else None
"""
callback = PipelineExecutor._find_guardrail_callback(step.guardrail)
if callback is None:
verbose_proxy_logger.warning(
f"Pipeline: guardrail '{step.guardrail}' not found in callbacks"
)
return ("error", None, f"Guardrail '{step.guardrail}' not found")
try:
# Inject guardrail name into metadata so should_run_guardrail() allows it
if "metadata" not in data:
data["metadata"] = {}
data["metadata"]["guardrails"] = [step.guardrail]
# Use unified_guardrail path if callback implements apply_guardrail
target: CustomLogger = callback
use_unified = "apply_guardrail" in type(callback).__dict__
if use_unified:
data["guardrail_to_apply"] = callback
target = UnifiedLLMGuardrails()
if mode == "pre_call":
response = await target.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=None, # type: ignore
data=data,
call_type=call_type, # type: ignore
)
elif mode == "post_call":
response = await target.async_post_call_success_hook(
user_api_key_dict=user_api_key_dict,
data=data,
response=data.get("response"), # type: ignore
)
else:
return ("error", None, f"Unsupported pipeline mode: {mode}")
# Normal return means pass
modified_data = None
if response is not None and isinstance(response, dict):
modified_data = response
return ("pass", modified_data, None)
except Exception as e:
if CustomGuardrail._is_guardrail_intervention(e):
error_msg = _extract_error_message(e)
return ("fail", None, error_msg)
else:
verbose_proxy_logger.error(
f"Pipeline: unexpected error from guardrail '{step.guardrail}': {e}"
)
return ("error", None, str(e))
@staticmethod
def _find_guardrail_callback(guardrail_name: str) -> Optional[CustomGuardrail]:
"""Look up an initialized guardrail callback by name from litellm.callbacks."""
for callback in litellm.callbacks:
if isinstance(callback, CustomGuardrail):
if callback.guardrail_name == guardrail_name:
return callback
return None
def _extract_error_message(e: Exception) -> str:
"""Extract a human-readable error message from a guardrail exception."""
if isinstance(e, ModifyResponseException):
return str(e)
if HTTPException is not None and isinstance(e, HTTPException):
detail = getattr(e, "detail", None)
if detail:
return str(detail)
return str(e)

View File

@@ -0,0 +1,834 @@
"""
CRUD ENDPOINTS FOR POLICIES
Provides REST API endpoints for managing policies and policy attachments.
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
from litellm.proxy.policy_engine.pipeline_executor import PipelineExecutor
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
from litellm.types.proxy.policy_engine import (
GuardrailPipeline,
PipelineTestRequest,
PolicyAttachmentCreateRequest,
PolicyAttachmentDBResponse,
PolicyAttachmentListResponse,
PolicyCreateRequest,
PolicyDBResponse,
PolicyListDBResponse,
PolicyUpdateRequest,
PolicyVersionCompareResponse,
PolicyVersionCreateRequest,
PolicyVersionListResponse,
PolicyVersionStatusUpdateRequest,
)
router = APIRouter()
# ─────────────────────────────────────────────────────────────────────────────
# Policy CRUD Endpoints
# ─────────────────────────────────────────────────────────────────────────────
@router.get(
"/policies/list",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
response_model=PolicyListDBResponse,
)
async def list_policies(version_status: Optional[str] = None):
"""
List all policies from the database. Optionally filter by version_status.
Query params:
- version_status: Optional. One of "draft", "published", "production".
If omitted, all versions are returned.
Example Request:
```bash
curl -X GET "http://localhost:4000/policies/list" \\
-H "Authorization: Bearer <your_api_key>"
curl -X GET "http://localhost:4000/policies/list?version_status=production" \\
-H "Authorization: Bearer <your_api_key>"
```
Example Response:
```json
{
"policies": [
{
"policy_id": "123e4567-e89b-12d3-a456-426614174000",
"policy_name": "global-baseline",
"version_number": 1,
"version_status": "production",
"inherit": null,
"description": "Base guardrails for all requests",
"guardrails_add": ["pii_masking"],
"guardrails_remove": [],
"condition": null,
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
],
"total_count": 1
}
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
policies = await get_policy_registry().get_all_policies_from_db(
prisma_client, version_status=version_status
)
return PolicyListDBResponse(policies=policies, total_count=len(policies))
except Exception as e:
verbose_proxy_logger.exception(f"Error listing policies: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/policies",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
response_model=PolicyDBResponse,
)
async def create_policy(
request: PolicyCreateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new policy.
Example Request:
```bash
curl -X POST "http://localhost:4000/policies" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"policy_name": "global-baseline",
"description": "Base guardrails for all requests",
"guardrails_add": ["pii_masking", "prompt_injection"],
"guardrails_remove": []
}'
```
Example Response:
```json
{
"policy_id": "123e4567-e89b-12d3-a456-426614174000",
"policy_name": "global-baseline",
"inherit": null,
"description": "Base guardrails for all requests",
"guardrails_add": ["pii_masking", "prompt_injection"],
"guardrails_remove": [],
"condition": null,
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
created_by = user_api_key_dict.user_id
result = await get_policy_registry().add_policy_to_db(
policy_request=request,
prisma_client=prisma_client,
created_by=created_by,
)
return result
except Exception as e:
verbose_proxy_logger.exception(f"Error creating policy: {e}")
if "unique constraint" in str(e).lower():
raise HTTPException(
status_code=400,
detail=f"Policy with name '{request.policy_name}' already exists",
)
raise HTTPException(status_code=500, detail=str(e))
# ─────────────────────────────────────────────────────────────────────────────
# Policy Versioning Endpoints (must be before /policies/{policy_id} to avoid path conflicts)
# ─────────────────────────────────────────────────────────────────────────────
@router.get(
"/policies/name/{policy_name}/versions",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
response_model=PolicyVersionListResponse,
)
async def list_policy_versions(policy_name: str):
"""
List all versions of a policy by name, ordered by version_number descending.
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
return await get_policy_registry().get_versions_by_policy_name(
policy_name=policy_name,
prisma_client=prisma_client,
)
except Exception as e:
verbose_proxy_logger.exception(f"Error listing policy versions: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/policies/name/{policy_name}/versions",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
response_model=PolicyDBResponse,
)
async def create_policy_version(
policy_name: str,
request: PolicyVersionCreateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new draft version of a policy. Copies all fields from the source.
Source is current production if source_policy_id is not provided.
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
created_by = user_api_key_dict.user_id
return await get_policy_registry().create_new_version(
policy_name=policy_name,
prisma_client=prisma_client,
source_policy_id=request.source_policy_id,
created_by=created_by,
)
except Exception as e:
verbose_proxy_logger.exception(f"Error creating policy version: {e}")
if "not found" in str(e).lower() or "no production" in str(e).lower():
raise HTTPException(status_code=404, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))
@router.put(
"/policies/{policy_id}/status",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
response_model=PolicyDBResponse,
)
async def update_policy_version_status(
policy_id: str,
request: PolicyVersionStatusUpdateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update a policy version's status. Valid transitions:
- draft -> published
- published -> production (demotes current production to published)
- production -> published (demotes, policy becomes inactive)
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
updated_by = user_api_key_dict.user_id
return await get_policy_registry().update_version_status(
policy_id=policy_id,
new_status=request.version_status,
prisma_client=prisma_client,
updated_by=updated_by,
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error updating version status: {e}")
if (
"invalid status" in str(e).lower()
or "only draft" in str(e).lower()
or "cannot promote" in str(e).lower()
):
raise HTTPException(status_code=400, detail=str(e))
if "not found" in str(e).lower():
raise HTTPException(status_code=404, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/policies/compare",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
response_model=PolicyVersionCompareResponse,
)
async def compare_policy_versions(
version_a: str,
version_b: str,
):
"""
Compare two policy versions. Query params: version_a, version_b (policy version IDs).
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
return await get_policy_registry().compare_versions(
policy_id_a=version_a,
policy_id_b=version_b,
prisma_client=prisma_client,
)
except Exception as e:
verbose_proxy_logger.exception(f"Error comparing versions: {e}")
if "not found" in str(e).lower():
raise HTTPException(status_code=404, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))
@router.delete(
"/policies/name/{policy_name}/all-versions",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_all_policy_versions(policy_name: str):
"""
Delete all versions of a policy. Also removes from in-memory registry.
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
return await get_policy_registry().delete_all_versions(
policy_name=policy_name,
prisma_client=prisma_client,
)
except Exception as e:
verbose_proxy_logger.exception(f"Error deleting all versions: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ─────────────────────────────────────────────────────────────────────────────
# Policy CRUD by ID
# ─────────────────────────────────────────────────────────────────────────────
@router.get(
"/policies/{policy_id}",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
response_model=PolicyDBResponse,
)
async def get_policy(policy_id: str):
"""
Get a policy by ID.
Example Request:
```bash
curl -X GET "http://localhost:4000/policies/123e4567-e89b-12d3-a456-426614174000" \\
-H "Authorization: Bearer <your_api_key>"
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
result = await get_policy_registry().get_policy_by_id_from_db(
policy_id=policy_id,
prisma_client=prisma_client,
)
if result is None:
raise HTTPException(
status_code=404, detail=f"Policy with ID {policy_id} not found"
)
return result
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error getting policy: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.put(
"/policies/{policy_id}",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
response_model=PolicyDBResponse,
)
async def update_policy(
policy_id: str,
request: PolicyUpdateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Update an existing policy.
Example Request:
```bash
curl -X PUT "http://localhost:4000/policies/123e4567-e89b-12d3-a456-426614174000" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"description": "Updated description",
"guardrails_add": ["pii_masking", "toxicity_filter"]
}'
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Check if policy exists and is draft (only drafts can be updated)
existing = await get_policy_registry().get_policy_by_id_from_db(
policy_id=policy_id,
prisma_client=prisma_client,
)
if existing is None:
raise HTTPException(
status_code=404, detail=f"Policy with ID {policy_id} not found"
)
if getattr(existing, "version_status", "production") != "draft":
raise HTTPException(
status_code=400,
detail="Only draft versions can be updated. Publish or create a new version to change published/production.",
)
updated_by = user_api_key_dict.user_id
result = await get_policy_registry().update_policy_in_db(
policy_id=policy_id,
policy_request=request,
prisma_client=prisma_client,
updated_by=updated_by,
)
return result
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error updating policy: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete(
"/policies/{policy_id}",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_policy(policy_id: str):
"""
Delete a policy.
Example Request:
```bash
curl -X DELETE "http://localhost:4000/policies/123e4567-e89b-12d3-a456-426614174000" \\
-H "Authorization: Bearer <your_api_key>"
```
Example Response:
```json
{
"message": "Policy 123e4567-e89b-12d3-a456-426614174000 deleted successfully"
}
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Check if policy exists
existing = await get_policy_registry().get_policy_by_id_from_db(
policy_id=policy_id,
prisma_client=prisma_client,
)
if existing is None:
raise HTTPException(
status_code=404, detail=f"Policy with ID {policy_id} not found"
)
result = await get_policy_registry().delete_policy_from_db(
policy_id=policy_id,
prisma_client=prisma_client,
)
# Result may include "warning" if production was deleted
return result
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error deleting policy: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/policies/{policy_id}/resolved-guardrails",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_resolved_guardrails(policy_id: str):
"""
Get the resolved guardrails for a policy (including inherited guardrails).
This endpoint resolves the full inheritance chain and returns the final
set of guardrails that would be applied for this policy.
Example Request:
```bash
curl -X GET "http://localhost:4000/policies/123e4567-e89b-12d3-a456-426614174000/resolved-guardrails" \\
-H "Authorization: Bearer <your_api_key>"
```
Example Response:
```json
{
"policy_id": "123e4567-e89b-12d3-a456-426614174000",
"policy_name": "healthcare-compliance",
"resolved_guardrails": ["pii_masking", "prompt_injection", "toxicity_filter"]
}
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Get the policy
policy = await get_policy_registry().get_policy_by_id_from_db(
policy_id=policy_id,
prisma_client=prisma_client,
)
if policy is None:
raise HTTPException(
status_code=404, detail=f"Policy with ID {policy_id} not found"
)
# Resolve guardrails
resolved = await get_policy_registry().resolve_guardrails_from_db(
policy_name=policy.policy_name,
prisma_client=prisma_client,
)
return {
"policy_id": policy.policy_id,
"policy_name": policy.policy_name,
"resolved_guardrails": resolved,
}
except HTTPException:
raise
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
verbose_proxy_logger.exception(f"Error resolving guardrails: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ─────────────────────────────────────────────────────────────────────────────
# Pipeline Test Endpoint
# ─────────────────────────────────────────────────────────────────────────────
@router.post(
"/policies/test-pipeline",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
)
async def test_pipeline(
request: PipelineTestRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Test a guardrail pipeline with sample messages.
Executes the pipeline steps against the provided test messages and returns
step-by-step results showing which guardrails passed/failed, actions taken,
and timing information.
Example Request:
```bash
curl -X POST "http://localhost:4000/policies/test-pipeline" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"pipeline": {
"mode": "pre_call",
"steps": [
{"guardrail": "pii-guard", "on_pass": "next", "on_fail": "block"}
]
},
"test_messages": [{"role": "user", "content": "My SSN is 123-45-6789"}]
}'
```
"""
try:
validated_pipeline = GuardrailPipeline(**request.pipeline)
except Exception as e:
raise HTTPException(status_code=400, detail=f"Invalid pipeline: {e}")
data = {
"messages": request.test_messages,
"model": "test",
"metadata": {},
}
try:
result = await PipelineExecutor.execute_steps(
steps=validated_pipeline.steps,
mode=validated_pipeline.mode,
data=data,
user_api_key_dict=user_api_key_dict,
call_type="completion",
policy_name="test-pipeline",
)
return result.model_dump()
except Exception as e:
verbose_proxy_logger.exception(f"Error testing pipeline: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ─────────────────────────────────────────────────────────────────────────────
# Policy Attachment CRUD Endpoints
# ─────────────────────────────────────────────────────────────────────────────
@router.get(
"/policies/attachments/list",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
response_model=PolicyAttachmentListResponse,
)
async def list_policy_attachments():
"""
List all policy attachments from the database.
Example Request:
```bash
curl -X GET "http://localhost:4000/policies/attachments/list" \\
-H "Authorization: Bearer <your_api_key>"
```
Example Response:
```json
{
"attachments": [
{
"attachment_id": "123e4567-e89b-12d3-a456-426614174000",
"policy_name": "global-baseline",
"scope": "*",
"teams": [],
"keys": [],
"models": [],
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
],
"total_count": 1
}
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
attachments = await get_attachment_registry().get_all_attachments_from_db(
prisma_client
)
return PolicyAttachmentListResponse(
attachments=attachments, total_count=len(attachments)
)
except Exception as e:
verbose_proxy_logger.exception(f"Error listing policy attachments: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/policies/attachments",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
response_model=PolicyAttachmentDBResponse,
)
async def create_policy_attachment(
request: PolicyAttachmentCreateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Create a new policy attachment.
Example Request:
```bash
curl -X POST "http://localhost:4000/policies/attachments" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"policy_name": "global-baseline",
"scope": "*"
}'
```
Example with team-specific attachment:
```bash
curl -X POST "http://localhost:4000/policies/attachments" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"policy_name": "healthcare-compliance",
"teams": ["healthcare-team", "medical-research"]
}'
```
Example Response:
```json
{
"attachment_id": "123e4567-e89b-12d3-a456-426614174000",
"policy_name": "global-baseline",
"scope": "*",
"teams": [],
"keys": [],
"models": [],
"created_at": "2024-01-01T00:00:00Z",
"updated_at": "2024-01-01T00:00:00Z"
}
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Verify the policy has a production version (attachments resolve against production)
policies = await get_policy_registry().get_all_policies_from_db(
prisma_client, version_status="production"
)
policy_names = {p.policy_name for p in policies}
if request.policy_name not in policy_names:
raise HTTPException(
status_code=404,
detail=f"Policy '{request.policy_name}' not found. Create the policy first.",
)
created_by = user_api_key_dict.user_id
result = await get_attachment_registry().add_attachment_to_db(
attachment_request=request,
prisma_client=prisma_client,
created_by=created_by,
)
return result
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error creating policy attachment: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get(
"/policies/attachments/{attachment_id}",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
response_model=PolicyAttachmentDBResponse,
)
async def get_policy_attachment(attachment_id: str):
"""
Get a policy attachment by ID.
Example Request:
```bash
curl -X GET "http://localhost:4000/policies/attachments/123e4567-e89b-12d3-a456-426614174000" \\
-H "Authorization: Bearer <your_api_key>"
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
result = await get_attachment_registry().get_attachment_by_id_from_db(
attachment_id=attachment_id,
prisma_client=prisma_client,
)
if result is None:
raise HTTPException(
status_code=404,
detail=f"Attachment with ID {attachment_id} not found",
)
return result
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error getting policy attachment: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.delete(
"/policies/attachments/{attachment_id}",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_policy_attachment(attachment_id: str):
"""
Delete a policy attachment.
Example Request:
```bash
curl -X DELETE "http://localhost:4000/policies/attachments/123e4567-e89b-12d3-a456-426614174000" \\
-H "Authorization: Bearer <your_api_key>"
```
Example Response:
```json
{
"message": "Attachment 123e4567-e89b-12d3-a456-426614174000 deleted successfully"
}
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Check if attachment exists
existing = await get_attachment_registry().get_attachment_by_id_from_db(
attachment_id=attachment_id,
prisma_client=prisma_client,
)
if existing is None:
raise HTTPException(
status_code=404,
detail=f"Attachment with ID {attachment_id} not found",
)
result = await get_attachment_registry().delete_attachment_from_db(
attachment_id=attachment_id,
prisma_client=prisma_client,
)
return result
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error deleting policy attachment: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,180 @@
"""
Policy Matcher - Matches requests against policy attachments.
Uses existing wildcard pattern matching helpers to determine which policies
apply to a given request based on team alias, key alias, and model.
Policies are matched via policy_attachments which define WHERE each policy applies.
"""
from typing import Dict, List, Optional
from litellm._logging import verbose_proxy_logger
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.types.proxy.policy_engine import Policy, PolicyMatchContext, PolicyScope
class PolicyMatcher:
"""
Matches incoming requests against policy attachments.
Supports wildcard patterns:
- "*" matches everything
- "prefix-*" matches anything starting with "prefix-"
Uses policy_attachments to determine which policies apply to a request.
"""
@staticmethod
def matches_pattern(value: Optional[str], patterns: List[str]) -> bool:
"""
Check if a value matches any of the given patterns.
Uses the existing RouteChecks._route_matches_wildcard_pattern helper.
Args:
value: The value to check (e.g., team alias, key alias, model)
patterns: List of patterns to match against
Returns:
True if value matches any pattern, False otherwise
"""
# If no value provided, only match if patterns include "*"
if value is None:
return "*" in patterns
for pattern in patterns:
# Use existing wildcard pattern matching helper
if RouteChecks._route_matches_wildcard_pattern(
route=value, pattern=pattern
):
return True
return False
@staticmethod
def scope_matches(scope: PolicyScope, context: PolicyMatchContext) -> bool:
"""
Check if a policy scope matches the given context.
A scope matches if ALL of its fields match:
- teams matches context.team_alias
- keys matches context.key_alias
- models matches context.model
Args:
scope: The policy scope to check
context: The request context
Returns:
True if scope matches context, False otherwise
"""
# Check teams
if not PolicyMatcher.matches_pattern(context.team_alias, scope.get_teams()):
return False
# Check keys
if not PolicyMatcher.matches_pattern(context.key_alias, scope.get_keys()):
return False
# Check models
if not PolicyMatcher.matches_pattern(context.model, scope.get_models()):
return False
# Check tags (only if scope specifies tags)
# Unlike teams/keys/models, empty tags means "do not check" rather than "match all"
scope_tags = scope.get_tags()
if scope_tags:
if not context.tags:
return False
# Match if ANY context tag matches ANY scope tag pattern
if not any(
PolicyMatcher.matches_pattern(tag, scope_tags) for tag in context.tags
):
return False
return True
@staticmethod
def get_matching_policies(
context: PolicyMatchContext,
) -> List[str]:
"""
Get list of policy names that match the given context via attachments.
Args:
context: The request context to match against
Returns:
List of policy names that match the context
"""
from litellm.proxy.policy_engine.attachment_registry import (
get_attachment_registry,
)
registry = get_attachment_registry()
if not registry.is_initialized():
verbose_proxy_logger.debug(
"AttachmentRegistry not initialized, returning empty list"
)
return []
return registry.get_attached_policies(context)
@staticmethod
def get_matching_policies_from_registry(
context: PolicyMatchContext,
) -> List[str]:
"""
Get list of policy names that match the given context from the global registry.
Args:
context: The request context to match against
Returns:
List of policy names that match the context
"""
return PolicyMatcher.get_matching_policies(context=context)
@staticmethod
def get_policies_with_matching_conditions(
policy_names: List[str],
context: PolicyMatchContext,
policies: Optional[Dict[str, Policy]] = None,
) -> List[str]:
"""
Filter policies to only those whose conditions match the context.
A policy's condition matches if:
- The policy has no condition (condition is None), OR
- The policy's condition evaluates to True for the given context
Args:
policy_names: List of policy names to filter
context: The request context to evaluate conditions against
policies: Dictionary of all policies (if None, uses global registry)
Returns:
List of policy names whose conditions match the context
"""
from litellm.proxy.policy_engine.condition_evaluator import ConditionEvaluator
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
if policies is None:
registry = get_policy_registry()
if not registry.is_initialized():
return []
policies = registry.get_all_policies()
matching_policies = []
for policy_name in policy_names:
policy = policies.get(policy_name)
if policy is None:
continue
# Policy matches if it has no condition OR condition evaluates to True
if policy.condition is None or ConditionEvaluator.evaluate(
policy.condition, context
):
matching_policies.append(policy_name)
return matching_policies

View File

@@ -0,0 +1,978 @@
"""
Policy Registry - In-memory storage for policies.
Handles storing, retrieving, and managing policies.
Policies define WHAT guardrails to apply. WHERE they apply is defined
by policy_attachments (see AttachmentRegistry).
"""
import json
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from litellm._logging import verbose_proxy_logger
from litellm.types.proxy.policy_engine import (
GuardrailPipeline,
PipelineStep,
Policy,
PolicyCondition,
PolicyCreateRequest,
PolicyDBResponse,
PolicyGuardrails,
PolicyUpdateRequest,
PolicyVersionCompareResponse,
PolicyVersionListResponse,
)
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
# Prefix for policy version IDs in request body. Use policy_<uuid> to execute a specific version.
POLICY_VERSION_ID_PREFIX = "policy_"
def _row_to_policy_db_response(row: Any) -> PolicyDBResponse:
"""Build PolicyDBResponse from a Prisma LiteLLM_PolicyTable row."""
return PolicyDBResponse(
policy_id=row.policy_id,
policy_name=row.policy_name,
version_number=getattr(row, "version_number", 1),
version_status=getattr(row, "version_status", "production"),
parent_version_id=getattr(row, "parent_version_id", None),
is_latest=getattr(row, "is_latest", True),
published_at=getattr(row, "published_at", None),
production_at=getattr(row, "production_at", None),
inherit=row.inherit,
description=row.description,
guardrails_add=row.guardrails_add or [],
guardrails_remove=row.guardrails_remove or [],
condition=row.condition,
pipeline=row.pipeline,
created_at=row.created_at,
updated_at=row.updated_at,
created_by=row.created_by,
updated_by=row.updated_by,
)
class PolicyRegistry:
"""
In-memory registry for storing and managing policies.
This is a singleton that holds all loaded policies and provides
methods to access them.
Policies define WHAT guardrails to apply:
- Base guardrails via guardrails.add/remove
- Inheritance via inherit field
- Conditional guardrails via condition.model
"""
def __init__(self):
self._policies: Dict[str, Policy] = {}
self._policies_by_id: Dict[str, Tuple[str, Policy]] = {}
self._initialized: bool = False
def load_policies(self, policies_config: Dict[str, Any]) -> None:
"""
Load policies from a configuration dictionary.
Args:
policies_config: Dictionary mapping policy names to policy definitions.
This is the raw config from the YAML file.
"""
self._policies = {}
self._policies_by_id = {}
for policy_name, policy_data in policies_config.items():
try:
policy = self._parse_policy(policy_name, policy_data)
self._policies[policy_name] = policy
verbose_proxy_logger.debug(f"Loaded policy: {policy_name}")
except Exception as e:
verbose_proxy_logger.error(
f"Error loading policy '{policy_name}': {str(e)}"
)
raise ValueError(f"Invalid policy '{policy_name}': {str(e)}") from e
self._initialized = True
verbose_proxy_logger.info(f"Loaded {len(self._policies)} policies")
def _parse_policy(self, policy_name: str, policy_data: Dict[str, Any]) -> Policy:
"""
Parse a policy from raw configuration data.
Args:
policy_name: Name of the policy
policy_data: Raw policy configuration
Returns:
Parsed Policy object
"""
# Parse guardrails
guardrails_data = policy_data.get("guardrails", {})
if isinstance(guardrails_data, dict):
guardrails = PolicyGuardrails(
add=guardrails_data.get("add"),
remove=guardrails_data.get("remove"),
)
else:
# Handle legacy format where guardrails might be a list
guardrails = PolicyGuardrails(
add=guardrails_data if guardrails_data else None
)
# Parse condition (simple model-based condition)
condition = None
condition_data = policy_data.get("condition")
if condition_data:
condition = PolicyCondition(model=condition_data.get("model"))
# Parse pipeline (optional ordered guardrail execution)
pipeline = PolicyRegistry._parse_pipeline(policy_data.get("pipeline"))
return Policy(
inherit=policy_data.get("inherit"),
description=policy_data.get("description"),
guardrails=guardrails,
condition=condition,
pipeline=pipeline,
)
@staticmethod
def _parse_pipeline(
pipeline_data: Optional[Dict[str, Any]],
) -> Optional[GuardrailPipeline]:
"""Parse a pipeline configuration from raw data."""
if pipeline_data is None:
return None
steps_data = pipeline_data.get("steps", [])
steps = [
PipelineStep(**step_data) if isinstance(step_data, dict) else step_data
for step_data in steps_data
]
return GuardrailPipeline(
mode=pipeline_data.get("mode", "pre_call"),
steps=steps,
)
def get_policy(self, policy_name: str) -> Optional[Policy]:
"""
Get a policy by name.
Args:
policy_name: Name of the policy to retrieve
Returns:
Policy object if found, None otherwise
"""
return self._policies.get(policy_name)
def get_all_policies(self) -> Dict[str, Policy]:
"""
Get all loaded policies.
Returns:
Dictionary mapping policy names to Policy objects
"""
return self._policies.copy()
def get_policy_names(self) -> List[str]:
"""
Get list of all policy names.
Returns:
List of policy names
"""
return list(self._policies.keys())
def has_policy(self, policy_name: str) -> bool:
"""
Check if a policy exists.
Args:
policy_name: Name of the policy to check
Returns:
True if policy exists, False otherwise
"""
return policy_name in self._policies
def is_initialized(self) -> bool:
"""
Check if the registry has been initialized with policies.
Returns:
True if policies have been loaded, False otherwise
"""
return self._initialized
def clear(self) -> None:
"""
Clear all policies from the registry.
"""
self._policies = {}
self._initialized = False
def add_policy(self, policy_name: str, policy: Policy) -> None:
"""
Add or update a single policy.
Args:
policy_name: Name of the policy
policy: Policy object to add
"""
self._policies[policy_name] = policy
verbose_proxy_logger.debug(f"Added/updated policy: {policy_name}")
def remove_policy(self, policy_name: str) -> bool:
"""
Remove a policy by name.
Args:
policy_name: Name of the policy to remove
Returns:
True if policy was removed, False if it didn't exist
"""
if policy_name in self._policies:
del self._policies[policy_name]
verbose_proxy_logger.debug(f"Removed policy: {policy_name}")
return True
return False
# ─────────────────────────────────────────────────────────────────────────
# Database CRUD Methods
# ─────────────────────────────────────────────────────────────────────────
async def add_policy_to_db(
self,
policy_request: PolicyCreateRequest,
prisma_client: "PrismaClient",
created_by: Optional[str] = None,
) -> PolicyDBResponse:
"""
Add a policy to the database.
Args:
policy_request: The policy creation request
prisma_client: The Prisma client instance
created_by: User who created the policy
Returns:
PolicyDBResponse with the created policy
"""
try:
now = datetime.now(timezone.utc)
# Build data dict; new policy is v1 production
data: Dict[str, Any] = {
"policy_name": policy_request.policy_name,
"version_number": 1,
"version_status": "production",
"is_latest": True,
"production_at": now,
"guardrails_add": policy_request.guardrails_add or [],
"guardrails_remove": policy_request.guardrails_remove or [],
"created_at": now,
"updated_at": now,
}
# Only add optional fields if they have values
if policy_request.inherit is not None:
data["inherit"] = policy_request.inherit
if policy_request.description is not None:
data["description"] = policy_request.description
if created_by is not None:
data["created_by"] = created_by
data["updated_by"] = created_by
if policy_request.condition is not None:
data["condition"] = json.dumps(policy_request.condition.model_dump())
if policy_request.pipeline is not None:
validated_pipeline = GuardrailPipeline(**policy_request.pipeline)
data["pipeline"] = json.dumps(validated_pipeline.model_dump())
created_policy = await prisma_client.db.litellm_policytable.create(
data=data
)
# Also add to in-memory registry
policy = self._parse_policy(
policy_request.policy_name,
{
"inherit": policy_request.inherit,
"description": policy_request.description,
"guardrails": {
"add": policy_request.guardrails_add,
"remove": policy_request.guardrails_remove,
},
"condition": (
policy_request.condition.model_dump()
if policy_request.condition
else None
),
"pipeline": policy_request.pipeline,
},
)
self.add_policy(policy_request.policy_name, policy)
return _row_to_policy_db_response(created_policy)
except Exception as e:
verbose_proxy_logger.exception(f"Error adding policy to DB: {e}")
raise Exception(f"Error adding policy to DB: {str(e)}")
async def update_policy_in_db(
self,
policy_id: str,
policy_request: PolicyUpdateRequest,
prisma_client: "PrismaClient",
updated_by: Optional[str] = None,
) -> PolicyDBResponse:
"""
Update a policy in the database. Only draft versions can be updated.
Args:
policy_id: The ID of the policy to update
policy_request: The policy update request
prisma_client: The Prisma client instance
updated_by: User who updated the policy
Returns:
PolicyDBResponse with the updated policy
Raises:
Exception: If policy is not in draft status (only drafts are editable).
"""
try:
existing = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": policy_id}
)
if existing is None:
raise Exception(f"Policy with ID {policy_id} not found")
version_status = getattr(existing, "version_status", "production")
if version_status != "draft":
raise Exception(
f"Only draft versions can be updated. This policy has status '{version_status}'."
)
# Build update data - only include fields that are set
update_data: Dict[str, Any] = {
"updated_at": datetime.now(timezone.utc),
"updated_by": updated_by,
}
if policy_request.policy_name is not None:
update_data["policy_name"] = policy_request.policy_name
if policy_request.inherit is not None:
update_data["inherit"] = policy_request.inherit
if policy_request.description is not None:
update_data["description"] = policy_request.description
if policy_request.guardrails_add is not None:
update_data["guardrails_add"] = policy_request.guardrails_add
if policy_request.guardrails_remove is not None:
update_data["guardrails_remove"] = policy_request.guardrails_remove
if policy_request.condition is not None:
update_data["condition"] = json.dumps(
policy_request.condition.model_dump()
)
if policy_request.pipeline is not None:
validated_pipeline = GuardrailPipeline(**policy_request.pipeline)
update_data["pipeline"] = json.dumps(validated_pipeline.model_dump())
updated_policy = await prisma_client.db.litellm_policytable.update(
where={"policy_id": policy_id},
data=update_data,
)
# Do NOT update in-memory registry: drafts are not loaded into memory.
return _row_to_policy_db_response(updated_policy)
except Exception as e:
verbose_proxy_logger.exception(f"Error updating policy in DB: {e}")
raise Exception(f"Error updating policy in DB: {str(e)}")
async def delete_policy_from_db(
self,
policy_id: str,
prisma_client: "PrismaClient",
) -> Dict[str, Any]:
"""
Delete a policy version from the database.
If the deleted version was production, it is removed from the in-memory
registry. No other version is auto-promoted; admin must explicitly promote.
Args:
policy_id: The ID of the policy version to delete
prisma_client: The Prisma client instance
Returns:
Dict with "message" and optional "warning" if production was deleted.
"""
try:
policy = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": policy_id}
)
if policy is None:
raise Exception(f"Policy with ID {policy_id} not found")
version_status = getattr(policy, "version_status", "production")
policy_name = policy.policy_name
# Delete from DB
await prisma_client.db.litellm_policytable.delete(
where={"policy_id": policy_id}
)
result: Dict[str, Any] = {
"message": f"Policy {policy_id} deleted successfully"
}
# Remove from in-memory registry only if this was the production version
if version_status == "production":
self.remove_policy(policy_name)
result["warning"] = (
"Production version was deleted. No other version was promoted. "
"Promote another version to production if this policy should remain active."
)
return result
except Exception as e:
verbose_proxy_logger.exception(f"Error deleting policy from DB: {e}")
raise Exception(f"Error deleting policy from DB: {str(e)}")
async def get_policy_by_id_from_db(
self,
policy_id: str,
prisma_client: "PrismaClient",
) -> Optional[PolicyDBResponse]:
"""
Get a policy by ID from the database.
Args:
policy_id: The ID of the policy to retrieve
prisma_client: The Prisma client instance
Returns:
PolicyDBResponse if found, None otherwise
"""
try:
policy = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": policy_id}
)
if policy is None:
return None
return _row_to_policy_db_response(policy)
except Exception as e:
verbose_proxy_logger.exception(f"Error getting policy from DB: {e}")
raise Exception(f"Error getting policy from DB: {str(e)}")
def get_policy_by_id_for_request(
self, policy_id: str
) -> Optional[Tuple[str, Policy]]:
"""
Return a policy version by ID from in-memory cache (no DB access).
Used when the request body specifies policy_<uuid> to execute a specific version
(e.g. published or draft). The cache is populated by sync_policies_from_db,
which loads draft and published versions keyed by policy_id.
Args:
policy_id: The policy version ID (raw UUID, no prefix)
Returns:
(policy_name, Policy) if found, None otherwise
"""
return self._policies_by_id.get(policy_id)
async def get_all_policies_from_db(
self,
prisma_client: "PrismaClient",
version_status: Optional[str] = None,
) -> List[PolicyDBResponse]:
"""
Get all policies from the database, optionally filtered by version_status.
Args:
prisma_client: The Prisma client instance
version_status: If set, only return policies with this status
("draft", "published", "production").
Returns:
List of PolicyDBResponse objects
"""
try:
where: Dict[str, Any] = {}
if version_status is not None:
where["version_status"] = version_status
policies = await prisma_client.db.litellm_policytable.find_many(
where=where if where else None,
order={"created_at": "desc"},
)
return [_row_to_policy_db_response(p) for p in policies]
except Exception as e:
verbose_proxy_logger.exception(f"Error getting policies from DB: {e}")
raise Exception(f"Error getting policies from DB: {str(e)}")
async def sync_policies_from_db(
self,
prisma_client: "PrismaClient",
) -> None:
"""
Sync policies from the database to in-memory registry.
- Production versions are loaded into _policies (by policy name) for resolution.
- Draft and published versions are loaded into _policies_by_id so request-body
policy_<uuid> overrides can be resolved without DB access in the hot path.
"""
try:
self._policies = {}
production = await self.get_all_policies_from_db(
prisma_client, version_status="production"
)
for policy_response in production:
policy = self._parse_policy(
policy_response.policy_name,
{
"inherit": policy_response.inherit,
"description": policy_response.description,
"guardrails": {
"add": policy_response.guardrails_add,
"remove": policy_response.guardrails_remove,
},
"condition": policy_response.condition,
"pipeline": policy_response.pipeline,
},
)
self.add_policy(policy_response.policy_name, policy)
self._policies_by_id = {}
non_production = await prisma_client.db.litellm_policytable.find_many(
where={"version_status": {"in": ["draft", "published"]}},
order={"created_at": "desc"},
)
for row in non_production:
policy = self._parse_policy(
row.policy_name,
{
"inherit": row.inherit,
"description": row.description,
"guardrails": {
"add": row.guardrails_add or [],
"remove": row.guardrails_remove or [],
},
"condition": row.condition,
"pipeline": row.pipeline,
},
)
self._policies_by_id[row.policy_id] = (row.policy_name, policy)
self._initialized = True
verbose_proxy_logger.info(
f"Synced {len(production)} production policies and {len(non_production)} "
"draft/published (by ID) from DB to in-memory registry"
)
except Exception as e:
verbose_proxy_logger.exception(f"Error syncing policies from DB: {e}")
raise Exception(f"Error syncing policies from DB: {str(e)}")
async def resolve_guardrails_from_db(
self,
policy_name: str,
prisma_client: "PrismaClient",
) -> List[str]:
"""
Resolve all guardrails for a policy from the database.
Uses the existing PolicyResolver to handle inheritance chain resolution.
Args:
policy_name: Name of the policy to resolve
prisma_client: The Prisma client instance
Returns:
List of resolved guardrail names
"""
from litellm.proxy.policy_engine.policy_resolver import PolicyResolver
try:
# Load only production versions so inheritance resolves against production
policies = await self.get_all_policies_from_db(
prisma_client, version_status="production"
)
# Build a temporary in-memory map for resolution
temp_policies = {}
for policy_response in policies:
policy = self._parse_policy(
policy_response.policy_name,
{
"inherit": policy_response.inherit,
"description": policy_response.description,
"guardrails": {
"add": policy_response.guardrails_add,
"remove": policy_response.guardrails_remove,
},
"condition": policy_response.condition,
"pipeline": policy_response.pipeline,
},
)
temp_policies[policy_response.policy_name] = policy
# Use the existing PolicyResolver to resolve guardrails
resolved_policy = PolicyResolver.resolve_policy_guardrails(
policy_name=policy_name,
policies=temp_policies,
context=None, # No context needed for simple resolution
)
return sorted(resolved_policy.guardrails)
except Exception as e:
verbose_proxy_logger.exception(f"Error resolving guardrails from DB: {e}")
raise Exception(f"Error resolving guardrails from DB: {str(e)}")
async def get_versions_by_policy_name(
self,
policy_name: str,
prisma_client: "PrismaClient",
) -> PolicyVersionListResponse:
"""
Get all versions of a policy by name, ordered by version_number descending.
Args:
policy_name: Name of the policy
prisma_client: The Prisma client instance
Returns:
PolicyVersionListResponse with policy_name and list of versions
"""
try:
rows = await prisma_client.db.litellm_policytable.find_many(
where={"policy_name": policy_name},
order={"version_number": "desc"},
)
versions = [_row_to_policy_db_response(r) for r in rows]
return PolicyVersionListResponse(
policy_name=policy_name,
versions=versions,
total_count=len(versions),
)
except Exception as e:
verbose_proxy_logger.exception(f"Error getting versions: {e}")
raise Exception(f"Error getting versions: {str(e)}")
async def create_new_version(
self,
policy_name: str,
prisma_client: "PrismaClient",
source_policy_id: Optional[str] = None,
created_by: Optional[str] = None,
) -> PolicyDBResponse:
"""
Create a new draft version of a policy. Copies all fields from the source.
Source is current production if source_policy_id is None.
Args:
policy_name: Name of the policy
prisma_client: The Prisma client instance
source_policy_id: Policy ID to clone from; if None, use current production
created_by: User who created the version
Returns:
PolicyDBResponse for the new draft version
"""
try:
if source_policy_id is not None:
source = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": source_policy_id}
)
if source is None:
raise Exception(f"Source policy {source_policy_id} not found")
if source.policy_name != policy_name:
raise Exception(
f"Source policy name '{source.policy_name}' does not match '{policy_name}'"
)
else:
# Find current production version for this policy_name
prod = await prisma_client.db.litellm_policytable.find_first(
where={
"policy_name": policy_name,
"version_status": "production",
}
)
if prod is None:
raise Exception(
f"No production version found for policy '{policy_name}'"
)
source = prod
# Next version number
latest = await prisma_client.db.litellm_policytable.find_first(
where={"policy_name": policy_name},
order={"version_number": "desc"},
)
next_num = (latest.version_number + 1) if latest else 1
now = datetime.now(timezone.utc)
# Set is_latest=False on all existing versions for this policy_name
await prisma_client.db.litellm_policytable.update_many(
where={"policy_name": policy_name},
data={"is_latest": False},
)
data: Dict[str, Any] = {
"policy_name": policy_name,
"version_number": next_num,
"version_status": "draft",
"parent_version_id": source.policy_id,
"is_latest": True,
"published_at": None,
"production_at": None,
"inherit": source.inherit,
"description": source.description,
"guardrails_add": source.guardrails_add or [],
"guardrails_remove": source.guardrails_remove or [],
"created_at": now,
"updated_at": now,
"created_by": created_by,
"updated_by": created_by,
}
# Prisma expects Json fields as JSON strings on create (same as add_policy_to_db)
if source.condition is not None:
data["condition"] = (
json.dumps(source.condition)
if isinstance(source.condition, dict)
else source.condition
)
if source.pipeline is not None:
data["pipeline"] = (
json.dumps(source.pipeline)
if isinstance(source.pipeline, dict)
else source.pipeline
)
created = await prisma_client.db.litellm_policytable.create(data=data)
return _row_to_policy_db_response(created)
except Exception as e:
verbose_proxy_logger.exception(f"Error creating new version: {e}")
raise Exception(f"Error creating new version: {str(e)}")
async def update_version_status(
self,
policy_id: str,
new_status: str,
prisma_client: "PrismaClient",
updated_by: Optional[str] = None,
) -> PolicyDBResponse:
"""
Update a policy version's status. Valid transitions:
- draft -> published (sets published_at)
- published -> production (sets production_at, demotes current production to published, updates in-memory)
- production -> published (demotes, removes from in-memory)
- draft -> production: NOT allowed (must publish first)
- published -> draft: NOT allowed
Args:
policy_id: The policy version ID
new_status: "published" or "production"
prisma_client: The Prisma client instance
updated_by: User who updated
Returns:
PolicyDBResponse for the updated version
"""
try:
if new_status not in ("published", "production"):
raise Exception(
f"Invalid status '{new_status}'. Use 'published' or 'production'."
)
row = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": policy_id}
)
if row is None:
raise Exception(f"Policy with ID {policy_id} not found")
current = getattr(row, "version_status", "production")
policy_name = row.policy_name
now = datetime.now(timezone.utc)
if new_status == "published":
if current != "draft":
raise Exception(
f"Only draft versions can be published. Current status: '{current}'."
)
updated = await prisma_client.db.litellm_policytable.update(
where={"policy_id": policy_id},
data={
"version_status": "published",
"published_at": now,
"updated_at": now,
"updated_by": updated_by,
},
)
return _row_to_policy_db_response(updated)
# new_status == "production"
if current not in ("draft", "published"):
raise Exception(
f"Only draft or published versions can be promoted to production. Current: '{current}'."
)
# Plan: "draft -> production" NOT allowed
if current == "draft":
raise Exception(
"Cannot promote draft directly to production. Publish the version first."
)
# Demote current production to published
await prisma_client.db.litellm_policytable.update_many(
where={
"policy_name": policy_name,
"version_status": "production",
},
data={
"version_status": "published",
"updated_at": now,
"updated_by": updated_by,
},
)
# Promote this version to production
updated = await prisma_client.db.litellm_policytable.update(
where={"policy_id": policy_id},
data={
"version_status": "production",
"production_at": now,
"updated_at": now,
"updated_by": updated_by,
},
)
# Update in-memory registry: remove old production (by name), add this one
self.remove_policy(policy_name)
policy = self._parse_policy(
policy_name,
{
"inherit": updated.inherit,
"description": updated.description,
"guardrails": {
"add": updated.guardrails_add or [],
"remove": updated.guardrails_remove or [],
},
"condition": updated.condition,
"pipeline": updated.pipeline,
},
)
self.add_policy(policy_name, policy)
return _row_to_policy_db_response(updated)
except Exception as e:
verbose_proxy_logger.exception(f"Error updating version status: {e}")
raise Exception(f"Error updating version status: {str(e)}")
async def compare_versions(
self,
policy_id_a: str,
policy_id_b: str,
prisma_client: "PrismaClient",
) -> PolicyVersionCompareResponse:
"""
Compare two policy versions and return field-by-field diffs.
Args:
policy_id_a: First policy version ID
policy_id_b: Second policy version ID
prisma_client: The Prisma client instance
Returns:
PolicyVersionCompareResponse with both versions and field_diffs
"""
try:
a = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": policy_id_a}
)
b = await prisma_client.db.litellm_policytable.find_unique(
where={"policy_id": policy_id_b}
)
if a is None:
raise Exception(f"Policy {policy_id_a} not found")
if b is None:
raise Exception(f"Policy {policy_id_b} not found")
resp_a = _row_to_policy_db_response(a)
resp_b = _row_to_policy_db_response(b)
# Compare fields that are part of policy content (not metadata)
compare_fields = [
"inherit",
"description",
"guardrails_add",
"guardrails_remove",
"condition",
"pipeline",
]
field_diffs: Dict[str, Dict[str, Any]] = {}
for field in compare_fields:
val_a = getattr(resp_a, field)
val_b = getattr(resp_b, field)
if val_a != val_b:
field_diffs[field] = {"version_a": val_a, "version_b": val_b}
return PolicyVersionCompareResponse(
version_a=resp_a,
version_b=resp_b,
field_diffs=field_diffs,
)
except Exception as e:
verbose_proxy_logger.exception(f"Error comparing versions: {e}")
raise Exception(f"Error comparing versions: {str(e)}")
async def delete_all_versions(
self,
policy_name: str,
prisma_client: "PrismaClient",
) -> Dict[str, str]:
"""
Delete all versions of a policy. Also removes from in-memory registry.
Args:
policy_name: Name of the policy
prisma_client: The Prisma client instance
Returns:
Dict with success message
"""
try:
await prisma_client.db.litellm_policytable.delete_many(
where={"policy_name": policy_name}
)
self.remove_policy(policy_name)
return {
"message": f"All versions of policy '{policy_name}' deleted successfully"
}
except Exception as e:
verbose_proxy_logger.exception(f"Error deleting all versions: {e}")
raise Exception(f"Error deleting all versions: {str(e)}")
# Global singleton instance
_policy_registry: Optional[PolicyRegistry] = None
def get_policy_registry() -> PolicyRegistry:
"""
Get the global PolicyRegistry singleton.
Returns:
The global PolicyRegistry instance
"""
global _policy_registry
if _policy_registry is None:
_policy_registry = PolicyRegistry()
return _policy_registry

View File

@@ -0,0 +1,416 @@
"""
Policy resolve and attachment impact estimation endpoints.
- /policies/resolve — debug which guardrails apply for a given context
- /policies/attachments/estimate-impact — preview blast radius before creating an attachment
"""
import json
from fastapi import APIRouter, Depends, HTTPException, Query
from litellm._logging import verbose_proxy_logger
from litellm.constants import MAX_POLICY_ESTIMATE_IMPACT_ROWS
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.policy_engine.attachment_registry import get_attachment_registry
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
from litellm.types.proxy.policy_engine import (
AttachmentImpactResponse,
PolicyAttachmentCreateRequest,
PolicyMatchContext,
PolicyMatchDetail,
PolicyResolveRequest,
PolicyResolveResponse,
)
router = APIRouter()
def _build_alias_where(field: str, patterns: list) -> dict:
"""Build a Prisma ``where`` clause for alias patterns.
Supports exact matches and suffix wildcards (``prefix*``).
Returns something like:
{"OR": [{"field": {"in": ["a","b"]}}, {"field": {"startsWith": "dev-"}}]}
"""
exact: list = []
prefix_conditions: list = []
for pat in patterns:
if pat.endswith("*"):
prefix_conditions.append({field: {"startsWith": pat[:-1]}})
else:
exact.append(pat)
conditions: list = []
if exact:
conditions.append({field: {"in": exact}})
conditions.extend(prefix_conditions)
if not conditions:
return {field: {"not": None}}
if len(conditions) == 1:
return conditions[0]
return {"OR": conditions}
def _parse_metadata(raw_metadata: object) -> dict:
"""Parse metadata that may be a dict, JSON string, or None."""
if raw_metadata is None:
return {}
if isinstance(raw_metadata, str):
try:
return json.loads(raw_metadata)
except (json.JSONDecodeError, TypeError):
return {}
return raw_metadata if isinstance(raw_metadata, dict) else {}
def _get_tags_from_metadata(metadata: object, json_metadata: object = None) -> list:
"""Extract tags list from a metadata field (or metadata_json fallback)."""
raw = json_metadata if json_metadata is not None else metadata
parsed = _parse_metadata(raw)
return parsed.get("tags", []) or []
async def _fetch_all_teams(prisma_client: object) -> list:
"""Fetch teams from DB once. Reuse the result across tag and alias lookups."""
return await prisma_client.db.litellm_teamtable.find_many( # type: ignore
where={},
order={"created_at": "desc"},
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
)
def _filter_keys_by_tags(keys: list, tag_patterns: list) -> tuple:
"""Filter key rows whose metadata.tags match any of the given patterns.
Returns (named_aliases, unnamed_count).
"""
affected: list = []
unnamed_count = 0
for key in keys:
key_alias = key.key_alias or ""
key_tags = _get_tags_from_metadata(
key.metadata, getattr(key, "metadata_json", None)
)
if key_tags and any(
RouteChecks._route_matches_wildcard_pattern(route=tag, pattern=pat)
for tag in key_tags
for pat in tag_patterns
):
if key_alias:
affected.append(key_alias)
else:
unnamed_count += 1
return affected, unnamed_count
def _filter_teams_by_tags(teams: list, tag_patterns: list) -> tuple:
"""Filter pre-fetched team rows whose metadata.tags match any patterns.
Returns (named_aliases, unnamed_count).
"""
affected: list = []
unnamed_count = 0
for team in teams:
team_alias = team.team_alias or ""
team_tags = _get_tags_from_metadata(team.metadata)
if team_tags and any(
RouteChecks._route_matches_wildcard_pattern(route=tag, pattern=pat)
for tag in team_tags
for pat in tag_patterns
):
if team_alias:
affected.append(team_alias)
else:
unnamed_count += 1
return affected, unnamed_count
async def _find_affected_by_team_patterns(
prisma_client: object,
all_teams: list,
team_patterns: list,
existing_teams: list,
existing_keys: list,
) -> tuple:
"""Filter pre-fetched teams by alias patterns, then fetch their keys.
Returns (new_teams, new_keys, unnamed_keys_count).
"""
new_teams: list = []
matched_team_ids: list = []
for team in all_teams:
team_alias = team.team_alias or ""
if team_alias and any(
RouteChecks._route_matches_wildcard_pattern(route=team_alias, pattern=pat)
for pat in team_patterns
):
if team_alias not in existing_teams:
new_teams.append(team_alias)
matched_team_ids.append(str(team.team_id))
new_keys: list = []
unnamed_keys_count = 0
if matched_team_ids:
keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore
where={"team_id": {"in": matched_team_ids}},
order={"created_at": "desc"},
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
)
for key in keys:
key_alias = key.key_alias or ""
if key_alias:
if key_alias not in existing_keys:
new_keys.append(key_alias)
else:
unnamed_keys_count += 1
return new_teams, new_keys, unnamed_keys_count
async def _find_affected_keys_by_alias(
prisma_client: object, key_patterns: list, existing_keys: list
) -> list:
"""Find keys whose alias matches the given patterns."""
affected: list = []
keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore
where=_build_alias_where("key_alias", key_patterns),
order={"created_at": "desc"},
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
)
for key in keys:
key_alias = key.key_alias or ""
if key_alias and any(
RouteChecks._route_matches_wildcard_pattern(route=key_alias, pattern=pat)
for pat in key_patterns
):
if key_alias not in existing_keys:
affected.append(key_alias)
return affected
# ─────────────────────────────────────────────────────────────────────────────
# Policy Resolve Endpoint
# ─────────────────────────────────────────────────────────────────────────────
@router.post(
"/policies/resolve",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
response_model=PolicyResolveResponse,
)
async def resolve_policies_for_context(
request: PolicyResolveRequest,
force_sync: bool = Query(
default=False,
description="Force a DB sync before resolving. Default uses in-memory cache.",
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Resolve which policies and guardrails apply for a given context.
Use this endpoint to debug "what guardrails would apply to a request
with this team/key/model/tags combination?"
Example Request:
```bash
curl -X POST "http://localhost:4000/policies/resolve" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"tags": ["healthcare"],
"model": "gpt-4"
}'
```
"""
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
from litellm.proxy.policy_engine.policy_resolver import PolicyResolver
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# Only sync from DB when explicitly requested; otherwise use in-memory cache
if force_sync:
await get_policy_registry().sync_policies_from_db(prisma_client)
await get_attachment_registry().sync_attachments_from_db(prisma_client)
# Build context from request
context = PolicyMatchContext(
team_alias=request.team_alias,
key_alias=request.key_alias,
model=request.model,
tags=request.tags,
)
# Get matching policies with reasons
match_results = get_attachment_registry().get_attached_policies_with_reasons(
context=context
)
if not match_results:
return PolicyResolveResponse(
effective_guardrails=[],
matched_policies=[],
)
# Filter by conditions
policy_names = [r["policy_name"] for r in match_results]
applied_policy_names = PolicyMatcher.get_policies_with_matching_conditions(
policy_names=policy_names,
context=context,
)
# Resolve guardrails for each applied policy
matched_policies = []
all_guardrails: set = set()
for result in match_results:
pname = result["policy_name"]
if pname not in applied_policy_names:
continue
resolved = PolicyResolver.resolve_policy_guardrails(
policy_name=pname,
policies=get_policy_registry().get_all_policies(),
context=context,
)
guardrails = resolved.guardrails if resolved else []
all_guardrails.update(guardrails)
matched_policies.append(
PolicyMatchDetail(
policy_name=pname,
matched_via=result["matched_via"],
guardrails_added=guardrails,
)
)
return PolicyResolveResponse(
effective_guardrails=sorted(all_guardrails),
matched_policies=matched_policies,
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error resolving policies: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ─────────────────────────────────────────────────────────────────────────────
# Attachment Impact Estimation Endpoint
# ─────────────────────────────────────────────────────────────────────────────
@router.post(
"/policies/attachments/estimate-impact",
tags=["Policies"],
dependencies=[Depends(user_api_key_auth)],
response_model=AttachmentImpactResponse,
)
async def estimate_attachment_impact(
request: PolicyAttachmentCreateRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Estimate how many keys and teams would be affected by a policy attachment.
Use this before creating an attachment to preview the blast radius.
Example Request:
```bash
curl -X POST "http://localhost:4000/policies/attachments/estimate-impact" \\
-H "Authorization: Bearer <your_api_key>" \\
-H "Content-Type: application/json" \\
-d '{
"policy_name": "hipaa-compliance",
"tags": ["healthcare", "health-*"]
}'
```
"""
from litellm.proxy.proxy_server import prisma_client
if prisma_client is None:
raise HTTPException(status_code=500, detail="Database not connected")
try:
# If global scope, everything is affected — not useful to enumerate
if request.scope == "*":
return AttachmentImpactResponse(
affected_keys_count=-1,
affected_teams_count=-1,
sample_keys=["(global scope — affects all keys)"],
sample_teams=["(global scope — affects all teams)"],
)
affected_keys: list = []
affected_teams: list = []
unnamed_keys = 0
unnamed_teams = 0
tag_patterns = request.tags or []
team_patterns = request.teams or []
# Fetch teams once — reused by both tag-based and alias-based lookups
all_teams: list = []
if tag_patterns or team_patterns:
all_teams = await _fetch_all_teams(prisma_client)
# Tag-based impact
if tag_patterns:
keys = await prisma_client.db.litellm_verificationtoken.find_many( # type: ignore
where={},
order={"created_at": "desc"},
take=MAX_POLICY_ESTIMATE_IMPACT_ROWS,
)
affected_keys, unnamed_keys = _filter_keys_by_tags(keys, tag_patterns)
affected_teams, unnamed_teams = _filter_teams_by_tags(
all_teams,
tag_patterns,
)
# Team-based impact (alias matching + keys belonging to those teams)
if team_patterns:
new_teams, new_keys, new_unnamed = await _find_affected_by_team_patterns(
prisma_client,
all_teams,
team_patterns,
affected_teams,
affected_keys,
)
affected_teams.extend(new_teams)
affected_keys.extend(new_keys)
unnamed_keys += new_unnamed
# Key-based impact (direct alias matching)
key_patterns = request.keys or []
if key_patterns:
new_keys = await _find_affected_keys_by_alias(
prisma_client,
key_patterns,
affected_keys,
)
affected_keys.extend(new_keys)
return AttachmentImpactResponse(
affected_keys_count=len(affected_keys) + unnamed_keys,
affected_teams_count=len(affected_teams) + unnamed_teams,
unnamed_keys_count=unnamed_keys,
unnamed_teams_count=unnamed_teams,
sample_keys=affected_keys[:10],
sample_teams=affected_teams[:10],
)
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.exception(f"Error estimating attachment impact: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,299 @@
"""
Policy Resolver - Resolves final guardrail list from policies.
Handles:
- Inheritance chain resolution (inherit with add/remove)
- Applying add/remove guardrails
- Evaluating model conditions
- Combining guardrails from multiple matching policies
"""
from typing import Dict, List, Optional, Set, Tuple
from litellm._logging import verbose_proxy_logger
from litellm.types.proxy.policy_engine import (
GuardrailPipeline,
Policy,
PolicyMatchContext,
ResolvedPolicy,
)
class PolicyResolver:
"""
Resolves the final list of guardrails from policies.
Handles:
- Inheritance chains with add/remove operations
- Model-based conditions
"""
@staticmethod
def resolve_inheritance_chain(
policy_name: str,
policies: Dict[str, Policy],
visited: Optional[Set[str]] = None,
) -> List[str]:
"""
Get the inheritance chain for a policy (from root to policy).
Args:
policy_name: Name of the policy
policies: Dictionary of all policies
visited: Set of visited policies (for cycle detection)
Returns:
List of policy names from root ancestor to the given policy
"""
if visited is None:
visited = set()
if policy_name in visited:
verbose_proxy_logger.warning(
f"Circular inheritance detected for policy '{policy_name}'"
)
return []
policy = policies.get(policy_name)
if policy is None:
return []
visited.add(policy_name)
if policy.inherit:
parent_chain = PolicyResolver.resolve_inheritance_chain(
policy_name=policy.inherit, policies=policies, visited=visited
)
return parent_chain + [policy_name]
return [policy_name]
@staticmethod
def resolve_policy_guardrails(
policy_name: str,
policies: Dict[str, Policy],
context: Optional[PolicyMatchContext] = None,
) -> ResolvedPolicy:
"""
Resolve the final guardrails for a single policy, including inheritance.
This method:
1. Resolves the inheritance chain
2. Applies add/remove from each policy in the chain
3. Evaluates model conditions (if context provided)
Args:
policy_name: Name of the policy to resolve
policies: Dictionary of all policies
context: Optional request context for evaluating conditions
Returns:
ResolvedPolicy with final guardrails list
"""
from litellm.proxy.policy_engine.condition_evaluator import ConditionEvaluator
inheritance_chain = PolicyResolver.resolve_inheritance_chain(
policy_name=policy_name, policies=policies
)
# Start with empty set of guardrails
guardrails: Set[str] = set()
# Apply each policy in the chain (from root to leaf)
for chain_policy_name in inheritance_chain:
policy = policies.get(chain_policy_name)
if policy is None:
continue
# Check if policy condition matches (if context provided)
if context is not None and policy.condition is not None:
if not ConditionEvaluator.evaluate(
condition=policy.condition,
context=context,
):
verbose_proxy_logger.debug(
f"Policy '{chain_policy_name}' condition did not match, skipping guardrails"
)
continue
# Add guardrails from guardrails.add
for guardrail in policy.guardrails.get_add():
guardrails.add(guardrail)
# Remove guardrails from guardrails.remove
for guardrail in policy.guardrails.get_remove():
guardrails.discard(guardrail)
return ResolvedPolicy(
policy_name=policy_name,
guardrails=list(guardrails),
inheritance_chain=inheritance_chain,
)
@staticmethod
def resolve_guardrails_for_context(
context: PolicyMatchContext,
policies: Optional[Dict[str, Policy]] = None,
policy_names: Optional[List[str]] = None,
) -> List[str]:
"""
Resolve the final list of guardrails for a request context.
This:
1. Finds all policies that match the context via policy_attachments (or policy_names if provided)
2. Resolves each policy's guardrails (including inheritance)
3. Evaluates model conditions
4. Combines all guardrails (union)
Args:
context: The request context
policies: Dictionary of all policies (if None, uses global registry)
policy_names: If provided, use this list instead of attachment matching
Returns:
List of guardrail names to apply
"""
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
if policies is None:
registry = get_policy_registry()
if not registry.is_initialized():
return []
policies = registry.get_all_policies()
# Use provided policy names or get matching policies via attachments
matching_policy_names = (
policy_names
if policy_names is not None
else PolicyMatcher.get_matching_policies(context=context)
)
if not matching_policy_names:
verbose_proxy_logger.debug(
f"No policies match context: team_alias={context.team_alias}, "
f"key_alias={context.key_alias}, model={context.model}"
)
return []
# Resolve each matching policy and combine guardrails
all_guardrails: Set[str] = set()
for policy_name in matching_policy_names:
resolved = PolicyResolver.resolve_policy_guardrails(
policy_name=policy_name,
policies=policies,
context=context,
)
all_guardrails.update(resolved.guardrails)
verbose_proxy_logger.debug(
f"Policy '{policy_name}' contributes guardrails: {resolved.guardrails}"
)
result = list(all_guardrails)
verbose_proxy_logger.debug(f"Final guardrails for context: {result}")
return result
@staticmethod
def resolve_pipelines_for_context(
context: PolicyMatchContext,
policies: Optional[Dict[str, Policy]] = None,
policy_names: Optional[List[str]] = None,
) -> List[Tuple[str, GuardrailPipeline]]:
"""
Resolve pipelines from matching policies for a request context.
Returns (policy_name, pipeline) tuples for policies that have pipelines.
Guardrails managed by pipelines should be excluded from the flat
guardrails list to avoid double execution.
Args:
context: The request context
policies: Dictionary of all policies (if None, uses global registry)
policy_names: If provided, use this list instead of attachment matching
Returns:
List of (policy_name, GuardrailPipeline) tuples
"""
from litellm.proxy.policy_engine.policy_matcher import PolicyMatcher
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
if policies is None:
registry = get_policy_registry()
if not registry.is_initialized():
return []
policies = registry.get_all_policies()
matching_policy_names = (
policy_names
if policy_names is not None
else PolicyMatcher.get_matching_policies(context=context)
)
if not matching_policy_names:
return []
pipelines: List[Tuple[str, GuardrailPipeline]] = []
for policy_name in matching_policy_names:
policy = policies.get(policy_name)
if policy is None:
continue
if policy.pipeline is not None:
pipelines.append((policy_name, policy.pipeline))
verbose_proxy_logger.debug(
f"Policy '{policy_name}' has pipeline with "
f"{len(policy.pipeline.steps)} steps"
)
return pipelines
@staticmethod
def get_pipeline_managed_guardrails(
pipelines: List[Tuple[str, GuardrailPipeline]],
) -> Set[str]:
"""
Get the set of guardrail names managed by pipelines.
These guardrails should be excluded from normal independent execution.
"""
managed: Set[str] = set()
for _policy_name, pipeline in pipelines:
for step in pipeline.steps:
managed.add(step.guardrail)
return managed
@staticmethod
def get_all_resolved_policies(
policies: Optional[Dict[str, Policy]] = None,
context: Optional[PolicyMatchContext] = None,
) -> Dict[str, ResolvedPolicy]:
"""
Resolve all policies and return their final guardrails.
Useful for debugging and displaying policy configurations.
Args:
policies: Dictionary of all policies (if None, uses global registry)
context: Optional context for evaluating conditions
Returns:
Dictionary mapping policy names to ResolvedPolicy objects
"""
from litellm.proxy.policy_engine.policy_registry import get_policy_registry
if policies is None:
registry = get_policy_registry()
if not registry.is_initialized():
return {}
policies = registry.get_all_policies()
resolved: Dict[str, ResolvedPolicy] = {}
for policy_name in policies:
resolved[policy_name] = PolicyResolver.resolve_policy_guardrails(
policy_name=policy_name,
policies=policies,
context=context,
)
return resolved

View File

@@ -0,0 +1,405 @@
"""
Policy Validator - Validates policy configurations.
Validates:
- Guardrail names exist in the guardrail registry
- Non-wildcard team aliases exist in the database
- Non-wildcard key aliases exist in the database
- Non-wildcard model names exist in the router or match a wildcard route
- Inheritance chains are valid (no cycles, parents exist)
"""
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from litellm._logging import verbose_proxy_logger
from litellm.types.proxy.policy_engine import (
Policy,
PolicyValidationError,
PolicyValidationErrorType,
PolicyValidationResponse,
)
if TYPE_CHECKING:
from litellm.proxy.utils import PrismaClient
from litellm.router import Router
class PolicyValidator:
"""
Validates policy configurations against actual data.
"""
def __init__(
self,
prisma_client: Optional["PrismaClient"] = None,
llm_router: Optional["Router"] = None,
):
"""
Initialize the validator.
Args:
prisma_client: Optional Prisma client for database validation
llm_router: Optional LLM router for model validation
"""
self.prisma_client = prisma_client
self.llm_router = llm_router
@staticmethod
def is_wildcard_pattern(pattern: str) -> bool:
"""
Check if a pattern contains wildcards.
Args:
pattern: The pattern to check
Returns:
True if the pattern contains wildcard characters
"""
return "*" in pattern or "?" in pattern
def get_available_guardrails(self) -> Set[str]:
"""
Get set of available guardrail names from the guardrail registry.
Returns:
Set of guardrail names
"""
try:
from litellm.proxy.guardrails.guardrail_registry import (
IN_MEMORY_GUARDRAIL_HANDLER,
)
guardrails = IN_MEMORY_GUARDRAIL_HANDLER.list_in_memory_guardrails()
return {
g.get("guardrail_name", "")
for g in guardrails
if g.get("guardrail_name")
}
except Exception as e:
verbose_proxy_logger.warning(
f"Could not get guardrails from registry: {str(e)}"
)
return set()
async def check_team_alias_exists(self, team_alias: str) -> bool:
"""
Check if a specific team alias exists in the database.
Args:
team_alias: The team alias to check
Returns:
True if the team alias exists
"""
if self.prisma_client is None:
return True # Can't validate without DB, assume valid
try:
team = await self.prisma_client.db.litellm_teamtable.find_first(
where={"team_alias": team_alias},
)
return team is not None
except Exception as e:
verbose_proxy_logger.warning(
f"Could not check team alias '{team_alias}': {str(e)}"
)
return True # Assume valid on error
async def check_key_alias_exists(self, key_alias: str) -> bool:
"""
Check if a specific key alias exists in the database.
Args:
key_alias: The key alias to check
Returns:
True if the key alias exists
"""
if self.prisma_client is None:
return True # Can't validate without DB, assume valid
try:
key = await self.prisma_client.db.litellm_verificationtoken.find_first(
where={"key_alias": key_alias},
)
return key is not None
except Exception as e:
verbose_proxy_logger.warning(
f"Could not check key alias '{key_alias}': {str(e)}"
)
return True # Assume valid on error
def check_model_exists(self, model: str) -> bool:
"""
Check if a model exists in the router or matches a wildcard pattern.
Args:
model: The model name to check
Returns:
True if the model exists or matches a pattern in the router
"""
if self.llm_router is None:
return True # Can't validate without router, assume valid
try:
# Check if model is in router's model names
if model in self.llm_router.model_names:
return True
# Check if model matches any pattern via pattern router
if hasattr(self.llm_router, "pattern_router"):
pattern_deployments = (
self.llm_router.pattern_router.get_deployments_by_pattern(
model=model
)
)
if pattern_deployments:
return True
return False
except Exception as e:
verbose_proxy_logger.warning(f"Could not check model '{model}': {str(e)}")
return True # Assume valid on error
def _validate_inheritance_chain(
self,
policy_name: str,
policies: Dict[str, Policy],
visited: Optional[Set[str]] = None,
max_depth: int = 100,
) -> List[PolicyValidationError]:
"""
Validate the inheritance chain for a policy.
Checks for:
- Parent policy exists
- No circular inheritance
- Max depth not exceeded
Args:
policy_name: Name of the policy to validate
policies: All policies
visited: Set of already visited policy names (for cycle detection)
max_depth: Maximum recursion depth to prevent infinite loops
Returns:
List of validation errors
"""
errors: List[PolicyValidationError] = []
# Prevent infinite recursion
if max_depth <= 0:
errors.append(
PolicyValidationError(
policy_name=policy_name,
error_type=PolicyValidationErrorType.CIRCULAR_INHERITANCE,
message="Inheritance chain too deep (exceeded max depth of 100)",
field="inherit",
)
)
return errors
if visited is None:
visited = set()
if policy_name in visited:
errors.append(
PolicyValidationError(
policy_name=policy_name,
error_type=PolicyValidationErrorType.CIRCULAR_INHERITANCE,
message=f"Circular inheritance detected: {' -> '.join(visited)} -> {policy_name}",
field="inherit",
)
)
return errors
policy = policies.get(policy_name)
if policy is None:
return errors
if policy.inherit:
if policy.inherit not in policies:
errors.append(
PolicyValidationError(
policy_name=policy_name,
error_type=PolicyValidationErrorType.INVALID_INHERITANCE,
message=f"Parent policy '{policy.inherit}' not found",
field="inherit",
value=policy.inherit,
)
)
else:
# Recursively check parent with decremented depth
visited.add(policy_name)
errors.extend(
self._validate_inheritance_chain(
policy.inherit, policies, visited, max_depth - 1
)
)
return errors
async def validate_policies(
self,
policies: Dict[str, Policy],
validate_db: bool = True,
) -> PolicyValidationResponse:
"""
Validate a set of policies.
Args:
policies: Dictionary mapping policy names to Policy objects
validate_db: Whether to validate against database (teams, keys)
Returns:
PolicyValidationResponse with errors and warnings
"""
errors: List[PolicyValidationError] = []
warnings: List[PolicyValidationError] = []
# Get available guardrails
available_guardrails = self.get_available_guardrails()
for policy_name, policy in policies.items():
# Validate guardrails
for guardrail in policy.guardrails.get_add():
if available_guardrails and guardrail not in available_guardrails:
errors.append(
PolicyValidationError(
policy_name=policy_name,
error_type=PolicyValidationErrorType.INVALID_GUARDRAIL,
message=f"Guardrail '{guardrail}' not found in guardrail registry",
field="guardrails.add",
value=guardrail,
)
)
for guardrail in policy.guardrails.get_remove():
if available_guardrails and guardrail not in available_guardrails:
warnings.append(
PolicyValidationError(
policy_name=policy_name,
error_type=PolicyValidationErrorType.INVALID_GUARDRAIL,
message=f"Guardrail '{guardrail}' in remove list not found in guardrail registry",
field="guardrails.remove",
value=guardrail,
)
)
# Validate pipeline if present
if policy.pipeline is not None:
pipeline_errors = PolicyValidator._validate_pipeline(
policy_name=policy_name,
policy=policy,
available_guardrails=available_guardrails,
)
errors.extend(pipeline_errors)
# Validate inheritance
inheritance_errors = self._validate_inheritance_chain(
policy_name=policy_name, policies=policies
)
errors.extend(inheritance_errors)
return PolicyValidationResponse(
valid=len(errors) == 0,
errors=errors,
warnings=warnings,
)
@staticmethod
def _validate_pipeline(
policy_name: str,
policy: Policy,
available_guardrails: Set[str],
) -> List[PolicyValidationError]:
"""Validate a policy's pipeline configuration."""
errors: List[PolicyValidationError] = []
pipeline = policy.pipeline
if pipeline is None:
return errors
guardrails_add = set(policy.guardrails.get_add())
for i, step in enumerate(pipeline.steps):
# Check guardrail is in policy's guardrails.add
if step.guardrail not in guardrails_add:
errors.append(
PolicyValidationError(
policy_name=policy_name,
error_type=PolicyValidationErrorType.INVALID_GUARDRAIL,
message=(
f"Pipeline step {i} guardrail '{step.guardrail}' "
f"is not in the policy's guardrails.add list"
),
field="pipeline.steps",
value=step.guardrail,
)
)
# Check guardrail exists in registry
if available_guardrails and step.guardrail not in available_guardrails:
errors.append(
PolicyValidationError(
policy_name=policy_name,
error_type=PolicyValidationErrorType.INVALID_GUARDRAIL,
message=(
f"Pipeline step {i} guardrail '{step.guardrail}' "
f"not found in guardrail registry"
),
field="pipeline.steps",
value=step.guardrail,
)
)
return errors
async def validate_policy_config(
self,
policy_config: Dict[str, Any],
validate_db: bool = True,
) -> PolicyValidationResponse:
"""
Validate a raw policy configuration dictionary.
This parses the config and then validates it.
Args:
policy_config: Raw policy configuration from YAML
validate_db: Whether to validate against database
Returns:
PolicyValidationResponse with errors and warnings
"""
from litellm.proxy.policy_engine.policy_registry import PolicyRegistry
# First, try to parse the policies
errors: List[PolicyValidationError] = []
policies: Dict[str, Policy] = {}
temp_registry = PolicyRegistry()
for policy_name, policy_data in policy_config.items():
try:
policy = temp_registry._parse_policy(policy_name, policy_data)
policies[policy_name] = policy
except Exception as e:
errors.append(
PolicyValidationError(
policy_name=policy_name,
error_type=PolicyValidationErrorType.INVALID_SYNTAX,
message=f"Failed to parse policy: {str(e)}",
)
)
# If there were parsing errors, return early
if errors:
return PolicyValidationResponse(
valid=False,
errors=errors,
warnings=[],
)
# Validate the parsed policies
return await self.validate_policies(policies, validate_db=validate_db)