Skip to content

Commit fbee279

Browse files
fix: remove duplicate layer multiplication in KV cache size calculation (#6481)
Signed-off-by: Jaedeok Kim <[email protected]>
1 parent 7bb0a78 commit fbee279

File tree

3 files changed

+27
-19
lines changed

3 files changed

+27
-19
lines changed

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,9 +477,9 @@ std::map<SizeType32, float> BlockManager::calculateWindowSizeToShare(
477477
windowSizeToContribution[windowSize] = cacheSizeWeight;
478478
}
479479

480-
for (auto const& [windowSize, layers] : windowSizeToLayers)
480+
for (auto const& [windowSize, _] : windowSizeToLayers)
481481
{
482-
windowSizeToContribution.at(windowSize) *= windowSize * layers.size();
482+
windowSizeToContribution.at(windowSize) *= windowSize;
483483
}
484484
auto const windowSizesTotalSum = std::accumulate(windowSizeToContribution.begin(), windowSizeToContribution.end(),
485485
0.0, [](auto sum, auto const& windowSize) { return sum + windowSize.second; });

cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -414,8 +414,8 @@ TEST_F(KVCacheManagerTest, BlockManagerTestWindowSizeToShare)
414414
{
415415
std::map<SizeType32, std::vector<SizeType32>> windowSizeToLayers{
416416
{1024, {1}}, // contribution = 1024*1 = 1024
417-
{4096, {0, 4, 5}}, // contribution = 4096*3 = 12288
418-
{8192, {2, 3}}, // contribution = 8192*2 = 16384
417+
{4096, {0, 4, 5}}, // contribution = 4096*1 = 4096
418+
{8192, {2, 3}}, // contribution = 8192*1 = 8192
419419
};
420420
// Use identical cache size per token across window sizes for simplicity.
421421
std::map<SizeType32, SizeType32> cacheSizePerTokenPerWindow{{1024, 1}, {4096, 1}, {8192, 1}};
@@ -431,9 +431,9 @@ TEST_F(KVCacheManagerTest, BlockManagerTestWindowSizeToShare)
431431
// Calculate expected shares based on contributions.
432432
std::map<SizeType32, float> expectedShares;
433433
std::map<SizeType32, SizeType32> contributions;
434-
for (auto const& [windowSize, layers] : windowSizeToLayers)
434+
for (auto const& [windowSize, _] : windowSizeToLayers)
435435
{
436-
contributions[windowSize] = windowSize * static_cast<SizeType32>(layers.size());
436+
contributions[windowSize] = windowSize * 1.0f;
437437
}
438438
auto const totalContribution = std::accumulate(contributions.begin(), contributions.end(), 0.0f,
439439
[](float sum, auto const& kv) { return sum + kv.second; });
@@ -445,27 +445,28 @@ TEST_F(KVCacheManagerTest, BlockManagerTestWindowSizeToShare)
445445
}
446446

447447
// Verify the exact hard-coded values mentioned in the comment
448-
EXPECT_NEAR(result.at(1024), 0.0345f, 1e-4f);
449-
EXPECT_NEAR(result.at(4096), 0.4138f, 1e-4f);
450-
EXPECT_NEAR(result.at(8192), 0.5517f, 1e-4f);
448+
EXPECT_NEAR(result.at(1024), 0.0769f, 1e-4f);
449+
EXPECT_NEAR(result.at(4096), 0.3077f, 1e-4f);
450+
EXPECT_NEAR(result.at(8192), 0.6154f, 1e-4f);
451451

452452
// Verify that when shares are converted to actual block counts, they match expected values.
453453
auto getRoundedBlocks
454454
= [&](float share) { return static_cast<SizeType32>(std::round(share * numPrimaryBlocks)); };
455-
EXPECT_EQ(getRoundedBlocks(result.at(1024)), 565);
456-
EXPECT_EQ(getRoundedBlocks(result.at(4096)), 6780);
457-
EXPECT_EQ(getRoundedBlocks(result.at(8192)), 9039);
455+
EXPECT_EQ(getRoundedBlocks(result.at(1024)), 1260);
456+
EXPECT_EQ(getRoundedBlocks(result.at(4096)), 5041);
457+
EXPECT_EQ(getRoundedBlocks(result.at(8192)), 10082);
458458
}
459459

460460
// Variable window size with different cache sizes per token per window
461461
{
462462
std::map<SizeType32, std::vector<SizeType32>> windowSizeToLayers{
463-
{1024, {1}}, // contribution = 1024*1*2 = 2048 (cache size per token = 2)
464-
{4096, {0, 4, 5}}, // contribution = 4096*3*4 = 49152 (cache size per token = 4)
465-
{8192, {2, 3}}, // contribution = 8192*2*1 = 16384 (cache size per token = 1)
463+
{1024, {1}}, // contribution = 1024*(1*2) = 2048 (cache size per token per layer = 2)
464+
{4096, {0, 4, 5}}, // contribution = 4096*(3*4) = 49152 (cache size per token per layer = 4)
465+
{8192, {2, 3}}, // contribution = 8192*(2*1) = 16384 (cache size per token per layer = 1)
466466
};
467-
// Different cache sizes per token per window
468-
std::map<SizeType32, SizeType32> cacheSizePerTokenPerWindow{{1024, 2}, {4096, 4}, {8192, 1}};
467+
// Different cache sizes per token per window.
468+
// cacheSizePerTokenPerWindow is accumulated across the layers of given window size.
469+
std::map<SizeType32, SizeType32> cacheSizePerTokenPerWindow{{1024, 2}, {4096, 12}, {8192, 2}};
469470

470471
auto result = BlockManager::calculateWindowSizeToShare(windowSizeToLayers, cacheSizePerTokenPerWindow);
471472
EXPECT_EQ(result.size(), 3);
@@ -478,10 +479,10 @@ TEST_F(KVCacheManagerTest, BlockManagerTestWindowSizeToShare)
478479
// Calculate expected shares based on contributions with different cache sizes per token.
479480
std::map<SizeType32, float> expectedShares;
480481
std::map<SizeType32, SizeType32> contributions;
481-
for (auto const& [windowSize, layers] : windowSizeToLayers)
482+
for (auto const& [windowSize, _] : windowSizeToLayers)
482483
{
483484
auto const cacheSizePerToken = cacheSizePerTokenPerWindow.at(windowSize);
484-
contributions[windowSize] = windowSize * static_cast<SizeType32>(layers.size()) * cacheSizePerToken;
485+
contributions[windowSize] = windowSize * cacheSizePerToken;
485486
}
486487
auto const totalContribution = std::accumulate(contributions.begin(), contributions.end(), 0.0f,
487488
[](float sum, auto const& kv) { return sum + kv.second; });

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,13 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
206206
self.max_attention_window_vec = kv_cache_config.max_attention_window.copy(
207207
) # Make a copy to avoid modifying original
208208

209+
# Clamp all window sizes to max_seq_len before calculating the
210+
# number of KV cache blocks. This prevents the KV cache pool from
211+
# being skewed by the largest window values.
212+
self.max_attention_window_vec = [
213+
min(max_seq_len, w) for w in self.max_attention_window_vec
214+
]
215+
209216
sink_token_length = (kv_cache_config.sink_token_length
210217
if kv_cache_config.sink_token_length is not None
211218
else 0)

0 commit comments

Comments
 (0)