Prim::autocast_promote operation?

Part of adding support for autocast + scripting (JIT scripting & Autocast), we need to implement a special “promote” policy: cast all the input tensors to the widest type (* this is limited to fp16/fp32 types)

Unlike a regular cast which maps a single value to another value, this promote operation needs to inspect a variable number of inputs. One option would be an operation which takes a TensorList input and returns another TensorList. Unfortunately the resulting IR would be rather messy (since we’d need to index into the list to extract each value).

A cleaner alternative could be a specialized built-in operation, for example:

%a.2 : Tensor, %b.2 : Tensor = prim::autocast_promote(%a, %b)

This would be a truly variadic operation and would model the intention directly in the IR - here’s a complete, hypothetical, illustration:

Before Autocast: 
graph(%a.1 : Tensor,
      %b.1 : Tensor,
      %c.1 : Tensor,
      %d.1 : Tensor):
  %4 : bool = prim::Constant[value=1]()
  %5 : float = prim::Constant[value=0.1]()
  %6 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::CreateObject()
   = prim::SetAttr[name="_enabled"](%6, %4)
  %8 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::Enter(%6)
  %e.1 : Tensor = aten::mm(%a.1, %b.1)
  %f.1 : Tensor = aten::addcmul(%e.1, %c.1, %d.1, %5)
  %11 : Tensor = prim::Exit(%6)
  %12 : (Tensor, Tensor) = prim::TupleConstruct(%e.1, %f.1)
  return (%12)
After Autocast: 
graph(%a.1 : Tensor,
      %b.1 : Tensor,
      %c.1 : Tensor,
      %d.1 : Tensor):
  %4 : bool = prim::Constant[value=1]()
  %5 : float = prim::Constant[value=0.1]()
  %6 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::CreateObject()
   = prim::SetAttr[name="_enabled"](%6, %4)
  %8 : __torch__.torch.cuda.amp.autocast_mode.autocast = prim::Enter(%6)
  %14 : Tensor = aten::autocast_to_fp16(%b.1)
  %15 : Tensor = aten::autocast_to_fp16(%a.1)
  %e.1 : Tensor = aten::mm(%15, %14)

  # this is how prim::autocast_promote might look in the IR
  %16 : Tensor, %17 : Tensor, %18 : Tensor = prim::autocast_promote(%d.1, %c.1, %e.1)

  %f.1 : Tensor = aten::addcmul(%18, %17, %16, %5)
  %11 : Tensor = prim::Exit(%6)
  %12 : (Tensor, Tensor) = prim::TupleConstruct(%e.1, %f.1)
  return (%12)

I’d love to get feedback from some of the TorchScript experts: does such an operation make sense? Any other alternatives to consider? Would something like prim::TupleUnpack be relevant here? (my impression is that the answer would be “no” since tuples are not really variadic)

Tuple is not variadic once it’s created, but I feel for autocast_promote we are not going to update the number of arguments/outputs once we inserted the node, so that should be fine?

Having said that, I actually hate having a Tuple there. It makes profiling/aliasing analysis messy in the unfortunate PyTorch world. Is there any benefit in having a tuple as output comparing to your prototype above? I think it’s quite common in the current JIT IR to have variadic operations (guard / requires_grad_check e.t.c.)

1 Like

I agree. My current favorite is the hypothetical prim::autocast_promote described above: the IR is cleaner and optimizations / profiling / fusing should be easier as well.

What I’m looking for is either validation that this is reasonable direction and/or pointing out potential complications and alternatives I may be missing. I haven’t yet implemented prim::autocast_promote and if I’m doing something completely stupid I’d rather learn it before spending the time to implement it.

Yea, also really dont like the TupleUnpack. It’s annoying how you can’t register an operator which returns two unboxed values. I’m going to file an issue about that. Before that exists we should do it variadically I think.

This looks good!

1 Like

We currently are statically casting tensors to widest among all inputs: pytorch/autocast.cpp at master · pytorch/pytorch · GitHub

The issue is that the decision is made upon the static type information embedded in the graph. In an autocast pass, we are inserting casting ops in the graph, which would impact/mutate tensor types in down stream nodes. We don’t yet have a reliable type propagation pass to propagate the updated scalar type.
This means the static scalar types we see in the graph could be wrong after autocast, and the casting decision we make could diverge from eager.

One more reason to implement prim::autocast_promote.

1 Like