Skip to content

Commit c02f694

Browse files
authored
Merge pull request #102 from roboflow/feature/dataset_split
feature/dataset_split
2 parents 83d357d + 55f4f00 commit c02f694

File tree

4 files changed

+157
-4
lines changed

4 files changed

+157
-4
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
- Added [[#100](https://github.com/roboflow/supervision/pull/100)]: support for Dataset inheritance. Current `Dataset` got renamed to `DetectionDataset` and make it inherit from `BaseDataset`.
44
- Added [[#100](https://github.com/roboflow/supervision/pull/100)]: ability to save datasets in YOLO format using `DetectionDataset.as_yolo`.
5+
- Added [[#102](https://github.com/roboflow/supervision/pull/103)]: support for splitting `DetectionDataset`.
56
- Changed [[#100](https://github.com/roboflow/supervision/pull/100)]: default value of `approximation_percentage` parameter from `0.75` to `0.0` in `DetectionDataset.as_yolo` and `DetectionDataset.as_pascal_voc`.
67

78
### 0.7.0 <small>May 11, 2023</small>

supervision/dataset/core.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from abc import ABC, abstractmethod
34
from dataclasses import dataclass
45
from pathlib import Path
56
from typing import Dict, Iterator, List, Optional, Tuple
@@ -16,14 +17,22 @@
1617
save_data_yaml,
1718
save_yolo_annotations,
1819
)
19-
from supervision.dataset.ultils import save_dataset_images
20+
from supervision.dataset.ultils import save_dataset_images, train_test_split
2021
from supervision.detection.core import Detections
2122
from supervision.file import list_files_with_extensions
2223

2324

2425
@dataclass
25-
class BaseDataset:
26-
pass
26+
class BaseDataset(ABC):
27+
@abstractmethod
28+
def __len__(self) -> int:
29+
pass
30+
31+
@abstractmethod
32+
def split(
33+
self, split_ratio=0.8, random_state=None, shuffle: bool = True
34+
) -> Tuple[BaseDataset, BaseDataset]:
35+
pass
2736

2837

2938
@dataclass
@@ -61,6 +70,36 @@ def __iter__(self) -> Iterator[Tuple[str, np.ndarray, Detections]]:
6170
for image_name, image in self.images.items():
6271
yield image_name, image, self.annotations.get(image_name, None)
6372

73+
def split(
74+
self, split_ratio=0.8, random_state=None, shuffle: bool = True
75+
) -> Tuple[DetectionDataset, DetectionDataset]:
76+
"""
77+
Splits the dataset into two parts using the provided split_ratio.
78+
79+
Returns:
80+
Tuple[DetectionDataset, DetectionDataset]: The split datasets.
81+
"""
82+
83+
image_names = list(self.images.keys())
84+
train_names, test_names = train_test_split(
85+
data=image_names,
86+
train_ratio=split_ratio,
87+
random_state=random_state,
88+
shuffle=shuffle,
89+
)
90+
91+
train_dataset = DetectionDataset(
92+
classes=self.classes,
93+
images={name: self.images[name] for name in train_names},
94+
annotations={name: self.annotations[name] for name in train_names},
95+
)
96+
test_dataset = DetectionDataset(
97+
classes=self.classes,
98+
images={name: self.images[name] for name in test_names},
99+
annotations={name: self.annotations[name] for name in test_names},
100+
)
101+
return train_dataset, test_dataset
102+
64103
def as_pascal_voc(
65104
self,
66105
images_directory_path: Optional[str] = None,

supervision/dataset/ultils.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
2+
import random
23
from pathlib import Path
3-
from typing import Dict, List
4+
from typing import Dict, List, Optional, Tuple, TypeVar
45

56
import cv2
67
import numpy as np
@@ -11,6 +12,8 @@
1112
mask_to_polygons,
1213
)
1314

15+
T = TypeVar("T")
16+
1417

1518
def approximate_mask_with_polygons(
1619
mask: np.ndarray,
@@ -48,3 +51,31 @@ def save_dataset_images(
4851
for image_name, image in images.items():
4952
target_image_path = os.path.join(images_directory_path, image_name)
5053
cv2.imwrite(target_image_path, image)
54+
55+
56+
def train_test_split(
57+
data: List[T],
58+
train_ratio: float = 0.8,
59+
random_state: Optional[int] = None,
60+
shuffle: bool = True,
61+
) -> Tuple[List[T], List[T]]:
62+
"""
63+
Splits the data into two parts using the provided train_ratio.
64+
65+
Args:
66+
data (List[T]): The data to split.
67+
train_ratio (float): The ratio of the training set to the entire dataset.
68+
random_state (Optional[int]): The seed for the random number generator.
69+
shuffle (bool): Whether to shuffle the data before splitting.
70+
71+
Returns:
72+
Tuple[List[T], List[T]]: The split data.
73+
"""
74+
if random_state is not None:
75+
random.seed(random_state)
76+
77+
if shuffle:
78+
random.shuffle(data)
79+
80+
split_index = int(len(data) * train_ratio)
81+
return data[:split_index], data[split_index:]

test/dataset/test_utils.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from typing import List, TypeVar, Optional, Tuple
2+
from contextlib import ExitStack as DoesNotRaise
3+
4+
import pytest
5+
6+
from supervision.dataset.ultils import train_test_split
7+
8+
T = TypeVar("T")
9+
10+
11+
@pytest.mark.parametrize(
12+
'data, train_ratio, random_state, shuffle, expected_result, exception',
13+
[
14+
(
15+
[],
16+
0.5,
17+
None,
18+
False,
19+
([], []),
20+
DoesNotRaise()
21+
), # empty data
22+
(
23+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
24+
0.5,
25+
None,
26+
False,
27+
([0, 1, 2, 3, 4], [5, 6, 7, 8, 9]),
28+
DoesNotRaise()
29+
), # data with 10 numbers and 50% train split
30+
(
31+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
32+
1.0,
33+
None,
34+
False,
35+
([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], []),
36+
DoesNotRaise()
37+
), # data with 10 numbers and 100% train split
38+
(
39+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
40+
0.0,
41+
None,
42+
False,
43+
([], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
44+
DoesNotRaise()
45+
), # data with 10 numbers and 0% train split
46+
(
47+
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j'],
48+
0.5,
49+
None,
50+
False,
51+
(['a', 'b', 'c', 'd', 'e'], ['f', 'g', 'h', 'i', 'j']),
52+
DoesNotRaise()
53+
), # data with 10 chars and 50% train split
54+
(
55+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
56+
0.5,
57+
23,
58+
True,
59+
([7, 8, 5, 6, 3], [2, 9, 0, 1, 4]),
60+
DoesNotRaise()
61+
), # data with 10 numbers and 50% train split with 23 random seed
62+
(
63+
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
64+
0.5,
65+
32,
66+
True,
67+
([4, 6, 0, 8, 9], [5, 7, 2, 3, 1]),
68+
DoesNotRaise()
69+
), # data with 10 numbers and 50% train split with 23 random seed
70+
]
71+
)
72+
def test_train_test_split(
73+
data: List[T],
74+
train_ratio: float,
75+
random_state: int,
76+
shuffle: bool,
77+
expected_result: Optional[Tuple[List[T], List[T]]],
78+
exception: Exception
79+
) -> None:
80+
with exception:
81+
result = train_test_split(data=data, train_ratio=train_ratio, random_state=random_state, shuffle=shuffle)
82+
assert result == expected_result

0 commit comments

Comments
 (0)