Hey PyTorch Community!
This is the first in a new series of posts about how we’re introducing “primitive operators” (“prims”) to PyTorch. Prims are relatively simple “building blocks” that can be used to implement every other PyTorch operator and module in Python. Breaking complicated operators into simpler building blocks in Python has two tremendous advantages:
- Tensor libraries, program transforms, and analysis tools can focus exclusively on prims and still support all of PyTorch’s operators.
- Writing in composable Python makes PyTorch operators easier to read, add, modify and extend.
But it also has two significant drawbacks:
- Operators written in Python and split into prims may be slower than those written in C++ with custom kernels.
- Operators written in terms of prims may have numerical accuracy issues if the prims are not carefully designed.
These challenges are why PyTorch operators are written in C++ and frequently use custom kernels today. It’s incredibly important that PyTorch operators are fast and numerically accurate.
It’s possible, however, that tracing can address that first drawback, and that clever prim design can address the second. The PyTorch ecosystem has several mechanisms for tracing programs today, including torchscript and fx, and research projects like TorchDynamo and torchy. These tracers (elaborated on in the next section) convert Python programs into sequences of PyTorch operations called “traces.” Traces are interesting because they’re easy to modify and execute, and some trace executors, like torchscript or XLA (usable from PyTorch with PyTorch/XLA), can even execute traces faster than usual by dynamically generating new kernels optimized for a specific sequence of operations.
Representing traces using simpler primitive operations makes them even easier to modify and execute. XLA already works by reasoning almost exclusively about its own set of primitive operations. If a trace executor can optimize a trace of PyTorch’s prims, then that trace might execute even faster than PyTorch’s C++ operators today!
The second drawback simply requires careful design of PyTorch’s prims and the system that represents and transforms them, and in this series we’ll discuss both the prims’ design as well as how they’re working with trace acquisition methods and trace executors to run as quickly as possible.
Tracing in PyTorch
As mentioned above, there are a variety of tracers in the PyTorch ecosystem today, but we plan to acquire our traces using TorchDynamo, a new tracer that uses Python’s frame evaluation API to automatically create FX traces from existing PyTorch programs. TorchDynamo has been in development for some time and regularly posts updates. We’ll describe our integration with TorchDynamo more in future posts, but for now let’s just define what a trace is.
A “trace” is a sequence of PyTorch operations, typically constructed by observing the operations performed by a program. The FX documentation provides the following example of how a trace of a module can constructed and represented:
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 4))
self.linear = torch.nn.Linear(4, 5)
def forward(self, x):
return self.linear(x + self.param).clamp(min=0.0, max=1.0)
module = MyModule()
from torch.fx import symbolic_trace
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced = symbolic_trace(module)
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)
"""
graph():
%x : [#users=1] = placeholder[target=x]
%param : [#users=1] = get_attr[target=param]
%add : [#users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {})
%linear : [#users=1] = call_module[target=linear](args = (%add,), kwargs = {})
%clamp : [#users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0})
return clamp
"""
Here FX creates a trace (what it calls a “graph”) by tracing through the module MyModule
with a “symbolic proxy,” a placeholder for a real tensor. TorchDynamo, like FX, also creates FX traces, and these traces can be transformed (not shown, but see the FX documentation for details) and then executed with real inputs:
t = torch.randn(1, 3, 4)
symbolic_traced(t)
: tensor([[[0.0000, 0.0000, 0.1651, 0.0000, 0.0000],
[0.0000, 0.0000, 0.0000, 0.4662, 0.0000],
[0.0000, 0.0000, 0.3520, 0.0000, 0.0000]]], grad_fn=<ClampBackward1>)
Today these traces are typically executed by calling into PyTorch’s ATen tensor library, which runs each operation separately.
The operations in this trace — the addition, linear module, and clamp — can be split into prims. The call to the linear module, for instance, can be rewritten as a matrix multiply and addition operation, and even those operations can be further split apart. In the next section we’ll look in detail how even the relatively simple addition operation can be written in Python to describe its broadcasting, type promotion and computation behavior, plus its handling of its alpha and out arguments.
Writing torch.add using prims
It may seem surprising that an operation as simple as torch.add can be split into even simpler primitive operations. But the add operation in PyTorch actually does several things. Let’s look at how we might write it in Python in terms of simpler operations:
def add(
a: Union[Tensor, Number],
b: Union[Tensor, Number],
*,
alpha: Optional[Number] = None,
out: Optional[Tensor] = None
):
"""
Python implementation of torch.add usings prims
"""
computation_dtype, result_dtype = _elementwise_dtypes(
a, b, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH
)
a, b = _convert_dtype(a, b, dtype=computation_dtype)
a, b = broadcast(a, b)
if alpha is not None:
alpha_promotion_type = utils.dtype_to_type(computation_dtype)
b = prims.mul(b, alpha_promotion_type(alpha))
result = prims.add(a, b)
result = _convert_dtype(result, dtype=result_dtype)
if out is not None:
out = _maybe_resize_out(out, result.shape)
return copy_to(out, result, allow_cross_device=False)
return result
There’s a lot going on here, and helpers like _elementwise_dtypes
are independently interesting, but to keep our exposition focused let’s step through this without reviewing the details of the helpers.
The first thing this implementation does is type promotion with the following lines:
computation_dtype, result_dtype = _elementwise_dtypes(
a, b, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.OP_MATH
)
a, b = _convert_dtype(a, b, dtype=computation_dtype)
...
if alpha is not None:
alpha_promotion_type = utils.dtype_to_type(computation_dtype)
...
result = _convert_dtype(result, dtype=result_dtype)
In PyTorch, like NumPy, when two tensors (arrays) with different datatypes are added together they go through “type promotion,” where a common datatype for the operation is chosen. For example:
# Type promotion examples
a = torch.tensor((1, 2, 3), dtype=torch.float32)
b = torch.tensor((4, 5, 6), dtype=torch.float64)
a + b
: tensor([5., 7., 9.], dtype=torch.float64)
a = torch.tensor((1, 2, 3), dtype=torch.long)
b = torch.tensor((4, 5, 6), dtype=torch.float16)
a + b
: tensor([5., 7., 9.], dtype=torch.float16)
In addition to type promotion rules applied to arguments of the different types, some operations also have a different precision for internal computations. Many operations with float16 and bfloat16 inputs, including torch.add, will actually upcast their inputs to float32 to compute, then write the result back to float16 or bfloat16. This is why both a computation_dtype
and result_dtype
are computed in the above, and then after the addition is performed in the computation_dtype
the result is cast back to the result_dtype
. The alpha
argument, if given, is handled differently because it’s a Python scalar.
After the initial type promotion comes broadcasting. Broadcasting happens when two tensors have different shapes that can be broadcast to a common shape by adding new outermost dimensions or expanding existing dimensions of length one. For example:
# Broadcasting example
a = torch.randn((3, 1, 4))
b = torch.randn((5, 4))
torch.add(a, b).shape
: torch.Size([3, 5, 4])
Broadcasting is handled by the call to the broadcast
helper.
The actual computation is handled by an optional call to prim.mul
and a call to prim.add
. This is the first time we’re seeing primitive operations used directly. These operations do not support type promotion or broadcasting or determining a computation datatype — they just perform the operation requested on the inputs as they’re provided. And prim.add
unlike the torch.add
operation, does not have an alpha
parameter.
Finally, if the out
argument is provided it may resized with _maybe_resize_out
before being copied to with prim.copy_to
. prim.copy_to
actually implements NumPy-like “safe casting” that prevents unexpected data conversions, as shown here:
# Safe casting example
a = torch.randn((3, 1, 4))
b = torch.randn((5, 4))
out = torch.empty((3, 5, 4), dtype=torch.long)
torch.add(a, b, out=out)
: RuntimeError: result type Float can't be cast to the desired output type Long
Writing torch.add
in Python as a series of simpler operations makes its type promotion, broadcasting, and internal computation behavior clear. Calling all these operations one after another, however, is much slower than just calling torch.add
today. Today in PyTorch torch.add
precompiles kernels that can fuse these data conversions, the addition and the multiplication, and the final copy to the out
tensor. Getting similar performance from our Python implementation requires tracing.
Tracing the Python
Creating a program by tracing is very different from eagerly executing a series of operations. Traces are simply sequences of PyTorch operations, and because they don’t have helper functions or control flow they’re relatively easy to analyze, transform, and optimize.
Continuing with our Python implementation of torch.add
from before, let’s consider calling it with just two float tensors that have the same shape. On these inputs our program doesn’t need to perform almost all of the operations described above — there’s no type promotion, no broadcasting, no multiplication with alpha, and no copying to an out tensor. When traced, this simplifies the program to just a call to prim.add
:
# trace for Python add
# a=torch.randn((4, 5)), b=torch.randn((4, 5))
result = prim.add(a, b)
return result
Tracing lets us remove unnecessary operations, and executing this trace is just as fast as calling torch.add
directly. On other inputs, however, more operations may appear in the trace:
# trace for Python add
# a=torch.randn((4, 5), dtype=torch.float16), b=torch.randn((4, 5))
a = prim.convert_element_type(a, torch.float32)
result = prim.add(a, b)
return result
And already this second trace is slower than torch.add
if executed operation by operation, because the conversion of a
from a float16 to a float32 input requires an extra kernel be launched.
This performance can be recovered if the trace executor can fuse the two kernels together, just like torch.add
does. This is a type of fusion that NVIDIA’s experimental nvFuser and XLA can perform, and these simplified traces that better represent what an operation like torch.add
is actually doing are easier for those systems to reason about. Internally, nvFuser and XLA have their own even more primitive components that represent hardware details, and without a simplified trace, like the ones above, that accurately represents all the semantics of torch.add
they would be required to implement that same logic before optimizing. We’ll describe nvFuser’s own primitives in more detail in a future update.
Summary and What’s Next
In this first “tracing with primitives” update we introduced the concept of “primitive operations,” simple build blocks that can describe all of PyTorch’s existing modules and operators in Python. These primitives are incredibly interesting because systems that support them can then support all of PyTorch, and because they’re in Python they’re easy to read, add, modify and extend. We looked in detail at how torch.add
can be written in Python and split into even simpler operations, and how that Python implementation would appear in traces for different inputs.
Without tracing, writing all of PyTorch’s operations in Python and using these prims would be slow, but with tracing and clever trace executors like nvFuser we expect to be as fast if not faster than PyTorch’s existing operators. That said, we’re not planning on getting rid of PyTorch’s existing implementations! While we’re excited about tracing and letting users use prims, we still want PyTorch’s eager mode to be great and fast. We think the development of primitives is interesting for users who use eager mode exclusively, too, as it should be easier to read an operation’s corresponding Python to understand how it works in detail, and that Python can still be copied and modified to facilitate writing new operations.
In future updates we’ll elaborate more on our prims and their design philosophy, plus focus more on transforming and optimizing traces. An experimental set of primitives is in development now and will be available in PyTorch soon.
If you have thoughts, questions, or suggestions, please comment below! We look forward to hearing from you.