|
1 | | -# 将单阶段检测器作为 RPN(待更新) |
| 1 | +# 将单阶段检测器作为 RPN |
| 2 | + |
| 3 | +候选区域网络 (Region Proposal Network, RPN) 作为 [Faster R-CNN](https://arxiv.org/abs/1506.01497) 的一个子模块,将为 Faster R-CNN 的第二阶段产生候选区域。在 MMDetection 里大多数的二阶段检测器使用 [`RPNHead`](../../../mmdet/models/dense_heads/rpn_head.py)作为候选区域网络来产生候选区域。然而,任何的单阶段检测器都可以作为候选区域网络,是因为他们对边界框的预测可以被视为是一种候选区域,并且因此能够在 R-CNN 中得到改进。因此在 MMDetection v3.0 中会支持将单阶段检测器作为 RPN 使用。 |
| 4 | + |
| 5 | +接下来我们通过一个例子,即如何在 [Faster R-CNN](../../../configs/faster_rcnn/faster-rcnn_r50_fpn_fcos-rpn_1x_coco.py) 中使用一个无锚框的单阶段的检测器模型 [FCOS](../../../configs/fcos/fcos_r50-caffe_fpn_gn-head_1x_coco.py) 作为 RPN ,详细阐述具体的全部流程。 |
| 6 | + |
| 7 | +主要流程如下: |
| 8 | + |
| 9 | +1. 在 Faster R-CNN 中使用 `FCOSHead` 作为 `RPNHead` |
| 10 | +2. 评估候选区域 |
| 11 | +3. 用预先训练的 FCOS 训练定制的 Faster R-CNN |
| 12 | + |
| 13 | +## 在 Faster R-CNN 中使用 `FCOSHead` 作为` RPNHead` |
| 14 | + |
| 15 | +为了在 Faster R-CNN 中使用 `FCOSHead` 作为 `RPNHead` ,我们应该创建一个名为 `configs/faster_rcnn/faster-rcnn_r50_fpn_fcos-rpn_1x_coco.py` 的配置文件,并且在 `configs/faster_rcnn/faster-rcnn_r50_fpn_fcos-rpn_1x_coco.py` 中将 `rpn_head` 的设置替换为 `bbox_head` 的设置,此外我们仍然使用 FCOS 的瓶颈设置,步幅为`[8,16,32,64,128]`,并且更新 `bbox_roi_extractor` 的 `featmap_stride` 为 ` [8,16,32,64,128]`。为了避免损失变慢,我们在前1000次迭代而不是前500次迭代中应用预热,这意味着 lr 增长得更慢。相关配置如下: |
| 16 | + |
| 17 | +```python |
| 18 | +_base_ = [ |
| 19 | + '../_base_/models/faster-rcnn_r50_fpn.py', |
| 20 | + '../_base_/datasets/coco_detection.py', |
| 21 | + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
| 22 | +] |
| 23 | +model = dict( |
| 24 | + # 从 configs/fcos/fcos_r50-caffe_fpn_gn-head_1x_coco.py 复制 |
| 25 | + neck=dict( |
| 26 | + start_level=1, |
| 27 | + add_extra_convs='on_output', # 使用 P5 |
| 28 | + relu_before_extra_convs=True), |
| 29 | + rpn_head=dict( |
| 30 | + _delete_=True, # 忽略未使用的旧设置 |
| 31 | + type='FCOSHead', |
| 32 | + num_classes=1, # 对于 rpn, num_classes = 1,如果 num_classes > 1,它将在 TwoStageDetector 中自动设置为1 |
| 33 | + in_channels=256, |
| 34 | + stacked_convs=4, |
| 35 | + feat_channels=256, |
| 36 | + strides=[8, 16, 32, 64, 128], |
| 37 | + loss_cls=dict( |
| 38 | + type='FocalLoss', |
| 39 | + use_sigmoid=True, |
| 40 | + gamma=2.0, |
| 41 | + alpha=0.25, |
| 42 | + loss_weight=1.0), |
| 43 | + loss_bbox=dict(type='IoULoss', loss_weight=1.0), |
| 44 | + loss_centerness=dict( |
| 45 | + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), |
| 46 | + roi_head=dict( # featmap_strides 的更新取决于于颈部的步伐 |
| 47 | + bbox_roi_extractor=dict(featmap_strides=[8, 16, 32, 64, 128]))) |
| 48 | +# 学习率 |
| 49 | +param_scheduler = [ |
| 50 | + dict( |
| 51 | + type='LinearLR', start_factor=0.001, by_epoch=False, begin=0, |
| 52 | + end=1000), # 慢慢增加 lr,否则损失变成 NAN |
| 53 | + dict( |
| 54 | + type='MultiStepLR', |
| 55 | + begin=0, |
| 56 | + end=12, |
| 57 | + by_epoch=True, |
| 58 | + milestones=[8, 11], |
| 59 | + gamma=0.1) |
| 60 | +] |
| 61 | +``` |
| 62 | + |
| 63 | +然后,我们可以使用下面的命令来训练我们的定制模型。更多训练命令,请参考[这里](train.md)。 |
| 64 | + |
| 65 | +```python |
| 66 | +# 使用8个 GPU 进行训练 |
| 67 | +bash |
| 68 | +tools/dist_train.sh |
| 69 | +configs/faster_rcnn/faster-rcnn_r50_fpn_fcos-rpn_1x_coco.py |
| 70 | +--work-dir /work_dirs/faster-rcnn_r50_fpn_fcos-rpn_1x_coco |
| 71 | +``` |
| 72 | + |
| 73 | +## 评估候选区域 |
| 74 | + |
| 75 | +候选区域的质量对检测器的性能有重要影响,因此,我们也提供了一种评估候选区域的方法。和上面一样创建一个新的名为 `configs/rpn/fcos-rpn_r50_fpn_1x_coco.py` 的配置文件,并且在 `configs/rpn/fcos-rpn_r50_fpn_1x_coco.py` 中将 `rpn_head` 的设置替换为 `bbox_head` 的设置。 |
| 76 | + |
| 77 | +```python |
| 78 | +_base_ = [ |
| 79 | + '../_base_/models/rpn_r50_fpn.py', '../_base_/datasets/coco_detection.py', |
| 80 | + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
| 81 | +] |
| 82 | +val_evaluator = dict(metric='proposal_fast') |
| 83 | +test_evaluator = val_evaluator |
| 84 | +model = dict( |
| 85 | + # 从 configs/fcos/fcos_r50-caffe_fpn_gn-head_1x_coco.py 复制 |
| 86 | + neck=dict( |
| 87 | + start_level=1, |
| 88 | + add_extra_convs='on_output', # 使用 P5 |
| 89 | + relu_before_extra_convs=True), |
| 90 | + rpn_head=dict( |
| 91 | + _delete_=True, # 忽略未使用的旧设置 |
| 92 | + type='FCOSHead', |
| 93 | + num_classes=1, # 对于 rpn, num_classes = 1,如果 num_classes >为1,它将在 rpn 中自动设置为1 |
| 94 | + in_channels=256, |
| 95 | + stacked_convs=4, |
| 96 | + feat_channels=256, |
| 97 | + strides=[8, 16, 32, 64, 128], |
| 98 | + loss_cls=dict( |
| 99 | + type='FocalLoss', |
| 100 | + use_sigmoid=True, |
| 101 | + gamma=2.0, |
| 102 | + alpha=0.25, |
| 103 | + loss_weight=1.0), |
| 104 | + loss_bbox=dict(type='IoULoss', loss_weight=1.0), |
| 105 | + loss_centerness=dict( |
| 106 | + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0))) |
| 107 | +``` |
| 108 | + |
| 109 | +假设我们在训练之后有检查点 `./work_dirs/faster-rcnn_r50_fpn_fcos-rpn_1x_coco/epoch_12.pth` ,然后,我们可以使用下面的命令来评估建议的质量。 |
| 110 | + |
| 111 | +```python |
| 112 | +# 使用8个 GPU 进行测试 |
| 113 | +bash |
| 114 | +tools/dist_test.sh |
| 115 | +configs/rpn/fcos-rpn_r50_fpn_1x_coco.py |
| 116 | +--work_dirs /faster-rcnn_r50_fpn_fcos-rpn_1x_coco/epoch_12.pth |
| 117 | +``` |
| 118 | + |
| 119 | +## 用预先训练的 FCOS 训练定制的 Faster R-CNN |
| 120 | + |
| 121 | +预训练不仅加快了训练的收敛速度,而且提高了检测器的性能。因此,我们在这里给出一个例子来说明如何使用预先训练的 FCOS 作为 RPN 来加速训练和提高精度。假设我们想在 Faster R-CNN 中使用 `FCOSHead` 作为 `rpn_head`,并加载预先训练权重来进行训练 [`fcos_r50-caffe_fpn_gn-head_1x_coco`](https://download.openmmlab.com/mmdetection/v2.0/fcos/fcos_r50_caffe_fpn_gn-head_1x_coco/fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth)。 配置文件 `configs/faster_rcnn/faster-rcnn_r50-caffe_fpn_fcos- rpn_1x_copy .py` 的内容如下所示。注意,`fcos_r50-caffe_fpn_gn-head_1x_coco` 使用 ResNet50 的 caffe 版本,因此需要更新 `data_preprocessor` 中的像素平均值和 std。 |
| 122 | + |
| 123 | +```python |
| 124 | +_base_ = [ |
| 125 | + '../_base_/models/faster-rcnn_r50_fpn.py', |
| 126 | + '../_base_/datasets/coco_detection.py', |
| 127 | + '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py' |
| 128 | +] |
| 129 | +model = dict( |
| 130 | + data_preprocessor=dict( |
| 131 | + mean=[103.530, 116.280, 123.675], |
| 132 | + std=[1.0, 1.0, 1.0], |
| 133 | + bgr_to_rgb=False), |
| 134 | + backbone=dict( |
| 135 | + norm_cfg=dict(type='BN', requires_grad=False), |
| 136 | + style='caffe', |
| 137 | + init_cfg=None), # the checkpoint in ``load_from`` contains the weights of backbone |
| 138 | + neck=dict( |
| 139 | + start_level=1, |
| 140 | + add_extra_convs='on_output', # 使用 P5 |
| 141 | + relu_before_extra_convs=True), |
| 142 | + rpn_head=dict( |
| 143 | + _delete_=True, # 忽略未使用的旧设置 |
| 144 | + type='FCOSHead', |
| 145 | + num_classes=1, # 对于 rpn, num_classes = 1,如果 num_classes > 1,它将在 TwoStageDetector 中自动设置为1 |
| 146 | + in_channels=256, |
| 147 | + stacked_convs=4, |
| 148 | + feat_channels=256, |
| 149 | + strides=[8, 16, 32, 64, 128], |
| 150 | + loss_cls=dict( |
| 151 | + type='FocalLoss', |
| 152 | + use_sigmoid=True, |
| 153 | + gamma=2.0, |
| 154 | + alpha=0.25, |
| 155 | + loss_weight=1.0), |
| 156 | + loss_bbox=dict(type='IoULoss', loss_weight=1.0), |
| 157 | + loss_centerness=dict( |
| 158 | + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)), |
| 159 | + roi_head=dict( # update featmap_strides due to the strides in neck |
| 160 | + bbox_roi_extractor=dict(featmap_strides=[8, 16, 32, 64, 128]))) |
| 161 | +load_from = 'https://download.openmmlab.com/mmdetection/v2.0/fcos/fcos_r50_caffe_fpn_gn-head_1x_coco/fcos_r50_caffe_fpn_gn-head_1x_coco-821213aa.pth' |
| 162 | +``` |
| 163 | + |
| 164 | +训练命令如下。 |
| 165 | + |
| 166 | +```python |
| 167 | +bash |
| 168 | +tools/dist_train.sh |
| 169 | +configs/faster_rcnn/faster-rcnn_r50-caffe_fpn_fcos-rpn_1x_coco.py \ |
| 170 | +--work-dir /work_dirs/faster-rcnn_r50-caffe_fpn_fcos-rpn_1x_coco |
| 171 | +``` |
0 commit comments