Skip to content

Commit 9fe538c

Browse files
[Feature] Support Objects365 Dataset (#9600)
1 parent f8d056d commit 9fe538c

19 files changed

+1401
-26
lines changed

.dev_scripts/gather_models.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,9 @@ def get_dataset_name(config):
143143
VOCDataset='Pascal VOC',
144144
WIDERFaceDataset='WIDER Face',
145145
OpenImagesDataset='OpenImagesDataset',
146-
OpenImagesChallengeDataset='OpenImagesChallengeDataset')
146+
OpenImagesChallengeDataset='OpenImagesChallengeDataset',
147+
Objects365V1Dataset='Objects365 v1',
148+
Objects365V2Dataset='Objects365 v2')
147149
cfg = Config.fromfile('./configs/' + config)
148150
return name_map[cfg.dataset_type]
149151

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# dataset settings
2+
dataset_type = 'Objects365V1Dataset'
3+
data_root = 'data/Objects365/Obj365_v1/'
4+
5+
# file_client_args = dict(
6+
# backend='petrel',
7+
# path_mapping=dict({
8+
# './data/': 's3://openmmlab/datasets/detection/',
9+
# 'data/': 's3://openmmlab/datasets/detection/'
10+
# }))
11+
file_client_args = dict(backend='disk')
12+
13+
train_pipeline = [
14+
dict(type='LoadImageFromFile', file_client_args=file_client_args),
15+
dict(type='LoadAnnotations', with_bbox=True),
16+
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
17+
dict(type='RandomFlip', prob=0.5),
18+
dict(type='PackDetInputs')
19+
]
20+
test_pipeline = [
21+
dict(type='LoadImageFromFile', file_client_args=file_client_args),
22+
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
23+
# If you don't have a gt annotation, delete the pipeline
24+
dict(type='LoadAnnotations', with_bbox=True),
25+
dict(
26+
type='PackDetInputs',
27+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
28+
'scale_factor'))
29+
]
30+
train_dataloader = dict(
31+
batch_size=2,
32+
num_workers=2,
33+
persistent_workers=True,
34+
sampler=dict(type='DefaultSampler', shuffle=True),
35+
batch_sampler=dict(type='AspectRatioBatchSampler'),
36+
dataset=dict(
37+
type=dataset_type,
38+
data_root=data_root,
39+
ann_file='annotations/objects365_train.json',
40+
data_prefix=dict(img='train/'),
41+
filter_cfg=dict(filter_empty_gt=True, min_size=32),
42+
pipeline=train_pipeline))
43+
val_dataloader = dict(
44+
batch_size=1,
45+
num_workers=2,
46+
persistent_workers=True,
47+
drop_last=False,
48+
sampler=dict(type='DefaultSampler', shuffle=False),
49+
dataset=dict(
50+
type=dataset_type,
51+
data_root=data_root,
52+
ann_file='annotations/objects365_val.json',
53+
data_prefix=dict(img='val/'),
54+
test_mode=True,
55+
pipeline=test_pipeline))
56+
test_dataloader = val_dataloader
57+
58+
val_evaluator = dict(
59+
type='CocoMetric',
60+
ann_file=data_root + 'annotations/objects365_val.json',
61+
metric='bbox',
62+
sort_categories=True,
63+
format_only=False)
64+
test_evaluator = val_evaluator
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# dataset settings
2+
dataset_type = 'Objects365V2Dataset'
3+
data_root = 'data/Objects365/Obj365_v2/'
4+
5+
# file_client_args = dict(
6+
# backend='petrel',
7+
# path_mapping=dict({
8+
# './data/': 's3://openmmlab/datasets/detection/',
9+
# 'data/': 's3://openmmlab/datasets/detection/'
10+
# }))
11+
file_client_args = dict(backend='disk')
12+
13+
train_pipeline = [
14+
dict(type='LoadImageFromFile', file_client_args=file_client_args),
15+
dict(type='LoadAnnotations', with_bbox=True),
16+
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
17+
dict(type='RandomFlip', prob=0.5),
18+
dict(type='PackDetInputs')
19+
]
20+
test_pipeline = [
21+
dict(type='LoadImageFromFile', file_client_args=file_client_args),
22+
dict(type='Resize', scale=(1333, 800), keep_ratio=True),
23+
# If you don't have a gt annotation, delete the pipeline
24+
dict(type='LoadAnnotations', with_bbox=True),
25+
dict(
26+
type='PackDetInputs',
27+
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
28+
'scale_factor'))
29+
]
30+
train_dataloader = dict(
31+
batch_size=2,
32+
num_workers=2,
33+
persistent_workers=True,
34+
sampler=dict(type='DefaultSampler', shuffle=True),
35+
batch_sampler=dict(type='AspectRatioBatchSampler'),
36+
dataset=dict(
37+
type=dataset_type,
38+
data_root=data_root,
39+
ann_file='annotations/zhiyuan_objv2_train.json',
40+
data_prefix=dict(img='train/'),
41+
filter_cfg=dict(filter_empty_gt=True, min_size=32),
42+
pipeline=train_pipeline))
43+
val_dataloader = dict(
44+
batch_size=1,
45+
num_workers=2,
46+
persistent_workers=True,
47+
drop_last=False,
48+
sampler=dict(type='DefaultSampler', shuffle=False),
49+
dataset=dict(
50+
type=dataset_type,
51+
data_root=data_root,
52+
ann_file='annotations/zhiyuan_objv2_val.json',
53+
data_prefix=dict(img='val/'),
54+
test_mode=True,
55+
pipeline=test_pipeline))
56+
test_dataloader = val_dataloader
57+
58+
val_evaluator = dict(
59+
type='CocoMetric',
60+
ann_file=data_root + 'annotations/zhiyuan_objv2_val.json',
61+
metric='bbox',
62+
format_only=False)
63+
test_evaluator = val_evaluator

