Idea and inspiration
3 years ego @albanD introduced the idea of using tensor subclass for defining new
device,
and published this illustrative example in his
subclass zoo: subclass_zoo/new_device.py at main · albanD/subclass_zoo · GitHub
The idea of new_device.py
is simple and quite elegant:
- The backend has it’s own data representation as a python object (
raw_data
of type numpy array in the example above). - The backend tells PyTorch how each ATen op maps in terms of its own opset( numpy ops in the example above), in form of a Python callable that calls its own ops (defined as python functions).
- Then
__torch_dispatch__
routes each ATen op to the corresponding implementation. - Some use of
__torch_function__
modes are used to capture tensor constructors.
Existance of The CoreAten Opset (IRs — PyTorch 2.6 documentation)
means that theoretically a backend only need to implement those 200 ish ops
instead 2000+ of ATen or torch functions; since other ATen ops will decompose into
the core Aten set (such decomposition can happen in C++).
Ideal backend registration API?
Let’s think from a backend provider’s perspective and ask the question: “What is the
minimum amount of information that a backend has to provide for PyTorch to run math
on that backend?”
The minimum set seems to be:
-
A function that can convert a CPU
torch.Tensor
to the backend’s object (raw_data
).
backend’s object is opaque to PyTorch. -
A set of function, each for every core Aten op, that takes the backend’s raw_data and
returns raw_data that implement the semantics of the core Aten op. I.e. lowerings from
core Aten to the backend ops.
With the above information, PyTorch the framework should be able to run arbitrary model
on this backend, by doing the following:
- On every core Aten op, or ops with defition, call the corresponding backend function.
- On new tensor first construct on CPU and use the function provided in 1. to move it
to device. - On mutation / view etc can be functionalized on the fly.
A backend can optionally do more to boost performance:
- register all other tensor constructors (ones, zeros, randn etc), so Tensor can be
constructed on device directly. - register mutation, view ops
- Override torch.* functions for those that torch default decomposition is undesired… etc
Then, with the above, torch.compile
with this backend should also work automatically,
by having Pytorch provide a torch.fx.Interpreter
that calls the backend registered lowerings.
Aside: What if I want to implement my backend lowerings in C++
I believe that the C++ ops registration API (Facilitating New Backend Integration by PrivateUse1 — PyTorch Tutorials 2.6.0+cu124 documentation)
still has its place and should be maintained concurrently as newer API.
However, for the sake of argument, let’s see what is the advantage of doing to the
Python route even if the lowerings are written in C++.
For now, assume that raw_data
is a cpp object that is exposed to Python with pybind11,
and the lowerings are also exposed as python functions with pybind11. Then, you can use
the same Python based API to register your backend. Why would one bother to do that?
For one, even though your backend still has torch
as a dependency, this dependency
is not a pip (i.e. runtime) dependency and no longer a compile time dependency.
This means:
-
To build your backend, you don’t have to build torch from source
(to get the generated headers, if you happen to depend on those), which can take
quite a few minutes. -
How many times you or your user has seen these types of error messages:
ImportError: ....so: undefined symbol: _ZN5torch8autograd12VariableInfoC1ERKN2at6TensorE
Usually the above happens if your torch version is not the one your backend is compiled against,
if your backend is pure Python, you will never see those errors again. -
Consequence of 2, now your library can support multiple version of torch at the same time.
(even if torch’s python API change, you can haveif torch.__version__ > ....
in your python code)
Putting the theory into practice
1 year ago I decided to put the above into practice and prototyped a backend to
run Pytorch on Google Cloud TPU by using jax.Array
as the backend raw_data
(xla/torchax/torchax/tensor.py at master · pytorch/xla · GitHub), and
writing lowerings of Pytorch Aten ops in Jax ops (xla/torchax/torchax/ops/jaten.py at master · pytorch/xla · GitHub).
See more at: xla/torchax/docs/how_it_works.md at master · pytorch/xla · GitHub
The result is a library that can run PyTorch on TPU on some serious workloads, such
as training llama from torchtitan, and
Run inference workloads for llama and mixtral.
(Disclaimer: this backend is still in alpha stage, if you are a TPU user please don’t
use it for anything serious yet. Although we do welcome experimenting and contributions!)
The experience of the above journey is overall pleasant, but it still has few issues:
-
Torch.compile doesn’t work OOTB. I have described the issue in detail here:
[RFC] torch_xla2 dynamo integration · Issue #7255 · pytorch/xla · GitHub and discussed with few folks about it,
the issue is that torch.compile attempts to trace into__torch_dispatch__
-
Using
meta
device here: subclass_zoo/new_device.py at main · albanD/subclass_zoo · GitHub
makes pytorch thinks my device type ismeta
. This is usually fine, except when a
decomposition need to create new Tensor (such as aten::cov, that calls aten::full with device=meta),
when we will be constructing an actual meta Tensor. I attempted to pass in privateuseone
for device on_make_subclass
or_make_wrapper_subclass
and gotten an error.
The above 2 issue seems fixable by either:
-
Define my torch.Tensor in C++, that C++ tensor will have one field of type
PyObject*
,
and methods to set and get this object in Python. So that the operator defitions,
and device data representation (i.e. the `raw_data) can still be in Python. However, that
removes the benefit of not depending on torch headers (such as compile time and ABI issues.) -
Do whatever magic
FakeTensor
does: somehowFakeTensor
is a meta Tensor and knows it’s device,
and interacts well with torch.compile as well as decompositions.
My ask from the Pytorch community
Do people think it’s a good idea to standardize this approach and maybe introduce
friendlier API for Python op registration (that is meant to override ATen for a device,
not to introduce custom ops like [RFC] New Python operator registration API - #4 by gilfree).
Any tips on how to accomplish 2? i.e. What exactly is FakeTensor’s magic?
If we define a custom C++ tensor.Tensor subclass described in 1. can that class live
in the upstream so there’s no torch header dependency from a backend? This class
is generic enough for any backend to use. Maybe give it its own dispatch key so
we don’t need to overload privateuseone?
Any other suggestions?