Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
_base_ = [
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py',
'./centernet_tta.py'
]

dataset_type = 'CocoDataset'
Expand Down
41 changes: 41 additions & 0 deletions configs/centernet/centernet_tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# This is different from the TTA of official CenterNet.

tta_model = dict(
type='DetTTAModel',
tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.5), max_per_img=100))

tta_pipeline = [
dict(
type='LoadImageFromFile',
to_float32=True,
file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
# ``RandomFlip`` must be placed before ``RandomCenterCropPad``,
# otherwise bounding box coordinates after flipping cannot be
# recovered correctly.
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
],
[
dict(
type='RandomCenterCropPad',
ratios=None,
border=None,
mean=[0, 0, 0],
std=[1, 1, 1],
to_rgb=True,
test_mode=True,
test_pad_mode=['logical_or', 31],
test_pad_add_pix=1),
],
[
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'flip', 'flip_direction', 'border'))
]
])
]
3 changes: 2 additions & 1 deletion configs/retinanet/retinanet_r50_fpn_1x_coco.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
_base_ = [
'../_base_/models/retinanet_r50_fpn.py',
'../_base_/datasets/coco_detection.py',
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py',
'./retinanet_tta.py'
]

# optimizer
Expand Down
23 changes: 23 additions & 0 deletions configs/retinanet/retinanet_tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
tta_model = dict(
type='DetTTAModel',
tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.5), max_per_img=100))

img_scales = [(1333, 800), (666, 400), (2000, 1200)]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[[
dict(type='Resize', scale=s, keep_ratio=True) for s in img_scales
], [
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
],
[
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor', 'flip',
'flip_direction'))
]])
]
2 changes: 1 addition & 1 deletion configs/rtmdet/rtmdet_l_8xb32-300e_coco.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
_base_ = [
'../_base_/default_runtime.py', '../_base_/schedules/schedule_1x.py',
'../_base_/datasets/coco_detection.py'
'../_base_/datasets/coco_detection.py', './rtmdet_tta.py'
]
model = dict(
type='RTMDet',
Expand Down
35 changes: 35 additions & 0 deletions configs/rtmdet/rtmdet_tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
tta_model = dict(
type='DetTTAModel',
tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.6), max_per_img=100))

img_scales = [(640, 640), (320, 320), (960, 960)]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale=s, keep_ratio=True)
for s in img_scales
],
[
# ``RandomFlip`` must be placed before ``Pad``, otherwise
# bounding box coordinates after flipping cannot be
# recovered correctly.
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
],
[
dict(
type='Pad',
size=(960, 960),
pad_val=dict(img=(114, 114, 114))),
],
[
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction'))
]
])
]
5 changes: 4 additions & 1 deletion configs/yolox/yolox_s_8xb8-300e_coco.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
_base_ = ['../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py']
_base_ = [
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py',
'./yolox_tta.py'
]

img_scale = (640, 640) # width, height

Expand Down
35 changes: 35 additions & 0 deletions configs/yolox/yolox_tta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
tta_model = dict(
type='DetTTAModel',
tta_cfg=dict(nms=dict(type='nms', iou_threshold=0.65), max_per_img=100))

img_scales = [(640, 640), (320, 320), (960, 960)]
tta_pipeline = [
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[
[
dict(type='Resize', scale=s, keep_ratio=True)
for s in img_scales
],
[
# ``RandomFlip`` must be placed before ``Pad``, otherwise
# bounding box coordinates after flipping cannot be
# recovered correctly.
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
],
[
dict(
type='Pad',
pad_to_square=True,
pad_val=dict(img=(114.0, 114.0, 114.0))),
],
[
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'flip', 'flip_direction'))
]
])
]
103 changes: 103 additions & 0 deletions docs/en/user_guides/test.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,3 +198,106 @@ data = dict(train_dataloader=dict(...), val_dataloader=dict(...), test_dataloade
```

Or you can set it through `--cfg-options` as `--cfg-options test_dataloader.batch_size=2`

## Test Time Augmentation (TTA)

Test time augmentation (TTA) is a data augmentation strategy used during the test phase. It applies different augmentations, such as flipping and scaling, to the same image for model inference, and then merges the predictions of each augmented image to obtain more accurate predictions. To make it easier for users to use TTA, MMEngine provides [BaseTTAModel](https://mmengine.readthedocs.io/en/latest/api/generated/mmengine.model.BaseTTAModel.html#mmengine.model.BaseTTAModel) class, which allows users to implement different TTA strategies by simply extending the BaseTTAModel class according to their needs.

In MMDetection, we provides [DetTTAModel](../../../mmdet/models/test_time_augs/det_tta.py) class, which inherits from BaseTTAModel.

### Use case

Using TTA requires two steps. First, you need to add `tta_model` and `tta_pipeline` in the configuration file:

```shell
tta_model = dict(
type='DetTTAModel',
tta_cfg=dict(nms=dict(
type='nms',
iou_threshold=0.5),
max_per_img=100))

tta_pipeline = [
dict(type='LoadImageFromFile',
file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[[
dict(type='Resize', scale=(1333, 800), keep_ratio=True)
], [ # It uses 2 flipping transformations (flipping and not flipping).
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
], [
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor', 'flip',
'flip_direction'))
]])]
```

Second, set `--tta` when running the test scripts as examples below:

```shell
# Single-gpu testing
python tools/test.py \
${CONFIG_FILE} \
${CHECKPOINT_FILE} \
[--tta]

