Skip to content

Commit 7975e8e

Browse files
google-genai-botcopybara-github
authored andcommitted
refactor: Extract a utility for aggregating partial streaming responses and emitting LlmResponses for them as needed
PiperOrigin-RevId: 800521404
1 parent 3bc2d77 commit 7975e8e

File tree

4 files changed

+305
-54
lines changed

4 files changed

+305
-54
lines changed

src/google/adk/models/gemini_llm_connection.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
144144

145145
text = ''
146146
async with Aclosing(self._gemini_session.receive()) as agen:
147+
# TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate
148+
# partial content and emit responses as needed.
147149
async for message in agen:
148150
logger.debug('Got LLM Live message: %s', message)
149151
if message.server_content:

src/google/adk/models/google_llm.py

Lines changed: 10 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
from .. import version
3535
from ..utils.context_utils import Aclosing
36+
from ..utils.streaming_utils import StreamingResponseAggregator
3637
from ..utils.variant_utils import GoogleLLMVariant
3738
from .base_llm import BaseLlm
3839
from .base_llm_connection import BaseLlmConnection
@@ -133,68 +134,23 @@ async def generate_content_async(
133134
contents=llm_request.contents,
134135
config=llm_request.config,
135136
)
136-
response = None
137-
thought_text = ''
138-
text = ''
139-
usage_metadata = None
137+
140138
# for sse, similar as bidi (see receive method in gemini_llm_connecton.py),
141139
# we need to mark those text content as partial and after all partial
142140
# contents are sent, we send an accumulated event which contains all the
143141
# previous partial content. The only difference is bidi rely on
144142
# complete_turn flag to detect end while sse depends on finish_reason.
143+
aggregator = StreamingResponseAggregator()
145144
async with Aclosing(responses) as agen:
146145
async for response in agen:
147146
logger.debug(_build_response_log(response))
148-
llm_response = LlmResponse.create(response)
149-
usage_metadata = llm_response.usage_metadata
150-
if (
151-
llm_response.content
152-
and llm_response.content.parts
153-
and llm_response.content.parts[0].text
154-
):
155-
part0 = llm_response.content.parts[0]
156-
if part0.thought:
157-
thought_text += part0.text
158-
else:
159-
text += part0.text
160-
llm_response.partial = True
161-
elif (thought_text or text) and (
162-
not llm_response.content
163-
or not llm_response.content.parts
164-
# don't yield the merged text event when receiving audio data
165-
or not llm_response.content.parts[0].inline_data
166-
):
167-
parts = []
168-
if thought_text:
169-
parts.append(types.Part(text=thought_text, thought=True))
170-
if text:
171-
parts.append(types.Part.from_text(text=text))
172-
yield LlmResponse(
173-
content=types.ModelContent(parts=parts),
174-
usage_metadata=llm_response.usage_metadata,
175-
)
176-
thought_text = ''
177-
text = ''
178-
yield llm_response
179-
180-
# generate an aggregated content at the end regardless the
181-
# response.candidates[0].finish_reason
182-
if (text or thought_text) and response and response.candidates:
183-
parts = []
184-
if thought_text:
185-
parts.append(types.Part(text=thought_text, thought=True))
186-
if text:
187-
parts.append(types.Part.from_text(text=text))
188-
yield LlmResponse(
189-
content=types.ModelContent(parts=parts),
190-
error_code=None
191-
if response.candidates[0].finish_reason == FinishReason.STOP
192-
else response.candidates[0].finish_reason,
193-
error_message=None
194-
if response.candidates[0].finish_reason == FinishReason.STOP
195-
else response.candidates[0].finish_message,
196-
usage_metadata=usage_metadata,
197-
)
147+
async with Aclosing(
148+
aggregator.process_response(response)
149+
) as aggregator_gen:
150+
async for llm_response in aggregator_gen:
151+
yield llm_response
152+
if (close_result := aggregator.close()) is not None:
153+
yield close_result
198154

