Skip to content

Commit e1e71ff

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Extract a utility for aggregating partial streaming responses and emitting LlmResponses for them as needed
PiperOrigin-RevId: 796951203
1 parent 5b999ed commit e1e71ff

File tree

3 files changed

+314
-54
lines changed

3 files changed

+314
-54
lines changed

src/google/adk/models/google_llm.py

Lines changed: 8 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from .. import version
3535
from ..utils.context_utils import Aclosing
36+
from ..utils.streaming_utils import StreamingResponseAggregator
3637
from ..utils.variant_utils import GoogleLLMVariant
3738
from .base_llm import BaseLlm
3839
from .base_llm_connection import BaseLlmConnection
@@ -133,68 +134,21 @@ async def generate_content_async(
133134
contents=llm_request.contents,
134135
config=llm_request.config,
135136
)
136-
response = None
137-
thought_text = ''
138-
text = ''
139-
usage_metadata = None
137+
140138
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
141139
# we need to mark those text content as partial and after all partial
142140
# contents are sent, we send an accumulated event which contains all the
143141
# previous partial content. The only difference is bidi rely on
144142
# complete_turn flag to detect end while sse depends on finish_reason.
143+
aggregator = StreamingResponseAggregator()
145144
async with Aclosing(responses) as agen:
146145
async for response in agen:
147146
logger.debug(_build_response_log(response))
148-
llm_response = LlmResponse.create(response)
149-
usage_metadata = llm_response.usage_metadata
150-
if (
151-
llm_response.content
152-
and llm_response.content.parts
153-
and llm_response.content.parts[0].text
154-
):
155-
part0 = llm_response.content.parts[0]
156-
if part0.thought:
157-
thought_text += part0.text
158-
else:
159-
text += part0.text
160-
llm_response.partial = True
161-
elif (thought_text or text) and (
162-
not llm_response.content
163-
or not llm_response.content.parts
164-
# don't yield the merged text event when receiving audio data
165-
or not llm_response.content.parts[0].inline_data
166-
):
167-
parts = []
168-
if thought_text:
169-
parts.append(types.Part(text=thought_text, thought=True))
170-
if text:
171-
parts.append(types.Part.from_text(text=text))
172-
yield LlmResponse(
173-
content=types.ModelContent(parts=parts),
174-
usage_metadata=llm_response.usage_metadata,
175-
)
176-
thought_text = ''
177-
text = ''
178-
yield llm_response
179-
180-
# generate an aggregated content at the end regardless the
181-
# response.candidates[0].finish_reason
182-
if (text or thought_text) and response and response.candidates:
183-
parts = []
184-
if thought_text:
185-
parts.append(types.Part(text=thought_text, thought=True))
186-
if text:
187-
parts.append(types.Part.from_text(text=text))
188-
yield LlmResponse(
189-
content=types.ModelContent(parts=parts),
190-
error_code=None
191-
if response.candidates[0].finish_reason == FinishReason.STOP
192-
else response.candidates[0].finish_reason,
193-
error_message=None
194-
if response.candidates[0].finish_reason == FinishReason.STOP
195-
else response.candidates[0].finish_message,
196-
usage_metadata=usage_metadata,
197-
)
147+
async for llm_response in aggregator.process_response(response):
148+
yield llm_response
149+
150+
async for llm_response in aggregator.close():
151+
yield llm_response
198152

