NNC walkthrough: how PyTorch ops get fused

In this note we’ll examine how a simple PyTorch program is getting transformed by JIT all the way to LLVM-compiled binary through NNC. My hope is that it will help to understand what each stage of the pipeline is doing and how one could examine that in more details.
The test program we’ll be using is extremely simple:

$ cat test.py

import torch

def foo(a):
    b = torch.conv2d(a, torch.randn(1, 1, 1, 1)) # not fusible
    x = torch.mul(b, b)                          # fusible
    y = torch.sin(x)                             # fusible
    z = torch.mul(y, y)                          # fusible
    return z

torch._C._jit_override_can_fuse_on_cpu(True)

a = torch.randn(1, 1, 128, 128)

scripted = torch.jit.script(foo)

# do several runs:
for _ in range(10):
    scripted(a)

To begin with, let’s see see what’s getting fused. To do that we will enable logging in the TE fuser pass:

$ PYTORCH_JIT_LOG_LEVEL="tensorexpr_fuser.cpp" python test.py

First, we see the original JIT graph before the fuser put its hands on:

Before TExprFuser:
graph(%a.1 : Tensor):
  %5 : int = prim::Constant[value=1]()
  %4 : None = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %1 : int[] = prim::Constant[value=[1, 1, 1, 1]]()
  %6 : Tensor = aten::randn(%1, %4, %4, %4, %4)
  %b.1 : Float(1, 1, 128, 128) = aten::conv2d(%a.1, %6, %4, %3, %2, %3, %5)
  %x.1 : Float(1, 1, 128, 128) = aten::mul(%b.1, %b.1)
  %y.1 : Float(1, 1, 128, 128) = aten::sin(%x.1)
  %z.1 : Float(1, 1, 128, 128) = aten::mul(%y.1, %y.1)
  return (%z.1)

If we scroll the dump lower, we’ll see how a fusion group was formed. Note that non-fusible ops, like aten::conv2d are not in it:

After creating fusion groups:
graph(%a.1 : Tensor):
  %5 : int = prim::Constant[value=1]()
  %4 : None = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %1 : int[] = prim::Constant[value=[1, 1, 1, 1]]()
  %6 : Tensor = aten::randn(%1, %4, %4, %4, %4)
  %b.1 : Float(1, 1, 128, 128) = aten::conv2d(%a.1, %6, %4, %3, %2, %3, %5)
  %z.2 : Float(1, 1, 128, 128) = prim::TensorExprGroup_0(%b.1)
  return (%z.2)
with prim::TensorExprGroup_0 = graph(%4 : Float(1, 1, 128, 128)):
  %x.2 : Float(1, 1, 128, 128) = aten::mul(%4, %4)
  %y.2 : Float(1, 1, 128, 128) = aten::sin(%x.2)
  %z.2 : Float(1, 1, 128, 128) = aten::mul(%y.2, %y.2)
  return (%z.2)

Fusion groups are only legal to run when the input shapes are exactly the same as we saw during profiling runs (they were encoded in the JIT IR before the fuser pass). Hence, we need to add runtime checks that would use a fallback path - a non-optimized version of the fused subgraph.

After guarding fusion groups:
graph(%a.1 : Tensor):
  %5 : int = prim::Constant[value=1]()
  %4 : None = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %1 : int[] = prim::Constant[value=[1, 1, 1, 1]]()
  %6 : Tensor = aten::randn(%1, %4, %4, %4, %4)
  %b.1 : Float(1, 1, 128, 128) = aten::conv2d(%a.1, %6, %4, %3, %2, %3, %5)
  %25 : Float(1, 1, 128, 128), %26 : bool = prim::TypeCheck[types=[Float(1, 1, 128, 128)]](%b.1)
  %27 : Float(1, 1, 128, 128) = prim::If(%26)
    block0():
      %z.3 : Float(1, 1, 128, 128) = prim::TensorExprGroup_0(%25)
      -> (%z.3)
    block1():
      %z.2 : Tensor = prim::FallbackGraph_1(%b.1)
      -> (%z.2)
  %21 : int[] = aten::size(%b.1)
  %22 : int[] = aten::size(%27)
  %23 : int[] = prim::BroadcastSizes(%21, %21)
  %24 : int[] = prim::BroadcastSizes(%23, %23)
  return (%27)
with prim::TensorExprGroup_0 = graph(%4 : Float(1, 1, 128, 128)):
  %x.2 : Float(1, 1, 128, 128) = aten::mul(%4, %4)
  %y.2 : Float(1, 1, 128, 128) = aten::sin(%x.2)
  %z.2 : Float(1, 1, 128, 128) = aten::mul(%y.2, %y.2)
  return (%z.2)
