When I tried to trace silu and silu_backward by below code
import torch
import torch.nn as nn
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, PythonKeyTracer
from torch.fx._symbolic_trace import Tracer
from torch.utils.weak import WeakTensorKeyDictionary
import weakref
device = torch.device("cpu")
def inner_func():
a = torch.tensor([2.0, 3.0], requires_grad=True, device=device)
b = torch._C._nn.silu(a)
c = torch.relu(b)
c.backward(torch.tensor([1.0, 1.0], device=device))
python_fx_tracer = PythonKeyTracer()
with ProxyTorchDispatchMode(python_fx_tracer, 'real'):
graph = python_fx_tracer.trace(inner_func)
print(graph)
I got a graph as below:
I find there are two silu_backward implementation in aten/src/ATen/native/Activation.cpp. silu_backward was decomposed and dispatched to the implementation as this link show.
Questions:
- Why is silu_backward dispatched to a implementation combined of other op in ProxyTorchDispatchMode? that is to say silu_backward is decomposed in fx graph.
- If I want to get a single silu_backward node without decomposition in fx graph, what should I do ?
The reason silu_backward
gets decomposed unconditionally is because it has a “CompositeImplicitAutograd” dispatch path and by design “math_silu_backward” implementation is picked up by the AOTAutograd component of torch.compile.
It would be great if there was a way to specify that certain operations of type “CompositeImplicitAutograd” shouldn’t be decomposed and traced through, I’m not aware of such a mechanism implemented.
There’s a related issue that unfortunately got no response Nonoptimal trace of silu_backward with AOT Autograd · Issue #86612 · pytorch/pytorch · GitHub
I know the AOTAutograd component of torch.compile picks up ‘math_silu_backward’ implementation on purpose and ‘silu_backward’ walks through different dispatch path in many dispatch mode contexts (for example, fake mode) compared to not being in such contexts.
I try to understand why silu_backward needs to walk through different dispatch path in these special context, but can’t figure out. Can you help to explain what the purpose is?
Thanks.
@Chillee might be able to answer this.
It looks like part of the reason is because make_fx()
will unconditionally run CompositeImplicitAutograd
decompositions, if there are any.
code that calls decompose
: pytorch/torch/fx/experimental/proxy_tensor.py at ee83c646bb30d5f11b64013b54174768b733214b · pytorch/pytorch · GitHub
decompose impl: pytorch/torch/_ops.py at ee83c646bb30d5f11b64013b54174768b733214b · pytorch/pytorch · GitHub
this isn’t exactly the nicest solution, but you can see that we rely on whether or not we return NotImplemented
to determine if decompose()
resulted in an actual decomposition. One way to avoid the decomposition would be to use the python dispatcher to register your own implementation of silu_backward()
, that just returns NotImplemented
:
@torch.ops.aten.silu_backward.default.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)
def my_silu_backward(grad_out, self):
# use with some caution: this is only really valid to run in the context of proxy tensor tracing
return NotImplemented