|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | + |
| 3 | +""" |
| 4 | +``torch.compile`` End-to-End Tutorial |
| 5 | +================================= |
| 6 | +**Author:** William Wen |
| 7 | +""" |
| 8 | + |
| 9 | +###################################################################### |
| 10 | +# ``torch.compile`` is the new way to speed up your PyTorch code! |
| 11 | +# ``torch.compile`` makes PyTorch code run faster by |
| 12 | +# JIT-compiling PyTorch code into optimized kernels, |
| 13 | +# while requiring minimal code changes. |
| 14 | +# |
| 15 | +# This tutorial covers an end-to-end example of training and evaluating a |
| 16 | +# real model with ``torch.compile``. For a gentle introduction to ``torch.compile``, |
| 17 | +# please check out `the introduction to torch.compile tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__. |
| 18 | +# |
| 19 | +# **Required pip Dependencies** |
| 20 | +# |
| 21 | +# - ``torch >= 2.0`` |
| 22 | +# - ``torchvision`` |
| 23 | +# |
| 24 | +# .. grid:: 2 |
| 25 | +# |
| 26 | +# .. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn |
| 27 | +# :class-card: card-prerequisites |
| 28 | +# |
| 29 | +# * How to apply ``torch.compile`` to a real model |
| 30 | +# * ``torch.compile`` speedups on a real model |
| 31 | +# * ``torch.compile``'s first few iterations are expected to be slower due to compilation overhead |
| 32 | +# |
| 33 | +# .. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites |
| 34 | +# :class-card: card-prerequisites |
| 35 | +# |
| 36 | +# * `Introduction to torch.compile <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__ |
| 37 | + |
| 38 | +# NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in |
| 39 | +# order to reproduce the speedup numbers shown below and documented elsewhere. |
| 40 | + |
| 41 | +import torch |
| 42 | +import warnings |
| 43 | + |
| 44 | +gpu_ok = False |
| 45 | +if torch.cuda.is_available(): |
| 46 | + device_cap = torch.cuda.get_device_capability() |
| 47 | + if device_cap in ((7, 0), (8, 0), (9, 0)): |
| 48 | + gpu_ok = True |
| 49 | + |
| 50 | +if not gpu_ok: |
| 51 | + warnings.warn( |
| 52 | + "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower " |
| 53 | + "than expected." |
| 54 | + ) |
| 55 | + |
| 56 | + |
| 57 | +###################################################################### |
| 58 | +# Let's demonstrate how using ``torch.compile`` can speed up a real model. |
| 59 | +# We will compare standard eager mode and |
| 60 | +# ``torch.compile`` by evaluating and training a ``torchvision`` model on random data. |
| 61 | +# |
| 62 | +# Before we start, we need to define some utility functions. |
| 63 | + |
| 64 | + |
| 65 | +# Returns the result of running `fn()` and the time it took for `fn()` to run, |
| 66 | +# in seconds. We use CUDA events and synchronization for the most accurate |
| 67 | +# measurements. |
| 68 | +def timed(fn): |
| 69 | + start = torch.cuda.Event(enable_timing=True) |
| 70 | + end = torch.cuda.Event(enable_timing=True) |
| 71 | + start.record() |
| 72 | + result = fn() |
| 73 | + end.record() |
| 74 | + torch.cuda.synchronize() |
| 75 | + return result, start.elapsed_time(end) / 1000 |
| 76 | + |
| 77 | + |
| 78 | +# Generates random input and targets data for the model, where `b` is |
| 79 | +# batch size. |
| 80 | +def generate_data(b): |
| 81 | + return ( |
| 82 | + torch.randn(b, 3, 128, 128).to().cuda(), |
| 83 | + torch.randint(1000, (b,)).cuda(), |
| 84 | + ) |
| 85 | + |
| 86 | + |
| 87 | +N_ITERS = 10 |
| 88 | + |
| 89 | +from torchvision.models import densenet121 |
| 90 | + |
| 91 | + |
| 92 | +def init_model(): |
| 93 | + return densenet121().cuda() |
| 94 | + |
| 95 | + |
| 96 | +###################################################################### |
| 97 | +# First, let's compare inference. |
| 98 | +# |
| 99 | +# Note that in the call to ``torch.compile``, we have the additional |
| 100 | +# ``mode`` argument, which we will discuss below. |
| 101 | + |
| 102 | +model = init_model() |
| 103 | + |
| 104 | +# Note that we generally recommend directly compiling a torch.nn.Module by calling |
| 105 | +# its .compile() method. |
| 106 | +model_opt = init_model() |
| 107 | +model_opt.compile(mode="reduce-overhead") |
| 108 | + |
| 109 | +inp = generate_data(16)[0] |
| 110 | +with torch.no_grad(): |
| 111 | + print("eager:", timed(lambda: model(inp))[1]) |
| 112 | + print("compile:", timed(lambda: model_opt(inp))[1]) |
| 113 | + |
| 114 | +###################################################################### |
| 115 | +# Notice that ``torch.compile`` takes a lot longer to complete |
| 116 | +# compared to eager. This is because ``torch.compile`` compiles |
| 117 | +# the model into optimized kernels as it executes. In our example, the |
| 118 | +# structure of the model doesn't change, and so recompilation is not |
| 119 | +# needed. So if we run our optimized model several more times, we should |
| 120 | +# see a significant improvement compared to eager. |
| 121 | + |
| 122 | +eager_times = [] |
| 123 | +for i in range(N_ITERS): |
| 124 | + inp = generate_data(16)[0] |
| 125 | + with torch.no_grad(): |
| 126 | + _, eager_time = timed(lambda: model(inp)) |
| 127 | + eager_times.append(eager_time) |
| 128 | + print(f"eager eval time {i}: {eager_time}") |
| 129 | + |
| 130 | +print("~" * 10) |
| 131 | + |
| 132 | +compile_times = [] |
| 133 | +for i in range(N_ITERS): |
| 134 | + inp = generate_data(16)[0] |
| 135 | + with torch.no_grad(): |
| 136 | + _, compile_time = timed(lambda: model_opt(inp)) |
| 137 | + compile_times.append(compile_time) |
| 138 | + print(f"compile eval time {i}: {compile_time}") |
| 139 | +print("~" * 10) |
| 140 | + |
| 141 | +import numpy as np |
| 142 | + |
| 143 | +eager_med = np.median(eager_times) |
| 144 | +compile_med = np.median(compile_times) |
| 145 | +speedup = eager_med / compile_med |
| 146 | +assert speedup > 1 |
| 147 | +print( |
| 148 | + f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x" |
| 149 | +) |
| 150 | +print("~" * 10) |
| 151 | + |
| 152 | +###################################################################### |
| 153 | +# And indeed, we can see that running our model with ``torch.compile`` |
| 154 | +# results in a significant speedup. Speedup mainly comes from reducing Python overhead and |
| 155 | +# GPU read/writes, and so the observed speedup may vary on factors such as model |
| 156 | +# architecture and batch size. For example, if a model's architecture is simple |
| 157 | +# and the amount of data is large, then the bottleneck would be |
| 158 | +# GPU compute and the observed speedup may be less significant. |
| 159 | +# |
| 160 | +# You may also see different speedup results depending on the chosen ``mode`` |
| 161 | +# argument. The ``"reduce-overhead"`` mode uses CUDA graphs to further reduce |
| 162 | +# the overhead of Python. For your own models, |
| 163 | +# you may need to experiment with different modes to maximize speedup. You can |
| 164 | +# read more about modes `here <https://pytorch.org/get-started/pytorch-2.0/#user-experience>`__. |
| 165 | +# |
| 166 | +# You may might also notice that the second time we run our model with ``torch.compile`` is significantly |
| 167 | +# slower than the other runs, although it is much faster than the first run. This is because the ``"reduce-overhead"`` |
| 168 | +# mode runs a few warm-up iterations for CUDA graphs. |
| 169 | +# |
| 170 | +# Now, let's consider comparing training. |
| 171 | + |
| 172 | +model = init_model() |
| 173 | +opt = torch.optim.Adam(model.parameters()) |
| 174 | + |
| 175 | + |
| 176 | +def train(mod, data): |
| 177 | + opt.zero_grad(True) |
| 178 | + pred = mod(data[0]) |
| 179 | + loss = torch.nn.CrossEntropyLoss()(pred, data[1]) |
| 180 | + loss.backward() |
| 181 | + opt.step() |
| 182 | + |
| 183 | + |
| 184 | +eager_times = [] |
| 185 | +for i in range(N_ITERS): |
| 186 | + inp = generate_data(16) |
| 187 | + _, eager_time = timed(lambda: train(model, inp)) |
| 188 | + eager_times.append(eager_time) |
| 189 | + print(f"eager train time {i}: {eager_time}") |
| 190 | +print("~" * 10) |
| 191 | + |
| 192 | +model = init_model() |
| 193 | +opt = torch.optim.Adam(model.parameters()) |
| 194 | + |
| 195 | +# Note that because we are compiling a regular Python function, we do not |
| 196 | +# call any .compile() method. |
| 197 | +train_opt = torch.compile(train, mode="reduce-overhead") |
| 198 | + |
| 199 | +compile_times = [] |
| 200 | +for i in range(N_ITERS): |
| 201 | + inp = generate_data(16) |
| 202 | + _, compile_time = timed(lambda: train_opt(model, inp)) |
| 203 | + compile_times.append(compile_time) |
| 204 | + print(f"compile train time {i}: {compile_time}") |
| 205 | +print("~" * 10) |
| 206 | + |
| 207 | +eager_med = np.median(eager_times) |
| 208 | +compile_med = np.median(compile_times) |
| 209 | +speedup = eager_med / compile_med |
| 210 | +assert speedup > 1 |
| 211 | +print( |
| 212 | + f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x" |
| 213 | +) |
| 214 | +print("~" * 10) |
| 215 | + |
| 216 | +###################################################################### |
| 217 | +# Again, we can see that ``torch.compile`` takes longer in the first |
| 218 | +# iteration, as it must compile the model, but in subsequent iterations, we see |
| 219 | +# significant speedups compared to eager. |
| 220 | +# |
| 221 | +# We remark that the speedup numbers presented in this tutorial are for |
| 222 | +# demonstration purposes only. Official speedup values can be seen at the |
| 223 | +# `TorchInductor performance dashboard <https://hud.pytorch.org/benchmark/compilers>`__. |
| 224 | + |
| 225 | +###################################################################### |
| 226 | +# Conclusion |
| 227 | +# ------------ |
| 228 | +# |
| 229 | +# In this tutorial, we applied ``torch.compile`` to training and inference on a real model, |
| 230 | +# demonstrating speedups. |
| 231 | +# |
| 232 | +# Importantly, we note that the first few iterations of a compiled model |
| 233 | +# are slower than eager mode due to compilation overhead, but subsequent iterations are expected to |
| 234 | +# have speedups. |
| 235 | +# |
| 236 | +# For a gentle introduction to ``torch.compile``, please check out `the introduction to torch.compile tutorial <https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html>`__. |
| 237 | +# |
| 238 | +# To troubleshoot issues and to gain a deeper understanding of how to apply ``torch.compile`` to your code, check out `the torch.compile programming model <https://docs.pytorch.org/docs/main/compile/programming_model.html>`__. |
| 239 | +# |
| 240 | +# We hope that you will give ``torch.compile`` a try! |
0 commit comments