@@ -54,8 +54,6 @@ def my_tool(param1: str, param2: int = 42) -> dict:
5454 TypeVar ,
5555 Union ,
5656 cast ,
57- get_args ,
58- get_origin ,
5957 get_type_hints ,
6058 overload ,
6159)
@@ -101,39 +99,12 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) -
10199 self .type_hints = get_type_hints (func )
102100 self ._context_param = context_param
103101
102+ self ._validate_signature ()
103+
104104 # Parse the docstring with docstring_parser
105105 doc_str = inspect .getdoc (func ) or ""
106106 self .doc = docstring_parser .parse (doc_str )
107107
108- def _contains_tool_context (tp : Any ) -> bool :
109- """Return True if the annotation `tp` (possibly Union/Optional) includes ToolContext."""
110- if tp is None :
111- return False
112- origin = get_origin (tp )
113- if origin is Union :
114- return any (_contains_tool_context (a ) for a in get_args (tp ))
115- # Handle direct ToolContext type
116- return tp is ToolContext
117-
118- for param in self .signature .parameters .values ():
119- # Prefer resolved type hints (handles forward refs); fall back to annotation
120- ann = self .type_hints .get (param .name , param .annotation )
121- if ann is inspect ._empty :
122- continue
123-
124- if _contains_tool_context (ann ):
125- # If decorator didn't opt-in to context injection, complain
126- if self ._context_param is None :
127- raise TypeError (
128- f"Parameter '{ param .name } ' is of type 'ToolContext' but '@tool(context=True)' is missing."
129- )
130- # If decorator specified a different param name, complain
131- if param .name != self ._context_param :
132- raise TypeError (
133- f"Parameter '{ param .name } ' is of type 'ToolContext' but has the wrong name. "
134- f"It should be named '{ self ._context_param } '."
135- )
136-
137108 # Get parameter descriptions from parsed docstring
138109 self .param_descriptions = {
139110 param .arg_name : param .description or f"Parameter { param .arg_name } " for param in self .doc .params
@@ -142,6 +113,21 @@ def _contains_tool_context(tp: Any) -> bool:
142113 # Create a Pydantic model for validation
143114 self .input_model = self ._create_input_model ()
144115
116+ def _validate_signature (self ) -> None :
117+ """Verify that ToolContext is used correctly in the function signature."""
118+ # Find and validate the ToolContext parameter
119+ for param in self .signature .parameters .values ():
120+ if param .annotation is ToolContext :
121+ if self ._context_param is None :
122+ raise ValueError ("@tool(context=True) must be set if passing in ToolContext param" )
123+
124+ if param .name != self ._context_param :
125+ raise ValueError (
126+ f"param_name=<{ param .name } > | ToolContext param must be named '{ self ._context_param } '"
127+ )
128+ # Found the parameter, no need to check further
129+ break
130+
145131 def _create_input_model (self ) -> Type [BaseModel ]:
146132 """Create a Pydantic model from function signature for input validation.
147133
0 commit comments