@@ -166,6 +166,8 @@ class KVCacheBlock
166
166
public:
167
167
using IdType = std::int32_t ;
168
168
169
+ static constexpr IdType kCachedBlocksRootId = -1 ;
170
+
169
171
explicit KVCacheBlock (IdType blockId, kernels::KVCacheIndex blockIdx);
170
172
171
173
void startScheduling ();
@@ -379,6 +381,16 @@ class GenerationRequest
379
381
return mKvCacheRetentionConfig .getDecodeDurationMs ();
380
382
}
381
383
384
+ [[nodiscard]] bool getContextRequiresCyclicKvCache () const
385
+ {
386
+ return mContextRequiresCyclicKvCache ;
387
+ }
388
+
389
+ void setContextRequiresCyclicKvCache (bool contextRequiresCyclicKvCache)
390
+ {
391
+ mContextRequiresCyclicKvCache = contextRequiresCyclicKvCache;
392
+ }
393
+
382
394
private:
383
395
// Request id of the sequence
384
396
LlmRequest::RequestIdType mRequestId ;
@@ -392,6 +404,9 @@ class GenerationRequest
392
404
runtime::ITensor::SharedPtr mCacheBlockIndices ;
393
405
// The retention priority to assign to decode blocks
394
406
executor::KvCacheRetentionConfig mKvCacheRetentionConfig ;
407
+
408
+ // A value indicating whether or not the context is long enough to warrant the use of cyclic kv-cache.
409
+ bool mContextRequiresCyclicKvCache {false };
395
410
};
396
411
397
412
// attach metadata to a pool pointer
@@ -443,7 +458,7 @@ class BlockManager
443
458
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks,
444
459
CacheType cacheType = CacheType::kSELF ,
445
460
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt ,
446
- std::shared_ptr<KVCacheEventManager> eventManager = nullptr );
461
+ std::shared_ptr<KVCacheEventManager> eventManager = nullptr , bool enableHashKey = false );
447
462
448
463
~BlockManager ();
449
464
@@ -712,6 +727,9 @@ class BlockManager
712
727
SizeType32 mMissedBlocks ;
713
728
std::set<KVCacheBlock::IdType> reusedBlockIds;
714
729
730
+ // Whether or not to maintain a hashmap of blocks.
731
+ bool mEnableHashKey ;
732
+
715
733
private:
716
734
friend class KVCacheManager ;
717
735
};
@@ -818,16 +836,18 @@ class BaseKVCacheManager
818
836
// ! \details These blocks become reusable from next step.
819
837
virtual void storeContextBlocks (LlmRequest const & llmRequest) = 0;
820
838
821
- virtual bool schedulingHasFreeBlocks (SizeType32 numRequired = 1 ) const = 0;
839
+ [[nodiscard]] virtual bool schedulingHasFreeBlocks (SizeType32 numRequired = 1 ) const = 0;
822
840
823
- virtual std::vector<std::vector<SizeType32>> const & getCacheBlockIds (LlmRequest::RequestIdType requestId) const = 0;
841
+ [[nodiscard]] virtual std::vector<std::vector<SizeType32>> const & getCacheBlockIds (
842
+ LlmRequest::RequestIdType requestId) const
843
+ = 0;
824
844
825
- virtual std::vector<std::vector<std::vector<SizeType32>>> getBatchCacheBlockIds (
845
+ [[nodiscard]] virtual std::vector<std::vector<std::vector<SizeType32>>> getBatchCacheBlockIds (
826
846
std::vector<LlmRequest::RequestIdType> const & requestIds) const
827
847
= 0;
828
848
829
- virtual runtime::ITensor::SharedPtr getPrimaryPool (SizeType32 layer_idx) const = 0;
830
- virtual SizeType32 getPoolLayerIdx (SizeType32 layer_idx) const = 0;
849
+ [[nodiscard]] virtual runtime::ITensor::SharedPtr getPrimaryPool (SizeType32 layer_idx) const = 0;
850
+ [[nodiscard]] virtual SizeType32 getPoolLayerIdx (SizeType32 layer_idx) const = 0;
831
851
832
852
virtual void refreshBlocks () = 0;
833
853
virtual void flushIterationEvents () = 0;
@@ -846,7 +866,7 @@ class BaseKVCacheManager
846
866
* 2 * modelConfig.getSizePerHead ();
847
867
}
848
868
849
- [[nodiscard]] static std::tuple<SizeType32, SizeType32> const calculateMaxNumBlocks (KvCacheConfig const & config,
869
+ [[nodiscard]] static std::tuple<SizeType32, SizeType32> calculateMaxNumBlocks (KvCacheConfig const & config,
850
870
nvinfer1::DataType dtype, tensorrt_llm::runtime::ModelConfig const & modelConfig,
851
871
tensorrt_llm::runtime::WorldConfig const & worldConfig, runtime::BufferManager const & bufferManager);
852
872
@@ -924,7 +944,7 @@ class KVCacheManager : public BaseKVCacheManager
924
944
return mBlockManager .getNumFreeBlocks ();
925
945
}
926
946
927
- [[nodiscard]] virtual SizeType32 getNumPools () const override
947
+ [[nodiscard]] SizeType32 getNumPools () const override
928
948
{
929
949
return mBlockManager .getNumPools ();
930
950
}
@@ -994,8 +1014,6 @@ class KVCacheManager : public BaseKVCacheManager
994
1014
// / @return The number of blocks
995
1015
[[nodiscard]] SizeType32 getRemainingBlocksToCompletion (LlmRequest const & req) const override ;
996
1016
997
- void addContextTokens (LlmRequest::RequestIdType requestId, SizeType32 numTokens);
998
-
999
1017
// / @brief Increase size for request with requestId. Allocate new KV cache block(s) if needed.
1000
1018
void addToken (LlmRequest::RequestIdType requestId) override ;
1001
1019
0 commit comments