AOTAutograd incorrect lowering composite ops in inference_mode

I have a problem when using aot_autograd with my own backend and torch.compile in in inference mode.

Example code:

def test1():
    with torch.inference_mode(True):
        @torch.compile(backend='my_backend')
        def fn(inp):
            instance_norm = torch.nn.InstanceNorm2d(3, device='myDevice')
            return instance_norm(inp)
        inp = torch.randn(1, 3, 4, 2).to(device='myDevice')
        print(fn(inp))

My backend:

@register_backend
def my_backend(graph_module: torch.fx.GraphModule, example_inputs: List[torch.Tensor], **kwargs):
	return aot_autograd(
		fw_compiler=training_compiler_fw,
		bw_compiler=training_compiler_bw,
		inference_compiler=inference_compiler,
	)(graph_module, example_inputs)

Without aot_autograd, using only inference_compiler, instance_norm is normally lowered to my custom op under DispatchKey::MyPrivateKey.
However, when using aot_autograd, it is lowered using the native lowering and happens under CompositeImplicitAutograd.

Example of registration:


TORCH_LIBRARY_IMPL(aten, MyPrivateKey, m) {
    m.impl("instance_norm_wrap", instance_norm_wrap);
}

Could you help resolve this issue?

Or How should a backend override CompositeImplicitAutograd only for a specific backend? For other backends,
we would like to dispatch via the default path, but it causes recursion. Is there guidance on how this should be done?

The following code works, but we still need the native path for non-MyDevice devices:

@torch.ops.aten.instance_norm.default.py_impl(torch.ops.aten.instance_norm.default, DispatchKey.CompositeImplicitAutograd)
def instance_norm(input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
    return torch.ops.my_dev.instance_norm_wrap.default(input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled)

@register_meta([torch.ops.my_dev.instance_norm_wrap.default])
def instance_norm(input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled):
    out = torch.empty_like(input)
    return out

You’re right that AOTDispatcher (previously AOTAutograd, since we think the new name is a bit more fitting) doesn’t have a great out-of-the-box way to prevent specific CompositeImplicitAutograd ops from getting decomposed during tracing under an inference setting - they will always get decomposed during tracing.

We could offer a better API for this - but for now, one thing you can do is modify the state of the python dispatcher (similar to what you did about with py_impl, but under a context manager so you will only temporarily modify it during your compilation):

@contextmanager
def override_instance_norm():
    op = torch.ops.aten.instance_norm.default
    old_table = op.py_kernels.copy()

    def new_impl(*args, **kwargs):
        return torch.ops.my_dev.instance_norm_wrap.default(*args, **kwargs)

    # temporarily put a decomp to your custom op in the dispatcher
    # careful - note that your new op will not work with training/autograd
    # unless you manually write and register your own derivative rule for it.
    op.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)(new_impl)

    try:
        yield
    finally:
        op.py_kernels.clear()
        op.py_kernels.update(old_table)
        # private APIs
        op._dispatch_cache.clear()

...

@register_backend
def my_backend(graph_module: torch.fx.GraphModule, example_inputs: List[torch.Tensor], **kwargs):
        # temporarily tweak how the dispatcher traces instance_norm when AOTDispatcher traces
        with override_instance_norm():
	        return aot_autograd(
	    	      fw_compiler=training_compiler_fw,
		      bw_compiler=training_compiler_bw,
		      inference_compiler=inference_compiler,
	        )(graph_module, example_inputs)

@bdhirsh , thank you for your reply. The example you shared shows how to create a custom CompositImplicitAutograd op for instance_norm. However, to enable it only for a specific backend, we would still require dispatches on other devices (example: CPU) to go through the default path even during graph compilation, in case there are instance_norm on CPU still for some reason.

One possible way is to write the custom op “new_impl” to handle this, by checking on the input tensor device and calling backend specific op when the inputs are on the device, and redirect op dispatch for other devices back to the default. Since we already create our own CompositImplicitAutograd dispatch with py_impl, we need to create another context for other devices to pop our CompositImplicitAutograd entry, dispatch to the default instance_norm and then push our CompositImplicitAutograd afterwards, to avoid the recursion @wdziurdz mentioned.

While this is doable, do you suggest any better way to define this? For example, is there any way to qualify custom op for dispatch keys like CompositImplicitAutograd for a specific backend only?

While this is doable, do you suggest any better way to define this? For example, is there any way to qualify custom op for dispatch keys like CompositImplicitAutograd for a specific backend only?

that’s a good question - I don’t think there’s an easy way to do this today using the “usual” per-backend dispatch key registration. Our tracing logic today has some special logic to automatically hook into CompositeImplicitAutograd decompositions and run them, and CompositeImplicitAutograd decomps are usually expected to be backend agnostic. So I think having the “new decomp” that you register branch on device is probably… a reasonable workaround:

    def new_impl(*args, **kwargs):
        if get_device(args, kwargs) == 'my_backend_name':
            return torch.ops.my_dev.instance_norm_wrap.default(*args, **kwargs)
        else:
            # otherwise, dispatch to the original C++ decomp from core
            # I think this should avoid the recursion issues - op_dk() will call directly
            # into the CompositeImplicitAutograd registration from C++
            return op._op_dk(torch._C.DispatchKey.CompositeImplicitAutograd, *args, **kwargs)