Skip to content

Commit 49cc800

Browse files
committed
fix(fastapi): pass state_delta to runner in /run endpoint
1 parent 831e2e6 commit 49cc800

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

src/google/adk/cli/adk_web_server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,7 @@ async def run_agent(req: RunAgentRequest) -> list[Event]:
993993
user_id=req.user_id,
994994
session_id=req.session_id,
995995
new_message=req.new_message,
996+
state_delta=req.state_delta,
996997
)
997998
) as agen:
998999
events = [event async for event in agen]

tests/unittests/cli/test_fast_api.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from google.adk.evaluation.eval_set import EvalSet
3535
from google.adk.evaluation.in_memory_eval_sets_manager import InMemoryEvalSetsManager
3636
from google.adk.events.event import Event
37+
from google.adk.events.event_actions import EventActions
3738
from google.adk.runners import Runner
3839
from google.adk.sessions.base_session_service import ListSessionsResponse
3940
from google.genai import types
@@ -94,6 +95,14 @@ def _event_3():
9495
)
9596

9697

98+
def _event_state_delta(state_delta: dict[str, Any]):
99+
return Event(
100+
author="dummy agent",
101+
invocation_id="invocation_id",
102+
actions=EventActions(state_delta=state_delta),
103+
)
104+
105+
97106
# Define mocked async generator functions for the Runner
98107
async def dummy_run_live(self, session, live_request_queue):
99108
yield _event_1()
@@ -110,6 +119,7 @@ async def dummy_run_async(
110119
user_id,
111120
session_id,
112121
new_message,
122+
state_delta=None,
113123
run_config: RunConfig = RunConfig(),
114124
):
115125
yield _event_1()
@@ -119,6 +129,10 @@ async def dummy_run_async(
119129
await asyncio.sleep(0)
120130

121131
yield _event_3()
132+
await asyncio.sleep(0)
133+
134+
if state_delta is not None:
135+
yield _event_state_delta(state_delta)
122136

123137

124138
# Define a local mock for EvalCaseResult specific to fast_api tests
@@ -743,6 +757,27 @@ def test_agent_run(test_app, create_test_session):
743757

744758
logger.info("Agent run test completed successfully")
745759

760+
def test_agent_run_passes_state_delta(test_app, create_test_session):
761+
"""Test /run forwards state_delta and surfaces it in events."""
762+
info = create_test_session
763+
payload = {
764+
"app_name": info["app_name"],
765+
"user_id": info["user_id"],
766+
"session_id": info["session_id"],
767+
"new_message": {"role": "user", "parts": [{"text": "Hello"}]},
768+
"streaming": False,
769+
"state_delta": {"k": "v", "count": 1},
770+
}
771+
772+
# Verify the response
773+
response = test_app.post("/run", json=payload)
774+
assert response.status_code == 200
775+
data = response.json()
776+
assert isinstance(data, list)
777+
assert len(data) == 4
778+
779+
# Verify we got the expected event
780+
assert data[3]["actions"]["stateDelta"] == payload["state_delta"]
746781

747782
def test_list_artifact_names(test_app, create_test_session):
748783
"""Test listing artifact names for a session."""

0 commit comments

Comments
 (0)