33
44import csv
55import os
6- import random
76from datetime import datetime
7+ from typing import Optional
88
99import flashinfer
1010import torch
1111
1212FLOAT32_BYTES = torch .finfo (torch .float ).bits // 8
13-
14- # KV Cache Layout for TRT-LLM
15- # kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
13+ FP8_DTYPE = torch .float8_e4m3fn
1614
1715
1816def to_float8 (x , dtype = torch .float8_e4m3fn ):
@@ -26,149 +24,168 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
2624
2725@torch .no_grad ()
2826def benchmark_decode (
29- num_seqs ,
30- max_seq_len ,
31- page_size = 16 ,
32- dtype = torch .bfloat16 ,
33- kv_layout = "HND" ,
34- num_kv_heads = 8 ,
35- kv_cache_dtype = "auto" ,
36- head_dim = 128 ,
37- warmup = 10 ,
38- trials = 20 ,
27+ dtype : torch .dtype ,
28+ quant_dtypes : tuple [
29+ Optional [torch .dtype ], Optional [torch .dtype ], Optional [torch .dtype ]
30+ ],
31+ batch_size : int ,
32+ max_seq_len : int ,
33+ num_heads : tuple [int , int ] = (64 , 8 ),
34+ head_size : int = 128 ,
35+ kv_layout : str = "HND" ,
36+ block_size : int = 16 ,
37+ warmup : int = 10 ,
38+ trials : int = 20 ,
3939):
4040 torch .set_default_device ("cuda" )
41- device = "cuda"
4241 torch .manual_seed (0 )
4342
44- HEAD_GRP_SIZE = 8
45- MAX_SEQ_LEN = max_seq_len
46-
47- # large number to reduce kv_cache reuse
48- NUM_BLOCKS = int (256000 / page_size )
49-
50- workspace_buffer = torch .empty (1024 * 1024 * 1024 , dtype = torch .int8 , device = device )
43+ q_quant_dtype , kv_quant_dtype , o_quant_dtype = quant_dtypes
44+ q_quant_dtype = q_quant_dtype or dtype
45+ kv_quant_dtype = kv_quant_dtype or dtype
46+ o_quant_dtype = o_quant_dtype or dtype
5147
52- # For decode, batch_size is num_decode_token
53- num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
54- sm_scale = float (1.0 / (head_dim ** 0.5 ))
55- q = torch .randn (num_seqs , num_qo_heads , head_dim , device = device , dtype = dtype )
56- kv_lens = [random .randint (1 , MAX_SEQ_LEN ) for _ in range (num_seqs )]
48+ num_qo_heads , num_kv_heads = num_heads
49+ assert num_qo_heads % num_kv_heads == 0
5750
58- max_kv_len = max (kv_lens )
59- kv_lens_tensor = torch .tensor (kv_lens , dtype = torch .int , device = device )
60- max_num_blocks_per_seq = (max_kv_len + page_size - 1 ) // page_size
51+ sm_scale = float (1.0 / (head_size ** 0.5 ))
6152
53+ # large number to reduce kv_cache reuse
54+ NUM_BLOCKS = int (256000 / block_size )
55+
56+ kv_cache_shape = None
57+ if kv_layout == "NHD" :
58+ kv_cache_shape = (NUM_BLOCKS , 2 , block_size , num_kv_heads , head_size )
59+ elif kv_layout == "HND" :
60+ kv_cache_shape = (NUM_BLOCKS , 2 , num_kv_heads , block_size , head_size )
61+ else :
62+ raise ValueError (f"Invalid kv_layout: { kv_layout } " )
63+
64+ query = torch .randn (batch_size , num_qo_heads , head_size , dtype = dtype )
65+ if q_quant_dtype == FP8_DTYPE :
66+ query , q_scale = to_float8 (query )
67+ ref_query = query .to (dtype ) * q_scale
68+ else :
69+ q_scale = 1.0
70+ ref_query = query
71+
72+ kv_lens = torch .randint (1 , max_seq_len , (batch_size ,), dtype = torch .int32 )
73+ kv_lens [- 1 ] = max_seq_len
74+
75+ seq_lens = kv_lens
76+ max_seq_len = torch .max (seq_lens ).item ()
77+
78+ kv_cache = torch .randn (kv_cache_shape , dtype = dtype )
79+ if kv_quant_dtype == FP8_DTYPE :
80+ kv_cache , kv_scale = to_float8 (kv_cache )
81+ ref_kv_cache = kv_cache .to (dtype ) * kv_scale
82+ else :
83+ kv_scale = 1.0
84+ ref_kv_cache = kv_cache
85+ k_scale = v_scale = kv_scale
86+
87+ max_num_blocks_per_seq = (max_seq_len + block_size - 1 ) // block_size
6288 block_tables = torch .randint (
63- 0 , NUM_BLOCKS , (num_seqs , max_num_blocks_per_seq ), dtype = torch .int32
89+ 0 , NUM_BLOCKS , (batch_size , max_num_blocks_per_seq ), dtype = torch .int32
6490 )
65-
66- kv_cache_shape = (NUM_BLOCKS , 2 , num_kv_heads , page_size , head_dim )
67- kv_cache = torch .randn (size = kv_cache_shape , device = device , dtype = dtype )
68- k_scale = v_scale = 1.0
69-
70- if kv_cache_dtype .startswith ("fp8" ):
71- kv_cache , _ = to_float8 (kv_cache )
72-
73- output_trtllm = torch .empty (q .shape , dtype = dtype )
74-
75- # Benchmark TRT decode
76- def trt_decode ():
77- return flashinfer .decode .trtllm_batch_decode_with_kv_cache (
78- q ,
79- kv_cache ,
80- workspace_buffer ,
81- block_tables ,
82- kv_lens_tensor ,
83- max_kv_len ,
84- bmm1_scale = k_scale * sm_scale ,
85- bmm2_scale = v_scale ,
86- out = output_trtllm ,
87- )
88-
89- def time_fn (fn , warmup = 10 , trials = 20 ):
90- torch .cuda .synchronize ()
91- start = torch .cuda .Event (enable_timing = True )
92- end = torch .cuda .Event (enable_timing = True )
93- times = []
94- for i in range (warmup ):
95- fn ()
96- for i in range (trials ):
97- start .record ()
98- fn ()
99- end .record ()
100- torch .cuda .synchronize ()
101- times .append (start .elapsed_time (end )) # ms
102- return sum (times ) / len (times ), torch .std (torch .tensor (times ))
103-
104- # TRT Decode
105- trt_mean , trt_std = time_fn (trt_decode )
106-
10791 kv_indptr = [0 ]
10892 kv_indices = []
10993 kv_last_page_lens = []
110- for i in range (num_seqs ):
111- seq_len = kv_lens [i ]
94+ for i in range (batch_size ):
95+ seq_len = seq_lens [i ]
11296 assert seq_len > 0
113- num_blocks = (seq_len + page_size - 1 ) // page_size
97+ num_blocks = (seq_len + block_size - 1 ) // block_size
11498 kv_indices .extend (block_tables [i , :num_blocks ])
11599 kv_indptr .append (kv_indptr [- 1 ] + num_blocks )
116- kv_last_page_len = seq_len % page_size
100+ kv_last_page_len = seq_len % block_size
117101 if kv_last_page_len == 0 :
118- kv_last_page_len = page_size
102+ kv_last_page_len = block_size
119103 kv_last_page_lens .append (kv_last_page_len )
120104
121105 kv_indptr = torch .tensor (kv_indptr , dtype = torch .int32 )
122106 kv_indices = torch .tensor (kv_indices , dtype = torch .int32 )
123107 kv_last_page_lens = torch .tensor (kv_last_page_lens , dtype = torch .int32 )
124-
125- output_baseline = torch .empty (q .shape , dtype = dtype )
108+ workspace_buffer = torch .zeros (1024 * 1024 * 1024 , dtype = torch .int8 )
126109
127110 wrapper = flashinfer .BatchDecodeWithPagedKVCacheWrapper (
128111 workspace_buffer ,
129112 kv_layout ,
130113 use_tensor_cores = ((num_qo_heads // num_kv_heads ) > 4 ),
131114 )
132-
133115 wrapper .plan (
134116 kv_indptr ,
135117 kv_indices ,
136118 kv_last_page_lens ,
137119 num_qo_heads ,
138120 num_kv_heads ,
139- head_dim ,
140- page_size ,
121+ head_size ,
122+ block_size ,
141123 "NONE" ,
124+ sm_scale = sm_scale ,
142125 q_data_type = dtype ,
143- kv_data_type = torch . float8_e4m3fn if kv_cache_dtype . startswith ( "fp8" ) else dtype ,
126+ kv_data_type = dtype ,
144127 )
145128
129+ def time_fn (fn , warmup = 10 , trials = 20 ):
130+ torch .cuda .synchronize ()
131+ start = torch .cuda .Event (enable_timing = True )
132+ end = torch .cuda .Event (enable_timing = True )
133+ times = []
134+ for i in range (warmup ):
135+ fn ()
136+ for i in range (trials ):
137+ start .record ()
138+ fn ()
139+ end .record ()
140+ torch .cuda .synchronize ()
141+ times .append (start .elapsed_time (end )) # ms
142+ return sum (times ) / len (times ), torch .std (torch .tensor (times ))
143+
144+ o_scale = 1.0
145+ output_baseline = torch .empty (ref_query .shape , dtype = dtype )
146+ output_trtllm = torch .empty (query .shape , dtype = o_quant_dtype )
147+
146148 def baseline_decode ():
147- return wrapper .run (q , kv_cache , sm_scale , k_scale , v_scale , output_baseline )
149+ return wrapper .run (ref_query , ref_kv_cache , out = output_baseline )
150+
151+ def trtllm_decode ():
152+ return flashinfer .decode .trtllm_batch_decode_with_kv_cache (
153+ query = query ,
154+ kv_cache = kv_cache ,
155+ workspace_buffer = workspace_buffer ,
156+ block_tables = block_tables ,
157+ seq_lens = seq_lens ,
158+ max_seq_len = max_seq_len ,
159+ bmm1_scale = q_scale * k_scale * sm_scale ,
160+ bmm2_scale = v_scale / o_scale ,
161+ out = output_trtllm ,
162+ )
148163
149164 baseline_mean , baseline_std = time_fn (baseline_decode )
165+ trtllm_mean , trtllm_std = time_fn (trtllm_decode )
150166
151167 # Calculate percentage speedup (positive means TRT is faster)
152- speedup_percent = (baseline_mean - trt_mean ) / baseline_mean
168+ speedup_percent = (baseline_mean - trtllm_mean ) / baseline_mean
153169
154170 print (
155- f"\t { num_seqs } \t { max_seq_len } \t { trt_mean :.3f} \t { trt_std .item ():.3f} "
171+ f"\t { batch_size } \t { max_seq_len } \t { trtllm_mean :.3f} \t { trtllm_std .item ():.3f} "
156172 f"\t { baseline_mean :.3f} \t { baseline_std .item ():.3f} \t { speedup_percent :.3f} "
157173 )
158174
159175 # Return results for CSV writing
160176 return {
161- "num_seqs " : num_seqs ,
162- "trt_mean " : trt_mean ,
163- "trt_std " : trt_std .item (),
177+ "batch_size " : batch_size ,
178+ "trtllm_mean " : trtllm_mean ,
179+ "trtllm_std " : trtllm_std .item (),
164180 "baseline_mean" : baseline_mean ,
165181 "baseline_std" : baseline_std .item (),
166182 "speedup_percent" : speedup_percent ,
167- "q_dtype" : str (dtype ),
168- "kv_cache_dtype" : kv_cache_dtype ,
169- "page_size" : page_size ,
183+ "q_dtype" : str (q_quant_dtype ),
184+ "kv_cache_dtype" : str (kv_quant_dtype ),
185+ "output_dtype" : str (o_quant_dtype ),
186+ "block_size" : block_size ,
170187 "num_kv_heads" : num_kv_heads ,
171- "head_dim " : head_dim ,
188+ "head_size " : head_size ,
172189 "max_seq_len" : max_seq_len ,
173190 }
174191
@@ -180,17 +197,18 @@ def write_results_to_csv(results, filename=None):
180197 filename = f"flashinfer_trtllm_benchmark_{ timestamp } .csv"
181198
182199 fieldnames = [
183- "num_seqs " ,
184- "trt_mean " ,
185- "trt_std " ,
200+ "batch_size " ,
201+ "trtllm_mean " ,
202+ "trtllm_std " ,
186203 "baseline_mean" ,
187204 "baseline_std" ,
188205 "speedup_percent" ,
189206 "q_dtype" ,
190207 "kv_cache_dtype" ,
191- "page_size" ,
208+ "output_dtype" ,
209+ "block_size" ,
192210 "num_kv_heads" ,
193- "head_dim " ,
211+ "head_size " ,
194212 "max_seq_len" ,
195213 ]
196214
@@ -209,45 +227,42 @@ def write_results_to_csv(results, filename=None):
209227
210228
211229if __name__ == "__main__" :
212- num_seqs = [1 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
230+ batch_sizes = [1 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
213231 max_seq_lens = [1024 , 2048 , 4096 , 8192 , 16384 , 32768 , 65536 , 131072 ]
214232 all_results = []
215233
216- print (
217- "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, "
218- "output_dtype: bfloat16"
219- )
220- print (
221- "\t num_seqs\t max_seq_len\t trt_mean\t trt_std\t baseline_mean\t "
222- "baseline_std\t speedup_percent"
223- )
224- for max_seq_len in max_seq_lens :
225- for bs in num_seqs :
226- result = benchmark_decode (
227- bs ,
228- max_seq_len ,
229- dtype = torch .bfloat16 ,
230- kv_cache_dtype = "auto" ,
231- )
232- all_results .append (result )
234+ dtype = torch .bfloat16
235+ quant_dtypes = [
236+ # (q_quant_dtype, kv_quant_dtype, o_quant_dtype)
237+ (None , None , None ),
238+ (None , FP8_DTYPE , None ),
239+ (FP8_DTYPE , FP8_DTYPE , FP8_DTYPE ),
240+ ]
233241
234- print (
235- "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, "
236- "output_dtype: bfloat16"
237- )
238- print (
239- "\t num_seqs\t max_seq_len\t trt_mean\t trt_std\t baseline_mean\t "
240- "baseline_std\t speedup_percent"
241- )
242- for max_seq_len in max_seq_lens :
243- for bs in num_seqs :
244- result = benchmark_decode (
245- bs ,
246- max_seq_len ,
247- dtype = torch .bfloat16 ,
248- kv_cache_dtype = "fp8" ,
249- )
250- all_results .append (result )
242+ for quant_dtype in quant_dtypes :
243+ q_quant_dtype , kv_quant_dtype , o_quant_dtype = quant_dtype
244+ q_quant_dtype = q_quant_dtype or dtype
245+ kv_quant_dtype = kv_quant_dtype or dtype
246+ o_quant_dtype = o_quant_dtype or dtype
247+
248+ print (
249+ f"Running benchmark for q_dtype = { q_quant_dtype } , "
250+ f"kv_cache_dtype: { kv_quant_dtype } , "
251+ f"output_dtype: { o_quant_dtype } "
252+ )
253+ print (
254+ "\t batch_size\t max_seq_len\t trtllm_mean\t trtllm_std\t baseline_mean\t "
255+ "baseline_std\t speedup_percent"
256+ )
257+ for max_seq_len in max_seq_lens :
258+ for bs in batch_sizes :
259+ result = benchmark_decode (
260+ dtype = dtype ,
261+ quant_dtypes = quant_dtype ,
262+ batch_size = bs ,
263+ max_seq_len = max_seq_len ,
264+ )
265+ all_results .append (result )
251266
252267 # Write all results to CSV
253268 write_results_to_csv (all_results )
0 commit comments