State of PyTorch core: September 2021 edition
There are a lot of projects currently going on in PyTorch core and it can be difficult to keep track of all of them or how they relate with each other. Here is my personal understanding of all the things that are going on, organized around the people who are working on these projects, and how I think about how they relate to each other.
This document is organized into a number of different categories which I found useful for grouping related projects together. Unless otherwise notated with [SHIPPED], all projects are works in progress.
Features
Flag ship features
Let’s talk about really big features that are coming down the pipe. I’ve limited the selection here to those that are particularly relevant to efforts in core.
- functorch (Horace He, Richard Zou). In this half, functorch is aiming to provide composable grad and vmap combinators similar to those provided by JAX, allowing users to transform single example programs into batched programs, and compose this with automatic differentiation. In the longer term, functorch also aims to provide a way to JIT compile PyTorch programs, and a general API for users to create their own, custom functional transformations on PyTorch programs.
-
Sharded tensor in distributed (Pritam Damania, Shen Li). Large scale models (esp transformer models) often introduce intra-layer parallelism by sharding parameters across multiple nodes. Sharded tensor seeks to provide sharded operations natively in PyTorch to support this use case. A primary use case is training personalization/ranking models.
- Unified sharded model parallel and collectives (Zachary DeVito, James Reed). This is a continuation on the work for sharded tensors, imagining an API that unifies sharding with other preexisting concepts in distributed into a single, uniform API, aka DistributedTensor.
- Lazy tensor (Will Constable). Many backends do not support eager mode execution, and it is necessary to lazily record a sequence of operations and then forward it on to the backend as a graph. This is a continuation of Alex Suhan’s LTC work, but with a new emphasis on integrating well with PyTorch core.
Backends
Backends are alternative implementations of PyTorch operations for other hardware (or less commonly, other runtimes). These also scale with the number of operators in PyTorch, although usually in a very simple way. These typically are driven by external partners.
-
[SHIPPED] XLA (Google). XLA is Google’s optimizing graph compiler that TensorFlow, JAX and torch_xla use as a backend. It is the only compiler that supports executing on Google’s TPU hardware offering, and XLA’s first to market, minimal operator set makes it a popular target for hardware startups who need to target some graph definition.
- Recently, XLA has been looking to add dynamic shape support based on bounded dynamic shapes supported by TensorFlow. https://docs.google.com/document/d/1c3fdF0P5fbdvLagFzbOqSbIaIcIKGCQ47Oxdm3qHaNY/edit
- ONNXRT (Microsoft). ONNXRT is an inference runtime built on top of the ONNX runtime. Microsoft has been looking to collaborate more closely with PyTorch’s internal integration points.
- [SHIPPED] IPEX (Intel). Intel works on CPU performance in PyTorch, and they use IPEX as a staging ground for upcoming CPU optimizations and novel memory layouts (e.g., MKLDNN layout).
- Vulkan (Facebook). Vulkan allows us to program GPUs on Android devices. Vulkan does not support Tensor views.
Alternative tensors
Not as big as flag ship features, features that involve adding a new type of tensor to PyTorch still typically have close interactions with core.
- [SHIPPED] Conjugate views (Anjali Chourdia). Conjugate views allow for lazy, O(1) conjugation operation on complex tensors, allowing downstream kernels to choose to do fused conjugate-and-then-operation.
- Zero tensors (Anjali Chourdia, Alban Desmaison). Zero tensors are an immutable, O(1) representation for zero tensors of arbitrary shapes. These simplify autograd computation (which frequently needs to materialize zero tensors in cases where there is no gradient).
- Linear operators (Anjali Chourdia). Linear operators encode the linear algebra structure of tensors, allowing for optimizations based on linear algebra equalities. A preexisting, out-of-core implementation of these concepts currently exists in GPyTorch. https://gpytorch.ai/
- [SHIPPED] Meta tensors (Edward Yang). Meta tensors are tensors but without any data associated with them. They can be used to do shape inference. Structured kernels implicitly support meta tensors; meta implementations of operators can also be written by hand.
- [SHIPPED] Crypten. CrypTensors are privacy preserving tensors whose data is not directly accessible but instead secret shared across multiple nodes. This is an out-of-tree research project built on top of PyTorch. GitHub - facebookresearch/CrypTen: A framework for Privacy Preserving Machine Learning
- Nested tensors (Christian Puhrsch). Nested tensors are tensors with irregular sizes, for working with data that doesn’t have uniform size/length. GitHub - pytorch/nestedtensor: [Prototype] Tools for the concurrent manipulation of variably sized Tensors.
- Masked tensors (Christian Puhrsch). Masked tensors rationalize reduction and normalization operators which take a mask saying what inputs are valid. RFC-0016: Masked reductions and normalizations by cpuhrsch · Pull Request #27 · pytorch/rfcs · GitHub
- Finite tensors (Edward Yang). A proof of concept tensor that is guaranteed to be finite, for debugging exploding gradients. Non-NaN tensors would also proceed similarly. Add example of FiniteTensor with __torch_dispatch__ by ezyang · Pull Request #62819 · pytorch/pytorch · GitHub
Infrastructure
One of the goals of PyTorch core is to provide key, low level abstractions that help us build the features described above. In this section, I’ll talk about both old and new infrastructural pieces in PyTorch, to help build a picture about what tools are available to you if you are embarking on a project.
-
[SHIPPED] Dispatcher / TORCH_LIBRARY operator registration (Composability team). The C++ dispatcher is our API that allows backends to register their own implementations of operators. This registration interface is a closed universe: there are a fixed set of permissible backends hardcoded into PyTorch, with a cap on the total number of backends we can support (64). The backwards compatibility story for this API is not great; BC-compatible changes to operators translate into BC-breaking changes for backends. Registering a Dispatched Operator in C++ — PyTorch Tutorials 1.9.0+cu102 documentation
- Most backends make use of the operator registration API to register their backend-specific implementations of operators; similarly, functorch, lazy tensor, conjugate views, meta tensors and nested tensors use the dispatcher to register their customized behavior.
-
[SHIPPED] Boxed fallback (Composability team). Boxed fallbacks let you write a single, generic implementation that can be used for many operators, instead of having to template metaprogram each operator individually. This comes at a slight performance cost (10%), but boxed fallbacks typically have much smaller binary size than their unboxed counterparts. pytorch/backend_fallback_test.cpp at master · pytorch/pytorch · GitHub
- functorch used to write template metaprograms for their batching rules, but switched to boxed fallback; in their words: “Very great way for handling large classes of operators”.
- Conjugate views use boxed fallback to materialize conjugated tensors before running operators that don’t support fused conjugate operations. Zero tensors are likely to use a similar mechanism to materialize traditional zero-filled tensors.
-
[SHIPPED]
__torch_function__
(Composability team).__torch_function__
provides the ability to override the behavior of torch namespace functions from Python. This mechanism bypasses PyTorch entirely: it effectively replaces programs that call torch.add with your_add. Programs using__torch_function__
integrate with PyTorch in a shallow way: for example, you cannot pass a__torch_function__
bearing object as a gradient in PyTorch’s autograd pass. Extending PyTorch — PyTorch 1.9.0 documentation-
torch.fx (not listed above) is entirely implemented using
__torch_function__
. I mention it here because it is a good illustration of the benefits and downsides of__torch_function__
. Benefit: as__torch_function__
is an entirely Python level interposition mechanism, torch.fx can also trace dynamic sizes (as it receives all Python objects as-is) and integrate with Python-only concepts like nn.Module. Downside: you cannot get autograd differentiated programs with torch.fx. -
Crypten uses
__torch_function__
to allow use of torch namespace functions with CrypTensors. Because of the lack of AD support, Crypten had to reimplement their own version of the autograd engine for their project. -
__torch_function__
provides a simple way of prototyping alternative tensors which are simply syntax sugar on top of preexisting PyTorch operations. For example, masked tensors was initially implemented in this way.
-
torch.fx (not listed above) is entirely implemented using
-
[SHIPPED]
__torch_dispatch__
(Edward Yang, Horace He).__torch_dispatch__
lets users define custom tensor types with custom operator behavior that they implement in Python. There are two ways to compare__torch_dispatch__
to the other mechanisms we described above; you can think of it as a Python version of the C++ dispatcher API, or you can think of it as a more deeply integrated variant of__torch_function__
.__torch_dispatch__
is a completely open extension mechanisms (users can create as many Tensor subclasses as they want).__torch_dispatch__
integrates smoothly with other PyTorch functionality (e.g., autograd, batching, automatic mixed precision); but because of this tight integration, the operators you have to implement don’t necessarily correspond to high level user API. Dispatch to Python · Issue #59049 · pytorch/pytorch · GitHub- For performance reasons, functorch implements all of their batching rules in C++. To permit FX style tracing of functorch programs after applying all functional transformations, their tracer uses
__torch_dispatch__
to call back to FX from C++. - There is a POC showing how to remove Crypten’s autograd implementation using
__torch_dispatch__
at POC: Use dispatch to Python to implement Crypten autograd by ezyang · Pull Request #290 · facebookresearch/CrypTen · GitHub . This POC shows how to combine both__torch_function__
and__torch_dispatch__
, where__torch_dispatch__
is used to do the bulk of operator overriding, but__torch_function__
can be occasionally used to override high level Python concepts when necessary (at the cost of needing to, e.g., reimplement autograd and everything else from scratch). -
__torch_dispatch__
is a good choice for prototyping alternative tensors projects that need a more deep integration with PyTorch, e.g., with autograd.
- For performance reasons, functorch implements all of their batching rules in C++. To permit FX style tracing of functorch programs after applying all functional transformations, their tracer uses
-
[ALMOST SHIPPED] Python mode key (Richard Zou).
__torch_dispatch__
ordinarily is only triggered if a Tensor input to a function has a custom dispatch function defined. Python mode key makes it possible to globally override the behavior of all operators, including factory functions, with a context manager, so you don’t have to do value dependent dispatch. In JAX, this is referred to as omnistaging. [Reland] Add python mode by zou3519 · Pull Request #64360 · pytorch/pytorch · GitHub- functorch can make use of Python mode key to override the behavior of factory functions. It can be used by distributed to override factory functions in nn.Modules to return meta tensors rather than CPU tensors.
- Finite tensors are better implemented as a mode as it makes it easier to just verify that ALL operations don’t return infinity (as opposed to only tensors which are propagated as finite tensors).
-
C++ open multiple dispatch (Alban Desmaison, Edward Yang).
__torch_dispatch__
is a very expressive, open extensibility mechanism, but it can only be operated from Python. In some situations, we want the expressivity of__torch_dispatch__
, but with the performance and portability of C++. C++ open multiple dispatch is a planned project to bring the analog of__torch_dispatch__
as a C++ API. Python-style open multiple dispatch in C++- functorch would make use of this mechanism after it is implemented. It may also be a possible implementation strategy for other alternative tensors projects that need to be implemented solely in C++ for performance reasons. It can also serve as a way for out-of-tree backends to experiment without having to request a dispatch key from core.
-
Open device registration (Brian Hirsh). The dispatcher supports only 64 dispatch keys and we are constantly butting up against this limit. Part of the reason we hit this limit is because we use O(n^2) keys to represent the Cartesian product of backends and functionality (e.g., AutogradXLA). This project seeks to represent these separately, so the number of keys scales linearly in the number of backends + functionality. An out-of-date initial exploration doc is Open device registration
- Conjugate views would switch to per-backend keys to allow backends to independently vary whether or not they support fused conjugate operations or not.
- It would be easier for us to give backends dispatch keys when they request it, without worrying about running out of space.
Operator decomposition / introspection (MinTorch in PyTorch)
Operator decomposition work seeks to reduce the number of operators necessary for backends to implement to get a working backend. Operator introspection seeks to make it easier to tell when multiple operators do “the same thing” so that they can be treated uniformly. A large number of ideas for what to do on this front are recorded at MinTorch in PyTorch, but here I’ll talk solely about efforts that are actively being worked on right now.
Decomposition:
-
Functionalization (Brian Hirsh). Functionalization eliminates the need to implement inplace/out variants of functions; after being functionalized a PyTorch program consists solely of functional operations (this nearly halves the number of operators you need to support).
- It is too difficult for functorch to directly implement batching rules for inplace operations, and they intend to rely on functionalization to support PyTorch programs that have mutation.
- XLA already implements functionalization to target the XLA language (which does not support mutation); the functionalization pass would subsume this preexisting functionality and alleviate Google from needing to maintain this functionality.
- Vulkan does not support Tensor views; functionalization allows programs that mutate aliased tensors to work correctly on mobile.
-
Convolution consolidation (Joel Schlosser). There are over 80 convolution operators, and the goal is to unify these to greatly reduce the number.
- After this work, backends would no longer have to override a special
convolution_overrideable
, they would be able to just overrideconvolution
directly. functorch and other cross-cutting features would only need to write a batching rule for convolution a few times (rather than eighty times). It will also simplify user experience for users of**__torch_dispatch__**
.
- After this work, backends would no longer have to override a special
Introspection:
-
Operator tagging (Edward Yang). Currently, native_functions.yaml is an undifferentiated mass of operators. Multiple efforts have been made to categorize the operators, but none in a durable place that would be kept up-to-date. The goal is to incorporate these categories directly into PyTorch and make them computationally relevant, so that we can maintain them and keep them up-to-date. WIP: A tagging system for operators
- Backends could make use of these tags to help understand the set of all operators in PyTorch, and what they need to cover (this was explicitly asked for by Microsoft).
- functorch made use of their earlier categorization work to help understand what kinds of different regimes operators fall into, and handle multiple operators all at once with generic implementations.
Operator authoring
Operator authoring work seeks to change how we write operators internally in PyTorch, to enable various cross-cutting use cases.
Low level (how do you actually write a kernel):
-
Ufunc codegen (Edward Yang). The code written for pointwise operations using TensorIterator is very regular; the goal of ufunc codegen is to greatly reduce the boilerplate required for these operators, so that we can more easily reuse them in other contexts.
- After this work, there will be a canonical definition of all pointwise operators in PyTorch, making it easier for backends and functorch to write a single implementation that covers all pointwise operations.
- This work may make use of operator tagging to express information about pointwise operators that is necessary to drive codegen (e.g., is the result tensor the same dtype as the input tensors).
- Reduced API surface area for TensorIterator may make it easier to do refactors to TensorIterator for performance
-
Authoring operators in NNC (Bert Maher, Jason Ansel). NNC is a JIT compiler, allowing it to compile more specialized implementations of operators, greatly reducing overhead and also achieving asymptotic speed ups in some cases. The goal of this workstream is to make it possible to directly author eager mode operators in NNC, turning on JIT by default for eager mode execution.
- This project takes the “replace TensorIterator wholesale with an entirely new component” approach, whereas ufunc codegen and other efforts to refactor TensorIterator are a more conservative “incrementally improve”. One decision we need to make is whether or not TensorIterator should be replaced with an approach that is loosely coupled (go straight to NNC) or tightly coupled (refactor TensorIterator). More discussion at: Sep 2 authoring operators in NNC
- Interpreting operators directly in CUDA (Zachary DeVito). Conventionally, you have to compile a full CUDA kernel for every combination of pointwise operations you want to fuse together. Could you instead write a single interpreter in CUDA that lets you dynamically vary what operations you do, while maintaining performance?
High level (how do you put multiple kernels together):
-
[NOT BEING WORKED ON] Authoring to support dynamic shapes (Richard Zou). Today in PyTorch, mechanisms that trace through PyTorch C++ lose the ability to track the provenance of integer shapes, because they are represented as
int
and not a symbolic integer. This project looks to find short term updates, including rewriting some composites as Python operators and changing autograd derivative formulas to be more friendly to dynamic tracing. This proposal is purely in discussion phase and no one is working on it right now.- Sub-project: Support dynamic shapes in lazy tensor (Nick Korovaiko). This is specifically making it possible to track dynamic shapes from the PyTorch frontend UI, by overriding size() to return a symbolic size tuple / tensor and then overriding operators to intercept calls that make uses of sizes and handle symbolic inputs directly. Nick is working on this in the context of LTC, in Dynamic Shapes; we also recommended that Google XLA team attempt this for their bounded dynamic support.
-
Authoring composites in Python (Mike Ruberry, Natalia Gimelshein). Authoring to support dynamic shapes proposes just tactically duplicating functions in Python when necessary. To eliminate this duplication, authoring composites in Python proposes only writing composite operators in Python, and then transpiling them to C++. This would make it easier to prototype new composites, also make it possible to eliminate intermediate dispatches when transpiled to C++. [WIP] Ops in Python by mruberry · Pull Request #63643 · pytorch/pytorch · GitHub
- This project is often compared to authoring operators in NNC, but it operates at a higher level, targeting only composites, and not actual kernel definitions of operators. It is a long term approach to solving dynamic shapes because Python code is easier to abstract interpret than C++ code.
- This potentially could integrate with authoring operators in NNC or ufunc codegen by offering an easy API for representing compositions of multiple pointwise operations, and then operating the compiler to generate the fused version of the operator.
-
Authoring shape functions in Python (Elias Ellison). Shape functions are typically written in C++, but if they were written in Python and then parsed by the TorchScript frontend, they can be evaluated symbolically, and, e.g., fed to Z3 or just reasoned about in the compiler.
- This is similar to authoring composites in Python where in both cases, code that was previously written in C++ is written in Python; however the nature of the code (composite operators versus shape checking code is quite different). Additionally, the two projects use different technology for processing the Python code (authoring composites is a lightweight transpiler based off of Python ASTs directly, whereas shape functions is based off of TorchScript compiler stack).
- Symbolic shape functions are potentially useful for XLA which needs to know how to propagate shape upper bounds; right now, they simply generalize their lowering functions on a per operator basis to propagate upper bounds. Similarly, symbolic shape functions are helpful for dynamic shapes, as they can properly propagate non-integer shapes.
Take aways
Having written everything down here, I have some high level takeaways that I’d like people to think about:
- We seem generally understaffed in operator decomposition / introspection, with a lot of our existing bandwidth tied up in making various functorch mechanisms production ready for all of PyTorch. We also feel understaffed in supporting external backend efforts.
- There is still a tension where we have both a Python transpiler implemented in Python (authoring composites in Python) versus TorchScript (authoring shape functions in Python). The decisions made by both projects locally make sense, but it would be nice to have a long term plan here.
- There already a ton of alternative tensor projects, but I bet there could be even more.
- We need to make a choice between high coupling (ufunc codegen) versus low coupling (authoring operators in NNC) strategy to TensorIterator replacement. And in general, there is not much current work on what it would take to turn on laziness by default (as envisioned by MinTorch in the limit).
- There is a lot of uncertainty in the space of distributed, especially regarding sharded tensors versus unifying abstractions.