Next Steps for PyTorch Compilers

At Facebook, the PyTorch Compiler team has been responsible for a large part of the backend development of PyTorch. We built TorchScript, and have recently been focusing on “unbundling TorchScript” into a collection of more focused modular products including:

  • PyTorch FX: enabling user defined program transformations
  • torch.package and torch::deploy: shipping Python to production environments and bypassing the Python GIL
  • Lazy Tensor Core: building new extension points for accelerators and compilers with an eager experience
  • A more focused TorchScript targeting mobile devices and some performance critical applications

In addition to these user facing changes, we have been investing in many parts of PyTorch, including: compiler analysis, optimizations, frameworks, and runtimes.

In this post, we wanted to talk about what comes next for PyTorch Compilers. There are huge challenges ahead, AI is shifting extremely fast, and our team needs to stay nimble in order to keep up. To that end, we are introducing the following big bets that will shape our investments in the coming years:

  • Compilers in Eager Mode. Using compiler technology to change how we implement PyTorch, both at compile time and at runtime.
  • Edge Devices. Help adapt PyTorch to the industry trend of explosive growth in running ML workloads on phones and other smart devices.
  • Next Generation Accelerators. Help catalyze an industry shift towards more flexible eager mode first accelerators.
  • Exploratory Research. New programming models often disrupt the AI industry, and we need to stay ahead of changing trends.

Compilers in Eager Mode

PyTorch’s success as a framework in large part comes from the usability benefits of eager mode and eager continues to be the predominant way people use PyTorch. PyTorch has been a leader in this area, and has proven that eager mode performance can beat less flexible graph mode frameworks on many workloads. Other frameworks have been playing catch up, and we need to lean into this key competitive advantage.

Across the industry, AI Compilers have been slow to adapt to this trend that PyTorch started. Almost all large AI Compiler projects (GLOW, XLA, TVM, etc) assume access to large and relatively restricted program graphs, under the false assumption one needs graphs to achieve industry leading performance. Meanwhile, our work on Package and Deploy has proven that a Python-first, eager approach to deployment is a good option for most users. We must reimagine what an AI Compiler looks like when deeply integrated with eager-mode execution, and explore more dynamic and flexible approaches and taking inspiration from successful JIT compilers in other domains (such as JavaScript).

A big investment in this area is using compilers to author operators in PyTorch, then back those operators with a JIT compiler that partially specializes. This will help us solve the “too many operators” problem, but can also lead to performance wins. The plans here are still being finalized, but for more details on the thinking see Python Operator Authoring w/ NNC.

There is also Lazy Tensors, which has the potential unlock speedups through fusion. This is still exploratory, but we are especially interested in lighter weight Lazy Tensors that only look at a sliding window of operations. Lazy Tensors provide a more accessible and painless way to compile ‘eager’ programs that have proven difficult to capture by other means. We will focus more this half on evaluating design space and solidify a design that makes the right trade-offs in terms of extensibility, hackability, low latency, and predictable performance for users. We plan is to use TorchScript IR as a backend extension point for Lazy Tensors.

Edge Devices

Another trend that we only imagine accelerating is the shift towards mobile and edge devices. There is enormous evidence for this shift as inference workloads increasing move from servers to user devices. Major vendors are now putting machine learning accelerators in phones. Edge device ML also helps with safeguarding user privacy by keeping sensitive data on the user device and not send it to servers. Edge is important today, but that importance will only grow in the years to come.

To facilitate this shift, the primary focus for the TorchScript stack (whole program capture) will become Mobile/Edge. TorchScript is well suited to the current needs of Edge devices, and we will prioritize supporting improvements in key areas like forwards/backwards compatibility and hardware devices (like AR/VR).

Additionally, we need to think one step ahead to what edge devices will look like in the future. Today, size (and other) requirements make running the full PyTorch eager experience impossible on many devices, but not all devices. We will explore and prototype possibilities that may help bring the full PyTorch eager experience to a wider range of environments.

Next Generation Accelerators

