@@ -189,7 +189,7 @@ def save_pretrained(
189189 save_directory : Union [str , os .PathLike ],
190190 safe_serialization : bool = True ,
191191 variant : Optional [str ] = None ,
192- max_shard_size : Union [int , str ] = "10GB" ,
192+ max_shard_size : Optional [ Union [int , str ]] = None ,
193193 push_to_hub : bool = False ,
194194 ** kwargs ,
195195 ):
@@ -205,7 +205,7 @@ class implements both a save and loading method. The pipeline is easily reloaded
205205 Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
206206 variant (`str`, *optional*):
207207 If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
208- max_shard_size (`int` or `str`, defaults to `"10GB" `):
208+ max_shard_size (`int` or `str`, defaults to `None `):
209209 The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
210210 lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
211211 If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
@@ -293,7 +293,8 @@ def is_saveable_module(name, value):
293293 save_kwargs ["safe_serialization" ] = safe_serialization
294294 if save_method_accept_variant :
295295 save_kwargs ["variant" ] = variant
296- if save_method_accept_max_shard_size :
296+ if save_method_accept_max_shard_size and max_shard_size is not None :
297+ # max_shard_size is expected to not be None in ModelMixin
297298 save_kwargs ["max_shard_size" ] = max_shard_size
298299
299300 save_method (os .path .join (save_directory , pipeline_component_name ), ** save_kwargs )
0 commit comments