Where do the 2000+ PyTorch operators come from?: More than you wanted to know

Recently, there has been a surge of interest in addressing PyTorch’s operator problem, ranging from Zachary Devito’s MinTorch to various efforts from other PyTorch teams (Frontend, Compiler, etc.). All of these try to address the same problem

PyTorch’s operator surface is too large

Specifically, there are 2055 entries in native_functions.yaml (as of this post), and in many cases, the amount of work to do scales linearly. Here are some examples:

  1. Vmap: We need to implement batching rules for each one of these operators.
  2. NNC/XLA/Any other compiler: We need to implement lowerings for each one of these operators.
  3. Future Model Parallel efforts: We need to implement model-parallelism for each one of these operators.
  4. MetaTensor: We need to implement shape propagation rules for each one of these operators.

Given that these cover some of the most important projects facing PyTorch today, and given that frameworks like our competition, Jax, has ~200 operators, I would consider resolving PyTorch’s operator problem to be one of the most important challenges facing PyTorch today. Just imagine how much easier it would be to implement vmap, metatensor, namedtensors, or anything else if we only needed to worry about 200 operators instead of 2000.

Exacerbating the problem, our operator surface grows by the release.

198714225_479903543240858_4721345600759586249_n

However, despite the importance of this problem and the amount of people interested in it, there remains a lot of uncertainty about where our operators come from. For example, how many of our 2055 operators are pointwise operators? How many of them are overloads? How much of the reason that we have more operators than Jax is because we simply support more functionality?

Funnily enough, the reason these questions are so hard to answer is because we have so many operators. Nobody would spend hours combing through the operators to classify them, right?

Right?

A Taxonomy of PyTorch Operators (2055)

The above is an overview of how the 2055 PyTorch operators (+ 143 operators that only have overloaded versions and are thus double counted) breakdown. Let’s dive into the data.

Overloads (840)

If we remove the overloaded versions of operators, we drop from 2055 operators down to the much more manageable 1215 (this makes for 840 entries that only exist as overloads). Of the 983 different overload names, 306 of them are out versions of operators. This is the full list of how the overloads break down. Although there are some major contributors, such as out variants, scalar variants, and grad_input variants, there is also a fairly long tail of miscellaneous overload types.

Base Operators (1215 total)

Now, we’re left with a list of 1215 unique operators. This breaks down as the following - note that these sets are disjoint. That is, if an operator is convolution_backward , I’m only going to count it for convolution.

Convolution Operators (67)

Some people might remember that I said there were 89 convolution operators previously. This number is different since this is the count after we remove overloads. This is a list of all the convolution operators. The list is primarily bloated by 1/2/3d, different backends, and backwards.

Pooling (42) / Batch Norm (15)

Like with conv, this list is similarly bloated.

Private (233)

This is a list of all private operators. These primarily breakdown into 1. foreach operators (72), 2. implementations for various operators (usually decomposed into during the autograd pass), and 3. random things ( _add_batch_dim , etc.)

In-Place (183)

These are all operators like relu_, add_ , etc.

Backward (78)

This is a list of all backward operators that aren’t already covered by the other categories.

Weird Stuff (40)

I also filtered out a bunch of operators that have properties like 1. don’t return tensors, 2. are backend specific. These are the operators that I filtered out.

Core (557)

Now, we’re finally left with 557 operators that are essentially, the core of PyTorch functionality. Modulo some weird/private operators + conv/batch norm/pooling, all other operators can be related to these core 557 operators, whether it’s through overloads, backwards, or in-place.

Essentially, this is what constitutes the core of PyTorch’s functionality.

So, how do these break down? At this point, there were no obvious ways to break it down further programmatically, so I went through and manually annotated each function. Some of these were guesses, so feel free to look through my annotations here!

Alias (40)

First of all, there are 40 operators that are simply aliases of other operators. However, I didn’t see any easy way of programmatically determining this. If possible, it would be nice if there was an easy way of mapping all aliases back to their original operator.

Composite Reduction (78)

These are all the operators that I think could be replaceable by a generalized reduction operator. This includes ops like all, sum, binary_cross_entropy, or layer_norm. It does not include things like median , which I did not see how could be implemented with a general reduction.

