Skip to content

Commit 002038a

Browse files
divakar-amdIsotr0py
authored andcommitted
[ROCm][MoE] moe tuning support for rocm (vllm-project#12049)
Signed-off-by: Divakar Verma <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent 5c34bbd commit 002038a

File tree

1 file changed

+224
-48
lines changed

1 file changed

+224
-48
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 224 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
22
import time
33
from datetime import datetime
4+
from itertools import product
45
from typing import Any, Dict, List, Tuple, TypedDict
56

67
import ray
@@ -11,7 +12,10 @@
1112

1213
from vllm.model_executor.layers.fused_moe.fused_moe import *
1314
from vllm.platforms import current_platform
14-
from vllm.utils import FlexibleArgumentParser
15+
from vllm.utils import FlexibleArgumentParser, is_navi
16+
17+
FP8_DTYPE = torch.float8_e4m3fnuz if current_platform.is_rocm(
18+
) and not is_navi() else torch.float8_e4m3fn
1519

1620

1721
class BenchmarkConfig(TypedDict):
@@ -80,8 +84,8 @@ def benchmark_config(
8084
a1_scale = torch.randn(1, dtype=torch.float32)
8185
a2_scale = torch.randn(1, dtype=torch.float32)
8286

83-
w1 = w1.to(torch.float8_e4m3fn)
84-
w2 = w2.to(torch.float8_e4m3fn)
87+
w1 = w1.to(FP8_DTYPE)
88+
w2 = w2.to(FP8_DTYPE)
8589

8690
input_gating = torch.empty(num_tokens, num_experts, dtype=torch.float32)
8791

@@ -141,35 +145,183 @@ def run():
141145
return avg
142146

143147

144-
def get_configs_compute_bound() -> List[Dict[str, int]]:
145-
# Reduced search space for faster tuning.
146-
# TODO(woosuk): Increase the search space and use a performance model to
147-
# prune the search space.
148+
def get_rocm_tuning_space(use_fp16):
149+
block_mn_range = [16, 32, 64, 128, 256]
150+
block_k_range = [16, 32, 64, 128, 256]
151+
if not use_fp16:
152+
block_k_range.remove(16) # BLOCK_K=16 not supported for fp8
153+
num_warps_range = [1, 2, 4, 8]
154+
group_m_range = [1, 4, 8, 16, 32]
155+
num_stage_range = [2]
156+
waves_per_eu_range = [0]
157+
matrix_instr_nonkdim_range = [16, 32] if use_fp16 else []
158+
kpack_range = [1, 2] if use_fp16 else []
159+
160+
param_ranges = {
161+
"BLOCK_SIZE_M": block_mn_range,
162+
"BLOCK_SIZE_N": block_mn_range,
163+
"BLOCK_SIZE_K": block_k_range,
164+
"GROUP_SIZE_M": group_m_range,
165+
"num_warps": num_warps_range,
166+
"num_stages": num_stage_range,
167+
"waves_per_eu": waves_per_eu_range,
168+
}
169+
if use_fp16:
170+
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
171+
param_ranges["kpack"] = kpack_range
172+
173+
return param_ranges
174+
175+
176+
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
148177
configs: List[BenchmarkConfig] = []
149-
for num_stages in [2, 3, 4, 5]:
150-
for block_m in [16, 32, 64, 128, 256]:
151-
for block_k in [64, 128, 256]:
152-
for block_n in [32, 64, 128, 256]:
153-
for num_warps in [4, 8]:
154-
for group_size in [1, 16, 32, 64]:
155-
configs.append({
156-
"BLOCK_SIZE_M": block_m,
157-
"BLOCK_SIZE_N": block_n,
158-
"BLOCK_SIZE_K": block_k,
159-
"GROUP_SIZE_M": group_size,
160-
"num_warps": num_warps,
161-
"num_stages": num_stages,
162-
})
178+
179+
if current_platform.is_rocm():
180+
param_ranges = get_rocm_tuning_space(use_fp16)
181+
else:
182+
# Reduced search space for faster tuning.
183+
# TODO(woosuk): Increase the search space and use a performance model to
184+
# prune the search space.
185+
block_m_range = [16, 32, 64, 128, 256]
186+
block_n_range = [32, 64, 128, 256]
187+
block_k_range = [64, 128, 256]
188+
num_warps_range = [4, 8]
189+
group_m_range = [1, 16, 32, 64]
190+
num_stage_range = [2, 3, 4, 5]
191+
192+
param_ranges = {
193+
"BLOCK_SIZE_M": block_m_range,
194+
"BLOCK_SIZE_N": block_n_range,
195+
"BLOCK_SIZE_K": block_k_range,
196+
"GROUP_SIZE_M": group_m_range,
197+
"num_warps": num_warps_range,
198+
"num_stages": num_stage_range,
199+
}
200+
201+
keys, values = zip(*param_ranges.items())
202+
for config_values in product(*values):
203+
config = dict(zip(keys, config_values))
204+
configs.append(config)
163205
return configs
164206

165207

208+
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
209+
search_space, is_fp16):
210+
N1, K1 = shard_intermediate_size, hidden_size
211+
N2, K2 = hidden_size, shard_intermediate_size // 2
212+
pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
213+
is_fp16)
214+
pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
215+
is_fp16)
216+
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
217+
return search_space
218+
219+
220+
# The following code is inspired by ROCm/Triton GEMM tuning script:
221+
# https://github.com/ROCm/triton/blob/triton-mlir/scripts/amd/gemm/tune_gemm.py#L89
222+
def prune_rocm_configs(M, N, K, configs, is_fp16=True):
223+
pruned_configs = []
224+
elemBytes_a = 2 if is_fp16 else 1
225+
elemBytes_b = 2 if is_fp16 else 1
226+
227+
mfma = 16 if M < 32 or N < 32 else 32
228+
229+
# TODO (zhanglx): figure out the boundary between large and small gemms
230+
large_gemm = False
231+
if M >= 2048 and N >= 2048:
232+
large_gemm = True
233+
234+
for config in configs:
235+
BLOCK_SIZE_M = config.get("BLOCK_SIZE_M")
236+
BLOCK_SIZE_N = config.get("BLOCK_SIZE_N")
237+
BLOCK_SIZE_K = config.get("BLOCK_SIZE_K")
238+
num_warps = config.get("num_warps")
239+
240+
if is_fp16:
241+
matrix_instr_nonkdim = config.get("matrix_instr_nonkdim")
242+
if matrix_instr_nonkdim > mfma:
243+
continue
244+
if mfma == 4 and BLOCK_SIZE_K < 64:
245+
continue
246+
# some layouts could not work properly in case
247+
# number elements per thread is less 1
248+
if BLOCK_SIZE_M * BLOCK_SIZE_N < 64:
249+
continue
250+
SPLIT_K = config.get("SPLIT_K", 1)
251+
GROUP_M = config.get("GROUP_SIZE_M")
252+
if is_fp16:
253+
if (matrix_instr_nonkdim > BLOCK_SIZE_M
254+
or matrix_instr_nonkdim > BLOCK_SIZE_N):
255+
continue
256+
if (matrix_instr_nonkdim >= M
257+
and matrix_instr_nonkdim != BLOCK_SIZE_M):
258+
continue
259+
if (matrix_instr_nonkdim >= N
260+
and matrix_instr_nonkdim != BLOCK_SIZE_N):
261+
continue
262+
# Skip BLOCK_SIZE that is too large compare to M/N
263+
# unless BLOCK_SIZE is already small enough
264+
if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16:
265+
continue
266+
if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16:
267+
continue
268+
# skip large split_k when not necessary
269+
if SPLIT_K != 1 and not need_split_k(M, N, K):
270+
continue
271+
# skip split_k that leads to EVEN_K = false
272+
leap = SPLIT_K * BLOCK_SIZE_K
273+
modv = K % leap
274+
if modv != 0:
275+
continue
276+
# skip large GROUP_M
277+
if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1:
278+
continue
279+
# out of shared memory resource
280+
# TODO (zhanglx): This does not consider the LDS usage in the epilogue
281+
LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a +
282+
BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b)
283+
if LDS > 65536:
284+
continue
285+
# Skip small block sizes and num_warps for large gemm
286+
# For fp16 and f8, we want to only use BLOCK_SIZE >= 64
287+
if large_gemm:
288+
if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64:
289+
continue
290+
if BLOCK_SIZE_K < 64:
291+
continue
292+
if num_warps < 4:
293+
continue
294+
295+
pruned_configs.append(config)
296+
297+
return pruned_configs
298+
299+
300+
def need_split_k(SIZE_M, SIZE_N, SIZE_K):
301+
return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024
302+
303+
304+
def merge_unique_dicts(list1, list2):
305+
result = []
306+
combined_list = list1.copy()
307+
combined_list.extend(list2)
308+
for dictionary in combined_list:
309+
if dictionary not in result:
310+
result.append(dictionary)
311+
return result
312+
313+
166314
@ray.remote(num_gpus=1)
167315
class BenchmarkWorker:
168316

