Skip to content

[Feature][WideEP]: More "masked-m" GEMM and quant kernels for use with DeepEP LowLatency + PPLX All2Alls #25484

@tlrmchlsmth

Description

@tlrmchlsmth

🚀 The feature, motivation and pitch

The dispatch operation in DeepEP LowLatency and PPLX-kernels is a hidden_states tensor with shape [num_local_experts, max_num_tokens_per_expert, hidden_size].

max_num_tokens_per_expert may be much larger than the actual number of tokens for each expert. To avoid useless work, we need efficient "masked-m" kernels that only work on the relevant parts of the tensors.

Currently we use DeepGEMM and a fused-silu-mul-quant kernel for blocked fp8 formats. We'd need to expand on this to support more architectures + quantization formats.

Alternatives

Alternatively, we could use the DeepEP high throughput kernels, which have a different format that don't have masking and could work with grouped GEMM kernels instead. DeepEP HT doesn't support CUDA Graphs, so we'd need to figure out how to support CUDA Graphs in this format.

Additional context

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions