-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Add TF support for mixed optional inputs (None/tensor) from generator in model.fit/evaluate/predict #21609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @neo-alex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request enhances the TensorFlow backend in Keras to support generators that provide mixed None and tensor values for optional model inputs during model.fit, evaluate, and predict operations. Previously, the TensorFlow backend struggled with this specific edge case due to tf.data.Dataset's inability to directly handle None values. The changes introduce robust handling of tf.experimental.Optional types, enabling Keras to dynamically adapt to the presence or absence of optional inputs within a batch, ensuring seamless training and inference with diverse data streams.
Highlights
- Enhanced Optional Input Handling: Implemented support for generators yielding a mix of None and tensor values for optional model inputs in TensorFlow backend operations (fit, evaluate, predict).
- tf.experimental.Optional Integration: Improved conversion logic between Python None/tensor values and TensorFlow's tf.experimental.Optional type to correctly represent and process optional inputs within the TensorFlow graph.
- Dynamic Input Branching: Utilized tf.switch_case to create dynamic execution paths based on the presence or absence of optional inputs within a batch, allowing the model to adapt its computation graph at runtime.
- Improved Sample Weight Calculation: Adjusted train_step and test_step to correctly determine batch size for sample_weight calculation, even when inputs contain None values.
- Extended Test Coverage: Added a new test case to model_test.py specifically validating the behavior of generators providing mixed None and tensor values for optional inputs.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for mixed optional inputs (i.e., inputs that can be either a tensor or None
across different batches) from generators in model.fit
, evaluate
, and predict
when using the TensorFlow backend. This is a significant enhancement that addresses a previously unsupported edge case.
The implementation is robust, involving several key changes:
- Optional Input Inference: The data adapter utilities now infer optional inputs by observing the first few batches from a generator.
- TensorFlow Spec Conversion:
KerasTensor
specs are correctly converted totf.OptionalSpec
for TensorFlow'stf.data
pipeline. - Dynamic Input Handling: A sophisticated wrapper,
_autoconvert_optionals
, is introduced in the TensorFlow trainer. It usestf.switch_case
to dynamically handle all possible combinations of present/absent optional inputs within atf.function
, which is necessary for graph-mode execution. - Error Handling and Testing: The changes include excellent, user-friendly error messages for misconfigured generators and a new test case to validate the mixed optional input scenario.
The overall approach is well-designed to handle the complexities of TensorFlow's static graph execution. My main feedback is a suggestion to add a warning for models with a large number of optional inputs, as the current implementation's complexity grows exponentially, which could impact performance.
shifts = tf.range(n, dtype=tf.int32) # [0, 1, 2, ..., n-1] | ||
terms = tf.bitwise.left_shift(flag_vec, shifts) # shape [n] | ||
index = tf.reduce_sum(terms) # scalar int32 in [0, 2^(n-1)] | ||
ncases = 1 << n # = 2^n total cases (efficiently computed) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The number of branches created (ncases
) grows exponentially with the number of optional inputs (n
). While this approach is correct, it could lead to performance issues (long compilation time, large graph size) if a model has a high number of optional inputs. Consider adding a warning if n
exceeds a certain threshold (e.g., 10) to inform the user about the potential performance impact.
ncases = 1 << n # = 2^n total cases (efficiently computed)
if n > 10:
warnings.warn(
f"Model has {n} optional inputs. This will create 2**{n} "
f"branches in the computational graph, which may be slow to "
f"compile and consume a lot of memory."
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is absolutely correct, but after many attempts to solve this issue, this solution is the only valid one I could find despite this potential performance impact. However, note that this could only meaningfully impact performance in this very special case: fit/evaluate/predict with jit_compile=True
of a TF model with many optional inputs from a generator with mixed values (whereas today this case is not supported at all). I agree that a warning could be a good idea though, I will add it in a new commit in this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in this commit.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #21609 +/- ##
==========================================
- Coverage 82.73% 82.44% -0.30%
==========================================
Files 572 572
Lines 57261 57379 +118
Branches 8961 8980 +19
==========================================
- Hits 47377 47308 -69
- Misses 7672 7771 +99
- Partials 2212 2300 +88
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
One more side comment wrt. auto-inferring optional inputs from observing the first generated data batches: currently, only the 2 first batches are observed for that, but we could increase it to 3 by default to cover detection of optional inputs with variable dimensions (beyond batch one), e.g. starting with None/tensor of shape (3, 4)/tensor of shape (3, 5). I can do the change if you think it is a good idea. |
@@ -68,7 +68,9 @@ def train_step(self, data): | |||
) | |||
self._loss_tracker.update_state( | |||
loss_module.unscale_loss_for_distribution(loss), | |||
sample_weight=tf.shape(tree.flatten(x)[0])[0], | |||
sample_weight=tf.shape( | |||
next(i for i in tree.flatten(x) if i is not None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably not a valid use case, but in theory i for i in tree.flatten(x) if i is not None
could run out, in which case tf.shape
can't be called.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, this assumes that not all inputs are both optional and actually None at the same time (otherwise, I don't see how to infer the batch size anyway...). In any case, this edit makes this line of code (which was present long before my 2 PRs) more robust, now that some optional inputs can be None when using model.fit/evaluate/predict. Without this change, if the first input received (e.g. first alphabetic key in a dict input) is optional and None, it would crash.
@@ -96,7 +98,9 @@ def test_step(self, data): | |||
) | |||
self._loss_tracker.update_state( | |||
loss_module.unscale_loss_for_distribution(loss), | |||
sample_weight=tf.shape(tree.flatten(x)[0])[0], | |||
sample_weight=tf.shape( | |||
next(i for i in tree.flatten(x) if i is not None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same reply here :)
branches = [make_branch(m) for m in range(ncases)] | ||
|
||
# Compute result with switch case | ||
return tf.switch_case(index, branch_fns=branches) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow. All this just to unpack tf.Optional
s.
First I'm impressed that you got this to work. But I'm hoping we can find something simpler. I played with it a bit but couldn't find anything so far. What I'm confused about is why we can't just do if x.has_value()
.
The related question is whether it will interfere with the existing usage of optionals here: https://github.com/keras-team/keras/blob/master/keras/src/backend/tensorflow/trainer.py#L238
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"All this just to unpack tf.Optionals" --> I know, believe me I tried many other things before to make this edge case work, without success - but if you find a better alternative, it would be great!
"What I'm confused about is why we can't just do if x.has_value()
" --> I understand your confusion, especially since a similar x.has_value()
is used without problem in a Python if
in the line of code you mention. Actually, this works here because this part of the code is never traced by TF Autograph (since it isn't wrapped in tf.function
). Unfortunately, when using a data generator, our x
of interest is fetched one item at a time from the iterator inside multi_step_on_iterator/one_step_on_data which are both potentially traced by TF Autograph here/there (when run_eagerly=False
, which is the default setting). During tracing, x.has_value()
is then a symbolic boolean tensor (which cannot be simply evaluated as a boolean in Python since it doesn't yet hold any actual value), so I believe we have no choice but use a TF conditional operator to handle whether x has value or not. Additionally, since TF conditional operators (like tf.cond
/tf.switch_case
) don't support having some branch(es) outputting None
when other branch(es) output a tensor, we cannot use a simple tf.cond(x.has_value(), lambda: x.get_value(), lambda: None)
to unwrap the required value in the format Keras expect. The only workaround I found is then to call the model inside the conditional branches, since it has consistent tensor outputs (no more possible None
in outputs, as opposed to inputs) which is compatible with TF conditional operators... at the cost of creating a big tf.switch_case
covering all possible optional input combinations.
"The related question is whether it will interfere with the existing usage of optionals here" --> I don't think that it creates any interference... The solution I propose is just heavier than I wished for (I would also like something simpler), but hopefully my verbose explanation above explains why I did it this way. Of course, I am more than open to a better way if you find one that still enables the new unit test I added in this PR.
This PR is a follow-up to the previous PR #21548 in order to address the last uncovered edge case: supporting optional inputs in model.fit/evaluate/predict with Tensorflow backend when mixed values (sometimes None, sometimes tensor) are produced by the same generator.
Here is a concrete example:
As already discussed with @hertschuh in the previous PR, covering this case in Tensorflow backend is not trivial for a few reasons:
optional.has_value()
&optional.get_value()
must be used to unwrap to correct content; however they return a tensor which can be symbolic when code is traced (with the defaultrun_eagerly=False
) so they can't be reliably tested as a Python bool and can only be used inside a TF control flow operator liketf.cond
/tf.switch_case
- additionally, these TF operators don't support returning None in one branch and a tensor in the other, so we suggest to both convert optional inputs & call the step function (which returns fixed output shapes without risk of None) inside a bigtf.switch_case
with as many branches as needed to cover every possible optional presence combinationAn existing unit test was also extended to test for the newly supported case.