# CPU: disable GPUs and run single-gpu testing script
export CUDA_VISIBLE_DEVICES=-1
python tools/test.py \
${CONFIG_FILE} \
${CHECKPOINT_FILE} \
[--out ${RESULT_FILE}] \
[--tta]

# Multi-gpu testing
bash tools/dist_test.sh \
${CONFIG_FILE} \
${CHECKPOINT_FILE} \
${GPU_NUM} \
[--tta]
```

You can also modify the TTA config by yourself, such as adding scaling enhancement:

```shell
tta_model = dict(
type='DetTTAModel',
tta_cfg=dict(nms=dict(
type='nms',
iou_threshold=0.5),
max_per_img=100))

img_scales = [(1333, 800), (666, 400), (2000, 1200)]
tta_pipeline = [
dict(type='LoadImageFromFile',
file_client_args=dict(backend='disk')),
dict(
type='TestTimeAug',
transforms=[[
dict(type='Resize', scale=s, keep_ratio=True) for s in img_scales
], [
dict(type='RandomFlip', prob=1.),
dict(type='RandomFlip', prob=0.)
], [
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape',
'img_shape', 'scale_factor', 'flip',
'flip_direction'))
]])]
```

The above data augmentation pipeline will first perform 3 multi-scaling transformations on the image, followed by 2 flipping transformations (flipping and not flipping). Finally, the image is packaged into the final result using PackDetInputs.

Here are more TTA use cases for your reference:

- [RetinaNet](../../../configs/retinanet/retinanet_tta.py)
- [CenterNet](../../../configs/centernet/centernet_tta.py)
- [YOLOX](../../../configs/rtmdet/rtmdet_tta.py)
- [RTMDet](../../../configs/yolox/yolox_tta.py)

For more advanced usage and data flow of TTA, please refer to [MMEngine](https://mmengine.readthedocs.io/en/latest/advanced_tutorials/test_time_augmentation.html#data-flow). We will support instance segmentation TTA latter.
6 changes: 3 additions & 3 deletions mmdet/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mmengine.logging import MessageHub
from mmengine.model import BaseDataPreprocessor, ImgDataPreprocessor
from mmengine.structures import PixelData
from mmengine.utils import is_list_of
from mmengine.utils import is_seq_of
from torch import Tensor

from mmdet.models.utils import unfold_wo_center
Expand Down Expand Up @@ -149,7 +149,7 @@ def _get_pad_shape(self, data: dict) -> List[tuple]:
pad_size_divisor."""
_batch_inputs = data['inputs']
# Process data with `pseudo_collate`.
if is_list_of(_batch_inputs, torch.Tensor):
if is_seq_of(_batch_inputs, torch.Tensor):
batch_pad_shape = []
for ori_input in _batch_inputs:
pad_h = int(
Expand All @@ -173,7 +173,7 @@ def _get_pad_shape(self, data: dict) -> List[tuple]:
self.pad_size_divisor)) * self.pad_size_divisor
batch_pad_shape = [(pad_h, pad_w)] * _batch_inputs.shape[0]
else:
raise TypeError('Output of `cast_data` should be a list of dict '
raise TypeError('Output of `cast_data` should be a dict '
'or a tuple with inputs and data_samples, but got'
f'{type(data)}: {data}')
return batch_pad_shape
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/test_time_augs/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .det_tta import DetTTAModel
from .merge_augs import (merge_aug_bboxes, merge_aug_masks,
merge_aug_proposals, merge_aug_results,
merge_aug_scores)

__all__ = [
'merge_aug_bboxes', 'merge_aug_masks', 'merge_aug_proposals',
'merge_aug_scores', 'merge_aug_results'
'merge_aug_scores', 'merge_aug_results', 'DetTTAModel'
]
Loading