6262 load_or_create_model_card ,
6363 populate_model_card ,
6464)
65+ from ..utils .torch_utils import device_synchronize , empty_device_cache
6566from .model_loading_utils import (
67+ _caching_allocator_warmup ,
6668 _determine_device_map ,
69+ _expand_device_map ,
6770 _fetch_index_file ,
6871 _fetch_index_file_legacy ,
72+ _find_mismatched_keys ,
6973 _load_state_dict_into_model ,
7074 load_model_dict_into_meta ,
7175 load_state_dict ,
@@ -1469,11 +1473,6 @@ def _load_pretrained_model(
14691473 for pat in cls ._keys_to_ignore_on_load_unexpected :
14701474 unexpected_keys = [k for k in unexpected_keys if re .search (pat , k ) is None ]
14711475
1472- mismatched_keys = []
1473-
1474- assign_to_params_buffers = None
1475- error_msgs = []
1476-
14771476 # Deal with offload
14781477 if device_map is not None and "disk" in device_map .values ():
14791478 if offload_folder is None :
@@ -1482,18 +1481,27 @@ def _load_pretrained_model(
14821481 " for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
14831482 " offers the weights in this format."
14841483 )
1485- if offload_folder is not None :
1484+ else :
14861485 os .makedirs (offload_folder , exist_ok = True )
14871486 if offload_state_dict is None :
14881487 offload_state_dict = True
14891488
1489+ # If a device map has been used, we can speedup the load time by warming up the device caching allocator.
1490+ # If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
1491+ # lot of individual calls to device malloc). We can, however, preallocate the memory required by the
1492+ # tensors using their expected shape and not performing any initialization of the memory (empty data).
1493+ # When the actual device allocations happen, the allocator already has a pool of unused device memory
1494+ # that it can re-use for faster loading of the model.
1495+ # TODO: add support for warmup with hf_quantizer
1496+ if device_map is not None and hf_quantizer is None :
1497+ expanded_device_map = _expand_device_map (device_map , expected_keys )
1498+ _caching_allocator_warmup (model , expanded_device_map , dtype )
1499+
14901500 offload_index = {} if device_map is not None and "disk" in device_map .values () else None
1501+ state_dict_folder , state_dict_index = None , None
14911502 if offload_state_dict :
14921503 state_dict_folder = tempfile .mkdtemp ()
14931504 state_dict_index = {}
1494- else :
1495- state_dict_folder = None
1496- state_dict_index = None
14971505
14981506 if state_dict is not None :
14991507 # load_state_dict will manage the case where we pass a dict instead of a file
@@ -1503,38 +1511,14 @@ def _load_pretrained_model(
15031511 if len (resolved_model_file ) > 1 :
15041512 resolved_model_file = logging .tqdm (resolved_model_file , desc = "Loading checkpoint shards" )
15051513
1514+ mismatched_keys = []
1515+ assign_to_params_buffers = None
1516+ error_msgs = []
1517+
15061518 for shard_file in resolved_model_file :
15071519 state_dict = load_state_dict (shard_file , dduf_entries = dduf_entries )
1508-
1509- def _find_mismatched_keys (
1510- state_dict ,
1511- model_state_dict ,
1512- loaded_keys ,
1513- ignore_mismatched_sizes ,
1514- ):
1515- mismatched_keys = []
1516- if ignore_mismatched_sizes :
1517- for checkpoint_key in loaded_keys :
1518- model_key = checkpoint_key
1519- # If the checkpoint is sharded, we may not have the key here.
1520- if checkpoint_key not in state_dict :
1521- continue
1522-
1523- if (
1524- model_key in model_state_dict
1525- and state_dict [checkpoint_key ].shape != model_state_dict [model_key ].shape
1526- ):
1527- mismatched_keys .append (
1528- (checkpoint_key , state_dict [checkpoint_key ].shape , model_state_dict [model_key ].shape )
1529- )
1530- del state_dict [checkpoint_key ]
1531- return mismatched_keys
1532-
15331520 mismatched_keys += _find_mismatched_keys (
1534- state_dict ,
1535- model_state_dict ,
1536- loaded_keys ,
1537- ignore_mismatched_sizes ,
1521+ state_dict , model_state_dict , loaded_keys , ignore_mismatched_sizes
15381522 )
15391523
15401524 if low_cpu_mem_usage :
@@ -1554,9 +1538,13 @@ def _find_mismatched_keys(
15541538 else :
15551539 if assign_to_params_buffers is None :
15561540 assign_to_params_buffers = check_support_param_buffer_assignment (model , state_dict )
1557-
15581541 error_msgs += _load_state_dict_into_model (model , state_dict , assign_to_params_buffers )
15591542
1543+ # Ensure tensors are correctly placed on device by synchronizing before returning control to user. This is
1544+ # required because we move tensors with non_blocking=True, which is slightly faster for model loading.
1545+ empty_device_cache ()
1546+ device_synchronize ()
1547+
15601548 if offload_index is not None and len (offload_index ) > 0 :
15611549 save_offload_index (offload_index , offload_folder )
15621550 offload_index = None
0 commit comments