Skip to content

Commit fc6b1e5

Browse files
authored
BREAKING CHANGE: Make AgentStreamEvent union of ModelResponseStreamEvent and HandleResponseEvent (#2689)
1 parent ebb4ee8 commit fc6b1e5

File tree

10 files changed

+38
-52
lines changed

10 files changed

+38
-52
lines changed

docs/agents.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,15 +115,13 @@ The example below shows how to stream events and text output. You can also [stre
115115
import asyncio
116116
from collections.abc import AsyncIterable
117117
from datetime import date
118-
from typing import Union
119118

120119
from pydantic_ai import Agent
121120
from pydantic_ai.messages import (
122121
AgentStreamEvent,
123122
FinalResultEvent,
124123
FunctionToolCallEvent,
125124
FunctionToolResultEvent,
126-
HandleResponseEvent,
127125
PartDeltaEvent,
128126
PartStartEvent,
129127
TextPartDelta,
@@ -152,7 +150,7 @@ output_messages: list[str] = []
152150

153151
async def event_stream_handler(
154152
ctx: RunContext,
155-
event_stream: AsyncIterable[Union[AgentStreamEvent, HandleResponseEvent]],
153+
event_stream: AsyncIterable[AgentStreamEvent],
156154
):
157155
async for event in event_stream:
158156
if isinstance(event, PartStartEvent):

pydantic_ai_slim/pydantic_ai/ag_ui.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,11 @@
2828
from .agent import AbstractAgent, AgentRun
2929
from .exceptions import UserError
3030
from .messages import (
31-
AgentStreamEvent,
3231
FunctionToolResultEvent,
3332
ModelMessage,
3433
ModelRequest,
3534
ModelResponse,
35+
ModelResponseStreamEvent,
3636
PartDeltaEvent,
3737
PartStartEvent,
3838
SystemPromptPart,
@@ -403,7 +403,7 @@ async def _agent_stream(run: AgentRun[AgentDepsT, Any]) -> AsyncIterator[BaseEve
403403

404404
async def _handle_model_request_event(
405405
stream_ctx: _RequestStreamContext,
406-
agent_event: AgentStreamEvent,
406+
agent_event: ModelResponseStreamEvent,
407407
) -> AsyncIterator[BaseEvent]:
408408
"""Handle an agent event and yield AG-UI protocol events.
409409

pydantic_ai_slim/pydantic_ai/agent/abstract.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator, Mapping, Sequence
66
from contextlib import AbstractAsyncContextManager, asynccontextmanager, contextmanager
77
from types import FrameType
8-
from typing import TYPE_CHECKING, Any, Callable, Generic, Union, cast, overload
8+
from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload
99

1010
from typing_extensions import Self, TypeAlias, TypeIs, TypeVar
1111

@@ -53,11 +53,7 @@
5353
"""Type variable for the result data of a run where `output_type` was customized on the run call."""
5454

5555
EventStreamHandler: TypeAlias = Callable[
56-
[
57-
RunContext[AgentDepsT],
58-
AsyncIterable[Union[_messages.AgentStreamEvent, _messages.HandleResponseEvent]],
59-
],
60-
Awaitable[None],
56+
[RunContext[AgentDepsT], AsyncIterable[_messages.AgentStreamEvent]], Awaitable[None]
6157
]
6258
"""A function that receives agent [`RunContext`][pydantic_ai.tools.RunContext] and an async iterable of events from the model's streaming response and the agent's execution of tools."""
6359

@@ -445,7 +441,9 @@ async def main():
445441
async with node.stream(graph_ctx) as stream:
446442
final_result_event = None
447443

448-
async def stream_to_final(stream: AgentStream) -> AsyncIterator[_messages.AgentStreamEvent]:
444+
async def stream_to_final(
445+
stream: AgentStream,
446+
) -> AsyncIterator[_messages.ModelResponseStreamEvent]:
449447
nonlocal final_result_event
450448
async for event in stream:
451449
yield event

pydantic_ai_slim/pydantic_ai/direct.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,9 @@ class StreamedResponseSync:
275275
"""
276276

277277
_async_stream_cm: AbstractAsyncContextManager[StreamedResponse]
278-
_queue: queue.Queue[messages.AgentStreamEvent | Exception | None] = field(default_factory=queue.Queue, init=False)
278+
_queue: queue.Queue[messages.ModelResponseStreamEvent | Exception | None] = field(
279+
default_factory=queue.Queue, init=False
280+
)
279281
_thread: threading.Thread | None = field(default=None, init=False)
280282
_stream_response: StreamedResponse | None = field(default=None, init=False)
281283
_exception: Exception | None = field(default=None, init=False)
@@ -295,8 +297,8 @@ def __exit__(
295297
) -> None:
296298
self._cleanup()
297299

298-
def __iter__(self) -> Iterator[messages.AgentStreamEvent]:
299-
"""Stream the response as an iterable of [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s."""
300+
def __iter__(self) -> Iterator[messages.ModelResponseStreamEvent]:
301+
"""Stream the response as an iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
300302
self._check_context_manager_usage()
301303

302304
while True:

pydantic_ai_slim/pydantic_ai/messages.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,13 +1266,10 @@ class FinalResultEvent:
12661266
__repr__ = _utils.dataclasses_no_defaults_repr
12671267

12681268

1269-
ModelResponseStreamEvent = Annotated[Union[PartStartEvent, PartDeltaEvent], pydantic.Discriminator('event_kind')]
1270-
"""An event in the model response stream, either starting a new part or applying a delta to an existing one."""
1271-
1272-
AgentStreamEvent = Annotated[
1269+
ModelResponseStreamEvent = Annotated[
12731270
Union[PartStartEvent, PartDeltaEvent, FinalResultEvent], pydantic.Discriminator('event_kind')
12741271
]
1275-
"""An event in the agent stream."""
1272+
"""An event in the model response stream, starting a new part, applying a delta to an existing one, or indicating the final result."""
12761273

12771274

12781275
@dataclass(repr=False)
@@ -1342,3 +1339,6 @@ class BuiltinToolResultEvent:
13421339
pydantic.Discriminator('event_kind'),
13431340
]
13441341
"""An event yielded when handling a model response, indicating tool calls and results."""
1342+
1343+
AgentStreamEvent = Annotated[Union[ModelResponseStreamEvent, HandleResponseEvent], pydantic.Discriminator('event_kind')]
1344+
"""An event in the agent stream: model response stream events and response-handling events."""

pydantic_ai_slim/pydantic_ai/models/__init__.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from ..builtin_tools import AbstractBuiltinTool
2626
from ..exceptions import UserError
2727
from ..messages import (
28-
AgentStreamEvent,
2928
FileUrl,
3029
FinalResultEvent,
3130
ModelMessage,
@@ -555,11 +554,11 @@ class StreamedResponse(ABC):
555554
final_result_event: FinalResultEvent | None = field(default=None, init=False)
556555

557556
_parts_manager: ModelResponsePartsManager = field(default_factory=ModelResponsePartsManager, init=False)
558-
_event_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
557+
_event_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
559558
_usage: RequestUsage = field(default_factory=RequestUsage, init=False)
560559

561-
def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
562-
"""Stream the response as an async iterable of [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s.
560+
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
561+
"""Stream the response as an async iterable of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s.
563562
564563
This proxies the `_event_iterator()` and emits all events, while also checking for matches
565564
on the result schema and emitting a [`FinalResultEvent`][pydantic_ai.messages.FinalResultEvent] if/when the
@@ -569,7 +568,7 @@ def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
569568

570569
async def iterator_with_final_event(
571570
iterator: AsyncIterator[ModelResponseStreamEvent],
572-
) -> AsyncIterator[AgentStreamEvent]:
571+
) -> AsyncIterator[ModelResponseStreamEvent]:
573572
async for event in iterator:
574573
yield event
575574
if (

pydantic_ai_slim/pydantic_ai/result.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
ToolOutputSchema,
2323
)
2424
from ._run_context import AgentDepsT, RunContext
25-
from .messages import AgentStreamEvent
25+
from .messages import ModelResponseStreamEvent
2626
from .output import (
2727
OutputDataT,
2828
ToolOutput,
@@ -51,7 +51,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]):
5151
_usage_limits: UsageLimits | None
5252
_tool_manager: ToolManager[AgentDepsT]
5353

54-
_agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False)
54+
_agent_stream_iterator: AsyncIterator[ModelResponseStreamEvent] | None = field(default=None, init=False)
5555
_initial_run_ctx_usage: RunUsage = field(init=False)
5656

5757
def __post_init__(self):
@@ -221,8 +221,8 @@ async def _stream_text_deltas() -> AsyncIterator[str]:
221221
deltas.append(text)
222222
yield ''.join(deltas)
223223

224-
def __aiter__(self) -> AsyncIterator[AgentStreamEvent]:
225-
"""Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s."""
224+
def __aiter__(self) -> AsyncIterator[ModelResponseStreamEvent]:
225+
"""Stream [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s."""
226226
if self._agent_stream_iterator is None:
227227
self._agent_stream_iterator = _get_usage_checking_stream_response(
228228
self._raw_stream_response, self._usage_limits, self.usage
@@ -426,7 +426,7 @@ def _get_usage_checking_stream_response(
426426
stream_response: models.StreamedResponse,
427427
limits: UsageLimits | None,
428428
get_usage: Callable[[], RunUsage],
429-
) -> AsyncIterator[AgentStreamEvent]:
429+
) -> AsyncIterator[ModelResponseStreamEvent]:
430430
if limits is not None and limits.has_token_limits():
431431

432432
async def _usage_checking_iterator():

tests/test_agent.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from pydantic_ai.messages import (
3030
AgentStreamEvent,
3131
BinaryContent,
32-
HandleResponseEvent,
3332
ImageUrl,
3433
ModelMessage,
3534
ModelMessagesTypeAdapter,
@@ -4220,9 +4219,7 @@ def foo() -> str:
42204219

42214220

42224221
async def test_wrapper_agent():
4223-
async def event_stream_handler(
4224-
ctx: RunContext[None], events: AsyncIterable[Union[AgentStreamEvent, HandleResponseEvent]]
4225-
):
4222+
async def event_stream_handler(ctx: RunContext[None], events: AsyncIterable[AgentStreamEvent]):
42264223
pass # pragma: no cover
42274224

42284225
foo_toolset = FunctionToolset()

tests/test_streaming.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
FinalResultEvent,
2121
FunctionToolCallEvent,
2222
FunctionToolResultEvent,
23-
HandleResponseEvent,
2423
ModelMessage,
2524
ModelRequest,
2625
ModelResponse,
@@ -1271,11 +1270,9 @@ async def test_run_event_stream_handler():
12711270
async def ret_a(x: str) -> str:
12721271
return f'{x}-apple'
12731272

1274-
events: list[AgentStreamEvent | HandleResponseEvent] = []
1273+
events: list[AgentStreamEvent] = []
12751274

1276-
async def event_stream_handler(
1277-
ctx: RunContext[None], stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent]
1278-
):
1275+
async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[AgentStreamEvent]):
12791276
async for event in stream:
12801277
events.append(event)
12811278

@@ -1314,11 +1311,9 @@ def test_run_sync_event_stream_handler():
13141311
async def ret_a(x: str) -> str:
13151312
return f'{x}-apple'
13161313

1317-
events: list[AgentStreamEvent | HandleResponseEvent] = []
1314+
events: list[AgentStreamEvent] = []
13181315

1319-
async def event_stream_handler(
1320-
ctx: RunContext[None], stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent]
1321-
):
1316+
async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[AgentStreamEvent]):
13221317
async for event in stream:
13231318
events.append(event)
13241319

@@ -1357,11 +1352,9 @@ async def test_run_stream_event_stream_handler():
13571352
async def ret_a(x: str) -> str:
13581353
return f'{x}-apple'
13591354

1360-
events: list[AgentStreamEvent | HandleResponseEvent] = []
1355+
events: list[AgentStreamEvent] = []
13611356

1362-
async def event_stream_handler(
1363-
ctx: RunContext[None], stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent]
1364-
):
1357+
async def event_stream_handler(ctx: RunContext[None], stream: AsyncIterable[AgentStreamEvent]):
13651358
async for event in stream:
13661359
events.append(event)
13671360

tests/test_temporal.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
FinalResultEvent,
2020
FunctionToolCallEvent,
2121
FunctionToolResultEvent,
22-
HandleResponseEvent,
2322
ModelMessage,
2423
ModelRequest,
2524
PartDeltaEvent,
@@ -196,7 +195,7 @@ class Deps(BaseModel):
196195

197196
async def event_stream_handler(
198197
ctx: RunContext[Deps],
199-
stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent],
198+
stream: AsyncIterable[AgentStreamEvent],
200199
):
201200
logfire.info(f'{ctx.run_step=}')
202201
async for event in stream:
@@ -636,11 +635,11 @@ async def test_complex_agent_run_in_workflow(
636635

637636

638637
async def test_complex_agent_run(allow_model_requests: None):
639-
events: list[AgentStreamEvent | HandleResponseEvent] = []
638+
events: list[AgentStreamEvent] = []
640639

641640
async def event_stream_handler(
642641
ctx: RunContext[Deps],
643-
stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent],
642+
stream: AsyncIterable[AgentStreamEvent],
644643
):
645644
async for event in stream:
646645
events.append(event)
@@ -1161,7 +1160,7 @@ async def test_temporal_agent_iter_in_workflow(allow_model_requests: None, clien
11611160

11621161
async def simple_event_stream_handler(
11631162
ctx: RunContext[None],
1164-
stream: AsyncIterable[AgentStreamEvent | HandleResponseEvent],
1163+
stream: AsyncIterable[AgentStreamEvent],
11651164
):
11661165
pass
11671166

0 commit comments

Comments
 (0)