Skip to content

Commit 900086f

Browse files
authored
[HybridKV][Bugfix] Fix Hybrid kvcache sharing bug in same attention type (#3760)
### What this PR does / why we need it? Part of #3106 Fix Hybrid kvcache sharing bug in same attention type Change the `shared_by` logic so that the same attention spec could share the same buffer instead of allocating more hbm. After this pr, kvcache memory saved 50% in qwen3-next compared with before (`self_attn:linear_attn=1:3` in an `attn_group`), and `gpu_memory_utilization` could increase to `0.8` on Qwen3-Next when running on A2 64G/card with tp4 <img width="2833" height="1540" alt="image" src="https://github.com/user-attachments/assets/2a91fa99-fb0f-447c-9e8b-acd587890fbe" /> ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Test pass with the latest e2e test case on qwen3-next - vLLM version: v0.11.0rc3 - vLLM main: vllm-project/vllm@c9461e0 --------- Signed-off-by: MengqingCao <[email protected]>
1 parent 789ba4c commit 900086f

File tree

2 files changed

+26
-20
lines changed

2 files changed

+26
-20
lines changed

tests/e2e/multicard/test_qwen3_next.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
def test_models_distributed_Qwen3_NEXT_TP4():
2828
example_prompts = [
2929
"Hello, my name is",
30-
]
30+
] * 4
3131
max_tokens = 5
3232
with VllmRunner("Qwen/Qwen3-Next-80B-A3B-Instruct",
3333
tensor_parallel_size=4,
3434
max_model_len=4096,
35-
gpu_memory_utilization=0.7,
35+
gpu_memory_utilization=0.8,
3636
distributed_executor_backend="mp",
3737
enforce_eager=True) as vllm_model:
3838
vllm_model.generate_greedy(example_prompts, max_tokens)

vllm_ascend/worker/model_runner_v1.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3225,25 +3225,26 @@ def initialize_kv_cache_tensors(
32253225
# TODO: REFACTOR ME to sharing hybrid cache
32263226
for idx in range(len(kv_cache_tensor.shared_by)):
32273227
layer_name = kv_cache_tensor.shared_by[idx]
3228-
if "linear_attn" in layer_name:
3228+
if "linear_attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys(
3229+
):
32293230
# for mamba linear attention
3231+
if self.vllm_config.kv_transfer_config is None:
3232+
tensor = torch.zeros(kv_cache_tensor.size,
3233+
dtype=torch.int8,
3234+
device=self.device)
3235+
else:
3236+
cache_size_aligned = kv_cache_tensor.size + alignment
3237+
tensor = torch.zeros(cache_size_aligned,
3238+
dtype=torch.int8,
3239+
device=self.device)
3240+
tensor = self._align_memory(
3241+
tensor, alignment)[:kv_cache_tensor.size]
32303242
for layer_name_inner in kv_cache_tensor.shared_by:
3231-
if ("attn" in layer_name_inner and "linear_attn" not in layer_name_inner) or \
3232-
layer_name_inner in kv_cache_raw_tensors.keys():
3233-
continue
3234-
if self.vllm_config.kv_transfer_config is None:
3235-
tensor = torch.zeros(kv_cache_tensor.size,
3236-
dtype=torch.int8,
3237-
device=self.device)
3238-
else:
3239-
cache_size_aligned = kv_cache_tensor.size + alignment
3240-
tensor = torch.zeros(cache_size_aligned,
3241-
dtype=torch.int8,
3242-
device=self.device)
3243-
tensor = self._align_memory(
3244-
tensor, alignment)[:kv_cache_tensor.size]
3245-
kv_cache_raw_tensors[layer_name_inner] = tensor
3246-
elif "attn" in layer_name:
3243+
# shared the kvcache between the linear_attn specs in the same group
3244+
if "linear_attn" in layer_name_inner:
3245+
kv_cache_raw_tensors[layer_name_inner] = tensor
3246+
elif "attn" in layer_name and layer_name not in kv_cache_raw_tensors.keys(
3247+
):
32473248
# for other attentions, e.g., self_attn, sliding window attn
32483249
if self.vllm_config.kv_transfer_config is None:
32493250
k_tensor = torch.zeros(kv_cache_tensor.size // 2,
@@ -3265,7 +3266,12 @@ def initialize_kv_cache_tensors(
32653266
alignment)[:cache_size]
32663267
v_tensor = self._align_memory(v_tensor,
32673268
alignment)[:cache_size]
3268-
kv_cache_raw_tensors[layer_name] = (k_tensor, v_tensor)
3269+
for layer_name_inner in kv_cache_tensor.shared_by:
3270+
# shared the kvcache between the self_attn specs in the same group
3271+
if ("attn" in layer_name_inner
3272+
and "linear_attn" not in layer_name_inner):
3273+
kv_cache_raw_tensors[layer_name_inner] = (k_tensor,
3274+
v_tensor)
32693275

32703276
layer_names = set()
32713277
for group in kv_cache_config.kv_cache_groups:

0 commit comments

Comments
 (0)