Skip to content

Initialization of Submodules Lifted with flax.nn.scan #2754

@daskol

Description

@daskol

One more issue 😄 . Promise this is the last one. There are a lot of question about flax.nn.scan and RTD and existing GitHub issues do not solve them.

With very deep model compilation times become insane and it takes about 1 hour to compile model for Nvidia runtime. So, I decided to prevent loop unrolling with jax.lax.scan and its lifting counterpart flax.nn.scan. However, I faced multiple issues. Incomplete list of issues follows.

  1. There is no clear way to initialize scanned submodules. I came up with solution to pass everything as args and kwargs to __call__ of submodule (in this case MLP).
  2. There is no keyword argument of flax.nn.scan as RTD says.
  3. Func flax.nn.scan always returns (carry, args) even if there is only carry and no args.
  4. RTD says that target should be either a type of nn.Module or a function which accepts nn.Module (type?) as its first position argument.
  5. If one specified name of modules in MLP then an exception is thrown. It is a bit strange because all parameter trees merged to a single parameter tree.
import flax.linen as nn
import jax
import jax.numpy as jnp


def initializer(val):
    def init(key, shape, dtype):
        return jnp.full(shape, val, dtype)

    return init


class MLP(nn.Module):

    @nn.compact
    def __call__(self, xs, var):
        h = nn.Dense(features=2, kernel_init=initializer(var))(xs)
        h = nn.relu(h)
        h = nn.Dense(features=2, kernel_init=initializer(var))(xs)
        return xs + h, None


class Transformer(nn.Module):

    length: int = 3

    def setup(self):
        def fn(self, *args, **kwargs):
            return MLP(self, *args, **kwargs)

        # FAIL: Function instead of derived type from nn.Module does not work.
        #
        #   ScanMLP = nn.scan(target=fn, ...)
        #
        #   jax._src.traceback_util.UnfilteredStackTrace: TypeError:
        #   Transformer.setup.<locals>.fn() missing 1 required positional
        #   argument: 'self'

        # OK: No problems.
        ScanMLP = nn.scan(target=fn,
                          variable_axes={'params': 0},
                          variable_broadcast=False,
                          split_rngs={'params': True},
                          length=self.length)

        self.vars = jnp.arange(self.length)  # e.g. [0, 1, 2]
        self.mlp = ScanMLP()  # FAIL: ScanMLP(self.vars)

    @nn.compact  # OK: This decorator does nothing. Why?
    def __call__(self, xs):
        carry, out = self.mlp(xs, self.vars)  # OK: Axis 0 (implicitely).
        assert out is None
        return carry


model = Transformer(length=1250)
ys, state = jax.jit(model.init_with_output)(jax.random.PRNGKey(42),
                                            jnp.ones((3, 2)))
kernel = state['params']['mlp']['Dense_0']['kernel']
assert (kernel[0, ...] == jnp.zeros((2, 2))).all()
assert (kernel[1, ...] == jnp.ones((2, 2))).all()

In this experiments flax v0.6.3 and jax v0.4.1 are used.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions