NNC walkthrough: how PyTorch ops get fused

In this note we’ll examine how a simple PyTorch program is getting transformed by JIT all the way to LLVM-compiled binary through NNC. My hope is that it will help to understand what each stage of the pipeline is doing and how one could examine that in more details.
The test program we’ll be using is extremely simple:

$ cat test.py

import torch

def foo(a):
    b = torch.conv2d(a, torch.randn(1, 1, 1, 1)) # not fusible
    x = torch.mul(b, b)                          # fusible
    y = torch.sin(x)                             # fusible
    z = torch.mul(y, y)                          # fusible
    return z

torch._C._jit_override_can_fuse_on_cpu(True)

a = torch.randn(1, 1, 128, 128)

scripted = torch.jit.script(foo)

# do several runs:
for _ in range(10):
    scripted(a)

To begin with, let’s see see what’s getting fused. To do that we will enable logging in the TE fuser pass:

$ PYTORCH_JIT_LOG_LEVEL="tensorexpr_fuser.cpp" python test.py

First, we see the original JIT graph before the fuser put its hands on:

Before TExprFuser:
graph(%a.1 : Tensor):
  %5 : int = prim::Constant[value=1]()
  %4 : None = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %1 : int[] = prim::Constant[value=[1, 1, 1, 1]]()
  %6 : Tensor = aten::randn(%1, %4, %4, %4, %4)
  %b.1 : Float(1, 1, 128, 128) = aten::conv2d(%a.1, %6, %4, %3, %2, %3, %5)
  %x.1 : Float(1, 1, 128, 128) = aten::mul(%b.1, %b.1)
  %y.1 : Float(1, 1, 128, 128) = aten::sin(%x.1)
  %z.1 : Float(1, 1, 128, 128) = aten::mul(%y.1, %y.1)
  return (%z.1)

If we scroll the dump lower, we’ll see how a fusion group was formed. Note that non-fusible ops, like aten::conv2d are not in it:

After creating fusion groups:
graph(%a.1 : Tensor):
  %5 : int = prim::Constant[value=1]()
  %4 : None = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %1 : int[] = prim::Constant[value=[1, 1, 1, 1]]()
  %6 : Tensor = aten::randn(%1, %4, %4, %4, %4)
  %b.1 : Float(1, 1, 128, 128) = aten::conv2d(%a.1, %6, %4, %3, %2, %3, %5)
  %z.2 : Float(1, 1, 128, 128) = prim::TensorExprGroup_0(%b.1)
  return (%z.2)
with prim::TensorExprGroup_0 = graph(%4 : Float(1, 1, 128, 128)):
  %x.2 : Float(1, 1, 128, 128) = aten::mul(%4, %4)
  %y.2 : Float(1, 1, 128, 128) = aten::sin(%x.2)
  %z.2 : Float(1, 1, 128, 128) = aten::mul(%y.2, %y.2)
  return (%z.2)

Fusion groups are only legal to run when the input shapes are exactly the same as we saw during profiling runs (they were encoded in the JIT IR before the fuser pass). Hence, we need to add runtime checks that would use a fallback path - a non-optimized version of the fused subgraph.

After guarding fusion groups:
graph(%a.1 : Tensor):
  %5 : int = prim::Constant[value=1]()
  %4 : None = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %1 : int[] = prim::Constant[value=[1, 1, 1, 1]]()
  %6 : Tensor = aten::randn(%1, %4, %4, %4, %4)
  %b.1 : Float(1, 1, 128, 128) = aten::conv2d(%a.1, %6, %4, %3, %2, %3, %5)
  %25 : Float(1, 1, 128, 128), %26 : bool = prim::TypeCheck[types=[Float(1, 1, 128, 128)]](%b.1)
  %27 : Float(1, 1, 128, 128) = prim::If(%26)
    block0():
      %z.3 : Float(1, 1, 128, 128) = prim::TensorExprGroup_0(%25)
      -> (%z.3)
    block1():
      %z.2 : Tensor = prim::FallbackGraph_1(%b.1)
      -> (%z.2)
  %21 : int[] = aten::size(%b.1)
  %22 : int[] = aten::size(%27)
  %23 : int[] = prim::BroadcastSizes(%21, %21)
  %24 : int[] = prim::BroadcastSizes(%23, %23)
  return (%27)
