-
Notifications
You must be signed in to change notification settings - Fork 746
Closed
Description
One more issue 😄 . Promise this is the last one. There are a lot of question about flax.nn.scan
and RTD and existing GitHub issues do not solve them.
With very deep model compilation times become insane and it takes about 1 hour to compile model for Nvidia runtime. So, I decided to prevent loop unrolling with jax.lax.scan
and its lifting counterpart flax.nn.scan
. However, I faced multiple issues. Incomplete list of issues follows.
- There is no clear way to initialize scanned submodules. I came up with solution to pass everything as
args
andkwargs
to__call__
of submodule (in this caseMLP
). - There is no keyword argument of
flax.nn.scan
as RTD says. - Func
flax.nn.scan
always returns(carry, args)
even if there is onlycarry
and noargs
. - RTD says that
target
should be either a type ofnn.Module
or a function which acceptsnn.Module
(type?) as its first position argument. - If one specified
name
of modules inMLP
then an exception is thrown. It is a bit strange because all parameter trees merged to a single parameter tree.
import flax.linen as nn
import jax
import jax.numpy as jnp
def initializer(val):
def init(key, shape, dtype):
return jnp.full(shape, val, dtype)
return init
class MLP(nn.Module):
@nn.compact
def __call__(self, xs, var):
h = nn.Dense(features=2, kernel_init=initializer(var))(xs)
h = nn.relu(h)
h = nn.Dense(features=2, kernel_init=initializer(var))(xs)
return xs + h, None
class Transformer(nn.Module):
length: int = 3
def setup(self):
def fn(self, *args, **kwargs):
return MLP(self, *args, **kwargs)
# FAIL: Function instead of derived type from nn.Module does not work.
#
# ScanMLP = nn.scan(target=fn, ...)
#
# jax._src.traceback_util.UnfilteredStackTrace: TypeError:
# Transformer.setup.<locals>.fn() missing 1 required positional
# argument: 'self'
# OK: No problems.
ScanMLP = nn.scan(target=fn,
variable_axes={'params': 0},
variable_broadcast=False,
split_rngs={'params': True},
length=self.length)
self.vars = jnp.arange(self.length) # e.g. [0, 1, 2]
self.mlp = ScanMLP() # FAIL: ScanMLP(self.vars)
@nn.compact # OK: This decorator does nothing. Why?
def __call__(self, xs):
carry, out = self.mlp(xs, self.vars) # OK: Axis 0 (implicitely).
assert out is None
return carry
model = Transformer(length=1250)
ys, state = jax.jit(model.init_with_output)(jax.random.PRNGKey(42),
jnp.ones((3, 2)))
kernel = state['params']['mlp']['Dense_0']['kernel']
assert (kernel[0, ...] == jnp.zeros((2, 2))).all()
assert (kernel[1, ...] == jnp.ones((2, 2))).all()
In this experiments flax v0.6.3 and jax v0.4.1 are used.
Metadata
Metadata
Assignees
Labels
No labels