Custom Ops Under torch.compile: autograd.Function vs torch.library.custom_op

Yes, I’m going with custom ops.

If you go from having a composite op (is made of other ops) to a non-composite op (a custom kernel), then yes you need to add implementation for the different components to tell them how it works. Specifically, for compile → fake impl, for autograd → backward.

What about in the case where my custom op is just a composition of Aten ops? How can I use to CompositeImplicitAutograd to tell the various torch components how my custom op works? In your example the “meta implementation” does not actually produce the correct values, just the correct shapes. I have a situation where I’m really just registering a custom op which is semantically identical to a fusion of torch native ops. (my understanding is that this is the best way / most supported way to register a device-specific fusion, please lmk if I’m using the wrong APIs.)

When I initially register a function to CompositeImplicitAutograd everything works great. I can use inference mode and autograd on all devices.

Then when I add a device specific registration through the C++ api, the inference mode still works, but backwards for that device produces an error, “autograd kernel was not registered” warning.

The nearest solution I have found is to add re-dispatching to the CompositeImplicitAutograd in the C++ registration. something like this.

void autograd_to_composite(
    const c10::OperatorHandle& op,
    torch::jit::Stack* stack) {
  op.callBoxedForDispatchKey(c10::DispatchKey::CompositeImplicitAutograd, *stack);
}

TORCH_LIBRARY_IMPL(..., AutogradDEVICE, m) {
  m.impl(
      "forward_op",
      torch::CppFunction::makeFromBoxedFunction<&autograd_to_composite>());
}

But it’s a little over zealous / conservative. It runs the CompositeImplicitAutograd implementation in forward always on the device (even though there’s a fast fused impl available). There’s an old article about figuring out which intermediate tensors to preserve in a autograd graph.

It seems in this situation, “having a fast implementation of a fusion of aten operations” is equivalent to “storing none of the intermediate nodes and instead recalculating all of them during the backwards, [and replacing the forward op pattern with the fast fusion impl]”.

It would be really nice if there was some way I could get pytorch autograd to work in this way. (I could add a Tag to my custom operator to assert, “Yes, you can derive device specific autograd from my CIA + forwards. I understand I’m asking for magic / my CIA is accurate / the definition of the op behavior”)

The benefit would be that I would not have to duplicate generic autograd logic in each the C++ registrations for each device backend. (because doing so is verbose and error prone, essentially redefining what is asserted by registering the CompositeImplicitAutograd function).

Please let me know your thoughts on this, it’s okay if the answer is “that’s not a supported workflow” or w/e I’ve just been struggling with it for a while.