Files
lijiaoqiao/llm-gateway-competitors/litellm-wheel-src/litellm/llms/chatgpt/responses/transformation.py
2026-03-26 20:06:14 +08:00

207 lines
7.3 KiB
Python

import json
from typing import Any, Optional
from litellm.constants import STREAM_SSE_DONE_STRING
from litellm.exceptions import AuthenticationError
from litellm.litellm_core_utils.core_helpers import process_response_headers
from litellm.litellm_core_utils.llm_response_utils.convert_dict_to_response import (
_safe_convert_created_field,
)
from litellm.llms.openai.common_utils import OpenAIError
from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig
from litellm.types.llms.openai import (
ResponsesAPIResponse,
ResponsesAPIStreamEvents,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LlmProviders
from litellm.utils import CustomStreamWrapper
from ..authenticator import Authenticator
from ..common_utils import (
CHATGPT_API_BASE,
GetAccessTokenError,
ensure_chatgpt_session_id,
get_chatgpt_default_headers,
get_chatgpt_default_instructions,
)
class ChatGPTResponsesAPIConfig(OpenAIResponsesAPIConfig):
def __init__(self) -> None:
super().__init__()
self.authenticator = Authenticator()
@property
def custom_llm_provider(self) -> LlmProviders:
return LlmProviders.CHATGPT
def validate_environment(
self,
headers: dict,
model: str,
litellm_params: Optional[GenericLiteLLMParams],
) -> dict:
try:
access_token = self.authenticator.get_access_token()
except GetAccessTokenError as e:
raise AuthenticationError(
model=model,
llm_provider="chatgpt",
message=str(e),
)
account_id = self.authenticator.get_account_id()
session_id = ensure_chatgpt_session_id(litellm_params)
default_headers = get_chatgpt_default_headers(
access_token, account_id, session_id
)
return {**default_headers, **headers}
def transform_responses_api_request(
self,
model: str,
input: Any,
response_api_optional_request_params: dict,
litellm_params: GenericLiteLLMParams,
headers: dict,
) -> dict:
request = super().transform_responses_api_request(
model,
input,
response_api_optional_request_params,
litellm_params,
headers,
)
base_instructions = get_chatgpt_default_instructions()
existing_instructions = request.get("instructions")
if existing_instructions:
if base_instructions not in existing_instructions:
request[
"instructions"
] = f"{base_instructions}\n\n{existing_instructions}"
else:
request["instructions"] = base_instructions
request["store"] = False
request["stream"] = True
include = list(request.get("include") or [])
if "reasoning.encrypted_content" not in include:
include.append("reasoning.encrypted_content")
request["include"] = include
allowed_keys = {
"model",
"input",
"instructions",
"stream",
"store",
"include",
"tools",
"tool_choice",
"reasoning",
"previous_response_id",
"truncation",
}
return {k: v for k, v in request.items() if k in allowed_keys}
def transform_response_api_response(
self,
model: str,
raw_response: Any,
logging_obj: Any,
):
content_type = (raw_response.headers or {}).get("content-type", "")
body_text = raw_response.text or ""
if "text/event-stream" not in content_type.lower():
trimmed_body = body_text.lstrip()
if not (
trimmed_body.startswith("event:")
or trimmed_body.startswith("data:")
or "\nevent:" in body_text
or "\ndata:" in body_text
):
return super().transform_response_api_response(
model=model,
raw_response=raw_response,
logging_obj=logging_obj,
)
logging_obj.post_call(
original_response=raw_response.text,
additional_args={"complete_input_dict": {}},
)
completed_response = None
error_message = None
for chunk in body_text.splitlines():
stripped_chunk = CustomStreamWrapper._strip_sse_data_from_chunk(chunk)
if not stripped_chunk:
continue
stripped_chunk = stripped_chunk.strip()
if not stripped_chunk:
continue
if stripped_chunk == STREAM_SSE_DONE_STRING:
break
try:
parsed_chunk = json.loads(stripped_chunk)
except json.JSONDecodeError:
continue
if not isinstance(parsed_chunk, dict):
continue
event_type = parsed_chunk.get("type")
if event_type == ResponsesAPIStreamEvents.RESPONSE_COMPLETED:
response_payload = parsed_chunk.get("response")
if isinstance(response_payload, dict):
response_payload = dict(response_payload)
if "created_at" in response_payload:
response_payload["created_at"] = _safe_convert_created_field(
response_payload["created_at"]
)
try:
completed_response = ResponsesAPIResponse(**response_payload)
except Exception:
completed_response = ResponsesAPIResponse.model_construct(
**response_payload
)
break
if event_type in (
ResponsesAPIStreamEvents.RESPONSE_FAILED,
ResponsesAPIStreamEvents.ERROR,
):
error_obj = parsed_chunk.get("error") or (
parsed_chunk.get("response") or {}
).get("error")
if error_obj is not None:
if isinstance(error_obj, dict):
error_message = error_obj.get("message") or str(error_obj)
else:
error_message = str(error_obj)
if completed_response is None:
raise OpenAIError(
message=error_message or raw_response.text,
status_code=raw_response.status_code,
)
raw_headers = dict(raw_response.headers)
processed_headers = process_response_headers(raw_headers)
if not hasattr(completed_response, "_hidden_params"):
setattr(completed_response, "_hidden_params", {})
completed_response._hidden_params["additional_headers"] = processed_headers
completed_response._hidden_params["headers"] = raw_headers
return completed_response
def get_complete_url(
self,
api_base: Optional[str],
litellm_params: dict,
) -> str:
api_base = api_base or self.authenticator.get_api_base() or CHATGPT_API_BASE
api_base = api_base.rstrip("/")
return f"{api_base}/responses"
def supports_native_websocket(self) -> bool:
"""ChatGPT does not support native WebSocket for Responses API"""
return False