How to replace a callable with an opaque python callable in an Inductor pass?

Hi,

HigherOrderOperators seem to be helpful for tracing models with opaque python code during a torch.compile invocation.

But can some HigherOrderOperator that uses some arbitrary python be used to replace a Node’s callable (or create a new node with a custom callable that can replace an existing node) in a torch.fx graph during an Inductor pass, or is there some other existing mechanism for doing so?

Basically, I’m creating a python class whose objects (override the _call_() method, and are hence callables) are created during an Inductor pass, and are subsequently used to replace some torch.fx Nodes’ callables.

Currently, I’m making such an implementation work locally with some code in Inductor’s graph.py, ir.py & lowering.py, which adds some conditional checks for the custom python class I created, similar to some conditional checks that already exist for HigherOrderOps. Basically, make_fallback() is called for the callables I create, but simply calling make_fallback() is not enough, as the aforementioned conditional checks also need to be added.

Please advise if adding such conditional code would be okay, or if there’s currently some other way to do this.
In case there’s no existing mechanism to do this, could we create a generic one that could also be (re)used by any other FallBack ops in the future (because it seems make_fallback() alone may be insufficient)?

Thanks!

cc @ezyang @jansel

If you want to embed an arbitrary callable in Inductor’s output, then make_fallback() should work. HigherOrderOperators do not exist past inductor’s lowering step, so while they might be useful at the aten-FX-graph level you need to map them to something in Inductor’s IR, and a fallback would be the easiest thing to map to since it requires very little code.

What are the conditional checks you want to add? Could you provide an example?

If they are compile-time checks, then you could either do them as an fx pass, or by defining a custom lowering function.

If they are run-time checks, you could put the check inside your custom Python function and just use make_fallback to put that function in the generated code. You could also define your own type of IR node to codegen the conditional in the output code, then write a lowering to that.

1 Like

Thanks a lot for your response, @jansel!

I meant that despite using make_fallback(), in order to make Inductor compilation actually work, I still had to modify some asserts statements to add an exception for the class of the python callables I’m adding.

Here’s another one from ir.py -

Thanks!

Ah, I think it is fine to relax that assert. I’d suggest adding a strict=False flag to make_fallback that just skips the assert entirely. For implicit fallbacks (created on an unknown IR node) I think we want the assert, but if someone calls make_fallback directly it is not needed.

Feel free to submit a PR for that.

1 Like

Thanks again for the clarification, @jansel! :slight_smile:

Sorry, I should’ve mentioned that my implementation is calling make_fallback() in the implicit fallback path.

While adding a strict=False flag in make_fallback() would help prevent modifying an assert in make_fallback() for this use-case, it would not resolve the problem that the _init()_ of FallBackKernel requires a kernel to be either an instance of HigherOrderOp, or an OpOverload.

Is it possible to allow FallBackKernel to be created for an instance of another class besides these two? Thanks!

Another option is you could create a PR that adds a class like:

class InductorImplicitFallback:
     def __call__(self, *args, **kwargs):
         raise NotImplementedError

Then add that class to the assert, and make your class a subclass of that one.

That way people can add this type of thing out of tree without needing to change PyTorch.

1 Like

Thanks again, @jansel! :slight_smile:

I’ll work on this solution.