[RFC] New Python operator registration API

Hi,
The general approach sounds great, and is simpler then torch.library :slight_smile:

There are three things I am missing in this API:

  1. (small) Automatic schema deduction from type annotation - In many cases its easy and reduces the
    effort, and its not hard to verify it against the string like schema.

  2. (Maybe out of scope?) Having an option for “native” autograd will be very useful.
    In my use case I have operators that dynamo cannot handle and cause graph breaks.
    I want torch.compile to generate a full graph for the forward pass (To have accelerated inference using
    a custom compiler) but I do not care about the backward speed, and willing to fall back to eager mode
    for it. I really don’t want to manually write the backward for the operators for these operators. Having
    something like:

    @staticmethod
    @autograd
    def eager(x): 
       if x.mean()>0.5:
          return torch.sin(x)
       else:
          return torch.cos(x)
    

    Will enable me to torch.export/torch.compile to get a full graph, pass it to my compiler and get fast
    inference code, while running normal eager mode with autograd support for training.

  3. (Not sure) Allowing to pass “schema” to the operator. Many operators I use have complex schemas
    (e.g. quantization schema with multiple config options is selected per layer in the network to give
    best accuracy). I want to have a way to pass some configuration object. The two
    options I see now is to serialize to string and deserialize from string (performance? not sure),
    or to dynamically create a custom op class per configuration (tons of custom ops - almost one per
    layer)

1 Like