@@ -45,6 +45,14 @@ def setUp(self):
45
45
super ().setUp ()
46
46
self .testdata_dir = pathlib .Path (FLAGS .test_srcdir )
47
47
48
+ def _to_full_paths (
49
+ self , paths : str | Sequence [str ]
50
+ ) -> pathlib .Path | Sequence [pathlib .Path ]:
51
+ if isinstance (paths , epath .PathLike ):
52
+ return self .testdata_dir / paths
53
+ else :
54
+ return [(self .testdata_dir / path ) for path in paths ]
55
+
48
56
49
57
class RangeDataSourceTest (DataSourceTest ):
50
58
@@ -137,15 +145,88 @@ def test_str(self):
137
145
138
146
class ArrayRecordDataSourceTest (DataSourceTest ):
139
147
140
- def test_array_record_data_implements_random_access (self ):
148
+ def test_array_record_data_implements_random_access_with_batched_read (self ):
141
149
assert issubclass (
142
150
data_sources .ArrayRecordDataSource , data_sources .RandomAccessDataSource
143
151
)
152
+ assert issubclass (
153
+ data_sources .ArrayRecordDataSource ,
154
+ data_sources .RandomAccessDataSourceWithBatchedRead ,
155
+ )
144
156
145
157
def test_array_record_source_empty_sequence (self ):
146
158
with self .assertRaises (ValueError ):
147
159
data_sources .ArrayRecordDataSource ([])
148
160
161
+ @parameterized .parameters (
162
+ ("digits.array_record-00000-of-00002" , False ),
163
+ (
164
+ (
165
+ "digits.array_record-00000-of-00002" ,
166
+ "digits.array_record-00001-of-00002" ,
167
+ ),
168
+ True ,
169
+ ),
170
+ (["digits.array_record-00000-of-00002" ], False ),
171
+ )
172
+ def test_array_record_data_source_sequential_get (
173
+ self , paths : str | Sequence [str ], expect_both_shards : bool
174
+ ):
175
+ ar_ds = data_sources .ArrayRecordDataSource (paths = self ._to_full_paths (paths ))
176
+ expected_data = [b"0" , b"1" , b"2" , b"3" , b"4" ]
177
+ if expect_both_shards :
178
+ expected_data += [b"5" , b"6" , b"7" , b"8" , b"9" ]
179
+ actual_data = [ar_ds [i ] for i in range (len (ar_ds ))]
180
+ self .assertEqual (expected_data , actual_data )
181
+
182
+ @parameterized .parameters (
183
+ ("digits.array_record-00000-of-00002" , False ),
184
+ (
185
+ (
186
+ "digits.array_record-00000-of-00002" ,
187
+ "digits.array_record-00001-of-00002" ,
188
+ ),
189
+ True ,
190
+ ),
191
+ (["digits.array_record-00000-of-00002" ], False ),
192
+ )
193
+ def test_array_record_data_source_get_all_items (
194
+ self , paths : str | Sequence [str ], expect_both_shards : bool
195
+ ):
196
+ ar_ds = data_sources .ArrayRecordDataSource (paths = self ._to_full_paths (paths ))
197
+ data = [b"0" , b"1" , b"2" , b"3" , b"4" ]
198
+ if expect_both_shards :
199
+ data += [b"5" , b"6" , b"7" , b"8" , b"9" ]
200
+
201
+ indices_to_read = list (range (len (ar_ds )))
202
+ expected_data = [data [idx ] for idx in indices_to_read ]
203
+ actual_data = ar_ds .getitems (indices_to_read )
204
+ self .assertEqual (expected_data , actual_data )
205
+
206
+ @parameterized .parameters (
207
+ ("digits.array_record-00000-of-00002" , False ),
208
+ (
209
+ (
210
+ "digits.array_record-00000-of-00002" ,
211
+ "digits.array_record-00001-of-00002" ,
212
+ ),
213
+ True ,
214
+ ),
215
+ (["digits.array_record-00000-of-00002" ], False ),
216
+ )
217
+ def test_array_record_data_source_get_multiple_items (
218
+ self , paths : str | Sequence [str ], expect_both_shards : bool
219
+ ):
220
+ ar_ds = data_sources .ArrayRecordDataSource (paths = self ._to_full_paths (paths ))
221
+ data = [b"0" , b"1" , b"2" , b"3" , b"4" ]
222
+ if expect_both_shards :
223
+ data += [b"5" , b"6" , b"7" , b"8" , b"9" ]
224
+
225
+ indices_to_read = list (range (len (ar_ds )))[slice (0 , None , 2 )]
226
+ expected_data = [data [idx ] for idx in indices_to_read ]
227
+ actual_data = ar_ds .getitems (indices_to_read )
228
+ self .assertEqual (expected_data , actual_data )
229
+
149
230
150
231
if __name__ == "__main__" :
151
232
absltest .main ()
0 commit comments