Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 61 additions & 11 deletions keras/src/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def train_step(self, data):
)
self._loss_tracker.update_state(
loss_module.unscale_loss_for_distribution(loss),
sample_weight=tf.shape(tree.flatten(x)[0])[0],
sample_weight=tf.shape(
next(i for i in tree.flatten(x) if i is not None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably not a valid use case, but in theory i for i in tree.flatten(x) if i is not None could run out, in which case tf.shape can't be called.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, this assumes that not all inputs are both optional and actually None at the same time (otherwise, I don't see how to infer the batch size anyway...). In any case, this edit makes this line of code (which was present long before my 2 PRs) more robust, now that some optional inputs can be None when using model.fit/evaluate/predict. Without this change, if the first input received (e.g. first alphabetic key in a dict input) is optional and None, it would crash.

)[0],
)
if self.optimizer is not None:
loss = self.optimizer.scale_loss(loss)
Expand Down Expand Up @@ -96,7 +98,9 @@ def test_step(self, data):
)
self._loss_tracker.update_state(
loss_module.unscale_loss_for_distribution(loss),
sample_weight=tf.shape(tree.flatten(x)[0])[0],
sample_weight=tf.shape(
next(i for i in tree.flatten(x) if i is not None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same reply here :)

)[0],
)
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)

Expand All @@ -109,17 +113,63 @@ def predict_step(self, data):
return y_pred

def _autoconvert_optionals(self, step_func):
# Wrapper converting (nested) TF Optional in input data to None
# Wrapper converting (nested) TF Optional in input data to tensor/None
@functools.wraps(step_func)
def wrapper(data):
converted_data = tree.map_structure(
lambda i: (
None if isinstance(i, tf.experimental.Optional) else i
),
data,
)
result = step_func(converted_data)
return result
# Flatten inputs
flat = tree.flatten(data)

# List positions of optional inputs
opt_pos = [
i
for i, x in enumerate(flat)
if isinstance(x, tf.experimental.Optional)
]
if not opt_pos: # if nothing optional, just call on data (shortcut)
return step_func(data)

# Build bitmask for optionals (1=present, 0=empty)
opts = [flat[i] for i in opt_pos]
flags = [o.has_value() for o in opts] # 1 Tensor[bool] per optional
flag_vec = tf.cast(tf.stack(flags), tf.int32) # shape [n]

# Compute bitmask index via TF ops (traceable with symbolic tensors)
n = len(flags) # number of optional inputs
shifts = tf.range(n, dtype=tf.int32) # [0, 1, 2, ..., n-1]
terms = tf.bitwise.left_shift(flag_vec, shifts) # shape [n]
index = tf.reduce_sum(terms) # scalar int32 in [0, 2^(n-1)]
ncases = 1 << n # = 2^n total cases (efficiently computed)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The number of branches created (ncases) grows exponentially with the number of optional inputs (n). While this approach is correct, it could lead to performance issues (long compilation time, large graph size) if a model has a high number of optional inputs. Consider adding a warning if n exceeds a certain threshold (e.g., 10) to inform the user about the potential performance impact.

            ncases = 1 << n  # = 2^n total cases (efficiently computed)
            if n > 10:
                warnings.warn(
                    f"Model has {n} optional inputs. This will create 2**{n} "
                    f"branches in the computational graph, which may be slow to "
                    f"compile and consume a lot of memory."
                )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is absolutely correct, but after many attempts to solve this issue, this solution is the only valid one I could find despite this potential performance impact. However, note that this could only meaningfully impact performance in this very special case: fit/evaluate/predict with jit_compile=True of a TF model with many optional inputs from a generator with mixed values (whereas today this case is not supported at all). I agree that a warning could be a good idea though, I will add it in a new commit in this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in this commit.

if n > 10:
warnings.warn(
f"Model has {n} optional inputs. This will create 2^{n} "
"branches in the computational graph, which may be slow to "
"compile and consume a lot of memory."
)

