Part of adding support for autocast + scripting (JIT scripting & Autocast), we need to implement a special “promote” policy: cast all the input tensors to the widest type (* this is limited to fp16/fp32 types)
Unlike a regular cast which maps a single value to another value, this promote operation needs to inspect a variable number of inputs. One option would be an operation which takes a TensorList
input and returns another TensorList
. Unfortunately the resulting IR would be rather messy (since we’d need to index into the list to extract each value).
A cleaner alternative could be a specialized built-in operation, for example:
%a.2 : Tensor, %b.2 : Tensor = prim::autocast_promote(%a, %b)
This would be a truly variadic operation and would model the intention directly in the IR - here’s a complete, hypothetical, illustration:
Before Autocast:
graph(%a.1 : Tensor,
%b.1 : Tensor,
%c.1 : Tensor,
%d.1 : Tensor):
%4 : bool = prim::Constant[value=1]()
%5 : float = prim::Constant[value=0.1]()
%6 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::CreateObject()
= prim::SetAttr[name="_enabled"](%6, %4)
%8 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::Enter(%6)
%e.1 : Tensor = aten::mm(%a.1, %b.1)
%f.1 : Tensor = aten::addcmul(%e.1, %c.1, %d.1, %5)
%11 : Tensor = prim::Exit(%6)
%12 : (Tensor, Tensor) = prim::TupleConstruct(%e.1, %f.1)
return (%12)
After Autocast:
graph(%a.1 : Tensor,
%b.1 : Tensor,
%c.1 : Tensor,
%d.1 : Tensor):
%4 : bool = prim::Constant[value=1]()
%5 : float = prim::Constant[value=0.1]()
%6 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::CreateObject()
= prim::SetAttr[name="_enabled"](%6, %4)
%8 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::Enter(%6)
%14 : Tensor = aten::autocast_to_fp16(%b.1)
%15 : Tensor = aten::autocast_to_fp16(%a.1)
%e.1 : Tensor = aten::mm(%15, %14)
# this is how prim::autocast_promote might look in the IR
%16 : Tensor, %17 : Tensor, %18 : Tensor = prim::autocast_promote(%d.1, %c.1, %e.1)
%f.1 : Tensor = aten::addcmul(%18, %17, %16, %5)
%11 : Tensor = prim::Exit(%6)
%12 : (Tensor, Tensor) = prim::TupleConstruct(%e.1, %f.1)
return (%12)
I’d love to get feedback from some of the TorchScript experts: does such an operation make sense? Any other alternatives to consider? Would something like prim::TupleUnpack
be relevant here? (my impression is that the answer would be “no” since tuples are not really variadic)