Static Runtime - Design
Static Runtime was designed to empower rapid data flow optimizations without a need to consider the full space of valid TorchScript IR. It can exist within the TorchScript IR interpreter or as a standalone component capable of running full models. This interaction model fully embraces the idea that an interpreter is an elegant solution for a large class of high complexity computation that does not dominate runtime. However, for high performance workloads, the ability to switch into a fully static domain can yield extra performance.
This note will cover the current design of the project. Planned development and features are noted in the text. Commentary on the ecosystem surrounding Static Runtime, including usage, composability and next steps, can be found at the bottom of this document in the section labeled “Bigger Picture.”
Static Runtime heavily leverages existing TorchScript infrastructure and largely builds on the optimizations exposed by the module freezing pass. Currently, Static Runtime targets inference workloads on CPU.
Please consider this simple function:
def f(i0, i1, i2, i3, i4): a0 = torch.add(i0, i1) a1 = torch.mul(a0, i2) a2 = torch.mul(a1, i3) return torch.cat([a2, i4])
The associated IR of this function looks as follows:
Highlighted in grey are the user exposed values. These values must adhere to the semantics described by the program.
Highlighted in green are the input dependent values not exposed to user. These values cannot be inferred from the program and will require execution to compute. Note that these values have no need to adhere to any semantics of the program as long as we can compute the correct output and leave the inputs unchanged.
In white are values fully specified by the program — they are constant relative to any inputs and can be used for various optimizations before the first run.
Shown below is the graph generated by Static Runtime for execution of this function.
Static Runtime (SR)
There are a couple of small ideas to cover.
First, we can see that the function input
Value s have been realized as a vector of
IValue s. These
IValue s are actually empty before the first run and populated in-place when invoked.
Next, we see that the
Node s of TorchScript IR are now complex boxes with multiple components. These complex boxes represent
ProcessedNode s in the SR codebase.
ProcessedNode s contain a vector of inputs (by reference), an executable function, and a vector of outputs (by value).
After correctly linking the vector of inputs to the associated outputs of previous
ProcessedNode s or the vector of function inputs, we can easily execute each function in the node in topological order to compute the correct output. This technique avoids refcount bumps and provides a stable representation to permit analysis and memory optimizations.
A keen observer will notice that some of the original nodes in the TorchScript IR are gone. The first nodes to disappear are
prim::Constant s, which can be run ahead of time and baked into the
ProcessedNode functions either by lambda capture or codegen. Another class of nodes that disappear are
prim::ListConstruct (as well as
prim::TupleConstruct ), as the SR affiliated
aten::stack functions can handle variable length inputs. This is due to the static nature of the program: we can predetermine the number of inputs and skip the construction of a list.
Finally, note the two boxes in the top left labeled “Bytes.” These represent the actual memory backing the activations (activations are highlighted in green in the TorchScript IR graph). We can associate different tensors with the same memory as long as they never execute at the same time. This is straightforward to calculate with a Static Runtime graph, but requires knowledge about the size of the activations. Currently, SR tracks the previous size and assumes it will be consistent run to run, with a fallback to
resize_ if not. This is equivalent to the Caffe2 memonger pass. Future work will likely adopt the shape propagation enabled by meta tensors to avoid this step.
Static Runtime makes a couple important assumptions about models that users should be aware of. The assumptions attempt to reflect typical inference use cases. Any complex or subtle assumptions are verified and accounted for in the provided fusion pass.
Gradients in PyTorch use a tape-based system that is useful for eager but isn’t necessary in a graph mode. As a result, Static Runtime strictly ignores tape-based gradients. Training support, if planned, will likely require graph-based autodiff rather than the standard autograd used in eager-mode PyTorch.
CPU is currently assumed, but will not be going forward. Support for GPU is planned for H1 2021.
Immutable Module Attributes
This is an extension of the module freezing pass in TorchScript. All attributes, including custom user defined classes, must not be mutated over the course of execution. This frees SR to move, copy and pre-compute the results of transformations to module attributes.
No “Hidden” Writes to Aliased Memory
In-place and view-producing operations are allowed by SR when they do not write to inputs. The provided fusion pass will ensure this, so users do not need to worry about this subtle assumption unless they are trying to use Static Runtime with a full model. View-producing operations are handled in a special way by SR and they limit the ability for the runtime to optimize the memory plan.
There are a number of performance opportunities exposed by the above assumptions.
Tensor object overhead reduction
By reusing tensor objects and only remapping the underlying memory run-to-run, Static Runtime can avoid many refcount bumps from the creation and copying of tensor objects. Refcounts still exist in the framework, as not all operators have been mapped to variants that are compatible with pre-allocated tensor objects.
Dispatch overhead reduction
By avoiding device selection, autograd tracking, and more generally, the dispatch logic required for PyTorch eager, Static Runtime can jump into kernels more directly.
Constants in the IR can be preprocessed, inlined into C++ lambdas and even used directly for kernel generation. They can also help provide further refinement and reduce dispatch overhead. As an example, the scale factor in
aten::add can be shown to have no impact if equal to one, permitting dispatch to a pure point-wise addition kernel.
Static Runtime maps operators to variants it can use most efficiently. This typically includes PyTorch variants that can write to preallocated outputs, but can on the rare occasion include SR specific implementations that are not representable in the TorchScript IR (such as a var-arg
Memory overhead reduction
With a full view of the graph and a recording of the previous sizes associated with Tensors, Static Runtime can slab allocate memory at the beginning of each invocation. This reduces the number of calls to allocators as well as potential fragmentation. Although not enabled in all cases, it exposes the opportunity for run-to-run memory-reuse.
On top of slab allocation, Static Runtime’s wholistic understanding of memory access patterns enables it to optimize memory reuse within the model. Caffe2’s “memonger” employed a similar concept and it is a crucial feature for high performance models in domains of high memory pressure.
Low overhead profiling
Static Runtime has its own low-overhead profiler to reveal the distribution of operator execution times without inadvertently including other overheads.
The above example shows Static Runtime optimizing an entire TorchScript IR graph. However, TorchScript IR is a subset of Python that Static Runtime does not attempt to handle in full. Instead, subgraphs of a TorchScript graph can be lowered to Static Runtime. In Python this can be done with
torch._C._fuse_to_static_runtime(model) . This approach provides compatibility to a large set of models out of the box.
TorchScript IR is the basis of Static Runtime, so any passes that apply to TorchScript IR can be run before handing the execution to Static Runtime. This means performance gained from forms of tracing and transformation such as FX or torch.trace are fully compatible with Static Runtime. Further, annotations to the TorchScript graph do not change the ability for Static Runtime to analyze and correctly lower subgraphs. Thus, profiled graphs are fundamentally additive (by providing increased constant semantics to the subgraphs) and will likely yield benefits in the future. Finally, passes that fuse the graph or lower into different types of nodes (like NNC or quantization) can still run and may benefit from the memory optimizations of Static Runtime.
Static Runtime comes with a set of assumptions that are used to support high performance. These assumptions are possible to prove in more general contexts and many ideas in Static Runtime can and will be upstreamed into TorchScript proper. The flexibility of Python makes it very hard to correctly optimize without explicit user input, so Static Runtime will remain a crucial staging ground for evolving performance ideas.