Skip to content

Commit f2dd43e

Browse files
committed
revert kwargs to explicit declaration
1 parent db67529 commit f2dd43e

File tree

1 file changed

+142
-16
lines changed

1 file changed

+142
-16
lines changed

library/train_util.py

Lines changed: 142 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,7 @@ def __init__(
409409

410410
self.alpha_mask = alpha_mask
411411

412+
412413
class DreamBoothSubset(BaseSubset):
413414
def __init__(
414415
self,
@@ -417,13 +418,47 @@ def __init__(
417418
class_tokens: Optional[str],
418419
caption_extension: str,
419420
cache_info: bool,
420-
**kwargs,
421+
num_repeats,
422+
shuffle_caption,
423+
caption_separator: str,
424+
keep_tokens,
425+
keep_tokens_separator,
426+
secondary_separator,
427+
enable_wildcard,
428+
color_aug,
429+
flip_aug,
430+
face_crop_aug_range,
431+
random_crop,
432+
caption_dropout_rate,
433+
caption_dropout_every_n_epochs,
434+
caption_tag_dropout_rate,
435+
caption_prefix,
436+
caption_suffix,
437+
token_warmup_min,
438+
token_warmup_step,
421439
) -> None:
422440
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
423441

424442
super().__init__(
425443
image_dir,
426-
**kwargs,
444+
num_repeats,
445+
shuffle_caption,
446+
caption_separator,
447+
keep_tokens,
448+
keep_tokens_separator,
449+
secondary_separator,
450+
enable_wildcard,
451+
color_aug,
452+
flip_aug,
453+
face_crop_aug_range,
454+
random_crop,
455+
caption_dropout_rate,
456+
caption_dropout_every_n_epochs,
457+
caption_tag_dropout_rate,
458+
caption_prefix,
459+
caption_suffix,
460+
token_warmup_min,
461+
token_warmup_step,
427462
)
428463

429464
self.is_reg = is_reg
@@ -444,13 +479,47 @@ def __init__(
444479
self,
445480
image_dir,
446481
metadata_file: str,
447-
**kwargs,
482+
num_repeats,
483+
shuffle_caption,
484+
caption_separator,
485+
keep_tokens,
486+
keep_tokens_separator,
487+
secondary_separator,
488+
enable_wildcard,
489+
color_aug,
490+
flip_aug,
491+
face_crop_aug_range,
492+
random_crop,
493+
caption_dropout_rate,
494+
caption_dropout_every_n_epochs,
495+
caption_tag_dropout_rate,
496+
caption_prefix,
497+
caption_suffix,
498+
token_warmup_min,
499+
token_warmup_step,
448500
) -> None:
449501
assert metadata_file is not None, "metadata_file must be specified / metadata_fileは指定が必須です"
450502

451503
super().__init__(
452504
image_dir,
453-
**kwargs,
505+
num_repeats,
506+
shuffle_caption,
507+
caption_separator,
508+
keep_tokens,
509+
keep_tokens_separator,
510+
secondary_separator,
511+
enable_wildcard,
512+
color_aug,
513+
flip_aug,
514+
face_crop_aug_range,
515+
random_crop,
516+
caption_dropout_rate,
517+
caption_dropout_every_n_epochs,
518+
caption_tag_dropout_rate,
519+
caption_prefix,
520+
caption_suffix,
521+
token_warmup_min,
522+
token_warmup_step,
454523
)
455524

456525
self.metadata_file = metadata_file
@@ -468,13 +537,47 @@ def __init__(
468537
conditioning_data_dir: str,
469538
caption_extension: str,
470539
cache_info: bool,
471-
**kwargs,
540+
num_repeats,
541+
shuffle_caption,
542+
caption_separator,
543+
keep_tokens,
544+
keep_tokens_separator,
545+
secondary_separator,
546+
enable_wildcard,
547+
color_aug,
548+
flip_aug,
549+
face_crop_aug_range,
550+
random_crop,
551+
caption_dropout_rate,
552+
caption_dropout_every_n_epochs,
553+
caption_tag_dropout_rate,
554+
caption_prefix,
555+
caption_suffix,
556+
token_warmup_min,
557+
token_warmup_step,
472558
) -> None:
473559
assert image_dir is not None, "image_dir must be specified / image_dirは指定が必須です"
474560

475561
super().__init__(
476562
image_dir,
477-
**kwargs,
563+
num_repeats,
564+
shuffle_caption,
565+
caption_separator,
566+
keep_tokens,
567+
keep_tokens_separator,
568+
secondary_separator,
569+
enable_wildcard,
570+
color_aug,
571+
flip_aug,
572+
face_crop_aug_range,
573+
random_crop,
574+
caption_dropout_rate,
575+
caption_dropout_every_n_epochs,
576+
caption_tag_dropout_rate,
577+
caption_prefix,
578+
caption_suffix,
579+
token_warmup_min,
580+
token_warmup_step,
478581
)
479582

480583
self.conditioning_data_dir = conditioning_data_dir
@@ -1100,10 +1203,12 @@ def __getitem__(self, index):
11001203
else:
11011204
latents = image_info.latents_flipped
11021205
alpha_mask = image_info.alpha_mask_flipped
1103-
1206+
11041207
image = None
11051208
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
1106-
latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(image_info.latents_npz)
1209+
latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask = load_latents_from_disk(
1210+
image_info.latents_npz
1211+
)
11071212
if flipped:
11081213
latents = flipped_latents
11091214
alpha_mask = flipped_alpha_mask
@@ -1116,7 +1221,9 @@ def __getitem__(self, index):
11161221
image = None
11171222
else:
11181223
# 画像を読み込み、必要ならcropする
1119-
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(subset, image_info.absolute_path, subset.alpha_mask)
1224+
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(
1225+
subset, image_info.absolute_path, subset.alpha_mask
1226+
)
11201227
im_h, im_w = img.shape[0:2]
11211228

11221229
if self.enable_bucket:
@@ -1157,7 +1264,7 @@ def __getitem__(self, index):
11571264
if img.shape[2] == 4:
11581265
alpha_mask = img[:, :, 3] # [W,H]
11591266
else:
1160-
alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H]
1267+
alpha_mask = np.full((im_w, im_h), 255, dtype=np.uint8) # [W,H]
11611268
alpha_mask = transforms.ToTensor()(alpha_mask)
11621269
else:
11631270
alpha_mask = None
@@ -2070,7 +2177,14 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
20702177
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
20712178
def load_latents_from_disk(
20722179
npz_path,
2073-
) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
2180+
) -> Tuple[
2181+
Optional[torch.Tensor],
2182+
Optional[List[int]],
2183+
Optional[List[int]],
2184+
Optional[torch.Tensor],
2185+
Optional[torch.Tensor],
2186+
Optional[torch.Tensor],
2187+
]:
20742188
npz = np.load(npz_path)
20752189
if "latents" not in npz:
20762190
raise ValueError(f"error: npz is old format. please re-generate {npz_path}")
@@ -2084,7 +2198,9 @@ def load_latents_from_disk(
20842198
return latents, original_size, crop_ltrb, flipped_latents, alpha_mask, flipped_alpha_mask
20852199

20862200

2087-
def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None):
2201+
def save_latents_to_disk(
2202+
npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, alpha_mask=None, flipped_alpha_mask=None
2203+
):
20882204
kwargs = {}
20892205
if flipped_latents_tensor is not None:
20902206
kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy()
@@ -2344,10 +2460,10 @@ def cache_batch_latents(
23442460
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
23452461
if info.use_alpha_mask:
23462462
if image.shape[2] == 4:
2347-
alpha_mask = image[:, :, 3] # [W,H]
2463+
alpha_mask = image[:, :, 3] # [W,H]
23482464
image = image[:, :, :3]
23492465
else:
2350-
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
2466+
alpha_mask = np.full_like(image[:, :, 0], 255, dtype=np.uint8) # [W,H]
23512467
alpha_masks.append(transforms.ToTensor()(alpha_mask))
23522468
image = IMAGE_TRANSFORMS(image)
23532469
images.append(image)
@@ -2377,13 +2493,23 @@ def cache_batch_latents(
23772493
flipped_latents = [None] * len(latents)
23782494
flipped_alpha_masks = [None] * len(image_infos)
23792495

2380-
for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks):
2496+
for info, latent, flipped_latent, alpha_mask, flipped_alpha_mask in zip(
2497+
image_infos, latents, flipped_latents, alpha_masks, flipped_alpha_masks
2498+
):
23812499
# check NaN
23822500
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
23832501
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
23842502

23852503
if cache_to_disk:
2386-
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, alpha_mask, flipped_alpha_mask)
2504+
save_latents_to_disk(
2505+
info.latents_npz,
2506+
latent,
2507+
info.latents_original_size,
2508+
info.latents_crop_ltrb,
2509+
flipped_latent,
2510+
alpha_mask,
2511+
flipped_alpha_mask,
2512+
)
23872513
else:
23882514
info.latents = latent
23892515
if flip_aug:

0 commit comments

Comments
 (0)