From 1e58e59e02b5087ec69db1f5a300d1893ed9ce3c Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Thu, 6 Nov 2025 21:29:26 -0500 Subject: [PATCH 1/3] swarm - switch to handoff node only after current node stops --- src/strands/multiagent/swarm.py | 41 ++++++++++++-------------- src/strands/session/session_manager.py | 4 +-- 2 files changed, 21 insertions(+), 24 deletions(-) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index accd56463..833cd240a 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -156,6 +156,7 @@ class SwarmState: # Total metrics across all agents accumulated_metrics: Metrics = field(default_factory=lambda: Metrics(latencyMs=0)) execution_time: int = 0 # Total execution time in milliseconds + handoff_node: SwarmNode | None = None # The agent to execute next handoff_message: str | None = None # Message passed during agent handoff def should_continue( @@ -537,7 +538,7 @@ def handoff_to_agent(agent_name: str, message: str, context: dict[str, Any] | No # Execute handoff swarm_ref._handle_handoff(target_node, message, context) - return {"status": "success", "content": [{"text": f"Handed off to {agent_name}: {message}"}]} + return {"status": "success", "content": [{"text": f"Handing off to {agent_name}: {message}"}]} except Exception as e: return {"status": "error", "content": [{"text": f"Error in handoff: {str(e)}"}]} @@ -553,21 +554,19 @@ def _handle_handoff(self, target_node: SwarmNode, message: str, context: dict[st ) return - # Update swarm state - previous_agent = cast(SwarmNode, self.state.current_node) - self.state.current_node = target_node + current_node = cast(SwarmNode, self.state.current_node) - # Store handoff message for the target agent + self.state.handoff_node = target_node self.state.handoff_message = message # Store handoff context as shared context if context: for key, value in context.items(): - self.shared_context.add_context(previous_agent, key, value) + self.shared_context.add_context(current_node, key, value) logger.debug( - "from_node=<%s>, to_node=<%s> | handed off from agent to agent", - previous_agent.node_id, + "from_node=<%s>, to_node=<%s> | handing off from agent to agent", + current_node.node_id, target_node.node_id, ) @@ -667,7 +666,6 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato logger.debug("reason=<%s> | stopping execution", reason) break - # Get current node current_node = self.state.current_node if not current_node or current_node.node_id not in self.nodes: logger.error("node=<%s> | node not found", current_node.node_id if current_node else "None") @@ -680,14 +678,10 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato len(self.state.node_history) + 1, ) - # Store the current node before execution to detect handoffs - previous_node = current_node - - # Execute node with timeout protection # TODO: Implement cancellation token to stop _execute_node from continuing try: - # Execute with timeout wrapper for async generator streaming self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, current_node.node_id, invocation_state)) + node_stream = self._stream_with_timeout( self._execute_node(current_node, self.state.task, invocation_state), self.node_timeout, @@ -697,28 +691,31 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato yield event self.state.node_history.append(current_node) - - # After self.state add current node, swarm state finish updating, we persist here self.hooks.invoke_callbacks(AfterNodeCallEvent(self, current_node.node_id, invocation_state)) logger.debug("node=<%s> | node execution completed", current_node.node_id) - # Check if handoff occurred during execution - if self.state.current_node is not None and self.state.current_node != previous_node: - # Emit handoff event (single node transition in Swarm) + # Check if handoff requested during execution + if self.state.handoff_node: + previous_node = current_node + current_node = self.state.handoff_node + + self.state.handoff_node = None + self.state.current_node = current_node + handoff_event = MultiAgentHandoffEvent( from_node_ids=[previous_node.node_id], - to_node_ids=[self.state.current_node.node_id], + to_node_ids=[current_node.node_id], message=self.state.handoff_message or "Agent handoff occurred", ) yield handoff_event logger.debug( "from_node=<%s>, to_node=<%s> | handoff detected", previous_node.node_id, - self.state.current_node.node_id, + current_node.node_id, ) + else: - # No handoff occurred, mark swarm as complete logger.debug("node=<%s> | no handoff occurred, marking swarm as complete", current_node.node_id) self.state.completion_status = Status.COMPLETED break diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index fb9132828..d4bc72c80 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -6,7 +6,7 @@ from ..experimental.hooks.multiagent.events import ( AfterMultiAgentInvocationEvent, - AfterNodeCallEvent, + BeforeNodeCallEvent, MultiAgentInitializedEvent, ) from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent @@ -44,7 +44,7 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source)) - registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) + registry.add_callback(BeforeNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source)) @abstractmethod From ce88c29343cddc26098311d76de846373753c20f Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Sun, 16 Nov 2025 16:51:26 -0500 Subject: [PATCH 2/3] set next nodes to hand off --- src/strands/multiagent/swarm.py | 11 ++++++----- src/strands/session/session_manager.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index a7dca1a0c..3913cd837 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -862,11 +862,12 @@ def _build_result(self) -> SwarmResult: def serialize_state(self) -> dict[str, Any]: """Serialize the current swarm state to a dictionary.""" status_str = self.state.completion_status.value - next_nodes = ( - [self.state.current_node.node_id] - if self.state.completion_status == Status.EXECUTING and self.state.current_node - else [] - ) + if self.state.handoff_node: + next_nodes = [self.state.handoff_node.node_id] + elif self.state.completion_status == Status.EXECUTING and self.state.current_node: + next_nodes = [self.state.current_node.node_id] + else: + next_nodes = [] return { "type": "swarm", diff --git a/src/strands/session/session_manager.py b/src/strands/session/session_manager.py index d4bc72c80..fb9132828 100644 --- a/src/strands/session/session_manager.py +++ b/src/strands/session/session_manager.py @@ -6,7 +6,7 @@ from ..experimental.hooks.multiagent.events import ( AfterMultiAgentInvocationEvent, - BeforeNodeCallEvent, + AfterNodeCallEvent, MultiAgentInitializedEvent, ) from ..hooks.events import AfterInvocationEvent, AgentInitializedEvent, MessageAddedEvent @@ -44,7 +44,7 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: registry.add_callback(AfterInvocationEvent, lambda event: self.sync_agent(event.agent)) registry.add_callback(MultiAgentInitializedEvent, lambda event: self.initialize_multi_agent(event.source)) - registry.add_callback(BeforeNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) + registry.add_callback(AfterNodeCallEvent, lambda event: self.sync_multi_agent(event.source)) registry.add_callback(AfterMultiAgentInvocationEvent, lambda event: self.sync_multi_agent(event.source)) @abstractmethod From 54cce6f9157703ff71bb62d975f3d36d6ffd33a1 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Mon, 17 Nov 2025 10:11:44 -0500 Subject: [PATCH 3/3] unit test --- tests/strands/multiagent/test_swarm.py | 27 ++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/strands/multiagent/test_swarm.py b/tests/strands/multiagent/test_swarm.py index e8a6a5f79..008b2954d 100644 --- a/tests/strands/multiagent/test_swarm.py +++ b/tests/strands/multiagent/test_swarm.py @@ -1149,3 +1149,30 @@ async def test_swarm_persistence(mock_strands_tracer, mock_use_span): assert final_state["status"] == "completed" assert len(final_state["node_history"]) == 1 assert "test_agent" in final_state["node_results"] + + +@pytest.mark.asyncio +async def test_swarm_handle_handoff(): + first_agent = create_mock_agent("first") + second_agent = create_mock_agent("second") + + swarm = Swarm([first_agent, second_agent]) + + async def handoff_stream(*args, **kwargs): + yield {"agent_start": True} + + swarm._handle_handoff(swarm.nodes["second"], "test message", {}) + + assert swarm.state.current_node.node_id == "first" + assert swarm.state.handoff_node.node_id == "second" + + yield {"result": first_agent.return_value} + + first_agent.stream_async = Mock(side_effect=handoff_stream) + + result = await swarm.invoke_async("test") + assert result.status == Status.COMPLETED + + tru_node_order = [node.node_id for node in result.node_history] + exp_node_order = ["first", "second"] + assert tru_node_order == exp_node_order