Skip to content

Commit 5defbd9

Browse files
williamwen42AlannaBurkesvekars
authored
Update torch.compile tutorial (#3363)
Update torch.compile tutorial for 2.8/new torch.compile programming model --------- Co-authored-by: Alanna Burke <[email protected]> Co-authored-by: Svetlana Karslioglu <[email protected]>
1 parent 2360b06 commit 5defbd9

File tree

4 files changed

+466
-408
lines changed

4 files changed

+466
-408
lines changed

compilers_index.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ control, as well as third-party backend solutions.
4040
:link: intermediate/torch_compile_tutorial.html
4141
:tags: Model-Optimization,torch.compile
4242

43+
.. customcarditem::
44+
:header: torch.compile End-to-End Tutorial
45+
:card_description: An example of applying torch.compile to a real model, demonstrating speedups.
46+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
47+
:link: intermediate/torch_compile_full_example.html
48+
:tags: Model-Optimization,torch.compile
49+
4350
.. customcarditem::
4451
:header: Compiled Autograd: Capturing a larger backward graph for torch.compile
4552
:card_description: Learn how to use compiled autograd to capture a larger backward graph.
@@ -177,6 +184,7 @@ control, as well as third-party backend solutions.
177184
:caption: torch.compile
178185

179186
intermediate/torch_compile_tutorial
187+
intermediate/torch_compile_full_example
180188
intermediate/compiled_autograd_tutorial
181189
intermediate/inductor_debug_cpu
182190
recipes/torch_compiler_set_stance_tutorial

index.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,13 @@ Welcome to PyTorch Tutorials
536536
:link: intermediate/torch_compile_tutorial.html
537537
:tags: Model-Optimization
538538

539+
.. customcarditem::
540+
:header: torch.compile End-to-End Tutorial
541+
:card_description: An example of applying torch.compile to a real model, demonstrating speedups.
542+
:image: _static/img/thumbnails/cropped/generic-pytorch-logo.png
543+
:link: intermediate/torch_compile_full_example.html
544+
:tags: Model-Optimization
545+
539546
.. customcarditem::
540547
:header: Building a Convolution/Batch Norm fuser in torch.compile
541548
:card_description: Build a simple pattern matcher pass that fuses batch norm into convolution to improve performance during inference.
Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
``torch.compile`` End-to-End Tutorial
5+
=================================
6+
**Author:** William Wen
7+
"""
8+
9+
######################################################################
10+
# ``torch.compile`` is the new way to speed up your PyTorch code!
11+
# ``torch.compile`` makes PyTorch code run faster by
12+
# JIT-compiling PyTorch code into optimized kernels,
13+
# while requiring minimal code changes.
14+
#
15+
# This tutorial covers an end-to-end example of training and evaluating a
16+
# real model with ``torch.compile``. For a gentle introduction to ``torch.compile``,
17+
# please check out `the introduction to torch.compile tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__.
18+
#
19+
# **Required pip Dependencies**
20+
#
21+
# - ``torch >= 2.0``
22+
# - ``torchvision``
23+
#
24+
# .. grid:: 2
25+
#
26+
# .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
27+
# :class-card: card-prerequisites
28+
#
29+
# * How to apply ``torch.compile`` to a real model
30+
# * ``torch.compile`` speedups on a real model
31+
# * ``torch.compile``'s first few iterations are expected to be slower due to compilation overhead
32+
#
33+
# .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
34+
# :class-card: card-prerequisites
35+
#
36+
# * `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__
37+
38+
# NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in
39+
# order to reproduce the speedup numbers shown below and documented elsewhere.
40+
41+
import torch
42+
import warnings
43+
44+
gpu_ok = False
45+
if torch.cuda.is_available():
46+
device_cap = torch.cuda.get_device_capability()
47+
if device_cap in ((7, 0), (8, 0), (9, 0)):
48+
gpu_ok = True
49+
50+
if not gpu_ok:
51+
warnings.warn(
52+
"GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
53+
"than expected."
54+
)
55+
56+
57+
######################################################################
58+
# Let's demonstrate how using ``torch.compile`` can speed up a real model.
59+
# We will compare standard eager mode and
60+
# ``torch.compile`` by evaluating and training a ``torchvision`` model on random data.
61+
#
62+
# Before we start, we need to define some utility functions.
63+
64+
65+
# Returns the result of running `fn()` and the time it took for `fn()` to run,
66+
# in seconds. We use CUDA events and synchronization for the most accurate
67+
# measurements.
68+
def timed(fn):
69+
start = torch.cuda.Event(enable_timing=True)
70+
end = torch.cuda.Event(enable_timing=True)
71+
start.record()
72+
result = fn()
73+
end.record()
74+
torch.cuda.synchronize()
75+
return result, start.elapsed_time(end) / 1000
76+
77+
78+
# Generates random input and targets data for the model, where `b` is
79+
# batch size.
80+
def generate_data(b):
81+
return (
82+
torch.randn(b, 3, 128, 128).to().cuda(),
83+
torch.randint(1000, (b,)).cuda(),
84+
)
85+
86+
87+
N_ITERS = 10
88+
89+
from torchvision.models import densenet121
90+
91+
92+
def init_model():
93+
return densenet121().cuda()
94+
95+
96+
######################################################################
97+
# First, let's compare inference.
98+
#
99+
# Note that in the call to ``torch.compile``, we have the additional
100+
# ``mode`` argument, which we will discuss below.
101+
102+
model = init_model()
103+
104+
# Note that we generally recommend directly compiling a torch.nn.Module by calling
105+
# its .compile() method.
106+
model_opt = init_model()
107+
model_opt.compile(mode="reduce-overhead")
108+
109+
inp = generate_data(16)[0]
110+
with torch.no_grad():
111+
print("eager:", timed(lambda: model(inp))[1])
112+
print("compile:", timed(lambda: model_opt(inp))[1])
113+
114+
######################################################################
115+
# Notice that ``torch.compile`` takes a lot longer to complete
116+
# compared to eager. This is because ``torch.compile`` compiles
117+
# the model into optimized kernels as it executes. In our example, the
118+
# structure of the model doesn't change, and so recompilation is not
119+
# needed. So if we run our optimized model several more times, we should
120+
# see a significant improvement compared to eager.
121+
122+
eager_times = []
123+
for i in range(N_ITERS):
124+
inp = generate_data(16)[0]
125+
with torch.no_grad():
126+
_, eager_time = timed(lambda: model(inp))
127+
eager_times.append(eager_time)
128+
print(f"eager eval time {i}: {eager_time}")
129+
130+
print("~" * 10)
131+
132+
compile_times = []
133+
for i in range(N_ITERS):
134+
inp = generate_data(16)[0]
135+
with torch.no_grad():
136+
_, compile_time = timed(lambda: model_opt(inp))
137+
compile_times.append(compile_time)
138+
print(f"compile eval time {i}: {compile_time}")
139+
print("~" * 10)
140+
141+
import numpy as np
142+
143+
eager_med = np.median(eager_times)
144+
compile_med = np.median(compile_times)
145+
speedup = eager_med / compile_med
146+
assert speedup > 1
147+
print(
148+
f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
149+
)
150+
print("~" * 10)
151+
152+
######################################################################
153+
# And indeed, we can see that running our model with ``torch.compile``
154+
# results in a significant speedup. Speedup mainly comes from reducing Python overhead and
155+
# GPU read/writes, and so the observed speedup may vary on factors such as model
156+
# architecture and batch size. For example, if a model's architecture is simple
157+
# and the amount of data is large, then the bottleneck would be
158+
# GPU compute and the observed speedup may be less significant.
159+
#
160+
# You may also see different speedup results depending on the chosen ``mode``
161+
# argument. The ``"reduce-overhead"`` mode uses CUDA graphs to further reduce
162+
# the overhead of Python. For your own models,
163+
# you may need to experiment with different modes to maximize speedup. You can
164+
# read more about modes `here <https://pytorch.org/get-started/pytorch-2.0/#user-experience>`__.
165+
#
166+
# You may might also notice that the second time we run our model with ``torch.compile`` is significantly
167+
# slower than the other runs, although it is much faster than the first run. This is because the ``"reduce-overhead"``
168+
# mode runs a few warm-up iterations for CUDA graphs.
169+
#
170+
# Now, let's consider comparing training.
171+
172+
model = init_model()
173+
opt = torch.optim.Adam(model.parameters())
174+
175+
176+
def train(mod, data):
177+
opt.zero_grad(True)
178+
pred = mod(data[0])
179+
loss = torch.nn.CrossEntropyLoss()(pred, data[1])
180+
loss.backward()
181+
opt.step()
182+
183+
184+
eager_times = []
185+
for i in range(N_ITERS):
186+
inp = generate_data(16)
187+
_, eager_time = timed(lambda: train(model, inp))
188+
eager_times.append(eager_time)
189+
print(f"eager train time {i}: {eager_time}")
190+
print("~" * 10)
191+
192+
model = init_model()
193+
opt = torch.optim.Adam(model.parameters())
194+
195+
# Note that because we are compiling a regular Python function, we do not
196+
# call any .compile() method.
197+
train_opt = torch.compile(train, mode="reduce-overhead")
198+
199+
compile_times = []
200+
for i in range(N_ITERS):
201+
inp = generate_data(16)
202+
_, compile_time = timed(lambda: train_opt(model, inp))
203+
compile_times.append(compile_time)
204+
print(f"compile train time {i}: {compile_time}")
205+
print("~" * 10)
206+
207+
eager_med = np.median(eager_times)
208+
compile_med = np.median(compile_times)
209+
speedup = eager_med / compile_med
210+
assert speedup > 1
211+
print(
212+
f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
213+
)
214+
print("~" * 10)
215+
216+
######################################################################
217+
# Again, we can see that ``torch.compile`` takes longer in the first
218+
# iteration, as it must compile the model, but in subsequent iterations, we see
219+
# significant speedups compared to eager.
220+
#
221+
# We remark that the speedup numbers presented in this tutorial are for
222+
# demonstration purposes only. Official speedup values can be seen at the
223+
# `TorchInductor performance dashboard <https://hud.pytorch.org/benchmark/compilers>`__.
224+
225+
######################################################################
226+
# Conclusion
227+
# ------------
228+
#
229+
# In this tutorial, we applied ``torch.compile`` to training and inference on a real model,
230+
# demonstrating speedups.
231+
#
232+
# Importantly, we note that the first few iterations of a compiled model
233+
# are slower than eager mode due to compilation overhead, but subsequent iterations are expected to
234+
# have speedups.
235+
#
236+
# For a gentle introduction to ``torch.compile``, please check out `the introduction to torch.compile tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__.
237+
#
238+
# To troubleshoot issues and to gain a deeper understanding of how to apply ``torch.compile`` to your code, check out `the torch.compile programming model <https://docs.pytorch.org/docs/main/compile/programming_model.html>`__.
239+
#
240+
# We hope that you will give ``torch.compile`` a try!

0 commit comments

Comments
 (0)