-
Couldn't load subscription status.
- Fork 17
DistributedEmbedding: do not hardcode the number of SparseCores with TF. #147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the DistributedEmbedding layer's compatibility and robustness by dynamically determining the number of SparseCore chips per TPU from the TensorFlow distribution strategy. This change eliminates a previously hardcoded value, making the system more adaptable to diverse TPU hardware configurations, including V6e. Additionally, it introduces improved error handling to guide users in correctly utilizing the function within a TPUStrategy scope, preventing potential misuse.
Highlights
- Dynamic SparseCore Count: The number of SparseCore chips per TPU is no longer a hardcoded value but is now dynamically retrieved from the TensorFlow distribution strategy's hardware features, improving adaptability across different TPU versions.
- Improved TPU Strategy Handling: A runtime check has been added to ensure that
get_embedding_tablesis called within the correctTPUStrategyscope, providing a helpful error message and guidance if this condition is not met. - Test Compatibility: The
distributed_embedding_test.pyfile was updated to wrap theget_embedding_tablescall within the strategy's scope, ensuring tests pass correctly on various TPU architectures, such as V6e.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the DistributedEmbedding layer to dynamically determine the number of SparseCores from the TPU strategy, removing a hardcoded value. This is a great improvement for hardware portability. The changes are well-implemented, including adding a runtime check to ensure get_embedding_tables is called within the correct strategy scope. I have one suggestion to make this new check even more robust by ensuring the strategy is the same instance used during layer initialization.
| if not self._is_tpu_strategy(strategy): | ||
| raise RuntimeError( | ||
| "`DistributedEmbedding.get_embedding_tables` needs to be " | ||
| "called under the TPUStrategy that DistributedEmbedding was " | ||
| f"created with, but is being called under strategy {strategy}. " | ||
| "Please use `with strategy.scope()` when calling " | ||
| "`get_embedding_tables`." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check is a good addition for robustness. However, it only verifies that the current strategy is a TPUStrategy, not that it's the same one the layer was initialized with. This could lead to subtle bugs if multiple TPUStrategy objects are in use, as properties like num_replicas_in_sync might differ between strategies, causing inconsistencies.
To make this check more robust and align with the error message's intent, I recommend storing the strategy during initialization and comparing against that stored instance here.
You would first need to add self._tpu_strategy = strategy in _sparsecore_init. Then you can apply the following suggestion:
| if not self._is_tpu_strategy(strategy): | |
| raise RuntimeError( | |
| "`DistributedEmbedding.get_embedding_tables` needs to be " | |
| "called under the TPUStrategy that DistributedEmbedding was " | |
| f"created with, but is being called under strategy {strategy}. " | |
| "Please use `with strategy.scope()` when calling " | |
| "`get_embedding_tables`." | |
| ) | |
| strategy = tf.distribute.get_strategy() | |
| if not hasattr(self, "_tpu_strategy") or strategy is not self._tpu_strategy: | |
| raise RuntimeError( | |
| "`DistributedEmbedding.get_embedding_tables` must be called " | |
| "under the same TPUStrategy that DistributedEmbedding was " | |
| f"created with, but is being called under strategy {strategy}. " | |
| "Please use `with strategy.scope()` when calling " | |
| "`get_embedding_tables`." | |
| ) |
5fe94a1 to
9738727
Compare
The number of SparseCore chips per TPU is now retrieved from the strategy. This makes the `distributed_embedding_tests.py` pass on V6e.
9738727 to
77afe83
Compare
…TF. (keras-team#147) The number of SparseCore chips per TPU is now retrieved from the strategy. This makes the `distributed_embedding_tests.py` pass on V6e.
The number of SparseCore chips per TPU is now retrieved from the strategy. This makes the
distributed_embedding_tests.pypass on V6e.