Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions llama-index-core/llama_index/core/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from llama_index.core.bridge.pydantic import BaseModel, FieldInfo
from llama_index.core.tools.types import AsyncBaseTool, ToolMetadata, ToolOutput
from llama_index.core.tools.utils import create_schema_from_function
from llama_index.core.schema import BaseNode, Document
from llama_index.core.workflow.context import Context

AsyncCallable = Callable[..., Awaitable[Any]]
Expand Down Expand Up @@ -295,6 +296,12 @@ def _parse_tool_output(self, raw_output: Any) -> List[ContentBlock]:
for item in raw_output
):
return raw_output
elif isinstance(raw_output, (BaseNode, Document)):
return [TextBlock(text=raw_output.get_content())]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will break any time the output is not a document/node object. We need an additional if/else check rather than doing this in the final else block (since here, the type is still "any")

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed a more type-safe fix. Going to add a test or two as well

elif isinstance(raw_output, list) and all(
isinstance(item, (BaseNode, Document)) for item in raw_output
):
return [TextBlock(text=item.get_content()) for item in raw_output]
else:
return [TextBlock(text=str(raw_output))]

Expand Down
29 changes: 29 additions & 0 deletions llama-index-core/tests/tools/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from llama_index.core.bridge.pydantic import BaseModel, Field
from llama_index.core.llms import TextBlock, ImageBlock
from llama_index.core.tools.function_tool import FunctionTool
from llama_index.core.schema import Document, TextNode
from llama_index.core.workflow.context import Context
from llama_index.core.workflow import Context

Expand Down Expand Up @@ -425,3 +426,31 @@ def my_method(self, ctx: Context, a: int) -> str:
assert "a" in fields
assert fields["a"].description == "some input value"
assert "self" not in fields


def test_function_tool_output_document_and_nodes():
def get_document() -> Document:
return Document(text="Hello" * 1024)

def get_node() -> TextNode:
return TextNode(text="Hello" * 1024)

def get_documents() -> List[Document]:
return [Document(text="Hello" * 1024), Document(text="World" * 1024)]

def get_nodes() -> List[TextNode]:
return [TextNode(text="Hello" * 1024), TextNode(text="World" * 1024)]

tool = FunctionTool.from_defaults(get_document)
assert tool.call().content == "Hello" * 1024

tool = FunctionTool.from_defaults(get_node)
assert tool.call().content == "Hello" * 1024

tool = FunctionTool.from_defaults(get_documents)
assert "Hello" * 1024 in tool.call().content
assert "World" * 1024 in tool.call().content

tool = FunctionTool.from_defaults(get_nodes)
assert "Hello" * 1024 in tool.call().content
assert "World" * 1024 in tool.call().content
Loading