Loop_tool's lazy frontend - Experimenting with symbolic laziness

I’ve been experimenting with symbolic kernel definitions and indexing math. The result is a system that provides a fair amount of useful information, including automatic forward/backward shape inference and calculation of stride information for use with torch.as_strided. This might be useful for compilation as well for debugging or defining extremely generic networks. I’m posting here to explain the idea and maybe motivate addition of the concept to PyTorch in some capacity.

Symbolic Indexing

I recently added symbolic indexing to loop_tool (an experimental loop toolkit) to define operations like convolutions. The idea is that symbolic indices can always be 1) unified with other constraints to calculate actual values and 2) differentiated to determine stride information.

For example, indexing logic like x = x_o + k yields a convolution of size k over input dimension x. If we dive into the shape implications of that equation, we can determine that |x_o| must be equal to |x| - |k| + 1. We now know the value of |x_o| only if we know |k| and |x|. The same is true for |x|: we know |x| if we have |x_o| and |k|. This implies that shape inference can happen “backwards” to some degree.

Further, we can determine how x changes with respect to x_o because we have the equation right there. If the convolution were to be strided, e.g. x = 2 * x_o + k, we know that stride(x_o) = 2 * stride(x) by differentiation.

These two ideas (symbolic constraint unification and differentiation) allow us to determine information that is certainly useful for a compiler, such as memory allocation and memory access information. Even in dynamic settings, the symbolic solutions to unified + differentiated indices can be emitted as code for fast dynamic indexing.

The experiment presented here is to expose these ideas directly to the frontend. In an eager use case, this lets us check shapes on the fly as we construct tensors, which is an already well-loved feature of PyTorch. However, in a lazy setting, users can do tricky and potentially useful things like automatically derive input shapes from output shapes.

Interactive Notebook

An example is given below. If you’d like to play with it immediately, here’s a notebook Google Colab

Example

First, we start with symbolically sized tensors

X = lt.Tensor(lt.Symbol("x"))
W = lt.Tensor(lt.Symbol("k"))

Then we define a contrived function to show this idea

def broadcast_and_conv_dilated(X, W):
  # local names for dims
  channel, spatial, spatial_out, window = [lt.Symbol(_) for _ in ["c", "s", "s_o", "w"]]

  # arbitrary well sized broadcasting tensor
  C = lt.Tensor(128, 20)
  X_broadcast = C.to(channel, spatial) * X.to(spatial)

  # convolve into `spatial_out` along `spatial` with window `window` dilated by 2

  # spatial = spatial_out + 2 * window <-- index equation for dilated convolution
  idx_fn = (spatial, spatial_out + lt.Expr(2) * window)

  # first, take a view into X so that its shape is [channel, spatial_out, window]
  X_col = X_broadcast.to(channel, spatial_out, window, constraints=[idx_fn])

  # then, reduce over the window dimension
  Y = (X_col * W.to(window)).sum(window)

  # shape is [channel, spatial_out]
  return Y

Convolutional ideas can be written in simplistic einsum-like notation if you add symbolic indexing. It’s functionally equivalent to views with custom strides (ala torch.as_strided), but the strides are calculated for you.

Next, we invoke the contrived function above and set the expected output size.

Y = broadcast_and_conv_dilated(X, W)
# set expected output size
Y.set_size(128, 16)

Finally, we unify the constraints in the program. This is the “lazy” aspect of the idea. We aren’t computing anything yet, but we are aware that a graph has been built under the hood for us. After calling unify() we can access the shapes of the inputs, which would have been a pain to calculate by hand.

Y.unify()

# everything gets derived
assert X.shape == [20]
assert W.shape == [3]

# and a compiler can immediately use it
print(Y.loop_tree)

And that’s it, thanks for reading this far! I’d appreciate any feedback on this idea. :slight_smile:

4 Likes