5353# KerasRS to TensorFlow
5454
5555
56- def translate_keras_rs_configuration (
56+ def keras_to_tf_tpu_configuration (
5757 feature_configs : types .Nested [FeatureConfig ],
5858 table_stacking : str | Sequence [str ] | Sequence [Sequence [str ]],
5959 num_replicas_in_sync : int ,
@@ -66,14 +66,15 @@ def translate_keras_rs_configuration(
6666 Args:
6767 feature_configs: The nested Keras RS feature configs.
6868 table_stacking: The Keras RS table stacking.
69+ num_replicas_in_sync: The number of replicas in sync from the strategy.
6970
7071 Returns:
7172 A tuple containing the TensorFlow TPU feature configs and the TensorFlow
7273 TPU sparse core embedding config.
7374 """
74- tables : dict [TableConfig , tf .tpu .experimental .embedding .TableConfig ] = {}
75+ tables : dict [int , tf .tpu .experimental .embedding .TableConfig ] = {}
7576 feature_configs = keras .tree .map_structure (
76- lambda f : translate_keras_rs_feature_config (
77+ lambda f : keras_to_tf_tpu_feature_config (
7778 f , tables , num_replicas_in_sync
7879 ),
7980 feature_configs ,
@@ -108,9 +109,9 @@ def translate_keras_rs_configuration(
108109 return feature_configs , sparse_core_embedding_config
109110
110111
111- def translate_keras_rs_feature_config (
112+ def keras_to_tf_tpu_feature_config (
112113 feature_config : FeatureConfig ,
113- tables : dict [TableConfig , tf .tpu .experimental .embedding .TableConfig ],
114+ tables : dict [int , tf .tpu .experimental .embedding .TableConfig ],
114115 num_replicas_in_sync : int ,
115116) -> tf .tpu .experimental .embedding .FeatureConfig :
116117 """Translates a Keras RS feature config to a TensorFlow TPU feature config.
@@ -120,7 +121,8 @@ def translate_keras_rs_feature_config(
120121
121122 Args:
122123 feature_config: The Keras RS feature config to translate.
123- tables: A mapping of KerasRS table configs to TF TPU table configs.
124+ tables: A mapping of KerasRS table config ids to TF TPU table configs.
125+ num_replicas_in_sync: The number of replicas in sync from the strategy.
124126
125127 Returns:
126128 The TensorFlow TPU feature config.
@@ -131,10 +133,10 @@ def translate_keras_rs_feature_config(
131133 f"but got { num_replicas_in_sync } ."
132134 )
133135
134- table = tables .get (feature_config .table , None )
136+ table = tables .get (id ( feature_config .table ) , None )
135137 if table is None :
136- table = translate_keras_rs_table_config (feature_config .table )
137- tables [feature_config .table ] = table
138+ table = keras_to_tf_tpu_table_config (feature_config .table )
139+ tables [id ( feature_config .table ) ] = table
138140
139141 if len (feature_config .output_shape ) < 2 :
140142 raise ValueError (
@@ -168,7 +170,7 @@ def translate_keras_rs_feature_config(
168170 )
169171
170172
171- def translate_keras_rs_table_config (
173+ def keras_to_tf_tpu_table_config (
172174 table_config : TableConfig ,
173175) -> tf .tpu .experimental .embedding .TableConfig :
174176 initializer = table_config .initializer
@@ -179,13 +181,13 @@ def translate_keras_rs_table_config(
179181 vocabulary_size = table_config .vocabulary_size ,
180182 dim = table_config .embedding_dim ,
181183 initializer = initializer ,
182- optimizer = translate_optimizer (table_config .optimizer ),
184+ optimizer = to_tf_tpu_optimizer (table_config .optimizer ),
183185 combiner = table_config .combiner ,
184186 name = table_config .name ,
185187 )
186188
187189
188- def translate_keras_optimizer (
190+ def keras_to_tf_tpu_optimizer (
189191 optimizer : keras .optimizers .Optimizer ,
190192) -> TfTpuOptimizer :
191193 """Translates a Keras optimizer to a TensorFlow TPU `_Optimizer`.
@@ -258,7 +260,7 @@ def translate_keras_optimizer(
258260 return optimizer_mapping .tpu_optimizer_class (** tpu_optimizer_kwargs )
259261
260262
261- def translate_optimizer (
263+ def to_tf_tpu_optimizer (
262264 optimizer : str | keras .optimizers .Optimizer | TfTpuOptimizer | None ,
263265) -> TfTpuOptimizer :
264266 """Translates a Keras optimizer into a TensorFlow TPU `_Optimizer`.
@@ -299,7 +301,7 @@ def translate_optimizer(
299301 "'sgd', 'adagrad', 'adam', or 'ftrl'"
300302 )
301303 elif isinstance (optimizer , keras .optimizers .Optimizer ):
302- return translate_keras_optimizer (optimizer )
304+ return keras_to_tf_tpu_optimizer (optimizer )
303305 else :
304306 raise ValueError (
305307 f"Unknown optimizer type { type (optimizer )} . Please pass an "
@@ -312,7 +314,7 @@ def translate_optimizer(
312314# TensorFlow to TensorFlow
313315
314316
315- def clone_tf_feature_configs (
317+ def clone_tf_tpu_feature_configs (
316318 feature_configs : types .Nested [tf .tpu .experimental .embedding .FeatureConfig ],
317319) -> types .Nested [tf .tpu .experimental .embedding .FeatureConfig ]:
318320 """Clones and resolves TensorFlow TPU feature configs.
@@ -327,7 +329,7 @@ def clone_tf_feature_configs(
327329 """
328330 table_configs_dict = {}
329331
330- def clone_and_resolve_tf_feature_config (
332+ def clone_and_resolve_tf_tpu_feature_config (
331333 fc : tf .tpu .experimental .embedding .FeatureConfig ,
332334 ) -> tf .tpu .experimental .embedding .FeatureConfig :
333335 if fc .table not in table_configs_dict :
@@ -336,7 +338,7 @@ def clone_and_resolve_tf_feature_config(
336338 vocabulary_size = fc .table .vocabulary_size ,
337339 dim = fc .table .dim ,
338340 initializer = fc .table .initializer ,
339- optimizer = translate_optimizer (fc .table .optimizer ),
341+ optimizer = to_tf_tpu_optimizer (fc .table .optimizer ),
340342 combiner = fc .table .combiner ,
341343 name = fc .table .name ,
342344 quantization_config = fc .table .quantization_config ,
@@ -352,5 +354,5 @@ def clone_and_resolve_tf_feature_config(
352354 )
353355
354356 return keras .tree .map_structure (
355- clone_and_resolve_tf_feature_config , feature_configs
357+ clone_and_resolve_tf_tpu_feature_config , feature_configs
356358 )
0 commit comments