Replies: 5 comments 2 replies
-
Wait...
|
Beta Was this translation helpful? Give feedback.
-
I am really confused too. It seems now that |
Beta Was this translation helpful? Give feedback.
-
In addition, I have another question. The release notes say that for some simple functions, we no longer need to use merge and split in the code. How is the efficiency in this case? Will there still be a relatively large overhead as mentioned in the earlier performance considerations? (I noticed that when we use jax.jit to decorate a function, the speed does seem to be noticeably faster than with nnx.jit.) |
Beta Was this translation helpful? Give feedback.
-
Split and merge paradigm is still there and wont be deprecated. Here is an example of usage: flax/examples/nnx_toy_examples/mutable_array_basic.py Lines 62 to 71 in 2bf5748 In a special case when model contains only parameters, one can simplify the code and remove split/merge calls. |
Beta Was this translation helpful? Give feedback.
-
I'm sure the incoming docs will cover everything, but I'm curious about the |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
v0.11.0 - Pytrees, MutableArrays, and more!
This version of Flax introduces some changes to improve interop with native JAX and adds support for the new
jax.experimental.MutableArray
. More on this soon! However, some breaking changes to align with the JAX way of doing things were necessary. Most code should remain intact, however, the following changes deviate from the current behavior:Rngs
in standard layers: all standard layers no longer hold a shared reference to therngs
object given in the constructor, instead they now keep afork
-ed copy of theRngs
orRngStream
objects. This impacts Using Rngs in NNX Transforms and Loading Checkpoints with RNGs.model
to avoid reference sharing, instead themodel
must be provided as the first argument toupdate
.split
andmerge
when interacting trivially with raw JAX transforms (state must still be manually propagated if not using MutableArrays, and referential transparency is still an issue). This affects when operating on Pytrees containing NNX Objects withjax.tree.*
APIs.Checkout the full NNX 0.10 to NNX 0.11 migration guide.
In the near future we'll share more information about new ways of using NNX with JAX transforms directly by leveraging the new Pytree and MutableArray support. Stay tuned!
What's Changed
.type
usage by @vfdev-5 in Fix failing CI jobs: trailing whitespace, deprecated.type
usage #4823.value
to[...]
in modules_test.py by @lukeyeh in refactor: move usages of.value
to[...]
in modules_test.py #4815transforms_test.py
from.value
to[...]
by @lukeyeh in Migratetransforms_test.py
from.value
to[...]
#4841New Contributors
.value
to[...]
in modules_test.py #4815Full Changelog: v0.10.7...v0.11.0
This discussion was created from the release v0.11.0.
Beta Was this translation helpful? Give feedback.
All reactions