State of symbolic shapes: Aug 12, 2023 edition
Previous update: State of symbolic shapes branch - #66 by ezyang
Executive summary
I’m trying something a little different, expanding the update to cover a wider variety of topics beyond dynamic shapes, mostly centered around things that I personally have involvement in (this is a lot of things, so you should be getting pretty good coverage this way!)
Benchmarking
- Inductor CI/perf is upgraded to CUDA 12 / gcc 9. This doesn’t seem to have any appreciable effect on perf, but we did it so we could do the next item.
- torchrec_dlrm is back. They were disabled a few months ago because of fbgemm nightly related flakiness. The flakiness has been resolved by building fbgemm/torchrec from source in the Docker image. These are now installed as part of the general torchbench installs, and should help some of the work we are doing on jagged tensors (since many important operators are currently implemented in fbgemm).
- Algorithmic efficiency. Frank Schneider posted about how PyTorch was slower than JAX in their upcoming algorithmic-efficiency benchmark suite. A bunch of us, spearheaded by @msaroufim, jumped in to take a look at what was going on. Status updates at https://docs.google.com/document/d/1okqKS32b0EhWQSFFoSV6IjGlYM4VhNYdxBPjdlFIw5w/edit (Meta-only). I personally have an interest in the dlrm side of things, since I’ve been working on sparse arch recently; after fixing some mild bugs, I was able to show parity on criteo1tb dlrm between PyTorch nightly and JAX on an A100x8 (PyTorch score: 7703.403180360794, JAX score: 7703.041719198227), although the number of evals varied, so I’m not sure if this a threat to validity. Unfortunately, this does not necessarily help their problem, which was an OOM. To make further progress on this, we may need some tools to help us understand why torch.compile memory usage is higher.
Export
- Pre-dispatch export, part 2. We had more discussion about pre-dispatch export in the Friday export meeting. @suo in particular was arguing that from a frontend perspective, it would make more sense to export pre-dispatch IR by default, and have the further post-dispatch lowerings be an extra pass on top that is opt-in by backends. One of the identified barriers to doing this is pre dispatch functionalization; the other is nondifferentiable decomps. nkaretnikov is going to take a look at
core_aten_decompositions
to see which of these are differentiable and which are not. In other news, torch.export is going platinum Expose torch.export() API by gmagogsfm · Pull Request #106904 · pytorch/pytorch · GitHub - dim order coming to Tensor. We probably should have added this API a long time ago, but export really wants this on Tensor so in it goes. [PyTorch][Tensor] Introduce tensor.dim_order by digantdesai · Pull Request #106835 · pytorch/pytorch · GitHub
Distributed
- Tracing FSDP. @voz wrote a post Redirecting... (Meta-only) about the state of tracing FSDP in Dynamo. The key info is that on a branch, he can trace everything through and get identical results on a single forward-backward to eager. There’s a lot of fixes that need to land to main; from his post:
- The value of various small changes to FSDP to make this work vs adding fixes in dynamo (Pretty easy, preferring dynamo ofc but for some mostly no op shuffling, we do FSDP as well)
- TypedStorage - is it tensor-like/tensor-associated enough to go in the graph? Do we need to add some ops for doing tensor typed storage data ptr comparison / checking free, etc?
- Working through the cudastream story, in particular around wait_stream and such
- Lot’s of little bug fixes here and there
- Coverage for missing comparisons, bytecode ops, general coverage gaps like attr access on FSDP modules, setting data on a tensor, etc.
- pytrees slow again for DTensor. Junjie and Rodrigo have been trying to improve DTensor’s eager perf, and we spent the first half of composability sync talking about it. Rodrigo had a hack to pre-compile pytree applications into Python code but apparently this doesn’t help that much: gist:5427cabfab6421d4e104905345f94a50 · GitHub . Another suggestion from the meeting was that after Brian’s subclass supports lands, maybe you could torch.compile each op individually with backend=“eager”.
- Data-dependent all2all. Will Feng got all2all collective working in inductor https://github.com/pytorch/pytorch/pull/106655/ This is notable because all2all collective has data-dependent output shape. It looks like unbacked symints worked here!
Custom ops
- Custom ops. Richard tells me he is going to add a class-based API for custom ops, to make it easier to define everything all in one place. More on this soon I assume!
- Custom op testing. https://github.com/pytorch/pytorch/pull/106903 is here to make it easier to retrofit pre-existing test suites to also test for important operator properties.
Nested/jagged tensor
- SkolemSymNodeImpl. @jw3468 is going to make size() work on jagged tensor by introducing a new concept to SymInt provisionally called SkolemSymNodeImpl. This is a special SymInt which is not symbolic (it can show up in eager mode) but only compares equal to itself (aka is a skolem variable). We will use this to represent jagged dimensions. All jagged tensors that have the same offsets tensor get assigned the same skolem variable, if you have different offsets tensors you can’t add them together because their skolem variables don’t match. More details at https://docs.google.com/document/d/1e-R_818YA4VlVTlozu5eyzRIV6TzyvSPDm9DMEw_4xg/edit (Meta-only)
- SAM single batch, vmap for nested tensor. @jbschlosser has been working on integrating nested tensor with SAM, and one challenge with SAM is that it is written in a single-batch style, so the first problem is batchifying the model in the first place. Last week, an idea was to use vmap to automatically convert single-batch to multi-batch, and there is a PoC for this (WIP) PoC for vmap + NT by jbschlosser · Pull Request #106786 · pytorch/pytorch · GitHub but there are still a number of spots in SAM which are not so easy to vmap across https://docs.google.com/document/d/1_yiHOBbaI4qFWqBfebjWPHOhkxKW3v-CHu3lj5apv1Y/edit . Joel is going to try a few more days on this, and then pivot if it is still not looking promising.
Dynamo
- Pivot on per-NN module caching. @anijain2305 is working on having a separate code cache per NN module, but on Friday with the help of @voz we realized that you actually the problem is separable into two pieces: (1) an enhanced cache size limit policy that knows about NN modules [RFC][dynamo] Separate cache sizes for nn module guard specialization by anijain2305 · Pull Request #107077 · pytorch/pytorch · GitHub and (2) improvements to cache lookup when there are a lot of cache entries (guard trees).
- Dynamo eager mode cond. Yidi Wu: to support cond in eager mode, we plan to torch.compile the entire cond operator, manufacturing fresh code objects to ensure that the caches don’t interfere with each other. https://docs.google.com/document/d/1esmHEa0fiktiSw1lvRsPmsbnTYxDSc0t3V9V5V0xK7I/edit#heading=h.pajqpbewbdg7 (Meta-only)
- Time to get rid of functional VariableTracker? VariableTracker in Dynamo is an immutable data structure: when a mutation happens, you allocate a fresh VariableTracker and then replace old VariableTrackers with the new one. This is because we have checkpointing functionality that is used to rewind old VariableTracker. However, this is a bit of pain from the modeling side, as every Python data structure has to be reimplemented to have purely functional operations. An alternate design is to allow direct mutation of VariableTrackers. To do checkpoints, we simply restart Dynamo analysis to “go back in time” by stopping execution at the point where we would have checkpointed (a deepcopy could also work, though I’m not a fan.) Speculate subgraph would be implemented by simply denying all mutations or doing some crazy thermometer continuation thing. This would help make Dynamo more metacircular and reduce the work needed to support new container types, of which we often need to support a lot.
Dynamic shapes
- expect_true irrefutable guards. I talked through this in the last 20min of composability sync. Check https://github.com/pytorch/pytorch/pull/106720 ; this is enough to make splits on unbacked SymInts work.
- Boolean masking, at last. @yanboliang is looking into a pre-autograd FX transform that replaces boolean mask updates with torch.where calls. One annoying detail is how to deal with Dynamo tracing the boolean masks in the first place, when Inductor can’t deal with boolean masks if you can’t eliminate them? Our idea, in lieu of fixing Inductor to work with data-dependent shapes (which we are working on), is to attempt to eliminate all data-dependent ops in a pre-dispatch pass, and if it is not possible, restart Dynamo analysis saying “you need to graph break on this op next time.”
- Notable fixes.
- SymInt’ify tile. This one needed for algorithmic-efficiency criteo1tb dlrm.
- [export] Refactor
constrain_as_value
andconstrain_as_size
from Tugsuu (was bounced, needs relanding)
- Notable new bugs.
Numbers
Training. 03414081ff Dashboard
- Some accuracy regressions. torchbench: hf_BigBird, vision_maskrcnn (flaky). It’s not clear what broke hf_BigBird; possibly the CUDA 12 upgrade. Need to investigate. AlbertForQuestionAnswering improved accuracy!
- The huge perf improvement across the board is thanks to Peter Bell’s work https://github.com/pytorch/pytorch/pull/106747 optimizing split reductions. This is not full runtime split reductions: instead Peter uses whatever the hint was at the time we compiled to plan the split reduction, and then we use it for all subsequent runs. This makes it more important to warm up Inductor with the “right” size hint to start; see also Padded tensor subclass · Issue #105325 · pytorch/pytorch · GitHub ; there was also another user complaining about other cases where we made suboptimal decisions if the first kernel we compiled with wasn’t representative
Inference. Dashboard 03414081ff
- A lot of change on the last day; some improvements and some regressions (but mostly regressions). Maybe CUDA 12 update related, need to check. hf_BigBird also failing here too. RobertaForQuestionAnswering failing accuracy now