tl;dr
I am interested in adding support for one-sided distributed GEMMs to DTensor. I’m seeking comment about the right way to integrate these into DTensor. There are essentially two issues I see: 1) DTensor would likely need to allocate local tensors in symmetric memory (if enabled), and 2) would need a way to select one-sided algorithms at dispatch time in cases where they are available and optimal.
Why One-Sided Algorithms?
There’s been a lot of work recently on one-sided algorithms for distributed GEMM. Instead of using collectives, these algorithms use one-sided put and get operations to move data. They directly use the put/get operations exposed by intra-node networks (like NVLink) and RDMA inter-node networks (like InfiniBand).
There are a few areas where one-sided algorithms have an edge over collective-based algorithms:
-
Communication/Computation Overlap: In each iteration, one-sided GEMMs perform one or two get operations, a local GEMM, and possibly a put or remote accumulate. (The precise mix depends on the chosen data movement strategy, which usually follows from the matrices’ partitionings.) This makes it straightforward to overlap computation and communication inside the distributed GEMM, unlike when AllGather/ReduceScatter collectives are used. There is significant recent work on optimized one-sided algorithms, much of it in the context of GPU-initiated communication ([1], [2], [3]), where small tiled GEMMs can be overlapped with transfers within the kernel. Similar overlapping techniques can be used with CPU-initiated communication, however.
-
One-sided distributed GEMM implementations can support a wide range of partitionings, including some partitionings that are not straightforward to support with DTensor’s current collective-based approach. Examples include multiplications between matrices with mixed replication factors and 2D replicated partitionings. It’s straightforward to support a wide range of partitionings with one-sided distributed GEMMs, and these achieve excellent performance on NVLink systems ([4], [5]).
What would be required for DTensor to support one-sided GEMMs?
I see two potential changes that would be required for DTensor to support distributed GEMMs that use one-sided communication.
Local Tensors Allocated in Symmetric Memory
Typically, data structures that will be used with one-sided communication are allocated in “pinned” symmetric memory. On NVLink, this usually involves passing a cudaIpcMemHandle_t object between processes to register the memory allocations on each process (or calling a library like NVSHMEM that does this for you). On inter-node interconnects like InfiniBand, this involves registering the memory region with the NIC and ensuring it will not be swapped out.
The most direct way to accomplish this would be to ensure that a DTensor object’s local tensor is always allocated in pinned memory (at least when one-sided GEMMs are enabled). This could be done by modifying 1) DTensor’s constructor, to make it allocate the local tensor in symmetric memory, and 2) distribute_tensor, to again allocate using symmetric memory. It’s worth noting that allocating in pinned memory can have performance benefits even if you don’t plan to use one-sided put/get yourself, since it can speed up collectives that themselves would have to pin memory before using put/get (or use a bounce buffer), thus removing that overhead. So, it might be worth supporting local tensors allocated in symmetric memory independent of one-sided algorithms. The actual allocation could be done with SymmetricMemory, or, in the future TorchComms, which merges one-sided memory into the main communication backend.
It should be noted that it is possible, at least with certain backends, to perform a rendezvous to register pinned memory after-the-fact rather than on allocation. This would potentially allow for one-sided operations to do the rendezvous directly in the algorithm call, rather than requiring modifications to DTensor, but at a performance cost.
Exposing One-Sided Algorithms
My understanding is that DTensor’s algorithm selection for distributed GEMM is based on sharding propagation. In essence, there are a number of possible sharding options, and based on the sharding selected there is a clear execution flow involving AllGather, local GEMM, and ReduceScatter, that has a clear cost. DTensor can then consider different sharding rules, picking the one with the lowest cost.
It seems to me like incorporating new types of algorithms would require modifying the sharding propagation logic, since it would need to select not only the sharding propagation, but also select from different algorithms (e.g., the original DTensor path or the one-sided path where available).
I have implemented a prototype that adds support for one-sided algorithms to DTensor by registering a custom op handler using DTensor’s op dispatcher. The handler in this case handles an ATen op like addmm.out, using some simple logic to determine whether a one-sided algorithm is supported and will be optimal for the use case. If not, it falls back to the existing DTensor path. While this works, real upstream support would probably need to integrate with the sharding propagation logic instead.
To end succinctly, here are my questions for the community:
-
Is modifying DTensor to support allocating local tiles using SymmetricMemory tenable? (Or using TorchComms once upstreamed.) Or is there another mechanism that we should consider?
-
What would be the proper way to integrate new types of algorithms, such as one-sided algorithms, into DTensor’s algorithm dispatch? Would modifying the sharding propagation logic to incorporate algorithm variants as well as shardings be the right way to go?
If this seems like a reasonable direction, I will develop my NVSHMEM prototype into something more robust built on top of SymmetricMemory to upstream.
[1] [2305.06942v2] Optimizing Distributed ML Communication with Fused Computation-Collective Operations
[2] [2406.06858v1] FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion
[3] [2511.12500] Iris: First-Class Multi-GPU Programming Experience in Triton