chore: initial snapshot for gitea/github upload
This commit is contained in:
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Functions to create audit logs for LiteLLM Proxy
|
||||
"""
|
||||
|
||||
import json
|
||||
from litellm._uuid import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import (
|
||||
AUDIT_ACTIONS,
|
||||
LiteLLM_AuditLogs,
|
||||
LitellmTableNames,
|
||||
Optional,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
|
||||
|
||||
async def create_object_audit_log(
|
||||
object_id: str,
|
||||
action: AUDIT_ACTIONS,
|
||||
litellm_changed_by: Optional[str],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_proxy_admin_name: Optional[str],
|
||||
table_name: LitellmTableNames,
|
||||
before_value: Optional[str] = None,
|
||||
after_value: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Create an audit log for an internal user.
|
||||
|
||||
Parameters:
|
||||
- user_id: str - The id of the user to create the audit log for.
|
||||
- action: AUDIT_ACTIONS - The action to create the audit log for.
|
||||
- user_row: LiteLLM_UserTable - The user row to create the audit log for.
|
||||
- litellm_changed_by: Optional[str] - The user id of the user who is changing the user.
|
||||
- user_api_key_dict: UserAPIKeyAuth - The user api key dictionary.
|
||||
- litellm_proxy_admin_name: Optional[str] - The name of the proxy admin.
|
||||
"""
|
||||
from litellm.secret_managers.main import get_secret_bool
|
||||
|
||||
store_audit_logs = litellm.store_audit_logs or get_secret_bool(
|
||||
"LITELLM_STORE_AUDIT_LOGS"
|
||||
)
|
||||
|
||||
if store_audit_logs is not True:
|
||||
return
|
||||
|
||||
await create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=table_name,
|
||||
object_id=object_id,
|
||||
action=action,
|
||||
updated_values=after_value,
|
||||
before_value=before_value,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs):
|
||||
"""
|
||||
Create an audit log for an object.
|
||||
"""
|
||||
from litellm.secret_managers.main import get_secret_bool
|
||||
|
||||
store_audit_logs = litellm.store_audit_logs or get_secret_bool(
|
||||
"LITELLM_STORE_AUDIT_LOGS"
|
||||
)
|
||||
if store_audit_logs is not True:
|
||||
return
|
||||
|
||||
from litellm.proxy.proxy_server import premium_user, prisma_client
|
||||
|
||||
if premium_user is not True:
|
||||
return
|
||||
|
||||
if prisma_client is None:
|
||||
raise Exception("prisma_client is None, no DB connected")
|
||||
|
||||
verbose_proxy_logger.debug("creating audit log for %s", request_data)
|
||||
|
||||
if isinstance(request_data.updated_values, dict):
|
||||
request_data.updated_values = json.dumps(request_data.updated_values)
|
||||
|
||||
if isinstance(request_data.before_value, dict):
|
||||
request_data.before_value = json.dumps(request_data.before_value)
|
||||
|
||||
_request_data = request_data.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
await prisma_client.db.litellm_auditlog.create(
|
||||
data={
|
||||
**_request_data, # type: ignore
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
# [Non-Blocking Exception. Do not allow blocking LLM API call]
|
||||
verbose_proxy_logger.error(f"Failed Creating audit log {e}")
|
||||
|
||||
return
|
||||
@@ -0,0 +1,366 @@
|
||||
"""
|
||||
Common utility functions for handling object permission updates across
|
||||
organizations, teams, and keys.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_ObjectPermissionTable,
|
||||
LiteLLM_TeamTableCachedObj,
|
||||
)
|
||||
|
||||
|
||||
async def attach_object_permission_to_dict(
|
||||
data_dict: Dict,
|
||||
prisma_client: PrismaClient,
|
||||
) -> Dict:
|
||||
"""
|
||||
Helper method to attach object_permission to a dictionary if object_permission_id is set.
|
||||
|
||||
This function:
|
||||
1. Checks if the dictionary has an object_permission_id
|
||||
2. If found, queries the database for the corresponding object permission
|
||||
3. Converts the object permission to a dictionary format
|
||||
4. Attaches it to the input dictionary under the 'object_permission' key
|
||||
|
||||
Args:
|
||||
data_dict: The dictionary to attach object_permission to
|
||||
prisma_client: The database client
|
||||
|
||||
Returns:
|
||||
Dict: The input dictionary with object_permission attached if found
|
||||
|
||||
Raises:
|
||||
ValueError: If prisma_client is None
|
||||
"""
|
||||
if prisma_client is None:
|
||||
raise ValueError("Prisma client not found")
|
||||
|
||||
object_permission_id = data_dict.get("object_permission_id")
|
||||
if object_permission_id:
|
||||
object_permission = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
)
|
||||
)
|
||||
if object_permission:
|
||||
# Convert to dict if needed
|
||||
try:
|
||||
object_permission = object_permission.model_dump()
|
||||
except Exception:
|
||||
object_permission = object_permission.dict()
|
||||
data_dict["object_permission"] = object_permission
|
||||
return data_dict
|
||||
|
||||
|
||||
async def handle_update_object_permission_common(
|
||||
data_json: Dict,
|
||||
existing_object_permission_id: Optional[str],
|
||||
prisma_client: Optional[PrismaClient],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Common logic for handling object permission updates across organizations, teams, and keys.
|
||||
|
||||
This function:
|
||||
1. Extracts `object_permission` from data_json
|
||||
2. Looks up existing object permission if it exists
|
||||
3. Merges new permissions with existing ones
|
||||
4. Upserts to the LiteLLM_ObjectPermissionTable
|
||||
5. Returns the object_permission_id
|
||||
|
||||
Args:
|
||||
data_json: The data dictionary containing the object_permission to update
|
||||
existing_object_permission_id: The current object_permission_id from the entity (can be None)
|
||||
prisma_client: The database client
|
||||
|
||||
Returns:
|
||||
Optional[str]: The object_permission_id after the update/creation, or None if no object_permission to process
|
||||
|
||||
Raises:
|
||||
ValueError: If prisma_client is None
|
||||
"""
|
||||
if prisma_client is None:
|
||||
raise ValueError("Prisma client not found")
|
||||
|
||||
#########################################################
|
||||
# Ensure `object_permission` is not added to the data_json
|
||||
# We need to update the entity at the object_permission_id level in the LiteLLM_ObjectPermissionTable
|
||||
#########################################################
|
||||
new_object_permission: Union[dict, str] = data_json.pop("object_permission", None)
|
||||
if new_object_permission is None:
|
||||
return None
|
||||
|
||||
# Lookup existing object permission ID and update that entry
|
||||
object_permission_id_to_use: str = existing_object_permission_id or str(
|
||||
uuid.uuid4()
|
||||
)
|
||||
existing_object_permissions_dict: Dict = {}
|
||||
|
||||
existing_object_permission = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
where={"object_permission_id": object_permission_id_to_use},
|
||||
)
|
||||
)
|
||||
|
||||
# Update the object permission
|
||||
if existing_object_permission is not None:
|
||||
existing_object_permissions_dict = existing_object_permission.model_dump(
|
||||
exclude_unset=True, exclude_none=True
|
||||
)
|
||||
|
||||
# Handle string JSON object permission
|
||||
if isinstance(new_object_permission, str):
|
||||
new_object_permission = json.loads(new_object_permission)
|
||||
|
||||
if isinstance(new_object_permission, dict):
|
||||
existing_object_permissions_dict.update(new_object_permission)
|
||||
|
||||
#########################################################
|
||||
# Serialize mcp_tool_permissions JSON field to avoid GraphQL parsing issues
|
||||
# (e.g., server IDs starting with "3e64" being interpreted as floats)
|
||||
#########################################################
|
||||
if "mcp_tool_permissions" in existing_object_permissions_dict:
|
||||
existing_object_permissions_dict["mcp_tool_permissions"] = safe_dumps(
|
||||
existing_object_permissions_dict["mcp_tool_permissions"]
|
||||
)
|
||||
|
||||
#########################################################
|
||||
# Commit the update to the LiteLLM_ObjectPermissionTable
|
||||
#########################################################
|
||||
created_object_permission_row = (
|
||||
await prisma_client.db.litellm_objectpermissiontable.upsert(
|
||||
where={"object_permission_id": object_permission_id_to_use},
|
||||
data={
|
||||
"create": existing_object_permissions_dict,
|
||||
"update": existing_object_permissions_dict,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"created_object_permission_row: {created_object_permission_row}"
|
||||
)
|
||||
|
||||
return created_object_permission_row.object_permission_id
|
||||
|
||||
|
||||
async def _set_object_permission(
|
||||
data_json: dict,
|
||||
prisma_client: Optional[PrismaClient],
|
||||
):
|
||||
"""
|
||||
Creates the LiteLLM_ObjectPermissionTable record for the key/team.
|
||||
Handles permissions for vector stores and mcp servers.
|
||||
"""
|
||||
if prisma_client is None or "object_permission" not in data_json:
|
||||
return data_json
|
||||
|
||||
permission_data = data_json["object_permission"]
|
||||
if not isinstance(permission_data, dict):
|
||||
data_json.pop("object_permission")
|
||||
return data_json
|
||||
|
||||
# Clean data: exclude None values and object_permission_id
|
||||
clean_data = {
|
||||
k: v
|
||||
for k, v in permission_data.items()
|
||||
if v is not None and k != "object_permission_id"
|
||||
}
|
||||
|
||||
# Serialize mcp_tool_permissions to JSON string for GraphQL compatibility
|
||||
if "mcp_tool_permissions" in clean_data:
|
||||
clean_data["mcp_tool_permissions"] = safe_dumps(
|
||||
clean_data["mcp_tool_permissions"]
|
||||
)
|
||||
|
||||
created_permission = await prisma_client.db.litellm_objectpermissiontable.create(
|
||||
data=clean_data
|
||||
)
|
||||
|
||||
data_json["object_permission_id"] = created_permission.object_permission_id
|
||||
data_json.pop("object_permission")
|
||||
return data_json
|
||||
|
||||
|
||||
async def _resolve_team_allowed_mcp_servers(
|
||||
team_object_permission: "LiteLLM_ObjectPermissionTable",
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Resolve the full set of MCP server IDs a team has access to.
|
||||
|
||||
Combines:
|
||||
- Direct mcp_servers list
|
||||
- Servers from mcp_access_groups
|
||||
- Server IDs referenced in mcp_tool_permissions keys
|
||||
"""
|
||||
from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import (
|
||||
MCPRequestHandler,
|
||||
)
|
||||
|
||||
direct_servers: List[str] = team_object_permission.mcp_servers or []
|
||||
access_group_servers: List[
|
||||
str
|
||||
] = await MCPRequestHandler._get_mcp_servers_from_access_groups(
|
||||
team_object_permission.mcp_access_groups or []
|
||||
)
|
||||
raw_tool_perms = team_object_permission.mcp_tool_permissions or {}
|
||||
if isinstance(raw_tool_perms, str):
|
||||
raw_tool_perms = json.loads(raw_tool_perms)
|
||||
tool_perm_servers: List[str] = list(raw_tool_perms.keys())
|
||||
return set(direct_servers + access_group_servers + tool_perm_servers)
|
||||
|
||||
|
||||
def _get_allow_all_keys_server_ids() -> Set[str]:
|
||||
"""Return the set of MCP server IDs marked with allow_all_keys=True."""
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
|
||||
return set(global_mcp_server_manager.get_allow_all_keys_server_ids())
|
||||
|
||||
|
||||
async def _get_team_allowed_mcp_servers(
|
||||
team_obj: Optional["LiteLLM_TeamTableCachedObj"],
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Get the full set of MCP server IDs a team allows.
|
||||
|
||||
If team has no object_permission or no MCP config, returns empty set
|
||||
(meaning only allow_all_keys servers are permitted).
|
||||
"""
|
||||
if team_obj is None:
|
||||
return set()
|
||||
|
||||
team_object_permission = team_obj.object_permission
|
||||
if team_object_permission is None:
|
||||
return set()
|
||||
|
||||
return await _resolve_team_allowed_mcp_servers(team_object_permission)
|
||||
|
||||
|
||||
def _extract_requested_mcp_server_ids(
|
||||
object_permission: Optional[dict],
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Extract all MCP server IDs referenced in a key's object_permission dict.
|
||||
|
||||
Includes:
|
||||
- mcp_servers list
|
||||
- Keys from mcp_tool_permissions
|
||||
"""
|
||||
if not object_permission or not isinstance(object_permission, dict):
|
||||
return set()
|
||||
|
||||
server_ids: Set[str] = set()
|
||||
mcp_servers = object_permission.get("mcp_servers")
|
||||
if isinstance(mcp_servers, list):
|
||||
server_ids.update(mcp_servers)
|
||||
|
||||
mcp_tool_permissions = object_permission.get("mcp_tool_permissions")
|
||||
if isinstance(mcp_tool_permissions, dict):
|
||||
server_ids.update(mcp_tool_permissions.keys())
|
||||
|
||||
return server_ids
|
||||
|
||||
|
||||
def _extract_requested_mcp_access_groups(
|
||||
object_permission: Optional[dict],
|
||||
) -> Set[str]:
|
||||
"""Extract MCP access groups from a key's object_permission dict."""
|
||||
if not object_permission or not isinstance(object_permission, dict):
|
||||
return set()
|
||||
|
||||
groups = object_permission.get("mcp_access_groups")
|
||||
if isinstance(groups, list):
|
||||
return set(groups)
|
||||
return set()
|
||||
|
||||
|
||||
async def validate_key_mcp_servers_against_team(
|
||||
object_permission: Optional[dict],
|
||||
team_obj: Optional["LiteLLM_TeamTableCachedObj"],
|
||||
):
|
||||
"""
|
||||
Validate that MCP servers requested on a key are within the allowed scope.
|
||||
|
||||
Rules:
|
||||
- If key is in a team: key's mcp_servers must be a subset of
|
||||
(team's allowed servers + allow_all_keys servers)
|
||||
- If key is NOT in a team: key's mcp_servers must only contain
|
||||
allow_all_keys servers
|
||||
- If team has no MCP config: key can only use allow_all_keys servers
|
||||
|
||||
Raises HTTPException(403) if validation fails.
|
||||
"""
|
||||
requested_servers = _extract_requested_mcp_server_ids(object_permission)
|
||||
requested_access_groups = _extract_requested_mcp_access_groups(object_permission)
|
||||
|
||||
# Nothing to validate
|
||||
if not requested_servers and not requested_access_groups:
|
||||
return
|
||||
|
||||
allow_all_keys_servers = _get_allow_all_keys_server_ids()
|
||||
team_allowed_servers = await _get_team_allowed_mcp_servers(team_obj)
|
||||
|
||||
# Combined allowed set = team servers + allow_all_keys servers
|
||||
all_allowed_servers = team_allowed_servers | allow_all_keys_servers
|
||||
|
||||
# Validate requested server IDs
|
||||
if requested_servers:
|
||||
disallowed_servers = requested_servers - all_allowed_servers
|
||||
if disallowed_servers:
|
||||
if team_obj is not None:
|
||||
detail = (
|
||||
f"Key requests MCP servers not allowed by team '{team_obj.team_id}': "
|
||||
f"{sorted(disallowed_servers)}. "
|
||||
f"Team allows: {sorted(team_allowed_servers)}. "
|
||||
f"Global (allow_all_keys) servers: {sorted(allow_all_keys_servers)}."
|
||||
)
|
||||
else:
|
||||
detail = (
|
||||
f"Key is not in a team. Only globally available (allow_all_keys) MCP servers "
|
||||
f"can be assigned: {sorted(allow_all_keys_servers)}. "
|
||||
f"Disallowed servers: {sorted(disallowed_servers)}."
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={"error": detail},
|
||||
)
|
||||
|
||||
# Validate requested access groups (must be subset of team's access groups)
|
||||
if requested_access_groups:
|
||||
team_access_groups: Set[str] = set()
|
||||
if (
|
||||
team_obj is not None
|
||||
and team_obj.object_permission is not None
|
||||
and team_obj.object_permission.mcp_access_groups
|
||||
):
|
||||
team_access_groups = set(team_obj.object_permission.mcp_access_groups)
|
||||
|
||||
disallowed_groups = requested_access_groups - team_access_groups
|
||||
if disallowed_groups:
|
||||
if team_obj is not None:
|
||||
detail = (
|
||||
f"Key requests MCP access groups not allowed by team '{team_obj.team_id}': "
|
||||
f"{sorted(disallowed_groups)}. "
|
||||
f"Team allows: {sorted(team_access_groups)}."
|
||||
)
|
||||
else:
|
||||
detail = (
|
||||
f"Key is not in a team. MCP access groups cannot be assigned to "
|
||||
f"keys outside of a team. Disallowed groups: {sorted(disallowed_groups)}."
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={"error": detail},
|
||||
)
|
||||
@@ -0,0 +1,179 @@
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import (
|
||||
KeyManagementRoutes,
|
||||
LiteLLM_TeamTableCachedObj,
|
||||
LiteLLM_VerificationToken,
|
||||
LiteLLMRoutes,
|
||||
LitellmUserRoles,
|
||||
Member,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.auth_checks import get_team_object
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
DEFAULT_TEAM_MEMBER_PERMISSIONS = [
|
||||
KeyManagementRoutes.KEY_INFO,
|
||||
KeyManagementRoutes.KEY_HEALTH,
|
||||
]
|
||||
|
||||
|
||||
class TeamMemberPermissionChecks:
|
||||
@staticmethod
|
||||
def get_permissions_for_team_member(
|
||||
team_member_object: Member,
|
||||
team_table: LiteLLM_TeamTableCachedObj,
|
||||
) -> List[KeyManagementRoutes]:
|
||||
"""
|
||||
Returns the permissions for a team member
|
||||
"""
|
||||
if team_table.team_member_permissions and isinstance(
|
||||
team_table.team_member_permissions, list
|
||||
):
|
||||
return [
|
||||
KeyManagementRoutes(permission)
|
||||
for permission in team_table.team_member_permissions
|
||||
]
|
||||
|
||||
return DEFAULT_TEAM_MEMBER_PERMISSIONS
|
||||
|
||||
@staticmethod
|
||||
def _get_list_of_route_enum_as_str(
|
||||
route_enum: List[KeyManagementRoutes],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns a list of the route enum as a list of strings
|
||||
"""
|
||||
return [route.value for route in route_enum]
|
||||
|
||||
@staticmethod
|
||||
async def can_team_member_execute_key_management_endpoint(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
route: KeyManagementRoutes,
|
||||
prisma_client: PrismaClient,
|
||||
user_api_key_cache: DualCache,
|
||||
existing_key_row: LiteLLM_VerificationToken,
|
||||
):
|
||||
"""
|
||||
Main handler for checking if a team member can update a key
|
||||
"""
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
_get_user_in_team,
|
||||
)
|
||||
|
||||
# 1. Don't execute these checks if the user role is proxy admin
|
||||
if user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value:
|
||||
return
|
||||
|
||||
# 2. Check if the operation is being done on a team key
|
||||
if existing_key_row.team_id is None:
|
||||
return
|
||||
|
||||
# 3. Get Team Object from DB
|
||||
team_table = await get_team_object(
|
||||
team_id=existing_key_row.team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
check_db_only=True,
|
||||
)
|
||||
|
||||
# 4. Extract `Member` object from `team_table`
|
||||
key_assigned_user_in_team = _get_user_in_team(
|
||||
team_table=team_table, user_id=user_api_key_dict.user_id
|
||||
)
|
||||
|
||||
# 5. Check if the team member has permissions for the endpoint
|
||||
TeamMemberPermissionChecks.does_team_member_have_permissions_for_endpoint(
|
||||
team_member_object=key_assigned_user_in_team,
|
||||
team_table=team_table,
|
||||
route=route,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def does_team_member_have_permissions_for_endpoint(
|
||||
team_member_object: Optional[Member],
|
||||
team_table: LiteLLM_TeamTableCachedObj,
|
||||
route: str,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
Raises an exception if the team member does not have permissions for calling the endpoint for a team
|
||||
"""
|
||||
|
||||
# permission checks only run for non-admin users
|
||||
# Non-Admin user trying to access information about a team's key
|
||||
if team_member_object is None:
|
||||
return False
|
||||
if team_member_object.role == "admin":
|
||||
return True
|
||||
|
||||
_team_member_permissions = (
|
||||
TeamMemberPermissionChecks.get_permissions_for_team_member(
|
||||
team_member_object=team_member_object,
|
||||
team_table=team_table,
|
||||
)
|
||||
)
|
||||
team_member_permissions = (
|
||||
TeamMemberPermissionChecks._get_list_of_route_enum_as_str(
|
||||
_team_member_permissions
|
||||
)
|
||||
)
|
||||
|
||||
if not RouteChecks.check_route_access(
|
||||
route=route, allowed_routes=team_member_permissions
|
||||
):
|
||||
raise ProxyException(
|
||||
message=f"Team member does not have permissions for endpoint: {route}. You only have access to the following endpoints: {team_member_permissions} for team {team_table.team_id}. To create keys for this team, please ask your proxy admin to check the team member permission settings and update the settings to allow team member users to create keys.",
|
||||
type=ProxyErrorTypes.team_member_permission_error,
|
||||
param=route,
|
||||
code=401,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def user_belongs_to_keys_team(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
existing_key_row: LiteLLM_VerificationToken,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if the user belongs to the team that the key is assigned to
|
||||
"""
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
_get_user_in_team,
|
||||
)
|
||||
from litellm.proxy.proxy_server import prisma_client, user_api_key_cache
|
||||
|
||||
if existing_key_row.team_id is None:
|
||||
return False
|
||||
team_table = await get_team_object(
|
||||
team_id=existing_key_row.team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||
check_db_only=True,
|
||||
)
|
||||
|
||||
# 4. Extract `Member` object from `team_table`
|
||||
team_member_object = _get_user_in_team(
|
||||
team_table=team_table, user_id=user_api_key_dict.user_id
|
||||
)
|
||||
return team_member_object is not None
|
||||
|
||||
@staticmethod
|
||||
def get_all_available_team_member_permissions() -> List[str]:
|
||||
"""
|
||||
Returns all available team member permissions
|
||||
"""
|
||||
all_available_permissions = []
|
||||
for route in LiteLLMRoutes.key_management_routes.value:
|
||||
all_available_permissions.append(route)
|
||||
return all_available_permissions
|
||||
|
||||
@staticmethod
|
||||
def default_team_member_permissions() -> List[str]:
|
||||
return [route.value for route in DEFAULT_TEAM_MEMBER_PERMISSIONS]
|
||||
@@ -0,0 +1,47 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import CommonProxyErrors, InvitationNew, UserAPIKeyAuth
|
||||
|
||||
|
||||
async def create_invitation_for_user(
|
||||
data: InvitationNew,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
"""
|
||||
Create an invitation for the user to onboard to LiteLLM Admin UI.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import litellm_proxy_admin_name, prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
current_time = litellm.utils.get_utc_datetime()
|
||||
expires_at = current_time + timedelta(days=7)
|
||||
|
||||
try:
|
||||
response = await prisma_client.db.litellm_invitationlink.create(
|
||||
data={
|
||||
"user_id": data.user_id,
|
||||
"created_at": current_time,
|
||||
"expires_at": expires_at,
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_at": current_time,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
} # type: ignore
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
if "Foreign key constraint failed on the field" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "User id does not exist in 'LiteLLM_UserTable'. Fix this by creating user via `/user/new`."
|
||||
},
|
||||
)
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
@@ -0,0 +1,459 @@
|
||||
# What is this?
|
||||
## Helper utils for the management endpoints (keys/users/teams)
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.proxy._types import ( # key request types; user request types; team request types; customer request types
|
||||
BudgetNewRequest,
|
||||
DeleteCustomerRequest,
|
||||
DeleteTeamRequest,
|
||||
DeleteUserRequest,
|
||||
KeyRequest,
|
||||
LiteLLM_BudgetTable,
|
||||
LiteLLM_TeamMembership,
|
||||
LiteLLM_UserTable,
|
||||
ManagementEndpointLoggingPayload,
|
||||
Member,
|
||||
SSOUserDefinedValues,
|
||||
UpdateCustomerRequest,
|
||||
UpdateKeyRequest,
|
||||
UpdateTeamRequest,
|
||||
UpdateUserRequest,
|
||||
UserAPIKeyAuth,
|
||||
VirtualKeyEvent,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
|
||||
def get_new_internal_user_defaults(
|
||||
user_id: str, user_email: Optional[str] = None
|
||||
) -> dict:
|
||||
user_info = litellm.default_internal_user_params or {}
|
||||
|
||||
returned_dict: SSOUserDefinedValues = {
|
||||
"models": user_info.get("models") or [],
|
||||
"max_budget": user_info.get("max_budget", litellm.max_internal_user_budget),
|
||||
"budget_duration": user_info.get(
|
||||
"budget_duration", litellm.internal_user_budget_duration
|
||||
),
|
||||
"user_email": user_email or user_info.get("user_email", None),
|
||||
"user_id": user_id,
|
||||
"user_role": "internal_user",
|
||||
}
|
||||
|
||||
non_null_dict = {}
|
||||
for k, v in returned_dict.items():
|
||||
if v is not None:
|
||||
non_null_dict[k] = v
|
||||
return non_null_dict
|
||||
|
||||
|
||||
async def handle_budget_for_entity(
|
||||
data,
|
||||
existing_budget_id: Optional[str],
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
litellm_proxy_admin_name: str,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Common helper to handle budget creation/updates for entities (organizations, tags, etc).
|
||||
|
||||
This function:
|
||||
1. Creates a new budget if budget_id is None but budget fields are provided
|
||||
2. Updates an existing budget if budget fields are provided and budget_id exists
|
||||
3. Returns the budget_id to use (existing or newly created)
|
||||
|
||||
Args:
|
||||
data: The request object (e.g., TagNewRequest, NewOrganizationRequest, etc.) containing budget fields
|
||||
existing_budget_id: The existing budget_id if updating an entity, None if creating new
|
||||
user_api_key_dict: User authentication info
|
||||
prisma_client: Database client
|
||||
litellm_proxy_admin_name: Admin name for audit trail
|
||||
|
||||
Returns:
|
||||
Optional[str]: The budget_id to use, or None if no budget was created/updated
|
||||
"""
|
||||
from litellm.proxy.management_endpoints.budget_management_endpoints import (
|
||||
update_budget,
|
||||
)
|
||||
|
||||
# Get all budget field names
|
||||
budget_params = LiteLLM_BudgetTable.model_fields.keys()
|
||||
|
||||
# Extract budget fields from data
|
||||
_json_data = (
|
||||
data.model_dump(exclude_none=True) if hasattr(data, "model_dump") else data
|
||||
)
|
||||
_budget_data = {k: v for k, v in _json_data.items() if k in budget_params}
|
||||
|
||||
# Check if budget_id is explicitly provided in the data
|
||||
data_budget_id = getattr(data, "budget_id", None)
|
||||
|
||||
# Case 1: Creating new entity - no existing budget_id
|
||||
if existing_budget_id is None:
|
||||
if data_budget_id is not None:
|
||||
# Use the provided budget_id
|
||||
return data_budget_id
|
||||
elif _budget_data:
|
||||
# Create a new budget with the provided fields
|
||||
budget_row = LiteLLM_BudgetTable(**_budget_data)
|
||||
new_budget_data = prisma_client.jsonify_object(
|
||||
budget_row.model_dump(exclude_none=True)
|
||||
)
|
||||
|
||||
_budget = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**new_budget_data, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
) # type: ignore
|
||||
|
||||
return _budget.budget_id
|
||||
else:
|
||||
# No budget fields provided, no budget to create
|
||||
return None
|
||||
|
||||
# Case 2: Updating existing entity - has existing budget_id
|
||||
else:
|
||||
# If budget fields are provided, update the existing budget
|
||||
if _budget_data:
|
||||
await update_budget(
|
||||
budget_obj=BudgetNewRequest(
|
||||
budget_id=existing_budget_id, **_budget_data
|
||||
),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
# If a different budget_id is explicitly provided, use that instead
|
||||
if data_budget_id is not None and data_budget_id != existing_budget_id:
|
||||
return data_budget_id
|
||||
|
||||
# Otherwise, keep using the existing budget_id
|
||||
return existing_budget_id
|
||||
|
||||
|
||||
async def add_new_member(
|
||||
new_member: Member,
|
||||
max_budget_in_team: Optional[float],
|
||||
prisma_client: PrismaClient,
|
||||
team_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
litellm_proxy_admin_name: str,
|
||||
default_team_budget_id: Optional[str] = None,
|
||||
) -> Tuple[LiteLLM_UserTable, Optional[LiteLLM_TeamMembership]]:
|
||||
"""
|
||||
Add a new member to a team
|
||||
|
||||
- add team id to user table
|
||||
- add team member w/ budget to team member table
|
||||
|
||||
Returns created/existing user + team membership w/ budget id
|
||||
"""
|
||||
returned_user: Optional[LiteLLM_UserTable] = None
|
||||
returned_team_membership: Optional[LiteLLM_TeamMembership] = None
|
||||
## ADD TEAM ID, to USER TABLE IF NEW ##
|
||||
if new_member.user_id is not None:
|
||||
new_user_defaults = get_new_internal_user_defaults(user_id=new_member.user_id)
|
||||
_returned_user = await prisma_client.db.litellm_usertable.upsert(
|
||||
where={"user_id": new_member.user_id},
|
||||
data={
|
||||
"update": {"teams": {"push": [team_id]}},
|
||||
"create": {"teams": [team_id], **new_user_defaults}, # type: ignore
|
||||
},
|
||||
)
|
||||
if _returned_user is not None:
|
||||
returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
|
||||
elif new_member.user_email is not None:
|
||||
new_user_defaults = get_new_internal_user_defaults(
|
||||
user_id=str(uuid.uuid4()), user_email=new_member.user_email
|
||||
)
|
||||
## user email is not unique acc. to prisma schema -> future improvement
|
||||
### for now: check if it exists in db, if not - insert it
|
||||
existing_user_row: Optional[list] = await prisma_client.get_data(
|
||||
key_val={"user_email": new_member.user_email},
|
||||
table_name="user",
|
||||
query_type="find_all",
|
||||
)
|
||||
if existing_user_row is None or (
|
||||
isinstance(existing_user_row, list) and len(existing_user_row) == 0
|
||||
):
|
||||
new_user_defaults["teams"] = [team_id]
|
||||
_returned_user = await prisma_client.insert_data(data=new_user_defaults, table_name="user") # type: ignore
|
||||
|
||||
if _returned_user is not None:
|
||||
returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
|
||||
elif len(existing_user_row) == 1:
|
||||
user_info = existing_user_row[0]
|
||||
_returned_user = await prisma_client.db.litellm_usertable.update(
|
||||
where={"user_id": user_info.user_id}, # type: ignore
|
||||
data={"teams": {"push": [team_id]}},
|
||||
)
|
||||
if _returned_user is not None:
|
||||
returned_user = LiteLLM_UserTable(**_returned_user.model_dump())
|
||||
elif len(existing_user_row) > 1:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Multiple users with this email found in db. Please use 'user_id' instead."
|
||||
},
|
||||
)
|
||||
|
||||
# Check if trying to set a budget for team member
|
||||
|
||||
if max_budget_in_team is not None:
|
||||
# create a new budget item for this member
|
||||
response = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
"max_budget": max_budget_in_team,
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
)
|
||||
|
||||
_budget_id = response.budget_id
|
||||
else:
|
||||
_budget_id = default_team_budget_id
|
||||
|
||||
if _budget_id and returned_user is not None and returned_user.user_id is not None:
|
||||
_returned_team_membership = (
|
||||
await prisma_client.db.litellm_teammembership.create(
|
||||
data={
|
||||
"team_id": team_id,
|
||||
"user_id": returned_user.user_id,
|
||||
"budget_id": _budget_id,
|
||||
},
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
)
|
||||
|
||||
returned_team_membership = LiteLLM_TeamMembership(
|
||||
**_returned_team_membership.model_dump()
|
||||
)
|
||||
|
||||
if returned_user is None:
|
||||
raise Exception("Unable to update user table with membership information!")
|
||||
|
||||
return returned_user, returned_team_membership
|
||||
|
||||
|
||||
def _delete_user_id_from_cache(kwargs):
|
||||
from litellm.proxy.proxy_server import user_api_key_cache
|
||||
|
||||
if kwargs.get("data") is not None:
|
||||
update_user_request = kwargs.get("data")
|
||||
if isinstance(update_user_request, UpdateUserRequest):
|
||||
user_api_key_cache.delete_cache(key=update_user_request.user_id)
|
||||
|
||||
# delete user request
|
||||
if isinstance(update_user_request, DeleteUserRequest):
|
||||
for user_id in update_user_request.user_ids:
|
||||
user_api_key_cache.delete_cache(key=user_id)
|
||||
pass
|
||||
|
||||
|
||||
def _delete_api_key_from_cache(kwargs):
|
||||
from litellm.proxy.proxy_server import user_api_key_cache
|
||||
|
||||
if kwargs.get("data") is not None:
|
||||
update_request = kwargs.get("data")
|
||||
if isinstance(update_request, UpdateKeyRequest):
|
||||
user_api_key_cache.delete_cache(key=update_request.key)
|
||||
|
||||
# delete key request
|
||||
if isinstance(update_request, KeyRequest) and update_request.keys:
|
||||
for key in update_request.keys:
|
||||
user_api_key_cache.delete_cache(key=key)
|
||||
pass
|
||||
|
||||
|
||||
def _delete_team_id_from_cache(kwargs):
|
||||
from litellm.proxy.proxy_server import user_api_key_cache
|
||||
|
||||
if kwargs.get("data") is not None:
|
||||
update_request = kwargs.get("data")
|
||||
if isinstance(update_request, UpdateTeamRequest):
|
||||
user_api_key_cache.delete_cache(key=update_request.team_id)
|
||||
|
||||
# delete team request
|
||||
if isinstance(update_request, DeleteTeamRequest):
|
||||
for team_id in update_request.team_ids:
|
||||
user_api_key_cache.delete_cache(key=team_id)
|
||||
pass
|
||||
|
||||
|
||||
def _delete_customer_id_from_cache(kwargs):
|
||||
from litellm.proxy.proxy_server import user_api_key_cache
|
||||
|
||||
if kwargs.get("data") is not None:
|
||||
update_request = kwargs.get("data")
|
||||
if isinstance(update_request, UpdateCustomerRequest):
|
||||
user_api_key_cache.delete_cache(key=update_request.user_id)
|
||||
|
||||
# delete customer request
|
||||
if isinstance(update_request, DeleteCustomerRequest):
|
||||
for user_id in update_request.user_ids:
|
||||
user_api_key_cache.delete_cache(key=user_id)
|
||||
pass
|
||||
|
||||
|
||||
async def send_management_endpoint_alert(
|
||||
request_kwargs: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
function_name: str,
|
||||
):
|
||||
"""
|
||||
Sends a slack alert when:
|
||||
- A virtual key is created, updated, or deleted
|
||||
- An internal user is created, updated, or deleted
|
||||
- A team is created, updated, or deleted
|
||||
"""
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj
|
||||
from litellm.types.integrations.slack_alerting import AlertType
|
||||
|
||||
management_function_to_event_name = {
|
||||
"generate_key_fn": AlertType.new_virtual_key_created,
|
||||
"update_key_fn": AlertType.virtual_key_updated,
|
||||
"delete_key_fn": AlertType.virtual_key_deleted,
|
||||
# Team events
|
||||
"new_team": AlertType.new_team_created,
|
||||
"update_team": AlertType.team_updated,
|
||||
"delete_team": AlertType.team_deleted,
|
||||
# Internal User events
|
||||
"new_user": AlertType.new_internal_user_created,
|
||||
"user_update": AlertType.internal_user_updated,
|
||||
"delete_user": AlertType.internal_user_deleted,
|
||||
}
|
||||
|
||||
# Check if alerting is enabled
|
||||
if (
|
||||
proxy_logging_obj is not None
|
||||
and proxy_logging_obj.slack_alerting_instance is not None
|
||||
):
|
||||
# Virtual Key Events
|
||||
if function_name in management_function_to_event_name:
|
||||
_event_name: AlertType = management_function_to_event_name[function_name]
|
||||
|
||||
key_event = VirtualKeyEvent(
|
||||
created_by_user_id=user_api_key_dict.user_id or "Unknown",
|
||||
created_by_user_role=user_api_key_dict.user_role or "Unknown",
|
||||
created_by_key_alias=user_api_key_dict.key_alias,
|
||||
request_kwargs=request_kwargs,
|
||||
)
|
||||
|
||||
# replace all "_" with " " and capitalize
|
||||
event_name = _event_name.replace("_", " ").title()
|
||||
await proxy_logging_obj.slack_alerting_instance.send_virtual_key_event_slack(
|
||||
key_event=key_event,
|
||||
event_name=event_name,
|
||||
alert_type=_event_name,
|
||||
)
|
||||
|
||||
|
||||
def management_endpoint_wrapper(func):
|
||||
"""
|
||||
This wrapper does the following:
|
||||
|
||||
1. Log I/O, Exceptions to OTEL
|
||||
2. Create an Audit log for success calls
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = datetime.now()
|
||||
_http_request: Optional[Request] = None
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
end_time = datetime.now()
|
||||
try:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
user_api_key_dict: UserAPIKeyAuth = (
|
||||
kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
|
||||
)
|
||||
|
||||
await send_management_endpoint_alert(
|
||||
request_kwargs=kwargs,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
function_name=func.__name__,
|
||||
)
|
||||
_http_request = kwargs.get("http_request", None)
|
||||
parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
|
||||
if parent_otel_span is not None:
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if open_telemetry_logger is not None:
|
||||
if _http_request:
|
||||
_route = _http_request.url.path
|
||||
_request_body: dict = await _read_request_body(
|
||||
request=_http_request
|
||||
)
|
||||
_response = dict(result) if result is not None else None
|
||||
|
||||
logging_payload = ManagementEndpointLoggingPayload(
|
||||
route=_route,
|
||||
request_data=_request_body,
|
||||
response=_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
await open_telemetry_logger.async_management_endpoint_success_hook( # type: ignore
|
||||
logging_payload=logging_payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
# Delete updated/deleted info from cache
|
||||
_delete_api_key_from_cache(kwargs=kwargs)
|
||||
_delete_user_id_from_cache(kwargs=kwargs)
|
||||
_delete_team_id_from_cache(kwargs=kwargs)
|
||||
_delete_customer_id_from_cache(kwargs=kwargs)
|
||||
except Exception as e:
|
||||
# Non-Blocking Exception
|
||||
verbose_logger.debug("Error in management endpoint wrapper: %s", str(e))
|
||||
pass
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
end_time = datetime.now()
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
user_api_key_dict: UserAPIKeyAuth = (
|
||||
kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
|
||||
)
|
||||
parent_otel_span = getattr(user_api_key_dict, "parent_otel_span", None)
|
||||
if parent_otel_span is not None:
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if open_telemetry_logger is not None:
|
||||
_http_request = kwargs.get("http_request")
|
||||
if _http_request:
|
||||
_route = _http_request.url.path
|
||||
_request_body: dict = await _read_request_body(
|
||||
request=_http_request
|
||||
)
|
||||
logging_payload = ManagementEndpointLoggingPayload(
|
||||
route=_route,
|
||||
request_data=_request_body,
|
||||
response=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
exception=e,
|
||||
)
|
||||
|
||||
await open_telemetry_logger.async_management_endpoint_failure_hook( # type: ignore
|
||||
logging_payload=logging_payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
Reference in New Issue
Block a user