chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
"""
|
||||
Handles the chat completion request for groq
|
||||
"""
|
||||
|
||||
from typing import Callable, List, Optional, Union, cast
|
||||
|
||||
from httpx._config import Timeout
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import CustomStreamingDecoder
|
||||
from litellm.utils import ModelResponse
|
||||
|
||||
from ...groq.chat.transformation import GroqChatConfig
|
||||
from ...openai_like.chat.handler import OpenAILikeChatHandler
|
||||
|
||||
|
||||
class GroqChatCompletion(OpenAILikeChatHandler):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def completion(
|
||||
self,
|
||||
*,
|
||||
model: str,
|
||||
messages: list,
|
||||
api_base: str,
|
||||
custom_llm_provider: str,
|
||||
custom_prompt_dict: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
api_key: Optional[str],
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers: Optional[dict] = None,
|
||||
timeout: Optional[Union[float, Timeout]] = None,
|
||||
client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None,
|
||||
custom_endpoint: Optional[bool] = None,
|
||||
streaming_decoder: Optional[CustomStreamingDecoder] = None,
|
||||
fake_stream: bool = False,
|
||||
):
|
||||
messages = GroqChatConfig()._transform_messages(
|
||||
messages=cast(List[AllMessageValues], messages), model=model
|
||||
)
|
||||
|
||||
if optional_params.get("stream") is True:
|
||||
fake_stream = GroqChatConfig()._should_fake_stream(optional_params)
|
||||
else:
|
||||
fake_stream = False
|
||||
|
||||
return super().completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_base=api_base,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
model_response=model_response,
|
||||
print_verbose=print_verbose,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
logging_obj=logging_obj,
|
||||
optional_params=optional_params,
|
||||
acompletion=acompletion,
|
||||
litellm_params=litellm_params,
|
||||
logger_fn=logger_fn,
|
||||
headers=headers,
|
||||
timeout=timeout,
|
||||
client=client,
|
||||
custom_endpoint=custom_endpoint,
|
||||
streaming_decoder=streaming_decoder,
|
||||
fake_stream=fake_stream,
|
||||
)
|
||||
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
Translate from OpenAI's `/v1/chat/completions` to Groq's `/v1/chat/completions`
|
||||
"""
|
||||
from typing import (
|
||||
Any,
|
||||
Coroutine,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
Iterator,
|
||||
AsyncIterator,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.openai.chat.gpt_transformation import (
|
||||
OpenAIChatCompletionStreamingHandler,
|
||||
)
|
||||
from litellm.llms.openai.common_utils import OpenAIError
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionToolParam,
|
||||
ChatCompletionToolParamFunctionChunk,
|
||||
)
|
||||
from litellm.types.utils import ModelResponse, ModelResponseStream
|
||||
|
||||
from ...openai_like.chat.transformation import OpenAILikeChatConfig
|
||||
|
||||
|
||||
class GroqChatConfig(OpenAILikeChatConfig):
|
||||
frequency_penalty: Optional[int] = None
|
||||
function_call: Optional[Union[str, dict]] = None
|
||||
functions: Optional[list] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
response_format: Optional[dict] = None
|
||||
tools: Optional[list] = None
|
||||
tool_choice: Optional[Union[str, dict]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> Optional[str]:
|
||||
return "groq"
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return super().get_config()
|
||||
|
||||
def get_model_response_iterator(
|
||||
self,
|
||||
streaming_response: Union[Iterator[str], AsyncIterator[str], ModelResponse],
|
||||
sync_stream: bool,
|
||||
json_mode: Optional[bool] = False,
|
||||
) -> Any:
|
||||
return GroqChatCompletionStreamingHandler(
|
||||
streaming_response=streaming_response,
|
||||
sync_stream=sync_stream,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
base_params = super().get_supported_openai_params(model)
|
||||
try:
|
||||
base_params.remove("max_retries")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
if litellm.supports_reasoning(
|
||||
model=model, custom_llm_provider=self.custom_llm_provider
|
||||
):
|
||||
base_params.append("reasoning_effort")
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error checking if model supports reasoning: {e}")
|
||||
|
||||
return base_params
|
||||
|
||||
@overload
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: Literal[True]
|
||||
) -> Coroutine[Any, Any, List[AllMessageValues]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _transform_messages(
|
||||
self,
|
||||
messages: List[AllMessageValues],
|
||||
model: str,
|
||||
is_async: Literal[False] = False,
|
||||
) -> List[AllMessageValues]:
|
||||
...
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: bool = False
|
||||
) -> Union[List[AllMessageValues], Coroutine[Any, Any, List[AllMessageValues]]]:
|
||||
for idx, message in enumerate(messages):
|
||||
"""
|
||||
1. Don't pass 'null' function_call assistant message to groq - https://github.com/BerriAI/litellm/issues/5839
|
||||
"""
|
||||
if isinstance(message, BaseModel):
|
||||
_message = message.model_dump()
|
||||
else:
|
||||
_message = message
|
||||
assistant_message = _message.get("role") == "assistant"
|
||||
if assistant_message:
|
||||
new_message = ChatCompletionAssistantMessage(role="assistant")
|
||||
for k, v in _message.items():
|
||||
if v is not None:
|
||||
new_message[k] = v # type: ignore
|
||||
messages[idx] = new_message
|
||||
|
||||
if is_async:
|
||||
return super()._transform_messages(
|
||||
messages=messages, model=model, is_async=True
|
||||
)
|
||||
else:
|
||||
return super()._transform_messages(
|
||||
messages=messages, model=model, is_async=False
|
||||
)
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
# groq is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.groq.com/openai/v1
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("GROQ_API_BASE")
|
||||
or "https://api.groq.com/openai/v1"
|
||||
) # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("GROQ_API_KEY")
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
def _should_fake_stream(self, optional_params: dict) -> bool:
|
||||
"""
|
||||
Groq doesn't support 'response_format' while streaming
|
||||
"""
|
||||
if optional_params.get("response_format") is not None:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _create_json_tool_call_for_response_format(
|
||||
self,
|
||||
json_schema: dict,
|
||||
):
|
||||
"""
|
||||
Handles creating a tool call for getting responses in JSON format.
|
||||
|
||||
Args:
|
||||
json_schema (Optional[dict]): The JSON schema the response should be in
|
||||
|
||||
Returns:
|
||||
AnthropicMessagesTool: The tool call to send to Anthropic API to get responses in JSON format
|
||||
"""
|
||||
return ChatCompletionToolParam(
|
||||
type="function",
|
||||
function=ChatCompletionToolParamFunctionChunk(
|
||||
name="json_tool_call",
|
||||
parameters=json_schema,
|
||||
),
|
||||
)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool = False,
|
||||
replace_max_completion_tokens_with_max_tokens: bool = False, # groq supports max_completion_tokens
|
||||
) -> dict:
|
||||
_response_format = non_default_params.get("response_format")
|
||||
if self._should_fake_stream(non_default_params):
|
||||
optional_params["fake_stream"] = True
|
||||
if _response_format is not None and isinstance(_response_format, dict):
|
||||
json_schema: Optional[dict] = None
|
||||
if "response_schema" in _response_format:
|
||||
json_schema = _response_format["response_schema"]
|
||||
elif "json_schema" in _response_format:
|
||||
json_schema = _response_format["json_schema"]["schema"]
|
||||
"""
|
||||
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
|
||||
- You usually want to provide a single tool
|
||||
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
|
||||
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model's perspective.
|
||||
|
||||
Note: This workaround is only for models that don't support native json_schema.
|
||||
Models like gpt-oss-120b, llama-4, kimi-k2 support native json_schema and should
|
||||
pass response_format directly to Groq.
|
||||
See: https://console.groq.com/docs/structured-outputs#supported-models
|
||||
"""
|
||||
if json_schema is not None:
|
||||
# Check if model supports native response_schema
|
||||
if not litellm.supports_response_schema(
|
||||
model=model, custom_llm_provider="groq"
|
||||
):
|
||||
# Check if user is also passing tools - this combination won't work
|
||||
# See: https://console.groq.com/docs/structured-outputs
|
||||
# "Streaming and tool use are not currently supported with Structured Outputs"
|
||||
if "tools" in non_default_params:
|
||||
raise litellm.BadRequestError(
|
||||
message=f"Groq model '{model}' does not support native structured outputs. "
|
||||
"LiteLLM uses a tool-calling workaround for structured outputs on this model, "
|
||||
"which is incompatible with user-provided tools. "
|
||||
"Either use a model that supports native structured outputs "
|
||||
"(e.g., gpt-oss-120b, llama-4, kimi-k2), or remove the tools parameter. "
|
||||
"See: https://console.groq.com/docs/structured-outputs#supported-models",
|
||||
model=model,
|
||||
llm_provider="groq",
|
||||
)
|
||||
# Use workaround only for models without native support
|
||||
_tool_choice = {
|
||||
"type": "function",
|
||||
"function": {"name": "json_tool_call"},
|
||||
}
|
||||
_tool = self._create_json_tool_call_for_response_format(
|
||||
json_schema=json_schema,
|
||||
)
|
||||
optional_params["tools"] = [_tool]
|
||||
optional_params["tool_choice"] = _tool_choice
|
||||
optional_params["json_mode"] = True
|
||||
non_default_params.pop(
|
||||
"response_format", None
|
||||
) # only remove if it's a json_schema - handled via using groq's tool calling params.
|
||||
# else: model supports native json_schema, let response_format pass through
|
||||
optional_params = super().map_openai_params(
|
||||
non_default_params, optional_params, model, drop_params
|
||||
)
|
||||
|
||||
return optional_params
|
||||
|
||||
def transform_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: ModelResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
request_data: dict,
|
||||
messages: List[AllMessageValues],
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
encoding: Any,
|
||||
api_key: Optional[str] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
) -> ModelResponse:
|
||||
model_response = super().transform_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=request_data,
|
||||
messages=messages,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
encoding=encoding,
|
||||
api_key=api_key,
|
||||
json_mode=json_mode,
|
||||
)
|
||||
|
||||
mapped_service_tier: Literal[
|
||||
"auto", "default", "flex"
|
||||
] = self._map_groq_service_tier(
|
||||
original_service_tier=getattr(model_response, "service_tier")
|
||||
)
|
||||
setattr(model_response, "service_tier", mapped_service_tier)
|
||||
return model_response
|
||||
|
||||
def _map_groq_service_tier(
|
||||
self, original_service_tier: Optional[str]
|
||||
) -> Literal["auto", "default", "flex"]:
|
||||
"""
|
||||
Ensure groq service tier is OpenAI compatible.
|
||||
"""
|
||||
if original_service_tier is None:
|
||||
return "auto"
|
||||
if original_service_tier not in ["auto", "default", "flex"]:
|
||||
return "auto"
|
||||
|
||||
return cast(Literal["auto", "default", "flex"], original_service_tier)
|
||||
|
||||
|
||||
class GroqChatCompletionStreamingHandler(OpenAIChatCompletionStreamingHandler):
|
||||
def chunk_parser(self, chunk: dict) -> ModelResponseStream:
|
||||
error = chunk.get("error")
|
||||
if error:
|
||||
raise OpenAIError(
|
||||
status_code=error.get("code"), message=error.get("message"), body=error
|
||||
)
|
||||
|
||||
# Map Groq's 'reasoning' field to LiteLLM's 'reasoning_content' field
|
||||
# Groq returns delta.reasoning, but LiteLLM expects delta.reasoning_content
|
||||
choices = chunk.get("choices", [])
|
||||
for choice in choices:
|
||||
delta = choice.get("delta", {})
|
||||
if "reasoning" in delta:
|
||||
delta["reasoning_content"] = delta.pop("reasoning")
|
||||
|
||||
return super().chunk_parser(chunk)
|
||||
@@ -0,0 +1,100 @@
|
||||
"""
|
||||
Translate from OpenAI's `/v1/audio/transcriptions` to Groq's `/v1/audio/transcriptions`
|
||||
"""
|
||||
|
||||
import types
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class GroqSTTConfig:
|
||||
frequency_penalty: Optional[int] = None
|
||||
function_call: Optional[Union[str, dict]] = None
|
||||
functions: Optional[list] = None
|
||||
logit_bias: Optional[dict] = None
|
||||
max_tokens: Optional[int] = None
|
||||
n: Optional[int] = None
|
||||
presence_penalty: Optional[int] = None
|
||||
stop: Optional[Union[str, list]] = None
|
||||
temperature: Optional[int] = None
|
||||
top_p: Optional[int] = None
|
||||
response_format: Optional[dict] = None
|
||||
tools: Optional[list] = None
|
||||
tool_choice: Optional[Union[str, dict]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
frequency_penalty: Optional[int] = None,
|
||||
function_call: Optional[Union[str, dict]] = None,
|
||||
functions: Optional[list] = None,
|
||||
logit_bias: Optional[dict] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[int] = None,
|
||||
stop: Optional[Union[str, list]] = None,
|
||||
temperature: Optional[int] = None,
|
||||
top_p: Optional[int] = None,
|
||||
response_format: Optional[dict] = None,
|
||||
tools: Optional[list] = None,
|
||||
tool_choice: Optional[Union[str, dict]] = None,
|
||||
) -> None:
|
||||
locals_ = locals().copy()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params_stt(self):
|
||||
return [
|
||||
"prompt",
|
||||
"response_format",
|
||||
"temperature",
|
||||
"language",
|
||||
]
|
||||
|
||||
def get_supported_openai_response_formats_stt(self) -> List[str]:
|
||||
return ["json", "verbose_json", "text"]
|
||||
|
||||
def map_openai_params_stt(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
response_formats = self.get_supported_openai_response_formats_stt()
|
||||
for param, value in non_default_params.items():
|
||||
if param == "response_format":
|
||||
if value in response_formats:
|
||||
optional_params[param] = value
|
||||
else:
|
||||
if litellm.drop_params is True or drop_params is True:
|
||||
pass
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="Groq doesn't support response_format={}. To drop unsupported openai params from the call, set `litellm.drop_params = True`".format(
|
||||
value
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
else:
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
Reference in New Issue
Block a user