JIT scripting & Autocast

Autocast (aka Automatic Mixed Precision) is an optimization which helps taking advantage of the storage and performance benefits of narrow types (float16) while preserving the additional range and numerical precision of float32.

Currently autocast is only supported in eager mode, but there’s interest in supporting autocast in TorchScript. The current autocast interface presents a few challenges for the JIT path, and I’d like to outline some of the pain points here and ask for feedback and guidance.

For reference, here’s a quick and minimal example of autocast usage:

import torch
from torch.cuda.amp import autocast

@torch.jit.script
def func(a, b):
    with autocast():
        return torch.mm(a, b)

a_float32 = torch.rand((8, 8), dtype=torch.float32, device="cuda")
b_float32 = torch.rand((8, 8), dtype=torch.float32, device="cuda")
result = func(a_float32, b_float32)
print(result.dtype) # expecting torch.float16

Thanks to the recent support for context managers (kudos to @SplitInfinity) plus a few small tweaks we can model this code in TorchScript:

graph(%a.1 : Tensor,
      %b.1 : Tensor):
  %2 : bool = prim::Constant[value=1]()
  %3 : __torch__.torch.cuda.amp.autocast = prim::CreateObject()
   = prim::SetAttr[name="_enabled"](%3, %2)
  %5 : None = prim::Enter(%3)
  %6 : Tensor = aten::mm(%a.1, %b.1)
  %7 : Tensor = prim::Exit(%3)
  return (%6)

We have prim::Enter & prim::Exit defining the lexical scope for autocast. If we had the concrete dtypes for all the tensors, it would be straightforward to inject casts and (almost) call it a day.

...
%5 : None = prim::Enter(%3)
%10 : int = prim::Constant[value=5]()
%11 : bool = prim::Constant[value=0]()
%12 : None = prim::Constant()
%13 : Tensor = aten::to(%b.1, %10, %11, %11, %12)
%14 : Tensor = aten::to(%a.1, %10, %11, %11, %12)
%6 : Tensor = aten::mm(%14, %13)
%7 : Tensor = prim::Exit(%3)
...

The first challenge is that we start without the dtypes, and while the autocast logic is straightforward (just a handful of policies), it needs the concrete types even for the basic policies. For example, CastPolicy::fp16 is defined to cast the output to float16 but only if input is float32.

Also, we have “non-local” policies which depend on multiple input types. CastPolicy::promote defines a cast to the widest type(*) from the input types (which obviously must be known)

Profiling provides concrete types, but we also need to implement the autocast logic while profiling. But mutating the graph while interpreting & profiling it doesn’t seem a very attractive option (even if it was possible).

One alternative is to define a set of specialized “smart” casts:

...
%13 : Tensor = aten::autocast_to_fp16(%b.1, %10, %11, %11, %12)
%14 : Tensor = aten::autocast_to_fp16(%a.1, %10, %11, %11, %12)
%6 : Tensor = aten::mm(%14, %13)
...

The new cast operations would implement the autocast logic so they don’t need concrete types in the IR, which means they can be injected statically based on the lexical autocast scopes (and they would be invariant to profiling)

Thoughts? Any other ideas?

Cool! Yea I agree that AOT graph manipulation probably seems like the way to go, as you’ve suggested. Embedding the actual conversion nodes in the IR gives the possibility of fusing the generated autocasting nodes and the computation. I’m not sure how often AMP is used in inference, but embedding the AMP nodes in the IR gives the possibility of AOT converting weights to the correct amp dtype instead of at runtime, as was recently suggested to me: https://github.com/pytorch/pytorch/pull/50222#issuecomment-775478422.

Thanks! Yes, one of the goals is to make it easy to optimize and fuse. The autocast nodes should allow both IR and fuser level optimizations, although there are some potential complications:

  1. The “smart” cast operations compute the result type at runtime, which is not ideal for an optimizer. Technically, this is not any worse than having incomplete types since with concrete types the autocast logic can be incorporated in the optimizer (practically this not great either, but at least doable)

  2. Since the cast nodes may or may not cast depending on the input type, we’d need conservative aliasing annotations (ie. assume that casts always alias)

