Skip to content

Commit e73ec1f

Browse files
authored
Merge pull request #60 from bonlime/dev
Huge release. Many things improved
2 parents 3119665 + 0568102 commit e73ec1f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

60 files changed

+2476
-560
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ Tool box for PyTorch for fast prototyping.
1212
* [TTA wrapper](./pytorch_tools/tta_wrapper/) - wrapper for easy test-time augmentation
1313

1414
# Installation
15-
Requeres GPU drivers and CUDA already installed.
15+
Requires GPU drivers and CUDA already installed.
1616

1717
`pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" git+https://github.com/NVIDIA/apex.git`
1818
`pip install git+https://github.com/bonlime/pytorch-tools.git@master`

pytorch_tools/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.4"
1+
__version__ = "0.1.5"
22

33
from . import fit_wrapper
44
from . import losses
@@ -9,3 +9,4 @@
99
from . import segmentation_models
1010
from . import tta_wrapper
1111
from . import utils
12+
from . import detection_models
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from .retinanet import RetinaNet
2+
from .retinanet import retinanet_r50_fpn
3+
from .retinanet import retinanet_r101_fpn
4+
5+
from .efficientdet import EfficientDet
6+
from .efficientdet import efficientdet_d0
7+
from .efficientdet import efficientdet_d1
8+
from .efficientdet import efficientdet_d2
9+
from .efficientdet import efficientdet_d3
10+
from .efficientdet import efficientdet_d4
11+
from .efficientdet import efficientdet_d5
12+
from .efficientdet import efficientdet_d6
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
import logging
2+
from copy import deepcopy
3+
from functools import wraps
4+
5+
import torch
6+
import torch.nn as nn
7+
from torch.hub import load_state_dict_from_url
8+
9+
from pytorch_tools.modules import ABN
10+
from pytorch_tools.modules.bifpn import BiFPN
11+
from pytorch_tools.modules import bn_from_name
12+
from pytorch_tools.modules.residual import conv1x1
13+
from pytorch_tools.modules.residual import conv3x3
14+
from pytorch_tools.modules.residual import DepthwiseSeparableConv
15+
from pytorch_tools.modules.tf_same_ops import conv_to_same_conv
16+
from pytorch_tools.modules.tf_same_ops import maxpool_to_same_maxpool
17+
18+
from pytorch_tools.segmentation_models.encoders import get_encoder
19+
20+
import pytorch_tools.utils.box as box_utils
21+
from pytorch_tools.utils.misc import DEFAULT_IMAGENET_SETTINGS
22+
from pytorch_tools.utils.misc import initialize_iterator
23+
24+
25+
def patch_bn(module):
26+
"""TF ported weights use slightly different eps in BN. Need to adjust for better performance"""
27+
if isinstance(module, ABN):
28+
module.eps = 1e-3
29+
module.momentum = 1e-2
30+
for m in module.children():
31+
patch_bn(m)
32+
33+
34+
class EfficientDet(nn.Module):
35+
"""TODO: add docstring"""
36+
37+
def __init__(
38+
self,
39+
pretrained="coco", # Not used. here for proper signature
40+
encoder_name="efficientnet_d0",
41+
encoder_weights="imagenet",
42+
pyramid_channels=64,
43+
num_fpn_layers=3,
44+
num_head_repeats=3,
45+
num_classes=90,
46+
drop_connect_rate=0,
47+
encoder_norm_layer="abn", # TODO: set to frozenabn when ready
48+
encoder_norm_act="swish",
49+
decoder_norm_layer="abn",
50+
decoder_norm_act="swish",
51+
match_tf_same_padding=False,
52+
):
53+
super().__init__()
54+
self.encoder = get_encoder(
55+
encoder_name,
56+
norm_layer=encoder_norm_layer,
57+
norm_act=encoder_norm_act,
58+
encoder_weights=encoder_weights,
59+
)
60+
norm_layer = bn_from_name(decoder_norm_layer)
61+
bn_args = dict(norm_layer=norm_layer, norm_act=decoder_norm_act)
62+
self.pyramid6 = nn.Sequential(
63+
conv1x1(self.encoder.out_shapes[0], pyramid_channels, bias=True),
64+
norm_layer(pyramid_channels, activation="identity"),
65+
nn.MaxPool2d(3, stride=2, padding=1),
66+
)
67+
self.pyramid7 = nn.MaxPool2d(3, stride=2, padding=1) # in EffDet it's a simple maxpool
68+
69+
self.bifpn = BiFPN(
70+
self.encoder.out_shapes[:-2],
71+
pyramid_channels=pyramid_channels,
72+
num_layers=num_fpn_layers,
73+
**bn_args,
74+
)
75+
76+
def make_head(out_size):
77+
layers = []
78+
for _ in range(num_head_repeats):
79+
# TODO: add drop connect
80+
layers += [DepthwiseSeparableConv(pyramid_channels, pyramid_channels, use_norm=False)]
81+
layers += [DepthwiseSeparableConv(pyramid_channels, out_size, use_norm=False)]
82+
return nn.ModuleList(layers)
83+
84+
# The convolution layers in the head are shared among all levels, but
85+
# each level has its batch normalization to capture the statistical
86+
# difference among different levels.
87+
def make_head_norm():
88+
return nn.ModuleList(
89+
[
90+
nn.ModuleList(
91+
[
92+
norm_layer(pyramid_channels, activation=decoder_norm_act)
93+
for _ in range(num_head_repeats)
94+
]
95+
+ [nn.Identity()] # no bn after last depthwise conv
96+
)
97+
for _ in range(5)
98+
]
99+
)
100+
101+
anchors_per_location = 9 # TODO: maybe allow to pass this arg?
102+
self.cls_head_convs = make_head(num_classes * anchors_per_location)
103+
self.cls_head_norms = make_head_norm()
104+
self.box_head_convs = make_head(4 * anchors_per_location)
105+
self.box_head_norms = make_head_norm()
106+
self.num_classes = num_classes
107+
self.num_head_repeats = num_head_repeats
108+
109+
patch_bn(self)
110+
self._initialize_weights()
111+
if match_tf_same_padding:
112+
conv_to_same_conv(self)
113+
maxpool_to_same_maxpool(self)
114+
115+
# Name from mmdetectin for convenience
116+
def extract_features(self, x):
117+
"""Extract features from backbone + enchance with BiFPN"""
118+
# don't use p2 and p1
119+
p5, p4, p3, _, _ = self.encoder(x)
120+
# coarser FPN levels
121+
p6 = self.pyramid6(p5)
122+
p7 = self.pyramid7(p6)
123+
features = [p7, p6, p5, p4, p3]
124+
# enhance features
125+
features = self.bifpn(features)
126+
# want features from lowest OS to highest to align with `generate_anchors_boxes` function
127+
features = list(reversed(features))
128+
return features
129+
130+
def forward(self, x):
131+
features = self.extract_features(x)
132+
class_outputs = []
133+
box_outputs = []
134+
for feat, (cls_bns, box_bns) in zip(features, zip(self.cls_head_norms, self.box_head_norms)):
135+
cls_feat, box_feat = feat, feat
136+
# it looks like that with drop_connect there is an additional residual here
137+
# TODO: need to investigate using pretrained weights
138+
for cls_conv, cls_bn in zip(self.cls_head_convs, cls_bns):
139+
cls_feat = cls_bn(cls_conv(cls_feat))
140+
for box_conv, box_bn in zip(self.box_head_convs, box_bns):
141+
box_feat = box_bn(box_conv(box_feat))
142+
143+
box_feat = box_feat.permute(0, 2, 3, 1)
144+
box_outputs.append(box_feat.contiguous().view(box_feat.shape[0], -1, 4))
145+
146+
cls_feat = cls_feat.permute(0, 2, 3, 1)
147+
class_outputs.append(cls_feat.contiguous().view(cls_feat.shape[0], -1, self.num_classes))
148+
149+
class_outputs = torch.cat(class_outputs, 1)
150+
box_outputs = torch.cat(box_outputs, 1)
151+
# my anchors are in [x1, y1, x2,y2] format while pretrained weights are in [y1, x1, y2, x2] format
152+
# it may be confusing to reorder x and y every time later so I do it once here. it gives
153+
# compatability with pretrained weigths from Google and doesn't affect training from scratch
154+
# box_outputs = box_outputs[..., [1, 0, 3, 2]] # TODO: return back
155+
return class_outputs, box_outputs
156+
157+
@torch.no_grad()
158+
def predict(self, x):
159+
"""Run forward on given images and decode raw prediction into bboxes
160+
Returns: bboxes, scores, classes
161+
"""
162+
class_outputs, box_outputs = self.forward(x)
163+
anchors = box_utils.generate_anchors_boxes(x.shape[-2:])[0]
164+
return box_utils.decode(class_outputs, box_outputs, anchors)
165+
166+
def _initialize_weights(self):
167+
# init everything except encoder
168+
no_encoder_m = [m for n, m in self.named_modules() if not "encoder" in n]
169+
initialize_iterator(no_encoder_m)
170+
# need to init last bias so that after sigmoid it's 0.01
171+
cls_bias_init = -torch.log(torch.tensor((1 - 0.01) / 0.01)) # -4.59
172+
nn.init.constant_(self.cls_head_convs[-1][1].bias, cls_bias_init)
173+
174+
175+
PRETRAIN_SETTINGS = {**DEFAULT_IMAGENET_SETTINGS, "input_size": (512, 512), "crop_pct": 1, "num_classes": 90}
176+
177+
# fmt: off
178+
CFGS = {
179+
"efficientdet_d0": {
180+
"default": {
181+
"params": {
182+
"encoder_name":"efficientnet_b0",
183+
"pyramid_channels":64,
184+
"num_fpn_layers":3,
185+
"num_head_repeats":3,
186+
},
187+
**PRETRAIN_SETTINGS,
188+
},
189+
"coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d0.pth",},
190+
},
191+
"efficientdet_d1": {
192+
"default": {
193+
"params": {
194+
"encoder_name":"efficientnet_b1",
195+
"pyramid_channels":88,
196+
"num_fpn_layers":4,
197+
"num_head_repeats":3,
198+
},
199+
**PRETRAIN_SETTINGS,
200+
"input_size": (640, 640),
201+
},
202+
"coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d1.pth",},
203+
},
204+
"efficientdet_d2": {
205+
"default": {
206+
"params": {
207+
"encoder_name":"efficientnet_b2",
208+
"pyramid_channels":112,
209+
"num_fpn_layers":5,
210+
"num_head_repeats":3,
211+
},
212+
**PRETRAIN_SETTINGS,
213+
"input_size": (768, 768),
214+
},
215+
"coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d2.pth",},
216+
},
217+
"efficientdet_d3": {
218+
"default": {
219+
"params": {
220+
"encoder_name":"efficientnet_b3",
221+
"pyramid_channels":160,
222+
"num_fpn_layers":6,
223+
"num_head_repeats":4,
224+
},
225+
**PRETRAIN_SETTINGS,
226+
"input_size": (896, 896),
227+
},
228+
"coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d3.pth",},
229+
},
230+
"efficientdet_d4": {
231+
"default": {
232+
"params": {
233+
"encoder_name":"efficientnet_b4",
234+
"pyramid_channels":224,
235+
"num_fpn_layers":7,
236+
"num_head_repeats":4,
237+
},
238+
**PRETRAIN_SETTINGS,
239+
"input_size": (1024, 1024),
240+
},
241+
"coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d4.pth",},
242+
},
243+
"efficientdet_d5": {
244+
"default": {
245+
"params": {
246+
"encoder_name":"efficientnet_b5",
247+
"pyramid_channels":288,
248+
"num_fpn_layers":7,
249+
"num_head_repeats":4,
250+
},
251+
**PRETRAIN_SETTINGS,
252+
"input_size": (1280, 1280),
253+
},
254+
"coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d5.pth",},
255+
},
256+
"efficientdet_d6": {
257+
"default": {
258+
"params": {
259+
"encoder_name":"efficientnet_b6",
260+
"pyramid_channels":384,
261+
"num_fpn_layers":8,
262+
"num_head_repeats":5,
263+
},
264+
**PRETRAIN_SETTINGS,
265+
"input_size": (1280, 1280),
266+
},
267+
"coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d6.pth",},
268+
},
269+
}
270+
# fmt: on
271+
272+
273+
def _efficientdet(arch, pretrained=None, **kwargs):
274+
cfgs = deepcopy(CFGS)
275+
cfg_settings = cfgs[arch]["default"]
276+
cfg_params = cfg_settings.pop("params")
277+
kwargs.update(cfg_params)
278+
model = EfficientDet(**kwargs)
279+
if pretrained:
280+
state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"])
281+
kwargs_cls = kwargs.get("num_classes", None)
282+
if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]:
283+
logging.warning(
284+
f"Using model pretrained for {cfg_settings['num_classes']} classes with {kwargs_cls} classes. Last layer is initialized randomly"
285+
)
286+
last_conv_name = f"cls_head_convs.{kwargs['num_head_repeats']}.1"
287+
state_dict[f"{last_conv_name}.weight"] = model.state_dict()[f"{last_conv_name}.weight"]
288+
state_dict[f"{last_conv_name}.bias"] = model.state_dict()[f"{last_conv_name}.bias"]
289+
model.load_state_dict(state_dict)
290+
setattr(model, "pretrained_settings", cfg_settings)
291+
return model
292+
293+
294+
@wraps(EfficientDet)
295+
def efficientdet_d0(pretrained="coco", **kwargs):
296+
return _efficientdet("efficientdet_d0", pretrained, **kwargs)
297+
298+
299+
@wraps(EfficientDet)
300+
def efficientdet_d1(pretrained="coco", **kwargs):
301+
return _efficientdet("efficientdet_d1", pretrained, **kwargs)
302+
303+
304+
@wraps(EfficientDet)
305+
def efficientdet_d2(pretrained="coco", **kwargs):
306+
return _efficientdet("efficientdet_d2", pretrained, **kwargs)
307+
308+
309+
@wraps(EfficientDet)
310+
def efficientdet_d3(pretrained="coco", **kwargs):
311+
return _efficientdet("efficientdet_d3", pretrained, **kwargs)
312+
313+
314+
@wraps(EfficientDet)
315+
def efficientdet_d4(pretrained="coco", **kwargs):
316+
return _efficientdet("efficientdet_d4", pretrained, **kwargs)
317+
318+
319+
@wraps(EfficientDet)
320+
def efficientdet_d5(pretrained="coco", **kwargs):
321+
return _efficientdet("efficientdet_d5", pretrained, **kwargs)
322+
323+
324+
@wraps(EfficientDet)
325+
def efficientdet_d6(pretrained="coco", **kwargs):
326+
return _efficientdet("efficientdet_d6", pretrained, **kwargs)
327+
328+
329+
# No B7 because it's the same model as B6 but with larger input

0 commit comments

Comments
 (0)