I have a problem when using aot_autograd with my own backend and torch.compile in in inference mode.
Example code:
def test1():
with torch.inference_mode(True):
@torch.compile(backend='my_backend')
def fn(inp):
instance_norm = torch.nn.InstanceNorm2d(3, device='myDevice')
return instance_norm(inp)
inp = torch.randn(1, 3, 4, 2).to(device='myDevice')
print(fn(inp))
My backend:
@register_backend
def my_backend(graph_module: torch.fx.GraphModule, example_inputs: List[torch.Tensor], **kwargs):
return aot_autograd(
fw_compiler=training_compiler_fw,
bw_compiler=training_compiler_bw,
inference_compiler=inference_compiler,
)(graph_module, example_inputs)
Without aot_autograd, using only inference_compiler, instance_norm is normally lowered to my custom op under DispatchKey::MyPrivateKey.
However, when using aot_autograd, it is lowered using the native lowering and happens under CompositeImplicitAutograd.
Example of registration:
TORCH_LIBRARY_IMPL(aten, MyPrivateKey, m) {
m.impl("instance_norm_wrap", instance_norm_wrap);
}
Could you help resolve this issue?
Or How should a backend override CompositeImplicitAutograd only for a specific backend? For other backends,
we would like to dispatch via the default path, but it causes recursion. Is there guidance on how this should be done?
The following code works, but we still need the native path for non-MyDevice devices:
@torch.ops.aten.instance_norm.default.py_impl(torch.ops.aten.instance_norm.default, DispatchKey.CompositeImplicitAutograd)
def instance_norm(input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
return torch.ops.my_dev.instance_norm_wrap.default(input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled)
@register_meta([torch.ops.my_dev.instance_norm_wrap.default])
def instance_norm(input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
out = torch.empty_like(input)
return out