-
Notifications
You must be signed in to change notification settings - Fork 746
Description
I am trying to implement some basic quantize/dequantize operations using NNX that would allow me to quantize activations and gradients. For now, I am trying to use nnx.custom_vjp
to manage the references of a quantizer module such that I can propagate quantization metadata between quantization/dequantization fwd/bwd calls. It's a very basic implementation with several limitations (such as being limited to using the same dtype for primals and tangents, just to name one) that I hope to build upon once I get this working correctly.
I have previously used something similar in combination with rematerialization in linen to reduce the size of the checkpointed values, which worked quite nicely. I have added the code for this implementation at the end.
Currently, the issue I am having with the NNX implementation is that the bwd_scale
and bwd_amax_hist
are not being updated correctly after the dq_bwd
call, even though the returned updates are correct within dq_bwd
, meaning that the q_bwd
has the wrong bwd_scale
value for dequantizing the gradient. One can also see from the output (below the code sample) that the module's bwd_scale
and bwd_amax_hist
don't seem to be traced over correctly.
I have tried to implement this in accordance with the nnx.custom_vjp
docstring example; however, I suspect I am getting something wrong with regard to how module state needs to be updated within a bwd function and how the reference semantics are handled when using nnx.custom_vjp
.
Any help pointing me in the right direction would be greatly appreciated!
NNX Implementation
import jax
import jax.numpy as jnp
from flax import nnx
from jax import random
# Example pure function implementations of amax quantization and dequantization
def update_metadata(x, amax_hist, q_dtype):
amax = jnp.max(jnp.abs(x))
qmax = jnp.finfo(q_dtype).max.astype(jnp.float32)
new_scale = amax.astype(jnp.float32) / qmax
new_amax_hist = jnp.roll(amax_hist, 1).at[0].set(amax)
return new_scale, new_amax_hist
def quantize(x, scale, q_dtype):
qmax = jnp.finfo(q_dtype).max.astype(jnp.float32)
x_q = jnp.clip(x / scale, -qmax, qmax).astype(q_dtype)
return x_q
def dequantize(x_q, scale, dq_dtype):
x = x_q.astype(dq_dtype) * scale
return x
# NNX wrappers for quantization and dequantization
class QuantizationMetadata(nnx.Variable):
pass
class Quantizer(nnx.Module):
def __init__(self, q_dtype, dq_dtype, amax_hist_len=128):
self.q_dtype = q_dtype
self.dq_dtype = dq_dtype
self.fwd_amax_hist = QuantizationMetadata(jnp.zeros((amax_hist_len,)))
self.fwd_scale = QuantizationMetadata(1)
self.bwd_amax_hist = QuantizationMetadata(jnp.zeros((amax_hist_len,)))
self.bwd_scale = QuantizationMetadata(1)
def quantize(self, x):
xq = nnx_quantize(self, x)
return xq
def dequantize(self, xq):
x = nnx_dequantize(self, xq)
return x
@nnx.custom_vjp
def nnx_quantize(m: Quantizer, x):
m.fwd_scale.value, m.fwd_amax_hist.value = update_metadata(x, m.fwd_amax_hist.value, m.q_dtype)
xq = quantize(x, m.fwd_scale.value, m.q_dtype)
return xq
def q_fwd(m: Quantizer, x):
xq = nnx_quantize(m, x)
return xq, m
def q_bwd(res, g):
m = res
(in_mg, in_xg), xq_g = g
xg = dequantize(xq_g, m.bwd_scale.value, m.dq_dtype)
return in_mg, xg
nnx_quantize.defvjp(q_fwd, q_bwd)
@nnx.custom_vjp
def nnx_dequantize(m: Quantizer, qx):
x = dequantize(qx, m.fwd_scale.value, m.dq_dtype)
return x
def dq_fwd(m: Quantizer, qx):
x = nnx_dequantize(m, qx)
return x, m
def dq_bwd(res, g):
m = res
(in_mg, in_xg), x_g = g
# copy tree and update variables
bwd_scale, bwd_amax_hist = update_metadata(x_g, m.bwd_amax_hist.value, m.q_dtype)
mg = jax.tree.map(lambda x: x, in_mg)
mg['bwd_scale'].value = bwd_scale
mg['bwd_amax_hist'].value = bwd_amax_hist
qx_g = quantize(x_g, bwd_scale, m.q_dtype)
return mg, qx_g
nnx_dequantize.defvjp(dq_fwd, dq_bwd)
if __name__=="__main__":
quantizer = Quantizer(jnp.float8_e4m3fn, jnp.float32)
def qdq(x):
xq = quantizer.quantize(x)
dqx = quantizer.dequantize(xq)
return dqx
x = jnp.array(5, jnp.float32)
y, g = nnx.value_and_grad(qdq)(x)
print("x:", x)
print("y:", y)
print("g:", g)
print("quantizer:", quantizer)
Output:
x: 5.0
y: 5.0
g: 448.0
quantizer: Quantizer(
bwd_amax_hist=QuantizationMetadata(
value=Array(shape=(128,), dtype=float32)
),
bwd_scale=QuantizationMetadata(
value=1
),
dq_dtype=<class 'jax.numpy.float32'>,
fwd_amax_hist=QuantizationMetadata(
value=Array(shape=(128,), dtype=float32)
),
fwd_scale=QuantizationMetadata(
value=Traced<ShapedArray(float32[])>with<JVPTrace> with
primal = Array(0.01116071, dtype=float32)
tangent = Traced<ShapedArray(float32[])>with<JaxprTrace> with
pval = (ShapedArray(float32[]), None)
recipe = JaxprEqnRecipe(eqn_id=<object object at 0x1108f69d0>, in_tracers=(Traced<ShapedArray(float32[128]):JaxprTrace>, Traced<ShapedArray(int32[], weak_type=True):JaxprTrace>, Traced<ShapedArray(float32[128]):JaxprTrace>, Traced<ShapedArray(float32[]):JaxprTrace>, Traced<ShapedArray(float32[128]):JaxprTrace>, Traced<ShapedArray(float0[]):JaxprTrace>, Traced<ShapedArray(float32[128]):JaxprTrace>, Traced<ShapedArray(float32[]):JaxprTrace>, Traced<ShapedArray(float8_e4m3fn[]):JaxprTrace>), out_tracer_refs=[<weakref at 0x1117907c0; to 'JaxprTracer' at 0x111740040>, <weakref at 0x1117921b0; dead>, <weakref at 0x111792160; to 'JaxprTracer' at 0x111790860>, <weakref at 0x111792020; to 'JaxprTracer' at 0x111790900>, <weakref at 0x111791f80; dead>], out_avals=[ShapedArray(float32[128]), ShapedArray(float0[], weak_type=True), ShapedArray(float32[128]), ShapedArray(float32[]), ShapedArray(float32[])], primitive=custom_lin, params={'num_res': 4, 'bwd': Wrapped function:
0 : _flatten_bwd (PyTreeDef((CustomNode(NodeStates[(None,)], [None, (CustomNode(State[('bwd_amax_hist', 'bwd_scale', 'fwd_amax_hist', 'fwd_scale')], [CustomNode(VariableState[(<class '__main__.QuantizationMetadata'>, (('get_value_hooks', ()), ('set_value_hooks', ()), ('create_value_hooks', ()), ('add_axis_hooks', ()), ('remove_axis_hooks', ())))], [*]), CustomNode(VariableState[(<class '__main__.QuantizationMetadata'>, (('get_value_hooks', ()), ('set_value_hooks', ()), ('create_value_hooks', ()), ('add_axis_hooks', ()), ('remove_axis_hooks', ())))], [*]), CustomNode(VariableState[(<class '__main__.QuantizationMetadata'>, (('get_value_hooks', ()), ('set_value_hooks', ()), ('create_value_hooks', ()), ('add_axis_hooks', ()), ('remove_axis_hooks', ())))], [*]), CustomNode(VariableState[(<class '__main__.QuantizationMetadata'>, (('get_value_hooks', ()), ('set_value_hooks', ()), ('create_value_hooks', ()), ('add_axis_hooks', ()), ('remove_axis_hooks', ())))], [*])]),)]), *)), [ShapedArray(float32[128]), ShapedArray(int32[], weak_type=True), ShapedArray(float32[128]), ShapedArray(float32[]), ShapedArray(float8_e4m3fn[])], <function transformation_with_aux2.<locals>.<lambda> at 0x111784720>)
1 : _get_result_paths_thunk ()
Core: dq_bwd
, 'out_avals': [ShapedArray(float32[128]), ShapedArray(float0[], weak_type=True), ShapedArray(float32[128]), ShapedArray(float32[]), ShapedArray(float32[])], 'symbolic_zeros': False}, effects=frozenset(), source_info=<jax._src.source_info_util.SourceInfo object at 0x111780b50>, ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types={}), xla_metadata=None))
),
q_dtype=<class 'jax.numpy.float8_e4m3fn'>
)
Linen implementation
from functools import partial
import jax
from flax import linen as nn
from jax import custom_vjp
import jax.numpy as jnp
from jax.experimental import primal_tangent_dtype
def quantize_fn(primal_dtype, quantized_primal_dtype):
def quantize_fwd(x, primal_scale, primal_amax_history):
new_primal_scale, new_primal_amax_history = nn.fp8_ops.update_fp8_meta(
x,
quantized_primal_dtype,
primal_scale,
primal_amax_history
)
qx = nn.fp8_ops.quantize(
x,
quantized_primal_dtype,
nn.fp8_ops._fm32_to_float32(new_primal_scale),
primal_dtype
)
return (qx, new_primal_scale), (new_primal_scale, new_primal_amax_history)
def quantize_bwd(res, g):
qx, reverse_scale = g
new_primal_scale, new_primal_amax_history = res
dqx = nn.fp8_ops.dequantize(
qx,
primal_dtype,
nn.fp8_ops._fm32_to_float32(reverse_scale)
)
return dqx, new_primal_scale, new_primal_amax_history
@custom_vjp
def quantize(x, primal_scale, primal_amax_history):
(qx, new_primal_scale), _ = quantize_fwd(x, primal_scale, primal_amax_history)
return qx, new_primal_scale
quantize.defvjp(quantize_fwd, quantize_bwd)
return quantize
def dequantize_fn(primal_dtype, quantized_primal_dtype):
def dequantize_fwd(x, new_primal_scale, tangent_scale, tangent_amax_history):
assert x.dtype == quantized_primal_dtype
dqx = nn.fp8_ops.dequantize(
x,
primal_dtype,
nn.fp8_ops._fm32_to_float32(new_primal_scale)
)
return dqx, (tangent_scale, tangent_amax_history)
def dequantize_bwd(res, g):
assert g.dtype == primal_dtype # gradient is in primal dtype
tangent_scale, tangent_amax_history = res
new_tangent_scale, new_tangent_amax_history = nn.fp8_ops.update_fp8_meta(
g,
quantized_primal_dtype,
tangent_scale,
tangent_amax_history
)
qg = nn.fp8_ops.quantize(
g,
quantized_primal_dtype,
nn.fp8_ops._fm32_to_float32(new_tangent_scale),
primal_dtype
)
reverse_scale = new_tangent_scale # hack to get new_tangent_scale into quantize_bwd
return qg, reverse_scale, new_tangent_scale, new_tangent_amax_history
@custom_vjp
def dequantize(x, new_primal_scale, tangent_scale, tangent_amax_history):
dqx, _ = dequantize_fwd(x, new_primal_scale, tangent_scale, tangent_amax_history)
return dqx
dequantize.defvjp(dequantize_fwd, dequantize_bwd)
return dequantize
def create_primal_variables(module: nn.Module, amax_hist_length: int = 1024):
"""Create scale and amax history variables for quantization."""
primal_scale = module.variable(
nn.fp8_ops.OVERWRITE_WITH_GRADIENT,
"primal_scale",
nn.initializers.ones_init(),
jax.random.key(0),
(1,),
)
primal_amax_history = module.variable(
nn.fp8_ops.OVERWRITE_WITH_GRADIENT,
"primal_amax_history",
nn.initializers.zeros_init(),
jax.random.key(0),
(amax_hist_length,),
)
return primal_scale, primal_amax_history
def create_tangent_variables(module: nn.Module, amax_hist_length: int = 1024):
"""Create tangent scale and amax history variables for gradient quantization."""
tangent_scale = module.variable(
nn.fp8_ops.OVERWRITE_WITH_GRADIENT,
"tangent_scale",
nn.initializers.ones_init(),
jax.random.key(0),
(1,),
)
tangent_amax_history = module.variable(
nn.fp8_ops.OVERWRITE_WITH_GRADIENT,
"tangent_amax_history",
nn.initializers.zeros_init(),
jax.random.key(0),
(amax_hist_length,),
)
return tangent_scale, tangent_amax_history