@@ -726,23 +726,29 @@ def _caching_allocator_warmup(
726726 very large margin.
727727 """
728728 factor = 2 if hf_quantizer is None else hf_quantizer .get_cuda_warm_up_factor ()
729- # Remove disk and cpu devices, and cast to proper torch.device
729+
730+ # Keep only accelerator devices
730731 accelerator_device_map = {
731732 param : torch .device (device )
732733 for param , device in expanded_device_map .items ()
733734 if str (device ) not in ["cpu" , "disk" ]
734735 }
735- total_byte_count = defaultdict (lambda : 0 )
736+ if not accelerator_device_map :
737+ return
738+
739+ elements_per_device = defaultdict (int )
736740 for param_name , device in accelerator_device_map .items ():
737741 try :
738- param = model .get_parameter (param_name )
742+ p = model .get_parameter (param_name )
739743 except AttributeError :
740- param = model .get_buffer (param_name )
741- # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
742- param_byte_count = param .numel () * param .element_size ()
744+ try :
745+ p = model .get_buffer (param_name )
746+ except AttributeError :
747+ raise AttributeError (f"Parameter or buffer with name={ param_name } not found in model" )
743748 # TODO: account for TP when needed.
744- total_byte_count [device ] += param_byte_count
749+ elements_per_device [device ] += p . numel ()
745750
746751 # This will kick off the caching allocator to avoid having to Malloc afterwards
747- for device , byte_count in total_byte_count .items ():
748- _ = torch .empty (byte_count // factor , dtype = dtype , device = device , requires_grad = False )
752+ for device , elem_count in elements_per_device .items ():
753+ warmup_elems = max (1 , elem_count // factor )
754+ _ = torch .empty (warmup_elems , dtype = dtype , device = device , requires_grad = False )
0 commit comments