Skip to content

Commit b4dfd3b

Browse files
committed
Fix batch sizes in distributed_embedding_test.py.
Fixed inconsistencies between the per TPU batch size and global batch size.
1 parent 7c36879 commit b4dfd3b

File tree

1 file changed

+56
-50
lines changed

1 file changed

+56
-50
lines changed

keras_rs/src/layers/embedding/distributed_embedding_test.py

Lines changed: 56 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def experimental_distribute_dataset(self, dataset, options=None):
4747
return dataset
4848

4949

50+
class JaxDummyStrategy(DummyStrategy):
51+
@property
52+
def num_replicas_in_sync(self):
53+
return len(jax.devices("tpu"))
54+
55+
5056
class DistributedEmbeddingTest(testing.TestCase, parameterized.TestCase):
5157
def setUp(self):
5258
super().setUp()
@@ -80,9 +86,15 @@ def setUp(self):
8086
)
8187
print("### num_replicas", self._strategy.num_replicas_in_sync)
8288
self.addCleanup(tf.tpu.experimental.shutdown_tpu_system, resolver)
89+
elif keras.backend.backend() == "jax" and self.on_tpu:
90+
self._strategy = JaxDummyStrategy()
8391
else:
8492
self._strategy = DummyStrategy()
8593

94+
self.batch_size = (
95+
BATCH_SIZE_PER_CORE * self._strategy.num_replicas_in_sync
96+
)
97+
8698
def run_with_strategy(self, fn, *args, jit_compile=False):
8799
"""Wrapper for running a function under a strategy."""
88100

@@ -120,31 +132,31 @@ def get_embedding_config(self, input_type, placement):
120132
feature_group["feature1"] = config.FeatureConfig(
121133
name="feature1",
122134
table=feature1_table,
123-
input_shape=(BATCH_SIZE_PER_CORE, sequence_length),
124-
output_shape=(BATCH_SIZE_PER_CORE, FEATURE1_EMBEDDING_OUTPUT_DIM),
135+
input_shape=(self.batch_size, sequence_length),
136+
output_shape=(self.batch_size, FEATURE1_EMBEDDING_OUTPUT_DIM),
125137
)
126138
feature_group["feature2"] = config.FeatureConfig(
127139
name="feature2",
128140
table=feature2_table,
129-
input_shape=(BATCH_SIZE_PER_CORE, sequence_length),
130-
output_shape=(BATCH_SIZE_PER_CORE, FEATURE2_EMBEDDING_OUTPUT_DIM),
141+
input_shape=(self.batch_size, sequence_length),
142+
output_shape=(self.batch_size, FEATURE2_EMBEDDING_OUTPUT_DIM),
131143
)
132144
return {"feature_group": feature_group}
133145

134146
def create_inputs_weights_and_labels(
135-
self, batch_size, input_type, feature_configs, backend=None
147+
self, input_type, feature_configs, backend=None
136148
):
137149
backend = backend or keras.backend.backend()
138150

139151
if input_type == "dense":
140152

141153
def create_tensor(feature_config, op):
142-
sequence_length = feature_config.input_shape[-1]
143-
return op((batch_size, sequence_length))
154+
return op(feature_config.input_shape)
144155

145156
elif input_type == "ragged":
146157

