Skip to content

Commit 8dfa56d

Browse files
authored
Add out_sharding argument WrappedKerasInitializer. (#102)
This is for forward-compatibility. Latest versions of JAX introduce the `out_sharding` argument.
1 parent b6b6325 commit 8dfa56d

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

keras_rs/src/layers/embedding/jax/config_conversion.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,13 @@ def key(self) -> Union[jax.Array, None]:
3636
return None
3737

3838
def __call__(
39-
self, key: Any, shape: Any, dtype: Any = jnp.float_
39+
self,
40+
key: Any,
41+
shape: Any,
42+
dtype: Any = jnp.float_,
43+
out_sharding: Any = None,
4044
) -> jax.Array:
45+
del out_sharding
4146
# Force use of provided key. The JAX backend for random initializers
4247
# forwards the `seed` attribute to the underlying JAX random functions.
4348
if key is not None and hasattr(self.initializer, "seed"):

0 commit comments

Comments
 (0)