Tracing with Primitives: Update 2

Hey PyTorch developers and community members!

The “tracing with primitives” program continues to grow, with over a dozen direct contributors who have added Python reference implementations for more than a quarter of PyTorch’s torch, torch.nn.functional, torch.linalg, torch.fft, and torch.special operations! (To learn more about why we’re writing Python reference implementations of PyTorch operators see “Tracing with Primitives: Update 0”.)

As the project expands it uncovers interesting challenges and opportunities. Some of those — like implementing a new tensor subclass by just defining the primitive operations — will be elaborated on in future updates. This update, however, is about engaging the entire PyTorch community and some architectural decisions we’ve made. So if you’d like to help us implement PyTorch’s logic in a more readable, hackable, and composable way then see below!

Python References as PyTorch Puzzles

In April, Sasha Rush challenged the PyTorch community to write single line Python implementations for 16 PyTorch operators using a limited subset of other operations. That’s a really cool and educational challenge, and it speaks directly to one of our goals with primitive operations — making PyTorch’s operators easier to understand by having fewer and simpler underlying concepts. While the tracing with primitives program isn’t limiting itself to single line Python reference implementations (just take a look at our torch.roll reference), we do want to suggest a few operators to the community as “PyTorch Puzzles.” Our suggestions are:

  • inner
  • kaiser_window
  • kron
  • linalg.cross
  • trapezoid

If you’re interested in writing a Python implementation for one (or more) of these operators then please submit a PR. The first correct (and tested!) implementation will be accepted and merged into PyTorch. See PyTorch’s contribution guide before getting started. You may need to build PyTorch from source and read the C++ source for each of these operations to ensure they’re implemented correctly and with high fidelity to their current ATen implementations.

We hope some of the community is interested in joining us on this project, and these operators are just the start of what we could work on together. Especially if you’re interested in a particular type of operation (like random operations or distributions) then don’t hesitate to reach out.

Architectural Updates

The following sections elaborate on some recent or planned updates for the “architecture” of the tracing with primitives project. This includes updates to handling strides, type promotion, gradients, PyTorch’s out parameter, and registering Python references as decompositions.

Strides in PyTorch

In PyTorch, like NumPy, tensors have a shape and strides that describe how to access the elements of the tensor in a contiguous block of memory. (You can learn more about strides from its PyTorch Podcast episode.) And strides, since they’re mostly about how memory is accessed, are typically just a performance concern. Sometimes in PyTorch, however, strides are a semantic concern, too, as the following snippet shows:

a = torch.randn(2, 2, 2)
b = torch.reshape(a, (8,))

# Reshaping a like the above always
# produces a tensor of the specified shape
: torch.Size([8])

# Because a can be viewed as a 1D tensor, 
# b is a view of a
b._base is a:
: True

# However if the dimensions of a are permuted...
a = torch.randn(2, 2, 2).permute(1, 0, 2)
b = torch.reshape(a, (8,))

# The resulting tensor b is no longer a view of a
b._base is a
: False

Whether b is a view of a depends on the strides of a. PyTorch’s documentation for reshape actually warns users not to rely on this behavior, but similar operations, like flatten and contiguous also sometimes return a view or a copy depending on the strides of their input and have no such warning.

Returning a view instead of always producing a copy is an important performance optimization for eager PyTorch, but modeling that same behavior with primitives is very tricky. PyTorch’s stride propagation has numerous vagaries, and its algorithms for doing so often aren’t commutative or associative. Non-commutative and non-associative properties are particularly tricky to mimic with primitives, because reference implementations often change the number and type of functions called. This is best illustrated by an example, like this reference for the clamp operator:

def clamp(
    a: TensorLikeType,
    min: Optional[TensorOrNumberLikeType] = None,
    max: Optional[TensorOrNumberLikeType] = None,
) -> TensorLikeType:
    a, min, max = _maybe_broadcast(a, min, max)

    if min is not None and max is not None:
        return minimum(maximum(a, min), max)
    if min is not None:
        return maximum(a, min)
    if max is not None:
        return minimum(a, max)

    msg = "clamp called but both min and max are none!"
    raise ValueError(msg)