By far the most successful machine learning accelerator to date is GPUs. GPUs have been so successful for the exact same reason PyTorch is successful: usability. A long list of accelerators from other companies have failed because they make too many sacrifices to the user experience and are too inflexible. Industry accelerators that are in use today suffer from enormous usability issues. Many researchers with easy access to accelerators choose to use GPUs instead because of the usability restrictions of existing accelerators. Existing accelerators run a narrow set of workloads very well, but are not a model we want to emulate.

We believe next generation accelerators across the industry should be eager mode first and look more like GPUs. They should have a streaming programming model to hide kernel launch overheads. They should support general purpose workloads to allow flexibility and exploration of new model architectures. They should have advanced memory subsystems to support fast reconfiguration and dynamic allocation of storage. Techniques to get partial graphs, such as Lazy Tensors and explicit fusion, can provide speedups — but vendors should not require on these techniques to have competitive performance.

The biggest challenge and the main technical focus to make this easier is dealing with the 2000+ PyTorch operators. PyTorch lacks a minimal integration surface. Compilers and the operator authoring efforts described above can help here! By redefining PyTorch operators in higher level language, we can provide a smaller integration surface similar that provides tools for vendors to codegen implementations for all operators from a smaller core integration. This has many benefits in addition to backend extensibility, it will make PyTorch easier to maintain and also allow other types of extensions that would require O(num-operators) work.

Exploratory Research

The final big bet is around new programming models (many of those popularized by JAX). The AI industry has been going through radical shifts where every few years a new framework comes in and disrupts everything. PyTorch itself was one of these disruptors in the past, but to stay relevant we can’t sit still. Staying relevant means we need to understand the use cases that are driving the success of other frameworks, but also think about new ways to innovate in places where our users have pain points.

Many of the projects in this area will be exploratory prototypes, whose main goal is to learn something and help advance the state of the art in research. Some projects in this area will graduate the incubator and ship to production.

Functorch is the main example here so far, but we plan to dramatically scale up investments in this area. There are a few main categories of interest:

  • Transformations (grad/vmap/etc)
    • The main work in this area is around continuing the work of functorch and making it more production ready.
    • There are also research into new transformations like masking and control flow ops
  • AOT Compilation compatible with autograd (similar to jax.jit and NNC)
    • This is key to get performance in overhead bound use cases
    • There are many motiving use cases where people have been drawn to other frames because of the performance systems like this provide.
  • Distributed
    • Distributed is another fast growing area that needs innovation in programming models. Scaling models to use more data and compute has led to outsized wins recently in fields like Natural Language Processing (and, arguably, has always led to wins in AI) and it is crucial that PyTorch provide best-in-class tools and interfaces for scaling models across multiple devices and nodes. Contemporary frameworks such as TensorFlow and JAX use graph representations of programs and compiler transforms to implement distributed computation. However, we want to push the boundary of usability in this domain: how can we reconcile PyTorch’s expressive and easy-to-use Eager interface with scaling? We are working with the PyTorch Distributed team to do research and design into the best way to realize this via such techniques as sharding via multiple-dispatch and generalizing definitions of sharding/placement to encompass more of the Python language.

We are hiring!

I am incredibly excited about the road to come for PyTorch Compilers and PyTorch overall. We can’t build that future without you, so if you are interested please reach out to me.

8 Likes

If you take a look at the past 10 years, people tried adding JIT support to Python itself, e.g. PyPy, Pyston, Cinder. None of them support NumPy and other numeric/scientific/ML Python libraries very well. These libraries heavily rely on CPython C API as an extension point, which is theory is only an implementation detail for CPython, but becomes the de facto contract that requires unnecessarily or even prohibitively high support effort. On the other hand, given these libraries are already heavily optimized natively, JIT code-gen never made it to a performant alternative, probably never will. We are now stuck with CPython GIL, and fragmented AOT/JIT (Cython, Numba, Dask, CuPy, etc.) story for the whole Python ecosystem.

I really hope we don’t have to repeat this again and again.

