I have always had trouble explaining to outsiders what PyTorch compilers truly are. In my previous careers, I could point to an anchoring infrastructure. But, PyTorch compilers have been constantly morphing – with new solution stacks or IRs popping up every half, each partially overlapping with previous solutions. Why is that?
PyTorch compilers are plural because there is no single mechanism to convert PyTorch programs into IR forms (or graphs ).
Other ML frameworks have either graph abstractions built into the programming model (e.g., TF) or the evaluation model (e.g., TVM), or a language frontend (e.g., Relay) that can be deterministically converted into IRs. In contrast, graph capture for an eager-first ML framework like PyTorch is non-trivial and design space in itself . To a large extent, the solution space of PyTorch compilers reflects the evolution or fragmentation of PyTorch graph capture mechanisms.
The rest of the post dives into the nuances of the PyTorch graph captures and tries to offer a framework for understanding all of them.
The following table buckets different PyTorch graph captures based on UX.
We gauge user experiences based on the following metrics:
- Amount of user efforts required or allowed
- Flip-switch : minimal user efforts but not much customization of capturing scope;
- User-directed : need deeper knowledge to successfully capture a graph; it often implies adoption barriers (negative) and/or customizability (positive);
- Whether graph capture is guaranteed to succeed (i.e., always-succeed vs best-effort )
- Whether graph (re)play is sound or not (i.e., sound vs unsound )
- Whether whole-graphs or multiple partial graphs are captured, which determines characteristics of the replay execution (e.g., whether Python fall-backs are needed)
To aid references, we give short names for each bucket
Out-of-box : flip-switch, always-succeed capture (not necessarily whole-graph), and sound replay w/ Python fallbacks. This one is our best UX option.
- Examples: Lazy Tensor, TorchDynamo
Human-in-the-loop : user-directed best-effort whole-graph capture. This option may have a steep tryout cost upfront but allows customization.
Best-effort : flip-switch, always-succeed whole-graph capture, and unsound replay. This option is easy to try out.
Two traits influence the soundness and usability of a graph capture & replay system.
- The ability to (transparently) skip unwanted ops in the capture scope, which determines whether capture is guaranteed to succeed. Unless we intend to develop a Python compiler, graph IR for an ML compiler cannot be the same as Python IR. Thus, a sound graph capture must be able to exclude Python ops that are not supported by the graph IR, preferably transparently.
- Multi-graph w/ Python fall-back : the system may capture multiple partial graphs to safely skip any unsupported Python ops in the captured scope. Because of the interleaving of Python IRs and multiple partial graphs, the (re)play system often supports Python execution (or fallback), where the control to enter and exit the execution of a partial graph is often transparent to programmers.
- Whole-graph : the system captures a single graph for the entire capture scope. If the capture scope contains Python construct unsupported by the graph IR, the system may fail to capture a graph. In a single-graph system, users can explicitly replay a captured graph. And the replay system does not have to support Python execution. All existing capture-and-replay systems are whole-graph systems.
Interaction between capture and replay , which determines the soundness of replay.
- Recapture-and-play (i.e., capture many times and play one or more times). Such systems check whether captured graphs match the replay context and re-capture if mismatched.
- Capture-and-replay (i.e., capture once and replay many times). Such a system requires users to guarantee the soundness of the replay.
The following table classifies existing graph capture along the two dimensions.
- Quadrant (IV) (i.e., Out-of-box) is sound because it can transparently skip unwanted Python constructs (through multi-graph capture w/ Python fallback) and support sound replay (via recapture-and-play).
- Quadrant (II) (i.e., human-in-the-loop, best-effort ) is good for export path because it’s easier to export a whole graph than multiple partial graphs. More importantly, if the execution environment does not support Python fallback, then Quadrant (II) is the only viable solution.
- Quadrant (I) is the best of two worlds. We do not have any solution in this quadrant, but there may be a space for innovation to improve Quadrant (IV) solutions to capture more and more whole graphs (perhaps via user intervention).
- Quadrant (III) does not make a lot of sense.
A system can capture graphs at different stage of a model execution life cycle, which leads to different overhead, IR semantics, and composability.
Before execution by examining Python bytecode (e.g., TorchDynamo) or AST (e.g.,
torch.jit.script). This is also called zero-overhead graph capture.
Tracing-based that captures graphs during the execution of Python programs. PyTorch provides two tracing mechanisms that capture IRs w/ different semantics:
- Python-level tracing via
torch_function,which captures IR at Python level (e.g., FX).
- C+±level dispatcher tracing via either custom dispatcher key (e.g., Lazy Tensor,
torch_dispatch(e.g., AOTAutograd), which captures streams of aten ops (i.e., aten IRs).
- Python-level tracing via
- Before-execution capture is zero-overhead;
- Tracing-based systems always incur overhead, either during warm-up time as in capture-and-replay systems, or recurring as in recapture-and-play systems.
On composability w/ PyTorch core extension points
- Before-execution graph capture is the least composable because it 1) takes additional handling to “see through” functions; 2) cannot access C++ (dispatcher) level semantics;
- Tracing-based graph capture is composable w/ functional transforms as it may naturally trace through functions (incl. first, higher-order functions);
- Furthermore, dispatcher-level tracing is the most composable as it can transparently incorporate dispatcher-level semantics like autograd and vmap.
On lowering to aten IRs
Dispatcher-level tracing has a huge advantage of lowering to Aten IRs in a way that is naturally consistent w/ eager execution. For instance, TorchDynamo uses
torch.jit.trace() to lower from captured Python Bytecode graphs to TorchScript IR. This process is sound because TorchDynamo captured graphs contain no control flow and are shape-specialized. The ongoing explorations of combining TorchDynamo w/ AOTAutograd to capture both forward and backward graphs together and to combine TorchDynamo w/ LazyTensor tracing are all examples of combining before-execution graph capture with tracing-based graph capture that is still sound.
The following picture summarizes different aspects of evaluating a graph capture. This post is just the starting point of understanding the existing design space of PyTorch compilers to lay the foundation for building a truly composable and more stable PyTorch compiler stack.