Skip to content

Conversation

neo-alex
Copy link
Contributor

This PR is a follow-up to the previous PR #21548 in order to address the last uncovered edge case: supporting optional inputs in model.fit/evaluate/predict with Tensorflow backend when mixed values (sometimes None, sometimes tensor) are produced by the same generator.

Here is a concrete example:

class OptionalInputLayer(layers.Layer):
    def __init__(self):
        super().__init__()
        self.dense = layers.Dense(2)

    def call(self, x, y=None):
        z = x if y is None else x + y
        return self.dense(z)

# Create model with 2 inputs (the second one being optional)
i1 = Input((2,), name="input1")
i2 = Input((2,), name="input2", optional=True)
outputs = OptionalInputLayer()(i1, i2)
model = Model({"input1": i1, "input2": i2}, outputs)
model.compile(loss=losses.MeanSquaredError)

# Train from generator (optional inputs always None or always tensor)
data_generator1 = (({"input1": np.ones((2, 2)), "input2": None}, np.ones((2, 2))) for _ in range(4))
model.fit(x=data_generator1)  # WORKS

# Train from generator (optional inputs with mixed None/tensor values)
data_generator2 = (({"input1": np.ones((2, 2)), "input2": None if i % 2 == 0 else np.ones((2, 2))}, np.ones((2, 2))) for i in range(4))
model.fit(x=data_generator2)  # DOESN'T WORK IN TENSORFLOW (UNTIL THIS PR)

As already discussed with @hertschuh in the previous PR, covering this case in Tensorflow backend is not trivial for a few reasons:

  • TF Dataset (used by default to provide data in TF backend) doesn't support None values directly, it requires tf.experimental.Optional instead, so there is a need to convert back & forth between these formats
  • for None/tensor > tf.experimental.Optional conversion: detecting optional inputs from data requires observing first batches produced by the generator to try to infer, as already done in present code (since data adapters are independent from model specification)
  • for tf.experimental.Optional > None/tensor conversion: the methods optional.has_value()& optional.get_value() must be used to unwrap to correct content; however they return a tensor which can be symbolic when code is traced (with the default run_eagerly=False) so they can't be reliably tested as a Python bool and can only be used inside a TF control flow operator like tf.cond/tf.switch_case - additionally, these TF operators don't support returning None in one branch and a tensor in the other, so we suggest to both convert optional inputs & call the step function (which returns fixed output shapes without risk of None) inside a big tf.switch_case with as many branches as needed to cover every possible optional presence combination

An existing unit test was also extended to test for the newly supported case.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @neo-alex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the TensorFlow backend in Keras to support generators that provide mixed None and tensor values for optional model inputs during model.fit, evaluate, and predict operations. Previously, the TensorFlow backend struggled with this specific edge case due to tf.data.Dataset's inability to directly handle None values. The changes introduce robust handling of tf.experimental.Optional types, enabling Keras to dynamically adapt to the presence or absence of optional inputs within a batch, ensuring seamless training and inference with diverse data streams.

Highlights

  • Enhanced Optional Input Handling: Implemented support for generators yielding a mix of None and tensor values for optional model inputs in TensorFlow backend operations (fit, evaluate, predict).
  • tf.experimental.Optional Integration: Improved conversion logic between Python None/tensor values and TensorFlow's tf.experimental.Optional type to correctly represent and process optional inputs within the TensorFlow graph.
  • Dynamic Input Branching: Utilized tf.switch_case to create dynamic execution paths based on the presence or absence of optional inputs within a batch, allowing the model to adapt its computation graph at runtime.
  • Improved Sample Weight Calculation: Adjusted train_step and test_step to correctly determine batch size for sample_weight calculation, even when inputs contain None values.
  • Extended Test Coverage: Added a new test case to model_test.py specifically validating the behavior of generators providing mixed None and tensor values for optional inputs.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for mixed optional inputs (i.e., inputs that can be either a tensor or None across different batches) from generators in model.fit, evaluate, and predict when using the TensorFlow backend. This is a significant enhancement that addresses a previously unsupported edge case.

The implementation is robust, involving several key changes:

  1. Optional Input Inference: The data adapter utilities now infer optional inputs by observing the first few batches from a generator.
  2. TensorFlow Spec Conversion: KerasTensor specs are correctly converted to tf.OptionalSpec for TensorFlow's tf.data pipeline.
  3. Dynamic Input Handling: A sophisticated wrapper, _autoconvert_optionals, is introduced in the TensorFlow trainer. It uses tf.switch_case to dynamically handle all possible combinations of present/absent optional inputs within a tf.function, which is necessary for graph-mode execution.
  4. Error Handling and Testing: The changes include excellent, user-friendly error messages for misconfigured generators and a new test case to validate the mixed optional input scenario.

The overall approach is well-designed to handle the complexities of TensorFlow's static graph execution. My main feedback is a suggestion to add a warning for models with a large number of optional inputs, as the current implementation's complexity grows exponentially, which could impact performance.

shifts = tf.range(n, dtype=tf.int32) # [0, 1, 2, ..., n-1]
terms = tf.bitwise.left_shift(flag_vec, shifts) # shape [n]
index = tf.reduce_sum(terms) # scalar int32 in [0, 2^(n-1)]
ncases = 1 << n # = 2^n total cases (efficiently computed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The number of branches created (ncases) grows exponentially with the number of optional inputs (n). While this approach is correct, it could lead to performance issues (long compilation time, large graph size) if a model has a high number of optional inputs. Consider adding a warning if n exceeds a certain threshold (e.g., 10) to inform the user about the potential performance impact.

            ncases = 1 << n  # = 2^n total cases (efficiently computed)
            if n > 10:
                warnings.warn(
                    f"Model has {n} optional inputs. This will create 2**{n} "
                    f"branches in the computational graph, which may be slow to "
                    f"compile and consume a lot of memory."
                )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is absolutely correct, but after many attempts to solve this issue, this solution is the only valid one I could find despite this potential performance impact. However, note that this could only meaningfully impact performance in this very special case: fit/evaluate/predict with jit_compile=True of a TF model with many optional inputs from a generator with mixed values (whereas today this case is not supported at all). I agree that a warning could be a good idea though, I will add it in a new commit in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in this commit.

