Custom cuda extension support in Inductor

Per this thread Inductor Triton Custom Op , inductor supports custom kernel written in triton. How about general cuda extension / kernel then? Is it possible to integrate them into inductor? How can inductor understand custom kernels and possibly fuse them?

cc (might be relevant) @oulgen @zou3519 @jansel

You can register a standard PyTorch op via Custom C++ and CUDA Extensions — PyTorch Tutorials 2.2.1+cu121 documentation
(e.g. torch.ops.yournamespace.yourop)

This should work with torch.compile, though you may need to define a “meta” function to tell PT2 the output size of your custom op. In the op you can check tensor.device() == kMeta and return a empty tensor of the correct size, or register a different implementation for the “meta” device.

@zou3519 is working on a new custom op API that should be cleaner to use, though the above should work for C++/CUDA today.

Thanks for the quick reply! By integration with torch.compile, I think there are two concerns:

  • The custom op will not cause graph break in Dynamo.
  • The custom op can be fused into a larger kernel by inductor.

It seems we are mainly talking about the first point (i.e. integration with Dynamo). Are there any examples of usage about the second point (i.e. integration with Inductor)?

Inductor can’t fuse with CUDA because it generates Triton not CUDA, and there is no way to modify the source code of a custom op.

You could define a Triton template (for example pytorch/torch/_inductor/kernel/unpack_mixed_mm.py at main · pytorch/pytorch · GitHub) then Inductor will be able to fuse epilogues into it.

Then I have an interesting question: how does Inductor understand PyTorch builtin ops? They also have cuda extensions in them. Does Inductor have some “inductor language” and re-write all builtin ops in that language?

That’s also the thing that blocks my understanding of inductor. I’m not aware of any good documentation/tutorials to dive into inductor.

Builtin ops are either:

  1. Decomposed into other ops: pytorch/torch/_inductor/decomposition.py at main · pytorch/pytorch · GitHub
  2. Lowered to Inductor’s IR: pytorch/torch/_inductor/lowering.py at main · pytorch/pytorch · GitHub
  3. Or (rarely) become a “fallback” op where we just run the original operator unmodified

For user-defined custom ops they are handled the same as case 3.

Then is it possible to lower and register a user-defined custom op into inductor IR? Would love that if there are any documentation/tutorial/examples!

Yes it is possible (using @register_lowering from lowering.py which can be called from out of tree), though that is not something officially supported because we don’t guarantee forwards/backward compatibility of our compiler IRs, and it would require calling internal PyTorch APIs.

Feel free to give it a try, lowering.py contains lots of examples.