Skip to content

Issue with NNX custom_vjp in activation quantization implementation #4651

@liamclarkza

Description

@liamclarkza

I am trying to implement some basic quantize/dequantize operations using NNX that would allow me to quantize activations and gradients. For now, I am trying to use nnx.custom_vjp to manage the references of a quantizer module such that I can propagate quantization metadata between quantization/dequantization fwd/bwd calls. It's a very basic implementation with several limitations (such as being limited to using the same dtype for primals and tangents, just to name one) that I hope to build upon once I get this working correctly.

I have previously used something similar in combination with rematerialization in linen to reduce the size of the checkpointed values, which worked quite nicely. I have added the code for this implementation at the end.

Currently, the issue I am having with the NNX implementation is that the bwd_scale and bwd_amax_hist are not being updated correctly after the dq_bwd call, even though the returned updates are correct within dq_bwd, meaning that the q_bwd has the wrong bwd_scale value for dequantizing the gradient. One can also see from the output (below the code sample) that the module's bwd_scale and bwd_amax_hist don't seem to be traced over correctly.

I have tried to implement this in accordance with the nnx.custom_vjp docstring example; however, I suspect I am getting something wrong with regard to how module state needs to be updated within a bwd function and how the reference semantics are handled when using nnx.custom_vjp.

Any help pointing me in the right direction would be greatly appreciated!

NNX Implementation

import jax
import jax.numpy as jnp
from flax import nnx
from jax import random

# Example pure function implementations of amax quantization and dequantization

def update_metadata(x, amax_hist, q_dtype):
    amax = jnp.max(jnp.abs(x))
    qmax = jnp.finfo(q_dtype).max.astype(jnp.float32)
    new_scale = amax.astype(jnp.float32) / qmax
    new_amax_hist = jnp.roll(amax_hist, 1).at[0].set(amax)
    return new_scale, new_amax_hist

def quantize(x, scale, q_dtype):
    qmax = jnp.finfo(q_dtype).max.astype(jnp.float32)
    x_q = jnp.clip(x / scale, -qmax, qmax).astype(q_dtype)
    return x_q

def dequantize(x_q, scale, dq_dtype):
    x = x_q.astype(dq_dtype) * scale
    return x


# NNX wrappers for quantization and dequantization

class QuantizationMetadata(nnx.Variable):
  pass


class Quantizer(nnx.Module):

    def __init__(self, q_dtype, dq_dtype, amax_hist_len=128):
        self.q_dtype = q_dtype
        self.dq_dtype = dq_dtype
        self.fwd_amax_hist = QuantizationMetadata(jnp.zeros((amax_hist_len,)))
        self.fwd_scale = QuantizationMetadata(1)
        self.bwd_amax_hist = QuantizationMetadata(jnp.zeros((amax_hist_len,)))
        self.bwd_scale = QuantizationMetadata(1)

    def quantize(self, x):
        xq = nnx_quantize(self, x)
        return xq

    def dequantize(self, xq):
        x = nnx_dequantize(self, xq)
        return x

@nnx.custom_vjp
def nnx_quantize(m: Quantizer, x):
    m.fwd_scale.value, m.fwd_amax_hist.value = update_metadata(x, m.fwd_amax_hist.value, m.q_dtype)
    xq = quantize(x, m.fwd_scale.value, m.q_dtype)
    return xq

def q_fwd(m: Quantizer, x):
    xq = nnx_quantize(m, x)
    return xq, m

def q_bwd(res, g):
    m = res
    (in_mg, in_xg), xq_g = g
    xg = dequantize(xq_g, m.bwd_scale.value, m.dq_dtype)
    return in_mg, xg

nnx_quantize.defvjp(q_fwd, q_bwd)


@nnx.custom_vjp
def nnx_dequantize(m: Quantizer, qx):
    x = dequantize(qx, m.fwd_scale.value, m.dq_dtype)
    return x

def dq_fwd(m: Quantizer, qx):
    x = nnx_dequantize(m, qx)
    return x, m

def dq_bwd(res, g):
    m = res
    (in_mg, in_xg), x_g = g

    # copy tree and update variables
    bwd_scale, bwd_amax_hist = update_metadata(x_g, m.bwd_amax_hist.value, m.q_dtype)
    mg = jax.tree.map(lambda x: x, in_mg)
    mg['bwd_scale'].value = bwd_scale
    mg['bwd_amax_hist'].value = bwd_amax_hist

    qx_g = quantize(x_g, bwd_scale, m.q_dtype)
    return mg, qx_g

