11import argparse
22import time
33from datetime import datetime
4+ from itertools import product
45from typing import Any , Dict , List , Tuple , TypedDict
56
67import ray
1112
1213from vllm .model_executor .layers .fused_moe .fused_moe import *
1314from vllm .platforms import current_platform
14- from vllm .utils import FlexibleArgumentParser
15+ from vllm .utils import FlexibleArgumentParser , is_navi
16+
17+ FP8_DTYPE = torch .float8_e4m3fnuz if current_platform .is_rocm (
18+ ) and not is_navi () else torch .float8_e4m3fn
1519
1620
1721class BenchmarkConfig (TypedDict ):
@@ -80,8 +84,8 @@ def benchmark_config(
8084 a1_scale = torch .randn (1 , dtype = torch .float32 )
8185 a2_scale = torch .randn (1 , dtype = torch .float32 )
8286
83- w1 = w1 .to (torch . float8_e4m3fn )
84- w2 = w2 .to (torch . float8_e4m3fn )
87+ w1 = w1 .to (FP8_DTYPE )
88+ w2 = w2 .to (FP8_DTYPE )
8589
8690 input_gating = torch .empty (num_tokens , num_experts , dtype = torch .float32 )
8791
@@ -141,35 +145,183 @@ def run():
141145 return avg
142146
143147
144- def get_configs_compute_bound () -> List [Dict [str , int ]]:
145- # Reduced search space for faster tuning.
146- # TODO(woosuk): Increase the search space and use a performance model to
147- # prune the search space.
148+ def get_rocm_tuning_space (use_fp16 ):
149+ block_mn_range = [16 , 32 , 64 , 128 , 256 ]
150+ block_k_range = [16 , 32 , 64 , 128 , 256 ]
151+ if not use_fp16 :
152+ block_k_range .remove (16 ) # BLOCK_K=16 not supported for fp8
153+ num_warps_range = [1 , 2 , 4 , 8 ]
154+ group_m_range = [1 , 4 , 8 , 16 , 32 ]
155+ num_stage_range = [2 ]
156+ waves_per_eu_range = [0 ]
157+ matrix_instr_nonkdim_range = [16 , 32 ] if use_fp16 else []
158+ kpack_range = [1 , 2 ] if use_fp16 else []
159+
160+ param_ranges = {
161+ "BLOCK_SIZE_M" : block_mn_range ,
162+ "BLOCK_SIZE_N" : block_mn_range ,
163+ "BLOCK_SIZE_K" : block_k_range ,
164+ "GROUP_SIZE_M" : group_m_range ,
165+ "num_warps" : num_warps_range ,
166+ "num_stages" : num_stage_range ,
167+ "waves_per_eu" : waves_per_eu_range ,
168+ }
169+ if use_fp16 :
170+ param_ranges ["matrix_instr_nonkdim" ] = matrix_instr_nonkdim_range
171+ param_ranges ["kpack" ] = kpack_range
172+
173+ return param_ranges
174+
175+
176+ def get_configs_compute_bound (use_fp16 ) -> List [Dict [str , int ]]:
148177 configs : List [BenchmarkConfig ] = []
149- for num_stages in [2 , 3 , 4 , 5 ]:
150- for block_m in [16 , 32 , 64 , 128 , 256 ]:
151- for block_k in [64 , 128 , 256 ]:
152- for block_n in [32 , 64 , 128 , 256 ]:
153- for num_warps in [4 , 8 ]:
154- for group_size in [1 , 16 , 32 , 64 ]:
155- configs .append ({
156- "BLOCK_SIZE_M" : block_m ,
157- "BLOCK_SIZE_N" : block_n ,
158- "BLOCK_SIZE_K" : block_k ,
159- "GROUP_SIZE_M" : group_size ,
160- "num_warps" : num_warps ,
161- "num_stages" : num_stages ,
162- })
178+
179+ if current_platform .is_rocm ():
180+ param_ranges = get_rocm_tuning_space (use_fp16 )
181+ else :
182+ # Reduced search space for faster tuning.
183+ # TODO(woosuk): Increase the search space and use a performance model to
184+ # prune the search space.
185+ block_m_range = [16 , 32 , 64 , 128 , 256 ]
186+ block_n_range = [32 , 64 , 128 , 256 ]
187+ block_k_range = [64 , 128 , 256 ]
188+ num_warps_range = [4 , 8 ]
189+ group_m_range = [1 , 16 , 32 , 64 ]
190+ num_stage_range = [2 , 3 , 4 , 5 ]
191+
192+ param_ranges = {
193+ "BLOCK_SIZE_M" : block_m_range ,
194+ "BLOCK_SIZE_N" : block_n_range ,
195+ "BLOCK_SIZE_K" : block_k_range ,
196+ "GROUP_SIZE_M" : group_m_range ,
197+ "num_warps" : num_warps_range ,
198+ "num_stages" : num_stage_range ,
199+ }
200+
201+ keys , values = zip (* param_ranges .items ())
202+ for config_values in product (* values ):
203+ config = dict (zip (keys , config_values ))
204+ configs .append (config )
163205 return configs
164206
165207
208+ def prune_rocm_search_space (num_tokens , shard_intermediate_size , hidden_size ,
209+ search_space , is_fp16 ):
210+ N1 , K1 = shard_intermediate_size , hidden_size
211+ N2 , K2 = hidden_size , shard_intermediate_size // 2
212+ pruned_space_1 = prune_rocm_configs (num_tokens * 2 , N1 , K1 , search_space ,
213+ is_fp16 )
214+ pruned_space_2 = prune_rocm_configs (num_tokens * 2 , N2 , K2 , search_space ,
215+ is_fp16 )
216+ search_space = merge_unique_dicts (pruned_space_1 , pruned_space_2 )
217+ return search_space
218+
219+
220+ # The following code is inspired by ROCm/Triton GEMM tuning script:
221+ # https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
222+ def prune_rocm_configs (M , N , K , configs , is_fp16 = True ):
223+ pruned_configs = []
224+ elemBytes_a = 2 if is_fp16 else 1
225+ elemBytes_b = 2 if is_fp16 else 1
226+
227+ mfma = 16 if M < 32 or N < 32 else 32
228+
229+ # TODO (zhanglx): figure out the boundary between large and small gemms
230+ large_gemm = False
231+ if M >= 2048 and N >= 2048 :
232+ large_gemm = True
233+
234+ for config in configs :
235+ BLOCK_SIZE_M = config .get ("BLOCK_SIZE_M" )
236+ BLOCK_SIZE_N = config .get ("BLOCK_SIZE_N" )
237+ BLOCK_SIZE_K = config .get ("BLOCK_SIZE_K" )
238+ num_warps = config .get ("num_warps" )
239+
240+ if is_fp16 :
241+ matrix_instr_nonkdim = config .get ("matrix_instr_nonkdim" )
242+ if matrix_instr_nonkdim > mfma :
243+ continue
244+ if mfma == 4 and BLOCK_SIZE_K < 64 :
245+ continue
246+ # some layouts could not work properly in case
247+ # number elements per thread is less 1
248+ if BLOCK_SIZE_M * BLOCK_SIZE_N < 64 :
249+ continue
250+ SPLIT_K = config .get ("SPLIT_K" , 1 )
251+ GROUP_M = config .get ("GROUP_SIZE_M" )
252+ if is_fp16 :
253+ if (matrix_instr_nonkdim > BLOCK_SIZE_M
254+ or matrix_instr_nonkdim > BLOCK_SIZE_N ):
255+ continue
256+ if (matrix_instr_nonkdim >= M
257+ and matrix_instr_nonkdim != BLOCK_SIZE_M ):
258+ continue
259+ if (matrix_instr_nonkdim >= N
260+ and matrix_instr_nonkdim != BLOCK_SIZE_N ):
261+ continue
262+ # Skip BLOCK_SIZE that is too large compare to M/N
263+ # unless BLOCK_SIZE is already small enough
264+ if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16 :
265+ continue
266+ if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16 :
267+ continue
268+ # skip large split_k when not necessary
269+ if SPLIT_K != 1 and not need_split_k (M , N , K ):
270+ continue
271+ # skip split_k that leads to EVEN_K = false
272+ leap = SPLIT_K * BLOCK_SIZE_K
273+ modv = K % leap
274+ if modv != 0 :
275+ continue
276+ # skip large GROUP_M
277+ if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1 :
278+ continue
279+ # out of shared memory resource
280+ # TODO (zhanglx): This does not consider the LDS usage in the epilogue
281+ LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
282+ BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b )
283+ if LDS > 65536 :
284+ continue
285+ # Skip small block sizes and num_warps for large gemm
286+ # For fp16 and f8, we want to only use BLOCK_SIZE >= 64
287+ if large_gemm :
288+ if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64 :
289+ continue
290+ if BLOCK_SIZE_K < 64 :
291+ continue
292+ if num_warps < 4 :
293+ continue
294+
295+ pruned_configs .append (config )
296+
297+ return pruned_configs
298+
299+
300+ def need_split_k (SIZE_M , SIZE_N , SIZE_K ):
301+ return (SIZE_M < 64 or SIZE_N < 64 ) and SIZE_K > 1024
302+
303+
304+ def merge_unique_dicts (list1 , list2 ):
305+ result = []
306+ combined_list = list1 .copy ()
307+ combined_list .extend (list2 )
308+ for dictionary in combined_list :
309+ if dictionary not in result :
310+ result .append (dictionary )
311+ return result
312+
313+
166314@ray .remote (num_gpus = 1 )
167315class BenchmarkWorker :
168316
169317 def __init__ (self , seed : int ) -> None :
170318 torch .set_default_device ("cuda" )
171319 current_platform .seed_everything (seed )
172320 self .seed = seed
321+ # Get the device ID to allocate tensors and kernels
322+ # on the respective GPU. This is required for Ray to work
323+ # correctly with multi-GPU tuning on the ROCm platform.
324+ self .device_id = int (ray .get_gpu_ids ()[0 ])
173325
174326 def benchmark (
175327 self ,
@@ -217,25 +369,33 @@ def tune(
217369 ) -> Dict [str , int ]:
218370 best_config = None
219371 best_time = float ("inf" )
220- for config in tqdm (search_space ):
221- try :
222- kernel_time = benchmark_config (config ,
223- num_tokens ,
224- num_experts ,
225- shard_intermediate_size ,
226- hidden_size ,
227- topk ,
228- dtype ,
229- use_fp8_w8a8 ,
230- use_int8_w8a16 ,
231- num_iters = 10 )
232- except triton .runtime .autotuner .OutOfResources :
233- # Some configurations may be invalid and fail to compile.
234- continue
235-
236- if kernel_time < best_time :
237- best_time = kernel_time
238- best_config = config
372+ if current_platform .is_rocm ():
373+ is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 )
374+ search_space = prune_rocm_search_space (num_tokens ,
375+ shard_intermediate_size ,
376+ hidden_size , search_space ,
377+ is_fp16 )
378+
379+ with torch .cuda .device (self .device_id ):
380+ for config in tqdm (search_space ):
381+ try :
382+ kernel_time = benchmark_config (config ,
383+ num_tokens ,
384+ num_experts ,
385+ shard_intermediate_size ,
386+ hidden_size ,
387+ topk ,
388+ dtype ,
389+ use_fp8_w8a8 ,
390+ use_int8_w8a16 ,
391+ num_iters = 20 )
392+ except triton .runtime .autotuner .OutOfResources :
393+ # Some configurations may be invalid and fail to compile.
394+ continue
395+
396+ if kernel_time < best_time :
397+ best_time = kernel_time
398+ best_config = config
239399 now = datetime .now ()
240400 print (f"{ now .ctime ()} ] Completed tuning for batch_size={ num_tokens } " )
241401 assert best_config is not None
@@ -244,12 +404,27 @@ def tune(
244404
245405def sort_config (config : BenchmarkConfig ) -> BenchmarkConfig :
246406 return {
247- "BLOCK_SIZE_M" : config ["BLOCK_SIZE_M" ],
248- "BLOCK_SIZE_N" : config ["BLOCK_SIZE_N" ],
249- "BLOCK_SIZE_K" : config ["BLOCK_SIZE_K" ],
250- "GROUP_SIZE_M" : config ["GROUP_SIZE_M" ],
251- "num_warps" : config ["num_warps" ],
252- "num_stages" : config ["num_stages" ],
407+ "BLOCK_SIZE_M" :
408+ config ["BLOCK_SIZE_M" ],
409+ "BLOCK_SIZE_N" :
410+ config ["BLOCK_SIZE_N" ],
411+ "BLOCK_SIZE_K" :
412+ config ["BLOCK_SIZE_K" ],
413+ "GROUP_SIZE_M" :
414+ config ["GROUP_SIZE_M" ],
415+ "num_warps" :
416+ config ["num_warps" ],
417+ "num_stages" :
418+ config ["num_stages" ],
419+ ** ({
420+ "waves_per_eu" : config ["waves_per_eu" ]
421+ } if "waves_per_eu" in config else {}),
422+ ** ({
423+ "matrix_instr_nonkdim" : config ["matrix_instr_nonkdim" ]
424+ } if "matrix_instr_nonkdim" in config else {}),
425+ ** ({
426+ "kpack" : config ["kpack" ]
427+ } if "kpack" in config else {}),
253428 }
254429
255430
@@ -294,7 +469,7 @@ def main(args: argparse.Namespace):
294469 shard_intermediate_size = 2 * intermediate_size // args .tp_size
295470
296471 hidden_size = config .hidden_size
297- dtype = config .torch_dtype
472+ dtype = torch . float16 if current_platform . is_rocm () else config .torch_dtype
298473 use_fp8_w8a8 = args .dtype == "fp8_w8a8"
299474 use_int8_w8a16 = args .dtype == "int8_w8a16"
300475
@@ -322,7 +497,8 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
322497 return ray .get (outputs )
323498
324499 if args .tune :
325- search_space = get_configs_compute_bound ()
500+ is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 )
501+ search_space = get_configs_compute_bound (is_fp16 )
326502 print (f"Start tuning over { len (search_space )} configurations..." )
327503
328504 start = time .time ()
0 commit comments