State of symbolic shapes branch

State of PT2: Nov 3, 2023 edition

Previous update: State of symbolic shapes branch - #71 by ezyang

Sorry about the month’s delay! Between more vacation and PTC there wasn’t much time to do a writeup over the weekend.

Executive summary

Big tickets

Dynamo

Core libraries

  • Ying Liu has been working on a tensor subclass for async execution. We discussed it in composability sync. The idea is that you can trigger an operation (typically communication) on a side stream, as well as some follow on operations, without having to literally move the follow on operations to the point where a sync happens. This also means that code in torchrec that has to be manually written as a pair of custom autograd functions for req/wait can be written in an intuitive, autograd style. We have a version that does this manually with callbacks (only queue kernels onto the stream at some known later point in time) and Ying is working on another version that uses streams only. One interesting thing we noticed that when you schedule allreduce in forwards first, backwards will naturally schedule it last, but you actually want the allreduce to happen ASAP! @albanD suggested we may be able to add an API to modify the priority order of autograd backwards, could be useful.
  • There will be a new repo https://github.com/pytorch-labs/ao for some of the new quantization schemes we’re working on. We discussed this in composability sync.
  • I did a long overdue update to record_stream docs at Add a note about performant record_stream use. by ezyang · Pull Request #112526 · pytorch/pytorch · GitHub after having some more discussions about it with @eellison who was trying to get cuda graph trees to work with record stream.
  • We’ve been talking about this with Vincent for a while, but there is now a proposed PR to add TensorDict to PyTorch core, check it out: [RFC] Tensordict integration by vmoens · Pull Request #112441 · pytorch/pytorch · GitHub

Dynamic shapes

Numbers

Training. 64f326097b dashboard

  • TIMM improvement is from channels last optimization

Inference. 64f326097b dashboard

  • 3% HF improvement from concat codgen on inference
5 Likes