199153
else:
200154
response = await self.api_client.aio.models.generate_content(
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import AsyncGenerator
18+
19+
from google.adk.models import llm_response as llm_response_lib
20+
from google.genai import types
21+
22+
23+
class StreamingResponseAggregator:
24+
"""Aggregates partial streaming responses.
25+
26+
It aggregates content from partial responses, and generates LlmResponses for
27+
individual (partial) model responses, as well as for aggregated content.
28+
"""
29+
30+
def __init__(self):
31+
self._text = ''
32+
self._thought_text = ''
33+
self._usage_metadata = None
34+
self._response = None
35+
36+
async def process_response(
37+
self, response: types.GenerateContentResponse
38+
) -> AsyncGenerator[llm_response_lib.LlmResponse, None]:
39+
"""Processes a single model response.
40+
41+
Args:
42+
response: The response to process.
43+
44+
Yields:
45+
The generated LlmResponse(s), for the partial response, and the aggregated
46+
response if needed.
47+
"""
48+
# results = []
49+
self._response = response
50+
llm_response = llm_response_lib.LlmResponse.create(response)
51+
self._usage_metadata = llm_response.usage_metadata
52+
if (
53+
llm_response.content
54+
and llm_response.content.parts
55+
and llm_response.content.parts[0].text
56+
):
57+
part0 = llm_response.content.parts[0]
58+
if part0.thought:
59+
self._thought_text += part0.text
60+
else:
61+
self._text += part0.text
62+
llm_response.partial = True
63+
elif (self._thought_text or self._text) and (
64+
not llm_response.content
65+
or not llm_response.content.parts
66+
# don't yield the merged text event when receiving audio data
67+
or not llm_response.content.parts[0].inline_data
68+
):
69+
parts = []
70+
if self._thought_text:
71+
parts.append(types.Part(text=self._thought_text, thought=True))
72+
if self._text:
73+
parts.append(types.Part.from_text(text=self._text))
74+
# results.append(
75+
yield llm_response_lib.LlmResponse(
76+
content=types.ModelContent(parts=parts),
77+
usage_metadata=llm_response.usage_metadata,
78+
)
79+
# )
80+
self._thought_text = ''
81+
self._text = ''
82+
# results.append(llm_response)
83+
yield llm_response
84+
85+
async def close(self) -> AsyncGenerator[llm_response_lib.LlmResponse, None]:
86+
"""Generate an aggregated response at the end, if needed.
87+
88+
This should be called after all the model responses are processed.
89+
90+
Yields:
91+
The aggregated LlmResponse.
92+
"""
93+
if (
94+
(self._text or self._thought_text)
95+
and self._response
96+
and self._response.candidates
97+
):
98+
parts = []
99+
if self._thought_text:
100+
parts.append(types.Part(text=self._thought_text, thought=True))
101+
if self._text:
102+
parts.append(types.Part.from_text(text=self._text))
103+
candidate = self._response.candidates[0]
104+
yield llm_response_lib.LlmResponse(
105+
content=types.ModelContent(parts=parts),
106+
error_code=None
107+
if candidate.finish_reason == types.FinishReason.STOP
108+
else candidate.finish_reason,
109+
error_message=None
110+
if candidate.finish_reason == types.FinishReason.STOP
111+
else candidate.finish_message,
112+
usage_metadata=self._usage_metadata,
113+
)
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
import unittest
18+
19+
from google.adk.utils import streaming_utils
20+
from google.genai import types
21+
22+
23+
class TestStreamingResponseAggregator(unittest.IsolatedAsyncioTestCase):
24+
25+
async def test_process_response_with_text(self):
26+
aggregator = streaming_utils.StreamingResponseAggregator()
27+
response = types.GenerateContentResponse(
28+
candidates=[
29+
types.Candidate(
30+
content=types.Content(parts=[types.Part(text="Hello")])
31+
)
32+
]
33+
)
34+
results = []
35+
async for r in aggregator.process_response(response):
36+
results.append(r)
37+
self.assertEqual(len(results), 1)
38+
self.assertEqual(results[0].content.parts[0].text, "Hello")
39+
self.assertTrue(results[0].partial)
40+
41+
async def test_process_response_with_thought(self):
42+
aggregator = streaming_utils.StreamingResponseAggregator()
43+
response = types.GenerateContentResponse(
44+
candidates=[
45+
types.Candidate(
46+
content=types.Content(
47+
parts=[types.Part(text="Thinking...", thought=True)]
48+
)
49+
)
50+
]
51+
)
52+
results = []
53+
async for r in aggregator.process_response(response):
54+
results.append(r)
55+
self.assertEqual(len(results), 1)
56+
self.assertEqual(results[0].content.parts[0].text, "Thinking...")
57+
self.assertTrue(results[0].content.parts[0].thought)
58+
self.assertTrue(results[0].partial)
59+
60+
async def test_process_response_multiple(self):
61+
aggregator = streaming_utils.StreamingResponseAggregator()
62+
response1 = types.GenerateContentResponse(
63+
candidates=[
64+
types.Candidate(
65+
content=types.Content(parts=[types.Part(text="Hello ")])
66+
)
67+
]
68+
)
69+
response2 = types.GenerateContentResponse(
70+
candidates=[
71+
types.Candidate(
72+
content=types.Content(parts=[types.Part(text="World!")])
73+
)
74+
]
75+
)
76+
async for _ in aggregator.process_response(response1):
77+
pass
78+
results = []
79+
async for r in aggregator.process_response(response2):
80+
results.append(r)
81+
self.assertEqual(len(results), 1)
82+
self.assertEqual(results[0].content.parts[0].text, "World!")
83+
84+
closed_responses = []
85+
async for r in aggregator.close():
86+
closed_responses.append(r)
87+
self.assertEqual(len(closed_responses), 1)
88+
self.assertEqual(closed_responses[0].content.parts[0].text, "Hello World!")
89+
90+
async def test_process_response_interleaved_thought_and_text(self):
91+
aggregator = streaming_utils.StreamingResponseAggregator()
92+
response1 = types.GenerateContentResponse(
93+
candidates=[
94+
types.Candidate(
95+
content=types.Content(
96+
parts=[types.Part(text="I am thinking...", thought=True)]
97+
)
98+
)
99+
]
100+
)
101+
response2 = types.GenerateContentResponse(
102+
candidates=[
103+
types.Candidate(
104+
content=types.Content(
105+
parts=[types.Part(text="Okay, I have a result.")]
106+
)
107+
)
108+
]
109+
)
110+
response3 = types.GenerateContentResponse(
111+
candidates=[
112+
types.Candidate(
113+
content=types.Content(
114+
parts=[types.Part(text=" The result is 42.")]
115+
)
116+
)
117+
]
118+
)
119+
120+
async for _ in aggregator.process_response(response1):
121+
pass
122+
async for _ in aggregator.process_response(response2):
123+
pass
124+
async for _ in aggregator.process_response(response3):
125+
pass
126+
127+
closed_responses = []
128+
async for r in aggregator.close():
129+
closed_responses.append(r)
130+
self.assertEqual(len(closed_responses), 1)
131+
closed_response = closed_responses[0]
132+
self.assertEqual(len(closed_response.content.parts), 2)
133+
self.assertEqual(closed_response.content.parts[0].text, "I am thinking...")
134+
self.assertTrue(closed_response.content.parts[0].thought)
135+
self.assertEqual(
136+
closed_response.content.parts[1].text,
137+
"Okay, I have a result. The result is 42.",
138+
)
139+
self.assertFalse(closed_response.content.parts[1].thought)
140+
141+
async def test_close_with_no_responses(self):
142+
aggregator = streaming_utils.StreamingResponseAggregator()
143+
closed_responses = []
144+
async for r in aggregator.close():
145+
closed_responses.append(r)
146+
self.assertEqual(len(closed_responses), 0)
147+
148+
async def test_close_with_finish_reason(self):
149+
aggregator = streaming_utils.StreamingResponseAggregator()
150+
response = types.GenerateContentResponse(
151+
candidates=[
152+
types.Candidate(
153+
content=types.Content(parts=[types.Part(text="Hello")]),
154+
finish_reason=types.FinishReason.STOP,
155+
)
156+
]
157+
)
158+
async for _ in aggregator.process_response(response):
159+
pass
160+
closed_responses = []
161+
async for r in aggregator.close():
162+
closed_responses.append(r)
163+
self.assertEqual(len(closed_responses), 1)
164+
self.assertEqual(closed_responses[0].content.parts[0].text, "Hello")
165+
self.assertIsNone(closed_responses[0].error_code)
166+
self.assertIsNone(closed_responses[0].error_message)
167+
168+
async def test_close_with_error(self):
169+
aggregator = streaming_utils.StreamingResponseAggregator()
170+
response = types.GenerateContentResponse(
171+
candidates=[
172+
types.Candidate(
173+
content=types.Content(parts=[types.Part(text="Error")]),
174+
finish_reason=types.FinishReason.RECITATION,
175+
finish_message="Recitation error",
176+
)
177+
]
178+
)
179+
async for _ in aggregator.process_response(response):
180+
pass
181+
closed_responses = []
182+
async for r in aggregator.close():
183+
closed_responses.append(r)
184+
self.assertEqual(len(closed_responses), 1)
185+
self.assertEqual(closed_responses[0].content.parts[0].text, "Error")
186+
self.assertEqual(
187+
closed_responses[0].error_code, types.FinishReason.RECITATION
188+
)
189+
self.assertEqual(closed_responses[0].error_message, "Recitation error")
190+
191+
192+
if __name__ == "__main__":
193+
unittest.main()

0 commit comments

Comments
 (0)