with prim::FallbackGraph_1 = graph(%b.1 : Float(1, 1, 128, 128)):
  %x.2 : Tensor = aten::mul(%b.1, %b.1)
  %y.2 : Tensor = aten::sin(%x.2)
  %z.2 : Tensor = aten::mul(%y.2, %y.2)
  return (%z.2)

Finally, after some additional cleanups, we end up with the following IR:

After TExprFuser:
graph(%a.1 : Tensor):
  %5 : int = prim::Constant[value=1]()
  %4 : None = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %1 : int[] = prim::Constant[value=[1, 1, 1, 1]]()
  %6 : Tensor = aten::randn(%1, %4, %4, %4, %4)
  %b.1 : Tensor = aten::conv2d(%a.1, %6, %4, %3, %2, %3, %5)
  %25 : Float(1, 1, 128, 128), %26 : bool = prim::TypeCheck[types=[Float(1, 1, 128, 128)]](%b.1)
  %27 : Tensor = prim::If(%26)
    block0():
      %z.3 : Float(1, 1, 128, 128) = prim::TensorExprGroup_0(%25)
      -> (%z.3)
    block1():
      %z.2 : Tensor = prim::FallbackGraph_1(%b.1)
      -> (%z.2)
  return (%27)
with prim::TensorExprGroup_0 = graph(%4 : Float(1, 1, 128, 128)):
  %x.2 : Float(1, 1, 128, 128) = aten::mul(%4, %4)
  %y.2 : Float(1, 1, 128, 128) = aten::sin(%x.2)
  %z.2 : Float(1, 1, 128, 128) = aten::mul(%y.2, %y.2)
  return (%z.2)
with prim::FallbackGraph_1 = graph(%b.1 : Float(1, 1, 128, 128, strides=[16384, 16384, 128, 1], requires_grad=0, device=cpu)):
  %x.2 : Tensor = aten::mul(%b.1, %b.1)
  %y.2 : Tensor = aten::sin(%x.2)
  %z.2 : Tensor = aten::mul(%y.2, %y.2)
  return (%z.2)

The prim::TensorExprGroup is a special node: when JIT interpreter reaches it, it invokes NNC to execute the subgraph.
Let’s now look at how NNC generates code for this subgraph. To do that, let’s enable logging in kernel.cpp:

$ PYTORCH_JIT_LOG_LEVEL=">>kernel.cpp"  python test.py

The first thing we see is the fused subgraph:

TensorExprKernel graph:
graph(%0 : Float(1, 1, 128, 128)):
  %x.2 : Float(1, 1, 128, 128) = aten::mul(%0, %0)
  %y.2 : Float(1, 1, 128, 128) = aten::sin(%x.2)
  %z.2 : Float(1, 1, 128, 128) = aten::mul(%y.2, %y.2)
  return (%z.2)

It is the same graph as we saw inside prim::TensorExprGroup node, and this is an ‘input’ for NNC. Firstly, it is getting lowered to tensor expressions:

Original Stmt:
{
  for (int i0 = 0; i0 < 1; i0++) {
    for (int i1 = 0; i1 < 1; i1++) {
      for (int i2 = 0; i2 < 128; i2++) {
        for (int i3 = 0; i3 < 128; i3++) {
          input1[i0, i1, i2, i3] = t0[(((0 + i0 * 16384) + i1 * 16384) + i2 * 128) + i3 * 1];
        }
      }
    }
  }
  for (int v = 0; v < 1; v++) {
    for (int v_1 = 0; v_1 < 1; v_1++) {
      for (int v_2 = 0; v_2 < 128; v_2++) {
        for (int v_3 = 0; v_3 < 128; v_3++) {
          aten_mul[v, v_1, v_2, v_3] = (input1(0, 0, v_2, v_3)) * (input1(0, 0, v_2, v_3));
        }
      }
    }
  }
  for (int v_4 = 0; v_4 < 1; v_4++) {
    for (int v_5 = 0; v_5 < 1; v_5++) {
      for (int v_6 = 0; v_6 < 128; v_6++) {
        for (int v_7 = 0; v_7 < 128; v_7++) {
          aten_sin[v_4, v_5, v_6, v_7] = sin(aten_mul(0, 0, v_6, v_7));
        }
      }
    }
  }
  for (int v_8 = 0; v_8 < 1; v_8++) {
    for (int v_9 = 0; v_9 < 1; v_9++) {
      for (int v_10 = 0; v_10 < 128; v_10++) {
        for (int v_11 = 0; v_11 < 128; v_11++) {
          aten_mul_1[v_8, v_9, v_10, v_11] = (aten_sin(0, 0, v_10, v_11)) * (aten_sin(0, 0, v_10, v_11));
        }
      }
    }
  }
}

