Tracing with Primitives: Update 0

Hey PyTorch Community!

This is the first in a new series of posts about how we’re introducing “primitive operators” (“prims”) to PyTorch. Prims are relatively simple “building blocks” that can be used to implement every other PyTorch operator and module in Python. Breaking complicated operators into simpler building blocks in Python has two tremendous advantages:

  • Tensor libraries, program transforms, and analysis tools can focus exclusively on prims and still support all of PyTorch’s operators.
  • Writing in composable Python makes PyTorch operators easier to read, add, modify and extend.

But it also has two significant drawbacks:

  • Operators written in Python and split into prims may be slower than those written in C++ with custom kernels.
  • Operators written in terms of prims may have numerical accuracy issues if the prims are not carefully designed.

These challenges are why PyTorch operators are written in C++ and frequently use custom kernels today. It’s incredibly important that PyTorch operators are fast and numerically accurate.

It’s possible, however, that tracing can address that first drawback, and that clever prim design can address the second. The PyTorch ecosystem has several mechanisms for tracing programs today, including torchscript and fx, and research projects like TorchDynamo and torchy. These tracers (elaborated on in the next section) convert Python programs into sequences of PyTorch operations called “traces.” Traces are interesting because they’re easy to modify and execute, and some trace executors, like torchscript or XLA (usable from PyTorch with PyTorch/XLA), can even execute traces faster than usual by dynamically generating new kernels optimized for a specific sequence of operations.

Representing traces using simpler primitive operations makes them even easier to modify and execute. XLA already works by reasoning almost exclusively about its own set of primitive operations. If a trace executor can optimize a trace of PyTorch’s prims, then that trace might execute even faster than PyTorch’s C++ operators today!

The second drawback simply requires careful design of PyTorch’s prims and the system that represents and transforms them, and in this series we’ll discuss both the prims’ design as well as how they’re working with trace acquisition methods and trace executors to run as quickly as possible.

Tracing in PyTorch

As mentioned above, there are a variety of tracers in the PyTorch ecosystem today, but we plan to acquire our traces using TorchDynamo, a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. TorchDynamo has been in development for some time and regularly posts updates. We’ll describe our integration with TorchDynamo more in future posts, but for now let’s just define what a trace is.

A “trace” is a sequence of PyTorch operations, typically constructed by observing the operations performed by a program. The FX documentation provides the following example of how a trace of a module can constructed and represented:

class MyModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.param = torch.nn.Parameter(torch.rand(3, 4))
        self.linear = torch.nn.Linear(4, 5)

    def forward(self, x):
        return self.linear(x + self.param).clamp(min=0.0, max=1.0)

module = MyModule()

from torch.fx import symbolic_trace

# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced = symbolic_trace(module)

# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
    %x : [#users=1] = placeholder[target=x]
    %param : [#users=1] = get_attr[target=param]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
    %linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
    %clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
    return clamp
"""

Here FX creates a trace (what it calls a “graph”) by tracing through the module MyModule with a “symbolic proxy,” a placeholder for a real tensor. TorchDynamo, like FX, also creates FX traces, and these traces can be transformed (not shown, but see the FX documentation for details) and then executed with real inputs:

t = torch.randn(1, 3, 4)
symbolic_traced(t)
: tensor([[[0.0000, 0.0000, 0.1651, 0.0000, 0.0000], 
           [0.0000, 0.0000, 0.0000, 0.4662, 0.0000], 
           [0.0000, 0.0000, 0.3520, 0.0000, 0.0000]]], grad_fn=<ClampBackward1>)

Today these traces are typically executed by calling into PyTorch’s ATen tensor library, which runs each operation separately.

The operations in this trace — the addition, linear module, and clamp — can be split into prims. The call to the linear module, for instance, can be rewritten as a matrix multiply and addition operation, and even those operations can be further split apart. In the next section we’ll look in detail how even the relatively simple addition operation can be written in Python to describe its broadcasting, type promotion and computation behavior, plus its handling of its alpha and out arguments.

Writing torch.add using prims

It may seem surprising that an operation as simple as torch.add can be split into even simpler primitive operations. But the add operation in PyTorch actually does several things. Let’s look at how we might write it in Python in terms of simpler operations:

def add(
    a: Union[Tensor, Number],
    b: Union[Tensor, Number],
    *,
    alpha: Optional[Number] = None,
    out: Optional[Tensor] = None
):
    """
    Python implementation of torch.add usings prims
    """
    computation_dtype, result_dtype = _elementwise_dtypes(
        a, b, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH
    )
    a, b = _convert_dtype(a, b, dtype=computation_dtype)

    a, b = broadcast(a, b)

    if alpha is not None:
        alpha_promotion_type = utils.dtype_to_type(computation_dtype)
        b = prims.mul(b, alpha_promotion_type(alpha))

    result = prims.add(a, b)

    result = _convert_dtype(result, dtype=result_dtype)

    if out is not None:
        out = _maybe_resize_out(out, result.shape)
        return copy_to(out, result, allow_cross_device=False)

    return result

There’s a lot going on here, and helpers like _elementwise_dtypes are independently interesting, but to keep our exposition focused let’s step through this without reviewing the details of the helpers.

The first thing this implementation does is type promotion with the following lines:

computation_dtype, result_dtype = _elementwise_dtypes(
        a, b, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH
    )
    a, b = _convert_dtype(a, b, dtype=computation_dtype)
  
  ... 
  
  if alpha is not None:
    alpha_promotion_type = utils.dtype_to_type(computation_dtype)
    
  ...
 
  result = _convert_dtype(result, dtype=result_dtype)

In PyTorch, like NumPy, when two tensors (arrays) with different datatypes are added together they go through “type promotion,” where a common datatype for the operation is chosen. For example:

# Type promotion examples
a = torch.tensor((1, 2, 3), dtype=torch.float32)
b = torch.tensor((4, 5, 6), dtype=torch.float64)
a + b
: tensor([5., 7., 9.], dtype=torch.float64)

a = torch.tensor((1, 2, 3), dtype=torch.long)
b = torch.tensor((4, 5, 6), dtype=torch.float16)
a + b
: tensor([5., 7., 9.], dtype=torch.float16)

In addition to type promotion rules applied to arguments of the different types, some operations also have a different precision for internal computations. Many operations with float16 and bfloat16 inputs, including torch.add, will actually upcast their inputs to float32 to compute, then write the result back to float16 or bfloat16. This is why both a computation_dtype and result_dtype are computed in the above, and then after the addition is performed in the computation_dtype the result is cast back to the result_dtype . The alpha argument, if given, is handled differently because it’s a Python scalar.

After the initial type promotion comes broadcasting. Broadcasting happens when two tensors have different shapes that can be broadcast to a common shape by adding new outermost dimensions or expanding existing dimensions of length one. For example:

# Broadcasting example
a = torch.randn((3, 1, 4))
b = torch.randn((5, 4))
torch.add(a, b).shape
: torch.Size([3, 5, 4])

Broadcasting is handled by the call to the broadcast helper.

The actual computation is handled by an optional call to prim.mul and a call to prim.add . This is the first time we’re seeing primitive operations used directly. These operations do not support type promotion or broadcasting or determining a computation datatype — they just perform the operation requested on the inputs as they’re provided. And prim.add unlike the torch.add operation, does not have an alpha parameter.

Finally, if the out argument is provided it may resized with _maybe_resize_out before being copied to with prim.copy_to . prim.copy_to actually implements NumPy-like “safe casting” that prevents unexpected data conversions, as shown here:

# Safe casting example
a = torch.randn((3, 1, 4))
b = torch.randn((5, 4))
out = torch.empty((3, 5, 4), dtype=torch.long)
torch.add(a, b, out=out)
: RuntimeError: result type Float can't be cast to the desired output type Long

Writing torch.add in Python as a series of simpler operations makes its type promotion, broadcasting, and internal computation behavior clear. Calling all these operations one after another, however, is much slower than just calling torch.add today. Today in PyTorch torch.add precompiles kernels that can fuse these data conversions, the addition and the multiplication, and the final copy to the out tensor. Getting similar performance from our Python implementation requires tracing.

Tracing the Python

Creating a program by tracing is very different from eagerly executing a series of operations. Traces are simply sequences of PyTorch operations, and because they don’t have helper functions or control flow they’re relatively easy to analyze, transform, and optimize.

Continuing with our Python implementation of torch.add from before, let’s consider calling it with just two float tensors that have the same shape. On these inputs our program doesn’t need to perform almost all of the operations described above — there’s no type promotion, no broadcasting, no multiplication with alpha, and no copying to an out tensor. When traced, this simplifies the program to just a call to prim.add :

# trace for Python add
# a=torch.randn((4, 5)), b=torch.randn((4, 5))

result = prim.add(a, b)
return result

Tracing lets us remove unnecessary operations, and executing this trace is just as fast as calling torch.add directly. On other inputs, however, more operations may appear in the trace:

# trace for Python add
# a=torch.randn((4, 5), dtype=torch.float16), b=torch.randn((4, 5))

a = prim.convert_element_type(a, torch.float32)
result = prim.add(a, b)
return result

And already this second trace is slower than torch.add if executed operation by operation, because the conversion of a from a float16 to a float32 input requires an extra kernel be launched.

This performance can be recovered if the trace executor can fuse the two kernels together, just like torch.add does. This is a type of fusion that NVIDIA’s experimental nvFuser and XLA can perform, and these simplified traces that better represent what an operation like torch.add is actually doing are easier for those systems to reason about. Internally, nvFuser and XLA have their own even more primitive components that represent hardware details, and without a simplified trace, like the ones above, that accurately represents all the semantics of torch.add they would be required to implement that same logic before optimizing. We’ll describe nvFuser’s own primitives in more detail in a future update.

Summary and What’s Next

In this first “tracing with primitives” update we introduced the concept of “primitive operations,” simple build blocks that can describe all of PyTorch’s existing modules and operators in Python. These primitives are incredibly interesting because systems that support them can then support all of PyTorch, and because they’re in Python they’re easy to read, add, modify and extend. We looked in detail at how torch.add can be written in Python and split into even simpler operations, and how that Python implementation would appear in traces for different inputs.

Without tracing, writing all of PyTorch’s operations in Python and using these prims would be slow, but with tracing and clever trace executors like nvFuser we expect to be as fast if not faster than PyTorch’s existing operators. That said, we’re not planning on getting rid of PyTorch’s existing implementations! While we’re excited about tracing and letting users use prims, we still want PyTorch’s eager mode to be great and fast. We think the development of primitives is interesting for users who use eager mode exclusively, too, as it should be easier to read an operation’s corresponding Python to understand how it works in detail, and that Python can still be copied and modified to facilitate writing new operations.

In future updates we’ll elaborate more on our prims and their design philosophy, plus focus more on transforming and optimizing traces. An experimental set of primitives is in development now and will be available in PyTorch soon.

If you have thoughts, questions, or suggestions, please comment below! We look forward to hearing from you.

17 Likes

We have a few tracing systems in PyTorch which each have their fun pros and cons; which we thinking about using for prim tracing?

TorchDynamo, as mentioned briefly in this paragraph:

3 Likes

What is the relationship to the functorch “decompositions” effort?

Also, do these “prims” decompositions sit above or below the Torch dispatcher?

CC @Chillee

2 Likes

I’ll answer the second question first, since it’s simpler :slight_smile:

Also, do these “prims” decompositions sit above or below the Torch dispatcher?

They sit above the Torch dispatcher.

What is the relationship to the functorch “decompositions” effort?

Functorch decompositions focus on decomposing ATen operators into simpler forms. In that sense, we expect that many of the functorch decompositions will largely be shared/reused for the “reference” implementations in PrimTorch.

In addition, the current plan is to write the functorch decompositions in terms of PyTorch (Python) ops, which would allow us to map from ATen decompositions to Prims.

TL;DR: PrimTorch decompositions become much more “primitive” than Functorch’s decompositions, and we hope to allow use of both Functorch and Primtorch decompositions together.

What is the API that PyTorch is trying to expose to backends / compiler writers? Is the idea that, when all these “decompositioney” things are rolled together that backends only need to handle prims?

Yep, that’s the key idea!

We do want to support selective “takeover” of entire operations that haven’t been converted to prims as an option, but by supporting just the prims you should enable all of PyTorch’s functionality.

Okay cool. Most of our backends support enough fusion that prims sounds like a great fit.

For us in Torch-MLIR it appears that it will be important that all these decompositioney things are TorchScript’able. Is that the case for both the prims decompositions and the functorch decompositions?

The FX traces of prims we’re creating can probably be converted into a torchscript representation if that’s what you’d like – we might want to look more carefully at the details together as the project matures, however.

Is tracing the only supported use case for decomposition to prims? That doesn’t work very well for Ahead-of-Time compilation to standalone deployable artifacts (e.g. for mobile, embedded, hermetic server deployments, etc.).

Currently we’re targeting tracing in our e2e flow, although the decompositions exist independently so they could be used in other scenarios. I’d be curious to hear more about the AoT scenarios you’re thinking of, too. Do you want to ping me on PyTorch Slack sometime?

@_sean_silva - I’d say tracing is generally a preferred way for export-y AOT-y workflows as it erases a lot of pythonism that is hard to remove with static analysis (that is a common experience with TorchScript). There’s some early work on making tracing more predictable and sound: allowing to be selective on dynamic things (e.g. varying sizes with fixed rank) and erroring out if some dynamic properties escape into python in a non-traceable way. There might be a sweet spot of mostly tracing and selective control flow at higher levels of the model code. @suo is working on some of it and can comment more

I’d be very curious in your AOT non-traceable use cases too

The most challenging use cases that I am aiming towards are generally NLP-ish end-to-end workloads: beam search, input pre/post processing, stateful processing, etc.

We have found that a very significant amount of development time in production deployments is dedicated to the “non-traceable” parts of those workloads, so supporting them directly in the compiler is important to achieving a large development velocity gains for those users. It does make things more difficult for the compiler :slight_smile:

I agree that mostly tracing with selective control is likely a sweet spot. This is something that JAX by its nature kind of forces (I’ve seen some prototypes for AoT workloads with JAX that build surrounding compilable code for control flow, state management/etc. which then calls into the traced graphs). But it seems that most of the PyTorch ecosystem is not written that way. Thoughts? It would definitely make my life much easier if I only had to handle traces + non-Tensor-compute side routines that call into the traces.

1 Like

Beam search is stereotypically the kind of thing you can’t trace away; I remember when TorchScript was originally under development beam search was the exemplar use case for loops and stuff.

PyTorch is all about tracing because we started off as an eager framework and added graph stuff on stuff. Tracing is the easiest way to reuse preexisting eager code. Obviously MLIR is coming at it from the other end, which is also valid, but it lives on the other end of the tradeoff curve.

1 Like

What’s the current direction/vision for how to deploy PyTorch models in demanding cases like those requiring beam search on the edge? In particular I’d like direction/vision from PT on a programming model for the edge ecosystem (possibly mixing tracing/non-tracing) that I can build a compiler for. Right now, I have been defaulting to building a TorchScript compiler, effectively, but I think we have the possibility of building something better / more predictable for users.

Note: Torch-MLIR itself is not strictly about edge/TorchScript – we are bringing up LTC and considering TorchDynamo. I personally have another hat that I wear which is very invested in the edge compiler ecosystem though, and I would very much like to unlock PyTorch for that ecosystem.

1 Like

torch.jit already allows a mix of tracing and AOT via combining torch.jit.trace, torch.jit.script and torch.jit.script_if_tracing. Will there be a path for something like torch.jit.trace to produce the new primitive ops rather than ATen ops?

@_sean_silva was there a follow-up to this discussion in Slack somewhere? If so could you please post a link or a summary?

I think TorchDynamo would be the thing analogous to jit.trace for the “tracing with primitives” project

jit.trace for the “tracing with primitives” project

@mruberry did you mean TorchDynamo will call into PrimTorch first by default?

After reading through the discussions in this track, in my view, PrimTorch is about a primitive ATen operations’ set, and the decomposition passes into it. Is my understanding of PrimTorch correct? cc @ezyang @Chillee