chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,271 @@
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import XAI_API_BASE
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
filter_value_from_dict,
|
||||
strip_name_from_messages,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import (
|
||||
Choices,
|
||||
ModelResponse,
|
||||
ModelResponseStream,
|
||||
PromptTokensDetailsWrapper,
|
||||
Usage,
|
||||
)
|
||||
|
||||
from ...openai.chat.gpt_transformation import (
|
||||
OpenAIChatCompletionStreamingHandler,
|
||||
OpenAIGPTConfig,
|
||||
)
|
||||
|
||||
|
||||
class XAIChatConfig(OpenAIGPTConfig):
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "xai"
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = api_base or get_secret_str("XAI_API_BASE") or XAI_API_BASE # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("XAI_API_KEY")
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_openai_params = [
|
||||
"logit_bias",
|
||||
"logprobs",
|
||||
"max_tokens",
|
||||
"n",
|
||||
"presence_penalty",
|
||||
"response_format",
|
||||
"seed",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"temperature",
|
||||
"tool_choice",
|
||||
"tools",
|
||||
"top_logprobs",
|
||||
"top_p",
|
||||
"user",
|
||||
"web_search_options",
|
||||
]
|
||||
# for some reason, grok-3-mini does not support stop tokens
|
||||
#########################################################
|
||||
# stop tokens check
|
||||
#########################################################
|
||||
if self._supports_stop_reason(model):
|
||||
base_openai_params.append("stop")
|
||||
|
||||
#########################################################
|
||||
# frequency penalty check
|
||||
#########################################################
|
||||
if self._supports_frequency_penalty(model):
|
||||
base_openai_params.append("frequency_penalty")
|
||||
|
||||
#########################################################
|
||||
# reasoning check
|
||||
#########################################################
|
||||
try:
|
||||
if litellm.supports_reasoning(
|
||||
model=model, custom_llm_provider=self.custom_llm_provider
|
||||
):
|
||||
base_openai_params.append("reasoning_effort")
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error checking if model supports reasoning: {e}")
|
||||
|
||||
return base_openai_params
|
||||
|
||||
def _supports_stop_reason(self, model: str) -> bool:
|
||||
if "grok-3-mini" in model:
|
||||
return False
|
||||
elif "grok-4" in model:
|
||||
return False
|
||||
elif "grok-code-fast" in model:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _supports_frequency_penalty(self, model: str) -> bool:
|
||||
"""
|
||||
From manual testing grok-4 does not support `frequency_penalty`
|
||||
|
||||
When sent the model fails from xAI API
|
||||
"""
|
||||
if "grok-4" in model:
|
||||
return False
|
||||
if "grok-code-fast" in model:
|
||||
return False
|
||||
return True
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool = False,
|
||||
) -> dict:
|
||||
supported_openai_params = self.get_supported_openai_params(model=model)
|
||||
for param, value in non_default_params.items():
|
||||
if param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
elif param == "tools" and value is not None:
|
||||
tools = []
|
||||
for tool in value:
|
||||
tool = filter_value_from_dict(tool, "strict")
|
||||
if tool is not None:
|
||||
tools.append(tool)
|
||||
if len(tools) > 0:
|
||||
optional_params["tools"] = tools
|
||||
elif param in supported_openai_params:
|
||||
if value is not None:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
return XAIChatCompletionStreamingHandler(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def transform_request(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Handle https://github.com/BerriAI/litellm/issues/9720
|
||||
|
||||
Filter out 'name' from messages
|
||||
"""
|
||||
messages = strip_name_from_messages(messages)
|
||||
return super().transform_request(
|
||||
model, messages, optional_params, litellm_params, headers
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _fix_choice_finish_reason_for_tool_calls(choice: Choices) -> None:
|
||||
"""
|
||||
Helper to fix finish_reason for tool calls when XAI API returns empty string.
|
||||
|
||||
XAI API returns empty string for finish_reason when using tools,
|
||||
so we need to set it to "tool_calls" when tool_calls are present.
|
||||
"""
|
||||
if (
|
||||
choice.finish_reason == ""
|
||||
and choice.message.tool_calls
|
||||
and len(choice.message.tool_calls) > 0
|
||||
):
|
||||
choice.finish_reason = "tool_calls"
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
"""
|
||||
Transform the response from the XAI API.
|
||||
|
||||
XAI API returns empty string for finish_reason when using tools,
|
||||
so we need to fix this after the standard OpenAI transformation.
|
||||
|
||||
Also handles X.AI web search usage tracking by extracting num_sources_used.
|
||||
"""
|
||||
|
||||
# First, let the parent class handle the standard transformation
|
||||
response = super().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
# Fix finish_reason for tool calls across all choices
|
||||
if response.choices:
|
||||
for choice in response.choices:
|
||||
if isinstance(choice, Choices):
|
||||
self._fix_choice_finish_reason_for_tool_calls(choice)
|
||||
|
||||
# Handle X.AI web search usage tracking
|
||||
try:
|
||||
raw_response_json = raw_response.json()
|
||||
self._enhance_usage_with_xai_web_search_fields(response, raw_response_json)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error extracting X.AI web search usage: {e}")
|
||||
return response
|
||||
|
||||
def _enhance_usage_with_xai_web_search_fields(
|
||||
self, model_response: ModelResponse, raw_response_json: dict
|
||||
) -> None:
|
||||
"""
|
||||
Extract num_sources_used from X.AI response and map it to web_search_requests.
|
||||
"""
|
||||
if not hasattr(model_response, "usage") or model_response.usage is None:
|
||||
return
|
||||
|
||||
usage: Usage = model_response.usage
|
||||
num_sources_used = None
|
||||
response_usage = raw_response_json.get("usage", {})
|
||||
if isinstance(response_usage, dict) and "num_sources_used" in response_usage:
|
||||
num_sources_used = response_usage.get("num_sources_used")
|
||||
|
||||
# Map num_sources_used to web_search_requests for cost detection
|
||||
if num_sources_used is not None and num_sources_used > 0:
|
||||
if usage.prompt_tokens_details is None:
|
||||
usage.prompt_tokens_details = PromptTokensDetailsWrapper()
|
||||
|
||||
usage.prompt_tokens_details.web_search_requests = int(num_sources_used)
|
||||
setattr(usage, "num_sources_used", int(num_sources_used))
|
||||
verbose_logger.debug(f"X.AI web search sources used: {num_sources_used}")
|
||||
|
||||
|
||||
class XAIChatCompletionStreamingHandler(OpenAIChatCompletionStreamingHandler):
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
"""
|
||||
Handle xAI-specific streaming behavior.
|
||||
|
||||
xAI Grok sends a final chunk with empty choices array but with usage data
|
||||
when stream_options={"include_usage": True} is set.
|
||||
|
||||
Example from xAI API:
|
||||
{"id":"...","object":"chat.completion.chunk","created":...,"model":"grok-4-1-fast-non-reasoning",
|
||||
"choices":[],"usage":{"prompt_tokens":171,"completion_tokens":2,"total_tokens":173,...}}
|
||||
"""
|
||||
# Handle chunks with empty choices but with usage data
|
||||
choices = chunk.get("choices", [])
|
||||
if len(choices) == 0 and "usage" in chunk:
|
||||
# xAI sends usage in a chunk with empty choices array
|
||||
# Add a dummy choice with empty delta to ensure proper processing
|
||||
chunk["choices"] = [{"index": 0, "delta": {}, "finish_reason": None}]
|
||||
|
||||
return super().chunk_parser(chunk)
|
||||
@@ -0,0 +1,83 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ProviderSpecificModelInfo
|
||||
|
||||
|
||||
class XAIModelInfo(BaseLLMModelInfo):
|
||||
def get_provider_info(
|
||||
self,
|
||||
model: str,
|
||||
) -> Optional[ProviderSpecificModelInfo]:
|
||||
"""
|
||||
Default values all models of this provider support.
|
||||
"""
|
||||
return {
|
||||
"supports_web_search": True,
|
||||
}
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Ensure Content-Type is set to application/json
|
||||
if "content-type" not in headers and "Content-Type" not in headers:
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
return headers
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
return api_base or get_secret_str("XAI_API_BASE") or "https://api.x.ai"
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(api_key: Optional[str] = None) -> Optional[str]:
|
||||
return api_key or get_secret_str("XAI_API_KEY")
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> Optional[str]:
|
||||
return model.replace("xai/", "")
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
api_base = self.get_api_base(api_base)
|
||||
api_key = self.get_api_key(api_key)
|
||||
if api_base is None or api_key is None:
|
||||
raise ValueError(
|
||||
"XAI_API_BASE or XAI_API_KEY is not set. Please set the environment variable, to query XAI's `/models` endpoint."
|
||||
)
|
||||
response = litellm.module_level_client.get(
|
||||
url=f"{api_base}/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError:
|
||||
raise Exception(
|
||||
f"Failed to fetch models from XAI. Status code: {response.status_code}, Response: {response.text}"
|
||||
)
|
||||
|
||||
models = response.json()["data"]
|
||||
|
||||
litellm_model_names = []
|
||||
for model in models:
|
||||
stripped_model_name = model["id"]
|
||||
litellm_model_name = "xai/" + stripped_model_name
|
||||
litellm_model_names.append(litellm_model_name)
|
||||
return litellm_model_names
|
||||
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Helper util for handling XAI-specific cost calculation
|
||||
- Uses the generic cost calculator which already handles tiered pricing correctly
|
||||
- Handles XAI-specific reasoning token billing (billed as part of completion tokens)
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from litellm.types.utils import Usage
|
||||
from litellm.litellm_core_utils.llm_cost_calc.utils import generic_cost_per_token
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.utils import ModelInfo
|
||||
|
||||
|
||||
def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
||||
"""
|
||||
Calculates the cost per token for a given XAI model, prompt tokens, and completion tokens.
|
||||
Uses the generic cost calculator for all pricing logic, with XAI-specific reasoning token handling.
|
||||
|
||||
Input:
|
||||
- model: str, the model name without provider prefix
|
||||
- usage: LiteLLM Usage block, containing XAI-specific usage information
|
||||
|
||||
Returns:
|
||||
Tuple[float, float] - prompt_cost_in_usd, completion_cost_in_usd
|
||||
"""
|
||||
# XAI-specific completion cost calculation
|
||||
# For XAI models, completion is billed as (visible completion tokens + reasoning tokens)
|
||||
completion_tokens = int(getattr(usage, "completion_tokens", 0) or 0)
|
||||
reasoning_tokens = 0
|
||||
if hasattr(usage, "completion_tokens_details") and usage.completion_tokens_details:
|
||||
reasoning_tokens = int(
|
||||
getattr(usage.completion_tokens_details, "reasoning_tokens", 0) or 0
|
||||
)
|
||||
|
||||
total_completion_tokens = completion_tokens + reasoning_tokens
|
||||
|
||||
modified_usage = Usage(
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=total_completion_tokens,
|
||||
total_tokens=usage.total_tokens,
|
||||
prompt_tokens_details=usage.prompt_tokens_details,
|
||||
completion_tokens_details=None,
|
||||
)
|
||||
|
||||
prompt_cost, completion_cost = generic_cost_per_token(
|
||||
model=model, usage=modified_usage, custom_llm_provider="xai"
|
||||
)
|
||||
|
||||
return prompt_cost, completion_cost
|
||||
|
||||
|
||||
def cost_per_web_search_request(usage: "Usage", model_info: "ModelInfo") -> float:
|
||||
"""
|
||||
Calculate the cost of web search requests for X.AI models.
|
||||
|
||||
X.AI Live Search costs $25 per 1,000 sources used.
|
||||
Each source costs $0.025.
|
||||
|
||||
The number of sources is stored in prompt_tokens_details.web_search_requests
|
||||
by the transformation layer to be compatible with the existing detection system.
|
||||
"""
|
||||
# Cost per source used: $25 per 1,000 sources = $0.025 per source
|
||||
cost_per_source = 25.0 / 1000.0 # $0.025
|
||||
|
||||
num_sources_used = 0
|
||||
|
||||
if (
|
||||
hasattr(usage, "prompt_tokens_details")
|
||||
and usage.prompt_tokens_details is not None
|
||||
and hasattr(usage.prompt_tokens_details, "web_search_requests")
|
||||
and usage.prompt_tokens_details.web_search_requests is not None
|
||||
):
|
||||
num_sources_used = int(usage.prompt_tokens_details.web_search_requests)
|
||||
|
||||
# Fallback: try to get from num_sources_used if set directly
|
||||
elif hasattr(usage, "num_sources_used") and usage.num_sources_used is not None:
|
||||
num_sources_used = int(usage.num_sources_used)
|
||||
|
||||
total_cost = cost_per_source * num_sources_used
|
||||
|
||||
return total_cost
|
||||
@@ -0,0 +1,5 @@
|
||||
"""xAI Realtime API handler."""
|
||||
|
||||
from .handler import XAIRealtime
|
||||
|
||||
__all__ = ["XAIRealtime"]
|
||||
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
This file contains the handler for xAI's Grok Voice Agent API `/v1/realtime` endpoint.
|
||||
|
||||
xAI's Realtime API is fully OpenAI-compatible, so we inherit from OpenAIRealtime
|
||||
and only override the configuration differences.
|
||||
|
||||
This requires websockets, and is currently only supported on LiteLLM Proxy.
|
||||
"""
|
||||
|
||||
from litellm.constants import XAI_API_BASE
|
||||
|
||||
from ...openai.realtime.handler import OpenAIRealtime
|
||||
|
||||
|
||||
class XAIRealtime(OpenAIRealtime):
|
||||
"""
|
||||
Handler for xAI Grok Voice Agent API.
|
||||
|
||||
xAI's Realtime API uses the same WebSocket protocol as OpenAI but with:
|
||||
- Different endpoint: wss://api.x.ai/v1/realtime (via _get_default_api_base)
|
||||
- No OpenAI-Beta header required (via _get_additional_headers)
|
||||
- Model: grok-4-1-fast-non-reasoning
|
||||
|
||||
All WebSocket logic is inherited from OpenAIRealtime.
|
||||
"""
|
||||
|
||||
def _get_default_api_base(self) -> str:
|
||||
"""xAI uses a different API base URL."""
|
||||
return XAI_API_BASE
|
||||
|
||||
def _get_additional_headers(self, api_key: str) -> dict:
|
||||
"""
|
||||
xAI does NOT require the OpenAI-Beta header.
|
||||
Only send Authorization header.
|
||||
"""
|
||||
return {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
@@ -0,0 +1,259 @@
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import XAI_API_BASE
|
||||
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams
|
||||
from litellm.types.llms.xai import XAIWebSearchTool, XAIXSearchTool
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class XAIResponsesAPIConfig(OpenAIResponsesAPIConfig):
|
||||
"""
|
||||
Configuration for XAI's Responses API.
|
||||
|
||||
Inherits from OpenAIResponsesAPIConfig since XAI's Responses API is largely
|
||||
compatible with OpenAI's, with a few differences:
|
||||
- Does not support the 'instructions' parameter
|
||||
- Requires code_interpreter tools to have 'container' field removed
|
||||
- Recommends store=false when sending images
|
||||
|
||||
Reference: https://docs.x.ai/docs/api-reference#create-new-response
|
||||
"""
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.XAI
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get supported parameters for XAI Responses API.
|
||||
|
||||
XAI supports most OpenAI Responses API params except 'instructions'.
|
||||
"""
|
||||
supported_params = super().get_supported_openai_params(model)
|
||||
|
||||
# Remove 'instructions' as it's not supported by XAI
|
||||
if "instructions" in supported_params:
|
||||
supported_params.remove("instructions")
|
||||
|
||||
return supported_params
|
||||
|
||||
def _transform_web_search_tool(
|
||||
self, tool: Dict[str, Any]
|
||||
) -> Union[XAIWebSearchTool, Dict[str, Any]]:
|
||||
"""
|
||||
Transform web_search tool to XAI format.
|
||||
|
||||
XAI supports web_search with specific filters:
|
||||
- allowed_domains (max 5)
|
||||
- excluded_domains (max 5)
|
||||
- enable_image_understanding
|
||||
|
||||
XAI does NOT support search_context_size (OpenAI-specific).
|
||||
"""
|
||||
xai_tool: Dict[str, Any] = {"type": "web_search"}
|
||||
|
||||
# Remove search_context_size if present (not supported by XAI)
|
||||
if "search_context_size" in tool:
|
||||
verbose_logger.info(
|
||||
"XAI does not support 'search_context_size' parameter. Removing it from web_search tool."
|
||||
)
|
||||
|
||||
# Handle filters (XAI-specific structure)
|
||||
filters = {}
|
||||
if "allowed_domains" in tool:
|
||||
allowed_domains = tool["allowed_domains"]
|
||||
filters["allowed_domains"] = allowed_domains
|
||||
|
||||
if "excluded_domains" in tool:
|
||||
excluded_domains = tool["excluded_domains"]
|
||||
filters["excluded_domains"] = excluded_domains
|
||||
|
||||
# Add filters if any were specified
|
||||
if filters:
|
||||
xai_tool["filters"] = filters
|
||||
|
||||
# Handle enable_image_understanding (top-level in XAI format)
|
||||
if "enable_image_understanding" in tool:
|
||||
xai_tool["enable_image_understanding"] = tool["enable_image_understanding"]
|
||||
|
||||
return xai_tool
|
||||
|
||||
def _transform_x_search_tool(
|
||||
self, tool: Dict[str, Any]
|
||||
) -> Union[XAIXSearchTool, Dict[str, Any]]:
|
||||
"""
|
||||
Transform x_search tool to XAI format.
|
||||
|
||||
XAI supports x_search with specific parameters:
|
||||
- allowed_x_handles (max 10)
|
||||
- excluded_x_handles (max 10)
|
||||
- from_date (ISO8601: YYYY-MM-DD)
|
||||
- to_date (ISO8601: YYYY-MM-DD)
|
||||
- enable_image_understanding
|
||||
- enable_video_understanding
|
||||
"""
|
||||
xai_tool: Dict[str, Any] = {"type": "x_search"}
|
||||
|
||||
# Handle allowed_x_handles
|
||||
if "allowed_x_handles" in tool:
|
||||
allowed_handles = tool["allowed_x_handles"]
|
||||
xai_tool["allowed_x_handles"] = allowed_handles
|
||||
|
||||
# Handle excluded_x_handles
|
||||
if "excluded_x_handles" in tool:
|
||||
excluded_handles = tool["excluded_x_handles"]
|
||||
xai_tool["excluded_x_handles"] = excluded_handles
|
||||
|
||||
# Handle date range
|
||||
if "from_date" in tool:
|
||||
xai_tool["from_date"] = tool["from_date"]
|
||||
|
||||
if "to_date" in tool:
|
||||
xai_tool["to_date"] = tool["to_date"]
|
||||
|
||||
# Handle media understanding flags
|
||||
if "enable_image_understanding" in tool:
|
||||
xai_tool["enable_image_understanding"] = tool["enable_image_understanding"]
|
||||
|
||||
if "enable_video_understanding" in tool:
|
||||
xai_tool["enable_video_understanding"] = tool["enable_video_understanding"]
|
||||
|
||||
return xai_tool
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
response_api_optional_params: ResponsesAPIOptionalRequestParams,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> Dict:
|
||||
"""
|
||||
Map parameters for XAI Responses API.
|
||||
|
||||
Handles XAI-specific transformations:
|
||||
1. Drops 'instructions' parameter (not supported)
|
||||
2. Transforms code_interpreter tools to remove 'container' field
|
||||
3. Transforms web_search tools to XAI format (removes search_context_size, adds filters)
|
||||
4. Transforms x_search tools to XAI format
|
||||
5. Sets store=false when images are detected (recommended by XAI)
|
||||
"""
|
||||
params = dict(response_api_optional_params)
|
||||
|
||||
# Drop instructions parameter (not supported by XAI)
|
||||
if "instructions" in params:
|
||||
verbose_logger.debug(
|
||||
"XAI Responses API does not support 'instructions' parameter. Dropping it."
|
||||
)
|
||||
params.pop("instructions")
|
||||
|
||||
if "metadata" in params:
|
||||
verbose_logger.debug(
|
||||
"XAI Responses API does not support 'metadata' parameter. Dropping it."
|
||||
)
|
||||
params.pop("metadata")
|
||||
|
||||
# Transform tools
|
||||
if "tools" in params and params["tools"]:
|
||||
tools_list = params["tools"]
|
||||
# Ensure tools is a list for iteration
|
||||
if not isinstance(tools_list, list):
|
||||
tools_list = [tools_list]
|
||||
|
||||
transformed_tools: List[Any] = []
|
||||
for tool in tools_list:
|
||||
if isinstance(tool, dict):
|
||||
tool_type = tool.get("type")
|
||||
|
||||
if tool_type == "code_interpreter":
|
||||
# XAI supports code_interpreter but doesn't use the container field
|
||||
verbose_logger.debug(
|
||||
"XAI: Transforming code_interpreter tool, removing container field"
|
||||
)
|
||||
transformed_tools.append({"type": "code_interpreter"})
|
||||
|
||||
elif tool_type == "web_search":
|
||||
# Transform web_search to XAI format
|
||||
verbose_logger.debug(
|
||||
"XAI: Transforming web_search tool to XAI format"
|
||||
)
|
||||
transformed_tools.append(self._transform_web_search_tool(tool))
|
||||
|
||||
elif tool_type == "x_search":
|
||||
# Transform x_search to XAI format
|
||||
verbose_logger.debug(
|
||||
"XAI: Transforming x_search tool to XAI format"
|
||||
)
|
||||
transformed_tools.append(self._transform_x_search_tool(tool))
|
||||
|
||||
else:
|
||||
# Keep other tools as-is
|
||||
transformed_tools.append(tool)
|
||||
else:
|
||||
transformed_tools.append(tool)
|
||||
|
||||
params["tools"] = transformed_tools
|
||||
|
||||
return params
|
||||
|
||||
def validate_environment(
|
||||
self, headers: dict, model: str, litellm_params: Optional[GenericLiteLLMParams]
|
||||
) -> dict:
|
||||
"""
|
||||
Validate environment and set up headers for XAI API.
|
||||
|
||||
Uses XAI_API_KEY from environment or litellm_params.
|
||||
"""
|
||||
litellm_params = litellm_params or GenericLiteLLMParams()
|
||||
api_key = (
|
||||
litellm_params.api_key or litellm.api_key or get_secret_str("XAI_API_KEY")
|
||||
)
|
||||
|
||||
if not api_key:
|
||||
raise ValueError(
|
||||
"XAI API key is required. Set XAI_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
|
||||
headers.update(
|
||||
{
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
}
|
||||
)
|
||||
return headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for XAI Responses API endpoint.
|
||||
|
||||
Returns:
|
||||
str: The full URL for the XAI /responses endpoint
|
||||
"""
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("XAI_API_BASE")
|
||||
or XAI_API_BASE
|
||||
)
|
||||
|
||||
# Remove trailing slashes
|
||||
api_base = api_base.rstrip("/")
|
||||
|
||||
return f"{api_base}/responses"
|
||||
|
||||
def supports_native_websocket(self) -> bool:
|
||||
"""XAI does not support native WebSocket for Responses API"""
|
||||
return False
|
||||
Reference in New Issue
Block a user