Very clear and insightful! Few feedbacks align with the proposal:

1.FX: One best practice you may want to consider is to define a suite of FX API, in an object oriented way, to traverse, dep-analyze, replace and create graph nodes in an efficient manner. Combining profiler and visibility tools, not only to bring your own compiler/optimization, it will be a super easy to use tool for production as well as research.
2.Efficient deployment: For edge, you are betting on whole graph capture. By developing hybrid runtime, the same stack makes sense for serving on public cloud too.
3.Exploratory Research: By leveraging JAX/XLA, we explored the whole-graph captured compilation with auto distribution capacity. Pros are the performance, cons are the capacity(“usability”). For extreme performance with extra codesign effort, fully-static (shape & control flow) is preferred for both single device performance and distribution. To be more general, besides proposed topics, a pre-hand distribution schedule with memory(“or latency”) constrained compilation flow is required (“to be designed”); It’s good to have async tensor and scoped in-code communication-primitives capability.

Together, we can do more!

1 Like

@jansel thanks for the write-up!

Could someone give an example of a program transformation that was done with TorchScript in the past but is better done with torch.fx today? For that example, how is torch.fx better?

To be honest I’m now a bit confused with all these different ways to compile/package a model.
As far as I understand we now have:
torch.jit.trace (pytorch 1.0+)
torch.jit.script (pytorch 1.0+)
torch.fx (pytorch 1.8+)
torch.package (pytorch 1.9+)

Could you do a short description of the pros/cons of each, what it’s meant for, what it can do/cannot do?

As far as I understand:
torch.jit is the historic way to export/import/deploy models in a single package
torch.fx is more meant to tweak existing models, but isn’t designed to import/export.
torch.package is… meant to be a replacement for torch.jit ? Not clear.

Also, what’s the best way to get a clear graph description of a model, with the output tensor size at each node?

@jansel @wconstab

One thing I like about FX is the low barrier to entry. For example, I wrote an FX-based profiler in a few hours as a quick experiment to explore fusion opportunities. We have seen an explosion in cool applications of FX like that. With TorchScript there is a lot more to learn in order to get started, and TorchScript graphs look a lot less like the original program than FX graphs.

Torch.fx isn’t a packaging format or a compiler. To package a model in FX you can use either TorchScript or torch.package.

TorchScript takes your model outside of Python. Because of that, the functionality is limited and you lose access to the many features in PyTorch/Python and some models won’t work or be incorrect.

Torch.package leaves the model in Python, so it supports all the features of PyTorch. You have the option of TorchScripting the model after you load it from the package, so you could think of TorchScript more as a runtime than a storage format.

Thanks @jansel for the explanation, I understand better now.
Looks like an appropriate tool for my software TorchStudio, which needs to export any kind of local model for local or remote training.

However, looking at the tutorial it seems torch.package needs some extra input from the user to specify what modules are extern or intern.
1/ In order to automate this, can I provide it a list of all known installed package (or at least the big ones such as numpy and such) to be considered as extern, even if the model doesn’t use them ?
2/ …and then can I tell it to consider everything I didn’t list as intern ?
3/ What about torch::deploy, is it (or will it be) part of libtorch ? It doesn’t rely on extra dependencies other than what’s provided with libtorch ? I guess if a model use an extern module, then it won’t work ?

1/ In order to automate this, can I provide it a list of all known installed package (or at least the big ones such as numpy and such) to be considered as extern, even if the model doesn’t use them ?

That would work, yes.

2/ …and then can I tell it to consider everything I didn’t list as intern ?

Yes, a pattern like intern("**") will catch everything.

3/ What about torch::deploy, is it (or will it be) part of libtorch ? It doesn’t rely on extra dependencies other than what’s provided with libtorch ? I guess if a model use an extern module, then it won’t work ?

Today it is an optional part of the libtorch build (controlled by the USE_DEPLOY flag). Using torch and any part of the Python standard library will work today, but external modules won’t work without changing the PyTorch build. Our plan is to handle external dependencies better before releasing torch::deploy publically.