This Python reference implementation follows the documentation for clamp, but its behavior is distinct from the actual clamp operation because calling two elementwise binary operations — minimum and maximum — is distinct from calling a single elementwise ternary operation (clamp), as the following snippet shows:

a = torch.randn(2, 1, 2)
b = torch.randn(1, 2, 1)
c = torch.randn(2, 2, 2).permute(1, 2, 0)

torch.minimum(torch.maximum(a, b), c).stride()
: (4, 2, 1)

torch.clamp(a, b, c).stride()
: (4, 1, 2)

This example may seem niche — it involves broadcasting and noncontiguity — but PyTorch needs to execute all its programs correctly, whether they’re run eagerly or traced to primitives.

This presents a bit of a quandary — modeling PyTorch’s stride behavior today is at best difficult and at worst in direct conflict with our goals for the tracing with primitives program, but failing to model it means we can’t determine whether operations like reshape will return a copy or a view. If PyTorch didn’t support inplace operations this wouldn’t be an issue, but inplace operations (operations that modify a tensor’s storage directly), behave differently on views vs. copies.

To resolve this issue, our current thinking is to make PyTorch’s operator semantics stride agnostic when tracing. This is a discrepancy with eager PyTorch that we may need to address in the future (possibly by introducing copy-on-write behavior to operations like reshape), but we think reliance on view vs. copy behavior is very, very rare. See this RFC for more details on stride updates and feel free to add your thoughts there, too!

Type Promotion

Type promotion occurs when certain operations, most notably elementwise operations, are performed on numbers or tensors with different datatypes. Essentially, it computes. a common “computation datatype” for the operation to be performed in. Type promotion can be extremely convenient, and it occurs in languages like C++ and Python, too:

# Python int x float type promotion
5 + 3.
: 8.0

# PyTorch int x float tensor type promotion
a = torch.tensor(5)
b = torch.tensor(3.)
a + b
: tensor(8.)

Type promotion is often intuitive, although its details are a little convoluted, involving multiple type promotion “kinds” and special cases. See the documentation for the Python reference implementation of elementwise type promotion here for details.

What’s important to know for our discussion, however, is that:

  • PyTorch considers numbers, tensors with zero dimensions (“scalar tensors”), and tensors with one or more dimensions with different priority when performing type promotion
  • Low precision floating point and complex datatypes (float16, bfloat16, and complex32) may be “upcast” to float32 or complex64 while the operation is performed and then “downcast” back to float16, bfloat16, or complex32 when the operation completes

There are a two major challenges with emulating this behavior in Python references. First, because numbers, tensors with zero dimensions, and tensors with one or more dimensions have different type promotion priorities it means that type promotion is not associative. For example:

a = 5.
b = torch.randint(0, 9, (2, 2), dtype=torch.long)
c = torch.randn((2, 2), dtype=torch.float16)

((a + b) + c).dtype
: torch.float32

(a + (b + c)).dtype
: torch.float16

And, as mentioned above with strides, properties that aren’t associative are tricky to preserve when writing a reference.

Second, low precision floating point and complex types are typically upcast on CUDA devices, but not on CPUs. We could model this divergence by writing device-specific code in the references, but the references are generally intended to be canonical device-agnostic implementation of PyTorch operators.

Preserving the output datatypes is incredibly important for staying consistent with eager PyTorch. These datatypes are user-facing and affect future computations. To model them consistently despite their non-associativity we decided to implement a decorator that performs type promotion before an operation is called, and then performs any output conversions after the operation is called. This decorator ensures that output datatypes are consistent with PyTorch’s eager mode, but it does mean that Python references can be more accurate than their eager counterparts for low precision floating point and complex datatypes. While Python references are supposed to be consistent, we think being more accurate in some cases is an OK divergence.

Consistent Gradients

Gradients computed by PyTorch are often tricky to mimic when writing reference implementations, too. Consider the computed gradient of the relu operation:

a = torch.tensor(((-2, -1), (0, 1)), dtype=torch.float32, requires_grad=True)
: tensor([[-2., -1.], 
          [ 0., 1.]], requires_grad=True)

b = torch.relu(a)
: tensor([[0., 0.], 
          [0., 1.]])

