-
Are there any gotchas or Flax NNX mechanics to be aware of in the context of the following two training loop styles? Internal import jax, jax.numpy as jnp, flax.nnx as nnx, optax, time
from typing import Sequence
from functools import partial
from tqdm import tqdm
jax.devices()
class Module(nnx.Module):
def __init__(self, x_dim: int, hidden: Sequence[int], rngs: nnx.Rngs):
layers = [nnx.Linear(x_dim, hidden[0], rngs=rngs), nnx.swish]
for i in range(len(hidden) - 1):
layers += [nnx.Linear(hidden[i], hidden[i+1], rngs=rngs)]
layers += [nnx.swish]
layers += [nnx.Linear(hidden[-1], 1, rngs=rngs)]
self.module = nnx.Sequential(*layers)
self.x_dim = x_dim
@partial(nnx.jit, static_argnames='n')
def step(self, key: jax.Array, n: int, optimizer: nnx.Optimizer) -> jax.Array:
def loss_fn(self):
x = jax.random.uniform(key, (n, self.x_dim), minval=-1.0, maxval=1.0) # (n, x_dim)
y = jnp.exp(-jnp.sum(x**2, axis=1, keepdims=True)) # (n, 1)
yhat = self.module(x)
loss = jnp.mean((y - yhat)**2)
return loss
loss, grads = nnx.value_and_grad(loss_fn)(self)
optimizer.update(grads)
return loss
def fit(self, tx: optax.GradientTransformation, n_epochs: int, key: jax.Array, n: int):
# Usually not a jitable method, e.g. w/ passing in dataloader
t, optimizer = [], nnx.Optimizer(self, tx)
for _ in (pb := tqdm(range(n_epochs), desc="Training")):
t0 = time.time()
key, trainkey = jax.random.split(key)
loss = self.step(trainkey, n, optimizer)
pb.set_postfix({'Loss': loss.item()})
t1 = time.time()
t.append(t1 - t0)
t = sum(t) / n_epochs
return t
module = Module(1000, [500, 500, 500], nnx.Rngs(0))
module.fit(optax.adam(1e-4), 1000, jax.random.PRNGKey(0), 1000) External @partial(nnx.jit, static_argnames='n')
def step(module: Module, key: jax.Array, n: int, optimizer: nnx.Optimizer) -> jax.Array:
def loss_fn(module):
x = jax.random.uniform(key, (n, module.x_dim), minval=-1.0, maxval=1.0) # (n, x_dim)
y = jnp.exp(-jnp.sum(x**2, axis=1, keepdims=True)) # (n, 1)
yhat = module.module(x)
loss = jnp.mean((y - yhat)**2)
return loss
loss, grads = nnx.value_and_grad(loss_fn)(module)
optimizer.update(grads)
return loss
def fit(module, tx: optax.GradientTransformation, n_epochs: int, key: jax.Array, n: int):
# Usually not a jitable method, e.g. w/ passing in dataloader
t, optimizer = [], nnx.Optimizer(module, tx)
for _ in (pb := tqdm(range(n_epochs), desc="Training")):
t0 = time.time()
key, trainkey = jax.random.split(key)
loss = step(module, trainkey, n, optimizer)
pb.set_postfix({'Loss': loss.item()})
t1 = time.time()
t.append(t1 - t0)
t = sum(t) / n_epochs
return t
module = Module(1000, [500, 500, 500], nnx.Rngs(0))
fit(module, optax.adam(1e-4), 1000, jax.random.PRNGKey(0), 1000) |
Beta Was this translation helpful? Give feedback.
Answered by
cgarciae
Jul 24, 2025
Replies: 2 comments
-
Beta Was this translation helpful? Give feedback.
0 replies
-
@cisprague there is not preference here. |
Beta Was this translation helpful? Give feedback.
0 replies
Answer selected by
cisprague
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
@cisprague there is not preference here.