147158
def create_tensor(feature_config, op):
159+
batch_size = feature_config.input_shape[0]
148160
sequence_length = feature_config.input_shape[-1]
149161
row_lengths = [
150162
1 + (i % sequence_length) for i in range(batch_size)
@@ -157,6 +169,7 @@ def create_tensor(feature_config, op):
157169
elif input_type == "sparse" and backend == "tensorflow":
158170

159171
def create_tensor(feature_config, op):
172+
batch_size = feature_config.input_shape[0]
160173
sequence_length = feature_config.input_shape[-1]
161174
indices = [[i, i % sequence_length] for i in range(batch_size)]
162175
return tf.sparse.reorder(
@@ -170,6 +183,7 @@ def create_tensor(feature_config, op):
170183
elif input_type == "sparse" and backend == "jax":
171184

172185
def create_tensor(feature_config, op):
186+
batch_size = feature_config.input_shape[0]
173187
sequence_length = feature_config.input_shape[-1]
174188
indices = [[i, i % sequence_length] for i in range(batch_size)]
175189
return jax_sparse.BCOO(
@@ -197,9 +211,7 @@ def create_tensor(feature_config, op):
197211
feature_configs,
198212
)
199213
labels = keras.tree.map_structure(
200-
lambda fc: np.ones(
201-
(batch_size,) + fc.output_shape[1:], dtype=np.float32
202-
),
214+
lambda fc: np.ones(fc.output_shape, dtype=np.float32),
203215
feature_configs,
204216
)
205217
return inputs, weights, labels
@@ -228,10 +240,9 @@ def test_basics(self, input_type, placement):
228240
):
229241
self.skipTest("Ragged and sparse are not compilable on TPU.")
230242

231-
batch_size = self._strategy.num_replicas_in_sync * BATCH_SIZE_PER_CORE
232243
feature_configs = self.get_embedding_config(input_type, placement)
233244
inputs, weights, _ = self.create_inputs_weights_and_labels(
234-
batch_size, input_type, feature_configs
245+
input_type, feature_configs
235246
)
236247

237248
if placement == "sparsecore" and not self.on_tpu:
@@ -264,11 +275,11 @@ def test_basics(self, input_type, placement):
264275

265276
self.assertEqual(
266277
res["feature_group"]["feature1"].shape,
267-
(batch_size, FEATURE1_EMBEDDING_OUTPUT_DIM),
278+
(self.batch_size, FEATURE1_EMBEDDING_OUTPUT_DIM),
268279
)
269280
self.assertEqual(
270281
res["feature_group"]["feature2"].shape,
271-
(batch_size, FEATURE2_EMBEDDING_OUTPUT_DIM),
282+
(self.batch_size, FEATURE2_EMBEDDING_OUTPUT_DIM),
272283
)
273284

274285
@parameterized.named_parameters(
@@ -289,16 +300,15 @@ def test_model_fit(self, input_type, use_weights):
289300
f"{input_type} not supported on {keras.backend.backend()}"
290301
)
291302

292-
batch_size = self._strategy.num_replicas_in_sync * BATCH_SIZE_PER_CORE
293303
feature_configs = self.get_embedding_config(input_type, self.placement)
294304
train_inputs, train_weights, train_labels = (
295305
self.create_inputs_weights_and_labels(
296-
batch_size, input_type, feature_configs, backend="tensorflow"
306+
input_type, feature_configs, backend="tensorflow"
297307
)
298308
)
299309
test_inputs, test_weights, test_labels = (
300310
self.create_inputs_weights_and_labels(
301-
batch_size, input_type, feature_configs, backend="tensorflow"
311+
input_type, feature_configs, backend="tensorflow"
302312
)
303313
)
304314

@@ -482,12 +492,11 @@ def test_correctness(
482492
feature_config = config.FeatureConfig(
483493
name="feature",
484494
table=table,
485-
input_shape=(BATCH_SIZE_PER_CORE, sequence_length),
486-
output_shape=(BATCH_SIZE_PER_CORE, EMBEDDING_OUTPUT_DIM),
495+
input_shape=(self.batch_size, sequence_length),
496+
output_shape=(self.batch_size, EMBEDDING_OUTPUT_DIM),
487497
)
488498

489-
batch_size = self._strategy.num_replicas_in_sync * BATCH_SIZE_PER_CORE
490-
num_repeats = batch_size // 2
499+
num_repeats = self.batch_size // 2
491500
if input_type == "dense" and input_rank == 1:
492501
inputs = keras.ops.convert_to_tensor([2, 3] * num_repeats)
493502
weights = keras.ops.convert_to_tensor([1.0, 2.0] * num_repeats)
@@ -512,14 +521,14 @@ def test_correctness(
512521
tf.SparseTensor(
513522
indices,
514523
[1, 2, 3, 4, 5] * num_repeats,
515-
dense_shape=(batch_size, 4),
524+
dense_shape=(self.batch_size, 4),
516525
)
517526
)
518527
weights = tf.sparse.reorder(
519528
tf.SparseTensor(
520529
indices,
521530
[1.0, 1.0, 2.0, 3.0, 4.0] * num_repeats,
522-
dense_shape=(batch_size, 4),
531+
dense_shape=(self.batch_size, 4),
523532
)
524533
)
525534
elif keras.backend.backend() == "jax":
@@ -528,15 +537,15 @@ def test_correctness(
528537
jnp.asarray([1, 2, 3, 4, 5] * num_repeats),
529538
jnp.asarray(indices),
530539
),
531-
shape=(batch_size, 4),
540+
shape=(self.batch_size, 4),
532541
unique_indices=True,
533542
)
534543
weights = jax_sparse.BCOO(
535544
(
536545
jnp.asarray([1.0, 1.0, 2.0, 3.0, 4.0] * num_repeats),
537546
jnp.asarray(indices),
538547
),
539-
shape=(batch_size, 4),
548+
shape=(self.batch_size, 4),
540549
unique_indices=True,
541550
)
542551
else:
@@ -600,7 +609,7 @@ def test_correctness(
600609
layer.__call__, inputs, weights, jit_compile=jit_compile
601610
)
602611

603-
self.assertEqual(res.shape, (batch_size, EMBEDDING_OUTPUT_DIM))
612+
self.assertEqual(res.shape, (self.batch_size, EMBEDDING_OUTPUT_DIM))
604613

605614
tables = layer.get_embedding_tables()
606615
emb = tables["table"]
@@ -644,26 +653,25 @@ def test_shared_table(self):
644653
"feature1": config.FeatureConfig(
645654
name="feature1",
646655
table=table1,
647-
input_shape=(BATCH_SIZE_PER_CORE, 1),
648-
output_shape=(BATCH_SIZE_PER_CORE, EMBEDDING_OUTPUT_DIM),
656+
input_shape=(self.batch_size, 1),
657+
output_shape=(self.batch_size, EMBEDDING_OUTPUT_DIM),
649658
),
650659
"feature2": config.FeatureConfig(
651660
name="feature2",
652661
table=table1,
653-
input_shape=(BATCH_SIZE_PER_CORE, 1),
654-
output_shape=(BATCH_SIZE_PER_CORE, EMBEDDING_OUTPUT_DIM),
662+
input_shape=(self.batch_size, 1),
663+
output_shape=(self.batch_size, EMBEDDING_OUTPUT_DIM),
655664
),
656665
"feature3": config.FeatureConfig(
657666
name="feature3",
658667
table=table1,
659-
input_shape=(BATCH_SIZE_PER_CORE, 1),
660-
output_shape=(BATCH_SIZE_PER_CORE, EMBEDDING_OUTPUT_DIM),
668+
input_shape=(self.batch_size, 1),
669+
output_shape=(self.batch_size, EMBEDDING_OUTPUT_DIM),
661670
),
662671
}
663672

