Skip to content

Commit 218b5ac

Browse files
authored
Merge pull request #54 from bonlime/dev
Merge last month commits
2 parents 4b6e145 + 1e86b17 commit 218b5ac

35 files changed

+1165
-534
lines changed

pytorch_tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.1"
1+
__version__ = "0.1.2"
22

33
from . import fit_wrapper
44
from . import losses

pytorch_tools/detection_models/__init__.py

Whitespace-only changes.
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from pytorch_tools.modules.fpn import FPN
5+
# from pytorch_tools.modules.bifpn import BiFPN
6+
from pytorch_tools.modules import bn_from_name
7+
# from pytorch_tools.modules.residual import conv1x1
8+
from pytorch_tools.modules.residual import conv3x3
9+
# from pytorch_tools.modules.decoder import SegmentationUpsample
10+
# from pytorch_tools.utils.misc import initialize
11+
from pytorch_tools.segmentation_models.encoders import get_encoder
12+
13+
14+
class RetinaNet(nn.Module):
15+
def __init__(
16+
self,
17+
encoder_name="resnet34",
18+
encoder_weights="imagenet",
19+
pyramid_channels=256,
20+
num_classes=80,
21+
norm_layer="abn",
22+
norm_act="relu",
23+
**encoder_params,
24+
):
25+
super().__init__()
26+
self.encoder = get_encoder(
27+
encoder_name,
28+
norm_layer=norm_layer,
29+
norm_act=norm_act,
30+
encoder_weights=encoder_weights,
31+
**encoder_params,
32+
)
33+
norm_layer = bn_from_name(norm_layer)
34+
self.pyramid6 = conv3x3(256, 256, 2, bias=True)
35+
self.pyramid7 = conv3x3(256, 256, 2, bias=True)
36+
self.fpn = FPN(
37+
self.encoder.out_shapes[:-2],
38+
pyramid_channels=pyramid_channels,
39+
)
40+
41+
def make_head(out_size):
42+
layers = []
43+
for _ in range(4):
44+
# some implementations don't use BN here but I think it's needed
45+
# TODO: test how it affects results
46+
layers += [nn.Conv2d(256, 256, 3, padding=1), norm_layer(256, activation=norm_act)]
47+
# layers += [nn.Conv2d(256, 256, 3, padding=1), nn.ReLU()]
48+
49+
layers += [nn.Conv2d(256, out_size, 3, padding=1)]
50+
return nn.Sequential(*layers)
51+
52+
self.ratios = [1.0, 2.0, 0.5]
53+
self.scales = [4 * 2 ** (i / 3) for i in range(3)]
54+
anchors = len(self.ratios) * len(self.scales) # 9
55+
56+
self.cls_head = make_head(num_classes * anchors)
57+
self.box_head = make_head(4 * anchors)
58+
59+
def forward(self, x):
60+
# don't use p2 and p1
61+
p5, p4, p3, _, _ = self.encoder(x)
62+
# enhance features
63+
p5, p4, p3 = self.fpn([p5, p4, p3])
64+
# coarsers FPN levels
65+
p6 = self.pyramid6(p5)
66+
p7 = self.pyramid7(F.relu(p6))
67+
features = [p7, p6, p5, p4, p3]
68+
# TODO: (18.03.20) TF implementation has additional BN here before class/box outputs
69+
class_outputs = [self.cls_head(f) for f in features]
70+
box_outputs = [self.box_head(f) for f in features]
71+
return class_outputs, box_outputs
72+
73+
74+

pytorch_tools/fit_wrapper/callbacks.py

