Skip to content

Commit ed0771e

Browse files
marcenacpcopybara-github
authored andcommitted
Internal change.
PiperOrigin-RevId: 784102873
1 parent d81f1e0 commit ed0771e

File tree

3 files changed

+88
-85
lines changed

3 files changed

+88
-85
lines changed

grain/_src/python/dataset/transformations/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,6 @@ py_test(
285285
srcs_version = "PY3",
286286
deps = [
287287
":packing_concat_then_split",
288-
"//grain/_src/core:exceptions",
289288
"//grain/_src/python/dataset",
290289
"//grain/_src/python/testing:experimental",
291290
"@abseil-py//absl/testing:absltest",

grain/_src/python/dataset/transformations/packing_concat_then_split.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,10 @@ class _CtsElement:
140140
parent_state: The state of the parent iterator *before* __next__() was
141141
called.
142142
features: Features as returned by calling __next__() on the parent iterator.
143-
slices: If set then maps the feature name to the `slice` object for the
144-
split features.
143+
slices: Maps the feature name to a tuple (start, stop) representing the
144+
slice of the feature to use (in case this element represents a partial
145+
element resulting from a split). A slice of (-1, -1) represents the whole
146+
feature.
145147
"""
146148

147149
parent_state: dict[str, Any]
@@ -151,7 +153,15 @@ class _CtsElement:
151153
def split(
152154
self, split_points: Mapping[str, int]
153155
) -> tuple[_CtsElement | None, _CtsElement]:
154-
"""Splits the element into two elements."""
156+
"""Splits the element into two elements.
157+
158+
Args:
159+
split_points: A mapping from feature name to the desired split index.
160+
161+
Returns:
162+
The left and right elements. If the element is not split, returns None
163+
for the left element and the original element for the right element.
164+
"""
155165
# We split at the very beginning.
156166
if all(x == 0 for x in split_points.values()):
157167
return None, self
@@ -256,15 +266,8 @@ def _has_full_length_feature(self, element: _CtsElement) -> bool:
256266
for key, target_sequence_length in self._config.length_struct.items():
257267
feature = element.get_sliced_features(key)
258268
sequence_length = 1 if np.ndim(feature) == 0 else len(feature)
259-
if sequence_length < target_sequence_length:
260-
continue
261269
if sequence_length == target_sequence_length:
262270
return True
263-
if sequence_length > target_sequence_length:
264-
raise exceptions.PyGrainInternalError(
265-
f"Feature '{key}' has {sequence_length} tokens but target length is"
266-
f" only {target_sequence_length}. The element should be split."
267-
)
268271
return False
269272

270273
def _pack_elements(
@@ -371,7 +374,10 @@ def _maybe_add_to_buffer(
371374
else:
372375
if sequence_length > available_tokens:
373376
needs_splitting = True
374-
split_points[key] = available_tokens
377+
start_index = 0
378+
if element.slices[key] != _EMPTY_SLICE:
379+
start_index = element.slices[key][0]
380+
split_points[key] = start_index + available_tokens
375381
new_tokens_in_buffer[key] = available_tokens
376382
else:
377383
# No splitting.

grain/_src/python/dataset/transformations/packing_concat_then_split_test.py

Lines changed: 71 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
from absl.testing import absltest
1919
from absl.testing import parameterized
20-
from grain._src.core import exceptions
2120
from grain._src.python.dataset import dataset
2221
from grain._src.python.dataset.transformations import packing_concat_then_split
2322
from grain._src.python.dataset.transformations import source
@@ -43,11 +42,14 @@ class ConcatThenSplitIterDatasetTest(parameterized.TestCase):
4342

4443
# observations will be [
4544
# [1], [2, 2], [3, 3, 3], [4, 4, 4, 4], [5, 5, 5, 5, 5],
46-
# [6, 6, 6, 6, 6, 6], [1], [2, 2], [3, 3, 3], ...
45+
# [6, 6, 6, 6, 6, 6], [7, 7, 7, 7, 7, 7, 7], [1], [2, 2], [3, 3, 3], ...
4746
# ].
4847
def dummy_iter_dataset(self, *, num_observations: int) -> dataset.IterDataset:
4948
return (
50-
source.RangeMapDataset(1, 7)
49+
# On purpose, we have observations longer (length=7) than the packing
50+
# sequence length of most test cases (6), so we can test splitting long
51+
# features.
52+
source.RangeMapDataset(1, 8)
5153
.repeat()
5254
.map_with_index(
5355
lambda index, value: {
@@ -92,16 +94,22 @@ def test_meta_features_not_restricting_when_splitting_full_length_features(
9294
"index": np.asarray([5, 6, 0, 0, 0, 0]),
9395
},
9496
{
95-
"observation": np.asarray([6, 6, 6, 1, 2, 2]),
96-
"observation_segment_ids": np.asarray([1, 1, 1, 2, 3, 3]),
97-
"observation_positions": np.asarray([0, 1, 2, 0, 0, 1]),
98-
"index": np.asarray([6, 7, 8, 0, 0, 0]),
97+
"observation": np.asarray([6, 6, 6, 7, 7, 7]),
98+
"observation_segment_ids": np.asarray([1, 1, 1, 2, 2, 2]),
99+
"observation_positions": np.asarray([0, 1, 2, 0, 1, 2]),
100+
"index": np.asarray([6, 7, 0, 0, 0, 0]),
101+
},
102+
{
103+
"observation": np.asarray([7, 7, 7, 7, 1, 2]),
104+
"observation_segment_ids": np.asarray([1, 1, 1, 1, 2, 3]),
105+
"observation_positions": np.asarray([0, 1, 2, 3, 0, 0]),
106+
"index": np.asarray([7, 8, 9, 0, 0, 0]),
99107
},
100108
# Reached end.
101109
{
102-
"observation": np.asarray([3, 3, 3, 0, 0, 0]),
103-
"observation_segment_ids": np.asarray([1, 1, 1, 0, 0, 0]),
104-
"observation_positions": np.asarray([0, 1, 2, 0, 0, 0]),
110+
"observation": np.asarray([2, 0, 0, 0, 0, 0]),
111+
"observation_segment_ids": np.asarray([1, 0, 0, 0, 0, 0]),
112+
"observation_positions": np.asarray([0, 0, 0, 0, 0, 0]),
105113
"index": np.asarray([9, 0, 0, 0, 0, 0]),
106114
},
107115
],
@@ -132,10 +140,10 @@ def test_meta_features_not_restricting(self):
132140
"index": np.asarray([4, 5, 0, 0, 0, 0]),
133141
},
134142
{
135-
"observation": np.asarray([5, 5, 5, 1, 2, 2]),
136-
"observation_segment_ids": np.asarray([1, 1, 1, 2, 3, 3]),
137-
"observation_positions": np.asarray([0, 1, 2, 0, 0, 1]),
138-
"index": np.asarray([5, 7, 8, 0, 0, 0]),
143+
"observation": np.asarray([5, 5, 5, 7, 7, 7]),
144+
"observation_segment_ids": np.asarray([1, 1, 1, 2, 2, 2]),
145+
"observation_positions": np.asarray([0, 1, 2, 0, 1, 2]),
146+
"index": np.asarray([5, 7, 0, 0, 0, 0]),
139147
},
140148
# Fully packed example comes without being split.
141149
{
@@ -144,11 +152,17 @@ def test_meta_features_not_restricting(self):
144152
"observation_positions": np.asarray([0, 1, 2, 3, 4, 5]),
145153
"index": np.asarray([6, 0, 0, 0, 0, 0]),
146154
},
155+
{
156+
"observation": np.asarray([7, 7, 7, 7, 1, 2]),
157+
"observation_segment_ids": np.asarray([1, 1, 1, 1, 2, 3]),
158+
"observation_positions": np.asarray([0, 1, 2, 3, 0, 0]),
159+
"index": np.asarray([7, 8, 9, 0, 0, 0]),
160+
},
147161
# Reached end.
148162
{
149-
"observation": np.asarray([3, 3, 3, 0, 0, 0]),
150-
"observation_segment_ids": np.asarray([1, 1, 1, 0, 0, 0]),
151-
"observation_positions": np.asarray([0, 1, 2, 0, 0, 0]),
163+
"observation": np.asarray([2, 0, 0, 0, 0, 0]),
164+
"observation_segment_ids": np.asarray([1, 0, 0, 0, 0, 0]),
165+
"observation_positions": np.asarray([0, 0, 0, 0, 0, 0]),
152166
"index": np.asarray([9, 0, 0, 0, 0, 0]),
153167
},
154168
],
@@ -191,15 +205,21 @@ def test_meta_features_restricting(self):
191205
"index": np.asarray([6, 0]),
192206
},
193207
{
194-
"observation": np.asarray([1, 2, 2, 0, 0, 0]),
195-
"observation_segment_ids": np.asarray([1, 2, 2, 0, 0, 0]),
196-
"observation_positions": np.asarray([0, 0, 1, 0, 0, 0]),
208+
"observation": np.asarray([7, 7, 7, 7, 7, 7]),
209+
"observation_segment_ids": np.asarray([1, 1, 1, 1, 1, 1]),
210+
"observation_positions": np.asarray([0, 1, 2, 3, 4, 5]),
211+
"index": np.asarray([7, 0]),
212+
},
213+
{
214+
"observation": np.asarray([7, 1, 0, 0, 0, 0]),
215+
"observation_segment_ids": np.asarray([1, 2, 0, 0, 0, 0]),
216+
"observation_positions": np.asarray([0, 0, 0, 0, 0, 0]),
197217
"index": np.asarray([7, 8]),
198218
},
199219
{
200-
"observation": np.asarray([3, 3, 3, 0, 0, 0]),
201-
"observation_segment_ids": np.asarray([1, 1, 1, 0, 0, 0]),
202-
"observation_positions": np.asarray([0, 1, 2, 0, 0, 0]),
220+
"observation": np.asarray([2, 2, 0, 0, 0, 0]),
221+
"observation_segment_ids": np.asarray([1, 1, 0, 0, 0, 0]),
222+
"observation_positions": np.asarray([0, 1, 0, 0, 0, 0]),
203223
"index": np.asarray([9, 0]),
204224
},
205225
],
@@ -233,10 +253,10 @@ def test_replace_first_token_with_bos(self):
233253
"index": np.asarray([4, 5, 0, 0, 0, 0]),
234254
},
235255
{
236-
"observation": np.asarray([1000, 5, 5, 1000, 1000, 2]),
237-
"observation_segment_ids": np.asarray([1, 1, 1, 2, 3, 3]),
238-
"observation_positions": np.asarray([0, 1, 2, 0, 0, 1]),
239-
"index": np.asarray([5, 7, 8, 0, 0, 0]),
256+
"observation": np.asarray([1000, 5, 5, 1000, 7, 7]),
257+
"observation_segment_ids": np.asarray([1, 1, 1, 2, 2, 2]),
258+
"observation_positions": np.asarray([0, 1, 2, 0, 1, 2]),
259+
"index": np.asarray([5, 7, 0, 0, 0, 0]),
240260
},
241261
# Fully packed example comes without being split.
242262
{
@@ -245,11 +265,17 @@ def test_replace_first_token_with_bos(self):
245265
"observation_positions": np.asarray([0, 1, 2, 3, 4, 5]),
246266
"index": np.asarray([6, 0, 0, 0, 0, 0]),
247267
},
268+
{
269+
"observation": np.asarray([1000, 7, 7, 7, 1000, 1000]),
270+
"observation_segment_ids": np.asarray([1, 1, 1, 1, 2, 3]),
271+
"observation_positions": np.asarray([0, 1, 2, 3, 0, 0]),
272+
"index": np.asarray([7, 8, 9, 0, 0, 0]),
273+
},
248274
# Reached end.
249275
{
250-
"observation": np.asarray([1000, 3, 3, 0, 0, 0]),
251-
"observation_segment_ids": np.asarray([1, 1, 1, 0, 0, 0]),
252-
"observation_positions": np.asarray([0, 1, 2, 0, 0, 0]),
276+
"observation": np.asarray([1000, 0, 0, 0, 0, 0]),
277+
"observation_segment_ids": np.asarray([1, 0, 0, 0, 0, 0]),
278+
"observation_positions": np.asarray([0, 0, 0, 0, 0, 0]),
253279
"index": np.asarray([9, 0, 0, 0, 0, 0]),
254280
},
255281
],
@@ -320,27 +346,27 @@ def _create_iter(state: dict[str, Any] | None):
320346
"index": np.asarray([4, 5, 6, 0, 0, 0]),
321347
},
322348
{
323-
"observation": np.asarray([6, 6, 6, 6, 6, 1, 2, 2]),
324-
"observation_segment_ids": np.asarray([1, 1, 1, 1, 1, 2, 3, 3]),
325-
"observation_positions": np.asarray([0, 1, 2, 3, 4, 0, 0, 1]),
326-
"index": np.asarray([6, 7, 8, 0, 0, 0]),
349+
"observation": np.asarray([6, 6, 6, 6, 6, 7, 7, 7]),
350+
"observation_segment_ids": np.asarray([1, 1, 1, 1, 1, 2, 2, 2]),
351+
"observation_positions": np.asarray([0, 1, 2, 3, 4, 0, 1, 2]),
352+
"index": np.asarray([6, 7, 0, 0, 0, 0]),
327353
},
328354
{
329-
"observation": np.asarray([3, 3, 3, 4, 4, 4, 4, 5]),
330-
"observation_segment_ids": np.asarray([1, 1, 1, 2, 2, 2, 2, 3]),
331-
"observation_positions": np.asarray([0, 1, 2, 0, 1, 2, 3, 0]),
332-
"index": np.asarray([9, 10, 11, 0, 0, 0]),
355+
"observation": np.asarray([7, 7, 7, 7, 1, 2, 2, 3]),
356+
"observation_segment_ids": np.asarray([1, 1, 1, 1, 2, 3, 3, 4]),
357+
"observation_positions": np.asarray([0, 1, 2, 3, 0, 0, 1, 0]),
358+
"index": np.asarray([7, 8, 9, 10, 0, 0]),
333359
},
334360
{
335-
"observation": np.asarray([5, 5, 5, 5, 6, 6, 6, 6]),
336-
"observation_segment_ids": np.asarray([1, 1, 1, 1, 2, 2, 2, 2]),
337-
"observation_positions": np.asarray([0, 1, 2, 3, 0, 1, 2, 3]),
338-
"index": np.asarray([11, 12, 0, 0, 0, 0]),
361+
"observation": np.asarray([3, 3, 4, 4, 4, 4, 5, 5]),
362+
"observation_segment_ids": np.asarray([1, 1, 2, 2, 2, 2, 3, 3]),
363+
"observation_positions": np.asarray([0, 1, 0, 1, 2, 3, 0, 1]),
364+
"index": np.asarray([10, 11, 12, 0, 0, 0]),
339365
},
340366
{
341-
"observation": np.asarray([6, 6, 0, 0, 0, 0, 0, 0]),
342-
"observation_segment_ids": np.asarray([1, 1, 0, 0, 0, 0, 0, 0]),
343-
"observation_positions": np.asarray([0, 1, 0, 0, 0, 0, 0, 0]),
367+
"observation": np.asarray([5, 5, 5, 0, 0, 0, 0, 0]),
368+
"observation_segment_ids": np.asarray([1, 1, 1, 0, 0, 0, 0, 0]),
369+
"observation_positions": np.asarray([0, 1, 2, 0, 0, 0, 0, 0]),
344370
"index": np.asarray([12, 0, 0, 0, 0, 0]),
345371
},
346372
],
@@ -406,34 +432,6 @@ def test_checkpointing_using_grain_built_in_tools(
406432
)
407433
)
408434

409-
@parameterized.product(
410-
bos_handling=list(BOSHandling),
411-
)
412-
def test_pack_sequence_longer_than_sequence_length(self, bos_handling):
413-
sequence_length = 10
414-
if bos_handling == BOSHandling.REPLACE_FIRST_TOKEN_WITH_BOS:
415-
bos_token_id = 1000
416-
bos_features = {"observation"}
417-
else:
418-
bos_token_id = None
419-
bos_features = {}
420-
ds = dataset.MapDataset.source([
421-
{"observation": np.repeat(1, 100)}, # 100 > sequence_length
422-
]).to_iter_dataset()
423-
ds = packing_concat_then_split.ConcatThenSplitIterDataset(
424-
ds,
425-
length_struct={"observation": sequence_length},
426-
split_full_length_features=False,
427-
bos_handling=bos_handling,
428-
bos_token_id=bos_token_id,
429-
bos_features=bos_features,
430-
)
431-
with self.assertRaisesWithPredicateMatch(
432-
exceptions.PyGrainInternalError,
433-
lambda _: "Feature 'observation' has 100 tokens",
434-
):
435-
next(iter(ds))
436-
437435
def assert_equal_elements(
438436
self,
439437
actual_elements: list[dict[str, np.ndarray]],

0 commit comments

Comments
 (0)