We had 3 nodes in the fused graph and the generated tensor expression has 4 loop nests: 1 extra is for copying the input tensor. No need to worry about redundancy of that copy, it will be removed as unnecessary later.
That tensor expression (or more accurately, a tensor statement) is then transformed by NNC, and most notable transformations are:

  • inlining
  • vectorization
  • index flattenning

Inlining is propagating a definition of a tensor into its use. For instance, in the third loop instead of loading a value from a buffer aten_mul we could use its definition: aten_mul[i,j,k,l] = input1[i,j,k,l]*input1[i,j,k,l]. Similarly, we could inline input1 into aten_mul.

Vectorization is replacing a scalar accesses with vector counterparts.
E.g.

for (int i = 0; i < 128; i++) {
  a[i] = b[i] + c[i]
}

is transformed by vectorization to:

for(int i = 0; i < 16; i++) {
  a[i*8:i*8+8] = b[i*8:i*8+8] + c[i*8:i*8+8]
}

Lastly, the N-dimensional accesses to tensors are flattened into 1-d in preparation for LLVM or Cuda codegen.
E.g.

for (int i = 0; i < 128; i++) {
  for (int j = 0; j < 100; j++) {
    a[i,j] = 0
  }
}

is transformed into

for (int i = 0; i < 128; i++) {
  for (int j = 0; j < 100; j++) {
    a[i*100 + j] = 0
  }
}

We could see all of these transformations on our example:

Final Stmt:
{
  Allocate(aten_sin, float, {16384});
  for (int v = 0; v < 128; v++) {
    for (int _outer = 0; _outer < 16; _outer++) {
      aten_sin[Ramp(8 * (_outer + 16 * v), 1, 8)] =
         sin(t0[Ramp(8 * (_outer + 16 * v), 1, 8)] *
             t0[Ramp(8 * (_outer + 16 * v), 1, 8)]);
    }
  }
  for (int v_1 = 0; v_1 < 128; v_1++) {
    for (int _outer_1 = 0; _outer_1 < 16; _outer_1++) {
      aten_mul[Ramp(8 * (16 * v_1 + _outer_1), 1, 8)] =
         aten_sin[Ramp(8 * (16 * v_1 + _outer_1), 1, 8)] *
         aten_sin[Ramp(8 * (16 * v_1 + _outer_1), 1, 8)];
    }
  }
  Free(aten_sin);
}

We can see that originally we had 3 operators (which meant we did 3 passes through the memory), and after NNC transformation we have just two loops (which means we do 2 passes through the memory). It is possible to do it in one sweep as well - in fact, NNC would do it for CUDA - but for CPU we have a heuristics that prevent inlining potentially expensive computations (sin in this case).

The final step in this process is invoking LLVM or Cuda to compile this tensor expression into an executable binary. We can inspect that step in detail by running our test in the following way:

$ PYTORCH_JIT_LOG_LEVEL=">>llvm_codegen"  python test.py

Note: PyTorch needs to be built with USE_LLVM for this command to work.

The output is pretty big, but one could find LLVM IR right after the lowering from the tensor expression (before LLVM optimizations), LLVM IR after LLVM optimizations, and final assembly code.
If we used CUDA, then we could similarly peek into CUDA codegen by using PYTORCH_JIT_LOG_LEVEL=">>cuda_codegen".

17 Likes

Apologies for digging up such an old thread, but I was just trying your example and I don’t see any fusion going on with PyTorch 1.10 (git of late August).
Has the fuser code been disabled in the meantime? Do I need to compile with some flag?

Just to confirm, I do see output like:

[DUMP tensorexpr_fuser.cpp:617] After guarding fusion groups

But zero changes between passes.

I would appreciate if you could help. Thank you!

We can increase verbosity of debug dumps from the fuser if we add ‘>>’ in front of ‘tensorexpr_fuser’: PYTORCH_JIT_LOG_LEVEL=“>>tensorexpr_fuser”. This way it will show why it is not fusing. I think the fuser should be generally be enabled by default for cpu and gpu in 1.10, except for windows - is it what you’re using?

Thanks. Now I see some debug stuff indeed. I’m on linux.
I get this:

