I am looking to decompose ops in my model graph traced by AOT module in the following way (using some code from AOT Autograd - How to use and optimize? — functorch nightly documentation)
from torch._decomp import core_aten_decompositions
from functorch.compile import aot_module
aot_module(model, fw_compiler=compiler_fn, bw_compiler=compiler_fn, decompositions=**core_aten_decompositions())
- Is this the correct way to do this?
- My assumption was Core Aten ops won’t be decomposed, and yet when I check the Aten ops as a result of this decomposition, ops like
aten.embedding_dense_backward
andaten.embedding
get decomposed even though they are listed as Core Aten ops. What am I missing?