Primitive Pointwise (50)

These are all the pointwise operators that I couldn’t easily see implementable as compositions of other pointwise operators. This includes ops like tan, lgamma, or special_i1. Note that it’s very possible this list could be reduced further, this is merely meant to give a general sense of how many operators we need to maintain functionality.

Composite Pointwise (87)

These are all the operators I think are implementable in terms of primitive pointwise ops. This includes ops like maximum, elu, dropout, etc.

Composite Matmul (13)

These are the operators that contain a matmul in them. This includes einsum, matmul, or addmm.

View/Reshape (70)

I’m stretching the definition here, but I put in this category all ops that primarily consist of copying tensors into other shapes or creating views from existing tensors. This includes operators like clone, cat, diag, or tile. I didn’t subdivide which ones I thought could be implemented in terms of the other ones, but I suspect many could.

Factory (39)

I categorized all ops that return tensors that either 1. don’t take tensors as input, 2. are only using the input tensors as “metadata” for the resulting output. This includes ops like randn, rand_like, bartlett_window, and poisson (maybe stretching it a little for poisson).

Named (5) / Complex (8) / Linalg (31) / Sparse (13) / FFT (20) / RNN (12) / Quantization (11) / Scatter + Gather (15) / FBGemm (7)

These are all the ops that are used for particular categories. Some examples are:

  • Named: rename, align_to, refine_names
  • Complex: view_as_real, real, imag
  • Linalg: inverse, logdet, matrix_rank, symeig
  • Sparse: smm, sparse_csr_tensor, indices
  • FFT: fft_fft, fft_ifft, fft_hfft
  • RNN: lstm, rnn_tanh, quantized_lstm_cell
  • Quantization : quantize_per_tensor, int_repr, fake_quantize_per_tensor_affine
  • Scatter + Gather: index, permute, scatter
  • FBGemm: fbgemm_linear_int8_weight_fp32_activation, fbgemm........

So, what is our real operator surface?

In some sense, the original count (2055) is our real operator surface. Regardless of whether it’s an alias or an overload, it still requires some amount of effort to handle. For example, look at the number of lines in functorch’s binary batching rules that need to deal with Scalar overloads, or the lines + effort needed to determine how to unify sum and sum.dim_IntList. This stuff wasn’t that difficult, but it did require some amount of brainpower and make the experience more painful than it needed to be.

If you assume that all overloads are expressing fundamentally similar behavior, then we’re closer to 1215 operators.

If we restrict ourselves to only counting user-facing operators, that’s closer to 557 + some convolution/pooling/batch norm operators.

There’s one more simplification we can make that I neglected in the above analysis. In PyTorch, we have this notion of an CompositeImplicitAutograd operator. That is, these are PyTorch operators that are already implemented in terms of other operators. In most cases, that means that we’re able to decompose these operators into other PyTorch operators with no issue.

If we remove the core operators that are composite, we’re down to only 304 operators.

TL:DR: 2055 total ops, 1312 after removing composite operators, 703 after removing in-place/out variants, 376 “core” operators

Conclusion + Takeaways

So, there you go :slight_smile: Unfortunately, looking at this breakdown, we can see that resolving PyTorch’s operator problem is not something that can be resolved by only tackling one part of the problem. A couple places on my mind to look:

  1. Scalar overloads. We have about 100 overloads that mention Scalar in them. However, I noticed that next to many of these entries, there’s a note that says “for C++ only, until we have conversion from C++ numbers to Tensor” (dated 2 years ago). Has that “until” happened yet? Can we make it happen?
  2. Overloads for BC(?) . Gregory Chanan mentioned that a significant number of our overloads exist only because we’re unable to add new default arguments for TS/Mobile compatibility reasons. That sounds like a problem that should be resolvable!
  3. Convolution Operators . We have 89 total, and 67 unique ones. And all of them are part of our operator surface. This shouldn’t be the case, and I’m looking forward to see somebody tackle this. The same goes for pooling operators (of which we have 42 unique ones).
  4. Composite pointwise/reductions. It’s possible to break these down into their constituent operators, the primary issue is how we can recover performance. Zachary Devito’s MinTorch has one proposal, perhaps there are others as well. Note that for some uses, performance doesn’t actually matter! For example, if you just want to propagate shapes or other metadata.
