Correctly using wait() in process group communications like all_gather

torch::Tensor ParallelOps::GatherFromTensorModelParallelRegion(
    const torch::Tensor& input /*[in]*/,
    const c10::intrusive_ptr<c10d::ProcessGroup>& process_group /*[in]*/) {
  std::int64_t world_size = process_group->getSize();
  if (world_size == 1) {
    return input;
  }

  std::vector<at::Tensor> output_tensors;
  for (std::int64_t i = 0; i < world_size; i++) {
    output_tensors.push_back(torch::empty_like(input));
  }

  std::vector<std::vector<at::Tensor>> output_tensors_vec{output_tensors};
  std::vector<at::Tensor> input_vec{input};

  auto work = process_group->allgather(output_tensors_vec, input_vec);
  work->wait();
  return torch::cat(output_tensors, input.dim() - 1 /*last_dim*/);
}

Given the above code snippet, consider that I have acquired a custom cuda stream right before this function is called. In order to achieve comm-compute overlap (or have a fine grained control over when I collect the results from allgather), I want to ensure that I enqueue the communication task on the comm stream and collect the results of this function right before the next op that uses these outputs.

I want to understand what does work->wait(); do here. Is it that it waits on the communication op to finish on the CUDA stream before allowing the program to continue execution or does it just wait for the host to finish enqueuing the job and then continue to run the host code (while the task on the CUDA stream runs async).

If it is the latter option, then I fail to understand the point of work->wait(); altogether. Isn’t that the default behavior of CUDA streams ? Also what would be the correct way to block on the outputs of this stream before they can be used for the next op

AFAIK, work→wait(); makes the current stream wait for nccl events (i.e. nccl kernels) to complete, so it won’t block your program by default. Async nccl ops will run on preserved nccl streams, and have to explicitly wait for it.

So I think you don’t have to create a custom cuda stream for a single allGather op. Just work→wait() when you need to use its output.

By the way, in the code you provide, the output tensor is created on your custom cuda stream. It may go wrong if you don’t manage its lifetime through record_stream explicitly.

1 Like

Thanks for the response!

When you say “Async nccl ops will run on preserved nccl streams, and have to explicitly wait for it.”, does this imply that in the above code snippet shared, the allgather (or any nccl comm) would run on a streams specially allocated for nccl ops and not on the stream that might have been defined by me earlier.

Also is this correct the right way to put it. wait() ensures that before the outputs of this op are used, they will be completed without blocking execution on the host

1.Yes. You can check this. It will run on current stream when AsyncOp=false after 2.8.0. (and work.wait() implicitly before 2.8.0 if it was synced, so not that much difference)

2.Yes. wait() only makes current cuda stream wait for nccl kernels, not host. (except barrier and blockwait) But be careful when you manage multi-streams yourself.