Understanding CUDAGraph Trees

I was listening to the podcast by @ezyang about CUDAGraph Trees. There, he mentions the memory bloat when using CUDA Graphs — which led to the design of CUDA Graph Trees. I was thinking of some clarification regarding the approach.

In the eager CUDAGraph approach (not using torch.compile), we could deal with this memory bloat by Sharing memory across captures.

The example shown there is as follows:

g1 = torch.cuda.CUDAGraph()
g2 = torch.cuda.CUDAGraph()

# (create static inputs for g1 and g2, run warmups of their workloads...)

# Captures g1
with torch.cuda.graph(g1):
    static_out_1 = g1_workload(static_in_1)

# Captures g2, hinting that g2 may share a memory pool with g1
with torch.cuda.graph(g2, pool=g1.pool()):
    static_out_2 = g2_workload(static_in_2)

static_in_1.copy_(real_data_1)
static_in_2.copy_(real_data_2)
g1.replay()
g2.replay()

After the first CUDAGraph capture in g1, we have a situation like the following:
image

  • We have a CUDAGraph g1, which uses the static placeholder1 to get its inputs. (The memory address of this placeholder is baked in the CUDAGraph).
  • For its internal operations, it uses the CUDAGraph Memory Pool.
  • The output1 we get after the graph replay is a pointer to a location in the CUDAGraph memory pool.

Next, we capture g2 and make g2 use the same memory pool as g1. And the situation which we have is something like the following:

  • g2 uses the same CUDAGraph Memory Pool as g1 for its internal operations.
  • g2 gets its input from a different static placeholder 2.
  • Output of g2 is again a pointer to the CUDAGraph memory pool.

From the post:

With torch.cuda.make_graphed_callables(), if you want to graph several callables and you know they’ll always run in the same order (and never concurrently) pass them as a tuple in the same order they’ll run in the live workload, and make_graphed_callables() will capture their graphs using a shared private pool.

This means that if the capture sequence is g1 → g2, we should follow the order of g1 → g2 during replay. (why?)

I want to understand a few details about the approach:

  1. I guess copying data to the static placeholder is not required in CUDAGraph trees (based on the diagrams and the lightning talk by @eellison). The output address of g1 is baked into the graph g2 as its (static) input.
  2. If we have a chain, say g1 → g2, I guess g1 and g2 share the same CUDA memory pool in the same manner as in the case of eager CUDA Graph capture.
  3. Now, as in the post by @eellison following g1 → g2 if we have another graph g4 (a chain g1 → g2 → g4), I guess, g4 as well shares the same memory pool as that of g1, g2 right?
  4. I am trying to understand why we need to branch out and checkpoint the memory state.
    • Let us say that we are in a situation where we have captured g1 followed by g2, and g2 shares the same memory pool as g1.
    • Now, suppose we reach a situation where we need to execute graph g3 after g1. And now, given the current situation of the memory pool, if we try capturing g3, it would lead to the dependency g1 → g2 → g3, right? Trying to replay in order g1 → g3 might result in an error because we are not following the capture order during replay.
  1. Regarding the above quote, can a simple example be provided illustrating its importance?
  1. Similarly, for the above quote, a simple example would help to understand it.
  1. An example of the “live memory” discussed above and how it causes dependency between two recordings.
  1. If I can understand the idea of “live tensors” (as asked in point 7), then the above quote will get clearer, I guess.

I am fascinated by the concept of torch.compile, and I find this section dealing with CUDAGraphs pretty interesting. So, couldn’t help but write down my doubts.

5 Likes

Hi! Seems like you have a great handle on things. Answering a few questions/clarifying a few things.

    • Now, suppose we reach a situation where we need to execute graph g3 after g1. And now, given the current situation of the memory pool, if we try capturing g3, it would lead to the dependency g1 → g2 → g3, right? Trying to replay in order g1 → g3 might result in an error because we are not following the capture order during replay.

