chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,373 @@
|
||||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import cohere_messages_pt_v2
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
from ..common_utils import ModelResponseIterator as CohereModelResponseIterator
|
||||
from ..common_utils import validate_environment as cohere_validate_environment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class CohereError(BaseLLMException):
|
||||
def __init__(
|
||||
self,
|
||||
status_code: int,
|
||||
message: str,
|
||||
headers: Optional[httpx.Headers] = None,
|
||||
):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(method="POST", url="https://api.cohere.ai/v1/chat")
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
status_code=status_code,
|
||||
message=message,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
class CohereChatConfig(BaseConfig):
|
||||
"""
|
||||
Configuration class for Cohere's API interface.
|
||||
|
||||
Args:
|
||||
preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
|
||||
chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model.
|
||||
generation_id (str, optional): Unique identifier for the generated reply.
|
||||
response_id (str, optional): Unique identifier for the response.
|
||||
conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation.
|
||||
prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
|
||||
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
|
||||
search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
|
||||
documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite.
|
||||
temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
|
||||
max_tokens [DEPRECATED - use max_completion_tokens] (int, optional): The maximum number of tokens the model will generate as part of the response.
|
||||
max_completion_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
|
||||
k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
|
||||
p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
|
||||
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
|
||||
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
|
||||
seed (int, optional): A seed to assist reproducibility of the model's response.
|
||||
"""
|
||||
|
||||
preamble: Optional[str] = None
|
||||
chat_history: Optional[list] = None
|
||||
generation_id: Optional[str] = None
|
||||
response_id: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
prompt_truncation: Optional[str] = None
|
||||
connectors: Optional[list] = None
|
||||
search_queries_only: Optional[bool] = None
|
||||
documents: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
max_completion_tokens: Optional[int] = None
|
||||
k: Optional[int] = None
|
||||
p: Optional[int] = None
|
||||
frequency_penalty: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
tools: Optional[list] = None
|
||||
tool_results: Optional[list] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preamble: Optional[str] = None,
|
||||
chat_history: Optional[list] = None,
|
||||
generation_id: Optional[str] = None,
|
||||
response_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
prompt_truncation: Optional[str] = None,
|
||||
connectors: Optional[list] = None,
|
||||
search_queries_only: Optional[bool] = None,
|
||||
documents: Optional[list] = None,
|
||||
temperature: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
k: Optional[int] = None,
|
||||
p: Optional[int] = None,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_results: Optional[list] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return cohere_validate_environment(
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "n":
|
||||
optional_params["num_generations"] = value
|
||||
if param == "top_p":
|
||||
optional_params["p"] = value
|
||||
if param == "frequency_penalty":
|
||||
optional_params["frequency_penalty"] = value
|
||||
if param == "presence_penalty":
|
||||
optional_params["presence_penalty"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "seed":
|
||||
optional_params["seed"] = value
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
## Load Config
|
||||
for k, v in litellm.CohereChatConfig.get_config().items():
|
||||
if (
|
||||
k not in optional_params
|
||||
): # completion(top_k=3) > cohere_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
optional_params[k] = v
|
||||
|
||||
most_recent_message, chat_history = cohere_messages_pt_v2(
|
||||
messages=messages, model=model, llm_provider="cohere_chat"
|
||||
)
|
||||
|
||||
## Handle Tool Calling
|
||||
if "tools" in optional_params:
|
||||
_is_function_call = True
|
||||
cohere_tools = self._construct_cohere_tool(tools=optional_params["tools"])
|
||||
optional_params["tools"] = cohere_tools
|
||||
if isinstance(most_recent_message, dict):
|
||||
optional_params["tool_results"] = [most_recent_message]
|
||||
elif isinstance(most_recent_message, str):
|
||||
optional_params["message"] = most_recent_message
|
||||
|
||||
## check if chat history message is 'user' and 'tool_results' is given -> force_single_step=True, else cohere api fails
|
||||
if len(chat_history) > 0 and chat_history[-1]["role"] == "USER":
|
||||
optional_params["force_single_step"] = True
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
model_response.choices[0].message.content = raw_response_json["text"] # type: ignore
|
||||
except Exception:
|
||||
raise CohereError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
## ADD CITATIONS
|
||||
if "citations" in raw_response_json:
|
||||
setattr(model_response, "citations", raw_response_json["citations"])
|
||||
|
||||
## Tool calling response
|
||||
cohere_tools_response = raw_response_json.get("tool_calls", None)
|
||||
if cohere_tools_response is not None and cohere_tools_response != []:
|
||||
# convert cohere_tools_response to OpenAI response format
|
||||
tool_calls = []
|
||||
for tool in cohere_tools_response:
|
||||
function_name = tool.get("name", "")
|
||||
generation_id = tool.get("generation_id", "")
|
||||
parameters = tool.get("parameters", {})
|
||||
tool_call = {
|
||||
"id": f"call_{generation_id}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(parameters),
|
||||
},
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
|
||||
## CALCULATING USAGE - use cohere `billed_units` for returning usage
|
||||
billed_units = raw_response_json.get("meta", {}).get("billed_units", {})
|
||||
|
||||
prompt_tokens = billed_units.get("input_tokens", 0)
|
||||
completion_tokens = billed_units.get("output_tokens", 0)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def _construct_cohere_tool(
|
||||
self,
|
||||
tools: Optional[list] = None,
|
||||
):
|
||||
if tools is None:
|
||||
tools = []
|
||||
cohere_tools = []
|
||||
for tool in tools:
|
||||
cohere_tool = self._translate_openai_tool_to_cohere(tool)
|
||||
cohere_tools.append(cohere_tool)
|
||||
return cohere_tools
|
||||
|
||||
def _translate_openai_tool_to_cohere(
|
||||
self,
|
||||
openai_tool: dict,
|
||||
):
|
||||
# cohere tools look like this
|
||||
"""
|
||||
{
|
||||
"name": "query_daily_sales_report",
|
||||
"description": "Connects to a database to retrieve overall sales volumes and sales information for a given day.",
|
||||
"parameter_definitions": {
|
||||
"day": {
|
||||
"description": "Retrieves sales data for this day, formatted as YYYY-MM-DD.",
|
||||
"type": "str",
|
||||
"required": True
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
# OpenAI tools look like this
|
||||
"""
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
},
|
||||
}
|
||||
"""
|
||||
cohere_tool = {
|
||||
"name": openai_tool["function"]["name"],
|
||||
"description": openai_tool["function"]["description"],
|
||||
"parameter_definitions": {},
|
||||
}
|
||||
|
||||
for param_name, param_def in openai_tool["function"]["parameters"][
|
||||
"properties"
|
||||
].items():
|
||||
required_params = (
|
||||
openai_tool.get("function", {})
|
||||
.get("parameters", {})
|
||||
.get("required", [])
|
||||
)
|
||||
cohere_param_def = {
|
||||
"description": param_def.get("description", ""),
|
||||
"type": param_def.get("type", ""),
|
||||
"required": param_name in required_params,
|
||||
}
|
||||
cohere_tool["parameter_definitions"][param_name] = cohere_param_def
|
||||
|
||||
return cohere_tool
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return CohereModelResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(status_code=status_code, message=error_message)
|
||||
@@ -0,0 +1,364 @@
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, AsyncIterator, Iterator, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.cohere import CohereV2ChatResponse
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionAnnotation,
|
||||
ChatCompletionAnnotationURLCitation,
|
||||
)
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
from ..common_utils import CohereError
|
||||
from ..common_utils import CohereV2ModelResponseIterator
|
||||
from ..common_utils import validate_environment as cohere_validate_environment
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class CohereV2ChatConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Configuration class for Cohere's API interface.
|
||||
|
||||
Args:
|
||||
preamble (str, optional): When specified, the default Cohere preamble will be replaced with the provided one.
|
||||
chat_history (List[Dict[str, str]], optional): A list of previous messages between the user and the model.
|
||||
generation_id (str, optional): Unique identifier for the generated reply.
|
||||
response_id (str, optional): Unique identifier for the response.
|
||||
conversation_id (str, optional): An alternative to chat_history, creates or resumes a persisted conversation.
|
||||
prompt_truncation (str, optional): Dictates how the prompt will be constructed. Options: 'AUTO', 'AUTO_PRESERVE_ORDER', 'OFF'.
|
||||
connectors (List[Dict[str, str]], optional): List of connectors (e.g., web-search) to enrich the model's reply.
|
||||
search_queries_only (bool, optional): When true, the response will only contain a list of generated search queries.
|
||||
documents (List[Dict[str, str]], optional): A list of relevant documents that the model can cite.
|
||||
temperature (float, optional): A non-negative float that tunes the degree of randomness in generation.
|
||||
max_tokens (int, optional): The maximum number of tokens the model will generate as part of the response.
|
||||
k (int, optional): Ensures only the top k most likely tokens are considered for generation at each step.
|
||||
p (float, optional): Ensures that only the most likely tokens, with total probability mass of p, are considered for generation.
|
||||
frequency_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||
presence_penalty (float, optional): Used to reduce repetitiveness of generated tokens.
|
||||
tools (List[Dict[str, str]], optional): A list of available tools (functions) that the model may suggest invoking.
|
||||
tool_results (List[Dict[str, Any]], optional): A list of results from invoking tools.
|
||||
seed (int, optional): A seed to assist reproducibility of the model's response.
|
||||
"""
|
||||
|
||||
preamble: Optional[str] = None
|
||||
chat_history: Optional[list] = None
|
||||
generation_id: Optional[str] = None
|
||||
response_id: Optional[str] = None
|
||||
conversation_id: Optional[str] = None
|
||||
prompt_truncation: Optional[str] = None
|
||||
connectors: Optional[list] = None
|
||||
search_queries_only: Optional[bool] = None
|
||||
documents: Optional[list] = None
|
||||
temperature: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
k: Optional[int] = None
|
||||
p: Optional[int] = None
|
||||
frequency_penalty: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
tools: Optional[list] = None
|
||||
tool_results: Optional[list] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
preamble: Optional[str] = None,
|
||||
chat_history: Optional[list] = None,
|
||||
generation_id: Optional[str] = None,
|
||||
response_id: Optional[str] = None,
|
||||
conversation_id: Optional[str] = None,
|
||||
prompt_truncation: Optional[str] = None,
|
||||
connectors: Optional[list] = None,
|
||||
search_queries_only: Optional[bool] = None,
|
||||
documents: Optional[list] = None,
|
||||
temperature: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
k: Optional[int] = None,
|
||||
p: Optional[int] = None,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_results: Optional[list] = None,
|
||||
seed: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return cohere_validate_environment(
|
||||
headers=headers,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"stream",
|
||||
"temperature",
|
||||
"max_tokens",
|
||||
"top_p",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"stop",
|
||||
"n",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"seed",
|
||||
"extra_headers",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "max_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "n":
|
||||
optional_params["num_generations"] = value
|
||||
if param == "top_p":
|
||||
optional_params["p"] = value
|
||||
if param == "frequency_penalty":
|
||||
optional_params["frequency_penalty"] = value
|
||||
if param == "presence_penalty":
|
||||
optional_params["presence_penalty"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "tools":
|
||||
optional_params["tools"] = value
|
||||
if param == "seed":
|
||||
optional_params["seed"] = value
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Cohere v2 chat api is in openai format, so we can use the openai transform request function to transform the request.
|
||||
"""
|
||||
data = super().transform_request(
|
||||
model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise CohereError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
try:
|
||||
cohere_v2_chat_response = CohereV2ChatResponse(**raw_response_json) # type: ignore
|
||||
except Exception:
|
||||
raise CohereError(message=raw_response.text, status_code=422)
|
||||
|
||||
cohere_content = cohere_v2_chat_response["message"].get("content", None)
|
||||
if cohere_content is not None:
|
||||
model_response.choices[0].message.content = "".join( # type: ignore
|
||||
[
|
||||
content.get("text", "")
|
||||
for content in cohere_content
|
||||
if content is not None
|
||||
]
|
||||
)
|
||||
|
||||
## ADD CITATIONS AS ANNOTATIONS
|
||||
annotations: Optional[List[ChatCompletionAnnotation]] = None
|
||||
citations = None
|
||||
|
||||
if (
|
||||
"message" in cohere_v2_chat_response
|
||||
and "citations" in cohere_v2_chat_response["message"]
|
||||
):
|
||||
citations = cohere_v2_chat_response["message"]["citations"]
|
||||
|
||||
if citations:
|
||||
annotations = self._translate_citations_to_openai_annotations(citations)
|
||||
|
||||
## Tool calling response
|
||||
cohere_tools_response = cohere_v2_chat_response["message"].get("tool_calls", [])
|
||||
if cohere_tools_response is not None and cohere_tools_response != []:
|
||||
# convert cohere_tools_response to OpenAI response format
|
||||
tool_calls: List[ChatCompletionToolCallChunk] = []
|
||||
for index, tool in enumerate(cohere_tools_response):
|
||||
tool_call: ChatCompletionToolCallChunk = {
|
||||
**tool, # type: ignore
|
||||
"index": index,
|
||||
}
|
||||
tool_calls.append(tool_call)
|
||||
_message = litellm.Message(
|
||||
tool_calls=tool_calls,
|
||||
content=None,
|
||||
annotations=annotations,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
else:
|
||||
if annotations:
|
||||
current_message = model_response.choices[0].message # type: ignore
|
||||
current_message.annotations = annotations
|
||||
|
||||
## CALCULATING USAGE - use cohere `billed_units` for returning usage
|
||||
token_usage = cohere_v2_chat_response["usage"].get("tokens", {})
|
||||
prompt_tokens = token_usage.get("input_tokens", 0)
|
||||
completion_tokens = token_usage.get("output_tokens", 0)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
return model_response
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
):
|
||||
return CohereV2ModelResponseIterator(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Cohere v2 chat completion.
|
||||
The api_base should already include the full path.
|
||||
"""
|
||||
if api_base is None:
|
||||
raise ValueError("api_base is required")
|
||||
return api_base
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(status_code=status_code, message=error_message)
|
||||
|
||||
def _translate_citations_to_openai_annotations(
|
||||
self, citations: List[dict]
|
||||
) -> List[ChatCompletionAnnotation]:
|
||||
"""
|
||||
Transform Cohere citations to OpenAI annotations format.
|
||||
|
||||
Creates separate annotations for each source in a citation, allowing multiple
|
||||
annotations with the same start/end index if they reference different sources.
|
||||
|
||||
Args:
|
||||
citations: List of Cohere citation objects with format:
|
||||
{
|
||||
"start": int,
|
||||
"end": int,
|
||||
"text": str,
|
||||
"sources": [
|
||||
{
|
||||
"type": "document",
|
||||
"document": {
|
||||
"title": str,
|
||||
"snippet": str,
|
||||
...
|
||||
},
|
||||
"id": str
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Returns:
|
||||
List of OpenAI ChatCompletionAnnotation objects (one per source)
|
||||
"""
|
||||
annotations: List[ChatCompletionAnnotation] = []
|
||||
|
||||
for citation in citations:
|
||||
start_index = citation.get("start", 0)
|
||||
end_index = citation.get("end", 0)
|
||||
|
||||
# Extract source information - loop through all sources
|
||||
sources = citation.get("sources", [])
|
||||
if not sources:
|
||||
continue
|
||||
|
||||
# Create an annotation for each source
|
||||
for source in sources:
|
||||
if source.get("type") == "document" and "document" in source:
|
||||
document = source["document"]
|
||||
title = document.get("title", "")
|
||||
url = source.get("url") or f"source:{source.get('id', 'unknown')}"
|
||||
|
||||
url_citation: ChatCompletionAnnotationURLCitation = {
|
||||
"start_index": start_index,
|
||||
"end_index": end_index,
|
||||
"title": title,
|
||||
"url": url,
|
||||
}
|
||||
|
||||
annotation: ChatCompletionAnnotation = {
|
||||
"type": "url_citation",
|
||||
"url_citation": url_citation,
|
||||
}
|
||||
|
||||
annotations.append(annotation)
|
||||
|
||||
return annotations
|
||||
@@ -0,0 +1,417 @@
|
||||
import json
|
||||
from typing import List, Optional, Literal, Tuple
|
||||
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionUsageBlock,
|
||||
GenericStreamingChunk,
|
||||
ProviderSpecificModelInfo,
|
||||
)
|
||||
|
||||
|
||||
class CohereError(BaseLLMException):
|
||||
def __init__(self, status_code, message):
|
||||
super().__init__(status_code=status_code, message=message)
|
||||
|
||||
|
||||
class CohereModelInfo(BaseLLMModelInfo):
|
||||
def get_provider_info(
|
||||
self,
|
||||
model: str,
|
||||
) -> Optional[ProviderSpecificModelInfo]:
|
||||
"""
|
||||
Default values all models of this provider support.
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns a list of models supported by this provider.
|
||||
"""
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
return api_key
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(
|
||||
api_base: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
return api_base
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> Optional[str]:
|
||||
"""
|
||||
Returns the base model name from the given model name.
|
||||
|
||||
Some providers like bedrock - can receive model=`invoke/anthropic.claude-3-opus-20240229-v1:0` or `converse/anthropic.claude-3-opus-20240229-v1:0`
|
||||
This function will return `anthropic.claude-3-opus-20240229-v1:0`
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def get_cohere_route(model: str) -> Literal["v1", "v2"]:
|
||||
"""
|
||||
Get the Cohere route for the given model.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., "cohere_chat/v2/command-r-plus", "command-r-plus")
|
||||
|
||||
Returns:
|
||||
"v2" for standard Cohere v2 API (default), "v1" for Cohere v1 API
|
||||
"""
|
||||
# Check for explicit v1 route
|
||||
if "v1/" in model:
|
||||
return "v1"
|
||||
|
||||
# Default to v2 for all other cases
|
||||
return "v2"
|
||||
|
||||
|
||||
def validate_environment(
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Return headers to use for cohere chat completion request
|
||||
|
||||
Cohere API Ref: https://docs.cohere.com/reference/chat
|
||||
Expected headers:
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
"Authorization": "Bearer $CO_API_KEY"
|
||||
}
|
||||
"""
|
||||
headers.update(
|
||||
{
|
||||
"Request-Source": "unspecified:litellm",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
)
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
class ModelResponseIterator:
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.content_blocks: List = []
|
||||
self.tool_index = -1
|
||||
self.json_mode = json_mode
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
|
||||
index = int(chunk.get("index", 0))
|
||||
|
||||
if "text" in chunk:
|
||||
text = chunk["text"]
|
||||
elif "is_finished" in chunk and chunk["is_finished"] is True:
|
||||
is_finished = chunk["is_finished"]
|
||||
finish_reason = chunk["finish_reason"]
|
||||
|
||||
if "citations" in chunk:
|
||||
provider_specific_fields = {"citations": chunk["citations"]}
|
||||
|
||||
returned_chunk = GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
|
||||
return returned_chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode JSON from chunk: {chunk}")
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
def convert_str_chunk_to_generic_chunk(self, chunk: str) -> GenericStreamingChunk:
|
||||
"""
|
||||
Convert a string chunk to a GenericStreamingChunk
|
||||
|
||||
Note: This is used for Cohere pass through streaming logging
|
||||
"""
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
data_json = json.loads(str_line)
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
|
||||
class CohereV2ModelResponseIterator:
|
||||
"""V2-specific response iterator for Cohere streaming"""
|
||||
|
||||
def __init__(
|
||||
self, streaming_response, sync_stream: bool, json_mode: Optional[bool] = False
|
||||
):
|
||||
self.streaming_response = streaming_response
|
||||
self.response_iterator = self.streaming_response
|
||||
self.content_blocks: List = []
|
||||
self.tool_index = -1
|
||||
self.json_mode = json_mode
|
||||
|
||||
def _parse_content_delta(self, chunk: dict) -> str:
|
||||
"""Parse content-delta chunks to extract text."""
|
||||
delta = chunk.get("delta", {})
|
||||
message = delta.get("message", {})
|
||||
content = message.get("content", {})
|
||||
if isinstance(content, dict) and "text" in content:
|
||||
return content["text"]
|
||||
elif isinstance(content, str):
|
||||
return content
|
||||
return ""
|
||||
|
||||
def _parse_tool_call_delta(
|
||||
self, chunk: dict
|
||||
) -> Optional[ChatCompletionToolCallChunk]:
|
||||
"""Parse tool-call-delta chunks to extract tool calls."""
|
||||
delta = chunk.get("delta", {})
|
||||
tool_calls = delta.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
return {
|
||||
"id": tool_calls[0].get("id", ""),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_calls[0].get("name", ""),
|
||||
"arguments": tool_calls[0].get("arguments", ""),
|
||||
},
|
||||
} # type: ignore
|
||||
return None
|
||||
|
||||
def _parse_tool_plan_delta(self, chunk: dict) -> Optional[dict]:
|
||||
"""Parse tool-plan-delta events to extract tool plan."""
|
||||
data = chunk.get("data", {})
|
||||
delta = data.get("delta", {})
|
||||
message = delta.get("message", {})
|
||||
tool_plan = message.get("tool_plan", "")
|
||||
if tool_plan:
|
||||
return {"tool_plan": tool_plan}
|
||||
return None
|
||||
|
||||
def _parse_citation_start(self, chunk: dict) -> Optional[dict]:
|
||||
"""Parse citation-start events to extract citations."""
|
||||
data = chunk.get("data", {})
|
||||
delta = data.get("delta", {})
|
||||
message = delta.get("message", {})
|
||||
citations = message.get("citations", {})
|
||||
if citations:
|
||||
citation_data = {
|
||||
"start": citations.get("start", 0),
|
||||
"end": citations.get("end", 0),
|
||||
"text": citations.get("text", ""),
|
||||
"sources": citations.get("sources", []),
|
||||
"type": citations.get("type", "TEXT_CONTENT"),
|
||||
}
|
||||
return {"citations": [citation_data]}
|
||||
return None
|
||||
|
||||
def _parse_message_end(
|
||||
self, chunk: dict
|
||||
) -> Tuple[bool, str, Optional[ChatCompletionUsageBlock]]:
|
||||
"""Parse message-end events to extract finish info and usage."""
|
||||
data = chunk.get("data", {})
|
||||
delta = data.get("delta", {})
|
||||
is_finished = True
|
||||
finish_reason = delta.get("finish_reason", "stop")
|
||||
|
||||
usage = None
|
||||
usage_data = delta.get("usage", {})
|
||||
if usage_data:
|
||||
tokens_data = usage_data.get("tokens", {})
|
||||
usage = ChatCompletionUsageBlock(
|
||||
prompt_tokens=tokens_data.get("input_tokens", 0),
|
||||
completion_tokens=tokens_data.get("output_tokens", 0),
|
||||
total_tokens=tokens_data.get("input_tokens", 0)
|
||||
+ tokens_data.get("output_tokens", 0),
|
||||
)
|
||||
|
||||
return is_finished, finish_reason, usage
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> GenericStreamingChunk:
|
||||
"""
|
||||
Parse Cohere v2 streaming chunks.
|
||||
|
||||
v2 format:
|
||||
- Content: chunk.type == "content-delta" -> chunk.delta.message.content.text
|
||||
- Tool calls: chunk.type == "tool-call-delta" -> chunk.delta.tool_calls
|
||||
- Tool plan: chunk.event == "tool-plan-delta" -> chunk.data.delta.message.tool_plan
|
||||
- Citations: chunk.event == "citation-start" -> chunk.data.delta.message.citations
|
||||
- Finish: chunk.event == "message-end" -> chunk.data.delta.finish_reason
|
||||
"""
|
||||
try:
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
provider_specific_fields = None
|
||||
|
||||
index = int(chunk.get("index", 0))
|
||||
chunk_type = chunk.get("type", "")
|
||||
event_type = chunk.get("event", "")
|
||||
|
||||
# Handle different chunk types
|
||||
if chunk_type == "content-delta":
|
||||
text = self._parse_content_delta(chunk)
|
||||
elif chunk_type == "tool-call-delta":
|
||||
tool_use = self._parse_tool_call_delta(chunk)
|
||||
elif event_type == "tool-plan-delta":
|
||||
provider_specific_fields = self._parse_tool_plan_delta(chunk)
|
||||
elif event_type == "citation-start":
|
||||
provider_specific_fields = self._parse_citation_start(chunk)
|
||||
elif event_type == "message-end":
|
||||
is_finished, finish_reason, usage = self._parse_message_end(chunk)
|
||||
|
||||
# Handle citations in any chunk type (fallback)
|
||||
if "citations" in chunk:
|
||||
if provider_specific_fields is None:
|
||||
provider_specific_fields = {}
|
||||
provider_specific_fields["citations"] = chunk["citations"]
|
||||
|
||||
return GenericStreamingChunk(
|
||||
text=text,
|
||||
tool_use=tool_use,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=index,
|
||||
provider_specific_fields=provider_specific_fields,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse v2 chunk: {e}, chunk: {chunk}")
|
||||
|
||||
# Sync iterator
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
try:
|
||||
chunk = self.response_iterator.__next__()
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
|
||||
except StopIteration:
|
||||
raise StopIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
|
||||
def convert_str_chunk_to_generic_chunk(self, chunk: str) -> GenericStreamingChunk:
|
||||
"""
|
||||
Convert a string chunk to a GenericStreamingChunk for v2
|
||||
|
||||
Note: This is used for Cohere v2 pass through streaming logging
|
||||
"""
|
||||
str_line = chunk
|
||||
if isinstance(chunk, bytes): # Handle binary data
|
||||
str_line = chunk.decode("utf-8") # Convert bytes to string
|
||||
index = str_line.find("data:")
|
||||
if index != -1:
|
||||
str_line = str_line[index:]
|
||||
|
||||
data_json = json.loads(str_line)
|
||||
return self.chunk_parser(chunk=data_json)
|
||||
|
||||
# Async iterator
|
||||
def __aiter__(self):
|
||||
self.async_response_iterator = self.streaming_response.__aiter__()
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
try:
|
||||
chunk = await self.async_response_iterator.__anext__()
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error receiving chunk from stream: {e}")
|
||||
|
||||
try:
|
||||
return self.convert_str_chunk_to_generic_chunk(chunk=chunk)
|
||||
except StopAsyncIteration:
|
||||
raise StopAsyncIteration
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error parsing chunk: {e},\nReceived chunk: {chunk}")
|
||||
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
Legacy /v1/embedding handler for Bedrock Cohere.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.bedrock import CohereEmbeddingRequest
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .v1_transformation import CohereEmbeddingConfig
|
||||
|
||||
|
||||
def validate_environment(api_key, headers: dict):
|
||||
# Create a lowercase key lookup to avoid duplicate headers with different cases
|
||||
# This is important when headers come from AWS signed requests (which use Title-Case)
|
||||
existing_keys_lower = {k.lower(): k for k in headers.keys()}
|
||||
|
||||
# Only add headers if they don't already exist (case-insensitive check)
|
||||
if "request-source" not in existing_keys_lower:
|
||||
headers["Request-Source"] = "unspecified:litellm"
|
||||
if "accept" not in existing_keys_lower:
|
||||
headers["accept"] = "application/json"
|
||||
if "content-type" not in existing_keys_lower:
|
||||
headers["content-type"] = "application/json"
|
||||
if api_key and "authorization" not in existing_keys_lower:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
return headers
|
||||
|
||||
|
||||
class CohereError(Exception):
|
||||
def __init__(self, status_code, message):
|
||||
self.status_code = status_code
|
||||
self.message = message
|
||||
self.request = httpx.Request(
|
||||
method="POST", url="https://api.cohere.ai/v1/generate"
|
||||
)
|
||||
self.response = httpx.Response(status_code=status_code, request=self.request)
|
||||
super().__init__(
|
||||
self.message
|
||||
) # Call the base class constructor with the parameters it needs
|
||||
|
||||
|
||||
async def async_embedding(
|
||||
model: str,
|
||||
data: Union[dict, CohereEmbeddingRequest],
|
||||
input: list,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str],
|
||||
headers: dict,
|
||||
encoding: Callable,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"headers": headers,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
## COMPLETION CALL
|
||||
|
||||
if client is None:
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.COHERE,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.post(api_base, headers=headers, data=json.dumps(data))
|
||||
except httpx.HTTPStatusError as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=e.response.text,
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
## PROCESS RESPONSE ##
|
||||
return CohereEmbeddingConfig()._transform_response(
|
||||
response=response,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
encoding=encoding,
|
||||
input=input,
|
||||
)
|
||||
|
||||
|
||||
def embedding(
|
||||
model: str,
|
||||
input: list,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
encoding: Any,
|
||||
data: Optional[Union[dict, CohereEmbeddingRequest]] = None,
|
||||
complete_api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
aembedding: Optional[bool] = None,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = httpx.Timeout(None),
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
):
|
||||
headers = validate_environment(api_key, headers=headers)
|
||||
embed_url = complete_api_base or "https://api.cohere.ai/v1/embed"
|
||||
model = model
|
||||
|
||||
data = data or CohereEmbeddingConfig()._transform_request(
|
||||
model=model, input=input, inference_params=optional_params
|
||||
)
|
||||
|
||||
## ROUTING
|
||||
if aembedding is True:
|
||||
return async_embedding(
|
||||
model=model,
|
||||
data=data,
|
||||
input=input,
|
||||
model_response=model_response,
|
||||
timeout=timeout,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
api_base=embed_url,
|
||||
api_key=api_key,
|
||||
headers=headers,
|
||||
encoding=encoding,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = HTTPHandler(concurrent_limit=1)
|
||||
|
||||
response = client.post(embed_url, headers=headers, data=json.dumps(data))
|
||||
|
||||
return CohereEmbeddingConfig()._transform_response(
|
||||
response=response,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
encoding=encoding,
|
||||
input=input,
|
||||
)
|
||||
@@ -0,0 +1,246 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Cohere's /v1/embed format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Convers
|
||||
- v3 embedding models
|
||||
- v2 embedding models
|
||||
|
||||
Docs - https://docs.cohere.com/v2/reference/embed
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import COHERE_DEFAULT_EMBEDDING_INPUT_TYPE
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm import BaseEmbeddingConfig
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.types.llms.bedrock import (
|
||||
CohereEmbeddingRequest,
|
||||
CohereEmbeddingRequestWithModel,
|
||||
)
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues, AllMessageValues
|
||||
from litellm.types.utils import EmbeddingResponse, PromptTokensDetailsWrapper, Usage
|
||||
from litellm.utils import is_base64_encoded
|
||||
|
||||
from ..common_utils import CohereError
|
||||
|
||||
|
||||
class CohereEmbeddingConfig(BaseEmbeddingConfig):
|
||||
"""
|
||||
Reference: https://docs.cohere.com/v2/reference/embed
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return ["encoding_format", "dimensions"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool = False,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "encoding_format":
|
||||
if isinstance(v, list):
|
||||
optional_params["embedding_types"] = v
|
||||
else:
|
||||
optional_params["embedding_types"] = [v]
|
||||
elif k == "dimensions":
|
||||
optional_params["output_dimension"] = v
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
default_headers = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
if api_key:
|
||||
default_headers["Authorization"] = f"Bearer {api_key}"
|
||||
headers = {**default_headers, **headers}
|
||||
return headers
|
||||
|
||||
def _is_v3_model(self, model: str) -> bool:
|
||||
return "3" in model
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
return api_base or "https://api.cohere.ai/v2/embed"
|
||||
|
||||
def _transform_request(
|
||||
self, model: str, input: List[str], inference_params: dict
|
||||
) -> CohereEmbeddingRequestWithModel:
|
||||
is_encoded = False
|
||||
for input_str in input:
|
||||
is_encoded = is_base64_encoded(input_str)
|
||||
|
||||
if is_encoded: # check if string is b64 encoded image or not
|
||||
transformed_request = CohereEmbeddingRequestWithModel(
|
||||
model=model,
|
||||
images=input,
|
||||
input_type="image",
|
||||
)
|
||||
else:
|
||||
transformed_request = CohereEmbeddingRequestWithModel(
|
||||
model=model,
|
||||
texts=input,
|
||||
input_type=COHERE_DEFAULT_EMBEDDING_INPUT_TYPE,
|
||||
)
|
||||
|
||||
for k, v in inference_params.items():
|
||||
transformed_request[k] = v # type: ignore
|
||||
|
||||
return transformed_request
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
if isinstance(input, list) and (
|
||||
isinstance(input[0], list) or isinstance(input[0], int)
|
||||
):
|
||||
raise ValueError("Input must be a list of strings")
|
||||
return cast(
|
||||
dict,
|
||||
self._transform_request(
|
||||
model=model,
|
||||
input=cast(List[str], input) if isinstance(input, List) else [input],
|
||||
inference_params=optional_params,
|
||||
),
|
||||
)
|
||||
|
||||
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
|
||||
input_tokens = 0
|
||||
|
||||
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
|
||||
|
||||
image_tokens: Optional[int] = meta.get("billed_units", {}).get("images")
|
||||
|
||||
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
|
||||
if image_tokens is None and text_tokens is None:
|
||||
for text in input:
|
||||
input_tokens += len(encoding.encode(text))
|
||||
else:
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
image_tokens=image_tokens,
|
||||
text_tokens=text_tokens,
|
||||
)
|
||||
if image_tokens:
|
||||
input_tokens += image_tokens
|
||||
if text_tokens:
|
||||
input_tokens += text_tokens
|
||||
|
||||
return Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=input_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
)
|
||||
|
||||
def _transform_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
api_key: Optional[str],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
data: Union[dict, CohereEmbeddingRequest],
|
||||
model_response: EmbeddingResponse,
|
||||
model: str,
|
||||
encoding: Any,
|
||||
input: list,
|
||||
) -> EmbeddingResponse:
|
||||
response_json = response.json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response_json,
|
||||
)
|
||||
"""
|
||||
response
|
||||
{
|
||||
'object': "list",
|
||||
'data': [
|
||||
|
||||
]
|
||||
'model',
|
||||
'usage'
|
||||
}
|
||||
"""
|
||||
embeddings = response_json["embeddings"]
|
||||
output_data = []
|
||||
for k, embedding_list in embeddings.items():
|
||||
for idx, embedding in enumerate(embedding_list):
|
||||
output_data.append(
|
||||
{"object": "embedding", "index": idx, "embedding": embedding}
|
||||
)
|
||||
model_response.object = "list"
|
||||
model_response.data = output_data
|
||||
model_response.model = model
|
||||
input_tokens = 0
|
||||
for text in input:
|
||||
input_tokens += len(encoding.encode(text))
|
||||
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
self._calculate_usage(input, encoding, response_json.get("meta", {})),
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> EmbeddingResponse:
|
||||
return self._transform_response(
|
||||
response=raw_response,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
data=request_data,
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
encoding=litellm.encoding,
|
||||
input=logging_obj.model_call_details["input"],
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
)
|
||||
@@ -0,0 +1,162 @@
|
||||
"""
|
||||
Legacy /v1/embedding transformation logic for Bedrock Cohere.
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm import COHERE_DEFAULT_EMBEDDING_INPUT_TYPE
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.llms.bedrock import (
|
||||
CohereEmbeddingRequest,
|
||||
CohereEmbeddingRequestWithModel,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse, PromptTokensDetailsWrapper, Usage
|
||||
from litellm.utils import is_base64_encoded
|
||||
|
||||
|
||||
class CohereEmbeddingConfig:
|
||||
"""
|
||||
Reference: https://docs.cohere.com/v2/reference/embed
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return ["encoding_format"]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "encoding_format":
|
||||
optional_params["embedding_types"] = v
|
||||
return optional_params
|
||||
|
||||
def _is_v3_model(self, model: str) -> bool:
|
||||
return "3" in model
|
||||
|
||||
def _transform_request(
|
||||
self, model: str, input: List[str], inference_params: dict
|
||||
) -> CohereEmbeddingRequestWithModel:
|
||||
is_encoded = False
|
||||
for input_str in input:
|
||||
is_encoded = is_base64_encoded(input_str)
|
||||
|
||||
if is_encoded: # check if string is b64 encoded image or not
|
||||
transformed_request = CohereEmbeddingRequestWithModel(
|
||||
model=model,
|
||||
images=input,
|
||||
input_type="image",
|
||||
)
|
||||
else:
|
||||
transformed_request = CohereEmbeddingRequestWithModel(
|
||||
model=model,
|
||||
texts=input,
|
||||
input_type=COHERE_DEFAULT_EMBEDDING_INPUT_TYPE,
|
||||
)
|
||||
|
||||
for k, v in inference_params.items():
|
||||
transformed_request[k] = v # type: ignore
|
||||
|
||||
return transformed_request
|
||||
|
||||
def _calculate_usage(self, input: List[str], encoding: Any, meta: dict) -> Usage:
|
||||
input_tokens = 0
|
||||
|
||||
text_tokens: Optional[int] = meta.get("billed_units", {}).get("input_tokens")
|
||||
|
||||
image_tokens: Optional[int] = meta.get("billed_units", {}).get("images")
|
||||
|
||||
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
|
||||
if image_tokens is None and text_tokens is None:
|
||||
for text in input:
|
||||
input_tokens += len(encoding.encode(text))
|
||||
else:
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
image_tokens=image_tokens,
|
||||
text_tokens=text_tokens,
|
||||
)
|
||||
if image_tokens:
|
||||
input_tokens += image_tokens
|
||||
if text_tokens:
|
||||
input_tokens += text_tokens
|
||||
|
||||
return Usage(
|
||||
prompt_tokens=input_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=input_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
)
|
||||
|
||||
def _transform_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
api_key: Optional[str],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
data: Union[dict, CohereEmbeddingRequest],
|
||||
model_response: EmbeddingResponse,
|
||||
model: str,
|
||||
encoding: Any,
|
||||
input: list,
|
||||
) -> EmbeddingResponse:
|
||||
response_json = response.json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response_json,
|
||||
)
|
||||
"""
|
||||
response
|
||||
{
|
||||
'object': "list",
|
||||
'data': [
|
||||
|
||||
]
|
||||
'model',
|
||||
'usage'
|
||||
}
|
||||
"""
|
||||
embeddings = response_json["embeddings"]
|
||||
output_data = []
|
||||
is_embeddings_by_type = (
|
||||
response_json.get("response_type") == "embeddings_by_type"
|
||||
)
|
||||
|
||||
if isinstance(embeddings, dict):
|
||||
is_embeddings_by_type = True
|
||||
|
||||
if is_embeddings_by_type:
|
||||
for embedding_type in embeddings:
|
||||
for idx, embedding in enumerate(embeddings[embedding_type]):
|
||||
output_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"index": idx,
|
||||
"embedding": embedding,
|
||||
"type": embedding_type,
|
||||
}
|
||||
)
|
||||
else:
|
||||
for idx, embedding in enumerate(embeddings):
|
||||
output_data.append(
|
||||
{"object": "embedding", "index": idx, "embedding": embedding}
|
||||
)
|
||||
model_response.object = "list"
|
||||
model_response.data = output_data
|
||||
model_response.model = model
|
||||
input_tokens = 0
|
||||
for text in input:
|
||||
input_tokens += len(encoding.encode(text))
|
||||
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
self._calculate_usage(input, encoding, response_json.get("meta", {})),
|
||||
)
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,229 @@
|
||||
# Cohere Rerank Guardrail Translation Handler
|
||||
|
||||
Handler for processing the rerank endpoint (`/v1/rerank`) with guardrails.
|
||||
|
||||
## Overview
|
||||
|
||||
This handler processes rerank requests by:
|
||||
1. Extracting the query text from the request
|
||||
2. Applying guardrails to the query
|
||||
3. Updating the request with the guardrailed query
|
||||
4. Returning the output unchanged (rankings are not text)
|
||||
|
||||
Note: Documents are not processed by guardrails as they represent the corpus
|
||||
being searched, not user input. Only the query is guardrailed.
|
||||
|
||||
## Data Format
|
||||
|
||||
### Input Format
|
||||
|
||||
**With String Documents:**
|
||||
```json
|
||||
{
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "What is the capital of France?",
|
||||
"documents": [
|
||||
"Paris is the capital of France.",
|
||||
"Berlin is the capital of Germany.",
|
||||
"Madrid is the capital of Spain."
|
||||
],
|
||||
"top_n": 2
|
||||
}
|
||||
```
|
||||
|
||||
**With Dict Documents:**
|
||||
```json
|
||||
{
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "What is the capital of France?",
|
||||
"documents": [
|
||||
{"text": "Paris is the capital of France.", "id": "doc1"},
|
||||
{"text": "Berlin is the capital of Germany.", "id": "doc2"},
|
||||
{"text": "Madrid is the capital of Spain.", "id": "doc3"}
|
||||
],
|
||||
"top_n": 2
|
||||
}
|
||||
```
|
||||
|
||||
### Output Format
|
||||
|
||||
```json
|
||||
{
|
||||
"id": "rerank-abc123",
|
||||
"results": [
|
||||
{"index": 0, "relevance_score": 0.98},
|
||||
{"index": 2, "relevance_score": 0.12}
|
||||
],
|
||||
"meta": {
|
||||
"billed_units": {"search_units": 1}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
The handler is automatically discovered and applied when guardrails are used with the rerank endpoint.
|
||||
|
||||
### Example: Using Guardrails with Rerank
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/rerank' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "What is machine learning?",
|
||||
"documents": [
|
||||
"Machine learning is a subset of AI.",
|
||||
"Deep learning uses neural networks.",
|
||||
"Python is a programming language."
|
||||
],
|
||||
"guardrails": ["content_filter"],
|
||||
"top_n": 2
|
||||
}'
|
||||
```
|
||||
|
||||
The guardrail will be applied to the query only (not the documents).
|
||||
|
||||
### Example: PII Masking in Query
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/rerank' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "Find documents about John Doe from john@example.com",
|
||||
"documents": [
|
||||
"Document 1 content here.",
|
||||
"Document 2 content here.",
|
||||
"Document 3 content here."
|
||||
],
|
||||
"guardrails": ["mask_pii"],
|
||||
"top_n": 3
|
||||
}'
|
||||
```
|
||||
|
||||
The query will be masked to: "Find documents about [NAME_REDACTED] from [EMAIL_REDACTED]"
|
||||
|
||||
### Example: Mixed Document Types
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://localhost:4000/v1/rerank' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer your-api-key' \
|
||||
-d '{
|
||||
"model": "rerank-english-v3.0",
|
||||
"query": "Technical documentation",
|
||||
"documents": [
|
||||
{"text": "This is document 1", "metadata": {"source": "wiki"}},
|
||||
{"text": "This is document 2", "metadata": {"source": "docs"}},
|
||||
"This is document 3 as a plain string"
|
||||
],
|
||||
"guardrails": ["content_moderation"]
|
||||
}'
|
||||
```
|
||||
|
||||
## Implementation Details
|
||||
|
||||
### Input Processing
|
||||
|
||||
- **Query Field**: `query` (string)
|
||||
- Processing: Apply guardrail to query text
|
||||
- Result: Updated query
|
||||
|
||||
- **Documents Field**: `documents` (list)
|
||||
- Processing: Not processed (corpus being searched, not user input)
|
||||
- Result: Unchanged
|
||||
|
||||
### Output Processing
|
||||
|
||||
- **Processing**: Not applicable (output contains relevance scores, not text)
|
||||
- **Result**: Response returned unchanged
|
||||
|
||||
## Use Cases
|
||||
|
||||
1. **PII Protection**: Remove PII from queries before reranking
|
||||
2. **Content Filtering**: Filter inappropriate content from search queries
|
||||
3. **Compliance**: Ensure queries meet requirements
|
||||
4. **Data Sanitization**: Clean up query text before semantic search operations
|
||||
|
||||
## Extension
|
||||
|
||||
Override these methods to customize behavior:
|
||||
|
||||
- `process_input_messages()`: Customize how query is processed
|
||||
- `process_output_response()`: Currently a no-op, but can be overridden if needed
|
||||
|
||||
## Supported Call Types
|
||||
|
||||
- `CallTypes.rerank` - Synchronous rerank
|
||||
- `CallTypes.arerank` - Asynchronous rerank
|
||||
|
||||
## Notes
|
||||
|
||||
- Only the query is processed by guardrails
|
||||
- Documents are not processed (they represent the corpus, not user input)
|
||||
- Output processing is a no-op since rankings don't contain text
|
||||
- Both sync and async call types use the same handler
|
||||
- Works with all rerank providers (Cohere, Together AI, etc.)
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### PII Masking in Search
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
response = litellm.rerank(
|
||||
model="rerank-english-v3.0",
|
||||
query="Find info about john@example.com",
|
||||
documents=[
|
||||
"Document 1 content.",
|
||||
"Document 2 content.",
|
||||
"Document 3 content."
|
||||
],
|
||||
guardrails=["mask_pii"],
|
||||
top_n=2
|
||||
)
|
||||
|
||||
# Query will have PII masked
|
||||
# query becomes: "Find info about [EMAIL_REDACTED]"
|
||||
print(response.results)
|
||||
```
|
||||
|
||||
### Content Filtering
|
||||
|
||||
```python
|
||||
import litellm
|
||||
|
||||
response = litellm.rerank(
|
||||
model="rerank-english-v3.0",
|
||||
query="Search query here",
|
||||
documents=[
|
||||
{"text": "Document 1 content", "id": "doc1"},
|
||||
{"text": "Document 2 content", "id": "doc2"},
|
||||
],
|
||||
guardrails=["content_filter"],
|
||||
)
|
||||
```
|
||||
|
||||
### Async Rerank with Guardrails
|
||||
|
||||
```python
|
||||
import litellm
|
||||
import asyncio
|
||||
|
||||
async def rerank_with_guardrails():
|
||||
response = await litellm.arerank(
|
||||
model="rerank-english-v3.0",
|
||||
query="Technical query",
|
||||
documents=["Doc 1", "Doc 2", "Doc 3"],
|
||||
guardrails=["sanitize"],
|
||||
top_n=2
|
||||
)
|
||||
return response
|
||||
|
||||
result = asyncio.run(rerank_with_guardrails())
|
||||
```
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
"""Cohere Rerank handler for Unified Guardrails."""
|
||||
|
||||
from litellm.llms.cohere.rerank.guardrail_translation.handler import CohereRerankHandler
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
guardrail_translation_mappings = {
|
||||
CallTypes.rerank: CohereRerankHandler,
|
||||
CallTypes.arerank: CohereRerankHandler,
|
||||
}
|
||||
|
||||
__all__ = ["guardrail_translation_mappings", "CohereRerankHandler"]
|
||||
@@ -0,0 +1,107 @@
|
||||
"""
|
||||
Cohere Rerank Handler for Unified Guardrails
|
||||
|
||||
This module provides guardrail translation support for the rerank endpoint.
|
||||
The handler processes only the 'query' parameter for guardrails.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.types.rerank import RerankResponse
|
||||
|
||||
|
||||
class CohereRerankHandler(BaseTranslation):
|
||||
"""
|
||||
Handler for processing rerank requests with guardrails.
|
||||
|
||||
This class provides methods to:
|
||||
1. Process input query (pre-call hook)
|
||||
2. Process output response (post-call hook) - not applicable for rerank
|
||||
|
||||
The handler specifically processes:
|
||||
- The 'query' parameter (string)
|
||||
|
||||
Note: Documents are not processed by guardrails as they are the corpus
|
||||
being searched, not user input.
|
||||
"""
|
||||
|
||||
async def process_input_messages(
|
||||
self,
|
||||
data: dict,
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process input query by applying guardrails.
|
||||
|
||||
Args:
|
||||
data: Request data dictionary containing 'query'
|
||||
guardrail_to_apply: The guardrail instance to apply
|
||||
|
||||
Returns:
|
||||
Modified data with guardrails applied to query only
|
||||
"""
|
||||
# Process query only
|
||||
query = data.get("query")
|
||||
if query is not None and isinstance(query, str):
|
||||
inputs = GenericGuardrailAPIInputs(texts=[query])
|
||||
# Include model information if available
|
||||
model = data.get("model")
|
||||
if model:
|
||||
inputs["model"] = model
|
||||
guardrailed_inputs = await guardrail_to_apply.apply_guardrail(
|
||||
inputs=inputs,
|
||||
request_data=data,
|
||||
input_type="request",
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
data["query"] = guardrailed_texts[0] if guardrailed_texts else query
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Rerank: Applied guardrail to query. "
|
||||
"Original length: %d, New length: %d",
|
||||
len(query),
|
||||
len(data["query"]),
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"Rerank: No query to process or query is not a string"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
async def process_output_response(
|
||||
self,
|
||||
response: "RerankResponse",
|
||||
guardrail_to_apply: "CustomGuardrail",
|
||||
litellm_logging_obj: Optional[Any] = None,
|
||||
user_api_key_dict: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Process output response - not applicable for rerank.
|
||||
|
||||
Rerank responses contain relevance scores and indices, not text,
|
||||
so there's nothing to apply guardrails to. This method returns
|
||||
the response unchanged.
|
||||
|
||||
Args:
|
||||
response: Rerank response object with rankings
|
||||
guardrail_to_apply: The guardrail instance (unused)
|
||||
litellm_logging_obj: Optional logging object (unused)
|
||||
user_api_key_dict: User API key metadata (unused)
|
||||
|
||||
Returns:
|
||||
Unmodified response (rankings don't need text guardrails)
|
||||
"""
|
||||
verbose_proxy_logger.debug(
|
||||
"Rerank: Output processing not applicable "
|
||||
"(output contains relevance scores, not text)"
|
||||
)
|
||||
return response
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Cohere Rerank - uses `llm_http_handler.py` to make httpx requests
|
||||
|
||||
Request/Response transformation is handled in `transformation.py`
|
||||
"""
|
||||
@@ -0,0 +1,158 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.rerank.transformation import BaseRerankConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankRequest, RerankResponse
|
||||
|
||||
from ..common_utils import CohereError
|
||||
|
||||
|
||||
class CohereRerankConfig(BaseRerankConfig):
|
||||
"""
|
||||
Reference: https://docs.cohere.com/v2/reference/rerank
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> str:
|
||||
if api_base:
|
||||
# Remove trailing slashes and ensure clean base URL
|
||||
api_base = api_base.rstrip("/")
|
||||
if not api_base.endswith("/v1/rerank"):
|
||||
api_base = f"{api_base}/v1/rerank"
|
||||
return api_base
|
||||
return "https://api.cohere.ai/v1/rerank"
|
||||
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
return [
|
||||
"query",
|
||||
"documents",
|
||||
"top_n",
|
||||
"max_chunks_per_doc",
|
||||
"rank_fields",
|
||||
"return_documents",
|
||||
]
|
||||
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: Optional[dict],
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Map Cohere rerank params
|
||||
|
||||
No mapping required - returns all supported params
|
||||
"""
|
||||
return dict(
|
||||
OptionalRerankParams(
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_chunks_per_doc=max_chunks_per_doc,
|
||||
)
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> dict:
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret_str("COHERE_API_KEY")
|
||||
or get_secret_str("CO_API_KEY")
|
||||
or litellm.cohere_key
|
||||
)
|
||||
|
||||
if api_key is None:
|
||||
raise ValueError(
|
||||
"Cohere API key is required. Please set 'COHERE_API_KEY' or 'CO_API_KEY' or 'litellm.cohere_key'"
|
||||
)
|
||||
|
||||
default_headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"accept": "application/json",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
# If 'Authorization' is provided in headers, it overrides the default.
|
||||
if "Authorization" in headers:
|
||||
default_headers["Authorization"] = headers["Authorization"]
|
||||
|
||||
# Merge other headers, overriding any default ones except Authorization
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def transform_rerank_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_rerank_params: Dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
if "query" not in optional_rerank_params:
|
||||
raise ValueError("query is required for Cohere rerank")
|
||||
if "documents" not in optional_rerank_params:
|
||||
raise ValueError("documents is required for Cohere rerank")
|
||||
rerank_request = RerankRequest(
|
||||
model=model,
|
||||
query=optional_rerank_params["query"],
|
||||
documents=optional_rerank_params["documents"],
|
||||
top_n=optional_rerank_params.get("top_n", None),
|
||||
rank_fields=optional_rerank_params.get("rank_fields", None),
|
||||
return_documents=optional_rerank_params.get("return_documents", None),
|
||||
max_chunks_per_doc=optional_rerank_params.get("max_chunks_per_doc", None),
|
||||
)
|
||||
return rerank_request.model_dump(exclude_none=True)
|
||||
|
||||
def transform_rerank_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: RerankResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
request_data: dict = {},
|
||||
optional_params: dict = {},
|
||||
litellm_params: dict = {},
|
||||
) -> RerankResponse:
|
||||
"""
|
||||
Transform Cohere rerank response
|
||||
|
||||
No transformation required, litellm follows cohere API response format
|
||||
"""
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
except Exception:
|
||||
raise CohereError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
|
||||
return RerankResponse(**raw_response_json)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return CohereError(message=error_message, status_code=status_code)
|
||||
@@ -0,0 +1,88 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from litellm.llms.cohere.rerank.transformation import CohereRerankConfig
|
||||
from litellm.types.rerank import OptionalRerankParams, RerankRequest
|
||||
|
||||
|
||||
class CohereRerankV2Config(CohereRerankConfig):
|
||||
"""
|
||||
Reference: https://docs.cohere.com/v2/reference/rerank
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
model: str,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> str:
|
||||
if api_base:
|
||||
# Remove trailing slashes and ensure clean base URL
|
||||
api_base = api_base.rstrip("/")
|
||||
if not api_base.endswith("/v2/rerank"):
|
||||
api_base = f"{api_base}/v2/rerank"
|
||||
return api_base
|
||||
return "https://api.cohere.ai/v2/rerank"
|
||||
|
||||
def get_supported_cohere_rerank_params(self, model: str) -> list:
|
||||
return [
|
||||
"query",
|
||||
"documents",
|
||||
"top_n",
|
||||
"max_tokens_per_doc",
|
||||
"rank_fields",
|
||||
"return_documents",
|
||||
]
|
||||
|
||||
def map_cohere_rerank_params(
|
||||
self,
|
||||
non_default_params: Optional[dict],
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
max_tokens_per_doc: Optional[int] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Map Cohere rerank params
|
||||
|
||||
No mapping required - returns all supported params
|
||||
"""
|
||||
return dict(
|
||||
OptionalRerankParams(
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
max_tokens_per_doc=max_tokens_per_doc,
|
||||
)
|
||||
)
|
||||
|
||||
def transform_rerank_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_rerank_params: Dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
if "query" not in optional_rerank_params:
|
||||
raise ValueError("query is required for Cohere rerank")
|
||||
if "documents" not in optional_rerank_params:
|
||||
raise ValueError("documents is required for Cohere rerank")
|
||||
rerank_request = RerankRequest(
|
||||
model=model,
|
||||
query=optional_rerank_params["query"],
|
||||
documents=optional_rerank_params["documents"],
|
||||
top_n=optional_rerank_params.get("top_n", None),
|
||||
rank_fields=optional_rerank_params.get("rank_fields", None),
|
||||
return_documents=optional_rerank_params.get("return_documents", None),
|
||||
max_tokens_per_doc=optional_rerank_params.get("max_tokens_per_doc", None),
|
||||
)
|
||||
return rerank_request.model_dump(exclude_none=True)
|
||||
Reference in New Issue
Block a user