Skip to content

Commit 3fb5aa3

Browse files
authored
Merge 5577f50 into aa10365
2 parents aa10365 + 5577f50 commit 3fb5aa3

File tree

17 files changed

+475
-114
lines changed

17 files changed

+475
-114
lines changed

configs/_base_/datasets/s3dis-seg.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -73,25 +73,6 @@
7373
with_seg_3d=True,
7474
backend_args=backend_args),
7575
dict(type='NormalizePointsColor', color_mean=None),
76-
dict(
77-
# a wrapper in order to successfully call test function
78-
# actually we don't perform test-time-aug
79-
type='MultiScaleFlipAug3D',
80-
img_scale=(1333, 800),
81-
pts_scale_ratio=1,
82-
flip=False,
83-
transforms=[
84-
dict(
85-
type='GlobalRotScaleTrans',
86-
rot_range=[0, 0],
87-
scale_ratio_range=[1., 1.],
88-
translation_std=[0, 0, 0]),
89-
dict(
90-
type='RandomFlip3D',
91-
sync_2d=False,
92-
flip_ratio_bev_horizontal=0.0,
93-
flip_ratio_bev_vertical=0.0),
94-
]),
9576
dict(type='Pack3DDetInputs', keys=['points'])
9677
]
9778
# construct a pipeline for data and gt loading in show function
@@ -109,6 +90,33 @@
10990
dict(type='NormalizePointsColor', color_mean=None),
11091
dict(type='Pack3DDetInputs', keys=['points'])
11192
]
93+
tta_pipeline = [
94+
dict(
95+
type='LoadPointsFromFile',
96+
coord_type='DEPTH',
97+
shift_height=False,
98+
use_color=True,
99+
load_dim=6,
100+
use_dim=[0, 1, 2, 3, 4, 5],
101+
backend_args=backend_args),
102+
dict(
103+
type='LoadAnnotations3D',
104+
with_bbox_3d=False,
105+
with_label_3d=False,
106+
with_mask_3d=False,
107+
with_seg_3d=True,
108+
backend_args=backend_args),
109+
dict(type='NormalizePointsColor', color_mean=None),
110+
dict(
111+
type='TestTimeAug',
112+
transforms=[[
113+
dict(
114+
type='RandomFlip3D',
115+
sync_2d=False,
116+
flip_ratio_bev_horizontal=0.,
117+
flip_ratio_bev_vertical=0.)
118+
], [dict(type='Pack3DDetInputs', keys=['points'])]])
119+
]
112120

113121
# train on area 1, 2, 3, 4, 6
114122
# test on area 5
@@ -157,3 +165,5 @@
157165
vis_backends = [dict(type='LocalVisBackend')]
158166
visualizer = dict(
159167
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
168+
169+
tta_model = dict(type='Seg3DTTAModel')

configs/_base_/datasets/scannet-seg.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -73,25 +73,6 @@
7373
with_seg_3d=True,
7474
backend_args=backend_args),
7575
dict(type='NormalizePointsColor', color_mean=None),
76-
dict(
77-
# a wrapper in order to successfully call test function
78-
# actually we don't perform test-time-aug
79-
type='MultiScaleFlipAug3D',
80-
img_scale=(1333, 800),
81-
pts_scale_ratio=1,
82-
flip=False,
83-
transforms=[
84-
dict(
85-
type='GlobalRotScaleTrans',
86-
rot_range=[0, 0],
87-
scale_ratio_range=[1., 1.],
88-
translation_std=[0, 0, 0]),
89-
dict(
90-
type='RandomFlip3D',
91-
sync_2d=False,
92-
flip_ratio_bev_horizontal=0.0,
93-
flip_ratio_bev_vertical=0.0),
94-
]),
9576
dict(type='Pack3DDetInputs', keys=['points'])
9677
]
9778
# construct a pipeline for data and gt loading in show function
@@ -109,6 +90,33 @@
10990
dict(type='NormalizePointsColor', color_mean=None),
11091
dict(type='Pack3DDetInputs', keys=['points'])
11192
]
93+
tta_pipeline = [
94+
dict(
95+
type='LoadPointsFromFile',
96+
coord_type='DEPTH',
97+
shift_height=False,
98+
use_color=True,
99+
load_dim=6,
100+
use_dim=[0, 1, 2, 3, 4, 5],
101+
backend_args=backend_args),
102+
dict(
103+
type='LoadAnnotations3D',
104+
with_bbox_3d=False,
105+
with_label_3d=False,
106+
with_mask_3d=False,
107+
with_seg_3d=True,
108+
backend_args=backend_args),
109+
dict(type='NormalizePointsColor', color_mean=None),
110+
dict(
111+
type='TestTimeAug',
112+
transforms=[[
113+
dict(
114+
type='RandomFlip3D',
115+
sync_2d=False,
116+
flip_ratio_bev_horizontal=0.,
117+
flip_ratio_bev_vertical=0.)
118+
], [dict(type='Pack3DDetInputs', keys=['points'])]])
119+
]
112120

