Skip to content

Commit 481be6d

Browse files
mgoindjmmoss
authored andcommitted
Use Blackwell FlashInfer MXFP4 MoE by default if available (vllm-project#23008)
Signed-off-by: mgoin <[email protected]>
1 parent 4d2db6f commit 481be6d

File tree

2 files changed

+51
-20
lines changed

2 files changed

+51
-20
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -762,11 +762,11 @@ def __init__(
762762
self.global_num_experts = num_experts + num_redundant_experts
763763

764764
# we padding globally so EP buffer allocation works
765-
if (quant_config and quant_config.get_name() == "mxfp4"
766-
and (current_platform.is_rocm()
767-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
768-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16)):
769-
hidden_size = round_up(hidden_size, 256)
765+
if quant_config and quant_config.get_name() == "mxfp4":
766+
from vllm.model_executor.layers.quantization.mxfp4 import ( # noqa: E501
767+
should_use_flashinfer_mxfp4)
768+
if current_platform.is_rocm() or should_use_flashinfer_mxfp4():
769+
hidden_size = round_up(hidden_size, 256)
770770

771771
# For smuggling this layer into the fused moe custom op
772772
compilation_config = vllm_config.compilation_config

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from torch.nn.parameter import Parameter
77

88
from vllm import envs
9+
from vllm.logger import init_logger
910
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig,
1011
FusedMoEMethodBase)
1112
from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import (
@@ -26,13 +27,38 @@
2627
from vllm.scalar_type import scalar_types
2728
from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer,
2829
next_power_of_2, round_up)
30+
from vllm.utils.flashinfer import has_flashinfer
2931

30-
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
31-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
32-
from flashinfer.fused_moe import cutlass_fused_moe
33-
from flashinfer.autotuner import autotune
34-
from flashinfer import (mxfp8_quantize, shuffle_matrix_a,
35-
shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe)
32+
logger = init_logger(__name__)
33+
34+
35+
def _should_use_flashinfer_mxfp4_bf16():
36+
"""Determine if FlashInfer MXFP4 BF16 should be used."""
37+
# If explicitly set, respect the setting
38+
if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"):
39+
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16
40+
41+
# Enable by default on SM100 if MXFP8 is not explicitly enabled
42+
if (current_platform.is_device_capability(100) and has_flashinfer()
43+
and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")):
44+
logger.info_once(
45+
"Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. "
46+
"For faster performance, consider setting "
47+
"VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, "
48+
"though this may impact accuracy.")
49+
return True
50+
51+
return False
52+
53+
54+
def _should_use_flashinfer_mxfp4_mxfp8():
55+
"""Determine if FlashInfer MXFP4 MXFP8 should be used."""
56+
return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
57+
58+
59+
def should_use_flashinfer_mxfp4():
60+
return (_should_use_flashinfer_mxfp4_mxfp8()
61+
or _should_use_flashinfer_mxfp4_bf16())
3662

3763

3864
class Mxfp4Config(QuantizationConfig):
@@ -89,12 +115,18 @@ def __init__(self, moe: FusedMoEConfig):
89115
self.use_marlin = self._should_use_marlin()
90116
self.flashinfer_autotune = True
91117

118+
if current_platform.is_device_capability(100) and not has_flashinfer():
119+
logger.warning_once(
120+
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
121+
"is not available. This may result in degraded performance. "
122+
"Please `pip install vllm[flashinfer]` for best results.")
123+
92124
def _should_use_marlin(self):
93125
if envs.VLLM_MXFP4_USE_MARLIN is not None:
94126
return envs.VLLM_MXFP4_USE_MARLIN
95127
if current_platform.is_cuda() and \
96-
not current_platform.has_device_capability(100):
97-
if not current_platform.is_device_capability(90):
128+
not current_platform.is_device_capability(100):
129+
if not current_platform.has_device_capability(90):
98130
# marlin kernel has better performance on ampere
99131
return True
100132
if not has_triton_kernels():
@@ -140,8 +172,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
140172
layer.hidden_size = hidden_size
141173
layer.intermediate_size_per_partition = \
142174
intermediate_size_per_partition_after_pad
143-
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
144-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16) and current_platform.is_device_capability(100):
175+
elif should_use_flashinfer_mxfp4():
145176
# pad the intermediate size to be a multiple of 2 * mxfp4_block
146177
# for to hold non-uniform sharded tensor as well as swizzling
147178
# other padding to increase performance
@@ -235,8 +266,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
235266
def process_weights_after_loading(self, layer):
236267
if self.use_marlin:
237268
prepare_moe_fp4_layer_for_marlin(layer)
238-
elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
239-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16):
269+
elif should_use_flashinfer_mxfp4():
270+
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
240271
layer.gemm1_alpha = Parameter(torch.tensor(
241272
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
242273
requires_grad=False)
@@ -573,11 +604,11 @@ def apply(
573604
logical_replica_count), (
574605
"MXFP4 are not supported with this configuration.")
575606

576-
if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8
577-
or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16) and current_platform.is_device_capability(100):
607+
if should_use_flashinfer_mxfp4():
608+
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
578609
assert not self.moe.use_ep, (
579610
"EP is not supported for flashinfer mxfp4 moe backend yet.")
580-
if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16:
611+
if _should_use_flashinfer_mxfp4_bf16():
581612
assert x.dtype == torch.bfloat16
582613
x_quant = x
583614
x_scale = None

0 commit comments

Comments
 (0)