Skip to content

Commit 0fb5c4b

Browse files
bnellnmchoprahetarth
authored andcommitted
[Kernels] Overlap shared experts with combine instead of dispatch (vllm-project#24254)
Signed-off-by: Bill Nell <[email protected]>
1 parent 707d394 commit 0fb5c4b

File tree

4 files changed

+203
-36
lines changed

4 files changed

+203
-36
lines changed

vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -240,15 +240,16 @@ def prepare(
240240
quant_config)
241241
return receiver()
242242

243-
def finalize(
243+
def _finalize(
244244
self,
245245
output: torch.Tensor,
246246
fused_expert_output: torch.Tensor,
247247
topk_weights: torch.Tensor,
248248
topk_ids: torch.Tensor,
249249
apply_router_weight_on_input: bool,
250250
weight_and_reduce_impl: mk.TopKWeightAndReduce,
251-
) -> None:
251+
do_async: bool,
252+
) -> Optional[Callable]:
252253

253254
assert self.handle is not None
254255

@@ -271,7 +272,46 @@ def finalize(
271272
topk_weights=None,
272273
config=self._get_combine_config(),
273274
previous_event=None,
274-
async_finish=False,
275+
async_finish=do_async,
275276
allocate_on_comm_stream=False)
276-
# Respect inplace outputs.
277-
output.copy_(combined_x, non_blocking=True)
277+
278+
if do_async:
279+
280+
def _receiver():
281+
event.current_stream_wait()
282+
# Respect inplace outputs.
283+
output.copy_(combined_x, non_blocking=True)
284+
285+
return lambda: _receiver()
286+
else:
287+
# Respect inplace outputs.
288+
output.copy_(combined_x, non_blocking=True)
289+
return None
290+
291+
def finalize_async(
292+
self,
293+
output: torch.Tensor,
294+
fused_expert_output: torch.Tensor,
295+
topk_weights: torch.Tensor,
296+
topk_ids: torch.Tensor,
297+
apply_router_weight_on_input: bool,
298+
weight_and_reduce_impl: mk.TopKWeightAndReduce,
299+
) -> Callable:
300+
receiver = self._finalize(output, fused_expert_output, topk_weights,
301+
topk_ids, apply_router_weight_on_input,
302+
weight_and_reduce_impl, True)
303+
assert receiver is not None
304+
return receiver
305+
306+
def finalize(
307+
self,
308+
output: torch.Tensor,
309+
fused_expert_output: torch.Tensor,
310+
topk_weights: torch.Tensor,
311+
topk_ids: torch.Tensor,
312+
apply_router_weight_on_input: bool,
313+
weight_and_reduce_impl: mk.TopKWeightAndReduce,
314+
) -> None:
315+
self._finalize(output, fused_expert_output, topk_weights, topk_ids,
316+
apply_router_weight_on_input, weight_and_reduce_impl,
317+
False)

vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
from vllm.model_executor.layers.fused_moe.utils import (
1313
moe_kernel_quantize_input, normalize_batched_scales_shape)
1414
from vllm.v1.worker.ubatching import (dbo_current_ubatch_id, dbo_enabled,
15-
dbo_maybe_run_recv_hook,
16-
dbo_register_recv_hook, dbo_yield)
15+
dbo_maybe_run_recv_hook)
1716

