You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I’m building a Transformer decoder in Flax using nn.scan to improve compilation times. With decode=True in MultiHeadDotProductAttention, cache mutable variable does not init on its own, causing a pytree structure mismatch during execution.
The issue is basically equivalent to #2754, except with a cache twist on it.
Error message:
TypeError: scan body function carry input and carry output must have the same pytree structure, but they differ:
The input carry component c[0][0] is a <class 'dict'> with 0 child but the corresponding component of the carry output is a <class 'dict'> with 1 child, so the numbers of children do not match, with the symmetric difference of key sets: {'cache'}.
The issue is that I don't know how to init cache variable for a scan. I did this with generic submodules before, but not with scan'ned ones.
Currently TransformersBlock takes all the time for compilation.
How would I do this? I see nnx has a MutliheadAttention.init_cache, but not linen.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I’m building a Transformer decoder in Flax using
nn.scan
to improve compilation times. Withdecode=True
inMultiHeadDotProductAttention
,cache
mutable variable does not init on its own, causing apytree structure
mismatch during execution.The issue is basically equivalent to #2754, except with a
cache
twist on it.Error message:
Code/Minimal Reproducible Code
The issue is that I don't know how to init
cache
variable for ascan
. I did this with generic submodules before, but not withscan
'ned ones.Currently TransformersBlock takes all the time for compilation.
How would I do this? I see
nnx
has aMutliheadAttention.init_cache
, but not linen.Beta Was this translation helpful? Give feedback.
All reactions