Skip to content

Commit 777b011

Browse files
committed
fixed import errors in test
1 parent 307ea9b commit 777b011

File tree

3 files changed

+11
-9
lines changed

3 files changed

+11
-9
lines changed

keras_rs/src/layers/embedding/jax/checkpoint_utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""A Wrapper over orbax CheckpointManager for Keras3 Jax TPU Embeddings."""
22

33
from typing import Any
4-
from etils import epath
5-
import jax
4+
65
import keras
76
import orbax.checkpoint as ocp
7+
from etils import epath
88

99

1010
class JaxKeras3CheckpointManager(ocp.CheckpointManager):
@@ -47,20 +47,18 @@ def save_state(self, epoch: int) -> None:
4747
state, metrics = self._get_state()
4848
self.save(
4949
epoch * self._steps_per_epoch,
50-
args=ocp.args.StandardSave(state),
51-
metrics=ocp.args.StandardSave(metrics),
50+
args=ocp.args.StandardSave(item=state),
5251
)
5352

5453
def restore_state(self, step: int | None = None) -> None:
5554
"""Restores the model from the checkpoint directory.
5655
5756
Args:
58-
step: The step number to restore the state from. Default=None
57+
step: The step .number to restore the state from. Default=None
5958
restores the latest step.
6059
"""
6160
if step is None:
6261
step = self.latest_step()
63-
6462
# Restore the model state only, not metrics.
6563
state, _ = self._get_state()
6664
restored_state = self.restore(

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def sparsecore_build(
384384
) -> None:
385385
del input_shapes # Unused.
386386

387-
if self._sparsecore_built:
387+
if hasattr(self, "_sparsecore_built") and self._sparsecore_built:
388388
return
389389

390390
feature_specs = config_conversion.keras_to_jte_feature_configs(

keras_rs/src/layers/embedding/jax/distributed_embedding_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
from jax_tpu_embedding.sparsecore.utils import utils as jte_utils
2020

2121
from keras_rs.src.layers.embedding import test_utils as keras_test_utils
22-
from keras_rs.src.layers.embedding.jax import config_conversion
2322
from keras_rs.src.layers.embedding.jax import checkpoint_utils
23+
from keras_rs.src.layers.embedding.jax import config_conversion
2424
from keras_rs.src.layers.embedding.jax import (
2525
distributed_embedding as jax_distributed_embedding,
2626
)
@@ -564,7 +564,10 @@ def loss_fn(y_true, y_pred):
564564
table_stacking=target_stacking,
565565
name=embedding_layer_name,
566566
)
567-
layer_2._sparsecore_build()
567+
input_shapes = jax.tree.map(
568+
lambda f: f.input_shape, feature_configs_with_zero_init
569+
)
570+
layer_2.build(input_shapes)
568571
model_for_restore = keras.Sequential([layer_2])
569572
manager_for_restore = checkpoint_utils.JaxKeras3CheckpointManager(
570573
model_for_restore,
@@ -574,6 +577,7 @@ def loss_fn(y_true, y_pred):
574577
)
575578
model_for_restore.compile(jit_compile=True, loss=loss_fn)
576579
model_for_restore.build()
580+
model_for_restore.optimizer.build(model_for_restore.trainable_variables)
577581
manager_for_restore.restore_state()
578582
jax.tree.map(
579583
np.testing.assert_array_equal,

0 commit comments

Comments
 (0)