Lines changed: 71 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -271,38 +271,49 @@ class CheckpointSaver(Callback):
271271
272272
Args:
273273
save_dir (str): path to folder where to save the model
274-
save_name (str): name of the saved model. can additionally
274+
save_name (str): name of the saved model. can additionally
275275
add epoch and metric to model save name
276+
monitor (str): quantity to monitor. Implicitly prefers validation metrics over train. One of:
277+
`loss` or name of any metric passed to the runner.
276278
mode (str): one of "min" of "max". Whether to decide to save based
277279
on minimizing or maximizing loss
278-
include_optimizer (bool): if True would also save `optimizers` state_dict.
280+
include_optimizer (bool): if True would also save `optimizers` state_dict.
279281
This increases checkpoint size 2x times.
282+
verbose (bool): If `True` reports each time new best is found
280283
"""
281284

282285
def __init__(
283-
self, save_dir, save_name="model_{ep}_{metric:.2f}.chpn", mode="min", include_optimizer=False
286+
self,
287+
save_dir,
288+
save_name="model_{ep}_{metric:.2f}.chpn",
289+
monitor="loss",
290+
mode="min",
291+
include_optimizer=False,
292+
verbose=True,
284293
):
285294
super().__init__()
286295
self.save_dir = save_dir
287296
self.save_name = save_name
288-
self.mode = ReduceMode(mode)
289-
self.best = float("inf") if self.mode == ReduceMode.MIN else -float("inf")
297+
self.monitor = monitor
298+
mode = ReduceMode(mode)
299+
if mode == ReduceMode.MIN:
300+
self.best = np.inf
301+
self.monitor_op = np.less
302+
elif mode == ReduceMode.MAX:
303+
self.best = -np.inf
304+
self.monitor_op = np.greater
290305
self.include_optimizer = include_optimizer
306+
self.verbose = verbose
291307

292308
def on_begin(self):
293309
os.makedirs(self.save_dir, exist_ok=True)
294310

295311
def on_epoch_end(self):
296-
# TODO zakirov(1.11.19) Add support for saving based on metric
297-
if self.state.val_loss is not None:
298-
current = self.state.val_loss.avg
299-
else:
300-
current = self.state.train_loss.avg
301-
if (self.mode == ReduceMode.MIN and current < self.best) or (
302-
self.mode == ReduceMode.MAX and current > self.best
303-
):
304-
ep = self.state.epoch
305-
# print(f"Epoch {ep}: best loss improved from {self.best:.4f} to {current:.4f}")
312+
current = self.get_monitor_value()
313+
if self.monitor_op(current, self.best):
314+
ep = self.state.epoch_log
315+
if self.verbose:
316+
print(f"Epoch {ep:2d}: best {self.monitor} improved from {self.best:.4f} to {current:.4f}")
306317
self.best = current
307318
save_name = os.path.join(self.save_dir, self.save_name.format(ep=ep, metric=current))
308319
self._save_checkpoint(save_name)
@@ -317,6 +328,18 @@ def _save_checkpoint(self, path):
317328
save_dict["optimizer"] = self.state.optimizer.state_dict()
318329
torch.save(save_dict, path)
319330

331+
def get_monitor_value(self):
332+
value = None
333+
if self.monitor == "loss":
334+
value = self.state.loss_meter.avg
335+
else:
336+
for metric_meter in self.state.metric_meters:
337+
if metric_meter.name == self.monitor:
338+
value = metric_meter.avg
339+
if value is None:
340+
raise ValueError(f"CheckpointSaver can't find {self.monitor} value to monitor")
341+
return value
342+
320343

321344
class TensorBoard(Callback):
322345
"""
@@ -407,7 +430,7 @@ def on_batch_end(self):
407430

408431
def on_loader_end(self):
409432
super().on_loader_end()
410-
f = plot_confusion_matrix(self.cmap, self.class_names, show=False)
433+
f = plot_confusion_matrix(self.cmap, self.class_names, normalize=True, show=False)
411434
cm_img = render_figure_to_tensor(f)
412435
if self.state.is_train:
413436
self.train_cm_img = cm_img
@@ -527,10 +550,11 @@ def mixup(self, data, target):
527550
if not self.state.is_train or np.random.rand() > self.prob:
528551
return data, target_one_hot
529552
prev_data, prev_target = (data, target_one_hot) if self.prev_input is None else self.prev_input
530-
self.prev_input = data, target_one_hot
553+
self.prev_input = data.clone(), target_one_hot.clone()
554+
perm = torch.randperm(data.size(0)).cuda()
531555
c = self.tb.sample()
532-
md = c * data + (1 - c) * prev_data
533-
mt = c * target_one_hot + (1 - c) * prev_target
556+
md = c * data + (1 - c) * prev_data[perm]
557+
mt = c * target_one_hot + (1 - c) * prev_target[perm]
534558
return md, mt
535559

536560

