Skip to content

Commit 8522a5b

Browse files
authored
Check for jax_tpu_embedding on JAX backend. (#114)
This is to allow users to potentially run Keras RS _without_ the dependency. If a user doesn't have `jax-tpu-embedding` installed, but are on `linux_x86_64` and has a sparsecore-capable TPU available, and if they try to use `auto` or `sparsecore` placement with distributed embedding, will raise an error informing them to install the dependency.
1 parent 3d78779 commit 8522a5b

File tree

3 files changed

+36
-8
lines changed

3 files changed

+36
-8
lines changed

keras_rs/src/layers/embedding/base_distributed_embedding.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
import importlib.util
23
import typing
34
from typing import Any, Sequence
45

@@ -304,7 +305,7 @@ def step(data):
304305
To use `DistributedEmbedding` on TPUs with JAX, one must create and set a
305306
Keras `Distribution`.
306307
```
307-
distribution = keras.distribution.DataParallel(devices=jax.device("tpu))
308+
distribution = keras.distribution.DataParallel(devices=jax.device("tpu"))
308309
keras.distribution.set_distribution(distribution)
309310
```
310311
@@ -901,6 +902,24 @@ def _default_device_get_embedding_tables(self) -> dict[str, types.Tensor]:
901902
return tables
902903

903904
def _has_sparsecore(self) -> bool:
905+
# Explicitly check for SparseCore availability.
906+
# We need this check here rather than in jax/distributed_embedding.py
907+
# so that we can warn the user about missing dependencies.
908+
if keras.backend.backend() == "jax":
909+
# Check if SparseCores are available.
910+
try:
911+
import jax
912+
913+
tpu_devices = jax.devices("tpu")
914+
except RuntimeError:
915+
# No TPUs available.
916+
return False
917+
918+
if len(tpu_devices) > 0:
919+
device_kind = tpu_devices[0].device_kind
920+
if device_kind in ["TPU v5", "TPU v6 lite"]:
921+
return True
922+
904923
return False
905924

906925
def _sparsecore_init(
@@ -909,6 +928,17 @@ def _sparsecore_init(
909928
table_stacking: str | Sequence[Sequence[str]],
910929
) -> None:
911930
del feature_configs, table_stacking
931+
932+
if keras.backend.backend() == "jax":
933+
jax_tpu_embedding_spec = importlib.util.find_spec(
934+
"jax_tpu_embedding"
935+
)
936+
if jax_tpu_embedding_spec is None:
937+
raise ImportError(
938+
"Please install jax-tpu-embedding to use "
939+
"DistributedEmbedding on sparsecore devices."
940+
)
941+
912942
raise self._unsupported_placement_error("sparsecore")
913943

914944
def _sparsecore_build(self, input_shapes: dict[str, types.Shape]) -> None:

keras_rs/src/layers/embedding/distributed_embedding.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
1+
import importlib.util
12
import platform
23
import sys
34

45
import keras
56

67
from keras_rs.src.api_export import keras_rs_export
78

8-
# JAX TPU embedding is only available on linux_x86_64.
9+
# JAX distributed embedding is only available on linux_x86_64, and only if
10+
# jax-tpu-embedding is installed.
11+
jax_tpu_embedding_spec = importlib.util.find_spec("jax_tpu_embedding")
912
if (
1013
keras.backend.backend() == "jax"
1114
and sys.platform == "linux"
1215
and platform.machine().lower() == "x86_64"
16+
and jax_tpu_embedding_spec is not None
1317
):
1418
from keras_rs.src.layers.embedding.jax.distributed_embedding import (
1519
DistributedEmbedding as BackendDistributedEmbedding,

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -347,12 +347,6 @@ def _add_table_variable(
347347

348348
return table_variable, slot_variables
349349

350-
def _has_sparsecore(self) -> bool:
351-
device_kind = jax.devices()[0].device_kind
352-
if device_kind in ["TPU v5", "TPU v6 lite"]:
353-
return True
354-
return False
355-
356350
@keras_utils.no_automatic_dependency_tracking
357351
def _sparsecore_init(
358352
self,

0 commit comments

Comments
 (0)