Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,11 +762,11 @@ def __init__(
self.global_num_experts = num_experts + num_redundant_experts

# we padding globally so EP buffer allocation works
if (quant_config and quant_config.get_name() == "mxfp4"
and (current_platform.is_rocm()
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)):
hidden_size = round_up(hidden_size, 256)
if quant_config and quant_config.get_name() == "mxfp4":
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
should_use_flashinfer_mxfp4)
if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
hidden_size = round_up(hidden_size, 256)

# For smuggling this layer into the fused moe custom op
compilation_config = vllm_config.compilation_config
Expand Down
60 changes: 46 additions & 14 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.nn.parameter import Parameter

from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
FusedMoEMethodBase)
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
Expand All @@ -26,12 +27,38 @@
from vllm.scalar_type import scalar_types
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
next_power_of_2, round_up)
from vllm.utils.flashinfer import has_flashinfer

if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
# from flashinfer.fused_moe import cutlass_fused_moe
from flashinfer import (mxfp8_quantize, shuffle_matrix_a,
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
logger = init_logger(__name__)


def _should_use_flashinfer_mxfp4_bf16():
"""Determine if FlashInfer MXFP4 BF16 should be used."""
# If explicitly set, respect the setting
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16

# Enable by default on SM100 if MXFP8 is not explicitly enabled
if (current_platform.is_device_capability(100) and has_flashinfer()
and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
logger.info_once(
"Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
"For faster performance, consider setting "
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
"though this may impact accuracy.")
return True

return False


def _should_use_flashinfer_mxfp4_mxfp8():
"""Determine if FlashInfer MXFP4 MXFP8 should be used."""
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8


def should_use_flashinfer_mxfp4():
return (_should_use_flashinfer_mxfp4_mxfp8()
or _should_use_flashinfer_mxfp4_bf16())


class Mxfp4Config(QuantizationConfig):
Expand Down Expand Up @@ -87,12 +114,18 @@ def __init__(self, moe: FusedMoEConfig):
self.moe = moe
self.use_marlin = self._should_use_marlin()

if current_platform.is_device_capability(100) and not has_flashinfer():
logger.warning_once(
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
"is not available. This may result in degraded performance. "
"Please `pip install vllm[flashinfer]` for best results.")

def _should_use_marlin(self):
if envs.VLLM_MXFP4_USE_MARLIN is not None:
return envs.VLLM_MXFP4_USE_MARLIN
if current_platform.is_cuda() and \
not current_platform.has_device_capability(100):
if not current_platform.is_device_capability(90):
not current_platform.is_device_capability(100):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are sm120 such as RTX 6000 Pro Blackwell , shouldn't we include those as well ?

if not current_platform.has_device_capability(90):
# marlin kernel has better performance on ampere
return True
if not has_triton_kernels():
Expand Down Expand Up @@ -138,8 +171,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
layer.hidden_size = hidden_size
layer.intermediate_size_per_partition = \
intermediate_size_per_partition_after_pad
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
elif should_use_flashinfer_mxfp4():
# pad the intermediate size to be a multiple of 2 * mxfp4_block
# for to hold non-uniform sharded tensor as well as swizzling
# other padding to increase performance
Expand Down Expand Up @@ -230,8 +262,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
def process_weights_after_loading(self, layer):
if self.use_marlin:
prepare_moe_fp4_layer_for_marlin(layer)
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
elif should_use_flashinfer_mxfp4():
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
layer.gemm1_alpha = Parameter(torch.tensor(
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
requires_grad=False)
Expand Down Expand Up @@ -478,11 +510,11 @@ def apply(
logical_replica_count), (
"MXFP4 are not supported with this configuration.")

if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
if should_use_flashinfer_mxfp4():
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
assert not self.moe.use_ep, (
"EP is not supported for flashinfer mxfp4 moe backend yet.")
if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
if _should_use_flashinfer_mxfp4_bf16():
assert x.dtype == torch.bfloat16
x_quant = x
x_scale = None
Expand Down