NNC walkthrough: how PyTorch ops get fused

The first issue I’ve noticed is that there’s no support for inplace ops.
For example, this graph without inplace ops:

graph(%0 : Long(10000, 13000, strides=[13000, 1], requires_grad=0, device=cuda:0),
      %1 : Long(10000, 13000, strides=[13000, 1], requires_grad=0, device=cuda:0),
      %2 : Int(10000, 13000, strides=[13000, 1], requires_grad=0, device=cuda:0)):
  %11 : int = prim::Constant[value=4]()
  %8 : int = prim::Constant[value=1]()
  %6 : bool = prim::Constant[value=0]()
  %5 : Device = prim::Constant[value="cuda:0"]()
  %4 : NoneType = prim::Constant()
  %3 : int = prim::Constant[value=3]()
  %7 : Tensor = aten::_to_copy(%0, %3, %4, %5, %4, %6, %4)
  %9 : Tensor = aten::add(%7, %1, %8)
  %10 : Tensor = aten::mul(%9, %2)
  %14 : Tensor = aten::to(%10, %11, %6, %6, %4)
  return (%14)

Gets fused:

[DEBUG cuda_codegen.cpp:1032] extern "C" __global__
[DEBUG cuda_codegen.cpp:1032] void fused_add_mul(int* tv_, int* tv__, long long* tv___, long long* aten_mul) {
[DEBUG cuda_codegen.cpp:1032] {
[DEBUG cuda_codegen.cpp:1032] if ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)<130000000ll ? 1 : 0) {
[DEBUG cuda_codegen.cpp:1032]     int tv__1 = __ldg(tv_ + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
[DEBUG cuda_codegen.cpp:1032]     long long v = (long long)(__ldg(tv__ + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)));
[DEBUG cuda_codegen.cpp:1032]     long long v_1 = __ldg(tv___ + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
[DEBUG cuda_codegen.cpp:1032]     aten_mul[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = (long long)(tv__1) * v + (long long)(tv__1) * v_1;
[DEBUG cuda_codegen.cpp:1032]   }}
[DEBUG cuda_codegen.cpp:1032] }

But if you change the mul to mul_ no fusion happens. Note that changing add → add_ & mul → mul_ yields an equivalent program. (side comment is: why doesn’t TorchScript optimize non-inplace into inplace?)
I also wonder why it changes the expression (a + b) * c into (a * c) + (b * c); is that actually better?

I’ve noticed that sometimes similar programs don’t get fully fused; only the mul + to. But I didn’t manage to find out yet why the add is left out sometimes.

1 Like