Skip to content

Commit 3ce4714

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Extract a utility for aggregating partial streaming responses and emitting LlmResponses for them as needed
PiperOrigin-RevId: 796951203
1 parent 2dd432c commit 3ce4714

File tree

4 files changed

+314
-54
lines changed

4 files changed

+314
-54
lines changed

src/google/adk/models/gemini_llm_connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
144144

145145
text = ''
146146
async with Aclosing(self._gemini_session.receive()) as agen:
147+
# TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate
148+
# partial content and emit responses as needed.
147149
async for message in agen:
148150
logger.debug('Got LLM Live message: %s', message)
149151
if message.server_content:

src/google/adk/models/google_llm.py

Lines changed: 8 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from .. import version
3535
from ..utils.context_utils import Aclosing
36+
from ..utils.streaming_utils import StreamingResponseAggregator
3637
from ..utils.variant_utils import GoogleLLMVariant
3738
from .base_llm import BaseLlm
3839
from .base_llm_connection import BaseLlmConnection
@@ -133,68 +134,21 @@ async def generate_content_async(
133134
contents=llm_request.contents,
134135
config=llm_request.config,
135136
)
136-
response = None
137-
thought_text = ''
138-
text = ''
139-
usage_metadata = None
137+
140138
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
141139
# we need to mark those text content as partial and after all partial
142140
# contents are sent, we send an accumulated event which contains all the
143141
# previous partial content. The only difference is bidi rely on
144142
# complete_turn flag to detect end while sse depends on finish_reason.
143+
aggregator = StreamingResponseAggregator()
145144
async with Aclosing(responses) as agen:
146145
async for response in agen:
147146
logger.debug(_build_response_log(response))
148-
llm_response = LlmResponse.create(response)
149-
usage_metadata = llm_response.usage_metadata
150-
if (
151-
llm_response.content
152-
and llm_response.content.parts
153-
and llm_response.content.parts[0].text
154-
):
155-
part0 = llm_response.content.parts[0]
156-
if part0.thought:
157-
thought_text += part0.text
158-
else:
159-
text += part0.text
160-
llm_response.partial = True
161-
elif (thought_text or text) and (
162-
not llm_response.content
163-
or not llm_response.content.parts
164-
# don't yield the merged text event when receiving audio data
165-
or not llm_response.content.parts[0].inline_data
166-
):
167-
parts = []
168-
if thought_text:
169-
parts.append(types.Part(text=thought_text, thought=True))
170-
if text:
171-
parts.append(types.Part.from_text(text=text))
172-
yield LlmResponse(
173-
content=types.ModelContent(parts=parts),
174-
usage_metadata=llm_response.usage_metadata,
175-
)
176-
thought_text = ''
177-
text = ''
178-
yield llm_response
179-
180-
# generate an aggregated content at the end regardless the
181-
# response.candidates[0].finish_reason
182-
if (text or thought_text) and response and response.candidates:
183-
parts = []
184-
if thought_text:
185-
parts.append(types.Part(text=thought_text, thought=True))
186-
if text:
187-
parts.append(types.Part.from_text(text=text))
188-
yield LlmResponse(
189-
content=types.ModelContent(parts=parts),
190-
error_code=None
191-
if response.candidates[0].finish_reason == FinishReason.STOP
192-
else response.candidates[0].finish_reason,
193-
error_message=None
194-
if response.candidates[0].finish_reason == FinishReason.STOP
195-
else response.candidates[0].finish_message,
196-
usage_metadata=usage_metadata,
197-
)
147+
async for llm_response in aggregator.process_response(response):
148+
yield llm_response
149+
150+
async for llm_response in aggregator.close():
151+
yield llm_response
198152

