|  | 
| 6 | 6 | from torch.nn.parameter import Parameter | 
| 7 | 7 | 
 | 
| 8 | 8 | from vllm import envs | 
|  | 9 | +from vllm.logger import init_logger | 
| 9 | 10 | from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, | 
| 10 | 11 |                                                   FusedMoEMethodBase) | 
| 11 | 12 | from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( | 
|  | 
| 26 | 27 | from vllm.scalar_type import scalar_types | 
| 27 | 28 | from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer, | 
| 28 | 29 |                         next_power_of_2, round_up) | 
|  | 30 | +from vllm.utils.flashinfer import has_flashinfer | 
| 29 | 31 | 
 | 
| 30 |  | -if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 | 
| 31 |  | -        or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): | 
| 32 |  | -    from flashinfer.fused_moe import cutlass_fused_moe | 
| 33 |  | -    from flashinfer.autotuner import autotune | 
| 34 |  | -    from flashinfer import (mxfp8_quantize, shuffle_matrix_a, | 
| 35 |  | -                            shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) | 
|  | 32 | +logger = init_logger(__name__) | 
|  | 33 | + | 
|  | 34 | + | 
|  | 35 | +def _should_use_flashinfer_mxfp4_bf16(): | 
|  | 36 | +    """Determine if FlashInfer MXFP4 BF16 should be used.""" | 
|  | 37 | +    # If explicitly set, respect the setting | 
|  | 38 | +    if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): | 
|  | 39 | +        return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 | 
|  | 40 | + | 
|  | 41 | +    # Enable by default on SM100 if MXFP8 is not explicitly enabled | 
|  | 42 | +    if (current_platform.is_device_capability(100) and has_flashinfer() | 
|  | 43 | +            and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): | 
|  | 44 | +        logger.info_once( | 
|  | 45 | +            "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. " | 
|  | 46 | +            "For faster performance, consider setting " | 
|  | 47 | +            "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, " | 
|  | 48 | +            "though this may impact accuracy.") | 
|  | 49 | +        return True | 
|  | 50 | + | 
|  | 51 | +    return False | 
|  | 52 | + | 
|  | 53 | + | 
|  | 54 | +def _should_use_flashinfer_mxfp4_mxfp8(): | 
|  | 55 | +    """Determine if FlashInfer MXFP4 MXFP8 should be used.""" | 
|  | 56 | +    return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 | 
|  | 57 | + | 
|  | 58 | + | 
|  | 59 | +def should_use_flashinfer_mxfp4(): | 
|  | 60 | +    return (_should_use_flashinfer_mxfp4_mxfp8() | 
|  | 61 | +            or _should_use_flashinfer_mxfp4_bf16()) | 
| 36 | 62 | 
 | 
| 37 | 63 | 
 | 
| 38 | 64 | class Mxfp4Config(QuantizationConfig): | 
| @@ -89,12 +115,18 @@ def __init__(self, moe: FusedMoEConfig): | 
| 89 | 115 |         self.use_marlin = self._should_use_marlin() | 
| 90 | 116 |         self.flashinfer_autotune = True | 
| 91 | 117 | 
 | 
|  | 118 | +        if current_platform.is_device_capability(100) and not has_flashinfer(): | 
|  | 119 | +            logger.warning_once( | 
|  | 120 | +                "MXFP4 MoE is enabled on Blackwell but FlashInfer " | 
|  | 121 | +                "is not available. This may result in degraded performance. " | 
|  | 122 | +                "Please `pip install vllm[flashinfer]` for best results.") | 
|  | 123 | + | 
| 92 | 124 |     def _should_use_marlin(self): | 
| 93 | 125 |         if envs.VLLM_MXFP4_USE_MARLIN is not None: | 
| 94 | 126 |             return envs.VLLM_MXFP4_USE_MARLIN | 
| 95 | 127 |         if current_platform.is_cuda() and \ | 
| 96 |  | -                not current_platform.has_device_capability(100): | 
| 97 |  | -            if not current_platform.is_device_capability(90): | 
|  | 128 | +                not current_platform.is_device_capability(100): | 
|  | 129 | +            if not current_platform.has_device_capability(90): | 
| 98 | 130 |                 # marlin kernel has better performance on ampere | 
| 99 | 131 |                 return True | 
| 100 | 132 |             if not has_triton_kernels(): | 
| @@ -140,8 +172,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, | 
| 140 | 172 |             layer.hidden_size = hidden_size | 
| 141 | 173 |             layer.intermediate_size_per_partition = \ | 
| 142 | 174 |                 intermediate_size_per_partition_after_pad | 
| 143 |  | -        elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 | 
| 144 |  | -              or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16) and current_platform.is_device_capability(100): | 
|  | 175 | +        elif should_use_flashinfer_mxfp4(): | 
| 145 | 176 |             # pad the intermediate size to be a multiple of 2 * mxfp4_block | 
| 146 | 177 |             # for to hold non-uniform sharded tensor as well as swizzling | 
| 147 | 178 |             # other padding to increase performance | 
| @@ -235,8 +266,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, | 
| 235 | 266 |     def process_weights_after_loading(self, layer): | 
| 236 | 267 |         if self.use_marlin: | 
| 237 | 268 |             prepare_moe_fp4_layer_for_marlin(layer) | 
| 238 |  | -        elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 | 
| 239 |  | -              or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): | 
|  | 269 | +        elif should_use_flashinfer_mxfp4(): | 
|  | 270 | +            from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a | 
| 240 | 271 |             layer.gemm1_alpha = Parameter(torch.tensor( | 
| 241 | 272 |                 [1.702] * self.num_experts, dtype=torch.float32).cuda(), | 
| 242 | 273 |                                           requires_grad=False) | 
| @@ -573,11 +604,11 @@ def apply( | 
| 573 | 604 |             logical_replica_count), ( | 
| 574 | 605 |                 "MXFP4 are not supported with this configuration.") | 
| 575 | 606 | 
 | 
| 576 |  | -        if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 | 
| 577 |  | -                or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16) and current_platform.is_device_capability(100): | 
|  | 607 | +        if should_use_flashinfer_mxfp4(): | 
|  | 608 | +            from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe | 
| 578 | 609 |             assert not self.moe.use_ep, ( | 
| 579 | 610 |                 "EP is not supported for flashinfer mxfp4 moe backend yet.") | 
| 580 |  | -            if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: | 
|  | 611 | +            if _should_use_flashinfer_mxfp4_bf16(): | 
| 581 | 612 |                 assert x.dtype == torch.bfloat16 | 
| 582 | 613 |                 x_quant = x | 
| 583 | 614 |                 x_scale = None | 
|  | 
0 commit comments