Skip to content

Commit cf9a6e7

Browse files
authored
Merge pull request #59 from bonlime/dev
Bunch of fixed for hrnet
2 parents 5f73e3e + 843a03f commit cf9a6e7

File tree

8 files changed

+131
-18
lines changed

8 files changed

+131
-18
lines changed

pytorch_tools/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.1.3"
1+
__version__ = "0.1.4"
22

33
from . import fit_wrapper
44
from . import losses

pytorch_tools/models/hrnet.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,11 +185,9 @@ def __init__(self, pre_channels, norm_layer=ABN, norm_act="relu"):
185185

186186
def forward(self, x):
187187
x = [self.incre_modules[i](x[i]) for i in range(4)]
188-
y = x[0]
189188
for i in range(1, 4):
190-
y = x[i] + self.downsamp_modules[i-1](y)
191-
y = self.final_layer(y)
192-
return y
189+
x[i] = x[i] + self.downsamp_modules[i-1](x[i-1])
190+
return self.final_layer(x[3])
193191

194192

195193
class HighResolutionNet(nn.Module):
@@ -359,7 +357,7 @@ def load_state_dict(self, state_dict, **kwargs):
359357
},
360358
"hrnet_w44": {
361359
"default": {"params": {"width": 44}, **DEFAULT_IMAGENET_SETTINGS,},
362-
"imagenet": {"url": None},
360+
"imagenet": {"url": "https://github.com/bonlime/pytorch-tools/releases/download/v0.1.2/hrnetv2_w44_imagenet_pretrained-8c55086c.pth"},
363361
},
364362
"hrnet_w48": {
365363
"default": {"params": {"width": 48}, **DEFAULT_IMAGENET_SETTINGS,},

pytorch_tools/modules/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .activations import Mish, MishNaive, Swish, SwishNaive
2121

2222
from .activated_batch_norm import ABN
23+
from .activated_group_norm import AGN
2324
from inplace_abn import InPlaceABN, InPlaceABNSync
2425

2526
def bn_from_name(norm_name):
@@ -32,5 +33,7 @@ def bn_from_name(norm_name):
3233
return InPlaceABNSync
3334
elif norm_name in ("frozen_abn", "frozenabn"):
3435
return partial(ABN, frozen=True)
36+
elif norm_name in ("agn", "groupnorm", "group_norm"):
37+
return AGN
3538
else:
3639
raise ValueError(f"Normalization {norm_name} not supported")
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.nn import init
5+
from torch.nn.parameter import Parameter
6+
7+
from .activations import ACT
8+
from .activations import ACT_FUNC_DICT
9+
10+
class AGN(nn.Module):
11+
"""Activated Group Normalization
12+
This gathers a GroupNorm and an activation function in a single module
13+
Parameters
14+
----------
15+
num_features : int
16+
Number of feature channels in the input and output.
17+
num_groups: int
18+
Number of groups to separate the channels into
19+
eps : float
20+
Small constant to prevent numerical issues.
21+
affine : bool
22+
If `True` apply learned scale and shift transformation after normalization.
23+
activation : str
24+
Name of the activation functions, one of: `relu`, `leaky_relu`, `elu` or `identity`.
25+
activation_param : float
26+
Negative slope for the `leaky_relu` activation.
27+
"""
28+
29+
def __init__(
30+
self,
31+
num_features,
32+
num_groups=32,
33+
eps=1e-5,
34+
affine=True,
35+
activation="relu",
36+
activation_param=0.01,
37+
):
38+
super(AGN, self).__init__()
39+
self.num_features = num_features
40+
self.num_groups = num_groups
41+
self.affine = affine
42+
self.eps = eps
43+
self.activation = ACT(activation)
44+
self.activation_param = activation_param
45+
if self.affine:
46+
self.weight = nn.Parameter(torch.ones(num_features))
47+
self.bias = nn.Parameter(torch.zeros(num_features))
48+
else:
49+
self.register_parameter("weight", None)
50+
self.register_parameter("bias", None)
51+
self.reset_parameters()
52+
53+
def reset_parameters(self):
54+
if self.affine:
55+
nn.init.constant_(self.weight, 1)
56+
nn.init.constant_(self.bias, 0)
57+
58+
def forward(self, x):
59+
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
60+
func = ACT_FUNC_DICT[self.activation]
61+
if self.activation == ACT.LEAKY_RELU:
62+
return func(x, inplace=True, negative_slope=self.activation_param)
63+
elif self.activation == ACT.ELU:
64+
return func(x, inplace=True, alpha=self.activation_param)
65+
else:
66+
return func(x, inplace=True)
67+
68+
def extra_repr(self):
69+
rep = "{num_features}, eps={eps}, affine={affine}, activation={activation}"
70+
if self.activation in ["leaky_relu", "elu"]:
71+
rep += "[{activation_param}]"
72+
return rep.format(**self.__dict__)
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from torch import nn
2+
import torch.nn.functional as F
3+
4+
# implements idea from `Weight Standardization` paper https://arxiv.org/abs/1903.10520
5+
# eps is inside sqrt to avoid overflow Idea from https://arxiv.org/abs/1911.05920
6+
class WS_Conv2d(nn.Conv2d):
7+
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
8+
padding=0, dilation=1, groups=1, bias=True):
9+
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
10+
11+
def forward(self, x):
12+
weight = self.weight
13+
weight = weight.sub(weight.mean(dim=(1, 2, 3), keepdim=True))
14+
std = weight.var(dim=(1, 2, 3), keepdim=True).add_(1e-7).sqrt_()
15+
weight = weight.div(std.expand_as(weight))
16+
return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
17+
18+
# code from random issue on github.
19+
def convertConv2WeightStand(module, nextChild=None):
20+
mod = module
21+
norm_list = [torch.nn.modules.batchnorm.BatchNorm1d, torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.BatchNorm3d, torch.nn.GroupNorm, torch.nn.LayerNorm]
22+
conv_list = [torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d, torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d]
23+
for norm in norm_list:
24+
for conv in conv_list:
25+
if isinstance(mod, conv) and isinstance(nextChild, norm):
26+
mod = Conv2d(mod.in_channels, mod.out_channels, mod.kernel_size, mod.stride,
27+
mod.padding, mod.dilation, mod.groups, mod.bias!=None)
28+
29+
moduleChildList = list(module.named_children())
30+
for index, [name, child] in enumerate(moduleChildList):
31+
nextChild = None
32+
if index < len(moduleChildList) -1:
33+
nextChild = moduleChildList[index+1][1]
34+
mod.add_module(name, convertConv2WeightStand(child, nextChild))
35+
36+
return mod
37+