113121
train_dataloader = dict(
114122
batch_size=8,
@@ -152,3 +160,5 @@
152160
vis_backends = [dict(type='LocalVisBackend')]
153161
visualizer = dict(
154162
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
163+
164+
tta_model = dict(type='Seg3DTTAModel')

configs/_base_/datasets/semantickitti.py

Lines changed: 70 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
seg_offset=2**16,
8383
dataset_type='semantickitti',
8484
backend_args=backend_args),
85-
dict(type='PointSegClassMapping', ),
85+
dict(type='PointSegClassMapping'),
8686
dict(
8787
type='RandomFlip3D',
8888
sync_2d=False,
@@ -112,12 +112,21 @@
112112
seg_offset=2**16,
113113
dataset_type='semantickitti',
114114
backend_args=backend_args),
115-
dict(type='PointSegClassMapping', ),
116-
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
115+
dict(type='PointSegClassMapping'),
116+
dict(type='Pack3DDetInputs', keys=['points'])
117117
]
118118
# construct a pipeline for data and gt loading in show function
119119
# please keep its loading function consistent with test_pipeline (e.g. client)
120120
eval_pipeline = [
121+
dict(
122+
type='LoadPointsFromFile',
123+
coord_type='LIDAR',
124+
load_dim=4,
125+
use_dim=4,
126+
backend_args=backend_args),
127+
dict(type='Pack3DDetInputs', keys=['points'])
128+
]
129+
tta_pipeline = [
121130
dict(
122131
type='LoadPointsFromFile',
123132
coord_type='LIDAR',
@@ -133,46 +142,75 @@
133142
seg_offset=2**16,
134143
dataset_type='semantickitti',
135144
backend_args=backend_args),
136-
dict(type='PointSegClassMapping', ),
137-
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
145+
dict(type='PointSegClassMapping'),
146+
dict(
147+
type='TestTimeAug',
148+
transforms=[[
149+
dict(
150+
type='RandomFlip3D',
151+
sync_2d=False,
152+
flip_ratio_bev_horizontal=0.,
153+
flip_ratio_bev_vertical=0.),
154+
dict(
155+
type='RandomFlip3D',
156+
sync_2d=False,
157+
flip_ratio_bev_horizontal=0.,
158+
flip_ratio_bev_vertical=1.),
159+
dict(
160+
type='RandomFlip3D',
161+
sync_2d=False,
162+
flip_ratio_bev_horizontal=1.,
163+
flip_ratio_bev_vertical=0.),
164+
dict(
165+
type='RandomFlip3D',
166+
sync_2d=False,
167+
flip_ratio_bev_horizontal=1.,
168+
flip_ratio_bev_vertical=1.)
169+
],
170+
[
171+
dict(
172+
type='GlobalRotScaleTrans',
173+
rot_range=[pcd_rotate_range, pcd_rotate_range],
174+
scale_ratio_range=[
175+
pcd_scale_factor, pcd_scale_factor
176+
],
177+
translation_std=[0, 0, 0])
178+
for pcd_rotate_range in [-0.78539816, 0.0, 0.78539816]
179+
for pcd_scale_factor in [0.95, 1.0, 1.05]
180+
], [dict(type='Pack3DDetInputs', keys=['points'])]])
138181
]
139182

