I have a query on the set of ops that a backend must register with PyTorch 2.0. As per IRs — PyTorch master documentation, PyTorch 2.0 offers two set of IRs for backends to interface with: Core Aten IR and Prims IR.
The Core Aten IR is fully functional and doesn’t have inplace or _out variations. However, does PyTorch 2.0 decompose torch ops into the Core Aten IR ops only when a python frame is passed via torch.compile? If the execution is not via torch.compile, or if torch.compile falls back to eager mode of execution, is there a way for the backends to still get only Core Aten IR ops? If this isn’t possible, then should a backend register other ops outside of Core Aten IR, including inplace and _out variations, to support eager execution?
@SherlockNoMad can you comment here?
At the moment, PT2 only run decomposition in the compilation path. I don’t have a recommended way to run decomposition in eager mode. If you think this would be useful feature, please raise a feature request via github issue.
In torch.compile path, computations fall back to eager would only happens when there is dynamo graph break.
Thanks for the clarification. My understanding is that it is not enough for a backend to enable only core aten IR / prim IR ops, as the eager mode execution is still a valid user option, and even torch.compile can fallback to eager for graph breaks, like you mentioned.
Given this, the set of ops that needs to be implemented on a backend with 2.0 is still the entire aten op set. Is this a valid assumption?
Plus the prims set, if you would like your user to torch.compile at some point.
It’s not that hard to run decompositions in “eager mode”, so if you support core Aten IR/Prim IR it would be pretty easy to make it run in eager mode (which is essentially just a graph with a single element).
You only need to support whatever prims/aten operators that make up operators you’re decomposing. For example, there are many prims that Inductor doesn’t support.
Will the decomposition be done in the framework before the ops are dispatched to backend, or should the backend handle it, maybe via a torch dispatch based decomposition?
There is also a potential performance impact, as the decomposition in torch.compile happens only during the graph compile time but in the eager flow, the decomposition will happen during every op execution.
should the backend handle it, maybe via a torch dispatch based decomposition?
Yeah that’s one reasonable option. Another option is to register a per-op kernel that’s precompiled based off of the decomposition. This could also resolve the issue you mentioned with “decomposition happening during every op execution”.
See Allow users to override kernels for existing C++ ops through Python by anjali411 · Pull Request #75905 · pytorch/pytorch · GitHub
That could work only because existing CUDA backend supports the entire ATen op set in eager mode. It still seems to me that alternative backend needs to support the entire op set for both ATen and Prims (neither stabilized), and the number of ops backends to register strictly goes up from PT 1.x to 2.0.
Indeed, we see XLA and MPS grinding their own decompositions. If backends still need to support the entire ATen op set anyway, I don’t see any incentives for them to decompose first to Prims.
Here’s an example of running decomposition in eager mode