Hi,
The general approach sounds great, and is simpler then torch.library
There are three things I am missing in this API:
-
(small) Automatic schema deduction from type annotation - In many cases its easy and reduces the
effort, and its not hard to verify it against the string like schema. -
(Maybe out of scope?) Having an option for “native” autograd will be very useful.
In my use case I have operators that dynamo cannot handle and cause graph breaks.
I want torch.compile to generate a full graph for the forward pass (To have accelerated inference using
a custom compiler) but I do not care about the backward speed, and willing to fall back to eager mode
for it. I really don’t want to manually write the backward for the operators for these operators. Having
something like:@staticmethod @autograd def eager(x): if x.mean()>0.5: return torch.sin(x) else: return torch.cos(x)
Will enable me to torch.export/torch.compile to get a full graph, pass it to my compiler and get fast
inference code, while running normal eager mode with autograd support for training. -
(Not sure) Allowing to pass “schema” to the operator. Many operators I use have complex schemas
(e.g. quantization schema with multiple config options is selected per layer in the network to give
best accuracy). I want to have a way to pass some configuration object. The two
options I see now is to serialize to string and deserialize from string (performance? not sure),
or to dynamically create a custom op class per configuration (tons of custom ops - almost one per
layer)