At Facebook, the PyTorch Compiler team has been responsible for a large part of the backend development of PyTorch. We built TorchScript, and have recently been focusing on “unbundling TorchScript” into a collection of more focused modular products including:
- PyTorch FX: enabling user defined program transformations
- torch.package and torch::deploy: shipping Python to production environments and bypassing the Python GIL
- Lazy Tensor Core: building new extension points for accelerators and compilers with an eager experience
- A more focused TorchScript targeting mobile devices and some performance critical applications
In addition to these user facing changes, we have been investing in many parts of PyTorch, including: compiler analysis, optimizations, frameworks, and runtimes.
In this post, we wanted to talk about what comes next for PyTorch Compilers. There are huge challenges ahead, AI is shifting extremely fast, and our team needs to stay nimble in order to keep up. To that end, we are introducing the following big bets that will shape our investments in the coming years:
- Compilers in Eager Mode. Using compiler technology to change how we implement PyTorch, both at compile time and at runtime.
- Edge Devices. Help adapt PyTorch to the industry trend of explosive growth in running ML workloads on phones and other smart devices.
- Next Generation Accelerators. Help catalyze an industry shift towards more flexible eager mode first accelerators.
- Exploratory Research. New programming models often disrupt the AI industry, and we need to stay ahead of changing trends.
PyTorch’s success as a framework in large part comes from the usability benefits of eager mode and eager continues to be the predominant way people use PyTorch. PyTorch has been a leader in this area, and has proven that eager mode performance can beat less flexible graph mode frameworks on many workloads. Other frameworks have been playing catch up, and we need to lean into this key competitive advantage.
A big investment in this area is using compilers to author operators in PyTorch, then back those operators with a JIT compiler that partially specializes. This will help us solve the “too many operators” problem, but can also lead to performance wins. The plans here are still being finalized, but for more details on the thinking see Python Operator Authoring w/ NNC.
There is also Lazy Tensors, which has the potential unlock speedups through fusion. This is still exploratory, but we are especially interested in lighter weight Lazy Tensors that only look at a sliding window of operations. Lazy Tensors provide a more accessible and painless way to compile ‘eager’ programs that have proven difficult to capture by other means. We will focus more this half on evaluating design space and solidify a design that makes the right trade-offs in terms of extensibility, hackability, low latency, and predictable performance for users. We plan is to use TorchScript IR as a backend extension point for Lazy Tensors.
Another trend that we only imagine accelerating is the shift towards mobile and edge devices. There is enormous evidence for this shift as inference workloads increasing move from servers to user devices. Major vendors are now putting machine learning accelerators in phones. Edge device ML also helps with safeguarding user privacy by keeping sensitive data on the user device and not send it to servers. Edge is important today, but that importance will only grow in the years to come.
To facilitate this shift, the primary focus for the TorchScript stack (whole program capture) will become Mobile/Edge. TorchScript is well suited to the current needs of Edge devices, and we will prioritize supporting improvements in key areas like forwards/backwards compatibility and hardware devices (like AR/VR).
Additionally, we need to think one step ahead to what edge devices will look like in the future. Today, size (and other) requirements make running the full PyTorch eager experience impossible on many devices, but not all devices. We will explore and prototype possibilities that may help bring the full PyTorch eager experience to a wider range of environments.
By far the most successful machine learning accelerator to date is GPUs. GPUs have been so successful for the exact same reason PyTorch is successful: usability. A long list of accelerators from other companies have failed because they make too many sacrifices to the user experience and are too inflexible. Industry accelerators that are in use today suffer from enormous usability issues. Many researchers with easy access to accelerators choose to use GPUs instead because of the usability restrictions of existing accelerators. Existing accelerators run a narrow set of workloads very well, but are not a model we want to emulate.
We believe next generation accelerators across the industry should be eager mode first and look more like GPUs. They should have a streaming programming model to hide kernel launch overheads. They should support general purpose workloads to allow flexibility and exploration of new model architectures. They should have advanced memory subsystems to support fast reconfiguration and dynamic allocation of storage. Techniques to get partial graphs, such as Lazy Tensors and explicit fusion, can provide speedups — but vendors should not require on these techniques to have competitive performance.
The biggest challenge and the main technical focus to make this easier is dealing with the 2000+ PyTorch operators. PyTorch lacks a minimal integration surface. Compilers and the operator authoring efforts described above can help here! By redefining PyTorch operators in higher level language, we can provide a smaller integration surface similar that provides tools for vendors to codegen implementations for all operators from a smaller core integration. This has many benefits in addition to backend extensibility, it will make PyTorch easier to maintain and also allow other types of extensions that would require O(num-operators) work.
The final big bet is around new programming models (many of those popularized by JAX). The AI industry has been going through radical shifts where every few years a new framework comes in and disrupts everything. PyTorch itself was one of these disruptors in the past, but to stay relevant we can’t sit still. Staying relevant means we need to understand the use cases that are driving the success of other frameworks, but also think about new ways to innovate in places where our users have pain points.
Many of the projects in this area will be exploratory prototypes, whose main goal is to learn something and help advance the state of the art in research. Some projects in this area will graduate the incubator and ship to production.
Functorch is the main example here so far, but we plan to dramatically scale up investments in this area. There are a few main categories of interest:
- Transformations (grad/vmap/etc)
- The main work in this area is around continuing the work of functorch and making it more production ready.
- There are also research into new transformations like masking and control flow ops
- AOT Compilation compatible with autograd (similar to jax.jit and NNC)
- This is key to get performance in overhead bound use cases
- There are many motiving use cases where people have been drawn to other frames because of the performance systems like this provide.
- Distributed is another fast growing area that needs innovation in programming models. Scaling models to use more data and compute has led to outsized wins recently in fields like Natural Language Processing (and, arguably, has always led to wins in AI) and it is crucial that PyTorch provide best-in-class tools and interfaces for scaling models across multiple devices and nodes. Contemporary frameworks such as TensorFlow and JAX use graph representations of programs and compiler transforms to implement distributed computation. However, we want to push the boundary of usability in this domain: how can we reconcile PyTorch’s expressive and easy-to-use Eager interface with scaling? We are working with the PyTorch Distributed team to do research and design into the best way to realize this via such techniques as sharding via multiple-dispatch and generalizing definitions of sharding/placement to encompass more of the Python language.
I am incredibly excited about the road to come for PyTorch Compilers and PyTorch overall. We can’t build that future without you, so if you are interested please reach out to me.