Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions src/backend/alembic/versions/658924a376dd_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
"""empty message

Revision ID: 658924a376dd
Revises: a6efd9f047b4
Create Date: 2024-06-25 18:14:17.461015

"""

from typing import Sequence, Union

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "658924a376dd"
down_revision: Union[str, None] = "a6efd9f047b4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"tool_calls",
sa.Column("name", sa.String(), nullable=False),
sa.Column("parameters", sa.JSON(), nullable=True),
sa.Column("message_id", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=True),
sa.Column("updated_at", sa.DateTime(), nullable=True),
sa.ForeignKeyConstraint(["message_id"], ["messages.id"], ondelete="CASCADE"),
sa.PrimaryKeyConstraint("id"),
)
op.create_index("tool_call_message_id", "tool_calls", ["message_id"], unique=False)
op.add_column("messages", sa.Column("tool_plan", sa.String(), nullable=True))
op.drop_index("tool_auth_index", table_name="tool_auth")
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_index("tool_auth_index", "tool_auth", ["user_id", "tool_id"], unique=True)
op.drop_column("messages", "tool_plan")
op.drop_index("tool_call_message_id", table_name="tool_calls")
op.drop_table("tool_calls")
# ### end Alembic commands ###
135 changes: 75 additions & 60 deletions src/backend/chat/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,37 +44,29 @@ def chat(self, chat_request: CohereChatRequest, **kwargs: Any) -> Any:

self.chat_request = chat_request
self.is_first_start = True
should_break = False

for step in range(MAX_STEPS):
logger.info(f"Step {step + 1}")
try:
stream = self.call_chat(self.chat_request, deployment_model, **kwargs)

for event in stream:
result = self.handle_event(event, chat_request)

if result:
yield result

if event[
"event_type"
] == StreamEvent.STREAM_END and self.is_final_event(
event, chat_request
):
should_break = True
break
except Exception as e:
yield {
"event_type": StreamEvent.STREAM_END,
"finish_reason": "ERROR",
"error": str(e),
"status_code": 500,
}
should_break = True

if should_break:
break
try:
stream = self.call_chat(self.chat_request, deployment_model, **kwargs)

for event in stream:
result = self.handle_event(event, chat_request)

if result:
yield result

if event[
"event_type"
] == StreamEvent.STREAM_END and self.is_final_event(
event, chat_request
):
break
except Exception as e:
yield {
"event_type": StreamEvent.STREAM_END,
"finish_reason": "ERROR",
"error": str(e),
"status_code": 500,
}