nnx_dequantize.defvjp(dq_fwd, dq_bwd)



if __name__=="__main__":

    quantizer = Quantizer(jnp.float8_e4m3fn, jnp.float32)

    def qdq(x):
        xq = quantizer.quantize(x)
        dqx = quantizer.dequantize(xq)
        return dqx

    x = jnp.array(5, jnp.float32)
    y, g = nnx.value_and_grad(qdq)(x)
    print("x:", x)
    print("y:", y)
    print("g:", g)
    print("quantizer:", quantizer)

Output:

x: 5.0
y: 5.0
g: 448.0
quantizer: Quantizer(
  bwd_amax_hist=QuantizationMetadata(
    value=Array(shape=(128,), dtype=float32)
  ),
  bwd_scale=QuantizationMetadata(
    value=1
  ),
  dq_dtype=<class 'jax.numpy.float32'>,
  fwd_amax_hist=QuantizationMetadata(
    value=Array(shape=(128,), dtype=float32)
  ),
  fwd_scale=QuantizationMetadata(
    value=Traced<ShapedArray(float32[])>with<JVPTrace> with
      primal = Array(0.01116071, dtype=float32)
      tangent = Traced<ShapedArray(float32[])>with<JaxprTrace> with
        pval = (ShapedArray(float32[]), None)
        recipe = JaxprEqnRecipe(eqn_id=<object object at 0x1108f69d0>, in_tracers=(Traced<ShapedArray(float32[128]):JaxprTrace>, Traced<ShapedArray(int32[], weak_type=True):JaxprTrace>, Traced<ShapedArray(float32[128]):JaxprTrace>, Traced<ShapedArray(float32[]):JaxprTrace>, Traced<ShapedArray(float32[128]):JaxprTrace>, Traced<ShapedArray(float0[]):JaxprTrace>, Traced<ShapedArray(float32[128]):JaxprTrace>, Traced<ShapedArray(float32[]):JaxprTrace>, Traced<ShapedArray(float8_e4m3fn[]):JaxprTrace>), out_tracer_refs=[<weakref at 0x1117907c0; to 'JaxprTracer' at 0x111740040>, <weakref at 0x1117921b0; dead>, <weakref at 0x111792160; to 'JaxprTracer' at 0x111790860>, <weakref at 0x111792020; to 'JaxprTracer' at 0x111790900>, <weakref at 0x111791f80; dead>], out_avals=[ShapedArray(float32[128]), ShapedArray(float0[], weak_type=True), ShapedArray(float32[128]), ShapedArray(float32[]), ShapedArray(float32[])], primitive=custom_lin, params={'num_res': 4, 'bwd': Wrapped function:
    0   : _flatten_bwd   (PyTreeDef((CustomNode(NodeStates[(None,)], [None, (CustomNode(State[('bwd_amax_hist', 'bwd_scale', 'fwd_amax_hist', 'fwd_scale')], [CustomNode(VariableState[(<class '__main__.QuantizationMetadata'>, (('get_value_hooks', ()), ('set_value_hooks', ()), ('create_value_hooks', ()), ('add_axis_hooks', ()), ('remove_axis_hooks', ())))], [*]), CustomNode(VariableState[(<class '__main__.QuantizationMetadata'>, (('get_value_hooks', ()), ('set_value_hooks', ()), ('create_value_hooks', ()), ('add_axis_hooks', ()), ('remove_axis_hooks', ())))], [*]), CustomNode(VariableState[(<class '__main__.QuantizationMetadata'>, (('get_value_hooks', ()), ('set_value_hooks', ()), ('create_value_hooks', ()), ('add_axis_hooks', ()), ('remove_axis_hooks', ())))], [*]), CustomNode(VariableState[(<class '__main__.QuantizationMetadata'>, (('get_value_hooks', ()), ('set_value_hooks', ()), ('create_value_hooks', ()), ('add_axis_hooks', ()), ('remove_axis_hooks', ())))], [*])]),)]), *)), [ShapedArray(float32[128]), ShapedArray(int32[], weak_type=True), ShapedArray(float32[128]), ShapedArray(float32[]), ShapedArray(float8_e4m3fn[])], <function transformation_with_aux2.<locals>.<lambda> at 0x111784720>)
    1   : _get_result_paths_thunk   ()
    Core: dq_bwd
    , 'out_avals': [ShapedArray(float32[128]), ShapedArray(float0[], weak_type=True), ShapedArray(float32[128]), ShapedArray(float32[]), ShapedArray(float32[])], 'symbolic_zeros': False}, effects=frozenset(), source_info=<jax._src.source_info_util.SourceInfo object at 0x111780b50>, ctx=JaxprEqnContext(compute_type=None, threefry_partitionable=True, cur_abstract_mesh=AbstractMesh((), axis_types={}), xla_metadata=None))
  ),
  q_dtype=<class 'jax.numpy.float8_e4m3fn'>
)

