From 526942944dcf5b2a99168338b56749f0fd856dd6 Mon Sep 17 00:00:00 2001 From: Fabien Hertschuh <1091026+hertschuh@users.noreply.github.com> Date: Thu, 22 May 2025 11:18:38 -0700 Subject: [PATCH] Move `DistributedEmbedding` declaration to its own file. Having it in `__init__.py` doesn't play nice with pytype. --- keras_rs/api/layers/__init__.py | 2 +- keras_rs/src/layers/embedding/__init__.py | 21 ------------------- .../layers/embedding/distributed_embedding.py | 21 +++++++++++++++++++ .../embedding/distributed_embedding_test.py | 16 +++++++------- 4 files changed, 31 insertions(+), 29 deletions(-) create mode 100644 keras_rs/src/layers/embedding/distributed_embedding.py diff --git a/keras_rs/api/layers/__init__.py b/keras_rs/api/layers/__init__.py index 570550d9..8d740e85 100644 --- a/keras_rs/api/layers/__init__.py +++ b/keras_rs/api/layers/__init__.py @@ -4,7 +4,7 @@ since your modifications would be overwritten. """ -from keras_rs.src.layers.embedding import ( +from keras_rs.src.layers.embedding.distributed_embedding import ( DistributedEmbedding as DistributedEmbedding, ) from keras_rs.src.layers.embedding.distributed_embedding_config import ( diff --git a/keras_rs/src/layers/embedding/__init__.py b/keras_rs/src/layers/embedding/__init__.py index dfa4dcbb..e69de29b 100644 --- a/keras_rs/src/layers/embedding/__init__.py +++ b/keras_rs/src/layers/embedding/__init__.py @@ -1,21 +0,0 @@ -import keras - -from keras_rs.src.api_export import keras_rs_export - -if keras.backend.backend() == "jax": - from keras_rs.src.layers.embedding.jax.distributed_embedding import ( - DistributedEmbedding as BackendDistributedEmbedding, - ) -elif keras.backend.backend() == "tensorflow": - from keras_rs.src.layers.embedding.tensorflow.distributed_embedding import ( - DistributedEmbedding as BackendDistributedEmbedding, - ) -else: - from keras_rs.src.layers.embedding.base_distributed_embedding import ( - DistributedEmbedding as BackendDistributedEmbedding, - ) - - -@keras_rs_export("keras_rs.layers.DistributedEmbedding") -class DistributedEmbedding(BackendDistributedEmbedding): - pass diff --git a/keras_rs/src/layers/embedding/distributed_embedding.py b/keras_rs/src/layers/embedding/distributed_embedding.py new file mode 100644 index 00000000..dfa4dcbb --- /dev/null +++ b/keras_rs/src/layers/embedding/distributed_embedding.py @@ -0,0 +1,21 @@ +import keras + +from keras_rs.src.api_export import keras_rs_export + +if keras.backend.backend() == "jax": + from keras_rs.src.layers.embedding.jax.distributed_embedding import ( + DistributedEmbedding as BackendDistributedEmbedding, + ) +elif keras.backend.backend() == "tensorflow": + from keras_rs.src.layers.embedding.tensorflow.distributed_embedding import ( + DistributedEmbedding as BackendDistributedEmbedding, + ) +else: + from keras_rs.src.layers.embedding.base_distributed_embedding import ( + DistributedEmbedding as BackendDistributedEmbedding, + ) + + +@keras_rs_export("keras_rs.layers.DistributedEmbedding") +class DistributedEmbedding(BackendDistributedEmbedding): + pass diff --git a/keras_rs/src/layers/embedding/distributed_embedding_test.py b/keras_rs/src/layers/embedding/distributed_embedding_test.py index abfa12f6..a788aa77 100644 --- a/keras_rs/src/layers/embedding/distributed_embedding_test.py +++ b/keras_rs/src/layers/embedding/distributed_embedding_test.py @@ -15,7 +15,7 @@ from absl.testing import parameterized from keras_rs.src import testing -from keras_rs.src.layers import embedding +from keras_rs.src.layers.embedding import distributed_embedding from keras_rs.src.layers.embedding import distributed_embedding_config as config FLAGS = flags.FLAGS @@ -243,11 +243,11 @@ def test_basics(self, input_type, placement): if placement == "sparsecore" and not self.on_tpu: with self.assertRaisesRegex(Exception, "sparsecore"): with self._strategy.scope(): - embedding.DistributedEmbedding(feature_configs) + distributed_embedding.DistributedEmbedding(feature_configs) return with self._strategy.scope(): - layer = embedding.DistributedEmbedding(feature_configs) + layer = distributed_embedding.DistributedEmbedding(feature_configs) if keras.backend.backend() == "jax": preprocessed_inputs = layer.preprocess(inputs, weights) @@ -329,7 +329,7 @@ def test_model_fit(self, input_type, use_weights): ) with self._strategy.scope(): - layer = embedding.DistributedEmbedding(feature_configs) + layer = distributed_embedding.DistributedEmbedding(feature_configs) if keras.backend.backend() == "jax": # Set global distribution to ensure optimizer variables are @@ -560,7 +560,7 @@ def test_correctness( weights = None with self._strategy.scope(): - layer = embedding.DistributedEmbedding(feature_config) + layer = distributed_embedding.DistributedEmbedding(feature_config) if keras.backend.backend() == "jax": preprocessed = layer.preprocess(inputs, weights) @@ -675,7 +675,7 @@ def test_shared_table(self): ) with self._strategy.scope(): - layer = embedding.DistributedEmbedding(embedding_config) + layer = distributed_embedding.DistributedEmbedding(embedding_config) res = self.run_with_strategy(layer.__call__, inputs) @@ -709,7 +709,9 @@ def test_save_load_model(self): path = os.path.join(temp_dir, "model.keras") with self._strategy.scope(): - layer = embedding.DistributedEmbedding(feature_configs) + layer = distributed_embedding.DistributedEmbedding( + feature_configs + ) keras_outputs = layer(keras_inputs) model = keras.Model(inputs=keras_inputs, outputs=keras_outputs)