5656def translate_keras_rs_configuration (
5757 feature_configs : types .Nested [FeatureConfig ],
5858 table_stacking : str | Sequence [str ] | Sequence [Sequence [str ]],
59+ num_replicas_in_sync : int ,
5960) -> tuple [
6061 types .Nested [tf .tpu .experimental .embedding .FeatureConfig ],
6162 tf .tpu .experimental .embedding .SparseCoreEmbeddingConfig ,
@@ -72,7 +73,10 @@ def translate_keras_rs_configuration(
7273 """
7374 tables : dict [TableConfig , tf .tpu .experimental .embedding .TableConfig ] = {}
7475 feature_configs = keras .tree .map_structure (
75- lambda f : translate_keras_rs_feature_config (f , tables ), feature_configs
76+ lambda f : translate_keras_rs_feature_config (
77+ f , tables , num_replicas_in_sync
78+ ),
79+ feature_configs ,
7680 )
7781
7882 # max_ids_per_chip_per_sample
@@ -107,6 +111,7 @@ def translate_keras_rs_configuration(
107111def translate_keras_rs_feature_config (
108112 feature_config : FeatureConfig ,
109113 tables : dict [TableConfig , tf .tpu .experimental .embedding .TableConfig ],
114+ num_replicas_in_sync : int ,
110115) -> tf .tpu .experimental .embedding .FeatureConfig :
111116 """Translates a Keras RS feature config to a TensorFlow TPU feature config.
112117
@@ -125,13 +130,28 @@ def translate_keras_rs_feature_config(
125130 table = translate_keras_rs_table_config (feature_config .table )
126131 tables [feature_config .table ] = table
127132
133+ # Exclude last dimension, TensorFlow's TPUEmbedding doesn't want it.
134+ output_shape = list (feature_config .output_shape [0 :- 1 ])
135+
136+ batch_size = output_shape [0 ]
137+ per_replica_batch_size : int | None = None
138+ if batch_size is not None :
139+ if batch_size % num_replicas_in_sync != 0 :
140+ raise ValueError (
141+ f"Invalid `output_shape` { feature_config .output_shape } in "
142+ f"`FeatureConfig` { feature_config } . Batch size { batch_size } is "
143+ "not a multiple of the number of TPUs {num_replicas_in_sync}."
144+ )
145+ per_replica_batch_size = batch_size // num_replicas_in_sync
146+
147+ # TensorFlow's TPUEmbedding wants the per replica batch size.
148+ output_shape = [per_replica_batch_size ] + output_shape [1 :]
149+
128150 # max_sequence_length
129151 return tf .tpu .experimental .embedding .FeatureConfig (
130152 name = feature_config .name ,
131153 table = table ,
132- output_shape = feature_config .output_shape [
133- 0 :- 1
134- ], # exclude last dimension
154+ output_shape = output_shape ,
135155 )
136156
137157
0 commit comments