|
| 1 | +import logging |
| 2 | +from copy import deepcopy |
| 3 | +from functools import wraps |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | +from torch.hub import load_state_dict_from_url |
| 8 | + |
| 9 | +from pytorch_tools.modules import ABN |
| 10 | +from pytorch_tools.modules.bifpn import BiFPN |
| 11 | +from pytorch_tools.modules import bn_from_name |
| 12 | +from pytorch_tools.modules.residual import conv1x1 |
| 13 | +from pytorch_tools.modules.residual import conv3x3 |
| 14 | +from pytorch_tools.modules.residual import DepthwiseSeparableConv |
| 15 | +from pytorch_tools.modules.tf_same_ops import conv_to_same_conv |
| 16 | +from pytorch_tools.modules.tf_same_ops import maxpool_to_same_maxpool |
| 17 | + |
| 18 | +from pytorch_tools.segmentation_models.encoders import get_encoder |
| 19 | + |
| 20 | +import pytorch_tools.utils.box as box_utils |
| 21 | +from pytorch_tools.utils.misc import DEFAULT_IMAGENET_SETTINGS |
| 22 | +from pytorch_tools.utils.misc import initialize_iterator |
| 23 | + |
| 24 | + |
| 25 | +def patch_bn(module): |
| 26 | + """TF ported weights use slightly different eps in BN. Need to adjust for better performance""" |
| 27 | + if isinstance(module, ABN): |
| 28 | + module.eps = 1e-3 |
| 29 | + module.momentum = 1e-2 |
| 30 | + for m in module.children(): |
| 31 | + patch_bn(m) |
| 32 | + |
| 33 | + |
| 34 | +class EfficientDet(nn.Module): |
| 35 | + """TODO: add docstring""" |
| 36 | + |
| 37 | + def __init__( |
| 38 | + self, |
| 39 | + pretrained="coco", # Not used. here for proper signature |
| 40 | + encoder_name="efficientnet_d0", |
| 41 | + encoder_weights="imagenet", |
| 42 | + pyramid_channels=64, |
| 43 | + num_fpn_layers=3, |
| 44 | + num_head_repeats=3, |
| 45 | + num_classes=90, |
| 46 | + drop_connect_rate=0, |
| 47 | + encoder_norm_layer="abn", # TODO: set to frozenabn when ready |
| 48 | + encoder_norm_act="swish", |
| 49 | + decoder_norm_layer="abn", |
| 50 | + decoder_norm_act="swish", |
| 51 | + match_tf_same_padding=False, |
| 52 | + ): |
| 53 | + super().__init__() |
| 54 | + self.encoder = get_encoder( |
| 55 | + encoder_name, |
| 56 | + norm_layer=encoder_norm_layer, |
| 57 | + norm_act=encoder_norm_act, |
| 58 | + encoder_weights=encoder_weights, |
| 59 | + ) |
| 60 | + norm_layer = bn_from_name(decoder_norm_layer) |
| 61 | + bn_args = dict(norm_layer=norm_layer, norm_act=decoder_norm_act) |
| 62 | + self.pyramid6 = nn.Sequential( |
| 63 | + conv1x1(self.encoder.out_shapes[0], pyramid_channels, bias=True), |
| 64 | + norm_layer(pyramid_channels, activation="identity"), |
| 65 | + nn.MaxPool2d(3, stride=2, padding=1), |
| 66 | + ) |
| 67 | + self.pyramid7 = nn.MaxPool2d(3, stride=2, padding=1) # in EffDet it's a simple maxpool |
| 68 | + |
| 69 | + self.bifpn = BiFPN( |
| 70 | + self.encoder.out_shapes[:-2], |
| 71 | + pyramid_channels=pyramid_channels, |
| 72 | + num_layers=num_fpn_layers, |
| 73 | + **bn_args, |
| 74 | + ) |
| 75 | + |
| 76 | + def make_head(out_size): |
| 77 | + layers = [] |
| 78 | + for _ in range(num_head_repeats): |
| 79 | + # TODO: add drop connect |
| 80 | + layers += [DepthwiseSeparableConv(pyramid_channels, pyramid_channels, use_norm=False)] |
| 81 | + layers += [DepthwiseSeparableConv(pyramid_channels, out_size, use_norm=False)] |
| 82 | + return nn.ModuleList(layers) |
| 83 | + |
| 84 | + # The convolution layers in the head are shared among all levels, but |
| 85 | + # each level has its batch normalization to capture the statistical |
| 86 | + # difference among different levels. |
| 87 | + def make_head_norm(): |
| 88 | + return nn.ModuleList( |
| 89 | + [ |
| 90 | + nn.ModuleList( |
| 91 | + [ |
| 92 | + norm_layer(pyramid_channels, activation=decoder_norm_act) |
| 93 | + for _ in range(num_head_repeats) |
| 94 | + ] |
| 95 | + + [nn.Identity()] # no bn after last depthwise conv |
| 96 | + ) |
| 97 | + for _ in range(5) |
| 98 | + ] |
| 99 | + ) |
| 100 | + |
| 101 | + anchors_per_location = 9 # TODO: maybe allow to pass this arg? |
| 102 | + self.cls_head_convs = make_head(num_classes * anchors_per_location) |
| 103 | + self.cls_head_norms = make_head_norm() |
| 104 | + self.box_head_convs = make_head(4 * anchors_per_location) |
| 105 | + self.box_head_norms = make_head_norm() |
| 106 | + self.num_classes = num_classes |
| 107 | + self.num_head_repeats = num_head_repeats |
| 108 | + |
| 109 | + patch_bn(self) |
| 110 | + self._initialize_weights() |
| 111 | + if match_tf_same_padding: |
| 112 | + conv_to_same_conv(self) |
| 113 | + maxpool_to_same_maxpool(self) |
| 114 | + |
| 115 | + # Name from mmdetectin for convenience |
| 116 | + def extract_features(self, x): |
| 117 | + """Extract features from backbone + enchance with BiFPN""" |
| 118 | + # don't use p2 and p1 |
| 119 | + p5, p4, p3, _, _ = self.encoder(x) |
| 120 | + # coarser FPN levels |
| 121 | + p6 = self.pyramid6(p5) |
| 122 | + p7 = self.pyramid7(p6) |
| 123 | + features = [p7, p6, p5, p4, p3] |
| 124 | + # enhance features |
| 125 | + features = self.bifpn(features) |
| 126 | + # want features from lowest OS to highest to align with `generate_anchors_boxes` function |
| 127 | + features = list(reversed(features)) |
| 128 | + return features |
| 129 | + |
| 130 | + def forward(self, x): |
| 131 | + features = self.extract_features(x) |
| 132 | + class_outputs = [] |
| 133 | + box_outputs = [] |
| 134 | + for feat, (cls_bns, box_bns) in zip(features, zip(self.cls_head_norms, self.box_head_norms)): |
| 135 | + cls_feat, box_feat = feat, feat |
| 136 | + # it looks like that with drop_connect there is an additional residual here |
| 137 | + # TODO: need to investigate using pretrained weights |
| 138 | + for cls_conv, cls_bn in zip(self.cls_head_convs, cls_bns): |
| 139 | + cls_feat = cls_bn(cls_conv(cls_feat)) |
| 140 | + for box_conv, box_bn in zip(self.box_head_convs, box_bns): |
| 141 | + box_feat = box_bn(box_conv(box_feat)) |
| 142 | + |
| 143 | + box_feat = box_feat.permute(0, 2, 3, 1) |
| 144 | + box_outputs.append(box_feat.contiguous().view(box_feat.shape[0], -1, 4)) |
| 145 | + |
| 146 | + cls_feat = cls_feat.permute(0, 2, 3, 1) |
| 147 | + class_outputs.append(cls_feat.contiguous().view(cls_feat.shape[0], -1, self.num_classes)) |
| 148 | + |
| 149 | + class_outputs = torch.cat(class_outputs, 1) |
| 150 | + box_outputs = torch.cat(box_outputs, 1) |
| 151 | + # my anchors are in [x1, y1, x2,y2] format while pretrained weights are in [y1, x1, y2, x2] format |
| 152 | + # it may be confusing to reorder x and y every time later so I do it once here. it gives |
| 153 | + # compatability with pretrained weigths from Google and doesn't affect training from scratch |
| 154 | + # box_outputs = box_outputs[..., [1, 0, 3, 2]] # TODO: return back |
| 155 | + return class_outputs, box_outputs |
| 156 | + |
| 157 | + @torch.no_grad() |
| 158 | + def predict(self, x): |
| 159 | + """Run forward on given images and decode raw prediction into bboxes |
| 160 | + Returns: bboxes, scores, classes |
| 161 | + """ |
| 162 | + class_outputs, box_outputs = self.forward(x) |
| 163 | + anchors = box_utils.generate_anchors_boxes(x.shape[-2:])[0] |
| 164 | + return box_utils.decode(class_outputs, box_outputs, anchors) |
| 165 | + |
| 166 | + def _initialize_weights(self): |
| 167 | + # init everything except encoder |
| 168 | + no_encoder_m = [m for n, m in self.named_modules() if not "encoder" in n] |
| 169 | + initialize_iterator(no_encoder_m) |
| 170 | + # need to init last bias so that after sigmoid it's 0.01 |
| 171 | + cls_bias_init = -torch.log(torch.tensor((1 - 0.01) / 0.01)) # -4.59 |
| 172 | + nn.init.constant_(self.cls_head_convs[-1][1].bias, cls_bias_init) |
| 173 | + |
| 174 | + |
| 175 | +PRETRAIN_SETTINGS = {**DEFAULT_IMAGENET_SETTINGS, "input_size": (512, 512), "crop_pct": 1, "num_classes": 90} |
| 176 | + |
| 177 | +# fmt: off |
| 178 | +CFGS = { |
| 179 | + "efficientdet_d0": { |
| 180 | + "default": { |
| 181 | + "params": { |
| 182 | + "encoder_name":"efficientnet_b0", |
| 183 | + "pyramid_channels":64, |
| 184 | + "num_fpn_layers":3, |
| 185 | + "num_head_repeats":3, |
| 186 | + }, |
| 187 | + **PRETRAIN_SETTINGS, |
| 188 | + }, |
| 189 | + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d0.pth",}, |
| 190 | + }, |
| 191 | + "efficientdet_d1": { |
| 192 | + "default": { |
| 193 | + "params": { |
| 194 | + "encoder_name":"efficientnet_b1", |
| 195 | + "pyramid_channels":88, |
| 196 | + "num_fpn_layers":4, |
| 197 | + "num_head_repeats":3, |
| 198 | + }, |
| 199 | + **PRETRAIN_SETTINGS, |
| 200 | + "input_size": (640, 640), |
| 201 | + }, |
| 202 | + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d1.pth",}, |
| 203 | + }, |
| 204 | + "efficientdet_d2": { |
| 205 | + "default": { |
| 206 | + "params": { |
| 207 | + "encoder_name":"efficientnet_b2", |
| 208 | + "pyramid_channels":112, |
| 209 | + "num_fpn_layers":5, |
| 210 | + "num_head_repeats":3, |
| 211 | + }, |
| 212 | + **PRETRAIN_SETTINGS, |
| 213 | + "input_size": (768, 768), |
| 214 | + }, |
| 215 | + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d2.pth",}, |
| 216 | + }, |
| 217 | + "efficientdet_d3": { |
| 218 | + "default": { |
| 219 | + "params": { |
| 220 | + "encoder_name":"efficientnet_b3", |
| 221 | + "pyramid_channels":160, |
| 222 | + "num_fpn_layers":6, |
| 223 | + "num_head_repeats":4, |
| 224 | + }, |
| 225 | + **PRETRAIN_SETTINGS, |
| 226 | + "input_size": (896, 896), |
| 227 | + }, |
| 228 | + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d3.pth",}, |
| 229 | + }, |
| 230 | + "efficientdet_d4": { |
| 231 | + "default": { |
| 232 | + "params": { |
| 233 | + "encoder_name":"efficientnet_b4", |
| 234 | + "pyramid_channels":224, |
| 235 | + "num_fpn_layers":7, |
| 236 | + "num_head_repeats":4, |
| 237 | + }, |
| 238 | + **PRETRAIN_SETTINGS, |
| 239 | + "input_size": (1024, 1024), |
| 240 | + }, |
| 241 | + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d4.pth",}, |
| 242 | + }, |
| 243 | + "efficientdet_d5": { |
| 244 | + "default": { |
| 245 | + "params": { |
| 246 | + "encoder_name":"efficientnet_b5", |
| 247 | + "pyramid_channels":288, |
| 248 | + "num_fpn_layers":7, |
| 249 | + "num_head_repeats":4, |
| 250 | + }, |
| 251 | + **PRETRAIN_SETTINGS, |
| 252 | + "input_size": (1280, 1280), |
| 253 | + }, |
| 254 | + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d5.pth",}, |
| 255 | + }, |
| 256 | + "efficientdet_d6": { |
| 257 | + "default": { |
| 258 | + "params": { |
| 259 | + "encoder_name":"efficientnet_b6", |
| 260 | + "pyramid_channels":384, |
| 261 | + "num_fpn_layers":8, |
| 262 | + "num_head_repeats":5, |
| 263 | + }, |
| 264 | + **PRETRAIN_SETTINGS, |
| 265 | + "input_size": (1280, 1280), |
| 266 | + }, |
| 267 | + "coco": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.5/efficientdet-d6.pth",}, |
| 268 | + }, |
| 269 | +} |
| 270 | +# fmt: on |
| 271 | + |
| 272 | + |
| 273 | +def _efficientdet(arch, pretrained=None, **kwargs): |
| 274 | + cfgs = deepcopy(CFGS) |
| 275 | + cfg_settings = cfgs[arch]["default"] |
| 276 | + cfg_params = cfg_settings.pop("params") |
| 277 | + kwargs.update(cfg_params) |
| 278 | + model = EfficientDet(**kwargs) |
| 279 | + if pretrained: |
| 280 | + state_dict = load_state_dict_from_url(cfgs[arch][pretrained]["url"]) |
| 281 | + kwargs_cls = kwargs.get("num_classes", None) |
| 282 | + if kwargs_cls and kwargs_cls != cfg_settings["num_classes"]: |
| 283 | + logging.warning( |
| 284 | + f"Using model pretrained for {cfg_settings['num_classes']} classes with {kwargs_cls} classes. Last layer is initialized randomly" |
| 285 | + ) |
| 286 | + last_conv_name = f"cls_head_convs.{kwargs['num_head_repeats']}.1" |
| 287 | + state_dict[f"{last_conv_name}.weight"] = model.state_dict()[f"{last_conv_name}.weight"] |
| 288 | + state_dict[f"{last_conv_name}.bias"] = model.state_dict()[f"{last_conv_name}.bias"] |
| 289 | + model.load_state_dict(state_dict) |
| 290 | + setattr(model, "pretrained_settings", cfg_settings) |
| 291 | + return model |
| 292 | + |
| 293 | + |
| 294 | +@wraps(EfficientDet) |
| 295 | +def efficientdet_d0(pretrained="coco", **kwargs): |
| 296 | + return _efficientdet("efficientdet_d0", pretrained, **kwargs) |
| 297 | + |
| 298 | + |
| 299 | +@wraps(EfficientDet) |
| 300 | +def efficientdet_d1(pretrained="coco", **kwargs): |
| 301 | + return _efficientdet("efficientdet_d1", pretrained, **kwargs) |
| 302 | + |
| 303 | + |
| 304 | +@wraps(EfficientDet) |
| 305 | +def efficientdet_d2(pretrained="coco", **kwargs): |
| 306 | + return _efficientdet("efficientdet_d2", pretrained, **kwargs) |
| 307 | + |
| 308 | + |
| 309 | +@wraps(EfficientDet) |
| 310 | +def efficientdet_d3(pretrained="coco", **kwargs): |
| 311 | + return _efficientdet("efficientdet_d3", pretrained, **kwargs) |
| 312 | + |
| 313 | + |
| 314 | +@wraps(EfficientDet) |
| 315 | +def efficientdet_d4(pretrained="coco", **kwargs): |
| 316 | + return _efficientdet("efficientdet_d4", pretrained, **kwargs) |
| 317 | + |
| 318 | + |
| 319 | +@wraps(EfficientDet) |
| 320 | +def efficientdet_d5(pretrained="coco", **kwargs): |
| 321 | + return _efficientdet("efficientdet_d5", pretrained, **kwargs) |
| 322 | + |
| 323 | + |
| 324 | +@wraps(EfficientDet) |
| 325 | +def efficientdet_d6(pretrained="coco", **kwargs): |
| 326 | + return _efficientdet("efficientdet_d6", pretrained, **kwargs) |
| 327 | + |
| 328 | + |
| 329 | +# No B7 because it's the same model as B6 but with larger input |
0 commit comments