Skip to content

Commit ab12dee

Browse files
committed
fix dsa
Signed-off-by: MengqingCao <[email protected]>
1 parent a34fc72 commit ab12dee

File tree

1 file changed

+27
-12
lines changed

1 file changed

+27
-12
lines changed

vllm_ascend/worker/model_runner_v1.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3222,23 +3222,38 @@ def _reshape_kv_cache_tensors(
32223222
else:
32233223
# k_cache: nope_cache v_cache: rope_cache
32243224
mla_num_blocks, mla_block_size, num_kv_heads, _ = kv_cache_shape
3225-
k_shape = [
3226-
mla_num_blocks, mla_block_size, num_kv_heads,
3227-
self.model_config.hf_text_config.kv_lora_rank
3228-
]
3229-
v_shape = [
3230-
mla_num_blocks, mla_block_size, num_kv_heads,
3231-
self.model_config.hf_text_config.qk_rope_head_dim
3232-
]
3233-
3225+
if not self.use_sparse:
3226+
k_shape = [
3227+
mla_num_blocks, mla_block_size, num_kv_heads,
3228+
self.model_config.hf_text_config.kv_lora_rank
3229+
]
3230+
v_shape = [
3231+
mla_num_blocks, mla_block_size, num_kv_heads,
3232+
self.model_config.hf_text_config.
3233+
qk_rope_head_dim
3234+
]
3235+
else:
3236+
k_shape = [
3237+
mla_num_blocks, mla_block_size, num_kv_heads,
3238+
self.model_config.hf_text_config.kv_lora_rank
3239+
]
3240+
v_shape = [
3241+
mla_num_blocks, mla_block_size, num_kv_heads,
3242+
self.model_config.hf_text_config.
3243+
qk_rope_head_dim
3244+
]
32343245
k_cache = raw_k_tensor.view(dtype).view(k_shape)
32353246
k_cache = self._convert_torch_format(k_cache)
32363247
v_cache = raw_v_tensor.view(dtype).view(v_shape)
32373248
v_cache = self._convert_torch_format(v_cache)
32383249
if self.use_sparse and raw_dsa_k_cache is not None:
3239-
dsa_k_cache_shape = (num_blocks, block_size, 1, 128)
3240-
dsa_k_cache = raw_dsa_k_cache.view(dtype).view(
3241-
dsa_k_cache_shape)
3250+
dsa_k_cache_shape = (num_blocks,
3251+
kv_cache_spec.block_size, 1, 128)
3252+
dsa_k_cache_size = (
3253+
num_blocks
3254+
) * kv_cache_spec.block_size * 128 * dtype.itemsize
3255+
dsa_k_cache = raw_dsa_k_cache[:dsa_k_cache_size].view(
3256+
dtype).view(dsa_k_cache_shape)
32423257
kv_caches[layer_name] = (k_cache, v_cache, dsa_k_cache)
32433258
else:
32443259
kv_caches[layer_name] = (k_cache, v_cache)

0 commit comments

Comments
 (0)