Linen implementation

from functools import partial
import jax
from flax import linen as nn
from jax import custom_vjp
import jax.numpy as jnp
from jax.experimental import primal_tangent_dtype


def quantize_fn(primal_dtype, quantized_primal_dtype):

    def quantize_fwd(x, primal_scale, primal_amax_history):
        new_primal_scale, new_primal_amax_history = nn.fp8_ops.update_fp8_meta(
            x,
            quantized_primal_dtype,
            primal_scale,
            primal_amax_history
        )
        qx = nn.fp8_ops.quantize(
            x,
            quantized_primal_dtype,
            nn.fp8_ops._fm32_to_float32(new_primal_scale),
            primal_dtype
        )
        return (qx, new_primal_scale), (new_primal_scale, new_primal_amax_history)

    def quantize_bwd(res, g):
        qx, reverse_scale = g
        new_primal_scale, new_primal_amax_history = res
        dqx = nn.fp8_ops.dequantize(
            qx,
            primal_dtype,
            nn.fp8_ops._fm32_to_float32(reverse_scale)
        )
        return dqx, new_primal_scale, new_primal_amax_history

    @custom_vjp
    def quantize(x, primal_scale, primal_amax_history):
        (qx, new_primal_scale), _ = quantize_fwd(x, primal_scale, primal_amax_history)
        return qx, new_primal_scale

    quantize.defvjp(quantize_fwd, quantize_bwd)

    return quantize


def dequantize_fn(primal_dtype, quantized_primal_dtype):

    def dequantize_fwd(x, new_primal_scale, tangent_scale, tangent_amax_history):
        assert x.dtype == quantized_primal_dtype
        dqx = nn.fp8_ops.dequantize(
            x,
            primal_dtype,
            nn.fp8_ops._fm32_to_float32(new_primal_scale)
        )
        return dqx, (tangent_scale, tangent_amax_history)

    def dequantize_bwd(res, g):
        assert g.dtype == primal_dtype # gradient is in primal dtype
        tangent_scale, tangent_amax_history = res
        new_tangent_scale, new_tangent_amax_history = nn.fp8_ops.update_fp8_meta(
            g,
            quantized_primal_dtype,
            tangent_scale,
            tangent_amax_history
        )
        qg = nn.fp8_ops.quantize(
            g,
            quantized_primal_dtype,
            nn.fp8_ops._fm32_to_float32(new_tangent_scale),
            primal_dtype
        )
        reverse_scale = new_tangent_scale # hack to get new_tangent_scale into quantize_bwd
        return qg, reverse_scale, new_tangent_scale, new_tangent_amax_history

    @custom_vjp
    def dequantize(x, new_primal_scale, tangent_scale, tangent_amax_history):
        dqx, _ = dequantize_fwd(x, new_primal_scale, tangent_scale, tangent_amax_history)
        return dqx

    dequantize.defvjp(dequantize_fwd, dequantize_bwd)

    return dequantize



def create_primal_variables(module: nn.Module, amax_hist_length: int = 1024):
    """Create scale and amax history variables for quantization."""
    primal_scale = module.variable(
        nn.fp8_ops.OVERWRITE_WITH_GRADIENT,
        "primal_scale",
        nn.initializers.ones_init(),
        jax.random.key(0),
        (1,),
    )
    primal_amax_history = module.variable(
        nn.fp8_ops.OVERWRITE_WITH_GRADIENT,
        "primal_amax_history",
        nn.initializers.zeros_init(),
        jax.random.key(0),
        (amax_hist_length,),
    )
    return primal_scale, primal_amax_history


def create_tangent_variables(module: nn.Module, amax_hist_length: int = 1024):
    """Create tangent scale and amax history variables for gradient quantization."""
    tangent_scale = module.variable(
        nn.fp8_ops.OVERWRITE_WITH_GRADIENT,
        "tangent_scale",
        nn.initializers.ones_init(),
        jax.random.key(0),
        (1,),
    )
    tangent_amax_history = module.variable(
        nn.fp8_ops.OVERWRITE_WITH_GRADIENT,
        "tangent_amax_history",
        nn.initializers.zeros_init(),
        jax.random.key(0),
        (amax_hist_length,),
    )
    return tangent_scale, tangent_amax_history

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions