Embrace tensor subclass as a Python device registration API

Idea and inspiration

3 years ego @albanD introduced the idea of using tensor subclass for defining new
device,
and published this illustrative example in his
subclass zoo: subclass_zoo/new_device.py at main · albanD/subclass_zoo · GitHub

The idea of new_device.py is simple and quite elegant:

  • The backend has it’s own data representation as a python object (raw_data of type numpy array in the example above).
  • The backend tells PyTorch how each ATen op maps in terms of its own opset( numpy ops in the example above), in form of a Python callable that calls its own ops (defined as python functions).
  • Then __torch_dispatch__ routes each ATen op to the corresponding implementation.
  • Some use of __torch_function__ modes are used to capture tensor constructors.

Existance of The CoreAten Opset (IRs — PyTorch 2.6 documentation)
means that theoretically a backend only need to implement those 200 ish ops
instead 2000+ of ATen or torch functions; since other ATen ops will decompose into
the core Aten set (such decomposition can happen in C++).

Ideal backend registration API?

Let’s think from a backend provider’s perspective and ask the question: “What is the
minimum amount of information that a backend has to provide for PyTorch to run math
on that backend?”

The minimum set seems to be:

  1. A function that can convert a CPU torch.Tensor to the backend’s object (raw_data).
    backend’s object is opaque to PyTorch.

  2. A set of function, each for every core Aten op, that takes the backend’s raw_data and
    returns raw_data that implement the semantics of the core Aten op. I.e. lowerings from
    core Aten to the backend ops.

With the above information, PyTorch the framework should be able to run arbitrary model
on this backend, by doing the following:

  • On every core Aten op, or ops with defition, call the corresponding backend function.
  • On new tensor first construct on CPU and use the function provided in 1. to move it
    to device.
  • On mutation / view etc can be functionalized on the fly.

A backend can optionally do more to boost performance:

  • register all other tensor constructors (ones, zeros, randn etc), so Tensor can be
    constructed on device directly.
  • register mutation, view ops
  • Override torch.* functions for those that torch default decomposition is undesired… etc

Then, with the above, torch.compile with this backend should also work automatically,
by having Pytorch provide a torch.fx.Interpreter that calls the backend registered lowerings.

Aside: What if I want to implement my backend lowerings in C++

I believe that the C++ ops registration API (Facilitating New Backend Integration by PrivateUse1 — PyTorch Tutorials 2.6.0+cu124 documentation)
still has its place and should be maintained concurrently as newer API.
However, for the sake of argument, let’s see what is the advantage of doing to the
Python route even if the lowerings are written in C++.
For now, assume that raw_data is a cpp object that is exposed to Python with pybind11,
and the lowerings are also exposed as python functions with pybind11. Then, you can use
the same Python based API to register your backend. Why would one bother to do that?

For one, even though your backend still has torch as a dependency, this dependency
is not a pip (i.e. runtime) dependency and no longer a compile time dependency.

This means:

  1. To build your backend, you don’t have to build torch from source
    (to get the generated headers, if you happen to depend on those), which can take
    quite a few minutes.

  2. How many times you or your user has seen these types of error messages:
    ImportError: ....so: undefined symbol: _ZN5torch8autograd12VariableInfoC1ERKN2at6TensorE
    Usually the above happens if your torch version is not the one your backend is compiled against,
    if your backend is pure Python, you will never see those errors again.

  3. Consequence of 2, now your library can support multiple version of torch at the same time.
    (even if torch’s python API change, you can have if torch.__version__ > .... in your python code)

Putting the theory into practice

1 year ago I decided to put the above into practice and prototyped a backend to
run Pytorch on Google Cloud TPU by using jax.Array as the backend raw_data (xla/torchax/torchax/tensor.py at master · pytorch/xla · GitHub), and
writing lowerings of Pytorch Aten ops in Jax ops (xla/torchax/torchax/ops/jaten.py at master · pytorch/xla · GitHub).
See more at: xla/torchax/docs/how_it_works.md at master · pytorch/xla · GitHub

The result is a library that can run PyTorch on TPU on some serious workloads, such
as training llama from torchtitan, and
Run inference workloads for llama and mixtral.

(Disclaimer: this backend is still in alpha stage, if you are a TPU user please don’t
use it for anything serious yet. Although we do welcome experimenting and contributions!)

