Memcpy based P2P communication for pipeline parallelism instead NCCL

TL;DR:

Memcpy-based communication (e.g. tensor.to and tensor.copy) is way better than NCCL P2P APIs for pipeline parallelism, but how can we enable it for multi-node, multi-process training with torchrun?

Context

I’ve been constructing a tool for automatic pipeline parallelism by slicing an FX graph produced by torch.compile. To find the optimal configuration of pipeline parallelism, I need to estimate the total runtime of a pipeline, which consists of computation time and communication time between devices.

The core part of pipeline parallelism involves sending the result of a subgraph to the next subgraph. This process relies on P2P communication between two devices. In PyTorch, P2P communication is realized by two functions: tensor.to (or tensor.copy_) and torch.distributed.send. The former uses cudaMemcpy, while the latter employs NCCL P2P kernels under the hood.

Here’s the issue: if we want to use more than 8 GPUs, we inevitably have to rely on multi-node training, usually with torchrun or something similar such as MPI. As far as I know, there is no PyTorch function that allows access to a remote tensor across nodes, except for the NCCL-backed dist.send (cf. RPC/RRef still does not support CUDA tensor).

However, I found some drawbacks of the current torch.distributed APIs compared to tensor.copy_ that block optimal training with pipeline parallelism:

  1. NCCL P2P can be almost 3 times slower than memcpy on PCIe.
  2. Multiple NCCL kernels cannot be overlapped and follow a predefined strict total order. Precisely, concurrent NCCL kernels are unsafe: Creating a Communicator — NCCL 2.21.5 documentation
  • This limitation forces users to carefully schedule every communication from and to the devices within a single (logical) stream.
  • If we want to use more axes of parallelism (e.g., 3D-parallelism), we must set the linear order of every communication; FSDP’s all-gather, TP’s all-reduce, and PP’s P2P.
  1. When we overlap NCCL P2P communication and computation, both become slower, but memcpy-based P2P doesn’t.
  2. The synchronization models of dist.send and dist.isend are easy to use, but they differ from the documentation. dist.send does not block the Python process, but rather inserts a synchronization barrier between the NCCL communication stream and the main computation stream (or the context of the stream where dist.send is called).

I conducted some experiments to compare NCCL P2P APIs to memcpy and confirm the above list. Then, I will propose features to be added that may improve pipeline parallelism for multi-node.

Communication and computations of pipeline parallelism

In brief, pipeline parallelism is a type of model parallelism that divides the computation graph into N subgraphs, and sequentially executes the subgraphs with N devices and M micro-batches.

The figure above displays an Nsight Systems log of a single training iteration of the 1F1B (1-forward 1-backward) pipeline with 4 micro-batches and the ViT-g/14 model. The blue regions with gray NVTX boxes below represent the forward passes, while the remaining areas indicate the backward passes.

Let’s examine what is happening between the forward and backward passes. The above shows the timelines of devices 2 to 4, and there are four pairs of communications: D2H (purple) / H2D (green) through PCIe, and D2D (red) through NVLink. From device 2 to 3, the activation tensor resulting from the forward pass in device 2 is moved to device 3, and the gradient tensor of the backward pass in device 3 is sent back immediately after.

We can observe that the second device almost simultaneously receives and sends tensors, and we cannot accurately predict the order of communication because it depends on the timing of each previous forward/backward computation.

The above figures are captured from my research project about automatic pipeline parallelism. What would happen if we change the communication scheme to NCCL?

This figure illustrates the forward passes of ViT-g/14 training with Colossal-AI’s pipeline parallelism. Since ‘Send 2→3’ is scheduled later than ‘Send 1→2’ on device 2, the subsequent forward pass on device 3 is blocked until ‘Send 1→2’ is completed.

This behavior complicates pipeline optimization; it’s not enough to consider only the order of communication on each individual device. It’s also necessary to examine all communication dependencies across devices and nodes. How can we determine whether ‘Send 1→2’ should start before ‘Send 2→3’ to reduce total latency?

Experiments

Environment

I conducted the experiments using 8 RTX 3090 GPUs and 2 CPU NUMA nodes in a single rack server. There are 4 NVLink bridges connecting odd-ranked devices to even-ranked devices, and the others are connected via PCIe bridges.

## dmidecode -t 2
# dmidecode 3.3
Getting SMBIOS data from sysfs.
SMBIOS 3.2.0 present.

