@@ -3222,23 +3222,38 @@ def _reshape_kv_cache_tensors(
32223222 else :
32233223 # k_cache: nope_cache v_cache: rope_cache
32243224 mla_num_blocks , mla_block_size , num_kv_heads , _ = kv_cache_shape
3225- k_shape = [
3226- mla_num_blocks , mla_block_size , num_kv_heads ,
3227- self .model_config .hf_text_config .kv_lora_rank
3228- ]
3229- v_shape = [
3230- mla_num_blocks , mla_block_size , num_kv_heads ,
3231- self .model_config .hf_text_config .qk_rope_head_dim
3232- ]
3233-
3225+ if not self .use_sparse :
3226+ k_shape = [
3227+ mla_num_blocks , mla_block_size , num_kv_heads ,
3228+ self .model_config .hf_text_config .kv_lora_rank
3229+ ]
3230+ v_shape = [
3231+ mla_num_blocks , mla_block_size , num_kv_heads ,
3232+ self .model_config .hf_text_config .
3233+ qk_rope_head_dim
3234+ ]
3235+ else :
3236+ k_shape = [
3237+ mla_num_blocks , mla_block_size , num_kv_heads ,
3238+ self .model_config .hf_text_config .kv_lora_rank
3239+ ]
3240+ v_shape = [
3241+ mla_num_blocks , mla_block_size , num_kv_heads ,
3242+ self .model_config .hf_text_config .
3243+ qk_rope_head_dim
3244+ ]
32343245 k_cache = raw_k_tensor .view (dtype ).view (k_shape )
32353246 k_cache = self ._convert_torch_format (k_cache )
32363247 v_cache = raw_v_tensor .view (dtype ).view (v_shape )
32373248 v_cache = self ._convert_torch_format (v_cache )
32383249 if self .use_sparse and raw_dsa_k_cache is not None :
3239- dsa_k_cache_shape = (num_blocks , block_size , 1 , 128 )
3240- dsa_k_cache = raw_dsa_k_cache .view (dtype ).view (
3241- dsa_k_cache_shape )
3250+ dsa_k_cache_shape = (num_blocks ,
3251+ kv_cache_spec .block_size , 1 , 128 )
3252+ dsa_k_cache_size = (
3253+ num_blocks
3254+ ) * kv_cache_spec .block_size * 128 * dtype .itemsize
3255+ dsa_k_cache = raw_dsa_k_cache [:dsa_k_cache_size ].view (
3256+ dtype ).view (dsa_k_cache_shape )
32423257 kv_caches [layer_name ] = (k_cache , v_cache , dsa_k_cache )
32433258 else :
32443259 kv_caches [layer_name ] = (k_cache , v_cache )
0 commit comments