Using Nsight Systems to profile GPU workload

This topic describes a common workflow to profile workloads on the GPU using Nsight Systems.

As an example, let’s profile the forward, backward, and optimizer.step() methods using the resnet18 model from torchvision.

To annotate each part of the training we will use nvtx ranges via the torch.cuda.nvtx.range_push/.range_pop operations. These ranges work as a stack and can be nested.
Also, we are usually not interested in the first iteration, which might add overhead to the overall training due to memory allocations, cudnn benchmarking etc., thus we start the profiling after a few iterations via torch.cuda.cudart().cudaProfilerStart() and stop it at the end via .cudaProfilerStop().

A complete code snippet can be seen here:

import torch
import torch.nn as nn
import torchvision.models as models

# setup
device = 'cuda:0'
model = models.resnet18().to(device)
data = torch.randn(64, 3, 224, 224, device=device)
target = torch.randint(0, 1000, (64,), device=device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

nb_iters = 20
warmup_iters = 10
for i in range(nb_iters):
    optimizer.zero_grad()

    # start profiling after 10 warmup iterations
    if i == warmup_iters: torch.cuda.cudart().cudaProfilerStart()

    # push range for current iteration
    if i >= warmup_iters: torch.cuda.nvtx.range_push("iteration{}".format(i))

    # push range for forward
    if i >= warmup_iters: torch.cuda.nvtx.range_push("forward")
    output = model(data)
    if i >= warmup_iters: torch.cuda.nvtx.range_pop()

    loss = criterion(output, target)

    if i >= warmup_iters: torch.cuda.nvtx.range_push("backward")
    loss.backward()
    if i >= warmup_iters: torch.cuda.nvtx.range_pop()

    if i >= warmup_iters: torch.cuda.nvtx.range_push("opt.step()")
    optimizer.step()
    if i >= warmup_iters: torch.cuda.nvtx.range_pop()

    # pop iteration range
    if i >= warmup_iters: torch.cuda.nvtx.range_pop()

torch.cuda.cudart().cudaProfilerStop()

To create the profile I’m using Nsight System 2020.4.3.7 via the CLI.
The CLI options for nsys profile can be found here and my “standard” command as well as the one used to create the profile for this example is:

nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s cpu  --capture-range=cudaProfilerApi --stop-on-range-end=true --cudabacktrace=true -x true -o my_profile python main.py

(Thanks to Michael Carilli to create this cmd a while ago :wink: )

The arguments can be found in the linked CLI docs. A few interesting arguments are:

  • -t cuda,nvtx,osrt,cudnn,cublas: selects the APIs to be traced
  • --capture-range=cudaProfilerApi and --stop-on-range-end=true: profiling will start only when cudaProfilerStart API is invoked / Stop profiling when the capture range ends.
  • --cudabacktrace=true: When tracing CUDA APIs, enable the collection of a backtrace when a CUDA API is invoked. (allows you to hover over a call and get the backtrace) You can also specify thresholds in ns which defines a threshold which the kernel must execute before backtraces are collected.

Note, if you are used to nvprof, try to copy/paste your nvprof cmd via:

nsys nvprof [options]

and Nsight Systems would try to translate the legacy nvprof command.

For this example the profile would look like this on a TitanV:

You can see the execution in different Python threads as well as different APIs executing kernels on the device.
We can zoom into a specific iteration and check the backtrace option, which is often helpful to isolate specific regressions and see which functions were invoked.

Feel free to add more useful arguments or any tips and tricks using Nsight Systems.

8 Likes

I have a gist with my preferred nsys commands for different scenarios and an explanation of each option.

There’s a gotcha to be aware of with Piotr’s command line

nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s cpu  --capture-range=cudaProfilerApi --stop-on-range-end=true --cudabacktrace=true -x true -o my_profile python main.py

CPU sampling (-s cpu) is great for getting backtraces that shows where particular timeline calls originate in the code, but also inflates CPU overhead (sometimes dramatically, 2X or more). So with -s cpu, you shouldn’t expect a realistic view of CPU whitespace.

To get a better idea of bottlenecks, you should first create a profile without CPU sampling (-s none), eg

nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s none -o nsight_report -f true -x true python script.py args...

Manual torch.cuda.nvtx.range_push/pop calls in your script are very helpful to orient yourself and immediately see where your code spends an unexpected amount of time (forward? backward? optimizer? between iterations, which usually means dataloader?).

6 Likes

Hi,

Is this supposed to work with multi-GPU scripts? I am using DataParallel in my case.

Right now I have the following feedback:

The application terminated before the collection started. No report was generated.
	Collection canceled.

Thanks in advance