Skip to content

Commit 8b15be5

Browse files
yushangdifacebook-github-bot
authored andcommitted
Add debug handle to inductor provenance tracking (pytorch#161110)
Summary: Pull Request resolved: pytorch#161110 Use debug handle on kernel names to distinguish different calls to the same kernel. Previous kernel name: kernel_name New kernel name: kernel_name:debug_handle We add the debug handle to the tlparse artifacts: `inductor_provenance_tracking_node_mappings` and `inductor_provenance_tracking_kernel_stack_traces`. We also add debug handles in the comments of the generated code so we can map to them in the provenance tracking highlighter tool: pytorch/tlparse#134 Example output code: ``` # Topologically Sorted Source Nodes: [y, z], Original ATen: [aten.addmm, aten.gelu] # [Provenance debug handles] triton_poi_fused_addmm_gelu_2:3 stream0 = get_raw_stream(0) triton_poi_fused_addmm_gelu_2.run(buf4, primals_5, 300, stream=stream0) ``` The debug handles will also be used by downstream profilers such as zoomer. Test Plan: ``` buck run mode/opt fbcode//caffe2/test/inductor:provenance_tracing ``` Rollback Plan: Reviewed By: angelayi Differential Revision: D78994959
1 parent a4fb657 commit 8b15be5

File tree

6 files changed

+135
-78
lines changed

6 files changed

+135
-78
lines changed

test/inductor/test_provenance_tracing.py

Lines changed: 62 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import re
99
import shutil
1010
import tempfile
11-
import unittest
1211
import zipfile
1312
from pathlib import Path
1413

@@ -19,11 +18,11 @@
1918
create_kernel_information_json,
2019
create_mapping_pre_post_grad_nodes,
2120
create_node_mapping_kernel_to_post_grad,
21+
reset_inductor_kernel_provenance_debug_handle,
2222
)
2323
from torch._inductor.fx_passes.post_grad import post_grad_passes
2424
from torch._inductor.test_case import run_tests, TestCase
2525
from torch._inductor.virtualized import V
26-
from torch.testing._internal.inductor_utils import HAS_GPU
2726
from torch.testing._internal.triton_utils import requires_cuda_and_triton
2827

2928

