State of PT2: Sep 8, 2023 edition
Previous update: State of symbolic shapes branch - #67 by ezyang
We were on break for two weeks because I went on vacation, and I didn’t have time to do a report before/after vacation lol.
Executive summary
- PyTorch 2.1 branch cut. The cut was three weeks go (right when I went on vacation lol) and we’re reaching the end of the cherry-pick window. Track ongoing cherry picks at: https://github.com/pytorch/pytorch/issues/108055
- Blueberries offsite was this week! The blueberries workstream is focused on accelerating SOTA transformer models using PT2, quantization, sparsity and other techniques. Some highlights: MFU is coming to the benchmark suite, some direct improvements to important models, int8 dynamic quantization with tensor subclasses. Many of these are not published yet, keep your eyes peeled at PTC!
- PyTorch conference registration filling up fast. If you want to go and haven’t registered yet, you should register at PyTorch Conference | Linux Foundation Events
Composability sync
- Aug 24 https://www.youtube.com/watch?v=H6EUSsvDmbw - we spent time going over recent KJT progress (to be reduxed below), and Voz reported progress on tracing FSDP with hooks (also to be reduxed below)
- Aug 31 - not livestreamed publicly, I wasn’t there, but apparently there was some discussion about streams for tracing FSDP (no minutes alas)
Distributed and PT2
- Tracing FSDP with Voz is deep in the weeds on backwards hooks support. We are attempting to implement hooks in a way that doesn’t require consolidated forward-backwards. The general strategy is (1) have Dynamo emit graphs that have
register_hook
calls on intermediates (register_hook
calls on inputs must not go in the graph, they have to happen as part of residuals), (2) write theseregister_hook
calls in such a way that when AOTAutograd runs, the actual hook code (which is arbitrary Python code and is not safe to run in tracing) is not run, but instead we run a meta function (which performs any needed metadata mutation) and then insert a call function to the original Python function (which will show up in backwards), (3) have compiled backwards take care of compiling this call function in the end. - Per parameter FSDP is looking pretty legit. Andrew Gu has been looking at the performance of per-parameter sharding (where parameters managed by FSDP aren’t shoved into a single flat buffer) and has found that we only really pay a penalty of 5% with per-parameter sharding but get better memory usage. Meta only: Redirecting...
- DDP optimizer brittleness. We currently support pipelining DDP code with PT2 by manually splitting graphs into multiple AOTAutograd functions so that backwards isn’t run too soon. The code here is kind of janky: I ran into two separate bugs that only happend when
optimize_ddp
was on: [DDP PT2] TypeError: convert_frame_assert.<locals>._convert_frame_assert() missing 2 required positional arguments: 'hooks' and 'frame_state' · Issue #107637 · pytorch/pytorch · GitHub and [optimize_ddp] moco - NameError: name 's2' is not defined · Issue #108877 · pytorch/pytorch · GitHub . Pritam has also been complaining about the graph break strategy: torch.compile graph breaks should be independent of DDP buckets · Issue #108966 · pytorch/pytorch · GitHub Will tells me that Chien-Chin is working on some new DDP strategy, but it appears to be centered around starting with a non-parallelized graph. Hopefully we can present it at composability this week. Note that DDP cannot be easily traced as it is implemented in C++.
Dynamic shapes
- Avik is proposing a change to the
dynamic_dim
API currently used to express dynamism in export API. Instead, they will adopt a Python typing generics style solution, where you bind generic variables for dynamic dimensionsbatch = Dim("batch", max=64)
and then use this to annotate types on input tensorsx: TensorType[batch, K, N]
. This is very reminiscent of [discussion] Expressing tensor dimension semantics / constraints through typing / constraints blocks. Constraints block could be scripted/traced and help for tracing/script execution and codegen · Issue #40373 · pytorch/pytorch · GitHub Meta only: https://docs.google.com/presentation/d/168U7XK72C_WSsZpGESP6Cho9udh193fi0gfjxCNcJ4E/edit - This is not really PT2 related, but there’s an interesting set of posts about the future of Sympy circulating around: Towards a new SymPy: part 1 - Outline — blog documentation Funnily enough, the part of Sympy which Oscar calls out as “overused” (the symbolic expression system) is precisely the part we actually care about. Maybe a good reason for us to figure out some way to note use this part (me, personally, I want a compact representation and hash consing.)
- I discussed this in a bit of detail in composability three weeks ago, but work on supporting fine-grained KJTs is going very well. This week, I worked with Michael Suo to get APS sparse arch tracing all the way through. I managed to get it going all the way through (though it failed on some seemingly unrelated problem.) So fine-grained tracing definitely seems like it will work, even if we generate tons of crappy guards. My plan for next week is to make a serious attempt at tracing multi-node model parallel sharded torchrec_dlrm.
- This week, when I had spare time in the offsites, I worked on fixing a few one-off bugs. There were several that were pretty easy to nail:
- Don’t fastpath conj copy when conj/neg bit mismatch
- Fix setitem with SymInt
- Add support for symbolic repeat_interleave
- Add torch._check_is_size
- Meta implementation for nms by ezyang · Pull Request #7944 · pytorch/vision · GitHub
- Avoid creating a tensor of shape when not tracing by ezyang · Pull Request #7942 · pytorch/vision · GitHub
- While working on the meta implementation for nms I played around with Richard Zou’s opcheck testing: https://github.com/pytorch/pytorch/pull/106903 It still needs some improvements ([RFC] Run only one pytest parametrization when generating optest by ezyang · Pull Request #108936 · pytorch/pytorch · GitHub Use a bit-identical test for mutation test by ezyang · Pull Request #108935 · pytorch/pytorch · GitHub optests improvements based on torchvision usage on nms by ezyang · Pull Request #108929 · pytorch/pytorch · GitHub) but I was able to get it to work end-to-end. Seems pretty promising!
Inductor fun
- Peter Bell is very close to landing inductor IR support for scan https://github.com/pytorch/pytorch/pull/106581 which allows for native cumsum/cumprod support. Now all we need is for someone to add a higher order op that feeds into this and we will have torch.scan!
- Someone should add a “realize” operator to PT2, which would force materializing a tensor rather than allowing fusions across it. Christian Puhrsch would find this useful for ensuring epilogue fusion occurs on int8 mm (today, regular fusion causes the pointwise operation to get fused into a later reduction, instead of fusing the pointwise into the matmul)
- ABI compatibility for AOT Inductor is continuing to proceed apace slowly, but one agreement is that we’re probably going to also only have the ABI compatible codegen for OSS as well.
Performance
- Flash Attention 2 is close to landing: Flash Attention v2 by drisspg · Pull Request #105602 · pytorch/pytorch · GitHub but it is currently stuck because it takes a lot of memory to compile, causing CI problems.
- In the PT2 weekly meeting, we discussed H100 benchmarking. There are a lot of interlocking parts to this: we need to upgrade Triton to get their H100 improvements, and not everyone on the PyTorch team has access to an H100. Still looking for someone to sign up for this.
- CUDA graph updates are a thing now: 1. Introduction — CUDA C Programming Guide There may be some opportunities here. Elias says: “It mostly helps with eliding input copies. For the most part, removing input copies only really matters when you torch.compile only part of your model and leave the rest of the model in eager. This use case is pretty unlikely to train well anyway since you’ll still need to bifurcate the memory pool.” However, personally, I also think CUDA graph updates could be pretty useful for allowing you to deallocate the pool of memory needed by a CUDA graph, only reallocating it when it’s time to run the CUDA graph again.
Dynamo
- There was a pretty notable pytree API BC breakage which caused some internal problems: Serialize pytree to json string by angelayi · Pull Request #106116 · pytorch/pytorch · GitHub
- Some big refactors that are in progress: refactoring skipfiles / allowed functions (talk to Yanbo), refactoring guard trees (talk to Animesh)
- A bunch of new contributors being onboarded to Dynamo: Quansight is working more on Dynamo issues, and Jack Cao from PyTorch XLA is looking to help us with consolidated forwards-backwards-optimizer support in Dynamo as it is essential for XLA Dynamo perf.
Numbers is on break this week due to A100 runners down: apt-get install nvidia-docker2, Could not get lock /var/lib/dpkg/lock-frontend · Issue #108862 · pytorch/pytorch · GitHub