CUDAGraphs In PyTorch 2.0
TL;DR
New Cudagraph Implementation improves HuggingFace Perf 12%, and Memory from .88% to 1.13% .
If you are using torch.compile, especially to lower the entire model, cudagraphs may provide speedups. Even if the model has dynamism !
Try: torch.compile(mode="reduce-overhead")
CUDAGraph Background
For a longer background on CUDAGraphs, read accelerating pytorch with CUDAGraphs.
CUDA Graphs, which made its debut in CUDA 10, let a series of CUDA kernels to be defined and encapsulated as a single unit, i.e., a graph of operations, rather than a sequence of individually-launched operations. It provides a mechanism to launch multiple GPU operations through a single CPU operation, and hence reduces the launching overheads.
CUDA Graphs can give large speedups, especially for models with high cpu overhead. There are a number of limitations from requiring the same kernels to be run with the same arguments and dependencies, and memory addresses.
- Control Flow is not possible
- Kernels which trigger host to device syncs (such as .item()) errors
- All input arguments to kernels are fixed to what they were recorded
- CUDA Memory addresses are fixed, however the values of the of the memory at those addresses can change
- No Essential CPU ops or CPU side effects
Pytorch CUDAGraph Integration
PyTorch provides a convenience wrapper around CUDAGraphs that handles a couple of tricky interactions with PyTorch’s caching allocator.
The CachingAllocator uses a separate memory pool for all the new allocations. During CUDAGraph recording, memory is accounted for, allocated, and freed exactly as during eager run.
On replay, just the kernels are invoked, and there are no changes to the allocator. Subsequent to initial recording, the allocator does not know which memory is actively being used in user programs.
NOTE: Using a separate memory pool between eager allocations and cudagraph allocations may increase the memory of your program if there is substantial memory allocated to both.
Make Graphed Callables
Make Graphed Callables is a PyTorch Abstraction to share a single memory pool over a series of callables. Graphed Callables takes advantage of the fact that on CUDA Graph recording, memory is exactly accounted for by the caching allocator to safely share memory between separate CUDA Graph recordings. In each invocation, outputs are preserved as live memory, preventing one callable from overwriting the live memory of another. Graphed Callables can only be invoked in a single order; memory addresses from the first run are burned into the second, and so forth.
TorchDynamo Previous CUDA Graphs Integration
Running with cudagraph_trees=False
does not reuse memory across separate graph captures, which can lead to large memory regressions. Even for a model that has no graph breaks this has issues. The forward and backward are separate graph captures, so the memory pools for forward and backward are not shared. In particular, memory for activations that are saved in the forward cannot be reclaimed in the backward. The general case is shown below. Parameters are not copied into the cuda graph memory pool because they are assumed to be static addresses.
CUDAGraph Trees Integration
Like Graph Callables, CUDA Graph Trees use a single memory pool across all graph captures. However, instead of requiring a single sequence of invocations, CUDA Graph Trees create separate trees of cuda graph captures.
- A CUDAGraphNode contains a single recording of a torchinductor compilation into a CUDA graph. It contains metadata about the inputs, outputs, and parent/child relationships to other nodes in the graph tree.
- A path through a CUDA graph tree is a unique path taken through the model. The path keeps track of all outputs at each node along the path.
- A CUDAGraphNode specializes the memory allocation patterns and tensor lifetimes from when it was recorded in order to check if it can re-execute. It checks that the same tensors are still alive from its parent node, and that any tensors that died after recording die again on re-execution, and that the path to the root is the same on execution as on recording.
Let’s take a look at an illustrative example:
In this example, there are two separate paths that we make through the function: 1 → 2 → 4, or 1 → 3 → 4.
We share all of the memory in a single memory pool between separate recordings by building up a tape of CUDA Graph recordings, in this instance, 1 → 2 → 4. We add invariants to ensure that memory is always in the same location as it were recorded, and no live tensors exist in user programs that might be overwritten.
- Same constraints from CUDA Graphs apply: same kernels must be invoked with the same arguments (static sizes, addresses, etc)
- The same pattern of memory must be observed between recording and replay: if a tensor output of one graph dies subsequent to another graph during recording, it must also do so during replay.
- Live memory in the cuda pool forces a dependence between two recordings
- These recordings can only be invoked in a single order 1 - > 2 → 4
All of the memory is shared in a single memory pool, so there is no additional memory overhead compared to eager. Now, what happens if we were to hit a new path, and run Graph 3 ?
Graph 1 gets replayed, and then we hit Graph 3 which we have not yet recorded. On graph replays the private memory pool is not updated, so y is not reflected in the allocator. Without care we would overwrite it. To support reusing the same memory pool after replaying other graphs, we checkpoint the memory pool back to its state at the end of graph 1. Checkpointing both updates the CUDACaching allocator to reflect the currently live tensors, and adds a deleter function to the live tensors so that when they die, the allocator will mark their memory as free. Now that our live tensors are reflected in the caching allocator, we are safe to run a new graph.
First we would hit the optimized, CUDAGraph.replay() path that we have already recorded in graph 1. Then we would hit Graph 3. Just as before, we will need to warm up the graph once before recording. On the warmup run, the memory addresses are not fixed, so graph 4 will also fallback to the inductor, non-cudagraph invocation.
Now, the second time we hit graph 3 we are warmed up and ready to record. We record graph 2, and then record graph 4 again, since the input memory addresses have changed. This creates a tree of CUDA Graph recordings, all using the same memory pool.
Dynamic Shapes Support:
Because CudaGraph Trees uses a single memory pool for each new capture, it can work well with dynamic shapes. recently landed changes so while we will record a new CUDAGraph for each new shape, we will only compile a single torchinductor graph. Recording a new CUDA Graph is an order of magnitude faster than a new inductor compilation so this significantly reduces compilation time. You should see speed ups so long as you can warm up your set of dynamic shapes for inference, or for training have limited enough dynamism that you start to see shapes you have already compiled. For cm3_leon inference, this led to a 6x speedup.
Limitations
Because CUDA Graph fixes memory addresses, CUDA Graphs do not have a great way of handling live tensors from a previous invocation.
Let’s say we are benchmarking running inference with the following code:
In the previous CUDA Graph implementation, the output from the first invocation would be overwritten by the second invocation. In CUDA Graph Trees, naively, the live output of the first run would force a dependency between the first run and the second run, and we would never hit the optimized cudagraph replay invocation. CUDA Graph Trees will ignore outputs from a previous run of torch.compile and not force a memory dependency. In training, we will not ignore outputs from a previous run of torch.compile if we have pending backwards that have not been invoked.
CUDAGraphs Trees tries to detect the beginning of the training loop, but if it gets it wrong you will get a hard error. You can call torch._inductor.cudagraph_mark_step_begin()
for that case.
Comparisons
Footguns | Separate CudaGraph | CUDAGraph Trees |
Memory Can Increase | On each graph compilation - new sizes, etc. | If you are also running non cudagraph memory |
Recordings | On any new invocation of a graph | Will re-record on any new, unique path you take through your program |
Footguns | Invocation of one graph will overwrite prior invocation | Cannot persist memory between separate runs through your model - one training loop, or one run of inference |
Results:
Today, comparing cudagraph_trees on OSS benchmarks to previous impl gives the following changes:
Hugging Face | Timm | Torchbench | |
Memory Compression Ratio (higher is better) | .88 -> 1.13 | .88 -> 1.00 | .61 -> .74 |
Performance | 1.73 -> 1.85 | 1.74x | 1.74x |
Note: TB memory is tested with extremely low memory and is disproportionately affected by a 250 MB cache clearing allocation in triton autotuning which is non-linear wrt/ rest of model’s memory usage.
Next Steps:
Support for FSDP:
As part of the PT2 efforts to trace through FSDP, we should also make sure that FSDP composes with cudagraphs.
Potential Other Directions:
Compiler Padding For Dynamic Shapes: In the future we could also consider automatically padding dynamic shapes to a user-specified multiple, so that we would reuse existing cudagraphs more frequently instead of hitting another recording.
Memory Planning:
Memory planning as it is usually implemented by the compiler runs into many difficulties where fallback operators must provide out variants, and allocations internal to an operator are not visible to memory planning. Doing planning at the Caching Allocator, in a similar fashion to mobile’s profiling allocator, avoids these difficulties. CUDAGraph is especially well suited for this type of planning because the pattern of allocation and deallocation of memory is fixed, and CUDAGraph itself is an api for getting this pattern without touching any real memory. You can use an initial recording, throw it out, and re-record with a preplanned allocation scheme.
Partitioning around Unsupported Operators
Today we will not cudagraph an inductor graph if it has cpu operators in it or contains an unsupported operator. We can partition the inductor graph into multiple subgraphs to successfully CUDA Graph
Thanks to Zachary Devito, Edward Yang, Jason Ansel, Alban Desmaison, and Peng Wu for their help along the way, without which this wouldn’t have been possible.