33
44import  csv 
55import  os 
6- import  random 
76from  datetime  import  datetime 
7+ from  typing  import  Optional 
88
99import  flashinfer 
1010import  torch 
1111
1212FLOAT32_BYTES  =  torch .finfo (torch .float ).bits  //  8 
13- 
14- # KV Cache Layout for TRT-LLM 
15- # kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim) 
13+ FP8_DTYPE  =  torch .float8_e4m3fn 
1614
1715
1816def  to_float8 (x , dtype = torch .float8_e4m3fn ):
@@ -26,149 +24,168 @@ def to_float8(x, dtype=torch.float8_e4m3fn):
2624
2725@torch .no_grad () 
2826def  benchmark_decode (
29-     num_seqs ,
30-     max_seq_len ,
31-     page_size = 16 ,
32-     dtype = torch .bfloat16 ,
33-     kv_layout = "HND" ,
34-     num_kv_heads = 8 ,
35-     kv_cache_dtype = "auto" ,
36-     head_dim = 128 ,
37-     warmup = 10 ,
38-     trials = 20 ,
27+     dtype : torch .dtype ,
28+     quant_dtypes : tuple [
29+         Optional [torch .dtype ], Optional [torch .dtype ], Optional [torch .dtype ]
30+     ],
31+     batch_size : int ,
32+     max_seq_len : int ,
33+     num_heads : tuple [int , int ] =  (64 , 8 ),
34+     head_size : int  =  128 ,
35+     kv_layout : str  =  "HND" ,
36+     block_size : int  =  16 ,
37+     warmup : int  =  10 ,
38+     trials : int  =  20 ,
3939):
4040    torch .set_default_device ("cuda" )
41-     device  =  "cuda" 
4241    torch .manual_seed (0 )
4342
44-     HEAD_GRP_SIZE  =  8 
45-     MAX_SEQ_LEN  =  max_seq_len 
46- 
47-     # large number to reduce kv_cache reuse 
48-     NUM_BLOCKS  =  int (256000  /  page_size )
49- 
50-     workspace_buffer  =  torch .empty (1024  *  1024  *  1024 , dtype = torch .int8 , device = device )
43+     q_quant_dtype , kv_quant_dtype , o_quant_dtype  =  quant_dtypes 
44+     q_quant_dtype  =  q_quant_dtype  or  dtype 
45+     kv_quant_dtype  =  kv_quant_dtype  or  dtype 
46+     o_quant_dtype  =  o_quant_dtype  or  dtype 
5147
52-     # For decode, batch_size is num_decode_token 
53-     num_qo_heads  =  num_kv_heads  *  HEAD_GRP_SIZE 
54-     sm_scale  =  float (1.0  /  (head_dim ** 0.5 ))
55-     q  =  torch .randn (num_seqs , num_qo_heads , head_dim , device = device , dtype = dtype )
56-     kv_lens  =  [random .randint (1 , MAX_SEQ_LEN ) for  _  in  range (num_seqs )]
48+     num_qo_heads , num_kv_heads  =  num_heads 
49+     assert  num_qo_heads  %  num_kv_heads  ==  0 
5750
58-     max_kv_len  =  max (kv_lens )
59-     kv_lens_tensor  =  torch .tensor (kv_lens , dtype = torch .int , device = device )
60-     max_num_blocks_per_seq  =  (max_kv_len  +  page_size  -  1 ) //  page_size 
51+     sm_scale  =  float (1.0  /  (head_size ** 0.5 ))
6152
53+     # large number to reduce kv_cache reuse 
54+     NUM_BLOCKS  =  int (256000  /  block_size )
55+ 
56+     kv_cache_shape  =  None 
57+     if  kv_layout  ==  "NHD" :
58+         kv_cache_shape  =  (NUM_BLOCKS , 2 , block_size , num_kv_heads , head_size )
59+     elif  kv_layout  ==  "HND" :
60+         kv_cache_shape  =  (NUM_BLOCKS , 2 , num_kv_heads , block_size , head_size )
61+     else :
62+         raise  ValueError (f"Invalid kv_layout: { kv_layout }  )
63+ 
64+     query  =  torch .randn (batch_size , num_qo_heads , head_size , dtype = dtype )
65+     if  q_quant_dtype  ==  FP8_DTYPE :
66+         query , q_scale  =  to_float8 (query )
67+         ref_query  =  query .to (dtype ) *  q_scale 
68+     else :
69+         q_scale  =  1.0 
70+         ref_query  =  query 
71+ 
72+     kv_lens  =  torch .randint (1 , max_seq_len , (batch_size ,), dtype = torch .int32 )
73+     kv_lens [- 1 ] =  max_seq_len 
74+ 
75+     seq_lens  =  kv_lens 
76+     max_seq_len  =  torch .max (seq_lens ).item ()
77+ 
78+     kv_cache  =  torch .randn (kv_cache_shape , dtype = dtype )
79+     if  kv_quant_dtype  ==  FP8_DTYPE :
80+         kv_cache , kv_scale  =  to_float8 (kv_cache )
81+         ref_kv_cache  =  kv_cache .to (dtype ) *  kv_scale 
82+     else :
83+         kv_scale  =  1.0 
84+         ref_kv_cache  =  kv_cache 
85+     k_scale  =  v_scale  =  kv_scale 
86+ 
87+     max_num_blocks_per_seq  =  (max_seq_len  +  block_size  -  1 ) //  block_size 
6288    block_tables  =  torch .randint (
63-         0 , NUM_BLOCKS , (num_seqs , max_num_blocks_per_seq ), dtype = torch .int32 
89+         0 , NUM_BLOCKS , (batch_size , max_num_blocks_per_seq ), dtype = torch .int32 
6490    )
65- 
66-     kv_cache_shape  =  (NUM_BLOCKS , 2 , num_kv_heads , page_size , head_dim )
67-     kv_cache  =  torch .randn (size = kv_cache_shape , device = device , dtype = dtype )
68-     k_scale  =  v_scale  =  1.0 
69- 
70-     if  kv_cache_dtype .startswith ("fp8" ):
71-         kv_cache , _  =  to_float8 (kv_cache )
72- 
73-     output_trtllm  =  torch .empty (q .shape , dtype = dtype )
74- 
75-     # Benchmark TRT decode 
76-     def  trt_decode ():
77-         return  flashinfer .decode .trtllm_batch_decode_with_kv_cache (
78-             q ,
79-             kv_cache ,
80-             workspace_buffer ,
81-             block_tables ,
82-             kv_lens_tensor ,
83-             max_kv_len ,
84-             bmm1_scale = k_scale  *  sm_scale ,
85-             bmm2_scale = v_scale ,
86-             out = output_trtllm ,
87-         )
88- 
89-     def  time_fn (fn , warmup = 10 , trials = 20 ):
90-         torch .cuda .synchronize ()
91-         start  =  torch .cuda .Event (enable_timing = True )
92-         end  =  torch .cuda .Event (enable_timing = True )
93-         times  =  []
94-         for  i  in  range (warmup ):
95-             fn ()
96-         for  i  in  range (trials ):
97-             start .record ()
98-             fn ()
99-             end .record ()
100-             torch .cuda .synchronize ()
101-             times .append (start .elapsed_time (end ))  # ms 
102-         return  sum (times ) /  len (times ), torch .std (torch .tensor (times ))
103- 
104-     # TRT Decode 
105-     trt_mean , trt_std  =  time_fn (trt_decode )
106- 
10791    kv_indptr  =  [0 ]
10892    kv_indices  =  []
10993    kv_last_page_lens  =  []
110-     for  i  in  range (num_seqs ):
111-         seq_len  =  kv_lens [i ]
94+     for  i  in  range (batch_size ):
95+         seq_len  =  seq_lens [i ]
11296        assert  seq_len  >  0 
113-         num_blocks  =  (seq_len  +  page_size  -  1 ) //  page_size 
97+         num_blocks  =  (seq_len  +  block_size  -  1 ) //  block_size 
11498        kv_indices .extend (block_tables [i , :num_blocks ])
11599        kv_indptr .append (kv_indptr [- 1 ] +  num_blocks )
116-         kv_last_page_len  =  seq_len  %  page_size 
100+         kv_last_page_len  =  seq_len  %  block_size 
117101        if  kv_last_page_len  ==  0 :
118-             kv_last_page_len  =  page_size 
102+             kv_last_page_len  =  block_size 
119103        kv_last_page_lens .append (kv_last_page_len )
120104
121105    kv_indptr  =  torch .tensor (kv_indptr , dtype = torch .int32 )
122106    kv_indices  =  torch .tensor (kv_indices , dtype = torch .int32 )
123107    kv_last_page_lens  =  torch .tensor (kv_last_page_lens , dtype = torch .int32 )
124- 
125-     output_baseline  =  torch .empty (q .shape , dtype = dtype )
108+     workspace_buffer  =  torch .zeros (1024  *  1024  *  1024 , dtype = torch .int8 )
126109
127110    wrapper  =  flashinfer .BatchDecodeWithPagedKVCacheWrapper (
128111        workspace_buffer ,
129112        kv_layout ,
130113        use_tensor_cores = ((num_qo_heads  //  num_kv_heads ) >  4 ),
131114    )
132- 
133115    wrapper .plan (
134116        kv_indptr ,
135117        kv_indices ,
136118        kv_last_page_lens ,
137119        num_qo_heads ,
138120        num_kv_heads ,
139-         head_dim ,
140-         page_size ,
121+         head_size ,
122+         block_size ,
141123        "NONE" ,
124+         sm_scale = sm_scale ,
142125        q_data_type = dtype ,
143-         kv_data_type = torch . float8_e4m3fn   if   kv_cache_dtype . startswith ( "fp8" )  else   dtype ,
126+         kv_data_type = dtype ,
144127    )
145128
129+     def  time_fn (fn , warmup = 10 , trials = 20 ):
130+         torch .cuda .synchronize ()
131+         start  =  torch .cuda .Event (enable_timing = True )
132+         end  =  torch .cuda .Event (enable_timing = True )
133+         times  =  []
134+         for  i  in  range (warmup ):
135+             fn ()
136+         for  i  in  range (trials ):
137+             start .record ()
138+             fn ()
139+             end .record ()
140+             torch .cuda .synchronize ()
141+             times .append (start .elapsed_time (end ))  # ms 
142+         return  sum (times ) /  len (times ), torch .std (torch .tensor (times ))
143+ 
144+     o_scale  =  1.0 
145+     output_baseline  =  torch .empty (ref_query .shape , dtype = dtype )
146+     output_trtllm  =  torch .empty (query .shape , dtype = o_quant_dtype )
147+ 
146148    def  baseline_decode ():
147-         return  wrapper .run (q , kv_cache , sm_scale , k_scale , v_scale , output_baseline )
149+         return  wrapper .run (ref_query , ref_kv_cache , out = output_baseline )
150+ 
151+     def  trtllm_decode ():
152+         return  flashinfer .decode .trtllm_batch_decode_with_kv_cache (
153+             query = query ,
154+             kv_cache = kv_cache ,
155+             workspace_buffer = workspace_buffer ,
156+             block_tables = block_tables ,
157+             seq_lens = seq_lens ,
158+             max_seq_len = max_seq_len ,
159+             bmm1_scale = q_scale  *  k_scale  *  sm_scale ,
160+             bmm2_scale = v_scale  /  o_scale ,
161+             out = output_trtllm ,
162+         )
148163
149164    baseline_mean , baseline_std  =  time_fn (baseline_decode )
165+     trtllm_mean , trtllm_std  =  time_fn (trtllm_decode )
150166
151167    # Calculate percentage speedup (positive means TRT is faster) 
152-     speedup_percent  =  (baseline_mean  -  trt_mean ) /  baseline_mean 
168+     speedup_percent  =  (baseline_mean  -  trtllm_mean ) /  baseline_mean 
153169
154170    print (
155-         f"\t { num_seqs } \t { max_seq_len } \t { trt_mean :.3f} \t { trt_std .item ():.3f}  
171+         f"\t { batch_size } \t { max_seq_len } \t { trtllm_mean :.3f} \t { trtllm_std .item ():.3f}  
156172        f"\t { baseline_mean :.3f} \t { baseline_std .item ():.3f} \t { speedup_percent :.3f}  
157173    )
158174
159175    # Return results for CSV writing 
160176    return  {
161-         "num_seqs " : num_seqs ,
162-         "trt_mean " : trt_mean ,
163-         "trt_std " : trt_std .item (),
177+         "batch_size " : batch_size ,
178+         "trtllm_mean " : trtllm_mean ,
179+         "trtllm_std " : trtllm_std .item (),
164180        "baseline_mean" : baseline_mean ,
165181        "baseline_std" : baseline_std .item (),
166182        "speedup_percent" : speedup_percent ,
167-         "q_dtype" : str (dtype ),
168-         "kv_cache_dtype" : kv_cache_dtype ,
169-         "page_size" : page_size ,
183+         "q_dtype" : str (q_quant_dtype ),
184+         "kv_cache_dtype" : str (kv_quant_dtype ),
185+         "output_dtype" : str (o_quant_dtype ),
186+         "block_size" : block_size ,
170187        "num_kv_heads" : num_kv_heads ,
171-         "head_dim " : head_dim ,
188+         "head_size " : head_size ,
172189        "max_seq_len" : max_seq_len ,
173190    }
174191
@@ -180,17 +197,18 @@ def write_results_to_csv(results, filename=None):
180197        filename  =  f"flashinfer_trtllm_benchmark_{ timestamp }  
181198
182199    fieldnames  =  [
183-         "num_seqs " ,
184-         "trt_mean " ,
185-         "trt_std " ,
200+         "batch_size " ,
201+         "trtllm_mean " ,
202+         "trtllm_std " ,
186203        "baseline_mean" ,
187204        "baseline_std" ,
188205        "speedup_percent" ,
189206        "q_dtype" ,
190207        "kv_cache_dtype" ,
191-         "page_size" ,
208+         "output_dtype" ,
209+         "block_size" ,
192210        "num_kv_heads" ,
193-         "head_dim " ,
211+         "head_size " ,
194212        "max_seq_len" ,
195213    ]
196214
@@ -209,45 +227,42 @@ def write_results_to_csv(results, filename=None):
209227
210228
211229if  __name__  ==  "__main__" :
212-     num_seqs  =  [1 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
230+     batch_sizes  =  [1 , 4 , 8 , 16 , 32 , 64 , 128 , 256 ]
213231    max_seq_lens  =  [1024 , 2048 , 4096 , 8192 , 16384 , 32768 , 65536 , 131072 ]
214232    all_results  =  []
215233
216-     print (
217-         "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: bfloat16, " 
218-         "output_dtype: bfloat16" 
219-     )
220-     print (
221-         "\t num_seqs\t max_seq_len\t trt_mean\t trt_std\t baseline_mean\t " 
222-         "baseline_std\t speedup_percent" 
223-     )
224-     for  max_seq_len  in  max_seq_lens :
225-         for  bs  in  num_seqs :
226-             result  =  benchmark_decode (
227-                 bs ,
228-                 max_seq_len ,
229-                 dtype = torch .bfloat16 ,
230-                 kv_cache_dtype = "auto" ,
231-             )
232-             all_results .append (result )
234+     dtype  =  torch .bfloat16 
235+     quant_dtypes  =  [
236+         # (q_quant_dtype, kv_quant_dtype, o_quant_dtype) 
237+         (None , None , None ),
238+         (None , FP8_DTYPE , None ),
239+         (FP8_DTYPE , FP8_DTYPE , FP8_DTYPE ),
240+     ]
233241
234-     print (
235-         "Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8, " 
236-         "output_dtype: bfloat16" 
237-     )
238-     print (
239-         "\t num_seqs\t max_seq_len\t trt_mean\t trt_std\t baseline_mean\t " 
240-         "baseline_std\t speedup_percent" 
241-     )
242-     for  max_seq_len  in  max_seq_lens :
243-         for  bs  in  num_seqs :
244-             result  =  benchmark_decode (
245-                 bs ,
246-                 max_seq_len ,
247-                 dtype = torch .bfloat16 ,
248-                 kv_cache_dtype = "fp8" ,
249-             )
250-             all_results .append (result )
242+     for  quant_dtype  in  quant_dtypes :
243+         q_quant_dtype , kv_quant_dtype , o_quant_dtype  =  quant_dtype 
244+         q_quant_dtype  =  q_quant_dtype  or  dtype 
245+         kv_quant_dtype  =  kv_quant_dtype  or  dtype 
246+         o_quant_dtype  =  o_quant_dtype  or  dtype 
247+ 
248+         print (
249+             f"Running benchmark for q_dtype = { q_quant_dtype }  
250+             f"kv_cache_dtype: { kv_quant_dtype }  
251+             f"output_dtype: { o_quant_dtype }  
252+         )
253+         print (
254+             "\t batch_size\t max_seq_len\t trtllm_mean\t trtllm_std\t baseline_mean\t " 
255+             "baseline_std\t speedup_percent" 
256+         )
257+         for  max_seq_len  in  max_seq_lens :
258+             for  bs  in  batch_sizes :
259+                 result  =  benchmark_decode (
260+                     dtype = dtype ,
261+                     quant_dtypes = quant_dtype ,
262+                     batch_size = bs ,
263+                     max_seq_len = max_seq_len ,
264+                 )
265+                 all_results .append (result )
251266
252267    # Write all results to CSV 
253268    write_results_to_csv (all_results )
0 commit comments