chore: initial public snapshot for github upload

This commit is contained in:
Your Name
2026-03-26 20:06:14 +08:00
commit 0e5ecd930e
3497 changed files with 1586236 additions and 0 deletions

View File

@@ -0,0 +1,30 @@
from typing import Optional
from .converse_handler import BedrockConverseLLM
from .invoke_handler import (
AmazonAnthropicClaudeStreamDecoder,
AmazonDeepSeekR1StreamDecoder,
AWSEventStreamDecoder,
BedrockLLM,
)
def get_bedrock_event_stream_decoder(
invoke_provider: Optional[str], model: str, sync_stream: bool, json_mode: bool
):
if invoke_provider and invoke_provider == "anthropic":
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
model=model,
sync_stream=sync_stream,
json_mode=json_mode,
)
return decoder
elif invoke_provider and invoke_provider == "deepseek_r1":
decoder = AmazonDeepSeekR1StreamDecoder(
model=model,
sync_stream=sync_stream,
)
return decoder
else:
decoder = AWSEventStreamDecoder(model=model)
return decoder

View File

@@ -0,0 +1,3 @@
from .transformation import AmazonAgentCoreConfig
__all__ = ["AmazonAgentCoreConfig"]

View File

@@ -0,0 +1,512 @@
import json
from typing import Any, Optional, Union
import httpx
import litellm
from litellm.anthropic_beta_headers_manager import (
update_headers_with_filtered_beta,
)
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
get_async_httpx_client,
)
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper
from ..base_aws_llm import BaseAWSLLM, Credentials
from ..common_utils import BedrockError, _get_all_bedrock_regions
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
def make_sync_call(
client: Optional[HTTPHandler],
api_base: str,
headers: dict,
data: str,
model: str,
messages: list,
logging_obj: LiteLLMLoggingObject,
json_mode: Optional[bool] = False,
fake_stream: bool = False,
stream_chunk_size: int = 1024,
):
if client is None:
client = _get_httpx_client() # Create a new client if none provided
response = client.post(
api_base,
headers=headers,
data=data,
stream=not fake_stream,
logging_obj=logging_obj,
)
if response.status_code != 200:
raise BedrockError(
status_code=response.status_code, message=str(response.read())
)
if fake_stream:
model_response: (
ModelResponse
) = litellm.AmazonConverseConfig()._transform_response(
model=model,
response=response,
model_response=litellm.ModelResponse(),
stream=True,
logging_obj=logging_obj,
optional_params={},
api_key="",
data=data,
messages=messages,
encoding=litellm.encoding,
) # type: ignore
completion_stream: Any = MockResponseIterator(
model_response=model_response, json_mode=json_mode
)
else:
decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode)
completion_stream = decoder.iter_bytes(
response.iter_bytes(chunk_size=stream_chunk_size)
)
# LOGGING
logging_obj.post_call(
input=messages,
api_key="",
original_response="first stream response received",
additional_args={"complete_input_dict": data},
)
return completion_stream
class BedrockConverseLLM(BaseAWSLLM):
def __init__(self) -> None:
super().__init__()
async def async_streaming(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj,
stream,
optional_params: dict,
litellm_params: dict,
credentials: Credentials,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
fake_stream: bool = False,
json_mode: Optional[bool] = False,
api_key: Optional[str] = None,
stream_chunk_size: int = 1024,
) -> CustomStreamWrapper:
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
data = json.dumps(request_data)
prepped = self.get_request_headers(
credentials=credentials,
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
extra_headers=headers,
endpoint_url=api_base,
data=data,
headers=headers,
api_key=api_key,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": dict(prepped.headers),
},
)
completion_stream = await make_call(
client=client,
api_base=api_base,
headers=dict(prepped.headers),
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
fake_stream=fake_stream,
json_mode=json_mode,
stream_chunk_size=stream_chunk_size,
)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
async def async_completion(
self,
model: str,
messages: list,
api_base: str,
model_response: ModelResponse,
timeout: Optional[Union[float, httpx.Timeout]],
encoding,
logging_obj: LiteLLMLoggingObject,
stream,
optional_params: dict,
litellm_params: dict,
credentials: Credentials,
logger_fn=None,
headers: dict = {},
client: Optional[AsyncHTTPHandler] = None,
api_key: Optional[str] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
data = json.dumps(request_data)
prepped = self.get_request_headers(
credentials=credentials,
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
extra_headers=headers,
endpoint_url=api_base,
data=data,
headers=headers,
api_key=api_key,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": api_base,
"headers": prepped.headers,
},
)
headers = dict(prepped.headers)
if client is None or not isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = get_async_httpx_client(
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
)
else:
client = client # type: ignore
try:
response = await client.post(
url=api_base,
headers=headers,
data=data,
logging_obj=logging_obj,
) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return litellm.AmazonConverseConfig()._transform_response(
model=model,
response=response,
model_response=model_response,
stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
)
def completion( # noqa: PLR0915
self,
model: str,
messages: list,
api_base: Optional[str],
custom_prompt_dict: dict,
model_response: ModelResponse,
encoding,
logging_obj: LiteLLMLoggingObject,
optional_params: dict,
acompletion: bool,
timeout: Optional[Union[float, httpx.Timeout]],
litellm_params: dict,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
api_key: Optional[str] = None,
):
## SETUP ##
stream = optional_params.pop("stream", None)
stream_chunk_size = optional_params.pop("stream_chunk_size", 1024)
unencoded_model_id = optional_params.pop("model_id", None)
fake_stream = optional_params.pop("fake_stream", False)
json_mode = optional_params.get("json_mode", False)
if unencoded_model_id is not None:
modelId = self.encode_model_id(model_id=unencoded_model_id)
else:
# Strip nova spec prefixes before encoding model ID for API URL
_model_for_id = model
_stripped = _model_for_id
for rp in ["bedrock/converse/", "bedrock/", "converse/"]:
if _stripped.startswith(rp):
_stripped = _stripped[len(rp) :]
break
# Strip embedded region prefix (e.g. "bedrock/us-east-1/model" -> "model")
# and capture it so it can be used as aws_region_name below.
_region_from_model: Optional[str] = None
_potential_region = _stripped.split("/", 1)[0]
if _potential_region in _get_all_bedrock_regions() and "/" in _stripped:
_region_from_model = _potential_region
_stripped = _stripped.split("/", 1)[1]
_model_for_id = _stripped
for _nova_prefix in ["nova-2/", "nova/"]:
if _stripped.startswith(_nova_prefix):
_model_for_id = _model_for_id.replace(_nova_prefix, "", 1)
break
modelId = self.encode_model_id(model_id=_model_for_id)
# Inject region extracted from model path so _get_aws_region_name picks it up
if (
_region_from_model is not None
and "aws_region_name" not in optional_params
):
optional_params["aws_region_name"] = _region_from_model
fake_stream = litellm.AmazonConverseConfig().should_fake_stream(
fake_stream=fake_stream,
model=model,
stream=stream,
custom_llm_provider="bedrock",
)
### SET REGION NAME ###
aws_region_name = self._get_aws_region_name(
optional_params=optional_params,
model=model,
model_id=unencoded_model_id,
)
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
aws_session_token = optional_params.pop("aws_session_token", None)
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
aws_external_id = optional_params.pop("aws_external_id", None)
optional_params.pop("aws_region_name", None)
litellm_params[
"aws_region_name"
] = aws_region_name # [DO NOT DELETE] important for async calls
credentials: Credentials = self.get_credentials(
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
aws_session_token=aws_session_token,
aws_region_name=aws_region_name,
aws_session_name=aws_session_name,
aws_profile_name=aws_profile_name,
aws_role_name=aws_role_name,
aws_web_identity_token=aws_web_identity_token,
aws_sts_endpoint=aws_sts_endpoint,
aws_external_id=aws_external_id,
)
### SET RUNTIME ENDPOINT ###
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name,
)
if (stream is not None and stream is True) and not fake_stream:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
## COMPLETION CALL
headers = {"Content-Type": "application/json"}
if extra_headers is not None:
headers = {"Content-Type": "application/json", **extra_headers}
# Filter beta headers in HTTP headers before making the request
headers = update_headers_with_filtered_beta(
headers=headers, provider="bedrock_converse"
)
### ROUTING (ASYNC, STREAMING, SYNC)
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream is True:
return self.async_streaming(
model=model,
messages=messages,
api_base=proxy_endpoint_url,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=True,
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
json_mode=json_mode,
fake_stream=fake_stream,
credentials=credentials,
api_key=api_key,
stream_chunk_size=stream_chunk_size,
) # type: ignore
### ASYNC COMPLETION
return self.async_completion(
model=model,
messages=messages,
api_base=proxy_endpoint_url,
model_response=model_response,
encoding=encoding,
logging_obj=logging_obj,
optional_params=optional_params,
stream=stream, # type: ignore
litellm_params=litellm_params,
logger_fn=logger_fn,
headers=headers,
timeout=timeout,
client=client,
credentials=credentials,
api_key=api_key,
) # type: ignore
## TRANSFORMATION ##
_data = litellm.AmazonConverseConfig()._transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=extra_headers,
)
data = json.dumps(_data)
prepped = self.get_request_headers(
credentials=credentials,
aws_region_name=aws_region_name,
extra_headers=extra_headers,
endpoint_url=proxy_endpoint_url,
data=data,
headers=headers,
api_key=api_key,
)
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": data,
"api_base": proxy_endpoint_url,
"headers": prepped.headers,
},
)
if client is None or isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
timeout = httpx.Timeout(timeout)
_params["timeout"] = timeout
client = _get_httpx_client(_params) # type: ignore
else:
client = client
if stream is not None and stream is True:
completion_stream = make_sync_call(
client=(
client
if client is not None and isinstance(client, HTTPHandler)
else None
),
api_base=proxy_endpoint_url,
headers=prepped.headers, # type: ignore
data=data,
model=model,
messages=messages,
logging_obj=logging_obj,
json_mode=json_mode,
fake_stream=fake_stream,
stream_chunk_size=stream_chunk_size,
)
streaming_response = CustomStreamWrapper(
completion_stream=completion_stream,
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
### COMPLETION
try:
response = client.post(
url=proxy_endpoint_url,
headers=prepped.headers,
data=data,
logging_obj=logging_obj,
) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return litellm.AmazonConverseConfig()._transform_response(
model=model,
response=response,
model_response=model_response,
stream=stream if isinstance(stream, bool) else False,
logging_obj=logging_obj,
api_key="",
data=data,
messages=messages,
optional_params=optional_params,
encoding=encoding,
)

View File

@@ -0,0 +1,5 @@
"""
Uses base_llm_http_handler to call the 'converse like' endpoint.
Relevant issue: https://github.com/BerriAI/litellm/issues/8085
"""

View File

@@ -0,0 +1,3 @@
"""
Uses `converse_transformation.py` to transform the messages to the format required by Bedrock Converse.
"""

View File

@@ -0,0 +1,547 @@
"""
Transformation for Bedrock Invoke Agent
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_agent-runtime_InvokeAgent.html
"""
import base64
import json
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import httpx
from litellm._logging import verbose_logger
from litellm._uuid import uuid
from litellm.litellm_core_utils.prompt_templates.common_utils import (
convert_content_list_to_str,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.bedrock.common_utils import BedrockError
from litellm.types.llms.bedrock_invoke_agents import (
InvokeAgentChunkPayload,
InvokeAgentEvent,
InvokeAgentEventHeaders,
InvokeAgentEventList,
InvokeAgentMetadata,
InvokeAgentModelInvocationInput,
InvokeAgentModelInvocationOutput,
InvokeAgentOrchestrationTrace,
InvokeAgentPreProcessingTrace,
InvokeAgentTrace,
InvokeAgentTracePayload,
InvokeAgentUsage,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices, Message, ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonInvokeAgentConfig(BaseConfig, BaseAWSLLM):
def __init__(self, **kwargs):
BaseConfig.__init__(self, **kwargs)
BaseAWSLLM.__init__(self, **kwargs)
def get_supported_openai_params(self, model: str) -> List[str]:
"""
This is a base invoke agent model mapping. For Invoke Agent - define a bedrock provider specific config that extends this class.
Bedrock Invoke Agents has 0 OpenAI compatible params
As of May 29th, 2025 - they don't support streaming.
"""
return []
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
This is a base invoke agent model mapping. For Invoke Agent - define a bedrock provider specific config that extends this class.
"""
return optional_params
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete url for the request
"""
### SET RUNTIME ENDPOINT ###
aws_bedrock_runtime_endpoint = optional_params.get(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
endpoint_url, _ = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=self._get_aws_region_name(
optional_params=optional_params, model=model
),
endpoint_type="agent",
)
agent_id, agent_alias_id = self._get_agent_id_and_alias_id(model)
session_id = self._get_session_id(optional_params)
endpoint_url = f"{endpoint_url}/agents/{agent_id}/agentAliases/{agent_alias_id}/sessions/{session_id}/text"
return endpoint_url
def sign_request(
self,
headers: dict,
optional_params: dict,
request_data: dict,
api_base: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
) -> Tuple[dict, Optional[bytes]]:
return self._sign_request(
service_name="bedrock",
headers=headers,
optional_params=optional_params,
request_data=request_data,
api_base=api_base,
model=model,
stream=stream,
fake_stream=fake_stream,
api_key=api_key,
)
def _get_agent_id_and_alias_id(self, model: str) -> tuple[str, str]:
"""
model = "agent/L1RT58GYRW/MFPSBCXYTW"
agent_id = "L1RT58GYRW"
agent_alias_id = "MFPSBCXYTW"
"""
# Split the model string by '/' and extract components
parts = model.split("/")
if len(parts) != 3 or parts[0] != "agent":
raise ValueError(
"Invalid model format. Expected format: 'model=agent/AGENT_ID/ALIAS_ID'"
)
return parts[1], parts[2] # Return (agent_id, agent_alias_id)
def _get_session_id(self, optional_params: dict) -> str:
""" """
return optional_params.get("sessionID", None) or str(uuid.uuid4())
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
# use the last message content as the query
query: str = convert_content_list_to_str(messages[-1])
return {
"inputText": query,
"enableTrace": True,
**optional_params,
}
def _parse_aws_event_stream(self, raw_content: bytes) -> InvokeAgentEventList:
"""
Parse AWS event stream format using boto3/botocore's built-in parser.
This is the same approach used in the existing AWSEventStreamDecoder.
"""
try:
from botocore.eventstream import EventStreamBuffer
from botocore.parsers import EventStreamJSONParser
except ImportError:
raise ImportError("boto3/botocore is required for AWS event stream parsing")
events: InvokeAgentEventList = []
parser = EventStreamJSONParser()
event_stream_buffer = EventStreamBuffer()
# Add the entire response to the buffer
event_stream_buffer.add_data(raw_content)
# Process all events in the buffer
for event in event_stream_buffer:
try:
headers = self._extract_headers_from_event(event)
event_type = headers.get("event_type", "")
if event_type == "chunk":
# Handle chunk events specially - they contain decoded content, not JSON
message = self._parse_message_from_event(event, parser)
parsed_event: InvokeAgentEvent = InvokeAgentEvent()
if message:
# For chunk events, create a payload with the decoded content
parsed_event = {
"headers": headers,
"payload": {
"bytes": base64.b64encode(
message.encode("utf-8")
).decode("utf-8")
}, # Re-encode for consistency
}
events.append(parsed_event)
elif event_type == "trace":
# Handle trace events normally - they contain JSON
message = self._parse_message_from_event(event, parser)
if message:
try:
event_data = json.loads(message)
parsed_event = {
"headers": headers,
"payload": event_data,
}
events.append(parsed_event)
except json.JSONDecodeError as e:
verbose_logger.warning(
f"Failed to parse trace event JSON: {e}"
)
else:
verbose_logger.debug(f"Unknown event type: {event_type}")
except Exception as e:
verbose_logger.error(f"Error processing event: {e}")
continue
return events
def _parse_message_from_event(self, event, parser) -> Optional[str]:
"""Extract message content from an AWS event, adapted from AWSEventStreamDecoder."""
try:
response_dict = event.to_response_dict()
verbose_logger.debug(f"Response dict: {response_dict}")
# Use the same response shape parsing as the existing decoder
parsed_response = parser.parse(
response_dict, self._get_response_stream_shape()
)
verbose_logger.debug(f"Parsed response: {parsed_response}")
if response_dict["status_code"] != 200:
decoded_body = response_dict["body"].decode()
if isinstance(decoded_body, dict):
error_message = decoded_body.get("message")
elif isinstance(decoded_body, str):
error_message = decoded_body
else:
error_message = ""
exception_status = response_dict["headers"].get(":exception-type")
error_message = exception_status + " " + error_message
raise BedrockError(
status_code=response_dict["status_code"],
message=(
json.dumps(error_message)
if isinstance(error_message, dict)
else error_message
),
)
if "chunk" in parsed_response:
chunk = parsed_response.get("chunk")
if not chunk:
return None
return chunk.get("bytes").decode()
else:
chunk = response_dict.get("body")
if not chunk:
return None
return chunk.decode()
except Exception as e:
verbose_logger.debug(f"Error parsing message from event: {e}")
return None
def _extract_headers_from_event(self, event) -> InvokeAgentEventHeaders:
"""Extract headers from an AWS event for categorization."""
try:
response_dict = event.to_response_dict()
headers = response_dict.get("headers", {})
# Extract the event-type and content-type headers that we care about
return InvokeAgentEventHeaders(
event_type=headers.get(":event-type", ""),
content_type=headers.get(":content-type", ""),
message_type=headers.get(":message-type", ""),
)
except Exception as e:
verbose_logger.debug(f"Error extracting headers: {e}")
return InvokeAgentEventHeaders(
event_type="", content_type="", message_type=""
)
def _get_response_stream_shape(self):
"""Get the response stream shape for parsing, reusing existing logic."""
try:
# Try to reuse the cached shape from the existing decoder
from litellm.llms.bedrock.chat.invoke_handler import (
get_response_stream_shape,
)
return get_response_stream_shape()
except ImportError:
# Fallback: create our own shape
try:
from botocore.loaders import Loader
from botocore.model import ServiceModel
loader = Loader()
bedrock_service_dict = loader.load_service_model(
"bedrock-runtime", "service-2"
)
bedrock_service_model = ServiceModel(bedrock_service_dict)
return bedrock_service_model.shape_for("ResponseStream")
except Exception as e:
verbose_logger.warning(f"Could not load response stream shape: {e}")
return None
def _extract_response_content(self, events: InvokeAgentEventList) -> str:
"""Extract the final response content from parsed events."""
response_parts = []
for event in events:
headers = event.get("headers", {})
payload = event.get("payload")
event_type = headers.get(
"event_type"
) # Note: using event_type not event-type
if event_type == "chunk" and payload:
# Extract base64 encoded content from chunk events
chunk_payload: InvokeAgentChunkPayload = payload # type: ignore
encoded_bytes = chunk_payload.get("bytes", "")
if encoded_bytes:
try:
decoded_content = base64.b64decode(encoded_bytes).decode(
"utf-8"
)
response_parts.append(decoded_content)
except Exception as e:
verbose_logger.warning(f"Failed to decode chunk content: {e}")
return "".join(response_parts)
def _extract_usage_info(self, events: InvokeAgentEventList) -> InvokeAgentUsage:
"""Extract token usage information from trace events."""
usage_info = InvokeAgentUsage(
inputTokens=0,
outputTokens=0,
model=None,
)
response_model: Optional[str] = None
for event in events:
if not self._is_trace_event(event):
continue
trace_data = self._get_trace_data(event)
if not trace_data:
continue
verbose_logger.debug(f"Trace event: {trace_data}")
# Extract usage from pre-processing trace
self._extract_and_update_preprocessing_usage(
trace_data=trace_data,
usage_info=usage_info,
)
# Extract model from orchestration trace
if response_model is None:
response_model = self._extract_orchestration_model(trace_data)
usage_info["model"] = response_model
return usage_info
def _is_trace_event(self, event: InvokeAgentEvent) -> bool:
"""Check if the event is a trace event."""
headers = event.get("headers", {})
event_type = headers.get("event_type")
payload = event.get("payload")
return event_type == "trace" and payload is not None
def _get_trace_data(self, event: InvokeAgentEvent) -> Optional[InvokeAgentTrace]:
"""Extract trace data from a trace event."""
payload = event.get("payload")
if not payload:
return None
trace_payload: InvokeAgentTracePayload = payload # type: ignore
return trace_payload.get("trace", {})
def _extract_and_update_preprocessing_usage(
self, trace_data: InvokeAgentTrace, usage_info: InvokeAgentUsage
) -> None:
"""Extract usage information from preprocessing trace."""
pre_processing: Optional[InvokeAgentPreProcessingTrace] = trace_data.get(
"preProcessingTrace"
)
if not pre_processing:
return
model_output: Optional[InvokeAgentModelInvocationOutput] = (
pre_processing.get("modelInvocationOutput")
or InvokeAgentModelInvocationOutput()
)
if not model_output:
return
metadata: Optional[InvokeAgentMetadata] = (
model_output.get("metadata") or InvokeAgentMetadata()
)
if not metadata:
return
usage: Optional[Union[InvokeAgentUsage, Dict]] = metadata.get("usage", {})
if not usage:
return
usage_info["inputTokens"] += usage.get("inputTokens", 0)
usage_info["outputTokens"] += usage.get("outputTokens", 0)
def _extract_orchestration_model(
self, trace_data: InvokeAgentTrace
) -> Optional[str]:
"""Extract model information from orchestration trace."""
orchestration_trace: Optional[InvokeAgentOrchestrationTrace] = trace_data.get(
"orchestrationTrace"
)
if not orchestration_trace:
return None
model_invocation: Optional[InvokeAgentModelInvocationInput] = (
orchestration_trace.get("modelInvocationInput")
or InvokeAgentModelInvocationInput()
)
if not model_invocation:
return None
return model_invocation.get("foundationModel")
def _build_model_response(
self,
content: str,
model: str,
usage_info: InvokeAgentUsage,
model_response: ModelResponse,
) -> ModelResponse:
"""Build the final ModelResponse object."""
# Create the message content
message = Message(content=content, role="assistant")
# Create choices
choice = Choices(finish_reason="stop", index=0, message=message)
# Update model response
model_response.choices = [choice]
model_response.model = usage_info.get("model", model)
# Add usage information if available
if usage_info:
from litellm.types.utils import Usage
usage = Usage(
prompt_tokens=usage_info.get("inputTokens", 0),
completion_tokens=usage_info.get("outputTokens", 0),
total_tokens=usage_info.get("inputTokens", 0)
+ usage_info.get("outputTokens", 0),
)
setattr(model_response, "usage", usage)
return model_response
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
# Get the raw binary content
raw_content = raw_response.content
verbose_logger.debug(
f"Processing {len(raw_content)} bytes of AWS event stream data"
)
# Parse the AWS event stream format
events = self._parse_aws_event_stream(raw_content)
verbose_logger.debug(f"Parsed {len(events)} events from stream")
# Extract response content from chunk events
content = self._extract_response_content(events)
# Extract usage information from trace events
usage_info = self._extract_usage_info(events)
# Build and return the model response
return self._build_model_response(
content=content,
model=model,
usage_info=usage_info,
model_response=model_response,
)
except Exception as e:
verbose_logger.error(
f"Error processing Bedrock Invoke Agent response: {str(e)}"
)
raise BedrockError(
message=f"Error processing response: {str(e)}",
status_code=raw_response.status_code,
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return BedrockError(status_code=status_code, message=error_message)
def should_fake_stream(
self,
model: Optional[str],
stream: Optional[bool],
custom_llm_provider: Optional[str] = None,
) -> bool:
return True

View File

@@ -0,0 +1,99 @@
import types
from typing import List, Optional
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
class AmazonAI21Config(AmazonInvokeConfig, BaseConfig):
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
Supported Params for the Amazon / AI21 models:
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
- `presencePenalty` (object): Placeholder for presence penalty object.
- `countPenalty` (object): Placeholder for count penalty object.
"""
maxTokens: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
stopSequences: Optional[list] = None
frequencePenalty: Optional[dict] = None
presencePenalty: Optional[dict] = None
countPenalty: Optional[dict] = None
def __init__(
self,
maxTokens: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[float] = None,
stopSequences: Optional[list] = None,
frequencePenalty: Optional[dict] = None,
presencePenalty: Optional[dict] = None,
countPenalty: Optional[dict] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
AmazonInvokeConfig.__init__(self)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List:
return [
"max_tokens",
"temperature",
"top_p",
"stream",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for k, v in non_default_params.items():
if k == "max_tokens":
optional_params["maxTokens"] = v
if k == "temperature":
optional_params["temperature"] = v
if k == "top_p":
optional_params["topP"] = v
if k == "stream":
optional_params["stream"] = v
return optional_params

View File

@@ -0,0 +1,75 @@
import types
from typing import List, Optional
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.llms.cohere.chat.transformation import CohereChatConfig
class AmazonCohereConfig(AmazonInvokeConfig, CohereChatConfig):
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
Supported Params for the Amazon / Cohere models:
- `max_tokens` (integer) max tokens,
- `temperature` (float) model temperature,
- `return_likelihood` (string) n/a
"""
max_tokens: Optional[int] = None
return_likelihood: Optional[str] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
return_likelihood: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
AmazonInvokeConfig.__init__(self)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List[str]:
supported_params = CohereChatConfig.get_supported_openai_params(
self, model=model
)
return supported_params
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return CohereChatConfig.map_openai_params(
self,
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)

View File

@@ -0,0 +1,135 @@
from typing import Any, List, Optional, cast
from httpx import Response
from litellm import verbose_logger
from litellm.litellm_core_utils.prompt_templates.common_utils import (
_parse_content_for_reasoning,
)
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
LiteLLMLoggingObj,
)
from litellm.types.llms.bedrock import AmazonDeepSeekR1StreamingResponse
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
ChatCompletionUsageBlock,
Choices,
Delta,
Message,
ModelResponse,
ModelResponseStream,
StreamingChoices,
)
from .amazon_llama_transformation import AmazonLlamaConfig
class AmazonDeepSeekR1Config(AmazonLlamaConfig):
def transform_response(
self,
model: str,
raw_response: Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Extract the reasoning content, and return it as a separate field in the response.
"""
response = super().transform_response(
model,
raw_response,
model_response,
logging_obj,
request_data,
messages,
optional_params,
litellm_params,
encoding,
api_key,
json_mode,
)
prompt = cast(Optional[str], request_data.get("prompt"))
message_content = cast(
Optional[str], cast(Choices, response.choices[0]).message.get("content")
)
if prompt and prompt.strip().endswith("<think>") and message_content:
message_content_with_reasoning_token = "<think>" + message_content
reasoning, content = _parse_content_for_reasoning(
message_content_with_reasoning_token
)
provider_specific_fields = (
cast(Choices, response.choices[0]).message.provider_specific_fields
or {}
)
if reasoning:
provider_specific_fields["reasoning_content"] = reasoning
message = Message(
**{
**cast(Choices, response.choices[0]).message.model_dump(),
"content": content,
"provider_specific_fields": provider_specific_fields,
}
)
cast(Choices, response.choices[0]).message = message
return response
class AmazonDeepseekR1ResponseIterator(BaseModelResponseIterator):
def __init__(self, streaming_response: Any, sync_stream: bool) -> None:
super().__init__(streaming_response=streaming_response, sync_stream=sync_stream)
self.has_finished_thinking = False
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
"""
Deepseek r1 starts by thinking, then it generates the response.
"""
try:
typed_chunk = AmazonDeepSeekR1StreamingResponse(**chunk) # type: ignore
generated_content = typed_chunk["generation"]
if generated_content == "</think>" and not self.has_finished_thinking:
verbose_logger.debug(
"Deepseek r1: </think> received, setting has_finished_thinking to True"
)
generated_content = ""
self.has_finished_thinking = True
prompt_token_count = typed_chunk.get("prompt_token_count") or 0
generation_token_count = typed_chunk.get("generation_token_count") or 0
usage = ChatCompletionUsageBlock(
prompt_tokens=prompt_token_count,
completion_tokens=generation_token_count,
total_tokens=prompt_token_count + generation_token_count,
)
return ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=typed_chunk["stop_reason"],
delta=Delta(
content=(
generated_content
if self.has_finished_thinking
else None
),
reasoning_content=(
generated_content
if not self.has_finished_thinking
else None
),
),
)
],
usage=usage,
)
except Exception as e:
raise e

View File

@@ -0,0 +1,80 @@
import types
from typing import List, Optional
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
class AmazonLlamaConfig(AmazonInvokeConfig, BaseConfig):
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
Supported Params for the Amazon / Meta Llama models:
- `max_gen_len` (integer) max tokens,
- `temperature` (float) temperature for model,
- `top_p` (float) top p for model
"""
max_gen_len: Optional[int] = None
temperature: Optional[float] = None
topP: Optional[float] = None
def __init__(
self,
maxTokenCount: Optional[int] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
AmazonInvokeConfig.__init__(self)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List:
return [
"max_tokens",
"temperature",
"top_p",
"stream",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for k, v in non_default_params.items():
if k == "max_tokens":
optional_params["max_gen_len"] = v
if k == "temperature":
optional_params["temperature"] = v
if k == "top_p":
optional_params["top_p"] = v
if k == "stream":
optional_params["stream"] = v
return optional_params

View File

@@ -0,0 +1,119 @@
import types
from typing import List, Optional, TYPE_CHECKING
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.llms.bedrock.common_utils import BedrockError
if TYPE_CHECKING:
from litellm.types.utils import ModelResponse
class AmazonMistralConfig(AmazonInvokeConfig, BaseConfig):
"""
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
Supported Params for the Amazon / Mistral models:
- `max_tokens` (integer) max tokens,
- `temperature` (float) temperature for model,
- `top_p` (float) top p for model
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
- `top_k` (float) top k for model
"""
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[float] = None
stop: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[int] = None,
top_k: Optional[float] = None,
stop: Optional[List[str]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
AmazonInvokeConfig.__init__(self)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self, model: str) -> List[str]:
return ["max_tokens", "temperature", "top_p", "stop", "stream"]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for k, v in non_default_params.items():
if k == "max_tokens":
optional_params["max_tokens"] = v
if k == "temperature":
optional_params["temperature"] = v
if k == "top_p":
optional_params["top_p"] = v
if k == "stop":
optional_params["stop"] = v
if k == "stream":
optional_params["stream"] = v
return optional_params
@staticmethod
def get_outputText(
completion_response: dict, model_response: "ModelResponse"
) -> str:
"""This function extracts the output text from a bedrock mistral completion.
As a side effect, it updates the finish reason for a model response.
Args:
completion_response: JSON from the completion.
model_response: ModelResponse
Returns:
A string with the response of the LLM
"""
if "choices" in completion_response:
outputText = completion_response["choices"][0]["message"]["content"]
model_response.choices[0].finish_reason = completion_response["choices"][0][
"finish_reason"
]
elif "outputs" in completion_response:
outputText = completion_response["outputs"][0]["text"]
model_response.choices[0].finish_reason = completion_response["outputs"][0][
"stop_reason"
]
else:
raise BedrockError(
message="Unexpected mistral completion response", status_code=400
)
return outputText

View File

@@ -0,0 +1,266 @@
"""
Transformation for Bedrock Moonshot AI (Kimi K2) models.
Supports the Kimi K2 Thinking model available on Amazon Bedrock.
Model format: bedrock/moonshot.kimi-k2-thinking-v1:0
Reference: https://aws.amazon.com/about-aws/whats-new/2025/12/amazon-bedrock-fully-managed-open-weight-models/
"""
from typing import TYPE_CHECKING, Any, List, Optional, Union
import re
import httpx
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.llms.bedrock.common_utils import BedrockError
from litellm.llms.moonshot.chat.transformation import MoonshotChatConfig
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import Choices
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
from litellm.types.utils import ModelResponse
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonMoonshotConfig(AmazonInvokeConfig, MoonshotChatConfig):
"""
Configuration for Bedrock Moonshot AI (Kimi K2) models.
Reference:
https://aws.amazon.com/about-aws/whats-new/2025/12/amazon-bedrock-fully-managed-open-weight-models/
https://platform.moonshot.ai/docs/api/chat
Supported Params for the Amazon / Moonshot models:
- `max_tokens` (integer) max tokens
- `temperature` (float) temperature for model (0-1 for Moonshot)
- `top_p` (float) top p for model
- `stream` (bool) whether to stream responses
- `tools` (list) tool definitions (supported on kimi-k2-thinking)
- `tool_choice` (str|dict) tool choice specification (supported on kimi-k2-thinking)
NOT Supported on Bedrock:
- `stop` sequences (Bedrock doesn't support stopSequences field for this model)
Note: The kimi-k2-thinking model DOES support tool calls, unlike kimi-thinking-preview.
"""
def __init__(self, **kwargs):
AmazonInvokeConfig.__init__(self, **kwargs)
MoonshotChatConfig.__init__(self, **kwargs)
@property
def custom_llm_provider(self) -> Optional[str]:
return "bedrock"
def _get_model_id(self, model: str) -> str:
"""
Extract the actual model ID from the LiteLLM model name.
Removes routing prefixes like:
- bedrock/invoke/moonshot.kimi-k2-thinking -> moonshot.kimi-k2-thinking
- invoke/moonshot.kimi-k2-thinking -> moonshot.kimi-k2-thinking
- moonshot.kimi-k2-thinking -> moonshot.kimi-k2-thinking
"""
# Remove bedrock/ prefix if present
if model.startswith("bedrock/"):
model = model[8:]
# Remove invoke/ prefix if present
if model.startswith("invoke/"):
model = model[7:]
# Remove any provider prefix (e.g., moonshot/)
if "/" in model and not model.startswith("arn:"):
parts = model.split("/", 1)
if len(parts) == 2:
model = parts[1]
return model
def get_supported_openai_params(self, model: str) -> List[str]:
"""
Get the supported OpenAI params for Moonshot AI models on Bedrock.
Bedrock-specific limitations:
- stopSequences field is not supported on Bedrock (unlike native Moonshot API)
- functions parameter is not supported (use tools instead)
- tool_choice doesn't support "required" value
Note: kimi-k2-thinking DOES support tool calls (unlike kimi-thinking-preview)
The parent MoonshotChatConfig class handles the kimi-thinking-preview exclusion.
"""
excluded_params: List[str] = [
"functions",
"stop",
] # Bedrock doesn't support stopSequences
base_openai_params = super(
MoonshotChatConfig, self
).get_supported_openai_params(model=model)
final_params: List[str] = []
for param in base_openai_params:
if param not in excluded_params:
final_params.append(param)
return final_params
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI parameters to Moonshot AI parameters for Bedrock.
Handles Moonshot AI specific limitations:
- tool_choice doesn't support "required" value
- Temperature <0.3 limitation for n>1
- Temperature range is [0, 1] (not [0, 2] like OpenAI)
"""
return MoonshotChatConfig.map_openai_params(
self,
non_default_params=non_default_params,
optional_params=optional_params,
model=model,
drop_params=drop_params,
)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the request for Bedrock Moonshot AI models.
Uses the Moonshot transformation logic which handles:
- Converting content lists to strings (Moonshot doesn't support list format)
- Adding tool_choice="required" message if needed
- Temperature and parameter validation
"""
# Filter out AWS credentials using the existing method from BaseAWSLLM
self._get_boto_credentials_from_optional_params(optional_params, model)
# Strip routing prefixes to get the actual model ID
clean_model_id = self._get_model_id(model)
# Use Moonshot's transform_request which handles message transformation
# and tool_choice="required" workaround
return MoonshotChatConfig.transform_request(
self,
model=clean_model_id,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
def _extract_reasoning_from_content(
self, content: str
) -> tuple[Optional[str], str]:
"""
Extract reasoning content from <reasoning> tags in the response.
Moonshot AI's Kimi K2 Thinking model returns reasoning in <reasoning> tags.
This method extracts that content and returns it separately.
Args:
content: The full content string from the API response
Returns:
tuple: (reasoning_content, main_content)
"""
if not content:
return None, content
# Match <reasoning>...</reasoning> tags
reasoning_match = re.match(
r"<reasoning>(.*?)</reasoning>\s*(.*)", content, re.DOTALL
)
if reasoning_match:
reasoning_content = reasoning_match.group(1).strip()
main_content = reasoning_match.group(2).strip()
return reasoning_content, main_content
return None, content
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: "ModelResponse",
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> "ModelResponse":
"""
Transform the response from Bedrock Moonshot AI models.
Moonshot AI uses OpenAI-compatible response format, but returns reasoning
content in <reasoning> tags. This method:
1. Calls parent class transformation
2. Extracts reasoning content from <reasoning> tags
3. Sets reasoning_content on the message object
"""
# First, get the standard transformation
model_response = MoonshotChatConfig.transform_response(
self,
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
# Extract reasoning content from <reasoning> tags
if model_response.choices and len(model_response.choices) > 0:
for choice in model_response.choices:
# Only process Choices (not StreamingChoices) which have message attribute
if (
isinstance(choice, Choices)
and choice.message
and choice.message.content
):
(
reasoning_content,
main_content,
) = self._extract_reasoning_from_content(choice.message.content)
if reasoning_content:
# Set the reasoning_content field
choice.message.reasoning_content = reasoning_content
# Update the main content without reasoning tags
choice.message.content = main_content
return model_response
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BedrockError:
"""Return the appropriate error class for Bedrock."""
return BedrockError(status_code=status_code, message=error_message)

View File

@@ -0,0 +1,120 @@
"""
Handles transforming requests for `bedrock/invoke/{nova} models`
Inherits from `AmazonConverseConfig`
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
"""
from typing import Any, List, Optional
import httpx
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
from ..converse_transformation import AmazonConverseConfig
from .base_invoke_transformation import AmazonInvokeConfig
class AmazonInvokeNovaConfig(AmazonInvokeConfig, AmazonConverseConfig):
"""
Config for sending `nova` requests to `/bedrock/invoke/`
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_supported_openai_params(self, model: str) -> list:
return AmazonConverseConfig.get_supported_openai_params(self, model)
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
return AmazonConverseConfig.map_openai_params(
self, non_default_params, optional_params, model, drop_params
)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
_transformed_nova_request = AmazonConverseConfig.transform_request(
self,
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
_bedrock_invoke_nova_request = BedrockInvokeNovaRequest(
**_transformed_nova_request
)
self._remove_empty_system_messages(_bedrock_invoke_nova_request)
bedrock_invoke_nova_request = self._filter_allowed_fields(
_bedrock_invoke_nova_request
)
return bedrock_invoke_nova_request
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: Logging,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return AmazonConverseConfig.transform_response(
self,
model,
raw_response,
model_response,
logging_obj,
request_data,
messages,
optional_params,
litellm_params,
encoding,
api_key,
json_mode,
)
def _filter_allowed_fields(
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
) -> dict:
"""
Filter out fields that are not allowed in the `BedrockInvokeNovaRequest` dataclass.
"""
allowed_fields = set(BedrockInvokeNovaRequest.__annotations__.keys())
return {
k: v for k, v in bedrock_invoke_nova_request.items() if k in allowed_fields
}
def _remove_empty_system_messages(
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
) -> None:
"""
In-place remove empty `system` messages from the request.
/bedrock/invoke/ does not allow empty `system` messages.
"""
_system_message = bedrock_invoke_nova_request.get("system", None)
if isinstance(_system_message, list) and len(_system_message) == 0:
bedrock_invoke_nova_request.pop("system", None)
return

View File

@@ -0,0 +1,192 @@
"""
Transformation for Bedrock imported models that use OpenAI Chat Completions format.
Use this for models imported into Bedrock that accept the OpenAI API format.
Model format: bedrock/openai/<model-id>
Example: bedrock/openai/arn:aws:bedrock:us-east-1:123456789012:imported-model/abc123
"""
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
import httpx
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.bedrock.common_utils import BedrockError
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
from litellm.passthrough.utils import CommonUtils
from litellm.types.llms.openai import AllMessageValues
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonBedrockOpenAIConfig(OpenAIGPTConfig, BaseAWSLLM):
"""
Configuration for Bedrock imported models that use OpenAI Chat Completions format.
This class handles the transformation of requests and responses for Bedrock
imported models that accept the OpenAI API format directly.
Inherits from OpenAIGPTConfig to leverage standard OpenAI parameter handling
and response transformation, while adding Bedrock-specific URL generation
and AWS request signing.
Usage:
model = "bedrock/openai/arn:aws:bedrock:us-east-1:123456789012:imported-model/abc123"
"""
def __init__(self, **kwargs):
OpenAIGPTConfig.__init__(self, **kwargs)
BaseAWSLLM.__init__(self, **kwargs)
@property
def custom_llm_provider(self) -> Optional[str]:
return "bedrock"
def _get_openai_model_id(self, model: str) -> str:
"""
Extract the actual model ID from the LiteLLM model name.
Input format: bedrock/openai/<model-id>
Returns: <model-id>
"""
# Remove bedrock/ prefix if present
if model.startswith("bedrock/"):
model = model[8:]
# Remove openai/ prefix
if model.startswith("openai/"):
model = model[7:]
return model
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for the Bedrock invoke endpoint.
Uses the standard Bedrock invoke endpoint format.
"""
model_id = self._get_openai_model_id(model)
# Get AWS region
aws_region_name = self._get_aws_region_name(
optional_params=optional_params, model=model
)
# Get runtime endpoint
aws_bedrock_runtime_endpoint = optional_params.get(
"aws_bedrock_runtime_endpoint", None
)
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=aws_region_name,
)
# Encode model ID for ARNs (e.g., :imported-model/ -> :imported-model%2F)
model_id = CommonUtils.encode_bedrock_runtime_modelid_arn(model_id)
# Build the invoke URL
if stream:
endpoint_url = (
f"{endpoint_url}/model/{model_id}/invoke-with-response-stream"
)
else:
endpoint_url = f"{endpoint_url}/model/{model_id}/invoke"
return endpoint_url
def sign_request(
self,
headers: dict,
optional_params: dict,
request_data: dict,
api_base: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
) -> Tuple[dict, Optional[bytes]]:
"""
Sign the request using AWS Signature Version 4.
"""
return self._sign_request(
service_name="bedrock",
headers=headers,
optional_params=optional_params,
request_data=request_data,
api_base=api_base,
api_key=api_key,
model=model,
stream=stream,
fake_stream=fake_stream,
)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform the request to OpenAI Chat Completions format for Bedrock imported models.
Removes AWS-specific params and stream param (handled separately in URL),
then delegates to parent class for standard OpenAI request transformation.
"""
# Remove stream from optional_params as it's handled separately in URL
optional_params.pop("stream", None)
# Remove AWS-specific params that shouldn't be in the request body
inference_params = {
k: v
for k, v in optional_params.items()
if k not in self.aws_authentication_params
}
# Use parent class transform_request for OpenAI format
return super().transform_request(
model=self._get_openai_model_id(model),
messages=messages,
optional_params=inference_params,
litellm_params=litellm_params,
headers=headers,
)
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate the environment and return headers.
For Bedrock, we don't need Bearer token auth since we use AWS SigV4.
"""
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BedrockError:
"""Return the appropriate error class for Bedrock."""
return BedrockError(status_code=status_code, message=error_message)

View File

@@ -0,0 +1,99 @@
"""
Handles transforming requests for `bedrock/invoke/{qwen2} models`
Inherits from `AmazonQwen3Config` since Qwen2 and Qwen3 architectures are mostly similar.
The main difference is in the response format: Qwen2 uses "text" field while Qwen3 uses "generation" field.
Qwen2 + Invoke API Tutorial: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
"""
from typing import Any, List, Optional
import httpx
from litellm.llms.bedrock.chat.invoke_transformations.amazon_qwen3_transformation import (
AmazonQwen3Config,
)
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
class AmazonQwen2Config(AmazonQwen3Config):
"""
Config for sending `qwen2` requests to `/bedrock/invoke/`
Inherits from AmazonQwen3Config since Qwen2 and Qwen3 architectures are mostly similar.
The main difference is in the response format: Qwen2 uses "text" field while Qwen3 uses "generation" field.
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
"""
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform Qwen2 Bedrock response to OpenAI format
Qwen2 uses "text" field, but we also support "generation" field for compatibility.
"""
try:
if hasattr(raw_response, "json"):
response_data = raw_response.json()
else:
response_data = raw_response
# Extract the generated text - Qwen2 uses "text" field, but also support "generation" for compatibility
generated_text = response_data.get("generation", "") or response_data.get(
"text", ""
)
# Clean up the response (remove assistant start token if present)
if generated_text.startswith("<|im_start|>assistant\n"):
generated_text = generated_text[len("<|im_start|>assistant\n") :]
if generated_text.endswith("<|im_end|>"):
generated_text = generated_text[: -len("<|im_end|>")]
# Set the content in the existing model_response structure
if hasattr(model_response, "choices") and len(model_response.choices) > 0:
choice = model_response.choices[0]
choice.message.content = generated_text
choice.finish_reason = "stop"
# Set usage information if available in response
if "usage" in response_data:
usage_data = response_data["usage"]
setattr(
model_response,
"usage",
Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
),
)
return model_response
except Exception as e:
if logging_obj:
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=raw_response,
additional_args={"error": str(e)},
)
raise e

View File

@@ -0,0 +1,225 @@
"""
Handles transforming requests for `bedrock/invoke/{qwen3} models`
Inherits from `AmazonInvokeConfig`
Qwen3 + Invoke API Tutorial: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
"""
from typing import Any, List, Optional
import httpx
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
LiteLLMLoggingObj,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
class AmazonQwen3Config(AmazonInvokeConfig, BaseConfig):
"""
Config for sending `qwen3` requests to `/bedrock/invoke/`
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
"""
max_tokens: Optional[int] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
stop: Optional[List[str]] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
AmazonInvokeConfig.__init__(self)
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"max_tokens",
"temperature",
"top_p",
"top_k",
"stop",
"stream",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for k, v in non_default_params.items():
if k == "max_tokens":
optional_params["max_tokens"] = v
if k == "temperature":
optional_params["temperature"] = v
if k == "top_p":
optional_params["top_p"] = v
if k == "top_k":
optional_params["top_k"] = v
if k == "stop":
optional_params["stop"] = v
if k == "stream":
optional_params["stream"] = v
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Transform OpenAI format to Qwen3 Bedrock invoke format
"""
# Convert messages to prompt format
prompt = self._convert_messages_to_prompt(messages)
# Build the request body
request_body = {
"prompt": prompt,
}
# Add optional parameters
if "max_tokens" in optional_params:
request_body["max_gen_len"] = optional_params["max_tokens"]
if "temperature" in optional_params:
request_body["temperature"] = optional_params["temperature"]
if "top_p" in optional_params:
request_body["top_p"] = optional_params["top_p"]
if "top_k" in optional_params:
request_body["top_k"] = optional_params["top_k"]
if "stop" in optional_params:
request_body["stop"] = optional_params["stop"]
return request_body
def _convert_messages_to_prompt(self, messages: List[AllMessageValues]) -> str:
"""
Convert OpenAI messages format to Qwen3 prompt format
Supports tool calls, multimodal content, and various message types
"""
prompt_parts = []
for message in messages:
role = message.get("role", "")
content = message.get("content", "")
tool_calls = message.get("tool_calls", [])
if role == "system":
prompt_parts.append(f"<|im_start|>system\n{content}<|im_end|>")
elif role == "user":
# Handle multimodal content
if isinstance(content, list):
text_content = []
for item in content:
if item.get("type") == "text":
text_content.append(item.get("text", ""))
elif item.get("type") == "image_url":
# For Qwen3, we can include image placeholders
text_content.append(
"<|vision_start|><|image_pad|><|vision_end|>"
)
content = "".join(text_content)
prompt_parts.append(f"<|im_start|>user\n{content}<|im_end|>")
elif role == "assistant":
if tool_calls and isinstance(tool_calls, list):
# Handle tool calls
for tool_call in tool_calls:
function_name = tool_call.get("function", {}).get("name", "")
function_args = tool_call.get("function", {}).get(
"arguments", ""
)
prompt_parts.append(
f'<|im_start|>assistant\n<tool_call>\n{{"name": "{function_name}", "arguments": "{function_args}"}}\n</tool_call><|im_end|>'
)
else:
prompt_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
elif role == "tool":
# Handle tool responses
prompt_parts.append(f"<|im_start|>tool\n{content}<|im_end|>")
# Add assistant start token for response generation
prompt_parts.append("<|im_start|>assistant\n")
return "\n".join(prompt_parts)
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform Qwen3 Bedrock response to OpenAI format
"""
try:
if hasattr(raw_response, "json"):
response_data = raw_response.json()
else:
response_data = raw_response
# Extract the generated text - Qwen3 uses "generation" field
generated_text = response_data.get("generation", "")
# Clean up the response (remove assistant start token if present)
if generated_text.startswith("<|im_start|>assistant\n"):
generated_text = generated_text[len("<|im_start|>assistant\n") :]
if generated_text.endswith("<|im_end|>"):
generated_text = generated_text[: -len("<|im_end|>")]
# Set the content in the existing model_response structure
if hasattr(model_response, "choices") and len(model_response.choices) > 0:
choice = model_response.choices[0]
choice.message.content = generated_text
choice.finish_reason = "stop"
# Set usage information if available in response
if "usage" in response_data:
usage_data = response_data["usage"]
setattr(
model_response,
"usage",
Usage(
prompt_tokens=usage_data.get("prompt_tokens", 0),
completion_tokens=usage_data.get("completion_tokens", 0),
total_tokens=usage_data.get("total_tokens", 0),
),
)
return model_response
except Exception as e:
if logging_obj:
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=raw_response,
additional_args={"error": str(e)},
)
raise e

View File

@@ -0,0 +1,116 @@
import re
import types
from typing import List, Optional, Union
import litellm
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
class AmazonTitanConfig(AmazonInvokeConfig, BaseConfig):
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
Supported Params for the Amazon Titan models:
- `maxTokenCount` (integer) max tokens,
- `stopSequences` (string[]) list of stop sequence strings
- `temperature` (float) temperature for model,
- `topP` (int) top p for model
"""
maxTokenCount: Optional[int] = None
stopSequences: Optional[list] = None
temperature: Optional[float] = None
topP: Optional[int] = None
def __init__(
self,
maxTokenCount: Optional[int] = None,
stopSequences: Optional[list] = None,
temperature: Optional[float] = None,
topP: Optional[int] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
AmazonInvokeConfig.__init__(self)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not k.startswith("_abc")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def _map_and_modify_arg(
self,
supported_params: dict,
provider: str,
model: str,
stop: Union[List[str], str],
):
"""
filter params to fit the required provider format, drop those that don't fit if user sets `litellm.drop_params = True`.
"""
filtered_stop = None
if "stop" in supported_params and litellm.drop_params:
if provider == "bedrock" and "amazon" in model:
filtered_stop = []
if isinstance(stop, list):
for s in stop:
if re.match(r"^(\|+|User:)$", s):
filtered_stop.append(s)
if filtered_stop is not None:
supported_params["stop"] = filtered_stop
return supported_params
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"max_tokens",
"max_completion_tokens",
"stop",
"temperature",
"top_p",
"stream",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for k, v in non_default_params.items():
if k == "max_tokens" or k == "max_completion_tokens":
optional_params["maxTokenCount"] = v
if k == "temperature":
optional_params["temperature"] = v
if k == "stop":
filtered_stop = self._map_and_modify_arg(
{"stop": v}, provider="bedrock", model=model, stop=v
)
optional_params["stopSequences"] = filtered_stop["stop"]
if k == "top_p":
optional_params["topP"] = v
if k == "stream":
optional_params["stream"] = v
return optional_params

View File

@@ -0,0 +1,280 @@
"""
Transforms OpenAI-style requests into TwelveLabs Pegasus 1.2 requests for Bedrock.
Reference:
https://docs.twelvelabs.io/docs/models/pegasus
"""
import json
import time
from typing import TYPE_CHECKING, Any, Dict, List, Optional
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.base_llm.base_utils import type_to_response_format_param
from litellm.llms.base_llm.chat.transformation import BaseConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.llms.bedrock.common_utils import BedrockError
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
from litellm.utils import get_base64_str
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonTwelveLabsPegasusConfig(AmazonInvokeConfig, BaseConfig):
"""
Handles transforming OpenAI-style requests into Bedrock InvokeModel requests for
`twelvelabs.pegasus-1-2-v1:0`.
Pegasus 1.2 requires an `inputPrompt` and a `mediaSource` that either references
an S3 object or a base64-encoded clip. Optional OpenAI params (temperature,
response_format, max_tokens) are translated to the TwelveLabs schema.
"""
def get_supported_openai_params(self, model: str) -> List[str]:
return [
"max_tokens",
"max_completion_tokens",
"temperature",
"response_format",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
for param, value in non_default_params.items():
if param in {"max_tokens", "max_completion_tokens"}:
optional_params["maxOutputTokens"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "response_format":
optional_params["responseFormat"] = self._normalize_response_format(
value
)
return optional_params
def _normalize_response_format(self, value: Any) -> Any:
"""Normalize response_format to TwelveLabs format.
TwelveLabs expects:
{
"jsonSchema": {...}
}
But OpenAI format is:
{
"type": "json_schema",
"json_schema": {
"name": "...",
"schema": {...}
}
}
"""
if isinstance(value, dict):
# If it has json_schema field, extract and transform it
if "json_schema" in value:
json_schema = value["json_schema"]
# Extract the schema if nested
if isinstance(json_schema, dict) and "schema" in json_schema:
return {"jsonSchema": json_schema["schema"]}
# Otherwise use json_schema directly
return {"jsonSchema": json_schema}
# If it already has jsonSchema, return as is
if "jsonSchema" in value:
return value
# Otherwise return the dict as is
return value
return type_to_response_format_param(response_format=value) or value
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
input_prompt = self._convert_messages_to_prompt(messages=messages)
request_data: Dict[str, Any] = {"inputPrompt": input_prompt}
media_source = self._build_media_source(optional_params)
if media_source is not None:
request_data["mediaSource"] = media_source
# Handle temperature and maxOutputTokens
for key in ("temperature", "maxOutputTokens"):
if key in optional_params:
request_data[key] = optional_params.get(key)
# Handle responseFormat - transform to TwelveLabs format
if "responseFormat" in optional_params:
response_format = optional_params["responseFormat"]
transformed_format = self._normalize_response_format(response_format)
if transformed_format:
request_data["responseFormat"] = transformed_format
return request_data
def _build_media_source(self, optional_params: dict) -> Optional[dict]:
direct_source = optional_params.get("mediaSource") or optional_params.get(
"media_source"
)
if isinstance(direct_source, dict):
return direct_source
base64_input = optional_params.get("video_base64") or optional_params.get(
"base64_string"
)
if base64_input:
return {"base64String": get_base64_str(base64_input)}
s3_uri = (
optional_params.get("video_s3_uri")
or optional_params.get("s3_uri")
or optional_params.get("media_source_s3_uri")
)
if s3_uri:
s3_location = {"uri": s3_uri}
bucket_owner = (
optional_params.get("video_s3_bucket_owner")
or optional_params.get("s3_bucket_owner")
or optional_params.get("media_source_bucket_owner")
)
if bucket_owner:
s3_location["bucketOwner"] = bucket_owner
return {"s3Location": s3_location}
return None
def _convert_messages_to_prompt(self, messages: List[AllMessageValues]) -> str:
prompt_parts: List[str] = []
for message in messages:
role = message.get("role", "user")
content = message.get("content", "")
if isinstance(content, list):
text_fragments = []
for item in content:
if isinstance(item, dict):
item_type = item.get("type")
if item_type == "text":
text_fragments.append(item.get("text", ""))
elif item_type == "image_url":
text_fragments.append("<image>")
elif item_type == "video_url":
text_fragments.append("<video>")
elif item_type == "audio_url":
text_fragments.append("<audio>")
elif isinstance(item, str):
text_fragments.append(item)
content = " ".join(text_fragments)
prompt_parts.append(f"{role}: {content}")
return "\n".join(part for part in prompt_parts if part).strip()
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform TwelveLabs Pegasus response to LiteLLM format.
TwelveLabs response format:
{
"message": "...",
"finishReason": "stop" | "length"
}
LiteLLM format:
ModelResponse with choices[0].message.content and finish_reason
"""
try:
completion_response = raw_response.json()
except Exception as e:
raise BedrockError(
message=f"Error parsing response: {raw_response.text}, error: {str(e)}",
status_code=raw_response.status_code,
)
verbose_logger.debug(
"twelvelabs pegasus response: %s",
json.dumps(completion_response, indent=4, default=str),
)
# Extract message content
message_content = completion_response.get("message", "")
# Extract finish reason and map to LiteLLM format
finish_reason_raw = completion_response.get("finishReason", "stop")
finish_reason = map_finish_reason(finish_reason_raw)
# Set the response content
try:
if (
message_content
and hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None)
is None
):
model_response.choices[0].message.content = message_content # type: ignore
model_response.choices[0].finish_reason = finish_reason
else:
raise Exception("Unable to set message content")
except Exception as e:
raise BedrockError(
message=f"Error setting response content: {str(e)}. Response: {completion_response}",
status_code=raw_response.status_code,
)
# Calculate usage from headers
bedrock_input_tokens = raw_response.headers.get(
"x-amzn-bedrock-input-token-count", None
)
bedrock_output_tokens = raw_response.headers.get(
"x-amzn-bedrock-output-token-count", None
)
prompt_tokens = int(
bedrock_input_tokens or litellm.token_counter(messages=messages)
)
completion_tokens = int(
bedrock_output_tokens
or litellm.token_counter(
text=model_response.choices[0].message.content, # type: ignore
count_response_tokens=True,
)
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response

View File

@@ -0,0 +1,98 @@
import types
from typing import Optional
import litellm
from .base_invoke_transformation import AmazonInvokeConfig
class AmazonAnthropicConfig(AmazonInvokeConfig):
"""
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
Supported Params for the Amazon / Anthropic models:
- `max_tokens_to_sample` (integer) max tokens,
- `temperature` (float) model temperature,
- `top_k` (integer) top k,
- `top_p` (integer) top p,
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
"""
max_tokens_to_sample: Optional[int] = litellm.max_tokens
stop_sequences: Optional[list] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[int] = None
anthropic_version: Optional[str] = None
def __init__(
self,
max_tokens_to_sample: Optional[int] = None,
stop_sequences: Optional[list] = None,
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_p: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals().copy()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
@staticmethod
def get_legacy_anthropic_model_names():
return [
"anthropic.claude-v2",
"anthropic.claude-instant-v1",
"anthropic.claude-v2:1",
]
def get_supported_openai_params(self, model: str):
return [
"max_tokens",
"max_completion_tokens",
"temperature",
"stop",
"top_p",
"stream",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
):
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens_to_sample"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "stream" and value is True:
optional_params["stream"] = value
return optional_params

View File

@@ -0,0 +1,206 @@
from typing import TYPE_CHECKING, Any, List, Optional
import httpx
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
AmazonInvokeConfig,
)
from litellm.llms.bedrock.common_utils import (
get_anthropic_beta_from_headers,
remove_custom_field_from_tools,
)
from litellm.types.llms.anthropic import ANTHROPIC_TOOL_SEARCH_BETA_HEADER
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class AmazonAnthropicClaudeConfig(AmazonInvokeConfig, AnthropicConfig):
"""
Reference:
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html
Supported Params for the Amazon / Anthropic Claude models (Claude 3, Claude 4, etc.):
Supports anthropic_beta parameter for beta features like:
- computer-use-2025-01-24 (Claude 3.7 Sonnet)
- computer-use-2024-10-22 (Claude 3.5 Sonnet v2)
- token-efficient-tools-2025-02-19 (Claude 3.7 Sonnet)
- interleaved-thinking-2025-05-14 (Claude 4 models)
- output-128k-2025-02-19 (Claude 3.7 Sonnet)
- dev-full-thinking-2025-05-14 (Claude 4 models)
- context-1m-2025-08-07 (Claude Sonnet 4)
"""
anthropic_version: str = "bedrock-2023-05-31"
@property
def custom_llm_provider(self) -> Optional[str]:
return "bedrock"
def get_supported_openai_params(self, model: str) -> List[str]:
return AnthropicConfig.get_supported_openai_params(self, model)
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
# Force tool-based structured outputs for Bedrock Invoke
# (similar to VertexAI fix in #19201)
# Bedrock Invoke doesn't support output_format parameter
original_model = model
if "response_format" in non_default_params:
# Use a model name that forces tool-based approach
model = "claude-3-sonnet-20240229"
optional_params = AnthropicConfig.map_openai_params(
self,
non_default_params,
optional_params,
model,
drop_params,
)
# Restore original model name
model = original_model
return optional_params
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
# Filter out AWS authentication parameters before passing to Anthropic transformation
# AWS params should only be used for signing requests, not included in request body
filtered_params = {
k: v
for k, v in optional_params.items()
if k not in self.aws_authentication_params
}
filtered_params = self._normalize_bedrock_tool_search_tools(filtered_params)
_anthropic_request = AnthropicConfig.transform_request(
self,
model=model,
messages=messages,
optional_params=filtered_params,
litellm_params=litellm_params,
headers=headers,
)
_anthropic_request.pop("model", None)
_anthropic_request.pop("stream", None)
# Bedrock Invoke doesn't support output_format parameter
_anthropic_request.pop("output_format", None)
# Bedrock Invoke doesn't support output_config parameter
# Fixes: https://github.com/BerriAI/litellm/issues/22797
_anthropic_request.pop("output_config", None)
if "anthropic_version" not in _anthropic_request:
_anthropic_request["anthropic_version"] = self.anthropic_version
# Remove `custom` field from tools (Bedrock doesn't support it)
# Claude Code sends `custom: {defer_loading: true}` on tool definitions,
# which causes Bedrock to reject the request with "Extra inputs are not permitted"
# Ref: https://github.com/BerriAI/litellm/issues/22847
remove_custom_field_from_tools(_anthropic_request)
tools = optional_params.get("tools")
tool_search_used = self.is_tool_search_used(tools)
programmatic_tool_calling_used = self.is_programmatic_tool_calling_used(tools)
input_examples_used = self.is_input_examples_used(tools)
beta_set = set(get_anthropic_beta_from_headers(headers))
auto_betas = self.get_anthropic_beta_list(
model=model,
optional_params=optional_params,
computer_tool_used=self.is_computer_tool_used(tools),
prompt_caching_set=False,
file_id_used=self.is_file_id_used(messages),
mcp_server_used=self.is_mcp_server_used(optional_params.get("mcp_servers")),
)
beta_set.update(auto_betas)
if tool_search_used and not (
programmatic_tool_calling_used or input_examples_used
):
beta_set.discard(ANTHROPIC_TOOL_SEARCH_BETA_HEADER)
if "opus-4" in model.lower() or "opus_4" in model.lower():
beta_set.add("tool-search-tool-2025-10-19")
# Filter out beta headers that Bedrock Invoke doesn't support
# Uses centralized configuration from anthropic_beta_headers_config.json
beta_list = list(beta_set)
_anthropic_request["anthropic_beta"] = beta_list
return _anthropic_request
def _normalize_bedrock_tool_search_tools(self, optional_params: dict) -> dict:
"""
Convert tool search entries to the format supported by the Bedrock Invoke API.
"""
tools = optional_params.get("tools")
if not tools or not isinstance(tools, list):
return optional_params
normalized_tools = []
for tool in tools:
tool_type = tool.get("type")
if tool_type == "tool_search_tool_bm25_20251119":
# Bedrock Invoke does not support the BM25 variant, so skip it.
continue
if tool_type == "tool_search_tool_regex_20251119":
normalized_tool = tool.copy()
normalized_tool["type"] = "tool_search_tool_regex"
normalized_tool["name"] = normalized_tool.get(
"name", "tool_search_tool_regex"
)
normalized_tools.append(normalized_tool)
continue
normalized_tools.append(tool)
optional_params["tools"] = normalized_tools
return optional_params
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
return AnthropicConfig.transform_response(
self,
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)

View File

@@ -0,0 +1,613 @@
import copy
import json
import time
from functools import partial
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.litellm_core_utils.prompt_templates.factory import (
cohere_message_pt,
custom_prompt,
deepseek_r1_pt,
prompt_factory,
)
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
from litellm.llms.bedrock.chat.invoke_handler import make_call, make_sync_call
from litellm.llms.bedrock.common_utils import BedrockError
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
_get_httpx_client,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import ModelResponse, Usage
from litellm.utils import CustomStreamWrapper
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
def __init__(self, **kwargs):
BaseConfig.__init__(self, **kwargs)
BaseAWSLLM.__init__(self, **kwargs)
def get_supported_openai_params(self, model: str) -> List[str]:
"""
This is a base invoke model mapping. For Invoke - define a bedrock provider specific config that extends this class.
"""
return [
"max_tokens",
"max_completion_tokens",
"stream",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
This is a base invoke model mapping. For Invoke - define a bedrock provider specific config that extends this class.
"""
for param, value in non_default_params.items():
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "stream":
optional_params["stream"] = value
return optional_params
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete url for the request
"""
provider = self.get_bedrock_invoke_provider(model)
modelId = self.get_bedrock_model_id(
model=model,
provider=provider,
optional_params=optional_params,
)
### SET RUNTIME ENDPOINT ###
aws_bedrock_runtime_endpoint = optional_params.get(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
api_base=api_base,
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
aws_region_name=self._get_aws_region_name(
optional_params=optional_params, model=model
),
)
if (stream is not None and stream is True) and provider != "ai21":
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
proxy_endpoint_url = (
f"{proxy_endpoint_url}/model/{modelId}/invoke-with-response-stream"
)
else:
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
return endpoint_url
def sign_request(
self,
headers: dict,
optional_params: dict,
request_data: dict,
api_base: str,
api_key: Optional[str] = None,
model: Optional[str] = None,
stream: Optional[bool] = None,
fake_stream: Optional[bool] = None,
) -> Tuple[dict, Optional[bytes]]:
return self._sign_request(
service_name="bedrock",
headers=headers,
optional_params=optional_params,
request_data=request_data,
api_base=api_base,
api_key=api_key,
model=model,
stream=stream,
fake_stream=fake_stream,
)
def _apply_config_to_params(self, config: dict, inference_params: dict) -> None:
"""Apply config values to inference_params if not already set."""
for k, v in config.items():
if k not in inference_params:
inference_params[k] = v
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
## SETUP ##
stream = optional_params.pop("stream", None)
custom_prompt_dict: dict = litellm_params.pop("custom_prompt_dict", None) or {}
hf_model_name = litellm_params.get("hf_model_name", None)
provider = self.get_bedrock_invoke_provider(model)
prompt, chat_history = self.convert_messages_to_prompt(
model=hf_model_name or model,
messages=messages,
provider=provider,
custom_prompt_dict=custom_prompt_dict,
)
inference_params = copy.deepcopy(optional_params)
inference_params = {
k: v
for k, v in inference_params.items()
if k not in self.aws_authentication_params
}
request_data: dict = {}
if provider == "cohere":
if model.startswith("cohere.command-r"):
## LOAD CONFIG
config = litellm.AmazonCohereChatConfig().get_config()
self._apply_config_to_params(config, inference_params)
_data = {"message": prompt, **inference_params}
if chat_history is not None:
_data["chat_history"] = chat_history
request_data = _data
else:
## LOAD CONFIG
config = litellm.AmazonCohereConfig.get_config()
self._apply_config_to_params(config, inference_params)
if stream is True:
inference_params[
"stream"
] = True # cohere requires stream = True in inference params
request_data = {"prompt": prompt, **inference_params}
elif provider == "anthropic":
transformed_request = (
litellm.AmazonAnthropicClaudeConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
)
return transformed_request
elif provider == "nova":
return litellm.AmazonInvokeNovaConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
elif provider == "ai21":
## LOAD CONFIG
config = litellm.AmazonAI21Config.get_config()
self._apply_config_to_params(config, inference_params)
request_data = {"prompt": prompt, **inference_params}
elif provider == "mistral":
## LOAD CONFIG
config = litellm.AmazonMistralConfig.get_config()
self._apply_config_to_params(config, inference_params)
request_data = {"prompt": prompt, **inference_params}
elif provider == "amazon": # amazon titan
## LOAD CONFIG
config = litellm.AmazonTitanConfig.get_config()
self._apply_config_to_params(config, inference_params)
request_data = {
"inputText": prompt,
"textGenerationConfig": inference_params,
}
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
## LOAD CONFIG
config = litellm.AmazonLlamaConfig.get_config()
self._apply_config_to_params(config, inference_params)
request_data = {"prompt": prompt, **inference_params}
elif provider == "twelvelabs":
return litellm.AmazonTwelveLabsPegasusConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
elif provider == "openai":
# OpenAI imported models use OpenAI Chat Completions format
return litellm.AmazonBedrockOpenAIConfig().transform_request(
model=model,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
headers=headers,
)
else:
raise BedrockError(
status_code=404,
message="Bedrock Invoke HTTPX: Unknown provider={}, model={}. Try calling via converse route - `bedrock/converse/<model>`.".format(
provider, model
),
)
return request_data
def transform_response( # noqa: PLR0915
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj: LiteLLMLoggingObj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding: Any,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
try:
completion_response = raw_response.json()
except Exception:
raise BedrockError(
message=raw_response.text, status_code=raw_response.status_code
)
verbose_logger.debug(
"bedrock invoke response % s",
json.dumps(completion_response, indent=4, default=str),
)
provider = self.get_bedrock_invoke_provider(model)
outputText: Optional[str] = None
try:
if provider == "cohere":
if "text" in completion_response:
outputText = completion_response["text"] # type: ignore
elif "generations" in completion_response:
outputText = completion_response["generations"][0]["text"]
model_response.choices[0].finish_reason = map_finish_reason(
completion_response["generations"][0]["finish_reason"]
)
elif provider == "anthropic":
return litellm.AmazonAnthropicClaudeConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
elif provider == "nova":
return litellm.AmazonInvokeNovaConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
)
elif provider == "twelvelabs":
return litellm.AmazonTwelveLabsPegasusConfig().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
elif provider == "ai21":
outputText = (
completion_response.get("completions")[0].get("data").get("text")
)
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
outputText = completion_response["generation"]
elif provider == "mistral":
outputText = litellm.AmazonMistralConfig.get_outputText(
completion_response, model_response
)
else: # amazon titan
outputText = completion_response.get("results")[0].get("outputText")
except Exception as e:
raise BedrockError(
message="Error processing={}, Received error={}".format(
raw_response.text, str(e)
),
status_code=422,
)
try:
if (
outputText is not None
and len(outputText) > 0
and hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
is None
):
model_response.choices[0].message.content = outputText # type: ignore
elif (
hasattr(model_response.choices[0], "message")
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
is not None
):
pass
else:
raise Exception()
except Exception as e:
raise BedrockError(
message="Error parsing received text={}.\nError-{}".format(
outputText, str(e)
),
status_code=raw_response.status_code,
)
## CALCULATING USAGE - bedrock returns usage in the headers
bedrock_input_tokens = raw_response.headers.get(
"x-amzn-bedrock-input-token-count", None
)
bedrock_output_tokens = raw_response.headers.get(
"x-amzn-bedrock-output-token-count", None
)
prompt_tokens = int(
bedrock_input_tokens or litellm.token_counter(messages=messages)
)
completion_tokens = int(
bedrock_output_tokens
or litellm.token_counter(
text=model_response.choices[0].message.content, # type: ignore
count_response_tokens=True,
)
)
model_response.created = int(time.time())
model_response.model = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
setattr(model_response, "usage", usage)
return model_response
def validate_environment(
self,
headers: dict,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
return headers
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
) -> BaseLLMException:
return BedrockError(status_code=status_code, message=error_message)
@track_llm_api_timing()
async def get_async_custom_stream_wrapper(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: str,
headers: dict,
data: dict,
messages: list,
client: Optional[AsyncHTTPHandler] = None,
json_mode: Optional[bool] = None,
signed_json_body: Optional[bytes] = None,
) -> CustomStreamWrapper:
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_call,
client=client,
api_base=api_base,
headers=headers,
data=json.dumps(data),
model=model,
messages=messages,
logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
json_mode=json_mode,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
@track_llm_api_timing()
def get_sync_custom_stream_wrapper(
self,
model: str,
custom_llm_provider: str,
logging_obj: LiteLLMLoggingObj,
api_base: str,
headers: dict,
data: dict,
messages: list,
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
json_mode: Optional[bool] = None,
signed_json_body: Optional[bytes] = None,
) -> CustomStreamWrapper:
if client is None or isinstance(client, AsyncHTTPHandler):
client = _get_httpx_client(params={})
streaming_response = CustomStreamWrapper(
completion_stream=None,
make_call=partial(
make_sync_call,
client=client,
api_base=api_base,
headers=headers,
data=json.dumps(data),
signed_json_body=signed_json_body,
model=model,
messages=messages,
logging_obj=logging_obj,
fake_stream=True if "ai21" in api_base else False,
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
json_mode=json_mode,
),
model=model,
custom_llm_provider="bedrock",
logging_obj=logging_obj,
)
return streaming_response
@property
def has_custom_stream_wrapper(self) -> bool:
return True
@property
def supports_stream_param_in_request_body(self) -> bool:
"""
Bedrock invoke does not allow passing `stream` in the request body.
"""
return False
@staticmethod
def get_bedrock_invoke_provider(
model: str,
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the bedrock provider from the model
handles 4 scenarios:
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
"""
if model.startswith("invoke/"):
model = model.replace("invoke/", "", 1)
# Special case: Check for "nova" in model name first (before "amazon")
# This handles amazon.nova-* models which would otherwise match "amazon" (Titan)
if "nova" in model.lower():
if "nova" in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, "nova")
_split_model = model.split(".")[0]
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
# If not a known provider, check for pattern with two slashes
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
if provider is not None:
return provider
for provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
if provider in model:
return provider
return None
@staticmethod
def _get_provider_from_model_path(
model_path: str,
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
"""
Helper function to get the provider from a model path with format: provider/model-name
Args:
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
Returns:
Optional[str]: The provider name, or None if no valid provider found
"""
parts = model_path.split("/")
if len(parts) >= 1:
provider = parts[0]
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
return None
def convert_messages_to_prompt(
self, model, messages, provider, custom_prompt_dict
) -> Tuple[str, Optional[list]]:
# handle anthropic prompts and amazon titan prompts
prompt = ""
chat_history: Optional[list] = None
## CUSTOM PROMPT
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details.get(
"initial_prompt_value", ""
),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
messages=messages,
)
return prompt, None
## ELSE
if provider == "anthropic" or provider == "amazon":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "mistral":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "meta" or provider == "llama":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
elif provider == "cohere":
prompt, chat_history = cohere_message_pt(messages=messages)
elif provider == "deepseek_r1":
prompt = deepseek_r1_pt(messages=messages)
else:
prompt = ""
for message in messages:
if "role" in message:
if message["role"] == "user":
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
else:
prompt += f"{message['content']}"
return prompt, chat_history # type: ignore