chore: initial public snapshot for github upload
This commit is contained in:
@@ -0,0 +1,206 @@
|
||||
import json
|
||||
from typing import Any, Optional
|
||||
|
||||
from litellm.constants import STREAM_SSE_DONE_STRING
|
||||
from litellm.exceptions import AuthenticationError
|
||||
from litellm.litellm_core_utils.core_helpers import process_response_headers
|
||||
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
|
||||
_safe_convert_created_field,
|
||||
)
|
||||
from litellm.llms.openai.common_utils import OpenAIError
|
||||
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
|
||||
from litellm.types.llms.openai import (
|
||||
ResponsesAPIResponse,
|
||||
ResponsesAPIStreamEvents,
|
||||
)
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
from ..authenticator import Authenticator
|
||||
from ..common_utils import (
|
||||
CHATGPT_API_BASE,
|
||||
GetAccessTokenError,
|
||||
ensure_chatgpt_session_id,
|
||||
get_chatgpt_default_headers,
|
||||
get_chatgpt_default_instructions,
|
||||
)
|
||||
|
||||
|
||||
class ChatGPTResponsesAPIConfig(OpenAIResponsesAPIConfig):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.authenticator = Authenticator()
|
||||
|
||||
@property
|
||||
def custom_llm_provider(self) -> LlmProviders:
|
||||
return LlmProviders.CHATGPT
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
litellm_params: Optional[GenericLiteLLMParams],
|
||||
) -> dict:
|
||||
try:
|
||||
access_token = self.authenticator.get_access_token()
|
||||
except GetAccessTokenError as e:
|
||||
raise AuthenticationError(
|
||||
model=model,
|
||||
llm_provider="chatgpt",
|
||||
message=str(e),
|
||||
)
|
||||
|
||||
account_id = self.authenticator.get_account_id()
|
||||
session_id = ensure_chatgpt_session_id(litellm_params)
|
||||
default_headers = get_chatgpt_default_headers(
|
||||
access_token, account_id, session_id
|
||||
)
|
||||
return {**default_headers, **headers}
|
||||
|
||||
def transform_responses_api_request(
|
||||
self,
|
||||
model: str,
|
||||
input: Any,
|
||||
response_api_optional_request_params: dict,
|
||||
litellm_params: GenericLiteLLMParams,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
request = super().transform_responses_api_request(
|
||||
model,
|
||||
input,
|
||||
response_api_optional_request_params,
|
||||
litellm_params,
|
||||
headers,
|
||||
)
|
||||
base_instructions = get_chatgpt_default_instructions()
|
||||
existing_instructions = request.get("instructions")
|
||||
if existing_instructions:
|
||||
if base_instructions not in existing_instructions:
|
||||
request[
|
||||
"instructions"
|
||||
] = f"{base_instructions}\n\n{existing_instructions}"
|
||||
else:
|
||||
request["instructions"] = base_instructions
|
||||
request["store"] = False
|
||||
request["stream"] = True
|
||||
include = list(request.get("include") or [])
|
||||
if "reasoning.encrypted_content" not in include:
|
||||
include.append("reasoning.encrypted_content")
|
||||
request["include"] = include
|
||||
|
||||
allowed_keys = {
|
||||
"model",
|
||||
"input",
|
||||
"instructions",
|
||||
"stream",
|
||||
"store",
|
||||
"include",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"reasoning",
|
||||
"previous_response_id",
|
||||
"truncation",
|
||||
}
|
||||
|
||||
return {k: v for k, v in request.items() if k in allowed_keys}
|
||||
|
||||
def transform_response_api_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: Any,
|
||||
logging_obj: Any,
|
||||
):
|
||||
content_type = (raw_response.headers or {}).get("content-type", "")
|
||||
body_text = raw_response.text or ""
|
||||
if "text/event-stream" not in content_type.lower():
|
||||
trimmed_body = body_text.lstrip()
|
||||
if not (
|
||||
trimmed_body.startswith("event:")
|
||||
or trimmed_body.startswith("data:")
|
||||
or "\nevent:" in body_text
|
||||
or "\ndata:" in body_text
|
||||
):
|
||||
return super().transform_response_api_response(
|
||||
model=model,
|
||||
raw_response=raw_response,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
logging_obj.post_call(
|
||||
original_response=raw_response.text,
|
||||
additional_args={"complete_input_dict": {}},
|
||||
)
|
||||
|
||||
completed_response = None
|
||||
error_message = None
|
||||
for chunk in body_text.splitlines():
|
||||
stripped_chunk = CustomStreamWrapper._strip_sse_data_from_chunk(chunk)
|
||||
if not stripped_chunk:
|
||||
continue
|
||||
stripped_chunk = stripped_chunk.strip()
|
||||
if not stripped_chunk:
|
||||
continue
|
||||
if stripped_chunk == STREAM_SSE_DONE_STRING:
|
||||
break
|
||||
try:
|
||||
parsed_chunk = json.loads(stripped_chunk)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if not isinstance(parsed_chunk, dict):
|
||||
continue
|
||||
event_type = parsed_chunk.get("type")
|
||||
if event_type == ResponsesAPIStreamEvents.RESPONSE_COMPLETED:
|
||||
response_payload = parsed_chunk.get("response")
|
||||
if isinstance(response_payload, dict):
|
||||
response_payload = dict(response_payload)
|
||||
if "created_at" in response_payload:
|
||||
response_payload["created_at"] = _safe_convert_created_field(
|
||||
response_payload["created_at"]
|
||||
)
|
||||
try:
|
||||
completed_response = ResponsesAPIResponse(**response_payload)
|
||||
except Exception:
|
||||
completed_response = ResponsesAPIResponse.model_construct(
|
||||
**response_payload
|
||||
)
|
||||
break
|
||||
if event_type in (
|
||||
ResponsesAPIStreamEvents.RESPONSE_FAILED,
|
||||
ResponsesAPIStreamEvents.ERROR,
|
||||
):
|
||||
error_obj = parsed_chunk.get("error") or (
|
||||
parsed_chunk.get("response") or {}
|
||||
).get("error")
|
||||
if error_obj is not None:
|
||||
if isinstance(error_obj, dict):
|
||||
error_message = error_obj.get("message") or str(error_obj)
|
||||
else:
|
||||
error_message = str(error_obj)
|
||||
|
||||
if completed_response is None:
|
||||
raise OpenAIError(
|
||||
message=error_message or raw_response.text,
|
||||
status_code=raw_response.status_code,
|
||||
)
|
||||
|
||||
raw_headers = dict(raw_response.headers)
|
||||
processed_headers = process_response_headers(raw_headers)
|
||||
if not hasattr(completed_response, "_hidden_params"):
|
||||
setattr(completed_response, "_hidden_params", {})
|
||||
completed_response._hidden_params["additional_headers"] = processed_headers
|
||||
completed_response._hidden_params["headers"] = raw_headers
|
||||
return completed_response
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
litellm_params: dict,
|
||||
) -> str:
|
||||
api_base = api_base or self.authenticator.get_api_base() or CHATGPT_API_BASE
|
||||
api_base = api_base.rstrip("/")
|
||||
return f"{api_base}/responses"
|
||||
|
||||
def supports_native_websocket(self) -> bool:
|
||||
"""ChatGPT does not support native WebSocket for Responses API"""
|
||||
return False
|
||||
Reference in New Issue
Block a user