Handle 0x0002, DMI type 2, 15 bytes
Base Board Information
	Manufacturer: Supermicro
	Product Name: H12DSG-O-CPU
	Version: 1.01A
	Serial Number: UM208S600092
	Asset Tag: To be filled by O.E.M.
	Features:
		Board is a hosting board
		Board is removable
		Board is replaceable
	Location In Chassis: To be filled by O.E.M.
	Chassis Handle: 0x0003
	Type: Motherboard
	Contained Object Handles: 0

## nvidia-smi topo -m
GPU0	GPU1	GPU2	GPU3	GPU4	GPU5	GPU6	GPU7	CPU Affinity	NUMA Affinity	GPU NUMA ID
GPU0	 X 	NV4	SYS	SYS	SYS	SYS	SYS	SYS	0-15,32-47	0		N/A
GPU1	NV4	 X 	SYS	SYS	SYS	SYS	SYS	SYS	0-15,32-47	0		N/A
GPU2	SYS	SYS	 X 	NV4	SYS	SYS	SYS	SYS	0-15,32-47	0		N/A
GPU3	SYS	SYS	NV4	 X 	SYS	SYS	SYS	SYS	0-15,32-47	0		N/A
GPU4	SYS	SYS	SYS	SYS	 X 	NV4	SYS	SYS	16-31,48-63	1		N/A
GPU5	SYS	SYS	SYS	SYS	NV4	 X 	SYS	SYS	16-31,48-63	1		N/A
GPU6	SYS	SYS	SYS	SYS	SYS	SYS	 X 	NV4	16-31,48-63	1		N/A
GPU7	SYS	SYS	SYS	SYS	SYS	SYS	NV4	 X 	16-31,48-63	1		N/A

## nvidia-smi
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.67                 Driver Version: 550.67         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0 Off |                  N/A |
| 30%   38C    P8             33W /  350W |      14MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
# ...
# ----------------....................................-------------------------------
# ...
|   7  NVIDIA GeForce RTX 3090        Off |   00000000:E1:00.0 Off |                  N/A |
| 30%   34C    P8             27W /  350W |      14MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

## p2pBandwidthLatencyTest
Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
   D\\D     0      1      2      3      4      5      6      7 
     0 838.25 101.78   8.44   8.48   7.60   7.95   7.61   7.91 
     1 101.74 838.03   8.74   8.73   7.60   8.03   7.65   8.00 
     2   8.45   8.75 837.80 101.65   7.87   8.17   7.91   8.13 
     3   8.50   8.74 101.68 837.13   7.69   8.08   7.79   8.08 
     4   7.58   7.62   7.85   7.69 838.93 101.78   7.40   7.40 
     5   7.91   8.03   8.15   8.00 101.50 837.80   7.48   7.84 
     6   7.58   7.67   7.86   7.70   7.45   7.60 838.90 101.25 
     7   7.86   8.02   8.14   8.09   7.36   7.75 101.58 838.25 

From my experiments, I will test the communication channels between devices 0↔2 and 0↔3 because the NVLink channel is too fast to observe the communication throughput.

Communication-Communication Overlap vs NCCL

In this scenario, I ran five forward passes of ViT-g/14 with a batch size of 32, and then sent two sets of five 64MB tensors (2**24 floats) from device 0 to devices 2 and 3.

You can see the codes for the below experiments from this link:P2P communication test · GitHub

First, I sent tensors with tensor.to method and interleaving. The total communication latency was 97.441ms (~5.8GB/s)

Then, I overlapped two communication channels by manually creating two separate communication streams. The total latency is reduced by x0.75 (75.075ms), but the throughput also decreases.

This can be expected using simple math. Let’s assume that the read/write bandwidth is equal for all, and there is a single communication channel between a device’s RAM and the host’s RAM for each device. If we overlap two sends, memcpy D2H (device to host) takes twice as long, but the subsequent two memcpy H2D (host to device) can run concurrently. Thus, the single latency of communication becomes 1.5 times longer than before. Since the number of communications becomes half, the total latency will be x0.75: 100ms to 75ms.


However, the throughput becomes 3 times slower (97.441ms → 270.14ms) if we change tensor.to to dist.send. If we give up interleaving and change it to batch_isend_irecv, the total latency increases (270.14ms → 298.44ms).