140183
train_dataloader = dict(
141184
batch_size=2,
142185
num_workers=4,
186+
persistent_workers=True,
143187
sampler=dict(type='DefaultSampler', shuffle=True),
144188
dataset=dict(
145-
type='RepeatDataset',
146-
times=1,
147-
dataset=dict(
148-
type=dataset_type,
149-
data_root=data_root,
150-
ann_file='semantickitti_infos_train.pkl',
151-
pipeline=train_pipeline,
152-
metainfo=metainfo,
153-
modality=input_modality,
154-
ignore_index=19,
155-
backend_args=backend_args)),
156-
)
189+
type=dataset_type,
190+
data_root=data_root,
191+
ann_file='semantickitti_infos_train.pkl',
192+
pipeline=train_pipeline,
193+
metainfo=metainfo,
194+
modality=input_modality,
195+
ignore_index=19,
196+
backend_args=backend_args))
157197

158198
test_dataloader = dict(
159199
batch_size=1,
160200
num_workers=1,
201+
persistent_workers=True,
202+
drop_last=False,
161203
sampler=dict(type='DefaultSampler', shuffle=False),
162204
dataset=dict(
163-
type='RepeatDataset',
164-
times=1,
165-
dataset=dict(
166-
type=dataset_type,
167-
data_root=data_root,
168-
ann_file='semantickitti_infos_val.pkl',
169-
pipeline=test_pipeline,
170-
metainfo=metainfo,
171-
modality=input_modality,
172-
ignore_index=19,
173-
test_mode=True,
174-
backend_args=backend_args)),
175-
)
205+
type=dataset_type,
206+
data_root=data_root,
207+
ann_file='semantickitti_infos_val.pkl',
208+
pipeline=test_pipeline,
209+
metainfo=metainfo,
210+
modality=input_modality,
211+
ignore_index=19,
212+
test_mode=True,
213+
backend_args=backend_args))
176214

177215
val_dataloader = test_dataloader
178216

@@ -182,3 +220,5 @@
182220
vis_backends = [dict(type='LocalVisBackend')]
183221
visualizer = dict(
184222
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')
223+
224+
tta_model = dict(type='Seg3DTTAModel')

configs/minkunet/minkunet_w32_8xb2-15e_semantickitti.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
]
2525

2626
train_dataloader = dict(
27-
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))
27+
sampler=dict(seed=0), dataset=dict(pipeline=train_pipeline))
2828

2929
lr = 0.24
3030
optim_wrapper = dict(

configs/spvcnn/spvcnn_w32_8xb2-15e_semantickitti.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
]
2525

2626
train_dataloader = dict(
27-
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))
27+
sampler=dict(seed=0), dataset=dict(pipeline=train_pipeline))
2828

