According to pytorch2.0’s Doc, ATen ops and Prims ops seem to be independent. But Fx Graph after aot_autograd seems like a mix of both ATen and Prims ops.
So how could one get pure ATen ops(or Prims ops) after aot_autograd? Or am I misunderstanding the design of PrimTorch?
Incorrect.
Ref to:
In my understanding, the Prims ops only work for those who have their “hardened decompositions” in module “_refs”(currently at Feb. 2023). For example if you use torch._refs.abs(x)
and get its Fx Graph, you may get pure Prim IR abs_1 = torch.ops.prims.abs.default(add_2)
.
The answer to this question is a little complicated because we have not spent enough effort on this part of the UX, but the short answer is that you have the ability to pick which decomps you use, so in principle you can get only aten or prims in your graph by choosing your decomps carefully. How to choose? Well, uh, we should have an api for this but we don’t. We could probably figure this out automatically by analyzing each decomp and seeing if it calls a prim or not. This would be a pretty useful PR if you want to submit it.
from torch._decomp import core_aten_decompositions
core_aten_decompositions() should give you a Aten-Only opset.
In fact, I tried core_aten_decompositions(), but got Ops outside Core-ATen-IR
Could you please report an issue in pytorch repo?
Core Aten IR is still under development, we might be missing some operators in core aten opset.