Skip to content

Commit 78be9e9

Browse files
committed
[nnx] add flaxlib
1 parent cf6db71 commit 78be9e9

22 files changed

+1340
-463
lines changed

benchmarks/nnx_graph_overhead.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,11 @@ def main(argv):
9797
def step_nnx(model: MLP, optimizer: nnx.Optimizer):
9898
pass
9999

100+
cached_step_nnx = nnx.cache_args(step_nnx, model, optimizer)
101+
100102
t0 = time()
101103
for _ in range(total_steps):
102-
step_nnx(model, optimizer)
104+
cached_step_nnx()
103105

104106
total_time = time() - t0
105107
time_per_step = total_time / total_steps
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Copyright 2024 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# %%
16+
from functools import partial
17+
import jax
18+
import jax.numpy as jnp
19+
from flax import nnx
20+
import optax
21+
import numpy as np
22+
from einop import einop
23+
from time import time
24+
from tqdm import tqdm
25+
26+
from flax import nnx
27+
28+
from absl import flags
29+
from absl import app
30+
31+
FLAGS = flags.FLAGS
32+
flags.DEFINE_enum(
33+
'mode', 'all', ['all', 'nnx', 'jax'], 'Mode to run the script in'
34+
)
35+
flags.DEFINE_integer('total_steps', 10_000, 'Total number of training steps')
36+
flags.DEFINE_integer('batch_size', 32, 'Batch size')
37+
flags.DEFINE_integer('width', 32, 'Hidden layer size')
38+
flags.DEFINE_integer('depth', 4, 'Depth of the model')
39+
40+
41+
class MlpBlock(nnx.Module):
42+
def __init__(self, din: int, mlp_dim: int, rngs: nnx.Rngs):
43+
self.din, self.mlp_dim = din, mlp_dim
44+
self.linear_in = nnx.Linear(din, mlp_dim, rngs=rngs)
45+
self.linear_out = nnx.Linear(mlp_dim, din, rngs=rngs)
46+
47+
def __call__(self, x):
48+
return self.linear_out(nnx.gelu(self.linear_in(x)))
49+
50+
51+
class MixerBlock(nnx.Module):
52+
def __init__(
53+
self,
54+
tokens_mlp_dim: int,
55+
channels_mlp_dim: int,
56+
hidden_dim: int,
57+
rngs: nnx.Rngs,
58+
):
59+
self.tokens_mlp_dim = tokens_mlp_dim
60+
self.channels_mlp_dim = channels_mlp_dim
61+
self.hidden_dim = hidden_dim
62+
self.token_mixing = MlpBlock(tokens_mlp_dim, hidden_dim, rngs=rngs)
63+
self.channel_mixing = MlpBlock(channels_mlp_dim, hidden_dim, rngs=rngs)
64+
self.ln1 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
65+
self.ln2 = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
66+
67+
def __call__(self, x):
68+
y = self.ln1(x)
69+
y = y.swapaxes(1, 2)
70+
y = self.token_mixing(y)
71+
y = y.swapaxes(1, 2)
72+
x = x + y
73+
y = self.ln2(x)
74+
return x + self.channel_mixing(y)
75+
76+
77+
class MlpMixer(nnx.Module):
78+
def __init__(
79+
self,
80+
din: int,
81+
kernel_size: tuple[int, int],
82+
strides: tuple[int, int],
83+
num_blocks: int,
84+
hidden_dim: int,
85+
tokens_mlp_dim: int,
86+
channels_mlp_dim: int,
87+
rngs: nnx.Rngs,
88+
):
89+
self.din = din
90+
self.kernel_size = kernel_size
91+
self.num_blocks = num_blocks
92+
self.hidden_dim = hidden_dim
93+
self.tokens_mlp_dim = tokens_mlp_dim
94+
self.channels_mlp_dim = channels_mlp_dim
95+
self.stem = nnx.Conv(
96+
din + 1,
97+
channels_mlp_dim,
98+
kernel_size=kernel_size,
99+
strides=strides,
100+
rngs=rngs,
101+
)
102+
self.blocks = [
103+
MixerBlock(tokens_mlp_dim, channels_mlp_dim, hidden_dim, rngs=rngs)
104+
for _ in range(num_blocks)
105+
]
106+
self.pre_head_layer_norm = nnx.LayerNorm(channels_mlp_dim, rngs=rngs)
107+
self.conv_t = nnx.ConvTranspose(
108+
channels_mlp_dim, din, kernel_size=kernel_size, strides=strides, rngs=rngs
109+
)
110+
111+
def __call__(self, *, x, t):
112+
# add time feature to input
113+
t = einop(t, 'n -> n h w c', h=x.shape[1], w=x.shape[2], c=1)
114+
x = jnp.concatenate([x, t], axis=-1)
115+
# create patches
116+
x = self.stem(x)
117+
h, w = x.shape[1], x.shape[2]
118+
x = einop(x, 'n h w c -> n (h w) c')
119+
# apply blocks
120+
for block in self.blocks:
121+
x = block(x)
122+
x = self.pre_head_layer_norm(x)
123+
# recreate image
124+
x = einop(x, 'n (h w) c -> n h w c', h=h, w=w)
125+
x = self.conv_t(x)
126+
return x
127+
128+
129+
def main(argv):
130+
print(argv)
131+
mode: str = FLAGS.mode
132+
total_steps: int = FLAGS.total_steps
133+
batch_size: int = FLAGS.batch_size
134+
width: int = FLAGS.width
135+
depth: int = FLAGS.depth
136+
137+
print(f'{mode=}, {total_steps=}, {batch_size=}, {width=}')
138+
139+
X = np.random.uniform(size=(batch_size, 28, 28, 1))
140+
141+
if mode == 'nnx' or mode == 'all':
142+
rngs = nnx.Rngs(0)
143+
flow = MlpMixer(
144+
din=1,
145+
kernel_size=(2, 2),
146+
strides=(2, 2),
147+
num_blocks=4,
148+
hidden_dim=512,
149+
tokens_mlp_dim=196,
150+
channels_mlp_dim=512,
151+
rngs=rngs,
152+
)
153+
optimizer = nnx.Optimizer(flow, tx=optax.adamw(1e-4))
154+
t0 = time()
155+
156+
mse = lambda a, b: jnp.mean((a - b) ** 2)
157+
158+
@nnx.jit(donate_argnums=(0, 1, 2))
159+
def train_step_nnx(flow, optimizer, rngs, x_1):
160+
print('JITTING NNX')
161+
x_0 = jax.random.normal(rngs(), x_1.shape)
162+
t = jax.random.uniform(rngs(), (len(x_1),))
163+
164+
x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t)
165+
dx_t = x_1 - x_0
166+
167+
loss, grads = nnx.value_and_grad(
168+
lambda flow: mse(flow(x=x_t, t=t), dx_t)
169+
)(flow)
170+
optimizer.update(grads)
171+
return loss
172+
173+
losses = []
174+
t0 = time()
175+
for step in tqdm(range(total_steps), desc='NNX'):
176+
loss = train_step_nnx(flow, optimizer, rngs, X)
177+
losses.append(loss)
178+
179+
total_time = time() - t0
180+
print('### NNX ###')
181+
print(f'final loss: {losses[-1]}')
182+
print('total time:', total_time)
183+
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
184+
185+
if mode == 'jax' or mode == 'all':
186+
rngs = nnx.Rngs(0)
187+
flow = MlpMixer(
188+
din=1,
189+
kernel_size=(2, 2),
190+
strides=(2, 2),
191+
num_blocks=depth,
192+
hidden_dim=width,
193+
tokens_mlp_dim=196,
194+
channels_mlp_dim=width,
195+
rngs=rngs,
196+
)
197+
optimizer = nnx.Optimizer(flow, tx=optax.adamw(1e-4))
198+
graphdef, state = nnx.split((flow, optimizer, rngs))
199+
t0 = time()
200+
201+
mse = lambda a, b: jnp.mean((a - b) ** 2)
202+
203+
@partial(nnx.jit, donate_argnums=0)
204+
def train_step_jax(state, x_1):
205+
print('JITTING JAX')
206+
flow, optimizer, rngs = nnx.merge(graphdef, state)
207+
x_0 = jax.random.normal(rngs(), x_1.shape)
208+
t = jax.random.uniform(rngs(), (len(x_1),))
209+
210+
x_t = jax.vmap(lambda x_0, x_1, t: (1 - t) * x_0 + t * x_1)(x_0, x_1, t)
211+
dx_t = x_1 - x_0
212+
213+
loss, grads = nnx.value_and_grad(
214+
lambda flow: mse(flow(x=x_t, t=t), dx_t)
215+
)(flow)
216+
optimizer.update(grads)
217+
state = nnx.state((flow, optimizer, rngs))
218+
return loss, state
219+
220+
losses = []
221+
t0 = time()
222+
for step in tqdm(range(total_steps), desc='JAX'):
223+
loss, state = train_step_jax(state, X)
224+
losses.append(loss)
225+
226+
nnx.update((flow, optimizer, rngs), state)
227+
total_time = time() - t0
228+
print('### JAX ###')
229+
print(f'final loss: {losses[-1]}')
230+
print('total time:', total_time)
231+
print(f'time per step: {total_time / total_steps * 1e6:.2f} µs')
232+
233+
234+
if __name__ == '__main__':
235+
app.run(main)

