@@ -526,11 +526,11 @@ def mixup(self, data, target):
526
526
target_one_hot = target
527
527
if not self .state .is_train or np .random .rand () > self .prob :
528
528
return data , target_one_hot
529
- prev_data , prev_target = data , target_one_hot if self .prev_input is None else self .prev_input
529
+ prev_data , prev_target = (data , target_one_hot ) if self .prev_input is None else self .prev_input
530
+ self .prev_input = data , target_one_hot
530
531
c = self .tb .sample ()
531
532
md = c * data + (1 - c ) * prev_data
532
533
mt = c * target_one_hot + (1 - c ) * prev_target
533
- self .prev_input = data , target_one_hot
534
534
return md , mt
535
535
536
536
@@ -569,16 +569,17 @@ def cutmix(self, data, target):
569
569
target_one_hot = target
570
570
if not self .state .is_train or np .random .rand () > self .prob :
571
571
return data , target_one_hot
572
- prev_data , prev_target = data , target_one_hot if self .prev_input is None else self .prev_input
573
- _ , _ , H , W = data .size ()
572
+ prev_data , prev_target = (data , target_one_hot ) if self .prev_input is None else self .prev_input
573
+ self .prev_input = data , target_one_hot
574
+ # prev_data shape can be different from current. so need to take min
575
+ H , W = min (data .size (2 ), prev_data .size (2 )), min (data .size (3 ), prev_data .size (3 ))
574
576
lam = self .tb .sample ()
575
577
lam = min ([lam , 1 - lam ])
576
578
bbh1 , bbw1 , bbh2 , bbw2 = self .rand_bbox (H , W , lam )
577
579
# real lambda may be diffrent from sampled. adjust for it
578
580
lam = (bbh2 - bbh1 ) * (bbw2 - bbw1 ) / (H * W )
579
- data [:, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_data [:, bbh1 :bbh2 , bbw1 :bbw2 ]
581
+ data [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_data [:, :, bbh1 :bbh2 , bbw1 :bbw2 ]
580
582
mixed_target = (1 - lam ) * target_one_hot + lam * prev_target
581
- self .prev_input = data , target_one_hot
582
583
return data , mixed_target
583
584
584
585
@staticmethod
@@ -607,12 +608,12 @@ def __init__(self, alpha=1.0, prob=0.5):
607
608
def cutmix (self , data , target ):
608
609
if not self .state .is_train or np .random .rand () > self .prob :
609
610
return data , target
610
- prev_data , prev_target = data , target if self .prev_input is None else self .prev_input
611
- _ , _ , H , W = data .size ()
611
+ prev_data , prev_target = (data , target ) if self .prev_input is None else self .prev_input
612
+ self .prev_input = data , target
613
+ H , W = min (data .size (2 ), prev_data .size (2 )), min (data .size (3 ), prev_data .size (3 ))
612
614
lam = self .tb .sample ()
613
615
lam = min ([lam , 1 - lam ])
614
616
bbh1 , bbw1 , bbh2 , bbw2 = self .rand_bbox (H , W , lam )
615
617
data [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_data [:, :, bbh1 :bbh2 , bbw1 :bbw2 ]
616
618
target [:, :, bbh1 :bbh2 , bbw1 :bbw2 ] = prev_target [:, :, bbh1 :bbh2 , bbw1 :bbw2 ]
617
- self .prev_input = data , target
618
619
return data , target
0 commit comments