Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion keras_rs/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
since your modifications would be overwritten.
"""

from keras_rs.src.layers.embedding import (
from keras_rs.src.layers.embedding.distributed_embedding import (
DistributedEmbedding as DistributedEmbedding,
)
from keras_rs.src.layers.embedding.distributed_embedding_config import (
Expand Down
21 changes: 0 additions & 21 deletions keras_rs/src/layers/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +0,0 @@
import keras

from keras_rs.src.api_export import keras_rs_export

if keras.backend.backend() == "jax":
from keras_rs.src.layers.embedding.jax.distributed_embedding import (
DistributedEmbedding as BackendDistributedEmbedding,
)
elif keras.backend.backend() == "tensorflow":
from keras_rs.src.layers.embedding.tensorflow.distributed_embedding import (
DistributedEmbedding as BackendDistributedEmbedding,
)
else:
from keras_rs.src.layers.embedding.base_distributed_embedding import (
DistributedEmbedding as BackendDistributedEmbedding,
)


@keras_rs_export("keras_rs.layers.DistributedEmbedding")
class DistributedEmbedding(BackendDistributedEmbedding):
pass
21 changes: 21 additions & 0 deletions keras_rs/src/layers/embedding/distributed_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import keras

from keras_rs.src.api_export import keras_rs_export

if keras.backend.backend() == "jax":
from keras_rs.src.layers.embedding.jax.distributed_embedding import (
DistributedEmbedding as BackendDistributedEmbedding,
)
elif keras.backend.backend() == "tensorflow":
from keras_rs.src.layers.embedding.tensorflow.distributed_embedding import (
DistributedEmbedding as BackendDistributedEmbedding,
)
else:
from keras_rs.src.layers.embedding.base_distributed_embedding import (
DistributedEmbedding as BackendDistributedEmbedding,
)


@keras_rs_export("keras_rs.layers.DistributedEmbedding")
class DistributedEmbedding(BackendDistributedEmbedding):
pass
16 changes: 9 additions & 7 deletions keras_rs/src/layers/embedding/distributed_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from absl.testing import parameterized

from keras_rs.src import testing
from keras_rs.src.layers import embedding
from keras_rs.src.layers.embedding import distributed_embedding
from keras_rs.src.layers.embedding import distributed_embedding_config as config

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -243,11 +243,11 @@ def test_basics(self, input_type, placement):
if placement == "sparsecore" and not self.on_tpu:
with self.assertRaisesRegex(Exception, "sparsecore"):
with self._strategy.scope():
embedding.DistributedEmbedding(feature_configs)
distributed_embedding.DistributedEmbedding(feature_configs)
return

with self._strategy.scope():
layer = embedding.DistributedEmbedding(feature_configs)
layer = distributed_embedding.DistributedEmbedding(feature_configs)

if keras.backend.backend() == "jax":
preprocessed_inputs = layer.preprocess(inputs, weights)
Expand Down Expand Up @@ -329,7 +329,7 @@ def test_model_fit(self, input_type, use_weights):
)

with self._strategy.scope():
layer = embedding.DistributedEmbedding(feature_configs)
layer = distributed_embedding.DistributedEmbedding(feature_configs)

if keras.backend.backend() == "jax":
# Set global distribution to ensure optimizer variables are
Expand Down Expand Up @@ -560,7 +560,7 @@ def test_correctness(
weights = None

with self._strategy.scope():
layer = embedding.DistributedEmbedding(feature_config)
layer = distributed_embedding.DistributedEmbedding(feature_config)

if keras.backend.backend() == "jax":
preprocessed = layer.preprocess(inputs, weights)
Expand Down Expand Up @@ -675,7 +675,7 @@ def test_shared_table(self):
)

with self._strategy.scope():
layer = embedding.DistributedEmbedding(embedding_config)
layer = distributed_embedding.DistributedEmbedding(embedding_config)

res = self.run_with_strategy(layer.__call__, inputs)

Expand Down Expand Up @@ -709,7 +709,9 @@ def test_save_load_model(self):
path = os.path.join(temp_dir, "model.keras")

with self._strategy.scope():
layer = embedding.DistributedEmbedding(feature_configs)
layer = distributed_embedding.DistributedEmbedding(
feature_configs
)
keras_outputs = layer(keras_inputs)
model = keras.Model(inputs=keras_inputs, outputs=keras_outputs)

Expand Down