Skip to content
Discussion options

You must be logged in to vote

Here an example using nnx.split and nnx.merge to use jax2tf to export an Flax NNX Module as a SavedModel:

from flax import nnx
import jax
import jax.numpy as jnp
from jax.experimental import jax2tf
import tensorflow as tf

model = nnx.Linear(3, 4, rngs=nnx.Rngs(0))

graphdef, state = nnx.split(model)
state = jax.tree.leaves(state) # flatten the state

tf_state = tf.nest.map_structure(tf.Variable, state)

def forward_jax(state, x):
  model = nnx.merge(graphdef, state)
  return model(x)

# test forward
y = forward_jax(state, jnp.ones((3,)))

def predict_tf(x):
  return jax2tf.convert(forward_jax)(tf_state, x)

tf_model = tf.Module()
# Tell the model saver what are the variables.
tf_model._v…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@5c4lar
Comment options

Answer selected by 5c4lar
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants