Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/google/adk/cli/cli_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ async def run_evals(
print(f"Running Eval: {eval_set_id}:{eval_name}")
session_id = f"{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}"

inference_result = (
inference_result, session_id = (
await EvaluationGenerator._generate_inferences_from_root_agent(
invocations=eval_case.conversation,
root_agent=root_agent,
Expand Down
25 changes: 19 additions & 6 deletions src/google/adk/evaluation/evaluation_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..runners import Runner
from ..sessions.base_session_service import BaseSessionService
try:
from ..sessions.vertex_ai_session_service import VertexAiSessionService
except ImportError:
VertexAiSessionService = None
from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.session import Session
from ..utils.context_utils import Aclosing
Expand Down Expand Up @@ -132,9 +136,10 @@ async def _process_query(
agent_to_evaluate = root_agent.find_agent(agent_name)
assert agent_to_evaluate, f"Sub-Agent `{agent_name}` not found."

return await EvaluationGenerator._generate_inferences_from_root_agent(
response_invocations, _ = await EvaluationGenerator._generate_inferences_from_root_agent(
invocations, agent_to_evaluate, reset_func, initial_session
)
return response_invocations

@staticmethod
async def _generate_inferences_from_root_agent(
Expand All @@ -146,7 +151,7 @@ async def _generate_inferences_from_root_agent(
session_service: Optional[BaseSessionService] = None,
artifact_service: Optional[BaseArtifactService] = None,
memory_service: Optional[BaseMemoryService] = None,
) -> list[Invocation]:
) -> tuple[list[Invocation], str]:
"""Scrapes the root agent given the list of Invocations."""
if not session_service:
session_service = InMemorySessionService()
Expand All @@ -158,14 +163,22 @@ async def _generate_inferences_from_root_agent(
initial_session.app_name if initial_session else "EvaluationGenerator"
)
user_id = initial_session.user_id if initial_session else "test_user_id"
session_id = session_id if session_id else str(uuid.uuid4())

_ = await session_service.create_session(
if VertexAiSessionService and isinstance(session_service, VertexAiSessionService):
vertex_session = await session_service.create_session(
app_name=app_name,
user_id=user_id,
state=initial_session.state if initial_session else {}
)
session_id = vertex_session.id
else:
session_id = session_id if session_id else str(uuid.uuid4())
_ = await session_service.create_session(
app_name=app_name,
user_id=user_id,
state=initial_session.state if initial_session else {},
session_id=session_id,
)
)

if not artifact_service:
artifact_service = InMemoryArtifactService()
Expand Down Expand Up @@ -219,7 +232,7 @@ async def _generate_inferences_from_root_agent(
)
)

return response_invocations
return response_invocations, session_id

@staticmethod
def _process_query_with_session(session_data, data):
Expand Down
3 changes: 2 additions & 1 deletion src/google/adk/evaluation/local_eval_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ async def _perform_inference_sigle_eval_item(
)

try:
inferences = (
inferences, session_id = (
await EvaluationGenerator._generate_inferences_from_root_agent(
invocations=eval_case.conversation,
root_agent=root_agent,
Expand All @@ -371,6 +371,7 @@ async def _perform_inference_sigle_eval_item(
)

inference_result.inferences = inferences
inference_result.session_id = session_id # Relevant for Vertex AI Session Service and other services that use ad-hoc session id.
inference_result.status = InferenceStatus.SUCCESS

return inference_result
Expand Down