with prim::TensorExprGroup_0 = graph(%4 : Float(1, 1, 128, 128)):
  %x.2 : Float(1, 1, 128, 128) = aten::mul(%4, %4)
  %y.2 : Float(1, 1, 128, 128) = aten::sin(%x.2)
  %z.2 : Float(1, 1, 128, 128) = aten::mul(%y.2, %y.2)
  return (%z.2)
with prim::FallbackGraph_1 = graph(%b.1 : Float(1, 1, 128, 128)):
  %x.2 : Tensor = aten::mul(%b.1, %b.1)
  %y.2 : Tensor = aten::sin(%x.2)
  %z.2 : Tensor = aten::mul(%y.2, %y.2)
  return (%z.2)

Finally, after some additional cleanups, we end up with the following IR:

After TExprFuser:
graph(%a.1 : Tensor):
  %5 : int = prim::Constant[value=1]()
  %4 : None = prim::Constant()
  %3 : int[] = prim::Constant[value=[1, 1]]()
  %2 : int[] = prim::Constant[value=[0, 0]]()
  %1 : int[] = prim::Constant[value=[1, 1, 1, 1]]()
  %6 : Tensor = aten::randn(%1, %4, %4, %4, %4)
  %b.1 : Tensor = aten::conv2d(%a.1, %6, %4, %3, %2, %3, %5)
  %25 : Float(1, 1, 128, 128), %26 : bool = prim::TypeCheck[types=[Float(1, 1, 128, 128)]](%b.1)
  %27 : Tensor = prim::If(%26)
    block0():
      %z.3 : Float(1, 1, 128, 128) = prim::TensorExprGroup_0(%25)
      -> (%z.3)
    block1():
      %z.2 : Tensor = prim::FallbackGraph_1(%b.1)
      -> (%z.2)
  return (%27)
with prim::TensorExprGroup_0 = graph(%4 : Float(1, 1, 128, 128)):
  %x.2 : Float(1, 1, 128, 128) = aten::mul(%4, %4)
  %y.2 : Float(1, 1, 128, 128) = aten::sin(%x.2)
  %z.2 : Float(1, 1, 128, 128) = aten::mul(%y.2, %y.2)
  return (%z.2)
with prim::FallbackGraph_1 = graph(%b.1 : Float(1, 1, 128, 128, strides=[16384, 16384, 128, 1], requires_grad=0, device=cpu)):
  %x.2 : Tensor = aten::mul(%b.1, %b.1)
  %y.2 : Tensor = aten::sin(%x.2)
  %z.2 : Tensor = aten::mul(%y.2, %y.2)
  return (%z.2)

The prim::TensorExprGroup is a special node: when JIT interpreter reaches it, it invokes NNC to execute the subgraph.
Let’s now look at how NNC generates code for this subgraph. To do that, let’s enable logging in kernel.cpp:

$ PYTORCH_JIT_LOG_LEVEL=">>kernel.cpp"  python test.py

The first thing we see is the fused subgraph:

TensorExprKernel graph:
graph(%0 : Float(1, 1, 128, 128)):
  %x.2 : Float(1, 1, 128, 128) = aten::mul(%0, %0)
  %y.2 : Float(1, 1, 128, 128) = aten::sin(%x.2)
  %z.2 : Float(1, 1, 128, 128) = aten::mul(%y.2, %y.2)
  return (%z.2)

It is the same graph as we saw inside prim::TensorExprGroup node, and this is an ‘input’ for NNC. Firstly, it is getting lowered to tensor expressions:

