Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,7 @@ async def run_agent(req: RunAgentRequest) -> list[Event]:
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
state_delta=req.state_delta,
)
) as agen:
events = [event async for event in agen]
Expand Down
37 changes: 37 additions & 0 deletions tests/unittests/cli/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from google.adk.evaluation.eval_set import EvalSet
from google.adk.evaluation.in_memory_eval_sets_manager import InMemoryEvalSetsManager
from google.adk.events.event import Event
from google.adk.events.event_actions import EventActions
from google.adk.runners import Runner
from google.adk.sessions.base_session_service import ListSessionsResponse
from google.genai import types
Expand Down Expand Up @@ -94,6 +95,14 @@ def _event_3():
)


def _event_state_delta(state_delta: dict[str, Any]):
return Event(
author="dummy agent",
invocation_id="invocation_id",
actions=EventActions(state_delta=state_delta),
)


# Define mocked async generator functions for the Runner
async def dummy_run_live(self, session, live_request_queue):
yield _event_1()
Expand All @@ -110,6 +119,7 @@ async def dummy_run_async(
user_id,
session_id,
new_message,
state_delta=None,
run_config: RunConfig = RunConfig(),
):
yield _event_1()
Expand All @@ -119,6 +129,10 @@ async def dummy_run_async(
await asyncio.sleep(0)

yield _event_3()
await asyncio.sleep(0)

if state_delta is not None:
yield _event_state_delta(state_delta)


# Define a local mock for EvalCaseResult specific to fast_api tests
Expand Down Expand Up @@ -744,6 +758,29 @@ def test_agent_run(test_app, create_test_session):
logger.info("Agent run test completed successfully")


def test_agent_run_passes_state_delta(test_app, create_test_session):
"""Test /run forwards state_delta and surfaces it in events."""
info = create_test_session
payload = {
"app_name": info["app_name"],
"user_id": info["user_id"],
"session_id": info["session_id"],
"new_message": {"role": "user", "parts": [{"text": "Hello"}]},
"streaming": False,
"state_delta": {"k": "v", "count": 1},
}

# Verify the response
response = test_app.post("/run", json=payload)
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
assert len(data) == 4

# Verify we got the expected event
assert data[3]["actions"]["stateDelta"] == payload["state_delta"]


def test_list_artifact_names(test_app, create_test_session):
"""Test listing artifact names for a session."""
info = create_test_session
Expand Down