We had previously captured g1 → g2, and now we are trying to execute g1 → g3. The reason we need to checkpoint the memory state is during the fast - path execution of cuda graphs we dont do any memory accounting in the cuda caching allocator. All of the allocations appear as deallocated. To capture a new graph and share memory pool, the cuda caching allocator needs to know which tensors are live so that new allocations made in that pool do not overwrite existing live tensors.

5/6 are pretty much the same concept. Run the below with TORCH_LOGS=“cudagraphs”

import torch

@torch.compile(mode="reduce-overhead")
def foo(x):
    return x + 1, x + 2

@torch.compile(mode="reduce-overhead")
def fee(y):
    return y * 4

for _ in range(3):
    torch.compiler.cudagraph_mark_step_begin()
    inp = torch.rand([4], device="cuda")
    a, b = foo(inp)
    del a
    # a's memory can be reused here
    fee(b)

torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")
a, b = foo(inp)
print("Should checkpoint now")
# a no longer deleted, still live, cant reclaim meomry
fee(b)

7/8. Memory dependency is forced by live tensors. The caching allocator liveness state depends on prior outputs in the current tree.

Here are some examples.

import torch

@torch.compile(mode="reduce-overhead")
def foo(x):
    return x + 1, x + 2

@torch.compile(mode="reduce-overhead")
def fee(y):
    return y * 4

def get_curr_node(device):
    return torch._inductor.cudagraph_trees.get_container(device.index).tree_manager.current_node

for i in range(3):
    torch.compiler.cudagraph_mark_step_begin()
    inp = torch.rand([4], device="cuda")
    a, b = foo(inp)
    del a, b
    # no mem dependency
    fee(inp)
    assert get_curr_node(inp.device).parent is None

for i in range(3):
    torch.compiler.cudagraph_mark_step_begin()
    inp = torch.rand([6], device="cuda")
    a, b = foo(inp)
    # mem dependency
    fee(inp)
    assert get_curr_node(inp.device).parent is not None
2 Likes

Thanks @eellison for your insightful explanation.

Just to clarify on this:

In the above example figure, we capture g1 → g2. Then we replay g1 → g2. As a result of this we have live tensors Output1 and Output2 in the CUDAGraph Memory Pool. (Since you said that the allocations are not accounted for during graph-replay or fast-path, and the allocations appear as deallocated, I assume Output1 and Output2 as just pointers to memory location in the CUDAGraph Memory Pool, which can be assumed logically as a long stretch of memory region)

Next if we try to use the same CUDAGraph Memory Pool for capture of CUDAGraph g3, then it can just overwrite the memory location (for the allocations of g3) which are pointed to by Output1 and/or Output2, since the CachingAllocator has no accounting information, the live tensors (Output1 and Output2) are not actually allocated. We need to preserve the state of the live tensors - therefore we checkpoint. Right?

I guess, I got why do we need to checkpoint. I would like to get some clarifications, about how the checkpointing is done and how it solves the problem.

How checkpointing the memory pool back to the state it was at the end of graph g1, solves the issue of maintaining the live-tensor state? If we checkpoint as above, the state of live tensors shall be lost right? I am not getting this part.

The comment in the CUDACachingAllocator.cpp seems to mention the same thing, but I am unable to see through it.

How checkpointing the memory pool back to the state it was at the end of graph g1, solves the issue of maintaining the live-tensor state? If we checkpoint as above, the state of live tensors shall be lost right? I am not getting this part.

The test here might be helpful.

After del outputs if were to run live_blocks(pool_id) it would be equal 0. However, after we run checkpointing on the captured Cuda Caching Allocator state, it correctly accounts for 2 live allocated blocks of memory. The _cuda_cudaCachingAllocator_raw_delete is a stand-in for what happens in cudagraph trees where the tensors call raw_delete in their deleter_fn.

Note that we checkpoint before we record g3, and we apply any deltas in liveness that might have occurred between the end of g1 and start of g3.

1 Like

Thanks @eellison.

Do you have any CUDAGraph Trees design doc apart from this? Like some doc, upon which the source code of CUDAGraph Trees is based.