|  | 
| 6 | 6 | from torch.nn.parameter import Parameter | 
| 7 | 7 | 
 | 
| 8 | 8 | from vllm import envs | 
|  | 9 | +from vllm.logger import init_logger | 
| 9 | 10 | from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, | 
| 10 | 11 |                                                   FusedMoEMethodBase) | 
| 11 | 12 | from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( | 
|  | 
| 26 | 27 | from vllm.scalar_type import scalar_types | 
| 27 | 28 | from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer, | 
| 28 | 29 |                         next_power_of_2, round_up) | 
|  | 30 | +from vllm.utils.flashinfer import has_flashinfer | 
| 29 | 31 | 
 | 
| 30 |  | -if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 | 
| 31 |  | -        or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): | 
| 32 |  | -    # from flashinfer.fused_moe import cutlass_fused_moe | 
| 33 |  | -    from flashinfer import (mxfp8_quantize, shuffle_matrix_a, | 
| 34 |  | -                            shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) | 
|  | 32 | +logger = init_logger(__name__) | 
|  | 33 | + | 
|  | 34 | + | 
|  | 35 | +def _should_use_flashinfer_mxfp4_bf16(): | 
|  | 36 | +    """Determine if FlashInfer MXFP4 BF16 should be used.""" | 
|  | 37 | +    # If explicitly set, respect the setting | 
|  | 38 | +    if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): | 
|  | 39 | +        return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 | 
|  | 40 | + | 
|  | 41 | +    # Enable by default on SM100 if MXFP8 is not explicitly enabled | 
|  | 42 | +    if (current_platform.is_device_capability(100) and has_flashinfer() | 
|  | 43 | +            and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): | 
|  | 44 | +        logger.info_once( | 
|  | 45 | +            "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. " | 
|  | 46 | +            "For faster performance, consider setting " | 
|  | 47 | +            "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, " | 
|  | 48 | +            "though this may impact accuracy.") | 
|  | 49 | +        return True | 
|  | 50 | + | 
|  | 51 | +    return False | 
|  | 52 | + | 
|  | 53 | + | 
|  | 54 | +def _should_use_flashinfer_mxfp4_mxfp8(): | 
|  | 55 | +    """Determine if FlashInfer MXFP4 MXFP8 should be used.""" | 
|  | 56 | +    return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 | 
|  | 57 | + | 
|  | 58 | + | 
|  | 59 | +def should_use_flashinfer_mxfp4(): | 
|  | 60 | +    return (_should_use_flashinfer_mxfp4_mxfp8() | 
|  | 61 | +            or _should_use_flashinfer_mxfp4_bf16()) | 
| 35 | 62 | 
 | 
| 36 | 63 | 
 | 
| 37 | 64 | class Mxfp4Config(QuantizationConfig): | 
| @@ -87,12 +114,18 @@ def __init__(self, moe: FusedMoEConfig): | 
| 87 | 114 |         self.moe = moe | 
| 88 | 115 |         self.use_marlin = self._should_use_marlin() | 
| 89 | 116 | 
 | 
|  | 117 | +        if current_platform.is_device_capability(100) and not has_flashinfer(): | 
|  | 118 | +            logger.warning_once( | 
|  | 119 | +                "MXFP4 MoE is enabled on Blackwell but FlashInfer " | 
|  | 120 | +                "is not available. This may result in degraded performance. " | 
|  | 121 | +                "Please `pip install vllm[flashinfer]` for best results.") | 
|  | 122 | + | 
| 90 | 123 |     def _should_use_marlin(self): | 
| 91 | 124 |         if envs.VLLM_MXFP4_USE_MARLIN is not None: | 
| 92 | 125 |             return envs.VLLM_MXFP4_USE_MARLIN | 
| 93 | 126 |         if current_platform.is_cuda() and \ | 
| 94 |  | -                not current_platform.has_device_capability(100): | 
| 95 |  | -            if not current_platform.is_device_capability(90): | 
|  | 127 | +                not current_platform.is_device_capability(100): | 
|  | 128 | +            if not current_platform.has_device_capability(90): | 
| 96 | 129 |                 # marlin kernel has better performance on ampere | 
| 97 | 130 |                 return True | 
| 98 | 131 |             if not has_triton_kernels(): | 
| @@ -138,8 +171,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, | 
| 138 | 171 |             layer.hidden_size = hidden_size | 
| 139 | 172 |             layer.intermediate_size_per_partition = \ | 
| 140 | 173 |                 intermediate_size_per_partition_after_pad | 
| 141 |  | -        elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 | 
| 142 |  | -              or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): | 
|  | 174 | +        elif should_use_flashinfer_mxfp4(): | 
| 143 | 175 |             # pad the intermediate size to be a multiple of 2 * mxfp4_block | 
| 144 | 176 |             # for to hold non-uniform sharded tensor as well as swizzling | 
| 145 | 177 |             # other padding to increase performance | 
| @@ -230,8 +262,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, | 
| 230 | 262 |     def process_weights_after_loading(self, layer): | 
| 231 | 263 |         if self.use_marlin: | 
| 232 | 264 |             prepare_moe_fp4_layer_for_marlin(layer) | 
| 233 |  | -        elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 | 
| 234 |  | -              or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): | 
|  | 265 | +        elif should_use_flashinfer_mxfp4(): | 
|  | 266 | +            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a | 
| 235 | 267 |             layer.gemm1_alpha = Parameter(torch.tensor( | 
| 236 | 268 |                 [1.702] * self.num_experts, dtype=torch.float32).cuda(), | 
| 237 | 269 |                                           requires_grad=False) | 
| @@ -478,11 +510,11 @@ def apply( | 
| 478 | 510 |             logical_replica_count), ( | 
| 479 | 511 |                 "MXFP4 are not supported with this configuration.") | 
| 480 | 512 | 
 | 
| 481 |  | -        if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 | 
| 482 |  | -                or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): | 
|  | 513 | +        if should_use_flashinfer_mxfp4(): | 
|  | 514 | +            from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe | 
| 483 | 515 |             assert not self.moe.use_ep, ( | 
| 484 | 516 |                 "EP is not supported for flashinfer mxfp4 moe backend yet.") | 
| 485 |  | -            if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: | 
|  | 517 | +            if _should_use_flashinfer_mxfp4_bf16(): | 
| 486 | 518 |                 assert x.dtype == torch.bfloat16 | 
| 487 | 519 |                 x_quant = x | 
| 488 | 520 |                 x_scale = None | 
|  | 
0 commit comments