Skip to content

Conversation

@hertschuh
Copy link
Collaborator

The number of SparseCore chips per TPU is now retrieved from the strategy. This makes the distributed_embedding_tests.py pass on V6e.

@hertschuh hertschuh requested a review from cantonios September 15, 2025 17:45
Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the DistributedEmbedding layer's compatibility and robustness by dynamically determining the number of SparseCore chips per TPU from the TensorFlow distribution strategy. This change eliminates a previously hardcoded value, making the system more adaptable to diverse TPU hardware configurations, including V6e. Additionally, it introduces improved error handling to guide users in correctly utilizing the function within a TPUStrategy scope, preventing potential misuse.

Highlights

  • Dynamic SparseCore Count: The number of SparseCore chips per TPU is no longer a hardcoded value but is now dynamically retrieved from the TensorFlow distribution strategy's hardware features, improving adaptability across different TPU versions.
  • Improved TPU Strategy Handling: A runtime check has been added to ensure that get_embedding_tables is called within the correct TPUStrategy scope, providing a helpful error message and guidance if this condition is not met.
  • Test Compatibility: The distributed_embedding_test.py file was updated to wrap the get_embedding_tables call within the strategy's scope, ensuring tests pass correctly on various TPU architectures, such as V6e.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the DistributedEmbedding layer to dynamically determine the number of SparseCores from the TPU strategy, removing a hardcoded value. This is a great improvement for hardware portability. The changes are well-implemented, including adding a runtime check to ensure get_embedding_tables is called within the correct strategy scope. I have one suggestion to make this new check even more robust by ensuring the strategy is the same instance used during layer initialization.

Comment on lines +284 to +291
if not self._is_tpu_strategy(strategy):
raise RuntimeError(
"`DistributedEmbedding.get_embedding_tables` needs to be "
"called under the TPUStrategy that DistributedEmbedding was "
f"created with, but is being called under strategy {strategy}. "
"Please use `with strategy.scope()` when calling "
"`get_embedding_tables`."
)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check is a good addition for robustness. However, it only verifies that the current strategy is a TPUStrategy, not that it's the same one the layer was initialized with. This could lead to subtle bugs if multiple TPUStrategy objects are in use, as properties like num_replicas_in_sync might differ between strategies, causing inconsistencies.

To make this check more robust and align with the error message's intent, I recommend storing the strategy during initialization and comparing against that stored instance here.

You would first need to add self._tpu_strategy = strategy in _sparsecore_init. Then you can apply the following suggestion:

Suggested change
if not self._is_tpu_strategy(strategy):
raise RuntimeError(
"`DistributedEmbedding.get_embedding_tables` needs to be "
"called under the TPUStrategy that DistributedEmbedding was "
f"created with, but is being called under strategy {strategy}. "
"Please use `with strategy.scope()` when calling "
"`get_embedding_tables`."
)
strategy = tf.distribute.get_strategy()
if not hasattr(self, "_tpu_strategy") or strategy is not self._tpu_strategy:
raise RuntimeError(
"`DistributedEmbedding.get_embedding_tables` must be called "
"under the same TPUStrategy that DistributedEmbedding was "
f"created with, but is being called under strategy {strategy}. "
"Please use `with strategy.scope()` when calling "
"`get_embedding_tables`."
)

The number of SparseCore chips per TPU is now retrieved from the strategy. This makes the `distributed_embedding_tests.py` pass on V6e.
@hertschuh hertschuh merged commit 2fbdc2c into keras-team:main Sep 16, 2025
5 checks passed
@hertschuh hertschuh deleted the tf_num_sc branch September 16, 2025 17:13
LakshmiKalaKadali pushed a commit to LakshmiKalaKadali/keras-rs that referenced this pull request Oct 23, 2025
…TF. (keras-team#147)

The number of SparseCore chips per TPU is now retrieved from the strategy. This makes the `distributed_embedding_tests.py` pass on V6e.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants