chore: initial public snapshot for github upload
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,97 @@
|
||||
from openai.types.batch import BatchRequestCounts
|
||||
from openai.types.batch import Metadata as OpenAIBatchMetadata
|
||||
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
|
||||
class BedrockBatchesHandler:
|
||||
"""
|
||||
Handler for Bedrock Batches.
|
||||
|
||||
Specific providers/models needed some special handling.
|
||||
|
||||
E.g. Twelve Labs Embedding Async Invoke
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _handle_async_invoke_status(
|
||||
batch_id: str, aws_region_name: str, logging_obj=None, **kwargs
|
||||
) -> "LiteLLMBatch":
|
||||
"""
|
||||
Handle async invoke status check for AWS Bedrock.
|
||||
|
||||
This is for Twelve Labs Embedding Async Invoke.
|
||||
|
||||
Args:
|
||||
batch_id: The async invoke ARN
|
||||
aws_region_name: AWS region name
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
dict: Status information including status, output_file_id (S3 URL), etc.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from litellm.llms.bedrock.embed.embedding import BedrockEmbedding
|
||||
|
||||
async def _async_get_status():
|
||||
# Create embedding handler instance
|
||||
embedding_handler = BedrockEmbedding()
|
||||
|
||||
# Get the status of the async invoke job
|
||||
status_response = await embedding_handler._get_async_invoke_status(
|
||||
invocation_arn=batch_id,
|
||||
aws_region_name=aws_region_name,
|
||||
logging_obj=logging_obj,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Transform response to a LiteLLMBatch object
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
|
||||
openai_batch_metadata: OpenAIBatchMetadata = {
|
||||
"output_file_id": status_response["outputDataConfig"][
|
||||
"s3OutputDataConfig"
|
||||
]["s3Uri"],
|
||||
"failure_message": status_response.get("failureMessage") or "",
|
||||
"model_arn": status_response["modelArn"],
|
||||
}
|
||||
|
||||
result = LiteLLMBatch(
|
||||
id=status_response["invocationArn"],
|
||||
object="batch",
|
||||
status=status_response["status"],
|
||||
created_at=status_response["submitTime"],
|
||||
in_progress_at=status_response["lastModifiedTime"],
|
||||
completed_at=status_response.get("endTime"),
|
||||
failed_at=status_response.get("endTime")
|
||||
if status_response["status"] == "failed"
|
||||
else None,
|
||||
request_counts=BatchRequestCounts(
|
||||
total=1,
|
||||
completed=1 if status_response["status"] == "completed" else 0,
|
||||
failed=1 if status_response["status"] == "failed" else 0,
|
||||
),
|
||||
metadata=openai_batch_metadata,
|
||||
completion_window="24h",
|
||||
endpoint="/v1/embeddings",
|
||||
input_file_id="",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
# Since this function is called from within an async context via run_in_executor,
|
||||
# we need to create a new event loop in a thread to avoid conflicts
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
new_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(new_loop)
|
||||
try:
|
||||
return new_loop.run_until_complete(_async_get_status())
|
||||
finally:
|
||||
new_loop.close()
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
return future.result()
|
||||
@@ -0,0 +1,549 @@
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Literal, Optional, Union, cast
|
||||
|
||||
from httpx import Headers, Response
|
||||
|
||||
from litellm.llms.base_llm.batches.transformation import BaseBatchesConfig
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.bedrock import (
|
||||
BedrockCreateBatchRequest,
|
||||
BedrockCreateBatchResponse,
|
||||
BedrockInputDataConfig,
|
||||
BedrockOutputDataConfig,
|
||||
BedrockS3InputDataConfig,
|
||||
BedrockS3OutputDataConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
CreateBatchRequest,
|
||||
)
|
||||
from litellm.types.utils import LiteLLMBatch, LlmProviders
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import CommonBatchFilesUtils
|
||||
|
||||
|
||||
class BedrockBatchesConfig(BaseAWSLLM, BaseBatchesConfig):
|
||||
"""
|
||||
Config for Bedrock Batches - handles batch job creation and management for Bedrock
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.common_utils = CommonBatchFilesUtils()
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.BEDROCK
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate and prepare environment for Bedrock batch requests.
|
||||
AWS credentials are handled by BaseAWSLLM.
|
||||
"""
|
||||
# Add any Bedrock-specific headers if needed
|
||||
return headers
|
||||
|
||||
def get_complete_batch_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
data: CreateBatchRequest,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for Bedrock batch creation.
|
||||
Bedrock batch jobs are created via the model invocation job API.
|
||||
"""
|
||||
aws_region_name = self._get_aws_region_name(optional_params, model)
|
||||
|
||||
# Bedrock model invocation job endpoint
|
||||
# Format: https://bedrock.{region}.amazonaws.com/model-invocation-job
|
||||
bedrock_endpoint = (
|
||||
f"https://bedrock.{aws_region_name}.amazonaws.com/model-invocation-job"
|
||||
)
|
||||
|
||||
return bedrock_endpoint
|
||||
|
||||
def transform_create_batch_request(
|
||||
self,
|
||||
model: str,
|
||||
create_batch_data: CreateBatchRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform the batch creation request to Bedrock format.
|
||||
|
||||
Bedrock batch inference requires:
|
||||
- modelId: The Bedrock model ID
|
||||
- jobName: Unique name for the batch job
|
||||
- inputDataConfig: Configuration for input data (S3 location)
|
||||
- outputDataConfig: Configuration for output data (S3 location)
|
||||
- roleArn: IAM role ARN for the batch job
|
||||
"""
|
||||
# Get required parameters
|
||||
input_file_id = create_batch_data.get("input_file_id")
|
||||
if not input_file_id:
|
||||
raise ValueError("input_file_id is required for Bedrock batch creation")
|
||||
|
||||
# Extract S3 information from file ID using common utility
|
||||
input_bucket, input_key = self.common_utils.parse_s3_uri(input_file_id)
|
||||
|
||||
# Get output S3 configuration
|
||||
output_bucket = litellm_params.get("s3_output_bucket_name") or os.getenv(
|
||||
"AWS_S3_OUTPUT_BUCKET_NAME"
|
||||
)
|
||||
if not output_bucket:
|
||||
# Use same bucket as input if no output bucket specified
|
||||
output_bucket = input_bucket
|
||||
|
||||
# Get IAM role ARN
|
||||
role_arn = (
|
||||
litellm_params.get("aws_batch_role_arn")
|
||||
or optional_params.get("aws_batch_role_arn")
|
||||
or os.getenv("AWS_BATCH_ROLE_ARN")
|
||||
)
|
||||
if not role_arn:
|
||||
raise ValueError(
|
||||
"AWS IAM role ARN is required for Bedrock batch jobs. "
|
||||
"Set 'aws_batch_role_arn' in litellm_params or AWS_BATCH_ROLE_ARN env var"
|
||||
)
|
||||
|
||||
if not model:
|
||||
raise ValueError(
|
||||
"Could not determine Bedrock model ID. Please pass `model` in your request body."
|
||||
)
|
||||
|
||||
# Generate job name with the correct model ID using common utility
|
||||
job_name = self.common_utils.generate_unique_job_name(model, prefix="litellm")
|
||||
output_key = f"litellm-batch-outputs/{job_name}/"
|
||||
|
||||
# Build input data config
|
||||
input_data_config: BedrockInputDataConfig = {
|
||||
"s3InputDataConfig": BedrockS3InputDataConfig(
|
||||
s3Uri=f"s3://{input_bucket}/{input_key}"
|
||||
)
|
||||
}
|
||||
|
||||
# Build output data config
|
||||
s3_output_config: BedrockS3OutputDataConfig = BedrockS3OutputDataConfig(
|
||||
s3Uri=f"s3://{output_bucket}/{output_key}"
|
||||
)
|
||||
|
||||
# Add optional KMS encryption key ID if provided
|
||||
s3_encryption_key_id = litellm_params.get(
|
||||
"s3_encryption_key_id"
|
||||
) or get_secret_str("AWS_S3_ENCRYPTION_KEY_ID")
|
||||
if s3_encryption_key_id:
|
||||
s3_output_config["s3EncryptionKeyId"] = s3_encryption_key_id
|
||||
|
||||
output_data_config: BedrockOutputDataConfig = {
|
||||
"s3OutputDataConfig": s3_output_config
|
||||
}
|
||||
|
||||
# Create Bedrock batch request with proper typing
|
||||
bedrock_request: BedrockCreateBatchRequest = {
|
||||
"modelId": model,
|
||||
"jobName": job_name,
|
||||
"inputDataConfig": input_data_config,
|
||||
"outputDataConfig": output_data_config,
|
||||
"roleArn": role_arn,
|
||||
}
|
||||
|
||||
# Add optional parameters if provided
|
||||
completion_window = create_batch_data.get("completion_window")
|
||||
if completion_window:
|
||||
# Map OpenAI completion window to Bedrock timeout
|
||||
# OpenAI uses "24h", Bedrock expects timeout in hours
|
||||
if completion_window == "24h":
|
||||
bedrock_request["timeoutDurationInHours"] = 24
|
||||
|
||||
# For Bedrock, we need to return a pre-signed request with AWS auth headers
|
||||
# Use common utility for AWS signing
|
||||
endpoint_url = f"https://bedrock.{self._get_aws_region_name(optional_params, model)}.amazonaws.com/model-invocation-job"
|
||||
signed_headers, signed_data = self.common_utils.sign_aws_request(
|
||||
service_name="bedrock",
|
||||
data=bedrock_request,
|
||||
endpoint_url=endpoint_url,
|
||||
optional_params=optional_params,
|
||||
method="POST",
|
||||
)
|
||||
|
||||
# Return a pre-signed request format that the HTTP handler can use
|
||||
return {
|
||||
"method": "POST",
|
||||
"url": endpoint_url,
|
||||
"headers": signed_headers,
|
||||
"data": signed_data.decode("utf-8"),
|
||||
}
|
||||
|
||||
def transform_create_batch_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: Response,
|
||||
logging_obj: Any,
|
||||
litellm_params: dict,
|
||||
) -> LiteLLMBatch:
|
||||
"""
|
||||
Transform Bedrock batch creation response to LiteLLM format.
|
||||
"""
|
||||
try:
|
||||
response_data: BedrockCreateBatchResponse = raw_response.json()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse Bedrock batch response: {e}")
|
||||
|
||||
# Extract information from typed Bedrock response
|
||||
job_arn = response_data.get("jobArn", "")
|
||||
status_str: str = str(response_data.get("status", "Submitted"))
|
||||
|
||||
# Map Bedrock status to OpenAI-compatible status
|
||||
status_mapping: Dict[str, str] = {
|
||||
"Submitted": "validating",
|
||||
"Validating": "validating",
|
||||
"Scheduled": "in_progress",
|
||||
"InProgress": "in_progress",
|
||||
"PartiallyCompleted": "completed",
|
||||
"Completed": "completed",
|
||||
"Failed": "failed",
|
||||
"Stopping": "cancelling",
|
||||
"Stopped": "cancelled",
|
||||
"Expired": "expired",
|
||||
}
|
||||
|
||||
openai_status = cast(
|
||||
Literal[
|
||||
"validating",
|
||||
"failed",
|
||||
"in_progress",
|
||||
"finalizing",
|
||||
"completed",
|
||||
"expired",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
],
|
||||
status_mapping.get(status_str, "validating"),
|
||||
)
|
||||
|
||||
# Get original request data from litellm_params if available
|
||||
original_request = litellm_params.get("original_batch_request", {})
|
||||
|
||||
# Create LiteLLM batch object
|
||||
return LiteLLMBatch(
|
||||
id=job_arn, # Use ARN as the batch ID
|
||||
object="batch",
|
||||
endpoint=original_request.get("endpoint", "/v1/chat/completions"),
|
||||
errors=None,
|
||||
input_file_id=original_request.get("input_file_id", ""),
|
||||
completion_window=original_request.get("completion_window", "24h"),
|
||||
status=openai_status,
|
||||
output_file_id=None, # Will be populated when job completes
|
||||
error_file_id=None,
|
||||
created_at=int(time.time()),
|
||||
in_progress_at=int(time.time()) if status_str == "InProgress" else None,
|
||||
expires_at=None,
|
||||
finalizing_at=None,
|
||||
completed_at=None,
|
||||
failed_at=None,
|
||||
expired_at=None,
|
||||
cancelling_at=None,
|
||||
cancelled_at=None,
|
||||
request_counts=None,
|
||||
metadata=original_request.get("metadata", {}),
|
||||
)
|
||||
|
||||
def transform_retrieve_batch_request(
|
||||
self,
|
||||
batch_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform batch retrieval request for Bedrock.
|
||||
|
||||
Args:
|
||||
batch_id: Bedrock job ARN
|
||||
optional_params: Optional parameters
|
||||
litellm_params: LiteLLM parameters
|
||||
|
||||
Returns:
|
||||
Transformed request data for Bedrock GetModelInvocationJob API
|
||||
"""
|
||||
# For Bedrock, batch_id should be the full job ARN
|
||||
# The GetModelInvocationJob API expects the full ARN as the identifier
|
||||
if not batch_id.startswith("arn:aws:bedrock:"):
|
||||
raise ValueError(f"Invalid batch_id format. Expected ARN, got: {batch_id}")
|
||||
|
||||
# Extract the job identifier from the ARN - use the full ARN path part
|
||||
# ARN format: arn:aws:bedrock:region:account:model-invocation-job/job-name
|
||||
arn_parts = batch_id.split(":")
|
||||
if len(arn_parts) < 6:
|
||||
raise ValueError(f"Invalid ARN format: {batch_id}")
|
||||
|
||||
region = arn_parts[3]
|
||||
# arn_parts[5] contains "model-invocation-job/{jobId}"
|
||||
|
||||
# Build the endpoint URL for GetModelInvocationJob
|
||||
# AWS API format: GET /model-invocation-job/{jobIdentifier}
|
||||
# Use the FULL ARN as jobIdentifier and URL-encode it (includes ':' and '/')
|
||||
import urllib.parse as _ul
|
||||
|
||||
encoded_arn = _ul.quote(batch_id, safe="")
|
||||
endpoint_url = (
|
||||
f"https://bedrock.{region}.amazonaws.com/model-invocation-job/{encoded_arn}"
|
||||
)
|
||||
|
||||
# Use common utility for AWS signing
|
||||
signed_headers, _ = self.common_utils.sign_aws_request(
|
||||
service_name="bedrock",
|
||||
data={}, # GET request has no body
|
||||
endpoint_url=endpoint_url,
|
||||
optional_params=optional_params,
|
||||
method="GET",
|
||||
)
|
||||
|
||||
# Return pre-signed request format
|
||||
return {
|
||||
"method": "GET",
|
||||
"url": endpoint_url,
|
||||
"headers": signed_headers,
|
||||
"data": None,
|
||||
}
|
||||
|
||||
def _parse_timestamps_and_status(self, response_data, status_str: str):
|
||||
"""Helper to parse timestamps based on status."""
|
||||
import datetime
|
||||
|
||||
def parse_timestamp(ts_str: Optional[str]) -> Optional[int]:
|
||||
if not ts_str:
|
||||
return None
|
||||
try:
|
||||
dt = datetime.datetime.fromisoformat(ts_str.replace("Z", "+00:00"))
|
||||
return int(dt.timestamp())
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
created_at = parse_timestamp(
|
||||
str(response_data.get("submitTime"))
|
||||
if response_data.get("submitTime") is not None
|
||||
else None
|
||||
)
|
||||
in_progress_states = {"InProgress", "Validating", "Scheduled"}
|
||||
in_progress_at = (
|
||||
parse_timestamp(
|
||||
str(response_data.get("lastModifiedTime"))
|
||||
if response_data.get("lastModifiedTime") is not None
|
||||
else None
|
||||
)
|
||||
if status_str in in_progress_states
|
||||
else None
|
||||
)
|
||||
completed_at = (
|
||||
parse_timestamp(
|
||||
str(response_data.get("endTime"))
|
||||
if response_data.get("endTime") is not None
|
||||
else None
|
||||
)
|
||||
if status_str in {"Completed", "PartiallyCompleted"}
|
||||
else None
|
||||
)
|
||||
failed_at = (
|
||||
parse_timestamp(
|
||||
str(response_data.get("endTime"))
|
||||
if response_data.get("endTime") is not None
|
||||
else None
|
||||
)
|
||||
if status_str == "Failed"
|
||||
else None
|
||||
)
|
||||
cancelled_at = (
|
||||
parse_timestamp(
|
||||
str(response_data.get("endTime"))
|
||||
if response_data.get("endTime") is not None
|
||||
else None
|
||||
)
|
||||
if status_str == "Stopped"
|
||||
else None
|
||||
)
|
||||
expires_at = parse_timestamp(
|
||||
str(response_data.get("jobExpirationTime"))
|
||||
if response_data.get("jobExpirationTime") is not None
|
||||
else None
|
||||
)
|
||||
|
||||
return (
|
||||
created_at,
|
||||
in_progress_at,
|
||||
completed_at,
|
||||
failed_at,
|
||||
cancelled_at,
|
||||
expires_at,
|
||||
)
|
||||
|
||||
def _extract_file_configs(self, response_data):
|
||||
"""Helper to extract input and output file configurations."""
|
||||
# Extract input file ID
|
||||
input_file_id = ""
|
||||
input_data_config = response_data.get("inputDataConfig", {})
|
||||
if isinstance(input_data_config, dict):
|
||||
s3_input_config = input_data_config.get("s3InputDataConfig", {})
|
||||
if isinstance(s3_input_config, dict):
|
||||
input_file_id = s3_input_config.get("s3Uri", "")
|
||||
|
||||
# Extract output file ID
|
||||
output_file_id = None
|
||||
output_data_config = response_data.get("outputDataConfig", {})
|
||||
if isinstance(output_data_config, dict):
|
||||
s3_output_config = output_data_config.get("s3OutputDataConfig", {})
|
||||
if isinstance(s3_output_config, dict):
|
||||
output_file_id = s3_output_config.get("s3Uri", "")
|
||||
|
||||
return input_file_id, output_file_id
|
||||
|
||||
def _extract_errors_and_metadata(self, response_data, raw_response):
|
||||
"""Helper to extract errors and enriched metadata."""
|
||||
# Extract errors
|
||||
message = response_data.get("message")
|
||||
errors = None
|
||||
if message:
|
||||
from openai.types.batch import Errors
|
||||
from openai.types.batch_error import BatchError
|
||||
|
||||
errors = Errors(
|
||||
data=[BatchError(message=message, code=str(raw_response.status_code))],
|
||||
object="list",
|
||||
)
|
||||
|
||||
# Enrich metadata with useful Bedrock fields
|
||||
enriched_metadata_raw: Dict[str, Any] = {
|
||||
"jobName": response_data.get("jobName"),
|
||||
"clientRequestToken": response_data.get("clientRequestToken"),
|
||||
"modelId": response_data.get("modelId"),
|
||||
"roleArn": response_data.get("roleArn"),
|
||||
"timeoutDurationInHours": response_data.get("timeoutDurationInHours"),
|
||||
"vpcConfig": response_data.get("vpcConfig"),
|
||||
}
|
||||
import json as _json
|
||||
|
||||
enriched_metadata: Dict[str, str] = {}
|
||||
for _k, _v in enriched_metadata_raw.items():
|
||||
if _v is None:
|
||||
continue
|
||||
if isinstance(_v, (dict, list)):
|
||||
try:
|
||||
enriched_metadata[_k] = _json.dumps(_v)
|
||||
except Exception:
|
||||
enriched_metadata[_k] = str(_v)
|
||||
else:
|
||||
enriched_metadata[_k] = str(_v)
|
||||
|
||||
return errors, enriched_metadata
|
||||
|
||||
def transform_retrieve_batch_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: Response,
|
||||
logging_obj: Any,
|
||||
litellm_params: dict,
|
||||
) -> LiteLLMBatch:
|
||||
"""
|
||||
Transform Bedrock batch retrieval response to LiteLLM format.
|
||||
"""
|
||||
from litellm.types.llms.bedrock import BedrockGetBatchResponse
|
||||
|
||||
try:
|
||||
response_data: BedrockGetBatchResponse = raw_response.json()
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse Bedrock batch response: {e}")
|
||||
|
||||
job_arn = response_data.get("jobArn", "")
|
||||
status_str: str = str(response_data.get("status", "Submitted"))
|
||||
|
||||
# Map Bedrock status to OpenAI-compatible status
|
||||
status_mapping: Dict[str, str] = {
|
||||
"Submitted": "validating",
|
||||
"Validating": "validating",
|
||||
"Scheduled": "in_progress",
|
||||
"InProgress": "in_progress",
|
||||
"PartiallyCompleted": "completed",
|
||||
"Completed": "completed",
|
||||
"Failed": "failed",
|
||||
"Stopping": "cancelling",
|
||||
"Stopped": "cancelled",
|
||||
"Expired": "expired",
|
||||
}
|
||||
openai_status = cast(
|
||||
Literal[
|
||||
"validating",
|
||||
"failed",
|
||||
"in_progress",
|
||||
"finalizing",
|
||||
"completed",
|
||||
"expired",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
],
|
||||
status_mapping.get(status_str, "validating"),
|
||||
)
|
||||
|
||||
# Parse timestamps
|
||||
(
|
||||
created_at,
|
||||
in_progress_at,
|
||||
completed_at,
|
||||
failed_at,
|
||||
cancelled_at,
|
||||
expires_at,
|
||||
) = self._parse_timestamps_and_status(response_data, status_str)
|
||||
|
||||
# Extract file configurations
|
||||
input_file_id, output_file_id = self._extract_file_configs(response_data)
|
||||
|
||||
# Extract errors and metadata
|
||||
errors, enriched_metadata = self._extract_errors_and_metadata(
|
||||
response_data, raw_response
|
||||
)
|
||||
|
||||
return LiteLLMBatch(
|
||||
id=job_arn,
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
errors=errors,
|
||||
input_file_id=input_file_id,
|
||||
completion_window="24h",
|
||||
status=openai_status,
|
||||
output_file_id=output_file_id,
|
||||
error_file_id=None,
|
||||
created_at=created_at or int(time.time()),
|
||||
in_progress_at=in_progress_at,
|
||||
expires_at=expires_at,
|
||||
finalizing_at=None,
|
||||
completed_at=completed_at,
|
||||
failed_at=failed_at,
|
||||
expired_at=None,
|
||||
cancelling_at=None,
|
||||
cancelled_at=cancelled_at,
|
||||
request_counts=None,
|
||||
metadata=enriched_metadata,
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
"""
|
||||
Get Bedrock-specific error class using common utility.
|
||||
"""
|
||||
return self.common_utils.get_error_class(error_message, status_code, headers)
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import Optional
|
||||
|
||||
from .converse_handler import BedrockConverseLLM
|
||||
from .invoke_handler import (
|
||||
AmazonAnthropicClaudeStreamDecoder,
|
||||
AmazonDeepSeekR1StreamDecoder,
|
||||
AWSEventStreamDecoder,
|
||||
BedrockLLM,
|
||||
)
|
||||
|
||||
|
||||
def get_bedrock_event_stream_decoder(
|
||||
invoke_provider: Optional[str], model: str, sync_stream: bool, json_mode: bool
|
||||
):
|
||||
if invoke_provider and invoke_provider == "anthropic":
|
||||
decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder(
|
||||
model=model,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
return decoder
|
||||
elif invoke_provider and invoke_provider == "deepseek_r1":
|
||||
decoder = AmazonDeepSeekR1StreamDecoder(
|
||||
model=model,
|
||||
sync_stream=sync_stream,
|
||||
)
|
||||
return decoder
|
||||
else:
|
||||
decoder = AWSEventStreamDecoder(model=model)
|
||||
return decoder
|
||||
@@ -0,0 +1,3 @@
|
||||
from .transformation import AmazonAgentCoreConfig
|
||||
|
||||
__all__ = ["AmazonAgentCoreConfig"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,512 @@
|
||||
import json
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.anthropic_beta_headers_manager import (
|
||||
update_headers_with_filtered_beta,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObject
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM, Credentials
|
||||
from ..common_utils import BedrockError, _get_all_bedrock_regions
|
||||
from .invoke_handler import AWSEventStreamDecoder, MockResponseIterator, make_call
|
||||
|
||||
|
||||
def make_sync_call(
|
||||
client: Optional[HTTPHandler],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
model: str,
|
||||
messages: list,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
json_mode: Optional[bool] = False,
|
||||
fake_stream: bool = False,
|
||||
stream_chunk_size: int = 1024,
|
||||
):
|
||||
if client is None:
|
||||
client = _get_httpx_client() # Create a new client if none provided
|
||||
|
||||
response = client.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
stream=not fake_stream,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise BedrockError(
|
||||
status_code=response.status_code, message=str(response.read())
|
||||
)
|
||||
|
||||
if fake_stream:
|
||||
model_response: (
|
||||
ModelResponse
|
||||
) = litellm.AmazonConverseConfig()._transform_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=litellm.ModelResponse(),
|
||||
stream=True,
|
||||
logging_obj=logging_obj,
|
||||
optional_params={},
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
encoding=litellm.encoding,
|
||||
) # type: ignore
|
||||
completion_stream: Any = MockResponseIterator(
|
||||
model_response=model_response, json_mode=json_mode
|
||||
)
|
||||
else:
|
||||
decoder = AWSEventStreamDecoder(model=model, json_mode=json_mode)
|
||||
completion_stream = decoder.iter_bytes(
|
||||
response.iter_bytes(chunk_size=stream_chunk_size)
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
original_response="first stream response received",
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return completion_stream
|
||||
|
||||
|
||||
class BedrockConverseLLM(BaseAWSLLM):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
async def async_streaming(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
credentials: Credentials,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
fake_stream: bool = False,
|
||||
json_mode: Optional[bool] = False,
|
||||
api_key: Optional[str] = None,
|
||||
stream_chunk_size: int = 1024,
|
||||
) -> CustomStreamWrapper:
|
||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
data = json.dumps(request_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
|
||||
extra_headers=headers,
|
||||
endpoint_url=api_base,
|
||||
data=data,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": dict(prepped.headers),
|
||||
},
|
||||
)
|
||||
|
||||
completion_stream = await make_call(
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=dict(prepped.headers),
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=fake_stream,
|
||||
json_mode=json_mode,
|
||||
stream_chunk_size=stream_chunk_size,
|
||||
)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
async def async_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
model_response: ModelResponse,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
encoding,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
stream,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
credentials: Credentials,
|
||||
logger_fn=None,
|
||||
headers: dict = {},
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
request_data = await litellm.AmazonConverseConfig()._async_transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
data = json.dumps(request_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=litellm_params.get("aws_region_name") or "us-west-2",
|
||||
extra_headers=headers,
|
||||
endpoint_url=api_base,
|
||||
data=data,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": api_base,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
|
||||
headers = dict(prepped.headers)
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = get_async_httpx_client(
|
||||
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
|
||||
)
|
||||
else:
|
||||
client = client # type: ignore
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
url=api_base,
|
||||
headers=headers,
|
||||
data=data,
|
||||
logging_obj=logging_obj,
|
||||
) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return litellm.AmazonConverseConfig()._transform_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream if isinstance(stream, bool) else False,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
def completion( # noqa: PLR0915
|
||||
self,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: Optional[str],
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
encoding,
|
||||
logging_obj: LiteLLMLoggingObject,
|
||||
optional_params: dict,
|
||||
acompletion: bool,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
litellm_params: dict,
|
||||
logger_fn=None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
stream_chunk_size = optional_params.pop("stream_chunk_size", 1024)
|
||||
unencoded_model_id = optional_params.pop("model_id", None)
|
||||
fake_stream = optional_params.pop("fake_stream", False)
|
||||
json_mode = optional_params.get("json_mode", False)
|
||||
if unencoded_model_id is not None:
|
||||
modelId = self.encode_model_id(model_id=unencoded_model_id)
|
||||
else:
|
||||
# Strip nova spec prefixes before encoding model ID for API URL
|
||||
_model_for_id = model
|
||||
_stripped = _model_for_id
|
||||
for rp in ["bedrock/converse/", "bedrock/", "converse/"]:
|
||||
if _stripped.startswith(rp):
|
||||
_stripped = _stripped[len(rp) :]
|
||||
break
|
||||
# Strip embedded region prefix (e.g. "bedrock/us-east-1/model" -> "model")
|
||||
# and capture it so it can be used as aws_region_name below.
|
||||
_region_from_model: Optional[str] = None
|
||||
_potential_region = _stripped.split("/", 1)[0]
|
||||
if _potential_region in _get_all_bedrock_regions() and "/" in _stripped:
|
||||
_region_from_model = _potential_region
|
||||
_stripped = _stripped.split("/", 1)[1]
|
||||
_model_for_id = _stripped
|
||||
for _nova_prefix in ["nova-2/", "nova/"]:
|
||||
if _stripped.startswith(_nova_prefix):
|
||||
_model_for_id = _model_for_id.replace(_nova_prefix, "", 1)
|
||||
break
|
||||
modelId = self.encode_model_id(model_id=_model_for_id)
|
||||
# Inject region extracted from model path so _get_aws_region_name picks it up
|
||||
if (
|
||||
_region_from_model is not None
|
||||
and "aws_region_name" not in optional_params
|
||||
):
|
||||
optional_params["aws_region_name"] = _region_from_model
|
||||
|
||||
fake_stream = litellm.AmazonConverseConfig().should_fake_stream(
|
||||
fake_stream=fake_stream,
|
||||
model=model,
|
||||
stream=stream,
|
||||
custom_llm_provider="bedrock",
|
||||
)
|
||||
|
||||
### SET REGION NAME ###
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
model_id=unencoded_model_id,
|
||||
)
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
aws_external_id = optional_params.pop("aws_external_id", None)
|
||||
optional_params.pop("aws_region_name", None)
|
||||
|
||||
litellm_params[
|
||||
"aws_region_name"
|
||||
] = aws_region_name # [DO NOT DELETE] important for async calls
|
||||
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
aws_external_id=aws_external_id,
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
if (stream is not None and stream is True) and not fake_stream:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse-stream"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/converse"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/converse"
|
||||
|
||||
## COMPLETION CALL
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
# Filter beta headers in HTTP headers before making the request
|
||||
headers = update_headers_with_filtered_beta(
|
||||
headers=headers, provider="bedrock_converse"
|
||||
)
|
||||
### ROUTING (ASYNC, STREAMING, SYNC)
|
||||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if stream is True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=True,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
json_mode=json_mode,
|
||||
fake_stream=fake_stream,
|
||||
credentials=credentials,
|
||||
api_key=api_key,
|
||||
stream_chunk_size=stream_chunk_size,
|
||||
) # type: ignore
|
||||
### ASYNC COMPLETION
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=proxy_endpoint_url,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
stream=stream, # type: ignore
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
credentials=credentials,
|
||||
api_key=api_key,
|
||||
) # type: ignore
|
||||
|
||||
## TRANSFORMATION ##
|
||||
|
||||
_data = litellm.AmazonConverseConfig()._transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=extra_headers,
|
||||
)
|
||||
data = json.dumps(_data)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": proxy_endpoint_url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = _get_httpx_client(_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
|
||||
if stream is not None and stream is True:
|
||||
completion_stream = make_sync_call(
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
api_base=proxy_endpoint_url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
data=data,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
json_mode=json_mode,
|
||||
fake_stream=fake_stream,
|
||||
stream_chunk_size=stream_chunk_size,
|
||||
)
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
|
||||
### COMPLETION
|
||||
|
||||
try:
|
||||
response = client.post(
|
||||
url=proxy_endpoint_url,
|
||||
headers=prepped.headers,
|
||||
data=data,
|
||||
logging_obj=logging_obj,
|
||||
) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return litellm.AmazonConverseConfig()._transform_response(
|
||||
model=model,
|
||||
response=response,
|
||||
model_response=model_response,
|
||||
stream=stream if isinstance(stream, bool) else False,
|
||||
logging_obj=logging_obj,
|
||||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
Uses base_llm_http_handler to call the 'converse like' endpoint.
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/8085
|
||||
"""
|
||||
@@ -0,0 +1,3 @@
|
||||
"""
|
||||
Uses `converse_transformation.py` to transform the messages to the format required by Bedrock Converse.
|
||||
"""
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,547 @@
|
||||
"""
|
||||
Transformation for Bedrock Invoke Agent
|
||||
|
||||
https://docs.aws.amazon.com/bedrock/latest/APIReference/API_agent-runtime_InvokeAgent.html
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.types.llms.bedrock_invoke_agents import (
|
||||
InvokeAgentChunkPayload,
|
||||
InvokeAgentEvent,
|
||||
InvokeAgentEventHeaders,
|
||||
InvokeAgentEventList,
|
||||
InvokeAgentMetadata,
|
||||
InvokeAgentModelInvocationInput,
|
||||
InvokeAgentModelInvocationOutput,
|
||||
InvokeAgentOrchestrationTrace,
|
||||
InvokeAgentPreProcessingTrace,
|
||||
InvokeAgentTrace,
|
||||
InvokeAgentTracePayload,
|
||||
InvokeAgentUsage,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonInvokeAgentConfig(BaseConfig, BaseAWSLLM):
|
||||
def __init__(self, **kwargs):
|
||||
BaseConfig.__init__(self, **kwargs)
|
||||
BaseAWSLLM.__init__(self, **kwargs)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
This is a base invoke agent model mapping. For Invoke Agent - define a bedrock provider specific config that extends this class.
|
||||
|
||||
Bedrock Invoke Agents has 0 OpenAI compatible params
|
||||
|
||||
As of May 29th, 2025 - they don't support streaming.
|
||||
"""
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
This is a base invoke agent model mapping. For Invoke Agent - define a bedrock provider specific config that extends this class.
|
||||
"""
|
||||
return optional_params
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
"""
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
aws_bedrock_runtime_endpoint = optional_params.get(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
endpoint_url, _ = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=self._get_aws_region_name(
|
||||
optional_params=optional_params, model=model
|
||||
),
|
||||
endpoint_type="agent",
|
||||
)
|
||||
|
||||
agent_id, agent_alias_id = self._get_agent_id_and_alias_id(model)
|
||||
session_id = self._get_session_id(optional_params)
|
||||
|
||||
endpoint_url = f"{endpoint_url}/agents/{agent_id}/agentAliases/{agent_alias_id}/sessions/{session_id}/text"
|
||||
|
||||
return endpoint_url
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
return self._sign_request(
|
||||
service_name="bedrock",
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
request_data=request_data,
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
stream=stream,
|
||||
fake_stream=fake_stream,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def _get_agent_id_and_alias_id(self, model: str) -> tuple[str, str]:
|
||||
"""
|
||||
model = "agent/L1RT58GYRW/MFPSBCXYTW"
|
||||
agent_id = "L1RT58GYRW"
|
||||
agent_alias_id = "MFPSBCXYTW"
|
||||
"""
|
||||
# Split the model string by '/' and extract components
|
||||
parts = model.split("/")
|
||||
if len(parts) != 3 or parts[0] != "agent":
|
||||
raise ValueError(
|
||||
"Invalid model format. Expected format: 'model=agent/AGENT_ID/ALIAS_ID'"
|
||||
)
|
||||
|
||||
return parts[1], parts[2] # Return (agent_id, agent_alias_id)
|
||||
|
||||
def _get_session_id(self, optional_params: dict) -> str:
|
||||
""" """
|
||||
return optional_params.get("sessionID", None) or str(uuid.uuid4())
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
# use the last message content as the query
|
||||
query: str = convert_content_list_to_str(messages[-1])
|
||||
return {
|
||||
"inputText": query,
|
||||
"enableTrace": True,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
def _parse_aws_event_stream(self, raw_content: bytes) -> InvokeAgentEventList:
|
||||
"""
|
||||
Parse AWS event stream format using boto3/botocore's built-in parser.
|
||||
This is the same approach used in the existing AWSEventStreamDecoder.
|
||||
"""
|
||||
try:
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
from botocore.parsers import EventStreamJSONParser
|
||||
except ImportError:
|
||||
raise ImportError("boto3/botocore is required for AWS event stream parsing")
|
||||
|
||||
events: InvokeAgentEventList = []
|
||||
parser = EventStreamJSONParser()
|
||||
event_stream_buffer = EventStreamBuffer()
|
||||
|
||||
# Add the entire response to the buffer
|
||||
event_stream_buffer.add_data(raw_content)
|
||||
|
||||
# Process all events in the buffer
|
||||
for event in event_stream_buffer:
|
||||
try:
|
||||
headers = self._extract_headers_from_event(event)
|
||||
|
||||
event_type = headers.get("event_type", "")
|
||||
|
||||
if event_type == "chunk":
|
||||
# Handle chunk events specially - they contain decoded content, not JSON
|
||||
message = self._parse_message_from_event(event, parser)
|
||||
parsed_event: InvokeAgentEvent = InvokeAgentEvent()
|
||||
if message:
|
||||
# For chunk events, create a payload with the decoded content
|
||||
parsed_event = {
|
||||
"headers": headers,
|
||||
"payload": {
|
||||
"bytes": base64.b64encode(
|
||||
message.encode("utf-8")
|
||||
).decode("utf-8")
|
||||
}, # Re-encode for consistency
|
||||
}
|
||||
events.append(parsed_event)
|
||||
|
||||
elif event_type == "trace":
|
||||
# Handle trace events normally - they contain JSON
|
||||
message = self._parse_message_from_event(event, parser)
|
||||
|
||||
if message:
|
||||
try:
|
||||
event_data = json.loads(message)
|
||||
parsed_event = {
|
||||
"headers": headers,
|
||||
"payload": event_data,
|
||||
}
|
||||
events.append(parsed_event)
|
||||
except json.JSONDecodeError as e:
|
||||
verbose_logger.warning(
|
||||
f"Failed to parse trace event JSON: {e}"
|
||||
)
|
||||
else:
|
||||
verbose_logger.debug(f"Unknown event type: {event_type}")
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error processing event: {e}")
|
||||
continue
|
||||
|
||||
return events
|
||||
|
||||
def _parse_message_from_event(self, event, parser) -> Optional[str]:
|
||||
"""Extract message content from an AWS event, adapted from AWSEventStreamDecoder."""
|
||||
try:
|
||||
response_dict = event.to_response_dict()
|
||||
verbose_logger.debug(f"Response dict: {response_dict}")
|
||||
|
||||
# Use the same response shape parsing as the existing decoder
|
||||
parsed_response = parser.parse(
|
||||
response_dict, self._get_response_stream_shape()
|
||||
)
|
||||
verbose_logger.debug(f"Parsed response: {parsed_response}")
|
||||
|
||||
if response_dict["status_code"] != 200:
|
||||
decoded_body = response_dict["body"].decode()
|
||||
if isinstance(decoded_body, dict):
|
||||
error_message = decoded_body.get("message")
|
||||
elif isinstance(decoded_body, str):
|
||||
error_message = decoded_body
|
||||
else:
|
||||
error_message = ""
|
||||
exception_status = response_dict["headers"].get(":exception-type")
|
||||
error_message = exception_status + " " + error_message
|
||||
raise BedrockError(
|
||||
status_code=response_dict["status_code"],
|
||||
message=(
|
||||
json.dumps(error_message)
|
||||
if isinstance(error_message, dict)
|
||||
else error_message
|
||||
),
|
||||
)
|
||||
|
||||
if "chunk" in parsed_response:
|
||||
chunk = parsed_response.get("chunk")
|
||||
if not chunk:
|
||||
return None
|
||||
return chunk.get("bytes").decode()
|
||||
else:
|
||||
chunk = response_dict.get("body")
|
||||
if not chunk:
|
||||
return None
|
||||
return chunk.decode()
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error parsing message from event: {e}")
|
||||
return None
|
||||
|
||||
def _extract_headers_from_event(self, event) -> InvokeAgentEventHeaders:
|
||||
"""Extract headers from an AWS event for categorization."""
|
||||
try:
|
||||
response_dict = event.to_response_dict()
|
||||
headers = response_dict.get("headers", {})
|
||||
|
||||
# Extract the event-type and content-type headers that we care about
|
||||
return InvokeAgentEventHeaders(
|
||||
event_type=headers.get(":event-type", ""),
|
||||
content_type=headers.get(":content-type", ""),
|
||||
message_type=headers.get(":message-type", ""),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error extracting headers: {e}")
|
||||
return InvokeAgentEventHeaders(
|
||||
event_type="", content_type="", message_type=""
|
||||
)
|
||||
|
||||
def _get_response_stream_shape(self):
|
||||
"""Get the response stream shape for parsing, reusing existing logic."""
|
||||
try:
|
||||
# Try to reuse the cached shape from the existing decoder
|
||||
from litellm.llms.bedrock.chat.invoke_handler import (
|
||||
get_response_stream_shape,
|
||||
)
|
||||
|
||||
return get_response_stream_shape()
|
||||
except ImportError:
|
||||
# Fallback: create our own shape
|
||||
try:
|
||||
from botocore.loaders import Loader
|
||||
from botocore.model import ServiceModel
|
||||
|
||||
loader = Loader()
|
||||
bedrock_service_dict = loader.load_service_model(
|
||||
"bedrock-runtime", "service-2"
|
||||
)
|
||||
bedrock_service_model = ServiceModel(bedrock_service_dict)
|
||||
return bedrock_service_model.shape_for("ResponseStream")
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Could not load response stream shape: {e}")
|
||||
return None
|
||||
|
||||
def _extract_response_content(self, events: InvokeAgentEventList) -> str:
|
||||
"""Extract the final response content from parsed events."""
|
||||
response_parts = []
|
||||
|
||||
for event in events:
|
||||
headers = event.get("headers", {})
|
||||
payload = event.get("payload")
|
||||
|
||||
event_type = headers.get(
|
||||
"event_type"
|
||||
) # Note: using event_type not event-type
|
||||
|
||||
if event_type == "chunk" and payload:
|
||||
# Extract base64 encoded content from chunk events
|
||||
chunk_payload: InvokeAgentChunkPayload = payload # type: ignore
|
||||
encoded_bytes = chunk_payload.get("bytes", "")
|
||||
if encoded_bytes:
|
||||
try:
|
||||
decoded_content = base64.b64decode(encoded_bytes).decode(
|
||||
"utf-8"
|
||||
)
|
||||
response_parts.append(decoded_content)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to decode chunk content: {e}")
|
||||
|
||||
return "".join(response_parts)
|
||||
|
||||
def _extract_usage_info(self, events: InvokeAgentEventList) -> InvokeAgentUsage:
|
||||
"""Extract token usage information from trace events."""
|
||||
usage_info = InvokeAgentUsage(
|
||||
inputTokens=0,
|
||||
outputTokens=0,
|
||||
model=None,
|
||||
)
|
||||
|
||||
response_model: Optional[str] = None
|
||||
|
||||
for event in events:
|
||||
if not self._is_trace_event(event):
|
||||
continue
|
||||
|
||||
trace_data = self._get_trace_data(event)
|
||||
if not trace_data:
|
||||
continue
|
||||
|
||||
verbose_logger.debug(f"Trace event: {trace_data}")
|
||||
|
||||
# Extract usage from pre-processing trace
|
||||
self._extract_and_update_preprocessing_usage(
|
||||
trace_data=trace_data,
|
||||
usage_info=usage_info,
|
||||
)
|
||||
|
||||
# Extract model from orchestration trace
|
||||
if response_model is None:
|
||||
response_model = self._extract_orchestration_model(trace_data)
|
||||
|
||||
usage_info["model"] = response_model
|
||||
return usage_info
|
||||
|
||||
def _is_trace_event(self, event: InvokeAgentEvent) -> bool:
|
||||
"""Check if the event is a trace event."""
|
||||
headers = event.get("headers", {})
|
||||
event_type = headers.get("event_type")
|
||||
payload = event.get("payload")
|
||||
return event_type == "trace" and payload is not None
|
||||
|
||||
def _get_trace_data(self, event: InvokeAgentEvent) -> Optional[InvokeAgentTrace]:
|
||||
"""Extract trace data from a trace event."""
|
||||
payload = event.get("payload")
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
trace_payload: InvokeAgentTracePayload = payload # type: ignore
|
||||
return trace_payload.get("trace", {})
|
||||
|
||||
def _extract_and_update_preprocessing_usage(
|
||||
self, trace_data: InvokeAgentTrace, usage_info: InvokeAgentUsage
|
||||
) -> None:
|
||||
"""Extract usage information from preprocessing trace."""
|
||||
pre_processing: Optional[InvokeAgentPreProcessingTrace] = trace_data.get(
|
||||
"preProcessingTrace"
|
||||
)
|
||||
if not pre_processing:
|
||||
return
|
||||
|
||||
model_output: Optional[InvokeAgentModelInvocationOutput] = (
|
||||
pre_processing.get("modelInvocationOutput")
|
||||
or InvokeAgentModelInvocationOutput()
|
||||
)
|
||||
if not model_output:
|
||||
return
|
||||
|
||||
metadata: Optional[InvokeAgentMetadata] = (
|
||||
model_output.get("metadata") or InvokeAgentMetadata()
|
||||
)
|
||||
if not metadata:
|
||||
return
|
||||
|
||||
usage: Optional[Union[InvokeAgentUsage, Dict]] = metadata.get("usage", {})
|
||||
if not usage:
|
||||
return
|
||||
|
||||
usage_info["inputTokens"] += usage.get("inputTokens", 0)
|
||||
usage_info["outputTokens"] += usage.get("outputTokens", 0)
|
||||
|
||||
def _extract_orchestration_model(
|
||||
self, trace_data: InvokeAgentTrace
|
||||
) -> Optional[str]:
|
||||
"""Extract model information from orchestration trace."""
|
||||
orchestration_trace: Optional[InvokeAgentOrchestrationTrace] = trace_data.get(
|
||||
"orchestrationTrace"
|
||||
)
|
||||
if not orchestration_trace:
|
||||
return None
|
||||
|
||||
model_invocation: Optional[InvokeAgentModelInvocationInput] = (
|
||||
orchestration_trace.get("modelInvocationInput")
|
||||
or InvokeAgentModelInvocationInput()
|
||||
)
|
||||
if not model_invocation:
|
||||
return None
|
||||
|
||||
return model_invocation.get("foundationModel")
|
||||
|
||||
def _build_model_response(
|
||||
self,
|
||||
content: str,
|
||||
model: str,
|
||||
usage_info: InvokeAgentUsage,
|
||||
model_response: ModelResponse,
|
||||
) -> ModelResponse:
|
||||
"""Build the final ModelResponse object."""
|
||||
|
||||
# Create the message content
|
||||
message = Message(content=content, role="assistant")
|
||||
|
||||
# Create choices
|
||||
choice = Choices(finish_reason="stop", index=0, message=message)
|
||||
|
||||
# Update model response
|
||||
model_response.choices = [choice]
|
||||
model_response.model = usage_info.get("model", model)
|
||||
|
||||
# Add usage information if available
|
||||
if usage_info:
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=usage_info.get("inputTokens", 0),
|
||||
completion_tokens=usage_info.get("outputTokens", 0),
|
||||
total_tokens=usage_info.get("inputTokens", 0)
|
||||
+ usage_info.get("outputTokens", 0),
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
# Get the raw binary content
|
||||
raw_content = raw_response.content
|
||||
verbose_logger.debug(
|
||||
f"Processing {len(raw_content)} bytes of AWS event stream data"
|
||||
)
|
||||
|
||||
# Parse the AWS event stream format
|
||||
events = self._parse_aws_event_stream(raw_content)
|
||||
verbose_logger.debug(f"Parsed {len(events)} events from stream")
|
||||
|
||||
# Extract response content from chunk events
|
||||
content = self._extract_response_content(events)
|
||||
|
||||
# Extract usage information from trace events
|
||||
usage_info = self._extract_usage_info(events)
|
||||
|
||||
# Build and return the model response
|
||||
return self._build_model_response(
|
||||
content=content,
|
||||
model=model,
|
||||
usage_info=usage_info,
|
||||
model_response=model_response,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
f"Error processing Bedrock Invoke Agent response: {str(e)}"
|
||||
)
|
||||
raise BedrockError(
|
||||
message=f"Error processing response: {str(e)}",
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return BedrockError(status_code=status_code, message=error_message)
|
||||
|
||||
def should_fake_stream(
|
||||
self,
|
||||
model: Optional[str],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
return True
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,99 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonAI21Config(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra
|
||||
|
||||
Supported Params for the Amazon / AI21 models:
|
||||
|
||||
- `maxTokens` (int32): The maximum number of tokens to generate per result. Optional, default is 16. If no `stopSequences` are given, generation stops after producing `maxTokens`.
|
||||
|
||||
- `temperature` (float): Modifies the distribution from which tokens are sampled. Optional, default is 0.7. A value of 0 essentially disables sampling and results in greedy decoding.
|
||||
|
||||
- `topP` (float): Used for sampling tokens from the corresponding top percentile of probability mass. Optional, default is 1. For instance, a value of 0.9 considers only tokens comprising the top 90% probability mass.
|
||||
|
||||
- `stopSequences` (array of strings): Stops decoding if any of the input strings is generated. Optional.
|
||||
|
||||
- `frequencyPenalty` (object): Placeholder for frequency penalty object.
|
||||
|
||||
- `presencePenalty` (object): Placeholder for presence penalty object.
|
||||
|
||||
- `countPenalty` (object): Placeholder for count penalty object.
|
||||
"""
|
||||
|
||||
maxTokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
stopSequences: Optional[list] = None
|
||||
frequencePenalty: Optional[dict] = None
|
||||
presencePenalty: Optional[dict] = None
|
||||
countPenalty: Optional[dict] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[float] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
frequencePenalty: Optional[dict] = None,
|
||||
presencePenalty: Optional[dict] = None,
|
||||
countPenalty: Optional[dict] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["maxTokens"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["topP"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,75 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.cohere.chat.transformation import CohereChatConfig
|
||||
|
||||
|
||||
class AmazonCohereConfig(AmazonInvokeConfig, CohereChatConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=command
|
||||
|
||||
Supported Params for the Amazon / Cohere models:
|
||||
|
||||
- `max_tokens` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `return_likelihood` (string) n/a
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
return_likelihood: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
return_likelihood: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
supported_params = CohereChatConfig.get_supported_openai_params(
|
||||
self, model=model
|
||||
)
|
||||
return supported_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return CohereChatConfig.map_openai_params(
|
||||
self,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
@@ -0,0 +1,135 @@
|
||||
from typing import Any, List, Optional, cast
|
||||
|
||||
from httpx import Response
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
_parse_content_for_reasoning,
|
||||
)
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.bedrock import AmazonDeepSeekR1StreamingResponse
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionUsageBlock,
|
||||
Choices,
|
||||
Delta,
|
||||
Message,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
)
|
||||
|
||||
from .amazon_llama_transformation import AmazonLlamaConfig
|
||||
|
||||
|
||||
class AmazonDeepSeekR1Config(AmazonLlamaConfig):
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Extract the reasoning content, and return it as a separate field in the response.
|
||||
"""
|
||||
response = super().transform_response(
|
||||
model,
|
||||
raw_response,
|
||||
model_response,
|
||||
logging_obj,
|
||||
request_data,
|
||||
messages,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
encoding,
|
||||
api_key,
|
||||
json_mode,
|
||||
)
|
||||
prompt = cast(Optional[str], request_data.get("prompt"))
|
||||
message_content = cast(
|
||||
Optional[str], cast(Choices, response.choices[0]).message.get("content")
|
||||
)
|
||||
if prompt and prompt.strip().endswith("<think>") and message_content:
|
||||
message_content_with_reasoning_token = "<think>" + message_content
|
||||
reasoning, content = _parse_content_for_reasoning(
|
||||
message_content_with_reasoning_token
|
||||
)
|
||||
provider_specific_fields = (
|
||||
cast(Choices, response.choices[0]).message.provider_specific_fields
|
||||
or {}
|
||||
)
|
||||
if reasoning:
|
||||
provider_specific_fields["reasoning_content"] = reasoning
|
||||
|
||||
message = Message(
|
||||
**{
|
||||
**cast(Choices, response.choices[0]).message.model_dump(),
|
||||
"content": content,
|
||||
"provider_specific_fields": provider_specific_fields,
|
||||
}
|
||||
)
|
||||
cast(Choices, response.choices[0]).message = message
|
||||
return response
|
||||
|
||||
|
||||
class AmazonDeepseekR1ResponseIterator(BaseModelResponseIterator):
|
||||
def __init__(self, streaming_response: Any, sync_stream: bool) -> None:
|
||||
super().__init__(streaming_response=streaming_response, sync_stream=sync_stream)
|
||||
self.has_finished_thinking = False
|
||||
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
"""
|
||||
Deepseek r1 starts by thinking, then it generates the response.
|
||||
"""
|
||||
try:
|
||||
typed_chunk = AmazonDeepSeekR1StreamingResponse(**chunk) # type: ignore
|
||||
generated_content = typed_chunk["generation"]
|
||||
if generated_content == "</think>" and not self.has_finished_thinking:
|
||||
verbose_logger.debug(
|
||||
"Deepseek r1: </think> received, setting has_finished_thinking to True"
|
||||
)
|
||||
generated_content = ""
|
||||
self.has_finished_thinking = True
|
||||
|
||||
prompt_token_count = typed_chunk.get("prompt_token_count") or 0
|
||||
generation_token_count = typed_chunk.get("generation_token_count") or 0
|
||||
usage = ChatCompletionUsageBlock(
|
||||
prompt_tokens=prompt_token_count,
|
||||
completion_tokens=generation_token_count,
|
||||
total_tokens=prompt_token_count + generation_token_count,
|
||||
)
|
||||
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=typed_chunk["stop_reason"],
|
||||
delta=Delta(
|
||||
content=(
|
||||
generated_content
|
||||
if self.has_finished_thinking
|
||||
else None
|
||||
),
|
||||
reasoning_content=(
|
||||
generated_content
|
||||
if not self.has_finished_thinking
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
@@ -0,0 +1,80 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonLlamaConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=meta.llama2-13b-chat-v1
|
||||
|
||||
Supported Params for the Amazon / Meta Llama models:
|
||||
|
||||
- `max_gen_len` (integer) max tokens,
|
||||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
"""
|
||||
|
||||
max_gen_len: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List:
|
||||
return [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_gen_len"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,119 @@
|
||||
import types
|
||||
from typing import List, Optional, TYPE_CHECKING
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
||||
class AmazonMistralConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-mistral.html
|
||||
Supported Params for the Amazon / Mistral models:
|
||||
|
||||
- `max_tokens` (integer) max tokens,
|
||||
- `temperature` (float) temperature for model,
|
||||
- `top_p` (float) top p for model
|
||||
- `stop` [string] A list of stop sequences that if generated by the model, stops the model from generating further output.
|
||||
- `top_k` (float) top k for model
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[float] = None
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[int] = None,
|
||||
top_k: Optional[float] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return ["max_tokens", "temperature", "top_p", "stop", "stream"]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_tokens"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
if k == "stop":
|
||||
optional_params["stop"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
|
||||
@staticmethod
|
||||
def get_outputText(
|
||||
completion_response: dict, model_response: "ModelResponse"
|
||||
) -> str:
|
||||
"""This function extracts the output text from a bedrock mistral completion.
|
||||
As a side effect, it updates the finish reason for a model response.
|
||||
|
||||
Args:
|
||||
completion_response: JSON from the completion.
|
||||
model_response: ModelResponse
|
||||
|
||||
Returns:
|
||||
A string with the response of the LLM
|
||||
|
||||
"""
|
||||
if "choices" in completion_response:
|
||||
outputText = completion_response["choices"][0]["message"]["content"]
|
||||
model_response.choices[0].finish_reason = completion_response["choices"][0][
|
||||
"finish_reason"
|
||||
]
|
||||
elif "outputs" in completion_response:
|
||||
outputText = completion_response["outputs"][0]["text"]
|
||||
model_response.choices[0].finish_reason = completion_response["outputs"][0][
|
||||
"stop_reason"
|
||||
]
|
||||
else:
|
||||
raise BedrockError(
|
||||
message="Unexpected mistral completion response", status_code=400
|
||||
)
|
||||
|
||||
return outputText
|
||||
@@ -0,0 +1,266 @@
|
||||
"""
|
||||
Transformation for Bedrock Moonshot AI (Kimi K2) models.
|
||||
|
||||
Supports the Kimi K2 Thinking model available on Amazon Bedrock.
|
||||
Model format: bedrock/moonshot.kimi-k2-thinking-v1:0
|
||||
|
||||
Reference: https://aws.amazon.com/about-aws/whats-new/2025/12/amazon-bedrock-fully-managed-open-weight-models/
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Union
|
||||
import re
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.llms.moonshot.chat.transformation import MoonshotChatConfig
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonMoonshotConfig(AmazonInvokeConfig, MoonshotChatConfig):
|
||||
"""
|
||||
Configuration for Bedrock Moonshot AI (Kimi K2) models.
|
||||
|
||||
Reference:
|
||||
https://aws.amazon.com/about-aws/whats-new/2025/12/amazon-bedrock-fully-managed-open-weight-models/
|
||||
https://platform.moonshot.ai/docs/api/chat
|
||||
|
||||
Supported Params for the Amazon / Moonshot models:
|
||||
- `max_tokens` (integer) max tokens
|
||||
- `temperature` (float) temperature for model (0-1 for Moonshot)
|
||||
- `top_p` (float) top p for model
|
||||
- `stream` (bool) whether to stream responses
|
||||
- `tools` (list) tool definitions (supported on kimi-k2-thinking)
|
||||
- `tool_choice` (str|dict) tool choice specification (supported on kimi-k2-thinking)
|
||||
|
||||
NOT Supported on Bedrock:
|
||||
- `stop` sequences (Bedrock doesn't support stopSequences field for this model)
|
||||
|
||||
Note: The kimi-k2-thinking model DOES support tool calls, unlike kimi-thinking-preview.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
AmazonInvokeConfig.__init__(self, **kwargs)
|
||||
MoonshotChatConfig.__init__(self, **kwargs)
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "bedrock"
|
||||
|
||||
def _get_model_id(self, model: str) -> str:
|
||||
"""
|
||||
Extract the actual model ID from the LiteLLM model name.
|
||||
|
||||
Removes routing prefixes like:
|
||||
- bedrock/invoke/moonshot.kimi-k2-thinking -> moonshot.kimi-k2-thinking
|
||||
- invoke/moonshot.kimi-k2-thinking -> moonshot.kimi-k2-thinking
|
||||
- moonshot.kimi-k2-thinking -> moonshot.kimi-k2-thinking
|
||||
"""
|
||||
# Remove bedrock/ prefix if present
|
||||
if model.startswith("bedrock/"):
|
||||
model = model[8:]
|
||||
|
||||
# Remove invoke/ prefix if present
|
||||
if model.startswith("invoke/"):
|
||||
model = model[7:]
|
||||
|
||||
# Remove any provider prefix (e.g., moonshot/)
|
||||
if "/" in model and not model.startswith("arn:"):
|
||||
parts = model.split("/", 1)
|
||||
if len(parts) == 2:
|
||||
model = parts[1]
|
||||
|
||||
return model
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
Get the supported OpenAI params for Moonshot AI models on Bedrock.
|
||||
|
||||
Bedrock-specific limitations:
|
||||
- stopSequences field is not supported on Bedrock (unlike native Moonshot API)
|
||||
- functions parameter is not supported (use tools instead)
|
||||
- tool_choice doesn't support "required" value
|
||||
|
||||
Note: kimi-k2-thinking DOES support tool calls (unlike kimi-thinking-preview)
|
||||
The parent MoonshotChatConfig class handles the kimi-thinking-preview exclusion.
|
||||
"""
|
||||
excluded_params: List[str] = [
|
||||
"functions",
|
||||
"stop",
|
||||
] # Bedrock doesn't support stopSequences
|
||||
|
||||
base_openai_params = super(
|
||||
MoonshotChatConfig, self
|
||||
).get_supported_openai_params(model=model)
|
||||
final_params: List[str] = []
|
||||
for param in base_openai_params:
|
||||
if param not in excluded_params:
|
||||
final_params.append(param)
|
||||
|
||||
return final_params
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI parameters to Moonshot AI parameters for Bedrock.
|
||||
|
||||
Handles Moonshot AI specific limitations:
|
||||
- tool_choice doesn't support "required" value
|
||||
- Temperature <0.3 limitation for n>1
|
||||
- Temperature range is [0, 1] (not [0, 2] like OpenAI)
|
||||
"""
|
||||
return MoonshotChatConfig.map_openai_params(
|
||||
self,
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request for Bedrock Moonshot AI models.
|
||||
|
||||
Uses the Moonshot transformation logic which handles:
|
||||
- Converting content lists to strings (Moonshot doesn't support list format)
|
||||
- Adding tool_choice="required" message if needed
|
||||
- Temperature and parameter validation
|
||||
|
||||
"""
|
||||
# Filter out AWS credentials using the existing method from BaseAWSLLM
|
||||
self._get_boto_credentials_from_optional_params(optional_params, model)
|
||||
|
||||
# Strip routing prefixes to get the actual model ID
|
||||
clean_model_id = self._get_model_id(model)
|
||||
|
||||
# Use Moonshot's transform_request which handles message transformation
|
||||
# and tool_choice="required" workaround
|
||||
return MoonshotChatConfig.transform_request(
|
||||
self,
|
||||
model=clean_model_id,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def _extract_reasoning_from_content(
|
||||
self, content: str
|
||||
) -> tuple[Optional[str], str]:
|
||||
"""
|
||||
Extract reasoning content from <reasoning> tags in the response.
|
||||
|
||||
Moonshot AI's Kimi K2 Thinking model returns reasoning in <reasoning> tags.
|
||||
This method extracts that content and returns it separately.
|
||||
|
||||
Args:
|
||||
content: The full content string from the API response
|
||||
|
||||
Returns:
|
||||
tuple: (reasoning_content, main_content)
|
||||
"""
|
||||
if not content:
|
||||
return None, content
|
||||
|
||||
# Match <reasoning>...</reasoning> tags
|
||||
reasoning_match = re.match(
|
||||
r"<reasoning>(.*?)</reasoning>\s*(.*)", content, re.DOTALL
|
||||
)
|
||||
|
||||
if reasoning_match:
|
||||
reasoning_content = reasoning_match.group(1).strip()
|
||||
main_content = reasoning_match.group(2).strip()
|
||||
return reasoning_content, main_content
|
||||
|
||||
return None, content
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: "ModelResponse",
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> "ModelResponse":
|
||||
"""
|
||||
Transform the response from Bedrock Moonshot AI models.
|
||||
|
||||
Moonshot AI uses OpenAI-compatible response format, but returns reasoning
|
||||
content in <reasoning> tags. This method:
|
||||
1. Calls parent class transformation
|
||||
2. Extracts reasoning content from <reasoning> tags
|
||||
3. Sets reasoning_content on the message object
|
||||
"""
|
||||
# First, get the standard transformation
|
||||
model_response = MoonshotChatConfig.transform_response(
|
||||
self,
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
# Extract reasoning content from <reasoning> tags
|
||||
if model_response.choices and len(model_response.choices) > 0:
|
||||
for choice in model_response.choices:
|
||||
# Only process Choices (not StreamingChoices) which have message attribute
|
||||
if (
|
||||
isinstance(choice, Choices)
|
||||
and choice.message
|
||||
and choice.message.content
|
||||
):
|
||||
(
|
||||
reasoning_content,
|
||||
main_content,
|
||||
) = self._extract_reasoning_from_content(choice.message.content)
|
||||
|
||||
if reasoning_content:
|
||||
# Set the reasoning_content field
|
||||
choice.message.reasoning_content = reasoning_content
|
||||
# Update the main content without reasoning tags
|
||||
choice.message.content = main_content
|
||||
|
||||
return model_response
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BedrockError:
|
||||
"""Return the appropriate error class for Bedrock."""
|
||||
return BedrockError(status_code=status_code, message=error_message)
|
||||
@@ -0,0 +1,120 @@
|
||||
"""
|
||||
Handles transforming requests for `bedrock/invoke/{nova} models`
|
||||
|
||||
Inherits from `AmazonConverseConfig`
|
||||
|
||||
Nova + Invoke API Tutorial: https://docs.aws.amazon.com/nova/latest/userguide/using-invoke-api.html
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.types.llms.bedrock import BedrockInvokeNovaRequest
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
from ..converse_transformation import AmazonConverseConfig
|
||||
from .base_invoke_transformation import AmazonInvokeConfig
|
||||
|
||||
|
||||
class AmazonInvokeNovaConfig(AmazonInvokeConfig, AmazonConverseConfig):
|
||||
"""
|
||||
Config for sending `nova` requests to `/bedrock/invoke/`
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
return AmazonConverseConfig.get_supported_openai_params(self, model)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return AmazonConverseConfig.map_openai_params(
|
||||
self, non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
_transformed_nova_request = AmazonConverseConfig.transform_request(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
_bedrock_invoke_nova_request = BedrockInvokeNovaRequest(
|
||||
**_transformed_nova_request
|
||||
)
|
||||
self._remove_empty_system_messages(_bedrock_invoke_nova_request)
|
||||
bedrock_invoke_nova_request = self._filter_allowed_fields(
|
||||
_bedrock_invoke_nova_request
|
||||
)
|
||||
return bedrock_invoke_nova_request
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: Logging,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
return AmazonConverseConfig.transform_response(
|
||||
self,
|
||||
model,
|
||||
raw_response,
|
||||
model_response,
|
||||
logging_obj,
|
||||
request_data,
|
||||
messages,
|
||||
optional_params,
|
||||
litellm_params,
|
||||
encoding,
|
||||
api_key,
|
||||
json_mode,
|
||||
)
|
||||
|
||||
def _filter_allowed_fields(
|
||||
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||
) -> dict:
|
||||
"""
|
||||
Filter out fields that are not allowed in the `BedrockInvokeNovaRequest` dataclass.
|
||||
"""
|
||||
allowed_fields = set(BedrockInvokeNovaRequest.__annotations__.keys())
|
||||
return {
|
||||
k: v for k, v in bedrock_invoke_nova_request.items() if k in allowed_fields
|
||||
}
|
||||
|
||||
def _remove_empty_system_messages(
|
||||
self, bedrock_invoke_nova_request: BedrockInvokeNovaRequest
|
||||
) -> None:
|
||||
"""
|
||||
In-place remove empty `system` messages from the request.
|
||||
|
||||
/bedrock/invoke/ does not allow empty `system` messages.
|
||||
"""
|
||||
_system_message = bedrock_invoke_nova_request.get("system", None)
|
||||
if isinstance(_system_message, list) and len(_system_message) == 0:
|
||||
bedrock_invoke_nova_request.pop("system", None)
|
||||
return
|
||||
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Transformation for Bedrock imported models that use OpenAI Chat Completions format.
|
||||
|
||||
Use this for models imported into Bedrock that accept the OpenAI API format.
|
||||
Model format: bedrock/openai/<model-id>
|
||||
|
||||
Example: bedrock/openai/arn:aws:bedrock:us-east-1:123456789012:imported-model/abc123
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.llms.openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
from litellm.passthrough.utils import CommonUtils
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonBedrockOpenAIConfig(OpenAIGPTConfig, BaseAWSLLM):
|
||||
"""
|
||||
Configuration for Bedrock imported models that use OpenAI Chat Completions format.
|
||||
|
||||
This class handles the transformation of requests and responses for Bedrock
|
||||
imported models that accept the OpenAI API format directly.
|
||||
|
||||
Inherits from OpenAIGPTConfig to leverage standard OpenAI parameter handling
|
||||
and response transformation, while adding Bedrock-specific URL generation
|
||||
and AWS request signing.
|
||||
|
||||
Usage:
|
||||
model = "bedrock/openai/arn:aws:bedrock:us-east-1:123456789012:imported-model/abc123"
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
OpenAIGPTConfig.__init__(self, **kwargs)
|
||||
BaseAWSLLM.__init__(self, **kwargs)
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "bedrock"
|
||||
|
||||
def _get_openai_model_id(self, model: str) -> str:
|
||||
"""
|
||||
Extract the actual model ID from the LiteLLM model name.
|
||||
|
||||
Input format: bedrock/openai/<model-id>
|
||||
Returns: <model-id>
|
||||
"""
|
||||
# Remove bedrock/ prefix if present
|
||||
if model.startswith("bedrock/"):
|
||||
model = model[8:]
|
||||
|
||||
# Remove openai/ prefix
|
||||
if model.startswith("openai/"):
|
||||
model = model[7:]
|
||||
|
||||
return model
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the Bedrock invoke endpoint.
|
||||
|
||||
Uses the standard Bedrock invoke endpoint format.
|
||||
"""
|
||||
model_id = self._get_openai_model_id(model)
|
||||
|
||||
# Get AWS region
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=optional_params, model=model
|
||||
)
|
||||
|
||||
# Get runtime endpoint
|
||||
aws_bedrock_runtime_endpoint = optional_params.get(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
)
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
||||
# Encode model ID for ARNs (e.g., :imported-model/ -> :imported-model%2F)
|
||||
model_id = CommonUtils.encode_bedrock_runtime_modelid_arn(model_id)
|
||||
|
||||
# Build the invoke URL
|
||||
if stream:
|
||||
endpoint_url = (
|
||||
f"{endpoint_url}/model/{model_id}/invoke-with-response-stream"
|
||||
)
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{model_id}/invoke"
|
||||
|
||||
return endpoint_url
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
"""
|
||||
Sign the request using AWS Signature Version 4.
|
||||
"""
|
||||
return self._sign_request(
|
||||
service_name="bedrock",
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
request_data=request_data,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform the request to OpenAI Chat Completions format for Bedrock imported models.
|
||||
|
||||
Removes AWS-specific params and stream param (handled separately in URL),
|
||||
then delegates to parent class for standard OpenAI request transformation.
|
||||
"""
|
||||
# Remove stream from optional_params as it's handled separately in URL
|
||||
optional_params.pop("stream", None)
|
||||
|
||||
# Remove AWS-specific params that shouldn't be in the request body
|
||||
inference_params = {
|
||||
k: v
|
||||
for k, v in optional_params.items()
|
||||
if k not in self.aws_authentication_params
|
||||
}
|
||||
|
||||
# Use parent class transform_request for OpenAI format
|
||||
return super().transform_request(
|
||||
model=self._get_openai_model_id(model),
|
||||
messages=messages,
|
||||
optional_params=inference_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate the environment and return headers.
|
||||
|
||||
For Bedrock, we don't need Bearer token auth since we use AWS SigV4.
|
||||
"""
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BedrockError:
|
||||
"""Return the appropriate error class for Bedrock."""
|
||||
return BedrockError(status_code=status_code, message=error_message)
|
||||
@@ -0,0 +1,99 @@
|
||||
"""
|
||||
Handles transforming requests for `bedrock/invoke/{qwen2} models`
|
||||
|
||||
Inherits from `AmazonQwen3Config` since Qwen2 and Qwen3 architectures are mostly similar.
|
||||
The main difference is in the response format: Qwen2 uses "text" field while Qwen3 uses "generation" field.
|
||||
|
||||
Qwen2 + Invoke API Tutorial: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.amazon_qwen3_transformation import (
|
||||
AmazonQwen3Config,
|
||||
)
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
|
||||
class AmazonQwen2Config(AmazonQwen3Config):
|
||||
"""
|
||||
Config for sending `qwen2` requests to `/bedrock/invoke/`
|
||||
|
||||
Inherits from AmazonQwen3Config since Qwen2 and Qwen3 architectures are mostly similar.
|
||||
The main difference is in the response format: Qwen2 uses "text" field while Qwen3 uses "generation" field.
|
||||
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
|
||||
"""
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform Qwen2 Bedrock response to OpenAI format
|
||||
|
||||
Qwen2 uses "text" field, but we also support "generation" field for compatibility.
|
||||
"""
|
||||
try:
|
||||
if hasattr(raw_response, "json"):
|
||||
response_data = raw_response.json()
|
||||
else:
|
||||
response_data = raw_response
|
||||
|
||||
# Extract the generated text - Qwen2 uses "text" field, but also support "generation" for compatibility
|
||||
generated_text = response_data.get("generation", "") or response_data.get(
|
||||
"text", ""
|
||||
)
|
||||
|
||||
# Clean up the response (remove assistant start token if present)
|
||||
if generated_text.startswith("<|im_start|>assistant\n"):
|
||||
generated_text = generated_text[len("<|im_start|>assistant\n") :]
|
||||
if generated_text.endswith("<|im_end|>"):
|
||||
generated_text = generated_text[: -len("<|im_end|>")]
|
||||
|
||||
# Set the content in the existing model_response structure
|
||||
if hasattr(model_response, "choices") and len(model_response.choices) > 0:
|
||||
choice = model_response.choices[0]
|
||||
choice.message.content = generated_text
|
||||
choice.finish_reason = "stop"
|
||||
|
||||
# Set usage information if available in response
|
||||
if "usage" in response_data:
|
||||
usage_data = response_data["usage"]
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
),
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
except Exception as e:
|
||||
if logging_obj:
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response,
|
||||
additional_args={"error": str(e)},
|
||||
)
|
||||
raise e
|
||||
@@ -0,0 +1,225 @@
|
||||
"""
|
||||
Handles transforming requests for `bedrock/invoke/{qwen3} models`
|
||||
|
||||
Inherits from `AmazonInvokeConfig`
|
||||
|
||||
Qwen3 + Invoke API Tutorial: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
|
||||
"""
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
|
||||
|
||||
class AmazonQwen3Config(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Config for sending `qwen3` requests to `/bedrock/invoke/`
|
||||
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/invoke-imported-model.html
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens: Optional[int] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"max_tokens",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"stop",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens":
|
||||
optional_params["max_tokens"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "top_p":
|
||||
optional_params["top_p"] = v
|
||||
if k == "top_k":
|
||||
optional_params["top_k"] = v
|
||||
if k == "stop":
|
||||
optional_params["stop"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform OpenAI format to Qwen3 Bedrock invoke format
|
||||
"""
|
||||
# Convert messages to prompt format
|
||||
prompt = self._convert_messages_to_prompt(messages)
|
||||
|
||||
# Build the request body
|
||||
request_body = {
|
||||
"prompt": prompt,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if "max_tokens" in optional_params:
|
||||
request_body["max_gen_len"] = optional_params["max_tokens"]
|
||||
if "temperature" in optional_params:
|
||||
request_body["temperature"] = optional_params["temperature"]
|
||||
if "top_p" in optional_params:
|
||||
request_body["top_p"] = optional_params["top_p"]
|
||||
if "top_k" in optional_params:
|
||||
request_body["top_k"] = optional_params["top_k"]
|
||||
if "stop" in optional_params:
|
||||
request_body["stop"] = optional_params["stop"]
|
||||
|
||||
return request_body
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[AllMessageValues]) -> str:
|
||||
"""
|
||||
Convert OpenAI messages format to Qwen3 prompt format
|
||||
Supports tool calls, multimodal content, and various message types
|
||||
"""
|
||||
prompt_parts = []
|
||||
|
||||
for message in messages:
|
||||
role = message.get("role", "")
|
||||
content = message.get("content", "")
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
|
||||
if role == "system":
|
||||
prompt_parts.append(f"<|im_start|>system\n{content}<|im_end|>")
|
||||
elif role == "user":
|
||||
# Handle multimodal content
|
||||
if isinstance(content, list):
|
||||
text_content = []
|
||||
for item in content:
|
||||
if item.get("type") == "text":
|
||||
text_content.append(item.get("text", ""))
|
||||
elif item.get("type") == "image_url":
|
||||
# For Qwen3, we can include image placeholders
|
||||
text_content.append(
|
||||
"<|vision_start|><|image_pad|><|vision_end|>"
|
||||
)
|
||||
content = "".join(text_content)
|
||||
prompt_parts.append(f"<|im_start|>user\n{content}<|im_end|>")
|
||||
elif role == "assistant":
|
||||
if tool_calls and isinstance(tool_calls, list):
|
||||
# Handle tool calls
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.get("function", {}).get("name", "")
|
||||
function_args = tool_call.get("function", {}).get(
|
||||
"arguments", ""
|
||||
)
|
||||
prompt_parts.append(
|
||||
f'<|im_start|>assistant\n<tool_call>\n{{"name": "{function_name}", "arguments": "{function_args}"}}\n</tool_call><|im_end|>'
|
||||
)
|
||||
else:
|
||||
prompt_parts.append(f"<|im_start|>assistant\n{content}<|im_end|>")
|
||||
elif role == "tool":
|
||||
# Handle tool responses
|
||||
prompt_parts.append(f"<|im_start|>tool\n{content}<|im_end|>")
|
||||
|
||||
# Add assistant start token for response generation
|
||||
prompt_parts.append("<|im_start|>assistant\n")
|
||||
|
||||
return "\n".join(prompt_parts)
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform Qwen3 Bedrock response to OpenAI format
|
||||
"""
|
||||
try:
|
||||
if hasattr(raw_response, "json"):
|
||||
response_data = raw_response.json()
|
||||
else:
|
||||
response_data = raw_response
|
||||
|
||||
# Extract the generated text - Qwen3 uses "generation" field
|
||||
generated_text = response_data.get("generation", "")
|
||||
|
||||
# Clean up the response (remove assistant start token if present)
|
||||
if generated_text.startswith("<|im_start|>assistant\n"):
|
||||
generated_text = generated_text[len("<|im_start|>assistant\n") :]
|
||||
if generated_text.endswith("<|im_end|>"):
|
||||
generated_text = generated_text[: -len("<|im_end|>")]
|
||||
|
||||
# Set the content in the existing model_response structure
|
||||
if hasattr(model_response, "choices") and len(model_response.choices) > 0:
|
||||
choice = model_response.choices[0]
|
||||
choice.message.content = generated_text
|
||||
choice.finish_reason = "stop"
|
||||
|
||||
# Set usage information if available in response
|
||||
if "usage" in response_data:
|
||||
usage_data = response_data["usage"]
|
||||
setattr(
|
||||
model_response,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=usage_data.get("prompt_tokens", 0),
|
||||
completion_tokens=usage_data.get("completion_tokens", 0),
|
||||
total_tokens=usage_data.get("total_tokens", 0),
|
||||
),
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
except Exception as e:
|
||||
if logging_obj:
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=raw_response,
|
||||
additional_args={"error": str(e)},
|
||||
)
|
||||
raise e
|
||||
@@ -0,0 +1,116 @@
|
||||
import re
|
||||
import types
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
|
||||
|
||||
class AmazonTitanConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-text-express-v1
|
||||
|
||||
Supported Params for the Amazon Titan models:
|
||||
|
||||
- `maxTokenCount` (integer) max tokens,
|
||||
- `stopSequences` (string[]) list of stop sequence strings
|
||||
- `temperature` (float) temperature for model,
|
||||
- `topP` (int) top p for model
|
||||
"""
|
||||
|
||||
maxTokenCount: Optional[int] = None
|
||||
stopSequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
topP: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
maxTokenCount: Optional[int] = None,
|
||||
stopSequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
topP: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
AmazonInvokeConfig.__init__(self)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not k.startswith("_abc")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def _map_and_modify_arg(
|
||||
self,
|
||||
supported_params: dict,
|
||||
provider: str,
|
||||
model: str,
|
||||
stop: Union[List[str], str],
|
||||
):
|
||||
"""
|
||||
filter params to fit the required provider format, drop those that don't fit if user sets `litellm.drop_params = True`.
|
||||
"""
|
||||
filtered_stop = None
|
||||
if "stop" in supported_params and litellm.drop_params:
|
||||
if provider == "bedrock" and "amazon" in model:
|
||||
filtered_stop = []
|
||||
if isinstance(stop, list):
|
||||
for s in stop:
|
||||
if re.match(r"^(\|+|User:)$", s):
|
||||
filtered_stop.append(s)
|
||||
if filtered_stop is not None:
|
||||
supported_params["stop"] = filtered_stop
|
||||
|
||||
return supported_params
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "max_tokens" or k == "max_completion_tokens":
|
||||
optional_params["maxTokenCount"] = v
|
||||
if k == "temperature":
|
||||
optional_params["temperature"] = v
|
||||
if k == "stop":
|
||||
filtered_stop = self._map_and_modify_arg(
|
||||
{"stop": v}, provider="bedrock", model=model, stop=v
|
||||
)
|
||||
optional_params["stopSequences"] = filtered_stop["stop"]
|
||||
if k == "top_p":
|
||||
optional_params["topP"] = v
|
||||
if k == "stream":
|
||||
optional_params["stream"] = v
|
||||
return optional_params
|
||||
@@ -0,0 +1,280 @@
|
||||
"""
|
||||
Transforms OpenAI-style requests into TwelveLabs Pegasus 1.2 requests for Bedrock.
|
||||
|
||||
Reference:
|
||||
https://docs.twelvelabs.io/docs/models/pegasus
|
||||
"""
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.llms.base_llm.base_utils import type_to_response_format_param
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
from litellm.utils import get_base64_str
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonTwelveLabsPegasusConfig(AmazonInvokeConfig, BaseConfig):
|
||||
"""
|
||||
Handles transforming OpenAI-style requests into Bedrock InvokeModel requests for
|
||||
`twelvelabs.pegasus-1-2-v1:0`.
|
||||
|
||||
Pegasus 1.2 requires an `inputPrompt` and a `mediaSource` that either references
|
||||
an S3 object or a base64-encoded clip. Optional OpenAI params (temperature,
|
||||
response_format, max_tokens) are translated to the TwelveLabs schema.
|
||||
"""
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"temperature",
|
||||
"response_format",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param in {"max_tokens", "max_completion_tokens"}:
|
||||
optional_params["maxOutputTokens"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "response_format":
|
||||
optional_params["responseFormat"] = self._normalize_response_format(
|
||||
value
|
||||
)
|
||||
return optional_params
|
||||
|
||||
def _normalize_response_format(self, value: Any) -> Any:
|
||||
"""Normalize response_format to TwelveLabs format.
|
||||
|
||||
TwelveLabs expects:
|
||||
{
|
||||
"jsonSchema": {...}
|
||||
}
|
||||
|
||||
But OpenAI format is:
|
||||
{
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": "...",
|
||||
"schema": {...}
|
||||
}
|
||||
}
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
# If it has json_schema field, extract and transform it
|
||||
if "json_schema" in value:
|
||||
json_schema = value["json_schema"]
|
||||
# Extract the schema if nested
|
||||
if isinstance(json_schema, dict) and "schema" in json_schema:
|
||||
return {"jsonSchema": json_schema["schema"]}
|
||||
# Otherwise use json_schema directly
|
||||
return {"jsonSchema": json_schema}
|
||||
# If it already has jsonSchema, return as is
|
||||
if "jsonSchema" in value:
|
||||
return value
|
||||
# Otherwise return the dict as is
|
||||
return value
|
||||
return type_to_response_format_param(response_format=value) or value
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
input_prompt = self._convert_messages_to_prompt(messages=messages)
|
||||
request_data: Dict[str, Any] = {"inputPrompt": input_prompt}
|
||||
|
||||
media_source = self._build_media_source(optional_params)
|
||||
if media_source is not None:
|
||||
request_data["mediaSource"] = media_source
|
||||
|
||||
# Handle temperature and maxOutputTokens
|
||||
for key in ("temperature", "maxOutputTokens"):
|
||||
if key in optional_params:
|
||||
request_data[key] = optional_params.get(key)
|
||||
|
||||
# Handle responseFormat - transform to TwelveLabs format
|
||||
if "responseFormat" in optional_params:
|
||||
response_format = optional_params["responseFormat"]
|
||||
transformed_format = self._normalize_response_format(response_format)
|
||||
if transformed_format:
|
||||
request_data["responseFormat"] = transformed_format
|
||||
|
||||
return request_data
|
||||
|
||||
def _build_media_source(self, optional_params: dict) -> Optional[dict]:
|
||||
direct_source = optional_params.get("mediaSource") or optional_params.get(
|
||||
"media_source"
|
||||
)
|
||||
if isinstance(direct_source, dict):
|
||||
return direct_source
|
||||
|
||||
base64_input = optional_params.get("video_base64") or optional_params.get(
|
||||
"base64_string"
|
||||
)
|
||||
if base64_input:
|
||||
return {"base64String": get_base64_str(base64_input)}
|
||||
|
||||
s3_uri = (
|
||||
optional_params.get("video_s3_uri")
|
||||
or optional_params.get("s3_uri")
|
||||
or optional_params.get("media_source_s3_uri")
|
||||
)
|
||||
if s3_uri:
|
||||
s3_location = {"uri": s3_uri}
|
||||
bucket_owner = (
|
||||
optional_params.get("video_s3_bucket_owner")
|
||||
or optional_params.get("s3_bucket_owner")
|
||||
or optional_params.get("media_source_bucket_owner")
|
||||
)
|
||||
if bucket_owner:
|
||||
s3_location["bucketOwner"] = bucket_owner
|
||||
return {"s3Location": s3_location}
|
||||
return None
|
||||
|
||||
def _convert_messages_to_prompt(self, messages: List[AllMessageValues]) -> str:
|
||||
prompt_parts: List[str] = []
|
||||
for message in messages:
|
||||
role = message.get("role", "user")
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, list):
|
||||
text_fragments = []
|
||||
for item in content:
|
||||
if isinstance(item, dict):
|
||||
item_type = item.get("type")
|
||||
if item_type == "text":
|
||||
text_fragments.append(item.get("text", ""))
|
||||
elif item_type == "image_url":
|
||||
text_fragments.append("<image>")
|
||||
elif item_type == "video_url":
|
||||
text_fragments.append("<video>")
|
||||
elif item_type == "audio_url":
|
||||
text_fragments.append("<audio>")
|
||||
elif isinstance(item, str):
|
||||
text_fragments.append(item)
|
||||
content = " ".join(text_fragments)
|
||||
prompt_parts.append(f"{role}: {content}")
|
||||
return "\n".join(part for part in prompt_parts if part).strip()
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform TwelveLabs Pegasus response to LiteLLM format.
|
||||
|
||||
TwelveLabs response format:
|
||||
{
|
||||
"message": "...",
|
||||
"finishReason": "stop" | "length"
|
||||
}
|
||||
|
||||
LiteLLM format:
|
||||
ModelResponse with choices[0].message.content and finish_reason
|
||||
"""
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message=f"Error parsing response: {raw_response.text}, error: {str(e)}",
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
"twelvelabs pegasus response: %s",
|
||||
json.dumps(completion_response, indent=4, default=str),
|
||||
)
|
||||
|
||||
# Extract message content
|
||||
message_content = completion_response.get("message", "")
|
||||
|
||||
# Extract finish reason and map to LiteLLM format
|
||||
finish_reason_raw = completion_response.get("finishReason", "stop")
|
||||
finish_reason = map_finish_reason(finish_reason_raw)
|
||||
|
||||
# Set the response content
|
||||
try:
|
||||
if (
|
||||
message_content
|
||||
and hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None)
|
||||
is None
|
||||
):
|
||||
model_response.choices[0].message.content = message_content # type: ignore
|
||||
model_response.choices[0].finish_reason = finish_reason
|
||||
else:
|
||||
raise Exception("Unable to set message content")
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message=f"Error setting response content: {str(e)}. Response: {completion_response}",
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
|
||||
# Calculate usage from headers
|
||||
bedrock_input_tokens = raw_response.headers.get(
|
||||
"x-amzn-bedrock-input-token-count", None
|
||||
)
|
||||
bedrock_output_tokens = raw_response.headers.get(
|
||||
"x-amzn-bedrock-output-token-count", None
|
||||
)
|
||||
|
||||
prompt_tokens = int(
|
||||
bedrock_input_tokens or litellm.token_counter(messages=messages)
|
||||
)
|
||||
|
||||
completion_tokens = int(
|
||||
bedrock_output_tokens
|
||||
or litellm.token_counter(
|
||||
text=model_response.choices[0].message.content, # type: ignore
|
||||
count_response_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,98 @@
|
||||
import types
|
||||
from typing import Optional
|
||||
|
||||
import litellm
|
||||
|
||||
from .base_invoke_transformation import AmazonInvokeConfig
|
||||
|
||||
|
||||
class AmazonAnthropicConfig(AmazonInvokeConfig):
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
|
||||
Supported Params for the Amazon / Anthropic models:
|
||||
|
||||
- `max_tokens_to_sample` (integer) max tokens,
|
||||
- `temperature` (float) model temperature,
|
||||
- `top_k` (integer) top k,
|
||||
- `top_p` (integer) top p,
|
||||
- `stop_sequences` (string[]) list of stop sequences - e.g. ["\\n\\nHuman:"],
|
||||
- `anthropic_version` (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
|
||||
"""
|
||||
|
||||
max_tokens_to_sample: Optional[int] = litellm.max_tokens
|
||||
stop_sequences: Optional[list] = None
|
||||
temperature: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
anthropic_version: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_tokens_to_sample: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
anthropic_version: Optional[str] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_legacy_anthropic_model_names():
|
||||
return [
|
||||
"anthropic.claude-v2",
|
||||
"anthropic.claude-instant-v1",
|
||||
"anthropic.claude-v2:1",
|
||||
]
|
||||
|
||||
def get_supported_openai_params(self, model: str):
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"temperature",
|
||||
"stop",
|
||||
"top_p",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens_to_sample"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
return optional_params
|
||||
@@ -0,0 +1,206 @@
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.anthropic.chat.transformation import AnthropicConfig
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.bedrock.common_utils import (
|
||||
get_anthropic_beta_from_headers,
|
||||
remove_custom_field_from_tools,
|
||||
)
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_TOOL_SEARCH_BETA_HEADER
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonAnthropicClaudeConfig(AmazonInvokeConfig, AnthropicConfig):
|
||||
"""
|
||||
Reference:
|
||||
https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude
|
||||
https://docs.anthropic.com/claude/docs/models-overview#model-comparison
|
||||
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html
|
||||
|
||||
Supported Params for the Amazon / Anthropic Claude models (Claude 3, Claude 4, etc.):
|
||||
Supports anthropic_beta parameter for beta features like:
|
||||
- computer-use-2025-01-24 (Claude 3.7 Sonnet)
|
||||
- computer-use-2024-10-22 (Claude 3.5 Sonnet v2)
|
||||
- token-efficient-tools-2025-02-19 (Claude 3.7 Sonnet)
|
||||
- interleaved-thinking-2025-05-14 (Claude 4 models)
|
||||
- output-128k-2025-02-19 (Claude 3.7 Sonnet)
|
||||
- dev-full-thinking-2025-05-14 (Claude 4 models)
|
||||
- context-1m-2025-08-07 (Claude Sonnet 4)
|
||||
"""
|
||||
|
||||
anthropic_version: str = "bedrock-2023-05-31"
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "bedrock"
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
return AnthropicConfig.get_supported_openai_params(self, model)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
# Force tool-based structured outputs for Bedrock Invoke
|
||||
# (similar to VertexAI fix in #19201)
|
||||
# Bedrock Invoke doesn't support output_format parameter
|
||||
original_model = model
|
||||
if "response_format" in non_default_params:
|
||||
# Use a model name that forces tool-based approach
|
||||
model = "claude-3-sonnet-20240229"
|
||||
|
||||
optional_params = AnthropicConfig.map_openai_params(
|
||||
self,
|
||||
non_default_params,
|
||||
optional_params,
|
||||
model,
|
||||
drop_params,
|
||||
)
|
||||
|
||||
# Restore original model name
|
||||
model = original_model
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
# Filter out AWS authentication parameters before passing to Anthropic transformation
|
||||
# AWS params should only be used for signing requests, not included in request body
|
||||
filtered_params = {
|
||||
k: v
|
||||
for k, v in optional_params.items()
|
||||
if k not in self.aws_authentication_params
|
||||
}
|
||||
filtered_params = self._normalize_bedrock_tool_search_tools(filtered_params)
|
||||
|
||||
_anthropic_request = AnthropicConfig.transform_request(
|
||||
self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=filtered_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
_anthropic_request.pop("model", None)
|
||||
_anthropic_request.pop("stream", None)
|
||||
# Bedrock Invoke doesn't support output_format parameter
|
||||
_anthropic_request.pop("output_format", None)
|
||||
# Bedrock Invoke doesn't support output_config parameter
|
||||
# Fixes: https://github.com/BerriAI/litellm/issues/22797
|
||||
_anthropic_request.pop("output_config", None)
|
||||
if "anthropic_version" not in _anthropic_request:
|
||||
_anthropic_request["anthropic_version"] = self.anthropic_version
|
||||
|
||||
# Remove `custom` field from tools (Bedrock doesn't support it)
|
||||
# Claude Code sends `custom: {defer_loading: true}` on tool definitions,
|
||||
# which causes Bedrock to reject the request with "Extra inputs are not permitted"
|
||||
# Ref: https://github.com/BerriAI/litellm/issues/22847
|
||||
remove_custom_field_from_tools(_anthropic_request)
|
||||
|
||||
tools = optional_params.get("tools")
|
||||
tool_search_used = self.is_tool_search_used(tools)
|
||||
programmatic_tool_calling_used = self.is_programmatic_tool_calling_used(tools)
|
||||
input_examples_used = self.is_input_examples_used(tools)
|
||||
|
||||
beta_set = set(get_anthropic_beta_from_headers(headers))
|
||||
auto_betas = self.get_anthropic_beta_list(
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
computer_tool_used=self.is_computer_tool_used(tools),
|
||||
prompt_caching_set=False,
|
||||
file_id_used=self.is_file_id_used(messages),
|
||||
mcp_server_used=self.is_mcp_server_used(optional_params.get("mcp_servers")),
|
||||
)
|
||||
beta_set.update(auto_betas)
|
||||
|
||||
if tool_search_used and not (
|
||||
programmatic_tool_calling_used or input_examples_used
|
||||
):
|
||||
beta_set.discard(ANTHROPIC_TOOL_SEARCH_BETA_HEADER)
|
||||
if "opus-4" in model.lower() or "opus_4" in model.lower():
|
||||
beta_set.add("tool-search-tool-2025-10-19")
|
||||
|
||||
# Filter out beta headers that Bedrock Invoke doesn't support
|
||||
# Uses centralized configuration from anthropic_beta_headers_config.json
|
||||
beta_list = list(beta_set)
|
||||
_anthropic_request["anthropic_beta"] = beta_list
|
||||
|
||||
return _anthropic_request
|
||||
|
||||
def _normalize_bedrock_tool_search_tools(self, optional_params: dict) -> dict:
|
||||
"""
|
||||
Convert tool search entries to the format supported by the Bedrock Invoke API.
|
||||
"""
|
||||
tools = optional_params.get("tools")
|
||||
if not tools or not isinstance(tools, list):
|
||||
return optional_params
|
||||
|
||||
normalized_tools = []
|
||||
for tool in tools:
|
||||
tool_type = tool.get("type")
|
||||
if tool_type == "tool_search_tool_bm25_20251119":
|
||||
# Bedrock Invoke does not support the BM25 variant, so skip it.
|
||||
continue
|
||||
if tool_type == "tool_search_tool_regex_20251119":
|
||||
normalized_tool = tool.copy()
|
||||
normalized_tool["type"] = "tool_search_tool_regex"
|
||||
normalized_tool["name"] = normalized_tool.get(
|
||||
"name", "tool_search_tool_regex"
|
||||
)
|
||||
normalized_tools.append(normalized_tool)
|
||||
continue
|
||||
normalized_tools.append(tool)
|
||||
|
||||
optional_params["tools"] = normalized_tools
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
return AnthropicConfig.transform_response(
|
||||
self,
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
@@ -0,0 +1,613 @@
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.litellm_core_utils.prompt_templates.factory import (
|
||||
cohere_message_pt,
|
||||
custom_prompt,
|
||||
deepseek_r1_pt,
|
||||
prompt_factory,
|
||||
)
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.bedrock.chat.invoke_handler import make_call, make_sync_call
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse, Usage
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
|
||||
|
||||
class AmazonInvokeConfig(BaseConfig, BaseAWSLLM):
|
||||
def __init__(self, **kwargs):
|
||||
BaseConfig.__init__(self, **kwargs)
|
||||
BaseAWSLLM.__init__(self, **kwargs)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
"""
|
||||
This is a base invoke model mapping. For Invoke - define a bedrock provider specific config that extends this class.
|
||||
"""
|
||||
return [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stream",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
This is a base invoke model mapping. For Invoke - define a bedrock provider specific config that extends this class.
|
||||
"""
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "stream":
|
||||
optional_params["stream"] = value
|
||||
return optional_params
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete url for the request
|
||||
"""
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
modelId = self.get_bedrock_model_id(
|
||||
model=model,
|
||||
provider=provider,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
aws_bedrock_runtime_endpoint = optional_params.get(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=self._get_aws_region_name(
|
||||
optional_params=optional_params, model=model
|
||||
),
|
||||
)
|
||||
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke-with-response-stream"
|
||||
proxy_endpoint_url = (
|
||||
f"{proxy_endpoint_url}/model/{modelId}/invoke-with-response-stream"
|
||||
)
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
|
||||
|
||||
return endpoint_url
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
return self._sign_request(
|
||||
service_name="bedrock",
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
request_data=request_data,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
|
||||
def _apply_config_to_params(self, config: dict, inference_params: dict) -> None:
|
||||
"""Apply config values to inference_params if not already set."""
|
||||
for k, v in config.items():
|
||||
if k not in inference_params:
|
||||
inference_params[k] = v
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
custom_prompt_dict: dict = litellm_params.pop("custom_prompt_dict", None) or {}
|
||||
hf_model_name = litellm_params.get("hf_model_name", None)
|
||||
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
|
||||
prompt, chat_history = self.convert_messages_to_prompt(
|
||||
model=hf_model_name or model,
|
||||
messages=messages,
|
||||
provider=provider,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params = {
|
||||
k: v
|
||||
for k, v in inference_params.items()
|
||||
if k not in self.aws_authentication_params
|
||||
}
|
||||
request_data: dict = {}
|
||||
if provider == "cohere":
|
||||
if model.startswith("cohere.command-r"):
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereChatConfig().get_config()
|
||||
self._apply_config_to_params(config, inference_params)
|
||||
_data = {"message": prompt, **inference_params}
|
||||
if chat_history is not None:
|
||||
_data["chat_history"] = chat_history
|
||||
request_data = _data
|
||||
else:
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonCohereConfig.get_config()
|
||||
self._apply_config_to_params(config, inference_params)
|
||||
if stream is True:
|
||||
inference_params[
|
||||
"stream"
|
||||
] = True # cohere requires stream = True in inference params
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "anthropic":
|
||||
transformed_request = (
|
||||
litellm.AmazonAnthropicClaudeConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
)
|
||||
|
||||
return transformed_request
|
||||
elif provider == "nova":
|
||||
return litellm.AmazonInvokeNovaConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
elif provider == "ai21":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonAI21Config.get_config()
|
||||
self._apply_config_to_params(config, inference_params)
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "mistral":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonMistralConfig.get_config()
|
||||
self._apply_config_to_params(config, inference_params)
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "amazon": # amazon titan
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonTitanConfig.get_config()
|
||||
self._apply_config_to_params(config, inference_params)
|
||||
request_data = {
|
||||
"inputText": prompt,
|
||||
"textGenerationConfig": inference_params,
|
||||
}
|
||||
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonLlamaConfig.get_config()
|
||||
self._apply_config_to_params(config, inference_params)
|
||||
request_data = {"prompt": prompt, **inference_params}
|
||||
elif provider == "twelvelabs":
|
||||
return litellm.AmazonTwelveLabsPegasusConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
elif provider == "openai":
|
||||
# OpenAI imported models use OpenAI Chat Completions format
|
||||
return litellm.AmazonBedrockOpenAIConfig().transform_request(
|
||||
model=model,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
else:
|
||||
raise BedrockError(
|
||||
status_code=404,
|
||||
message="Bedrock Invoke HTTPX: Unknown provider={}, model={}. Try calling via converse route - `bedrock/converse/<model>`.".format(
|
||||
provider, model
|
||||
),
|
||||
)
|
||||
|
||||
return request_data
|
||||
|
||||
def transform_response( # noqa: PLR0915
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
try:
|
||||
completion_response = raw_response.json()
|
||||
except Exception:
|
||||
raise BedrockError(
|
||||
message=raw_response.text, status_code=raw_response.status_code
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"bedrock invoke response % s",
|
||||
json.dumps(completion_response, indent=4, default=str),
|
||||
)
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
outputText: Optional[str] = None
|
||||
try:
|
||||
if provider == "cohere":
|
||||
if "text" in completion_response:
|
||||
outputText = completion_response["text"] # type: ignore
|
||||
elif "generations" in completion_response:
|
||||
outputText = completion_response["generations"][0]["text"]
|
||||
model_response.choices[0].finish_reason = map_finish_reason(
|
||||
completion_response["generations"][0]["finish_reason"]
|
||||
)
|
||||
elif provider == "anthropic":
|
||||
return litellm.AmazonAnthropicClaudeConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
elif provider == "nova":
|
||||
return litellm.AmazonInvokeNovaConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
)
|
||||
elif provider == "twelvelabs":
|
||||
return litellm.AmazonTwelveLabsPegasusConfig().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
elif provider == "ai21":
|
||||
outputText = (
|
||||
completion_response.get("completions")[0].get("data").get("text")
|
||||
)
|
||||
elif provider == "meta" or provider == "llama" or provider == "deepseek_r1":
|
||||
outputText = completion_response["generation"]
|
||||
elif provider == "mistral":
|
||||
outputText = litellm.AmazonMistralConfig.get_outputText(
|
||||
completion_response, model_response
|
||||
)
|
||||
else: # amazon titan
|
||||
outputText = completion_response.get("results")[0].get("outputText")
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Error processing={}, Received error={}".format(
|
||||
raw_response.text, str(e)
|
||||
),
|
||||
status_code=422,
|
||||
)
|
||||
|
||||
try:
|
||||
if (
|
||||
outputText is not None
|
||||
and len(outputText) > 0
|
||||
and hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
|
||||
is None
|
||||
):
|
||||
model_response.choices[0].message.content = outputText # type: ignore
|
||||
elif (
|
||||
hasattr(model_response.choices[0], "message")
|
||||
and getattr(model_response.choices[0].message, "tool_calls", None) # type: ignore
|
||||
is not None
|
||||
):
|
||||
pass
|
||||
else:
|
||||
raise Exception()
|
||||
except Exception as e:
|
||||
raise BedrockError(
|
||||
message="Error parsing received text={}.\nError-{}".format(
|
||||
outputText, str(e)
|
||||
),
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
|
||||
## CALCULATING USAGE - bedrock returns usage in the headers
|
||||
bedrock_input_tokens = raw_response.headers.get(
|
||||
"x-amzn-bedrock-input-token-count", None
|
||||
)
|
||||
bedrock_output_tokens = raw_response.headers.get(
|
||||
"x-amzn-bedrock-output-token-count", None
|
||||
)
|
||||
|
||||
prompt_tokens = int(
|
||||
bedrock_input_tokens or litellm.token_counter(messages=messages)
|
||||
)
|
||||
|
||||
completion_tokens = int(
|
||||
bedrock_output_tokens
|
||||
or litellm.token_counter(
|
||||
text=model_response.choices[0].message.content, # type: ignore
|
||||
count_response_tokens=True,
|
||||
)
|
||||
)
|
||||
|
||||
model_response.created = int(time.time())
|
||||
model_response.model = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
setattr(model_response, "usage", usage)
|
||||
|
||||
return model_response
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
return headers
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
) -> BaseLLMException:
|
||||
return BedrockError(status_code=status_code, message=error_message)
|
||||
|
||||
@track_llm_api_timing()
|
||||
async def get_async_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
signed_json_body: Optional[bytes] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_call,
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
json_mode=json_mode,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
@track_llm_api_timing()
|
||||
def get_sync_custom_stream_wrapper(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
messages: list,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
signed_json_body: Optional[bytes] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
client = _get_httpx_client(params={})
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=None,
|
||||
make_call=partial(
|
||||
make_sync_call,
|
||||
client=client,
|
||||
api_base=api_base,
|
||||
headers=headers,
|
||||
data=json.dumps(data),
|
||||
signed_json_body=signed_json_body,
|
||||
model=model,
|
||||
messages=messages,
|
||||
logging_obj=logging_obj,
|
||||
fake_stream=True if "ai21" in api_base else False,
|
||||
bedrock_invoke_provider=self.get_bedrock_invoke_provider(model),
|
||||
json_mode=json_mode,
|
||||
),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
return streaming_response
|
||||
|
||||
@property
|
||||
def has_custom_stream_wrapper(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def supports_stream_param_in_request_body(self) -> bool:
|
||||
"""
|
||||
Bedrock invoke does not allow passing `stream` in the request body.
|
||||
"""
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_invoke_provider(
|
||||
model: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the bedrock provider from the model
|
||||
|
||||
handles 4 scenarios:
|
||||
1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
4. model=us.amazon.nova-pro-v1:0 -> Returns `nova`
|
||||
"""
|
||||
if model.startswith("invoke/"):
|
||||
model = model.replace("invoke/", "", 1)
|
||||
|
||||
# Special case: Check for "nova" in model name first (before "amazon")
|
||||
# This handles amazon.nova-* models which would otherwise match "amazon" (Titan)
|
||||
if "nova" in model.lower():
|
||||
if "nova" in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, "nova")
|
||||
|
||||
_split_model = model.split(".")[0]
|
||||
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||
|
||||
# If not a known provider, check for pattern with two slashes
|
||||
provider = AmazonInvokeConfig._get_provider_from_model_path(model)
|
||||
if provider is not None:
|
||||
return provider
|
||||
|
||||
for provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
if provider in model:
|
||||
return provider
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_from_model_path(
|
||||
model_path: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the provider from a model path with format: provider/model-name
|
||||
|
||||
Args:
|
||||
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
|
||||
|
||||
Returns:
|
||||
Optional[str]: The provider name, or None if no valid provider found
|
||||
"""
|
||||
parts = model_path.split("/")
|
||||
if len(parts) >= 1:
|
||||
provider = parts[0]
|
||||
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
|
||||
return None
|
||||
|
||||
def convert_messages_to_prompt(
|
||||
self, model, messages, provider, custom_prompt_dict
|
||||
) -> Tuple[str, Optional[list]]:
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
prompt = ""
|
||||
chat_history: Optional[list] = None
|
||||
## CUSTOM PROMPT
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details.get(
|
||||
"initial_prompt_value", ""
|
||||
),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
messages=messages,
|
||||
)
|
||||
return prompt, None
|
||||
## ELSE
|
||||
if provider == "anthropic" or provider == "amazon":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "mistral":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "meta" or provider == "llama":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "cohere":
|
||||
prompt, chat_history = cohere_message_pt(messages=messages)
|
||||
elif provider == "deepseek_r1":
|
||||
prompt = deepseek_r1_pt(messages=messages)
|
||||
else:
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
if "role" in message:
|
||||
if message["role"] == "user":
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
else:
|
||||
prompt += f"{message['content']}"
|
||||
return prompt, chat_history # type: ignore
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Helper util for handling bedrock-specific cost calculation
|
||||
- e.g.: prompt caching
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.utils import Usage
|
||||
|
||||
|
||||
def cost_per_token(
|
||||
model: str, usage: "Usage", service_tier: Optional[str] = None
|
||||
) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given model, prompt tokens, and completion tokens.
|
||||
|
||||
Follows the same logic as Anthropic's cost per token calculation.
|
||||
"""
|
||||
return generic_cost_per_token(
|
||||
model=model,
|
||||
usage=usage,
|
||||
custom_llm_provider="bedrock",
|
||||
service_tier=service_tier,
|
||||
)
|
||||
@@ -0,0 +1,115 @@
|
||||
"""
|
||||
Bedrock Token Counter implementation using the CountTokens API.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.base_llm.base_utils import BaseTokenCounter
|
||||
from litellm.llms.bedrock.common_utils import BedrockError, get_bedrock_base_model
|
||||
from litellm.llms.bedrock.count_tokens.handler import BedrockCountTokensHandler
|
||||
from litellm.types.utils import LlmProviders, TokenCountResponse
|
||||
|
||||
|
||||
class BedrockTokenCounter(BaseTokenCounter):
|
||||
"""Token counter implementation for AWS Bedrock provider using the CountTokens API."""
|
||||
|
||||
def should_use_token_counting_api(
|
||||
self,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if we should use the Bedrock CountTokens API for token counting.
|
||||
"""
|
||||
return custom_llm_provider == LlmProviders.BEDROCK.value
|
||||
|
||||
async def count_tokens(
|
||||
self,
|
||||
model_to_use: str,
|
||||
messages: Optional[List[Dict[str, Any]]],
|
||||
contents: Optional[List[Dict[str, Any]]],
|
||||
deployment: Optional[Dict[str, Any]] = None,
|
||||
request_model: str = "",
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
system: Optional[Any] = None,
|
||||
) -> Optional[TokenCountResponse]:
|
||||
"""
|
||||
Count tokens using AWS Bedrock's CountTokens API.
|
||||
|
||||
This method calls the existing BedrockCountTokensHandler to make an API call
|
||||
to Bedrock's token counting endpoint, bypassing the local tiktoken-based counting.
|
||||
|
||||
Args:
|
||||
model_to_use: The model identifier
|
||||
messages: The messages to count tokens for
|
||||
contents: Alternative content format (not used for Bedrock)
|
||||
deployment: Deployment configuration containing litellm_params
|
||||
request_model: The original request model name
|
||||
|
||||
Returns:
|
||||
TokenCountResponse with token count, or None if counting fails
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
deployment = deployment or {}
|
||||
litellm_params = deployment.get("litellm_params", {})
|
||||
|
||||
# Build request data in the format expected by BedrockCountTokensHandler
|
||||
request_data: Dict[str, Any] = {
|
||||
"model": model_to_use,
|
||||
"messages": messages,
|
||||
}
|
||||
|
||||
if tools:
|
||||
request_data["tools"] = tools
|
||||
|
||||
if system:
|
||||
request_data["system"] = system
|
||||
|
||||
# Get the resolved model (strip prefixes like bedrock/, converse/, etc.)
|
||||
resolved_model = get_bedrock_base_model(model_to_use)
|
||||
|
||||
try:
|
||||
handler = BedrockCountTokensHandler()
|
||||
result = await handler.handle_count_tokens_request(
|
||||
request_data=request_data,
|
||||
litellm_params=litellm_params,
|
||||
resolved_model=resolved_model,
|
||||
)
|
||||
|
||||
# Transform response to TokenCountResponse
|
||||
if result is not None:
|
||||
return TokenCountResponse(
|
||||
total_tokens=result.get("input_tokens", 0),
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="bedrock_api",
|
||||
original_response=result,
|
||||
)
|
||||
except BedrockError as e:
|
||||
verbose_logger.warning(
|
||||
f"Bedrock CountTokens API error: status={e.status_code}, message={e.message}"
|
||||
)
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="bedrock_api",
|
||||
error=True,
|
||||
error_message=e.message,
|
||||
status_code=e.status_code,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Error calling Bedrock CountTokens API: {e}")
|
||||
return TokenCountResponse(
|
||||
total_tokens=0,
|
||||
request_model=request_model,
|
||||
model_used=model_to_use,
|
||||
tokenizer_type="bedrock_api",
|
||||
error=True,
|
||||
error_message=str(e),
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,136 @@
|
||||
"""
|
||||
AWS Bedrock CountTokens API handler.
|
||||
|
||||
Simplified handler leveraging existing LiteLLM Bedrock infrastructure.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.llms.bedrock.count_tokens.transformation import BedrockCountTokensConfig
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
|
||||
|
||||
class BedrockCountTokensHandler(BedrockCountTokensConfig):
|
||||
"""
|
||||
Simplified handler for AWS Bedrock CountTokens API requests.
|
||||
|
||||
Uses existing LiteLLM infrastructure for authentication and request handling.
|
||||
"""
|
||||
|
||||
async def handle_count_tokens_request(
|
||||
self,
|
||||
request_data: Dict[str, Any],
|
||||
litellm_params: Dict[str, Any],
|
||||
resolved_model: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Handle a CountTokens request using existing LiteLLM patterns.
|
||||
|
||||
Args:
|
||||
request_data: The incoming request payload
|
||||
litellm_params: LiteLLM configuration parameters
|
||||
resolved_model: The actual model ID resolved from router
|
||||
|
||||
Returns:
|
||||
Dictionary containing token count response
|
||||
"""
|
||||
try:
|
||||
# Validate the request
|
||||
self.validate_count_tokens_request(request_data)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"Processing CountTokens request for resolved model: {resolved_model}"
|
||||
)
|
||||
|
||||
# Get AWS region using existing LiteLLM function
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=litellm_params,
|
||||
model=resolved_model,
|
||||
model_id=None,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Retrieved AWS region: {aws_region_name}")
|
||||
|
||||
# Transform request to Bedrock format (supports both Converse and InvokeModel)
|
||||
bedrock_request = self.transform_anthropic_to_bedrock_count_tokens(
|
||||
request_data=request_data
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Transformed request: {bedrock_request}")
|
||||
|
||||
# Get endpoint URL using simplified function
|
||||
endpoint_url = self.get_bedrock_count_tokens_endpoint(
|
||||
resolved_model, aws_region_name
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Making request to: {endpoint_url}")
|
||||
|
||||
# Use existing _sign_request method from BaseAWSLLM
|
||||
# Extract api_key for bearer token auth if provided
|
||||
api_key = litellm_params.get("api_key", None)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
signed_headers, signed_body = self._sign_request(
|
||||
service_name="bedrock",
|
||||
headers=headers,
|
||||
optional_params=litellm_params,
|
||||
request_data=bedrock_request,
|
||||
api_base=endpoint_url,
|
||||
model=resolved_model,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.BEDROCK
|
||||
)
|
||||
|
||||
response = await async_client.post(
|
||||
endpoint_url,
|
||||
headers=signed_headers,
|
||||
data=signed_body,
|
||||
timeout=30.0,
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Response status: {response.status_code}")
|
||||
|
||||
if response.status_code != 200:
|
||||
error_text = response.text
|
||||
verbose_logger.error(f"AWS Bedrock error: {error_text}")
|
||||
raise BedrockError(
|
||||
status_code=response.status_code,
|
||||
message=error_text,
|
||||
)
|
||||
|
||||
bedrock_response = response.json()
|
||||
|
||||
verbose_logger.debug(f"Bedrock response: {bedrock_response}")
|
||||
|
||||
# Transform response back to expected format
|
||||
final_response = self.transform_bedrock_response_to_anthropic(
|
||||
bedrock_response
|
||||
)
|
||||
|
||||
verbose_logger.debug(f"Final response: {final_response}")
|
||||
|
||||
return final_response
|
||||
|
||||
except BedrockError:
|
||||
# Re-raise Bedrock exceptions as-is
|
||||
raise
|
||||
except httpx.HTTPStatusError as e:
|
||||
# HTTP errors - preserve the actual status code
|
||||
verbose_logger.error(f"HTTP error in CountTokens handler: {str(e)}")
|
||||
raise BedrockError(
|
||||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error in CountTokens handler: {str(e)}")
|
||||
raise BedrockError(
|
||||
status_code=500,
|
||||
message=f"CountTokens processing error: {str(e)}",
|
||||
)
|
||||
@@ -0,0 +1,262 @@
|
||||
"""
|
||||
AWS Bedrock CountTokens API transformation logic.
|
||||
|
||||
This module handles the transformation of requests from Anthropic Messages API format
|
||||
to AWS Bedrock's CountTokens API format and vice versa.
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.bedrock.common_utils import get_bedrock_base_model
|
||||
|
||||
|
||||
class BedrockCountTokensConfig(BaseAWSLLM):
|
||||
"""
|
||||
Configuration and transformation logic for AWS Bedrock CountTokens API.
|
||||
|
||||
AWS Bedrock CountTokens API Specification:
|
||||
- Endpoint: POST /model/{modelId}/count-tokens
|
||||
- Input formats: 'invokeModel' or 'converse'
|
||||
- Response: {"inputTokens": <number>}
|
||||
"""
|
||||
|
||||
def _detect_input_type(self, request_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Detect whether to use 'converse' or 'invokeModel' input format.
|
||||
|
||||
Args:
|
||||
request_data: The original request data
|
||||
|
||||
Returns:
|
||||
'converse' or 'invokeModel'
|
||||
"""
|
||||
# If the request has messages in the expected Anthropic format, use converse
|
||||
if "messages" in request_data and isinstance(request_data["messages"], list):
|
||||
return "converse"
|
||||
|
||||
# For raw text or other formats, use invokeModel
|
||||
# This handles cases where the input is prompt-based or already in raw Bedrock format
|
||||
return "invokeModel"
|
||||
|
||||
def transform_anthropic_to_bedrock_count_tokens(
|
||||
self,
|
||||
request_data: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform request to Bedrock CountTokens format.
|
||||
Supports both Converse and InvokeModel input types.
|
||||
|
||||
Input (Anthropic format):
|
||||
{
|
||||
"model": "claude-3-5-sonnet",
|
||||
"messages": [{"role": "user", "content": "Hello!"}]
|
||||
}
|
||||
|
||||
Output (Bedrock CountTokens format for Converse):
|
||||
{
|
||||
"input": {
|
||||
"converse": {
|
||||
"messages": [...],
|
||||
"system": [...] (if present)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Output (Bedrock CountTokens format for InvokeModel):
|
||||
{
|
||||
"input": {
|
||||
"invokeModel": {
|
||||
"body": "{...raw model input...}"
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
input_type = self._detect_input_type(request_data)
|
||||
|
||||
if input_type == "converse":
|
||||
return self._transform_to_converse_format(request_data)
|
||||
else:
|
||||
return self._transform_to_invoke_model_format(request_data)
|
||||
|
||||
def _transform_to_converse_format(
|
||||
self, request_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform to Converse input format, including system and tools."""
|
||||
messages = request_data.get("messages", [])
|
||||
system = request_data.get("system")
|
||||
tools = request_data.get("tools")
|
||||
|
||||
# Transform messages
|
||||
user_messages = []
|
||||
for message in messages:
|
||||
transformed_message: Dict[str, Any] = {
|
||||
"role": message.get("role"),
|
||||
"content": [],
|
||||
}
|
||||
content = message.get("content", "")
|
||||
if isinstance(content, str):
|
||||
transformed_message["content"].append({"text": content})
|
||||
elif isinstance(content, list):
|
||||
transformed_message["content"] = content
|
||||
user_messages.append(transformed_message)
|
||||
|
||||
converse_input: Dict[str, Any] = {"messages": user_messages}
|
||||
|
||||
# Transform system prompt (string or list of blocks → Bedrock format)
|
||||
system_blocks = self._transform_system(system)
|
||||
if system_blocks:
|
||||
converse_input["system"] = system_blocks
|
||||
|
||||
# Transform tools (Anthropic format → Bedrock toolConfig)
|
||||
tool_config = self._transform_tools(tools)
|
||||
if tool_config:
|
||||
converse_input["toolConfig"] = tool_config
|
||||
|
||||
return {"input": {"converse": converse_input}}
|
||||
|
||||
def _transform_system(self, system: Optional[Any]) -> List[Dict[str, Any]]:
|
||||
"""Transform Anthropic system prompt to Bedrock system blocks."""
|
||||
if system is None:
|
||||
return []
|
||||
if isinstance(system, str):
|
||||
return [{"text": system}]
|
||||
if isinstance(system, list):
|
||||
# Already in blocks format (e.g. [{"type": "text", "text": "..."}])
|
||||
return [
|
||||
{"text": block.get("text", "")}
|
||||
for block in system
|
||||
if isinstance(block, dict)
|
||||
]
|
||||
return []
|
||||
|
||||
def _transform_tools(
|
||||
self, tools: Optional[List[Dict[str, Any]]]
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Transform Anthropic tools to Bedrock toolConfig format."""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
bedrock_tools = []
|
||||
for tool in tools:
|
||||
name = tool.get("name", "")
|
||||
# Bedrock tool names must match [a-zA-Z][a-zA-Z0-9_]* and max 64 chars
|
||||
name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
|
||||
if name and not name[0].isalpha():
|
||||
name = "t_" + name
|
||||
name = name[:64]
|
||||
|
||||
description = tool.get("description") or name
|
||||
input_schema = tool.get(
|
||||
"input_schema", {"type": "object", "properties": {}}
|
||||
)
|
||||
|
||||
bedrock_tools.append(
|
||||
{
|
||||
"toolSpec": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
"inputSchema": {"json": input_schema},
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
return {"tools": bedrock_tools}
|
||||
|
||||
def _transform_to_invoke_model_format(
|
||||
self, request_data: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform to InvokeModel input format."""
|
||||
import json
|
||||
|
||||
# For InvokeModel, we need to provide the raw body that would be sent to the model
|
||||
# Remove the 'model' field from the body as it's not part of the model input
|
||||
body_data = {k: v for k, v in request_data.items() if k != "model"}
|
||||
|
||||
return {"input": {"invokeModel": {"body": json.dumps(body_data)}}}
|
||||
|
||||
def get_bedrock_count_tokens_endpoint(
|
||||
self, model: str, aws_region_name: str
|
||||
) -> str:
|
||||
"""
|
||||
Construct the AWS Bedrock CountTokens API endpoint using existing LiteLLM functions.
|
||||
|
||||
Args:
|
||||
model: The resolved model ID from router lookup
|
||||
aws_region_name: AWS region (e.g., "eu-west-1")
|
||||
|
||||
Returns:
|
||||
Complete endpoint URL for CountTokens API
|
||||
"""
|
||||
# Use existing LiteLLM function to get the base model ID (removes region prefix)
|
||||
model_id = get_bedrock_base_model(model)
|
||||
|
||||
# Remove bedrock/ prefix if present
|
||||
if model_id.startswith("bedrock/"):
|
||||
model_id = model_id[8:] # Remove "bedrock/" prefix
|
||||
|
||||
base_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
endpoint = f"{base_url}/model/{model_id}/count-tokens"
|
||||
|
||||
return endpoint
|
||||
|
||||
def transform_bedrock_response_to_anthropic(
|
||||
self, bedrock_response: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform Bedrock CountTokens response to Anthropic format.
|
||||
|
||||
Input (Bedrock response):
|
||||
{
|
||||
"inputTokens": 123
|
||||
}
|
||||
|
||||
Output (Anthropic format):
|
||||
{
|
||||
"input_tokens": 123
|
||||
}
|
||||
"""
|
||||
input_tokens = bedrock_response.get("inputTokens", 0)
|
||||
|
||||
return {"input_tokens": input_tokens}
|
||||
|
||||
def validate_count_tokens_request(self, request_data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the incoming count tokens request.
|
||||
Supports both Converse and InvokeModel input formats.
|
||||
|
||||
Args:
|
||||
request_data: The request payload
|
||||
|
||||
Raises:
|
||||
ValueError: If the request is invalid
|
||||
"""
|
||||
if not request_data.get("model"):
|
||||
raise ValueError("model parameter is required")
|
||||
|
||||
input_type = self._detect_input_type(request_data)
|
||||
|
||||
if input_type == "converse":
|
||||
# Validate Converse format (messages-based)
|
||||
messages = request_data.get("messages", [])
|
||||
if not messages:
|
||||
raise ValueError("messages parameter is required for Converse input")
|
||||
|
||||
if not isinstance(messages, list):
|
||||
raise ValueError("messages must be a list")
|
||||
|
||||
for i, message in enumerate(messages):
|
||||
if not isinstance(message, dict):
|
||||
raise ValueError(f"Message {i} must be a dictionary")
|
||||
|
||||
if "role" not in message:
|
||||
raise ValueError(f"Message {i} must have a 'role' field")
|
||||
|
||||
if "content" not in message:
|
||||
raise ValueError(f"Message {i} must have a 'content' field")
|
||||
else:
|
||||
# For InvokeModel format, we need at least some content to count tokens
|
||||
# The content structure varies by model, so we do minimal validation
|
||||
if len(request_data) <= 1: # Only has 'model' field
|
||||
raise ValueError("Request must contain content to count tokens")
|
||||
@@ -0,0 +1,361 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Nova /invoke and /async-invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Supports:
|
||||
- Synchronous embeddings (SINGLE_EMBEDDING)
|
||||
- Asynchronous embeddings with segmentation (SEGMENTED_EMBEDDING)
|
||||
- Multimodal inputs: text, image, video, audio
|
||||
- Multiple embedding purposes and dimensions
|
||||
|
||||
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/nova-embed.html
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.types.utils import (
|
||||
Embedding,
|
||||
EmbeddingResponse,
|
||||
PromptTokensDetailsWrapper,
|
||||
Usage,
|
||||
)
|
||||
|
||||
|
||||
class AmazonNovaEmbeddingConfig:
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/nova-embed.html
|
||||
|
||||
Amazon Nova Multimodal Embeddings supports:
|
||||
- Text, image, video, and audio inputs
|
||||
- Synchronous (InvokeModel) and asynchronous (StartAsyncInvoke) APIs
|
||||
- Multiple embedding purposes and dimensions
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return [
|
||||
"dimensions",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
"""Map OpenAI-style parameters to Nova parameters."""
|
||||
for k, v in non_default_params.items():
|
||||
if k == "dimensions":
|
||||
# Map OpenAI dimensions to Nova embedding_dimension
|
||||
optional_params["embedding_dimension"] = v
|
||||
elif k in self.get_supported_openai_params():
|
||||
optional_params[k] = v
|
||||
return optional_params
|
||||
|
||||
def _parse_data_url(self, data_url: str) -> tuple:
|
||||
"""
|
||||
Parse a data URL to extract the media type and base64 data.
|
||||
|
||||
Args:
|
||||
data_url: Data URL in format: data:image/jpeg;base64,/9j/4AAQ...
|
||||
|
||||
Returns:
|
||||
tuple: (media_type, base64_data)
|
||||
media_type: e.g., "image/jpeg", "video/mp4", "audio/mpeg"
|
||||
base64_data: The base64-encoded data without the prefix
|
||||
"""
|
||||
if not data_url.startswith("data:"):
|
||||
raise ValueError(f"Invalid data URL format: {data_url[:50]}...")
|
||||
|
||||
# Split by comma to separate metadata from data
|
||||
# Format: data:image/jpeg;base64,<base64_data>
|
||||
if "," not in data_url:
|
||||
raise ValueError(
|
||||
f"Invalid data URL format (missing comma): {data_url[:50]}..."
|
||||
)
|
||||
|
||||
metadata, base64_data = data_url.split(",", 1)
|
||||
|
||||
# Extract media type from metadata
|
||||
# Remove 'data:' prefix and ';base64' suffix
|
||||
metadata = metadata[5:] # Remove 'data:'
|
||||
|
||||
if ";" in metadata:
|
||||
media_type = metadata.split(";")[0]
|
||||
else:
|
||||
media_type = metadata
|
||||
|
||||
return media_type, base64_data
|
||||
|
||||
def _transform_request(
|
||||
self,
|
||||
input: str,
|
||||
inference_params: dict,
|
||||
async_invoke_route: bool = False,
|
||||
model_id: Optional[str] = None,
|
||||
output_s3_uri: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform OpenAI-style input to Nova format.
|
||||
|
||||
Only handles OpenAI params (dimensions). All other Nova-specific params
|
||||
should be passed via inference_params and will be passed through as-is.
|
||||
|
||||
Args:
|
||||
input: The input text or media reference
|
||||
inference_params: Additional parameters (will be passed through)
|
||||
async_invoke_route: Whether this is for async invoke
|
||||
model_id: Model ID (for async invoke)
|
||||
output_s3_uri: S3 URI for output (for async invoke)
|
||||
|
||||
Returns:
|
||||
dict: Nova embedding request
|
||||
"""
|
||||
# Determine task type
|
||||
task_type = "SEGMENTED_EMBEDDING" if async_invoke_route else "SINGLE_EMBEDDING"
|
||||
|
||||
# Build the base request structure
|
||||
request: dict = {
|
||||
"schemaVersion": "nova-multimodal-embed-v1",
|
||||
"taskType": task_type,
|
||||
}
|
||||
|
||||
# Start with inference_params (user-provided params)
|
||||
embedding_params = inference_params.copy()
|
||||
|
||||
embedding_params.pop("output_s3_uri", None)
|
||||
|
||||
# Map OpenAI dimensions to embeddingDimension if provided
|
||||
if "dimensions" in embedding_params:
|
||||
embedding_params["embeddingDimension"] = embedding_params.pop("dimensions")
|
||||
elif "embedding_dimension" in embedding_params:
|
||||
embedding_params["embeddingDimension"] = embedding_params.pop(
|
||||
"embedding_dimension"
|
||||
)
|
||||
|
||||
# Add required embeddingPurpose if not provided (required by Nova API)
|
||||
if "embeddingPurpose" not in embedding_params:
|
||||
embedding_params["embeddingPurpose"] = "GENERIC_INDEX"
|
||||
|
||||
# Add required embeddingDimension if not provided (required by Nova API)
|
||||
if "embeddingDimension" not in embedding_params:
|
||||
embedding_params["embeddingDimension"] = 3072
|
||||
|
||||
# For text/media input, add basic structure if user hasn't provided text/image/video/audio
|
||||
if (
|
||||
"text" not in embedding_params
|
||||
and "image" not in embedding_params
|
||||
and "video" not in embedding_params
|
||||
and "audio" not in embedding_params
|
||||
):
|
||||
# Check if input is a data URL (e.g., data:image/jpeg;base64,...)
|
||||
if input.startswith("data:"):
|
||||
# Parse the data URL to extract media type and base64 data
|
||||
media_type, base64_data = self._parse_data_url(input)
|
||||
|
||||
if media_type.startswith("image/"):
|
||||
# Extract image format from MIME type (e.g., image/jpeg -> jpeg)
|
||||
image_format = media_type.split("/")[1].lower()
|
||||
# Nova API expects specific formats
|
||||
if image_format == "jpg":
|
||||
image_format = "jpeg"
|
||||
|
||||
embedding_params["image"] = {
|
||||
"format": image_format,
|
||||
"source": {"bytes": base64_data},
|
||||
}
|
||||
elif media_type.startswith("video/"):
|
||||
# Handle video data URLs
|
||||
video_format = media_type.split("/")[1].lower()
|
||||
embedding_params["video"] = {
|
||||
"format": video_format,
|
||||
"source": {"bytes": base64_data},
|
||||
}
|
||||
elif media_type.startswith("audio/"):
|
||||
# Handle audio data URLs
|
||||
audio_format = media_type.split("/")[1].lower()
|
||||
embedding_params["audio"] = {
|
||||
"format": audio_format,
|
||||
"source": {"bytes": base64_data},
|
||||
}
|
||||
else:
|
||||
# Fallback to text for unknown types
|
||||
embedding_params["text"] = {"value": input, "truncationMode": "END"}
|
||||
elif input.startswith("s3://"):
|
||||
# S3 URL - default to text for now, user should specify modality
|
||||
embedding_params["text"] = {
|
||||
"source": {"s3Location": {"uri": input}},
|
||||
"truncationMode": "END", # Required by Nova API
|
||||
}
|
||||
else:
|
||||
# Plain text input
|
||||
embedding_params["text"] = {
|
||||
"value": input,
|
||||
"truncationMode": "END", # Required by Nova API
|
||||
}
|
||||
|
||||
# Set the embedding params in the request
|
||||
if task_type == "SINGLE_EMBEDDING":
|
||||
request["singleEmbeddingParams"] = embedding_params
|
||||
else:
|
||||
request["segmentedEmbeddingParams"] = embedding_params
|
||||
|
||||
# For async invoke, wrap in the async invoke format
|
||||
if async_invoke_route and model_id:
|
||||
return self._wrap_async_invoke_request(
|
||||
model_input=request,
|
||||
model_id=model_id,
|
||||
output_s3_uri=output_s3_uri,
|
||||
)
|
||||
|
||||
return request
|
||||
|
||||
def _wrap_async_invoke_request(
|
||||
self,
|
||||
model_input: dict,
|
||||
model_id: str,
|
||||
output_s3_uri: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Wrap the transformed request in the AWS Bedrock async invoke format.
|
||||
|
||||
Args:
|
||||
model_input: The transformed Nova embedding request
|
||||
model_id: The model identifier (without async_invoke prefix)
|
||||
output_s3_uri: S3 URI for output data config
|
||||
|
||||
Returns:
|
||||
dict: The wrapped async invoke request
|
||||
"""
|
||||
import urllib.parse
|
||||
|
||||
# Clean the model ID
|
||||
unquoted_model_id = urllib.parse.unquote(model_id)
|
||||
if unquoted_model_id.startswith("async_invoke/"):
|
||||
unquoted_model_id = unquoted_model_id.replace("async_invoke/", "")
|
||||
|
||||
# Validate that the S3 URI is not empty
|
||||
if not output_s3_uri or output_s3_uri.strip() == "":
|
||||
raise ValueError("output_s3_uri is required for async invoke requests")
|
||||
|
||||
return {
|
||||
"modelId": unquoted_model_id,
|
||||
"modelInput": model_input,
|
||||
"outputDataConfig": {"s3OutputDataConfig": {"s3Uri": output_s3_uri}},
|
||||
}
|
||||
|
||||
def _transform_response(
|
||||
self,
|
||||
response_list: List[dict],
|
||||
model: str,
|
||||
batch_data: Optional[List[dict]] = None,
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Transform Nova response to OpenAI format.
|
||||
|
||||
Nova response format:
|
||||
{
|
||||
"embeddings": [
|
||||
{
|
||||
"embeddingType": "TEXT" | "IMAGE" | "VIDEO" | "AUDIO" | "AUDIO_VIDEO_COMBINED",
|
||||
"embedding": [0.1, 0.2, ...],
|
||||
"truncatedCharLength": 100 # Optional, only for text
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
embeddings: List[Embedding] = []
|
||||
total_tokens = 0
|
||||
|
||||
for response in response_list:
|
||||
# Nova response has an "embeddings" array
|
||||
if "embeddings" in response and isinstance(response["embeddings"], list):
|
||||
for item in response["embeddings"]:
|
||||
if "embedding" in item:
|
||||
embedding = Embedding(
|
||||
embedding=item["embedding"],
|
||||
index=len(embeddings),
|
||||
object="embedding",
|
||||
)
|
||||
embeddings.append(embedding)
|
||||
|
||||
# Estimate token count
|
||||
# For text, use truncatedCharLength if available
|
||||
if "truncatedCharLength" in item:
|
||||
total_tokens += item["truncatedCharLength"] // 4
|
||||
else:
|
||||
# Rough estimate based on embedding dimension
|
||||
total_tokens += len(item["embedding"]) // 4
|
||||
elif "embedding" in response:
|
||||
# Direct embedding response (fallback)
|
||||
embedding = Embedding(
|
||||
embedding=response["embedding"],
|
||||
index=len(embeddings),
|
||||
object="embedding",
|
||||
)
|
||||
embeddings.append(embedding)
|
||||
total_tokens += len(response["embedding"]) // 4
|
||||
|
||||
# Count images from original requests for cost calculation
|
||||
image_count = 0
|
||||
if batch_data:
|
||||
for request_data in batch_data:
|
||||
# Nova wraps params in singleEmbeddingParams or segmentedEmbeddingParams
|
||||
params = request_data.get(
|
||||
"singleEmbeddingParams",
|
||||
request_data.get("segmentedEmbeddingParams", {}),
|
||||
)
|
||||
if "image" in params:
|
||||
image_count += 1
|
||||
|
||||
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
|
||||
if image_count > 0:
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
image_count=image_count,
|
||||
)
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=total_tokens,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(data=embeddings, model=model, usage=usage)
|
||||
|
||||
def _transform_async_invoke_response(
|
||||
self, response: dict, model: str
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Transform async invoke response (invocation ARN) to OpenAI format.
|
||||
|
||||
AWS async invoke returns:
|
||||
{
|
||||
"invocationArn": "arn:aws:bedrock:us-east-1:123456789012:async-invoke/abc123"
|
||||
}
|
||||
|
||||
We transform this to a job-like embedding response with the ARN in hidden params.
|
||||
"""
|
||||
invocation_arn = response.get("invocationArn", "")
|
||||
|
||||
# Create a placeholder embedding object for the job
|
||||
embedding = Embedding(
|
||||
embedding=[], # Empty embedding for async jobs
|
||||
index=0,
|
||||
object="embedding",
|
||||
)
|
||||
|
||||
# Create usage object (empty for async jobs)
|
||||
usage = Usage(prompt_tokens=0, total_tokens=0)
|
||||
|
||||
# Create hidden params with job ID
|
||||
from litellm.types.llms.base import HiddenParams
|
||||
|
||||
hidden_params = HiddenParams()
|
||||
setattr(hidden_params, "_invocation_arn", invocation_arn)
|
||||
|
||||
return EmbeddingResponse(
|
||||
data=[embedding],
|
||||
model=model,
|
||||
usage=usage,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
@@ -0,0 +1,88 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan G1 /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Convers
|
||||
- G1 request format
|
||||
|
||||
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonTitanG1EmbeddingRequest,
|
||||
AmazonTitanG1EmbeddingResponse,
|
||||
)
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
|
||||
|
||||
class AmazonTitanG1Config:
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
return optional_params
|
||||
|
||||
def _transform_request(
|
||||
self, input: str, inference_params: dict
|
||||
) -> AmazonTitanG1EmbeddingRequest:
|
||||
return AmazonTitanG1EmbeddingRequest(inputText=input)
|
||||
|
||||
def _transform_response(
|
||||
self, response_list: List[dict], model: str
|
||||
) -> EmbeddingResponse:
|
||||
total_prompt_tokens = 0
|
||||
|
||||
transformed_responses: List[Embedding] = []
|
||||
for index, response in enumerate(response_list):
|
||||
_parsed_response = AmazonTitanG1EmbeddingResponse(**response) # type: ignore
|
||||
transformed_responses.append(
|
||||
Embedding(
|
||||
embedding=_parsed_response["embedding"],
|
||||
index=index,
|
||||
object="embedding",
|
||||
)
|
||||
)
|
||||
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=total_prompt_tokens,
|
||||
)
|
||||
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
||||
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan multimodal /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-mm.html
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonTitanMultimodalEmbeddingConfig,
|
||||
AmazonTitanMultimodalEmbeddingRequest,
|
||||
AmazonTitanMultimodalEmbeddingResponse,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
Embedding,
|
||||
EmbeddingResponse,
|
||||
PromptTokensDetailsWrapper,
|
||||
Usage,
|
||||
)
|
||||
from litellm.utils import get_base64_str, is_base64_encoded
|
||||
|
||||
|
||||
class AmazonTitanMultimodalEmbeddingG1Config:
|
||||
"""
|
||||
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-mm.html
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return ["dimensions"]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "dimensions":
|
||||
optional_params[
|
||||
"embeddingConfig"
|
||||
] = AmazonTitanMultimodalEmbeddingConfig(outputEmbeddingLength=v)
|
||||
return optional_params
|
||||
|
||||
def _transform_request(
|
||||
self, input: str, inference_params: dict
|
||||
) -> AmazonTitanMultimodalEmbeddingRequest:
|
||||
## check if b64 encoded str or not ##
|
||||
is_encoded = is_base64_encoded(input)
|
||||
if is_encoded: # check if string is b64 encoded image or not
|
||||
b64_str = get_base64_str(input)
|
||||
transformed_request = AmazonTitanMultimodalEmbeddingRequest(
|
||||
inputImage=b64_str
|
||||
)
|
||||
else:
|
||||
transformed_request = AmazonTitanMultimodalEmbeddingRequest(inputText=input)
|
||||
|
||||
for k, v in inference_params.items():
|
||||
transformed_request[k] = v # type: ignore
|
||||
return transformed_request
|
||||
|
||||
def _transform_response(
|
||||
self,
|
||||
response_list: List[dict],
|
||||
model: str,
|
||||
batch_data: Optional[List[dict]] = None,
|
||||
) -> EmbeddingResponse:
|
||||
total_prompt_tokens = 0
|
||||
transformed_responses: List[Embedding] = []
|
||||
for index, response in enumerate(response_list):
|
||||
_parsed_response = AmazonTitanMultimodalEmbeddingResponse(**response) # type: ignore
|
||||
transformed_responses.append(
|
||||
Embedding(
|
||||
embedding=_parsed_response["embedding"],
|
||||
index=index,
|
||||
object="embedding",
|
||||
)
|
||||
)
|
||||
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||
|
||||
# Count images from original requests for cost calculation
|
||||
image_count = 0
|
||||
if batch_data:
|
||||
for request_data in batch_data:
|
||||
if "inputImage" in request_data:
|
||||
image_count += 1
|
||||
|
||||
prompt_tokens_details: Optional[PromptTokensDetailsWrapper] = None
|
||||
if image_count > 0:
|
||||
prompt_tokens_details = PromptTokensDetailsWrapper(
|
||||
image_count=image_count,
|
||||
)
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=total_prompt_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
)
|
||||
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
||||
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Amazon Titan V2 /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Convers
|
||||
- v2 request format
|
||||
|
||||
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonTitanV2EmbeddingRequest,
|
||||
AmazonTitanV2EmbeddingResponse,
|
||||
)
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
|
||||
|
||||
class AmazonTitanV2Config:
|
||||
"""
|
||||
Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
|
||||
normalize: boolean - flag indicating whether or not to normalize the output embeddings. Defaults to true
|
||||
dimensions: int - The number of dimensions the output embeddings should have. The following values are accepted: 1024 (default), 512, 256.
|
||||
"""
|
||||
|
||||
normalize: Optional[bool] = None
|
||||
dimensions: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self, normalize: Optional[bool] = None, dimensions: Optional[int] = None
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return ["dimensions", "encoding_format"]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "dimensions":
|
||||
optional_params["dimensions"] = v
|
||||
elif k == "encoding_format":
|
||||
# Map OpenAI encoding_format to AWS embeddingTypes
|
||||
if v == "float":
|
||||
optional_params["embeddingTypes"] = ["float"]
|
||||
elif v == "base64":
|
||||
# base64 maps to binary format in AWS
|
||||
optional_params["embeddingTypes"] = ["binary"]
|
||||
else:
|
||||
# For any other encoding format, default to float
|
||||
optional_params["embeddingTypes"] = ["float"]
|
||||
return optional_params
|
||||
|
||||
def _transform_request(
|
||||
self, input: str, inference_params: dict
|
||||
) -> AmazonTitanV2EmbeddingRequest:
|
||||
return AmazonTitanV2EmbeddingRequest(inputText=input, **inference_params) # type: ignore
|
||||
|
||||
def _transform_response(
|
||||
self, response_list: List[dict], model: str
|
||||
) -> EmbeddingResponse:
|
||||
total_prompt_tokens = 0
|
||||
|
||||
transformed_responses: List[Embedding] = []
|
||||
for index, response in enumerate(response_list):
|
||||
_parsed_response = AmazonTitanV2EmbeddingResponse(**response) # type: ignore
|
||||
|
||||
# According to AWS docs, embeddingsByType is always present
|
||||
# If binary was requested (encoding_format="base64"), use binary data
|
||||
# Otherwise, use float data from embeddingsByType or fallback to embedding field
|
||||
embedding_data: Union[List[float], List[int]]
|
||||
|
||||
if (
|
||||
"embeddingsByType" in _parsed_response
|
||||
and "binary" in _parsed_response["embeddingsByType"]
|
||||
):
|
||||
# Use binary data if available (for encoding_format="base64")
|
||||
embedding_data = _parsed_response["embeddingsByType"]["binary"]
|
||||
elif (
|
||||
"embeddingsByType" in _parsed_response
|
||||
and "float" in _parsed_response["embeddingsByType"]
|
||||
):
|
||||
# Use float data from embeddingsByType
|
||||
embedding_data = _parsed_response["embeddingsByType"]["float"]
|
||||
elif "embedding" in _parsed_response:
|
||||
# Fallback to legacy embedding field
|
||||
embedding_data = _parsed_response["embedding"]
|
||||
else:
|
||||
raise ValueError(f"No embedding data found in response: {response}")
|
||||
|
||||
transformed_responses.append(
|
||||
Embedding(
|
||||
embedding=embedding_data,
|
||||
index=index,
|
||||
object="embedding",
|
||||
)
|
||||
)
|
||||
total_prompt_tokens += _parsed_response["inputTextTokenCount"]
|
||||
|
||||
usage = Usage(
|
||||
prompt_tokens=total_prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=total_prompt_tokens,
|
||||
)
|
||||
return EmbeddingResponse(model=model, usage=usage, data=transformed_responses)
|
||||
@@ -0,0 +1,47 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock Cohere /invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
|
||||
from litellm.llms.cohere.embed.transformation import CohereEmbeddingConfig
|
||||
from litellm.types.llms.bedrock import CohereEmbeddingRequest
|
||||
|
||||
|
||||
class BedrockCohereEmbeddingConfig:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return ["encoding_format", "dimensions"]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "encoding_format":
|
||||
optional_params["embedding_types"] = v
|
||||
elif k == "dimensions":
|
||||
optional_params["output_dimension"] = v
|
||||
return optional_params
|
||||
|
||||
def _is_v3_model(self, model: str) -> bool:
|
||||
return "3" in model
|
||||
|
||||
def _transform_request(
|
||||
self, model: str, input: List[str], inference_params: dict
|
||||
) -> CohereEmbeddingRequest:
|
||||
transformed_request = CohereEmbeddingConfig()._transform_request(
|
||||
model, input, inference_params
|
||||
)
|
||||
|
||||
new_transformed_request = CohereEmbeddingRequest(
|
||||
input_type=transformed_request["input_type"],
|
||||
)
|
||||
for k in CohereEmbeddingRequest.__annotations__.keys():
|
||||
if k in transformed_request:
|
||||
new_transformed_request[k] = transformed_request[k] # type: ignore
|
||||
|
||||
return new_transformed_request
|
||||
@@ -0,0 +1,699 @@
|
||||
"""
|
||||
Handles embedding calls to Bedrock's `/invoke` endpoint
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import urllib.parse
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union, get_args
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.constants import BEDROCK_EMBEDDING_PROVIDERS_LITERAL
|
||||
from litellm.llms.cohere.embed.handler import embedding as cohere_embedding
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonEmbeddingRequest,
|
||||
CohereEmbeddingRequest,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse, LlmProviders
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockError
|
||||
from .amazon_nova_transformation import AmazonNovaEmbeddingConfig
|
||||
from .amazon_titan_g1_transformation import AmazonTitanG1Config
|
||||
from .amazon_titan_multimodal_transformation import (
|
||||
AmazonTitanMultimodalEmbeddingG1Config,
|
||||
)
|
||||
from .amazon_titan_v2_transformation import AmazonTitanV2Config
|
||||
from .cohere_transformation import BedrockCohereEmbeddingConfig
|
||||
from .twelvelabs_marengo_transformation import TwelveLabsMarengoEmbeddingConfig
|
||||
|
||||
|
||||
class BedrockEmbedding(BaseAWSLLM):
|
||||
def _load_credentials(
|
||||
self,
|
||||
optional_params: dict,
|
||||
) -> Tuple[Any, str]:
|
||||
try:
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||
aws_secret_access_key = optional_params.pop("aws_secret_access_key", None)
|
||||
aws_access_key_id = optional_params.pop("aws_access_key_id", None)
|
||||
aws_session_token = optional_params.pop("aws_session_token", None)
|
||||
aws_region_name = optional_params.pop("aws_region_name", None)
|
||||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
aws_sts_endpoint = optional_params.pop("aws_sts_endpoint", None)
|
||||
|
||||
### SET REGION NAME ###
|
||||
if aws_region_name is None:
|
||||
# check env #
|
||||
litellm_aws_region_name = get_secret("AWS_REGION_NAME", None)
|
||||
|
||||
if litellm_aws_region_name is not None and isinstance(
|
||||
litellm_aws_region_name, str
|
||||
):
|
||||
aws_region_name = litellm_aws_region_name
|
||||
|
||||
standard_aws_region_name = get_secret("AWS_REGION", None)
|
||||
if standard_aws_region_name is not None and isinstance(
|
||||
standard_aws_region_name, str
|
||||
):
|
||||
aws_region_name = standard_aws_region_name
|
||||
|
||||
if aws_region_name is None:
|
||||
aws_region_name = "us-west-2"
|
||||
|
||||
credentials: Credentials = self.get_credentials( # type: ignore
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
aws_session_token=aws_session_token,
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=aws_session_name,
|
||||
aws_profile_name=aws_profile_name,
|
||||
aws_role_name=aws_role_name,
|
||||
aws_web_identity_token=aws_web_identity_token,
|
||||
aws_sts_endpoint=aws_sts_endpoint,
|
||||
)
|
||||
return credentials, aws_region_name
|
||||
|
||||
async def async_embeddings(self):
|
||||
pass
|
||||
|
||||
def _make_sync_call(
|
||||
self,
|
||||
client: Optional[HTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
) -> dict:
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = _get_httpx_client(_params) # type: ignore
|
||||
else:
|
||||
client = client
|
||||
try:
|
||||
response = client.post(url=api_base, headers=headers, data=json.dumps(data)) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return response.json()
|
||||
|
||||
async def _make_async_call(
|
||||
self,
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: dict,
|
||||
) -> dict:
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
_params = {}
|
||||
if timeout is not None:
|
||||
if isinstance(timeout, float) or isinstance(timeout, int):
|
||||
timeout = httpx.Timeout(timeout)
|
||||
_params["timeout"] = timeout
|
||||
client = get_async_httpx_client(
|
||||
params=_params, llm_provider=litellm.LlmProviders.BEDROCK
|
||||
)
|
||||
else:
|
||||
client = client
|
||||
|
||||
try:
|
||||
response = await client.post(url=api_base, headers=headers, data=json.dumps(data)) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return response.json()
|
||||
|
||||
def _transform_response(
|
||||
self,
|
||||
response_list: List[dict],
|
||||
model: str,
|
||||
provider: BEDROCK_EMBEDDING_PROVIDERS_LITERAL,
|
||||
is_async_invoke: Optional[bool] = False,
|
||||
batch_data: Optional[List[dict]] = None,
|
||||
) -> Optional[EmbeddingResponse]:
|
||||
"""
|
||||
Transforms the response from the Bedrock embedding provider to the OpenAI format.
|
||||
"""
|
||||
returned_response: Optional[EmbeddingResponse] = None
|
||||
|
||||
# Handle async invoke responses (single response with invocationArn)
|
||||
if (
|
||||
is_async_invoke
|
||||
and len(response_list) == 1
|
||||
and "invocationArn" in response_list[0]
|
||||
):
|
||||
if provider == "twelvelabs":
|
||||
returned_response = (
|
||||
TwelveLabsMarengoEmbeddingConfig()._transform_async_invoke_response(
|
||||
response=response_list[0], model=model
|
||||
)
|
||||
)
|
||||
elif provider == "nova":
|
||||
returned_response = (
|
||||
AmazonNovaEmbeddingConfig()._transform_async_invoke_response(
|
||||
response=response_list[0], model=model
|
||||
)
|
||||
)
|
||||
else:
|
||||
# For other providers, create a generic async response
|
||||
invocation_arn = response_list[0].get("invocationArn", "")
|
||||
|
||||
from litellm.types.utils import Embedding, Usage
|
||||
|
||||
embedding = Embedding(
|
||||
embedding=[],
|
||||
index=0,
|
||||
object="embedding", # Must be literal "embedding"
|
||||
)
|
||||
usage = Usage(prompt_tokens=0, total_tokens=0)
|
||||
|
||||
# Create hidden params with job ID
|
||||
from litellm.types.llms.base import HiddenParams
|
||||
|
||||
hidden_params = HiddenParams()
|
||||
setattr(hidden_params, "_invocation_arn", invocation_arn)
|
||||
|
||||
returned_response = EmbeddingResponse(
|
||||
data=[embedding],
|
||||
model=model,
|
||||
usage=usage,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
else:
|
||||
# Handle regular invoke responses
|
||||
if model == "amazon.titan-embed-image-v1":
|
||||
returned_response = (
|
||||
AmazonTitanMultimodalEmbeddingG1Config()._transform_response(
|
||||
response_list=response_list, model=model, batch_data=batch_data
|
||||
)
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v1":
|
||||
returned_response = AmazonTitanG1Config()._transform_response(
|
||||
response_list=response_list, model=model
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v2:0":
|
||||
returned_response = AmazonTitanV2Config()._transform_response(
|
||||
response_list=response_list, model=model
|
||||
)
|
||||
elif provider == "twelvelabs":
|
||||
returned_response = (
|
||||
TwelveLabsMarengoEmbeddingConfig()._transform_response(
|
||||
response_list=response_list, model=model
|
||||
)
|
||||
)
|
||||
elif provider == "nova":
|
||||
returned_response = AmazonNovaEmbeddingConfig()._transform_response(
|
||||
response_list=response_list, model=model, batch_data=batch_data
|
||||
)
|
||||
|
||||
##########################################################
|
||||
# Validate returned response
|
||||
##########################################################
|
||||
if returned_response is None:
|
||||
raise Exception(
|
||||
"Unable to map model response to known provider format. model={}".format(
|
||||
model
|
||||
)
|
||||
)
|
||||
return returned_response
|
||||
|
||||
def _single_func_embeddings(
|
||||
self,
|
||||
client: Optional[HTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
batch_data: List[dict],
|
||||
credentials: Any,
|
||||
extra_headers: Optional[dict],
|
||||
endpoint_url: str,
|
||||
aws_region_name: str,
|
||||
model: str,
|
||||
logging_obj: Any,
|
||||
provider: BEDROCK_EMBEDDING_PROVIDERS_LITERAL,
|
||||
api_key: Optional[str] = None,
|
||||
is_async_invoke: Optional[bool] = False,
|
||||
):
|
||||
responses: List[dict] = []
|
||||
for data in batch_data:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
prepped = self.get_request_headers( # type: ignore # type: ignore
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=endpoint_url,
|
||||
data=json.dumps(data),
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepped.url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
headers_for_request = (
|
||||
dict(prepped.headers) if hasattr(prepped, "headers") else {}
|
||||
)
|
||||
response = self._make_sync_call(
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
api_base=prepped.url,
|
||||
headers=headers_for_request,
|
||||
data=data,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
responses.append(response)
|
||||
|
||||
return self._transform_response(
|
||||
response_list=responses,
|
||||
model=model,
|
||||
provider=provider,
|
||||
is_async_invoke=is_async_invoke,
|
||||
batch_data=batch_data,
|
||||
)
|
||||
|
||||
async def _async_single_func_embeddings(
|
||||
self,
|
||||
client: Optional[AsyncHTTPHandler],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
batch_data: List[dict],
|
||||
credentials: Any,
|
||||
extra_headers: Optional[dict],
|
||||
endpoint_url: str,
|
||||
aws_region_name: str,
|
||||
model: str,
|
||||
logging_obj: Any,
|
||||
provider: BEDROCK_EMBEDDING_PROVIDERS_LITERAL,
|
||||
api_key: Optional[str] = None,
|
||||
is_async_invoke: Optional[bool] = False,
|
||||
):
|
||||
responses: List[dict] = []
|
||||
for data in batch_data:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
prepped = self.get_request_headers( # type: ignore # type: ignore
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=endpoint_url,
|
||||
data=json.dumps(data),
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepped.url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
# Convert CaseInsensitiveDict to regular dict for httpx compatibility
|
||||
# This ensures custom headers are properly forwarded, especially with IAM roles and custom api_base
|
||||
headers_for_request = (
|
||||
dict(prepped.headers) if hasattr(prepped, "headers") else {}
|
||||
)
|
||||
response = await self._make_async_call(
|
||||
client=client,
|
||||
timeout=timeout,
|
||||
api_base=prepped.url,
|
||||
headers=headers_for_request,
|
||||
data=data,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
responses.append(response)
|
||||
## TRANSFORM RESPONSE ##
|
||||
return self._transform_response(
|
||||
response_list=responses,
|
||||
model=model,
|
||||
provider=provider,
|
||||
is_async_invoke=is_async_invoke,
|
||||
batch_data=batch_data,
|
||||
)
|
||||
|
||||
def embeddings( # noqa: PLR0915
|
||||
self,
|
||||
model: str,
|
||||
input: List[str],
|
||||
api_base: Optional[str],
|
||||
model_response: EmbeddingResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]],
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
aembedding: Optional[bool],
|
||||
extra_headers: Optional[dict],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
) -> EmbeddingResponse:
|
||||
credentials, aws_region_name = self._load_credentials(optional_params)
|
||||
|
||||
### TRANSFORMATION ###
|
||||
unencoded_model_id = (
|
||||
optional_params.pop("model_id", None) or model
|
||||
) # default to model if not passed
|
||||
modelId = urllib.parse.quote(unencoded_model_id, safe="")
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params={"aws_region_name": aws_region_name},
|
||||
model=model,
|
||||
model_id=unencoded_model_id,
|
||||
)
|
||||
# Check async invoke needs to be used
|
||||
has_async_invoke = "async_invoke/" in model
|
||||
if has_async_invoke:
|
||||
model = model.replace("async_invoke/", "", 1)
|
||||
provider = self.get_bedrock_embedding_provider(model)
|
||||
if provider is None:
|
||||
raise Exception(
|
||||
f"Unable to determine bedrock embedding provider for model: {model}. "
|
||||
f"Supported providers: {list(get_args(BEDROCK_EMBEDDING_PROVIDERS_LITERAL))}"
|
||||
)
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params = {
|
||||
k: v
|
||||
for k, v in inference_params.items()
|
||||
if k.lower() not in self.aws_authentication_params
|
||||
}
|
||||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
|
||||
data: Optional[CohereEmbeddingRequest] = None
|
||||
batch_data: Optional[List] = None
|
||||
if provider == "cohere":
|
||||
data = BedrockCohereEmbeddingConfig()._transform_request(
|
||||
model=model, input=input, inference_params=inference_params
|
||||
)
|
||||
elif provider == "amazon" and model in [
|
||||
"amazon.titan-embed-image-v1",
|
||||
"amazon.titan-embed-text-v1",
|
||||
"amazon.titan-embed-text-v2:0",
|
||||
]:
|
||||
batch_data = []
|
||||
for i in input:
|
||||
if model == "amazon.titan-embed-image-v1":
|
||||
transformed_request: (
|
||||
AmazonEmbeddingRequest
|
||||
) = AmazonTitanMultimodalEmbeddingG1Config()._transform_request(
|
||||
input=i, inference_params=inference_params
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v1":
|
||||
transformed_request = AmazonTitanG1Config()._transform_request(
|
||||
input=i, inference_params=inference_params
|
||||
)
|
||||
elif model == "amazon.titan-embed-text-v2:0":
|
||||
transformed_request = AmazonTitanV2Config()._transform_request(
|
||||
input=i, inference_params=inference_params
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"Unmapped model. Received={}. Expected={}".format(
|
||||
model,
|
||||
[
|
||||
"amazon.titan-embed-image-v1",
|
||||
"amazon.titan-embed-text-v1",
|
||||
"amazon.titan-embed-text-v2:0",
|
||||
],
|
||||
)
|
||||
)
|
||||
batch_data.append(transformed_request)
|
||||
elif provider == "twelvelabs":
|
||||
batch_data = []
|
||||
for i in input:
|
||||
twelvelabs_request = (
|
||||
TwelveLabsMarengoEmbeddingConfig()._transform_request(
|
||||
input=i,
|
||||
inference_params=inference_params,
|
||||
async_invoke_route=has_async_invoke,
|
||||
model_id=modelId,
|
||||
output_s3_uri=inference_params.get("output_s3_uri"),
|
||||
)
|
||||
)
|
||||
batch_data.append(twelvelabs_request)
|
||||
elif provider == "nova":
|
||||
batch_data = []
|
||||
for i in input:
|
||||
nova_request = AmazonNovaEmbeddingConfig()._transform_request(
|
||||
input=i,
|
||||
inference_params=inference_params,
|
||||
async_invoke_route=has_async_invoke,
|
||||
model_id=modelId,
|
||||
output_s3_uri=inference_params.get("output_s3_uri"),
|
||||
)
|
||||
batch_data.append(nova_request)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint_url, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
),
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
if has_async_invoke:
|
||||
endpoint_url = f"{endpoint_url}/async-invoke"
|
||||
else:
|
||||
endpoint_url = f"{endpoint_url}/model/{modelId}/invoke"
|
||||
|
||||
if batch_data is not None:
|
||||
if aembedding:
|
||||
return self._async_single_func_embeddings( # type: ignore
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
batch_data=batch_data,
|
||||
credentials=credentials,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=endpoint_url,
|
||||
aws_region_name=aws_region_name,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
provider=provider,
|
||||
is_async_invoke=has_async_invoke,
|
||||
)
|
||||
returned_response = self._single_func_embeddings(
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, HTTPHandler)
|
||||
else None
|
||||
),
|
||||
timeout=timeout,
|
||||
batch_data=batch_data,
|
||||
credentials=credentials,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=endpoint_url,
|
||||
aws_region_name=aws_region_name,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
provider=provider,
|
||||
is_async_invoke=has_async_invoke,
|
||||
)
|
||||
if returned_response is None:
|
||||
raise Exception("Unable to map Bedrock request to provider")
|
||||
return returned_response
|
||||
elif data is None:
|
||||
raise Exception("Unable to map Bedrock request to provider")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
prepped = self.get_request_headers( # type: ignore
|
||||
credentials=credentials,
|
||||
aws_region_name=aws_region_name,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=endpoint_url,
|
||||
data=json.dumps(data),
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
## ROUTING ##
|
||||
# Convert CaseInsensitiveDict to regular dict for httpx compatibility
|
||||
headers_for_request = (
|
||||
dict(prepped.headers) if hasattr(prepped, "headers") else {}
|
||||
)
|
||||
return cohere_embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
encoding=encoding,
|
||||
data=data, # type: ignore
|
||||
complete_api_base=prepped.url,
|
||||
api_key=None,
|
||||
aembedding=aembedding,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
headers=headers_for_request,
|
||||
)
|
||||
|
||||
async def _get_async_invoke_status(
|
||||
self, invocation_arn: str, aws_region_name: str, logging_obj=None, **kwargs
|
||||
) -> dict:
|
||||
"""
|
||||
Get the status of an async invoke job using the GetAsyncInvoke operation.
|
||||
|
||||
Args:
|
||||
invocation_arn: The invocation ARN from the async invoke response
|
||||
aws_region_name: AWS region name
|
||||
**kwargs: Additional parameters (credentials, etc.)
|
||||
|
||||
Returns:
|
||||
dict: Status response from AWS Bedrock
|
||||
"""
|
||||
|
||||
# Get AWS credentials using the same method as other Bedrock methods
|
||||
credentials, _ = self._load_credentials(kwargs)
|
||||
|
||||
# Get the runtime endpoint
|
||||
endpoint_url, _ = self.get_runtime_endpoint(
|
||||
api_base=None,
|
||||
aws_bedrock_runtime_endpoint=kwargs.get("aws_bedrock_runtime_endpoint"),
|
||||
aws_region_name=aws_region_name,
|
||||
)
|
||||
|
||||
from urllib.parse import quote
|
||||
|
||||
# Encode the ARN for use in URL path
|
||||
encoded_arn = quote(invocation_arn, safe="")
|
||||
status_url = f"{endpoint_url.rstrip('/')}/async-invoke/{encoded_arn}"
|
||||
|
||||
# Prepare headers for GET request
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
# Use AWSRequest directly for GET requests (get_request_headers hardcodes POST)
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
# Create AWSRequest with GET method and encoded URL
|
||||
request = AWSRequest(
|
||||
method="GET",
|
||||
url=status_url,
|
||||
data=None, # GET request, no body
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
# Sign the request - SigV4Auth will create canonical string from request URL
|
||||
sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name)
|
||||
sigv4.add_auth(request)
|
||||
|
||||
# Prepare the request
|
||||
prepped = request.prepare()
|
||||
|
||||
# LOGGING
|
||||
if logging_obj is not None:
|
||||
# Create custom curl command for GET request
|
||||
masked_headers = logging_obj._get_masked_headers(prepped.headers)
|
||||
formatted_headers = " ".join(
|
||||
[f"-H '{k}: {v}'" for k, v in masked_headers.items()]
|
||||
)
|
||||
custom_curl = "\n\nGET Request Sent from LiteLLM:\n"
|
||||
custom_curl += "curl -X GET \\\n"
|
||||
custom_curl += f"{prepped.url} \\\n"
|
||||
custom_curl += f"{formatted_headers}\n"
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=invocation_arn,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": {"invocation_arn": invocation_arn},
|
||||
"api_base": prepped.url,
|
||||
"headers": prepped.headers,
|
||||
"request_str": custom_curl, # Override with custom GET curl command
|
||||
},
|
||||
)
|
||||
|
||||
# Make the GET request
|
||||
client = get_async_httpx_client(llm_provider=LlmProviders.BEDROCK)
|
||||
response = await client.get(
|
||||
url=prepped.url,
|
||||
headers=prepped.headers,
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
if logging_obj is not None:
|
||||
logging_obj.post_call(
|
||||
input=invocation_arn,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"complete_input_dict": {"invocation_arn": invocation_arn}
|
||||
},
|
||||
)
|
||||
|
||||
# Parse response
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
raise Exception(
|
||||
f"Failed to get async invoke status: {response.status_code} - {response.text}"
|
||||
)
|
||||
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Transformation logic from OpenAI /v1/embeddings format to Bedrock TwelveLabs Marengo /invoke and /async-invoke format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
|
||||
Docs - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-marengo.html
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Union, cast
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
TWELVELABS_EMBEDDING_INPUT_TYPES,
|
||||
TwelveLabsAsyncInvokeRequest,
|
||||
TwelveLabsMarengoEmbeddingRequest,
|
||||
TwelveLabsOutputDataConfig,
|
||||
TwelveLabsS3Location,
|
||||
TwelveLabsS3OutputDataConfig,
|
||||
)
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
|
||||
|
||||
class TwelveLabsMarengoEmbeddingConfig:
|
||||
"""
|
||||
Reference - https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-marengo.html
|
||||
|
||||
Supports text, image, video, and audio inputs.
|
||||
- InvokeModel: text and image inputs
|
||||
- StartAsyncInvoke: video, audio, image, and text inputs
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def get_supported_openai_params(self) -> List[str]:
|
||||
return [
|
||||
"encoding_format",
|
||||
"textTruncate",
|
||||
"embeddingOption",
|
||||
"startSec",
|
||||
"lengthSec",
|
||||
"useFixedLengthSec",
|
||||
"minClipSec",
|
||||
"input_type",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self, non_default_params: dict, optional_params: dict
|
||||
) -> dict:
|
||||
for k, v in non_default_params.items():
|
||||
if k == "encoding_format":
|
||||
# TwelveLabs doesn't have encoding_format, but we can map it to embeddingOption
|
||||
if v == "float":
|
||||
optional_params["embeddingOption"] = ["visual-text", "visual-image"]
|
||||
elif k == "textTruncate":
|
||||
optional_params["textTruncate"] = v
|
||||
elif k == "embeddingOption":
|
||||
optional_params["embeddingOption"] = v
|
||||
elif k == "input_type":
|
||||
# Map input_type to inputType for Bedrock
|
||||
optional_params["inputType"] = v
|
||||
elif k in ["startSec", "lengthSec", "useFixedLengthSec", "minClipSec"]:
|
||||
optional_params[k] = v
|
||||
return optional_params
|
||||
|
||||
def _extract_bucket_owner_from_params(self, inference_params: dict) -> str:
|
||||
"""
|
||||
Extract bucket owner from inference parameters.
|
||||
"""
|
||||
return inference_params.get("bucketOwner", "")
|
||||
|
||||
def _is_s3_url(self, input: str) -> bool:
|
||||
"""Check if input is an S3 URL."""
|
||||
return input.startswith("s3://")
|
||||
|
||||
def _transform_request(
|
||||
self,
|
||||
input: str,
|
||||
inference_params: dict,
|
||||
async_invoke_route: bool = False,
|
||||
model_id: Optional[str] = None,
|
||||
output_s3_uri: Optional[str] = None,
|
||||
) -> Union[TwelveLabsMarengoEmbeddingRequest, TwelveLabsAsyncInvokeRequest]:
|
||||
"""
|
||||
Transform OpenAI-style input to TwelveLabs Marengo format/async-invoke format.
|
||||
|
||||
Supports:
|
||||
- Text inputs (for both invoke and async-invoke)
|
||||
- Image inputs (for both invoke and async-invoke)
|
||||
- Video inputs (async-invoke only)
|
||||
- Audio inputs (async-invoke only)
|
||||
- S3 URLs for all media types (async-invoke only)
|
||||
"""
|
||||
# Get input_type or default to "text"
|
||||
input_type = cast(
|
||||
TWELVELABS_EMBEDDING_INPUT_TYPES,
|
||||
inference_params.get("inputType")
|
||||
or inference_params.get("input_type")
|
||||
or "text",
|
||||
)
|
||||
|
||||
# Validate that async-invoke is used for video/audio
|
||||
if input_type in ["video", "audio"] and not async_invoke_route:
|
||||
raise ValueError(
|
||||
f"Input type '{input_type}' requires async_invoke route. "
|
||||
f"Use model format: 'bedrock/async_invoke/model_id'"
|
||||
)
|
||||
|
||||
transformed_request: TwelveLabsMarengoEmbeddingRequest = {
|
||||
"inputType": input_type
|
||||
}
|
||||
|
||||
if input_type == "text":
|
||||
transformed_request["inputText"] = input
|
||||
# Set default textTruncate if not specified
|
||||
if "textTruncate" not in inference_params:
|
||||
transformed_request["textTruncate"] = "end"
|
||||
|
||||
elif input_type in ["image", "video", "audio"]:
|
||||
if self._is_s3_url(input):
|
||||
# S3 URL input
|
||||
s3_location: TwelveLabsS3Location = {"uri": input}
|
||||
bucket_owner = self._extract_bucket_owner_from_params(inference_params)
|
||||
if bucket_owner:
|
||||
s3_location["bucketOwner"] = bucket_owner
|
||||
|
||||
transformed_request["mediaSource"] = {"s3Location": s3_location}
|
||||
else:
|
||||
# Base64 encoded input
|
||||
if input.startswith("data:"):
|
||||
# Extract base64 data from data URL
|
||||
b64_str = input.split(",", 1)[1] if "," in input else input
|
||||
else:
|
||||
# Direct base64 string
|
||||
from litellm.utils import get_base64_str
|
||||
|
||||
b64_str = get_base64_str(input)
|
||||
|
||||
transformed_request["mediaSource"] = {"base64String": b64_str}
|
||||
|
||||
# Apply any additional inference parameters
|
||||
for k, v in inference_params.items():
|
||||
if k not in [
|
||||
"inputType",
|
||||
"input_type", # Exclude both camelCase and snake_case
|
||||
"inputText",
|
||||
"mediaSource",
|
||||
"bucketOwner", # Don't include bucketOwner in the request
|
||||
]: # Don't override core fields
|
||||
transformed_request[k] = v # type: ignore
|
||||
|
||||
# If async invoke route, wrap in the async invoke format
|
||||
if async_invoke_route and model_id:
|
||||
return self._wrap_async_invoke_request(
|
||||
model_input=transformed_request,
|
||||
model_id=model_id,
|
||||
output_s3_uri=output_s3_uri,
|
||||
)
|
||||
|
||||
return transformed_request
|
||||
|
||||
def _wrap_async_invoke_request(
|
||||
self,
|
||||
model_input: TwelveLabsMarengoEmbeddingRequest,
|
||||
model_id: str,
|
||||
output_s3_uri: Optional[str] = None,
|
||||
) -> TwelveLabsAsyncInvokeRequest:
|
||||
"""
|
||||
Wrap the transformed request in the correct AWS Bedrock async invoke format.
|
||||
|
||||
Args:
|
||||
model_input: The transformed TwelveLabs Marengo embedding request
|
||||
model_id: The model identifier (without async_invoke prefix)
|
||||
output_s3_uri: Optional S3 URI for output data config
|
||||
|
||||
Returns:
|
||||
TwelveLabsAsyncInvokeRequest: The wrapped async invoke request
|
||||
"""
|
||||
import urllib.parse
|
||||
|
||||
# Clean the model ID
|
||||
unquoted_model_id = urllib.parse.unquote(model_id)
|
||||
if unquoted_model_id.startswith("async_invoke/"):
|
||||
unquoted_model_id = unquoted_model_id.replace("async_invoke/", "")
|
||||
|
||||
# Validate that the S3 URI is not empty
|
||||
if not output_s3_uri or output_s3_uri.strip() == "":
|
||||
raise ValueError("output_s3_uri cannot be empty for async invoke requests")
|
||||
|
||||
return TwelveLabsAsyncInvokeRequest(
|
||||
modelId=unquoted_model_id,
|
||||
modelInput=model_input,
|
||||
outputDataConfig=TwelveLabsOutputDataConfig(
|
||||
s3OutputDataConfig=TwelveLabsS3OutputDataConfig(s3Uri=output_s3_uri)
|
||||
),
|
||||
)
|
||||
|
||||
def _transform_response(
|
||||
self, response_list: List[dict], model: str
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Transform TwelveLabs response to OpenAI format.
|
||||
Handles the actual TwelveLabs response format: {"data": [{"embedding": [...]}]}
|
||||
"""
|
||||
embeddings: List[Embedding] = []
|
||||
total_tokens = 0
|
||||
|
||||
for response in response_list:
|
||||
# TwelveLabs response format has a "data" field containing the embeddings
|
||||
if "data" in response and isinstance(response["data"], list):
|
||||
for item in response["data"]:
|
||||
if "embedding" in item:
|
||||
# Single embedding response
|
||||
embedding = Embedding(
|
||||
embedding=item["embedding"],
|
||||
index=len(embeddings),
|
||||
object="embedding",
|
||||
)
|
||||
embeddings.append(embedding)
|
||||
|
||||
# Estimate token count (rough approximation)
|
||||
if "inputTextTokenCount" in item:
|
||||
total_tokens += item["inputTextTokenCount"]
|
||||
else:
|
||||
# Rough estimate: 1 token per 4 characters for text, or use embedding size
|
||||
total_tokens += len(item["embedding"]) // 4
|
||||
elif "embedding" in response:
|
||||
# Direct embedding response (fallback for other formats)
|
||||
embedding = Embedding(
|
||||
embedding=response["embedding"],
|
||||
index=len(embeddings),
|
||||
object="embedding",
|
||||
)
|
||||
embeddings.append(embedding)
|
||||
|
||||
# Estimate token count (rough approximation)
|
||||
if "inputTextTokenCount" in response:
|
||||
total_tokens += response["inputTextTokenCount"]
|
||||
else:
|
||||
# Rough estimate: 1 token per 4 characters for text
|
||||
total_tokens += len(response.get("inputText", "")) // 4
|
||||
elif "embeddings" in response:
|
||||
# Multiple embeddings response (from video/audio)
|
||||
for i, emb in enumerate(response["embeddings"]):
|
||||
embedding = Embedding(
|
||||
embedding=emb["embedding"],
|
||||
index=len(embeddings),
|
||||
object="embedding",
|
||||
)
|
||||
embeddings.append(embedding)
|
||||
total_tokens += len(emb["embedding"]) // 4 # Rough estimate
|
||||
|
||||
usage = Usage(prompt_tokens=total_tokens, total_tokens=total_tokens)
|
||||
|
||||
return EmbeddingResponse(data=embeddings, model=model, usage=usage)
|
||||
|
||||
def _transform_async_invoke_response(
|
||||
self, response: dict, model: str
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Transform async invoke response (invocation ARN) to OpenAI format.
|
||||
|
||||
AWS async invoke returns:
|
||||
{
|
||||
"invocationArn": "arn:aws:bedrock:us-east-1:123456789012:async-invoke/abc123"
|
||||
}
|
||||
|
||||
We transform this to a job-like embedding response:
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding_job_id:1234567890",
|
||||
"embedding": [],
|
||||
"index": 0
|
||||
}
|
||||
],
|
||||
"model": "model",
|
||||
"usage": {}
|
||||
}
|
||||
"""
|
||||
invocation_arn = response.get("invocationArn", "")
|
||||
|
||||
# Create a placeholder embedding object for the job
|
||||
embedding = Embedding(
|
||||
embedding=[], # Empty embedding for async jobs
|
||||
index=0,
|
||||
object="embedding",
|
||||
)
|
||||
|
||||
# Create usage object (empty for async jobs)
|
||||
usage = Usage(prompt_tokens=0, total_tokens=0)
|
||||
|
||||
# Create hidden params with job ID
|
||||
from litellm.types.llms.base import HiddenParams
|
||||
|
||||
hidden_params = HiddenParams()
|
||||
setattr(hidden_params, "_invocation_arn", invocation_arn)
|
||||
|
||||
return EmbeddingResponse(
|
||||
data=[embedding],
|
||||
model=model,
|
||||
usage=usage,
|
||||
hidden_params=hidden_params,
|
||||
)
|
||||
@@ -0,0 +1,210 @@
|
||||
import asyncio
|
||||
import base64
|
||||
from typing import Any, Coroutine, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm import LlmProviders
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.types.llms.openai import (
|
||||
FileContentRequest,
|
||||
HttpxBinaryResponseContent,
|
||||
)
|
||||
from litellm.types.utils import SpecialEnums
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
|
||||
|
||||
class BedrockFilesHandler(BaseAWSLLM):
|
||||
"""
|
||||
Handles downloading files from S3 for Bedrock batch processing.
|
||||
|
||||
This implementation downloads files from S3 buckets where Bedrock
|
||||
stores batch output files.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=LlmProviders.BEDROCK,
|
||||
)
|
||||
|
||||
def _extract_s3_uri_from_file_id(self, file_id: str) -> str:
|
||||
"""
|
||||
Extract S3 URI from encoded file ID.
|
||||
|
||||
The file ID can be in two formats:
|
||||
1. Base64-encoded unified file ID containing: llm_output_file_id,s3://bucket/path
|
||||
2. Direct S3 URI: s3://bucket/path
|
||||
|
||||
Args:
|
||||
file_id: Encoded file ID or direct S3 URI
|
||||
|
||||
Returns:
|
||||
S3 URI (e.g., "s3://bucket-name/path/to/file")
|
||||
"""
|
||||
# First, try to decode if it's a base64-encoded unified file ID
|
||||
try:
|
||||
# Add padding if needed
|
||||
padded = file_id + "=" * (-len(file_id) % 4)
|
||||
decoded = base64.urlsafe_b64decode(padded).decode()
|
||||
|
||||
# Check if it's a unified file ID format
|
||||
if decoded.startswith(SpecialEnums.LITELM_MANAGED_FILE_ID_PREFIX.value):
|
||||
# Extract llm_output_file_id from the decoded string
|
||||
if "llm_output_file_id," in decoded:
|
||||
s3_uri = decoded.split("llm_output_file_id,")[1].split(";")[0]
|
||||
return s3_uri
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# If not base64 encoded or doesn't contain llm_output_file_id, assume it's already an S3 URI
|
||||
if file_id.startswith("s3://"):
|
||||
return file_id
|
||||
|
||||
# If it doesn't start with s3://, assume it's a direct S3 URI and add the prefix
|
||||
return f"s3://{file_id}"
|
||||
|
||||
def _parse_s3_uri(self, s3_uri: str) -> Tuple[str, str]:
|
||||
"""
|
||||
Parse S3 URI to extract bucket name and object key.
|
||||
|
||||
Args:
|
||||
s3_uri: S3 URI (e.g., "s3://bucket-name/path/to/file")
|
||||
|
||||
Returns:
|
||||
Tuple of (bucket_name, object_key)
|
||||
"""
|
||||
if not s3_uri.startswith("s3://"):
|
||||
raise ValueError(
|
||||
f"Invalid S3 URI format: {s3_uri}. Expected format: s3://bucket-name/path/to/file"
|
||||
)
|
||||
|
||||
# Remove 's3://' prefix
|
||||
path = s3_uri[5:]
|
||||
|
||||
if "/" in path:
|
||||
bucket_name, object_key = path.split("/", 1)
|
||||
else:
|
||||
bucket_name = path
|
||||
object_key = ""
|
||||
|
||||
return bucket_name, object_key
|
||||
|
||||
async def afile_content(
|
||||
self,
|
||||
file_content_request: FileContentRequest,
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> HttpxBinaryResponseContent:
|
||||
"""
|
||||
Download file content from S3 bucket for Bedrock files.
|
||||
|
||||
Args:
|
||||
file_content_request: Contains file_id (encoded or S3 URI)
|
||||
optional_params: Optional parameters containing AWS credentials
|
||||
timeout: Request timeout
|
||||
max_retries: Max retry attempts
|
||||
|
||||
Returns:
|
||||
HttpxBinaryResponseContent: Binary content wrapped in compatible response format
|
||||
"""
|
||||
import boto3
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
file_id = file_content_request.get("file_id")
|
||||
if not file_id:
|
||||
raise ValueError("file_id is required in file_content_request")
|
||||
|
||||
# Extract S3 URI from file ID
|
||||
s3_uri = self._extract_s3_uri_from_file_id(file_id)
|
||||
bucket_name, object_key = self._parse_s3_uri(s3_uri)
|
||||
|
||||
# Get AWS credentials
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=optional_params, model=""
|
||||
)
|
||||
credentials: Credentials = self.get_credentials(
|
||||
aws_access_key_id=optional_params.get("aws_access_key_id"),
|
||||
aws_secret_access_key=optional_params.get("aws_secret_access_key"),
|
||||
aws_session_token=optional_params.get("aws_session_token"),
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=optional_params.get("aws_session_name"),
|
||||
aws_profile_name=optional_params.get("aws_profile_name"),
|
||||
aws_role_name=optional_params.get("aws_role_name"),
|
||||
aws_web_identity_token=optional_params.get("aws_web_identity_token"),
|
||||
aws_sts_endpoint=optional_params.get("aws_sts_endpoint"),
|
||||
)
|
||||
|
||||
# Create S3 client
|
||||
s3_client = boto3.client(
|
||||
"s3",
|
||||
aws_access_key_id=credentials.access_key,
|
||||
aws_secret_access_key=credentials.secret_key,
|
||||
aws_session_token=credentials.token,
|
||||
region_name=aws_region_name,
|
||||
verify=self._get_ssl_verify(),
|
||||
)
|
||||
|
||||
# Download file from S3
|
||||
try:
|
||||
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
|
||||
file_content = response["Body"].read()
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Failed to download file from S3: {s3_uri}. Error: {str(e)}"
|
||||
)
|
||||
|
||||
# Create mock HTTP response
|
||||
mock_response = httpx.Response(
|
||||
status_code=200,
|
||||
content=file_content,
|
||||
headers={"content-type": "application/octet-stream"},
|
||||
request=httpx.Request(method="GET", url=s3_uri),
|
||||
)
|
||||
|
||||
return HttpxBinaryResponseContent(response=mock_response)
|
||||
|
||||
def file_content(
|
||||
self,
|
||||
_is_async: bool,
|
||||
file_content_request: FileContentRequest,
|
||||
api_base: Optional[str],
|
||||
optional_params: dict,
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
) -> Union[
|
||||
HttpxBinaryResponseContent, Coroutine[Any, Any, HttpxBinaryResponseContent]
|
||||
]:
|
||||
"""
|
||||
Download file content from S3 bucket for Bedrock files.
|
||||
Supports both sync and async operations.
|
||||
|
||||
Args:
|
||||
_is_async: Whether to run asynchronously
|
||||
file_content_request: Contains file_id (encoded or S3 URI)
|
||||
api_base: API base (unused for S3 operations)
|
||||
optional_params: Optional parameters containing AWS credentials
|
||||
timeout: Request timeout
|
||||
max_retries: Max retry attempts
|
||||
|
||||
Returns:
|
||||
HttpxBinaryResponseContent or Coroutine: Binary content wrapped in compatible response format
|
||||
"""
|
||||
if _is_async:
|
||||
return self.afile_content(
|
||||
file_content_request=file_content_request,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
else:
|
||||
return asyncio.run(
|
||||
self.afile_content(
|
||||
file_content_request=file_content_request,
|
||||
optional_params=optional_params,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,772 @@
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
from httpx import Headers, Response
|
||||
from openai.types.file_deleted import FileDeleted
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.files.utils import FilesAPIUtils
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import extract_file_data
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
from litellm.llms.base_llm.files.transformation import (
|
||||
BaseFilesConfig,
|
||||
LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
CreateFileRequest,
|
||||
FileTypes,
|
||||
HttpxBinaryResponseContent,
|
||||
OpenAICreateFileRequestOptionalParams,
|
||||
OpenAIFileObject,
|
||||
PathLike,
|
||||
)
|
||||
from litellm.types.utils import ExtractedFileData, LlmProviders
|
||||
from litellm.utils import get_llm_provider
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockError
|
||||
|
||||
|
||||
class BedrockFilesConfig(BaseAWSLLM, BaseFilesConfig):
|
||||
"""
|
||||
Config for Bedrock Files - handles S3 uploads for Bedrock batch processing
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.jsonl_transformation = BedrockJsonlFilesTransformation()
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.BEDROCK
|
||||
|
||||
@property
|
||||
def file_upload_http_method(self) -> str:
|
||||
"""
|
||||
Bedrock files are uploaded to S3, which requires PUT requests
|
||||
"""
|
||||
return "PUT"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
# No additional headers needed for S3 uploads - AWS credentials handled by BaseAWSLLM
|
||||
return headers
|
||||
|
||||
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
|
||||
"""
|
||||
Helper to extract content from various OpenAI file types and return as string.
|
||||
|
||||
Handles:
|
||||
- Direct content (str, bytes, IO[bytes])
|
||||
- Tuple formats: (filename, content, [content_type], [headers])
|
||||
- PathLike objects
|
||||
"""
|
||||
content: Union[str, bytes] = b""
|
||||
# Extract file content from tuple if necessary
|
||||
if isinstance(openai_file_content, tuple):
|
||||
# Take the second element which is always the file content
|
||||
file_content = openai_file_content[1]
|
||||
else:
|
||||
file_content = openai_file_content
|
||||
|
||||
# Handle different file content types
|
||||
if isinstance(file_content, str):
|
||||
# String content can be used directly
|
||||
content = file_content
|
||||
elif isinstance(file_content, bytes):
|
||||
# Bytes content can be decoded
|
||||
content = file_content
|
||||
elif isinstance(file_content, PathLike): # PathLike
|
||||
with open(str(file_content), "rb") as f:
|
||||
content = f.read()
|
||||
elif hasattr(file_content, "read"): # IO[bytes]
|
||||
# File-like objects need to be read
|
||||
content = file_content.read()
|
||||
|
||||
# Ensure content is string
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8")
|
||||
|
||||
return content
|
||||
|
||||
def _get_s3_object_name_from_batch_jsonl(
|
||||
self,
|
||||
openai_jsonl_content: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""
|
||||
Gets a unique S3 object name for the Bedrock batch processing job
|
||||
|
||||
named as: litellm-bedrock-files/{model}/{uuid}
|
||||
"""
|
||||
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
|
||||
# Remove bedrock/ prefix if present
|
||||
if _model.startswith("bedrock/"):
|
||||
_model = _model[8:]
|
||||
|
||||
# Replace colons with hyphens for Bedrock S3 URI compliance
|
||||
_model = _model.replace(":", "-")
|
||||
|
||||
object_name = f"litellm-bedrock-files-{_model}-{uuid.uuid4()}.jsonl"
|
||||
return object_name
|
||||
|
||||
def get_object_name(
|
||||
self, extracted_file_data: ExtractedFileData, purpose: str
|
||||
) -> str:
|
||||
"""
|
||||
Get the object name for the request
|
||||
"""
|
||||
extracted_file_data_content = extracted_file_data.get("content")
|
||||
|
||||
if extracted_file_data_content is None:
|
||||
raise ValueError("file content is required")
|
||||
|
||||
if purpose == "batch":
|
||||
## 1. If jsonl, check if there's a model name
|
||||
file_content = self._get_content_from_openai_file(
|
||||
extracted_file_data_content
|
||||
)
|
||||
|
||||
# Split into lines and parse each line as JSON
|
||||
openai_jsonl_content = [
|
||||
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||
]
|
||||
if len(openai_jsonl_content) > 0:
|
||||
return self._get_s3_object_name_from_batch_jsonl(openai_jsonl_content)
|
||||
|
||||
## 2. If not jsonl, return the filename
|
||||
filename = extracted_file_data.get("filename")
|
||||
if filename:
|
||||
return filename
|
||||
## 3. If no file name, return timestamp
|
||||
return str(int(time.time()))
|
||||
|
||||
def get_complete_file_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: Dict,
|
||||
litellm_params: Dict,
|
||||
data: CreateFileRequest,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete S3 URL for the file upload request
|
||||
"""
|
||||
bucket_name = litellm_params.get("s3_bucket_name") or os.getenv(
|
||||
"AWS_S3_BUCKET_NAME"
|
||||
)
|
||||
if not bucket_name:
|
||||
raise ValueError(
|
||||
"S3 bucket_name is required. Set 's3_bucket_name' in litellm_params or AWS_S3_BUCKET_NAME env var"
|
||||
)
|
||||
|
||||
aws_region_name = self._get_aws_region_name(optional_params, model)
|
||||
|
||||
file_data = data.get("file")
|
||||
purpose = data.get("purpose")
|
||||
if file_data is None:
|
||||
raise ValueError("file is required")
|
||||
if purpose is None:
|
||||
raise ValueError("purpose is required")
|
||||
extracted_file_data = extract_file_data(file_data)
|
||||
object_name = self.get_object_name(extracted_file_data, purpose)
|
||||
|
||||
# S3 endpoint URL format
|
||||
s3_endpoint_url = (
|
||||
optional_params.get("s3_endpoint_url")
|
||||
or f"https://s3.{aws_region_name}.amazonaws.com"
|
||||
)
|
||||
|
||||
return f"{s3_endpoint_url}/{bucket_name}/{object_name}"
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[OpenAICreateFileRequestOptionalParams]:
|
||||
return []
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
return optional_params
|
||||
|
||||
# Providers whose InvokeModel body uses the Converse API format
|
||||
# (messages + inferenceConfig + image blocks). Nova is the primary
|
||||
# example; add others here as they adopt the same schema.
|
||||
CONVERSE_INVOKE_PROVIDERS = ("nova",)
|
||||
|
||||
def _map_openai_to_bedrock_params(
|
||||
self,
|
||||
openai_request_body: Dict[str, Any],
|
||||
provider: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform OpenAI request body to Bedrock-compatible modelInput
|
||||
parameters using existing transformation logic.
|
||||
|
||||
Routes to the correct per-provider transformation so that the
|
||||
resulting dict matches the InvokeModel body that Bedrock expects
|
||||
for batch inference.
|
||||
"""
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
_model = openai_request_body.get("model", "")
|
||||
messages = openai_request_body.get("messages", [])
|
||||
optional_params = {
|
||||
k: v
|
||||
for k, v in openai_request_body.items()
|
||||
if k not in ["model", "messages"]
|
||||
}
|
||||
|
||||
# --- Anthropic: use existing AmazonAnthropicClaudeConfig ---
|
||||
if provider == LlmProviders.ANTHROPIC:
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.anthropic_claude3_transformation import (
|
||||
AmazonAnthropicClaudeConfig,
|
||||
)
|
||||
|
||||
config = AmazonAnthropicClaudeConfig()
|
||||
mapped_params = config.map_openai_params(
|
||||
non_default_params={},
|
||||
optional_params=optional_params,
|
||||
model=_model,
|
||||
drop_params=False,
|
||||
)
|
||||
return config.transform_request(
|
||||
model=_model,
|
||||
messages=messages,
|
||||
optional_params=mapped_params,
|
||||
litellm_params={},
|
||||
headers={},
|
||||
)
|
||||
|
||||
# --- Converse API providers (e.g. Nova): use AmazonConverseConfig
|
||||
# to correctly convert image_url blocks to Bedrock image format
|
||||
# and wrap inference params inside inferenceConfig. ---
|
||||
if provider in self.CONVERSE_INVOKE_PROVIDERS:
|
||||
from litellm.llms.bedrock.chat.converse_transformation import (
|
||||
AmazonConverseConfig,
|
||||
)
|
||||
|
||||
converse_config = AmazonConverseConfig()
|
||||
mapped_params = converse_config.map_openai_params(
|
||||
non_default_params=optional_params,
|
||||
optional_params={},
|
||||
model=_model,
|
||||
drop_params=False,
|
||||
)
|
||||
return converse_config.transform_request(
|
||||
model=_model,
|
||||
messages=messages,
|
||||
optional_params=mapped_params,
|
||||
litellm_params={},
|
||||
headers={},
|
||||
)
|
||||
|
||||
# --- All other providers: passthrough (OpenAI-compatible models
|
||||
# like openai.gpt-oss-*, qwen, deepseek, etc.) ---
|
||||
return {
|
||||
"messages": messages,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
def _transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
self, openai_jsonl_content: List[Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Transforms OpenAI JSONL content to Bedrock batch format
|
||||
|
||||
Bedrock batch format: { "recordId": "alphanumeric string", "modelInput": {JSON body} }
|
||||
Example:
|
||||
{
|
||||
"recordId": "CALL0000001",
|
||||
"modelInput": {
|
||||
"anthropic_version": "bedrock-2023-05-31",
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
bedrock_jsonl_content = []
|
||||
for idx, _openai_jsonl_content in enumerate(openai_jsonl_content):
|
||||
# Extract the request body from OpenAI format
|
||||
openai_body = _openai_jsonl_content.get("body", {})
|
||||
model = openai_body.get("model", "")
|
||||
|
||||
try:
|
||||
model, _, _, _ = get_llm_provider(
|
||||
model=model,
|
||||
custom_llm_provider=None,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"litellm.llms.bedrock.files.transformation.py::_transform_openai_jsonl_content_to_bedrock_jsonl_content() - Error inferring custom_llm_provider - {str(e)}"
|
||||
)
|
||||
|
||||
# Determine provider from model name
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
|
||||
# Transform to Bedrock modelInput format
|
||||
model_input = self._map_openai_to_bedrock_params(
|
||||
openai_request_body=openai_body, provider=provider
|
||||
)
|
||||
|
||||
# Create Bedrock batch record
|
||||
record_id = _openai_jsonl_content.get(
|
||||
"custom_id", f"CALL{str(idx).zfill(7)}"
|
||||
)
|
||||
bedrock_record = {"recordId": record_id, "modelInput": model_input}
|
||||
|
||||
bedrock_jsonl_content.append(bedrock_record)
|
||||
return bedrock_jsonl_content
|
||||
|
||||
def transform_create_file_request(
|
||||
self,
|
||||
model: str,
|
||||
create_file_data: CreateFileRequest,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> Union[bytes, str, dict]:
|
||||
"""
|
||||
Transform file request and return a pre-signed request for S3.
|
||||
This keeps the HTTP handler clean by doing all the signing here.
|
||||
"""
|
||||
file_data = create_file_data.get("file")
|
||||
if file_data is None:
|
||||
raise ValueError("file is required")
|
||||
extracted_file_data = extract_file_data(file_data)
|
||||
extracted_file_data_content = extracted_file_data.get("content")
|
||||
|
||||
if extracted_file_data_content is None:
|
||||
raise ValueError("file content is required")
|
||||
|
||||
# Get and transform the file content
|
||||
if FilesAPIUtils.is_batch_jsonl_file(
|
||||
create_file_data=create_file_data,
|
||||
extracted_file_data=extracted_file_data,
|
||||
):
|
||||
## Transform JSONL content to Bedrock format
|
||||
original_file_content = self._get_content_from_openai_file(
|
||||
extracted_file_data_content
|
||||
)
|
||||
openai_jsonl_content = [
|
||||
json.loads(line)
|
||||
for line in original_file_content.splitlines()
|
||||
if line.strip()
|
||||
]
|
||||
bedrock_jsonl_content = (
|
||||
self._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
openai_jsonl_content
|
||||
)
|
||||
)
|
||||
file_content = "\n".join(json.dumps(item) for item in bedrock_jsonl_content)
|
||||
elif isinstance(extracted_file_data_content, bytes):
|
||||
file_content = extracted_file_data_content.decode("utf-8")
|
||||
elif isinstance(extracted_file_data_content, str):
|
||||
file_content = extracted_file_data_content
|
||||
else:
|
||||
raise ValueError("Unsupported file content type")
|
||||
|
||||
# Get the S3 URL for upload
|
||||
api_base = self.get_complete_file_url(
|
||||
api_base=None,
|
||||
api_key=None,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
data=create_file_data,
|
||||
)
|
||||
|
||||
# Sign the request and return a pre-signed request object
|
||||
signed_headers, signed_body = self._sign_s3_request(
|
||||
content=file_content,
|
||||
api_base=api_base,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
litellm_params["upload_url"] = api_base
|
||||
|
||||
# Return a dict that tells the HTTP handler exactly what to do
|
||||
return {
|
||||
"method": "PUT",
|
||||
"url": api_base,
|
||||
"headers": signed_headers,
|
||||
"data": signed_body or file_content,
|
||||
}
|
||||
|
||||
def _sign_s3_request(
|
||||
self,
|
||||
content: str,
|
||||
api_base: str,
|
||||
optional_params: dict,
|
||||
) -> Tuple[dict, str]:
|
||||
"""
|
||||
Sign S3 PUT request using the same proven logic as S3Logger.
|
||||
Reuses the exact pattern from litellm/integrations/s3_v2.py
|
||||
"""
|
||||
try:
|
||||
import hashlib
|
||||
|
||||
import requests
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
# Get AWS credentials using existing methods
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=optional_params, model=""
|
||||
)
|
||||
credentials = self.get_credentials(
|
||||
aws_access_key_id=optional_params.get("aws_access_key_id"),
|
||||
aws_secret_access_key=optional_params.get("aws_secret_access_key"),
|
||||
aws_session_token=optional_params.get("aws_session_token"),
|
||||
aws_region_name=aws_region_name,
|
||||
aws_session_name=optional_params.get("aws_session_name"),
|
||||
aws_profile_name=optional_params.get("aws_profile_name"),
|
||||
aws_role_name=optional_params.get("aws_role_name"),
|
||||
aws_web_identity_token=optional_params.get("aws_web_identity_token"),
|
||||
aws_sts_endpoint=optional_params.get("aws_sts_endpoint"),
|
||||
)
|
||||
|
||||
# Calculate SHA256 hash of the content (REQUIRED for S3)
|
||||
content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
|
||||
|
||||
# Prepare headers with required S3 headers (same as s3_v2.py)
|
||||
request_headers = {
|
||||
"Content-Type": "application/json", # JSONL files are JSON content
|
||||
"x-amz-content-sha256": content_hash, # REQUIRED by S3
|
||||
"Content-Language": "en",
|
||||
"Cache-Control": "private, immutable, max-age=31536000, s-maxage=0",
|
||||
}
|
||||
|
||||
# Use requests.Request to prepare the request (same pattern as s3_v2.py)
|
||||
req = requests.Request("PUT", api_base, data=content, headers=request_headers)
|
||||
prepped = req.prepare()
|
||||
|
||||
# Sign the request with S3 service
|
||||
aws_request = AWSRequest(
|
||||
method=prepped.method,
|
||||
url=prepped.url,
|
||||
data=prepped.body,
|
||||
headers=prepped.headers,
|
||||
)
|
||||
|
||||
# Get region name for non-LLM API calls (same as s3_v2.py)
|
||||
signing_region = self.get_aws_region_name_for_non_llm_api_calls(
|
||||
aws_region_name=aws_region_name
|
||||
)
|
||||
|
||||
SigV4Auth(credentials, "s3", signing_region).add_auth(aws_request)
|
||||
|
||||
# Return signed headers and body
|
||||
signed_body = aws_request.body
|
||||
if isinstance(signed_body, bytes):
|
||||
signed_body = signed_body.decode("utf-8")
|
||||
elif signed_body is None:
|
||||
signed_body = content # Fallback to original content
|
||||
|
||||
return dict(aws_request.headers), signed_body
|
||||
|
||||
def _convert_https_url_to_s3_uri(self, https_url: str) -> tuple[str, str]:
|
||||
"""
|
||||
Convert HTTPS S3 URL to s3:// URI format.
|
||||
|
||||
Args:
|
||||
https_url: HTTPS S3 URL (e.g., "https://s3.us-west-2.amazonaws.com/bucket/key")
|
||||
|
||||
Returns:
|
||||
Tuple of (s3_uri, filename)
|
||||
|
||||
Example:
|
||||
Input: "https://s3.us-west-2.amazonaws.com/litellm-proxy/file.jsonl"
|
||||
Output: ("s3://litellm-proxy/file.jsonl", "file.jsonl")
|
||||
"""
|
||||
import re
|
||||
|
||||
# Match HTTPS S3 URL patterns
|
||||
# Pattern 1: https://s3.region.amazonaws.com/bucket/key
|
||||
# Pattern 2: https://bucket.s3.region.amazonaws.com/key
|
||||
|
||||
pattern1 = r"https://s3\.([^.]+)\.amazonaws\.com/([^/]+)/(.+)"
|
||||
pattern2 = r"https://([^.]+)\.s3\.([^.]+)\.amazonaws\.com/(.+)"
|
||||
|
||||
match1 = re.match(pattern1, https_url)
|
||||
match2 = re.match(pattern2, https_url)
|
||||
|
||||
if match1:
|
||||
# Pattern: https://s3.region.amazonaws.com/bucket/key
|
||||
region, bucket, key = match1.groups()
|
||||
s3_uri = f"s3://{bucket}/{key}"
|
||||
elif match2:
|
||||
# Pattern: https://bucket.s3.region.amazonaws.com/key
|
||||
bucket, region, key = match2.groups()
|
||||
s3_uri = f"s3://{bucket}/{key}"
|
||||
else:
|
||||
# Fallback: try to extract bucket and key from URL path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
parsed = urlparse(https_url)
|
||||
path_parts = parsed.path.lstrip("/").split("/", 1)
|
||||
if len(path_parts) >= 2:
|
||||
bucket, key = path_parts[0], path_parts[1]
|
||||
s3_uri = f"s3://{bucket}/{key}"
|
||||
else:
|
||||
raise ValueError(f"Unable to parse S3 URL: {https_url}")
|
||||
|
||||
# Extract filename from key
|
||||
filename = key.split("/")[-1] if "/" in key else key
|
||||
|
||||
return s3_uri, filename
|
||||
|
||||
def transform_create_file_response(
|
||||
self,
|
||||
model: Optional[str],
|
||||
raw_response: Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Transform S3 File upload response into OpenAI-style FileObject
|
||||
"""
|
||||
# For S3 uploads, we typically get an ETag and other metadata
|
||||
response_headers = raw_response.headers
|
||||
# Extract S3 object information from the response
|
||||
# S3 PUT object returns ETag and other metadata in headers
|
||||
content_length = response_headers.get("Content-Length", "0")
|
||||
|
||||
# Use the actual upload URL that was used for the S3 upload
|
||||
upload_url = litellm_params.get("upload_url")
|
||||
file_id: str = ""
|
||||
filename: str = ""
|
||||
if upload_url:
|
||||
# Convert HTTPS S3 URL to s3:// URI format
|
||||
file_id, filename = self._convert_https_url_to_s3_uri(upload_url)
|
||||
|
||||
return OpenAIFileObject(
|
||||
purpose="batch", # Default purpose for Bedrock files
|
||||
id=file_id,
|
||||
filename=filename,
|
||||
created_at=int(time.time()), # Current timestamp
|
||||
status="uploaded",
|
||||
bytes=int(content_length) if content_length.isdigit() else 0,
|
||||
object="file",
|
||||
)
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[Dict, Headers]
|
||||
) -> BaseLLMException:
|
||||
return BedrockError(
|
||||
status_code=status_code, message=error_message, headers=headers
|
||||
)
|
||||
|
||||
def transform_retrieve_file_request(
|
||||
self,
|
||||
file_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
raise NotImplementedError("BedrockFilesConfig does not support file retrieval")
|
||||
|
||||
def transform_retrieve_file_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> OpenAIFileObject:
|
||||
raise NotImplementedError("BedrockFilesConfig does not support file retrieval")
|
||||
|
||||
def transform_delete_file_request(
|
||||
self,
|
||||
file_id: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
raise NotImplementedError("BedrockFilesConfig does not support file deletion")
|
||||
|
||||
def transform_delete_file_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> FileDeleted:
|
||||
raise NotImplementedError("BedrockFilesConfig does not support file deletion")
|
||||
|
||||
def transform_list_files_request(
|
||||
self,
|
||||
purpose: Optional[str],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
raise NotImplementedError("BedrockFilesConfig does not support file listing")
|
||||
|
||||
def transform_list_files_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> List[OpenAIFileObject]:
|
||||
raise NotImplementedError("BedrockFilesConfig does not support file listing")
|
||||
|
||||
def transform_file_content_request(
|
||||
self,
|
||||
file_content_request,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> tuple[str, dict]:
|
||||
raise NotImplementedError(
|
||||
"BedrockFilesConfig does not support file content retrieval"
|
||||
)
|
||||
|
||||
def transform_file_content_response(
|
||||
self,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> HttpxBinaryResponseContent:
|
||||
raise NotImplementedError(
|
||||
"BedrockFilesConfig does not support file content retrieval"
|
||||
)
|
||||
|
||||
|
||||
class BedrockJsonlFilesTransformation:
|
||||
"""
|
||||
Transforms OpenAI /v1/files/* requests to Bedrock S3 file uploads for batch processing
|
||||
"""
|
||||
|
||||
def transform_openai_file_content_to_bedrock_file_content(
|
||||
self, openai_file_content: Optional[FileTypes] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Transforms OpenAI FileContentRequest to Bedrock S3 file format
|
||||
"""
|
||||
|
||||
if openai_file_content is None:
|
||||
raise ValueError("contents of file are None")
|
||||
# Read the content of the file
|
||||
file_content = self._get_content_from_openai_file(openai_file_content)
|
||||
|
||||
# Split into lines and parse each line as JSON
|
||||
openai_jsonl_content = [
|
||||
json.loads(line) for line in file_content.splitlines() if line.strip()
|
||||
]
|
||||
bedrock_jsonl_content = (
|
||||
self._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
openai_jsonl_content
|
||||
)
|
||||
)
|
||||
bedrock_jsonl_string = "\n".join(
|
||||
json.dumps(item) for item in bedrock_jsonl_content
|
||||
)
|
||||
object_name = self._get_s3_object_name(
|
||||
openai_jsonl_content=openai_jsonl_content
|
||||
)
|
||||
return bedrock_jsonl_string, object_name
|
||||
|
||||
def _transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
self, openai_jsonl_content: List[Dict[str, Any]]
|
||||
):
|
||||
"""
|
||||
Delegate to the main BedrockFilesConfig transformation method
|
||||
"""
|
||||
config = BedrockFilesConfig()
|
||||
return config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
openai_jsonl_content
|
||||
)
|
||||
|
||||
def _get_s3_object_name(
|
||||
self,
|
||||
openai_jsonl_content: List[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""
|
||||
Gets a unique S3 object name for the Bedrock batch processing job
|
||||
|
||||
named as: litellm-bedrock-files-{model}-{uuid}
|
||||
"""
|
||||
_model = openai_jsonl_content[0].get("body", {}).get("model", "")
|
||||
# Remove bedrock/ prefix if present
|
||||
if _model.startswith("bedrock/"):
|
||||
_model = _model[8:]
|
||||
object_name = f"litellm-bedrock-files-{_model}-{uuid.uuid4()}.jsonl"
|
||||
return object_name
|
||||
|
||||
def _get_content_from_openai_file(self, openai_file_content: FileTypes) -> str:
|
||||
"""
|
||||
Helper to extract content from various OpenAI file types and return as string.
|
||||
|
||||
Handles:
|
||||
- Direct content (str, bytes, IO[bytes])
|
||||
- Tuple formats: (filename, content, [content_type], [headers])
|
||||
- PathLike objects
|
||||
"""
|
||||
content: Union[str, bytes] = b""
|
||||
# Extract file content from tuple if necessary
|
||||
if isinstance(openai_file_content, tuple):
|
||||
# Take the second element which is always the file content
|
||||
file_content = openai_file_content[1]
|
||||
else:
|
||||
file_content = openai_file_content
|
||||
|
||||
# Handle different file content types
|
||||
if isinstance(file_content, str):
|
||||
# String content can be used directly
|
||||
content = file_content
|
||||
elif isinstance(file_content, bytes):
|
||||
# Bytes content can be decoded
|
||||
content = file_content
|
||||
elif isinstance(file_content, PathLike): # PathLike
|
||||
with open(str(file_content), "rb") as f:
|
||||
content = f.read()
|
||||
elif hasattr(file_content, "read"): # IO[bytes]
|
||||
# File-like objects need to be read
|
||||
content = file_content.read()
|
||||
|
||||
# Ensure content is string
|
||||
if isinstance(content, bytes):
|
||||
content = content.decode("utf-8")
|
||||
|
||||
return content
|
||||
|
||||
def transform_s3_bucket_response_to_openai_file_object(
|
||||
self, create_file_data: CreateFileRequest, s3_upload_response: Dict[str, Any]
|
||||
) -> OpenAIFileObject:
|
||||
"""
|
||||
Transforms S3 Bucket upload file response to OpenAI FileObject
|
||||
"""
|
||||
# S3 response typically contains ETag, key, etc.
|
||||
object_key = s3_upload_response.get("Key", "")
|
||||
bucket_name = s3_upload_response.get("Bucket", "")
|
||||
|
||||
# Extract filename from object key
|
||||
filename = object_key.split("/")[-1] if "/" in object_key else object_key
|
||||
|
||||
return OpenAIFileObject(
|
||||
purpose=create_file_data.get("purpose", "batch"),
|
||||
id=f"s3://{bucket_name}/{object_key}",
|
||||
filename=filename,
|
||||
created_at=int(time.time()), # Current timestamp
|
||||
status="uploaded",
|
||||
bytes=s3_upload_response.get("ContentLength", 0),
|
||||
object="file",
|
||||
)
|
||||
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Bedrock Image Edit Module
|
||||
|
||||
Handles image edit operations for Bedrock stability models.
|
||||
"""
|
||||
|
||||
from .handler import BedrockImageEdit
|
||||
|
||||
__all__ = ["BedrockImageEdit"]
|
||||
@@ -0,0 +1,309 @@
|
||||
"""
|
||||
Bedrock Image Edit Handler
|
||||
|
||||
Handles image edit requests for Bedrock stability models.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||
from litellm.llms.bedrock.image_edit.stability_transformation import (
|
||||
BedrockStabilityImageEditConfig,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.awsrequest import AWSPreparedRequest
|
||||
else:
|
||||
AWSPreparedRequest = Any
|
||||
|
||||
|
||||
class BedrockImageEditPreparedRequest(BaseModel):
|
||||
"""
|
||||
Internal/Helper class for preparing the request for bedrock image edit
|
||||
"""
|
||||
|
||||
endpoint_url: str
|
||||
prepped: AWSPreparedRequest
|
||||
body: bytes
|
||||
data: dict
|
||||
|
||||
|
||||
class BedrockImageEdit(BaseAWSLLM):
|
||||
"""
|
||||
Bedrock Image Edit handler
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls, model: str | None):
|
||||
if BedrockStabilityImageEditConfig._is_stability_edit_model(model):
|
||||
return BedrockStabilityImageEditConfig
|
||||
else:
|
||||
raise ValueError(f"Unsupported model for bedrock image edit: {model}")
|
||||
|
||||
def image_edit(
|
||||
self,
|
||||
model: str,
|
||||
image: list,
|
||||
prompt: Optional[str],
|
||||
model_response: ImageResponse,
|
||||
optional_params: dict,
|
||||
logging_obj: LitellmLogging,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
aimage_edit: bool = False,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
logging_obj=logging_obj,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
if aimage_edit is True:
|
||||
return self.async_image_edit(
|
||||
prepared_request=prepared_request,
|
||||
timeout=timeout,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
model_response=model_response,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
model_response = self._transform_response_dict_to_openai_response(
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
data=prepared_request.data,
|
||||
)
|
||||
return model_response
|
||||
|
||||
async def async_image_edit(
|
||||
self,
|
||||
prepared_request: BedrockImageEditPreparedRequest,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
model: str,
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: Optional[str],
|
||||
model_response: ImageResponse,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Asynchronous handler for bedrock image edit
|
||||
"""
|
||||
async_client = client or get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.BEDROCK,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
model_response = self._transform_response_dict_to_openai_response(
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
data=prepared_request.data,
|
||||
model_response=model_response,
|
||||
)
|
||||
return model_response
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
model: str,
|
||||
image: list,
|
||||
prompt: Optional[str],
|
||||
optional_params: dict,
|
||||
api_base: Optional[str],
|
||||
extra_headers: Optional[dict],
|
||||
logging_obj: LitellmLogging,
|
||||
api_key: Optional[str],
|
||||
) -> BedrockImageEditPreparedRequest:
|
||||
"""
|
||||
Prepare the request body, headers, and endpoint URL for the Bedrock Image Edit API
|
||||
|
||||
Args:
|
||||
model (str): The model to use for the image edit
|
||||
image (list): The images to edit
|
||||
prompt (Optional[str]): The prompt for the edit
|
||||
optional_params (dict): The optional parameters for the image edit
|
||||
api_base (Optional[str]): The base URL for the Bedrock API
|
||||
extra_headers (Optional[dict]): The extra headers to include in the request
|
||||
logging_obj (LitellmLogging): The logging object to use for logging
|
||||
api_key (Optional[str]): The API key to use
|
||||
|
||||
Returns:
|
||||
BedrockImageEditPreparedRequest: The prepared request object
|
||||
"""
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params, model
|
||||
)
|
||||
|
||||
# Use the existing ARN-aware provider detection method
|
||||
bedrock_provider = self.get_bedrock_invoke_provider(model)
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
modelId = self.get_bedrock_model_id(
|
||||
model=model,
|
||||
provider=bedrock_provider,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
_, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
|
||||
data = self._get_request_body(
|
||||
model=model,
|
||||
image=image,
|
||||
prompt=prompt,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
# Make POST Request
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=boto3_credentials_info.credentials,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
data=body,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": proxy_endpoint_url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
return BedrockImageEditPreparedRequest(
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
prepped=prepped,
|
||||
body=body,
|
||||
data=data,
|
||||
)
|
||||
|
||||
def _get_request_body(
|
||||
self,
|
||||
model: str,
|
||||
image: list,
|
||||
prompt: Optional[str],
|
||||
optional_params: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the request body for the Bedrock Image Edit API
|
||||
|
||||
Checks the model/provider and transforms the request body accordingly
|
||||
|
||||
Returns:
|
||||
dict: The request body to use for the Bedrock Image Edit API
|
||||
"""
|
||||
config_class = self.get_config_class(model=model)
|
||||
config_instance = config_class()
|
||||
request_body, _ = config_instance.transform_image_edit_request(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
image=image[0] if image else None,
|
||||
image_edit_optional_request_params=optional_params,
|
||||
litellm_params={},
|
||||
headers={},
|
||||
)
|
||||
return dict(request_body)
|
||||
|
||||
def _transform_response_dict_to_openai_response(
|
||||
self,
|
||||
model_response: ImageResponse,
|
||||
model: str,
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: Optional[str],
|
||||
response: httpx.Response,
|
||||
data: dict,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transforms the Image Edit response from Bedrock to OpenAI format
|
||||
"""
|
||||
|
||||
## LOGGING
|
||||
if logging_obj is not None:
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
verbose_logger.debug("raw model_response: %s", response.text)
|
||||
response_dict = response.json()
|
||||
if response_dict is None:
|
||||
raise ValueError("Error in response object format, got None")
|
||||
|
||||
config_class = self.get_config_class(model=model)
|
||||
config_instance = config_class()
|
||||
|
||||
model_response = config_instance.transform_image_edit_response(
|
||||
model=model,
|
||||
raw_response=response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,399 @@
|
||||
"""
|
||||
Bedrock Stability AI Image Edit Transformation
|
||||
|
||||
Handles transformation between OpenAI-compatible format and Bedrock Stability AI Image Edit API format.
|
||||
|
||||
Supported models:
|
||||
- stability.stable-conservative-upscale-v1:0
|
||||
- stability.stable-creative-upscale-v1:0
|
||||
- stability.stable-fast-upscale-v1:0
|
||||
- stability.stable-outpaint-v1:0
|
||||
- stability.stable-image-control-sketch-v1:0
|
||||
- stability.stable-image-control-structure-v1:0
|
||||
- stability.stable-image-erase-object-v1:0
|
||||
- stability.stable-image-inpaint-v1:0
|
||||
- stability.stable-image-remove-background-v1:0
|
||||
- stability.stable-image-search-recolor-v1:0
|
||||
- stability.stable-image-search-replace-v1:0
|
||||
- stability.stable-image-style-guide-v1:0
|
||||
- stability.stable-style-transfer-v1:0
|
||||
|
||||
API Reference: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters.html
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.image_edit.transformation import BaseImageEditConfig
|
||||
from litellm.types.images.main import ImageEditOptionalRequestParams
|
||||
from litellm.types.llms.stability import (
|
||||
OPENAI_SIZE_TO_STABILITY_ASPECT_RATIO,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import FileTypes, ImageObject, ImageResponse
|
||||
from litellm.utils import get_model_info
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BedrockStabilityImageEditConfig(BaseImageEditConfig):
|
||||
"""
|
||||
Configuration for Bedrock Stability AI image edit.
|
||||
|
||||
Supports all Stability image edit operations through Bedrock.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _is_stability_edit_model(cls, model: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Returns True if the model is a Bedrock Stability edit model.
|
||||
|
||||
Bedrock Stability edit models follow this pattern:
|
||||
stability.stable-conservative-upscale-v1:0
|
||||
stability.stable-creative-upscale-v1:0
|
||||
stability.stable-fast-upscale-v1:0
|
||||
stability.stable-outpaint-v1:0
|
||||
stability.stable-image-inpaint-v1:0
|
||||
stability.stable-image-erase-object-v1:0
|
||||
etc.
|
||||
"""
|
||||
if model:
|
||||
model_lower = model.lower()
|
||||
if "stability." in model_lower and any(
|
||||
[
|
||||
"upscale" in model_lower,
|
||||
"outpaint" in model_lower,
|
||||
"inpaint" in model_lower,
|
||||
"erase" in model_lower,
|
||||
"remove-background" in model_lower,
|
||||
"search-recolor" in model_lower,
|
||||
"search-replace" in model_lower,
|
||||
"control-sketch" in model_lower,
|
||||
"control-structure" in model_lower,
|
||||
"style-guide" in model_lower,
|
||||
"style-transfer" in model_lower,
|
||||
]
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Return list of OpenAI params supported by Bedrock Stability.
|
||||
"""
|
||||
return [
|
||||
"n", # Number of images (Stability always returns 1, we can loop)
|
||||
"size", # Maps to aspect_ratio
|
||||
"response_format", # b64_json or url (Stability only returns b64)
|
||||
"mask",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
image_edit_optional_params: ImageEditOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""
|
||||
Map OpenAI parameters to Bedrock Stability parameters.
|
||||
|
||||
OpenAI -> Stability mappings:
|
||||
- size -> aspect_ratio
|
||||
- n -> (handled separately, Stability returns 1 image per request)
|
||||
"""
|
||||
supported_params = self.get_supported_openai_params(model)
|
||||
# Define mapping from OpenAI params to Stability params
|
||||
param_mapping = {
|
||||
"size": "aspect_ratio",
|
||||
# "n" and "response_format" are handled separately
|
||||
}
|
||||
|
||||
# Create a copy to not mutate original - convert TypedDict to regular dict
|
||||
mapped_params: Dict[str, Any] = dict(image_edit_optional_params)
|
||||
|
||||
for k, v in image_edit_optional_params.items():
|
||||
if k in param_mapping:
|
||||
# Map param if mapping exists and value is valid
|
||||
if k == "size" and v in OPENAI_SIZE_TO_STABILITY_ASPECT_RATIO:
|
||||
mapped_params[param_mapping[k]] = OPENAI_SIZE_TO_STABILITY_ASPECT_RATIO[v] # type: ignore
|
||||
# Don't copy "size" itself to final dict
|
||||
elif k == "n":
|
||||
# Store for logic but do not add to outgoing params
|
||||
mapped_params["_n"] = v
|
||||
elif k == "response_format":
|
||||
# Only b64 supported at Stability; store for postprocessing
|
||||
mapped_params["_response_format"] = v
|
||||
elif k not in supported_params:
|
||||
if not drop_params:
|
||||
raise ValueError(
|
||||
f"Parameter {k} is not supported for model {model}. "
|
||||
f"Supported parameters are {supported_params}. "
|
||||
f"Set drop_params=True to drop unsupported parameters."
|
||||
)
|
||||
# Otherwise, param will simply be dropped
|
||||
else:
|
||||
# param is supported and not mapped, keep as-is
|
||||
continue
|
||||
|
||||
# Remove OpenAI params that have been mapped unless they're in stability
|
||||
for mapped in ["size", "n", "response_format"]:
|
||||
if mapped in mapped_params:
|
||||
del mapped_params[mapped]
|
||||
|
||||
return mapped_params
|
||||
|
||||
def transform_image_edit_request( # noqa: PLR0915
|
||||
self,
|
||||
model: str,
|
||||
prompt: Optional[str],
|
||||
image: Optional[FileTypes],
|
||||
image_edit_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Tuple[Dict, Any]:
|
||||
"""
|
||||
Transform OpenAI-style request to Bedrock Stability request format.
|
||||
|
||||
Returns the request body dict that will be JSON-encoded by the handler.
|
||||
"""
|
||||
# Build Bedrock Stability request
|
||||
data: Dict[str, Any] = {
|
||||
"output_format": "png", # Default to PNG
|
||||
}
|
||||
|
||||
# Add prompt only if provided (some models don't require it)
|
||||
if prompt is not None and prompt != "":
|
||||
data["prompt"] = prompt
|
||||
|
||||
# Convert image to base64 if provided
|
||||
if image is not None:
|
||||
image_b64: str
|
||||
if hasattr(image, "read") and callable(getattr(image, "read", None)):
|
||||
# File-like object (e.g., BufferedReader from open())
|
||||
image_bytes = image.read() # type: ignore
|
||||
image_b64 = base64.b64encode(image_bytes).decode("utf-8") # type: ignore
|
||||
elif isinstance(image, bytes):
|
||||
# Raw bytes
|
||||
image_b64 = base64.b64encode(image).decode("utf-8")
|
||||
elif isinstance(image, str):
|
||||
# Already a base64 string
|
||||
image_b64 = image
|
||||
else:
|
||||
# Try to handle as bytes
|
||||
image_b64 = base64.b64encode(bytes(image)).decode("utf-8") # type: ignore
|
||||
|
||||
# For style-transfer models, map image to init_image
|
||||
model_lower = model.lower()
|
||||
if "style-transfer" in model_lower:
|
||||
data["init_image"] = image_b64
|
||||
else:
|
||||
data["image"] = image_b64
|
||||
|
||||
# Add optional params (already mapped in map_openai_params)
|
||||
for key, value in image_edit_optional_request_params.items(): # type: ignore
|
||||
# Skip internal params (prefixed with _)
|
||||
if key.startswith("_") or value is None:
|
||||
continue
|
||||
|
||||
# File-like optional params (mask, init_image, style_image, etc.)
|
||||
if key in ["mask", "init_image", "style_image"]:
|
||||
# Handle case where value might be in a list
|
||||
file_value = value
|
||||
if isinstance(value, list) and len(value) > 0:
|
||||
file_value = value[0]
|
||||
|
||||
if hasattr(file_value, "read") and callable(
|
||||
getattr(file_value, "read", None)
|
||||
):
|
||||
file_bytes = file_value.read() # type: ignore
|
||||
elif isinstance(file_value, bytes):
|
||||
file_bytes = file_value
|
||||
elif isinstance(file_value, str):
|
||||
# Already a base64 string
|
||||
data[key] = file_value
|
||||
continue
|
||||
else:
|
||||
file_bytes = file_value # type: ignore
|
||||
|
||||
if isinstance(file_bytes, bytes):
|
||||
file_b64 = base64.b64encode(file_bytes).decode("utf-8")
|
||||
else:
|
||||
file_b64 = str(file_bytes)
|
||||
data[key] = file_b64
|
||||
continue
|
||||
|
||||
# Numeric fields that need to be converted to int/float
|
||||
numeric_int_fields = ["left", "right", "up", "down", "seed"]
|
||||
numeric_float_fields = [
|
||||
"strength",
|
||||
"creativity",
|
||||
"control_strength",
|
||||
"grow_mask",
|
||||
"fidelity",
|
||||
"composition_fidelity",
|
||||
"style_strength",
|
||||
"change_strength",
|
||||
]
|
||||
|
||||
if key in numeric_int_fields:
|
||||
# Convert to int (these are pixel values for outpaint)
|
||||
try:
|
||||
data[key] = int(value) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
data[key] = value # type: ignore
|
||||
elif key in numeric_float_fields:
|
||||
# Convert to float
|
||||
try:
|
||||
data[key] = float(value) # type: ignore
|
||||
except (ValueError, TypeError):
|
||||
data[key] = value # type: ignore
|
||||
|
||||
# Supported text fields
|
||||
elif key in [
|
||||
"negative_prompt",
|
||||
"aspect_ratio",
|
||||
"output_format",
|
||||
"model",
|
||||
"mode",
|
||||
"style_preset",
|
||||
"select_prompt",
|
||||
"search_prompt",
|
||||
]:
|
||||
data[key] = value # type: ignore
|
||||
|
||||
return data, {}
|
||||
|
||||
def transform_image_edit_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform Bedrock Stability response to OpenAI-compatible ImageResponse.
|
||||
|
||||
Bedrock returns: {"images": ["base64..."], "finish_reasons": [null], "seeds": [123]}
|
||||
OpenAI expects: {"data": [{"b64_json": "base64..."}], "created": timestamp}
|
||||
"""
|
||||
try:
|
||||
response_data = raw_response.json()
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Error parsing Bedrock Stability response: {e}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
# Check for errors in response
|
||||
if "errors" in response_data:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Bedrock Stability error: {response_data['errors']}",
|
||||
status_code=raw_response.status_code,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
# Check finish_reasons
|
||||
finish_reasons = response_data.get("finish_reasons", [])
|
||||
if finish_reasons and finish_reasons[0]:
|
||||
raise self.get_error_class(
|
||||
error_message=f"Bedrock Stability error: {finish_reasons[0]}",
|
||||
status_code=400,
|
||||
headers=raw_response.headers,
|
||||
)
|
||||
|
||||
model_response = ImageResponse()
|
||||
if not model_response.data:
|
||||
model_response.data = []
|
||||
|
||||
# Extract images from response
|
||||
images = response_data.get("images", [])
|
||||
if images:
|
||||
for image_b64 in images:
|
||||
if image_b64:
|
||||
model_response.data.append(
|
||||
ImageObject(
|
||||
b64_json=image_b64,
|
||||
url=None,
|
||||
revised_prompt=None,
|
||||
)
|
||||
)
|
||||
|
||||
if not hasattr(model_response, "_hidden_params"):
|
||||
model_response._hidden_params = {}
|
||||
if "additional_headers" not in model_response._hidden_params:
|
||||
model_response._hidden_params["additional_headers"] = {}
|
||||
|
||||
# Set cost based on model
|
||||
model_info = get_model_info(model, custom_llm_provider="bedrock")
|
||||
cost_per_image = model_info.get("output_cost_per_image", 0)
|
||||
if cost_per_image is not None:
|
||||
model_response._hidden_params["additional_headers"][
|
||||
"llm_provider-x-litellm-response-cost"
|
||||
] = float(cost_per_image)
|
||||
|
||||
return model_response
|
||||
|
||||
def use_multipart_form_data(self) -> bool:
|
||||
"""
|
||||
Bedrock Stability uses JSON format, not multipart/form-data.
|
||||
"""
|
||||
return False
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for the Bedrock Image Edit API.
|
||||
|
||||
For Bedrock, this is handled by the handler which constructs the endpoint URL
|
||||
based on the model ID and AWS region. This method is required by the base class
|
||||
but the actual URL construction happens in BedrockImageEdit.image_edit().
|
||||
|
||||
Returns a placeholder - the real endpoint is constructed in the handler.
|
||||
"""
|
||||
# Bedrock URLs are constructed in the handler using boto3
|
||||
# This is a placeholder for the abstract method requirement
|
||||
return "bedrock://image-edit"
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate environment for Bedrock Stability image edit.
|
||||
|
||||
For Bedrock, AWS credentials are managed by the BaseAWSLLM class.
|
||||
This method validates that headers are properly set up.
|
||||
|
||||
Args:
|
||||
headers: The request headers to validate/update
|
||||
model: The model name being used
|
||||
api_key: Optional API key (not used for Bedrock, which uses AWS credentials)
|
||||
|
||||
Returns:
|
||||
Updated headers dict
|
||||
"""
|
||||
if headers is None:
|
||||
headers = {}
|
||||
|
||||
# Bedrock uses AWS credentials, not API keys
|
||||
# Headers are set up by the handler's get_request_headers() method
|
||||
# This just ensures basic headers are present
|
||||
if "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
@@ -0,0 +1,220 @@
|
||||
import types
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonNovaCanvasColorGuidedGenerationParams,
|
||||
AmazonNovaCanvasColorGuidedRequest,
|
||||
AmazonNovaCanvasImageGenerationConfig,
|
||||
AmazonNovaCanvasInpaintingParams,
|
||||
AmazonNovaCanvasInpaintingRequest,
|
||||
AmazonNovaCanvasRequestBase,
|
||||
AmazonNovaCanvasTextToImageParams,
|
||||
AmazonNovaCanvasTextToImageRequest,
|
||||
AmazonNovaCanvasTextToImageResponse,
|
||||
)
|
||||
from litellm.llms.bedrock.common_utils import get_cached_model_info
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
class AmazonNovaCanvasConfig:
|
||||
"""
|
||||
Reference: https://us-east-1.console.aws.amazon.com/bedrock/home?region=us-east-1#/model-catalog/serverless/amazon.nova-canvas-v1:0
|
||||
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
||||
""" """
|
||||
return ["n", "size", "quality"]
|
||||
|
||||
@classmethod
|
||||
def _is_nova_model(cls, model: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Returns True if the model is a Nova Canvas model
|
||||
|
||||
Nova models follow this pattern:
|
||||
|
||||
"""
|
||||
if model and "amazon.nova-canvas" in model:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def transform_request_body(
|
||||
cls, text: str, optional_params: dict
|
||||
) -> AmazonNovaCanvasRequestBase:
|
||||
"""
|
||||
Transform the request body for Amazon Nova Canvas model
|
||||
"""
|
||||
task_type = optional_params.pop("taskType", "TEXT_IMAGE")
|
||||
image_generation_config = optional_params.pop("imageGenerationConfig", {})
|
||||
|
||||
# Extract model_id parameter to prevent "extraneous key" error from Bedrock API
|
||||
# Following the same pattern as chat completions and embeddings
|
||||
unencoded_model_id = optional_params.pop("model_id", None) # noqa: F841
|
||||
|
||||
image_generation_config = {**image_generation_config, **optional_params}
|
||||
if task_type == "TEXT_IMAGE":
|
||||
text_to_image_params: Dict[str, Any] = image_generation_config.pop(
|
||||
"textToImageParams", {}
|
||||
)
|
||||
text_to_image_params = {"text": text, **text_to_image_params}
|
||||
try:
|
||||
text_to_image_params_typed = AmazonNovaCanvasTextToImageParams(
|
||||
**text_to_image_params # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming text to image params: {e}. Got params: {text_to_image_params}, Expected params: {AmazonNovaCanvasTextToImageParams.__annotations__}"
|
||||
)
|
||||
|
||||
try:
|
||||
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
|
||||
**image_generation_config
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
|
||||
)
|
||||
|
||||
return AmazonNovaCanvasTextToImageRequest(
|
||||
textToImageParams=text_to_image_params_typed,
|
||||
taskType=task_type,
|
||||
imageGenerationConfig=image_generation_config_typed,
|
||||
)
|
||||
if task_type == "COLOR_GUIDED_GENERATION":
|
||||
color_guided_generation_params: Dict[
|
||||
str, Any
|
||||
] = image_generation_config.pop("colorGuidedGenerationParams", {})
|
||||
color_guided_generation_params = {
|
||||
"text": text,
|
||||
**color_guided_generation_params,
|
||||
}
|
||||
try:
|
||||
color_guided_generation_params_typed = AmazonNovaCanvasColorGuidedGenerationParams(
|
||||
**color_guided_generation_params # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming color guided generation params: {e}. Got params: {color_guided_generation_params}, Expected params: {AmazonNovaCanvasColorGuidedGenerationParams.__annotations__}"
|
||||
)
|
||||
|
||||
try:
|
||||
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
|
||||
**image_generation_config
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
|
||||
)
|
||||
|
||||
return AmazonNovaCanvasColorGuidedRequest(
|
||||
taskType=task_type,
|
||||
colorGuidedGenerationParams=color_guided_generation_params_typed,
|
||||
imageGenerationConfig=image_generation_config_typed,
|
||||
)
|
||||
if task_type == "INPAINTING":
|
||||
inpainting_params: Dict[str, Any] = image_generation_config.pop(
|
||||
"inpaintingParams", {}
|
||||
)
|
||||
inpainting_params = {"text": text, **inpainting_params}
|
||||
try:
|
||||
inpainting_params_typed = AmazonNovaCanvasInpaintingParams(
|
||||
**inpainting_params # type: ignore
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming inpainting params: {e}. Got params: {inpainting_params}, Expected params: {AmazonNovaCanvasInpaintingParams.__annotations__}"
|
||||
)
|
||||
|
||||
try:
|
||||
image_generation_config_typed = AmazonNovaCanvasImageGenerationConfig(
|
||||
**image_generation_config
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"Error transforming image generation config: {e}. Got params: {image_generation_config}, Expected params: {AmazonNovaCanvasImageGenerationConfig.__annotations__}"
|
||||
)
|
||||
|
||||
return AmazonNovaCanvasInpaintingRequest(
|
||||
taskType=task_type,
|
||||
inpaintingParams=inpainting_params_typed,
|
||||
imageGenerationConfig=image_generation_config_typed,
|
||||
)
|
||||
raise NotImplementedError(f"Task type {task_type} is not supported")
|
||||
|
||||
@classmethod
|
||||
def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> dict:
|
||||
"""
|
||||
Map the OpenAI params to the Bedrock params
|
||||
"""
|
||||
_size = non_default_params.get("size")
|
||||
if _size is not None:
|
||||
width, height = _size.split("x")
|
||||
optional_params["width"], optional_params["height"] = int(width), int(
|
||||
height
|
||||
)
|
||||
if non_default_params.get("n") is not None:
|
||||
optional_params["numberOfImages"] = non_default_params.get("n")
|
||||
if non_default_params.get("quality") is not None:
|
||||
if non_default_params.get("quality") in ("hd", "premium"):
|
||||
optional_params["quality"] = "premium"
|
||||
if non_default_params.get("quality") == "standard":
|
||||
optional_params["quality"] = "standard"
|
||||
return optional_params
|
||||
|
||||
@classmethod
|
||||
def transform_response_dict_to_openai_response(
|
||||
cls, model_response: ImageResponse, response_dict: dict
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the response dict to the OpenAI response
|
||||
"""
|
||||
|
||||
nova_response = AmazonNovaCanvasTextToImageResponse(**response_dict)
|
||||
openai_images: List[Image] = []
|
||||
for _img in nova_response.get("images", []):
|
||||
openai_images.append(Image(b64_json=_img))
|
||||
|
||||
model_response.data = openai_images
|
||||
return model_response
|
||||
|
||||
@classmethod
|
||||
def cost_calculator(
|
||||
cls,
|
||||
model: str,
|
||||
image_response: ImageResponse,
|
||||
size: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> float:
|
||||
get_model_info = get_cached_model_info()
|
||||
model_info = get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
)
|
||||
|
||||
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
|
||||
num_images: int = 0
|
||||
if image_response.data:
|
||||
num_images = len(image_response.data)
|
||||
return output_cost_per_image * num_images
|
||||
@@ -0,0 +1,164 @@
|
||||
import copy
|
||||
import os
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
from litellm.llms.bedrock.common_utils import get_cached_model_info
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
class AmazonStabilityConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Supported Params for the Amazon / Stable Diffusion models:
|
||||
|
||||
- `cfg_scale` (integer): Default `7`. Between [ 0 .. 35 ]. How strictly the diffusion process adheres to the prompt text (higher values keep your image closer to your prompt)
|
||||
|
||||
- `seed` (float): Default: `0`. Between [ 0 .. 4294967295 ]. Random noise seed (omit this option or use 0 for a random seed)
|
||||
|
||||
- `steps` (array of strings): Default `30`. Between [ 10 .. 50 ]. Number of diffusion steps to run.
|
||||
|
||||
- `width` (integer): Default: `512`. multiple of 64 >= 128. Width of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
|
||||
- `height` (integer): Default: `512`. multiple of 64 >= 128. Height of the image to generate, in pixels, in an increment divible by 64.
|
||||
Engine-specific dimension validation:
|
||||
|
||||
- SDXL Beta: must be between 128x128 and 512x896 (or 896x512); only one dimension can be greater than 512.
|
||||
- SDXL v0.9: must be one of 1024x1024, 1152x896, 1216x832, 1344x768, 1536x640, 640x1536, 768x1344, 832x1216, or 896x1152
|
||||
- SDXL v1.0: same as SDXL v0.9
|
||||
- SD v1.6: must be between 320x320 and 1536x1536
|
||||
"""
|
||||
|
||||
cfg_scale: Optional[int] = None
|
||||
seed: Optional[float] = None
|
||||
steps: Optional[List[str]] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_scale: Optional[int] = None,
|
||||
seed: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
||||
return ["size"]
|
||||
|
||||
@classmethod
|
||||
def map_openai_params(
|
||||
cls,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
):
|
||||
_size = non_default_params.get("size")
|
||||
if _size is not None:
|
||||
width, height = _size.split("x")
|
||||
optional_params["width"] = int(width)
|
||||
optional_params["height"] = int(height)
|
||||
|
||||
return optional_params
|
||||
|
||||
@classmethod
|
||||
def transform_request_body(
|
||||
cls,
|
||||
text: str,
|
||||
optional_params: dict,
|
||||
) -> dict:
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop(
|
||||
"user", None
|
||||
) # make sure user is not passed in for bedrock call
|
||||
|
||||
prompt = text.replace(os.linesep, " ")
|
||||
## LOAD CONFIG
|
||||
config = cls.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
|
||||
return {
|
||||
"text_prompts": [{"text": prompt, "weight": 1}],
|
||||
**inference_params,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def transform_response_dict_to_openai_response(
|
||||
cls, model_response: ImageResponse, response_dict: dict
|
||||
) -> ImageResponse:
|
||||
image_list: List[Image] = []
|
||||
for artifact in response_dict["artifacts"]:
|
||||
_image = Image(b64_json=artifact["base64"])
|
||||
image_list.append(_image)
|
||||
|
||||
model_response.data = image_list
|
||||
|
||||
return model_response
|
||||
|
||||
@classmethod
|
||||
def cost_calculator(
|
||||
cls,
|
||||
model: str,
|
||||
image_response: ImageResponse,
|
||||
size: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> float:
|
||||
optional_params = optional_params or {}
|
||||
|
||||
# see model_prices_and_context_window.json for details on how steps is used
|
||||
# Reference pricing by steps for stability 1: https://aws.amazon.com/bedrock/pricing/
|
||||
_steps = optional_params.get("steps", 50)
|
||||
steps = "max-steps" if _steps > 50 else "50-steps"
|
||||
|
||||
# size is stored in model_prices_and_context_window.json as 1024-x-1024
|
||||
# current size has 1024x1024
|
||||
size = size or "1024-x-1024"
|
||||
model = f"{size}/{steps}/{model}"
|
||||
|
||||
get_model_info = get_cached_model_info()
|
||||
model_info = get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
)
|
||||
|
||||
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
|
||||
num_images: int = 0
|
||||
if image_response.data:
|
||||
num_images = len(image_response.data)
|
||||
return output_cost_per_image * num_images
|
||||
@@ -0,0 +1,128 @@
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonStability3TextToImageRequest,
|
||||
AmazonStability3TextToImageResponse,
|
||||
)
|
||||
from litellm.llms.bedrock.common_utils import get_cached_model_info
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
class AmazonStability3Config:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
|
||||
Stability API Ref: https://platform.stability.ai/docs/api-reference#tag/Generate/paths/~1v2beta~1stable-image~1generate~1sd3/post
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
||||
"""
|
||||
No additional OpenAI params are mapped for stability 3
|
||||
"""
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _is_stability_3_model(cls, model: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Returns True if the model is a Stability 3 model
|
||||
|
||||
Stability 3 models follow this pattern:
|
||||
sd3-large
|
||||
sd3-large-turbo
|
||||
sd3-medium
|
||||
sd3.5-large
|
||||
sd3.5-large-turbo
|
||||
|
||||
Stability ultra models
|
||||
stable-image-ultra-v1
|
||||
"""
|
||||
if model:
|
||||
if "sd3" in model or "sd3.5" in model:
|
||||
return True
|
||||
if "stable-image" in model:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def transform_request_body(
|
||||
cls, text: str, optional_params: dict
|
||||
) -> AmazonStability3TextToImageRequest:
|
||||
"""
|
||||
Transform the request body for the Stability 3 models
|
||||
"""
|
||||
data = AmazonStability3TextToImageRequest(prompt=text, **optional_params)
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def map_openai_params(cls, non_default_params: dict, optional_params: dict) -> dict:
|
||||
"""
|
||||
Map the OpenAI params to the Bedrock params
|
||||
|
||||
No OpenAI params are mapped for Stability 3, so directly return the optional_params
|
||||
"""
|
||||
return optional_params
|
||||
|
||||
@classmethod
|
||||
def transform_response_dict_to_openai_response(
|
||||
cls, model_response: ImageResponse, response_dict: dict
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transform the response dict to the OpenAI response
|
||||
"""
|
||||
|
||||
stability_3_response = AmazonStability3TextToImageResponse(**response_dict)
|
||||
|
||||
finish_reasons = stability_3_response.get("finish_reasons", [])
|
||||
finish_reasons = [reason for reason in finish_reasons if reason]
|
||||
if len(finish_reasons) > 0:
|
||||
raise BedrockError(status_code=400, message="; ".join(finish_reasons))
|
||||
|
||||
openai_images: List[Image] = []
|
||||
for _img in stability_3_response.get("images", []):
|
||||
openai_images.append(Image(b64_json=_img))
|
||||
|
||||
model_response.data = openai_images
|
||||
return model_response
|
||||
|
||||
@classmethod
|
||||
def cost_calculator(
|
||||
cls,
|
||||
model: str,
|
||||
image_response: ImageResponse,
|
||||
size: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> float:
|
||||
get_model_info = get_cached_model_info()
|
||||
model_info = get_model_info(
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
)
|
||||
|
||||
output_cost_per_image: float = model_info.get("output_cost_per_image") or 0.0
|
||||
num_images: int = 0
|
||||
if image_response.data:
|
||||
num_images = len(image_response.data)
|
||||
return output_cost_per_image * num_images
|
||||
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
Transformation logic for Amazon Titan Image Generation.
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
from litellm.utils import get_model_info
|
||||
from litellm.types.llms.bedrock import (
|
||||
AmazonNovaCanvasImageGenerationConfig,
|
||||
AmazonTitanImageGenerationRequestBody,
|
||||
AmazonTitanTextToImageParams,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
class AmazonTitanImageGenerationConfig:
|
||||
"""
|
||||
Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=stability.stable-diffusion-xl-v0
|
||||
"""
|
||||
|
||||
cfg_scale: Optional[int] = None
|
||||
seed: Optional[float] = None
|
||||
steps: Optional[List[str]] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cfg_scale: Optional[int] = None,
|
||||
seed: Optional[float] = None,
|
||||
steps: Optional[List[str]] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _is_titan_model(cls, model: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Returns True if the model is a Titan model
|
||||
|
||||
Titan models follow this pattern:
|
||||
|
||||
"""
|
||||
if model and "amazon.titan" in model:
|
||||
return True
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_supported_openai_params(cls, model: Optional[str] = None) -> List:
|
||||
return ["size", "n", "quality"]
|
||||
|
||||
@classmethod
|
||||
def map_openai_params(
|
||||
cls,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
):
|
||||
from typing import Any, Dict
|
||||
|
||||
image_generation_config: Dict[str, Any] = {}
|
||||
for k, v in non_default_params.items():
|
||||
if k == "size" and v is not None:
|
||||
width, height = v.split("x")
|
||||
image_generation_config["width"] = int(width)
|
||||
image_generation_config["height"] = int(height)
|
||||
elif k == "n" and v is not None:
|
||||
image_generation_config["numberOfImages"] = v
|
||||
elif (
|
||||
k == "quality" and v is not None
|
||||
): # 'auto', 'hd', 'standard', 'high', 'medium', 'low'
|
||||
if v in ("hd", "premium", "high"):
|
||||
image_generation_config["quality"] = "premium"
|
||||
elif v in ("standard", "medium", "low"):
|
||||
image_generation_config["quality"] = "standard"
|
||||
|
||||
if image_generation_config:
|
||||
optional_params["imageGenerationConfig"] = image_generation_config
|
||||
return optional_params
|
||||
|
||||
@classmethod
|
||||
def transform_request_body(
|
||||
cls,
|
||||
text: str,
|
||||
optional_params: dict,
|
||||
) -> AmazonTitanImageGenerationRequestBody:
|
||||
from typing import Any, Dict
|
||||
|
||||
image_generation_config = optional_params.pop("imageGenerationConfig", {})
|
||||
negative_text = optional_params.pop("negativeText", None)
|
||||
text_to_image_params: Dict[str, Any] = {"text": text}
|
||||
if negative_text:
|
||||
text_to_image_params["negativeText"] = negative_text
|
||||
task_type = optional_params.pop("taskType", "TEXT_IMAGE")
|
||||
user_specified_image_generation_config = optional_params.pop(
|
||||
"imageGenerationConfig", {}
|
||||
)
|
||||
image_generation_config = {
|
||||
**image_generation_config,
|
||||
**user_specified_image_generation_config,
|
||||
}
|
||||
return AmazonTitanImageGenerationRequestBody(
|
||||
taskType=task_type,
|
||||
textToImageParams=AmazonTitanTextToImageParams(**text_to_image_params), # type: ignore
|
||||
imageGenerationConfig=AmazonNovaCanvasImageGenerationConfig(
|
||||
**image_generation_config
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def transform_response_dict_to_openai_response(
|
||||
cls, model_response: ImageResponse, response_dict: dict
|
||||
) -> ImageResponse:
|
||||
image_list: List[Image] = []
|
||||
for image in response_dict["images"]:
|
||||
_image = Image(b64_json=image)
|
||||
image_list.append(_image)
|
||||
|
||||
model_response.data = image_list
|
||||
|
||||
return model_response
|
||||
|
||||
@classmethod
|
||||
def cost_calculator(
|
||||
cls,
|
||||
model: str,
|
||||
image_response: ImageResponse,
|
||||
size: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> float:
|
||||
model_info = get_model_info(model=model)
|
||||
output_cost_per_image = model_info.get("output_cost_per_image") or 0.0
|
||||
if not image_response.data:
|
||||
return 0.0
|
||||
num_images = len(image_response.data)
|
||||
return output_cost_per_image * num_images
|
||||
@@ -0,0 +1,24 @@
|
||||
from typing import Optional
|
||||
|
||||
from litellm.llms.bedrock.image_generation.image_handler import BedrockImageGeneration
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
|
||||
def cost_calculator(
|
||||
model: str,
|
||||
image_response: ImageResponse,
|
||||
size: Optional[str] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
) -> float:
|
||||
"""
|
||||
Bedrock image generation cost calculator
|
||||
|
||||
Handles both Stability 1 and Stability 3 models
|
||||
"""
|
||||
config_class = BedrockImageGeneration.get_config_class(model=model)
|
||||
return config_class.cost_calculator(
|
||||
model=model,
|
||||
image_response=image_response,
|
||||
size=size,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
@@ -0,0 +1,333 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||
from litellm.llms.bedrock.image_generation.amazon_nova_canvas_transformation import (
|
||||
AmazonNovaCanvasConfig,
|
||||
)
|
||||
from litellm.llms.bedrock.image_generation.amazon_stability1_transformation import (
|
||||
AmazonStabilityConfig,
|
||||
)
|
||||
from litellm.llms.bedrock.image_generation.amazon_stability3_transformation import (
|
||||
AmazonStability3Config,
|
||||
)
|
||||
from litellm.llms.bedrock.image_generation.amazon_titan_transformation import (
|
||||
AmazonTitanImageGenerationConfig,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.awsrequest import AWSPreparedRequest
|
||||
else:
|
||||
AWSPreparedRequest = Any
|
||||
|
||||
|
||||
class BedrockImagePreparedRequest(BaseModel):
|
||||
"""
|
||||
Internal/Helper class for preparing the request for bedrock image generation
|
||||
"""
|
||||
|
||||
endpoint_url: str
|
||||
prepped: AWSPreparedRequest
|
||||
body: bytes
|
||||
data: dict
|
||||
|
||||
|
||||
BedrockImageConfigClass = Union[
|
||||
type[AmazonTitanImageGenerationConfig],
|
||||
type[AmazonNovaCanvasConfig],
|
||||
type[AmazonStability3Config],
|
||||
type[AmazonStabilityConfig],
|
||||
]
|
||||
|
||||
|
||||
class BedrockImageGeneration(BaseAWSLLM):
|
||||
"""
|
||||
Bedrock Image Generation handler
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_config_class(cls, model: str | None) -> BedrockImageConfigClass:
|
||||
if AmazonTitanImageGenerationConfig._is_titan_model(model):
|
||||
return AmazonTitanImageGenerationConfig
|
||||
elif AmazonNovaCanvasConfig._is_nova_model(model):
|
||||
return AmazonNovaCanvasConfig
|
||||
elif AmazonStability3Config._is_stability_3_model(model):
|
||||
return AmazonStability3Config
|
||||
else:
|
||||
return litellm.AmazonStabilityConfig
|
||||
|
||||
def image_generation(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
optional_params: dict,
|
||||
logging_obj: LitellmLogging,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
aimg_generation: bool = False,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
):
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
if aimg_generation is True:
|
||||
return self.async_image_generation(
|
||||
prepared_request=prepared_request,
|
||||
timeout=timeout,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
model_response=model_response,
|
||||
client=(
|
||||
client
|
||||
if client is not None and isinstance(client, AsyncHTTPHandler)
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
model_response = self._transform_response_dict_to_openai_response(
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
data=prepared_request.data,
|
||||
)
|
||||
return model_response
|
||||
|
||||
async def async_image_generation(
|
||||
self,
|
||||
prepared_request: BedrockImagePreparedRequest,
|
||||
timeout: Optional[Union[float, httpx.Timeout]],
|
||||
model: str,
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Asynchronous handler for bedrock image generation
|
||||
|
||||
Awaits the response from the bedrock image generation endpoint
|
||||
"""
|
||||
async_client = client or get_async_httpx_client(
|
||||
llm_provider=litellm.LlmProviders.BEDROCK,
|
||||
params={"timeout": timeout},
|
||||
)
|
||||
|
||||
try:
|
||||
response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
### FORMAT RESPONSE TO OPENAI FORMAT ###
|
||||
model_response = self._transform_response_dict_to_openai_response(
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
prompt=prompt,
|
||||
response=response,
|
||||
data=prepared_request.data,
|
||||
model_response=model_response,
|
||||
)
|
||||
return model_response
|
||||
|
||||
def _extract_headers_from_optional_params(self, optional_params: dict) -> dict:
|
||||
"""
|
||||
Extract guardrail parameters from optional_params and convert them to headers.
|
||||
"""
|
||||
headers = {}
|
||||
guardrail_identifier = optional_params.pop("guardrailIdentifier", None)
|
||||
guardrail_version = optional_params.pop("guardrailVersion", None)
|
||||
|
||||
if guardrail_identifier is not None:
|
||||
headers["x-amz-bedrock-guardrail-identifier"] = guardrail_identifier
|
||||
if guardrail_version is not None:
|
||||
headers["x-amz-bedrock-guardrail-version"] = guardrail_version
|
||||
|
||||
return headers
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
api_base: Optional[str],
|
||||
extra_headers: Optional[dict],
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: str,
|
||||
api_key: Optional[str],
|
||||
) -> BedrockImagePreparedRequest:
|
||||
"""
|
||||
Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API
|
||||
|
||||
Args:
|
||||
model (str): The model to use for the image generation
|
||||
optional_params (dict): The optional parameters for the image generation
|
||||
api_base (Optional[str]): The base URL for the Bedrock API
|
||||
extra_headers (Optional[dict]): The extra headers to include in the request
|
||||
logging_obj (LitellmLogging): The logging object to use for logging
|
||||
prompt (str): The prompt to use for the image generation
|
||||
Returns:
|
||||
BedrockImagePreparedRequest: The prepared request object
|
||||
|
||||
The BedrockImagePreparedRequest contains:
|
||||
endpoint_url (str): The endpoint URL for the Bedrock Image Generation API
|
||||
prepped (httpx.Request): The prepared request object
|
||||
body (bytes): The request body
|
||||
"""
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params, model
|
||||
)
|
||||
|
||||
# Use the existing ARN-aware provider detection method
|
||||
bedrock_provider = self.get_bedrock_invoke_provider(model)
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
modelId = self.get_bedrock_model_id(
|
||||
model=model,
|
||||
provider=bedrock_provider,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
_, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke"
|
||||
data = self._get_request_body(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
# Make POST Request
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
|
||||
# Extract guardrail parameters and add them as headers
|
||||
guardrail_headers = self._extract_headers_from_optional_params(optional_params)
|
||||
headers.update(guardrail_headers)
|
||||
|
||||
prepped = self.get_request_headers(
|
||||
credentials=boto3_credentials_info.credentials,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
extra_headers=extra_headers,
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
data=body,
|
||||
headers=headers,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": proxy_endpoint_url,
|
||||
"headers": prepped.headers,
|
||||
},
|
||||
)
|
||||
return BedrockImagePreparedRequest(
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
prepped=prepped,
|
||||
body=body,
|
||||
data=data,
|
||||
)
|
||||
|
||||
def _get_request_body(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Get the request body for the Bedrock Image Generation API
|
||||
|
||||
Checks the model/provider and transforms the request body accordingly
|
||||
|
||||
Returns:
|
||||
dict: The request body to use for the Bedrock Image Generation API
|
||||
"""
|
||||
config_class = self.get_config_class(model=model)
|
||||
request_body = config_class.transform_request_body(
|
||||
text=prompt, optional_params=optional_params
|
||||
)
|
||||
return dict(request_body)
|
||||
|
||||
def _transform_response_dict_to_openai_response(
|
||||
self,
|
||||
model_response: ImageResponse,
|
||||
model: str,
|
||||
logging_obj: LitellmLogging,
|
||||
prompt: str,
|
||||
response: httpx.Response,
|
||||
data: dict,
|
||||
) -> ImageResponse:
|
||||
"""
|
||||
Transforms the Image Generation response from Bedrock to OpenAI format
|
||||
"""
|
||||
|
||||
## LOGGING
|
||||
if logging_obj is not None:
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response.text,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
verbose_logger.debug("raw model_response: %s", response.text)
|
||||
response_dict = response.json()
|
||||
if response_dict is None:
|
||||
raise ValueError("Error in response object format, got None")
|
||||
|
||||
config_class = self.get_config_class(model=model)
|
||||
|
||||
config_class.transform_response_dict_to_openai_response(
|
||||
model_response=model_response,
|
||||
response_dict=response_dict,
|
||||
)
|
||||
|
||||
return model_response
|
||||
@@ -0,0 +1,550 @@
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.anthropic.common_utils import AnthropicModelInfo
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.transformation import (
|
||||
AnthropicMessagesConfig,
|
||||
)
|
||||
from litellm.llms.base_llm.anthropic_messages.transformation import (
|
||||
BaseAnthropicMessagesConfig,
|
||||
)
|
||||
from litellm.llms.bedrock.chat.invoke_handler import AWSEventStreamDecoder
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.llms.bedrock.common_utils import (
|
||||
get_anthropic_beta_from_headers,
|
||||
is_claude_4_5_on_bedrock,
|
||||
remove_custom_field_from_tools,
|
||||
)
|
||||
from litellm.types.llms.anthropic import ANTHROPIC_TOOL_SEARCH_BETA_HEADER
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import GenericStreamingChunk
|
||||
from litellm.types.utils import GenericStreamingChunk as GChunk
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class AmazonAnthropicClaudeMessagesConfig(
|
||||
AnthropicMessagesConfig,
|
||||
AmazonInvokeConfig,
|
||||
):
|
||||
"""
|
||||
Call Claude model family in the /v1/messages API spec
|
||||
Supports anthropic_beta parameter for beta features.
|
||||
"""
|
||||
|
||||
DEFAULT_BEDROCK_ANTHROPIC_API_VERSION = "bedrock-2023-05-31"
|
||||
|
||||
# Beta header patterns that are not supported by Bedrock Invoke API
|
||||
# These will be filtered out to prevent 400 "invalid beta flag" errors
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
BaseAnthropicMessagesConfig.__init__(self, **kwargs)
|
||||
AmazonInvokeConfig.__init__(self, **kwargs)
|
||||
|
||||
def validate_anthropic_messages_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[Any],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> Tuple[dict, Optional[str]]:
|
||||
return headers, api_base
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: dict,
|
||||
request_data: dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
model: Optional[str] = None,
|
||||
stream: Optional[bool] = None,
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
return AmazonInvokeConfig.sign_request(
|
||||
self=self,
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
request_data=request_data,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
stream=stream,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
return AmazonInvokeConfig.get_complete_url(
|
||||
self=self,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
def _remove_ttl_from_cache_control(
|
||||
self, anthropic_messages_request: Dict, model: Optional[str] = None
|
||||
) -> None:
|
||||
"""
|
||||
Remove unsupported fields from cache_control for Bedrock.
|
||||
|
||||
Bedrock only supports `type` and `ttl` in cache_control. It does NOT support:
|
||||
- `scope` (e.g., "global") - always removed
|
||||
- `ttl` - removed for older models; Claude 4.5+ supports "5m" and "1h"
|
||||
|
||||
Processes both `system` and `messages` content blocks.
|
||||
|
||||
Args:
|
||||
anthropic_messages_request: The request dictionary to modify in-place
|
||||
model: The model name to check if it supports ttl
|
||||
"""
|
||||
is_claude_4_5 = False
|
||||
if model:
|
||||
is_claude_4_5 = self._is_claude_4_5_on_bedrock(model)
|
||||
|
||||
def _sanitize_cache_control(cache_control: dict) -> None:
|
||||
if not isinstance(cache_control, dict):
|
||||
return
|
||||
# Bedrock doesn't support scope (e.g., "global" for cross-request caching)
|
||||
cache_control.pop("scope", None)
|
||||
# Remove ttl for models that don't support it
|
||||
if "ttl" in cache_control:
|
||||
ttl = cache_control["ttl"]
|
||||
if is_claude_4_5 and ttl in ["5m", "1h"]:
|
||||
return
|
||||
cache_control.pop("ttl", None)
|
||||
|
||||
def _process_content_list(content: list) -> None:
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "cache_control" in item:
|
||||
_sanitize_cache_control(item["cache_control"])
|
||||
|
||||
# Process system (list of content blocks)
|
||||
if "system" in anthropic_messages_request:
|
||||
system = anthropic_messages_request["system"]
|
||||
if isinstance(system, list):
|
||||
_process_content_list(system)
|
||||
|
||||
# Process messages
|
||||
if "messages" in anthropic_messages_request:
|
||||
for message in anthropic_messages_request["messages"]:
|
||||
if isinstance(message, dict) and "content" in message:
|
||||
content = message["content"]
|
||||
if isinstance(content, list):
|
||||
_process_content_list(content)
|
||||
|
||||
def _supports_extended_thinking_on_bedrock(self, model: str) -> bool:
|
||||
"""
|
||||
Check if the model supports extended thinking beta headers on Bedrock.
|
||||
|
||||
On 3rd-party platforms (e.g., Amazon Bedrock), extended thinking is only
|
||||
supported on: Claude Opus 4.5, Claude Opus 4.1, Opus 4, or Sonnet 4.
|
||||
|
||||
Ref: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
True if the model supports extended thinking on Bedrock
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
# Supported models on Bedrock for extended thinking
|
||||
supported_patterns = [
|
||||
"opus-4.5",
|
||||
"opus_4.5",
|
||||
"opus-4-5",
|
||||
"opus_4_5", # Opus 4.5
|
||||
"opus-4.1",
|
||||
"opus_4.1",
|
||||
"opus-4-1",
|
||||
"opus_4_1", # Opus 4.1
|
||||
"opus-4",
|
||||
"opus_4", # Opus 4
|
||||
"sonnet-4",
|
||||
"sonnet_4", # Sonnet 4
|
||||
"sonnet-4.6",
|
||||
"sonnet_4.6",
|
||||
"sonnet-4-6",
|
||||
"sonnet_4_6",
|
||||
"opus-4.6",
|
||||
"opus_4.6",
|
||||
"opus-4-6",
|
||||
"opus_4_6",
|
||||
]
|
||||
|
||||
return any(pattern in model_lower for pattern in supported_patterns)
|
||||
|
||||
def _is_claude_opus_4_5(self, model: str) -> bool:
|
||||
"""
|
||||
Check if the model is Claude Opus 4.5.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
True if the model is Claude Opus 4.5
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
opus_4_5_patterns = [
|
||||
"opus-4.5",
|
||||
"opus_4.5",
|
||||
"opus-4-5",
|
||||
"opus_4_5",
|
||||
]
|
||||
return any(pattern in model_lower for pattern in opus_4_5_patterns)
|
||||
|
||||
def _is_claude_4_5_on_bedrock(self, model: str) -> bool:
|
||||
"""
|
||||
Check if the model is Claude 4.5 on Bedrock.
|
||||
|
||||
Claude Sonnet 4.5, Haiku 4.5, and Opus 4.5 support 1-hour prompt caching.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
True if the model is Claude 4.5
|
||||
"""
|
||||
return is_claude_4_5_on_bedrock(model)
|
||||
|
||||
def _supports_tool_search_on_bedrock(self, model: str) -> bool:
|
||||
"""
|
||||
Check if the model supports tool search on Bedrock.
|
||||
|
||||
On Amazon Bedrock, server-side tool search is supported on Claude Opus 4.5
|
||||
and Claude Sonnet 4.5 with the tool-search-tool-2025-10-19 beta header.
|
||||
|
||||
Ref: https://platform.claude.com/docs/en/agents-and-tools/tool-use/tool-search-tool
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
|
||||
Returns:
|
||||
True if the model supports tool search on Bedrock
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
# Supported models for tool search on Bedrock
|
||||
supported_patterns = [
|
||||
# Opus 4.5
|
||||
"opus-4.5",
|
||||
"opus_4.5",
|
||||
"opus-4-5",
|
||||
"opus_4_5",
|
||||
# Sonnet 4.5
|
||||
"sonnet-4.5",
|
||||
"sonnet_4.5",
|
||||
"sonnet-4-5",
|
||||
"sonnet_4_5",
|
||||
# Opus 4.6
|
||||
"opus-4.6",
|
||||
"opus_4.6",
|
||||
"opus-4-6",
|
||||
"opus_4_6",
|
||||
# sonnet 4.6
|
||||
"sonnet-4.6",
|
||||
"sonnet_4.6",
|
||||
"sonnet-4-6",
|
||||
"sonnet_4_6",
|
||||
]
|
||||
|
||||
return any(pattern in model_lower for pattern in supported_patterns)
|
||||
|
||||
def _get_tool_search_beta_header_for_bedrock(
|
||||
self,
|
||||
model: str,
|
||||
tool_search_used: bool,
|
||||
programmatic_tool_calling_used: bool,
|
||||
input_examples_used: bool,
|
||||
beta_set: set,
|
||||
) -> None:
|
||||
"""
|
||||
Adjust tool search beta header for Bedrock.
|
||||
|
||||
Bedrock requires a different beta header for tool search on Opus 4 models
|
||||
when tool search is used without programmatic tool calling or input examples.
|
||||
|
||||
Note: On Amazon Bedrock, server-side tool search is only supported on Claude Opus 4
|
||||
with the `tool-search-tool-2025-10-19` beta header.
|
||||
|
||||
Ref: https://platform.claude.com/docs/en/agents-and-tools/tool-use/tool-search-tool
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
tool_search_used: Whether tool search is used
|
||||
programmatic_tool_calling_used: Whether programmatic tool calling is used
|
||||
input_examples_used: Whether input examples are used
|
||||
beta_set: The set of beta headers to modify in-place
|
||||
"""
|
||||
if tool_search_used and not (
|
||||
programmatic_tool_calling_used or input_examples_used
|
||||
):
|
||||
beta_set.discard(ANTHROPIC_TOOL_SEARCH_BETA_HEADER)
|
||||
if self._supports_tool_search_on_bedrock(model):
|
||||
beta_set.add("tool-search-tool-2025-10-19")
|
||||
|
||||
def _convert_output_format_to_inline_schema(
|
||||
self,
|
||||
output_format: Dict,
|
||||
anthropic_messages_request: Dict,
|
||||
) -> None:
|
||||
"""
|
||||
Convert Anthropic output_format to inline schema in message content.
|
||||
|
||||
Bedrock Invoke doesn't support the output_format parameter, so we embed
|
||||
the schema directly into the user message content as text instructions.
|
||||
|
||||
This approach adds the schema to the last user message, instructing the model
|
||||
to respond in the specified JSON format.
|
||||
|
||||
Args:
|
||||
output_format: The output_format dict with 'type' and 'schema'
|
||||
anthropic_messages_request: The request dict to modify in-place
|
||||
|
||||
Ref: https://aws.amazon.com/blogs/machine-learning/structured-data-response-with-amazon-bedrock-prompt-engineering-and-tool-use/
|
||||
"""
|
||||
import json
|
||||
|
||||
# Extract schema from output_format
|
||||
schema = output_format.get("schema")
|
||||
if not schema:
|
||||
return
|
||||
|
||||
# Get messages from the request
|
||||
messages = anthropic_messages_request.get("messages", [])
|
||||
if not messages:
|
||||
return
|
||||
|
||||
# Find the last user message
|
||||
last_user_message_idx = None
|
||||
for idx in range(len(messages) - 1, -1, -1):
|
||||
if messages[idx].get("role") == "user":
|
||||
last_user_message_idx = idx
|
||||
break
|
||||
|
||||
if last_user_message_idx is None:
|
||||
return
|
||||
|
||||
last_user_message = messages[last_user_message_idx]
|
||||
content = last_user_message.get("content", [])
|
||||
|
||||
# Ensure content is a list
|
||||
if isinstance(content, str):
|
||||
content = [{"type": "text", "text": content}]
|
||||
last_user_message["content"] = content
|
||||
|
||||
# Add schema as text content to the message
|
||||
schema_text = {"type": "text", "text": json.dumps(schema)}
|
||||
content.append(schema_text)
|
||||
|
||||
def transform_anthropic_messages_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict],
|
||||
anthropic_messages_optional_request_params: Dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> Dict:
|
||||
anthropic_messages_request = AnthropicMessagesConfig.transform_anthropic_messages_request(
|
||||
self=self,
|
||||
model=model,
|
||||
messages=messages,
|
||||
anthropic_messages_optional_request_params=anthropic_messages_optional_request_params,
|
||||
litellm_params=litellm_params,
|
||||
headers=headers,
|
||||
)
|
||||
#########################################################
|
||||
############## BEDROCK Invoke SPECIFIC TRANSFORMATION ###
|
||||
#########################################################
|
||||
|
||||
# 1. anthropic_version is required for all claude models
|
||||
if "anthropic_version" not in anthropic_messages_request:
|
||||
anthropic_messages_request[
|
||||
"anthropic_version"
|
||||
] = self.DEFAULT_BEDROCK_ANTHROPIC_API_VERSION
|
||||
|
||||
# 2. `stream` is not allowed in request body for bedrock invoke
|
||||
if "stream" in anthropic_messages_request:
|
||||
anthropic_messages_request.pop("stream", None)
|
||||
|
||||
# 3. `model` is not allowed in request body for bedrock invoke
|
||||
if "model" in anthropic_messages_request:
|
||||
anthropic_messages_request.pop("model", None)
|
||||
|
||||
# 4. Remove `ttl` field from cache_control in messages (Bedrock doesn't support it for older models)
|
||||
self._remove_ttl_from_cache_control(
|
||||
anthropic_messages_request=anthropic_messages_request, model=model
|
||||
)
|
||||
|
||||
# 5. Convert `output_format` to inline schema (Bedrock invoke doesn't support output_format)
|
||||
output_format = anthropic_messages_request.pop("output_format", None)
|
||||
if output_format:
|
||||
self._convert_output_format_to_inline_schema(
|
||||
output_format=output_format,
|
||||
anthropic_messages_request=anthropic_messages_request,
|
||||
)
|
||||
|
||||
# 5b. Strip `output_config` — Bedrock Invoke doesn't support it
|
||||
# Fixes: https://github.com/BerriAI/litellm/issues/22797
|
||||
anthropic_messages_request.pop("output_config", None)
|
||||
|
||||
# 5a. Remove `custom` field from tools (Bedrock doesn't support it)
|
||||
# Claude Code sends `custom: {defer_loading: true}` on tool definitions,
|
||||
# which causes Bedrock to reject the request with "Extra inputs are not permitted"
|
||||
# Ref: https://github.com/BerriAI/litellm/issues/22847
|
||||
remove_custom_field_from_tools(anthropic_messages_request)
|
||||
|
||||
# 6. AUTO-INJECT beta headers based on features used
|
||||
anthropic_model_info = AnthropicModelInfo()
|
||||
tools = anthropic_messages_optional_request_params.get("tools")
|
||||
messages_typed = cast(List[AllMessageValues], messages)
|
||||
tool_search_used = anthropic_model_info.is_tool_search_used(tools)
|
||||
programmatic_tool_calling_used = (
|
||||
anthropic_model_info.is_programmatic_tool_calling_used(tools)
|
||||
)
|
||||
input_examples_used = anthropic_model_info.is_input_examples_used(tools)
|
||||
|
||||
beta_set = set(get_anthropic_beta_from_headers(headers))
|
||||
auto_betas = anthropic_model_info.get_anthropic_beta_list(
|
||||
model=model,
|
||||
optional_params=anthropic_messages_optional_request_params,
|
||||
computer_tool_used=anthropic_model_info.is_computer_tool_used(tools),
|
||||
prompt_caching_set=False,
|
||||
file_id_used=anthropic_model_info.is_file_id_used(messages_typed),
|
||||
mcp_server_used=anthropic_model_info.is_mcp_server_used(
|
||||
anthropic_messages_optional_request_params.get("mcp_servers")
|
||||
),
|
||||
)
|
||||
beta_set.update(auto_betas)
|
||||
|
||||
self._get_tool_search_beta_header_for_bedrock(
|
||||
model=model,
|
||||
tool_search_used=tool_search_used,
|
||||
programmatic_tool_calling_used=programmatic_tool_calling_used,
|
||||
input_examples_used=input_examples_used,
|
||||
beta_set=beta_set,
|
||||
)
|
||||
|
||||
if "tool-search-tool-2025-10-19" in beta_set:
|
||||
beta_set.add("tool-examples-2025-10-29")
|
||||
|
||||
if beta_set:
|
||||
anthropic_messages_request["anthropic_beta"] = list(beta_set)
|
||||
|
||||
return anthropic_messages_request
|
||||
|
||||
def get_async_streaming_response_iterator(
|
||||
self,
|
||||
model: str,
|
||||
httpx_response: httpx.Response,
|
||||
request_body: dict,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
) -> AsyncIterator:
|
||||
aws_decoder = AmazonAnthropicClaudeMessagesStreamDecoder(
|
||||
model=model,
|
||||
)
|
||||
completion_stream = aws_decoder.aiter_bytes(
|
||||
httpx_response.aiter_bytes(chunk_size=aws_decoder.DEFAULT_CHUNK_SIZE)
|
||||
)
|
||||
# Convert decoded Bedrock events to Server-Sent Events expected by Anthropic clients.
|
||||
return self.bedrock_sse_wrapper(
|
||||
completion_stream=completion_stream,
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
request_body=request_body,
|
||||
)
|
||||
|
||||
async def bedrock_sse_wrapper(
|
||||
self,
|
||||
completion_stream: AsyncIterator[
|
||||
Union[bytes, GenericStreamingChunk, ModelResponseStream, dict]
|
||||
],
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
request_body: dict,
|
||||
):
|
||||
"""
|
||||
Bedrock invoke does not return SSE formatted data. This function is a wrapper to ensure litellm chunks are SSE formatted.
|
||||
"""
|
||||
from litellm.llms.anthropic.experimental_pass_through.messages.streaming_iterator import (
|
||||
BaseAnthropicMessagesStreamingIterator,
|
||||
)
|
||||
|
||||
handler = BaseAnthropicMessagesStreamingIterator(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
request_body=request_body,
|
||||
)
|
||||
|
||||
async for chunk in handler.async_sse_wrapper(completion_stream):
|
||||
yield chunk
|
||||
|
||||
|
||||
class AmazonAnthropicClaudeMessagesStreamDecoder(AWSEventStreamDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
) -> None:
|
||||
"""
|
||||
Iterator to return Bedrock invoke response in anthropic /messages format
|
||||
"""
|
||||
super().__init__(model=model)
|
||||
self.DEFAULT_CHUNK_SIZE = 1024
|
||||
|
||||
def _chunk_parser(
|
||||
self, chunk_data: dict
|
||||
) -> Union[GChunk, ModelResponseStream, dict]:
|
||||
"""
|
||||
Parse the chunk data into anthropic /messages format
|
||||
|
||||
Bedrock returns usage metrics using camelCase keys. Convert these to
|
||||
the Anthropic `/v1/messages` specification so callers receive a
|
||||
consistent response shape when streaming.
|
||||
"""
|
||||
amazon_bedrock_invocation_metrics = chunk_data.pop(
|
||||
"amazon-bedrock-invocationMetrics", {}
|
||||
)
|
||||
if amazon_bedrock_invocation_metrics:
|
||||
anthropic_usage = {}
|
||||
if "inputTokenCount" in amazon_bedrock_invocation_metrics:
|
||||
anthropic_usage["input_tokens"] = amazon_bedrock_invocation_metrics[
|
||||
"inputTokenCount"
|
||||
]
|
||||
if "outputTokenCount" in amazon_bedrock_invocation_metrics:
|
||||
anthropic_usage["output_tokens"] = amazon_bedrock_invocation_metrics[
|
||||
"outputTokenCount"
|
||||
]
|
||||
chunk_data["usage"] = anthropic_usage
|
||||
return chunk_data
|
||||
@@ -0,0 +1,3 @@
|
||||
# /v1/messages
|
||||
|
||||
This folder contains transformation logic for calling bedrock models in the Anthropic /v1/messages API spec.
|
||||
@@ -0,0 +1,249 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, cast
|
||||
|
||||
from httpx import Response
|
||||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockEventStreamDecoderBase, BedrockModelInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.utils import CostResponseTypes
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from httpx import URL
|
||||
|
||||
|
||||
class BedrockPassthroughConfig(
|
||||
BaseAWSLLM, BedrockModelInfo, BedrockEventStreamDecoderBase, BasePassthroughConfig
|
||||
):
|
||||
def is_streaming_request(self, endpoint: str, request_data: dict) -> bool:
|
||||
return "stream" in endpoint
|
||||
|
||||
def _encode_model_id_for_endpoint(self, model_id: str) -> str:
|
||||
"""
|
||||
Encode model_id (especially ARNs) for use in Bedrock endpoints.
|
||||
|
||||
ARNs contain special characters like colons and slashes that need to be
|
||||
properly URL-encoded when used in HTTP request paths. For example:
|
||||
arn:aws:bedrock:us-east-1:123:application-inference-profile/abc123
|
||||
becomes:
|
||||
arn:aws:bedrock:us-east-1:123:application-inference-profile%2Fabc123
|
||||
|
||||
Args:
|
||||
model_id: The model ID or ARN to encode
|
||||
|
||||
Returns:
|
||||
The encoded model_id suitable for use in endpoint URLs
|
||||
"""
|
||||
from litellm.passthrough.utils import CommonUtils
|
||||
import re
|
||||
|
||||
# Create a temporary endpoint with the model_id to check if encoding is needed
|
||||
temp_endpoint = f"/model/{model_id}/converse"
|
||||
encoded_temp_endpoint = CommonUtils.encode_bedrock_runtime_modelid_arn(
|
||||
temp_endpoint
|
||||
)
|
||||
|
||||
# Extract the encoded model_id from the temporary endpoint
|
||||
encoded_model_id_match = re.search(r"/model/([^/]+)/", encoded_temp_endpoint)
|
||||
if encoded_model_id_match:
|
||||
return encoded_model_id_match.group(1)
|
||||
else:
|
||||
# Fallback to original model_id if extraction fails
|
||||
return model_id
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
endpoint: str,
|
||||
request_query_params: Optional[dict],
|
||||
litellm_params: dict,
|
||||
) -> Tuple["URL", str]:
|
||||
optional_params = litellm_params.copy()
|
||||
model_id = optional_params.get("model_id", None)
|
||||
|
||||
aws_region_name = self._get_aws_region_name(
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
model_id=model_id,
|
||||
)
|
||||
|
||||
aws_bedrock_runtime_endpoint = optional_params.get(
|
||||
"aws_bedrock_runtime_endpoint"
|
||||
)
|
||||
endpoint_url, _ = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=aws_region_name,
|
||||
endpoint_type="runtime",
|
||||
)
|
||||
|
||||
# If model_id is provided (e.g., Application Inference Profile ARN), use it in the endpoint
|
||||
# instead of the translated model name
|
||||
if model_id is not None:
|
||||
import re
|
||||
|
||||
# Encode the model_id if it's an ARN to properly handle special characters
|
||||
encoded_model_id = self._encode_model_id_for_endpoint(model_id)
|
||||
|
||||
# Replace the model name in the endpoint with the encoded model_id
|
||||
endpoint = re.sub(r"model/[^/]+/", f"model/{encoded_model_id}/", endpoint)
|
||||
return (
|
||||
self.format_url(endpoint, endpoint_url, request_query_params or {}),
|
||||
endpoint_url,
|
||||
)
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
litellm_params: dict,
|
||||
request_data: Optional[dict],
|
||||
api_base: str,
|
||||
model: Optional[str] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
optional_params = litellm_params.copy()
|
||||
return self._sign_request(
|
||||
service_name="bedrock",
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
request_data=request_data or {},
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
)
|
||||
|
||||
def logging_non_streaming_response(
|
||||
self,
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
httpx_response: Response,
|
||||
request_data: dict,
|
||||
logging_obj: Logging,
|
||||
endpoint: str,
|
||||
) -> Optional["CostResponseTypes"]:
|
||||
from litellm import encoding
|
||||
from litellm.types.utils import LlmProviders, ModelResponse
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
if "invoke" in endpoint:
|
||||
chat_config_model = "invoke/" + model
|
||||
elif "converse" in endpoint:
|
||||
chat_config_model = "converse/" + model
|
||||
else:
|
||||
return None
|
||||
|
||||
provider_chat_config = ProviderConfigManager.get_provider_chat_config(
|
||||
provider=LlmProviders(custom_llm_provider),
|
||||
model=chat_config_model,
|
||||
)
|
||||
|
||||
if provider_chat_config is None:
|
||||
raise ValueError(f"No provider config found for model: {model}")
|
||||
|
||||
litellm_model_response: ModelResponse = provider_chat_config.transform_response(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "no-message-pass-through-endpoint"}],
|
||||
raw_response=httpx_response,
|
||||
model_response=ModelResponse(),
|
||||
logging_obj=logging_obj,
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key="",
|
||||
request_data=request_data,
|
||||
encoding=encoding,
|
||||
)
|
||||
|
||||
return litellm_model_response
|
||||
|
||||
def _convert_raw_bytes_to_str_lines(self, raw_bytes: List[bytes]) -> List[str]:
|
||||
from botocore.eventstream import EventStreamBuffer
|
||||
|
||||
all_chunks = []
|
||||
event_stream_buffer = EventStreamBuffer()
|
||||
for chunk in raw_bytes:
|
||||
event_stream_buffer.add_data(chunk)
|
||||
for event in event_stream_buffer:
|
||||
message = self._parse_message_from_event(event)
|
||||
if message is not None:
|
||||
all_chunks.append(message)
|
||||
|
||||
return all_chunks
|
||||
|
||||
def handle_logging_collected_chunks(
|
||||
self,
|
||||
all_chunks: List[str],
|
||||
litellm_logging_obj: "LiteLLMLoggingObj",
|
||||
model: str,
|
||||
custom_llm_provider: str,
|
||||
endpoint: str,
|
||||
) -> Optional["CostResponseTypes"]:
|
||||
"""
|
||||
1. Convert all_chunks to a ModelResponseStream
|
||||
2. combine model_response_stream to model_response
|
||||
3. Return the model_response
|
||||
"""
|
||||
|
||||
from litellm.litellm_core_utils.streaming_handler import (
|
||||
convert_generic_chunk_to_model_response_stream,
|
||||
generic_chunk_has_all_required_fields,
|
||||
)
|
||||
from litellm.llms.bedrock.chat import get_bedrock_event_stream_decoder
|
||||
from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import (
|
||||
AmazonInvokeConfig,
|
||||
)
|
||||
from litellm.main import stream_chunk_builder
|
||||
from litellm.types.utils import GenericStreamingChunk, ModelResponseStream
|
||||
|
||||
all_translated_chunks = []
|
||||
if "invoke" in endpoint:
|
||||
invoke_provider = AmazonInvokeConfig.get_bedrock_invoke_provider(model)
|
||||
if invoke_provider is None:
|
||||
raise ValueError(
|
||||
f"Invalid invoke provider: {invoke_provider}, for model: {model}"
|
||||
)
|
||||
obj = get_bedrock_event_stream_decoder(
|
||||
invoke_provider=invoke_provider,
|
||||
model=model,
|
||||
sync_stream=True,
|
||||
json_mode=False,
|
||||
)
|
||||
elif "converse" in endpoint:
|
||||
obj = get_bedrock_event_stream_decoder(
|
||||
invoke_provider=None,
|
||||
model=model,
|
||||
sync_stream=True,
|
||||
json_mode=False,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
for chunk in all_chunks:
|
||||
message = json.loads(chunk)
|
||||
translated_chunk = obj._chunk_parser(chunk_data=message)
|
||||
|
||||
if isinstance(
|
||||
translated_chunk, dict
|
||||
) and generic_chunk_has_all_required_fields(cast(dict, translated_chunk)):
|
||||
chunk_obj = convert_generic_chunk_to_model_response_stream(
|
||||
cast(GenericStreamingChunk, translated_chunk)
|
||||
)
|
||||
elif isinstance(translated_chunk, ModelResponseStream):
|
||||
chunk_obj = translated_chunk
|
||||
else:
|
||||
continue
|
||||
|
||||
all_translated_chunks.append(chunk_obj)
|
||||
|
||||
if len(all_translated_chunks) > 0:
|
||||
model_response = stream_chunk_builder(
|
||||
chunks=all_translated_chunks,
|
||||
logging_obj=litellm_logging_obj,
|
||||
)
|
||||
return model_response
|
||||
return None
|
||||
@@ -0,0 +1,307 @@
|
||||
"""
|
||||
This file contains the handler for AWS Bedrock Nova Sonic realtime API.
|
||||
|
||||
This uses aws_sdk_bedrock_runtime for bidirectional streaming with Nova Sonic.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from .transformation import BedrockRealtimeConfig
|
||||
|
||||
|
||||
class BedrockRealtime(BaseAWSLLM):
|
||||
"""Handler for Bedrock Nova Sonic realtime speech-to-speech API."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def async_realtime(
|
||||
self,
|
||||
model: str,
|
||||
websocket: Any,
|
||||
logging_obj: LiteLLMLogging,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
aws_region_name: Optional[str] = None,
|
||||
aws_access_key_id: Optional[str] = None,
|
||||
aws_secret_access_key: Optional[str] = None,
|
||||
aws_session_token: Optional[str] = None,
|
||||
aws_role_name: Optional[str] = None,
|
||||
aws_session_name: Optional[str] = None,
|
||||
aws_profile_name: Optional[str] = None,
|
||||
aws_web_identity_token: Optional[str] = None,
|
||||
aws_sts_endpoint: Optional[str] = None,
|
||||
aws_bedrock_runtime_endpoint: Optional[str] = None,
|
||||
aws_external_id: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Establish bidirectional streaming connection with Bedrock Nova Sonic.
|
||||
|
||||
Args:
|
||||
model: Model ID (e.g., 'amazon.nova-sonic-v1:0')
|
||||
websocket: Client WebSocket connection
|
||||
logging_obj: LiteLLM logging object
|
||||
aws_region_name: AWS region
|
||||
Various AWS authentication parameters
|
||||
"""
|
||||
try:
|
||||
from aws_sdk_bedrock_runtime.client import (
|
||||
BedrockRuntimeClient,
|
||||
InvokeModelWithBidirectionalStreamOperationInput,
|
||||
)
|
||||
from aws_sdk_bedrock_runtime.config import Config
|
||||
from smithy_aws_core.identity.environment import (
|
||||
EnvironmentCredentialsResolver,
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Missing aws_sdk_bedrock_runtime. Install with: pip install aws-sdk-bedrock-runtime"
|
||||
)
|
||||
|
||||
# Get AWS region
|
||||
if aws_region_name is None:
|
||||
optional_params = {
|
||||
"aws_region_name": aws_region_name,
|
||||
}
|
||||
aws_region_name = self._get_aws_region_name(optional_params, model)
|
||||
|
||||
# Get endpoint URL
|
||||
if api_base is not None:
|
||||
endpoint_uri = api_base
|
||||
elif aws_bedrock_runtime_endpoint is not None:
|
||||
endpoint_uri = aws_bedrock_runtime_endpoint
|
||||
else:
|
||||
endpoint_uri = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com"
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"Bedrock Realtime: Connecting to {endpoint_uri} with model {model}"
|
||||
)
|
||||
|
||||
# Initialize Bedrock client with aws_sdk_bedrock_runtime
|
||||
config = Config(
|
||||
endpoint_uri=endpoint_uri,
|
||||
region=aws_region_name,
|
||||
aws_credentials_identity_resolver=EnvironmentCredentialsResolver(),
|
||||
)
|
||||
bedrock_client = BedrockRuntimeClient(config=config)
|
||||
|
||||
transformation_config = BedrockRealtimeConfig()
|
||||
|
||||
try:
|
||||
# Initialize the bidirectional stream
|
||||
bedrock_stream = (
|
||||
await bedrock_client.invoke_model_with_bidirectional_stream(
|
||||
InvokeModelWithBidirectionalStreamOperationInput(model_id=model)
|
||||
)
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Bedrock Realtime: Bidirectional stream established"
|
||||
)
|
||||
|
||||
# Track state for transformation
|
||||
session_state = {
|
||||
"current_output_item_id": None,
|
||||
"current_response_id": None,
|
||||
"current_conversation_id": None,
|
||||
"current_delta_chunks": None,
|
||||
"current_item_chunks": None,
|
||||
"current_delta_type": None,
|
||||
"session_configuration_request": None,
|
||||
}
|
||||
|
||||
# Create tasks for bidirectional forwarding
|
||||
client_to_bedrock_task = asyncio.create_task(
|
||||
self._forward_client_to_bedrock(
|
||||
websocket,
|
||||
bedrock_stream,
|
||||
transformation_config,
|
||||
model,
|
||||
session_state,
|
||||
)
|
||||
)
|
||||
|
||||
bedrock_to_client_task = asyncio.create_task(
|
||||
self._forward_bedrock_to_client(
|
||||
bedrock_stream,
|
||||
websocket,
|
||||
transformation_config,
|
||||
model,
|
||||
logging_obj,
|
||||
session_state,
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for both tasks to complete
|
||||
await asyncio.gather(
|
||||
client_to_bedrock_task,
|
||||
bedrock_to_client_task,
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error in BedrockRealtime.async_realtime: {e}"
|
||||
)
|
||||
try:
|
||||
await websocket.close(code=1011, reason=f"Internal error: {str(e)}")
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
async def _forward_client_to_bedrock(
|
||||
self,
|
||||
client_ws: Any,
|
||||
bedrock_stream: Any,
|
||||
transformation_config: BedrockRealtimeConfig,
|
||||
model: str,
|
||||
session_state: dict,
|
||||
):
|
||||
"""Forward messages from client WebSocket to Bedrock stream."""
|
||||
try:
|
||||
from aws_sdk_bedrock_runtime.models import (
|
||||
BidirectionalInputPayloadPart,
|
||||
InvokeModelWithBidirectionalStreamInputChunk,
|
||||
)
|
||||
|
||||
while True:
|
||||
# Receive message from client
|
||||
message = await client_ws.receive_text()
|
||||
verbose_proxy_logger.debug(
|
||||
f"Bedrock Realtime: Received from client: {message[:200]}"
|
||||
)
|
||||
|
||||
# Transform OpenAI format to Bedrock format
|
||||
transformed_messages = transformation_config.transform_realtime_request(
|
||||
message=message,
|
||||
model=model,
|
||||
session_configuration_request=session_state.get(
|
||||
"session_configuration_request"
|
||||
),
|
||||
)
|
||||
|
||||
# Send transformed messages to Bedrock
|
||||
for bedrock_message in transformed_messages:
|
||||
event = InvokeModelWithBidirectionalStreamInputChunk(
|
||||
value=BidirectionalInputPayloadPart(
|
||||
bytes_=bedrock_message.encode("utf-8")
|
||||
)
|
||||
)
|
||||
await bedrock_stream.input_stream.send(event)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Bedrock Realtime: Sent to Bedrock: {bedrock_message[:200]}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Client to Bedrock forwarding ended: {e}", exc_info=True
|
||||
)
|
||||
# Close the Bedrock stream input
|
||||
try:
|
||||
await bedrock_stream.input_stream.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _forward_bedrock_to_client(
|
||||
self,
|
||||
bedrock_stream: Any,
|
||||
client_ws: Any,
|
||||
transformation_config: BedrockRealtimeConfig,
|
||||
model: str,
|
||||
logging_obj: LiteLLMLogging,
|
||||
session_state: dict,
|
||||
):
|
||||
"""Forward messages from Bedrock stream to client WebSocket."""
|
||||
try:
|
||||
while True:
|
||||
# Receive from Bedrock
|
||||
output = await bedrock_stream.await_output()
|
||||
result = await output[1].receive()
|
||||
|
||||
if result.value and result.value.bytes_:
|
||||
bedrock_response = result.value.bytes_.decode("utf-8")
|
||||
verbose_proxy_logger.debug(
|
||||
f"Bedrock Realtime: Received from Bedrock: {bedrock_response[:200]}"
|
||||
)
|
||||
|
||||
# Transform Bedrock format to OpenAI format
|
||||
from litellm.types.realtime import RealtimeResponseTransformInput
|
||||
|
||||
realtime_response_transform_input: RealtimeResponseTransformInput = {
|
||||
"current_output_item_id": session_state.get(
|
||||
"current_output_item_id"
|
||||
),
|
||||
"current_response_id": session_state.get("current_response_id"),
|
||||
"current_conversation_id": session_state.get(
|
||||
"current_conversation_id"
|
||||
),
|
||||
"current_delta_chunks": session_state.get(
|
||||
"current_delta_chunks"
|
||||
),
|
||||
"current_item_chunks": session_state.get("current_item_chunks"),
|
||||
"current_delta_type": session_state.get("current_delta_type"),
|
||||
"session_configuration_request": session_state.get(
|
||||
"session_configuration_request"
|
||||
),
|
||||
}
|
||||
|
||||
transformed_response = transformation_config.transform_realtime_response(
|
||||
message=bedrock_response,
|
||||
model=model,
|
||||
logging_obj=logging_obj,
|
||||
realtime_response_transform_input=realtime_response_transform_input,
|
||||
)
|
||||
|
||||
# Update session state
|
||||
session_state.update(
|
||||
{
|
||||
"current_output_item_id": transformed_response.get(
|
||||
"current_output_item_id"
|
||||
),
|
||||
"current_response_id": transformed_response.get(
|
||||
"current_response_id"
|
||||
),
|
||||
"current_conversation_id": transformed_response.get(
|
||||
"current_conversation_id"
|
||||
),
|
||||
"current_delta_chunks": transformed_response.get(
|
||||
"current_delta_chunks"
|
||||
),
|
||||
"current_item_chunks": transformed_response.get(
|
||||
"current_item_chunks"
|
||||
),
|
||||
"current_delta_type": transformed_response.get(
|
||||
"current_delta_type"
|
||||
),
|
||||
"session_configuration_request": transformed_response.get(
|
||||
"session_configuration_request"
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
# Send transformed messages to client
|
||||
openai_messages = transformed_response.get("response", [])
|
||||
for openai_message in openai_messages:
|
||||
message_json = json.dumps(openai_message)
|
||||
await client_ws.send_text(message_json)
|
||||
verbose_proxy_logger.debug(
|
||||
f"Bedrock Realtime: Sent to client: {message_json[:200]}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
f"Bedrock to client forwarding ended: {e}", exc_info=True
|
||||
)
|
||||
# Close the client WebSocket
|
||||
try:
|
||||
await client_ws.close()
|
||||
except Exception:
|
||||
pass
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,179 @@
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.types.llms.bedrock import BedrockPreparedRequest
|
||||
from litellm.types.rerank import RerankRequest
|
||||
from litellm.types.utils import RerankResponse
|
||||
|
||||
from ..base_aws_llm import BaseAWSLLM
|
||||
from ..common_utils import BedrockError
|
||||
from .transformation import BedrockRerankConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from botocore.awsrequest import AWSPreparedRequest
|
||||
else:
|
||||
AWSPreparedRequest = Any
|
||||
|
||||
|
||||
class BedrockRerankHandler(BaseAWSLLM):
|
||||
async def arerank(
|
||||
self,
|
||||
prepared_request: BedrockPreparedRequest,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
):
|
||||
if client is None:
|
||||
client = get_async_httpx_client(llm_provider=litellm.LlmProviders.BEDROCK)
|
||||
try:
|
||||
response = await client.post(
|
||||
url=prepared_request["endpoint_url"],
|
||||
headers=dict(prepared_request["prepped"].headers),
|
||||
data=prepared_request["body"],
|
||||
timeout=timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return BedrockRerankConfig()._transform_response(response.json())
|
||||
|
||||
def rerank(
|
||||
self,
|
||||
model: str,
|
||||
query: str,
|
||||
documents: List[Union[str, Dict[str, Any]]],
|
||||
optional_params: dict,
|
||||
logging_obj: LitellmLogging,
|
||||
top_n: Optional[int] = None,
|
||||
rank_fields: Optional[List[str]] = None,
|
||||
return_documents: Optional[bool] = True,
|
||||
max_chunks_per_doc: Optional[int] = None,
|
||||
_is_async: Optional[bool] = False,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
api_base: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
) -> RerankResponse:
|
||||
request_data = RerankRequest(
|
||||
model=model,
|
||||
query=query,
|
||||
documents=documents,
|
||||
top_n=top_n,
|
||||
rank_fields=rank_fields,
|
||||
return_documents=return_documents,
|
||||
)
|
||||
data = BedrockRerankConfig()._transform_request(request_data)
|
||||
|
||||
prepared_request = self._prepare_request(
|
||||
model=model,
|
||||
optional_params=optional_params,
|
||||
api_base=api_base,
|
||||
extra_headers=extra_headers,
|
||||
data=cast(dict, data),
|
||||
)
|
||||
|
||||
logging_obj.pre_call(
|
||||
input=data,
|
||||
api_key="",
|
||||
additional_args={
|
||||
"complete_input_dict": data,
|
||||
"api_base": prepared_request["endpoint_url"],
|
||||
"headers": dict(prepared_request["prepped"].headers),
|
||||
},
|
||||
)
|
||||
|
||||
if _is_async:
|
||||
return self.arerank(prepared_request, timeout=timeout, client=client if client is not None and isinstance(client, AsyncHTTPHandler) else None) # type: ignore
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client()
|
||||
try:
|
||||
response = client.post(
|
||||
url=prepared_request["endpoint_url"],
|
||||
headers=dict(prepared_request["prepped"].headers),
|
||||
data=prepared_request["body"],
|
||||
timeout=timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
logging_obj.post_call(
|
||||
original_response=response.text,
|
||||
api_key="",
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
return BedrockRerankConfig()._transform_response(response_json)
|
||||
|
||||
def _prepare_request(
|
||||
self,
|
||||
model: str,
|
||||
api_base: Optional[str],
|
||||
extra_headers: Optional[dict],
|
||||
data: dict,
|
||||
optional_params: dict,
|
||||
) -> BedrockPreparedRequest:
|
||||
try:
|
||||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
boto3_credentials_info = self._get_boto_credentials_from_optional_params(
|
||||
optional_params, model
|
||||
)
|
||||
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
_, proxy_endpoint_url = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint,
|
||||
aws_region_name=boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
proxy_endpoint_url = proxy_endpoint_url.replace(
|
||||
"bedrock-runtime", "bedrock-agent-runtime"
|
||||
)
|
||||
proxy_endpoint_url = f"{proxy_endpoint_url}/rerank"
|
||||
sigv4 = SigV4Auth(
|
||||
boto3_credentials_info.credentials,
|
||||
"bedrock",
|
||||
boto3_credentials_info.aws_region_name,
|
||||
)
|
||||
# Make POST Request
|
||||
body = json.dumps(data).encode("utf-8")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if extra_headers is not None:
|
||||
headers = {"Content-Type": "application/json", **extra_headers}
|
||||
request = AWSRequest(
|
||||
method="POST", url=proxy_endpoint_url, data=body, headers=headers
|
||||
)
|
||||
sigv4.add_auth(request)
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
request.headers["Authorization"] = extra_headers["Authorization"]
|
||||
prepped = request.prepare()
|
||||
|
||||
return BedrockPreparedRequest(
|
||||
endpoint_url=proxy_endpoint_url,
|
||||
prepped=prepped,
|
||||
body=body,
|
||||
data=data,
|
||||
)
|
||||
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Translates from Cohere's `/v1/rerank` input format to Bedrock's `/rerank` input format.
|
||||
|
||||
Why separate file? Make it easy to see how transformation works
|
||||
"""
|
||||
|
||||
from litellm._uuid import uuid
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from litellm.types.llms.bedrock import (
|
||||
BedrockRerankBedrockRerankingConfiguration,
|
||||
BedrockRerankConfiguration,
|
||||
BedrockRerankInlineDocumentSource,
|
||||
BedrockRerankModelConfiguration,
|
||||
BedrockRerankQuery,
|
||||
BedrockRerankRequest,
|
||||
BedrockRerankSource,
|
||||
BedrockRerankTextDocument,
|
||||
BedrockRerankTextQuery,
|
||||
)
|
||||
from litellm.types.rerank import (
|
||||
RerankBilledUnits,
|
||||
RerankRequest,
|
||||
RerankResponse,
|
||||
RerankResponseMeta,
|
||||
RerankResponseResult,
|
||||
RerankTokens,
|
||||
)
|
||||
|
||||
|
||||
class BedrockRerankConfig:
|
||||
def _transform_sources(
|
||||
self, documents: List[Union[str, dict]]
|
||||
) -> List[BedrockRerankSource]:
|
||||
"""
|
||||
Transform the sources from RerankRequest format to Bedrock format.
|
||||
"""
|
||||
_sources = []
|
||||
for document in documents:
|
||||
if isinstance(document, str):
|
||||
_sources.append(
|
||||
BedrockRerankSource(
|
||||
inlineDocumentSource=BedrockRerankInlineDocumentSource(
|
||||
textDocument=BedrockRerankTextDocument(text=document),
|
||||
type="TEXT",
|
||||
),
|
||||
type="INLINE",
|
||||
)
|
||||
)
|
||||
else:
|
||||
_sources.append(
|
||||
BedrockRerankSource(
|
||||
inlineDocumentSource=BedrockRerankInlineDocumentSource(
|
||||
jsonDocument=document, type="JSON"
|
||||
),
|
||||
type="INLINE",
|
||||
)
|
||||
)
|
||||
return _sources
|
||||
|
||||
def _transform_request(self, request_data: RerankRequest) -> BedrockRerankRequest:
|
||||
"""
|
||||
Transform the request from RerankRequest format to Bedrock format.
|
||||
"""
|
||||
_sources = self._transform_sources(request_data.documents)
|
||||
|
||||
return BedrockRerankRequest(
|
||||
queries=[
|
||||
BedrockRerankQuery(
|
||||
textQuery=BedrockRerankTextQuery(text=request_data.query),
|
||||
type="TEXT",
|
||||
)
|
||||
],
|
||||
rerankingConfiguration=BedrockRerankConfiguration(
|
||||
bedrockRerankingConfiguration=BedrockRerankBedrockRerankingConfiguration(
|
||||
modelConfiguration=BedrockRerankModelConfiguration(
|
||||
modelArn=request_data.model
|
||||
),
|
||||
numberOfResults=request_data.top_n or len(request_data.documents),
|
||||
),
|
||||
type="BEDROCK_RERANKING_MODEL",
|
||||
),
|
||||
sources=_sources,
|
||||
)
|
||||
|
||||
def _transform_response(self, response: dict) -> RerankResponse:
|
||||
"""
|
||||
Transform the response from Bedrock into the RerankResponse format.
|
||||
|
||||
example input:
|
||||
{"results":[{"index":0,"relevanceScore":0.6847912669181824},{"index":1,"relevanceScore":0.5980774760246277}]}
|
||||
"""
|
||||
_billed_units = RerankBilledUnits(
|
||||
**response.get("usage", {"search_units": 1})
|
||||
) # by default 1 search unit
|
||||
_tokens = RerankTokens(**response.get("usage", {}))
|
||||
rerank_meta = RerankResponseMeta(billed_units=_billed_units, tokens=_tokens)
|
||||
|
||||
_results: Optional[List[RerankResponseResult]] = None
|
||||
|
||||
bedrock_results = response.get("results")
|
||||
if bedrock_results:
|
||||
_results = [
|
||||
RerankResponseResult(
|
||||
index=result.get("index"),
|
||||
relevance_score=result.get("relevanceScore"),
|
||||
)
|
||||
for result in bedrock_results
|
||||
]
|
||||
|
||||
if _results is None:
|
||||
raise ValueError(f"No results found in the response={response}")
|
||||
|
||||
return RerankResponse(
|
||||
id=response.get("id") or str(uuid.uuid4()),
|
||||
results=_results,
|
||||
meta=rerank_meta,
|
||||
) # Return response
|
||||
@@ -0,0 +1,356 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.vector_store.transformation import BaseVectorStoreConfig
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.types.integrations.rag.bedrock_knowledgebase import (
|
||||
BedrockKBContent,
|
||||
BedrockKBResponse,
|
||||
BedrockKBRetrievalConfiguration,
|
||||
BedrockKBRetrievalQuery,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.vector_stores import (
|
||||
BaseVectorStoreAuthCredentials,
|
||||
VectorStoreIndexEndpoints,
|
||||
VECTOR_STORE_OPENAI_PARAMS,
|
||||
VectorStoreResultContent,
|
||||
VectorStoreSearchOptionalRequestParams,
|
||||
VectorStoreSearchResponse,
|
||||
VectorStoreSearchResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class BedrockVectorStoreConfig(BaseVectorStoreConfig, BaseAWSLLM):
|
||||
"""Vector store configuration for AWS Bedrock Knowledge Bases."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
BaseVectorStoreConfig.__init__(self)
|
||||
BaseAWSLLM.__init__(self)
|
||||
|
||||
def get_auth_credentials(
|
||||
self, litellm_params: dict
|
||||
) -> BaseVectorStoreAuthCredentials:
|
||||
return {}
|
||||
|
||||
def get_vector_store_endpoints_by_type(self) -> VectorStoreIndexEndpoints:
|
||||
return {
|
||||
"read": [("POST", "/knowledgebases/{knowledge_base_id}/retrieve")],
|
||||
"write": [],
|
||||
}
|
||||
|
||||
def get_supported_openai_params(
|
||||
self, model: str
|
||||
) -> List[VECTOR_STORE_OPENAI_PARAMS]:
|
||||
return ["filters", "max_num_results", "ranking_options"]
|
||||
|
||||
def _map_operator_to_aws(self, operator: str) -> str:
|
||||
"""
|
||||
Map OpenAI-style operators to AWS Bedrock operator names.
|
||||
|
||||
OpenAI uses: eq, ne, gt, gte, lt, lte, in, nin
|
||||
AWS uses: equals, notEquals, greaterThan, greaterThanOrEquals, lessThan, lessThanOrEquals, in, notIn, startsWith, listContains, stringContains
|
||||
"""
|
||||
operator_mapping = {
|
||||
"eq": "equals",
|
||||
"ne": "notEquals",
|
||||
"gt": "greaterThan",
|
||||
"gte": "greaterThanOrEquals",
|
||||
"lt": "lessThan",
|
||||
"lte": "lessThanOrEquals",
|
||||
"in": "in",
|
||||
"nin": "notIn",
|
||||
# AWS-specific operators (pass through)
|
||||
"equals": "equals",
|
||||
"notEquals": "notEquals",
|
||||
"greaterThan": "greaterThan",
|
||||
"greaterThanOrEquals": "greaterThanOrEquals",
|
||||
"lessThan": "lessThan",
|
||||
"lessThanOrEquals": "lessThanOrEquals",
|
||||
"notIn": "notIn",
|
||||
"startsWith": "startsWith",
|
||||
"listContains": "listContains",
|
||||
"stringContains": "stringContains",
|
||||
}
|
||||
return operator_mapping.get(operator, operator)
|
||||
|
||||
def _map_operator_filter(self, filter_dict: dict) -> dict:
|
||||
"""
|
||||
Map a single OpenAI operator filter to AWS KB format.
|
||||
|
||||
OpenAI format: {"key": <key>, "value": <value>, "operator": <operator>}
|
||||
AWS KB format: {"operator": {"key": <key>, "value": <value>}}
|
||||
"""
|
||||
aws_operator = self._map_operator_to_aws(filter_dict["operator"])
|
||||
return {
|
||||
aws_operator: {
|
||||
"key": filter_dict["key"],
|
||||
"value": filter_dict["value"],
|
||||
}
|
||||
}
|
||||
|
||||
def _map_and_or_filters(self, value: dict) -> dict:
|
||||
"""
|
||||
Map OpenAI and/or filters to AWS KB format.
|
||||
|
||||
OpenAI format: {"and" | "or": [{"key": <key>, "value": <value>, "operator": <operator>}]}
|
||||
AWS KB format: {"andAll" | "orAll": [{"operator": {"key": <key>, "value": <value>}}]}
|
||||
|
||||
Note: AWS requires andAll/orAll to have at least 2 elements.
|
||||
For single filters, unwrap and return just the operator.
|
||||
"""
|
||||
aws_filters = {}
|
||||
|
||||
if "and" in value:
|
||||
and_filters = value["and"]
|
||||
# If only 1 filter, return just the operator (AWS requires andAll to have >=2 elements)
|
||||
if len(and_filters) == 1:
|
||||
return self._map_operator_filter(and_filters[0])
|
||||
|
||||
aws_filters["andAll"] = [
|
||||
{
|
||||
self._map_operator_to_aws(and_filters[i]["operator"]): {
|
||||
"key": and_filters[i]["key"],
|
||||
"value": and_filters[i]["value"],
|
||||
}
|
||||
}
|
||||
for i in range(len(and_filters))
|
||||
]
|
||||
|
||||
if "or" in value:
|
||||
or_filters = value["or"]
|
||||
# If only 1 filter, return just the operator (AWS requires orAll to have >=2 elements)
|
||||
if len(or_filters) == 1:
|
||||
return self._map_operator_filter(or_filters[0])
|
||||
|
||||
aws_filters["orAll"] = [
|
||||
{
|
||||
self._map_operator_to_aws(or_filters[i]["operator"]): {
|
||||
"key": or_filters[i]["key"],
|
||||
"value": or_filters[i]["value"],
|
||||
}
|
||||
}
|
||||
for i in range(len(or_filters))
|
||||
]
|
||||
|
||||
return aws_filters
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_num_results":
|
||||
optional_params["numberOfResults"] = value
|
||||
elif param == "filters" and value is not None:
|
||||
# map the openai filters to the aws kb filters format
|
||||
# openai filters = {"key": <key>, "value": <value>, "operator": <operator>} OR {"and" | "or": [{"key": <key>, "value": <value>, "operator": <operator>}]}
|
||||
# aws kb filters = {"operator": {"<key>": <value>}} OR {"andAll | orAll": [{"operator": {"<key>": <value>}}]}
|
||||
# 1. check if filter is in openai format
|
||||
# 2. if it is, map it to the aws kb filters format
|
||||
# 3. if it is not, assume it is in aws kb filters format and add it to the optional_params
|
||||
aws_filters: Optional[Dict] = None
|
||||
|
||||
if isinstance(value, dict):
|
||||
if "operator" in value.keys():
|
||||
# Single operator - map directly (no wrapping needed)
|
||||
aws_filters = self._map_operator_filter(value)
|
||||
elif "and" in value.keys() or "or" in value.keys():
|
||||
aws_filters = self._map_and_or_filters(value)
|
||||
else:
|
||||
# Assume it's already in AWS KB format
|
||||
aws_filters = value
|
||||
optional_params["filters"] = aws_filters
|
||||
|
||||
return optional_params
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
headers = headers or {}
|
||||
headers.setdefault("Content-Type", "application/json")
|
||||
return headers
|
||||
|
||||
def get_complete_url(self, api_base: Optional[str], litellm_params: dict) -> str:
|
||||
aws_region_name = litellm_params.get("aws_region_name")
|
||||
endpoint_url, _ = self.get_runtime_endpoint(
|
||||
api_base=api_base,
|
||||
aws_bedrock_runtime_endpoint=litellm_params.get(
|
||||
"aws_bedrock_runtime_endpoint"
|
||||
),
|
||||
aws_region_name=self.get_aws_region_name_for_non_llm_api_calls(
|
||||
aws_region_name=aws_region_name
|
||||
),
|
||||
endpoint_type="agent",
|
||||
)
|
||||
return f"{endpoint_url}/knowledgebases"
|
||||
|
||||
def transform_search_vector_store_request(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: Union[str, List[str]],
|
||||
vector_store_search_optional_params: VectorStoreSearchOptionalRequestParams,
|
||||
api_base: str,
|
||||
litellm_logging_obj: LiteLLMLoggingObj,
|
||||
litellm_params: dict,
|
||||
) -> Tuple[str, Dict]:
|
||||
if isinstance(query, list):
|
||||
query = " ".join(query)
|
||||
|
||||
url = f"{api_base}/{vector_store_id}/retrieve"
|
||||
|
||||
request_body: Dict[str, Any] = {
|
||||
"retrievalQuery": BedrockKBRetrievalQuery(text=query),
|
||||
}
|
||||
|
||||
retrieval_config: Dict[str, Any] = {}
|
||||
max_results = vector_store_search_optional_params.get("max_num_results")
|
||||
if max_results is not None:
|
||||
retrieval_config.setdefault("vectorSearchConfiguration", {})[
|
||||
"numberOfResults"
|
||||
] = max_results
|
||||
filters = vector_store_search_optional_params.get("filters")
|
||||
if filters is not None:
|
||||
retrieval_config.setdefault("vectorSearchConfiguration", {})[
|
||||
"filter"
|
||||
] = filters
|
||||
if retrieval_config:
|
||||
# Create a properly typed retrieval configuration
|
||||
typed_retrieval_config: BedrockKBRetrievalConfiguration = {}
|
||||
if "vectorSearchConfiguration" in retrieval_config:
|
||||
typed_retrieval_config["vectorSearchConfiguration"] = retrieval_config[
|
||||
"vectorSearchConfiguration"
|
||||
]
|
||||
request_body["retrievalConfiguration"] = typed_retrieval_config
|
||||
|
||||
litellm_logging_obj.model_call_details["query"] = query
|
||||
return url, request_body
|
||||
|
||||
def sign_request(
|
||||
self,
|
||||
headers: dict,
|
||||
optional_params: Dict,
|
||||
request_data: Dict,
|
||||
api_base: str,
|
||||
api_key: Optional[str] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
return self._sign_request(
|
||||
service_name="bedrock",
|
||||
headers=headers,
|
||||
optional_params=optional_params,
|
||||
request_data=request_data,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
def _get_file_id_from_metadata(self, metadata: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract file_id from Bedrock KB metadata.
|
||||
Uses source URI if available, otherwise generates a fallback ID.
|
||||
"""
|
||||
source_uri = metadata.get("x-amz-bedrock-kb-source-uri", "") if metadata else ""
|
||||
if source_uri:
|
||||
return source_uri
|
||||
|
||||
chunk_id = (
|
||||
metadata.get("x-amz-bedrock-kb-chunk-id", "unknown")
|
||||
if metadata
|
||||
else "unknown"
|
||||
)
|
||||
return f"bedrock-kb-{chunk_id}"
|
||||
|
||||
def _get_filename_from_metadata(self, metadata: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Extract filename from Bedrock KB metadata.
|
||||
Tries to extract filename from source URI, falls back to domain name or data source ID.
|
||||
"""
|
||||
source_uri = metadata.get("x-amz-bedrock-kb-source-uri", "") if metadata else ""
|
||||
|
||||
if source_uri:
|
||||
try:
|
||||
parsed_uri = urlparse(source_uri)
|
||||
filename = (
|
||||
parsed_uri.path.split("/")[-1]
|
||||
if parsed_uri.path and parsed_uri.path != "/"
|
||||
else parsed_uri.netloc
|
||||
)
|
||||
if not filename or filename == "/":
|
||||
filename = parsed_uri.netloc
|
||||
return filename
|
||||
except Exception:
|
||||
return source_uri
|
||||
|
||||
data_source_id = (
|
||||
metadata.get("x-amz-bedrock-kb-data-source-id", "unknown")
|
||||
if metadata
|
||||
else "unknown"
|
||||
)
|
||||
return f"bedrock-kb-document-{data_source_id}"
|
||||
|
||||
def _get_attributes_from_metadata(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract all attributes from Bedrock KB metadata.
|
||||
Returns a copy of the metadata dictionary.
|
||||
"""
|
||||
if not metadata:
|
||||
return {}
|
||||
return dict(metadata)
|
||||
|
||||
def transform_search_vector_store_response(
|
||||
self, response: httpx.Response, litellm_logging_obj: LiteLLMLoggingObj
|
||||
) -> VectorStoreSearchResponse:
|
||||
try:
|
||||
response_data = BedrockKBResponse(**response.json())
|
||||
results: List[VectorStoreSearchResult] = []
|
||||
for item in response_data.get("retrievalResults", []) or []:
|
||||
content: Optional[BedrockKBContent] = item.get("content")
|
||||
text = content.get("text") if content else None
|
||||
if text is None:
|
||||
continue
|
||||
|
||||
# Extract metadata and use helper functions
|
||||
metadata = item.get("metadata", {}) or {}
|
||||
file_id = self._get_file_id_from_metadata(metadata)
|
||||
filename = self._get_filename_from_metadata(metadata)
|
||||
attributes = self._get_attributes_from_metadata(metadata)
|
||||
|
||||
results.append(
|
||||
VectorStoreSearchResult(
|
||||
score=item.get("score"),
|
||||
content=[VectorStoreResultContent(text=text, type="text")],
|
||||
file_id=file_id,
|
||||
filename=filename,
|
||||
attributes=attributes,
|
||||
)
|
||||
)
|
||||
return VectorStoreSearchResponse(
|
||||
object="vector_store.search_results.page",
|
||||
search_query=litellm_logging_obj.model_call_details.get("query", ""),
|
||||
data=results,
|
||||
)
|
||||
except Exception as e:
|
||||
raise self.get_error_class(
|
||||
error_message=str(e),
|
||||
status_code=response.status_code,
|
||||
headers=response.headers,
|
||||
)
|
||||
|
||||
# Vector store creation is not yet implemented
|
||||
def transform_create_vector_store_request(
|
||||
self,
|
||||
vector_store_create_optional_params,
|
||||
api_base: str,
|
||||
) -> Tuple[str, Dict]:
|
||||
raise NotImplementedError
|
||||
|
||||
def transform_create_vector_store_response(self, response: httpx.Response):
|
||||
raise NotImplementedError
|
||||
Reference in New Issue
Block a user