17
17
18
18
from absl .testing import absltest
19
19
from absl .testing import parameterized
20
- from grain ._src .core import exceptions
21
20
from grain ._src .python .dataset import dataset
22
21
from grain ._src .python .dataset .transformations import packing_concat_then_split
23
22
from grain ._src .python .dataset .transformations import source
@@ -43,11 +42,14 @@ class ConcatThenSplitIterDatasetTest(parameterized.TestCase):
43
42
44
43
# observations will be [
45
44
# [1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],
46
- # [6, 6, 6, 6, 6, 6], [1], [2, 2], [3, 3, 3], ...
45
+ # [6, 6, 6, 6, 6, 6], [7, 7, 7, 7, 7, 7, 7], [ 1], [2, 2], [3, 3, 3], ...
47
46
# ].
48
47
def dummy_iter_dataset (self , * , num_observations : int ) -> dataset .IterDataset :
49
48
return (
50
- source .RangeMapDataset (1 , 7 )
49
+ # On purpose, we have observations longer (length=7) than the packing
50
+ # sequence length of most test cases (6), so we can test splitting long
51
+ # features.
52
+ source .RangeMapDataset (1 , 8 )
51
53
.repeat ()
52
54
.map_with_index (
53
55
lambda index , value : {
@@ -92,16 +94,22 @@ def test_meta_features_not_restricting_when_splitting_full_length_features(
92
94
"index" : np .asarray ([5 , 6 , 0 , 0 , 0 , 0 ]),
93
95
},
94
96
{
95
- "observation" : np .asarray ([6 , 6 , 6 , 1 , 2 , 2 ]),
96
- "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 2 , 3 , 3 ]),
97
- "observation_positions" : np .asarray ([0 , 1 , 2 , 0 , 0 , 1 ]),
98
- "index" : np .asarray ([6 , 7 , 8 , 0 , 0 , 0 ]),
97
+ "observation" : np .asarray ([6 , 6 , 6 , 7 , 7 , 7 ]),
98
+ "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 2 , 2 , 2 ]),
99
+ "observation_positions" : np .asarray ([0 , 1 , 2 , 0 , 1 , 2 ]),
100
+ "index" : np .asarray ([6 , 7 , 0 , 0 , 0 , 0 ]),
101
+ },
102
+ {
103
+ "observation" : np .asarray ([7 , 7 , 7 , 7 , 1 , 2 ]),
104
+ "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 1 , 2 , 3 ]),
105
+ "observation_positions" : np .asarray ([0 , 1 , 2 , 3 , 0 , 0 ]),
106
+ "index" : np .asarray ([7 , 8 , 9 , 0 , 0 , 0 ]),
99
107
},
100
108
# Reached end.
101
109
{
102
- "observation" : np .asarray ([3 , 3 , 3 , 0 , 0 , 0 ]),
103
- "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 0 , 0 , 0 ]),
104
- "observation_positions" : np .asarray ([0 , 1 , 2 , 0 , 0 , 0 ]),
110
+ "observation" : np .asarray ([2 , 0 , 0 , 0 , 0 , 0 ]),
111
+ "observation_segment_ids" : np .asarray ([1 , 0 , 0 , 0 , 0 , 0 ]),
112
+ "observation_positions" : np .asarray ([0 , 0 , 0 , 0 , 0 , 0 ]),
105
113
"index" : np .asarray ([9 , 0 , 0 , 0 , 0 , 0 ]),
106
114
},
107
115
],
@@ -132,10 +140,10 @@ def test_meta_features_not_restricting(self):
132
140
"index" : np .asarray ([4 , 5 , 0 , 0 , 0 , 0 ]),
133
141
},
134
142
{
135
- "observation" : np .asarray ([5 , 5 , 5 , 1 , 2 , 2 ]),
136
- "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 2 , 3 , 3 ]),
137
- "observation_positions" : np .asarray ([0 , 1 , 2 , 0 , 0 , 1 ]),
138
- "index" : np .asarray ([5 , 7 , 8 , 0 , 0 , 0 ]),
143
+ "observation" : np .asarray ([5 , 5 , 5 , 7 , 7 , 7 ]),
144
+ "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 2 , 2 , 2 ]),
145
+ "observation_positions" : np .asarray ([0 , 1 , 2 , 0 , 1 , 2 ]),
146
+ "index" : np .asarray ([5 , 7 , 0 , 0 , 0 , 0 ]),
139
147
},
140
148
# Fully packed example comes without being split.
141
149
{
@@ -144,11 +152,17 @@ def test_meta_features_not_restricting(self):
144
152
"observation_positions" : np .asarray ([0 , 1 , 2 , 3 , 4 , 5 ]),
145
153
"index" : np .asarray ([6 , 0 , 0 , 0 , 0 , 0 ]),
146
154
},
155
+ {
156
+ "observation" : np .asarray ([7 , 7 , 7 , 7 , 1 , 2 ]),
157
+ "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 1 , 2 , 3 ]),
158
+ "observation_positions" : np .asarray ([0 , 1 , 2 , 3 , 0 , 0 ]),
159
+ "index" : np .asarray ([7 , 8 , 9 , 0 , 0 , 0 ]),
160
+ },
147
161
# Reached end.
148
162
{
149
- "observation" : np .asarray ([3 , 3 , 3 , 0 , 0 , 0 ]),
150
- "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 0 , 0 , 0 ]),
151
- "observation_positions" : np .asarray ([0 , 1 , 2 , 0 , 0 , 0 ]),
163
+ "observation" : np .asarray ([2 , 0 , 0 , 0 , 0 , 0 ]),
164
+ "observation_segment_ids" : np .asarray ([1 , 0 , 0 , 0 , 0 , 0 ]),
165
+ "observation_positions" : np .asarray ([0 , 0 , 0 , 0 , 0 , 0 ]),
152
166
"index" : np .asarray ([9 , 0 , 0 , 0 , 0 , 0 ]),
153
167
},
154
168
],
@@ -191,15 +205,21 @@ def test_meta_features_restricting(self):
191
205
"index" : np .asarray ([6 , 0 ]),
192
206
},
193
207
{
194
- "observation" : np .asarray ([1 , 2 , 2 , 0 , 0 , 0 ]),
195
- "observation_segment_ids" : np .asarray ([1 , 2 , 2 , 0 , 0 , 0 ]),
196
- "observation_positions" : np .asarray ([0 , 0 , 1 , 0 , 0 , 0 ]),
208
+ "observation" : np .asarray ([7 , 7 , 7 , 7 , 7 , 7 ]),
209
+ "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 1 , 1 , 1 ]),
210
+ "observation_positions" : np .asarray ([0 , 1 , 2 , 3 , 4 , 5 ]),
211
+ "index" : np .asarray ([7 , 0 ]),
212
+ },
213
+ {
214
+ "observation" : np .asarray ([7 , 1 , 0 , 0 , 0 , 0 ]),
215
+ "observation_segment_ids" : np .asarray ([1 , 2 , 0 , 0 , 0 , 0 ]),
216
+ "observation_positions" : np .asarray ([0 , 0 , 0 , 0 , 0 , 0 ]),
197
217
"index" : np .asarray ([7 , 8 ]),
198
218
},
199
219
{
200
- "observation" : np .asarray ([3 , 3 , 3 , 0 , 0 , 0 ]),
201
- "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 0 , 0 , 0 ]),
202
- "observation_positions" : np .asarray ([0 , 1 , 2 , 0 , 0 , 0 ]),
220
+ "observation" : np .asarray ([2 , 2 , 0 , 0 , 0 , 0 ]),
221
+ "observation_segment_ids" : np .asarray ([1 , 1 , 0 , 0 , 0 , 0 ]),
222
+ "observation_positions" : np .asarray ([0 , 1 , 0 , 0 , 0 , 0 ]),
203
223
"index" : np .asarray ([9 , 0 ]),
204
224
},
205
225
],
@@ -233,10 +253,10 @@ def test_replace_first_token_with_bos(self):
233
253
"index" : np .asarray ([4 , 5 , 0 , 0 , 0 , 0 ]),
234
254
},
235
255
{
236
- "observation" : np .asarray ([1000 , 5 , 5 , 1000 , 1000 , 2 ]),
237
- "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 2 , 3 , 3 ]),
238
- "observation_positions" : np .asarray ([0 , 1 , 2 , 0 , 0 , 1 ]),
239
- "index" : np .asarray ([5 , 7 , 8 , 0 , 0 , 0 ]),
256
+ "observation" : np .asarray ([1000 , 5 , 5 , 1000 , 7 , 7 ]),
257
+ "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 2 , 2 , 2 ]),
258
+ "observation_positions" : np .asarray ([0 , 1 , 2 , 0 , 1 , 2 ]),
259
+ "index" : np .asarray ([5 , 7 , 0 , 0 , 0 , 0 ]),
240
260
},
241
261
# Fully packed example comes without being split.
242
262
{
@@ -245,11 +265,17 @@ def test_replace_first_token_with_bos(self):
245
265
"observation_positions" : np .asarray ([0 , 1 , 2 , 3 , 4 , 5 ]),
246
266
"index" : np .asarray ([6 , 0 , 0 , 0 , 0 , 0 ]),
247
267
},
268
+ {
269
+ "observation" : np .asarray ([1000 , 7 , 7 , 7 , 1000 , 1000 ]),
270
+ "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 1 , 2 , 3 ]),
271
+ "observation_positions" : np .asarray ([0 , 1 , 2 , 3 , 0 , 0 ]),
272
+ "index" : np .asarray ([7 , 8 , 9 , 0 , 0 , 0 ]),
273
+ },
248
274
# Reached end.
249
275
{
250
- "observation" : np .asarray ([1000 , 3 , 3 , 0 , 0 , 0 ]),
251
- "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 0 , 0 , 0 ]),
252
- "observation_positions" : np .asarray ([0 , 1 , 2 , 0 , 0 , 0 ]),
276
+ "observation" : np .asarray ([1000 , 0 , 0 , 0 , 0 , 0 ]),
277
+ "observation_segment_ids" : np .asarray ([1 , 0 , 0 , 0 , 0 , 0 ]),
278
+ "observation_positions" : np .asarray ([0 , 0 , 0 , 0 , 0 , 0 ]),
253
279
"index" : np .asarray ([9 , 0 , 0 , 0 , 0 , 0 ]),
254
280
},
255
281
],
@@ -320,27 +346,27 @@ def _create_iter(state: dict[str, Any] | None):
320
346
"index" : np .asarray ([4 , 5 , 6 , 0 , 0 , 0 ]),
321
347
},
322
348
{
323
- "observation" : np .asarray ([6 , 6 , 6 , 6 , 6 , 1 , 2 , 2 ]),
324
- "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 1 , 1 , 2 , 3 , 3 ]),
325
- "observation_positions" : np .asarray ([0 , 1 , 2 , 3 , 4 , 0 , 0 , 1 ]),
326
- "index" : np .asarray ([6 , 7 , 8 , 0 , 0 , 0 ]),
349
+ "observation" : np .asarray ([6 , 6 , 6 , 6 , 6 , 7 , 7 , 7 ]),
350
+ "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 1 , 1 , 2 , 2 , 2 ]),
351
+ "observation_positions" : np .asarray ([0 , 1 , 2 , 3 , 4 , 0 , 1 , 2 ]),
352
+ "index" : np .asarray ([6 , 7 , 0 , 0 , 0 , 0 ]),
327
353
},
328
354
{
329
- "observation" : np .asarray ([3 , 3 , 3 , 4 , 4 , 4 , 4 , 5 ]),
330
- "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 2 , 2 , 2 , 2 , 3 ]),
331
- "observation_positions" : np .asarray ([0 , 1 , 2 , 0 , 1 , 2 , 3 , 0 ]),
332
- "index" : np .asarray ([9 , 10 , 11 , 0 , 0 , 0 ]),
355
+ "observation" : np .asarray ([7 , 7 , 7 , 7 , 1 , 2 , 2 , 3 ]),
356
+ "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 1 , 2 , 3 , 3 , 4 ]),
357
+ "observation_positions" : np .asarray ([0 , 1 , 2 , 3 , 0 , 0 , 1 , 0 ]),
358
+ "index" : np .asarray ([7 , 8 , 9 , 10 , 0 , 0 ]),
333
359
},
334
360
{
335
- "observation" : np .asarray ([5 , 5 , 5 , 5 , 6 , 6 , 6 , 6 ]),
336
- "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 1 , 2 , 2 , 2 , 2 ]),
337
- "observation_positions" : np .asarray ([0 , 1 , 2 , 3 , 0 , 1 , 2 , 3 ]),
338
- "index" : np .asarray ([11 , 12 , 0 , 0 , 0 , 0 ]),
361
+ "observation" : np .asarray ([3 , 3 , 4 , 4 , 4 , 4 , 5 , 5 ]),
362
+ "observation_segment_ids" : np .asarray ([1 , 1 , 2 , 2 , 2 , 2 , 3 , 3 ]),
363
+ "observation_positions" : np .asarray ([0 , 1 , 0 , 1 , 2 , 3 , 0 , 1 ]),
364
+ "index" : np .asarray ([10 , 11 , 12 , 0 , 0 , 0 ]),
339
365
},
340
366
{
341
- "observation" : np .asarray ([6 , 6 , 0 , 0 , 0 , 0 , 0 , 0 ]),
342
- "observation_segment_ids" : np .asarray ([1 , 1 , 0 , 0 , 0 , 0 , 0 , 0 ]),
343
- "observation_positions" : np .asarray ([0 , 1 , 0 , 0 , 0 , 0 , 0 , 0 ]),
367
+ "observation" : np .asarray ([5 , 5 , 5 , 0 , 0 , 0 , 0 , 0 ]),
368
+ "observation_segment_ids" : np .asarray ([1 , 1 , 1 , 0 , 0 , 0 , 0 , 0 ]),
369
+ "observation_positions" : np .asarray ([0 , 1 , 2 , 0 , 0 , 0 , 0 , 0 ]),
344
370
"index" : np .asarray ([12 , 0 , 0 , 0 , 0 , 0 ]),
345
371
},
346
372
],
@@ -406,34 +432,6 @@ def test_checkpointing_using_grain_built_in_tools(
406
432
)
407
433
)
408
434
409
- @parameterized .product (
410
- bos_handling = list (BOSHandling ),
411
- )
412
- def test_pack_sequence_longer_than_sequence_length (self , bos_handling ):
413
- sequence_length = 10
414
- if bos_handling == BOSHandling .REPLACE_FIRST_TOKEN_WITH_BOS :
415
- bos_token_id = 1000
416
- bos_features = {"observation" }
417
- else :
418
- bos_token_id = None
419
- bos_features = {}
420
- ds = dataset .MapDataset .source ([
421
- {"observation" : np .repeat (1 , 100 )}, # 100 > sequence_length
422
- ]).to_iter_dataset ()
423
- ds = packing_concat_then_split .ConcatThenSplitIterDataset (
424
- ds ,
425
- length_struct = {"observation" : sequence_length },
426
- split_full_length_features = False ,
427
- bos_handling = bos_handling ,
428
- bos_token_id = bos_token_id ,
429
- bos_features = bos_features ,
430
- )
431
- with self .assertRaisesWithPredicateMatch (
432
- exceptions .PyGrainInternalError ,
433
- lambda _ : "Feature 'observation' has 100 tokens" ,
434
- ):
435
- next (iter (ds ))
436
-
437
435
def assert_equal_elements (
438
436
self ,
439
437
actual_elements : list [dict [str , np .ndarray ]],
0 commit comments