@@ -570,16 +594,17 @@ def cutmix(self, data, target):
570594
if not self.state.is_train or np.random.rand() > self.prob:
571595
return data, target_one_hot
572596
prev_data, prev_target = (data, target_one_hot) if self.prev_input is None else self.prev_input
573-
self.prev_input = data, target_one_hot
597+
self.prev_input = data.clone(), target_one_hot.clone()
574598
# prev_data shape can be different from current. so need to take min
575599
H, W = min(data.size(2), prev_data.size(2)), min(data.size(3), prev_data.size(3))
600+
perm = torch.randperm(data.size(0)).cuda()
576601
lam = self.tb.sample()
577602
lam = min([lam, 1 - lam])
578603
bbh1, bbw1, bbh2, bbw2 = self.rand_bbox(H, W, lam)
579604
# real lambda may be diffrent from sampled. adjust for it
580605
lam = (bbh2 - bbh1) * (bbw2 - bbw1) / (H * W)
581-
data[:, :, bbh1:bbh2, bbw1:bbw2] = prev_data[:, :, bbh1:bbh2, bbw1:bbw2]
582-
mixed_target = (1 - lam) * target_one_hot + lam * prev_target
606+
data[:, :, bbh1:bbh2, bbw1:bbw2] = prev_data[perm, :, bbh1:bbh2, bbw1:bbw2]
607+
mixed_target = (1 - lam) * target_one_hot + lam * prev_target[perm]
583608
return data, mixed_target
584609

585610
@staticmethod
@@ -609,11 +634,32 @@ def cutmix(self, data, target):
609634
if not self.state.is_train or np.random.rand() > self.prob:
610635
return data, target
611636
prev_data, prev_target = (data, target) if self.prev_input is None else self.prev_input
612-
self.prev_input = data, target
637+
self.prev_input = data.clone(), target.clone()
613638
H, W = min(data.size(2), prev_data.size(2)), min(data.size(3), prev_data.size(3))
639+
perm = torch.randperm(data.size(0)).cuda()
614640
lam = self.tb.sample()
615641
lam = min([lam, 1 - lam])
616642
bbh1, bbw1, bbh2, bbw2 = self.rand_bbox(H, W, lam)
617-
data[:, :, bbh1:bbh2, bbw1:bbw2] = prev_data[:, :, bbh1:bbh2, bbw1:bbw2]
618-
target[:, :, bbh1:bbh2, bbw1:bbw2] = prev_target[:, :, bbh1:bbh2, bbw1:bbw2]
643+
data[:, :, bbh1:bbh2, bbw1:bbw2] = prev_data[perm, :, bbh1:bbh2, bbw1:bbw2]
644+
target[:, :, bbh1:bbh2, bbw1:bbw2] = prev_target[perm, :, bbh1:bbh2, bbw1:bbw2]
619645
return data, target
646+
647+
648+
class ScheduledDropout(Callback):
649+
def __init__(self, drop_rate=0.1, epochs=30, attr_name="dropout.p"):
650+
"""
651+
Slowly changes dropout value for `attr_name` each epoch.
652+
Ref: https://arxiv.org/abs/1703.06229
653+
Args:
654+
drop_rate (float): max dropout rate
655+
epochs (int): num epochs to max dropout to fully take effect
656+
attr_name (str): name of dropout block in model
657+
"""
658+
super().__init__()
659+
self.drop_rate = drop_rate
660+
self.epochs = epochs
661+
self.attr_name = attr_name
662+
663+
def on_epoch_end(self):
664+
current_rate = self.drop_rate * min(1, self.state.epoch / self.epochs)
665+
setattr(self.state.model, self.attr_name, current_rate)

