@@ -245,7 +245,6 @@ def _offload_to_memory(self):
245245 param .data = self .cpu_param_dict [param ]
246246 for buffer in self .buffers :
247247 buffer .data = self .cpu_param_dict [buffer ]
248-
249248 else :
250249 for group_module in self .modules :
251250 group_module .to (self .offload_device , non_blocking = False )
@@ -303,9 +302,23 @@ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
303302 if self .group .onload_leader == module :
304303 if self .group .onload_self :
305304 self .group .onload_ ()
306- if self .next_group is not None and not self .next_group .onload_self :
305+
306+ should_onload_next_group = self .next_group is not None and not self .next_group .onload_self
307+ if should_onload_next_group :
307308 self .next_group .onload_ ()
308309
310+ should_synchronize = (
311+ not self .group .onload_self and self .group .stream is not None and not should_onload_next_group
312+ )
313+ if should_synchronize :
314+ # If this group didn't onload itself, it means it was asynchronously onloaded by the
315+ # previous group. We need to synchronize the side stream to ensure parameters
316+ # are completely loaded to proceed with forward pass. Without this, uninitialized
317+ # weights will be used in the computation, leading to incorrect results
318+ # Also, we should only do this synchronization if we don't already do it from the sync call in
319+ # self.next_group.onload_, hence the `not should_onload_next_group` check.
320+ self .group .stream .synchronize ()
321+
309322 args = send_to_device (args , self .group .onload_device , non_blocking = self .group .non_blocking )
310323 kwargs = send_to_device (kwargs , self .group .onload_device , non_blocking = self .group .non_blocking )
311324 return args , kwargs
0 commit comments