benchmarks/nnx_simple_training.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
# %%
16+
from functools import partial
1617
import jax
1718
import jax.numpy as jnp
1819
import numpy as np
@@ -97,7 +98,7 @@ def main(argv):
9798
optimizer = nnx.Optimizer(model, tx)
9899
t0 = time()
99100

100-
@nnx.jit
101+
@nnx.jit(donate_argnums=(0, 1))
101102
def train_step_nnx(model: MLP, optimizer: nnx.Optimizer, batch):
102103
x, y = batch
103104

@@ -108,18 +109,21 @@ def loss_fn(model: MLP):
108109
grads: nnx.State = nnx.grad(loss_fn)(model)
109110
optimizer.update(grads)
110111

111-
@nnx.jit
112+
@nnx.jit(donate_argnums=0)
112113
def test_step_nnx(model: MLP, batch):
113114
x, y = batch
114115
y_pred = model(x)
115116
loss = jnp.mean((y - y_pred) ** 2)
116117
return {'loss': loss}
117118

119+
cached_train_step_nnx = nnx.cache_args(train_step_nnx, model, optimizer)
120+
cached_test_step_nnx = nnx.cache_args(test_step_nnx, model)
121+
118122
for step, batch in enumerate(dataset(X, Y, batch_size)):
119-
train_step_nnx(model, optimizer, batch)
123+
cached_train_step_nnx(batch)
120124

