State of symbolic shapes branch

State of symbolic shapes: Aug 6, 2023 edition

Previous update: State of symbolic shapes branch - #65 by ezyang

Executive summary

  • More on KJT/torchrec. I had a nice discussion with Dennis van der Staay about torchrec and work on sparse arch. Some new information: (1) this workstream is almost certainly going to involve distributed later, because applying PT2 to post-torchrec sharded models is going to involve tracing past communication primitives–this also implies I’m going to want to get FakePG working on torchrec, (2) working on unit tests should be a pretty good idea, but there’s still some basic infra work to do (laid out last week), (3) not really expecting concrete performance improvements as sparse arch is going to be typically communication bound, so this is a mostly “we think this is promising, and the investment is not too big, because we’ve already done so much with dynamic shapes so far.”)
  • Pre-dispatch export. We’ve agreed to allow QAT to short-term publish a new export interface that produces a pre-dispatch FX graph with ATen operators which is suitable for graph transformations and training. The long term goal will to be have pre-dispatch functionalization which is the invariant the export team wants to allow this to be worked into torch.export proper. Pre-dispatch will generate an ExportedModule so that the APIs match.
  • Fake export. Export now supports exporting entirely fake modules/inputs. This means to export a model you don’t have to actually load its weights into memory; you can load it in a fake mode and still export it. This means we have some delicate code in Dynamo for dealing with two concurrent fake modes (but it’s not so bad: the outer fake mode is typically disabled while we do Dynamo analysis.) Only ONNX supports torch.load’ing models in fake mode at the moment.
  • Improved user stacks in Dynamo. torch._guards.TracingContext.extract_stack() now always accurately reports a user stack from anywhere in Dynamo, and we reliably use it for reporting real stacks for exceptions (previously, they used an entirely different mechanism.)
  • Improved error messages for non-local inputs in export. See Improve error message when export encounters non-local input for the details. This isn’t complete; follow through is also to make this work for outputs, and also work a little harder with the pytree representation (probably this week.)
  • Dynamo change in attitude. Many folks are concerned that Dynamo is just “endless” bugs. I pitched Animesh and Voz on a new attitude to fixing Dynamo bugs, which is that we should imagine the platonic ideal implementation of Dynamo as a faithful reimplementation of CPython in Python. Then, fixing a bug should not just be moving code around to fix a particular problem, but instead improving the local vicinity of code to bring it closer in line to this ideal. An example I used a lot explaining this was dict.keys support (bug fix is changing its type from tuple to set; real fix is to accurately model dict views.) To do this well, you need to regularly look at CPython code, and Dynamo may need to grow some new abstractions (perhaps a proper implementation of Python’s object model, Python traceable polyfills).
  • Notable new bugs.


As we’re not really doing much on performance numbers recently, I am simplifying this section.

Training. 68cb854d73 Dashboard

Nothing much to report.

Inference. 68cb854d73 Dashboard

The big perf increase in torchbench is due to maml getting removed from the benchmark set (it slows down a lot under PT2 and was depressing the score). clip, hf_Whisper, llama_v2 are new models added thanks to @msaroufim !

What’s next?

There are a lot of things that need doing

  • Finish overhauling export input/output pytree matching (probably not dumping the pytree in/out spec, but I think if we tree_map into positional identifiers we can reliably detect KJT missing situations)
  • Make unbacked SymInts work in Inductor gist:1293a41299604c44310341b7540eabcb · GitHub (biggest problem is unbacked SymInt binding in wrapper codegen and the hinting logic)
  • Irrefutable guards
  • Write up the plan for sparse arch / KJT
  • Land pytree support for KJT/JT
  • 0/1 specialization suppression for list of int in KJT

Stuff that probably can wait until later?

  • Host side torch.cond
  • DynTensor