199153
else:
200154
response = await self.api_client.aio.models.generate_content(
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import AsyncGenerator
18+
19+
from google.genai import types
20+
21+
from ..models.llm_response import LlmResponse
22+
23+
24+
class StreamingResponseAggregator:
25+
"""Aggregates partial streaming responses.
26+
27+
It aggregates content from partial responses, and generates LlmResponses for
28+
individual (partial) model responses, as well as for aggregated content.
29+
"""
30+
31+
def __init__(self):
32+
self._text = ''
33+
self._thought_text = ''
34+
self._usage_metadata = None
35+
self._response = None
36+
37+
async def process_response(
38+
self, response: types.GenerateContentResponse
39+
) -> AsyncGenerator[LlmResponse, None]:
40+
"""Processes a single model response.
41+
42+
Args:
43+
response: The response to process.
44+
45+
Yields:
46+
The generated LlmResponse(s), for the partial response, and the aggregated
47+
response if needed.
48+
"""
49+
# results = []
50+
self._response = response
51+
llm_response = LlmResponse.create(response)
52+
self._usage_metadata = llm_response.usage_metadata
53+
if (
54+
llm_response.content
55+
and llm_response.content.parts
56+
and llm_response.content.parts[0].text
57+
):
58+
part0 = llm_response.content.parts[0]
59+
if part0.thought:
60+
self._thought_text += part0.text
61+
else:
62+
self._text += part0.text
63+
llm_response.partial = True
64+
elif (self._thought_text or self._text) and (
65+
not llm_response.content
66+
or not llm_response.content.parts
67+
# don't yield the merged text event when receiving audio data
68+
or not llm_response.content.parts[0].inline_data
69+
):
70+
parts = []
71+
if self._thought_text:
72+
parts.append(types.Part(text=self._thought_text, thought=True))
73+
if self._text:
74+
parts.append(types.Part.from_text(text=self._text))
75+
yield LlmResponse(
76+
content=types.ModelContent(parts=parts),
77+
usage_metadata=llm_response.usage_metadata,
78+
)
79+
self._thought_text = ''
80+
self._text = ''
81+
yield llm_response
82+
83+
async def close(self) -> AsyncGenerator[LlmResponse, None]:
84+
"""Generate an aggregated response at the end, if needed.
85+
86+
This should be called after all the model responses are processed.
87+
88+
Yields:
89+
The aggregated LlmResponse.
90+
"""
91+
if (
92+
(self._text or self._thought_text)
93+
and self._response
94+
and self._response.candidates
95+
):
96+
parts = []
97+
if self._thought_text:
98+
parts.append(types.Part(text=self._thought_text, thought=True))
99+
if self._text:
100+
parts.append(types.Part.from_text(text=self._text))
101+
candidate = self._response.candidates[0]
102+
yield LlmResponse(
103+
content=types.ModelContent(parts=parts),
104+
error_code=None
105+
if candidate.finish_reason == types.FinishReason.STOP
106+
else candidate.finish_reason,
107+
error_message=None
108+
if candidate.finish_reason == types.FinishReason.STOP
109+
else candidate.finish_message,
110+
usage_metadata=self._usage_metadata,
111+
)
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from google.adk.utils import streaming_utils
18+
from google.genai import types
19+
import pytest
20+
21+
22+
class TestStreamingResponseAggregator:
23+
24+
@pytest.mark.asyncio
25+
async def test_process_response_with_text(self):
26+
aggregator = streaming_utils.StreamingResponseAggregator()
27+
response = types.GenerateContentResponse(
28+
candidates=[
29+
types.Candidate(
30+
content=types.Content(parts=[types.Part(text="Hello")])
31+
)
32+
]
33+
)
34+
results = []
35+
async for r in aggregator.process_response(response):
36+
results.append(r)
37+
assert len(results) == 1
38+
assert results[0].content.parts[0].text == "Hello"
39+
assert results[0].partial
40+
41+
@pytest.mark.asyncio
42+
async def test_process_response_with_thought(self):
43+
aggregator = streaming_utils.StreamingResponseAggregator()
44+
response = types.GenerateContentResponse(
45+
candidates=[
46+
types.Candidate(
47+
content=types.Content(
48+
parts=[types.Part(text="Thinking...", thought=True)]
49+
)
50+
)
51+
]
52+
)
53+
results = []
54+
async for r in aggregator.process_response(response):
55+
results.append(r)
56+
assert len(results) == 1
57+
assert results[0].content.parts[0].text == "Thinking..."
58+
assert results[0].content.parts[0].thought
59+
assert results[0].partial
60+
61+
@pytest.mark.asyncio
62+
async def test_process_response_multiple(self):
63+
aggregator = streaming_utils.StreamingResponseAggregator()
64+
response1 = types.GenerateContentResponse(
65+
candidates=[
66+
types.Candidate(
67+
content=types.Content(parts=[types.Part(text="Hello ")])
68+
)
69+
]
70+
)
71+
response2 = types.GenerateContentResponse(
72+
candidates=[
73+
types.Candidate(
74+
content=types.Content(parts=[types.Part(text="World!")])
75+
)
76+
]
77+
)
78+
async for _ in aggregator.process_response(response1):
79+
pass
80+
results = []
81+
async for r in aggregator.process_response(response2):
82+
results.append(r)
83+
assert len(results) == 1
84+
assert results[0].content.parts[0].text == "World!"
85+
86+
closed_responses = []
87+
async for r in aggregator.close():
88+
closed_responses.append(r)
89+
assert len(closed_responses) == 1
90+
assert closed_responses[0].content.parts[0].text == "Hello World!"
91+
92+
@pytest.mark.asyncio
93+
async def test_process_response_interleaved_thought_and_text(self):
94+
aggregator = streaming_utils.StreamingResponseAggregator()
95+
response1 = types.GenerateContentResponse(
96+
candidates=[
97+
types.Candidate(
98+
content=types.Content(
99+
parts=[types.Part(text="I am thinking...", thought=True)]
100+
)
101+
)
102+
]
103+
)
104+
response2 = types.GenerateContentResponse(
105+
candidates=[
106+
types.Candidate(
107+
content=types.Content(
108+
parts=[types.Part(text="Okay, I have a result.")]
109+
)
110+
)
111+
]
112+
)
113+
response3 = types.GenerateContentResponse(
114+
candidates=[
115+
types.Candidate(
116+
content=types.Content(
117+
parts=[types.Part(text=" The result is 42.")]
118+
)
119+
)
120+
]
121+
)
122+
123+
async for _ in aggregator.process_response(response1):
124+
pass
125+
async for _ in aggregator.process_response(response2):
126+
pass
127+
async for _ in aggregator.process_response(response3):
128+
pass
129+
130+
closed_responses = []
131+
async for r in aggregator.close():
132+
closed_responses.append(r)
133+
assert len(closed_responses) == 1
134+
closed_response = closed_responses[0]
135+
assert len(closed_response.content.parts) == 2
136+
assert closed_response.content.parts[0].text == "I am thinking..."
137+
assert closed_response.content.parts[0].thought
138+
assert (
139+
closed_response.content.parts[1].text
140+
== "Okay, I have a result. The result is 42."
141+
)
142+
assert not closed_response.content.parts[1].thought
143+
144+
@pytest.mark.asyncio
145+
async def test_close_with_no_responses(self):
146+
aggregator = streaming_utils.StreamingResponseAggregator()
147+
closed_responses = []
148+
async for r in aggregator.close():
149+
closed_responses.append(r)
150+
assert len(closed_responses) == 0
151+
152+
@pytest.mark.asyncio
153+
async def test_close_with_finish_reason(self):
154+
aggregator = streaming_utils.StreamingResponseAggregator()
155+
response = types.GenerateContentResponse(
156+
candidates=[
157+
types.Candidate(
158+
content=types.Content(parts=[types.Part(text="Hello")]),
159+
finish_reason=types.FinishReason.STOP,
160+
)
161+
]
162+
)
163+
async for _ in aggregator.process_response(response):
164+
pass
165+
closed_responses = []
166+
async for r in aggregator.close():
167+
closed_responses.append(r)
168+
assert len(closed_responses) == 1
169+
assert closed_responses[0].content.parts[0].text == "Hello"
170+
assert closed_responses[0].error_code is None
171+
assert closed_responses[0].error_message is None
172+
173+
@pytest.mark.asyncio
174+
async def test_close_with_error(self):
175+
aggregator = streaming_utils.StreamingResponseAggregator()
176+
response = types.GenerateContentResponse(
177+
candidates=[
178+
types.Candidate(
179+
content=types.Content(parts=[types.Part(text="Error")]),
180+
finish_reason=types.FinishReason.RECITATION,
181+
finish_message="Recitation error",
182+
)
183+
]
184+
)
185+
async for _ in aggregator.process_response(response):
186+
pass
187+
closed_responses = []
188+
async for r in aggregator.close():
189+
closed_responses.append(r)
190+
assert len(closed_responses) == 1
191+
assert closed_responses[0].content.parts[0].text == "Error"
192+
assert closed_responses[0].error_code == types.FinishReason.RECITATION
193+
assert closed_responses[0].error_message == "Recitation error"

0 commit comments

Comments
 (0)