torch::Tensor ParallelOps::GatherFromTensorModelParallelRegion(
const torch::Tensor& input /*[in]*/,
const c10::intrusive_ptr<c10d::ProcessGroup>& process_group /*[in]*/) {
std::int64_t world_size = process_group->getSize();
if (world_size == 1) {
return input;
}
std::vector<at::Tensor> output_tensors;
for (std::int64_t i = 0; i < world_size; i++) {
output_tensors.push_back(torch::empty_like(input));
}
std::vector<std::vector<at::Tensor>> output_tensors_vec{output_tensors};
std::vector<at::Tensor> input_vec{input};
auto work = process_group->allgather(output_tensors_vec, input_vec);
work->wait();
return torch::cat(output_tensors, input.dim() - 1 /*last_dim*/);
}
Given the above code snippet, consider that I have acquired a custom cuda stream right before this function is called. In order to achieve comm-compute overlap (or have a fine grained control over when I collect the results from allgather), I want to ensure that I enqueue the communication task on the comm stream and collect the results of this function right before the next op that uses these outputs.
I want to understand what does work->wait(); do here. Is it that it waits on the communication op to finish on the CUDA stream before allowing the program to continue execution or does it just wait for the host to finish enqueuing the job and then continue to run the host code (while the task on the CUDA stream runs async).
If it is the latter option, then I fail to understand the point of work->wait(); altogether. Isn’t that the default behavior of CUDA streams ? Also what would be the correct way to block on the outputs of this stream before they can be used for the next op