Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions mmdet/models/layers/res_layer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional

from mmcv.cnn import build_conv_layer, build_norm_layer
from mmengine.model import BaseModule, Sequential
from torch import Tensor
from torch import nn as nn

from mmdet.utils import ConfigType, OptConfigType, OptMultiConfig


class ResLayer(Sequential):
"""ResLayer to build ResNet style backbone.
Expand All @@ -24,16 +29,16 @@ class ResLayer(Sequential):
"""

def __init__(self,
block,
inplanes,
planes,
num_blocks,
stride=1,
avg_down=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
downsample_first=True,
**kwargs):
block: BaseModule,
inplanes: int,
planes: int,
num_blocks: int,
stride: int = 1,
avg_down: bool = False,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
downsample_first: bool = True,
**kwargs) -> None:
self.block = block

downsample = None
Expand Down Expand Up @@ -114,18 +119,18 @@ class SimplifiedBasicBlock(BaseModule):
expansion = 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 109: super().init(*layers)


def __init__(self,
inplanes,
planes,
stride=1,
dilation=1,
downsample=None,
style='pytorch',
with_cp=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
dcn=None,
plugins=None,
init_fg=None):
inplanes: int,
planes: int,
stride: int = 1,
dilation: int = 1,
downsample: Optional[Sequential] = None,
style: ConfigType = 'pytorch',
with_cp: bool = False,
conv_cfg: OptConfigType = None,
norm_cfg: ConfigType = dict(type='BN'),
dcn: OptConfigType = None,
plugins: OptConfigType = None,
init_fg: OptMultiConfig = None) -> None:
super(SimplifiedBasicBlock, self).__init__(init_fg)
assert dcn is None, 'Not implemented yet.'
assert plugins is None, 'Not implemented yet.'
Expand Down Expand Up @@ -168,7 +173,7 @@ def norm2(self):
"""nn.Module: normalization layer after the second convolution layer"""
return getattr(self, self.norm2_name) if self.with_norm else None

def forward(self, x):
def forward(self, x: Tensor) -> Tensor:
"""Forward function."""

identity = x
Expand Down