Skip to content

Commit 48c59d3

Browse files
committed
[nnx] add cache_args
1 parent 429033b commit 48c59d3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2754
-1066
lines changed

benchmarks/nnx_graph_overhead.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,31 +24,52 @@
2424
from absl import app
2525

2626
FLAGS = flags.FLAGS
27-
flags.DEFINE_enum('mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in')
27+
flags.DEFINE_enum(
28+
'mode', 'nnx', ['all', 'nnx', 'jax'], 'Mode to run the script in'
29+
)
2830
flags.DEFINE_integer('total_steps', 100, 'Total number of training steps')
2931
flags.DEFINE_integer('width', 32, 'Hidden layer size')
3032
flags.DEFINE_integer('depth', 5, 'Depth of the model')
3133

3234

33-
3435
class Linear(nnx.Module):
3536
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
36-
self.list = [
37-
nnx.Param(jax.random.uniform(rngs.params(), (din, dout))),
38-
nnx.Param(jnp.zeros((dout,))),
39-
]
40-
self.dict = {
41-
'w': nnx.Param(jax.random.uniform(rngs.params(), (din, dout))),
42-
'b': nnx.Param(jnp.zeros((dout,))),
43-
}
37+
self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))
38+
self.b = nnx.Param(jnp.zeros((dout,)))
39+
40+
def __call__(self, x):
41+
return x @ self.w + self.b
42+
43+
44+
class Block(nnx.Module):
45+
def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):
46+
self.linear = Linear(din, dout, rngs=rngs)
47+
self.bn = nnx.BatchNorm(dout, rngs=rngs)
48+
49+
def __call__(self, x):
50+
return nnx.relu(self.bn(self.linear(x)))
51+
4452

53+
class Count(nnx.Variable):
54+
pass
4555

4656

4757
class MLP(nnx.Module):
48-
def __init__(self, depth, *, rngs: nnx.Rngs):
58+
def __init__(self, din, dhidden, dout, depth, *, rngs: nnx.Rngs):
59+
self.count = Count(jnp.array(0))
60+
self.linear_in = Block(din, dhidden, rngs=rngs)
4961
self.intermediates = [
50-
Linear(10, 10, rngs=rngs) for _ in range(depth)
62+
Block(dhidden, dhidden, rngs=rngs) for _ in range(depth - 2)
5163
]
64+
self.linear_out = Block(dhidden, dout, rngs=rngs)
65+
66+
def __call__(self, x):
67+
self.count.value += 1
68+
x = nnx.relu(self.linear_in(x))
69+
for layer in self.intermediates:
70+
x = nnx.relu(layer(x))
71+
x = self.linear_out(x)
72+
return x
5273

5374

5475
def main(argv):
@@ -63,21 +84,24 @@ def main(argv):
6384
X = np.linspace(0, 1, 100)[:, None]
6485
Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape)
6586

66-
model = MLP(depth=depth, rngs=nnx.Rngs(0))
67-
tx = optax.sgd(1e-3)
68-
optimizer = nnx.Optimizer(model, tx)
69-
7087
#------------------------------------------------------------
7188
# NNX
7289
#------------------------------------------------------------
7390
if mode in ['all', 'nnx']:
91+
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
92+
tx = optax.sgd(1e-3)
93+
optimizer = nnx.Optimizer(model, tx)
94+
t0 = time()
95+
7496
@nnx.jit
7597
def step_nnx(model: MLP, optimizer: nnx.Optimizer):
7698
pass
7799

100+
cached_step_nnx = nnx.cached_partial(step_nnx, model, optimizer)
101+
78102
t0 = time()
79103
for _ in range(total_steps):
80-
step_nnx(model, optimizer)
104+
cached_step_nnx()
81105

82106
total_time = time() - t0
83107
time_per_step = total_time / total_steps
@@ -93,6 +117,11 @@ def step_nnx(model: MLP, optimizer: nnx.Optimizer):
93117
#------------------------------------------------------------
94118

