[RFC][c10d] a new Pytorch API (split_group) to create a process group through ncclCommSplit

This is a repost of the RFC in Github: [RFC][c10d] a new Pytorch API (split_group) to create a process group through ncclCommSplit · Issue #130407 · pytorch/pytorch · GitHub

Motivation
In current Pytorch/c10d, the new_group API is used to create a new process group from the default pg, when device_id is specified in init_process_group and nccl is used as the backend, the new_group call will use ncclCommSplit to create the nccl communicators to save communicator resources. It has a few drawbacks:

  1. Redundant calls
    Suppose the default group has 256 ranks, we need to have 32 children PGs and each child PG has 8 ranks. in this case, each rank needs to call new_group and ncclCommSplit 32 times because of how we implement new_group API and the collective requirement of ncclCommSplit. For a specific global rank, 31 calls of ncclCommSplit would be no_color split, and only 1 of them is colored split. With the proposed new split_group API, we expect only 1 call of split_group/ncclCommSplit is needed per rank in the above example case
  2. new_group can only split from default_pg
    Ideally, a new pg should be able to be split from any pg

With the new split_group API, users can create new PGs using ncclCommSplit with less number of calls and initialize the PG eagerly. This is also useful in the cases of Pipeline Parallelism where creating P2P communicators efficiently is needed.

Proposal
A new c10d API:

def split_group(
    parent_pg=None,
    split_ranks=None,
    timeout=None,
    pg_options=None,
    group_desc=None,
):
    """
    Create a new distributed group split from the given parent group.

    users of this API must guarantee that all ranks in the parent group enter this API call.
    And the split of the group is the same across the ranks.

    Args:
        parent_pg (ProcessGroup, optional): The parent process group. If None,
            the default process group will be used. Users need to gurantee that 
            the parent group is fully initialized (e.g, communicators are initialized)
        split_ranks (list[int]): the split ranks, which is a list of list of ranks.
            Users need to make sure the validity of the split rannks such that one 
            split (represented by one inner list of ints) does not overlap with any other split.
            note, the ranks in each split is the group rank in the parent pg.
        timeout (timedelta, optional): see `init_process_group` for details and default value.    
        group_desc (str, optional): a string to describe the process group.

    Returns:
        A handle of distributed group that can be given to collective calls or
        GroupMember.NON_GROUP_MEMBER if the rank is not part of any split_ranks`.

    """

Example usage of the API

# suppose parent_pg has 8 ranks [0,1,2,3,4,5,6,7]
new_pg = dist.split_group(parent_pg, [0,1,2,3],[4,5,6,7])
# call collective within the new pg on each rank
new_pg.allreduce(tensor)
1 Like