Skip to content

Commit 9b94a1b

Browse files
committed
fix(litellm): map LiteLLM context-window errors to ContextWindowOverflowException
1 parent 776fd93 commit 9b94a1b

File tree

2 files changed

+85
-8
lines changed

2 files changed

+85
-8
lines changed

src/strands/models/litellm.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from typing_extensions import Unpack, override
1414

1515
from ..types.content import ContentBlock, Messages
16+
from ..types.exceptions import ContextWindowOverflowException
1617
from ..types.streaming import StreamEvent
1718
from ..types.tools import ToolChoice, ToolSpec
1819
from ._validation import validate_config_keys
@@ -22,6 +23,15 @@
2223

2324
T = TypeVar("T", bound=BaseModel)
2425

26+
LITELLM_CONTEXT_WINDOW_OVERFLOW_MESSAGES = [
27+
"Context Window Error",
28+
"Context Window Exceeded",
29+
"ContextWindowExceeded",
30+
"Context window exceeded",
31+
"Input is too long",
32+
"ContextWindowExceededError",
33+
]
34+
2535

2636
class LiteLLMModel(OpenAIModel):
2737
"""LiteLLM model provider implementation."""
@@ -135,7 +145,25 @@ async def stream(
135145
logger.debug("request=<%s>", request)
136146

137147
logger.debug("invoking model")
138-
response = await litellm.acompletion(**self.client_args, **request)
148+
try:
149+
response = await litellm.acompletion(**self.client_args, **request)
150+
except Exception as e:
151+
# Prefer litellm-specific typed exception if exposed
152+
litellm_exc_type = getattr(litellm, "ContextWindowExceededError", None) or getattr(
153+
litellm, "ContextWindowExceeded", None
154+
)
155+
if litellm_exc_type and isinstance(e, litellm_exc_type):
156+
logger.warning("litellm client raised context window overflow")
157+
raise ContextWindowOverflowException(e) from e
158+
159+
# Fallback to substring checks similar to Bedrock handling
160+
error_message = str(e)
161+
if any(substr in error_message for substr in LITELLM_CONTEXT_WINDOW_OVERFLOW_MESSAGES):
162+
logger.warning("litellm threw context window overflow error")
163+
raise ContextWindowOverflowException(e) from e
164+
165+
# Not a context-window error — re-raise original
166+
raise
139167

140168
logger.debug("got response from model")
141169
yield self.format_chunk({"chunk_type": "message_start"})
@@ -205,15 +233,37 @@ async def structured_output(
205233
Yields:
206234
Model events with the last being the structured output.
207235
"""
208-
if not supports_response_schema(self.get_config()["model_id"]):
236+
supports_schema = supports_response_schema(self.get_config()["model_id"])
237+
238+
# If the provider does not support response schemas, we cannot reliably parse structured output.
239+
# In that case we must not call the provider and must raise the documented ValueError.
240+
if not supports_schema:
209241
raise ValueError("Model does not support response_format")
210242

211-
response = await litellm.acompletion(
212-
**self.client_args,
213-
model=self.get_config()["model_id"],
214-
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
215-
response_format=output_model,
216-
)
243+
# For providers that DO support response schemas, call litellm and map context-window errors.
244+
try:
245+
response = await litellm.acompletion(
246+
**self.client_args,
247+
model=self.get_config()["model_id"],
248+
messages=self.format_request(prompt, system_prompt=system_prompt)["messages"],
249+
response_format=output_model,
250+
)
251+
except Exception as e:
252+
# Prefer litellm-specific typed exception if exposed
253+
litellm_exc_type = getattr(litellm, "ContextWindowExceededError", None) or getattr(
254+
litellm, "ContextWindowExceeded", None
255+
)
256+
if litellm_exc_type and isinstance(e, litellm_exc_type):
257+
logger.warning("litellm client raised context window overflow in structured_output")
258+
raise ContextWindowOverflowException(e) from e
259+
260+
error_message = str(e)
261+
if any(substr in error_message for substr in LITELLM_CONTEXT_WINDOW_OVERFLOW_MESSAGES):
262+
logger.warning("litellm threw context window overflow error in structured_output")
263+
raise ContextWindowOverflowException(e) from e
264+
265+
# Not a context-window error — re-raise original
266+
raise
217267

218268
if len(response.choices) > 1:
219269
raise ValueError("Multiple choices found in the response.")

tests/strands/models/test_litellm.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import strands
88
from strands.models.litellm import LiteLLMModel
9+
from strands.types.exceptions import ContextWindowOverflowException
910

1011

1112
@pytest.fixture
@@ -301,6 +302,32 @@ async def test_structured_output_unsupported_model(litellm_acompletion, model, t
301302
litellm_acompletion.assert_not_called()
302303

303304

305+
@pytest.mark.asyncio
306+
async def test_stream_context_window_maps_to_exception(litellm_acompletion, model):
307+
# Make the litellm client raise an error that indicates a context-window overflow.
308+
litellm_acompletion.side_effect = Exception("Input is too long for requested model")
309+
310+
with pytest.raises(ContextWindowOverflowException):
311+
async for _ in model.stream([{"role": "user", "content": [{"text": "x"}]}]):
312+
pass
313+
314+
315+
@pytest.mark.asyncio
316+
async def test_structured_output_context_window_maps_to_exception(litellm_acompletion, model, test_output_model_cls):
317+
# Litellm structured_output path raising similar message should be mapped too.
318+
litellm_acompletion.side_effect = Exception("Context Window Error - Input too long")
319+
320+
# Ensure supports_response_schema returns True so structured_output will call litellm.acompletion
321+
# and we can observe mapping to ContextWindowOverflowException.
322+
with unittest.mock.patch.object(strands.models.litellm, "supports_response_schema", return_value=True):
323+
with pytest.raises(ContextWindowOverflowException):
324+
# structured_output is async generator; consuming it should raise our mapped exception.
325+
async for _ in model.structured_output(
326+
output_model=test_output_model_cls, prompt=[{"role": "user", "content": [{"text": "x"}]}]
327+
):
328+
pass
329+
330+
304331
def test_config_validation_warns_on_unknown_keys(litellm_acompletion, captured_warnings):
305332
"""Test that unknown config keys emit a warning."""
306333
LiteLLMModel(client_args={"api_key": "test"}, model_id="test-model", invalid_param="test")

0 commit comments

Comments
 (0)