pytorch_tools/fit_wrapper/wrapper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ def __init__(self, model, optimizer, criterion, metrics=None, callbacks=ConsoleL
2222
super().__init__()
2323

2424
if not hasattr(amp._amp_state, "opt_properties"):
25-
model, optimizer = amp.initialize(model, optimizer, enabled=False)
25+
model_optimizer = amp.initialize(model, optimizer, enabled=False)
26+
model, optimizer = (model_optimizer, None) if optimizer is None else model_optimizer
2627

2728
self.state = RunnerState(model=model, optimizer=optimizer, criterion=criterion, metrics=metrics,)
2829
self.callbacks = Callbacks(callbacks)

pytorch_tools/losses/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
import torch.nn as nn
33

44
from .base import Loss
5-
from .focal import BinaryFocalLoss, FocalLoss
6-
from .dice_jaccard import DiceLoss, JaccardLoss
5+
from .focal import FocalLoss
6+
from .dice_jaccard import DiceLoss
7+
from .dice_jaccard import JaccardLoss
78
from .lovasz import LovaszLoss
89
from .wing_loss import WingLoss
910
from .vgg_loss import ContentLoss, StyleLoss
1011
from .smooth import CrossEntropyLoss
1112
from .hinge import BinaryHinge
1213

13-
from .functional import sigmoid_focal_loss
14+
from .functional import focal_loss_with_logits
1415
from .functional import soft_dice_score
1516
from .functional import soft_jaccard_score
1617
from .functional import wing_loss

pytorch_tools/losses/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@ class Mode(Enum):
88
MULTICLASS = "multiclass"
99
MULTILABEL = "multilabel"
1010

11-
11+
class Reduction(Enum):
12+
SUM = "sum"
13+
MEAN = "mean"
14+
NONE = "none"
15+
1216
class Loss(_Loss):
1317
"""Loss which supports addition and multiplication"""
1418

pytorch_tools/losses/dice_jaccard.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,22 @@ class DiceLoss(Loss):
77
"""
88
Implementation of Dice loss for image segmentation task.
99
It supports binary, multiclass and multilabel cases
10+
11+
Args:
12+
mode (str): Target mode {'binary', 'multiclass', 'multilabel'}
13+
'multilabel' - expects y_true of shape [N, C, H, W]
14+
'multiclass', 'binary' - expects y_true of shape [N, H, W]
15+
log_loss (bool): If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
16+
from_logits (bool): If True assumes input is raw logits
17+
eps (float): small epsilon for numerical stability
18+
Shape:
19+
y_pred: [N, C, H, W]
20+
y_true: [N, C, H, W] or [N, H, W] depending on mode
1021
"""
1122

1223
IOU_FUNCTION = soft_dice_score
1324

1425
def __init__(self, mode="binary", log_loss=False, from_logits=True, eps=1.):
15-
"""
16-
Args:
17-
mode (str): Target mode {'binary', 'multiclass', 'multilabel'}
18-
'multilabel' - expects y_true of shape [N, C, H, W]
19-
'multiclass', 'binary' - expects y_true of shape [N, H, W]
20-
log_loss (bool): If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
21-
from_logits (bool): If True assumes input is raw logits
22-
eps (float): small epsilon for numerical stability
23-
Shape:
24-
y_pred: [N, C, H, W]
25-
y_true: [N, C, H, W] or [N, H, W] depending on mode
26-
"""
27-
2826
super(DiceLoss, self).__init__()
2927
self.mode = Mode(mode) # raises an error if not valid
3028
self.log_loss = log_loss
@@ -34,9 +32,9 @@ def __init__(self, mode="binary", log_loss=False, from_logits=True, eps=1.):
3432
def forward(self, y_pred, y_true):
3533
if self.from_logits:
3634
# Apply activations to get [0..1] class probabilities
37-
if self.mode == Mode.BINARY:
35+
if self.mode == Mode.BINARY or self.mode == Mode.MULTILABEL:
3836
y_pred = y_pred.sigmoid()
39-
else:
37+
elif self.mode == Mode.MULTICLASS:
4038
y_pred = y_pred.softmax(dim=1)
4139

4240
bs = y_true.size(0)
@@ -74,6 +72,17 @@ class JaccardLoss(DiceLoss):
7472
"""
7573
Implementation of Jaccard loss for image segmentation task.
7674
It supports binary, multiclass and multilabel cases
75+
76+
Args:
77+
mode (str): Target mode {'binary', 'multiclass', 'multilabel'}
78+
'multilabel' - expects y_true of shape [N, C, H, W]
79+
'multiclass', 'binary' - expects y_true of shape [N, H, W]
80+
log_loss (bool): If True, loss computed as `-log(jaccard)`; otherwise `1 - jaccard`
81+
from_logits (bool): If True assumes input is raw logits
82+
eps (float): small epsilon for numerical stability
83+
Shape:
84+
y_pred: [N, C, H, W]
85+
y_true: [N, C, H, W] or [N, H, W] depending on mode
7786
"""
7887

7988
# the only difference is which function to use

0 commit comments

Comments
 (0)