State of symbolic shapes branch

State of symbolic shapes: Aug 12, 2023 edition

Previous update: State of symbolic shapes branch - #66 by ezyang

Executive summary

I’m trying something a little different, expanding the update to cover a wider variety of topics beyond dynamic shapes, mostly centered around things that I personally have involvement in (this is a lot of things, so you should be getting pretty good coverage this way!)

Benchmarking

  • Inductor CI/perf is upgraded to CUDA 12 / gcc 9. This doesn’t seem to have any appreciable effect on perf, but we did it so we could do the next item.
  • torchrec_dlrm is back. They were disabled a few months ago because of fbgemm nightly related flakiness. The flakiness has been resolved by building fbgemm/torchrec from source in the Docker image. These are now installed as part of the general torchbench installs, and should help some of the work we are doing on jagged tensors (since many important operators are currently implemented in fbgemm).
  • Algorithmic efficiency. Frank Schneider posted about how PyTorch was slower than JAX in their upcoming algorithmic-efficiency benchmark suite. A bunch of us, spearheaded by @msaroufim, jumped in to take a look at what was going on. Status updates at https://docs.google.com/document/d/1okqKS32b0EhWQSFFoSV6IjGlYM4VhNYdxBPjdlFIw5w/edit (Meta-only). I personally have an interest in the dlrm side of things, since I’ve been working on sparse arch recently; after fixing some mild bugs, I was able to show parity on criteo1tb dlrm between PyTorch nightly and JAX on an A100x8 (PyTorch score: 7703.403180360794, JAX score: 7703.041719198227), although the number of evals varied, so I’m not sure if this a threat to validity. Unfortunately, this does not necessarily help their problem, which was an OOM. To make further progress on this, we may need some tools to help us understand why torch.compile memory usage is higher.

Export

Distributed

  • Tracing FSDP. @voz wrote a post Redirecting... (Meta-only) about the state of tracing FSDP in Dynamo. The key info is that on a branch, he can trace everything through and get identical results on a single forward-backward to eager. There’s a lot of fixes that need to land to main; from his post:
    1. The value of various small changes to FSDP to make this work vs adding fixes in dynamo (Pretty easy, preferring dynamo ofc but for some mostly no op shuffling, we do FSDP as well)
    2. TypedStorage - is it tensor-like/tensor-associated enough to go in the graph? Do we need to add some ops for doing tensor typed storage data ptr comparison / checking free, etc?
    3. Working through the cudastream story, in particular around wait_stream and such
    4. Lot’s of little bug fixes here and there
    5. Coverage for missing comparisons, bytecode ops, general coverage gaps like attr access on FSDP modules, setting data on a tensor, etc.
  • pytrees slow again for DTensor. Junjie and Rodrigo have been trying to improve DTensor’s eager perf, and we spent the first half of composability sync talking about it. Rodrigo had a hack to pre-compile pytree applications into Python code but apparently this doesn’t help that much: gist:5427cabfab6421d4e104905345f94a50 · GitHub . Another suggestion from the meeting was that after Brian’s subclass supports lands, maybe you could torch.compile each op individually with backend=“eager”.
  • Data-dependent all2all. Will Feng got all2all collective working in inductor https://github.com/pytorch/pytorch/pull/106655/ This is notable because all2all collective has data-dependent output shape. It looks like unbacked symints worked here!

Custom ops

  • Custom ops. Richard tells me he is going to add a class-based API for custom ops, to make it easier to define everything all in one place. More on this soon I assume!
  • Custom op testing. https://github.com/pytorch/pytorch/pull/106903 is here to make it easier to retrofit pre-existing test suites to also test for important operator properties.

Nested/jagged tensor

Dynamo

  • Pivot on per-NN module caching. @anijain2305 is working on having a separate code cache per NN module, but on Friday with the help of @voz we realized that you actually the problem is separable into two pieces: (1) an enhanced cache size limit policy that knows about NN modules [RFC][dynamo] Separate cache sizes for nn module guard specialization by anijain2305 · Pull Request #107077 · pytorch/pytorch · GitHub and (2) improvements to cache lookup when there are a lot of cache entries (guard trees).
  • Dynamo eager mode cond. Yidi Wu: to support cond in eager mode, we plan to torch.compile the entire cond operator, manufacturing fresh code objects to ensure that the caches don’t interfere with each other. https://docs.google.com/document/d/1esmHEa0fiktiSw1lvRsPmsbnTYxDSc0t3V9V5V0xK7I/edit#heading=h.pajqpbewbdg7 (Meta-only)
  • Time to get rid of functional VariableTracker? VariableTracker in Dynamo is an immutable data structure: when a mutation happens, you allocate a fresh VariableTracker and then replace old VariableTrackers with the new one. This is because we have checkpointing functionality that is used to rewind old VariableTracker. However, this is a bit of pain from the modeling side, as every Python data structure has to be reimplemented to have purely functional operations. An alternate design is to allow direct mutation of VariableTrackers. To do checkpoints, we simply restart Dynamo analysis to “go back in time” by stopping execution at the point where we would have checkpointed (a deepcopy could also work, though I’m not a fan.) Speculate subgraph would be implemented by simply denying all mutations or doing some crazy thermometer continuation thing. This would help make Dynamo more metacircular and reduce the work needed to support new container types, of which we often need to support a lot.

Dynamic shapes

  • expect_true irrefutable guards. I talked through this in the last 20min of composability sync. Check https://github.com/pytorch/pytorch/pull/106720 ; this is enough to make splits on unbacked SymInts work.
  • Boolean masking, at last. @yanboliang is looking into a pre-autograd FX transform that replaces boolean mask updates with torch.where calls. One annoying detail is how to deal with Dynamo tracing the boolean masks in the first place, when Inductor can’t deal with boolean masks if you can’t eliminate them? Our idea, in lieu of fixing Inductor to work with data-dependent shapes (which we are working on), is to attempt to eliminate all data-dependent ops in a pre-dispatch pass, and if it is not possible, restart Dynamo analysis saying “you need to graph break on this op next time.”
  • Notable fixes.
  • Notable new bugs.

Numbers

Training. 03414081ff Dashboard

  • Some accuracy regressions. torchbench: hf_BigBird, vision_maskrcnn (flaky). It’s not clear what broke hf_BigBird; possibly the CUDA 12 upgrade. Need to investigate. AlbertForQuestionAnswering improved accuracy!
  • The huge perf improvement across the board is thanks to Peter Bell’s work https://github.com/pytorch/pytorch/pull/106747 optimizing split reductions. This is not full runtime split reductions: instead Peter uses whatever the hint was at the time we compiled to plan the split reduction, and then we use it for all subsequent runs. This makes it more important to warm up Inductor with the “right” size hint to start; see also Padded tensor subclass · Issue #105325 · pytorch/pytorch · GitHub ; there was also another user complaining about other cases where we made suboptimal decisions if the first kernel we compiled with wasn’t representative

Inference. Dashboard 03414081ff

  • A lot of change on the last day; some improvements and some regressions (but mostly regressions). Maybe CUDA 12 update related, need to check. hf_BigBird also failing here too. RobertaForQuestionAnswering failing accuracy now
1 Like