1817
# DeepEP kernels quantize dispatch inputs in 128 element chunks.
1918
DEEPEP_QUANT_BLOCK_SIZE = 128
@@ -198,21 +197,22 @@ def prepare(
198197
hook()
199198
return receiver()
200199

201-
def finalize(
200+
def _finalize(
202201
self,
203202
output: torch.Tensor,
204203
fused_expert_output: torch.Tensor,
205204
topk_weights: torch.Tensor,
206205
topk_ids: torch.Tensor,
207206
apply_router_weight_on_input: bool,
208207
weight_and_reduce_impl: mk.TopKWeightAndReduce,
209-
) -> None:
208+
do_async: bool,
209+
) -> Optional[Callable]:
210210
assert isinstance(
211211
weight_and_reduce_impl, TopKWeightAndReduceDelegate
212212
), ("Weight application and reduction happens in the combine kernel.")
213213

214214
a2a_idx = dbo_current_ubatch_id()
215-
do_recv_hook = dbo_enabled()
215+
do_recv_hook = dbo_enabled() or do_async
216216
handle = self.handles[a2a_idx]
217217
assert handle is not None
218218

@@ -232,6 +232,45 @@ def finalize(
232232
zero_copy=False,
233233
return_recv_hook=do_recv_hook,
234234
out=output)
235-
if recv_hook is not None:
236-
dbo_register_recv_hook(recv_hook)
237-
dbo_yield()
235+
236+
return recv_hook
237+
238+
def finalize_async(
239+
self,
240+
output: torch.Tensor,
241+
fused_expert_output: torch.Tensor,
242+
topk_weights: torch.Tensor,
243+
topk_ids: torch.Tensor,
244+
apply_router_weight_on_input: bool,
245+
weight_and_reduce_impl: mk.TopKWeightAndReduce,
246+
) -> Callable:
247+
recv_hook = self._finalize(
248+
output,
249+
fused_expert_output,
250+
topk_weights,
251+
topk_ids,
252+
apply_router_weight_on_input,
253+
weight_and_reduce_impl,
254+
do_async=True,
255+
)
256+
assert recv_hook is not None
257+
return recv_hook
258+
259+
def finalize(
260+
self,
261+
output: torch.Tensor,
262+
fused_expert_output: torch.Tensor,
263+
topk_weights: torch.Tensor,
264+
topk_ids: torch.Tensor,
265+
apply_router_weight_on_input: bool,
266+
weight_and_reduce_impl: mk.TopKWeightAndReduce,
267+
) -> None:
268+
self._finalize(
269+
output,
270+
fused_expert_output,
271+
topk_weights,
272+
topk_ids,
273+
apply_router_weight_on_input,
274+
weight_and_reduce_impl,
275+
do_async=False,
276+
)

vllm/model_executor/layers/fused_moe/modular_kernel.py

Lines changed: 76 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ def prepare(
209209

210210
def supports_async(self) -> bool:
211211
"""
212-
Indicates whether or not this class implements prepare_async.
212+
Indicates whether or not this class implements prepare_async and
213+
finalize_async.
213214
"""
214215
return False
215216

@@ -275,6 +276,42 @@ def finalize(
275276
"""
276277
raise NotImplementedError
277278

279+
def finalize_async(
280+
self,
281+
output: torch.Tensor,
282+
fused_expert_output: torch.Tensor,
283+
topk_weights: torch.Tensor,
284+
topk_ids: torch.Tensor,
285+
apply_router_weight_on_input: bool,
286+
weight_and_reduce_impl: TopKWeightAndReduce,
287+
) -> Callable:
288+
"""
289+
Perform any combine plus apply weights and perform a reduction on the
290+
fused experts output but do not wait for results from other workers.
291+
- output: The output tensor, written in place. Must be (M, K) shape.
292+
- fused_expert_output: The unweighted, unreduced output of the fused
293+
experts, it will have (M, topk, K) shape.
294+
- topk_weights: The weights to be applied to the fused_experts_output.
295+
- topk_ids: The topk_ids.
296+
- apply_router_weight_on_input: When False, apply the weights to
297+
fused_expert_output.
298+
- weight_and_reduce_impl: An optional TopKWeightAndReduce
299+
implementation.
300+
301+
Returns a callback that when invoked waits for results from other
302+
workers and has the same return signature as `finalize`, e.g.
303+
304+
receiver = obj.finalize_async(output, ...)
305+
... output not valid yet ...
306+
receiver()
307+
... output valid here ...
308+
309+
is equivalent to:
310+
311+
obj.finalize(output, ...)
312+
"""
313+
raise NotImplementedError
314+
278315
@property
279316
@abstractmethod
280317
def activation_format(self) -> FusedMoEActivationFormat:
@@ -814,23 +851,20 @@ def forward(
814851
"""
815852

816853
a1 = hidden_states
817-
output = a1 if inplace else torch.zeros_like(a1)
854+
if inplace and self.shared_experts is None:
855+
output = a1
856+
else:
857+
output = torch.zeros_like(a1)
818858

819859
local_num_experts = w1.size(0)
820860
if global_num_experts == -1:
821861
global_num_experts = local_num_experts
822862

823-
shared_output: torch.Tensor
824-
825863
if not self.prepare_finalize.supports_async():
826864
# We shouldn't be running an a2a kernel that doesn't
827865
# support async prepare/finalize
828866
assert not dbo_enabled()
829867

830-
# Run shared experts serially with dispatch.
831-
if self.shared_experts is not None:
832-
shared_output = self.shared_experts(a1)
833-
834868
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
835869
_expert_topk_weights) = self.prepare_finalize.prepare(
836870
a1,
@@ -854,9 +888,6 @@ def forward(
854888
self.fused_experts.quant_config,
855889
)
856890

857-
if self.shared_experts is not None:
858-
shared_output = self.shared_experts(a1)
859-
860891
# If DBO is being used, register the hook with the ubatch context
861892
# and call it in dbo_maybe_run_recv_hook instead of passing it to
862893
# the receiver.
@@ -900,16 +931,42 @@ def forward(
900931
apply_router_weight_on_input=apply_router_weight_on_input,
901932
)
902933

903-
self.prepare_finalize.finalize(
904-
output,
905-
fused_out,
906-
topk_weights,
907-
topk_ids,
908-
apply_router_weight_on_input,
909-
self.fused_experts.finalize_weight_and_reduce_impl(),
910-
)
934+
shared_output: Optional[torch.Tensor] = None
935+
936+
if not self.prepare_finalize.supports_async():
937+
assert not dbo_enabled()
938+
939+
self.prepare_finalize.finalize(
940+
output,
941+
fused_out,
942+
topk_weights,
943+
topk_ids,
944+
apply_router_weight_on_input,
945+
self.fused_experts.finalize_weight_and_reduce_impl(),
946+
)
947+
if self.shared_experts is not None:
948+
shared_output = self.shared_experts(a1)
949+
else:
950+
recv_hook = self.prepare_finalize.finalize_async(
951+
output,
952+
fused_out,
953+
topk_weights,
954+
topk_ids,
955+
apply_router_weight_on_input,
956+
self.fused_experts.finalize_weight_and_reduce_impl(),
957+
)
958+
959+
if self.shared_experts is not None:
960+
shared_output = self.shared_experts(a1)
961+
962+
assert recv_hook is not None
963+
dbo_register_recv_hook(recv_hook)
964+
dbo_yield()
965+
if not dbo_enabled():
966+
recv_hook()
911967

912968
if self.shared_experts is None:
913969
return output
914970
else:
971+
assert shared_output is not None
915972
return shared_output, output

vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -272,15 +272,15 @@ def prepare(
272272
hook()
273273
return receiver()
274274

275-
def finalize(
275+
def finalize_async(
276276
self,
277277
output: torch.Tensor,
278278
fused_expert_output: torch.Tensor,
279279
topk_weights: torch.Tensor,
280280
topk_ids: torch.Tensor,
281281
apply_router_weight_on_input: bool,
282282
weight_and_reduce_impl: mk.TopKWeightAndReduce,
283-
) -> None:
283+
) -> Callable:
284284
assert isinstance(
285285
weight_and_reduce_impl, TopKWeightAndReduceDelegate
286286
), ("Weight application and reduction happens in the combine kernel.")
@@ -303,8 +303,39 @@ def finalize(
303303
if apply_router_weight_on_input:
304304
topk_weights = torch.ones_like(topk_weights)
305305

306+
topk_ids_u32 = topk_ids.view(dtype=torch.uint32)
307+
306308
self.a2a.combine(out_tokens=output,
307-
indices=topk_ids.view(dtype=torch.uint32),
309+
indices=topk_ids_u32,
308310
weights=topk_weights,
309311
expert_y=fused_expert_output,
310-
bound_m=bound_m)
312+
bound_m=bound_m,
313+
do_send=True,
314+
do_recv=False)
315+
316+
return lambda: self.a2a.combine(out_tokens=output,
317+
indices=topk_ids_u32,
318+
weights=topk_weights,
319+
expert_y=fused_expert_output,
320+
bound_m=bound_m,
321+
do_send=False,
322+
do_recv=True)
323+
324+
def finalize(
325+
self,
326+
output: torch.Tensor,
327+
fused_expert_output: torch.Tensor,
328+
topk_weights: torch.Tensor,
329+
topk_ids: torch.Tensor,
330+
apply_router_weight_on_input: bool,
331+
weight_and_reduce_impl: mk.TopKWeightAndReduce,
332+
) -> None:
333+
receiver = self.finalize_async(
334+
output,
335+
fused_expert_output,
336+
topk_weights,
337+
topk_ids,
338+
apply_router_weight_on_input,
339+
weight_and_reduce_impl,
340+
)
341+
receiver()

0 commit comments

Comments
 (0)