What I want to accomplish
Given a non-core-aten operator P
decompose it into core aten ops using the a decomposition table. If the decomposition in the table doesn’t consist of only core Aten ops, then keep going (recursively) until everything left is core Aten ops.
Why I want to do that
I am implementing a Pytorch backend in Python through tensor subclass and torch dispatch modes (details described here: Embrace tensor subclass as a Python device registration API). In order to not have to implement ALL the 2000+ ops in pytorch, I rely heavily in decompositions provided by pytorch in torch._decomp.core_aten_decompositions()
.
Current approach
My way to obtain this currently is via recursively applying the dispatch mode, basically:
def __torch_dispatch__(self, func, args, kwargs):
decomp = lookup_decomposition(func)
with self:
decomp(*args, **kwargs)
The idea is that if decomp is a callable that is equivalent of func implemented in terms of other ops, then these ops will come back to my dispatch mode (enabled again with with self:
), if any of those have further decompositions it will keep going until everything is core Aten.
This strategy has worked well enough until torch 2.6.
At torch 2.6, the content of core_aten_decompositions()
changed to include the _special_op_to_decompose_cia
closure defined here: pytorch/torch/_export/utils.py at 60a45eb862d5e8b4ba2dd435d34ef04ae231e885 · pytorch/pytorch · GitHub
At the call kernel._op_dk(...)
the same op will be captured by my dispatch mode again. Causing infinite loop.
I understand that it is desirable to also use decompositions defined in C++ for CIA ops, and I do want to take advantage of these decompositions.
What is recommended way to accomplish decomposition with presence of the new CIA decompositions?
Thanks in advance.
cc. @gmagogsfm @angelayi