configs/objects365/README.md

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Objects365 Dataset
2+
3+
> [Objects365 Dataset](https://openaccess.thecvf.com/content_ICCV_2019/papers/Shao_Objects365_A_Large-Scale_High-Quality_Dataset_for_Object_Detection_ICCV_2019_paper.pdf)
4+
5+
<!-- [DATASET] -->
6+
7+
## Abstract
8+
9+
<!-- [ABSTRACT] -->
10+
11+
#### Objects365 Dataset V1
12+
13+
[Objects365 Dataset V1](http://www.objects365.org/overview.html) is a brand new dataset,
14+
designed to spur object detection research with a focus on diverse objects in the Wild.
15+
It has 365 object categories over 600K training images. More than 10 million, high-quality bounding boxes are manually labeled through a three-step, carefully designed annotation pipeline. It is the largest object detection dataset (with full annotation) so far and establishes a more challenging benchmark for the community. Objects365 can serve as a better feature learning dataset for localization-sensitive tasks like object detection
16+
and semantic segmentation.
17+
18+
<!-- [IMAGE] -->
19+
20+
<div align=center>
21+
<img src="https://user-images.githubusercontent.com/48282753/208368046-b7573022-06c9-4a99-af17-a6ac7407e3d8.png" height="400"/>
22+
</div>
23+
24+
#### Objects365 Dataset V2
25+
26+
[Objects365 Dataset V2](http://www.objects365.org/overview.html) is based on the V1 release of the Objects365 dataset.
27+
Objects 365 annotated 365 object classes on more than 1800k images, with more than 29 million bounding boxes in the training set, surpassing PASCAL VOC, ImageNet, and COCO datasets.
28+
Objects 365 includes 11 categories of people, clothing, living room, bathroom, kitchen, office/medical, electrical appliances, transportation, food, animals, sports/musical instruments, and each category has dozens of subcategories.
29+
30+
## Citation
31+
32+
```
33+
@inproceedings{shao2019objects365,
34+
title={Objects365: A large-scale, high-quality dataset for object detection},
35+
author={Shao, Shuai and Li, Zeming and Zhang, Tianyuan and Peng, Chao and Yu, Gang and Zhang, Xiangyu and Li, Jing and Sun, Jian},
36+
booktitle={Proceedings of the IEEE/CVF international conference on computer vision},
37+
pages={8430--8439},
38+
year={2019}
39+
}
40+
```
41+
42+
## Prepare Dataset
43+
44+
1. You need to download and extract Objects365 dataset. Users can download Objects365 V2 by using `tools/misc/download_dataset.py`.
45+
46+
**Usage**
47+
48+
```shell
49+
python tools/misc/download_dataset.py --dataset-name objects365v2 \
50+
--save-dir ${SAVING PATH} \
51+
--unzip \
52+
--delete # Optional, delete the download zip file
53+
```
54+
55+
**Note:** There is no download link for Objects365 V1 right now. If you would like to download Objects365-V1, please visit [official website](http://www.objects365.org/) to concat the author.
56+
57+
2. The directory should be like this:
58+
59+
```none
60+
mmdetection
61+
├── mmdet
62+
├── tools
63+
├── configs
64+
├── data
65+
│ ├── Objects365
66+
│ │ ├── Obj365_v1
67+
│ │ │ ├── annotations
68+
│ │ │ │ ├── objects365_train.json
69+
│ │ │ │ ├── objects365_val.json
70+
│ │ │ ├── train # training images
71+
│ │ │ ├── val # validation images
72+
│ │ ├── Obj365_v2
73+
│ │ │ ├── annotations
74+
│ │ │ │ ├── zhiyuan_objv2_train.json
75+
│ │ │ │ ├── zhiyuan_objv2_val.json
76+
│ │ │ ├── train # training images
77+
│ │ │ │ ├── patch0
78+
│ │ │ │ ├── patch1
79+
│ │ │ │ ├── ...
80+
│ │ │ ├── val # validation images
81+
│ │ │ │ ├── patch0
82+
│ │ │ │ ├── patch1
83+
│ │ │ │ ├── ...
84+
```
85+
86+
## Results and Models
87+
88+
### Objects365 V1
89+
90+
| Architecture | Backbone | Style | Lr schd | Mem (GB) | box AP | Config | Download |
91+
| :----------: | :------: | :-----: | :-----: | :------: | :----: | :------------------------------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
92+
| Faster R-CNN | R-50 | pytorch | 1x | - | 19.6 | [config](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/objects365/faster-rcnn_r50_fpn_16xb4-1x_objects365v1.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v1/faster_rcnn_r50_fpn_16x4_1x_obj365v1_20221219_181226-9ff10f95.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v1/faster_rcnn_r50_fpn_16x4_1x_obj365v1_20221219_181226.log.json) |
93+
| Faster R-CNN | R-50 | pytorch | 1350K | - | 22.3 | [config](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/objects365/faster-rcnn_r50-syncbn_fpn_1350k_objects365v1.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_syncbn_1350k_obj365v1/faster_rcnn_r50_fpn_syncbn_1350k_obj365v1_20220510_142457-337d8965.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_syncbn_1350k_obj365v1/faster_rcnn_r50_fpn_syncbn_1350k_obj365v1_20220510_142457.log.json) |
94+
| Retinanet | R-50 | pytorch | 1x | - | 14.8 | [config](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/objects365/retinanet_r50_fpn_1x_objects365v1.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_1x_obj365v1/retinanet_r50_fpn_1x_obj365v1_20221219_181859-ba3e3dd5.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_1x_obj365v1/retinanet_r50_fpn_1x_obj365v1_20221219_181859.log.json) |
95+
| Retinanet | R-50 | pytorch | 1350K | - | 18.0 | [config](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/objects365/retinanet_r50-syncbn_fpn_1350k_objects365v1.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_syncbn_1350k_obj365v1/retinanet_r50_fpn_syncbn_1350k_obj365v1_20220513_111237-7517c576.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_syncbn_1350k_obj365v1/retinanet_r50_fpn_syncbn_1350k_obj365v1_20220513_111237.log.json) |
96+
97+
### Objects365 V2
98+
99+
| Architecture | Backbone | Style | Lr schd | Mem (GB) | box AP | Config | Download |
100+
| :----------: | :------: | :-----: | :-----: | :------: | :----: | :--------------------------------------------------------------------------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
101+
| Faster R-CNN | R-50 | pytorch | 1x | - | 19.8 | [config](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/objects365/faster-rcnn_r50_fpn_16xb4-1x_objects365v2.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v2/faster_rcnn_r50_fpn_16x4_1x_obj365v2_20221220_175040-5910b015.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/objects365/faster_rcnn_r50_fpn_16x4_1x_obj365v2/faster_rcnn_r50_fpn_16x4_1x_obj365v2_20221220_175040.log.json) |
102+
| Retinanet | R-50 | pytorch | 1x | - | 16.7 | [config](https://github.com/open-mmlab/mmdetection/tree/3.x/configs/objects365/retinanet_r50_fpn_1x_objects365v2.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_1x_obj365v2/retinanet_r50_fpn_1x_obj365v2_20221223_122105-d9b191f1.pth) \| [log](https://download.openmmlab.com/mmdetection/v2.0/objects365/retinanet_r50_fpn_1x_obj365v2/retinanet_r50_fpn_1x_obj365v2_20221223_122105.log.json) |
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
_base_ = [
2+
'../_base_/models/faster-rcnn_r50_fpn.py',
3+
'../_base_/datasets/objects365v2_detection.py',
4+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
5+
]
6+
7+
model = dict(
8+
backbone=dict(norm_cfg=dict(type='SyncBN', requires_grad=True)),
9+
roi_head=dict(bbox_head=dict(num_classes=365)))
10+
11+
# training schedule for 1350K
12+
train_cfg = dict(
13+
_delete_=True,
14+
type='IterBasedTrainLoop',
15+
max_iters=1350000, # 36 epochs
16+
val_interval=150000)
17+
18+
# Using 8 GPUS while training
19+
optim_wrapper = dict(
20+
type='OptimWrapper',
21+
optimizer=dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001),
22+
clip_grad=dict(max_norm=35, norm_type=2))
23+
24+
# learning rate policy
25+
param_scheduler = [
26+
dict(
27+
type='LinearLR',
28+
start_factor=1.0 / 1000,
29+
by_epoch=False,
30+
begin=0,
31+
end=1000),
32+
dict(
33+
type='MultiStepLR',
34+
begin=0,
35+
end=1350000,
36+
by_epoch=False,
37+
milestones=[900000, 1200000],
38+
gamma=0.1)
39+
]
40+
41+
train_dataloader = dict(sampler=dict(type='InfiniteSampler'))
42+
default_hooks = dict(checkpoint=dict(by_epoch=False, interval=150000))
43+
44+
log_processor = dict(by_epoch=False)
45+
46+
# NOTE: `auto_scale_lr` is for automatically scaling LR,
47+
# USER SHOULD NOT CHANGE ITS VALUES.
48+
# base_batch_size = (8 GPUs) x (2 samples per GPU)
49+
auto_scale_lr = dict(base_batch_size=16)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
_base_ = [
2+
'../_base_/models/faster-rcnn_r50_fpn.py',
3+
'../_base_/datasets/objects365v1_detection.py',
4+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
5+
]
6+
7+
model = dict(roi_head=dict(bbox_head=dict(num_classes=365)))
8+
9+
train_dataloader = dict(
10+
batch_size=4, # using 16 GPUS while training. total batch size is 16 x 4)
11+
)
12+
13+
# Using 32 GPUS while training
14+
optim_wrapper = dict(
15+
type='OptimWrapper',
16+
optimizer=dict(type='SGD', lr=0.08, momentum=0.9, weight_decay=0.0001),
17+
clip_grad=dict(max_norm=35, norm_type=2))
18+
19+
# learning rate
20+
param_scheduler = [
21+
dict(
22+
type='LinearLR',
23+
start_factor=1.0 / 1000,
24+
by_epoch=False,
25+
begin=0,
26+
end=1000),
27+
dict(
28+
type='MultiStepLR',
29+
begin=0,
30+
end=12,
31+
by_epoch=True,
32+
milestones=[8, 11],
33+
gamma=0.1)
34+
]
35+
36+
# NOTE: `auto_scale_lr` is for automatically scaling LR,
37+
# USER SHOULD NOT CHANGE ITS VALUES.
38+
# base_batch_size = (32 GPUs) x (2 samples per GPU)
39+
auto_scale_lr = dict(base_batch_size=64)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
_base_ = [
2+
'../_base_/models/faster-rcnn_r50_fpn.py',
3+
'../_base_/datasets/objects365v2_detection.py',
4+
'../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
5+
]
6+
7+
model = dict(roi_head=dict(bbox_head=dict(num_classes=365)))
8+
9+
train_dataloader = dict(
10+
batch_size=4, # using 16 GPUS while training. total batch size is 16 x 4)
11+
)
12+
13+
# Using 32 GPUS while training
14+
optim_wrapper = dict(
15+
type='OptimWrapper',
16+
optimizer=dict(type='SGD', lr=0.08, momentum=0.9, weight_decay=0.0001),
17+
clip_grad=dict(max_norm=35, norm_type=2))
18+
19+
# learning rate
20+
param_scheduler = [
21+
dict(
22+
type='LinearLR',
23+
start_factor=1.0 / 1000,
24+
by_epoch=False,
25+
begin=0,
26+
end=1000),
27+
dict(
28+
type='MultiStepLR',
29+
begin=0,
30+
end=12,
31+
by_epoch=True,
32+
milestones=[8, 11],
33+
gamma=0.1)
34+
]
35+
36+
# NOTE: `auto_scale_lr` is for automatically scaling LR,
37+
# USER SHOULD NOT CHANGE ITS VALUES.
38+
# base_batch_size = (32 GPUs) x (2 samples per GPU)
39+
auto_scale_lr = dict(base_batch_size=64)

0 commit comments

Comments
 (0)