Note that no gradient propagates back to the element of a with value 0.

A natural Python reference for relu uses the maximum operation, but in PyTorch produces a different gradient:

a = torch.tensor(((-2, -1), (0, 1)), dtype=torch.float32, requires_grad=True)
c = torch.maximum(torch.zeros_like(a), a)
: tensor([[0.0000, 0.0000], 
          [0.5000, 1.0000]])

Conceptually, the backward formula for maximum “splits” the incoming gradient when the two values it’s comparing are equal. This is unlike relu, which always produces a gradient of zero for values less than or equal to zero.

A correct reference implementation of relu has to use where instead:

a = torch.tensor(((-2, -1), (0, 1)), dtype=torch.float32, requires_grad=True)
d = torch.where(a < 0, 0, a)
: tensor([[0., 0.], 
          [0., 1.]])

The implementation above also correctly propagates nan values from the input. While there may be a few issues with PyTorch’s gradient handling, we are generally happy with how gradients are computed, and we’ve decided to embrace the difference in gradients between operations like maximum and where, even though they can sometimes have the same results for forward operations while still having distinct backwards. Python references just have to be careful about implementing the same gradient and nan propagation behavior as the referenced operation.

The out Parameter

Many operations in PyTorch accept a keyword-only out parameter, where a tensor can be provided to an operation and modified inplace to hold its output:

out = torch.empty(3)
a = torch.tensor((1, 2, 3.))
b = torch.tensor((4, 5, 6.))

torch.add(a, b, out=out)
: tensor([5., 7., 9.])

Some operations take advantage of the out argument if it’s provided and reuse its memory, although this is not always possible. Historically in PyTorch the out argument didn’t even have to have the same size as the output, although this behavior is now deprecated since it was often indicative of an error:

too_small = torch.empty(2)
torch.add(a, b, out=too_small)
: UserWarning: An output with one or more elements was resized since 
  it had shape [2], which does not match the required output shape [3].
  This behavior is deprecated, and in a future PyTorch release outputs will 
  not be resized unless they have zero elements. 
  You can explicitly reuse an out tensor t by resizing it, inplace, 
  to zero elements with t.resize_(0). 
  (Triggered internally at ../aten/src/ATen/native/Resize.cpp:24.)
: tensor([5., 7., 9.])

Full details of how the out parameter is intended to work can be found in PyTorch’s Developer FAQ, although a few operators still have to be updated to implement the behavior described there.

One of the advantages of working in Python is that it’s much easier to enforce consistency through composable transforms, and for the out parameter we’ve implemented another decorator that adds the out parameter to a Python reference. This also means that Python references can’t implement optimizations with their out arguments, and are reliant on trace executors to optimize their memory usage.

Registering Python References as Decompositions

Python reference implementations for operators are only as useful as they are usable, and today in PyTorch we have several mechanisms where you can start experimenting with them:

  • Call the implementations in the private torch._refs module directly
  • Set torch._prims.context.TorchRefsMode (see details here) to remap calls to instead. The context can be strict (if doesn’t have a the operation will fail) or not ( will run as usual if there is no
  • Look them up in torch._decomps.decomposition_table. References for ATen operators are registered as decompositions using the @register_decomposition decorator. __torch_dispatch__ users, like AoTAutograd, use the references this way…

Remember to register your refs as decompositions when writing them! It doesn’t happen automatically; you have to add the decorator.

What’s Next

The tracing with primitives program is continuing to add more Python reference implementations, and in our next update we’ll also look at the end-to-end process of how PyTorch traces of primitive operations are acquired, represented, transformed, and then executed by trace executors like nvFuser. We’ll even have a prototype you can try!

Thoughts/Questions/Suggestions? Don’t hesitate to reach out and comment below!


Does this mean we’ll be able to use custom compilers for primitives like triton-lang? It’s been a dream of mine to write pytorch in a device agnostic language like triton so that pytorch is fairly decoupled from the hardware.

Some context:

1 Like

Yep, it’s absolutely a goal to better support use cases like that!


Really excited to see this move along!

The stride issue is interesting – I wonder whether @bdhirsh’s functionalization pass could help? Could maybe help with the out situation too.

1 Like