A couple things:
Computing the result at runtime shouldn’t be too much of a limitation. It’s not too different from fusing aten::to
nodes which you can see how to add here I think the conservative aliasing should be fine, not really any different than
aten::to
.
As Tom said, the tricky thing in my mind is how we handle (or dont handle) global state of amp being set:
with autocast():
my_jitted_model()
my_jitted-model()
The results of the first model invocation would be invalid on the second run. We should think about how we can design this so that it works and is performant for autocast enabled within the scripted model and outside of it. As far as I understand, it’s more common for autocast to be enabled outside of the model code right ?