1508 lines
52 KiB
Python
1508 lines
52 KiB
Python
import base64
|
|
import datetime
|
|
import hashlib
|
|
import json
|
|
from dataclasses import dataclass
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
AsyncIterator,
|
|
Dict,
|
|
List,
|
|
Optional,
|
|
Protocol,
|
|
Tuple,
|
|
Union,
|
|
)
|
|
from urllib.parse import urlparse
|
|
|
|
import httpx
|
|
|
|
import litellm
|
|
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
|
from litellm.llms.custom_httpx.http_handler import (
|
|
AsyncHTTPHandler,
|
|
HTTPHandler,
|
|
_get_httpx_client,
|
|
get_async_httpx_client,
|
|
version,
|
|
)
|
|
from litellm.llms.oci.common_utils import OCIError
|
|
from litellm.types.llms.oci import (
|
|
CohereChatRequest,
|
|
CohereMessage,
|
|
CohereChatResult,
|
|
CohereParameterDefinition,
|
|
CohereStreamChunk,
|
|
CohereTool,
|
|
CohereToolCall,
|
|
OCIChatRequestPayload,
|
|
OCICompletionPayload,
|
|
OCICompletionResponse,
|
|
OCIContentPartUnion,
|
|
OCIImageContentPart,
|
|
OCIImageUrl,
|
|
OCIMessage,
|
|
OCIRoles,
|
|
OCIServingMode,
|
|
OCIStreamChunk,
|
|
OCITextContentPart,
|
|
OCIToolCall,
|
|
OCIToolDefinition,
|
|
OCIVendors,
|
|
)
|
|
from litellm.types.llms.openai import AllMessageValues
|
|
from litellm.types.utils import (
|
|
Delta,
|
|
LlmProviders,
|
|
ModelResponse,
|
|
ModelResponseStream,
|
|
StreamingChoices,
|
|
)
|
|
from litellm.utils import (
|
|
ChatCompletionMessageToolCall,
|
|
CustomStreamWrapper,
|
|
Usage,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
|
|
|
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
|
else:
|
|
LiteLLMLoggingObj = Any
|
|
|
|
|
|
class OCISignerProtocol(Protocol):
|
|
"""
|
|
Protocol for OCI request signers (e.g., oci.signer.Signer).
|
|
|
|
This protocol defines the interface expected for OCI SDK signer objects.
|
|
Compatible with the OCI Python SDK's Signer class.
|
|
|
|
See: https://docs.oracle.com/en-us/iaas/tools/python/latest/api/signing.html
|
|
"""
|
|
|
|
def do_request_sign(
|
|
self, request: Any, *, enforce_content_headers: bool = False
|
|
) -> None:
|
|
"""
|
|
Sign an HTTP request by adding authentication headers.
|
|
|
|
Args:
|
|
request: Request object with method, url, headers, body, and path_url attributes
|
|
enforce_content_headers: Whether to enforce content-type and content-length headers
|
|
"""
|
|
...
|
|
|
|
|
|
@dataclass
|
|
class OCIRequestWrapper:
|
|
"""
|
|
Wrapper for HTTP requests compatible with OCI signer interface.
|
|
|
|
This class wraps request data in a format compatible with OCI SDK signers,
|
|
which expect objects with method, url, headers, body, and path_url attributes.
|
|
"""
|
|
|
|
method: str
|
|
url: str
|
|
headers: dict
|
|
body: bytes
|
|
|
|
@property
|
|
def path_url(self) -> str:
|
|
"""Returns the path + query string for OCI signing."""
|
|
parsed_url = urlparse(self.url)
|
|
return parsed_url.path + ("?" + parsed_url.query if parsed_url.query else "")
|
|
|
|
|
|
def sha256_base64(data: bytes) -> str:
|
|
digest = hashlib.sha256(data).digest()
|
|
return base64.b64encode(digest).decode()
|
|
|
|
|
|
def build_signature_string(method, path, headers, signed_headers):
|
|
lines = []
|
|
for header in signed_headers:
|
|
if header == "(request-target)":
|
|
value = f"{method.lower()} {path}"
|
|
else:
|
|
value = headers[header]
|
|
lines.append(f"{header}: {value}")
|
|
return "\n".join(lines)
|
|
|
|
|
|
def load_private_key_from_str(key_str: str):
|
|
try:
|
|
from cryptography.hazmat.primitives import serialization
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"cryptography package is required for OCI authentication. "
|
|
"Please install it with: pip install cryptography"
|
|
) from e
|
|
|
|
key = serialization.load_pem_private_key(
|
|
key_str.encode("utf-8"),
|
|
password=None,
|
|
)
|
|
if not isinstance(key, rsa.RSAPrivateKey):
|
|
raise TypeError(
|
|
"The provided private key is not an RSA key, which is required for OCI signing."
|
|
)
|
|
return key
|
|
|
|
|
|
def load_private_key_from_file(file_path: str):
|
|
"""Loads a private key from a file path"""
|
|
try:
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
key_str = f.read().strip()
|
|
except FileNotFoundError:
|
|
raise FileNotFoundError(f"Private key file not found: {file_path}")
|
|
except OSError as e:
|
|
raise OSError(f"Failed to read private key file '{file_path}': {e}") from e
|
|
|
|
if not key_str:
|
|
raise ValueError(f"Private key file is empty: {file_path}")
|
|
|
|
return load_private_key_from_str(key_str)
|
|
|
|
|
|
def get_vendor_from_model(model: str) -> OCIVendors:
|
|
"""
|
|
Extracts the vendor from the model name.
|
|
Args:
|
|
model (str): The model name.
|
|
Returns:
|
|
str: The vendor name.
|
|
"""
|
|
vendor = model.split(".")[0].lower()
|
|
if vendor == "cohere":
|
|
return OCIVendors.COHERE
|
|
else:
|
|
return OCIVendors.GENERIC
|
|
|
|
|
|
# 5 minute timeout (models may need to load)
|
|
STREAMING_TIMEOUT = 60 * 5
|
|
|
|
|
|
class OCIChatConfig(BaseConfig):
|
|
"""
|
|
Configuration class for OCI's API interface.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
) -> None:
|
|
locals_ = locals().copy()
|
|
for key, value in locals_.items():
|
|
if key != "self" and value is not None:
|
|
setattr(self.__class__, key, value)
|
|
# mark the class as using a custom stream wrapper because the default only iterates on lines
|
|
setattr(self.__class__, "has_custom_stream_wrapper", True)
|
|
|
|
self.openai_to_oci_generic_param_map = {
|
|
"stream": "isStream",
|
|
"max_tokens": "maxTokens",
|
|
"max_completion_tokens": "maxTokens",
|
|
"temperature": "temperature",
|
|
"tools": "tools",
|
|
"frequency_penalty": "frequencyPenalty",
|
|
"logprobs": "logProbs",
|
|
"logit_bias": "logitBias",
|
|
"n": "numGenerations",
|
|
"presence_penalty": "presencePenalty",
|
|
"seed": "seed",
|
|
"stop": "stop",
|
|
"tool_choice": "toolChoice",
|
|
"top_p": "topP",
|
|
"max_retries": False,
|
|
"top_logprobs": False,
|
|
"modalities": False,
|
|
"prediction": False,
|
|
"stream_options": False,
|
|
"function_call": False,
|
|
"functions": False,
|
|
"extra_headers": False,
|
|
"parallel_tool_calls": False,
|
|
"audio": False,
|
|
"web_search_options": False,
|
|
"response_format": "responseFormat",
|
|
}
|
|
|
|
# Cohere and Gemini use the same parameter mapping as GENERIC
|
|
self.openai_to_oci_cohere_param_map = (
|
|
self.openai_to_oci_generic_param_map.copy()
|
|
)
|
|
|
|
def get_supported_openai_params(self, model: str) -> List[str]:
|
|
supported_params = []
|
|
vendor = get_vendor_from_model(model)
|
|
if vendor == OCIVendors.COHERE:
|
|
open_ai_to_oci_param_map = self.openai_to_oci_cohere_param_map
|
|
open_ai_to_oci_param_map.pop("tool_choice")
|
|
open_ai_to_oci_param_map.pop("max_retries")
|
|
else:
|
|
open_ai_to_oci_param_map = self.openai_to_oci_generic_param_map
|
|
for key, value in open_ai_to_oci_param_map.items():
|
|
if value:
|
|
supported_params.append(key)
|
|
|
|
return supported_params
|
|
|
|
def map_openai_params(
|
|
self,
|
|
non_default_params: dict,
|
|
optional_params: dict,
|
|
model: str,
|
|
drop_params: bool,
|
|
) -> dict:
|
|
adapted_params = {}
|
|
vendor = get_vendor_from_model(model)
|
|
if vendor == OCIVendors.COHERE:
|
|
open_ai_to_oci_param_map = self.openai_to_oci_cohere_param_map
|
|
else:
|
|
open_ai_to_oci_param_map = self.openai_to_oci_generic_param_map
|
|
|
|
all_params = {**non_default_params, **optional_params}
|
|
|
|
for key, value in all_params.items():
|
|
alias = open_ai_to_oci_param_map.get(key)
|
|
|
|
if alias is False:
|
|
# Workaround for mypy issue
|
|
if drop_params or litellm.drop_params:
|
|
continue
|
|
raise Exception(f"param `{key}` is not supported on OCI")
|
|
|
|
if alias is None:
|
|
adapted_params[key] = value
|
|
continue
|
|
|
|
adapted_params[alias] = value
|
|
|
|
if alias == "responseFormat":
|
|
adapted_params["response_format"] = value
|
|
|
|
return adapted_params
|
|
|
|
def _sign_with_oci_signer(
|
|
self,
|
|
headers: dict,
|
|
optional_params: dict,
|
|
request_data: dict,
|
|
api_base: str,
|
|
) -> Tuple[dict, bytes]:
|
|
"""
|
|
Sign request using OCI SDK Signer object.
|
|
|
|
Args:
|
|
headers: Request headers to be signed
|
|
optional_params: Optional parameters including oci_signer
|
|
request_data: The request body dict to be sent in HTTP request
|
|
api_base: The complete URL for the HTTP request
|
|
|
|
Returns:
|
|
Tuple of (signed_headers, encoded_body)
|
|
|
|
Raises:
|
|
OCIError: If signing fails
|
|
ValueError: If HTTP method is unsupported
|
|
"""
|
|
oci_signer = optional_params.get("oci_signer")
|
|
body = json.dumps(request_data).encode("utf-8")
|
|
method = str(optional_params.get("method", "POST")).upper()
|
|
|
|
if method not in ["POST", "GET", "PUT", "DELETE", "PATCH"]:
|
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
|
|
|
prepared_headers = headers.copy()
|
|
prepared_headers.setdefault("content-type", "application/json")
|
|
prepared_headers.setdefault("content-length", str(len(body)))
|
|
|
|
request_wrapper = OCIRequestWrapper(
|
|
method=method, url=api_base, headers=prepared_headers, body=body
|
|
)
|
|
|
|
if oci_signer is None:
|
|
raise ValueError(
|
|
"oci_signer cannot be None when calling _sign_with_oci_signer"
|
|
)
|
|
|
|
try:
|
|
oci_signer.do_request_sign(request_wrapper, enforce_content_headers=True)
|
|
except Exception as e:
|
|
raise OCIError(
|
|
status_code=500,
|
|
message=(
|
|
f"Failed to sign request with provided oci_signer: {str(e)}. "
|
|
"The signer must implement the OCI SDK Signer interface with a "
|
|
"do_request_sign(request, enforce_content_headers=True) method. "
|
|
"See: https://docs.oracle.com/en-us/iaas/tools/python/latest/api/signing.html"
|
|
),
|
|
) from e
|
|
|
|
headers.update(request_wrapper.headers)
|
|
return headers, body
|
|
|
|
def _sign_with_manual_credentials(
|
|
self,
|
|
headers: dict,
|
|
optional_params: dict,
|
|
request_data: dict,
|
|
api_base: str,
|
|
) -> Tuple[dict, None]:
|
|
"""
|
|
Sign request using manual OCI credentials.
|
|
|
|
Args:
|
|
headers: Request headers to be signed
|
|
optional_params: Optional parameters including OCI credentials
|
|
request_data: The request body dict to be sent in HTTP request
|
|
api_base: The complete URL for the HTTP request
|
|
|
|
Returns:
|
|
Tuple of (signed_headers, None)
|
|
|
|
Raises:
|
|
Exception: If required credentials are missing
|
|
ImportError: If cryptography package is not installed
|
|
"""
|
|
oci_region = optional_params.get("oci_region", "us-ashburn-1")
|
|
api_base = (
|
|
api_base
|
|
or litellm.api_base
|
|
or f"https://inference.generativeai.{oci_region}.oci.oraclecloud.com"
|
|
)
|
|
oci_user = optional_params.get("oci_user")
|
|
oci_fingerprint = optional_params.get("oci_fingerprint")
|
|
oci_tenancy = optional_params.get("oci_tenancy")
|
|
oci_key = optional_params.get("oci_key")
|
|
oci_key_file = optional_params.get("oci_key_file")
|
|
|
|
if (
|
|
not oci_user
|
|
or not oci_fingerprint
|
|
or not oci_tenancy
|
|
or not (oci_key or oci_key_file)
|
|
):
|
|
raise Exception(
|
|
"Missing required parameters: oci_user, oci_fingerprint, oci_tenancy, "
|
|
"and at least one of oci_key or oci_key_file."
|
|
)
|
|
|
|
method = str(optional_params.get("method", "POST")).upper()
|
|
body = json.dumps(request_data).encode("utf-8")
|
|
parsed = urlparse(api_base)
|
|
path = parsed.path or "/"
|
|
host = parsed.netloc
|
|
|
|
date = datetime.datetime.utcnow().strftime("%a, %d %b %Y %H:%M:%S GMT")
|
|
content_type = headers.get("content-type", "application/json")
|
|
content_length = str(len(body))
|
|
x_content_sha256 = sha256_base64(body)
|
|
|
|
headers_to_sign = {
|
|
"date": date,
|
|
"host": host,
|
|
"content-type": content_type,
|
|
"content-length": content_length,
|
|
"x-content-sha256": x_content_sha256,
|
|
}
|
|
|
|
signed_headers = [
|
|
"date",
|
|
"(request-target)",
|
|
"host",
|
|
"content-length",
|
|
"content-type",
|
|
"x-content-sha256",
|
|
]
|
|
signing_string = build_signature_string(
|
|
method, path, headers_to_sign, signed_headers
|
|
)
|
|
|
|
try:
|
|
from cryptography.hazmat.primitives import hashes
|
|
from cryptography.hazmat.primitives.asymmetric import padding
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"cryptography package is required for OCI authentication. "
|
|
"Please install it with: pip install cryptography"
|
|
) from e
|
|
|
|
# Handle oci_key - it should be a string (PEM content)
|
|
oci_key_content = None
|
|
if oci_key:
|
|
if isinstance(oci_key, str):
|
|
oci_key_content = oci_key
|
|
# Fix common issues with PEM content
|
|
# Replace escaped newlines with actual newlines
|
|
oci_key_content = oci_key_content.replace("\\n", "\n")
|
|
# Ensure proper line endings
|
|
if "\r\n" in oci_key_content:
|
|
oci_key_content = oci_key_content.replace("\r\n", "\n")
|
|
else:
|
|
raise OCIError(
|
|
status_code=400,
|
|
message=f"oci_key must be a string containing the PEM private key content. "
|
|
f"Got type: {type(oci_key).__name__}",
|
|
)
|
|
|
|
private_key = (
|
|
load_private_key_from_str(oci_key_content)
|
|
if oci_key_content
|
|
else load_private_key_from_file(oci_key_file)
|
|
if oci_key_file
|
|
else None
|
|
)
|
|
|
|
if private_key is None:
|
|
raise OCIError(
|
|
status_code=400,
|
|
message="Private key is required for OCI authentication. Please provide either oci_key or oci_key_file.",
|
|
)
|
|
|
|
signature = private_key.sign(
|
|
signing_string.encode("utf-8"),
|
|
padding.PKCS1v15(),
|
|
hashes.SHA256(),
|
|
)
|
|
signature_b64 = base64.b64encode(signature).decode()
|
|
|
|
key_id = f"{oci_tenancy}/{oci_user}/{oci_fingerprint}"
|
|
|
|
authorization = (
|
|
'Signature version="1",'
|
|
f'keyId="{key_id}",'
|
|
'algorithm="rsa-sha256",'
|
|
f'headers="{" ".join(signed_headers)}",'
|
|
f'signature="{signature_b64}"'
|
|
)
|
|
|
|
headers.update(
|
|
{
|
|
"authorization": authorization,
|
|
"date": date,
|
|
"host": host,
|
|
"content-type": content_type,
|
|
"content-length": content_length,
|
|
"x-content-sha256": x_content_sha256,
|
|
}
|
|
)
|
|
|
|
return headers, None
|
|
|
|
def sign_request(
|
|
self,
|
|
headers: dict,
|
|
optional_params: dict,
|
|
request_data: dict,
|
|
api_base: str,
|
|
api_key: Optional[str] = None,
|
|
model: Optional[str] = None,
|
|
stream: Optional[bool] = None,
|
|
fake_stream: Optional[bool] = None,
|
|
) -> Tuple[dict, Optional[bytes]]:
|
|
"""
|
|
Sign the OCI request by adding authentication headers.
|
|
|
|
Supports two signing modes:
|
|
1. OCI SDK Signer: Use an oci_signer object to sign the request
|
|
2. Manual Signing: Use OCI credentials to manually sign the request
|
|
|
|
Args:
|
|
headers: Request headers to be signed
|
|
optional_params: Optional parameters including auth credentials or oci_signer
|
|
request_data: The request body dict to be sent in HTTP request
|
|
api_base: The complete URL for the HTTP request
|
|
api_key: Optional API key (not used for OCI)
|
|
model: Optional model name
|
|
stream: Optional streaming flag
|
|
fake_stream: Optional fake streaming flag
|
|
|
|
Returns:
|
|
Tuple of (signed_headers, encoded_body):
|
|
- If oci_signer is provided: Returns (headers, body) where body is the encoded JSON
|
|
- If manual credentials are provided: Returns (headers, None) as body is not returned
|
|
for the manual signing path
|
|
|
|
Raises:
|
|
OCIError: If signing fails with oci_signer
|
|
Exception: If required credentials are missing
|
|
ImportError: If cryptography package is not installed (manual signing only)
|
|
|
|
Example:
|
|
>>> from oci.signer import Signer
|
|
>>> signer = Signer(
|
|
... tenancy="ocid1.tenancy.oc1..",
|
|
... user="ocid1.user.oc1..",
|
|
... fingerprint="xx:xx:xx",
|
|
... private_key_file_location="~/.oci/key.pem"
|
|
... )
|
|
>>> headers, body = config.sign_request(
|
|
... headers={},
|
|
... optional_params={"oci_signer": signer},
|
|
... request_data={"message": "Hello"},
|
|
... api_base="https://inference.generativeai.us-ashburn-1.oci.oraclecloud.com/..."
|
|
... )
|
|
"""
|
|
oci_signer = optional_params.get("oci_signer")
|
|
|
|
# If a signer is provided, use it for request signing
|
|
if oci_signer is not None:
|
|
return self._sign_with_oci_signer(
|
|
headers, optional_params, request_data, api_base
|
|
)
|
|
|
|
# Standard manual credential signing
|
|
return self._sign_with_manual_credentials(
|
|
headers, optional_params, request_data, api_base
|
|
)
|
|
|
|
def validate_environment(
|
|
self,
|
|
headers: dict,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
api_key: Optional[str] = None,
|
|
api_base: Optional[str] = None,
|
|
) -> dict:
|
|
"""
|
|
Validate the OCI environment and credentials.
|
|
|
|
Supports two authentication modes:
|
|
1. OCI SDK Signer: Pass an oci_signer object (e.g., oci.signer.Signer)
|
|
2. Manual Credentials: Pass oci_user, oci_fingerprint, oci_tenancy, and oci_key/oci_key_file
|
|
|
|
Args:
|
|
headers: Request headers to populate
|
|
model: Model name
|
|
messages: List of chat messages
|
|
optional_params: Optional parameters including authentication credentials
|
|
litellm_params: LiteLLM parameters
|
|
api_key: Optional API key (not used for OCI)
|
|
api_base: Optional API base URL
|
|
|
|
Returns:
|
|
Updated headers dict
|
|
|
|
Raises:
|
|
Exception: If required parameters are missing or invalid
|
|
"""
|
|
oci_signer = optional_params.get("oci_signer")
|
|
oci_region = optional_params.get("oci_region", "us-ashburn-1")
|
|
|
|
# Determine api_base
|
|
api_base = (
|
|
api_base
|
|
or litellm.api_base
|
|
or f"https://inference.generativeai.{oci_region}.oci.oraclecloud.com"
|
|
)
|
|
|
|
if not api_base:
|
|
raise Exception(
|
|
"Either `api_base` must be provided or `litellm.api_base` must be set. "
|
|
"Alternatively, you can set the `oci_region` optional parameter to use the default OCI region."
|
|
)
|
|
|
|
# Validate credentials only if signer is not provided
|
|
if oci_signer is None:
|
|
oci_user = optional_params.get("oci_user")
|
|
oci_fingerprint = optional_params.get("oci_fingerprint")
|
|
oci_tenancy = optional_params.get("oci_tenancy")
|
|
oci_key = optional_params.get("oci_key")
|
|
oci_key_file = optional_params.get("oci_key_file")
|
|
oci_compartment_id = optional_params.get("oci_compartment_id")
|
|
|
|
if (
|
|
not oci_user
|
|
or not oci_fingerprint
|
|
or not oci_tenancy
|
|
or not (oci_key or oci_key_file)
|
|
or not oci_compartment_id
|
|
):
|
|
raise Exception(
|
|
"Missing required parameters: oci_user, oci_fingerprint, oci_tenancy, oci_compartment_id "
|
|
"and at least one of oci_key or oci_key_file. "
|
|
"Alternatively, provide an oci_signer object from the OCI SDK."
|
|
)
|
|
|
|
# Common header setup
|
|
headers.update(
|
|
{
|
|
"content-type": "application/json",
|
|
"user-agent": f"litellm/{version}",
|
|
}
|
|
)
|
|
|
|
if not messages:
|
|
raise Exception(
|
|
"kwarg `messages` must be an array of messages that follow the openai chat standard"
|
|
)
|
|
|
|
return headers
|
|
|
|
def get_complete_url(
|
|
self,
|
|
api_base: Optional[str],
|
|
api_key: Optional[str],
|
|
model: str,
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
stream: Optional[bool] = None,
|
|
) -> str:
|
|
oci_region = optional_params.get("oci_region", "us-ashburn-1")
|
|
return f"https://inference.generativeai.{oci_region}.oci.oraclecloud.com/20231130/actions/chat"
|
|
|
|
def _get_optional_params(self, vendor: OCIVendors, optional_params: dict) -> Dict:
|
|
selected_params = {}
|
|
if vendor == OCIVendors.COHERE:
|
|
open_ai_to_oci_param_map = self.openai_to_oci_cohere_param_map
|
|
# remove tool_choice from the map
|
|
open_ai_to_oci_param_map.pop("tool_choice")
|
|
# Add default values for Cohere API
|
|
selected_params = {
|
|
"maxTokens": 600,
|
|
"temperature": 1,
|
|
"topK": 0,
|
|
"topP": 0.75,
|
|
"frequencyPenalty": 0,
|
|
}
|
|
else:
|
|
open_ai_to_oci_param_map = self.openai_to_oci_generic_param_map
|
|
|
|
# Map OpenAI params to OCI params
|
|
for openai_key, oci_key in open_ai_to_oci_param_map.items():
|
|
if oci_key and openai_key in optional_params:
|
|
selected_params[oci_key] = optional_params[openai_key] # type: ignore[index]
|
|
|
|
# Also check for already-mapped OCI params (for backward compatibility)
|
|
for oci_value in open_ai_to_oci_param_map.values():
|
|
if (
|
|
oci_value
|
|
and oci_value in optional_params
|
|
and oci_value not in selected_params
|
|
):
|
|
selected_params[oci_value] = optional_params[oci_value] # type: ignore[index]
|
|
|
|
if "tools" in selected_params:
|
|
if vendor == OCIVendors.COHERE:
|
|
selected_params["tools"] = self.adapt_tool_definitions_to_cohere_standard( # type: ignore[assignment]
|
|
selected_params["tools"] # type: ignore[arg-type]
|
|
)
|
|
else:
|
|
selected_params["tools"] = adapt_tool_definition_to_oci_standard( # type: ignore[assignment]
|
|
selected_params["tools"], vendor # type: ignore[arg-type]
|
|
)
|
|
|
|
# Transform response_format type to OCI uppercase format
|
|
if "responseFormat" in selected_params:
|
|
rf = selected_params["responseFormat"]
|
|
if isinstance(rf, dict) and "type" in rf:
|
|
rf_payload = dict(rf)
|
|
selected_params["responseFormat"] = rf_payload
|
|
|
|
response_type = rf_payload["type"]
|
|
schema_payload: Optional[Any] = None
|
|
|
|
if "json_schema" in rf_payload:
|
|
raw_schema_payload = rf_payload.pop("json_schema")
|
|
if isinstance(raw_schema_payload, dict):
|
|
schema_payload = dict(raw_schema_payload)
|
|
else:
|
|
schema_payload = raw_schema_payload
|
|
|
|
if schema_payload is not None:
|
|
rf_payload["jsonSchema"] = schema_payload
|
|
|
|
if vendor == OCIVendors.COHERE:
|
|
# Cohere expects lower-case type values
|
|
rf_payload["type"] = response_type
|
|
else:
|
|
format_type = response_type.upper()
|
|
if format_type == "JSON":
|
|
format_type = "JSON_OBJECT"
|
|
rf_payload["type"] = format_type
|
|
|
|
return selected_params
|
|
|
|
def adapt_messages_to_cohere_standard(
|
|
self, messages: List[AllMessageValues]
|
|
) -> List[CohereMessage]:
|
|
"""Build chat history for Cohere models."""
|
|
chat_history = []
|
|
for msg in messages[:-1]: # All messages except the last one
|
|
role = msg.get("role")
|
|
content = msg.get("content")
|
|
|
|
if isinstance(content, list):
|
|
# Extract text from content array
|
|
text_content = ""
|
|
for content_item in content:
|
|
if (
|
|
isinstance(content_item, dict)
|
|
and content_item.get("type") == "text"
|
|
):
|
|
text_content += content_item.get("text", "")
|
|
content = text_content
|
|
|
|
# Ensure content is a string
|
|
if not isinstance(content, str):
|
|
content = str(content) if content is not None else ""
|
|
|
|
# Handle tool calls
|
|
tool_calls: Optional[List[CohereToolCall]] = None
|
|
if role == "assistant" and "tool_calls" in msg and msg.get("tool_calls"): # type: ignore[union-attr,typeddict-item]
|
|
tool_calls = []
|
|
for tool_call in msg["tool_calls"]: # type: ignore[union-attr,typeddict-item]
|
|
# Parse arguments if they're a JSON string
|
|
raw_arguments: Any = tool_call.get("function", {}).get(
|
|
"arguments", {}
|
|
)
|
|
if isinstance(raw_arguments, str):
|
|
try:
|
|
arguments: Dict[str, Any] = json.loads(raw_arguments)
|
|
except json.JSONDecodeError:
|
|
arguments = {}
|
|
else:
|
|
arguments = raw_arguments
|
|
|
|
tool_calls.append(
|
|
CohereToolCall(
|
|
name=str(tool_call.get("function", {}).get("name", "")),
|
|
parameters=arguments,
|
|
)
|
|
)
|
|
|
|
if role == "user":
|
|
chat_history.append(CohereMessage(role="USER", message=content))
|
|
elif role == "assistant":
|
|
chat_history.append(
|
|
CohereMessage(role="CHATBOT", message=content, toolCalls=tool_calls)
|
|
)
|
|
elif role == "tool":
|
|
# Tool messages need special handling
|
|
chat_history.append(
|
|
CohereMessage(
|
|
role="TOOL",
|
|
message=content,
|
|
toolCalls=None, # Tool messages don't have tool calls
|
|
)
|
|
)
|
|
|
|
return chat_history
|
|
|
|
def adapt_tool_definitions_to_cohere_standard(
|
|
self, tools: List[Dict[str, Any]]
|
|
) -> List[CohereTool]:
|
|
"""Adapt tool definitions to Cohere format."""
|
|
cohere_tools = []
|
|
for tool in tools:
|
|
function_def = tool.get("function", {})
|
|
parameters = function_def.get("parameters", {}).get("properties", {})
|
|
required = function_def.get("parameters", {}).get("required", [])
|
|
|
|
parameter_definitions = {}
|
|
for param_name, param_schema in parameters.items():
|
|
parameter_definitions[param_name] = CohereParameterDefinition(
|
|
description=param_schema.get("description", ""),
|
|
type=param_schema.get("type", "string"),
|
|
isRequired=param_name in required,
|
|
)
|
|
|
|
cohere_tools.append(
|
|
CohereTool(
|
|
name=function_def.get("name", ""),
|
|
description=function_def.get("description", ""),
|
|
parameterDefinitions=parameter_definitions,
|
|
)
|
|
)
|
|
|
|
return cohere_tools
|
|
|
|
def _extract_text_content(self, content: Any) -> str:
|
|
"""Extract text content from message content."""
|
|
if isinstance(content, str):
|
|
return content
|
|
elif isinstance(content, list):
|
|
text_content = ""
|
|
for content_item in content:
|
|
if (
|
|
isinstance(content_item, dict)
|
|
and content_item.get("type") == "text"
|
|
):
|
|
text_content += content_item.get("text", "")
|
|
return text_content
|
|
return str(content)
|
|
|
|
def transform_request(
|
|
self,
|
|
model: str,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
headers: dict,
|
|
) -> dict:
|
|
oci_compartment_id = optional_params.get("oci_compartment_id", None)
|
|
if not oci_compartment_id:
|
|
raise Exception("kwarg `oci_compartment_id` is required for OCI requests")
|
|
|
|
vendor = get_vendor_from_model(model)
|
|
|
|
oci_serving_mode = optional_params.get("oci_serving_mode", "ON_DEMAND")
|
|
if oci_serving_mode not in ["ON_DEMAND", "DEDICATED"]:
|
|
raise Exception(
|
|
"kwarg `oci_serving_mode` must be either 'ON_DEMAND' or 'DEDICATED'"
|
|
)
|
|
|
|
if oci_serving_mode == "DEDICATED":
|
|
oci_endpoint_id = optional_params.get("oci_endpoint_id", model)
|
|
servingMode = OCIServingMode(
|
|
servingType="DEDICATED",
|
|
endpointId=oci_endpoint_id,
|
|
)
|
|
else:
|
|
servingMode = OCIServingMode(
|
|
servingType="ON_DEMAND",
|
|
modelId=model,
|
|
)
|
|
|
|
# Build request based on vendor type
|
|
if vendor == OCIVendors.COHERE:
|
|
# For Cohere, we need to use the specific Cohere format
|
|
# Extract the last user message as the main message
|
|
user_messages = [msg for msg in messages if msg.get("role") == "user"]
|
|
if not user_messages:
|
|
raise Exception("No user message found for Cohere model")
|
|
|
|
# Extract system messages into preambleOverride
|
|
system_messages = [msg for msg in messages if msg.get("role") == "system"]
|
|
preamble_override = None
|
|
if system_messages:
|
|
preamble = "\n".join(
|
|
self._extract_text_content(msg["content"])
|
|
for msg in system_messages
|
|
)
|
|
if preamble:
|
|
preamble_override = preamble
|
|
|
|
# Create Cohere-specific chat request
|
|
optional_cohere_params = self._get_optional_params(
|
|
OCIVendors.COHERE, optional_params
|
|
)
|
|
chat_request = CohereChatRequest(
|
|
apiFormat="COHERE",
|
|
message=self._extract_text_content(user_messages[-1]["content"]),
|
|
chatHistory=self.adapt_messages_to_cohere_standard(messages),
|
|
preambleOverride=preamble_override,
|
|
**optional_cohere_params,
|
|
)
|
|
|
|
data = OCICompletionPayload(
|
|
compartmentId=oci_compartment_id,
|
|
servingMode=servingMode,
|
|
chatRequest=chat_request,
|
|
)
|
|
else:
|
|
# Use generic format for other vendors
|
|
data = OCICompletionPayload(
|
|
compartmentId=oci_compartment_id,
|
|
servingMode=servingMode,
|
|
chatRequest=OCIChatRequestPayload(
|
|
apiFormat=vendor.value,
|
|
messages=adapt_messages_to_generic_oci_standard(messages),
|
|
**self._get_optional_params(vendor, optional_params),
|
|
),
|
|
)
|
|
|
|
return data.model_dump(exclude_none=True)
|
|
|
|
def _handle_cohere_response(
|
|
self, json_response: dict, model: str, model_response: ModelResponse
|
|
) -> ModelResponse:
|
|
"""Handle Cohere-specific response format."""
|
|
cohere_response = CohereChatResult(**json_response)
|
|
# Cohere response format (uses camelCase)
|
|
model_id = model
|
|
|
|
# Set basic response info
|
|
model_response.model = model_id
|
|
model_response.created = int(datetime.datetime.now().timestamp())
|
|
|
|
# Extract the response text
|
|
response_text = cohere_response.chatResponse.text
|
|
oci_finish_reason = cohere_response.chatResponse.finishReason
|
|
|
|
# Map finish reason
|
|
if oci_finish_reason == "COMPLETE":
|
|
finish_reason = "stop"
|
|
elif oci_finish_reason == "MAX_TOKENS":
|
|
finish_reason = "length"
|
|
else:
|
|
finish_reason = "stop"
|
|
|
|
# Handle tool calls
|
|
tool_calls: Optional[List[Dict[str, Any]]] = None
|
|
if cohere_response.chatResponse.toolCalls:
|
|
tool_calls = []
|
|
for tool_call in cohere_response.chatResponse.toolCalls:
|
|
tool_calls.append(
|
|
{
|
|
"id": f"call_{len(tool_calls)}", # Generate a simple ID
|
|
"type": "function",
|
|
"function": {
|
|
"name": tool_call.name,
|
|
"arguments": json.dumps(tool_call.parameters),
|
|
},
|
|
}
|
|
)
|
|
|
|
# Create choice
|
|
from litellm.types.utils import Choices
|
|
|
|
choice = Choices(
|
|
index=0,
|
|
message={
|
|
"role": "assistant",
|
|
"content": response_text,
|
|
"tool_calls": tool_calls,
|
|
},
|
|
finish_reason=finish_reason,
|
|
)
|
|
model_response.choices = [choice]
|
|
|
|
# Extract usage info
|
|
usage_info = cohere_response.chatResponse.usage
|
|
from litellm.types.utils import Usage
|
|
|
|
model_response.usage = Usage( # type: ignore[attr-defined]
|
|
prompt_tokens=usage_info.promptTokens, # type: ignore[union-attr]
|
|
completion_tokens=usage_info.completionTokens, # type: ignore[union-attr]
|
|
total_tokens=usage_info.totalTokens, # type: ignore[union-attr]
|
|
)
|
|
|
|
return model_response
|
|
|
|
def _handle_generic_response(
|
|
self,
|
|
json: dict,
|
|
model: str,
|
|
model_response: ModelResponse,
|
|
raw_response: httpx.Response,
|
|
) -> ModelResponse:
|
|
"""Handle generic OCI response format."""
|
|
try:
|
|
completion_response = OCICompletionResponse(**json)
|
|
except TypeError as e:
|
|
raise OCIError(
|
|
message=f"Response cannot be casted to OCICompletionResponse: {str(e)}",
|
|
status_code=raw_response.status_code,
|
|
)
|
|
|
|
iso_str = completion_response.chatResponse.timeCreated
|
|
dt = datetime.datetime.fromisoformat(iso_str.replace("Z", "+00:00"))
|
|
model_response.created = int(dt.timestamp())
|
|
|
|
model_response.model = completion_response.modelId
|
|
|
|
message = model_response.choices[0].message # type: ignore
|
|
response_message = completion_response.chatResponse.choices[0].message
|
|
if response_message.content and response_message.content[0].type == "TEXT":
|
|
message.content = response_message.content[0].text
|
|
if response_message.toolCalls:
|
|
message.tool_calls = adapt_tools_to_openai_standard(
|
|
response_message.toolCalls
|
|
)
|
|
|
|
usage = Usage(
|
|
prompt_tokens=completion_response.chatResponse.usage.promptTokens,
|
|
completion_tokens=completion_response.chatResponse.usage.completionTokens,
|
|
total_tokens=completion_response.chatResponse.usage.totalTokens,
|
|
)
|
|
model_response.usage = usage # type: ignore
|
|
|
|
return model_response
|
|
|
|
def transform_response(
|
|
self,
|
|
model: str,
|
|
raw_response: httpx.Response,
|
|
model_response: ModelResponse,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
request_data: dict,
|
|
messages: List[AllMessageValues],
|
|
optional_params: dict,
|
|
litellm_params: dict,
|
|
encoding: Any,
|
|
api_key: Optional[str] = None,
|
|
json_mode: Optional[bool] = None,
|
|
) -> ModelResponse:
|
|
json = raw_response.json() # noqa: F811
|
|
|
|
error = json.get("error")
|
|
|
|
if error is not None:
|
|
raise OCIError(
|
|
message=str(json["error"]),
|
|
status_code=raw_response.status_code,
|
|
)
|
|
|
|
if not isinstance(json, dict):
|
|
raise OCIError(
|
|
message="Invalid response format from OCI",
|
|
status_code=raw_response.status_code,
|
|
)
|
|
|
|
vendor = get_vendor_from_model(model)
|
|
|
|
# Handle response based on vendor type
|
|
if vendor == OCIVendors.COHERE:
|
|
model_response = self._handle_cohere_response(json, model, model_response)
|
|
else:
|
|
model_response = self._handle_generic_response(
|
|
json, model, model_response, raw_response
|
|
)
|
|
|
|
model_response._hidden_params["additional_headers"] = raw_response.headers
|
|
|
|
return model_response
|
|
|
|
@track_llm_api_timing()
|
|
def get_sync_custom_stream_wrapper(
|
|
self,
|
|
model: str,
|
|
custom_llm_provider: str,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
api_base: str,
|
|
headers: dict,
|
|
data: dict,
|
|
messages: list,
|
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
|
json_mode: Optional[bool] = None,
|
|
signed_json_body: Optional[bytes] = None,
|
|
) -> "OCIStreamWrapper":
|
|
if "stream" in data:
|
|
del data["stream"]
|
|
if client is None or isinstance(client, AsyncHTTPHandler):
|
|
client = _get_httpx_client(params={})
|
|
|
|
try:
|
|
response = client.post(
|
|
api_base,
|
|
headers=headers,
|
|
data=json.dumps(data),
|
|
stream=True,
|
|
logging_obj=logging_obj,
|
|
timeout=STREAMING_TIMEOUT,
|
|
)
|
|
except httpx.HTTPStatusError as e:
|
|
raise OCIError(status_code=e.response.status_code, message=e.response.text)
|
|
|
|
if response.status_code != 200:
|
|
raise OCIError(status_code=response.status_code, message=response.text)
|
|
|
|
completion_stream = response.iter_text()
|
|
|
|
streaming_response = OCIStreamWrapper(
|
|
completion_stream=completion_stream,
|
|
model=model,
|
|
custom_llm_provider=custom_llm_provider,
|
|
logging_obj=logging_obj,
|
|
)
|
|
return streaming_response
|
|
|
|
@track_llm_api_timing()
|
|
async def get_async_custom_stream_wrapper(
|
|
self,
|
|
model: str,
|
|
custom_llm_provider: str,
|
|
logging_obj: LiteLLMLoggingObj,
|
|
api_base: str,
|
|
headers: dict,
|
|
data: dict,
|
|
messages: list,
|
|
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
|
json_mode: Optional[bool] = None,
|
|
signed_json_body: Optional[bytes] = None,
|
|
) -> "OCIStreamWrapper":
|
|
if "stream" in data:
|
|
del data["stream"]
|
|
|
|
if client is None or isinstance(client, HTTPHandler):
|
|
client = get_async_httpx_client(llm_provider=LlmProviders.BYTEZ, params={})
|
|
|
|
try:
|
|
response = await client.post(
|
|
api_base,
|
|
headers=headers,
|
|
data=json.dumps(data),
|
|
stream=True,
|
|
logging_obj=logging_obj,
|
|
timeout=STREAMING_TIMEOUT,
|
|
)
|
|
except httpx.HTTPStatusError as e:
|
|
raise OCIError(status_code=e.response.status_code, message=e.response.text)
|
|
|
|
if response.status_code != 200:
|
|
raise OCIError(status_code=response.status_code, message=response.text)
|
|
|
|
completion_stream = response.aiter_text()
|
|
|
|
async def split_chunks(completion_stream: AsyncIterator[str]):
|
|
async for item in completion_stream:
|
|
for chunk in item.split("\n\n"):
|
|
if not chunk:
|
|
continue
|
|
yield chunk.strip()
|
|
|
|
streaming_response = OCIStreamWrapper(
|
|
completion_stream=split_chunks(completion_stream),
|
|
model=model,
|
|
custom_llm_provider=custom_llm_provider,
|
|
logging_obj=logging_obj,
|
|
)
|
|
return streaming_response
|
|
|
|
def get_error_class(
|
|
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
|
) -> BaseLLMException:
|
|
return OCIError(status_code=status_code, message=error_message)
|
|
|
|
|
|
open_ai_to_generic_oci_role_map: Dict[str, OCIRoles] = {
|
|
"system": "SYSTEM",
|
|
"user": "USER",
|
|
"assistant": "ASSISTANT",
|
|
"tool": "TOOL",
|
|
}
|
|
|
|
|
|
def adapt_messages_to_generic_oci_standard_content_message(
|
|
role: str, content: Union[str, list]
|
|
) -> OCIMessage:
|
|
new_content: List[OCIContentPartUnion] = []
|
|
if isinstance(content, str):
|
|
return OCIMessage(
|
|
role=open_ai_to_generic_oci_role_map[role],
|
|
content=[OCITextContentPart(text=content)],
|
|
toolCalls=None,
|
|
toolCallId=None,
|
|
)
|
|
|
|
# content is a list of content items:
|
|
# [
|
|
# {"type": "text", "text": "Hello"},
|
|
# {"type": "image_url", "image_url": "https://example.com/image.png"}
|
|
# ]
|
|
for content_item in content:
|
|
if not isinstance(content_item, dict):
|
|
raise Exception("Each content item must be a dictionary")
|
|
|
|
type = content_item.get("type")
|
|
if not isinstance(type, str):
|
|
raise Exception("Prop `type` is not a string")
|
|
|
|
if type not in ["text", "image_url"]:
|
|
raise Exception(f"Prop `{type}` is not supported")
|
|
|
|
if type == "text":
|
|
text = content_item.get("text")
|
|
if not isinstance(text, str):
|
|
raise Exception("Prop `text` is not a string")
|
|
new_content.append(OCITextContentPart(text=text))
|
|
|
|
elif type == "image_url":
|
|
image_url = content_item.get("image_url")
|
|
# Handle both OpenAI format (object with url) and string format
|
|
if isinstance(image_url, dict):
|
|
image_url = image_url.get("url")
|
|
if not isinstance(image_url, str):
|
|
raise Exception(
|
|
"Prop `image_url` must be a string or an object with a `url` property"
|
|
)
|
|
new_content.append(OCIImageContentPart(imageUrl=OCIImageUrl(url=image_url)))
|
|
|
|
return OCIMessage(
|
|
role=open_ai_to_generic_oci_role_map[role],
|
|
content=new_content,
|
|
toolCalls=None,
|
|
toolCallId=None,
|
|
)
|
|
|
|
|
|
def adapt_messages_to_generic_oci_standard_tool_call(
|
|
role: str, tool_calls: list
|
|
) -> OCIMessage:
|
|
tool_calls_formated = []
|
|
for tool_call in tool_calls:
|
|
if not isinstance(tool_call, dict):
|
|
raise Exception("Each tool call must be a dictionary")
|
|
|
|
if tool_call.get("type") != "function":
|
|
raise Exception("OCI only supports function tools")
|
|
|
|
tool_call_id = tool_call.get("id")
|
|
if not isinstance(tool_call_id, str):
|
|
raise Exception("Prop `id` is not a string")
|
|
|
|
tool_function = tool_call.get("function")
|
|
if not isinstance(tool_function, dict):
|
|
raise Exception("Prop `function` is not a dictionary")
|
|
|
|
function_name = tool_function.get("name")
|
|
if not isinstance(function_name, str):
|
|
raise Exception("Prop `name` is not a string")
|
|
|
|
arguments = tool_call["function"].get("arguments", "{}")
|
|
if not isinstance(arguments, str):
|
|
raise Exception("Prop `arguments` is not a string")
|
|
|
|
# tool_calls_formated.append(OCIToolCall(
|
|
# id=tool_call_id,
|
|
# type="FUNCTION",
|
|
# function=OCIFunction(
|
|
# name=function_name,
|
|
# arguments=arguments
|
|
# )
|
|
# ))
|
|
|
|
tool_calls_formated.append(
|
|
OCIToolCall(
|
|
id=tool_call_id,
|
|
type="FUNCTION",
|
|
name=function_name,
|
|
arguments=arguments,
|
|
)
|
|
)
|
|
|
|
return OCIMessage(
|
|
role=open_ai_to_generic_oci_role_map[role],
|
|
content=None,
|
|
toolCalls=tool_calls_formated,
|
|
toolCallId=None,
|
|
)
|
|
|
|
|
|
def adapt_messages_to_generic_oci_standard_tool_response(
|
|
role: str, tool_call_id: str, content: str
|
|
) -> OCIMessage:
|
|
return OCIMessage(
|
|
role=open_ai_to_generic_oci_role_map[role],
|
|
content=[OCITextContentPart(text=content)],
|
|
toolCalls=None,
|
|
toolCallId=tool_call_id,
|
|
)
|
|
|
|
|
|
def adapt_messages_to_generic_oci_standard(
|
|
messages: List[AllMessageValues],
|
|
) -> List[OCIMessage]:
|
|
new_messages = []
|
|
for message in messages:
|
|
role = message["role"]
|
|
content = message.get("content")
|
|
tool_calls = message.get("tool_calls")
|
|
tool_call_id = message.get("tool_call_id")
|
|
|
|
if role == "assistant" and tool_calls is not None:
|
|
if not isinstance(tool_calls, list):
|
|
raise Exception("Prop `tool_calls` must be a list of tool calls")
|
|
new_messages.append(
|
|
adapt_messages_to_generic_oci_standard_tool_call(role, tool_calls)
|
|
)
|
|
|
|
elif role in ["system", "user", "assistant"] and content is not None:
|
|
if not isinstance(content, (str, list)):
|
|
raise Exception(
|
|
"Prop `content` must be a string or a list of content items"
|
|
)
|
|
new_messages.append(
|
|
adapt_messages_to_generic_oci_standard_content_message(role, content)
|
|
)
|
|
|
|
elif role == "tool":
|
|
if not isinstance(tool_call_id, str):
|
|
raise Exception("Prop `tool_call_id` is required and must be a string")
|
|
if not isinstance(content, str):
|
|
raise Exception("Prop `content` is not a string")
|
|
new_messages.append(
|
|
adapt_messages_to_generic_oci_standard_tool_response(
|
|
role, tool_call_id, content
|
|
)
|
|
)
|
|
|
|
return new_messages
|
|
|
|
|
|
def adapt_tool_definition_to_oci_standard(tools: List[Dict], vendor: OCIVendors):
|
|
new_tools = []
|
|
for tool in tools:
|
|
if tool["type"] != "function":
|
|
raise Exception("OCI only supports function tools")
|
|
|
|
tool_function = tool.get("function")
|
|
if not isinstance(tool_function, dict):
|
|
raise Exception("Prop `function` is not a dictionary")
|
|
|
|
new_tool = OCIToolDefinition(
|
|
type="FUNCTION",
|
|
name=tool_function.get("name"),
|
|
description=tool_function.get("description", ""),
|
|
parameters=tool_function.get("parameters", {}),
|
|
)
|
|
new_tools.append(new_tool)
|
|
|
|
return new_tools
|
|
|
|
|
|
def adapt_tools_to_openai_standard(
|
|
tools: List[OCIToolCall],
|
|
) -> List[ChatCompletionMessageToolCall]:
|
|
new_tools = []
|
|
for tool in tools:
|
|
new_tool = ChatCompletionMessageToolCall(
|
|
id=tool.id,
|
|
type="function",
|
|
function={
|
|
"name": tool.name,
|
|
"arguments": tool.arguments,
|
|
},
|
|
)
|
|
new_tools.append(new_tool)
|
|
return new_tools
|
|
|
|
|
|
class OCIStreamWrapper(CustomStreamWrapper):
|
|
"""
|
|
Custom stream wrapper for OCI responses.
|
|
This class is used to handle streaming responses from OCI's API.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
**kwargs: Any,
|
|
):
|
|
super().__init__(**kwargs)
|
|
|
|
def chunk_creator(self, chunk: Any):
|
|
if not isinstance(chunk, str):
|
|
raise ValueError(f"Chunk is not a string: {chunk}")
|
|
if not chunk.startswith("data:"):
|
|
raise ValueError(f"Chunk does not start with 'data:': {chunk}")
|
|
dict_chunk = json.loads(chunk[5:]) # Remove 'data: ' prefix and parse JSON
|
|
|
|
# Check if this is a Cohere stream chunk
|
|
if "apiFormat" in dict_chunk and dict_chunk.get("apiFormat") == "COHERE":
|
|
return self._handle_cohere_stream_chunk(dict_chunk)
|
|
else:
|
|
return self._handle_generic_stream_chunk(dict_chunk)
|
|
|
|
def _handle_cohere_stream_chunk(self, dict_chunk: dict):
|
|
"""Handle Cohere-specific streaming chunks."""
|
|
try:
|
|
typed_chunk = CohereStreamChunk(**dict_chunk)
|
|
except TypeError as e:
|
|
raise ValueError(f"Chunk cannot be casted to CohereStreamChunk: {str(e)}")
|
|
|
|
if typed_chunk.index is None:
|
|
typed_chunk.index = 0
|
|
|
|
# Extract text content
|
|
text = typed_chunk.text or ""
|
|
|
|
# Map finish reason to standard format
|
|
finish_reason = typed_chunk.finishReason
|
|
if finish_reason == "COMPLETE":
|
|
finish_reason = "stop"
|
|
elif finish_reason == "MAX_TOKENS":
|
|
finish_reason = "length"
|
|
elif finish_reason is None:
|
|
finish_reason = None
|
|
else:
|
|
finish_reason = "stop"
|
|
|
|
# For Cohere, we don't have tool calls in the streaming format
|
|
tool_calls = None
|
|
|
|
return ModelResponseStream(
|
|
choices=[
|
|
StreamingChoices(
|
|
index=typed_chunk.index if typed_chunk.index else 0,
|
|
delta=Delta(
|
|
content=text,
|
|
tool_calls=tool_calls,
|
|
provider_specific_fields=None,
|
|
thinking_blocks=None,
|
|
reasoning_content=None,
|
|
),
|
|
finish_reason=finish_reason,
|
|
)
|
|
]
|
|
)
|
|
|
|
def _handle_generic_stream_chunk(self, dict_chunk: dict):
|
|
"""Handle generic OCI streaming chunks."""
|
|
# Fix missing required fields in tool calls before Pydantic validation
|
|
# OCI streams tool calls progressively, so early chunks may be missing required fields
|
|
if dict_chunk.get("message") and dict_chunk["message"].get("toolCalls"):
|
|
for tool_call in dict_chunk["message"]["toolCalls"]:
|
|
if "arguments" not in tool_call:
|
|
tool_call["arguments"] = ""
|
|
if "id" not in tool_call:
|
|
tool_call["id"] = ""
|
|
if "name" not in tool_call:
|
|
tool_call["name"] = ""
|
|
|
|
try:
|
|
typed_chunk = OCIStreamChunk(**dict_chunk)
|
|
except TypeError as e:
|
|
raise ValueError(f"Chunk cannot be casted to OCIStreamChunk: {str(e)}")
|
|
|
|
if typed_chunk.index is None:
|
|
typed_chunk.index = 0
|
|
|
|
text = ""
|
|
if typed_chunk.message and typed_chunk.message.content:
|
|
for item in typed_chunk.message.content:
|
|
if isinstance(item, OCITextContentPart):
|
|
text += item.text
|
|
elif isinstance(item, OCIImageContentPart):
|
|
raise ValueError(
|
|
"OCI does not support image content in streaming responses"
|
|
)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported content type in OCI response: {item.type}"
|
|
)
|
|
|
|
tool_calls = None
|
|
if typed_chunk.message and typed_chunk.message.toolCalls:
|
|
tool_calls = adapt_tools_to_openai_standard(typed_chunk.message.toolCalls)
|
|
|
|
return ModelResponseStream(
|
|
choices=[
|
|
StreamingChoices(
|
|
index=typed_chunk.index if typed_chunk.index else 0,
|
|
delta=Delta(
|
|
content=text,
|
|
tool_calls=(
|
|
[tool.model_dump() for tool in tool_calls]
|
|
if tool_calls
|
|
else None
|
|
),
|
|
provider_specific_fields=None, # OCI does not have provider specific fields in the response
|
|
thinking_blocks=None, # OCI does not have thinking blocks in the response
|
|
reasoning_content=None, # OCI does not have reasoning content in the response
|
|
),
|
|
finish_reason=typed_chunk.finishReason,
|
|
)
|
|
]
|
|
)
|