Note: I wonder if the slowdown of NCCL comes from the disabled P2P on retail GPUs (RTX3090). I’ll repeat this on an A100 cluster.

Computation-Communication Overlap vs NCCL

Parallelization frameworks such as Deepspeed usually support communication-computation overlap to enhance throughput. But by how much?


The above figures are timelines of two sets of five ViT-g/14 forward passes with interleaved 10 communications (64MB per send, GPU 0 to 2 through PCIe).

Without comm-comp. overlap, communication takes 266.9ms and computation 396.8ms. However, if we overlap those, each takes 277.8ms and 584.6ms. This means that the overlapping part of computation becomes 3 times slower (100ms → ~300ms).

Overlapping certainly provides benefits, but it disrupts the assumptions of pipeline schedule optimizers (e.g., Alpa, DAPPLE) that assume computation and communication are independent and parallelizable.

This slowdown is not observed with memcpy-based communication, and it’s also two times faster than NCCL (133ms → 137ms).

Here, we can see the delay based on “Command Buffer Full”. I suppose that the communication load was too large (total 640MB), and if I reduce the size of the communication by half, the gap between two computation sets is gone.

Documentation errors of dist.send and dist.isend

From the figures of NCCL comm-comp. overlapping, we can see the timing of NCCL kernel invocation from the CPU from the left side of the images (blue vertical bars).

The differences in the two images are just changing dist.send to dist.isend, but contrary to the documentation of dist.send(Distributed communication package - torch.distributed — PyTorch 2.3 documentation), it does not block the Python process, thus the following computation kernels are seamlessly computed with the help of asynchronous CUDA kernel activation.

Also, a wait method of the result of dist.isend does not block the Python process, but it works similarly to the torch.cuda.Event.wait(stream) function that guarantees all the following kernels assigned to stream are executed after the wait function.

Proposal & Questions

  • NCCL P2P APIs require a strict total order of every NCCL kernel. Can we build out-of-order P2P APIs for multi-node systems?
  • We may be able to implement harmless computation-P2P communication overlap for multi-node training since tensor communication should not consume CUDA threads, as we can see from memcpy based pipeline parallelism.
    • We might use GPUDirect RDMA or revive tensorpipe for that.
    • RPC is based on tensorpipe, but it does not support direct P2P communication of CUDA tensors. Can we build simpler APIs just for P2P communication between processes
2 Likes

Cool post! I would like to follow up and read this more carefully after vacation. From a quick read, there are a couple of things I wanted to say.

  1. I think some of the things you assumed about nccl p2p op serialization / strict ordering should not be true. I’d like to understand how you launch the nccl ops better. Within pytorch, we create a new nccl communicator + nccl stream for every pair of send+recv rank within a given process group, precisely because this is supposed to allow overlap between communication operations with different peers. Specifically, a recv from rank 0 to rank 1 should be able to go in parallel with a send from rank 1 to rank 2. However, a send from rank 1 to rank 2 would not go in parallel with a recv from rank 2 to rank 1. They would use the same communicator and be serialized. Using batch_isend_irecv in this case alone makes sense- it would allow nccl to run these operations together in one kernel. I did a little benchmarking to show this works at least in small scale a while back. Note that you would not want to use batch for ops between different sets of peers (don’t batch recv from 0 to 1 with send from 1 to 2, as this would cause strict coupling).

  2. @yifuwang has recently done some work on DMA-based communication primitives. You may want to take a look at this.

Generally, we’ve been focusing more on pipelining for training in large settings where we assume the pipeline would be the outer-most dimension using the slowest network link. I’d be curious to see if you have data to suggest memcopy based primitives would beat nccl at that setting. But at the same time we’re getting more interested in pipelining for distributed inference for a model too large for one gpu, so it would be a good time to collaborate more on faster comms for the local setting. Feel free to reach out to me on pytorch-slack if you want to discuss more directly.

That’s surprising to know. If we have a process group of size 4, and:

  • they have an allreduce operation
  • rank 0 sends to rank 1
  • rank 0 sends to rank 2

Do you mean the first allreduce will create 4 communicators (1 for each), and each of the following send/recv will create 2 communicators, in total there will be 8 communicators created?

Not quite. For all reduce I would usually recommend using a dedicated process group per parallelism strategy but that actually doesn’t matter in this case. If you have a process group with 4 ranks, it will do the following:

  • for any collectives it will use a 4-rank nccl communicator and cuda steam
  • for p2p operations, it will create a new 2-rank communicator and nccl stream for any unique pair of ranks

