Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion vllm/v1/core/kv_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,29 @@
return kv_cache_groups


def create_kv_cache_group_specs_for_variable_sizes(
kv_cache_spec: dict[str, KVCacheSpec],
grouped_layer_names: list[list[str]]) -> list[KVCacheGroupSpec]:

kv_cache_groups = []
for layer_names_one_group in grouped_layer_names:
layer_specs = [
kv_cache_spec[layer_name] for layer_name in layer_names_one_group
]
merged_layer_spec = layer_specs[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using the spec of the first layer (layer_specs[0]) to represent the entire group can be problematic when layers have different head_size values, as indicated by the PR's purpose. This leads to an incorrect KVCacheGroupSpec for the group.

For instance, get_max_concurrency_for_kv_cache_config uses group.kv_cache_spec.page_size_bytes for its calculation. If the first layer has a smaller head_size than other layers in the group, its page_size_bytes will be smaller than the maximum page size used for allocation in _get_kv_cache_config_allFullAttentionSpec_type. This will cause max_concurrency to be overestimated, which could lead to out-of-memory errors at runtime.

To ensure correctness, the merged_layer_spec should represent the most demanding configuration within the group, which corresponds to the layer with the largest page_size_bytes.

Suggested change
merged_layer_spec = layer_specs[0]
merged_layer_spec = max(layer_specs, key=lambda spec: spec.page_size_bytes)

kv_cache_groups.append(
KVCacheGroupSpec(layer_names_one_group, merged_layer_spec))
return kv_cache_groups


def is_kv_cache_instance_all_FullAttentionSpec(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:
if all(isinstance(spec, FullAttentionSpec) for spec in kv_cache_spec.values()):
return True
else:
return False


def is_kv_cache_type_uniform(kv_cache_spec: dict[str, KVCacheSpec]) -> bool:

Check failure on line 760 in vllm/v1/core/kv_cache_utils.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (SIM103)

vllm/v1/core/kv_cache_utils.py:755:5: SIM103 Return the condition directly
"""
Whether all layers in the given KVCacheSpec have the same KV cache spec.
Note that we regard FullAttentionSpec with and without sliding window as
Expand Down Expand Up @@ -806,6 +828,42 @@
return page_sizes.pop()


def _get_kv_cache_config_allFullAttentionSpec_type(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
page_size = max([layer.page_size_bytes for layer in kv_cache_spec.values()])
num_blocks = get_num_blocks(vllm_config, len(kv_cache_spec),
available_memory, page_size)

per_layer_size = page_size * num_blocks
# All layers have the same KV cache spec, so we create one kv cache group
# for all layers.
grouped_layer_names = [list(kv_cache_spec.keys())]

# Each layer uses a separate Tensor to store its KV cache.
kv_cache_tensors = [
KVCacheTensor(size=per_layer_size, shared_by=[layer_name])
for layer_name in kv_cache_spec
]

kv_cache_config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=kv_cache_tensors,
kv_cache_groups=create_kv_cache_group_specs_for_variable_sizes(kv_cache_spec,
grouped_layer_names),
)

num_tokens = num_blocks * vllm_config.cache_config.block_size
num_tokens_str = f"{num_tokens:,}"
logger.info("GPU KV cache size: %s tokens", num_tokens_str)
max_model_len_str = f"{vllm_config.model_config.max_model_len:,}"
max_concurrency = get_max_concurrency_for_kv_cache_config(
vllm_config, kv_cache_config)
logger.info("Maximum concurrency for %s tokens per request: %.2fx",
max_model_len_str, max_concurrency)
return kv_cache_config


def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig,
kv_cache_spec: dict[str, KVCacheSpec],
available_memory: int) -> KVCacheConfig:
Expand Down Expand Up @@ -1071,7 +1129,7 @@
attention_chunk_size=spec.attention_chunk_size,
)

if not is_kv_cache_type_uniform(kv_cache_spec):
if not is_kv_cache_type_uniform(kv_cache_spec) and not is_kv_cache_instance_all_FullAttentionSpec(kv_cache_spec):
raise ValueError("Hybrid KV cache manager is disabled but failed to "
"convert the KV cache specs to one unified type.")

Expand Down Expand Up @@ -1114,6 +1172,9 @@
return _get_kv_cache_config_uniform_page_size(vllm_config,
kv_cache_spec,
available_memory)
elif is_kv_cache_instance_all_FullAttentionSpec(kv_cache_spec):
return _get_kv_cache_config_allFullAttentionSpec_type(vllm_config, kv_cache_spec,
available_memory)

raise NotImplementedError

Expand Down
Loading