Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion mcpgateway/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
GatewayUpdate,
JsonPathModifier,
PromptCreate,
PromptExecuteArgs,
PromptRead,
PromptUpdate,
ResourceCreate,
Expand Down Expand Up @@ -1632,7 +1633,13 @@ async def get_prompt(
Rendered prompt or metadata.
"""
logger.debug(f"User: {user} requested prompt: {name} with args={args}")
return await prompt_service.get_prompt(db, name, args)
try:
PromptExecuteArgs(args=args)
return await prompt_service.get_prompt(db, name, args)
except Exception as ex:
logger.error(f"Error retrieving prompt {name}: {ex}")
if isinstance(ex, ValueError):
return JSONResponse(content={"message": "Prompt execution arguments contains HTML tags that may cause security issues"}, status_code=422)


@prompt_router.get("/{name}")
Expand Down
28 changes: 28 additions & 0 deletions mcpgateway/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1433,6 +1433,34 @@ def validate_arguments(cls, v: Dict[str, Any]) -> Dict[str, Any]:
return v


class PromptExecuteArgs(BaseModel):
"""
Schema for args executing a prompt

Attributes:
args (Dict[str, str]): Arguments for prompt execution.
"""

model_config = ConfigDict(str_strip_whitespace=True)

args: Dict[str, str] = Field(default_factory=dict, description="Arguments for prompt execution")

@field_validator("args")
@classmethod
def validate_args(cls, v: dict) -> dict:
"""Ensure prompt arguments pass XSS validation

Args:
v (dict): Value to validate

Returns:
dict: Value if validated as safe
"""
for val in v.values():
SecurityValidator.validate_no_xss(val, "Prompt execution arguments")
return v


class PromptUpdate(BaseModelWithConfigDict):
"""Schema for updating an existing prompt.

Expand Down
18 changes: 18 additions & 0 deletions mcpgateway/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,24 @@ def validate_url(cls, value: str, field_name: str = "URL") -> str:

return value

@classmethod
def validate_no_xss(cls, value: str, field_name: str) -> None:
"""
Validate that a string does not contain XSS patterns.

Args:
value (str): Value to validate.
field_name (str): Name of the field being validated.

Raises:
ValueError: If the value contains XSS patterns.
"""
if not value:
return # Empty values are considered safe
# Check for dangerous HTML tags
if re.search(cls.DANGEROUS_HTML_PATTERN, value, re.IGNORECASE):
raise ValueError(f"{field_name} contains HTML tags that may cause security issues")

@classmethod
def validate_json_depth(
cls,
Expand Down
Loading