So using one process group, you should be able to overlap an all reduce and several p2p operations among different peers.

That said, the all reduce and the p2p operations may use some of the same resources in the gpu, I’m not totally sure about that part.

I’m counting the total number of communicators across all processes. According to what you said, it seems my counting should be correct, 8 communicators in total, right?

W.r.t. the stream, when you say a collective would use a 4-rank nccl communicator and cuda steam, does this cuda stream respect the “current cuda stream” in pytorch? I mean, when i change the current stream, will the stream used to launch nccl collective change? Or it just uses its own stream?

This has great significance w.r.t. cuda graph capture.

PyTorch ProcessGroupNCCL creates internal cuda streams. I described above how different communicators and streams are created.

When performing a collective, the collective stream for that process group is used. The collective is immediately launched to the cuda driver on this stream, and a new cuda event is created to mark the completion of this collective. When users call work.wait(), the current/default stream does an event sync to cause the user stream to block until the collective finishes.

Blocking is entirely within the gpu. CPU will not block unless something else synchronizes the cpu to cuda, such as printing a tensor or explicit calling cuda sync api.

This is also the main difference between send/Isend or allreduce with async_op=true/false. The more ‘blocking’ version of these ops immediately blocks the default stream (not cpu), while the asynchronous version lets the user schedule additional compute on the default stream before the collective finishes.

I’m not an expert on cuds graphs, but I think it should be possible to use them with collectives as long as all the streams sync back to the main/default stream within the same graph capture region. In other words it should be ok to use asynchronous comm ops as long as you call work.wait before ending graph capture.

Thanks for reply @wconstab!

I withdraw the second claim (cannot overlap two NCCL P2P commands). After I change the warm-up parts and replace dist.send/recv to dist.isend/irecv. I can overlap two P2P streams.

However, isn’t it still unsafe if we launch many concurrent NCCL sends at the same time according to NCCL documentation?
From my experience, some Transformer models spread the result of the first operation to all the following pipeline parts, if we directly slice a compiled FX graph into N parts and run a pipeline with N devices (Maybe it comes from the result of Embedding layer) Then, we need to execute (N -1) dist.isend at once after the first forward pass.


(Sequential / 2 → 0 ← 3)

(Parallel / 2 → 0 ← 3)

(Parallel / 2 → 0 → 3)

Also, I found that two concurrent NCCL P2P streams may or may not harm each other, depending on the direction and the device id.

I invoked 5 sets of P2P 1 and P2P 2 operations sequentially (5x dist.send(0)dist.send(1)) or concurrently (5x dist.isend(0)dist.isend(1)) like below.

works = []
if rank == 0:
    for i in range(5):
        w = dist.isend(v_inl[i], 1)
        works.append(w)
elif rank == 1:
    for i in range(10):
        w = dist.irecv(v_outl[i], 0 if i % 2 == 0 else 2)
        works.append(w)
else:
    for i in range(5):
        w = dist.isend(v_inl[i], 1)
        works.append(w)

for w in works:
    w.wait()

Here’s the result. (Same workload of the main post)

SEQ/CON P2P 1 direction P2P 2 direction P2P 1 - mean P2P 2 - mean Total
Sequential cuda:0 → cuda:2 0 → 3 27ms 27ms 269ms
Concurrent 0 → 2 0 → 3 57ms 57ms 294ms
Sequential 2 → 0 3 → 0 26ms 21.5ms 232ms
Concurrent 2 → 0 3 → 0 27ms 22.5ms 132ms
Sequential 2 → 0 0 → 3 26.5ms 26ms 261ms
Concurrent 2 → 0 0 → 3 46ms → 26.5ms 29ms 193ms
Sequential memcpy 0 → 2 0 → 3 9.5ms 9.5ms 97.5ms
Concurrent memcpy 0 → 2 0 → 3 14.5ms 14.5ms 72.5ms

We can see that two isend can overlap but there’s no advantage to do that, but two irecv operations run without performance degradation. There is also a strange point that cuda:0cuda:3 does not show bidirectionally equal throughput, and from concurrent irecv and isend, only the performance irecv part is harmed.

Nevertheless, memcpy is still the best choice if we run pipeline parallelism on a PCIe connected single node.