95119
if mode in ['all', 'jax']:
120+
model = MLP(din=1, dhidden=width, dout=1, depth=depth, rngs=nnx.Rngs(0))
121+
tx = optax.sgd(1e-3)
122+
optimizer = nnx.Optimizer(model, tx)
123+
t0 = time()
124+
96125
@jax.jit
97126
def step_jax(graphdef, state):
98127
return graphdef, state
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# %%
16+
from functools import partial
17+
import jax
18+
import jax.numpy as jnp
19+
from flax import nnx
20+
import optax
21+
import numpy as np
22+
from einop import einop
23+
from time import time
24+
from tqdm import tqdm
25+
26+
from flax import nnx
27+
28+
from absl import flags
29+
from absl import app
30+
31+
FLAGS = flags.FLAGS
32+
flags.DEFINE_enum(
33+
'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
34+
)
35+
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
36+
flags.DEFINE_integer('batch_size', 32, 'Batch size')
37+
flags.DEFINE_integer('width', 32, 'Hidden layer size')
38+
flags.DEFINE_integer('depth', 4, 'Depth of the model')
39+
40+
41+
class MlpBlock(nnx.Module):
42+
def __init__(self, din: int, mlp_dim: int, rngs: nnx.Rngs):
43+
self.din, self.mlp_dim = din, mlp_dim
44+
self.linear_in = nnx.Linear(din, mlp_dim, rngs=rngs)
45+
self.linear_out = nnx.Linear(mlp_dim, din, rngs=rngs)
46+
47+
def __call__(self, x):
48+
return self.linear_out(nnx.gelu(self.linear_in(x)))
49+
50+
51+
class MixerBlock(nnx.Module):
52+
def __init__(
53+
self,
54+
tokens_mlp_dim: int,
55+
channels_mlp_dim: int,
56+
hidden_dim: int,
57+
rngs: nnx.Rngs,
58+
):
59+
self.tokens_mlp_dim = tokens_mlp_dim
60+
self.channels_mlp_dim = channels_mlp_dim
61+
self.hidden_dim = hidden_dim
62+
self.token_mixing = MlpBlock(tokens_mlp_dim, hidden_dim, rngs=rngs)
63+
self.channel_mixing = MlpBlock(channels_mlp_dim, hidden_dim, rngs=rngs)
64+
self.ln1 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
65+
self.ln2 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
66+
67+
def __call__(self, x):
68+
y = self.ln1(x)
69+
y = y.swapaxes(1, 2)
70+
y = self.token_mixing(y)
71+
y = y.swapaxes(1, 2)
72+
x = x + y
73+
y = self.ln2(x)
74+
return x + self.channel_mixing(y)
75+
76+
77+
class MlpMixer(nnx.Module):
78+
def __init__(
79+
self,
80+
din: int,
81+
kernel_size: tuple[int, int],
82+
strides: tuple[int, int],
83+
num_blocks: int,
84+
hidden_dim: int,
85+
tokens_mlp_dim: int,
86+
channels_mlp_dim: int,
87+
rngs: nnx.Rngs,
88+
):
89+
self.din = din
90+
self.kernel_size = kernel_size
91+
self.num_blocks = num_blocks
92+
self.hidden_dim = hidden_dim
93+
self.tokens_mlp_dim = tokens_mlp_dim
94+
self.channels_mlp_dim = channels_mlp_dim
95+
self.stem = nnx.Conv(
96+
din + 1,
97+
channels_mlp_dim,
98+
kernel_size=kernel_size,
99+
strides=strides,
100+
rngs=rngs,
101+
)
102+
self.blocks = [
103+
MixerBlock(tokens_mlp_dim, channels_mlp_dim, hidden_dim, rngs=rngs)
104+
for _ in range(num_blocks)
105+
]
106+
self.pre_head_layer_norm = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
107+
self.conv_t = nnx.ConvTranspose(
108+
channels_mlp_dim, din, kernel_size=kernel_size, strides=strides, rngs=rngs
109+
)
110+
111+
def __call__(self, *, x, t):
112+
# add time feature to input
113+
t = einop(t, 'n -> n h w c', h=x.shape[1], w=x.shape[2], c=1)
114+
x = jnp.concatenate([x, t], axis=-1)
115+
# create patches
116+
x = self.stem(x)
117+
h, w = x.shape[1], x.shape[2]
118+
x = einop(x, 'n h w c -> n (h w) c')
119+
# apply blocks
120+
for block in self.blocks:
121+
x = block(x)
122+
x = self.pre_head_layer_norm(x)
123+
# recreate image
124+
x = einop(x, 'n (h w) c -> n h w c', h=h, w=w)
125+
x = self.conv_t(x)
126+
return x
127+
128+
129+
def main(argv):
130+
print(argv)
131+
mode: str = FLAGS.mode
132+
total_steps: int = FLAGS.total_steps
133+
batch_size: int = FLAGS.batch_size
134+
width: int = FLAGS.width
135+
depth: int = FLAGS.depth
136+
137+
print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}')
138+
139+
X = np.random.uniform(size=(batch_size, 28, 28, 1))
140+
141+
if mode == 'nnx' or mode == 'all':
142+
rngs = nnx.Rngs(0)
143+
flow = MlpMixer(
144+
din=1,
145+
kernel_size=(2, 2),
146+
strides=(2, 2),
147+
num_blocks=4,
148+
hidden_dim=512,
149+
tokens_mlp_dim=196,
150+
channels_mlp_dim=512,
151+
rngs=rngs,
152+
)
153+
optimizer = nnx.Optimizer(flow, tx=optax.adamw(1e-4))
154+
t0 = time()
155+
156+
mse = lambda a, b: jnp.mean((a - b) ** 2)
157+
158+
@nnx.jit(donate_argnums=(0, 1, 2))
159+
def train_step_nnx(flow, optimizer, rngs, x_1):
160+
print('JITTING NNX')
161+
x_0 = jax.random.normal(rngs(), x_1.shape)
162+
t = jax.random.uniform(rngs(), (len(x_1),))
163+
164+
x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t)
165+
dx_t = x_1 - x_0
166+
167+
loss, grads = nnx.value_and_grad(
168+
lambda flow: mse(flow(x=x_t, t=t), dx_t)
169+
)(flow)
170+
optimizer.update(grads)
171+
return loss
172+
173+
losses = []
174+
t0 = time()
175+
for step in tqdm(range(total_steps), desc='NNX'):
176+
loss = train_step_nnx(flow, optimizer, rngs, X)
177+
losses.append(loss)
178+
179+
total_time = time() - t0
180+
print('### NNX ###')
181+
print(f'final loss: {losses[-1]}')
182+
print('total time:', total_time)
183+
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
184+
185+
if mode == 'jax' or mode == 'all':
186+
rngs = nnx.Rngs(0)
187+
flow = MlpMixer(
188+
din=1,
189+
kernel_size=(2, 2),
190+
strides=(2, 2),
191+
num_blocks=depth,
192+
hidden_dim=width,
193+
tokens_mlp_dim=196,
194+
channels_mlp_dim=width,
195+
rngs=rngs,
196+
)
197+
optimizer = nnx.Optimizer(flow, tx=optax.adamw(1e-4))
198+
graphdef, state = nnx.split((flow, optimizer, rngs))
199+
t0 = time()
200+
201+
mse = lambda a, b: jnp.mean((a - b) ** 2)
202+
203+
@partial(nnx.jit, donate_argnums=0)
204+
def train_step_jax(state, x_1):
205+
print('JITTING JAX')
206+
flow, optimizer, rngs = nnx.merge(graphdef, state)
207+
x_0 = jax.random.normal(rngs(), x_1.shape)
208+
t = jax.random.uniform(rngs(), (len(x_1),))
209+
210+
x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t)
211+
dx_t = x_1 - x_0
212+
213+
loss, grads = nnx.value_and_grad(
214+
lambda flow: mse(flow(x=x_t, t=t), dx_t)
215+
)(flow)
216+
optimizer.update(grads)
217+
state = nnx.state((flow, optimizer, rngs))
218+
return loss, state
219+
220+
losses = []
221+
t0 = time()
222+
for step in tqdm(range(total_steps), desc='JAX'):
223+
loss, state = train_step_jax(state, X)
224+
losses.append(loss)
225+
226+
nnx.update((flow, optimizer, rngs), state)
227+
total_time = time() - t0
228+
print('### JAX ###')
229+
print(f'final loss: {losses[-1]}')
230+
print('total time:', total_time)
231+
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
232+
233+
234+
if __name__ == '__main__':
235+
app.run(main)

0 commit comments

Comments
 (0)