Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
from ..evaluation.eval_sets_manager import EvalSetsManager
from ..events.event import Event
from ..memory.base_memory_service import BaseMemoryService
from ..plugins.base_plugin import BasePlugin
from ..runners import Runner
from ..sessions.base_session_service import BaseSessionService
from ..sessions.session import Session
Expand Down Expand Up @@ -250,6 +251,7 @@ def __init__(
eval_sets_manager: EvalSetsManager,
eval_set_results_manager: EvalSetResultsManager,
agents_dir: str,
plugins: Optional[list[BasePlugin]] = None,
):
self.agent_loader = agent_loader
self.session_service = session_service
Expand All @@ -259,6 +261,7 @@ def __init__(
self.eval_sets_manager = eval_sets_manager
self.eval_set_results_manager = eval_set_results_manager
self.agents_dir = agents_dir
self.plugins = plugins or []
# Internal propeties we want to allow being modified from callbacks.
self.runners_to_clean: set[str] = set()
self.current_app_name_ref: SharedValue[str] = SharedValue(value="")
Expand All @@ -282,6 +285,7 @@ async def get_runner_async(self, app_name: str) -> Runner:
session_service=self.session_service,
memory_service=self.memory_service,
credential_service=self.credential_service,
plugins=self.plugins,
)
self.runner_dict[app_name] = runner
return runner
Expand Down
14 changes: 14 additions & 0 deletions src/google/adk/cli/cli_tools_click.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import logging
import os
import tempfile
from typing import Any
from typing import Callable
from typing import Optional

import click
Expand Down Expand Up @@ -804,6 +806,16 @@ async def _lifespan(app: FastAPI):
)
@fast_api_common_options()
@adk_services_options()
# The --plugin option is currently ONLY for api_server
@click.option(
"--plugin",
"plugins",
multiple=True,
help=(
"Optional. The fully qualified path to a BasePlugin class to register, "
"e.g., 'my_agent.my_plugin.MyPlugin'. Can be specified multiple times."
),
)
@deprecated_adk_services_options()
def cli_api_server(
agents_dir: str,
Expand All @@ -821,6 +833,7 @@ def cli_api_server(
artifact_storage_uri: Optional[str] = None, # Deprecated
a2a: bool = False,
reload_agents: bool = False,
plugins: Optional[tuple[str]] = None,
):
"""Starts a FastAPI server for agents.

Expand Down Expand Up @@ -849,6 +862,7 @@ def cli_api_server(
host=host,
port=port,
reload_agents=reload_agents,
plugins=plugins,
),
host=host,
port=port,
Expand Down
25 changes: 25 additions & 0 deletions src/google/adk/cli/fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

from __future__ import annotations

import importlib
import json
import logging
import os
from pathlib import Path
import shutil
from typing import Any
from typing import List
from typing import Mapping
from typing import Optional

Expand All @@ -40,6 +42,7 @@
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
from ..memory.in_memory_memory_service import InMemoryMemoryService
from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
from ..plugins.base_plugin import BasePlugin
from ..runners import Runner
from ..sessions.in_memory_session_service import InMemorySessionService
from ..sessions.vertex_ai_session_service import VertexAiSessionService
Expand All @@ -53,6 +56,19 @@
logger = logging.getLogger("google_adk." + __name__)


def _load_plugin_class(path: str) -> type[BasePlugin]:
"""Dynamically imports and returns a class from a string path."""
try:
module_path, class_name = path.rsplit(".", 1)
module = importlib.import_module(module_path)
plugin_class = getattr(module, class_name)
if not issubclass(plugin_class, BasePlugin):
raise TypeError(f"Class at '{path}' is not a subclass of BasePlugin.")
return plugin_class
except (ImportError, AttributeError, ValueError, TypeError) as e:
raise click.ClickException(f"Failed to load plugin '{path}': {e}") from e


def get_fast_api_app(
*,
agents_dir: str,
Expand All @@ -69,6 +85,7 @@ def get_fast_api_app(
trace_to_cloud: bool = False,
reload_agents: bool = False,
lifespan: Optional[Lifespan[FastAPI]] = None,
plugins: Optional[tuple[str]] = None,
) -> FastAPI:
# Set up eval managers.
if eval_storage_uri:
Expand Down Expand Up @@ -177,6 +194,13 @@ def _parse_agent_engine_resource_name(agent_engine_id_or_resource_name):
# initialize Agent Loader
agent_loader = AgentLoader(agents_dir)

# Instantiate the plugins from their string paths
plugin_instances: List[BasePlugin] = []
if plugins:
for plugin_path in plugins:
plugin_class = _load_plugin_class(plugin_path)
plugin_instances.append(plugin_class(name=plugin_class.__name__))

adk_web_server = AdkWebServer(
agent_loader=agent_loader,
session_service=session_service,
Expand All @@ -186,6 +210,7 @@ def _parse_agent_engine_resource_name(agent_engine_id_or_resource_name):
eval_sets_manager=eval_sets_manager,
eval_set_results_manager=eval_set_results_manager,
agents_dir=agents_dir,
plugins=plugin_instances,
)

# Callbacks & other optional args for when constructing the FastAPI instance
Expand Down
Loading