121125
if step % 1000 == 0:
122-
logs = test_step_nnx(model, (X, Y))
126+
logs = cached_test_step_nnx((X, Y))
123127

124128
if step >= total_steps - 1:
125129
break
@@ -137,8 +141,8 @@ def test_step_nnx(model: MLP, batch):
137141
optimizer = nnx.Optimizer(model, tx)
138142
t0 = time()
139143

140-
@jax.jit
141-
def train_step_jax(graphdef, state, batch):
144+
@partial(jax.jit, donate_argnums=0)
145+
def train_step_jax(state, batch):
142146
model, optimizer = nnx.merge(graphdef, state)
143147
x, y = batch
144148

@@ -151,8 +155,8 @@ def loss_fn(model: MLP):
151155

152156
return nnx.state((model, optimizer))
153157

154-
@jax.jit
155-
def test_step_jax(graphdef, state, batch):
158+
@partial(jax.jit, donate_argnums=0)
159+
def test_step_jax(state, batch):
156160
model, optimizer = nnx.merge(graphdef, state)
157161
x, y = batch
158162
y_pred = model(x)
@@ -163,10 +167,10 @@ def test_step_jax(graphdef, state, batch):
163167
graphdef, state = nnx.split((model, optimizer))
164168

165169
for step, batch in enumerate(dataset(X, Y, batch_size)):
166-
state = train_step_jax(graphdef, state, batch)
170+
state = train_step_jax(state, batch)
167171

168172
if step % 1000 == 0:
169-
state, logs = test_step_jax(graphdef, state, (X, Y))
173+
state, logs = test_step_jax(state, (X, Y))
170174

171175
if step >= total_steps - 1:
172176
break

examples/nnx_toy_examples/02_lifted_transforms.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,15 @@ def test_step(model: MLP, batch):
8282
loss = jnp.mean((y - y_pred) ** 2)
8383
return {'loss': loss}
8484

85+
cached_train_step = nnx.cache_args(train_step, model, optimizer)
86+
cached_test_step = nnx.cache_args(test_step, model)
8587

8688
total_steps = 10_000
8789
for step, batch in enumerate(dataset(32)):
88-
train_step(model, optimizer, batch)
90+
cached_train_step(batch)
8991

9092
if step % 1000 == 0:
91-
logs = test_step(model, (X, Y))
93+
logs = cached_test_step((X, Y))
9294
print(f"step: {step}, loss: {logs['loss']}")
9395

9496
if step >= total_steps - 1:

flax/configurations.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323

2424
class Config:
25+
flax_use_flaxlib: bool
2526
# See https://google.github.io/pytype/faq.html.
2627
_HAS_DYNAMIC_ATTRIBUTES = True
2728

@@ -62,6 +63,10 @@ def update(self, name_or_holder, value, /):
6263
raise LookupError(f'Unrecognized config option: {name}')
6364
self._values[name] = value
6465

66+
def __repr__(self):
67+
values_repr = ', '.join(f'\n {k}={v!r}' for k, v in self._values.items())
68+
return f'Config({values_repr}\n)'
69+
6570

6671
config = Config()
6772

@@ -201,3 +206,9 @@ def temp_flip_flag(var_name: str, var_value: bool):
201206
' PRNG keys.'
202207
),
203208
)
209+
210+
flax_use_flaxlib = bool_flag(
211+
name='flax_use_flaxlib',
212+
default=False,
213+
help='Whether to use flaxlib for C++ acceleration.',
214+
)

flax/nnx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
from .graph import MergeContext as MergeContext
5757
from .graph import merge_context as merge_context
5858
from .graph import variables as variables
59+
from .graph import cache_args as cache_args
5960
from .nn import initializers as initializers
6061
from .nn.activations import celu as celu
6162
from .nn.activations import elu as elu

0 commit comments

Comments
 (0)