State of symbolic shapes branch

State of symbolic shapes: Feb 11 edition

Previous update: State of symbolic shapes branch - #40 by ezyang

Executive summary

It’s the eve of the branch cut for PyTorch 2.0. Training support is still not landed to master, though Horace says he’s made progress against some of the FSDP bugs plaguing his PR. Brian’s functionalization for inference is on the cusp of landing, after working through considerable flakiness in Inductor. The general mood is that we need more inductor developers. We had a lot of forward looking design discussion this week, check the bullets for details.

  • Tracing speed improved on master. Fast path binary ops in fake tensor by ezyang · Pull Request #94047 · pytorch/pytorch · GitHub has made it to master; on some models, it can cut dynamic shapes tracing time in half. There is some low hanging fruit for extending the approach here to more operators. In general, we’ve found improving trace time challenging as Python profiles tend not to identify problems directly (one thing we note is that we don’t get aggregation on a per-op basis because of all the metaprogramming we do; another hypothesis for our difficulty is the amount of Python-C++ crossings we tend to do). We don’t think Sympy is the main cause of slowdown, but Sympy can cause other problems when it gets into pathological cases (see below.) We need to start tracking this metric in our runs.
  • Dynamic shape induced graph breaks are way down. Voz and Joel have done a lot of good work hosing down extra graph breaks from dynamic shapes. The remaining holdout is avoiding graph breaks on negation; fixing the graph break is easy but the resulting graph appears to cause Inductor to infinite loop; Voz was unsuccessful at diagnosing the problem.
  • Dynamic by default. jansel has argued us back into goaling on dynamic by default. The general idea is that we should be leaning into PyTorch’s historic flexibility and support for dynamic shapes. Jason’s straw proposal is the first time we assume it’s static, and upon recompile we compile with dynamic shapes. Elias cautions that if there aren’t too many dynamic shapes, it may be better to generate separate specialized kernels for each and retain use of CUDA graphs and autotuning. Most of us were generally positive on this idea, but concerned about the implementation: dynamic shapes trace time, greater bug surface and lower generated code quality. Still, we can probably keep moving PT2 in this direction, but more data is also necessary.
  • Fine-grained dynamic dimensions. We want an API for marking dimensions as being explicitly dynamic. There are two primary users for this: (1) export, where we want to reject specialization/guards on dimensions that are explicitly dynamic as these usually indicate tracing problems and (2) eager mode power users, who want fine grained control over what dimensions should be compiled dynamically–e.g., to maximize performance, avoid unnecessary recompilation, or just diagnose why a model is recompiling when it shouldn’t. The fine-grained API is not intended to be the initial starting point for regular users: normal users should be able to not annotate anything and get intelligent results (see bullet above.) We went through a lot of API variations, but our current plan is to ship an API that lets you mark tensors as having dynamic dimensions, which affects how torch.compile does compilation. In the long term, we intend for eager mode to propagate these annotations using symbolic meta formulas (which is important for making this work across graph breaks). These annotations will NOT do anything if inserted inside a model; they only work at the “top level” where you trigger compilation. Some minutes for this discussion at https://docs.google.com/document/d/1aoIyYE8_6cYpWqS25thzVoIiKsT5aaUEOiiPwbIXt8k/edit Sherlock has an out-of-date prototype of the API at Fine grain dynamic shape by SherlockNoMad · Pull Request #93813 · pytorch/pytorch · GitHub
  • torch.constrain. When you mark a dimension as dynamic, you reject any user code which would constrain the dimension. But what if the constraint is that the dimension != 0; you would like a way to say that this constraint is acceptable. torch.constrain is an API that can be called inside the model to indicate these constraints should be accepted. Originally, Voz and Horace were imagining that you could put arbitrarily Python expressions in these constraints. We’ve negotiated to only allow a more limited set of constraints to start: min/max and multiple-of. We plan not to allow “relational” constraints that relate two separate symbolic variables: we simply assume that these are always OK. This means that export still needs some mechanism for communicating “implicit” guards (much in the same way we implicitly guard on the dtypes of all tensor inputs.) This API will likely also be used by unbacked SymInts. A big benefit of only allowing min/max constraints is they are easy to implement; you do not need to use Z3.
  • Unspecialize is adding too many float/int arguments. Horace noticed that forward graphs with dynamic shapes turned on have a lot of SymInt/SymFloat arguments. This is because we set specialize_int_float = False with dynamic shapes, but it turns out there are quite a few int/floats we SHOULD be specializing on. Horace is looking into this more.
  • SymInts/SymFloats in the backend compiler calling convention. In several conversations this week, we discussed whether or not Inductor should accept graphs that ONLY have Tensor inputs, or are int/floats valid inputs to the graphs. jansel and Natalia argued that it works well to turn int/floats into 0d cpu scalar tensors, and inductor supports this well. However, Horace pointed out that this does not work in general: if you have a backward graph for x.sum(), you need to do an x.expand(sym_int), where the SymInt isn’t necessarily derivable from any input to the backward graph. You can make it work, but the price is an FX graph that isn’t executable from Python anymore. So our current thinking is that we are doubling down on int/float inputs, unless jansel manages to change our minds.
  • Unbacked SymInts. Lots of progress: the stack at Get boolean masking to work with unbacked SymInts by ezyang · Pull Request #94523 · pytorch/pytorch · GitHub gets boolean masking working (and all but the last PR are passing CI). The general flavor of the work here is that you run into a lot of guards that fail on unbacked SymInts, but they all tend to be workaroundable one way or another. Edward next plans to tackle one of the internal models that needs this for mobile export, as well as getting range analysis going.
  • Model status on master. See also Symbolic shapes work items tracker - Google Sheets
    • aot_eager inference: -2 (+4 WoW). The CUDA 11.7 upgrade regression has been resolved. The remaining two errors are iadd related, which should be resolved by General in-place binary op support in dynamo by jbschlosser · Pull Request #94203 · pytorch/pytorch · GitHub
    • aot_eager training: 0 (unchanged). No regressions!
    • inductor inference: -8 (+2 WoW). The improvements are from hoisting from Natalia; there are also some new improvements from Natalia which should help with the remaining errors.
    • inductor training: still waiting on Horace to land his patch
  • Opinfo tests on symbolic shapes.
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 553 passed (+3 WoW), 523 skipped (+1 WoW), 196 xfailed (+4 WoW). The increase in xfails are from some new RNG opinfos contributed by min-jean-cho.
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 305 passed (+1 WoW), 146 skipped (+2 WoW), 185 xfailed (no change)
  • Graph breaks on master. -2 (+17 WoW). We made a lot of progress! timm_efficientdet and XLNetLMHeadModel are the last stragglers.
  • Tracing cost of enabling dynamic shapes (NEW!) Mean: 28s (-5s WoW), Max: 660s (-145s WoW; swin_base_patch4_window7_224). TODO: explain

What’s made it to master since last time?

ezyang

voz

jbschlosser

nkaretnikov

ngimel

bdhirsh:

What’s coming next?

  • ezyang: Unbacked SymInt range analysis, and then probably moving into the calling convention area
  • Chillee: landing inductor training, ??? not really sure ???
  • bdhirsh: landing inference functionalization, into per-dispatch key mode stacks and then torchquant/DTensor PT2 support
  • jbschlosser: helping with the core/export “fake sprint” https://docs.google.com/document/d/1NT2eCCT13QRqet_iCZzPXcjkBOK9qiuQKbaQsd8o4Tk/edit#heading=h.uh5wfznvwt0c ; straggler failures that aren’t captured in CI (moco, hf_BigBird)
  • voz: graph breaks, into dynamic to static guard rejection. Some minor rambling around with Z3.

Our north star: Dynamic shapes at feature parity with static shapes. Actively in discussions about getting back to “turned on by default, for some definition of default” for PT 2.1.

2 Likes