State of symbolic shapes branch

State of symbolic shapes: Feb 19 edition

Previous update: State of symbolic shapes branch - #42 by ezyang

Executive summary

The branch cut for PyTorch 2.0 has come and gone. Training support was not landed to master in time for the branch cut, so it is unlikely to be working in the official PT2 release (on the bright side: inference functionalization made it for the branch cut! We think this fixed some inductor inference accuracy problems.) Stick to testing dynamic shapes on inference, sorry folks. A lot of the work this week was focused on inference export scenarios: there was quite a bit of progress on unbacked SymInts, the beginnings of a fine-grained dynamic shape API are in master and we also covered the 0/1 specialization problem in composability sync.

  • Unbacked SymInt progress. We made a lot of progress for unbacked SymInts support; most notably, value range analysis for unbacked SymInts has made it to trunk, so you can now use constrain_range on an unbacked SymInt to discharge guards (e.g., by marking an unbacked SymInt as >= 0). The entirety of CrystalDPR has successfully been traced end-to-end Redirecting... (Meta-only), and I’m currently in the process of landing the necessary changes / cleaning up hacks. The biggest new discovered item we have to handle is when user code tests if a mask has any True elements in it, and if so, performs a boolean mask (otherwise, it lets all the elements through); to faithfully export this, we must use torch.cond, and we also must be able to join divergent shape variables together (into an unbacked SymInt). We also identified that guard free implementations of PyTorch must be allowed to change stride behavior sometimes; we will justify this under stride agnostic PyTorch (which says changes in strides are not allowed to affect program semantics.)
  • Fine-grained dynamic shapes. Voz has landed an initial version of fine-grained dynamic shape control in https://github.com/pytorch/pytorch/pull/94787 . This is a compromise API, where you still have to mark dynamic=True and assume_static_by_default=False, pending some refactors to bypass this. The PR as landed still uses backed SymInts, and only tests if a dynamic dimension is overconstrained at the end of trace time; switching it to use unbacked SymInts for export and better diagnostics for why guards occurred is upcoming work.
  • 0/1 specialization. Export is moving to using unbacked SymInts for exporting dynamic dimensions, to ensure that the resulting programs are not 0/1 specialized. This is the outcome of discussing the following doc The 0/1 specialization problem in Pt2 Export - Google Docs in composability sync this week. If you’re curious what it all means, there’s a podcast giving basic background about this problem here: https://pytorch-dev-podcast.simplecast.com/episodes/zero-one-specialization
  • Model status on master.
    • aot_eager inference: -1 (+1 WoW). The last holdout is vision_maskrcnn, which is due to an FX printing problem involving constants and inplace addition.
    • aot_eager training: 0 (unchanged). No regressions!
    • inductor inference: -5 (+3 WoW). We no longer have accuracy failures with dynamic shapes (we never root caused this, but I bet it’s inference functionalization related); pytorch_unet was fixed by more improvements from Natalia
    • inductor training: still waiting on Horace to land his patch
  • OpInfo tests on symbolic shapes.
    • 557 passed (+4 WoW), 524 skipped (+1 WoW), 197 xfailed (+1 WoW). New OpInfo for _upsample_bilinear2d_aa among others.
    • 306 passed (+1 WoW), 148 skipped (+2 WoW), 185 xfailed (no change)

What’s made it to master since last time?

ezyang

voz

nkaretnikov

ngimel

bdhirsh:

What’s coming next?

  • ezyang: landing the rest of the CrystalDPR enablement fixes, presenting about how unbacked SymInt enablement works, then probably more on calling convention (but I also might want to work on some real model enablement for a bit? Depends on how much time I have)
  • Chillee: still landing inductor training
  • bdhirsh: per-dispatch key mode stacks and then torchquant/DTensor PT2 support
  • jbschlosser: not sure
  • voz: not sure