@@ -409,6 +409,7 @@ def __init__(
409
409
410
410
self .alpha_mask = alpha_mask
411
411
412
+
412
413
class DreamBoothSubset (BaseSubset ):
413
414
def __init__ (
414
415
self ,
@@ -417,13 +418,47 @@ def __init__(
417
418
class_tokens : Optional [str ],
418
419
caption_extension : str ,
419
420
cache_info : bool ,
420
- ** kwargs ,
421
+ num_repeats ,
422
+ shuffle_caption ,
423
+ caption_separator : str ,
424
+ keep_tokens ,
425
+ keep_tokens_separator ,
426
+ secondary_separator ,
427
+ enable_wildcard ,
428
+ color_aug ,
429
+ flip_aug ,
430
+ face_crop_aug_range ,
431
+ random_crop ,
432
+ caption_dropout_rate ,
433
+ caption_dropout_every_n_epochs ,
434
+ caption_tag_dropout_rate ,
435
+ caption_prefix ,
436
+ caption_suffix ,
437
+ token_warmup_min ,
438
+ token_warmup_step ,
421
439
) -> None :
422
440
assert image_dir is not None , "image_dir must be specified / image_dirは指定が必須です"
423
441
424
442
super ().__init__ (
425
443
image_dir ,
426
- ** kwargs ,
444
+ num_repeats ,
445
+ shuffle_caption ,
446
+ caption_separator ,
447
+ keep_tokens ,
448
+ keep_tokens_separator ,
449
+ secondary_separator ,
450
+ enable_wildcard ,
451
+ color_aug ,
452
+ flip_aug ,
453
+ face_crop_aug_range ,
454
+ random_crop ,
455
+ caption_dropout_rate ,
456
+ caption_dropout_every_n_epochs ,
457
+ caption_tag_dropout_rate ,
458
+ caption_prefix ,
459
+ caption_suffix ,
460
+ token_warmup_min ,
461
+ token_warmup_step ,
427
462
)
428
463
429
464
self .is_reg = is_reg
@@ -444,13 +479,47 @@ def __init__(
444
479
self ,
445
480
image_dir ,
446
481
metadata_file : str ,
447
- ** kwargs ,
482
+ num_repeats ,
483
+ shuffle_caption ,
484
+ caption_separator ,
485
+ keep_tokens ,
486
+ keep_tokens_separator ,
487
+ secondary_separator ,
488
+ enable_wildcard ,
489
+ color_aug ,
490
+ flip_aug ,
491
+ face_crop_aug_range ,
492
+ random_crop ,
493
+ caption_dropout_rate ,
494
+ caption_dropout_every_n_epochs ,
495
+ caption_tag_dropout_rate ,
496
+ caption_prefix ,
497
+ caption_suffix ,
498
+ token_warmup_min ,
499
+ token_warmup_step ,
448
500
) -> None :
449
501
assert metadata_file is not None , "metadata_file must be specified / metadata_fileは指定が必須です"
450
502
451
503
super ().__init__ (
452
504
image_dir ,
453
- ** kwargs ,
505
+ num_repeats ,
506
+ shuffle_caption ,
507
+ caption_separator ,
508
+ keep_tokens ,
509
+ keep_tokens_separator ,
510
+ secondary_separator ,
511
+ enable_wildcard ,
512
+ color_aug ,
513
+ flip_aug ,
514
+ face_crop_aug_range ,
515
+ random_crop ,
516
+ caption_dropout_rate ,
517
+ caption_dropout_every_n_epochs ,
518
+ caption_tag_dropout_rate ,
519
+ caption_prefix ,
520
+ caption_suffix ,
521
+ token_warmup_min ,
522
+ token_warmup_step ,
454
523
)
455
524
456
525
self .metadata_file = metadata_file
@@ -468,13 +537,47 @@ def __init__(
468
537
conditioning_data_dir : str ,
469
538
caption_extension : str ,
470
539
cache_info : bool ,
471
- ** kwargs ,
540
+ num_repeats ,
541
+ shuffle_caption ,
542
+ caption_separator ,
543
+ keep_tokens ,
544
+ keep_tokens_separator ,
545
+ secondary_separator ,
546
+ enable_wildcard ,
547
+ color_aug ,
548
+ flip_aug ,
549
+ face_crop_aug_range ,
550
+ random_crop ,
551
+ caption_dropout_rate ,
552
+ caption_dropout_every_n_epochs ,
553
+ caption_tag_dropout_rate ,
554
+ caption_prefix ,
555
+ caption_suffix ,
556
+ token_warmup_min ,
557
+ token_warmup_step ,
472
558
) -> None :
473
559
assert image_dir is not None , "image_dir must be specified / image_dirは指定が必須です"
474
560
475
561
super ().__init__ (
476
562
image_dir ,
477
- ** kwargs ,
563
+ num_repeats ,
564
+ shuffle_caption ,
565
+ caption_separator ,
566
+ keep_tokens ,
567
+ keep_tokens_separator ,
568
+ secondary_separator ,
569
+ enable_wildcard ,
570
+ color_aug ,
571
+ flip_aug ,
572
+ face_crop_aug_range ,
573
+ random_crop ,
574
+ caption_dropout_rate ,
575
+ caption_dropout_every_n_epochs ,
576
+ caption_tag_dropout_rate ,
577
+ caption_prefix ,
578
+ caption_suffix ,
579
+ token_warmup_min ,
580
+ token_warmup_step ,
478
581
)
479
582
480
583
self .conditioning_data_dir = conditioning_data_dir
@@ -1100,10 +1203,12 @@ def __getitem__(self, index):
1100
1203
else :
1101
1204
latents = image_info .latents_flipped
1102
1205
alpha_mask = image_info .alpha_mask_flipped
1103
-
1206
+
1104
1207
image = None
1105
1208
elif image_info .latents_npz is not None : # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
1106
- latents , original_size , crop_ltrb , flipped_latents , alpha_mask , flipped_alpha_mask = load_latents_from_disk (image_info .latents_npz )
1209
+ latents , original_size , crop_ltrb , flipped_latents , alpha_mask , flipped_alpha_mask = load_latents_from_disk (
1210
+ image_info .latents_npz
1211
+ )
1107
1212
if flipped :
1108
1213
latents = flipped_latents
1109
1214
alpha_mask = flipped_alpha_mask
@@ -1116,7 +1221,9 @@ def __getitem__(self, index):
1116
1221
image = None
1117
1222
else :
1118
1223
# 画像を読み込み、必要ならcropする
1119
- img , face_cx , face_cy , face_w , face_h = self .load_image_with_face_info (subset , image_info .absolute_path , subset .alpha_mask )
1224
+ img , face_cx , face_cy , face_w , face_h = self .load_image_with_face_info (
1225
+ subset , image_info .absolute_path , subset .alpha_mask
1226
+ )
1120
1227
im_h , im_w = img .shape [0 :2 ]
1121
1228
1122
1229
if self .enable_bucket :
@@ -1157,7 +1264,7 @@ def __getitem__(self, index):
1157
1264
if img .shape [2 ] == 4 :
1158
1265
alpha_mask = img [:, :, 3 ] # [W,H]
1159
1266
else :
1160
- alpha_mask = np .full ((im_w , im_h ), 255 , dtype = np .uint8 ) # [W,H]
1267
+ alpha_mask = np .full ((im_w , im_h ), 255 , dtype = np .uint8 ) # [W,H]
1161
1268
alpha_mask = transforms .ToTensor ()(alpha_mask )
1162
1269
else :
1163
1270
alpha_mask = None
@@ -2070,7 +2177,14 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool):
2070
2177
# 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top)
2071
2178
def load_latents_from_disk (
2072
2179
npz_path ,
2073
- ) -> Tuple [Optional [torch .Tensor ], Optional [List [int ]], Optional [List [int ]], Optional [torch .Tensor ], Optional [torch .Tensor ], Optional [torch .Tensor ]]:
2180
+ ) -> Tuple [
2181
+ Optional [torch .Tensor ],
2182
+ Optional [List [int ]],
2183
+ Optional [List [int ]],
2184
+ Optional [torch .Tensor ],
2185
+ Optional [torch .Tensor ],
2186
+ Optional [torch .Tensor ],
2187
+ ]:
2074
2188
npz = np .load (npz_path )
2075
2189
if "latents" not in npz :
2076
2190
raise ValueError (f"error: npz is old format. please re-generate { npz_path } " )
@@ -2084,7 +2198,9 @@ def load_latents_from_disk(
2084
2198
return latents , original_size , crop_ltrb , flipped_latents , alpha_mask , flipped_alpha_mask
2085
2199
2086
2200
2087
- def save_latents_to_disk (npz_path , latents_tensor , original_size , crop_ltrb , flipped_latents_tensor = None , alpha_mask = None , flipped_alpha_mask = None ):
2201
+ def save_latents_to_disk (
2202
+ npz_path , latents_tensor , original_size , crop_ltrb , flipped_latents_tensor = None , alpha_mask = None , flipped_alpha_mask = None
2203
+ ):
2088
2204
kwargs = {}
2089
2205
if flipped_latents_tensor is not None :
2090
2206
kwargs ["latents_flipped" ] = flipped_latents_tensor .float ().cpu ().numpy ()
@@ -2344,10 +2460,10 @@ def cache_batch_latents(
2344
2460
image , original_size , crop_ltrb = trim_and_resize_if_required (random_crop , image , info .bucket_reso , info .resized_size )
2345
2461
if info .use_alpha_mask :
2346
2462
if image .shape [2 ] == 4 :
2347
- alpha_mask = image [:, :, 3 ] # [W,H]
2463
+ alpha_mask = image [:, :, 3 ] # [W,H]
2348
2464
image = image [:, :, :3 ]
2349
2465
else :
2350
- alpha_mask = np .full_like (image [:, :, 0 ], 255 , dtype = np .uint8 ) # [W,H]
2466
+ alpha_mask = np .full_like (image [:, :, 0 ], 255 , dtype = np .uint8 ) # [W,H]
2351
2467
alpha_masks .append (transforms .ToTensor ()(alpha_mask ))
2352
2468
image = IMAGE_TRANSFORMS (image )
2353
2469
images .append (image )
@@ -2377,13 +2493,23 @@ def cache_batch_latents(
2377
2493
flipped_latents = [None ] * len (latents )
2378
2494
flipped_alpha_masks = [None ] * len (image_infos )
2379
2495
2380
- for info , latent , flipped_latent , alpha_mask , flipped_alpha_mask in zip (image_infos , latents , flipped_latents , alpha_masks , flipped_alpha_masks ):
2496
+ for info , latent , flipped_latent , alpha_mask , flipped_alpha_mask in zip (
2497
+ image_infos , latents , flipped_latents , alpha_masks , flipped_alpha_masks
2498
+ ):
2381
2499
# check NaN
2382
2500
if torch .isnan (latents ).any () or (flipped_latent is not None and torch .isnan (flipped_latent ).any ()):
2383
2501
raise RuntimeError (f"NaN detected in latents: { info .absolute_path } " )
2384
2502
2385
2503
if cache_to_disk :
2386
- save_latents_to_disk (info .latents_npz , latent , info .latents_original_size , info .latents_crop_ltrb , flipped_latent , alpha_mask , flipped_alpha_mask )
2504
+ save_latents_to_disk (
2505
+ info .latents_npz ,
2506
+ latent ,
2507
+ info .latents_original_size ,
2508
+ info .latents_crop_ltrb ,
2509
+ flipped_latent ,
2510
+ alpha_mask ,
2511
+ flipped_alpha_mask ,
2512
+ )
2387
2513
else :
2388
2514
info .latents = latent
2389
2515
if flip_aug :
0 commit comments