Following a few issues with the MPS implementation of at::to
and at::copy_
, I wanted to do a quick dive into what these functions are doing and their implementation details.
At the same time, some discussions happened about the behavior of the memory_format
argument for at::to()
which is related and so covered here as well.
at::to()
First, let’s define what these functions do. The at::to
function actually has multiple overloads and is quite confusing. And as you might expect, the python binding is actually done by hand (in the template for python_variable_methods.cpp). That custom binding chooses which overload to call to avoid any issue with c++ function overloading.
This function first checks if anything needs to be done: if the given property is already correct on the Tensor (for example doing .to("cuda")
on a Tensor already on the CUDA device).
Memory format is special here where it is a no-op both if the given memory_format is preserve or if the suggested memory format of the input Tensor is the same as the one being passed.
Otherwise, it calls into at::_to_copy
. This one has a single signature and is more straightforward. In particular:
- All arguments are optional and are kwarg-only
- It never returns a view of the input.
- The output is always a freshly created Tensor with the following properties:
- If memory_format is “MemoryFormat::Preserve” then we preserve the input’s stride (if the input supports stride)
- Otherwise we create a newly contiguous Tensor based on the provided memory_format and the suggested memory_format for the input (via
self.suggest_memory_format()
) otherwise.
- Always uses
output.copy_(input)
to populate the newly created output.
at::copy_
Second let’s look at the at::copy_
function. Ignoring the non_blocking
argument as it is not relevant.
There are a few of special cases for meta, quantized, vulkan, meta and mps devices and then the implementation that supports everything else.
Let’s now look at the CPU and CUDA case only as it should be representative.
First, copying CPU → CPU, we have special cases for conj/neg inputs, when copying a transposed Tensor into a normal Tensor (with all other properties being the same) and then defer to TensorIterator with a trivial kernel to perform the copy for every other case. Note that a full memcpy of the whole buffer is never used.
Second, copying GPU:i → GPU:i or GPU:i → GPU:j (with p2p access), we rely on the TensorIterator to collapse dimensions and provide Tensor properties. If the two Tensors are perfectly aligned in memory (confusingly named is_contiguous()
on the TensorIterator), we can use cuda’s memcpy to perform the copy, otherwise, we use a trivial kernel with the TensorIterator.
Third, copying CPU → GPU or GPU:i → GPU:j (without p2p), we need to use one of cuda’s memcpy to perform such transfer and cannot use the TensorIterator directly to handle misaligned copy.
In this case, if the Tensors are not perfectly aligned (according to TensorIterator), then we create strided contiguous copies of the non-contiguous Tensors and then use cuda’s memcpy on these.
Conclusion
A few closing observations about at::to()
and memory format:
- The default “MemoryFormat::Preserve” for
.to()
ensures that the copy kernel that follow is always getting perfectly aligned inputs when moving across devices. - A possible resolution for `x.to(memory_format=torch.contiguous_format)` does not always return a contiguous tensor · Issue #62027 · pytorch/pytorch · GitHub could be:
- We add a new memory format
MemoryFormat::None
that is the memory format when the Tensor is contiguous with respect to no other memory format. - Checking is_contiguous for the None memory format is always
True
. - We can add a method to return the current memory format (always preferring any memory format before None). We have the invariant that
t.is_contiguous(memory_format=t.memory_format) == True
- We can update
.to(memory_format=)
to use the newly added method instead ofsuggest_memory_format()
to check if anything needs to be done. This means in particular that, for the first example in the issue, the memory format will now be None, won’t match the requested format and so a copy will be triggered.
- We add a new memory format
A few closing observations about at::copy_
:
- The updated CPU → MPS copy implementation from [MPS] Revamp copy_to_mps_ implementation by malfet · Pull Request #86956 · pytorch/pytorch · GitHub is indeed now matching the cuda logic by using contiguous intermediary. It is different though as it does not use the TensorIterator to detect if the memcpy is possible or not. Since the TensorIterator is already created by the time we call this custom MPS function in the CompositeExplicit
at::copy_
function, we can update the MPS implementation to use the computations already performed by the TensorIterator (note that it goes one step further as it will also properly detect non-contiguous Tensors that have the same layout in memory). - The MPS → CPU and MPS → MPS implementations should be updated in a similar way as the CPU → MPS to fix similar edge cases.
- We could refactor some of the logic in Copy.cu to reduce code duplication between the CUDA and MPS implementations as they have the exact same properties.
- We should be able to remove the
at::_copy_from
MPS implementation as, since it is a supported device, that function is never called fromat::copy_
. Maybe it is called from other places directly?