@codecov-commenter
Copy link

codecov-commenter commented Aug 22, 2025

Codecov Report

❌ Patch coverage is 78.18182% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.44%. Comparing base (19367bc) to head (63e8758).
⚠️ Report is 2 commits behind head on master.

Files with missing lines Patch % Lines
...c/trainers/data_adapters/generator_data_adapter.py 0.00% 10 Missing ⚠️
keras/src/backend/tensorflow/trainer.py 92.59% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #21609      +/-   ##
==========================================
- Coverage   82.73%   82.44%   -0.30%     
==========================================
  Files         572      572              
  Lines       57261    57379     +118     
  Branches     8961     8980      +19     
==========================================
- Hits        47377    47308      -69     
- Misses       7672     7771      +99     
- Partials     2212     2300      +88     
Flag Coverage Δ
keras 82.25% <78.18%> (-0.30%) ⬇️
keras-jax 63.54% <16.36%> (-0.24%) ⬇️
keras-numpy 57.83% <16.36%> (-0.18%) ⬇️
keras-openvino 34.32% <0.00%> (-0.10%) ⬇️
keras-tensorflow 64.22% <78.18%> (-0.14%) ⬇️
keras-torch 63.75% <16.36%> (-0.18%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@neo-alex
Copy link
Contributor Author

One more side comment wrt. auto-inferring optional inputs from observing the first generated data batches: currently, only the 2 first batches are observed for that, but we could increase it to 3 by default to cover detection of optional inputs with variable dimensions (beyond batch one), e.g. starting with None/tensor of shape (3, 4)/tensor of shape (3, 5). I can do the change if you think it is a good idea.

@@ -68,7 +68,9 @@ def train_step(self, data):
)
self._loss_tracker.update_state(
loss_module.unscale_loss_for_distribution(loss),
sample_weight=tf.shape(tree.flatten(x)[0])[0],
sample_weight=tf.shape(
next(i for i in tree.flatten(x) if i is not None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably not a valid use case, but in theory i for i in tree.flatten(x) if i is not None could run out, in which case tf.shape can't be called.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this assumes that not all inputs are both optional and actually None at the same time (otherwise, I don't see how to infer the batch size anyway...). In any case, this edit makes this line of code (which was present long before my 2 PRs) more robust, now that some optional inputs can be None when using model.fit/evaluate/predict. Without this change, if the first input received (e.g. first alphabetic key in a dict input) is optional and None, it would crash.

@@ -96,7 +98,9 @@ def test_step(self, data):
)
self._loss_tracker.update_state(
loss_module.unscale_loss_for_distribution(loss),
sample_weight=tf.shape(tree.flatten(x)[0])[0],
sample_weight=tf.shape(
next(i for i in tree.flatten(x) if i is not None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reply here :)

branches = [make_branch(m) for m in range(ncases)]

# Compute result with switch case
return tf.switch_case(index, branch_fns=branches)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow. All this just to unpack tf.Optionals.

First I'm impressed that you got this to work. But I'm hoping we can find something simpler. I played with it a bit but couldn't find anything so far. What I'm confused about is why we can't just do if x.has_value().

The related question is whether it will interfere with the existing usage of optionals here: https://github.com/keras-team/keras/blob/master/keras/src/backend/tensorflow/trainer.py#L238

Copy link
Contributor Author

@neo-alex neo-alex Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"All this just to unpack tf.Optionals" --> I know, believe me I tried many other things before to make this edge case work, without success - but if you find a better alternative, it would be great!

"What I'm confused about is why we can't just do if x.has_value()" --> I understand your confusion, especially since a similar x.has_value() is used without problem in a Python if in the line of code you mention. Actually, this works here because this part of the code is never traced by TF Autograph (since it isn't wrapped in tf.function). Unfortunately, when using a data generator, our x of interest is fetched one item at a time from the iterator inside multi_step_on_iterator/one_step_on_data which are both potentially traced by TF Autograph here/there (when run_eagerly=False, which is the default setting). During tracing, x.has_value() is then a symbolic boolean tensor (which cannot be simply evaluated as a boolean in Python since it doesn't yet hold any actual value), so I believe we have no choice but use a TF conditional operator to handle whether x has value or not. Additionally, since TF conditional operators (like tf.cond/tf.switch_case) don't support having some branch(es) outputting None when other branch(es) output a tensor, we cannot use a simple tf.cond(x.has_value(), lambda: x.get_value(), lambda: None) to unwrap the required value in the format Keras expect. The only workaround I found is then to call the model inside the conditional branches, since it has consistent tensor outputs (no more possible None in outputs, as opposed to inputs) which is compatible with TF conditional operators... at the cost of creating a big tf.switch_case covering all possible optional input combinations.

"The related question is whether it will interfere with the existing usage of optionals here" --> I don't think that it creates any interference... The solution I propose is just heavier than I wished for (I would also like something simpler), but hopefully my verbose explanation above explains why I did it this way. Of course, I am more than open to a better way if you find one that still enables the new unit test I added in this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants