Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/backend/chat/custom/tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
from backend.model_deployments.base import BaseDeployment
from backend.schemas.context import Context
from backend.services.logger.utils import LoggerFactory
from backend.tools.base import ToolAuthException, ToolError, ToolErrorCode
from backend.tools.base import (
ToolAuthException,
ToolErrorCode,
)

TIMEOUT_SECONDS = 60

Expand Down Expand Up @@ -110,19 +113,17 @@ async def _call_tool_async(
{
"call": tool_call,
"outputs": tool.get_tool_error(
ToolError(
text="Tool authentication failed",
details=str(e),
type=ToolErrorCode.AUTH,
)
details=str(e),
text="Tool authentication failed",
error_type=ToolErrorCode.AUTH,
),
}
]
except Exception as e:
return [
{
"call": tool_call,
"outputs": tool.get_tool_error(ToolError(text=str(e))),
"outputs": tool.get_tool_error(details=str(e)),
}
]

Expand Down
18 changes: 8 additions & 10 deletions src/backend/tests/unit/chat/test_tool_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,10 @@ async def call(
"name": "toolkit_calculator",
"parameters": {"code": "6*7"},
},
"outputs": [{'type': 'other', 'success': False, 'text': 'Calculator failed', 'details': ''}],
"outputs": [{'type': 'other', 'success': False, 'text': 'Error calling tool toolkit_calculator.', 'details': 'Calculator failed'}],
},
]


@patch("backend.chat.custom.tool_calls.TIMEOUT_SECONDS", 1)
def test_async_call_tools_timeout(mock_get_available_tools) -> None:
class MockCalculator(BaseTool):
Expand Down Expand Up @@ -249,8 +248,8 @@ async def call(
)

assert {'call': {'name': 'web_scrape', 'parameters': {'code': '6*7'}}, 'outputs': [
{'details': '', 'success': False, 'text': "Model didn't pass required parameter: url", 'type'
: 'other'}]} in results
{'type': 'other', 'success': False, 'text': 'Error calling tool web_scrape.',
'details': "Model didn't pass required parameter: url"}]} in results
assert {
"call": {"name": "toolkit_calculator", "parameters": {"code": "6*7"}},
"outputs": [{"result": 42}],
Expand Down Expand Up @@ -299,7 +298,7 @@ async def call(
async_call_tools(chat_history, MockCohereDeployment(), ctx)
)
assert {'call': {'name': 'toolkit_calculator', 'parameters': {'invalid_param': '6*7'}}, 'outputs': [
{'details': '', 'success': False, 'text': "Model didn't pass required parameter: code",
{'details': "Model didn't pass required parameter: code", 'success': False, 'text': 'Error calling tool toolkit_calculator.',
'type': 'other'}]} in results

def test_tools_params_checker_invalid_param_type(mock_get_available_tools) -> None:
Expand Down Expand Up @@ -343,9 +342,8 @@ async def call(
async_call_tools(chat_history, MockCohereDeployment(), ctx)
)
assert {'call': {'name': 'toolkit_calculator', 'parameters': {'code': 6}}, 'outputs': [
{'details': '', 'success': False,
'text': "Model passed invalid parameter. Parameter 'code' must be of type str, but got int",
'type': 'other'}]} in results
{'type': 'other', 'success': False, 'text': 'Error calling tool toolkit_calculator.',
'details': "Model passed invalid parameter. Parameter 'code' must be of type str, but got int"}]} in results

def test_tools_params_checker_required_param_empty(mock_get_available_tools) -> None:
class MockCalculator(BaseTool):
Expand Down Expand Up @@ -388,5 +386,5 @@ async def call(
async_call_tools(chat_history, MockCohereDeployment(), ctx)
)
assert {'call': {'name': 'toolkit_calculator', 'parameters': {'code': ''}}, 'outputs': [
{'details': '', 'success': False, 'text': 'Model passed empty value for required parameter: code',
'type': 'other'}]} in results
{'details': 'Model passed empty value for required parameter: code', 'success': False,
'text': 'Error calling tool toolkit_calculator.', 'type': 'other'}]} in results
3 changes: 2 additions & 1 deletion src/backend/tests/unit/tools/test_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,5 @@ async def test_calculator_invalid_syntax() -> None:
ctx = Context()
calculator = Calculator()
result = await calculator.call({"code": "2+"}, ctx)
assert result == {"text": "Parsing error - syntax not allowed."}

assert result == [{'details': 'parse error [column 2]: parity, expression: 2+', 'success': False, 'text': 'Error calling tool toolkit_calculator.', 'type': 'other'}]
5 changes: 3 additions & 2 deletions src/backend/tests/unit/tools/test_lang_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ async def test_wiki_retriever_no_docs() -> None:
):
result = await retriever.call({"query": query}, ctx)

assert result == []
assert result == [{'details': '','success': False,'text': 'No results found.','type': 'other'}]