169317
def __init__(self, seed: int) -> None:
170318
torch.set_default_device("cuda")
171319
current_platform.seed_everything(seed)
172320
self.seed = seed
321+
# Get the device ID to allocate tensors and kernels
322+
# on the respective GPU. This is required for Ray to work
323+
# correctly with multi-GPU tuning on the ROCm platform.
324+
self.device_id = int(ray.get_gpu_ids()[0])
173325

174326
def benchmark(
175327
self,
@@ -217,25 +369,33 @@ def tune(
217369
) -> Dict[str, int]:
218370
best_config = None
219371
best_time = float("inf")
220-
for config in tqdm(search_space):
221-
try:
222-
kernel_time = benchmark_config(config,
223-
num_tokens,
224-
num_experts,
225-
shard_intermediate_size,
226-
hidden_size,
227-
topk,
228-
dtype,
229-
use_fp8_w8a8,
230-
use_int8_w8a16,
231-
num_iters=10)
232-
except triton.runtime.autotuner.OutOfResources:
233-
# Some configurations may be invalid and fail to compile.
234-
continue
235-
236-
if kernel_time < best_time:
237-
best_time = kernel_time
238-
best_config = config
372+
if current_platform.is_rocm():
373+
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
374+
search_space = prune_rocm_search_space(num_tokens,
375+
shard_intermediate_size,
376+
hidden_size, search_space,
377+
is_fp16)
378+
379+
with torch.cuda.device(self.device_id):
380+
for config in tqdm(search_space):
381+
try:
382+
kernel_time = benchmark_config(config,
383+
num_tokens,
384+
num_experts,
385+
shard_intermediate_size,
386+
hidden_size,
387+
topk,
388+
dtype,
389+
use_fp8_w8a8,
390+
use_int8_w8a16,
391+
num_iters=20)
392+
except triton.runtime.autotuner.OutOfResources:
393+
# Some configurations may be invalid and fail to compile.
394+
continue
395+
396+
if kernel_time < best_time:
397+
best_time = kernel_time
398+
best_config = config
239399
now = datetime.now()
240400
print(f"{now.ctime()}] Completed tuning for batch_size={num_tokens}")
241401
assert best_config is not None
@@ -244,12 +404,27 @@ def tune(
244404

245405
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
246406
return {
247-
"BLOCK_SIZE_M": config["BLOCK_SIZE_M"],
248-
"BLOCK_SIZE_N": config["BLOCK_SIZE_N"],
249-
"BLOCK_SIZE_K": config["BLOCK_SIZE_K"],
250-
"GROUP_SIZE_M": config["GROUP_SIZE_M"],
251-
"num_warps": config["num_warps"],
252-
"num_stages": config["num_stages"],
407+
"BLOCK_SIZE_M":
408+
config["BLOCK_SIZE_M"],
409+
"BLOCK_SIZE_N":
410+
config["BLOCK_SIZE_N"],
411+
"BLOCK_SIZE_K":
412+
config["BLOCK_SIZE_K"],
413+
"GROUP_SIZE_M":
414+
config["GROUP_SIZE_M"],
415+
"num_warps":
416+
config["num_warps"],
417+
"num_stages":
418+
config["num_stages"],
419+
**({
420+
"waves_per_eu": config["waves_per_eu"]
421+
} if "waves_per_eu" in config else {}),
422+
**({
423+
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
424+
} if "matrix_instr_nonkdim" in config else {}),
425+
**({
426+
"kpack": config["kpack"]
427+
} if "kpack" in config else {}),
253428
}
254429

255430

@@ -294,7 +469,7 @@ def main(args: argparse.Namespace):
294469
shard_intermediate_size = 2 * intermediate_size // args.tp_size
295470

296471
hidden_size = config.hidden_size
297-
dtype = config.torch_dtype
472+
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
298473
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
299474
use_int8_w8a16 = args.dtype == "int8_w8a16"
300475

@@ -322,7 +497,8 @@ def _distribute(method: str, inputs: List[Any]) -> List[Any]:
322497
return ray.get(outputs)
323498

324499
if args.tune:
325-
search_space = get_configs_compute_bound()
500+
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
501+
search_space = get_configs_compute_bound(is_fp16)
326502
print(f"Start tuning over {len(search_space)} configurations...")
327503

328504
start = time.time()

0 commit comments

Comments
 (0)