199155
else:
200156
response = await self.api_client.aio.models.generate_content(
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from typing import AsyncGenerator
18+
from typing import Optional
19+
20+
from google.genai import types
21+
22+
from ..models.llm_response import LlmResponse
23+
24+
25+
class StreamingResponseAggregator:
26+
"""Aggregates partial streaming responses.
27+
28+
It aggregates content from partial responses, and generates LlmResponses for
29+
individual (partial) model responses, as well as for aggregated content.
30+
"""
31+
32+
def __init__(self):
33+
self._text = ''
34+
self._thought_text = ''
35+
self._usage_metadata = None
36+
self._response = None
37+
38+
async def process_response(
39+
self, response: types.GenerateContentResponse
40+
) -> AsyncGenerator[LlmResponse, None]:
41+
"""Processes a single model response.
42+
43+
Args:
44+
response: The response to process.
45+
46+
Yields:
47+
The generated LlmResponse(s), for the partial response, and the aggregated
48+
response if needed.
49+
"""
50+
# results = []
51+
self._response = response
52+
llm_response = LlmResponse.create(response)
53+
self._usage_metadata = llm_response.usage_metadata
54+
if (
55+
llm_response.content
56+
and llm_response.content.parts
57+
and llm_response.content.parts[0].text
58+
):
59+
part0 = llm_response.content.parts[0]
60+
if part0.thought:
61+
self._thought_text += part0.text
62+
else:
63+
self._text += part0.text
64+
llm_response.partial = True
65+
elif (self._thought_text or self._text) and (
66+
not llm_response.content
67+
or not llm_response.content.parts
68+
# don't yield the merged text event when receiving audio data
69+
or not llm_response.content.parts[0].inline_data
70+
):
71+
parts = []
72+
if self._thought_text:
73+
parts.append(types.Part(text=self._thought_text, thought=True))
74+
if self._text:
75+
parts.append(types.Part.from_text(text=self._text))
76+
yield LlmResponse(
77+
content=types.ModelContent(parts=parts),
78+
usage_metadata=llm_response.usage_metadata,
79+
)
80+
self._thought_text = ''
81+
self._text = ''
82+
yield llm_response
83+
84+
def close(self) -> Optional[LlmResponse]:
85+
"""Generate an aggregated response at the end, if needed.
86+
87+
This should be called after all the model responses are processed.
88+
89+
Returns:
90+
The aggregated LlmResponse.
91+
"""
92+
if (
93+
(self._text or self._thought_text)
94+
and self._response
95+
and self._response.candidates
96+
):
97+
parts = []
98+
if self._thought_text:
99+
parts.append(types.Part(text=self._thought_text, thought=True))
100+
if self._text:
101+
parts.append(types.Part.from_text(text=self._text))
102+
candidate = self._response.candidates[0]
103+
return LlmResponse(
104+
content=types.ModelContent(parts=parts),
105+
error_code=None
106+
if candidate.finish_reason == types.FinishReason.STOP
107+
else candidate.finish_reason,
108+
error_message=None
109+
if candidate.finish_reason == types.FinishReason.STOP
110+
else candidate.finish_message,
111+
usage_metadata=self._usage_metadata,
112+
)
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from google.adk.utils import streaming_utils
18+
from google.genai import types
19+
import pytest
20+
21+
22+
class TestStreamingResponseAggregator:
23+
24+
@pytest.mark.asyncio
25+
async def test_process_response_with_text(self):
26+
aggregator = streaming_utils.StreamingResponseAggregator()
27+
response = types.GenerateContentResponse(
28+
candidates=[
29+
types.Candidate(
30+
content=types.Content(parts=[types.Part(text="Hello")])
31+
)
32+
]
33+
)
34+
results = []
35+
async for r in aggregator.process_response(response):
36+
results.append(r)
37+
assert len(results) == 1
38+
assert results[0].content.parts[0].text == "Hello"
39+
assert results[0].partial
40+
41+
@pytest.mark.asyncio
42+
async def test_process_response_with_thought(self):
43+
aggregator = streaming_utils.StreamingResponseAggregator()
44+
response = types.GenerateContentResponse(
45+
candidates=[
46+
types.Candidate(
47+
content=types.Content(
48+
parts=[types.Part(text="Thinking...", thought=True)]
49+
)
50+
)
51+
]
52+
)
53+
results = []
54+
async for r in aggregator.process_response(response):
55+
results.append(r)
56+
assert len(results) == 1
57+
assert results[0].content.parts[0].text == "Thinking..."
58+
assert results[0].content.parts[0].thought
59+
assert results[0].partial
60+
61+
@pytest.mark.asyncio
62+
async def test_process_response_multiple(self):
63+
aggregator = streaming_utils.StreamingResponseAggregator()
64+
response1 = types.GenerateContentResponse(
65+
candidates=[
66+
types.Candidate(
67+
content=types.Content(parts=[types.Part(text="Hello ")])
68+
)
69+
]
70+
)
71+
response2 = types.GenerateContentResponse(
72+
candidates=[
73+
types.Candidate(
74+
content=types.Content(parts=[types.Part(text="World!")])
75+
)
76+
]
77+
)
78+
async for _ in aggregator.process_response(response1):
79+
pass
80+
results = []
81+
async for r in aggregator.process_response(response2):
82+
results.append(r)
83+
assert len(results) == 1
84+
assert results[0].content.parts[0].text == "World!"
85+
86+
closed_response = aggregator.close()
87+
assert closed_response is not None
88+
assert closed_response.content.parts[0].text == "Hello World!"
89+
90+
@pytest.mark.asyncio
91+
async def test_process_response_interleaved_thought_and_text(self):
92+
aggregator = streaming_utils.StreamingResponseAggregator()
93+
response1 = types.GenerateContentResponse(
94+
candidates=[
95+
types.Candidate(
96+
content=types.Content(
97+
parts=[types.Part(text="I am thinking...", thought=True)]
98+
)
99+
)
100+
]
101+
)
102+
response2 = types.GenerateContentResponse(
103+
candidates=[
104+
types.Candidate(
105+
content=types.Content(
106+
parts=[types.Part(text="Okay, I have a result.")]
107+
)
108+
)
109+
]
110+
)
111+
response3 = types.GenerateContentResponse(
112+
candidates=[
113+
types.Candidate(
114+
content=types.Content(
115+
parts=[types.Part(text=" The result is 42.")]
116+
)
117+
)
118+
]
119+
)
120+
121+
async for _ in aggregator.process_response(response1):
122+
pass
123+
async for _ in aggregator.process_response(response2):
124+
pass
125+
async for _ in aggregator.process_response(response3):
126+
pass
127+
128+
closed_response = aggregator.close()
129+
assert closed_response is not None
130+
assert len(closed_response.content.parts) == 2
131+
assert closed_response.content.parts[0].text == "I am thinking..."
132+
assert closed_response.content.parts[0].thought
133+
assert (
134+
closed_response.content.parts[1].text
135+
== "Okay, I have a result. The result is 42."
136+
)
137+
assert not closed_response.content.parts[1].thought
138+
139+
def test_close_with_no_responses(self):
140+
aggregator = streaming_utils.StreamingResponseAggregator()
141+
closed_response = aggregator.close()
142+
assert closed_response is None
143+
144+
@pytest.mark.asyncio
145+
async def test_close_with_finish_reason(self):
146+
aggregator = streaming_utils.StreamingResponseAggregator()
147+
response = types.GenerateContentResponse(
148+
candidates=[
149+
types.Candidate(
150+
content=types.Content(parts=[types.Part(text="Hello")]),
151+
finish_reason=types.FinishReason.STOP,
152+
)
153+
]
154+
)
155+
async for _ in aggregator.process_response(response):
156+
pass
157+
closed_response = aggregator.close()
158+
assert closed_response is not None
159+
assert closed_response.content.parts[0].text == "Hello"
160+
assert closed_response.error_code is None
161+
assert closed_response.error_message is None
162+
163+
@pytest.mark.asyncio
164+
async def test_close_with_error(self):
165+
aggregator = streaming_utils.StreamingResponseAggregator()
166+
response = types.GenerateContentResponse(
167+
candidates=[
168+
types.Candidate(
169+
content=types.Content(parts=[types.Part(text="Error")]),
170+
finish_reason=types.FinishReason.RECITATION,
171+
finish_message="Recitation error",
172+
)
173+
]
174+
)
175+
async for _ in aggregator.process_response(response):
176+
pass
177+
closed_response = aggregator.close()
178+
assert closed_response is not None
179+
assert closed_response.content.parts[0].text == "Error"
180+
assert closed_response.error_code == types.FinishReason.RECITATION
181+
assert closed_response.error_message == "Recitation error"

0 commit comments

Comments
 (0)