Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/xai/chat/transformation.py
2026-03-26 20:06:14 +08:00

272 lines
9.6 KiB
Python

from typing import Any, AsyncIterator, Iterator, List, Optional, Tuple, Union
import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.constants import XAI_API_BASE
from litellm.litellm_core_utils.prompt_templates.common_utils import (
filter_value_from_dict,
strip_name_from_messages,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.utils import (
Choices,
ModelResponse,
ModelResponseStream,
PromptTokensDetailsWrapper,
Usage,
)
from ...openai.chat.gpt_transformation import (
OpenAIChatCompletionStreamingHandler,
OpenAIGPTConfig,
)
class XAIChatConfig(OpenAIGPTConfig):
@property
def custom_llm_provider(self) -> Optional[str]:
return "xai"
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
api_base = api_base or get_secret_str("XAI_API_BASE") or XAI_API_BASE # type: ignore
dynamic_api_key = api_key or get_secret_str("XAI_API_KEY")
return api_base, dynamic_api_key
def get_supported_openai_params(self, model: str) -> list:
base_openai_params = [
"logit_bias",
"logprobs",
"max_tokens",
"n",
"presence_penalty",
"response_format",
"seed",
"stream",
"stream_options",
"temperature",
"tool_choice",
"tools",
"top_logprobs",
"top_p",
"user",
"web_search_options",
]
# for some reason, grok-3-mini does not support stop tokens
#########################################################
# stop tokens check
#########################################################
if self._supports_stop_reason(model):
base_openai_params.append("stop")
#########################################################
# frequency penalty check
#########################################################
if self._supports_frequency_penalty(model):
base_openai_params.append("frequency_penalty")
#########################################################
# reasoning check
#########################################################
try:
if litellm.supports_reasoning(
model=model, custom_llm_provider=self.custom_llm_provider
):
base_openai_params.append("reasoning_effort")
except Exception as e:
verbose_logger.debug(f"Error checking if model supports reasoning: {e}")
return base_openai_params
def _supports_stop_reason(self, model: str) -> bool:
if "grok-3-mini" in model:
return False
elif "grok-4" in model:
return False
elif "grok-code-fast" in model:
return False
return True
def _supports_frequency_penalty(self, model: str) -> bool:
"""
From manual testing grok-4 does not support `frequency_penalty`
When sent the model fails from xAI API
"""
if "grok-4" in model:
return False
if "grok-code-fast" in model:
return False
return True
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool = False,
) -> dict:
supported_openai_params = self.get_supported_openai_params(model=model)
for param, value in non_default_params.items():
if param == "max_completion_tokens":
optional_params["max_tokens"] = value
elif param == "tools" and value is not None:
tools = []
for tool in value:
tool = filter_value_from_dict(tool, "strict")
if tool is not None:
tools.append(tool)
if len(tools) > 0:
optional_params["tools"] = tools
elif param in supported_openai_params:
if value is not None:
optional_params[param] = value
return optional_params
def get_model_response_iterator(
self,
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
sync_stream: bool,
json_mode: Optional[bool] = False,
) -> Any:
return XAIChatCompletionStreamingHandler(
streaming_response=streaming_response,
sync_stream=sync_stream,
json_mode=json_mode,
)
def transform_request(
self,
model: str,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
headers: dict,
) -> dict:
"""
Handle https://github.com/BerriAI/litellm/issues/9720
Filter out 'name' from messages
"""
messages = strip_name_from_messages(messages)
return super().transform_request(
model, messages, optional_params, litellm_params, headers
)
@staticmethod
def _fix_choice_finish_reason_for_tool_calls(choice: Choices) -> None:
"""
Helper to fix finish_reason for tool calls when XAI API returns empty string.
XAI API returns empty string for finish_reason when using tools,
so we need to set it to "tool_calls" when tool_calls are present.
"""
if (
choice.finish_reason == ""
and choice.message.tool_calls
and len(choice.message.tool_calls) > 0
):
choice.finish_reason = "tool_calls"
def transform_response(
self,
model: str,
raw_response: httpx.Response,
model_response: ModelResponse,
logging_obj,
request_data: dict,
messages: List[AllMessageValues],
optional_params: dict,
litellm_params: dict,
encoding,
api_key: Optional[str] = None,
json_mode: Optional[bool] = None,
) -> ModelResponse:
"""
Transform the response from the XAI API.
XAI API returns empty string for finish_reason when using tools,
so we need to fix this after the standard OpenAI transformation.
Also handles X.AI web search usage tracking by extracting num_sources_used.
"""
# First, let the parent class handle the standard transformation
response = super().transform_response(
model=model,
raw_response=raw_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=request_data,
messages=messages,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
api_key=api_key,
json_mode=json_mode,
)
# Fix finish_reason for tool calls across all choices
if response.choices:
for choice in response.choices:
if isinstance(choice, Choices):
self._fix_choice_finish_reason_for_tool_calls(choice)
# Handle X.AI web search usage tracking
try:
raw_response_json = raw_response.json()
self._enhance_usage_with_xai_web_search_fields(response, raw_response_json)
except Exception as e:
verbose_logger.debug(f"Error extracting X.AI web search usage: {e}")
return response
def _enhance_usage_with_xai_web_search_fields(
self, model_response: ModelResponse, raw_response_json: dict
) -> None:
"""
Extract num_sources_used from X.AI response and map it to web_search_requests.
"""
if not hasattr(model_response, "usage") or model_response.usage is None:
return
usage: Usage = model_response.usage
num_sources_used = None
response_usage = raw_response_json.get("usage", {})
if isinstance(response_usage, dict) and "num_sources_used" in response_usage:
num_sources_used = response_usage.get("num_sources_used")
# Map num_sources_used to web_search_requests for cost detection
if num_sources_used is not None and num_sources_used > 0:
if usage.prompt_tokens_details is None:
usage.prompt_tokens_details = PromptTokensDetailsWrapper()
usage.prompt_tokens_details.web_search_requests = int(num_sources_used)
setattr(usage, "num_sources_used", int(num_sources_used))
verbose_logger.debug(f"X.AI web search sources used: {num_sources_used}")
class XAIChatCompletionStreamingHandler(OpenAIChatCompletionStreamingHandler):
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
"""
Handle xAI-specific streaming behavior.
xAI Grok sends a final chunk with empty choices array but with usage data
when stream_options={"include_usage": True} is set.
Example from xAI API:
{"id":"...","object":"chat.completion.chunk","created":...,"model":"grok-4-1-fast-non-reasoning",
"choices":[],"usage":{"prompt_tokens":171,"completion_tokens":2,"total_tokens":173,...}}
"""
# Handle chunks with empty choices but with usage data
choices = chunk.get("choices", [])
if len(choices) == 0 and "usage" in chunk:
# xAI sends usage in a chunk with empty choices array
# Add a dummy choice with empty delta to ensure proper processing
chunk["choices"] = [{"index": 0, "delta": {}, "finish_reason": None}]
return super().chunk_parser(chunk)