Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions keras_rs/src/layers/embedding/tensorflow/config_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
def translate_keras_rs_configuration(
feature_configs: types.Nested[FeatureConfig],
table_stacking: str | Sequence[str] | Sequence[Sequence[str]],
num_replicas_in_sync: int,
) -> tuple[
types.Nested[tf.tpu.experimental.embedding.FeatureConfig],
tf.tpu.experimental.embedding.SparseCoreEmbeddingConfig,
Expand All @@ -72,7 +73,10 @@ def translate_keras_rs_configuration(
"""
tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig] = {}
feature_configs = keras.tree.map_structure(
lambda f: translate_keras_rs_feature_config(f, tables), feature_configs
lambda f: translate_keras_rs_feature_config(
f, tables, num_replicas_in_sync
),
feature_configs,
)

# max_ids_per_chip_per_sample
Expand Down Expand Up @@ -107,6 +111,7 @@ def translate_keras_rs_configuration(
def translate_keras_rs_feature_config(
feature_config: FeatureConfig,
tables: dict[TableConfig, tf.tpu.experimental.embedding.TableConfig],
num_replicas_in_sync: int,
) -> tf.tpu.experimental.embedding.FeatureConfig:
"""Translates a Keras RS feature config to a TensorFlow TPU feature config.

Expand All @@ -120,18 +125,46 @@ def translate_keras_rs_feature_config(
Returns:
The TensorFlow TPU feature config.
"""
if num_replicas_in_sync <= 0:
raise ValueError(
"`num_replicas_in_sync` must be positive, "
f"but got {num_replicas_in_sync}."
)

table = tables.get(feature_config.table, None)
if table is None:
table = translate_keras_rs_table_config(feature_config.table)
tables[feature_config.table] = table

if len(feature_config.output_shape) < 2:
raise ValueError(
f"Invalid `output_shape` {feature_config.output_shape} in "
f"`FeatureConfig` {feature_config}. It must have at least 2 "
"dimensions: a batch dimension and an embedding dimension."
)

# Exclude last dimension, TensorFlow's TPUEmbedding doesn't want it.
output_shape = list(feature_config.output_shape[0:-1])

batch_size = output_shape[0]
per_replica_batch_size: int | None = None
if batch_size is not None:
if batch_size % num_replicas_in_sync != 0:
raise ValueError(
f"Invalid `output_shape` {feature_config.output_shape} in "
f"`FeatureConfig` {feature_config}. Batch size {batch_size} is "
f"not a multiple of the number of TPUs {num_replicas_in_sync}."
)
per_replica_batch_size = batch_size // num_replicas_in_sync

# TensorFlow's TPUEmbedding wants the per replica batch size.
output_shape = [per_replica_batch_size] + output_shape[1:]

# max_sequence_length
return tf.tpu.experimental.embedding.FeatureConfig(
name=feature_config.name,
table=table,
output_shape=feature_config.output_shape[
0:-1
], # exclude last dimension
output_shape=output_shape,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,9 @@ def _sparsecore_init(
)
self._tpu_feature_configs, self._sparse_core_embedding_config = (
config_conversion.translate_keras_rs_configuration(
feature_configs, table_stacking
feature_configs,
table_stacking,
strategy.num_replicas_in_sync,
)
)
if tpu_embedding_feature == EMBEDDING_FEATURE_V1:
Expand Down