chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
This hook is used to inject cache control directives into the messages of a chat completion.
|
||||
|
||||
Users can define
|
||||
- `cache_control_injection_points` in the completion params and litellm will inject the cache control directives into the messages at the specified injection points.
|
||||
|
||||
"""
|
||||
|
||||
import copy
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.custom_prompt_management import CustomPromptManagement
|
||||
from litellm.integrations.prompt_management_base import PromptManagementClient
|
||||
from litellm.types.integrations.anthropic_cache_control_hook import (
|
||||
CacheControlInjectionPoint,
|
||||
CacheControlMessageInjectionPoint,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionCachedContent
|
||||
from litellm.types.prompts.init_prompts import PromptSpec
|
||||
from litellm.types.utils import StandardCallbackDynamicParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AnthropicCacheControlHook(CustomPromptManagement):
|
||||
def get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""
|
||||
Apply cache control directives based on specified injection points.
|
||||
|
||||
Returns:
|
||||
- model: str - the model to use
|
||||
- messages: List[AllMessageValues] - messages with applied cache controls
|
||||
- non_default_params: dict - params with any global cache controls
|
||||
"""
|
||||
# Extract cache control injection points
|
||||
injection_points: List[CacheControlInjectionPoint] = non_default_params.pop(
|
||||
"cache_control_injection_points", []
|
||||
)
|
||||
if not injection_points:
|
||||
return model, messages, non_default_params
|
||||
|
||||
# Create a deep copy of messages to avoid modifying the original list
|
||||
processed_messages = copy.deepcopy(messages)
|
||||
|
||||
# Process message-level cache controls
|
||||
for point in injection_points:
|
||||
if point.get("location") == "message":
|
||||
point = cast(CacheControlMessageInjectionPoint, point)
|
||||
processed_messages = self._process_message_injection(
|
||||
point=point, messages=processed_messages
|
||||
)
|
||||
|
||||
return model, processed_messages, non_default_params
|
||||
|
||||
@staticmethod
|
||||
def _process_message_injection(
|
||||
point: CacheControlMessageInjectionPoint, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""Process message-level cache control injection."""
|
||||
control: ChatCompletionCachedContent = point.get(
|
||||
"control", None
|
||||
) or ChatCompletionCachedContent(type="ephemeral")
|
||||
|
||||
_targetted_index: Optional[Union[int, str]] = point.get("index", None)
|
||||
targetted_index: Optional[int] = None
|
||||
if isinstance(_targetted_index, str):
|
||||
try:
|
||||
targetted_index = int(_targetted_index)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
targetted_index = _targetted_index
|
||||
|
||||
targetted_role = point.get("role", None)
|
||||
|
||||
# Case 1: Target by specific index
|
||||
if targetted_index is not None:
|
||||
original_index = targetted_index
|
||||
# Handle negative indices (convert to positive)
|
||||
if targetted_index < 0:
|
||||
targetted_index += len(messages)
|
||||
|
||||
if 0 <= targetted_index < len(messages):
|
||||
messages[
|
||||
targetted_index
|
||||
] = AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
||||
messages[targetted_index], control
|
||||
)
|
||||
else:
|
||||
verbose_logger.warning(
|
||||
f"AnthropicCacheControlHook: Provided index {original_index} is out of bounds for message list of length {len(messages)}. "
|
||||
f"Targeted index was {targetted_index}. Skipping cache control injection for this point."
|
||||
)
|
||||
# Case 2: Target by role
|
||||
elif targetted_role is not None:
|
||||
for msg in messages:
|
||||
if msg.get("role") == targetted_role:
|
||||
msg = (
|
||||
AnthropicCacheControlHook._safe_insert_cache_control_in_message(
|
||||
message=msg, control=control
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
@staticmethod
|
||||
def _safe_insert_cache_control_in_message(
|
||||
message: AllMessageValues, control: ChatCompletionCachedContent
|
||||
) -> AllMessageValues:
|
||||
"""
|
||||
Safe way to insert cache control in a message
|
||||
|
||||
OpenAI Message content can be either:
|
||||
- string
|
||||
- list of objects
|
||||
|
||||
This method handles inserting cache control in both cases.
|
||||
Per Anthropic's API specification, when using multiple content blocks,
|
||||
only the last content block can have cache_control.
|
||||
"""
|
||||
message_content = message.get("content", None)
|
||||
|
||||
# 1. if string, insert cache control in the message
|
||||
if isinstance(message_content, str):
|
||||
message["cache_control"] = control # type: ignore
|
||||
# 2. list of objects - only apply to last item per Anthropic spec
|
||||
elif isinstance(message_content, list):
|
||||
if len(message_content) > 0 and isinstance(message_content[-1], dict):
|
||||
message_content[-1]["cache_control"] = control # type: ignore
|
||||
return message
|
||||
|
||||
@property
|
||||
def integration_name(self) -> str:
|
||||
"""Return the integration name for this hook."""
|
||||
return "anthropic_cache_control_hook"
|
||||
|
||||
def should_run_prompt_management(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
) -> bool:
|
||||
"""Always return False since this is not a true prompt management system."""
|
||||
return False
|
||||
|
||||
def _compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_spec: Optional[PromptSpec],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""Not used - this hook only modifies messages, doesn't fetch prompts."""
|
||||
return PromptManagementClient(
|
||||
prompt_id=prompt_id,
|
||||
prompt_template=[],
|
||||
prompt_template_model=None,
|
||||
prompt_template_optional_params=None,
|
||||
completed_messages=None,
|
||||
)
|
||||
|
||||
async def async_compile_prompt_helper(
|
||||
self,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
) -> PromptManagementClient:
|
||||
"""Not used - this hook only modifies messages, doesn't fetch prompts."""
|
||||
return self._compile_prompt_helper(
|
||||
prompt_id=prompt_id,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
)
|
||||
|
||||
async def async_get_chat_completion_prompt(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
non_default_params: dict,
|
||||
prompt_id: Optional[str],
|
||||
prompt_variables: Optional[dict],
|
||||
dynamic_callback_params: StandardCallbackDynamicParams,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
prompt_spec: Optional[PromptSpec] = None,
|
||||
tools: Optional[List[Dict]] = None,
|
||||
prompt_label: Optional[str] = None,
|
||||
prompt_version: Optional[int] = None,
|
||||
ignore_prompt_manager_model: Optional[bool] = False,
|
||||
ignore_prompt_manager_optional_params: Optional[bool] = False,
|
||||
) -> Tuple[str, List[AllMessageValues], dict]:
|
||||
"""Async version - delegates to sync since no async operations needed."""
|
||||
return self.get_chat_completion_prompt(
|
||||
model=model,
|
||||
messages=messages,
|
||||
non_default_params=non_default_params,
|
||||
prompt_id=prompt_id,
|
||||
prompt_variables=prompt_variables,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
prompt_spec=prompt_spec,
|
||||
prompt_label=prompt_label,
|
||||
prompt_version=prompt_version,
|
||||
ignore_prompt_manager_model=ignore_prompt_manager_model,
|
||||
ignore_prompt_manager_optional_params=ignore_prompt_manager_optional_params,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def should_use_anthropic_cache_control_hook(non_default_params: Dict) -> bool:
|
||||
if non_default_params.get("cache_control_injection_points", None):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_custom_logger_for_anthropic_cache_control_hook(
|
||||
non_default_params: Dict,
|
||||
) -> Optional[CustomLogger]:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
_init_custom_logger_compatible_class,
|
||||
)
|
||||
|
||||
if AnthropicCacheControlHook.should_use_anthropic_cache_control_hook(
|
||||
non_default_params
|
||||
):
|
||||
return _init_custom_logger_compatible_class(
|
||||
logging_integration="anthropic_cache_control_hook",
|
||||
internal_usage_cache=None,
|
||||
llm_router=None,
|
||||
)
|
||||
return None
|
||||
Reference in New Issue
Block a user