def is_final_event(
self, event: Dict[str, Any], chat_request: CohereChatRequest
Expand Down Expand Up @@ -131,52 +123,75 @@ def call_chat(self, chat_request, deployment_model, **kwargs: Any):
agent_id = kwargs.get("agent_id", "")
managed_tools = self.get_managed_tools(chat_request)

# If tools are managed and not zero shot tools, replace the tools in the chat request
if len(managed_tools) == len(chat_request.tools):
tool_names = []
if managed_tools:
chat_request.tools = managed_tools
tool_names = [tool.name for tool in managed_tools]

# Get the tool calls stream and either return a direct answer or continue
tool_calls_stream = self.get_tool_calls(
managed_tools, chat_request.chat_history, deployment_model, **kwargs
)
is_direct_answer, new_chat_history, stream = self.handle_tool_calls_stream(
tool_calls_stream
)

for event in stream:
yield event
# Add files to chat history if the tool requires it
if ToolName.Read_File in tool_names or ToolName.Search_File in tool_names:
chat_request.chat_history = self.add_files_to_chat_history(
chat_request.chat_history,
kwargs.get("conversation_id"),
kwargs.get("session"),
kwargs.get("user_id"),
)

if is_direct_answer:
return
print(f"Chat history: {chat_request.chat_history}")

# If the stream contains tool calls, call the tools and update the chat history
tool_results = self.call_tools(new_chat_history, deployment_model, **kwargs)
chat_request.tool_results = [result for result in tool_results]
chat_request.chat_history = new_chat_history
# Loop until there are no new tool calls
for step in range(MAX_STEPS):
logger.info(f"Step {step + 1}")

# Remove the message if tool results are present
if tool_results:
chat_request.message = ""
# Invoke chat stream
has_tool_calls = False
for event in deployment_model.invoke_chat_stream(
chat_request, trace_id=trace_id, user_id=user_id, agent_id=agent_id
):
if event["event_type"] == StreamEvent.STREAM_END:
chat_request.chat_history = event["response"].get(
"chat_history", []
)
elif event["event_type"] == StreamEvent.TOOL_CALLS_GENERATION:
has_tool_calls = True

for event in deployment_model.invoke_chat_stream(
chat_request, trace_id=trace_id, user_id=user_id, agent_id=agent_id
):
if event["event_type"] != StreamEvent.STREAM_START:
yield event
if event["event_type"] == StreamEvent.STREAM_END:
chat_request.chat_history = event["response"].get("chat_history", [])

# Update the chat request and restore the message
# Check for new tool calls in the chat history
if has_tool_calls:
# Handle tool calls
tool_results = self.call_tools(
chat_request.chat_history, deployment_model, **kwargs
)

# Remove the message if tool results are present
if tool_results:
chat_request.tool_results = [result for result in tool_results]
chat_request.message = ""
else:
break # Exit loop if there are no new tool calls

# Restore the original chat request message if needed
self.chat_request = chat_request

def update_chat_history_with_tool_results(
self, chat_request: Any, tool_results: List[Dict[str, Any]]
):
if not hasattr(chat_request, "chat_history"):
chat_request.chat_history = []

chat_request.chat_history.extend(tool_results)

def call_tools(self, chat_history, deployment_model, **kwargs: Any):
tool_results = []
if not hasattr(chat_history[-1], "tool_results"):
if not "tool_calls" in chat_history[-1]:
logging.warning("No tool calls found in chat history.")
return tool_results

tool_calls = chat_history[-1].tool_calls
tool_calls = chat_history[-1]["tool_calls"]
tool_plan = chat_history[-1].get("message", None)
logger.info(f"Tool calls: {tool_calls}")
logger.info(f"Tool plan: {tool_plan}")

# TODO: Call tools in parallel
for tool_call in tool_calls:
Expand Down
48 changes: 48 additions & 0 deletions src/backend/crud/tool_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from sqlalchemy.orm import Session

from backend.database_models.tool_call import ToolCall


def create_tool_call(db: Session, tool_call: ToolCall) -> ToolCall:
"""
Create a new tool call.

Args:
db (Session): Database session.
tool_call (ToolCall): Tool call data to be created.

Returns:
ToolCall: Created tool call.
"""
db.add(tool_call)
db.commit()
db.refresh(tool_call)
return tool_call


def get_tool_call_by_id(db: Session, tool_call_id: str) -> ToolCall:
"""
Get a tool call by its ID.

Args:
db (Session): Database session.
tool_call_id (str): Tool call ID.

Returns:
ToolCall: Tool call with the given ID.
"""
return db.query(ToolCall).filter(ToolCall.id == tool_call_id).first()


def list_tool_calls_by_message_id(db: Session, message_id: str) -> list[ToolCall]:
"""
List all tool calls by message ID.

Args:
db (Session): Database session.
message_id (str): Message ID.

Returns:
list[ToolCall]: List of tool calls.
"""
return db.query(ToolCall).filter(ToolCall.message_id == message_id).all()
3 changes: 3 additions & 0 deletions src/backend/database_models/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from backend.database_models.citation import Citation
from backend.database_models.document import Document
from backend.database_models.file import File
from backend.database_models.tool_call import ToolCall


class MessageAgent(StrEnum):
Expand All @@ -32,10 +33,12 @@ class Message(Base):
position: Mapped[int]
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
generation_id: Mapped[str] = mapped_column(String, nullable=True)
tool_plan: Mapped[str] = mapped_column(String, nullable=True)

documents: Mapped[List["Document"]] = relationship()
citations: Mapped[List["Citation"]] = relationship()
files: Mapped[List["File"]] = relationship()
tool_calls: Mapped[List["ToolCall"]] = relationship()

agent: Mapped[MessageAgent] = mapped_column(
Enum(MessageAgent, native_enum=False),
Expand Down
20 changes: 20 additions & 0 deletions src/backend/database_models/tool_call.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from sqlalchemy import JSON, ForeignKey, Index, String
from sqlalchemy.orm import Mapped, mapped_column

from backend.database_models.base import Base


class ToolCall(Base):
"""
Default ToolCall model for tool calls.
"""

__tablename__ = "tool_calls"

name: Mapped[str]
parameters: Mapped[dict] = mapped_column(JSON, nullable=True)
message_id: Mapped[str] = mapped_column(
ForeignKey("messages.id", ondelete="CASCADE")
)

__table_args__ = (Index("tool_call_message_id", message_id),)
5 changes: 5 additions & 0 deletions src/backend/routers/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ async def chat_stream(
should_store,
managed_tools,
deployment_config,
next_message_position,
) = process_chat(session, chat_request, request, agent_id)

return EventSourceResponse(
Expand All @@ -84,6 +85,7 @@ async def chat_stream(
conversation_id,
user_id,
should_store=should_store,
next_message_position=next_message_position,
),
media_type="text/event-stream",
)
Expand Down Expand Up @@ -125,6 +127,7 @@ async def chat(
should_store,
managed_tools,
deployment_config,
next_message_position,
) = process_chat(session, chat_request, request, agent_id)

return generate_chat_response(
Expand All @@ -144,6 +147,7 @@ async def chat(
conversation_id,
user_id,
should_store=should_store,
next_message_position=next_message_position,
)


Expand All @@ -167,6 +171,7 @@ def langchain_chat_stream(
should_store,
managed_tools,
_,
_,
) = process_chat(session, chat_request, request)

return EventSourceResponse(
Expand Down
24 changes: 21 additions & 3 deletions src/backend/schemas/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ class ChatMessage(BaseModel):
)
message: str | None = Field(
title="Contents of the chat message.",
default=None,
)
tool_plan: str | None = Field(
title="Contents of the tool plan.",
default=None,
)
tool_results: List[Dict[str, Any]] | None = Field(
title="Results from the tool call.",
Expand All @@ -52,8 +57,13 @@ class ChatMessage(BaseModel):
default=None,
)

def to_dict(self) -> Dict[str, str]:
return {"role": self.role, "message": self.message}
def to_dict(self) -> Dict[str, Any]:
return {
"role": self.role,
"message": self.message,
"tool_results": self.tool_results,
"tool_calls": self.tool_calls,
}


# TODO: fix titles of these types
Expand Down Expand Up @@ -184,7 +194,15 @@ class StreamEnd(ChatResponse):
title="List of tool calls generated for custom tools",
default=[],
)
finish_reason: str | None = Field(default=None)
finish_reason: str | None = (Field(default=None),)
chat_history: List[ChatMessage] | None = Field(
default=None,
title="A list of entries used to construct the conversation. If provided, these messages will be used to build the prompt and the conversation_id will be ignored so no data will be stored to maintain state.",
)
error: str | None = Field(
title="Error message if the response is an error.",
default=None,
)


class NonStreamedChatResponse(ChatResponse):
Expand Down
Loading