|
6 | 6 | from torch.nn.parameter import Parameter |
7 | 7 |
|
8 | 8 | from vllm import envs |
| 9 | +from vllm.logger import init_logger |
9 | 10 | from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEConfig, |
10 | 11 | FusedMoEMethodBase) |
11 | 12 | from vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe import ( |
|
26 | 27 | from vllm.scalar_type import scalar_types |
27 | 28 | from vllm.utils import (has_triton_kernels, is_torch_equal_or_newer, |
28 | 29 | next_power_of_2, round_up) |
| 30 | +from vllm.utils.flashinfer import has_flashinfer |
29 | 31 |
|
30 | | -if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 |
31 | | - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): |
32 | | - # from flashinfer.fused_moe import cutlass_fused_moe |
33 | | - from flashinfer import (mxfp8_quantize, shuffle_matrix_a, |
34 | | - shuffle_matrix_sf_a, trtllm_fp4_block_scale_moe) |
| 32 | +logger = init_logger(__name__) |
| 33 | + |
| 34 | + |
| 35 | +def _should_use_flashinfer_mxfp4_bf16(): |
| 36 | + """Determine if FlashInfer MXFP4 BF16 should be used.""" |
| 37 | + # If explicitly set, respect the setting |
| 38 | + if envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16"): |
| 39 | + return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16 |
| 40 | + |
| 41 | + # Enable by default on SM100 if MXFP8 is not explicitly enabled |
| 42 | + if (current_platform.is_device_capability(100) and has_flashinfer() |
| 43 | + and not envs.is_set("VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8")): |
| 44 | + logger.info_once( |
| 45 | + "Enabling FlashInfer MXFP4 BF16 backend by default for Blackwell. " |
| 46 | + "For faster performance, consider setting " |
| 47 | + "VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8=1, " |
| 48 | + "though this may impact accuracy.") |
| 49 | + return True |
| 50 | + |
| 51 | + return False |
| 52 | + |
| 53 | + |
| 54 | +def _should_use_flashinfer_mxfp4_mxfp8(): |
| 55 | + """Determine if FlashInfer MXFP4 MXFP8 should be used.""" |
| 56 | + return envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 |
| 57 | + |
| 58 | + |
| 59 | +def should_use_flashinfer_mxfp4(): |
| 60 | + return (_should_use_flashinfer_mxfp4_mxfp8() |
| 61 | + or _should_use_flashinfer_mxfp4_bf16()) |
35 | 62 |
|
36 | 63 |
|
37 | 64 | class Mxfp4Config(QuantizationConfig): |
@@ -87,12 +114,18 @@ def __init__(self, moe: FusedMoEConfig): |
87 | 114 | self.moe = moe |
88 | 115 | self.use_marlin = self._should_use_marlin() |
89 | 116 |
|
| 117 | + if current_platform.is_device_capability(100) and not has_flashinfer(): |
| 118 | + logger.warning_once( |
| 119 | + "MXFP4 MoE is enabled on Blackwell but FlashInfer " |
| 120 | + "is not available. This may result in degraded performance. " |
| 121 | + "Please `pip install vllm[flashinfer]` for best results.") |
| 122 | + |
90 | 123 | def _should_use_marlin(self): |
91 | 124 | if envs.VLLM_MXFP4_USE_MARLIN is not None: |
92 | 125 | return envs.VLLM_MXFP4_USE_MARLIN |
93 | 126 | if current_platform.is_cuda() and \ |
94 | | - not current_platform.has_device_capability(100): |
95 | | - if not current_platform.is_device_capability(90): |
| 127 | + not current_platform.is_device_capability(100): |
| 128 | + if not current_platform.has_device_capability(90): |
96 | 129 | # marlin kernel has better performance on ampere |
97 | 130 | return True |
98 | 131 | if not has_triton_kernels(): |
@@ -138,8 +171,7 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, |
138 | 171 | layer.hidden_size = hidden_size |
139 | 172 | layer.intermediate_size_per_partition = \ |
140 | 173 | intermediate_size_per_partition_after_pad |
141 | | - elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 |
142 | | - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): |
| 174 | + elif should_use_flashinfer_mxfp4(): |
143 | 175 | # pad the intermediate size to be a multiple of 2 * mxfp4_block |
144 | 176 | # for to hold non-uniform sharded tensor as well as swizzling |
145 | 177 | # other padding to increase performance |
@@ -230,8 +262,8 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, |
230 | 262 | def process_weights_after_loading(self, layer): |
231 | 263 | if self.use_marlin: |
232 | 264 | prepare_moe_fp4_layer_for_marlin(layer) |
233 | | - elif (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 |
234 | | - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): |
| 265 | + elif should_use_flashinfer_mxfp4(): |
| 266 | + from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a |
235 | 267 | layer.gemm1_alpha = Parameter(torch.tensor( |
236 | 268 | [1.702] * self.num_experts, dtype=torch.float32).cuda(), |
237 | 269 | requires_grad=False) |
@@ -478,11 +510,11 @@ def apply( |
478 | 510 | logical_replica_count), ( |
479 | 511 | "MXFP4 are not supported with this configuration.") |
480 | 512 |
|
481 | | - if (envs.VLLM_USE_FLASHINFER_MOE_MXFP4_MXFP8 |
482 | | - or envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16): |
| 513 | + if should_use_flashinfer_mxfp4(): |
| 514 | + from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe |
483 | 515 | assert not self.moe.use_ep, ( |
484 | 516 | "EP is not supported for flashinfer mxfp4 moe backend yet.") |
485 | | - if envs.VLLM_USE_FLASHINFER_MOE_MXFP4_BF16: |
| 517 | + if _should_use_flashinfer_mxfp4_bf16(): |
486 | 518 | assert x.dtype == torch.bfloat16 |
487 | 519 | x_quant = x |
488 | 520 | x_scale = None |
|
0 commit comments