Recently, there has been a surge of interest in addressing PyTorch’s operator problem, ranging from Zachary Devito’s MinTorch to various efforts from other PyTorch teams (Frontend, Compiler, etc.). All of these try to address the same problem
PyTorch’s operator surface is too large
Specifically, there are 2055 entries in native_functions.yaml (as of this post), and in many cases, the amount of work to do scales linearly. Here are some examples:
- Vmap: We need to implement batching rules for each one of these operators.
- NNC/XLA/Any other compiler: We need to implement lowerings for each one of these operators.
- Future Model Parallel efforts: We need to implement model-parallelism for each one of these operators.
- MetaTensor: We need to implement shape propagation rules for each one of these operators.
Given that these cover some of the most important projects facing PyTorch today, and given that frameworks like our competition, Jax, has ~200 operators, I would consider resolving PyTorch’s operator problem to be one of the most important challenges facing PyTorch today. Just imagine how much easier it would be to implement vmap, metatensor, namedtensors, or anything else if we only needed to worry about 200 operators instead of 2000.
Exacerbating the problem, our operator surface grows by the release.
However, despite the importance of this problem and the amount of people interested in it, there remains a lot of uncertainty about where our operators come from. For example, how many of our 2055 operators are pointwise operators? How many of them are overloads? How much of the reason that we have more operators than Jax is because we simply support more functionality?
Funnily enough, the reason these questions are so hard to answer is because we have so many operators. Nobody would spend hours combing through the operators to classify them, right?
Right?
A Taxonomy of PyTorch Operators (2055)
The above is an overview of how the 2055 PyTorch operators (+ 143 operators that only have overloaded versions and are thus double counted) breakdown. Let’s dive into the data.
Overloads (840)
If we remove the overloaded versions of operators, we drop from 2055 operators down to the much more manageable 1215 (this makes for 840 entries that only exist as overloads). Of the 983 different overload names, 306 of them are out versions of operators. This is the full list of how the overloads break down. Although there are some major contributors, such as out variants, scalar variants, and grad_input variants, there is also a fairly long tail of miscellaneous overload types.
Base Operators (1215 total)
Now, we’re left with a list of 1215 unique operators. This breaks down as the following - note that these sets are disjoint. That is, if an operator is convolution_backward
, I’m only going to count it for convolution.
Convolution Operators (67)
Some people might remember that I said there were 89 convolution operators previously. This number is different since this is the count after we remove overloads. This is a list of all the convolution operators. The list is primarily bloated by 1/2/3d, different backends, and backwards.
Pooling (42) / Batch Norm (15)
Like with conv, this list is similarly bloated.
Private (233)
This is a list of all private operators. These primarily breakdown into 1. foreach operators (72), 2. implementations for various operators (usually decomposed into during the autograd pass), and 3. random things ( _add_batch_dim
, etc.)
In-Place (183)
These are all operators like relu_,
add_
, etc.
Backward (78)
This is a list of all backward operators that aren’t already covered by the other categories.
Weird Stuff (40)
I also filtered out a bunch of operators that have properties like 1. don’t return tensors, 2. are backend specific. These are the operators that I filtered out.
Core (557)
Now, we’re finally left with 557 operators that are essentially, the core of PyTorch functionality. Modulo some weird/private operators + conv/batch norm/pooling, all other operators can be related to these core 557 operators, whether it’s through overloads, backwards, or in-place.
Essentially, this is what constitutes the core of PyTorch’s functionality.
So, how do these break down? At this point, there were no obvious ways to break it down further programmatically, so I went through and manually annotated each function. Some of these were guesses, so feel free to look through my annotations here!
Alias (40)
First of all, there are 40 operators that are simply aliases of other operators. However, I didn’t see any easy way of programmatically determining this. If possible, it would be nice if there was an easy way of mapping all aliases back to their original operator.
Composite Reduction (78)
These are all the operators that I think could be replaceable by a generalized reduction operator. This includes ops like all, sum, binary_cross_entropy, or layer_norm.
It does not include things like median
, which I did not see how could be implemented with a general reduction.
Primitive Pointwise (50)
These are all the pointwise operators that I couldn’t easily see implementable as compositions of other pointwise operators. This includes ops like tan, lgamma, or special_i1.
Note that it’s very possible this list could be reduced further, this is merely meant to give a general sense of how many operators we need to maintain functionality.
Composite Pointwise (87)
These are all the operators I think are implementable in terms of primitive pointwise ops. This includes ops like maximum, elu, dropout, etc.
Composite Matmul (13)
These are the operators that contain a matmul in them. This includes einsum, matmul, or addmm.
View/Reshape (70)
I’m stretching the definition here, but I put in this category all ops that primarily consist of copying tensors into other shapes or creating views from existing tensors. This includes operators like clone, cat, diag, or tile.
I didn’t subdivide which ones I thought could be implemented in terms of the other ones, but I suspect many could.
Factory (39)
I categorized all ops that return tensors that either 1. don’t take tensors as input, 2. are only using the input tensors as “metadata” for the resulting output. This includes ops like randn, rand_like, bartlett_window, and poisson
(maybe stretching it a little for poisson).
Named (5) / Complex (8) / Linalg (31) / Sparse (13) / FFT (20) / RNN (12) / Quantization (11) / Scatter + Gather (15) / FBGemm (7)
These are all the ops that are used for particular categories. Some examples are:
- Named:
rename, align_to, refine_names
- Complex:
view_as_real, real, imag
- Linalg:
inverse, logdet, matrix_rank, symeig
- Sparse:
smm, sparse_csr_tensor, indices
- FFT:
fft_fft, fft_ifft, fft_hfft
- RNN:
lstm, rnn_tanh, quantized_lstm_cell
- Quantization
: quantize_per_tensor, int_repr, fake_quantize_per_tensor_affine
- Scatter + Gather:
index, permute, scatter
- FBGemm:
fbgemm_linear_int8_weight_fp32_activation, fbgemm........
So, what is our real operator surface?
In some sense, the original count (2055) is our real operator surface. Regardless of whether it’s an alias or an overload, it still requires some amount of effort to handle. For example, look at the number of lines in functorch’s binary batching rules that need to deal with Scalar overloads, or the lines + effort needed to determine how to unify sum
and sum.dim_IntList
. This stuff wasn’t that difficult, but it did require some amount of brainpower and make the experience more painful than it needed to be.
If you assume that all overloads are expressing fundamentally similar behavior, then we’re closer to 1215 operators.
If we restrict ourselves to only counting user-facing operators, that’s closer to 557 + some convolution/pooling/batch norm operators.
There’s one more simplification we can make that I neglected in the above analysis. In PyTorch, we have this notion of an CompositeImplicitAutograd
operator. That is, these are PyTorch operators that are already implemented in terms of other operators. In most cases, that means that we’re able to decompose these operators into other PyTorch operators with no issue.
If we remove the core operators that are composite, we’re down to only 304 operators.
TL:DR: 2055 total ops, 1312 after removing composite operators, 703 after removing in-place/out variants, 376 “core” operators
Conclusion + Takeaways
So, there you go Unfortunately, looking at this breakdown, we can see that resolving PyTorch’s operator problem is not something that can be resolved by only tackling one part of the problem. A couple places on my mind to look:
- Scalar overloads. We have about 100 overloads that mention Scalar in them. However, I noticed that next to many of these entries, there’s a note that says “for C++ only, until we have conversion from C++ numbers to Tensor” (dated 2 years ago). Has that “until” happened yet? Can we make it happen?
- Overloads for BC(?) . Gregory Chanan mentioned that a significant number of our overloads exist only because we’re unable to add new default arguments for TS/Mobile compatibility reasons. That sounds like a problem that should be resolvable!
- Convolution Operators . We have 89 total, and 67 unique ones. And all of them are part of our operator surface. This shouldn’t be the case, and I’m looking forward to see somebody tackle this. The same goes for pooling operators (of which we have 42 unique ones).
- Composite pointwise/reductions. It’s possible to break these down into their constituent operators, the primary issue is how we can recover performance. Zachary Devito’s MinTorch has one proposal, perhaps there are others as well. Note that for some uses, performance doesn’t actually matter! For example, if you just want to propagate shapes or other metadata.