2929
lr = 0.24
3030
optim_wrapper = dict(

mmdet3d/datasets/transforms/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
from .dbsampler import DataBaseSampler
33
from .formating import Pack3DDetInputs
4-
from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D,
5-
LoadMultiViewImageFromFiles, LoadPointsFromDict,
6-
LoadPointsFromFile, LoadPointsFromMultiSweeps,
7-
MonoDet3DInferencerLoader,
4+
from .loading import (LidarDet3DInferencerLoader, LoadAnnotations3D,
5+
LoadImageFromFileMono3D, LoadMultiViewImageFromFiles,
6+
LoadPointsFromDict, LoadPointsFromFile,
7+
LoadPointsFromMultiSweeps, MonoDet3DInferencerLoader,
88
MultiModalityDet3DInferencerLoader, NormalizePointsColor,
99
PointSegClassMapping)
1010
from .test_time_aug import MultiScaleFlipAug3D

mmdet3d/models/segmentors/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,9 @@
33
from .cylinder3d import Cylinder3D
44
from .encoder_decoder import EncoderDecoder3D
55
from .minkunet import MinkUNet
6+
from .seg3d_tta import Seg3DTTAModel
67

7-
__all__ = ['Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D', 'MinkUNet']
8+
__all__ = [
9+
'Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D', 'MinkUNet',
10+
'Seg3DTTAModel'
11+
]

mmdet3d/models/segmentors/base.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -132,17 +132,12 @@ def _forward(self,
132132
"""
133133
pass
134134

135-
@abstractmethod
136-
def aug_test(self, batch_inputs, batch_data_samples):
137-
"""Placeholder for augmentation test."""
138-
pass
139-
140-
def postprocess_result(self, seg_pred_list: List[dict],
135+
def postprocess_result(self, seg_logits_list: List[Tensor],
141136
batch_data_samples: SampleList) -> SampleList:
142137
"""Convert results list to `Det3DDataSample`.
143138
144139
Args:
145-
seg_logits_list (List[dict]): List of segmentation results,
140+
seg_logits_list (List[Tensor]): List of segmentation results,
146141
seg_logits from model of each input point clouds sample.
147142
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
148143
samples. It usually includes information such as `metainfo` and
@@ -152,12 +147,19 @@ def postprocess_result(self, seg_pred_list: List[dict],
152147
List[:obj:`Det3DDataSample`]: Segmentation results of the input
153148
points. Each Det3DDataSample usually contains:
154149
155-
- ``pred_pts_seg`` (PixelData): Prediction of 3D semantic
150+
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic
156151
segmentation.
152+
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
153+
segmentation before normalization.
157154
"""
158155

159-
for i in range(len(seg_pred_list)):
160-
seg_pred = seg_pred_list[i]
161-
batch_data_samples[i].set_data(
162-
{'pred_pts_seg': PointData(**{'pts_semantic_mask': seg_pred})})
156+
for i in range(len(seg_logits_list)):
157+
seg_logits = seg_logits_list[i]
158+
seg_pred = seg_logits.argmax(dim=0)
159+
batch_data_samples[i].set_data({
160+
'pts_seg_logits':
161+
PointData(**{'pts_seg_logits': seg_logits}),
162+
'pred_pts_seg':
163+
PointData(**{'pts_semantic_mask': seg_pred})
164+
})
163165
return batch_data_samples

mmdet3d/models/segmentors/cylinder3d.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,16 +127,18 @@ def predict(self,
127127
List[:obj:`Det3DDataSample`]: Segmentation results of the input
128128
points. Each Det3DDataSample usually contains:
129129
130-
- ``pred_pts_seg`` (PixelData): Prediction of 3D semantic
130+
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic
131131
segmentation.
132+
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
133+
segmentation before normalization.
132134
"""
133135
# 3D segmentation requires per-point prediction, so it's impossible
134136
# to use down-sampling to get a batch of scenes with same num_points
135137
# therefore, we only support testing one scene every time
136138
x = self.extract_feat(batch_inputs_dict)
137-
seg_pred_list = self.decode_head.predict(x, batch_inputs_dict,
138-
batch_data_samples)
139-
for i in range(len(seg_pred_list)):
140-
seg_pred_list[i] = seg_pred_list[i].argmax(1).cpu()
139+
seg_logits_list = self.decode_head.predict(x, batch_inputs_dict,
140+
batch_data_samples)
141+
for i in range(len(seg_logits_list)):
142+
seg_logits_list[i] = seg_logits_list[i].transpose(0, 1)
141143

142-
return self.postprocess_result(seg_pred_list, batch_data_samples)
144+
return self.postprocess_result(seg_logits_list, batch_data_samples)

0 commit comments

Comments
 (0)