I think these are real limitations, but so far I don’t have any better answers. Unless we decide to come up with a JIT autocast interface different than eager mode, which seems even less desirable.

The comment you linked to seems to refer to the current state where AMP mixes poorly with scripting: the dispatch key mechanism used by eager mode autocast is invisible to the profiler, which leads to incorrect dtypes in the profiled graph. So this sounds like a AMP bug, although my understanding is that this is not a supported scenario (until we add proper support for AMP + JIT).

Cool things!

If the initial pattern is with the cast operations, you can then profile and guard an optimized (=static) version of the autocast. In fact, if the fusers incuded the autocast node in their fusion group, this will happen automatically.

For the interop with eager I think we eventually need to dispatch on this in one way or the other. As mentioned on slack, I think we might have a good think whether the issue of having some global (well, or thread) state needing to influence the JIT graph is particular to autocast or whether it’ll show up more.

Best regards

Thomas

Do you mean replacing the smart/dynamic casts with static casts after profiling? I like the idea - it would make a nice optimization!

Regarding interop with eager mode, I agree that’s an interesting topic in itself.

A couple things:

Computing the result at runtime shouldn’t be too much of a limitation. It’s not too different from fusing aten::to nodes which you can see how to add here :slight_smile: I think the conservative aliasing should be fine, not really any different than aten::to.

As Tom said, the tricky thing in my mind is how we handle (or dont handle) global state of amp being set:

with autocast():
     my_jitted_model()
my_jitted-model()

The results of the first model invocation would be invalid on the second run. We should think about how we can design this so that it works and is performant for autocast enabled within the scripted model and outside of it. As far as I understand, it’s more common for autocast to be enabled outside of the model code right ?

1 Like

Thanks @eellison for the pointers!

I was planning to save the eager / scripted / traced composability for a separate thread, but too late now :slight_smile: Yes, I’d expect to see code which wraps scripted code in a eager mode autocast, especially if we want to make it easy to script existing code. I see two levels of support (excluding the broken behavior we see today):

Level 1: Only support autocast inside the script and check (at the executor level) that we’re not mixing eager and scripted code. This is limiting, but will not produce incorrect results and it’s easy to document / teach.

Level 2: Support mixing eager , traced and scripted code freely. In order to support this in a world were we’d have the casts statically injected in the IR, the best idea so far is to specialize the graph: one version which assumes initial state is autocast(true) and another specialization assuming autocast(false). The executor will dispatch to the right specialization depending on the autocast state (this is what @tom was alluding to). Any other ideas?

Naturally, I’m all for level 2. :slight_smile: By the time we check, we might as well do the right thing.

1 Like

Yea, I’m in with two. I don’t think we should overthink tracing here… If they trace the graph with autocasting, they bake in the type conversion nodes, not much that can be done about that.
As you said, specializing the graph makes sense to me. What other global state would it make sense to specialize on in the future ? I guess there’s current_default_dtype and with no_grad, that might be it ? I think with no_grad works right now. current_default_dtype might have slight differences but probably isn’t worth focusing on, it’s never actually come up as a use case.

1 Like

coming to the party late :wink:

Reading from the context I think we are leaning towards a static cast utilizing profiling information. Cast op would change tensor types and invalidate downstream profiling information. So we would still need to propagate tensor types. This seems to circle back to the original problem that profiling is trying to solve.

Yea, and we definitely only want to insert these cast nodes when we’re on GPU (and into a fusible node / fusion group). You’re right we would need to propagate tensor types. I dont think that would be very difficult here since it’ll have all tensor properties preserved except for the dtype. strides should follow the existing memory format or if it’s non-dense than they’ll be contiguous

1 Like

@jjsjann123

Reading from the context I think we are leaning towards a static cast utilizing profiling information

Not quite. The proposal described in the original post is a mix of static and runtime parts:

  1. The location of the injected “auto-casts” is static, following the lexical scopes introduced by the cuda.autocast() context manager.

  2. The injected autocasts handle dtypes dynamically (runtime). For example the hypothetical aten::autocast_to_fp16 would cast float32 → float16 and would leave any other tensors untouched.

Part #1 requires the specialization discussed later in the thread in order to support mixing eager-mode and scripting.