5959
6060logger = get_logger (__name__ )
6161if is_torch_npu_available ():
62+ import torch_npu
6263 torch .npu .config .allow_internal_format = False
6364
6465DATASET_NAME_MAPPING = {
@@ -531,7 +532,7 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
531532 return {"prompt_embeds" : prompt_embeds .cpu (), "pooled_prompt_embeds" : pooled_prompt_embeds .cpu ()}
532533
533534
534- def compute_vae_encodings (batch , vae ):
535+ def compute_vae_encodings (batch , accelerator , vae ):
535536 images = batch .pop ("pixel_values" )
536537 pixel_values = torch .stack (list (images ))
537538 pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
@@ -540,7 +541,7 @@ def compute_vae_encodings(batch, vae):
540541 with torch .no_grad ():
541542 model_input = vae .encode (pixel_values ).latent_dist .sample ()
542543 model_input = model_input * vae .config .scaling_factor
543- return {"model_input" : model_input . cpu ( )}
544+ return {"model_input" : accelerator . gather ( model_input )}
544545
545546
546547def generate_timestep_weights (args , num_timesteps ):
@@ -910,7 +911,7 @@ def preprocess_train(examples):
910911 proportion_empty_prompts = args .proportion_empty_prompts ,
911912 caption_column = args .caption_column ,
912913 )
913- compute_vae_encodings_fn = functools .partial (compute_vae_encodings , vae = vae )
914+ compute_vae_encodings_fn = functools .partial (compute_vae_encodings , accelerator = accelerator , vae = vae )
914915 with accelerator .main_process_first ():
915916 from datasets .fingerprint import Hasher
916917
@@ -935,7 +936,10 @@ def preprocess_train(examples):
935936 del compute_vae_encodings_fn , compute_embeddings_fn , text_encoder_one , text_encoder_two
936937 del text_encoders , tokenizers , vae
937938 gc .collect ()
938- torch .cuda .empty_cache ()
939+ if is_torch_npu_available ():
940+ torch_npu .npu .empty_cache ()
941+ else :
942+ torch .cuda .empty_cache ()
939943
940944 def collate_fn (examples ):
941945 model_input = torch .stack ([torch .tensor (example ["model_input" ]) for example in examples ])
@@ -1091,8 +1095,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10911095 # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
10921096 target_size = (args .resolution , args .resolution )
10931097 add_time_ids = list (original_size + crops_coords_top_left + target_size )
1094- add_time_ids = torch .tensor ([add_time_ids ])
1095- add_time_ids = add_time_ids .to (accelerator .device , dtype = weight_dtype )
1098+ add_time_ids = torch .tensor ([add_time_ids ], device = accelerator .device , dtype = weight_dtype )
10961099 return add_time_ids
10971100
10981101 add_time_ids = torch .cat (
@@ -1261,7 +1264,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
12611264 )
12621265
12631266 del pipeline
1264- torch .cuda .empty_cache ()
1267+ if is_torch_npu_available ():
1268+ torch_npu .npu .empty_cache ()
1269+ else :
1270+ torch .cuda .empty_cache ()
12651271
12661272 if args .use_ema :
12671273 # Switch back to the original UNet parameters.
0 commit comments