What is the difference between the two design goals?
__torch_function__
solves the problem of “I can overload the meaning of tensor.add() with Python duck typing, but I can’t overload the meaning of torch.add(tensor) because this isn’t a method call.” It lets you define a method which lets you overload the meaning of torch.* API calls. It works even if you don’t subclass Tensor.
__torch_dispatch__
solves the problem of “PyTorch has a big pile of C++ code which implements important subsystems like autograd, and I can’t interpose on it.” It offers a callback into Python after these subsystems have been processed. A torch dispatch Tensor subclass (these must be Tensor subclasses) can, for example, change the behavior of operations called by the autograd engine.
You can also see more detailed (though not 100% finished yet) documentation at Extending PyTorch — PyTorch main documentation
what you said may wrong. The flollowing code work well:
import torch
class PyTensor(torch.Tensor):
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
print(f"PyTensor into {func}")
class MyTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
print(f"MyTensor into {func}")
x = torch.randn(8)
y1 = PyTensor(x)
y2 = MyTensor(x)
y1 = torch.add(y1, y1)
y2 = torch.add(y2, y2)
the output is:
PyTensor into aten.add.Tensor
MyTensor into <built-in method add of type object at 0x7f2f100ea7e0>```
Great! I got answer from link you post. Thank you! Copy to here:
In a similar way where torch_function is able to interpose on all of torch’s Python API and Tensor methods, torch_dispatch is able intercepting all calls into the aten native API.
You can refer to this website for this question.
https://pytorch.org/docs/main/notes/extending.html#extending-torch-native-api
While
__torch_function__
allows one to effectively extend PyTorch’s pure Python components’ behavior, it does not allow one to extend the parts of PyTorch implemented in C++. To that end, aTensor
subclass can also define__torch_dispatch__
which will be able to override the behavior at the C++ level.