Skip to content

Commit 5806ce1

Browse files
authored
Merge branch 'google:main' into oauth-audience-prompt
2 parents 9a38de4 + 2dd432c commit 5806ce1

File tree

9 files changed

+184
-30
lines changed

9 files changed

+184
-30
lines changed

src/google/adk/cli/adk_web_server.py

Lines changed: 101 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,18 @@ class AddSessionToEvalSetRequest(common.BaseModel):
173173

174174

175175
class RunEvalRequest(common.BaseModel):
176-
eval_ids: list[str] # if empty, then all evals in the eval set are run.
176+
eval_ids: list[str] = Field(
177+
deprecated=True,
178+
default_factory=list,
179+
description="This field is deprecated, use eval_case_ids instead.",
180+
)
181+
eval_case_ids: list[str] = Field(
182+
default_factory=list,
183+
description=(
184+
"List of eval case ids to evaluate. if empty, then all eval cases in"
185+
" the eval set are run."
186+
),
187+
)
177188
eval_metrics: list[EvalMetric]
178189

179190

@@ -195,6 +206,10 @@ class RunEvalResult(common.BaseModel):
195206
session_id: str
196207

197208

209+
class RunEvalResponse(common.BaseModel):
210+
run_eval_results: list[RunEvalResult]
211+
212+
198213
class GetEventGraphResult(common.BaseModel):
199214
dot_src: str
200215

@@ -207,6 +222,22 @@ class ListEvalSetsResponse(common.BaseModel):
207222
eval_set_ids: list[str]
208223

209224

