from typing import Any, AsyncIterator, Iterator, List, Optional, Tuple, Union import httpx import litellm from litellm._logging import verbose_logger from litellm.constants import XAI_API_BASE from litellm.litellm_core_utils.prompt_templates.common_utils import ( filter_value_from_dict, strip_name_from_messages, ) from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import AllMessageValues from litellm.types.utils import ( Choices, ModelResponse, ModelResponseStream, PromptTokensDetailsWrapper, Usage, ) from ...openai.chat.gpt_transformation import ( OpenAIChatCompletionStreamingHandler, OpenAIGPTConfig, ) class XAIChatConfig(OpenAIGPTConfig): @property def custom_llm_provider(self) -> Optional[str]: return "xai" def _get_openai_compatible_provider_info( self, api_base: Optional[str], api_key: Optional[str] ) -> Tuple[Optional[str], Optional[str]]: api_base = api_base or get_secret_str("XAI_API_BASE") or XAI_API_BASE # type: ignore dynamic_api_key = api_key or get_secret_str("XAI_API_KEY") return api_base, dynamic_api_key def get_supported_openai_params(self, model: str) -> list: base_openai_params = [ "logit_bias", "logprobs", "max_tokens", "n", "presence_penalty", "response_format", "seed", "stream", "stream_options", "temperature", "tool_choice", "tools", "top_logprobs", "top_p", "user", "web_search_options", ] # for some reason, grok-3-mini does not support stop tokens ######################################################### # stop tokens check ######################################################### if self._supports_stop_reason(model): base_openai_params.append("stop") ######################################################### # frequency penalty check ######################################################### if self._supports_frequency_penalty(model): base_openai_params.append("frequency_penalty") ######################################################### # reasoning check ######################################################### try: if litellm.supports_reasoning( model=model, custom_llm_provider=self.custom_llm_provider ): base_openai_params.append("reasoning_effort") except Exception as e: verbose_logger.debug(f"Error checking if model supports reasoning: {e}") return base_openai_params def _supports_stop_reason(self, model: str) -> bool: if "grok-3-mini" in model: return False elif "grok-4" in model: return False elif "grok-code-fast" in model: return False return True def _supports_frequency_penalty(self, model: str) -> bool: """ From manual testing grok-4 does not support `frequency_penalty` When sent the model fails from xAI API """ if "grok-4" in model: return False if "grok-code-fast" in model: return False return True def map_openai_params( self, non_default_params: dict, optional_params: dict, model: str, drop_params: bool = False, ) -> dict: supported_openai_params = self.get_supported_openai_params(model=model) for param, value in non_default_params.items(): if param == "max_completion_tokens": optional_params["max_tokens"] = value elif param == "tools" and value is not None: tools = [] for tool in value: tool = filter_value_from_dict(tool, "strict") if tool is not None: tools.append(tool) if len(tools) > 0: optional_params["tools"] = tools elif param in supported_openai_params: if value is not None: optional_params[param] = value return optional_params def get_model_response_iterator( self, streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse], sync_stream: bool, json_mode: Optional[bool] = False, ) -> Any: return XAIChatCompletionStreamingHandler( streaming_response=streaming_response, sync_stream=sync_stream, json_mode=json_mode, ) def transform_request( self, model: str, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, headers: dict, ) -> dict: """ Handle https://github.com/BerriAI/litellm/issues/9720 Filter out 'name' from messages """ messages = strip_name_from_messages(messages) return super().transform_request( model, messages, optional_params, litellm_params, headers ) @staticmethod def _fix_choice_finish_reason_for_tool_calls(choice: Choices) -> None: """ Helper to fix finish_reason for tool calls when XAI API returns empty string. XAI API returns empty string for finish_reason when using tools, so we need to set it to "tool_calls" when tool_calls are present. """ if ( choice.finish_reason == "" and choice.message.tool_calls and len(choice.message.tool_calls) > 0 ): choice.finish_reason = "tool_calls" def transform_response( self, model: str, raw_response: httpx.Response, model_response: ModelResponse, logging_obj, request_data: dict, messages: List[AllMessageValues], optional_params: dict, litellm_params: dict, encoding, api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: """ Transform the response from the XAI API. XAI API returns empty string for finish_reason when using tools, so we need to fix this after the standard OpenAI transformation. Also handles X.AI web search usage tracking by extracting num_sources_used. """ # First, let the parent class handle the standard transformation response = super().transform_response( model=model, raw_response=raw_response, model_response=model_response, logging_obj=logging_obj, request_data=request_data, messages=messages, optional_params=optional_params, litellm_params=litellm_params, encoding=encoding, api_key=api_key, json_mode=json_mode, ) # Fix finish_reason for tool calls across all choices if response.choices: for choice in response.choices: if isinstance(choice, Choices): self._fix_choice_finish_reason_for_tool_calls(choice) # Handle X.AI web search usage tracking try: raw_response_json = raw_response.json() self._enhance_usage_with_xai_web_search_fields(response, raw_response_json) except Exception as e: verbose_logger.debug(f"Error extracting X.AI web search usage: {e}") return response def _enhance_usage_with_xai_web_search_fields( self, model_response: ModelResponse, raw_response_json: dict ) -> None: """ Extract num_sources_used from X.AI response and map it to web_search_requests. """ if not hasattr(model_response, "usage") or model_response.usage is None: return usage: Usage = model_response.usage num_sources_used = None response_usage = raw_response_json.get("usage", {}) if isinstance(response_usage, dict) and "num_sources_used" in response_usage: num_sources_used = response_usage.get("num_sources_used") # Map num_sources_used to web_search_requests for cost detection if num_sources_used is not None and num_sources_used > 0: if usage.prompt_tokens_details is None: usage.prompt_tokens_details = PromptTokensDetailsWrapper() usage.prompt_tokens_details.web_search_requests = int(num_sources_used) setattr(usage, "num_sources_used", int(num_sources_used)) verbose_logger.debug(f"X.AI web search sources used: {num_sources_used}") class XAIChatCompletionStreamingHandler(OpenAIChatCompletionStreamingHandler): def chunk_parser(self, chunk: dict) -> ModelResponseStream: """ Handle xAI-specific streaming behavior. xAI Grok sends a final chunk with empty choices array but with usage data when stream_options={"include_usage": True} is set. Example from xAI API: {"id":"...","object":"chat.completion.chunk","created":...,"model":"grok-4-1-fast-non-reasoning", "choices":[],"usage":{"prompt_tokens":171,"completion_tokens":2,"total_tokens":173,...}} """ # Handle chunks with empty choices but with usage data choices = chunk.get("choices", []) if len(choices) == 0 and "usage" in chunk: # xAI sends usage in a chunk with empty choices array # Add a dummy choice with empty delta to ensure proper processing chunk["choices"] = [{"index": 0, "delta": {}, "finish_reason": None}] return super().chunk_parser(chunk)