-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Add TF support for mixed optional inputs (None/tensor) from generator in model.fit/evaluate/predict #21609
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add TF support for mixed optional inputs (None/tensor) from generator in model.fit/evaluate/predict #21609
Changes from all commits
7d6c672
0448a38
0e9c75d
e23dba4
63e8758
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -68,7 +68,9 @@ def train_step(self, data): | |
) | ||
self._loss_tracker.update_state( | ||
loss_module.unscale_loss_for_distribution(loss), | ||
sample_weight=tf.shape(tree.flatten(x)[0])[0], | ||
sample_weight=tf.shape( | ||
next(i for i in tree.flatten(x) if i is not None) | ||
)[0], | ||
) | ||
if self.optimizer is not None: | ||
loss = self.optimizer.scale_loss(loss) | ||
|
@@ -96,7 +98,9 @@ def test_step(self, data): | |
) | ||
self._loss_tracker.update_state( | ||
loss_module.unscale_loss_for_distribution(loss), | ||
sample_weight=tf.shape(tree.flatten(x)[0])[0], | ||
sample_weight=tf.shape( | ||
next(i for i in tree.flatten(x) if i is not None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same reply here :) |
||
)[0], | ||
) | ||
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) | ||
|
||
|
@@ -109,17 +113,63 @@ def predict_step(self, data): | |
return y_pred | ||
|
||
def _autoconvert_optionals(self, step_func): | ||
# Wrapper converting (nested) TF Optional in input data to None | ||
# Wrapper converting (nested) TF Optional in input data to tensor/None | ||
@functools.wraps(step_func) | ||
def wrapper(data): | ||
converted_data = tree.map_structure( | ||
lambda i: ( | ||
None if isinstance(i, tf.experimental.Optional) else i | ||
), | ||
data, | ||
) | ||
result = step_func(converted_data) | ||
return result | ||
# Flatten inputs | ||
flat = tree.flatten(data) | ||
|
||
# List positions of optional inputs | ||
opt_pos = [ | ||
i | ||
for i, x in enumerate(flat) | ||
if isinstance(x, tf.experimental.Optional) | ||
] | ||
if not opt_pos: # if nothing optional, just call on data (shortcut) | ||
return step_func(data) | ||
|
||
# Build bitmask for optionals (1=present, 0=empty) | ||
opts = [flat[i] for i in opt_pos] | ||
flags = [o.has_value() for o in opts] # 1 Tensor[bool] per optional | ||
flag_vec = tf.cast(tf.stack(flags), tf.int32) # shape [n] | ||
|
||
# Compute bitmask index via TF ops (traceable with symbolic tensors) | ||
n = len(flags) # number of optional inputs | ||
shifts = tf.range(n, dtype=tf.int32) # [0, 1, 2, ..., n-1] | ||
terms = tf.bitwise.left_shift(flag_vec, shifts) # shape [n] | ||
index = tf.reduce_sum(terms) # scalar int32 in [0, 2^(n-1)] | ||
ncases = 1 << n # = 2^n total cases (efficiently computed) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The number of branches created ( ncases = 1 << n # = 2^n total cases (efficiently computed)
if n > 10:
warnings.warn(
f"Model has {n} optional inputs. This will create 2**{n} "
f"branches in the computational graph, which may be slow to "
f"compile and consume a lot of memory."
) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is absolutely correct, but after many attempts to solve this issue, this solution is the only valid one I could find despite this potential performance impact. However, note that this could only meaningfully impact performance in this very special case: fit/evaluate/predict with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done in this commit. |
||
if n > 10: | ||
warnings.warn( | ||
f"Model has {n} optional inputs. This will create 2^{n} " | ||
"branches in the computational graph, which may be slow to " | ||
"compile and consume a lot of memory." | ||
) | ||
|
||
# Create a branch function for each possible bitmask combination | ||
def make_branch(mask: int): | ||
def branch(): | ||
# Unwrap optional inputs to tensor/None in flat inputs | ||
inputs = list(flat) | ||
for j, i in enumerate(opt_pos): | ||
if inputs[i].element_spec is None: | ||
inputs[i] = None # special case: always None | ||
else: | ||
present = ((mask >> j) & 1) == 1 | ||
inputs[i] = opts[j].get_value() if present else None | ||
|
||
# Pack rebuilt inputs like original data | ||
struct_inputs = tree.pack_sequence_as(data, inputs) | ||
|
||
# Call step_func (same output shapes for all branches) | ||
return step_func(struct_inputs) | ||
|
||
return branch | ||
|
||
branches = [make_branch(m) for m in range(ncases)] | ||
|
||
# Compute result with switch case | ||
return tf.switch_case(index, branch_fns=branches) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wow. All this just to unpack First I'm impressed that you got this to work. But I'm hoping we can find something simpler. I played with it a bit but couldn't find anything so far. What I'm confused about is why we can't just do The related question is whether it will interfere with the existing usage of optionals here: https://github.com/keras-team/keras/blob/master/keras/src/backend/tensorflow/trainer.py#L238 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "All this just to unpack tf.Optionals" --> I know, believe me I tried many other things before to make this edge case work, without success - but if you find a better alternative, it would be great! "What I'm confused about is why we can't just do if "The related question is whether it will interfere with the existing usage of optionals here" --> I don't think that it creates any interference... The solution I propose is just heavier than I wished for (I would also like something simpler), but hopefully my verbose explanation above explains why I did it this way. Of course, I am more than open to a better way if you find one that still enables the new unit test I added in this PR. |
||
|
||
return wrapper | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is probably not a valid use case, but in theory
i for i in tree.flatten(x) if i is not None
could run out, in which casetf.shape
can't be called.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, this assumes that not all inputs are both optional and actually None at the same time (otherwise, I don't see how to infer the batch size anyway...). In any case, this edit makes this line of code (which was present long before my 2 PRs) more robust, now that some optional inputs can be None when using model.fit/evaluate/predict. Without this change, if the first input received (e.g. first alphabetic key in a dict input) is optional and None, it would crash.