Skip to content
Discussion options

You must be logged in to vote

Hey! I'm guessing you want to replicate the weights but have different RNGs, to do this you can use the nnx.split_rngs decorator to split the RNGs before entering pmap and, and use StateAxes to specify the parallelization axes for substates of your Module, in this case map RngState to 0 and the rest (...) to None:

state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})

@nnx.split_rngs(splits=1)
@nnx.pmap(in_axes=(state_axes, 0))
def forward(model, x):
  return model(x)

out = forward(model, jnp.ones((1, 16, 2)))

For more info, check out the Filters guide.

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by maxxxzdn
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants