Using Nsight Systems to profile GPU workload

This topic describes a common workflow to profile workloads on the GPU using Nsight Systems.

As an example, let’s profile the forward, backward, and optimizer.step() methods using the resnet18 model from torchvision.

To annotate each part of the training we will use nvtx ranges via the torch.cuda.nvtx.range_push/.range_pop operations. These ranges work as a stack and can be nested.
Also, we are usually not interested in the first iteration, which might add overhead to the overall training due to memory allocations, cudnn benchmarking etc., thus we start the profiling after a few iterations via torch.cuda.cudart().cudaProfilerStart() and stop it at the end via .cudaProfilerStop().

A complete code snippet can be seen here:

import torch
import torch.nn as nn
import torchvision.models as models

# setup
device = 'cuda:0'
model = models.resnet18().to(device)
data = torch.randn(64, 3, 224, 224, device=device)
target = torch.randint(0, 1000, (64,), device=device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

nb_iters = 20
warmup_iters = 10
for i in range(nb_iters):
    optimizer.zero_grad()

    # start profiling after 10 warmup iterations
    if i == warmup_iters: torch.cuda.cudart().cudaProfilerStart()

    # push range for current iteration
    if i >= warmup_iters: torch.cuda.nvtx.range_push("iteration{}".format(i))

    # push range for forward
    if i >= warmup_iters: torch.cuda.nvtx.range_push("forward")
    output = model(data)
    if i >= warmup_iters: torch.cuda.nvtx.range_pop()

    loss = criterion(output, target)

    if i >= warmup_iters: torch.cuda.nvtx.range_push("backward")
    loss.backward()
    if i >= warmup_iters: torch.cuda.nvtx.range_pop()

    if i >= warmup_iters: torch.cuda.nvtx.range_push("opt.step()")
    optimizer.step()
    if i >= warmup_iters: torch.cuda.nvtx.range_pop()

    # pop iteration range
    if i >= warmup_iters: torch.cuda.nvtx.range_pop()

torch.cuda.cudart().cudaProfilerStop()

To create the profile I’m using Nsight System 2020.4.3.7 via the CLI.
The CLI options for nsys profile can be found here and my “standard” command as well as the one used to create the profile for this example is:

nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s cpu  --capture-range=cudaProfilerApi --stop-on-range-end=true --cudabacktrace=true -x true -o my_profile python main.py

(Thanks to Michael Carilli to create this cmd a while ago :wink: )

The arguments can be found in the linked CLI docs. A few interesting arguments are:

  • -t cuda,nvtx,osrt,cudnn,cublas: selects the APIs to be traced
  • --capture-range=cudaProfilerApi and --stop-on-range-end=true: profiling will start only when cudaProfilerStart API is invoked / Stop profiling when the capture range ends.
  • --cudabacktrace=true: When tracing CUDA APIs, enable the collection of a backtrace when a CUDA API is invoked. (allows you to hover over a call and get the backtrace) You can also specify thresholds in ns which defines a threshold which the kernel must execute before backtraces are collected.

Note, if you are used to nvprof, try to copy/paste your nvprof cmd via:

nsys nvprof [options]

and Nsight Systems would try to translate the legacy nvprof command.

For this example the profile would look like this on a TitanV:

You can see the execution in different Python threads as well as different APIs executing kernels on the device.
We can zoom into a specific iteration and check the backtrace option, which is often helpful to isolate specific regressions and see which functions were invoked.

Feel free to add more useful arguments or any tips and tricks using Nsight Systems.

18 Likes

I have a gist with my preferred nsys commands for different scenarios and an explanation of each option.

There’s a gotcha to be aware of with Piotr’s command line

nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s cpu  --capture-range=cudaProfilerApi --stop-on-range-end=true --cudabacktrace=true -x true -o my_profile python main.py

CPU sampling (-s cpu) is great for getting backtraces that shows where particular timeline calls originate in the code, but also inflates CPU overhead (sometimes dramatically, 2X or more). So with -s cpu, you shouldn’t expect a realistic view of CPU whitespace.

To get a better idea of bottlenecks, you should first create a profile without CPU sampling (-s none), eg

nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s none -o nsight_report -f true -x true python script.py args...

Manual torch.cuda.nvtx.range_push/pop calls in your script are very helpful to orient yourself and immediately see where your code spends an unexpected amount of time (forward? backward? optimizer? between iterations, which usually means dataloader?).

7 Likes

Hi,

Is this supposed to work with multi-GPU scripts? I am using DataParallel in my case.

Right now I have the following feedback:

The application terminated before the collection started. No report was generated.
	Collection canceled.

Thanks in advance

Hi,

is it necessary to conduct device synchronize if am interested in CUDA Kernel time for each op of the layer?

In newer PyTorch/nsys versions I don’t see cuDNN info anymore.
Can anyone help with that?

I also posted on NVIDIA forums here:

If I build PyTorch from sources, then there is cuDNN info in nsys traces.

We can zoom into a specific iteration and check the backtrace option, which is often helpful to isolate specific regressions and see which functions were invoked.

I can’t quite manage to get the backtrace feature to work, most of the kernels have no call stack when hovering over them. Did anyone encounter the issue? Posted on nvidia forums as well to see what’s wrong Call stack is visible/captured only for some CUDA kernels (broken backtraces) - Profiling Linux Targets - NVIDIA Developer Forums

I’m wondering if I should not build pytorch from source with -fno-omit-frame-pointer for this to work.

I have noticed that when profiling my networks with nsys, the cpu is always running 100% during loss.backward(). The graph looks similar to the one in the first picture here. I was hoping someone could explain what is happening here, because I am trying to find bottlenecks in my routine as I don’t seem to be gettng expected performance boosts when increasing the batch size, using AMP, and so forth. It also doesn’t seem to matter whether I preload my data onto the gpu or use workers to transfer it from the cpu at runtime. In both cases, I only have appreciable cpu activity during loss.backward(). I was under the impression that if I preloaded all of my data and only retrieved shuffled batches at training time via, eg. torch.gather(), I would not be using the cpu at all. Any help? Thanks!

Hey folks, thanks for starting this thread. It proofed very helpful in trying to profile my application. I am currently following the PyTorch lightning guide: Find bottlenecks in your code (intermediate) — PyTorch Lightning 2.0.0 documentation and use nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s none --capture-range-end stop --capture-range=cudaProfilerApi --cudabacktrace=true -x true poetry run python main_graph.py as the command to collect the emitted information. However, I am getting a 300MB file just when doing one step of training and a 1.7GB file doing thirty steps. Can anyone give me a hint if this is due to a huge misconfiguration on my side, or owed to the fact of using PyTorch lighting and PyTorch geometric?
Any input would be greatly appreciated :hugs:

In nsight 2023.3.1 , one option is outdated.

nsys profile -w true -t cuda,nvtx,cudnn,cublas --capture-range=cudaProfilerApi --stop-on-range-end=true -x true -o 512 python ladies_e2e.py
unrecognised option '--stop-on-range-end=true'

usage: nsys profile [<args>] [application] [<application args>]
Try 'nsys profile --help' for more information.