The experience of the above journey is overall pleasant, but it still has few issues:

  1. Torch.compile doesn’t work OOTB. I have described the issue in detail here:
    [RFC] torch_xla2 dynamo integration · Issue #7255 · pytorch/xla · GitHub and discussed with few folks about it,
    the issue is that torch.compile attempts to trace into __torch_dispatch__

  2. Using meta device here: subclass_zoo/new_device.py at main · albanD/subclass_zoo · GitHub
    makes pytorch thinks my device type is meta. This is usually fine, except when a
    decomposition need to create new Tensor (such as aten::cov, that calls aten::full with device=meta),
    when we will be constructing an actual meta Tensor. I attempted to pass in privateuseone
    for device on _make_subclass or _make_wrapper_subclass and gotten an error.

The above 2 issue seems fixable by either:

  1. Define my torch.Tensor in C++, that C++ tensor will have one field of type PyObject*,
    and methods to set and get this object in Python. So that the operator defitions,
    and device data representation (i.e. the `raw_data) can still be in Python. However, that
    removes the benefit of not depending on torch headers (such as compile time and ABI issues.)

  2. Do whatever magic FakeTensor does: somehow FakeTensor is a meta Tensor and knows it’s device,
    and interacts well with torch.compile as well as decompositions.

My ask from the Pytorch community

Do people think it’s a good idea to standardize this approach and maybe introduce
friendlier API for Python op registration (that is meant to override ATen for a device,
not to introduce custom ops like [RFC] New Python operator registration API - #4 by gilfree).

Any tips on how to accomplish 2? i.e. What exactly is FakeTensor’s magic?

If we define a custom C++ tensor.Tensor subclass described in 1. can that class live
in the upstream so there’s no torch header dependency from a backend? This class
is generic enough for any backend to use. Maybe give it its own dispatch key so
we don’t need to overload privateuseone?

Any other suggestions?

cc. @ezyang @albanD @zou3519
Thanks!
Han from Pytorch/XLA.

Hey!

Thanks for taking the time to write down this proposal !

The experience of the above journey is overall pleasant, but it still has few issues:

  1. This is ~expected as of today. In particular because we have been focusing the torch.compile/subclass usage for subclass that are “wrapper around other Tensor that eventually desugar into ops on plain Tensors”.
    I think you can make sure your design fits this but it might be awkward? In particular, store all the data as another Tensor that is just a holder, and translates everything to your own custom_ops (as in PyTorch Custom Operators — PyTorch Tutorials 2.6.0+cu124 documentation).
    Then compile will “desugar” into these ops and run them as black boxes.

  2. IIRC the FakeTensor trick is pretty simple: pytorch/torch/_subclasses/fake_tensor.py at main · pytorch/pytorch · GitHub
    I would agree with Ed on the issue that, if you can have the device be accurate based on the device you want to use. It is also a bit tricky and, while FakeTensor helped us clean up a lot of things, there is most likely a few rough edges left.

While we’re working with @janeyx99 to provide ABI-stability for a subset of libtorch, we are focusing on custom kernel writers right now, not out of tree device.
This work will help, if you go down the path of only python Tensor object + custom ops in c++.
But if you need to use the PrivateUse1 extension points, this will not be covered in the current plan.

If we define a custom C++ tensor.Tensor subclass described in 1. can that class live
in the upstream so there’s no torch header dependency from a backend? This class
is generic enough for any backend to use. Maybe give it its own dispatch key so
we don’t need to overload privateuseone?

I’m not sure to understand what you mean here and would have a couple questions:

  • What blocks you today from doing all of this with the subclass?
  • The second concern for most backend writers once they have something that works is performance. I am not sure what will be the actual characteristics performance-wise of this approach and if the overhead will be acceptable.
  • What blocks you today from doing all of this with the subclass?

Nothing really. Although, with any C++ component we lose the advantages of not having to deal with C++ build system nor ABIs.

So if such subclass proves to be effective and generic enough, it would be great if it can be upstreamed.

  • The second concern for most backend writers once they have something that works is performance. I am not sure what will be the actual characteristics performance-wise of this approach and if the overhead will be acceptable.

For us it’s not an issue, because if you can apply jax.jit to it. Then what happens is that the raw_data object inside of my tensor subclass will be this Tracer object that jax.jit uses. The final result is that we have a pure StableHLO graph that can be compiled and executed by XLA.

I suspect this works with many similar systems, such as Apple’s MLX.

For backends that did not bundle a graph mechanism with their Python API, then probably using torch.compile is needed for similar experience.

If a backend is / want to be eager only then the probably the way to reduce overhead of decompositions is to use __torch_function__ mode to capture higher level concepts and do the implementation directly.