|
| 1 | +# Copyright 2024 The Flax Authors. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# %% |
| 16 | +from functools import partial |
| 17 | +import jax |
| 18 | +import jax.numpy as jnp |
| 19 | +from flax import nnx |
| 20 | +import optax |
| 21 | +import numpy as np |
| 22 | +from einop import einop |
| 23 | +from time import time |
| 24 | +from tqdm import tqdm |
| 25 | + |
| 26 | +from flax import nnx |
| 27 | + |
| 28 | +from absl import flags |
| 29 | +from absl import app |
| 30 | + |
| 31 | +FLAGS = flags.FLAGS |
| 32 | +flags.DEFINE_enum( |
| 33 | + 'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in' |
| 34 | +) |
| 35 | +flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps') |
| 36 | +flags.DEFINE_integer('batch_size', 32, 'Batch size') |
| 37 | +flags.DEFINE_integer('width', 32, 'Hidden layer size') |
| 38 | +flags.DEFINE_integer('depth', 4, 'Depth of the model') |
| 39 | + |
| 40 | + |
| 41 | +class MlpBlock(nnx.Module): |
| 42 | + def __init__(self, din: int, mlp_dim: int, rngs: nnx.Rngs): |
| 43 | + self.din, self.mlp_dim = din, mlp_dim |
| 44 | + self.linear_in = nnx.Linear(din, mlp_dim, rngs=rngs) |
| 45 | + self.linear_out = nnx.Linear(mlp_dim, din, rngs=rngs) |
| 46 | + |
| 47 | + def __call__(self, x): |
| 48 | + return self.linear_out(nnx.gelu(self.linear_in(x))) |
| 49 | + |
| 50 | + |
| 51 | +class MixerBlock(nnx.Module): |
| 52 | + def __init__( |
| 53 | + self, |
| 54 | + tokens_mlp_dim: int, |
| 55 | + channels_mlp_dim: int, |
| 56 | + hidden_dim: int, |
| 57 | + rngs: nnx.Rngs, |
| 58 | + ): |
| 59 | + self.tokens_mlp_dim = tokens_mlp_dim |
| 60 | + self.channels_mlp_dim = channels_mlp_dim |
| 61 | + self.hidden_dim = hidden_dim |
| 62 | + self.token_mixing = MlpBlock(tokens_mlp_dim, hidden_dim, rngs=rngs) |
| 63 | + self.channel_mixing = MlpBlock(channels_mlp_dim, hidden_dim, rngs=rngs) |
| 64 | + self.ln1 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs) |
| 65 | + self.ln2 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs) |
| 66 | + |
| 67 | + def __call__(self, x): |
| 68 | + y = self.ln1(x) |
| 69 | + y = y.swapaxes(1, 2) |
| 70 | + y = self.token_mixing(y) |
| 71 | + y = y.swapaxes(1, 2) |
| 72 | + x = x + y |
| 73 | + y = self.ln2(x) |
| 74 | + return x + self.channel_mixing(y) |
| 75 | + |
| 76 | + |
| 77 | +class MlpMixer(nnx.Module): |
| 78 | + def __init__( |
| 79 | + self, |
| 80 | + din: int, |
| 81 | + kernel_size: tuple[int, int], |
| 82 | + strides: tuple[int, int], |
| 83 | + num_blocks: int, |
| 84 | + hidden_dim: int, |
| 85 | + tokens_mlp_dim: int, |
| 86 | + channels_mlp_dim: int, |
| 87 | + rngs: nnx.Rngs, |
| 88 | + ): |
| 89 | + self.din = din |
| 90 | + self.kernel_size = kernel_size |
| 91 | + self.num_blocks = num_blocks |
| 92 | + self.hidden_dim = hidden_dim |
| 93 | + self.tokens_mlp_dim = tokens_mlp_dim |
| 94 | + self.channels_mlp_dim = channels_mlp_dim |
| 95 | + self.stem = nnx.Conv( |
| 96 | + din + 1, |
| 97 | + channels_mlp_dim, |
| 98 | + kernel_size=kernel_size, |
| 99 | + strides=strides, |
| 100 | + rngs=rngs, |
| 101 | + ) |
| 102 | + self.blocks = [ |
| 103 | + MixerBlock(tokens_mlp_dim, channels_mlp_dim, hidden_dim, rngs=rngs) |
| 104 | + for _ in range(num_blocks) |
| 105 | + ] |
| 106 | + self.pre_head_layer_norm = nnx.LayerNorm(channels_mlp_dim, rngs=rngs) |
| 107 | + self.conv_t = nnx.ConvTranspose( |
| 108 | + channels_mlp_dim, din, kernel_size=kernel_size, strides=strides, rngs=rngs |
| 109 | + ) |
| 110 | + |
| 111 | + def __call__(self, *, x, t): |
| 112 | + # add time feature to input |
| 113 | + t = einop(t, 'n -> n h w c', h=x.shape[1], w=x.shape[2], c=1) |
| 114 | + x = jnp.concatenate([x, t], axis=-1) |
| 115 | + # create patches |
| 116 | + x = self.stem(x) |
| 117 | + h, w = x.shape[1], x.shape[2] |
| 118 | + x = einop(x, 'n h w c -> n (h w) c') |
| 119 | + # apply blocks |
| 120 | + for block in self.blocks: |
| 121 | + x = block(x) |
| 122 | + x = self.pre_head_layer_norm(x) |
| 123 | + # recreate image |
| 124 | + x = einop(x, 'n (h w) c -> n h w c', h=h, w=w) |
| 125 | + x = self.conv_t(x) |
| 126 | + return x |
| 127 | + |
| 128 | + |
| 129 | +def main(argv): |
| 130 | + print(argv) |
| 131 | + mode: str = FLAGS.mode |
| 132 | + total_steps: int = FLAGS.total_steps |
| 133 | + batch_size: int = FLAGS.batch_size |
| 134 | + width: int = FLAGS.width |
| 135 | + depth: int = FLAGS.depth |
| 136 | + |
| 137 | + print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}') |
| 138 | + |
| 139 | + X = np.random.uniform(size=(batch_size, 28, 28, 1)) |
| 140 | + |
| 141 | + if mode == 'nnx' or mode == 'all': |
| 142 | + rngs = nnx.Rngs(0) |
| 143 | + flow = MlpMixer( |
| 144 | + din=1, |
| 145 | + kernel_size=(2, 2), |
| 146 | + strides=(2, 2), |
| 147 | + num_blocks=4, |
| 148 | + hidden_dim=512, |
| 149 | + tokens_mlp_dim=196, |
| 150 | + channels_mlp_dim=512, |
| 151 | + rngs=rngs, |
| 152 | + ) |
| 153 | + optimizer = nnx.Optimizer(flow, tx=optax.adamw(1e-4)) |
| 154 | + t0 = time() |
| 155 | + |
| 156 | + mse = lambda a, b: jnp.mean((a - b) ** 2) |
| 157 | + |
| 158 | + @nnx.jit(donate_argnums=(0, 1, 2)) |
| 159 | + def train_step_nnx(flow, optimizer, rngs, x_1): |
| 160 | + print('JITTING NNX') |
| 161 | + x_0 = jax.random.normal(rngs(), x_1.shape) |
| 162 | + t = jax.random.uniform(rngs(), (len(x_1),)) |
| 163 | + |
| 164 | + x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t) |
| 165 | + dx_t = x_1 - x_0 |
| 166 | + |
| 167 | + loss, grads = nnx.value_and_grad( |
| 168 | + lambda flow: mse(flow(x=x_t, t=t), dx_t) |
| 169 | + )(flow) |
| 170 | + optimizer.update(grads) |
| 171 | + return loss |
| 172 | + |
| 173 | + losses = [] |
| 174 | + t0 = time() |
| 175 | + for step in tqdm(range(total_steps), desc='NNX'): |
| 176 | + loss = train_step_nnx(flow, optimizer, rngs, X) |
| 177 | + losses.append(loss) |
| 178 | + |
| 179 | + total_time = time() - t0 |
| 180 | + print('### NNX ###') |
| 181 | + print(f'final loss: {losses[-1]}') |
| 182 | + print('total time:', total_time) |
| 183 | + print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') |
| 184 | + |
| 185 | + if mode == 'jax' or mode == 'all': |
| 186 | + rngs = nnx.Rngs(0) |
| 187 | + flow = MlpMixer( |
| 188 | + din=1, |
| 189 | + kernel_size=(2, 2), |
| 190 | + strides=(2, 2), |
| 191 | + num_blocks=depth, |
| 192 | + hidden_dim=width, |
| 193 | + tokens_mlp_dim=196, |
| 194 | + channels_mlp_dim=width, |
| 195 | + rngs=rngs, |
| 196 | + ) |
| 197 | + optimizer = nnx.Optimizer(flow, tx=optax.adamw(1e-4)) |
| 198 | + graphdef, state = nnx.split((flow, optimizer, rngs)) |
| 199 | + t0 = time() |
| 200 | + |
| 201 | + mse = lambda a, b: jnp.mean((a - b) ** 2) |
| 202 | + |
| 203 | + @partial(nnx.jit, donate_argnums=0) |
| 204 | + def train_step_jax(state, x_1): |
| 205 | + print('JITTING JAX') |
| 206 | + flow, optimizer, rngs = nnx.merge(graphdef, state) |
| 207 | + x_0 = jax.random.normal(rngs(), x_1.shape) |
| 208 | + t = jax.random.uniform(rngs(), (len(x_1),)) |
| 209 | + |
| 210 | + x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t) |
| 211 | + dx_t = x_1 - x_0 |
| 212 | + |
| 213 | + loss, grads = nnx.value_and_grad( |
| 214 | + lambda flow: mse(flow(x=x_t, t=t), dx_t) |
| 215 | + )(flow) |
| 216 | + optimizer.update(grads) |
| 217 | + state = nnx.state((flow, optimizer, rngs)) |
| 218 | + return loss, state |
| 219 | + |
| 220 | + losses = [] |
| 221 | + t0 = time() |
| 222 | + for step in tqdm(range(total_steps), desc='JAX'): |
| 223 | + loss, state = train_step_jax(state, X) |
| 224 | + losses.append(loss) |
| 225 | + |
| 226 | + nnx.update((flow, optimizer, rngs), state) |
| 227 | + total_time = time() - t0 |
| 228 | + print('### JAX ###') |
| 229 | + print(f'final loss: {losses[-1]}') |
| 230 | + print('total time:', total_time) |
| 231 | + print(f'time per step: {total_time / total_steps * 1e6:.2f} µs') |
| 232 | + |
| 233 | + |
| 234 | +if __name__ == '__main__': |
| 235 | + app.run(main) |
0 commit comments