Skip to content

Commit 8ddb245

Browse files
committed
update
1 parent ee584f4 commit 8ddb245

File tree

4 files changed

+125
-73
lines changed

4 files changed

+125
-73
lines changed

configs/centernet/centernet_r18-dcnv2_8xb16-crop512-140e_coco.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -145,27 +145,29 @@
145145
file_client_args=dict(backend='disk')),
146146
dict(
147147
type='TestTimeAug',
148-
transforms=[[
149-
dict(type='RandomFlip', prob=1.),
150-
dict(type='RandomFlip', prob=0.)
151-
],
152-
[
153-
dict(
154-
type='RandomCenterCropPad',
155-
ratios=None,
156-
border=None,
157-
mean=[0, 0, 0],
158-
std=[1, 1, 1],
159-
to_rgb=True,
160-
test_mode=True,
161-
test_pad_mode=['logical_or', 31],
162-
test_pad_add_pix=1),
163-
],
164-
[
165-
dict(
166-
type='PackDetInputs',
167-
meta_keys=('img_id', 'img_path', 'ori_shape',
168-
'img_shape', 'flip', 'flip_direction',
169-
'border'))
170-
]])
148+
transforms=[
149+
[
150+
# ``RandomFlip`` must be placed before ``RandomCenterCropPad``
151+
dict(type='RandomFlip', prob=1.),
152+
dict(type='RandomFlip', prob=0.)
153+
],
154+
[
155+
dict(
156+
type='RandomCenterCropPad',
157+
ratios=None,
158+
border=None,
159+
mean=[0, 0, 0],
160+
std=[1, 1, 1],
161+
to_rgb=True,
162+
test_mode=True,
163+
test_pad_mode=['logical_or', 31],
164+
test_pad_add_pix=1),
165+
],
166+
[
167+
dict(
168+
type='PackDetInputs',
169+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
170+
'flip', 'flip_direction', 'border'))
171+
]
172+
])
171173
]

configs/rtmdet/rtmdet_l_8xb32-300e_coco.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -191,25 +191,28 @@
191191
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
192192
dict(
193193
type='TestTimeAug',
194-
transforms=[[
195-
dict(type='Resize', scale=(640, 640), keep_ratio=True),
196-
dict(type='Resize', scale=(672, 672), keep_ratio=True),
197-
dict(type='Resize', scale=(608, 608), keep_ratio=True),
198-
], [
199-
dict(type='RandomFlip', prob=1.),
200-
dict(type='RandomFlip', prob=0.)
201-
],
202-
[
203-
dict(
204-
type='Pad',
205-
size=(640, 640),
206-
pad_val=dict(img=(114, 114, 114))),
207-
], [dict(type='LoadAnnotations', with_bbox=True)],
208-
[
209-
dict(
210-
type='PackDetInputs',
211-
meta_keys=('img_id', 'img_path', 'ori_shape',
212-
'img_shape', 'scale_factor', 'flip',
213-
'flip_direction'))
214-
]])
194+
transforms=[
195+
[
196+
dict(type='Resize', scale=(640, 640), keep_ratio=True),
197+
dict(type='Resize', scale=(672, 672), keep_ratio=True),
198+
dict(type='Resize', scale=(608, 608), keep_ratio=True),
199+
],
200+
[
201+
# ``RandomFlip`` must be placed before ``Pad``
202+
dict(type='RandomFlip', prob=1.),
203+
dict(type='RandomFlip', prob=0.)
204+
],
205+
[
206+
dict(
207+
type='Pad',
208+
size=(640, 640),
209+
pad_val=dict(img=(114, 114, 114))),
210+
],
211+
[
212+
dict(
213+
type='PackDetInputs',
214+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
215+
'scale_factor', 'flip', 'flip_direction'))
216+
]
217+
])
215218
]

configs/yolox/yolox_s_8xb8-300e_coco.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -244,25 +244,28 @@
244244
dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
245245
dict(
246246
type='TestTimeAug',
247-
transforms=[[
248-
dict(type='Resize', scale=(416, 416), keep_ratio=True),
249-
dict(type='Resize', scale=(384, 384), keep_ratio=True),
250-
dict(type='Resize', scale=(448, 448), keep_ratio=True),
251-
], [
252-
dict(type='RandomFlip', prob=1.),
253-
dict(type='RandomFlip', prob=0.)
254-
],
255-
[
256-
dict(
257-
type='Pad',
258-
pad_to_square=True,
259-
pad_val=dict(img=(114.0, 114.0, 114.0))),
260-
],
261-
[
262-
dict(
263-
type='PackDetInputs',
264-
meta_keys=('img_id', 'img_path', 'ori_shape',
265-
'img_shape', 'scale_factor', 'flip',
266-
'flip_direction'))
267-
]])
247+
transforms=[
248+
[
249+
dict(type='Resize', scale=(416, 416), keep_ratio=True),
250+
dict(type='Resize', scale=(384, 384), keep_ratio=True),
251+
dict(type='Resize', scale=(448, 448), keep_ratio=True),
252+
],
253+
[
254+
# ``RandomFlip`` must be placed before ``Pad``
255+
dict(type='RandomFlip', prob=1.),
256+
dict(type='RandomFlip', prob=0.)
257+
],
258+
[
259+
dict(
260+
type='Pad',
261+
pad_to_square=True,
262+
pad_val=dict(img=(114.0, 114.0, 114.0))),
263+
],
264+
[
265+
dict(
266+
type='PackDetInputs',
267+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
268+
'scale_factor', 'flip', 'flip_direction'))
269+
]
270+
])
268271
]

mmdet/models/test_time_augs/det_tta.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from typing import List
2+
from typing import List, Tuple
33

44
import torch
55
from mmcv.ops import batched_nms
66
from mmengine.model import BaseTTAModel
77
from mmengine.registry import MODELS
88
from mmengine.structures import InstanceData
9+
from torch import Tensor
910

1011
from mmdet.structures import DetDataSample
1112
from mmdet.structures.bbox import bbox_flip
@@ -14,13 +15,44 @@
1415
@MODELS.register_module()
1516
class DetTTAModel(BaseTTAModel):
1617
"""Merge augmented detection results, only bboxes corresponding score under
17-
flipping and multi-scale resizing can be processed now."""
18+
flipping and multi-scale resizing can be processed now.
19+
20+
Examples:
21+
>>> tta_model = dict(
22+
>>> type='DetTTAModel',
23+
>>> tta_cfg=dict(nms=dict(
24+
>>> type='nms',
25+
>>> iou_threshold=0.5),
26+
>>> max_per_img=100)
27+
>>>
28+
>>> tta_pipeline = [
29+
>>> dict(type='LoadImageFromFile',
30+
>>> file_client_args=dict(backend='disk')),
31+
>>> dict(
32+
>>> type='TestTimeAug',
33+
>>> transforms=[[
34+
>>> dict(type='Resize',
35+
>>> scale=(1333, 800),
36+
>>> keep_ratio=True),
37+
>>> ], [
38+
>>> dict(type='RandomFlip', prob=1.),
39+
>>> dict(type='RandomFlip', prob=0.)
40+
>>> ], [
41+
>>> dict(
42+
>>> type='PackDetInputs',
43+
>>> meta_keys=('img_id', 'img_path', 'ori_shape',
44+
>>> 'img_shape', 'scale_factor', 'flip',
45+
>>> 'flip_direction'))
46+
>>> ]])]
47+
"""
1848

1949
def __init__(self, tta_cfg=None, **kwargs):
2050
super().__init__(**kwargs)
2151
self.tta_cfg = tta_cfg
2252

23-
def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
53+
def merge_aug_bboxes(self, aug_bboxes: List[Tensor],
54+
aug_scores: List[Tensor],
55+
img_metas: List[str]) -> Tuple[Tensor, Tensor]:
2456
"""Merge augmented detection bboxes and scores.
2557
2658
Args:
@@ -50,20 +82,32 @@ def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
5082
return bboxes, scores
5183

5284
def merge_preds(self, data_samples_list: List[List[DetDataSample]]):
53-
"""Merge predictions of enhanced data to one prediction.
85+
"""Merge batch predictions of enhanced data.
5486
5587
Args:
56-
data_samples_list (List[List[ClsDataSample]]): List of predictions
57-
of all enhanced data.
88+
data_samples_list (List[List[DetDataSample]]): List of predictions
89+
of all enhanced data. The outer list indicates images, and the
90+
inner list corresponds to the different views of one image.
91+
Each element of the inner list is a ``DetDataSample``.
5892
Returns:
59-
List[ClsDataSample]: Merged prediction.
93+
List[DetDataSample]: Merged batch prediction.
6094
"""
6195
merged_data_samples = []
6296
for data_samples in data_samples_list:
6397
merged_data_samples.append(self._merge_single_sample(data_samples))
6498
return merged_data_samples
6599

66-
def _merge_single_sample(self, data_samples):
100+
def _merge_single_sample(
101+
self, data_samples: List[DetDataSample]) -> DetDataSample:
102+
"""Merge predictions which come form the different views of one image
103+
to one prediction.
104+
105+
Args:
106+
data_samples_list (List[DetDataSample]): List of predictions
107+
of enhanced data which come form one image.
108+
Returns:
109+
List[DetDataSample]: Merged prediction.
110+
"""
67111
aug_bboxes = []
68112
aug_scores = []
69113
aug_labels = []

0 commit comments

Comments
 (0)