225+
class EvalResult(EvalSetResult):
226+
"""This class has no field intentionally.
227+
228+
The goal here is to just give a new name to the class to align with the API
229+
endpoint.
230+
"""
231+
232+
233+
class ListEvalResultsResponse(common.BaseModel):
234+
eval_result_ids: list[str]
235+
236+
237+
class ListMetricsInfoResponse(common.BaseModel):
238+
metrics_info: list[MetricInfo]
239+
240+
210241
class AdkWebServer:
211242
"""Helper class for setting up and running the ADK web server on FastAPI.
212243
@@ -690,14 +721,30 @@ async def delete_eval(
690721
except NotFoundError as nfe:
691722
raise HTTPException(status_code=404, detail=str(nfe)) from nfe
692723

724+
@deprecated(
725+
"Please use run_eval instead. This will be removed in future releases."
726+
)
693727
@app.post(
694728
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
695729
response_model_exclude_none=True,
696730
tags=[TAG_EVALUATION],
697731
)
698-
async def run_eval(
732+
async def run_eval_legacy(
699733
app_name: str, eval_set_id: str, req: RunEvalRequest
700734
) -> list[RunEvalResult]:
735+
run_eval_response = await run_eval(
736+
app_name=app_name, eval_set_id=eval_set_id, req=req
737+
)
738+
return run_eval_response.run_eval_results
739+
740+
@app.post(
741+
"/apps/{app_name}/eval-sets/{eval_set_id}/run",
742+
response_model_exclude_none=True,
743+
tags=[TAG_EVALUATION],
744+
)
745+
async def run_eval(
746+
app_name: str, eval_set_id: str, req: RunEvalRequest
747+
) -> RunEvalResponse:
701748
"""Runs an eval given the details in the eval request."""
702749
# Create a mapping from eval set file to all the evals that needed to be
703750
# run.
@@ -727,7 +774,7 @@ async def run_eval(
727774
inference_request = InferenceRequest(
728775
app_name=app_name,
729776
eval_set_id=eval_set.eval_set_id,
730-
eval_case_ids=req.eval_ids,
777+
eval_case_ids=req.eval_case_ids or req.eval_ids,
731778
inference_config=InferenceConfig(),
732779
)
733780
inference_results = await _collect_inferences(
@@ -760,18 +807,41 @@ async def run_eval(
760807
)
761808
)
762809

763-
return run_eval_results
810+
return RunEvalResponse(run_eval_results=run_eval_results)
764811

765812
@app.get(
766-
"/apps/{app_name}/eval_results/{eval_result_id}",
813+
"/apps/{app_name}/eval-results/{eval_result_id}",
767814
response_model_exclude_none=True,
768815
tags=[TAG_EVALUATION],
769816
)
770817
async def get_eval_result(
771818
app_name: str,
772819
eval_result_id: str,
773-
) -> EvalSetResult:
820+
) -> EvalResult:
774821
"""Gets the eval result for the given eval id."""
822+
try:
823+
eval_set_result = self.eval_set_results_manager.get_eval_set_result(
824+
app_name, eval_result_id
825+
)
826+
return EvalResult(**eval_set_result.model_dump())
827+
except ValueError as ve:
828+
raise HTTPException(status_code=404, detail=str(ve)) from ve
829+
except ValidationError as ve:
830+
raise HTTPException(status_code=500, detail=str(ve)) from ve
831+
832+
@deprecated(
833+
"Please use get_eval_result instead. This will be removed in future"
834+
" releases."
835+
)
836+
@app.get(
837+
"/apps/{app_name}/eval_results/{eval_result_id}",
838+
response_model_exclude_none=True,
839+
tags=[TAG_EVALUATION],
840+
)
841+
async def get_eval_result_legacy(
842+
app_name: str,
843+
eval_result_id: str,
844+
) -> EvalSetResult:
775845
try:
776846
return self.eval_set_results_manager.get_eval_set_result(
777847
app_name, eval_result_id
@@ -782,27 +852,46 @@ async def get_eval_result(
782852
raise HTTPException(status_code=500, detail=str(ve)) from ve
783853

784854
@app.get(
785-
"/apps/{app_name}/eval_results",
855+
"/apps/{app_name}/eval-results",
786856
response_model_exclude_none=True,
787857
tags=[TAG_EVALUATION],
788858
)
789-
async def list_eval_results(app_name: str) -> list[str]:
859+
async def list_eval_results(app_name: str) -> ListEvalResultsResponse:
790860
"""Lists all eval results for the given app."""
791-
return self.eval_set_results_manager.list_eval_set_results(app_name)
861+
eval_result_ids = self.eval_set_results_manager.list_eval_set_results(
862+
app_name
863+
)
864+
return ListEvalResultsResponse(eval_result_ids=eval_result_ids)
865+
866+
@deprecated(
867+
"Please use list_eval_results instead. This will be removed in future"
868+
" releases."
869+
)
870+
@app.get(
871+
"/apps/{app_name}/eval_results",
872+
response_model_exclude_none=True,
873+
tags=[TAG_EVALUATION],
874+
)
875+
async def list_eval_results_legacy(app_name: str) -> list[str]:
876+
list_eval_results_response = await list_eval_results(app_name)
877+
return list_eval_results_response.eval_result_ids
792878

793879
@app.get(
794-
"/apps/{app_name}/eval_metrics",
880+
"/apps/{app_name}/metrics-info",
795881
response_model_exclude_none=True,
796882
tags=[TAG_EVALUATION],
797883
)
798-
async def list_eval_metrics(app_name: str) -> list[MetricInfo]:
884+
async def list_metrics_info(app_name: str) -> ListMetricsInfoResponse:
799885
"""Lists all eval metrics for the given app."""
800886
try:
801887
from ..evaluation.metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
802888

803889
# Right now we ignore the app_name as eval metrics are not tied to the
804890
# app_name, but they could be moving forward.
805-
return DEFAULT_METRIC_EVALUATOR_REGISTRY.get_registered_metrics()
891+
metrics_info = (
892+
DEFAULT_METRIC_EVALUATOR_REGISTRY.get_registered_metrics()
893+
)
894+
return ListMetricsInfoResponse(metrics_info=metrics_info)
806895
except ModuleNotFoundError as e:
807896
logger.exception("%s\n%s", MISSING_EVAL_DEPENDENCIES_MESSAGE, e)
808897
raise HTTPException(

src/google/adk/flows/llm_flows/audio_cache_manager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,6 @@ async def _flush_cache_to_services(
141141
Returns:
142142
True if the cache was successfully flushed, False otherwise.
143143
"""
144-
print('flush cache')
145144
if not invocation_context.artifact_service or not audio_cache:
146145
logger.debug('Skipping cache flush: no artifact service or empty cache')
147146
return False

