import hashlib import json from datetime import datetime, timezone from typing import Any, Dict, List, Optional import litellm from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy.management_helpers.object_permission_utils import ( handle_update_object_permission_common, ) from litellm.proxy.utils import PrismaClient from litellm.types.agents import AgentConfig, AgentResponse, PatchAgentRequest class AgentRegistry: def __init__(self): self.agent_list: List[AgentResponse] = [] def reset_agent_list(self): self.agent_list = [] def register_agent(self, agent_config: AgentResponse): self.agent_list.append(agent_config) def deregister_agent(self, agent_name: str): self.agent_list = [ agent for agent in self.agent_list if agent.agent_name != agent_name ] def get_agent_list(self, agent_names: Optional[List[str]] = None): if agent_names is not None: return [ agent for agent in self.agent_list if agent.agent_name in agent_names ] return self.agent_list def get_public_agent_list(self) -> List[AgentResponse]: public_agent_list: List[AgentResponse] = [] if litellm.public_agent_groups is None: return public_agent_list for agent in self.agent_list: if agent.agent_id in litellm.public_agent_groups: public_agent_list.append(agent) return public_agent_list def _create_agent_id(self, agent_config: AgentConfig) -> str: return hashlib.sha256( json.dumps(agent_config, sort_keys=True).encode() ).hexdigest() def load_agents_from_config(self, agent_config: Optional[List[AgentConfig]] = None): if agent_config is None: return None for agent_config_item in agent_config: if not isinstance(agent_config_item, dict): raise ValueError("agent_config must be a list of dictionaries") agent_name = agent_config_item.get("agent_name") agent_card_params = agent_config_item.get("agent_card_params") if not all([agent_name, agent_card_params]): continue # create a stable hash id for config item config_hash = self._create_agent_id(agent_config_item) self.register_agent(agent_config=AgentResponse(agent_id=config_hash, **agent_config_item)) # type: ignore def load_agents_from_db_and_config( self, agent_config: Optional[List[AgentConfig]] = None, db_agents: Optional[List[Dict[str, Any]]] = None, ): self.reset_agent_list() if agent_config: for agent_config_item in agent_config: if not isinstance(agent_config_item, dict): raise ValueError("agent_config must be a list of dictionaries") self.register_agent(agent_config=AgentResponse(agent_id=self._create_agent_id(agent_config_item), **agent_config_item)) # type: ignore if db_agents: for db_agent in db_agents: if not isinstance(db_agent, dict): raise ValueError("db_agents must be a list of dictionaries") self.register_agent(agent_config=AgentResponse(**db_agent)) # type: ignore return self.agent_list ########################################################### ########### DB management helpers for agents ########### ############################################################ async def add_agent_to_db( self, agent: AgentConfig, prisma_client: PrismaClient, created_by: str ) -> AgentResponse: """ Add an agent to the database """ try: agent_name = agent.get("agent_name") # Serialize litellm_params litellm_params_obj: Any = agent.get("litellm_params", {}) if hasattr(litellm_params_obj, "model_dump"): litellm_params_dict = litellm_params_obj.model_dump() else: litellm_params_dict = ( dict(litellm_params_obj) if litellm_params_obj else {} ) litellm_params: str = safe_dumps(litellm_params_dict) # Serialize agent_card_params agent_card_params_obj: Any = agent.get("agent_card_params", {}) if hasattr(agent_card_params_obj, "model_dump"): agent_card_params_dict = agent_card_params_obj.model_dump() else: agent_card_params_dict = ( dict(agent_card_params_obj) if agent_card_params_obj else {} ) agent_card_params: str = safe_dumps(agent_card_params_dict) # Handle object_permission (MCP tool access for agent) object_permission_id: Optional[str] = None if agent.get("object_permission") is not None: agent_copy = dict(agent) object_permission_id = await handle_update_object_permission_common( agent_copy, None, prisma_client ) # Serialize static_headers static_headers_obj = agent.get("static_headers") static_headers_val: Optional[str] = ( safe_dumps(dict(static_headers_obj)) if static_headers_obj else None ) extra_headers_val: Optional[List[str]] = agent.get("extra_headers") create_data: Dict[str, Any] = { "agent_name": agent_name, "litellm_params": litellm_params, "agent_card_params": agent_card_params, "created_by": created_by, "updated_by": created_by, "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc), } if static_headers_val is not None: create_data["static_headers"] = static_headers_val if extra_headers_val is not None: create_data["extra_headers"] = extra_headers_val if object_permission_id is not None: create_data["object_permission_id"] = object_permission_id for rate_field in ( "tpm_limit", "rpm_limit", "session_tpm_limit", "session_rpm_limit", ): _val = agent.get(rate_field) if _val is not None: create_data[rate_field] = _val # Create agent in DB created_agent = await prisma_client.db.litellm_agentstable.create( data=create_data, include={"object_permission": True}, ) created_agent_dict = created_agent.model_dump() if created_agent.object_permission is not None: try: created_agent_dict[ "object_permission" ] = created_agent.object_permission.model_dump() except Exception: created_agent_dict[ "object_permission" ] = created_agent.object_permission.dict() return AgentResponse(**created_agent_dict) # type: ignore except Exception as e: raise Exception(f"Error adding agent to DB: {str(e)}") async def delete_agent_from_db( self, agent_id: str, prisma_client: PrismaClient ) -> Dict[str, Any]: """ Delete an agent from the database """ try: deleted_agent = await prisma_client.db.litellm_agentstable.delete( where={"agent_id": agent_id} ) return dict(deleted_agent) except Exception as e: raise Exception(f"Error deleting agent from DB: {str(e)}") async def patch_agent_in_db( self, agent_id: str, agent: PatchAgentRequest, prisma_client: PrismaClient, updated_by: str, ) -> AgentResponse: """ Patch an agent in the database. Get the existing agent from the database and patch it with the new values. Args: agent_id: The ID of the agent to patch agent: The new agent values to patch prisma_client: The Prisma client to use updated_by: The user ID of the user who is patching the agent Returns: The patched agent """ try: existing_agent = await prisma_client.db.litellm_agentstable.find_unique( where={"agent_id": agent_id} ) if existing_agent is not None: existing_agent = dict(existing_agent) if existing_agent is None: raise Exception(f"Agent with ID {agent_id} not found") augment_agent = {**existing_agent, **agent} update_data: Dict[str, Any] = {} if augment_agent.get("agent_name"): update_data["agent_name"] = augment_agent.get("agent_name") if augment_agent.get("litellm_params"): update_data["litellm_params"] = safe_dumps( augment_agent.get("litellm_params") ) if augment_agent.get("agent_card_params"): update_data["agent_card_params"] = safe_dumps( augment_agent.get("agent_card_params") ) for rate_field in ( "tpm_limit", "rpm_limit", "session_tpm_limit", "session_rpm_limit", ): if rate_field in agent: update_data[rate_field] = agent.get(rate_field) if "static_headers" in agent: headers_value = agent.get("static_headers") update_data["static_headers"] = safe_dumps( dict(headers_value) if headers_value is not None else {} ) if "extra_headers" in agent: extra_headers_value = agent.get("extra_headers") update_data["extra_headers"] = ( extra_headers_value if extra_headers_value is not None else [] ) if agent.get("object_permission") is not None: agent_copy = dict(augment_agent) existing_object_permission_id = existing_agent.get( "object_permission_id" ) object_permission_id = await handle_update_object_permission_common( agent_copy, existing_object_permission_id, prisma_client, ) if object_permission_id is not None: update_data["object_permission_id"] = object_permission_id # Patch agent in DB patched_agent = await prisma_client.db.litellm_agentstable.update( where={"agent_id": agent_id}, data={ **update_data, "updated_by": updated_by, "updated_at": datetime.now(timezone.utc), }, include={"object_permission": True}, ) patched_agent_dict = patched_agent.model_dump() if patched_agent.object_permission is not None: try: patched_agent_dict[ "object_permission" ] = patched_agent.object_permission.model_dump() except Exception: patched_agent_dict[ "object_permission" ] = patched_agent.object_permission.dict() return AgentResponse(**patched_agent_dict) # type: ignore except Exception as e: raise Exception(f"Error patching agent in DB: {str(e)}") async def update_agent_in_db( self, agent_id: str, agent: AgentConfig, prisma_client: PrismaClient, updated_by: str, ) -> AgentResponse: """ Update an agent in the database """ try: agent_name = agent.get("agent_name") # Serialize litellm_params litellm_params_obj: Any = agent.get("litellm_params", {}) if hasattr(litellm_params_obj, "model_dump"): litellm_params_dict = litellm_params_obj.model_dump() else: litellm_params_dict = ( dict(litellm_params_obj) if litellm_params_obj else {} ) litellm_params: str = safe_dumps(litellm_params_dict) # Serialize agent_card_params agent_card_params_obj: Any = agent.get("agent_card_params", {}) if hasattr(agent_card_params_obj, "model_dump"): agent_card_params_dict = agent_card_params_obj.model_dump() else: agent_card_params_dict = ( dict(agent_card_params_obj) if agent_card_params_obj else {} ) agent_card_params: str = safe_dumps(agent_card_params_dict) # Serialize static_headers for update static_headers_obj_u = agent.get("static_headers") static_headers_val_u: str = ( safe_dumps(dict(static_headers_obj_u)) if static_headers_obj_u is not None else safe_dumps({}) ) extra_headers_val_u: List[str] = agent.get("extra_headers") or [] update_data: Dict[str, Any] = { "agent_name": agent_name, "litellm_params": litellm_params, "agent_card_params": agent_card_params, "static_headers": static_headers_val_u, "extra_headers": extra_headers_val_u, "updated_by": updated_by, "updated_at": datetime.now(timezone.utc), } for rate_field in ( "tpm_limit", "rpm_limit", "session_tpm_limit", "session_rpm_limit", ): _val = agent.get(rate_field) if _val is not None: update_data[rate_field] = _val if agent.get("object_permission") is not None: existing_agent = await prisma_client.db.litellm_agentstable.find_unique( where={"agent_id": agent_id} ) existing_object_permission_id = ( existing_agent.object_permission_id if existing_agent is not None else None ) agent_copy = dict(agent) object_permission_id = await handle_update_object_permission_common( agent_copy, existing_object_permission_id, prisma_client, ) if object_permission_id is not None: update_data["object_permission_id"] = object_permission_id # Update agent in DB updated_agent = await prisma_client.db.litellm_agentstable.update( where={"agent_id": agent_id}, data=update_data, include={"object_permission": True}, ) updated_agent_dict = updated_agent.model_dump() if updated_agent.object_permission is not None: try: updated_agent_dict[ "object_permission" ] = updated_agent.object_permission.model_dump() except Exception: updated_agent_dict[ "object_permission" ] = updated_agent.object_permission.dict() return AgentResponse(**updated_agent_dict) # type: ignore except Exception as e: raise Exception(f"Error updating agent in DB: {str(e)}") @staticmethod async def get_all_agents_from_db( prisma_client: PrismaClient, ) -> List[Dict[str, Any]]: """ Get all agents from the database """ try: agents_from_db = await prisma_client.db.litellm_agentstable.find_many( order={"created_at": "desc"}, include={"object_permission": True}, ) agents: List[Dict[str, Any]] = [] for agent in agents_from_db: agent_dict = dict(agent) # object_permission is eagerly loaded via include above if agent.object_permission is not None: try: agent_dict[ "object_permission" ] = agent.object_permission.model_dump() except Exception: agent_dict["object_permission"] = agent.object_permission.dict() agents.append(agent_dict) return agents except Exception as e: raise Exception(f"Error getting agents from DB: {str(e)}") def get_agent_by_id( self, agent_id: str, ) -> Optional[AgentResponse]: """ Get an agent by its ID from the database """ try: for agent in self.agent_list: if agent.agent_id == agent_id: return agent return None except Exception as e: raise Exception(f"Error getting agent from DB: {str(e)}") def get_agent_by_name(self, agent_name: str) -> Optional[AgentResponse]: """ Get an agent by its name from the database """ try: for agent in self.agent_list: if agent.agent_name == agent_name: return agent return None except Exception as e: raise Exception(f"Error getting agent from DB: {str(e)}") global_agent_registry = AgentRegistry()