Skip to content

Commit 80bbf9d

Browse files
authored
Support for prompts and resources in streamablehttp client side connection (#795)
* added support for prompts and resources for client side connection in streamable http Signed-off-by: Keval Mahajan <[email protected]> * Minor change Signed-off-by: Keval Mahajan <[email protected]> * support for gateway connection from client without accessing virtual servers Signed-off-by: Keval Mahajan <[email protected]> * update read_resource Signed-off-by: Keval Mahajan <[email protected]> * fix read_resource Signed-off-by: Keval Mahajan <[email protected]> --------- Signed-off-by: Keval Mahajan <[email protected]>
1 parent 4870313 commit 80bbf9d

File tree

3 files changed

+177
-13
lines changed

3 files changed

+177
-13
lines changed

mcpgateway/transports/streamablehttp_transport.py

Lines changed: 164 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from mcp.server.streamable_http import EventCallback, EventId, EventMessage, EventStore, StreamId
4646
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
4747
from mcp.types import JSONRPCMessage
48+
from pydantic import AnyUrl
4849
from sqlalchemy.orm import Session
4950
from starlette.datastructures import Headers
5051
from starlette.responses import JSONResponse
@@ -55,16 +56,21 @@
5556
from mcpgateway.config import settings
5657
from mcpgateway.db import SessionLocal
5758
from mcpgateway.services.logging_service import LoggingService
59+
from mcpgateway.services.prompt_service import PromptService
60+
from mcpgateway.services.resource_service import ResourceService
5861
from mcpgateway.services.tool_service import ToolService
5962
from mcpgateway.utils.verify_credentials import verify_credentials
6063

6164
# Initialize logging service first
6265
logging_service = LoggingService()
6366
logger = logging_service.get_logger(__name__)
6467

65-
# Initialize ToolService and MCP Server
68+
# Initialize ToolService, PromptService and MCP Server
6669
tool_service: ToolService = ToolService()
67-
mcp_app: Server[Any] = Server("mcp-streamable-http-stateless")
70+
prompt_service: PromptService = PromptService()
71+
resource_service: ResourceService = ResourceService()
72+
73+
mcp_app: Server[Any] = Server("mcp-streamable-http")
6874

6975
server_id_var: contextvars.ContextVar[str] = contextvars.ContextVar("server_id", default="default_server_id")
7076
request_headers_var: contextvars.ContextVar[dict[str, Any]] = contextvars.ContextVar("request_headers", default={})
@@ -410,6 +416,160 @@ async def list_tools() -> List[types.Tool]:
410416
return []
411417

412418

419+
@mcp_app.list_prompts()
420+
async def list_prompts() -> List[types.Prompt]:
421+
"""
422+
Lists all prompts available to the MCP Server.
423+
424+
Returns:
425+
A list of Prompt objects containing metadata such as name, description, and arguments.
426+
Logs and returns an empty list on failure.
427+
428+
Examples:
429+
>>> import inspect
430+
>>> sig = inspect.signature(list_prompts)
431+
>>> list(sig.parameters.keys())
432+
[]
433+
>>> sig.return_annotation
434+
typing.List[mcp.types.Prompt]
435+
"""
436+
437+
server_id = server_id_var.get()
438+
439+
if server_id:
440+
try:
441+
async with get_db() as db:
442+
prompts = await prompt_service.list_server_prompts(db, server_id)
443+
return [types.Prompt(name=prompt.name, description=prompt.description, arguments=prompt.arguments) for prompt in prompts]
444+
except Exception as e:
445+
logger.exception(f"Error listing Prompts:{e}")
446+
return []
447+
else:
448+
try:
449+
async with get_db() as db:
450+
prompts = await prompt_service.list_prompts(db, False, None, None)
451+
return [types.Prompt(name=prompt.name, description=prompt.description, arguments=prompt.arguments) for prompt in prompts]
452+
except Exception as e:
453+
logger.exception(f"Error listing prompts:{e}")
454+
return []
455+
456+
457+
@mcp_app.get_prompt()
458+
async def get_prompt(name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
459+
"""
460+
Retrieves a prompt by name, optionally substituting arguments.
461+
462+
Args:
463+
name (str): The name of the prompt to retrieve.
464+
arguments (Optional[dict[str, str]]): Optional dictionary of arguments to substitute into the prompt.
465+
466+
Returns:
467+
GetPromptResult: Object containing the prompt messages and description.
468+
Returns an empty list on failure or if no prompt content is found.
469+
470+
Logs exceptions if any errors occur during retrieval.
471+
472+
Examples:
473+
>>> import inspect
474+
>>> sig = inspect.signature(get_prompt)
475+
>>> list(sig.parameters.keys())
476+
['name', 'arguments']
477+
>>> sig.return_annotation.__name__
478+
'GetPromptResult'
479+
"""
480+
try:
481+
async with get_db() as db:
482+
try:
483+
result = await prompt_service.get_prompt(db=db, name=name, arguments=arguments)
484+
except Exception as e:
485+
logger.exception(f"Error getting prompt '{name}': {e}")
486+
return []
487+
if not result or not result.messages:
488+
logger.warning(f"No content returned by prompt: {name}")
489+
return []
490+
message_dicts = [message.dict() for message in result.messages]
491+
return types.GetPromptResult(messages=message_dicts, description=result.description)
492+
except Exception as e:
493+
logger.exception(f"Error getting prompt '{name}': {e}")
494+
return []
495+
496+
497+
@mcp_app.list_resources()
498+
async def list_resources() -> List[types.Resource]:
499+
"""
500+
Lists all resources available to the MCP Server.
501+
502+
Returns:
503+
A list of Resource objects containing metadata such as uri, name, description, and mimeType.
504+
Logs and returns an empty list on failure.
505+
506+
Examples:
507+
>>> import inspect
508+
>>> sig = inspect.signature(list_resources)
509+
>>> list(sig.parameters.keys())
510+
[]
511+
>>> sig.return_annotation
512+
typing.List[mcp.types.Resource]
513+
"""
514+
server_id = server_id_var.get()
515+
516+
if server_id:
517+
try:
518+
async with get_db() as db:
519+
resources = await resource_service.list_server_resources(db, server_id)
520+
return [types.Resource(uri=resource.uri, name=resource.name, description=resource.description, mimeType=resource.mime_type) for resource in resources]
521+
except Exception as e:
522+
logger.exception(f"Error listing Resources:{e}")
523+
return []
524+
else:
525+
try:
526+
async with get_db() as db:
527+
resources = await resource_service.list_resources(db, False)
528+
return [types.Resource(uri=resource.uri, name=resource.name, description=resource.description, mimeType=resource.mime_type) for resource in resources]
529+
except Exception as e:
530+
logger.exception(f"Error listing resources:{e}")
531+
return []
532+
533+
534+
@mcp_app.read_resource()
535+
async def read_resource(uri: AnyUrl) -> Union[str, bytes]:
536+
"""
537+
Reads the content of a resource specified by its URI.
538+
539+
Args:
540+
uri (AnyUrl): The URI of the resource to read.
541+
542+
Returns:
543+
Union[str, bytes]: The content of the resource, typically as text.
544+
Returns an empty list on failure or if no content is found.
545+
546+
Logs exceptions if any errors occur during reading.
547+
548+
Examples:
549+
>>> import inspect
550+
>>> sig = inspect.signature(read_resource)
551+
>>> list(sig.parameters.keys())
552+
['uri']
553+
>>> sig.return_annotation
554+
typing.Union[str, bytes]
555+
"""
556+
try:
557+
async with get_db() as db:
558+
try:
559+
result = await resource_service.read_resource(db=db, uri=str(uri))
560+
except Exception as e:
561+
logger.exception(f"Error reading resource '{uri}': {e}")
562+
return []
563+
if not result or not result.text:
564+
logger.warning(f"No content returned by resource: {uri}")
565+
return []
566+
567+
return result.text
568+
except Exception as e:
569+
logger.exception(f"Error reading resource '{uri}': {e}")
570+
return []
571+
572+
413573
class SessionManagerWrapper:
414574
"""
415575
Wrapper class for managing the lifecycle of a StreamableHTTPSessionManager instance.
@@ -526,6 +686,8 @@ async def handle_streamable_http(self, scope: Scope, receive: Receive, send: Sen
526686
if match:
527687
server_id = match.group("server_id")
528688
server_id_var.set(server_id)
689+
else:
690+
server_id_var.set(None)
529691

530692
try:
531693
await self.session_manager.handle_request(scope, receive, send)

mcpgateway/wrapper.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,17 +165,19 @@ def convert_url(url: str) -> str:
165165
166166
Examples:
167167
>>> convert_url("http://localhost:4444/servers/uuid")
168-
'http://localhost:4444/servers/uuid/mcp'
168+
'http://localhost:4444/servers/uuid/mcp/'
169169
>>> convert_url("http://localhost:4444/servers/uuid/sse")
170-
'http://localhost:4444/servers/uuid/mcp'
170+
'http://localhost:4444/servers/uuid/mcp/'
171171
>>> convert_url("http://localhost:4444/servers/uuid/mcp")
172-
'http://localhost:4444/servers/uuid/mcp'
172+
'http://localhost:4444/servers/uuid/mcp/'
173173
"""
174174
if url.endswith("/mcp") or url.endswith("/mcp/"):
175+
if url.endswith("/mcp"):
176+
return url + "/"
175177
return url
176178
if url.endswith("/sse"):
177-
return url.replace("/sse", "/mcp")
178-
return url + "/mcp"
179+
return url.replace("/sse", "/mcp/")
180+
return url + "/mcp/"
179181

180182

181183
def send_to_stdout(obj: Union[dict, str]) -> None:
@@ -630,7 +632,7 @@ def parse_args() -> Settings:
630632
>>> sys.argv = ["prog", "--url", "http://localhost:4444/servers/u"]
631633
>>> try:
632634
... s = parse_args()
633-
... s.server_url.endswith("/mcp")
635+
... s.server_url.endswith("/mcp/")
634636
... finally:
635637
... sys.argv = _argv
636638
True

tests/unit/mcpgateway/test_wrapper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def setup_function():
3030
# Utilities
3131
# -------------------
3232
def test_convert_url_variants():
33-
assert wrapper.convert_url("http://x/servers/uuid") == "http://x/servers/uuid/mcp"
34-
assert wrapper.convert_url("http://x/servers/uuid/") == "http://x/servers/uuid//mcp"
35-
assert wrapper.convert_url("http://x/servers/uuid/mcp") == "http://x/servers/uuid/mcp"
36-
assert wrapper.convert_url("http://x/servers/uuid/sse") == "http://x/servers/uuid/mcp"
33+
assert wrapper.convert_url("http://x/servers/uuid") == "http://x/servers/uuid/mcp/"
34+
assert wrapper.convert_url("http://x/servers/uuid/") == "http://x/servers/uuid//mcp/"
35+
assert wrapper.convert_url("http://x/servers/uuid/mcp") == "http://x/servers/uuid/mcp/"
36+
assert wrapper.convert_url("http://x/servers/uuid/sse") == "http://x/servers/uuid/mcp/"
3737

3838

3939
def test_make_error_defaults_and_data():
@@ -145,7 +145,7 @@ def test_parse_args_with_env(monkeypatch):
145145
sys.argv = ["prog"]
146146
try:
147147
s = wrapper.parse_args()
148-
assert s.server_url.endswith("/mcp")
148+
assert s.server_url.endswith("/mcp/")
149149
assert s.auth_header == "Bearer 123"
150150
finally:
151151
sys.argv = sys_argv

0 commit comments

Comments
 (0)