Skip to content

Commit e2f0a7b

Browse files
committed
Improve the performance and suitable for NPU
1 parent 38a3e4d commit e2f0a7b

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

examples/text_to_image/train_text_to_image_sdxl.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959

6060
logger = get_logger(__name__)
6161
if is_torch_npu_available():
62+
import torch_npu
6263
torch.npu.config.allow_internal_format = False
6364

6465
DATASET_NAME_MAPPING = {
@@ -531,7 +532,7 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
531532
return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()}
532533

533534

534-
def compute_vae_encodings(batch, vae):
535+
def compute_vae_encodings(batch, accelerator, vae):
535536
images = batch.pop("pixel_values")
536537
pixel_values = torch.stack(list(images))
537538
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
@@ -540,7 +541,7 @@ def compute_vae_encodings(batch, vae):
540541
with torch.no_grad():
541542
model_input = vae.encode(pixel_values).latent_dist.sample()
542543
model_input = model_input * vae.config.scaling_factor
543-
return {"model_input": model_input.cpu()}
544+
return {"model_input": accelerator.gather(model_input)}
544545

545546

546547
def generate_timestep_weights(args, num_timesteps):
@@ -910,7 +911,7 @@ def preprocess_train(examples):
910911
proportion_empty_prompts=args.proportion_empty_prompts,
911912
caption_column=args.caption_column,
912913
)
913-
compute_vae_encodings_fn = functools.partial(compute_vae_encodings, vae=vae)
914+
compute_vae_encodings_fn = functools.partial(compute_vae_encodings, accelerator=accelerator, vae=vae)
914915
with accelerator.main_process_first():
915916
from datasets.fingerprint import Hasher
916917

@@ -935,7 +936,10 @@ def preprocess_train(examples):
935936
del compute_vae_encodings_fn, compute_embeddings_fn, text_encoder_one, text_encoder_two
936937
del text_encoders, tokenizers, vae
937938
gc.collect()
938-
torch.cuda.empty_cache()
939+
if is_torch_npu_available():
940+
torch_npu.npu.empty_cache()
941+
else:
942+
torch.cuda.empty_cache()
939943

940944
def collate_fn(examples):
941945
model_input = torch.stack([torch.tensor(example["model_input"]) for example in examples])
@@ -1091,8 +1095,7 @@ def compute_time_ids(original_size, crops_coords_top_left):
10911095
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
10921096
target_size = (args.resolution, args.resolution)
10931097
add_time_ids = list(original_size + crops_coords_top_left + target_size)
1094-
add_time_ids = torch.tensor([add_time_ids])
1095-
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1098+
add_time_ids = torch.tensor([add_time_ids], device=accelerator.device, dtype=weight_dtype)
10961099
return add_time_ids
10971100

10981101
add_time_ids = torch.cat(
@@ -1261,7 +1264,10 @@ def compute_time_ids(original_size, crops_coords_top_left):
12611264
)
12621265

12631266
del pipeline
1264-
torch.cuda.empty_cache()
1267+
if is_torch_npu_available():
1268+
torch_npu.npu.empty_cache()
1269+
else:
1270+
torch.cuda.empty_cache()
12651271

12661272
if args.use_ema:
12671273
# Switch back to the original UNet parameters.

0 commit comments

Comments
 (0)