Autocast (aka Automatic Mixed Precision) is an optimization which helps taking advantage of the storage and performance benefits of narrow types (float16) while preserving the additional range and numerical precision of float32.
Currently autocast is only supported in eager mode, but there’s interest in supporting autocast in TorchScript. The current autocast interface presents a few challenges for the JIT path, and I’d like to outline some of the pain points here and ask for feedback and guidance.
For reference, here’s a quick and minimal example of autocast usage:
import torch
from torch.cuda.amp import autocast
@torch.jit.script
def func(a, b):
with autocast():
return torch.mm(a, b)
a_float32 = torch.rand((8, 8), dtype=torch.float32, device="cuda")
b_float32 = torch.rand((8, 8), dtype=torch.float32, device="cuda")
result = func(a_float32, b_float32)
print(result.dtype) # expecting torch.float16
Thanks to the recent support for context managers (kudos to @SplitInfinity) plus a few small tweaks we can model this code in TorchScript:
graph(%a.1 : Tensor,
%b.1 : Tensor):
%2 : bool = prim::Constant[value=1]()
%3 : __torch__.torch.cuda.amp.autocast = prim::CreateObject()
= prim::SetAttr[name="_enabled"](%3, %2)
%5 : None = prim::Enter(%3)
%6 : Tensor = aten::mm(%a.1, %b.1)
%7 : Tensor = prim::Exit(%3)
return (%6)
We have prim::Enter & prim::Exit defining the lexical scope for autocast. If we had the concrete dtypes for all the tensors, it would be straightforward to inject casts and (almost) call it a day.
...
%5 : None = prim::Enter(%3)
%10 : int = prim::Constant[value=5]()
%11 : bool = prim::Constant[value=0]()
%12 : None = prim::Constant()
%13 : Tensor = aten::to(%b.1, %10, %11, %11, %12)
%14 : Tensor = aten::to(%a.1, %10, %11, %11, %12)
%6 : Tensor = aten::mm(%14, %13)
%7 : Tensor = prim::Exit(%3)
...
The first challenge is that we start without the dtypes, and while the autocast logic is straightforward (just a handful of policies), it needs the concrete types even for the basic policies. For example, CastPolicy::fp16 is defined to cast the output to float16 but only if input is float32.
Also, we have “non-local” policies which depend on multiple input types. CastPolicy::promote defines a cast to the widest type(*) from the input types (which obviously must be known)
Profiling provides concrete types, but we also need to implement the autocast logic while profiling. But mutating the graph while interpreting & profiling it doesn’t seem a very attractive option (even if it was possible).
One alternative is to define a set of specialized “smart” casts:
...
%13 : Tensor = aten::autocast_to_fp16(%b.1, %10, %11, %11, %12)
%14 : Tensor = aten::autocast_to_fp16(%a.1, %10, %11, %11, %12)
%6 : Tensor = aten::mm(%14, %13)
...
The new cast operations would implement the autocast logic so they don’t need concrete types in the IR, which means they can be injected statically based on the lexical autocast scopes (and they would be invariant to profiling)
Thoughts? Any other ideas?