@pytest.mark.skipif(not is_cohere_env_set, reason="Cohere API key not set")
Expand Down Expand Up @@ -163,4 +164,4 @@ async def test_vector_db_retriever_no_docs() -> None:
mock_db.as_retriever().get_relevant_documents.return_value = mock_docs
result = await retriever.call({"query": query}, ctx)

assert result == []
assert result == [{'details': '', 'success': False, 'text': 'No results found.', 'type': 'other'}]
8 changes: 4 additions & 4 deletions src/backend/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(self, message, tool_id: str):
self.message = message
self.tool_id = tool_id


class ToolError(BaseModel, extra="allow"):
type: ToolErrorCode = ToolErrorCode.OTHER
success: bool = False
Expand All @@ -38,6 +37,7 @@ class ToolError(BaseModel, extra="allow"):
class ToolArgument(StrEnum):
DOMAIN_FILTER = "domain_filter"
SITE_FILTER = "site_filter"

class ParametersValidationMeta(type):
"""
Metaclass to decorate all tools `call` methods with the parameter checker.
Expand Down Expand Up @@ -90,14 +90,14 @@ def _handle_tool_specific_errors(cls, error: Exception, **kwargs: Any) -> None:
...

@classmethod
def get_tool_error(cls, err: ToolError):
tool_error = err.model_dump()
def get_tool_error(cls, details: str, text: str = "Error calling tool", error_type: ToolErrorCode = ToolErrorCode.OTHER):
tool_error = ToolError(text=f"{text} {cls.ID}.", details=details, type=error_type).model_dump()
logger.error(event=f"Error calling tool {cls.ID}", error=tool_error)
return [tool_error]

@classmethod
def get_no_results_error(cls):
return cls.get_tool_error(ToolError(text="No results found."))
return ToolError(text="No results found.", details="No results found for the given params.")

@abstractmethod
async def call(
Expand Down
12 changes: 8 additions & 4 deletions src/backend/tools/brave_search/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,19 @@ async def call(
# Get domain filtering from kwargs
filtered_domains = kwargs.get(ToolArgument.DOMAIN_FILTER, [])

response = await self.client.search_async(
q=query, count=self.num_results, include_domains=filtered_domains
)
try:
response = await self.client.search_async(
q=query, count=self.num_results, include_domains=filtered_domains
)
except Exception as e:
return self.get_tool_error(details=str(e))

response = dict(response)

results = response.get("web", {}).get("results", [])

if not results:
self.get_no_results_error()
return self.get_no_results_error()

tool_results = []
for result in results:
Expand Down
6 changes: 3 additions & 3 deletions src/backend/tools/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ async def call(

to_evaluate = expression.replace("pi", "PI").replace("e", "E")

result = []
try:
result = {"text": math_parser.parse(to_evaluate).evaluate({})}
except Exception as e:
logger.error(event=f"[Calculator] Error parsing expression: {e}")
result = {"text": "Parsing error - syntax not allowed."}
return self.get_tool_error(details=str(e))

return result

return result # type: ignore
13 changes: 9 additions & 4 deletions src/backend/tools/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]:
session = kwargs.get("session")
user_id = kwargs.get("user_id")
if not file:
return []
return self.get_tool_error(details="Files are not passed in model generated params")

_, file_id = file
retrieved_file = file_crud.get_file(session, file_id, user_id)
if not retrieved_file:
return []
return self.get_tool_error(details="The wrong files were passed in the tool parameters, or files were not found")

return [
{
Expand Down Expand Up @@ -125,13 +125,15 @@ async def call(
user_id = kwargs.get("user_id")

if not query or not files:
return []
return self.get_tool_error(
details="Missing query or files. The wrong files might have been passed in the tool parameters")

file_ids = [file_id for _, file_id in files]
retrieved_files = file_crud.get_files_by_ids(session, file_ids, user_id)

if not retrieved_files:
return []
return self.get_tool_error(
details="Missing files. The wrong files might have been passed in the tool parameters")

results = []
for file in retrieved_files:
Expand All @@ -142,4 +144,7 @@ async def call(
"url": file.file_name,
}
)
if not results:
return self.get_no_results_error()

return results
37 changes: 21 additions & 16 deletions src/backend/tools/google_drive/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,17 @@ async def call(self, parameters: dict, **kwargs: Any) -> List[Dict[str, Any]]:
# Search Google Drive
logger.info(event="[Google Drive] Defaulting to raw Google Drive search.")
agent_tool_metadata = kwargs["agent_tool_metadata"]
documents = await _default_gdrive_list_files(
user_id=user_id, query=query, agent_tool_metadata=agent_tool_metadata
)
try:
documents = await _default_gdrive_list_files(
user_id=user_id, query=query, agent_tool_metadata=agent_tool_metadata
)
except Exception as e:
return self.get_tool_error(details=str(e))

if not documents:
logger.info(event="[Google Drive] No documents found.")
return self.get_no_results_error()

return documents


Expand Down Expand Up @@ -141,20 +149,17 @@ async def _default_gdrive_list_files(
fields = f"nextPageToken, files({DOC_FIELDS})"

search_results = []
try:
search_results = (
service.files()
.list(
pageSize=SEARCH_LIMIT,
q=q,
includeItemsFromAllDrives=True,
supportsAllDrives=True,
fields=fields,
)
.execute()
search_results = (
service.files()
.list(
pageSize=SEARCH_LIMIT,
q=q,
includeItemsFromAllDrives=True,
supportsAllDrives=True,
fields=fields,
)
except Exception as error:
logger.error(event="[Google Drive] Error searching files", error=error)
.execute()
)

files = search_results.get("files", [])
if not files:
Expand Down
8 changes: 5 additions & 3 deletions src/backend/tools/google_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ async def call(
# Get domain filtering from kwargs
filtered_domains = kwargs.get(ToolArgument.DOMAIN_FILTER, [])
domain_filters = [f"site:{domain}" for domain in filtered_domains]

response = cse.list(q=query, cx=self.CSE_ID, orTerms=domain_filters).execute()
search_results = response.get("items", [])
try:
response = cse.list(q=query, cx=self.CSE_ID, orTerms=domain_filters).execute()
search_results = response.get("items", [])
except Exception as e:
return self.get_tool_error(details=str(e))

if not search_results:
return self.get_no_results_error()
Expand Down
3 changes: 3 additions & 0 deletions src/backend/tools/hybrid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ async def call(
**kwargs,
)

if not reranked_results:
return self.get_no_results_error()

return reranked_results

async def rerank_results(
Expand Down
37 changes: 24 additions & 13 deletions src/backend/tools/lang_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,17 @@ async def call(
) -> List[Dict[str, Any]]:
wiki_retriever = WikipediaRetriever()
query = parameters.get("query", "")
docs = wiki_retriever.get_relevant_documents(query)
text_splitter = CharacterTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
documents = text_splitter.split_documents(docs)
try:
docs = wiki_retriever.get_relevant_documents(query)
text_splitter = CharacterTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
)
documents = text_splitter.split_documents(docs)
except Exception as e:
return self.get_tool_error(details=str(e))

if not documents:
return self.get_no_results_error()

return [
{
Expand Down Expand Up @@ -115,13 +121,18 @@ async def call(
cohere_embeddings = CohereEmbeddings(cohere_api_key=self.COHERE_API_KEY)

# Load text files and split into chunks
loader = PyPDFLoader(self.filepath)
text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0)
pages = loader.load_and_split(text_splitter)

# Create a vector store from the documents
db = Chroma.from_documents(documents=pages, embedding=cohere_embeddings)
query = parameters.get("query", "")
input_docs = db.as_retriever().get_relevant_documents(query)
try:
loader = PyPDFLoader(self.filepath)
text_splitter = CharacterTextSplitter(chunk_size=300, chunk_overlap=0)
pages = loader.load_and_split(text_splitter)

# Create a vector store from the documents
db = Chroma.from_documents(documents=pages, embedding=cohere_embeddings)
query = parameters.get("query", "")
input_docs = db.as_retriever().get_relevant_documents(query)
except Exception as e:
return self.get_tool_error(details=str(e))
if not input_docs:
return self.get_no_results_error()

return [{"text": doc.page_content} for doc in input_docs]
13 changes: 10 additions & 3 deletions src/backend/tools/python_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,15 @@ async def call(self, parameters: dict, ctx: Any, **kwargs: Any):
raise Exception("Python Interpreter tool called while URL not set")

code = parameters.get("code", "")
res = requests.post(self.INTERPRETER_URL, json={"code": code})
try:
res = requests.post(self.INTERPRETER_URL, json={"code": code})
clean_res = self._clean_response(res.json())
except Exception as e:
return self.get_tool_error(details=str(e))

if not clean_res:
return self.get_no_results_error()

clean_res = self._clean_response(res.json())
return clean_res

def _clean_response(self, result: Any) -> Dict[str, str]:
Expand All @@ -82,7 +88,8 @@ def _clean_response(self, result: Any) -> Dict[str, str]:
r.setdefault("text", r.get("std_out"))
elif r.get("success") is False:
error_message = r.get("error", {}).get("message", "")
r.setdefault("text", error_message)
# r.setdefault("text", error_message)
return self.get_tool_error(details=error_message)
elif r.get("output_file") and r.get("output_file").get("filename"):
if r["output_file"]["filename"] != "":
r.setdefault(
Expand Down
12 changes: 10 additions & 2 deletions src/backend/tools/slack/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ async def call(self, parameters: dict, ctx: Any, **kwargs: Any) -> List[Dict[str

# Search Slack
slack_service = get_slack_service(user_id=user_id, search_limit=SEARCH_LIMIT)
all_results = slack_service.search_all(query=query)
return slack_service.serialize_results(all_results)
try:
all_results = slack_service.search_all(query=query)
results = slack_service.serialize_results(all_results)
except Exception as e:
return self.get_tool_error(details=str(e))

if not results:
return self.get_no_results_error()

return results

Loading
Loading