Understanding CUDAGraph Trees

Hi! Seems like you have a great handle on things. Answering a few questions/clarifying a few things.

    • Now, suppose we reach a situation where we need to execute graph g3 after g1. And now, given the current situation of the memory pool, if we try capturing g3, it would lead to the dependency g1 → g2 → g3, right? Trying to replay in order g1 → g3 might result in an error because we are not following the capture order during replay.

We had previously captured g1 → g2, and now we are trying to execute g1 → g3. The reason we need to checkpoint the memory state is during the fast - path execution of cuda graphs we dont do any memory accounting in the cuda caching allocator. All of the allocations appear as deallocated. To capture a new graph and share memory pool, the cuda caching allocator needs to know which tensors are live so that new allocations made in that pool do not overwrite existing live tensors.

5/6 are pretty much the same concept. Run the below with TORCH_LOGS=“cudagraphs”

import torch

@torch.compile(mode="reduce-overhead")
def foo(x):
    return x + 1, x + 2

@torch.compile(mode="reduce-overhead")
def fee(y):
    return y * 4

for _ in range(3):
    torch.compiler.cudagraph_mark_step_begin()
    inp = torch.rand([4], device="cuda")
    a, b = foo(inp)
    del a
    # a's memory can be reused here
    fee(b)

torch.compiler.cudagraph_mark_step_begin()
inp = torch.rand([4], device="cuda")
a, b = foo(inp)
print("Should checkpoint now")
# a no longer deleted, still live, cant reclaim meomry
fee(b)

7/8. Memory dependency is forced by live tensors. The caching allocator liveness state depends on prior outputs in the current tree.

Here are some examples.

import torch

@torch.compile(mode="reduce-overhead")
def foo(x):
    return x + 1, x + 2

@torch.compile(mode="reduce-overhead")
def fee(y):
    return y * 4

def get_curr_node(device):
    return torch._inductor.cudagraph_trees.get_container(device.index).tree_manager.current_node

for i in range(3):
    torch.compiler.cudagraph_mark_step_begin()
    inp = torch.rand([4], device="cuda")
    a, b = foo(inp)
    del a, b
    # no mem dependency
    fee(inp)
    assert get_curr_node(inp.device).parent is None

for i in range(3):
    torch.compiler.cudagraph_mark_step_begin()
    inp = torch.rand([6], device="cuda")
    a, b = foo(inp)
    # mem dependency
    fee(inp)
    assert get_curr_node(inp.device).parent is not None
2 Likes