Skip to content

Commit 8a5cfca

Browse files
committed
Debug
1 parent 6c3f8ce commit 8a5cfca

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

examples/ml_perf/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,11 @@ def call(self, inputs: dict[str, Tensor]) -> Tensor:
196196
for small_emb_feature in small_emb_inputs.keys():
197197
small_emb_input = small_emb_inputs[small_emb_feature]
198198
embedding_layer = self.small_embedding_layers[small_emb_feature]
199+
199200
embedding = embedding_layer(small_emb_input)
200201
embedding = ops.sum(embedding, axis=-2)
201202
small_embeddings.append(embedding)
203+
202204
small_embeddings = ops.concatenate(small_embeddings, axis=-1)
203205

204206
# Interaction

keras_rs/src/layers/embedding/jax/distributed_embedding.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,8 @@ def sparsecore_build(
412412
feature_specs = config_conversion.keras_to_jte_feature_configs(
413413
self._sc_feature_configs
414414
)
415+
print(f"--->{self._sc_feature_configs=}")
416+
print(f"--->{feature_specs=}")
415417

416418
# Distribution for sparsecore operations.
417419
sparsecore_distribution, sparsecore_layout = (

0 commit comments

Comments
 (0)