Original Stmt:
{
  for (int i0 = 0; i0 < 1; i0++) {
    for (int i1 = 0; i1 < 1; i1++) {
      for (int i2 = 0; i2 < 128; i2++) {
        for (int i3 = 0; i3 < 128; i3++) {
          input1[i0, i1, i2, i3] = t0[(((0 + i0 * 16384) + i1 * 16384) + i2 * 128) + i3 * 1];
        }
      }
    }
  }
  for (int v = 0; v < 1; v++) {
    for (int v_1 = 0; v_1 < 1; v_1++) {
      for (int v_2 = 0; v_2 < 128; v_2++) {
        for (int v_3 = 0; v_3 < 128; v_3++) {
          aten_mul[v, v_1, v_2, v_3] = (input1(0, 0, v_2, v_3)) * (input1(0, 0, v_2, v_3));
        }
      }
    }
  }
  for (int v_4 = 0; v_4 < 1; v_4++) {
    for (int v_5 = 0; v_5 < 1; v_5++) {
      for (int v_6 = 0; v_6 < 128; v_6++) {
        for (int v_7 = 0; v_7 < 128; v_7++) {
          aten_sin[v_4, v_5, v_6, v_7] = sin(aten_mul(0, 0, v_6, v_7));
        }
      }
    }
  }
  for (int v_8 = 0; v_8 < 1; v_8++) {
    for (int v_9 = 0; v_9 < 1; v_9++) {
      for (int v_10 = 0; v_10 < 128; v_10++) {
        for (int v_11 = 0; v_11 < 128; v_11++) {
          aten_mul_1[v_8, v_9, v_10, v_11] = (aten_sin(0, 0, v_10, v_11)) * (aten_sin(0, 0, v_10, v_11));
        }
      }
    }
  }
}

We had 3 nodes in the fused graph and the generated tensor expression has 4 loop nests: 1 extra is for copying the input tensor. No need to worry about redundancy of that copy, it will be removed as unnecessary later.
That tensor expression (or more accurately, a tensor statement) is then transformed by NNC, and most notable transformations are:

  • inlining
  • vectorization
  • index flattenning

Inlining is propagating a definition of a tensor into its use. For instance, in the third loop instead of loading a value from a buffer aten_mul we could use its definition: aten_mul[i,j,k,l] = input1[i,j,k,l]*input1[i,j,k,l]. Similarly, we could inline input1 into aten_mul.

Vectorization is replacing a scalar accesses with vector counterparts.
E.g.

for (int i = 0; i < 128; i++) {
  a[i] = b[i] + c[i]
}

is transformed by vectorization to:

for(int i = 0; i < 16; i++) {
  a[i*8:i*8+8] = b[i*8:i*8+8] + c[i*8:i*8+8]
}

Lastly, the N-dimensional accesses to tensors are flattened into 1-d in preparation for LLVM or Cuda codegen.
E.g.

for (int i = 0; i < 128; i++) {
  for (int j = 0; j < 100; j++) {
    a[i,j] = 0
  }
}

is transformed into

for (int i = 0; i < 128; i++) {
  for (int j = 0; j < 100; j++) {
    a[i*100 + j] = 0
  }
}

We could see all of these transformations on our example:

Final Stmt:
{
  Allocate(aten_sin, float, {16384});
  for (int v = 0; v < 128; v++) {
    for (int _outer = 0; _outer < 16; _outer++) {
      aten_sin[Ramp(8 * (_outer + 16 * v), 1, 8)] =
         sin(t0[Ramp(8 * (_outer + 16 * v), 1, 8)] *
             t0[Ramp(8 * (_outer + 16 * v), 1, 8)]);
    }
  }
  for (int v_1 = 0; v_1 < 128; v_1++) {
    for (int _outer_1 = 0; _outer_1 < 16; _outer_1++) {
      aten_mul[Ramp(8 * (16 * v_1 + _outer_1), 1, 8)] =
         aten_sin[Ramp(8 * (16 * v_1 + _outer_1), 1, 8)] *
         aten_sin[Ramp(8 * (16 * v_1 + _outer_1), 1, 8)];
    }
  }
  Free(aten_sin);
}

We can see that originally we had 3 operators (which meant we did 3 passes through the memory), and after NNC transformation we have just two loops (which means we do 2 passes through the memory). It is possible to do it in one sweep as well - in fact, NNC would do it for CUDA - but for CPU we have a heuristics that prevent inlining potentially expensive computations (sin in this case).

The final step in this process is invoking LLVM or Cuda to compile this tensor expression into an executable binary. We can inspect that step in detail by running our test in the following way:

$ PYTORCH_JIT_LOG_LEVEL=">>llvm_codegen"  python test.py

Note: PyTorch needs to be built with USE_LLVM for this command to work.

The output is pretty big, but one could find LLVM IR right after the lowering from the tensor expression (before LLVM optimizations), LLVM IR after LLVM optimizations, and final assembly code.
If we used CUDA, then we could similarly peek into CUDA codegen by using PYTORCH_JIT_LOG_LEVEL=">>cuda_codegen".

10 Likes