8
8
import re
9
9
import shutil
10
10
import tempfile
11
- import unittest
12
11
import zipfile
13
12
from pathlib import Path
14
13
19
18
create_kernel_information_json ,
20
19
create_mapping_pre_post_grad_nodes ,
21
20
create_node_mapping_kernel_to_post_grad ,
21
+ reset_inductor_kernel_provenance_debug_handle ,
22
22
)
23
23
from torch ._inductor .fx_passes .post_grad import post_grad_passes
24
24
from torch ._inductor .test_case import run_tests , TestCase
25
25
from torch ._inductor .virtualized import V
26
- from torch .testing ._internal .inductor_utils import HAS_GPU
27
26
from torch .testing ._internal .triton_utils import requires_cuda_and_triton
28
27
29
28
@@ -94,11 +93,12 @@ class TestProvenanceTracingArtifact(TestCase):
94
93
corresponding "inductor triton kernel node" is expected.
95
94
"""
96
95
97
- def _check_provenance_tracing_artifact (self , filepath , expected_data ):
96
+ def _check_provenance_tracing_kernel_to_post_grad (self , filepath , expected_data ):
98
97
self .assertTrue (filepath .is_dir ())
99
- filename = Path (filepath ) / "inductor_generated_kernel_to_post_grad_nodes .json"
98
+ filename = Path (filepath ) / "inductor_provenance_tracking_node_mappings .json"
100
99
with open (filename ) as f :
101
100
actual_data = json .load (f )
101
+ actual_data = actual_data ["cppCodeToPost" ]
102
102
# check that the generated provenance tracing artifact is expected
103
103
self .assertEqual (sorted (actual_data .items ()), sorted (expected_data .items ()))
104
104
@@ -116,10 +116,11 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
116
116
c = torch .randn (10 , 30 , device = device )
117
117
example_inputs = (a , b , c )
118
118
119
- model = Model ()
119
+ model = Model (). to ( device )
120
120
filepath = None
121
121
122
122
for backend in ["aot_inductor" , "inductor" ]:
123
+ reset_inductor_kernel_provenance_debug_handle ()
123
124
try :
124
125
with config .patch (
125
126
{
@@ -142,28 +143,12 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
142
143
self .assertTrue (m )
143
144
filepath = Path (m .group (1 ))
144
145
if device == "cuda" :
145
- expected_data = {
146
- "triton_poi_fused_mul_0" : ["mul" ],
147
- "triton_poi_fused_addmm_gelu_1" : [
148
- "mul_3" ,
149
- "mul_1" ,
150
- "add_tensor" ,
151
- "add" ,
152
- "erf" ,
153
- "mul_2" ,
154
- ],
155
- }
156
- if backend == "aot_inductor" :
157
- expected_data ["aoti_torch_cuda_mm_out" ] = ["mm_default" ]
158
- else :
159
- expected_data ["extern_kernels.mm" ] = ["mm_default" ]
160
- self ._check_provenance_tracing_artifact (filepath , expected_data )
161
146
expected_mapping = [
162
147
(
163
148
"cppCodeToPost" ,
164
149
{
165
- "triton_poi_fused_mul_0" : ["mul" ],
166
- "triton_poi_fused_addmm_gelu_1" : [
150
+ "triton_poi_fused_mul_0:1 " : ["mul" ],
151
+ "triton_poi_fused_addmm_gelu_1:2 " : [
167
152
"mul_3" ,
168
153
"mul_1" ,
169
154
"add_tensor" ,
@@ -176,13 +161,13 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
176
161
(
177
162
"postToCppCode" ,
178
163
{
179
- "mul" : ["triton_poi_fused_mul_0" ],
180
- "mul_3" : ["triton_poi_fused_addmm_gelu_1" ],
181
- "mul_1" : ["triton_poi_fused_addmm_gelu_1" ],
182
- "add_tensor" : ["triton_poi_fused_addmm_gelu_1" ],
183
- "add" : ["triton_poi_fused_addmm_gelu_1" ],
184
- "erf" : ["triton_poi_fused_addmm_gelu_1" ],
185
- "mul_2" : ["triton_poi_fused_addmm_gelu_1" ],
164
+ "mul" : ["triton_poi_fused_mul_0:1 " ],
165
+ "mul_3" : ["triton_poi_fused_addmm_gelu_1:2 " ],
166
+ "mul_1" : ["triton_poi_fused_addmm_gelu_1:2 " ],
167
+ "add_tensor" : ["triton_poi_fused_addmm_gelu_1:2 " ],
168
+ "add" : ["triton_poi_fused_addmm_gelu_1:2 " ],
169
+ "erf" : ["triton_poi_fused_addmm_gelu_1:2 " ],
170
+ "mul_2" : ["triton_poi_fused_addmm_gelu_1:2 " ],
186
171
},
187
172
),
188
173
(
@@ -208,15 +193,19 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
208
193
),
209
194
]
210
195
if backend == "aot_inductor" :
211
- expected_mapping [0 ][1 ]["aoti_torch_cuda_mm_out" ] = [
196
+ expected_mapping [0 ][1 ]["aoti_torch_cuda_mm_out:3 " ] = [
212
197
"mm_default"
213
198
]
214
199
expected_mapping [1 ][1 ]["mm_default" ] = [
215
- "aoti_torch_cuda_mm_out"
200
+ "aoti_torch_cuda_mm_out:3 "
216
201
]
217
202
else :
218
- expected_mapping [0 ][1 ]["extern_kernels.mm" ] = ["mm_default" ]
219
- expected_mapping [1 ][1 ]["mm_default" ] = ["extern_kernels.mm" ]
203
+ expected_mapping [0 ][1 ]["extern_kernels.mm:3" ] = [
204
+ "mm_default"
205
+ ]
206
+ expected_mapping [1 ][1 ]["mm_default" ] = [
207
+ "extern_kernels.mm:3"
208
+ ]
220
209
self ._check_provenance_tracking_node_mappings (
221
210
filepath , expected_mapping
222
211
)
@@ -225,9 +214,9 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
225
214
# check the inductor kernel to post grad nodes mapping is expected for cpu
226
215
if backend == "aot_inductor" :
227
216
expected_data = {
228
- "cpp_fused_mul_0" : ["mul" ],
229
- "aoti_torch_cpu_addmm_out" : ["addmm" ],
230
- "cpp_fused_gelu_1" : [
217
+ "cpp_fused_mul_0:1 " : ["mul" ],
218
+ "aoti_torch_cpu_addmm_out:3 " : ["addmm" ],
219
+ "cpp_fused_gelu_1:2 " : [
231
220
"mul_3" ,
232
221
"mul_1" ,
233
222
"add" ,
@@ -238,17 +227,19 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
238
227
else :
239
228
# backend == "inductor"
240
229
expected_data = {
241
- "cpp_fused_mul_0" : ["mul" ],
242
- "cpp_fused_gelu_1" : [
230
+ "cpp_fused_mul_0:1 " : ["mul" ],
231
+ "cpp_fused_gelu_1:2 " : [
243
232
"mul_3" ,
244
233
"mul_1" ,
245
234
"add" ,
246
235
"erf" ,
247
236
"mul_2" ,
248
237
],
249
- "extern_kernels.addmm" : ["addmm" ],
238
+ "extern_kernels.addmm:3 " : ["addmm" ],
250
239
}
251
- self ._check_provenance_tracing_artifact (filepath , expected_data )
240
+ self ._check_provenance_tracing_kernel_to_post_grad (
241
+ filepath , expected_data
242
+ )
252
243
253
244
finally :
254
245
if filepath :
@@ -258,7 +249,6 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
258
249
def test_triton_kernel_to_post_grad_tracing_cuda (self ):
259
250
self ._test_triton_kernel_to_post_grad_tracing (device = "cuda" )
260
251
261
- @unittest .skipIf (HAS_GPU , "the test is only for cpu" )
262
252
def test_triton_kernel_to_post_grad_tracing_cpu (self ):
263
253
self ._test_triton_kernel_to_post_grad_tracing (device = "cpu" )
264
254
@@ -274,6 +264,7 @@ def test_triton_kernel_to_post_grad_tracing_extern_kernel(self):
274
264
filepath = None
275
265
276
266
for backend in ["aot_inductor" , "inductor" ]:
267
+ reset_inductor_kernel_provenance_debug_handle ()
277
268
try :
278
269
with config .patch (
279
270
{
@@ -297,15 +288,17 @@ def test_triton_kernel_to_post_grad_tracing_extern_kernel(self):
297
288
filepath = Path (m .group (1 ))
298
289
if backend == "inductor" :
299
290
expected_data = {
300
- "extern_kernels.addmm" : ["addmm" ],
291
+ "extern_kernels.addmm:1 " : ["addmm" ],
301
292
}
302
293
else :
303
294
# backend = aot_inductor
304
295
expected_data = {
305
- "aoti_torch_cuda_addmm_out" : ["addmm" ],
306
- "triton_poi_fused_0" : ["_tensor_constant1" ],
296
+ "aoti_torch_cuda_addmm_out:2 " : ["addmm" ],
297
+ "triton_poi_fused_0:1 " : ["_tensor_constant1" ],
307
298
}
308
- self ._check_provenance_tracing_artifact (filepath , expected_data )
299
+ self ._check_provenance_tracing_kernel_to_post_grad (
300
+ filepath , expected_data
301
+ )
309
302
finally :
310
303
if filepath :
311
304
shutil .rmtree (filepath )
@@ -319,6 +312,7 @@ def _test_pt_tracing_combo_kernel(self, backend):
319
312
example_inputs = (a , b , c )
320
313
321
314
model = Model2 ()
315
+ reset_inductor_kernel_provenance_debug_handle ()
322
316
323
317
with config .patch (
324
318
{
@@ -342,8 +336,8 @@ def _test_pt_tracing_combo_kernel(self, backend):
342
336
m = re .match (r"WARNING.* debug trace: (.*)" , cm .output [0 ])
343
337
self .assertTrue (m )
344
338
filepath = Path (m .group (1 )).resolve ()
345
- expected_data = {"triton_poi_fused_0" : ["relu" , "sigmoid" , "tanh" ]}
346
- self ._check_provenance_tracing_artifact (filepath , expected_data )
339
+ expected_data = {"triton_poi_fused_0:1 " : ["relu" , "sigmoid" , "tanh" ]}
340
+ self ._check_provenance_tracing_kernel_to_post_grad (filepath , expected_data )
347
341
348
342
@requires_cuda_and_triton
349
343
def test_triton_kernel_to_post_grad_tracing_combo_kernel (self ):
@@ -556,25 +550,28 @@ def test_tlparse_kernel_stack_traces(self):
556
550
example_inputs = (x , a , b , c )
557
551
558
552
expected = {
559
- "triton_poi_fused_addmm_relu_sigmoid_threshold_backward_0" : [
553
+ "triton_poi_fused_addmm_relu_sigmoid_threshold_backward_0:1 " : [
560
554
"x = self.sigmoid(x)" ,
561
555
"x = self.fc1(x)" ,
562
556
"x = self.relu(x)" ,
563
557
],
564
- "triton_poi_fused_mul_1" : [
558
+ "triton_poi_fused_mul_1:2 " : [
565
559
"d = a * 3.14" ,
566
560
],
567
- "triton_poi_fused_addmm_gelu_2" : [
561
+ "triton_poi_fused_addmm_gelu_2:3 " : [
568
562
"z = torch.nn.functional.gelu(y)" ,
569
563
"y = torch.addmm(c, d, b)" ,
570
564
],
571
- "extern_kernels.mm" : [
565
+ "extern_kernels.mm:4 " : [
572
566
"x = self.fc1(x)" ,
567
+ ],
568
+ "extern_kernels.mm:5" : [
573
569
"y = torch.addmm(c, d, b)" ,
574
570
],
575
571
}
576
572
577
573
with self ._setup_provenance_capture () as payload_buffer :
574
+ reset_inductor_kernel_provenance_debug_handle ()
578
575
compiled = torch .compile (model )
579
576
compiled (* example_inputs )
580
577
payload_content = payload_buffer .getvalue ().strip ()
@@ -623,6 +620,7 @@ def test_kernel_information_generation(self):
623
620
with tempfile .TemporaryDirectory () as temp_dir :
624
621
ep = torch .export .export (model , inputs , strict = False )
625
622
pt2_file = os .path .join (temp_dir , "model.pt2" )
623
+ reset_inductor_kernel_provenance_debug_handle ()
626
624
torch ._inductor .aoti_compile_and_package (ep , package_path = pt2_file )
627
625
628
626
# Extract and check kernel_information.json exists in the package
@@ -646,7 +644,7 @@ def test_kernel_information_generation(self):
646
644
kernel_info = json .load (f )
647
645
648
646
expected = {
649
- "triton_poi_fused_addmm_relu_sigmoid_0" : {
647
+ "triton_poi_fused_addmm_relu_sigmoid_0:1 " : {
650
648
"stack_traces" : [
651
649
"x = self.sigmoid(x)" ,
652
650
"x = self.fc1(x)" ,
@@ -655,14 +653,14 @@ def test_kernel_information_generation(self):
655
653
"post_grad_nodes" : ["sigmoid" , "relu" , "add_tensor_1" ],
656
654
"pre_grad_nodes" : ["sigmoid" , "relu" , "linear" ],
657
655
},
658
- "triton_poi_fused_mul_1" : {
656
+ "triton_poi_fused_mul_1:2 " : {
659
657
"stack_traces" : [
660
658
"d = a * 3.14" ,
661
659
],
662
660
"post_grad_nodes" : ["mul" ],
663
661
"pre_grad_nodes" : ["mul" ],
664
662
},
665
- "triton_poi_fused_addmm_gelu_2" : {
663
+ "triton_poi_fused_addmm_gelu_2:3 " : {
666
664
"stack_traces" : [
667
665
"z = torch.nn.functional.gelu(y)" ,
668
666
"y = torch.addmm(c, d, b)" ,
@@ -677,13 +675,19 @@ def test_kernel_information_generation(self):
677
675
],
678
676
"pre_grad_nodes" : ["gelu" , "addmm" ],
679
677
},
680
- "aoti_torch_cuda_mm_out" : {
678
+ "aoti_torch_cuda_mm_out:4 " : {
681
679
"stack_traces" : [
682
680
"x = self.fc1(x)" ,
681
+ ],
682
+ "post_grad_nodes" : ["mm_default_1" ],
683
+ "pre_grad_nodes" : ["linear" ],
684
+ },
685
+ "aoti_torch_cuda_mm_out:5" : {
686
+ "stack_traces" : [
683
687
"y = torch.addmm(c, d, b)" ,
684
688
],
685
- "post_grad_nodes" : ["mm_default_1" , " mm_default" ],
686
- "pre_grad_nodes" : ["linear" , " addmm" ],
689
+ "post_grad_nodes" : ["mm_default" ],
690
+ "pre_grad_nodes" : ["addmm" ],
687
691
},
688
692
}
689
693
0 commit comments