Overlapping device to host copy with GPU collectives

Hello, I’m running tests on AWS P5 instances and I’m trying to asynchronously offload tensors from device to host while doing GPU collectives. My impression is that DtoH is using the PCIe and GPU CC is using RDMA through NVlink, so they should be irrelevant. However what I observed is that the CC ops can be highly impacted by the offloading, even running it synchronously.
I create a simple script to run my tests with torchrun -nproc_per_node 8 so it is a single node test with 8 GPUs.

import torch
import torch.distributed as dist

NO_OFFLOAD = True
SYNC_OFFLOAD = False

dist.init_process_group("nccl")
torch.cuda.set_device(dist.get_rank())

tensor = torch.randn(8192, 8192, device="cuda", dtype=torch.bfloat16)
dist.all_reduce(tensor) # warm up nccl
tensor_list = [torch.randn(8192, 8192, device="cuda", dtype=torch.bfloat16) for _ in range(8)]
cpu_tensor = torch.empty(tensor.shape, dtype=tensor.dtype, device=torch.device("cpu"), pin_memory=True)
d2h_stream = torch.cuda.Stream(device=dist.get_rank(), priority=-1)
torch.cuda.synchronize()
if not NO_OFFLOAD:
    with torch.cuda.stream(d2h_stream):
        with torch.no_grad():
            for i in range(5):
                cpu_tensor.copy_(tensor, non_blocking=True)
if SYNC_OFFLOAD:
    torch.cuda.synchronize()
for i in range(32):
    dist.all_gather(tensor_list, tensor)

Firstly with NO_OFFLOAD = True I see below profile


where each allgather takes around 2.7ms. Now I turn on offload

the first several allgathers becomes very slow, up to 13ms. Even the offloading is finished, it still take ~4.4ms for the rest of allgathers. Next I enable sync offload so that allgathers will happen after the offloading.

However the first allgather becomes extremely slow, ~110ms. After that the rest allgathers become 2.7ms again.
Could someone help me understand this behavior? Help is greatly appreciated! I’m on torch 2.2.0 with NCCL version 2.19.4+cuda12.1.

Hi, for the 3rd profile (SYNC_OFFLOAD), could the long first all_gather be due to straggler effect?
Can you mind adding a barrier here and see if the profile changes? Thanks!

if SYNC_OFFLOAD:
    torch.cuda.synchronize()
dist.barrier()

Thank @kwen2501! Actually I found that weird behavior for the 3rd profile (SYNC_OFFLOAD) is because different process have different offloading time. After applying the barrier that behavior is gone. However I still observe slowness in NCCL when overlapping with DtoH copies.

1 Like

Great, thanks! That matches our expectation!

Now back to Fig. 1 vs Fig. 2, I wonder why Fig. 2 has one more line of all-gather kernels? I am referring to the kernels in the dotted box. Thanks!

@kwen2501 what do you mean more line? The total number of the all-gathers are the same (32). If you are referring to the number of rows I think it is just about how I expand the Nsys App for the stream details.

Thanks! I reported this to NCCL team.
If you’d like, you can also open an issue on NCCL’s GitHub: GitHub - NVIDIA/nccl: Optimized primitives for collective multi-GPU communication, for easier tracking.