[DEBUG tensorexpr_fuser.cpp:696] Considering node:%z.1 : Float(1, 1, 128, 128, strides=[16384, 16384, 128, 1], requires_grad=0, device=cpu) = aten::mul(%y.1, %y.1) # xx.py:7:8
[DEBUG tensorexpr_fuser.cpp:1088] Failed cond isFusableOnDevice(node)
[DEBUG tensorexpr_fuser.cpp:696] Considering node:%y.1 : Float(1, 1, 128, 128, strides=[16384, 16384, 128, 1], requires_grad=0, device=cpu) = aten::sin(%x.1) # xx.py:6:8
[DEBUG tensorexpr_fuser.cpp:1088] Failed cond isFusableOnDevice(node)
[DEBUG tensorexpr_fuser.cpp:696] Considering node:%x.1 : Float(1, 1, 128, 128, strides=[16384, 16384, 128, 1], requires_grad=0, device=cpu) = aten::mul(%b.1, %b.1) # xx.py:5:8
[DEBUG tensorexpr_fuser.cpp:1088] Failed cond isFusableOnDevice(node)
[DEBUG tensorexpr_fuser.cpp:696] Considering node:%b.1 : Float(1, 1, 128, 128, strides=[16384, 16384, 128, 1], requires_grad=0, device=cpu) = aten::conv2d(%a.1, %6, %4, %3, %2, %3, %5) # xx.py:4:8
[DEBUG tensorexpr_fuser.cpp:1088] Failed cond isFusableOnDevice(node)
...

Investigating canFuseOnDevice, we have this:

    if (device->is_cpu()) {
      // CPU fusion is only supported for single-thread.
      if (!canFuseOnCPU()) {
        return false;
      }
      if (at::get_num_threads() == 1 || texprParallelCPUEnabled()) {
        return true;
      }
      return false;
    }

So CPU only supports fusion in single-thread mode. Doing export OMP_NUM_THREADS=1 did the trick. I’ll try CUDA next.
Thank you!

FWIW, this is the test driver I’m using to test the different fusers:

  elif arg == '--fuser-nnc':
    torch._C._jit_override_can_fuse_on_cpu(True)
    torch._C._jit_override_can_fuse_on_gpu(True)
    torch._C._jit_set_texpr_parallel_cpu_enabled(True)
    torch._C._jit_set_te_must_use_llvm_cpu(False)
    os.environ['PYTORCH_TENSOREXPR_DONT_USE_LLVM'] = '1'
  elif arg == '--fuser-nnc-llvm':
    torch._C._jit_override_can_fuse_on_cpu(True)
    torch._C._jit_override_can_fuse_on_gpu(True)
    torch._C._jit_set_texpr_parallel_cpu_enabled(True)
  elif arg == '--nvfuser':
    #os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1'
    torch._C._jit_override_can_fuse_on_cpu(False)
    torch._C._jit_override_can_fuse_on_gpu(False)
    torch._C._jit_set_texpr_fuser_enabled(False)
    torch._C._jit_set_nvfuser_enabled(True)

not seeing great results so far to be honest.

Hm, I think we’ve enabled multi-threading support too. Looking at sources for pytorch 1.10 I don’t see a check for get_num_threads() == 1 (pytorch/tensorexpr_fuser.cpp at release/1.10 · pytorch/pytorch · GitHub):

    if (device->is_cpu()) {
      return canFuseOnCPU();
    } else if (device->is_cuda()) {
      return canFuseOnGPU();
    } ...

Could you please double check what version are you using?

As for the driver for the test, I don’t think there is a value in benchmarking CPU fusion with PYTORCH_TENSOREXPR_DONT_USE_LLVM (for CUDA it should be fine) - for cpu it would essentially interpret the tensorexpr AST, which would be terribly slow. This setting is supposed to be used for debugging purposes only.

Also, I’d be happy to look into the performance issues you’re seeing, could you please paste complete repros somewhere?

Sorry, my August checkout was too old. I’ve updated to 1.10+ and it all works as you say. Plus all crashes I was getting seem fixed now.

I will try to get you some complete repros. I’m getting these programs using a lazy tensor scheme on Hugging Face & TensorVision. I see a few fusions going on there, but very few. Give me a few more days to investigate.

For CUDA, I thought TorchScript could produce a .cu file for the whole function & compile. But doesn’t seem to be the case.

1 Like

For CUDA, I thought TorchScript could produce a .cu file for the whole function & compile. But doesn’t seem to be the case.

We don’t generate the code for entire function, but we do generate the cuda kernel. You can inspect it with
PYTORCH_JIT_LOG_LEVEL=">>>cuda_codegen" when you use TensorExpr fuser and with PYTORCH_CUDA_FUSER_DEBUG=1 when you use NVFuser. This cuda source is compiled on the fly, we do not save it as a real file.

1 Like