# Create a branch function for each possible bitmask combination
def make_branch(mask: int):
def branch():
# Unwrap optional inputs to tensor/None in flat inputs
inputs = list(flat)
for j, i in enumerate(opt_pos):
if inputs[i].element_spec is None:
inputs[i] = None # special case: always None
else:
present = ((mask >> j) & 1) == 1
inputs[i] = opts[j].get_value() if present else None

# Pack rebuilt inputs like original data
struct_inputs = tree.pack_sequence_as(data, inputs)

# Call step_func (same output shapes for all branches)
return step_func(struct_inputs)

return branch

branches = [make_branch(m) for m in range(ncases)]

# Compute result with switch case
return tf.switch_case(index, branch_fns=branches)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow. All this just to unpack tf.Optionals.

First I'm impressed that you got this to work. But I'm hoping we can find something simpler. I played with it a bit but couldn't find anything so far. What I'm confused about is why we can't just do if x.has_value().

The related question is whether it will interfere with the existing usage of optionals here: https://github.com/keras-team/keras/blob/master/keras/src/backend/tensorflow/trainer.py#L238

Copy link
Contributor Author

@neo-alex neo-alex Aug 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"All this just to unpack tf.Optionals" --> I know, believe me I tried many other things before to make this edge case work, without success - but if you find a better alternative, it would be great!

"What I'm confused about is why we can't just do if x.has_value()" --> I understand your confusion, especially since a similar x.has_value() is used without problem in a Python if in the line of code you mention. Actually, this works here because this part of the code is never traced by TF Autograph (since it isn't wrapped in tf.function). Unfortunately, when using a data generator, our x of interest is fetched one item at a time from the iterator inside multi_step_on_iterator/one_step_on_data which are both potentially traced by TF Autograph here/there (when run_eagerly=False, which is the default setting). During tracing, x.has_value() is then a symbolic boolean tensor (which cannot be simply evaluated as a boolean in Python since it doesn't yet hold any actual value), so I believe we have no choice but use a TF conditional operator to handle whether x has value or not. Additionally, since TF conditional operators (like tf.cond/tf.switch_case) don't support having some branch(es) outputting None when other branch(es) output a tensor, we cannot use a simple tf.cond(x.has_value(), lambda: x.get_value(), lambda: None) to unwrap the required value in the format Keras expect. The only workaround I found is then to call the model inside the conditional branches, since it has consistent tensor outputs (no more possible None in outputs, as opposed to inputs) which is compatible with TF conditional operators... at the cost of creating a big tf.switch_case covering all possible optional input combinations.

"The related question is whether it will interfere with the existing usage of optionals here" --> I don't think that it creates any interference... The solution I propose is just heavier than I wished for (I would also like something simpler), but hopefully my verbose explanation above explains why I did it this way. Of course, I am more than open to a better way if you find one that still enables the new unit test I added in this PR.


return wrapper

Expand Down
11 changes: 8 additions & 3 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,16 +1254,21 @@ def test_functional_optional_inputs(self, is_optional_none):
model.predict(x={"x1": x1, "x2": x2})

@parameterized.named_parameters(
("optional_none", True), ("optional_tensor", False)
("optional_none", True),
("optional_tensor", False),
("optional_mixed", "sometimes"),
)
def test_functional_optional_inputs_generator(self, is_optional_none):
model = _get_model_optional_inputs()
x1 = np.ones((2, 2))
x2 = None if is_optional_none else np.ones((2, 2))
y_true = np.ones((2, 2))

def data_generator(with_y=True):
for _ in range(4):
for i in range(4):
if is_optional_none == "sometimes":
x2 = None if i % 2 == 0 else np.ones((2, 2))
else:
x2 = None if is_optional_none else np.ones((2, 2))
yield ({"x1": x1, "x2": x2},) + ((y_true,) if with_y else ())

