Skip to content

Commit 4b6e145

Browse files
authored
Changes (#53)
* fix bug in cutmix (again) * support output_stride in BiFPN * fixed cutmix bug connected to `prev_data` shape
1 parent 0a7c4da commit 4b6e145

File tree

4 files changed

+19
-14
lines changed

4 files changed

+19
-14
lines changed

pytorch_tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.0"
1+
__version__ = "0.1.1"
22

33
from . import fit_wrapper
44
from . import losses

pytorch_tools/fit_wrapper/callbacks.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -526,11 +526,11 @@ def mixup(self, data, target):
526526
target_one_hot = target
527527
if not self.state.is_train or np.random.rand() > self.prob:
528528
return data, target_one_hot
529-
prev_data, prev_target = data, target_one_hot if self.prev_input is None else self.prev_input
529+
prev_data, prev_target = (data, target_one_hot) if self.prev_input is None else self.prev_input
530+
self.prev_input = data, target_one_hot
530531
c = self.tb.sample()
531532
md = c * data + (1 - c) * prev_data
532533
mt = c * target_one_hot + (1 - c) * prev_target
533-
self.prev_input = data, target_one_hot
534534
return md, mt
535535

536536

@@ -569,16 +569,17 @@ def cutmix(self, data, target):
569569
target_one_hot = target
570570
if not self.state.is_train or np.random.rand() > self.prob:
571571
return data, target_one_hot
572-
prev_data, prev_target = data, target_one_hot if self.prev_input is None else self.prev_input
573-
_, _, H, W = data.size()
572+
prev_data, prev_target = (data, target_one_hot) if self.prev_input is None else self.prev_input
573+
self.prev_input = data, target_one_hot
574+
# prev_data shape can be different from current. so need to take min
575+
H, W = min(data.size(2), prev_data.size(2)), min(data.size(3), prev_data.size(3))
574576
lam = self.tb.sample()
575577
lam = min([lam, 1 - lam])
576578
bbh1, bbw1, bbh2, bbw2 = self.rand_bbox(H, W, lam)
577579
# real lambda may be diffrent from sampled. adjust for it
578580
lam = (bbh2 - bbh1) * (bbw2 - bbw1) / (H * W)
579-
data[:, bbh1:bbh2, bbw1:bbw2] = prev_data[:, bbh1:bbh2, bbw1:bbw2]
581+
data[:, :, bbh1:bbh2, bbw1:bbw2] = prev_data[:, :, bbh1:bbh2, bbw1:bbw2]
580582
mixed_target = (1 - lam) * target_one_hot + lam * prev_target
581-
self.prev_input = data, target_one_hot
582583
return data, mixed_target
583584

584585
@staticmethod
@@ -607,12 +608,12 @@ def __init__(self, alpha=1.0, prob=0.5):
607608
def cutmix(self, data, target):
608609
if not self.state.is_train or np.random.rand() > self.prob:
609610
return data, target
610-
prev_data, prev_target = data, target if self.prev_input is None else self.prev_input
611-
_, _, H, W = data.size()
611+
prev_data, prev_target = (data, target) if self.prev_input is None else self.prev_input
612+
self.prev_input = data, target
613+
H, W = min(data.size(2), prev_data.size(2)), min(data.size(3), prev_data.size(3))
612614
lam = self.tb.sample()
613615
lam = min([lam, 1 - lam])
614616
bbh1, bbw1, bbh2, bbw2 = self.rand_bbox(H, W, lam)
615617
data[:, :, bbh1:bbh2, bbw1:bbw2] = prev_data[:, :, bbh1:bbh2, bbw1:bbw2]
616618
target[:, :, bbh1:bbh2, bbw1:bbw2] = prev_target[:, :, bbh1:bbh2, bbw1:bbw2]
617-
self.prev_input = data, target
618619
return data, target

pytorch_tools/modules/bifpn.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,15 @@ class BiFPNLayer(nn.Module):
4141
p_out: features processed by 1 layer of BiFPN
4242
"""
4343

44-
def __init__(self, channels=64, upsample_mode="nearest", **bn_args):
44+
def __init__(self, channels=64, output_stride=32, upsample_mode="nearest", **bn_args):
4545
super(BiFPNLayer, self).__init__()
4646

4747
self.up = nn.Upsample(scale_factor=2, mode=upsample_mode)
48+
self.first_up = self.up if output_stride == 32 else nn.Identity()
49+
last_stride = 2 if output_stride == 32 else 1
4850
self.down_p2 = DepthwiseSeparableConv(channels, channels, stride=2, **bn_args)
4951
self.down_p3 = DepthwiseSeparableConv(channels, channels, stride=2, **bn_args)
50-
self.down_p4 = DepthwiseSeparableConv(channels, channels, stride=2, **bn_args)
52+
self.down_p4 = DepthwiseSeparableConv(channels, channels, stride=last_stride, **bn_args)
5153

5254
## TODO (jamil) 11.02.2020 Rewrite this using list comprehensions
5355
self.fuse_p4_td = FastNormalizedFusion(in_nodes=2)
@@ -75,7 +77,7 @@ def forward(self, features):
7577
p5_inp, p4_inp, p3_inp, p2_inp, p1_inp = features
7678

7779
# Top-down pathway
78-
p4_td = self.p4_td(self.fuse_p4_td(p4_inp, self.up(p5_inp)))
80+
p4_td = self.p4_td(self.fuse_p4_td(p4_inp, self.first_up(p5_inp)))
7981
p3_td = self.p3_td(self.fuse_p3_td(p3_inp, self.up(p4_td)))
8082
p2_out = self.p2_td(self.fuse_p2_td(p2_inp, self.up(p3_td)))
8183

@@ -134,6 +136,7 @@ def __init__(
134136
encoder_channels,
135137
pyramid_channels=64,
136138
num_layers=1,
139+
output_stride=32,
137140
**bn_args,
138141
):
139142
super(BiFPN, self).__init__()
@@ -142,7 +145,7 @@ def __init__(
142145

143146
bifpns = []
144147
for _ in range(num_layers):
145-
bifpns.append(BiFPNLayer(pyramid_channels, **bn_args))
148+
bifpns.append(BiFPNLayer(pyramid_channels, output_stride, **bn_args))
146149
self.bifpn = nn.Sequential(*bifpns)
147150

148151
def forward(self, features):

pytorch_tools/segmentation_models/segm_fpn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
self.encoder.out_shapes,
9898
pyramid_channels=pyramid_channels,
9999
num_layers=num_fpn_layers,
100+
output_stride=output_stride,
100101
**bn_args,
101102
)
102103

0 commit comments

Comments
 (0)