664-
batch_size = self._strategy.num_replicas_in_sync * BATCH_SIZE_PER_CORE
665673
inputs, _, _ = self.create_inputs_weights_and_labels(
666-
batch_size, "dense", embedding_config
674+
"dense", embedding_config
667675
)
668676

669677
with self._strategy.scope():
@@ -676,13 +684,13 @@ def test_shared_table(self):
676684
self.assertLen(layer.trainable_variables, 1)
677685

678686
self.assertEqual(
679-
res["feature1"].shape, (batch_size, EMBEDDING_OUTPUT_DIM)
687+
res["feature1"].shape, (self.batch_size, EMBEDDING_OUTPUT_DIM)
680688
)
681689
self.assertEqual(
682-
res["feature2"].shape, (batch_size, EMBEDDING_OUTPUT_DIM)
690+
res["feature2"].shape, (self.batch_size, EMBEDDING_OUTPUT_DIM)
683691
)
684692
self.assertEqual(
685-
res["feature3"].shape, (batch_size, EMBEDDING_OUTPUT_DIM)
693+
res["feature3"].shape, (self.batch_size, EMBEDDING_OUTPUT_DIM)
686694
)
687695

688696
def test_mixed_placement(self):
@@ -719,26 +727,25 @@ def test_mixed_placement(self):
719727
"feature1": config.FeatureConfig(
720728
name="feature1",
721729
table=table1,
722-
input_shape=(BATCH_SIZE_PER_CORE, 1),
723-
output_shape=(BATCH_SIZE_PER_CORE, embedding_output_dim1),
730+
input_shape=(self.batch_size, 1),
731+
output_shape=(self.batch_size, embedding_output_dim1),
724732
),
725733
"feature2": config.FeatureConfig(
726734
name="feature2",
727735
table=table2,
728-
input_shape=(BATCH_SIZE_PER_CORE, 1),
729-
output_shape=(BATCH_SIZE_PER_CORE, embedding_output_dim2),
736+
input_shape=(self.batch_size, 1),
737+
output_shape=(self.batch_size, embedding_output_dim2),
730738
),
731739
"feature3": config.FeatureConfig(
732740
name="feature3",
733741
table=table3,
734-
input_shape=(BATCH_SIZE_PER_CORE, 1),
735-
output_shape=(BATCH_SIZE_PER_CORE, embedding_output_dim3),
742+
input_shape=(self.batch_size, 1),
743+
output_shape=(self.batch_size, embedding_output_dim3),
736744
),
737745
}
738746

739-
batch_size = self._strategy.num_replicas_in_sync * BATCH_SIZE_PER_CORE
740747
inputs, _, _ = self.create_inputs_weights_and_labels(
741-
batch_size, "dense", embedding_config
748+
"dense", embedding_config
742749
)
743750

744751
with self._strategy.scope():
@@ -747,20 +754,19 @@ def test_mixed_placement(self):
747754
res = self.run_with_strategy(layer.__call__, inputs)
748755

749756
self.assertEqual(
750-
res["feature1"].shape, (batch_size, embedding_output_dim1)
757+
res["feature1"].shape, (self.batch_size, embedding_output_dim1)
751758
)
752759
self.assertEqual(
753-
res["feature2"].shape, (batch_size, embedding_output_dim2)
760+
res["feature2"].shape, (self.batch_size, embedding_output_dim2)
754761
)
755762
self.assertEqual(
756-
res["feature3"].shape, (batch_size, embedding_output_dim3)
763+
res["feature3"].shape, (self.batch_size, embedding_output_dim3)
757764
)
758765

759766
def test_save_load_model(self):
760-
batch_size = self._strategy.num_replicas_in_sync * BATCH_SIZE_PER_CORE
761767
feature_configs = self.get_embedding_config("dense", self.placement)
762768
inputs, _, _ = self.create_inputs_weights_and_labels(
763-
batch_size, "dense", feature_configs
769+
"dense", feature_configs
764770
)
765771

766772
keras_inputs = keras.tree.map_structure(

0 commit comments

Comments
 (0)