|
13 | 13 | from typing_extensions import Unpack, override |
14 | 14 |
|
15 | 15 | from ..types.content import ContentBlock, Messages |
| 16 | +from ..types.exceptions import ContextWindowOverflowException |
16 | 17 | from ..types.streaming import StreamEvent |
17 | 18 | from ..types.tools import ToolChoice, ToolSpec |
18 | 19 | from ._validation import validate_config_keys |
|
22 | 23 |
|
23 | 24 | T = TypeVar("T", bound=BaseModel) |
24 | 25 |
|
| 26 | +LITELLM_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [ |
| 27 | + "Context Window Error", |
| 28 | + "Context Window Exceeded", |
| 29 | + "ContextWindowExceeded", |
| 30 | + "Context window exceeded", |
| 31 | + "Input is too long", |
| 32 | + "ContextWindowExceededError", |
| 33 | +] |
| 34 | + |
25 | 35 |
|
26 | 36 | class LiteLLMModel(OpenAIModel): |
27 | 37 | """LiteLLM model provider implementation.""" |
@@ -135,7 +145,25 @@ async def stream( |
135 | 145 | logger.debug("request=<%s>", request) |
136 | 146 |
|
137 | 147 | logger.debug("invoking model") |
138 | | - response = await litellm.acompletion(**self.client_args, **request) |
| 148 | + try: |
| 149 | + response = await litellm.acompletion(**self.client_args, **request) |
| 150 | + except Exception as e: |
| 151 | + # Prefer litellm-specific typed exception if exposed |
| 152 | + litellm_exc_type = getattr(litellm, "ContextWindowExceededError", None) or getattr( |
| 153 | + litellm, "ContextWindowExceeded", None |
| 154 | + ) |
| 155 | + if litellm_exc_type and isinstance(e, litellm_exc_type): |
| 156 | + logger.warning("litellm client raised context window overflow") |
| 157 | + raise ContextWindowOverflowException(e) from e |
| 158 | + |
| 159 | + # Fallback to substring checks similar to Bedrock handling |
| 160 | + error_message = str(e) |
| 161 | + if any(substr in error_message for substr in LITELLM_CONTEXT_WINDOW_OVERFLOW_MESSAGES): |
| 162 | + logger.warning("litellm threw context window overflow error") |
| 163 | + raise ContextWindowOverflowException(e) from e |
| 164 | + |
| 165 | + # Not a context-window error — re-raise original |
| 166 | + raise |
139 | 167 |
|
140 | 168 | logger.debug("got response from model") |
141 | 169 | yield self.format_chunk({"chunk_type": "message_start"}) |
@@ -205,15 +233,37 @@ async def structured_output( |
205 | 233 | Yields: |
206 | 234 | Model events with the last being the structured output. |
207 | 235 | """ |
208 | | - if not supports_response_schema(self.get_config()["model_id"]): |
| 236 | + supports_schema = supports_response_schema(self.get_config()["model_id"]) |
| 237 | + |
| 238 | + # If the provider does not support response schemas, we cannot reliably parse structured output. |
| 239 | + # In that case we must not call the provider and must raise the documented ValueError. |
| 240 | + if not supports_schema: |
209 | 241 | raise ValueError("Model does not support response_format") |
210 | 242 |
|
211 | | - response = await litellm.acompletion( |
212 | | - **self.client_args, |
213 | | - model=self.get_config()["model_id"], |
214 | | - messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], |
215 | | - response_format=output_model, |
216 | | - ) |
| 243 | + # For providers that DO support response schemas, call litellm and map context-window errors. |
| 244 | + try: |
| 245 | + response = await litellm.acompletion( |
| 246 | + **self.client_args, |
| 247 | + model=self.get_config()["model_id"], |
| 248 | + messages=self.format_request(prompt, system_prompt=system_prompt)["messages"], |
| 249 | + response_format=output_model, |
| 250 | + ) |
| 251 | + except Exception as e: |
| 252 | + # Prefer litellm-specific typed exception if exposed |
| 253 | + litellm_exc_type = getattr(litellm, "ContextWindowExceededError", None) or getattr( |
| 254 | + litellm, "ContextWindowExceeded", None |
| 255 | + ) |
| 256 | + if litellm_exc_type and isinstance(e, litellm_exc_type): |
| 257 | + logger.warning("litellm client raised context window overflow in structured_output") |
| 258 | + raise ContextWindowOverflowException(e) from e |
| 259 | + |
| 260 | + error_message = str(e) |
| 261 | + if any(substr in error_message for substr in LITELLM_CONTEXT_WINDOW_OVERFLOW_MESSAGES): |
| 262 | + logger.warning("litellm threw context window overflow error in structured_output") |
| 263 | + raise ContextWindowOverflowException(e) from e |
| 264 | + |
| 265 | + # Not a context-window error — re-raise original |
| 266 | + raise |
217 | 267 |
|
218 | 268 | if len(response.choices) > 1: |
219 | 269 | raise ValueError("Multiple choices found in the response.") |
|
0 commit comments