TL;DR:
Memcpy-based communication (e.g. tensor.to
and tensor.copy
) is way better than NCCL P2P APIs for pipeline parallelism, but how can we enable it for multi-node, multi-process training with torchrun?
Context
I’ve been constructing a tool for automatic pipeline parallelism by slicing an FX graph produced by torch.compile. To find the optimal configuration of pipeline parallelism, I need to estimate the total runtime of a pipeline, which consists of computation time and communication time between devices.
The core part of pipeline parallelism involves sending the result of a subgraph to the next subgraph. This process relies on P2P communication between two devices. In PyTorch, P2P communication is realized by two functions: tensor.to
(or tensor.copy_
) and torch.distributed.send
. The former uses cudaMemcpy
, while the latter employs NCCL P2P kernels under the hood.
Here’s the issue: if we want to use more than 8 GPUs, we inevitably have to rely on multi-node training, usually with torchrun
or something similar such as MPI. As far as I know, there is no PyTorch function that allows access to a remote tensor across nodes, except for the NCCL-backed dist.send
(cf. RPC/RRef
still does not support CUDA tensor).
However, I found some drawbacks of the current torch.distributed
APIs compared to tensor.copy_
that block optimal training with pipeline parallelism:
- NCCL P2P can be almost 3 times slower than memcpy on PCIe.
- Multiple NCCL kernels cannot be overlapped and follow a predefined strict total order. Precisely, concurrent NCCL kernels are unsafe: Creating a Communicator — NCCL 2.21.5 documentation
- This limitation forces users to carefully schedule every communication from and to the devices within a single (logical) stream.
- If we want to use more axes of parallelism (e.g., 3D-parallelism), we must set the linear order of every communication; FSDP’s all-gather, TP’s all-reduce, and PP’s P2P.
- When we overlap NCCL P2P communication and computation, both become slower, but memcpy-based P2P doesn’t.
- The synchronization models of
dist.send
anddist.isend
are easy to use, but they differ from the documentation.dist.send
does not block the Python process, but rather inserts a synchronization barrier between the NCCL communication stream and the main computation stream (or the context of the stream wheredist.send
is called).
I conducted some experiments to compare NCCL P2P APIs to memcpy and confirm the above list. Then, I will propose features to be added that may improve pipeline parallelism for multi-node.
Communication and computations of pipeline parallelism
In brief, pipeline parallelism is a type of model parallelism that divides the computation graph into N subgraphs, and sequentially executes the subgraphs with N devices and M micro-batches.
The figure above displays an Nsight Systems log of a single training iteration of the 1F1B (1-forward 1-backward) pipeline with 4 micro-batches and the ViT-g/14 model. The blue regions with gray NVTX boxes below represent the forward passes, while the remaining areas indicate the backward passes.
Let’s examine what is happening between the forward and backward passes. The above shows the timelines of devices 2 to 4, and there are four pairs of communications: D2H (purple) / H2D (green) through PCIe, and D2D (red) through NVLink. From device 2 to 3, the activation tensor resulting from the forward pass in device 2 is moved to device 3, and the gradient tensor of the backward pass in device 3 is sent back immediately after.
We can observe that the second device almost simultaneously receives and sends tensors, and we cannot accurately predict the order of communication because it depends on the timing of each previous forward/backward computation.
The above figures are captured from my research project about automatic pipeline parallelism. What would happen if we change the communication scheme to NCCL?
This figure illustrates the forward passes of ViT-g/14 training with Colossal-AI’s pipeline parallelism. Since ‘Send 2→3’ is scheduled later than ‘Send 1→2’ on device 2, the subsequent forward pass on device 3 is blocked until ‘Send 1→2’ is completed.
This behavior complicates pipeline optimization; it’s not enough to consider only the order of communication on each individual device. It’s also necessary to examine all communication dependencies across devices and nodes. How can we determine whether ‘Send 1→2’ should start before ‘Send 2→3’ to reduce total latency?
Experiments
Environment
I conducted the experiments using 8 RTX 3090 GPUs and 2 CPU NUMA nodes in a single rack server. There are 4 NVLink bridges connecting odd-ranked devices to even-ranked devices, and the others are connected via PCIe bridges.
## dmidecode -t 2
# dmidecode 3.3
Getting SMBIOS data from sysfs.
SMBIOS 3.2.0 present.
Handle 0x0002, DMI type 2, 15 bytes
Base Board Information
Manufacturer: Supermicro
Product Name: H12DSG-O-CPU
Version: 1.01A
Serial Number: UM208S600092
Asset Tag: To be filled by O.E.M.
Features:
Board is a hosting board
Board is removable
Board is replaceable
Location In Chassis: To be filled by O.E.M.
Chassis Handle: 0x0003
Type: Motherboard
Contained Object Handles: 0
## nvidia-smi topo -m
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV4 SYS SYS SYS SYS SYS SYS 0-15,32-47 0 N/A
GPU1 NV4 X SYS SYS SYS SYS SYS SYS 0-15,32-47 0 N/A
GPU2 SYS SYS X NV4 SYS SYS SYS SYS 0-15,32-47 0 N/A
GPU3 SYS SYS NV4 X SYS SYS SYS SYS 0-15,32-47 0 N/A
GPU4 SYS SYS SYS SYS X NV4 SYS SYS 16-31,48-63 1 N/A
GPU5 SYS SYS SYS SYS NV4 X SYS SYS 16-31,48-63 1 N/A
GPU6 SYS SYS SYS SYS SYS SYS X NV4 16-31,48-63 1 N/A
GPU7 SYS SYS SYS SYS SYS SYS NV4 X 16-31,48-63 1 N/A
## nvidia-smi
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.67 Driver Version: 550.67 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 3090 Off | 00000000:01:00.0 Off | N/A |
| 30% 38C P8 33W / 350W | 14MiB / 24576MiB | 0% Default |
| | | N/A |
# ...
# ----------------....................................-------------------------------
# ...
| 7 NVIDIA GeForce RTX 3090 Off | 00000000:E1:00.0 Off | N/A |
| 30% 34C P8 27W / 350W | 14MiB / 24576MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
## p2pBandwidthLatencyTest
Bidirectional P2P=Enabled Bandwidth Matrix (GB/s)
D\\D 0 1 2 3 4 5 6 7
0 838.25 101.78 8.44 8.48 7.60 7.95 7.61 7.91
1 101.74 838.03 8.74 8.73 7.60 8.03 7.65 8.00
2 8.45 8.75 837.80 101.65 7.87 8.17 7.91 8.13
3 8.50 8.74 101.68 837.13 7.69 8.08 7.79 8.08
4 7.58 7.62 7.85 7.69 838.93 101.78 7.40 7.40
5 7.91 8.03 8.15 8.00 101.50 837.80 7.48 7.84
6 7.58 7.67 7.86 7.70 7.45 7.60 838.90 101.25
7 7.86 8.02 8.14 8.09 7.36 7.75 101.58 838.25
From my experiments, I will test the communication channels between devices 0↔2 and 0↔3 because the NVLink channel is too fast to observe the communication throughput.
Communication-Communication Overlap vs NCCL
In this scenario, I ran five forward passes of ViT-g/14 with a batch size of 32, and then sent two sets of five 64MB tensors (2**24 floats) from device 0 to devices 2 and 3.
You can see the codes for the below experiments from this link:P2P communication test · GitHub
First, I sent tensors with tensor.to
method and interleaving. The total communication latency was 97.441ms (~5.8GB/s)
Then, I overlapped two communication channels by manually creating two separate communication streams. The total latency is reduced by x0.75 (75.075ms), but the throughput also decreases.
This can be expected using simple math. Let’s assume that the read/write bandwidth is equal for all, and there is a single communication channel between a device’s RAM and the host’s RAM for each device. If we overlap two sends, memcpy D2H (device to host) takes twice as long, but the subsequent two memcpy H2D (host to device) can run concurrently. Thus, the single latency of communication becomes 1.5 times longer than before. Since the number of communications becomes half, the total latency will be x0.75: 100ms to 75ms.
However, the throughput becomes 3 times slower (97.441ms → 270.14ms) if we change tensor.to
to dist.send
. If we give up interleaving and change it to batch_isend_irecv
, the total latency increases (270.14ms → 298.44ms).
Note: I wonder if the slowdown of NCCL comes from the disabled P2P on retail GPUs (RTX3090). I’ll repeat this on an A100 cluster.
Computation-Communication Overlap vs NCCL
Parallelization frameworks such as Deepspeed usually support communication-computation overlap to enhance throughput. But by how much?
The above figures are timelines of two sets of five ViT-g/14 forward passes with interleaved 10 communications (64MB per send, GPU 0 to 2 through PCIe).
Without comm-comp. overlap, communication takes 266.9ms and computation 396.8ms. However, if we overlap those, each takes 277.8ms and 584.6ms. This means that the overlapping part of computation becomes 3 times slower (100ms → ~300ms).
Overlapping certainly provides benefits, but it disrupts the assumptions of pipeline schedule optimizers (e.g., Alpa, DAPPLE) that assume computation and communication are independent and parallelizable.
This slowdown is not observed with memcpy-based communication, and it’s also two times faster than NCCL (133ms → 137ms).
Here, we can see the delay based on “Command Buffer Full”. I suppose that the communication load was too large (total 640MB), and if I reduce the size of the communication by half, the gap between two computation sets is gone.
Documentation errors of dist.send
and dist.isend
From the figures of NCCL comm-comp. overlapping, we can see the timing of NCCL kernel invocation from the CPU from the left side of the images (blue vertical bars).
The differences in the two images are just changing dist.send
to dist.isend
, but contrary to the documentation of dist.send
(Distributed communication package - torch.distributed — PyTorch 2.3 documentation), it does not block the Python process, thus the following computation kernels are seamlessly computed with the help of asynchronous CUDA kernel activation.
Also, a wait
method of the result of dist.isend
does not block the Python process, but it works similarly to the torch.cuda.Event.wait(stream)
function that guarantees all the following kernels assigned to stream
are executed after the wait
function.
Proposal & Questions
- NCCL P2P APIs require a strict total order of every NCCL kernel. Can we build out-of-order P2P APIs for multi-node systems?
- We may be able to implement harmless computation-P2P communication overlap for multi-node training since tensor communication should not consume CUDA threads, as we can see from memcpy based pipeline parallelism.
- We might use GPUDirect RDMA or revive tensorpipe for that.
- RPC is based on
tensorpipe
, but it does not support direct P2P communication of CUDA tensors. Can we build simpler APIs just for P2P communication between processes