pytorch_tools/optim/__init__.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,29 @@
1010

1111
from torch import optim
1212

13-
13+
# 2e-5 is the lowest epsilon than saves from overflow in fp16
1414
def optimizer_from_name(optim_name):
1515
optim_name = optim_name.lower()
1616
if optim_name == "sgd":
1717
return optim.SGD
1818
elif optim_name == "sgdw":
1919
return SGDW
2020
elif optim_name == "adam":
21-
return optim.Adam
21+
return partial(optim.Adam, eps=2e-5)
2222
elif optim_name == "adamw":
23-
return optim.AdamW
23+
return partial(AdamW_my, eps=2e-5)
2424
elif optim_name == "adamw_gc":
25-
return partial(AdamW_my, center=True)
25+
# in this implementation eps in inside sqrt so it can be smaller
26+
return partial(AdamW_my, center=True, eps=1e-7)
2627
elif optim_name == "rmsprop":
27-
return optim.RMSprop
28+
return partial(optim.RMSprop, 2e-5)
2829
elif optim_name == "radam":
29-
return RAdam
30+
return partial(RAdam, eps=2e-5)
3031
elif optim_name in ["fused_sgd", "fusedsgd"]:
3132
return FusedSGD
3233
elif optim_name in ["fused_adam", "fusedadam"]:
33-
return FusedAdam
34+
return partial(FusedAdam, eps=2e-5)
3435
elif optim_name in ["fused_novograd", "fusednovograd", "novograd"]:
35-
return FusedNovoGrad
36+
return partial(FusedNovoGrad, eps=2e-5)
3637
else:
3738
raise ValueError(f"Optimizer {optim_name} not found")

pytorch_tools/optim/adamw.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# it's a copy from torch.optim with additional `center` param
66
# AdamW is only differs from Adam in one line (where weight decay happens)
77
# upd. flag `center` comes from `Gradient Centralization` paper.
8+
# upd. moved `eps` inside sqrt to avoid nan in gradients
89
class AdamW(Optimizer):
910
r"""Implements AdamW algorithm.
1011
@@ -79,7 +80,7 @@ def step(self, closure=None):
7980

8081
#Gradient Centralization operation for Conv layers
8182
if group['center'] and len(list(grad.size()))>3:
82-
grad.add_(-grad.mean(dim = tuple(range(1,len(list(grad.size())))), keepdim = True))
83+
grad.add_(-grad.mean(dim = tuple(range(1,grad.dim())), keepdim = True))
8384

8485
state = self.state[p]
8586

@@ -110,9 +111,9 @@ def step(self, closure=None):
110111
# Maintains the maximum of all 2nd moment running avg. till now
111112
torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq)
112113
# Use the max. for normalizing running avg. of gradient
113-
denom = (max_exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
114+
denom = (max_exp_avg_sq.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
114115
else:
115-
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
116+
denom = (exp_avg_sq.add_(group['eps']).sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
116117

117118
step_size = group['lr'] / bias_correction1
118119

pytorch_tools/segmentation_models/hrnet.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ class HRNet(nn.Module):
3939
4040
Args:
4141
encoder_name (str): name of classification model used as feature extractor to build segmentation model.
42-
Models expects encoder to have output stride 16 or 8. Only Resnet and Effnet family models are supported for now
4342
encoder_weights (str): one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
4443
num_classes (int): a number of classes for output (output shape - ``(batch, classes, h, w)``).
4544
pretrained (Union[str, None]): hrnet_w48 and hrnet_w48+OCR have pretrained weights. init models using functions rather than
@@ -203,6 +202,8 @@ def _hrnet(arch, pretrained=None, **kwargs):
203202
)
204203
# if there is last_linear in state_dict, it's going to be overwritten
205204
if cfg_params.get("OCR", False):
205+
state_dict["aux_head.2.weight"] = model.state_dict()["aux_head.2.weight"]
206+
state_dict["aux_head.2.bias"] = model.state_dict()["aux_head.2.bias"]
206207
state_dict["head.weight"] = model.state_dict()["head.weight"]
207208
state_dict["head.bias"] = model.state_dict()["head.bias"]
208209
else:

0 commit comments

Comments
 (0)