Skip to content

Commit 7cd57d2

Browse files
Grain Teamcopybara-github
authored andcommitted
Internal
PiperOrigin-RevId: 799280753
1 parent 3b2a332 commit 7cd57d2

File tree

4 files changed

+110
-4
lines changed

4 files changed

+110
-4
lines changed

grain/_src/python/BUILD

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ py_test(
2525
srcs = [
2626
"data_sources_test.py",
2727
],
28-
args = ["--test_srcdir=grain/_src/python"],
28+
args = ["--test_srcdir=grain/_src/python/testdata"],
2929
data = [
3030
"//grain/_src/python/testdata:digits.array_record-00000-of-00002",
3131
"//grain/_src/python/testdata:digits.array_record-00001-of-00002",
@@ -180,7 +180,7 @@ py_test(
180180
srcs = [
181181
"data_loader_test.py",
182182
],
183-
args = ["--test_srcdir=grain/_src/python"],
183+
args = ["--test_srcdir=grain/_src/python/testdata"],
184184
data = [
185185
"//grain/_src/python/testdata:digits.array_record-00000-of-00002",
186186
"//grain/_src/python/testdata:digits.array_record-00001-of-00002",

grain/_src/python/data_loader_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ class DataLoaderTest(absl_parameterized.TestCase):
165165

166166
def setUp(self):
167167
super().setUp()
168-
self.testdata_dir = pathlib.Path(FLAGS.test_srcdir) / "testdata"
168+
self.testdata_dir = pathlib.Path(FLAGS.test_srcdir)
169169
self.read_options = (
170170
options.ReadOptions(num_threads=self.num_threads_per_worker)
171171
if (self.num_threads_per_worker is not None)

grain/_src/python/data_sources.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ def __getitem__(self, record_key: SupportsIndex) -> bytes:
122122
_bytes_read_counter.IncrementBy(len(data), "ArrayRecordDataSource")
123123
return data
124124

125+
@dataset_stats.trace_input_pipeline(stage_category=dataset_stats.IPL_CAT_READ)
126+
def getitems(self, record_keys: Sequence[SupportsIndex]) -> Sequence[Any]:
127+
# TODO: Modify the implementation to use underlying
128+
# ArrayRecordReader to read multiple records at once.
129+
return [self.__getitem__(record_key) for record_key in record_keys]
130+
125131
@property
126132
def paths(self) -> ArrayRecordDataSourcePaths:
127133
return self._paths
@@ -156,6 +162,25 @@ def __getitem__(self, record_key: SupportsIndex) -> T:
156162
"""
157163

158164

165+
@typing.runtime_checkable
166+
class RandomAccessDataSourceWithBatchedRead(
167+
RandomAccessDataSource[T], Protocol, Generic[T]
168+
):
169+
"""Interface for datasources that support batched reads."""
170+
171+
def getitems(self, record_keys: Sequence[SupportsIndex]) -> Sequence[T]:
172+
"""Returns the values for the given record_keys.
173+
174+
This method must be threadsafe and deterministic.
175+
176+
Arguments:
177+
record_keys: A sequence of integers in [0, len(self)-1].
178+
179+
Returns:
180+
The sequence of corresponding records.
181+
"""
182+
183+
159184
class RangeDataSource:
160185
"""Range data source, similar to python range() function."""
161186

grain/_src/python/data_sources_test.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ def setUp(self):
4545
super().setUp()
4646
self.testdata_dir = pathlib.Path(FLAGS.test_srcdir)
4747

48+
def _to_full_paths(
49+
self, paths: str | Sequence[str]
50+
) -> pathlib.Path | Sequence[pathlib.Path]:
51+
if isinstance(paths, epath.PathLike):
52+
return self.testdata_dir / paths
53+
else:
54+
return [(self.testdata_dir / path) for path in paths]
55+
4856

4957
class RangeDataSourceTest(DataSourceTest):
5058

@@ -137,15 +145,88 @@ def test_str(self):
137145

138146
class ArrayRecordDataSourceTest(DataSourceTest):
139147

140-
def test_array_record_data_implements_random_access(self):
148+
def test_array_record_data_implements_random_access_with_batched_read(self):
141149
assert issubclass(
142150
data_sources.ArrayRecordDataSource, data_sources.RandomAccessDataSource
143151
)
152+
assert issubclass(
153+
data_sources.ArrayRecordDataSource,
154+
data_sources.RandomAccessDataSourceWithBatchedRead,
155+
)
144156

145157
def test_array_record_source_empty_sequence(self):
146158
with self.assertRaises(ValueError):
147159
data_sources.ArrayRecordDataSource([])
148160

161+
@parameterized.parameters(
162+
("digits.array_record-00000-of-00002", False),
163+
(
164+
(
165+
"digits.array_record-00000-of-00002",
166+
"digits.array_record-00001-of-00002",
167+
),
168+
True,
169+
),
170+
(["digits.array_record-00000-of-00002"], False),
171+
)
172+
def test_array_record_data_source_sequential_get(
173+
self, paths: str | Sequence[str], expect_both_shards: bool
174+
):
175+
ar_ds = data_sources.ArrayRecordDataSource(paths=self._to_full_paths(paths))
176+
expected_data = [b"0", b"1", b"2", b"3", b"4"]
177+
if expect_both_shards:
178+
expected_data += [b"5", b"6", b"7", b"8", b"9"]
179+
actual_data = [ar_ds[i] for i in range(len(ar_ds))]
180+
self.assertEqual(expected_data, actual_data)
181+
182+
@parameterized.parameters(
183+
("digits.array_record-00000-of-00002", False),
184+
(
185+
(
186+
"digits.array_record-00000-of-00002",
187+
"digits.array_record-00001-of-00002",
188+
),
189+
True,
190+
),
191+
(["digits.array_record-00000-of-00002"], False),
192+
)
193+
def test_array_record_data_source_get_all_items(
194+
self, paths: str | Sequence[str], expect_both_shards: bool
195+
):
196+
ar_ds = data_sources.ArrayRecordDataSource(paths=self._to_full_paths(paths))
197+
data = [b"0", b"1", b"2", b"3", b"4"]
198+
if expect_both_shards:
199+
data += [b"5", b"6", b"7", b"8", b"9"]
200+
201+
indices_to_read = list(range(len(ar_ds)))
202+
expected_data = [data[idx] for idx in indices_to_read]
203+
actual_data = ar_ds.getitems(indices_to_read)
204+
self.assertEqual(expected_data, actual_data)
205+
206+
@parameterized.parameters(
207+
("digits.array_record-00000-of-00002", False),
208+
(
209+
(
210+
"digits.array_record-00000-of-00002",
211+
"digits.array_record-00001-of-00002",
212+
),
213+
True,
214+
),
215+
(["digits.array_record-00000-of-00002"], False),
216+
)
217+
def test_array_record_data_source_get_multiple_items(
218+
self, paths: str | Sequence[str], expect_both_shards: bool
219+
):
220+
ar_ds = data_sources.ArrayRecordDataSource(paths=self._to_full_paths(paths))
221+
data = [b"0", b"1", b"2", b"3", b"4"]
222+
if expect_both_shards:
223+
data += [b"5", b"6", b"7", b"8", b"9"]
224+
225+
indices_to_read = list(range(len(ar_ds)))[slice(0, None, 2)]
226+
expected_data = [data[idx] for idx in indices_to_read]
227+
actual_data = ar_ds.getitems(indices_to_read)
228+
self.assertEqual(expected_data, actual_data)
229+
149230

150231
if __name__ == "__main__":
151232
absltest.main()

0 commit comments

Comments
 (0)