model.compile(loss="mse", optimizer="adam")
Expand Down
42 changes: 31 additions & 11 deletions keras/src/trainers/data_adapters/data_adapter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,15 @@ def get_keras_tensor_spec(batches):
A nested structure of `KerasTensor`.
"""

def get_single_tensor_spec(*tensors):
def get_single_tensor_spec(*tensors_or_none):
# Filter out None values (possible for optional inputs)
tensors = [t for t in tensors_or_none if t is not None]
if len(tensors) == 0:
return None

# Detect optional input when some tensors are None
is_optional = len(tensors_or_none) > len(tensors)

x = tensors[0]
if not hasattr(x, "shape"):
# Try to convert to a numpy array.
Expand All @@ -176,21 +184,26 @@ def get_single_tensor_spec(*tensors):

dtype = backend.standardize_dtype(x.dtype)
if is_tensorflow_ragged(x):
return backend.KerasTensor(
tensor_spec = backend.KerasTensor(
shape=shape,
dtype=dtype,
ragged=True,
ragged_rank=x.ragged_rank,
row_splits_dtype=x.row_splits.dtype,
)
if is_tensorflow_sparse(x) or is_scipy_sparse(x) or is_jax_sparse(x):
return backend.KerasTensor(shape=shape, dtype=dtype, sparse=True)
elif is_tensorflow_sparse(x) or is_scipy_sparse(x) or is_jax_sparse(x):
tensor_spec = backend.KerasTensor(
shape=shape, dtype=dtype, sparse=True
)
else:
return backend.KerasTensor(shape=shape, dtype=dtype)
tensor_spec = backend.KerasTensor(shape=shape, dtype=dtype)

return tree.map_structure(
get_single_tensor_spec, *batches, none_is_leaf=False
)
backend.common.tensor_attributes.set_tensor_attr(
tensor_spec, "_keras_optional", is_optional
)
return tensor_spec

return tree.map_structure(get_single_tensor_spec, *batches)


def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True):
Expand All @@ -214,16 +227,23 @@ def convert_to_tf_tensor_spec(keras_tensor, batch_axis_to_none=True):
if batch_axis_to_none:
shape[0] = None
if keras_tensor.ragged:
return tf.RaggedTensorSpec(
tf_tensor_spec = tf.RaggedTensorSpec(
shape=shape,
dtype=keras_tensor.dtype,
ragged_rank=keras_tensor.ragged_rank,
row_splits_dtype=keras_tensor.row_splits_dtype,
)
elif keras_tensor.sparse:
return tf.SparseTensorSpec(shape=shape, dtype=keras_tensor.dtype)
tf_tensor_spec = tf.SparseTensorSpec(
shape=shape, dtype=keras_tensor.dtype
)
else:
return tf.TensorSpec(shape=shape, dtype=keras_tensor.dtype)
tf_tensor_spec = tf.TensorSpec(shape=shape, dtype=keras_tensor.dtype)
if backend.common.tensor_attributes.get_tensor_attr(
keras_tensor, "_keras_optional"
):
tf_tensor_spec = tf.OptionalSpec(tf_tensor_spec)
return tf_tensor_spec


def get_tensor_spec(batches):
Expand Down
23 changes: 22 additions & 1 deletion keras/src/trainers/data_adapters/generator_data_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,27 @@ def get_tf_dataset(self):
from keras.src.utils.module_utils import tensorflow as tf

def convert_to_tf(x, spec):
is_optional = isinstance(spec, tf.OptionalSpec)
if x is None:
return tf.experimental.Optional.empty(None)
if not is_optional:
raise TypeError(
"Generator yielded a `None` element where a tensor of "
f"shape {spec.shape} was expected. For every optional "
"tensor your generator provides, make sure that the "
"generator's first two batches include a `None` value "
"and an actual tensor."
)
return tf.experimental.Optional.empty(spec._element_spec)
if is_optional:
spec = spec._element_spec
if spec is None:
raise TypeError(
f"Generator yielded a tensor of shape {x.shape} where "
"a `None` element was expected. For every optional "
"tensor your generator provides, make sure that the "
"generator's first two batches include a `None` value "
"and an actual tensor."
)
if data_adapter_utils.is_scipy_sparse(x):
x = data_adapter_utils.scipy_sparse_to_tf_sparse(x)
elif data_adapter_utils.is_jax_sparse(x):
Expand All @@ -48,6 +67,8 @@ def convert_to_tf(x, spec):
"dimension value wherever there is a variable input "
"dimension."
)
if is_optional:
return tf.experimental.Optional.from_value(x)
return x

def get_tf_iterator():
Expand Down