[RFC] New Python operator registration API

We’re developing a new high-level Python operator registration API, designed to be positioned above the existing low-level Python torch.library and C++ TORCH_LIBRARY APIs. We expect this API to be the first API that users reach for in the future when bringing a custom operator to PyTorch. After some initial prototypes and feedback, we’ve settled on the following design. Please let us know your thoughts.

Design doc: [PUBLIC] Python Custom Ops (2024) - Google Docs

6 Likes

Hi,

The goal seems to give another mechanism for defining new Ops, and not to override behavior of existing Aten ops. Is that the intent?

Follow up: should potential new core ATen ops (say, scan, while_loop etc) be implemented this way?

1 Like

One my feedback is that it would be nice for such framework to be extensible for some future external op defines corresponding to existing ops (and not only be confined to forward/backward/double-backward). E.g. to be able to define in the future right_inverse / inverse functions for existing ops and maybe having a higher-order generic functions like : torch.func.inverse(op) which would return a function computing an inverse

1 Like

Hi,
The general approach sounds great, and is simpler then torch.library :slight_smile:

There are three things I am missing in this API:

  1. (small) Automatic schema deduction from type annotation - In many cases its easy and reduces the
    effort, and its not hard to verify it against the string like schema.

  2. (Maybe out of scope?) Having an option for “native” autograd will be very useful.
    In my use case I have operators that dynamo cannot handle and cause graph breaks.
    I want torch.compile to generate a full graph for the forward pass (To have accelerated inference using
    a custom compiler) but I do not care about the backward speed, and willing to fall back to eager mode
    for it. I really don’t want to manually write the backward for the operators for these operators. Having
    something like:

    @staticmethod
    @autograd
    def eager(x): 
       if x.mean()>0.5:
          return torch.sin(x)
       else:
          return torch.cos(x)
    

    Will enable me to torch.export/torch.compile to get a full graph, pass it to my compiler and get fast
    inference code, while running normal eager mode with autograd support for training.

  3. (Not sure) Allowing to pass “schema” to the operator. Many operators I use have complex schemas
    (e.g. quantization schema with multiple config options is selected per layer in the network to give
    best accuracy). I want to have a way to pass some configuration object. The two
    options I see now is to serialize to string and deserialize from string (performance? not sure),
    or to dynamically create a custom op class per configuration (tons of custom ops - almost one per
    layer)

1 Like

Yes, the goal is to give a mechanism for defining new ops, not for overriding existing ATen ops.

New core ATen ops should still be implemented the way they are implemented inside PyTorch (i.e. in C++ or via the HOP mechanism).

We plan to add more overridable staticmethods for more things that appear in the future. If we need an inverse function, then we’ll add an inverse() method to the class.

@gilfree

  1. We’ve thought about this. There are two main problems: (a) which static method do we deduce the schema from? A user can define cpu() for CPU implementatin, cuda() for CUDA implementation, etc. (b) consistency. The schema is not just about type information, it is also about mutability information (e.g. if the operator mutates an input, this shows up in the schema). For consistency, it’s better if there is one way to define the schema (via string)

  2. This is an interesting request… but the better solution is probably to use torch.cond in this situation. Creating a custom op from Python means that the code will not work with AOTInductor, so it depends on what your inference pipeline looks like. Using torch.cond to represent both true/false branches means that the code will work in all situations (inference vs training, AOTInductor vs export)

  3. What does the configuration object look like?

By the way, if these threads get too long, please feel free to comment directly on the gdoc! Your feedback is very appreciated.

Hi, @zou3519, Thanks for your response.

1 - Ok, got it. I guess that the abstract can be used, or even a function named signature or something, and add some type annotations e.g. x: Out[Tensor], but it becomes less useful.
2 - I am aware that AOTInducter will not work - I do not plan to use inductor as a backend. We have our own backend - we want the dynamo graph for the forward, in order to compile it with our backend. As for the backward pass - I’m ok with eager - it’s the best I can expect. Dynamo utterly fails on our code - I get ~80 graph breaks per custom convolution layer. Data dependent conditions are just an illustrative example.
3 - We usually work with simple dataclasses, but dicts can also be ok.

Thanks Again,

If there is anything I can help with pushing this API, let me know, I will do my best.

@gilfree

  1. It should be simple to build this API on top of the Python Custom Op proposal, but I’m not sure that we’ll build it into the API. What you’d want to do here is just switch on if the input Tensors have requires_grad=True: if they do, then call a custom op, if they don’t, then call a python function.

  2. Is it possible to splat the dataclass so that it can be passed into an operator?

Hi @zou3519

  1. Hmm. I think I did bad work explaining myself, will retry:
    The basic Idea is to have a multi-level compilation. First create a graph at some level of abstraction, do some transformations on it (e.g. conv-bn fuse, but with custom conv) and then pass it to further compilation, or stop at this level and train eagerly with autograd.

    My goal is to separate the level the ops graph is represented from the level in which manual gradients are required (like torch.fx autowrap_functions argument).

  2. If by splatting you mean decompose to trivial arguments - not sure. We have some lists there and a large number of members, which will make the schemas complex. Serializing to str is always an option, but in would be much nicer to be able to pass generic python object.

Thanks!