src/google/adk/models/gemini_llm_connection.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,14 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
164164
message.server_content.input_transcription
165165
and message.server_content.input_transcription.text
166166
):
167+
user_text = message.server_content.input_transcription.text
168+
parts = [
169+
types.Part.from_text(
170+
text=user_text,
171+
)
172+
]
167173
llm_response = LlmResponse(
168-
input_transcription=message.server_content.input_transcription,
174+
content=types.Content(role='user', parts=parts)
169175
)
170176
yield llm_response
171177
if (
@@ -180,8 +186,13 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
180186
# We rely on other control signals to determine when to yield the
181187
# full text response(turn_complete, interrupted, or tool_call).
182188
text += message.server_content.output_transcription.text
189+
parts = [
190+
types.Part.from_text(
191+
text=message.server_content.output_transcription.text
192+
)
193+
]
183194
llm_response = LlmResponse(
184-
output_transcription=message.server_content.output_transcription
195+
content=types.Content(role='model', parts=parts), partial=True
185196
)
186197
yield llm_response
187198

src/google/adk/tools/agent_tool.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ async def run_async(
139139
state=tool_context.state.to_dict(),
140140
)
141141

142-
last_event = None
142+
last_content = None
143143
async with Aclosing(
144144
runner.run_async(
145145
user_id=session.user_id, session_id=session.id, new_message=content
@@ -149,11 +149,12 @@ async def run_async(
149149
# Forward state delta to parent session.
150150
if event.actions.state_delta:
151151
tool_context.state.update(event.actions.state_delta)
152-
last_event = event
152+
if event.content:
153+
last_content = event.content
153154

154-
if not last_event or not last_event.content or not last_event.content.parts:
155+
if not last_content:
155156
return ''
156-
merged_text = '\n'.join(p.text for p in last_event.content.parts if p.text)
157+
merged_text = '\n'.join(p.text for p in last_content.parts if p.text)
157158
if isinstance(self.agent, LlmAgent) and self.agent.output_schema:
158159
tool_result = self.agent.output_schema.model_validate_json(
159160
merged_text

src/google/adk/tools/base_toolset.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,17 @@
2222
from typing import Optional
2323
from typing import Protocol
2424
from typing import runtime_checkable
25+
from typing import Type
2526
from typing import TYPE_CHECKING
27+
from typing import TypeVar
2628
from typing import Union
2729

2830
from ..agents.readonly_context import ReadonlyContext
2931
from .base_tool import BaseTool
3032

3133
if TYPE_CHECKING:
3234
from ..models.llm_request import LlmRequest
35+
from .tool_configs import ToolArgsConfig
3336
from .tool_context import ToolContext
3437

3538

@@ -53,6 +56,9 @@ def __call__(
5356
"""
5457

5558

59+
SelfToolset = TypeVar("SelfToolset", bound="BaseToolset")
60+
61+
5662
class BaseToolset(ABC):
5763
"""Base class for toolset.
5864
@@ -152,6 +158,22 @@ async def close(self) -> None:
152158
resources are properly released to prevent leaks.
153159
"""
154160

161+
@classmethod
162+
def from_config(
163+
cls: Type[SelfToolset], config: ToolArgsConfig, config_abs_path: str
164+
) -> SelfToolset:
165+
"""Creates a toolset instance from a config.
166+
167+
Args:
168+
config: The config for the tool.
169+
config_abs_path: The absolute path to the config file that contains the
170+
tool config.
171+
172+
Returns:
173+
The toolset instance.
174+
"""
175+
raise ValueError(f"from_config() not implemented for toolset: {cls}")
176+
155177
def _is_tool_selected(
156178
self, tool: BaseTool, readonly_context: ReadonlyContext
157179
) -> bool:

src/google/adk/tools/mcp_tool/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,19 @@
2121
from .mcp_session_manager import StdioConnectionParams
2222
from .mcp_session_manager import StreamableHTTPConnectionParams
2323
from .mcp_tool import MCPTool
24+
from .mcp_tool import McpTool
2425
from .mcp_toolset import MCPToolset
26+
from .mcp_toolset import McpToolset
2527

2628
__all__.extend([
2729
'adk_to_mcp_tool_type',
2830
'gemini_to_json_schema',
31+
'McpTool',
2932
'MCPTool',
33+
'McpToolset',
3034
'MCPToolset',
31-
'StdioConnectionParams',
3235
'SseConnectionParams',
36+
'StdioConnectionParams',
3337
'StreamableHTTPConnectionParams',
3438
])
3539

src/google/adk/tools/mcp_tool/mcp_tool.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import base64
1818
import logging
1919
from typing import Optional
20+
import warnings
2021

2122
from fastapi.openapi.models import APIKeyIn
2223
from google.genai.types import FunctionDeclaration
@@ -52,7 +53,7 @@
5253
logger = logging.getLogger("google_adk." + __name__)
5354

5455

55-
class MCPTool(BaseAuthenticatedTool):
56+
class McpTool(BaseAuthenticatedTool):
5657
"""Turns an MCP Tool into an ADK Tool.
5758
5859
Internally, the tool initializes from a MCP Tool, and uses the MCP Session to
@@ -216,3 +217,15 @@ async def _get_headers(
216217
)
217218

218219
return headers
220+
221+
222+
class MCPTool(McpTool):
223+
"""Deprecated name, use `McpTool` instead."""
224+
225+
def __init__(self, *args, **kwargs):
226+
warnings.warn(
227+
"MCPTool class is deprecated, use `McpTool` instead.",
228+
DeprecationWarning,
229+
stacklevel=2,
230+
)
231+
super().__init__(*args, **kwargs)

0 commit comments

Comments
 (0)