Skip to content

Commit be17775

Browse files
bonlimezakajd
andauthored
add bit-M-ResNet (#67)
Co-authored-by: zakajd <[email protected]>
1 parent ad0975f commit be17775

File tree

3 files changed

+318
-1
lines changed

3 files changed

+318
-1
lines changed

pytorch_tools/models/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,10 @@
4545
from .hrnet import hrnet_w44
4646
from .hrnet import hrnet_w48
4747
from .hrnet import hrnet_w64
48+
49+
from .bit_resnet import bit_m_50x1
50+
from .bit_resnet import bit_m_50x3
51+
from .bit_resnet import bit_m_101x1
52+
from .bit_resnet import bit_m_101x3
53+
from .bit_resnet import bit_m_152x2
54+
from .bit_resnet import bit_m_152x4

pytorch_tools/models/bit_resnet.py

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
# Copyright 2020 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Lint as: python3
16+
"""Bottleneck ResNet v2 with GroupNorm and Weight Standardization."""
17+
import os
18+
import numpy as np
19+
from copy import deepcopy
20+
from functools import wraps
21+
from urllib.parse import urlparse
22+
from collections import OrderedDict # pylint: disable=g-importing-member
23+
24+
import torch
25+
import torch.nn as nn
26+
import torch.nn.functional as F
27+
28+
from pytorch_tools.modules.weight_standartization import WS_Conv2d as StdConv2d
29+
30+
31+
def conv3x3(cin, cout, stride=1, groups=1, bias=False):
32+
return StdConv2d(cin, cout, kernel_size=3, stride=stride,
33+
padding=1, bias=bias, groups=groups)
34+
35+
36+
def conv1x1(cin, cout, stride=1, bias=False):
37+
return StdConv2d(cin, cout, kernel_size=1, stride=stride,
38+
padding=0, bias=bias)
39+
40+
41+
def tf2th(conv_weights):
42+
"""Possibly convert HWIO to OIHW."""
43+
if conv_weights.ndim == 4:
44+
conv_weights = conv_weights.transpose([3, 2, 0, 1])
45+
return torch.from_numpy(conv_weights)
46+
47+
48+
class PreActBottleneck(nn.Module):
49+
"""Pre-activation (v2) bottleneck block.
50+
51+
Follows the implementation of "Identity Mappings in Deep Residual Networks":
52+
https://github.com/KaimingHe/resnet-1k-layers/blob/master/resnet-pre-act.lua
53+
54+
Except it puts the stride on 3x3 conv when available.
55+
"""
56+
57+
def __init__(self, cin, cout=None, cmid=None, stride=1):
58+
super().__init__()
59+
cout = cout or cin
60+
cmid = cmid or cout//4
61+
62+
self.gn1 = nn.GroupNorm(32, cin)
63+
self.conv1 = conv1x1(cin, cmid)
64+
self.gn2 = nn.GroupNorm(32, cmid)
65+
self.conv2 = conv3x3(cmid, cmid, stride) # Original code has it on conv1!!
66+
self.gn3 = nn.GroupNorm(32, cmid)
67+
self.conv3 = conv1x1(cmid, cout)
68+
self.relu = nn.ReLU(inplace=True)
69+
70+
if (stride != 1 or cin != cout):
71+
# Projection also with pre-activation according to paper.
72+
self.downsample = conv1x1(cin, cout, stride)
73+
74+
def forward(self, x):
75+
out = self.relu(self.gn1(x))
76+
77+
# Residual branch
78+
residual = x
79+
if hasattr(self, 'downsample'):
80+
residual = self.downsample(out)
81+
82+
# Unit's branch
83+
out = self.conv1(out)
84+
out = self.conv2(self.relu(self.gn2(out)))
85+
out = self.conv3(self.relu(self.gn3(out)))
86+
87+
return out + residual
88+
89+
def load_from(self, weights, prefix=''):
90+
convname = 'standardized_conv2d'
91+
with torch.no_grad():
92+
self.conv1.weight.copy_(tf2th(weights[f'{prefix}a/{convname}/kernel']))
93+
self.conv2.weight.copy_(tf2th(weights[f'{prefix}b/{convname}/kernel']))
94+
self.conv3.weight.copy_(tf2th(weights[f'{prefix}c/{convname}/kernel']))
95+
self.gn1.weight.copy_(tf2th(weights[f'{prefix}a/group_norm/gamma']))
96+
self.gn2.weight.copy_(tf2th(weights[f'{prefix}b/group_norm/gamma']))
97+
self.gn3.weight.copy_(tf2th(weights[f'{prefix}c/group_norm/gamma']))
98+
self.gn1.bias.copy_(tf2th(weights[f'{prefix}a/group_norm/beta']))
99+
self.gn2.bias.copy_(tf2th(weights[f'{prefix}b/group_norm/beta']))
100+
self.gn3.bias.copy_(tf2th(weights[f'{prefix}c/group_norm/beta']))
101+
if hasattr(self, 'downsample'):
102+
w = weights[f'{prefix}a/proj/{convname}/kernel']
103+
self.downsample.weight.copy_(tf2th(w))
104+
105+
# this models are designed for trasfer learning only! not for training from scratch
106+
class ResNetV2(nn.Module):
107+
"""
108+
Implementation of Pre-activation (v2) ResNet mode.
109+
Used to create Bit-M-50/101/152x1/2/3/4 models
110+
111+
Args:
112+
num_classes (int): Number of classification classes. Defaults to 5
113+
"""
114+
115+
def __init__(
116+
self,
117+
block_units,
118+
width_factor,
119+
# in_channels=3, # TODO: add later
120+
num_classes=5, # just a random number
121+
# encoder=False, # TODO: add later
122+
):
123+
super().__init__()
124+
wf = width_factor # shortcut 'cause we'll use it a lot.
125+
126+
# The following will be unreadable if we split lines.
127+
# pylint: disable=line-too-long
128+
self.root = nn.Sequential(OrderedDict([
129+
('conv', StdConv2d(3, 64*wf, kernel_size=7, stride=2, padding=3, bias=False)),
130+
('pad', nn.ConstantPad2d(1, 0)),
131+
('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=0)),
132+
# The following is subtly not the same!
133+
# ('pool', nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
134+
]))
135+
136+
self.body = nn.Sequential(OrderedDict([
137+
('block1', nn.Sequential(OrderedDict(
138+
[('unit01', PreActBottleneck(cin=64*wf, cout=256*wf, cmid=64*wf))] +
139+
[(f'unit{i:02d}', PreActBottleneck(cin=256*wf, cout=256*wf, cmid=64*wf)) for i in range(2, block_units[0] + 1)],
140+
))),
141+
('block2', nn.Sequential(OrderedDict(
142+
[('unit01', PreActBottleneck(cin=256*wf, cout=512*wf, cmid=128*wf, stride=2))] +
143+
[(f'unit{i:02d}', PreActBottleneck(cin=512*wf, cout=512*wf, cmid=128*wf)) for i in range(2, block_units[1] + 1)],
144+
))),
145+
('block3', nn.Sequential(OrderedDict(
146+
[('unit01', PreActBottleneck(cin=512*wf, cout=1024*wf, cmid=256*wf, stride=2))] +
147+
[(f'unit{i:02d}', PreActBottleneck(cin=1024*wf, cout=1024*wf, cmid=256*wf)) for i in range(2, block_units[2] + 1)],
148+
))),
149+
('block4', nn.Sequential(OrderedDict(
150+
[('unit01', PreActBottleneck(cin=1024*wf, cout=2048*wf, cmid=512*wf, stride=2))] +
151+
[(f'unit{i:02d}', PreActBottleneck(cin=2048*wf, cout=2048*wf, cmid=512*wf)) for i in range(2, block_units[3] + 1)],
152+
))),
153+
]))
154+
# pylint: enable=line-too-long
155+
156+
self.head = nn.Sequential(OrderedDict([
157+
('gn', nn.GroupNorm(32, 2048*wf)),
158+
('relu', nn.ReLU(inplace=True)),
159+
('avg', nn.AdaptiveAvgPool2d(output_size=1)),
160+
('conv', nn.Conv2d(2048*wf, num_classes, kernel_size=1, bias=True)),
161+
]))
162+
163+
def features(self, x):
164+
return self.body(self.root(x))
165+
166+
def logits(self, x):
167+
return self.head(x)
168+
169+
def forward(self, x):
170+
x = self.logits(self.features(x))
171+
assert x.shape[-2:] == (1, 1) # We should have no spatial shape left.
172+
return x[...,0,0]
173+
174+
def load_from(self, weights, prefix='resnet/'):
175+
with torch.no_grad():
176+
self.root.conv.weight.copy_(tf2th(weights[f'{prefix}root_block/standardized_conv2d/kernel'])) # pylint: disable=line-too-long
177+
self.head.gn.weight.copy_(tf2th(weights[f'{prefix}group_norm/gamma']))
178+
self.head.gn.bias.copy_(tf2th(weights[f'{prefix}group_norm/beta']))
179+
# always zero_head
180+
nn.init.zeros_(self.head.conv.weight)
181+
nn.init.zeros_(self.head.conv.bias)
182+
183+
for bname, block in self.body.named_children():
184+
for uname, unit in block.named_children():
185+
unit.load_from(weights, prefix=f'{prefix}{bname}/{uname}/')
186+
187+
188+
189+
190+
KNOWN_MODELS = OrderedDict([
191+
('BiT-M-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)),
192+
('BiT-M-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)),
193+
('BiT-M-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)),
194+
('BiT-M-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)),
195+
('BiT-M-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)),
196+
('BiT-M-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)),
197+
198+
('BiT-S-R50x1', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 1, *a, **kw)),
199+
('BiT-S-R50x3', lambda *a, **kw: ResNetV2([3, 4, 6, 3], 3, *a, **kw)),
200+
('BiT-S-R101x1', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 1, *a, **kw)),
201+
('BiT-S-R101x3', lambda *a, **kw: ResNetV2([3, 4, 23, 3], 3, *a, **kw)),
202+
('BiT-S-R152x2', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 2, *a, **kw)),
203+
('BiT-S-R152x4', lambda *a, **kw: ResNetV2([3, 8, 36, 3], 4, *a, **kw)),
204+
])
205+
206+
207+
PRETRAIN_SETTINGS = {
208+
"input_space": "RGB",
209+
"input_size": [3, 448, 448],
210+
"input_range": [0, 1],
211+
"mean": [0.5, 0.5, 0.5],
212+
"std": [0.5, 0.5, 0.5],
213+
"num_classes": None,
214+
}
215+
216+
# fmt: off
217+
CFGS = {
218+
# weights are loaded by default
219+
"bit_m_50x1": {
220+
"default": {
221+
"params": {"block_units": [3, 4, 6, 3], "width_factor": 1},
222+
"url": "https://storage.googleapis.com/bit_models/BiT-M-R50x1.npz",
223+
**PRETRAIN_SETTINGS
224+
},
225+
},
226+
"bit_m_50x3": {
227+
"default": {
228+
"params": {"block_units": [3, 4, 6, 3], "width_factor": 3},
229+
"url": "https://storage.googleapis.com/bit_models/BiT-M-R50x3.npz",
230+
**PRETRAIN_SETTINGS,
231+
},
232+
},
233+
"bit_m_101x1": {
234+
"default": {
235+
"params": {"block_units": [3, 4, 23, 3], "width_factor": 1},
236+
"url": "https://storage.googleapis.com/bit_models/BiT-M-R101x1.npz",
237+
**PPRETRAIN_SETTINGS,
238+
},
239+
},
240+
"bit_m_101x3": {
241+
"default": {
242+
"params": {"block_units": [3, 4, 23, 3], "width_factor": 3},
243+
"url": "https://storage.googleapis.com/bit_models/BiT-M-R101x3.npz",
244+
**PPRETRAIN_SETTINGS,
245+
},
246+
},
247+
"bit_m_152x2": {
248+
"default": {
249+
"params": {"block_units": [3, 8, 36, 3], "width_factor": 2},
250+
"url": "https://storage.googleapis.com/bit_models/BiT-M-R152x2.npz",
251+
**PPRETRAIN_SETTINGS,
252+
},
253+
},
254+
"bit_m_152x4": {
255+
"default": {
256+
"params": {"block_units": [3, 8, 36, 3], "width_factor": 4},
257+
"url": "https://storage.googleapis.com/bit_models/BiT-M-R152x4.npz",
258+
**PPRETRAIN_SETTINGS
259+
},
260+
},
261+
}
262+
263+
# fmt: on
264+
def _bit_resnet(arch, pretrained=None, **kwargs):
265+
cfgs = deepcopy(CFGS)
266+
cfg_settings = cfgs[arch]["default"]
267+
cfg_params = cfg_settings.pop("params")
268+
cfg_url = cfg_settings.pop("url")
269+
kwargs.pop("pretrained", None)
270+
kwargs.update(cfg_params)
271+
model = ResNetV2(**kwargs)
272+
# load weights to torch checkpoints folder
273+
try:
274+
torch.hub.load_state_dict_from_url(cfg_url)
275+
except RuntimeError:
276+
pass # to avoid RuntimeError: Only one file(not dir) is allowed in the zipfile
277+
filename = os.path.basename(urlparse(cfg_url).path)
278+
torch_home = torch.hub._get_torch_home()
279+
cached_file = os.path.join(torch_home, 'checkpoints', filename)
280+
weights = np.load(cached_file)
281+
model.load_from(weights)
282+
return model
283+
284+
# only want M versions of models for fine-tuning
285+
@wraps(ResNetV2)
286+
def bit_m_50x1(**kwargs):
287+
return _bit_resnet("bit_m_50x1", **kwargs)
288+
289+
@wraps(ResNetV2)
290+
def bit_m_50x3(**kwargs):
291+
return _bit_resnet("bit_m_50x3", **kwargs)
292+
293+
@wraps(ResNetV2)
294+
def bit_m_101x1(**kwargs):
295+
return _bit_resnet("bit_m_101x1", **kwargs)
296+
297+
@wraps(ResNetV2)
298+
def bit_m_101x3(**kwargs):
299+
return _bit_resnet("bit_m_101x3", **kwargs)
300+
301+
@wraps(ResNetV2)
302+
def bit_m_152x2(**kwargs):
303+
return _bit_resnet("bit_m_152x2", **kwargs)
304+
305+
@wraps(ResNetV2)
306+
def bit_m_152x4(**kwargs):
307+
return _bit_resnet("bit_m_152x4", **kwargs)
308+
309+
310+

pytorch_tools/models/resnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
import torch
1414
import torch.nn as nn
15-
from torchvision.models.utils import load_state_dict_from_url
15+
from torch.hub import load_state_dict_from_url
1616

1717
from pytorch_tools.modules import BasicBlock, Bottleneck
1818
from pytorch_tools.modules import GlobalPool2d, BlurPool

0 commit comments

Comments
 (0)