16 Likes

This is great work and gives great visibility for layers living under the PyTorch Op surface (like torch-mlir).

I actually noticed that I’d prefer scalars and tensors to be handled in separate functions, instead of one. For example I noticed at::add(Tensor,Tensor) was dispatched instead of at::add(Tensor,Scalar) while the second tensor was actually scalar and I need to handle it differently. Why is it important? Because naturally tensor exits on device’s storage (GPU memory) while scalar is always on CPU. So instead of giving dispatcher to handle it I need in my code to check if 2nd tensor is actually CPU scalar tensor and apply entirely different kernel to one that is scalar.

Also great summary and is very relevant for me (I work on OpenCL backend)

Can you share actual lists of operators you extracted (so I can understand the scope of work I need to do)

Can you share actual lists of operators you extracted (so I can understand the scope of work I need to do)

I think you’re just looking for this spreadsheet? FuncTorch Batching Rules Tracker - Google Sheets

@Chillee More than the current set of ops in the spreadsheet I would be interested in the scripts you used to generate this list and graphs. If you can share it - either as part of PyTorch itself or something standalone I think this needs to be tracked (as an FYI) with every commit going into PyTorch so down stream consumers either like @artyom-beilis for OpenCL or us for #torch-mlir are aware of these changes and can start enabling the corresponding support. Thanks again for the great work.

1 Like

The code for generating this isn’t actually that bad, see functorch/gen_data.py at main · pytorch/functorch · GitHub

It makes use of some (relatively) new APIs to print out say, composite ops, as well as some manual annotations, but is otherwise just a bunch of if/else rules.

There were some discussions about providing better annotations on ops that was more integrated in PyTorch core to make it easier for developers to interface with PyTorch, but not sure about concrete designs on that yet.

If there was some more details on what specifically you’d like to keep track of, I might have a better idea on how I can adjust these scripts to help out.

oh this is great. let me poke around and get back with specifics or PRs :smile:

Long term goal is to validate lowering of all (or most) PyTorch Ops down to torch-mlir so we can support any PyTorch workload across CPU/GPU/Accelerators.

This is great @Chillee! I agree - having this large operator surface area presents a lot of challenges. Speaking for Torch-MLIR, having these normalized or defined out of existence would be ideal.

I think a lot of this comes from being designed around op-by-op device dispatches, rather than being designed from the ground up around compiler fusion support and other things that compose nicely in a compiler. So adding new ops end up being the main extension point. We went down this path in TensorFlow as well (for similar reasons, I think), with a similar set of downsides.

The analogy I like to think of is: imagine if your C++ compiler didn’t support inlining (kind of analogous to fusion) – how much more convoluted our C++ code would be to work around that, manually inlining everything everywhere, having to use macros, etc.

@Chillee
Thanks for this post, it’s really helpful!
How’d you get the count of operators in JAX?

They were primarily taken from this page: jax.lax package — JAX documentation

This is great @Chillee. I cannot agree anymore when I stand at the point of AI chip manufactory. The category methodology in this post is from the code base of PyTorch, right? However, maybe there are some drawbacks to the current methodology.

  • An operator can be classified into more than one type.
  • many words like misc, weird, long tail
  • No unified underlying category methodology

Do you have any more ideas about the taxonomy, for example from computation pattern or something else? Thank you for your great work again.

1 Like

However, maybe there are some drawbacks to the current methodology.

Yes, I certainly agree. This taxonomy was primarily meant to get a high-level view of what the operators in PyTorch consist of - I certainly wouldn’t even guarantee that these tags are necessarily correct :stuck_out_tongue:

There is some current work (cc: @anjali411 ) on providing a more structured tagging system for operators, as well as a lot of ongoing work on providing decompositions of our operator set (see Tracing with Primitives: Update 0).

Thank you for your reply and telling me about the ongoing work.