State of PT2: Sep 15, 2023 edition
Previous update: State of symbolic shapes branch - #69 by ezyang
Executive summary
Dynamo
- KJT tracing updates: Tracing torchrec_dlrm with distributed sharding manages to get to the wait on the sharded embedding table lookups, at which point we are stuck on a complicated custom autograd function. Voz to take a look after finishing up intermediate backward hooks. In other news, the production folks on the workstream have finished getting rid of layer splitting for disables only, so they’re now quite interested in compiling through as well. Lots of foundational work that still needs to be done; hoping for Q4 but is very aggressive! Meta only: https://docs.google.com/document/d/1VTGEh0MqadAsuRy0s5u39wQhNwMSVgCgYewivMcBbuU/edit#heading=h.jknt1mqmztph
- Animesh is going to be working on improving guard evaluation overhead, but there is still some disagreement among Voz, Jason and Edward about two major things: (1) should we port guards to C++ and do the rest of the scheme all in one go, and (2) should we stay in the “one compiled function, one check function” regime, or go straight to Voz’s one shared check function for everything.
- Some folks from the Cinder team came to the PT2 weekly to talk about some challenges of running PyTorch with lazy imports. One big problem is the way Dynamo implements skipfiles by traversing modules to find all identifiers attached to them; this plays poorly with lazy imports. Other trouble points include decorators which put identifiers into global state, and our dispatcher registration mechanism.
- Horace is complaining about compile time still kinda slow while he’s been working on llama. Profiling shows pytree is still big culprit (20%); we also spend a lot of time doing map_aggregate in FX (10%). Some discussion about reviving our fake tensor propagation rules caching idea.
- Meta only: We’ve had a lot of PT2 related SEVs recently. There’s been some initial investigation classifying what happened https://docs.google.com/document/d/1bMoQEoBlZ4vwsUztH1dEeEETHuNJXEj1uItQH8Cd7jo/edit#heading=h.dnms1ad3rdvu and some suggestions on what to do next https://docs.google.com/document/d/1jhwgscFWe_G8JDSSRbFpkZKy02cw8wipyuNYaHXF4Rg/edit . A lot of the problem stems from insufficient / flaky downstream testing. Michael Suo is leading the charge here.
- Unrelatedly, there is also some external feedback (Meta only: https://docs.google.com/document/d/1Ss3idfGSTV4GWElOe6pgw9JkJ_b0p5AaEdkf-Yzil7M/edit) that PT2 speedups are promising but hard to actually work reliably. A lot of it has to do with distributed, e.g., torch.compile/triton holding GIL during compilation and CompiledKernel call results in deadlocks during distributed training · Issue #109074 · pytorch/pytorch · GitHub and TorchInductor workers use "fork" which doesn't work in a multithreaded environment · Issue #108586 · pytorch/pytorch · GitHub
Inductor
- ABI compatible AOTInductor made a bit of progress this week, with [inductor] Make AOTInductor runtime interface ABI compatible by desertfire · Pull Request #109450 · pytorch/pytorch · GitHub and https://github.com/pytorch/pytorch/pull/109391 by Bin Bao.
- Will Feng looking into improved item() and tolist() support in Inductor: [WIP] Add .item() and .tolist() support in Dynamo/Inductor without graph break by yf225 · Pull Request #109262 · pytorch/pytorch · GitHub
Composability sync hit a lot of topics this week. Composability meeting notes - Google Docs Topics that weren’t otherwise covered in this doc:
- Elias told us about how SDPA pattern matches (and others; both inference and training patterns supported) are now compiled ahead of time, making it a lot cheaper to do lots of patterns. We took advantage of that to add a lot more patterns to match other SDPA variants. Add Python serialization to Pattern Matcher patterns by eellison · Pull Request #108894 · pytorch/pytorch · GitHub
- Chien-Chin told us about the new PT2 DDP plans. We cannot directly trace DDP because it is implemented in C++, and we cannot easily port it to Python because the implementation is complicated by bucketing. So the idea is to implement a Python non-bucketed DDP, and rely on compile to optimize it away.
- Horace told us about developments in LLMs. One thing he wants is dequant primitives in PT2: a way to take int3/int4 packed values and unpack them into a larger tensor, with the idea that PT2 would compile away the memory traffic. In general he doesn’t think we should directly do this in PT, as there are so many quantization formats.
Dynamic shapes
- Last week I mentioned opcheck testing is usable, but Richard Zou is still evolving it on user feedback. A recent new change is to put the xfails into a JSON file so it can easily be automatically updated. However, there are still complaints from folks that it’s too hard to understand what goes wrong when a test crashes. Richard is going to investigate a two stage process now, where by we separate generation of test inputs and actually running the tests. To ensure generation of test inputs is kept up to date, we only need a single new test which runs all of the tests in the test file in one go and xrefs what tests are exercised with what we have recorded.
- Horace wants a version of Tensor where some of the sizes are stored on device. This would allow you to perform a data-dependent operation without synchronizing; and you would still save on memory traffic because you would have kernels mask out memory loads when they go out of bounds of the dynamic shape. In some sense, this is a specialization of jagged tensor where everything in the jagged dimension has the same size.
- Notable bug fixes:
Numbers
This is nearly a month worth of numbers!
Training. 34ddf08f27 dashboard
- mobilevit_s in timm models no longer runs, it looks like it’s due to flash attention, Elias will be fixing it along with the pattern matcher PRs.
- Performance regression in torchbench between
0cfc5899f9bade72c7e18666e2006b003b5848bc..3a79621c9dce17f77fbddc06aab21f6bc477f313
. Testing inductor-A100-perf-nightly · pytorch/pytorch@ad74286 · GitHub inductor-A100-perf-nightly · pytorch/pytorch@4948bbc · GitHub - Hugging Face improvement from flash attention v2 landing again
- Meaningful compile time regression everywhere, with no obvious culprit. Example model:
Inference. 34ddf08f27 dashboard
- A lot of torchbench improvement: detectron2_fcos_r_50_fpn, doctr_reco_predictor, drq, llama, pyhpc_turbulent_kinetic_energy all now pass accuracy.
- cudagraphs freezing accuracy improvement in timm models, likely from some major bugfixes for freezing
- pytorch_stargan had huge perf improvement c2ac0da445cfe3d848342926f9cd4422bd35bfe2…781b7ebe912ec24cbd917cd548b748b1650ab6a2
- HuggingFace regression due to pin update Problems hit when upgrading the version of HF used in CI · Issue #108145 · pytorch/pytorch · GitHub
- Fairly large aot inductor regression due to ABI changes.