Skip to content

Commit b6b6325

Browse files
authored
Only enable JAX on linux_x86_64. (#101)
1 parent 6dc2af9 commit b6b6325

File tree

3 files changed

+11
-3
lines changed

3 files changed

+11
-3
lines changed

keras_rs/src/layers/embedding/distributed_embedding.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1+
import platform
2+
import sys
3+
14
import keras
25

36
from keras_rs.src.api_export import keras_rs_export
47

5-
if keras.backend.backend() == "jax":
8+
# JAX TPU embedding is only available on linux_x86_64.
9+
if (
10+
keras.backend.backend() == "jax"
11+
and sys.platform == "linux"
12+
and platform.machine().lower() == "x86_64"
13+
):
614
from keras_rs.src.layers.embedding.jax.distributed_embedding import (
715
DistributedEmbedding as BackendDistributedEmbedding,
816
)

requirements-jax-cuda.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ torch>=2.1.0
1111
jax[cuda12_pip]==0.6.0
1212

1313
# Support for large embeddings.
14-
jax-tpu-embedding
14+
jax-tpu-embedding;sys_platform == 'linux' and platform_machine == 'x86_64'
1515

1616
-r requirements-common.txt

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ torch>=2.1.0
88

99
# Jax.
1010
jax[cpu]
11-
jax-tpu-embedding
11+
jax-tpu-embedding;sys_platform == 'linux' and platform_machine == 'x86_64'
1212

1313
# pre-commit checks (formatting, linting, etc.)
1414
pre-commit

0 commit comments

Comments
 (0)