@@ -94,11 +93,12 @@ class TestProvenanceTracingArtifact(TestCase):
9493
corresponding "inductor triton kernel node" is expected.
9594
"""
9695

97-
def _check_provenance_tracing_artifact(self, filepath, expected_data):
96+
def _check_provenance_tracing_kernel_to_post_grad(self, filepath, expected_data):
9897
self.assertTrue(filepath.is_dir())
99-
filename = Path(filepath) / "inductor_generated_kernel_to_post_grad_nodes.json"
98+
filename = Path(filepath) / "inductor_provenance_tracking_node_mappings.json"
10099
with open(filename) as f:
101100
actual_data = json.load(f)
101+
actual_data = actual_data["cppCodeToPost"]
102102
# check that the generated provenance tracing artifact is expected
103103
self.assertEqual(sorted(actual_data.items()), sorted(expected_data.items()))
104104

@@ -116,10 +116,11 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
116116
c = torch.randn(10, 30, device=device)
117117
example_inputs = (a, b, c)
118118

119-
model = Model()
119+
model = Model().to(device)
120120
filepath = None
121121

122122
for backend in ["aot_inductor", "inductor"]:
123+
reset_inductor_kernel_provenance_debug_handle()
123124
try:
124125
with config.patch(
125126
{
@@ -142,28 +143,12 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
142143
self.assertTrue(m)
143144
filepath = Path(m.group(1))
144145
if device == "cuda":
145-
expected_data = {
146-
"triton_poi_fused_mul_0": ["mul"],
147-
"triton_poi_fused_addmm_gelu_1": [
148-
"mul_3",
149-
"mul_1",
150-
"add_tensor",
151-
"add",
152-
"erf",
153-
"mul_2",
154-
],
155-
}
156-
if backend == "aot_inductor":
157-
expected_data["aoti_torch_cuda_mm_out"] = ["mm_default"]
158-
else:
159-
expected_data["extern_kernels.mm"] = ["mm_default"]
160-
self._check_provenance_tracing_artifact(filepath, expected_data)
161146
expected_mapping = [
162147
(
163148
"cppCodeToPost",
164149
{
165-
"triton_poi_fused_mul_0": ["mul"],
166-
"triton_poi_fused_addmm_gelu_1": [
150+
"triton_poi_fused_mul_0:1": ["mul"],
151+
"triton_poi_fused_addmm_gelu_1:2": [
167152
"mul_3",
168153
"mul_1",
169154
"add_tensor",
@@ -176,13 +161,13 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
176161
(
177162
"postToCppCode",
178163
{
179-
"mul": ["triton_poi_fused_mul_0"],
180-
"mul_3": ["triton_poi_fused_addmm_gelu_1"],
181-
"mul_1": ["triton_poi_fused_addmm_gelu_1"],
182-
"add_tensor": ["triton_poi_fused_addmm_gelu_1"],
183-
"add": ["triton_poi_fused_addmm_gelu_1"],
184-
"erf": ["triton_poi_fused_addmm_gelu_1"],
185-
"mul_2": ["triton_poi_fused_addmm_gelu_1"],
164+
"mul": ["triton_poi_fused_mul_0:1"],
165+
"mul_3": ["triton_poi_fused_addmm_gelu_1:2"],
166+
"mul_1": ["triton_poi_fused_addmm_gelu_1:2"],
167+
"add_tensor": ["triton_poi_fused_addmm_gelu_1:2"],
168+
"add": ["triton_poi_fused_addmm_gelu_1:2"],
169+
"erf": ["triton_poi_fused_addmm_gelu_1:2"],
170+
"mul_2": ["triton_poi_fused_addmm_gelu_1:2"],
186171
},
187172
),
188173
(
@@ -208,15 +193,19 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
208193
),
209194
]
210195
if backend == "aot_inductor":
211-
expected_mapping[0][1]["aoti_torch_cuda_mm_out"] = [
196+
expected_mapping[0][1]["aoti_torch_cuda_mm_out:3"] = [
212197
"mm_default"
213198
]
214199
expected_mapping[1][1]["mm_default"] = [
215-
"aoti_torch_cuda_mm_out"
200+
"aoti_torch_cuda_mm_out:3"
216201
]
217202
else:
218-
expected_mapping[0][1]["extern_kernels.mm"] = ["mm_default"]
219-
expected_mapping[1][1]["mm_default"] = ["extern_kernels.mm"]
203+
expected_mapping[0][1]["extern_kernels.mm:3"] = [
204+
"mm_default"
205+
]
206+
expected_mapping[1][1]["mm_default"] = [
207+
"extern_kernels.mm:3"
208+
]
220209
self._check_provenance_tracking_node_mappings(
221210
filepath, expected_mapping
222211
)
@@ -225,9 +214,9 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
225214
# check the inductor kernel to post grad nodes mapping is expected for cpu
226215
if backend == "aot_inductor":
227216
expected_data = {
228-
"cpp_fused_mul_0": ["mul"],
229-
"aoti_torch_cpu_addmm_out": ["addmm"],
230-
"cpp_fused_gelu_1": [
217+
"cpp_fused_mul_0:1": ["mul"],
218+
"aoti_torch_cpu_addmm_out:3": ["addmm"],
219+
"cpp_fused_gelu_1:2": [
231220
"mul_3",
232221
"mul_1",
233222
"add",
@@ -238,17 +227,19 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
238227
else:
239228
# backend == "inductor"
240229
expected_data = {
241-
"cpp_fused_mul_0": ["mul"],
242-
"cpp_fused_gelu_1": [
230+
"cpp_fused_mul_0:1": ["mul"],
231+
"cpp_fused_gelu_1:2": [
243232
"mul_3",
244233
"mul_1",
245234
"add",
246235
"erf",
247236
"mul_2",
248237
],
249-
"extern_kernels.addmm": ["addmm"],
238+
"extern_kernels.addmm:3": ["addmm"],
250239
}
251-
self._check_provenance_tracing_artifact(filepath, expected_data)
240+
self._check_provenance_tracing_kernel_to_post_grad(
241+
filepath, expected_data
242+
)
252243

253244
finally:
254245
if filepath:
@@ -258,7 +249,6 @@ def _test_triton_kernel_to_post_grad_tracing(self, device):
258249
def test_triton_kernel_to_post_grad_tracing_cuda(self):
259250
self._test_triton_kernel_to_post_grad_tracing(device="cuda")
260251

261-
@unittest.skipIf(HAS_GPU, "the test is only for cpu")
262252
def test_triton_kernel_to_post_grad_tracing_cpu(self):
263253
self._test_triton_kernel_to_post_grad_tracing(device="cpu")
264254

@@ -274,6 +264,7 @@ def test_triton_kernel_to_post_grad_tracing_extern_kernel(self):
274264
filepath = None
275265

276266
for backend in ["aot_inductor", "inductor"]:
267+
reset_inductor_kernel_provenance_debug_handle()
277268
try:
278269
with config.patch(
279270
{
@@ -297,15 +288,17 @@ def test_triton_kernel_to_post_grad_tracing_extern_kernel(self):
297288
filepath = Path(m.group(1))
298289
if backend == "inductor":
299290
expected_data = {
300-
"extern_kernels.addmm": ["addmm"],
291+
"extern_kernels.addmm:1": ["addmm"],
301292
}
302293
else:
303294
# backend = aot_inductor
304295
expected_data = {
305-
"aoti_torch_cuda_addmm_out": ["addmm"],
306-
"triton_poi_fused_0": ["_tensor_constant1"],
296+
"aoti_torch_cuda_addmm_out:2": ["addmm"],
297+
"triton_poi_fused_0:1": ["_tensor_constant1"],
307298
}
308-
self._check_provenance_tracing_artifact(filepath, expected_data)
299+
self._check_provenance_tracing_kernel_to_post_grad(
300+
filepath, expected_data
301+
)
309302
finally:
310303
if filepath:
311304
shutil.rmtree(filepath)
@@ -319,6 +312,7 @@ def _test_pt_tracing_combo_kernel(self, backend):
319312
example_inputs = (a, b, c)
320313

321314
model = Model2()
315+
reset_inductor_kernel_provenance_debug_handle()
322316

323317
with config.patch(
324318
{
@@ -342,8 +336,8 @@ def _test_pt_tracing_combo_kernel(self, backend):
342336
m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
343337
self.assertTrue(m)
344338
filepath = Path(m.group(1)).resolve()
345-
expected_data = {"triton_poi_fused_0": ["relu", "sigmoid", "tanh"]}
346-
self._check_provenance_tracing_artifact(filepath, expected_data)
339+
expected_data = {"triton_poi_fused_0:1": ["relu", "sigmoid", "tanh"]}
340+
self._check_provenance_tracing_kernel_to_post_grad(filepath, expected_data)
347341

348342
@requires_cuda_and_triton
349343
def test_triton_kernel_to_post_grad_tracing_combo_kernel(self):
@@ -556,25 +550,28 @@ def test_tlparse_kernel_stack_traces(self):
556550
example_inputs = (x, a, b, c)
557551

558552
expected = {
559-
"triton_poi_fused_addmm_relu_sigmoid_threshold_backward_0": [
553+
"triton_poi_fused_addmm_relu_sigmoid_threshold_backward_0:1": [
560554
"x = self.sigmoid(x)",
561555
"x = self.fc1(x)",
562556
"x = self.relu(x)",
563557
],
564-
"triton_poi_fused_mul_1": [
558+
"triton_poi_fused_mul_1:2": [
565559
"d = a * 3.14",
566560
],
567-
"triton_poi_fused_addmm_gelu_2": [
561+
"triton_poi_fused_addmm_gelu_2:3": [
568562
"z = torch.nn.functional.gelu(y)",
569563
"y = torch.addmm(c, d, b)",
570564
],
571-
"extern_kernels.mm": [
565+
"extern_kernels.mm:4": [
572566
"x = self.fc1(x)",
567+
],
568+
"extern_kernels.mm:5": [
573569
"y = torch.addmm(c, d, b)",
574570
],
575571
}
576572

577573
with self._setup_provenance_capture() as payload_buffer:
574+
reset_inductor_kernel_provenance_debug_handle()
578575
compiled = torch.compile(model)
579576
compiled(*example_inputs)
580577
payload_content = payload_buffer.getvalue().strip()
@@ -623,6 +620,7 @@ def test_kernel_information_generation(self):
623620
with tempfile.TemporaryDirectory() as temp_dir:
624621
ep = torch.export.export(model, inputs, strict=False)
625622
pt2_file = os.path.join(temp_dir, "model.pt2")
623+
reset_inductor_kernel_provenance_debug_handle()
626624
torch._inductor.aoti_compile_and_package(ep, package_path=pt2_file)
627625

628626
# Extract and check kernel_information.json exists in the package
@@ -646,7 +644,7 @@ def test_kernel_information_generation(self):
646644
kernel_info = json.load(f)
647645

648646
expected = {
649-
"triton_poi_fused_addmm_relu_sigmoid_0": {
647+
"triton_poi_fused_addmm_relu_sigmoid_0:1": {
650648
"stack_traces": [
651649
"x = self.sigmoid(x)",
652650
"x = self.fc1(x)",
@@ -655,14 +653,14 @@ def test_kernel_information_generation(self):
655653
"post_grad_nodes": ["sigmoid", "relu", "add_tensor_1"],
656654
"pre_grad_nodes": ["sigmoid", "relu", "linear"],
657655
},
658-
"triton_poi_fused_mul_1": {
656+
"triton_poi_fused_mul_1:2": {
659657
"stack_traces": [
660658
"d = a * 3.14",
661659
],
662660
"post_grad_nodes": ["mul"],
663661
"pre_grad_nodes": ["mul"],
664662
},
665-
"triton_poi_fused_addmm_gelu_2": {
663+
"triton_poi_fused_addmm_gelu_2:3": {
666664
"stack_traces": [
667665
"z = torch.nn.functional.gelu(y)",
668666
"y = torch.addmm(c, d, b)",
@@ -677,13 +675,19 @@ def test_kernel_information_generation(self):
677675
],
678676
"pre_grad_nodes": ["gelu", "addmm"],
679677
},
680-
"aoti_torch_cuda_mm_out": {
678+
"aoti_torch_cuda_mm_out:4": {
681679
"stack_traces": [
682680
"x = self.fc1(x)",
681+
],
682+
"post_grad_nodes": ["mm_default_1"],
683+
"pre_grad_nodes": ["linear"],
684+
},
685+
"aoti_torch_cuda_mm_out:5": {
686+
"stack_traces": [
683687
"y = torch.addmm(c, d, b)",
684688
],
685-
"post_grad_nodes": ["mm_default_1", "mm_default"],
686-
"pre_grad_nodes": ["linear", "addmm"],
689+
"post_grad_nodes": ["mm_default"],
690+
"pre_grad_nodes": ["addmm"],
687691
},
688692
}
689693

torch/_inductor/codegen/cpp.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5392,10 +5392,6 @@ def define_kernel(self, src_code, nodes, kernel_args=None):
53925392
else ""
53935393
)
53945394
kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
5395-
# below add provenance tracing info for cpu CppKernel types
5396-
if config.trace.provenance_tracking_level != 0:
5397-
set_kernel_post_grad_provenance_tracing(nodes, kernel_name)
5398-
53995395
kernel_decl_name = kernel_name if V.graph.cpp_wrapper else "kernel"
54005396
src_code = src_code.replace(str(Placeholder.KERNEL_NAME), kernel_decl_name)
54015397
src_code = src_code.replace(str(Placeholder.DESCRIPTIVE_NAME), kernel_name)
@@ -5434,7 +5430,15 @@ def flush(self):
54345430
kernel_name = self.define_kernel(
54355431
src_code, self.kernel_group.scheduled_nodes
54365432
)
5437-
self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name)
5433+
# below add provenance tracing info for cpu CppKernel types
5434+
debug_handle: Optional[int] = None
5435+
if config.trace.provenance_tracking_level != 0:
5436+
debug_handle = set_kernel_post_grad_provenance_tracing(
5437+
self.kernel_group.scheduled_nodes, kernel_name
5438+
)
5439+
self.kernel_group.call_kernel(
5440+
V.graph.wrapper_code, kernel_name, debug_handle=debug_handle
5441+
)
54385442
self.reset_kernel_group()
54395443
self._set_flush_status(False)
54405444

@@ -5509,10 +5513,14 @@ def codegen_group(self, name=None) -> str:
55095513
code.splice(self.loops_code)
55105514
return code.getvalue()
55115515

5512-
def call_kernel(self, wrapper, kernel_name):
5516+
def call_kernel(self, wrapper, kernel_name, debug_handle: Optional[int] = None):
55135517
_, call_args, arg_types = self.args.cpp_argdefs()
55145518
wrapper.generate_kernel_call(
5515-
kernel_name, call_args, triton=False, arg_types=arg_types
5519+
kernel_name,
5520+
call_args,
5521+
triton=False,
5522+
arg_types=arg_types,
5523+
debug_handle=debug_handle,
55165524
)
55175525

55185526

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,7 @@ def generate_c_shim_extern_kernel_call(
12191219
device: str,
12201220
*,
12211221
debug_args: Optional[list[str]] = None,
1222+
debug_handle: Optional[int] = None,
12221223
) -> None:
12231224
"""debug_args kwarg allows CppWrapperCpuArrayRef to pass in wrapped arguments in
12241225
place of args while preserving debug printer output."""
@@ -1235,14 +1236,16 @@ def generate_c_shim_extern_kernel_call(
12351236
]
12361237
with debug_printer_manager:
12371238
shim_fn = self.get_c_shim_func_name(kernel, device)
1239+
self.write_provenance_debug_handle(shim_fn, debug_handle)
12381240
shim_fn_codes = (
12391241
f"AOTI_TORCH_ERROR_CODE_CHECK({shim_fn}({', '.join(args)}));"
12401242
)
12411243
if enable_kernel_profile:
1244+
debug_handle_str = "" if debug_handle is None else f":{debug_handle}"
12421245
shim_fn_codes = textwrap.dedent(
12431246
f"""
12441247
{{
1245-
RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}", nullptr);
1248+
RAIIAtenRecordFunctionHandle record_{shim_fn}_("{shim_fn}{debug_handle_str}", nullptr);
12461249
{shim_fn_codes}
12471250
}}
12481251
"""
@@ -1338,6 +1341,7 @@ def _generate_extern_kernel_out_helper(
13381341
out_view: Optional[str],
13391342
args: list[str],
13401343
device: str,
1344+
debug_handle: Optional[int] = None,
13411345
) -> None:
13421346
if out_view:
13431347
out_name = f"{out}_as_strided"
@@ -1346,7 +1350,9 @@ def _generate_extern_kernel_out_helper(
13461350
else:
13471351
args.insert(0, out)
13481352

1349-
self.generate_c_shim_extern_kernel_call(kernel, args, device)
1353+
self.generate_c_shim_extern_kernel_call(
1354+
kernel, args, device, debug_handle=debug_handle
1355+
)
13501356

13511357
def generate_scatter_fallback(
13521358
self,

0 commit comments

Comments
 (0)