The first issue I’ve noticed is that there’s no support for inplace ops.
For example, this graph without inplace ops:

graph(%0 : Long(10000, 13000, strides=[13000, 1], requires_grad=0, device=cuda:0),
      %1 : Long(10000, 13000, strides=[13000, 1], requires_grad=0, device=cuda:0),
      %2 : Int(10000, 13000, strides=[13000, 1], requires_grad=0, device=cuda:0)):
  %11 : int = prim::Constant[value=4]()
  %8 : int = prim::Constant[value=1]()
  %6 : bool = prim::Constant[value=0]()
  %5 : Device = prim::Constant[value="cuda:0"]()
  %4 : NoneType = prim::Constant()
  %3 : int = prim::Constant[value=3]()
  %7 : Tensor = aten::_to_copy(%0, %3, %4, %5, %4, %6, %4)
  %9 : Tensor = aten::add(%7, %1, %8)
  %10 : Tensor = aten::mul(%9, %2)
  %14 : Tensor = aten::to(%10, %11, %6, %6, %4)
  return (%14)

Gets fused:

[DEBUG cuda_codegen.cpp:1032] extern "C" __global__
[DEBUG cuda_codegen.cpp:1032] void fused_add_mul(int* tv_, int* tv__, long long* tv___, long long* aten_mul) {
[DEBUG cuda_codegen.cpp:1032] {
[DEBUG cuda_codegen.cpp:1032] if ((long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)<130000000ll ? 1 : 0) {
[DEBUG cuda_codegen.cpp:1032]     int tv__1 = __ldg(tv_ + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
[DEBUG cuda_codegen.cpp:1032]     long long v = (long long)(__ldg(tv__ + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)));
[DEBUG cuda_codegen.cpp:1032]     long long v_1 = __ldg(tv___ + (long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x));
[DEBUG cuda_codegen.cpp:1032]     aten_mul[(long long)(threadIdx.x) + 512ll * (long long)(blockIdx.x)] = (long long)(tv__1) * v + (long long)(tv__1) * v_1;
[DEBUG cuda_codegen.cpp:1032]   }}
[DEBUG cuda_codegen.cpp:1032] }

But if you change the mul to mul_ no fusion happens. Note that changing add → add_ & mul → mul_ yields an equivalent program. (side comment is: why doesn’t TorchScript optimize non-inplace into inplace?)
I also wonder why it changes the expression (a + b) * c into (a * c) + (b * c); is that actually better?

I’ve noticed that sometimes similar programs don’t get fully fused; only the mul + to. But I didn’t manage to find out yet why the add is left out sometimes.

1 Like

Right, we do not support in-place. TorchScript, in fact, has a pass that replaces in-place ops with their out-of-place equivalents, but it is not run in the default pipeline (because it can easily be a pessimization rather than optimization). NNC itself doesn’t currently support inplace ops either, but we’re considering changing that.

I also wonder why it changes the expression (a + b) * c into (a * c) + (b * c)

Hm, I’m surprised by this as well :slight_smile: If you’re curious to investigate this, I’d suggest running the test with PYTORCH_JIT_LOG_LEVEL=">>kernel:>>cuda_codegen" to see where this happens.

Btw, maybe that would be interesting to you too: there is an API to invoke NNC on a graph directly, without going through the fuser pass (in a case when the graph has some unsupported ops it will fail). It’s not a public API, but maybe you’ll find it convenient for your experiments. Here is an example of how it’s used (and I can provide more if needed):

1 Like

I found one reason why NNC doesn’t fuse much for cuda with Hugging Face: some tensors are in cpu, and even if the operation is run on the GPU, NNC seems to bail out when operands have mixed devices.
e.g.:

graph(%0 : Int(1, 9, strides=[9, 1], requires_grad=0, device=cuda:0),
      %1 : Long(requires_grad=0, device=cpu),
      %2 : Int(1, 9, strides=[9, 1], requires_grad=0, device=cuda:0)):
  %9 : NoneType = prim::Constant()
  %7 : bool = prim::Constant[value=0]()
  %6 : int = prim::Constant[value=4]()
  %3 : int = prim::Constant[value=1]()
  %4 : Tensor = aten::add(%0, %1, %3)
  %5 : Tensor = aten::mul(%4, %2)
  %10 : Tensor = aten::to(%5, %6, %7, %7, %9)
  return (%10)

If you change input %1 to have device=cuda, the TensorExprKernel graph gets the 3 instructions instead of just 2.
Adding explicit copies to move operands to the device where the operation will execute seems like a NOP, as anyway data needs to be moved. Plus then it enables fusion with the current code unmodified. Is this something you guys are considering implementing?