Does @yifuwang understand the internal behavior of NCCL? I’ll contact you on Pytorch slack afterwards.

Updated results on A100 clusters

I recently managed to use A100 clusters and here’s the result:

  • The throughput of NCCL inter-node and intra-node P2P communication in A100 is much more stable than that of RTX3090.
  • A100 supports direct memcpy P2P, and I think that it allows the same throughput between memcpy and NCCL P2P send/recv.
  • I suspect that there is a serious flaw in the implementation of the NCCL P2P APIs for some environments, including the recent retail RTX series.

Environment

I conducted the experiments on two nodes of BullSequana X2415 (4x A100 40GB), that the nodes are connected with Infiniband HDR and each pair of devices in a node has its own fast interconnect (it does not say that they have NVLink or NVSwitch, and their throughput is also much slower than NVLink).

Since the throughput of interconnects is too fast to measure, I increased the communication workload from 64MB (2^24 floats) to 128MB and 256MB (2^25 and 2^26 floats), and checked both throughput and overhead latency.

Communication-Communication overlap





SEQ/CON P2P 1 direction (node id, device id) P2P 2 direction P2P 1 - mean (128MB/256MB ) P2P 2 - mean
Sequential Cross-node NCCL send (0, 0) → (1, 0) Cross-node NCCL (0, 0) → (1, 1) 8.5ms / 17.0ms 8.7ms / 17.5ms
Concurrent Cross-node NCCL send Cross-node NCCL send 15.5ms / 32.0ms 16.5ms / 33.0ms
Sequential Intra-node send NCCL (0, 0) → (0, 1) Intra-node send NCCL (0, 0) → (0, 2) 1.2ms / 2.7ms 1.2ms / 1.7ms
Concurrent Intra-node send NCCL Intra-node send NCCL 1.5ms / 3.0ms 1.5ms / 3.9ms
Sequential Intra-node send memcpy Intra-node send memcpy 1.4ms / 2.8ms 1.4ms / 2.8ms
Concurrent Intra-node send memcpy Intra-node send memcpy 1.4ms / 2.8ms 1.4ms / 2.8ms

Contrary to the original benchmark set, we can observe that intra-node NCCL is slightly faster than the memcpy version. Also, the measured throughput of cross-node communication (120Gbps) is close to the throughput of Infiniband HDR (200Gbps), so we may conclude that cross-node memcpy is not superior to NCCL, although I couldn’t find how to achieve cross-channel communication between slurm-activated processes without NCCL.

I think the difference between the throughput of 8x RTX3090 and that of 4x A100 is the support of direct P2P communication over PCIe. Although X2415 blade is not equipped with NVLink (600GB/s), as the measured throughput of intra-node comm. was 90GB/s, tensor.copy_ method uses Memcpy PtoP kernel that the throughput of it is the same as NCCL P2P send/recv. However, the communication between RTX3090 is a pair of Memcpy DtoH and Memcpy HtoD (PCIe) or Memcpy DtoD (NVLink).

Communication-Computation overlap




I also checked whether the overlapped computation and NCCL communication is harmed in the A100 cluster, and I found that it isn’t.

Through the overall benchmarks, I caught that there is a flaw in the combination of NCCL P2P APIs and the recent retail RTX GPUs (3090/4090). I’m suspicious about the capacity of CUDA threads in RTX3090 that holds concurrent NCCL P2P kernels is much smaller than that of A100.

Some troubles in dist.barrier

I had a hard time taking the results above, since dist.barrier did or didn’t block the main python process, even in the same experiment or the same process.

From the NSight System profiler, dist.barrier is observed as a single NCCL all-reduce, and I suspect that dist.barrier is internally implemented as using it to synchronize the streams within separate devices. I don’t know whether this implementation is correct to faithfully make a barrier.

Conclusion

We should beware of environments that do not support direct P2P communication. Although the results with Infiniband and ML-specialized devices show sufficient throughput for pipeline parallelism, we may check if there is an alternative channel to increase communication throughput if we are not equipped with top-notch devices.

This is a pretty cool and insightful post about torch NCCL! (and I do think there is a need for more detailed official documentation :slight_smile:

hello~
I don’t know if you mind showing me your code about the implementation of NCCL comm overlapping with computation? I’m having an issue where NCCL P2P doesn’t start immediately when doing computaion