-
Notifications
You must be signed in to change notification settings - Fork 725
Reduce allocation overhead in quantized sdpa #15610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: gh/kimishpatel/202/base
Are you sure you want to change the base?
Conversation
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/15610
Note: Links to docs will display an error until the docs builds have been completed. ❌ 9 New Failures, 4 Unrelated FailuresAs of commit 602a3a7 with merge base 7600df8 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
Pull Request resolved: #15610 For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. ghstack-source-id: 321455128 @exported-using-ghexport Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/)
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
For small models dequantizing portions of v cache causes extra alloc overhead. Probably a better way to handle this is to dequantize entire v cache outside the model There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help. Differential Revision: [D85532077](https://our.internmc.facebook.com/intern/diff/D85532077/) [ghstack-poisoned]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the quantized scaled dot-product attention (SDPA) implementation to reduce allocation overhead by moving the dequantization buffer allocation from inside the dequant_and_gemm function to the outer cpu_flash_attention scope. Instead of allocating a new std::vector for each dequantization operation, a pre-allocated per-thread scratch buffer is now shared across iterations.
Key changes:
- Added
buf_qdq_ptrparameter todequant_and_gemmand_qk_at_v_gemmfunctions to accept externally allocated dequantization buffers - Allocated a shared scratch buffer (
scratch_for_quant_dequant) incpu_flash_attentionwith per-thread partitioning - Removed the local
std::vector<float> dequantized_v_dataallocation fromdequant_and_gemm
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, | ||
| // query.options()); | ||
| int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize; | ||
| // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads, |
Copilot
AI
Nov 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The comment says "align to 64 bytes" but kAlignment = 32 aligns to 32 elements. Since size_per_thread_qdq_vec is an element count (not byte count), and assuming accum_t is float (4 bytes), this aligns to 128 bytes (32 * 4), not 64 bytes.
Either:
- Change
kAlignmentto 16 if 64-byte alignment is desired, or - Update the comment to say "align to 32 elements" or "align to 128 bytes (for float)"
| // Lets align size_per_thread_qdq_vec to 64 bytes, for coalesced cache reads, | |
| // Lets align size_per_thread_qdq_vec to 32 elements (128 bytes for float), for coalesced cache reads, |
| // by padding with right number of per thread elements | ||
| constexpr int64_t kAlignment = 32; | ||
| size_per_thread_qdq_vec = | ||
| (size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1)); |
Copilot
AI
Nov 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The alignment calculation is incorrect. The formula (x + kAlignment - 1) & (-(kAlignment - 1)) uses the wrong mask.
For aligning to a power-of-2 boundary, the correct formula is:
(size_per_thread_qdq_vec + kAlignment - 1) & (-kAlignment)or equivalently:
(size_per_thread_qdq_vec + kAlignment - 1) & ~(kAlignment - 1)The current code uses -(kAlignment - 1) which equals -31 = 0xFFFFFFE1, but the correct mask should be -32 = 0xFFFFFFE0 to properly zero out the bottom 5 bits.
| (size_per_thread_qdq_vec + kAlignment - 1) & (-(kAlignment - 1)); | |
| (size_per_thread_qdq_vec + kAlignment - 1) & -kAlignment; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@copilot are you sure? Please double check again
| // at::Tensor buf_reduced = at::empty( | ||
| // {num_thread, qSplitSize, is_reduced_type ? kvSplitSize : 0}, | ||
| // query.options()); | ||
| int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize; |
Copilot
AI
Nov 17, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The buffer size calculation appears to be larger than necessary. The dequantize operation needs kvBlockSize * headSize elements (at most kvSplitSize * headSize), but this allocates qSplitSize * kvSplitSize * headSize. The extra qSplitSize factor seems unnecessary and wastes memory per thread.
Consider changing to:
int64_t size_per_thread_qdq_vec = kvSplitSize * headSize;| int64_t size_per_thread_qdq_vec = qSplitSize * kvSplitSize * headSize; | |
| int64_t size_per_thread_qdq_vec = kvSplitSize * headSize; |
|
@mergennachin I've opened a new pull request, #15852, to work on those changes. Once the pull request is ready, I'll request review from you. |
Stack from ghstack (oldest at bottom):
For small models dequantizing portions of v cache causes extra alloc overhead.
Probably a better way to handle this is to dequantize entire v cache outside the model
There isnt significant perf advantage from this yet but subsequent diffs will use caching allocator where this refactor help.
Differential Revision: D85532077