Skip to content

Commit 1fa7640

Browse files
committed
fix(tool/decorator): simplify validation logic
1 parent 3e6d530 commit 1fa7640

File tree

2 files changed

+28
-48
lines changed

2 files changed

+28
-48
lines changed

src/strands/tools/decorator.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,6 @@ def my_tool(param1: str, param2: int = 42) -> dict:
5454
TypeVar,
5555
Union,
5656
cast,
57-
get_args,
58-
get_origin,
5957
get_type_hints,
6058
overload,
6159
)
@@ -101,39 +99,12 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
10199
self.type_hints = get_type_hints(func)
102100
self._context_param = context_param
103101

102+
self._validate_signature()
103+
104104
# Parse the docstring with docstring_parser
105105
doc_str = inspect.getdoc(func) or ""
106106
self.doc = docstring_parser.parse(doc_str)
107107

108-
def _contains_tool_context(tp: Any) -> bool:
109-
"""Return True if the annotation `tp` (possibly Union/Optional) includes ToolContext."""
110-
if tp is None:
111-
return False
112-
origin = get_origin(tp)
113-
if origin is Union:
114-
return any(_contains_tool_context(a) for a in get_args(tp))
115-
# Handle direct ToolContext type
116-
return tp is ToolContext
117-
118-
for param in self.signature.parameters.values():
119-
# Prefer resolved type hints (handles forward refs); fall back to annotation
120-
ann = self.type_hints.get(param.name, param.annotation)
121-
if ann is inspect._empty:
122-
continue
123-
124-
if _contains_tool_context(ann):
125-
# If decorator didn't opt-in to context injection, complain
126-
if self._context_param is None:
127-
raise TypeError(
128-
f"Parameter '{param.name}' is of type 'ToolContext' but '@tool(context=True)' is missing."
129-
)
130-
# If decorator specified a different param name, complain
131-
if param.name != self._context_param:
132-
raise TypeError(
133-
f"Parameter '{param.name}' is of type 'ToolContext' but has the wrong name. "
134-
f"It should be named '{self._context_param}'."
135-
)
136-
137108
# Get parameter descriptions from parsed docstring
138109
self.param_descriptions = {
139110
param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params
@@ -142,6 +113,21 @@ def _contains_tool_context(tp: Any) -> bool:
142113
# Create a Pydantic model for validation
143114
self.input_model = self._create_input_model()
144115

116+
def _validate_signature(self) -> None:
117+
"""Verify that ToolContext is used correctly in the function signature."""
118+
# Find and validate the ToolContext parameter
119+
for param in self.signature.parameters.values():
120+
if param.annotation is ToolContext:
121+
if self._context_param is None:
122+
raise ValueError("@tool(context=True) must be set if passing in ToolContext param")
123+
124+
if param.name != self._context_param:
125+
raise ValueError(
126+
f"param_name=<{param.name}> | ToolContext param must be named '{self._context_param}'"
127+
)
128+
# Found the parameter, no need to check further
129+
break
130+
145131
def _create_input_model(self) -> Type[BaseModel]:
146132
"""Create a Pydantic model from function signature for input validation.
147133

tests/strands/tools/test_decorator.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,41 +1366,35 @@ async def async_generator() -> AsyncGenerator:
13661366

13671367

13681368
def test_tool_with_mismatched_tool_context_param_name_raises_error():
1369-
"""Verify that a TypeError is raised for a mismatched tool_context parameter name."""
1370-
with pytest.raises(TypeError) as excinfo:
1369+
"""Verify that a ValueError is raised for a mismatched tool_context parameter name."""
1370+
with pytest.raises(ValueError) as excinfo:
13711371

13721372
@strands.tool(context=True)
13731373
def my_tool(context: ToolContext):
13741374
pass
13751375

1376-
assert (
1377-
"Parameter 'context' is of type 'ToolContext' but has the wrong name. It should be named 'tool_context'."
1378-
in str(excinfo.value)
1379-
)
1376+
assert "ToolContext param must be named 'tool_context'" in str(excinfo.value)
1377+
assert "param_name=<context>" in str(excinfo.value)
13801378

13811379

13821380
def test_tool_with_tool_context_but_no_context_flag_raises_error():
1383-
"""Verify that a TypeError is raised if ToolContext is used without context=True."""
1384-
with pytest.raises(TypeError) as excinfo:
1381+
"""Verify that a ValueError is raised if ToolContext is used without context=True."""
1382+
with pytest.raises(ValueError) as excinfo:
13851383

13861384
@strands.tool
13871385
def my_tool(tool_context: ToolContext):
13881386
pass
13891387

1390-
assert "Parameter 'tool_context' is of type 'ToolContext' but '@tool(context=True)' is missing." in str(
1391-
excinfo.value
1392-
)
1388+
assert "@tool(context=True) must be set" in str(excinfo.value)
13931389

13941390

13951391
def test_tool_with_tool_context_named_custom_context_raises_error_if_mismatched():
1396-
"""Verify that a TypeError is raised when context param name doesn't match the decorator value."""
1397-
with pytest.raises(TypeError) as excinfo:
1392+
"""Verify that a ValueError is raised when context param name doesn't match the decorator value."""
1393+
with pytest.raises(ValueError) as excinfo:
13981394

13991395
@strands.tool(context="my_context")
14001396
def my_tool(tool_context: ToolContext):
14011397
pass
14021398

1403-
assert (
1404-
"Parameter 'tool_context' is of type 'ToolContext' but has the wrong name. It should be named 'my_context'."
1405-
in str(excinfo.value)
1406-
)
1399+
assert "ToolContext param must be named 'my_context'" in str(excinfo.value)
1400+
assert "param_name=<tool_context>" in str(excinfo.value)

0 commit comments

Comments
 (0)