Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
device="cpu",
pin_memory=pin_memory)
self.paged_kv_indptr_np = self.paged_kv_indptr_cpu.numpy()
self.paged_kv_indptr_buffer = torch.zeros_like(
self.paged_kv_indptr_cpu, pin_memory=pin_memory)
self.paged_kv_indices_cpu = torch.zeros(max_num_pages,
dtype=torch.int32,
device="cpu",
Expand Down Expand Up @@ -361,12 +363,18 @@ def build(self,
dtype=np.int32,
out=self.paged_kv_indptr_np[1:num_reqs + 1],
)
# NOTE(woosuk): Because self.paged_kv_indptr_cpu can be modified
# after this line (e.g., for cuda graphs), we need to copy the data to
# self.paged_kv_indptr_buffer to avoid race condition.
self.paged_kv_indptr_buffer[:num_reqs +
1] = (self.paged_kv_indptr_cpu[:num_reqs +
1])
paged_kv_indptr = self.paged_kv_indptr[:num_reqs + 1]
paged_kv_indptr.copy_(self.paged_kv_indptr_cpu[:num_reqs + 1],
paged_kv_indptr.copy_(self.paged_kv_indptr_buffer[:num_reqs + 1],
non_blocking=True)

# write self.paged_kv_indices inplace
num_actual_pages = num_blocks_np.sum().item()
num_actual_pages = self.paged_kv_indptr_np[num_reqs]
paged_kv_indices = self.paged_kv_indices[:num_actual_pages]
_copy_page_indices_kernel[(num_reqs, )](
paged_kv_indices,
Expand Down