This topic describes a common workflow to profile workloads on the GPU using Nsight Systems.
As an example, let’s profile the forward
, backward
, and optimizer.step()
methods using the resnet18
model from torchvision
.
To annotate each part of the training we will use nvtx
ranges via the torch.cuda.nvtx.range_push/.range_pop
operations. These ranges work as a stack and can be nested.
Also, we are usually not interested in the first iteration, which might add overhead to the overall training due to memory allocations, cudnn benchmarking etc., thus we start the profiling after a few iterations via torch.cuda.cudart().cudaProfilerStart()
and stop it at the end via .cudaProfilerStop()
.
A complete code snippet can be seen here:
import torch
import torch.nn as nn
import torchvision.models as models
# setup
device = 'cuda:0'
model = models.resnet18().to(device)
data = torch.randn(64, 3, 224, 224, device=device)
target = torch.randint(0, 1000, (64,), device=device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
nb_iters = 20
warmup_iters = 10
for i in range(nb_iters):
optimizer.zero_grad()
# start profiling after 10 warmup iterations
if i == warmup_iters: torch.cuda.cudart().cudaProfilerStart()
# push range for current iteration
if i >= warmup_iters: torch.cuda.nvtx.range_push("iteration{}".format(i))
# push range for forward
if i >= warmup_iters: torch.cuda.nvtx.range_push("forward")
output = model(data)
if i >= warmup_iters: torch.cuda.nvtx.range_pop()
loss = criterion(output, target)
if i >= warmup_iters: torch.cuda.nvtx.range_push("backward")
loss.backward()
if i >= warmup_iters: torch.cuda.nvtx.range_pop()
if i >= warmup_iters: torch.cuda.nvtx.range_push("opt.step()")
optimizer.step()
if i >= warmup_iters: torch.cuda.nvtx.range_pop()
# pop iteration range
if i >= warmup_iters: torch.cuda.nvtx.range_pop()
torch.cuda.cudart().cudaProfilerStop()
To create the profile I’m using Nsight System 2020.4.3.7 via the CLI.
The CLI options for nsys profile
can be found here and my “standard” command as well as the one used to create the profile for this example is:
nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s cpu --capture-range=cudaProfilerApi --stop-on-range-end=true --cudabacktrace=true -x true -o my_profile python main.py
(Thanks to Michael Carilli to create this cmd a while ago )
The arguments can be found in the linked CLI docs. A few interesting arguments are:
-
-t cuda,nvtx,osrt,cudnn,cublas
: selects the APIs to be traced -
--capture-range=cudaProfilerApi
and--stop-on-range-end=true
: profiling will start only whencudaProfilerStart
API is invoked / Stop profiling when the capture range ends. -
--cudabacktrace=true
: When tracing CUDA APIs, enable the collection of a backtrace when a CUDA API is invoked. (allows you to hover over a call and get the backtrace) You can also specify thresholds inns
which defines a threshold which the kernel must execute before backtraces are collected.
Note, if you are used to nvprof
, try to copy/paste your nvprof
cmd via:
nsys nvprof [options]
and Nsight Systems would try to translate the legacy nvprof
command.
For this example the profile would look like this on a TitanV:
You can see the execution in different Python threads as well as different APIs executing kernels on the device.
We can zoom into a specific iteration and check the backtrace option, which is often helpful to isolate specific regressions and see which functions were invoked.
Feel free to add more useful arguments or any tips and tricks using Nsight Systems.