State of symbolic shapes branch: Nov 20 edition
The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ) is a long running branch containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch - #18 by ezyang
Commit ID at time of writing: 41c314473272b2622e20c261f19549b4bd3f1d8f
Executive summary
This week, we made two major merges into the branch: (1) @bdhirsh’s AOTAutograd input mutations PR and (2) voz’s fake tensor plumbing to AOTAutograd PR. While not all of the bugs from these two PRs have been totally resolved (and in particular, bugs in input mutations are suppressing the pass rate on the branch), the features work well enough that we only suffered a minor regression in aot_eager pass rate. Additionally, we have BERT_pytorch working end-to-end for BOTH training and inductor, meaning that on branch (but not master) we have achieved our goal for inductor.
- PSA: We now have int64_t, SymInt overloads for all binary operators in C++, so you no longer have to rewrite
2 + symint
intosymint + 2
; both work now. - PSA: DebugInterpreter is now actually verifying stride equality; at time of writing, this is revealing seven models which have incorrect sizes/strides. These bugs are ripe for picking!
- We have more clarity about what is missing to properly integrate inductor with dynamic shapes (instead of all of the hacks that are currently on branch.) A big question mark is whether or not ShapeEnv should be shared between AOTAutograd’s forward and backward; when the ShapeEnv is shared, Inductor cannot necessarily infer all of the shape variables (because a shape variable may only occur in a more complex shape expression from input, e.g.,
s0 * 2
). ref @Chillee is pivoting back to working on this, after having spent time working on channels last in the general workstream. There’s also some discussion about forward-backwards staging confusion at [aot_eager] [hf_Longformer] Cannot view a tensor with shape · Issue #1888 · pytorch/torchdynamo · GitHub - Some progress was made designing a more precise warning mechanism for Copy-on-write tensors for eager PyTorch - Google Docs see bottom of doc (though no implementation progress)
-
Model training status on symbolic-shapes. See also Symbolic shapes work items tracker - Google Sheets
- aot_eager: 128 out of 163 (-7 WoW) logs csv. Regression is primarily due to AOTAutograd input mutation support hitting our branch, but having two bugs (accidentally using SymInt sizes to perform manipulations on real tensors, and needing to regenerate views of intermediates rather than returning them directly).
- inductor: 36 out of 163 (+12 WoW) logs csv; notably, BERT_pytorch is passing with inductor now, and a spot check of generated Triton code suggests the output is dynamic: BERT_pytorch dynamic Triton · GitHub
- End-to-end BERT_pytorch working with dynamic shapes and inductor by ezyang · Pull Request #89313 · pytorch/pytorch · GitHub demonstrates the minimal set of changes from our branch necessary to get BERT_pytorch working end-to-end on master. The PR doesn’t pass CI as is; our current thinking is to do some necessary refactors first which will simplify the final form of this PR.
-
OpInfo tests on symbolic shapes.
-
pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive
- 494 passed (+106 WoW), 229 failed (-105 WoW), 512 skipped (+11 WoW) logs csv. This improvement is partially due to Towards unifying symbolic and non symbolic fake tensor, which allows us to attempt C++ meta functions even if they’re not known to be SymInt-aware; it turns out many of them still work correctly anyway. This is at the cost of worse error messages when SymInt is not supported. -
pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive
- 249 passed (-6 WoW), 221 failed (+8 WoW), 133 skipped (+4 WoW) logs csv. The regression here is from outstanding bugs on AOTAutograd input mutation changes on the branch; the numbers should improve once that regression is fixed.
-
Previous branch diff: 82 files changed, 2236 insertions(+), 747 deletions(-)
Current branch diff: 68 files changed, 2612 insertions(+), 554 deletions(-)
Notable bugs
- DebugInterpreter actually works now: Fix cat striding in PrimTorch was identified by looking at a model failing a stride quality assert in debug interpreter. I figured out that DebugInterpreter was not correctly testing strides for correctness when I fixed a separate stride problem that was causing an accuracy failure, and then attempted to diagnose why DebugInterpreter hadn’t caught it earlier; turns out the stride matching function returns a tuple, and tuples are always True
- Testing is important: unrelated refactoring on master broke our previously symint-ified upsample_nearest2d, and had to be fixed again in SymIntify upsample_nearest2d again after composite-ification. If we had appropriate testing on master, it probably could have been caught at regression time.
-
Simplify cudnn rnn support greatly untangles a complicated situation with cudnn SymInt support. The root cause of the problem is that
cudnn_rnn
is a composite function that calls a black box cudnn function to figure out what the output size of the workspace should be. This cannot be SymInt’ified, but in the original attempt to make this all work, our poor intrepid engineer tried SymInt’ifying all of the Descriptor code responsible for driving calls to the cudnn API. This patch undos all of that in favor of an earlier guard. However, this is not enough to fix tts_angular, as this eventually calls_cudnn_rnn
which requires a meta implementation, but this meta implementation requires query cudnn APIs. Handling this black box correctly may require something similar to how we plan to handle missing meta functions for custom functions (guess relations between inputs and outputs, and verify the black box function acts consistently for concrete values we haven’t seen before.) Or we can just guard (but guarding requires us to be able to call into the cudnn API from our Python implementation.) - Set functionalization storage to size 0 for sparse tensor wrapping is a lucky hack that works around a very delicate situation, which is that functionalization really doesn’t know how to handle sparse tensors. Fortunately, we don’t actually need to functionalize operations on sparse tensors, so hacking an empty storage “works”, but it is likely to fail if someone stress tests us on sparse support.
-
first draft of input mutation handling for aot autograd is affected by two bugs right now.
- “RuntimeError: Expected !is_symbolic() to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)” The problem here is that the patch accidentally directly tries to use fake tensor sizes (with SymInts) to reshape real tensors, accidentally propagating SymInts into real tensors (a big no no! And also missing asserts; we should add asserts.) At time of writing, this is close to fixed but not quite, using the idea of having the forward return sizes/strides/storage_offset directly as symint outputs of the forward. Most of the tests pass on the branch, but while hf_Longformer and tacotron2 no longer have the error, they’re both OOM’ing even when set a batch size of 1… this warrants further investigation. A suggestion is to use zdevito’s memory profiling tools to diagnose: Debugging PyTorch memory use with snapshots | Zach’s Blog
- “RuntimeError: Output 0 of AsStridedBackward0 is a view of a view which was created in no_grad mode and is being modified inplace with grad mode enabled.” This is the thing where our graph should return a graph intermediate, and manually regenerate the view(s) of the intermediate in the epilogue. @bdhirsh hasn’t started working on fixing it yet.
- Detach fake tensors into val, so they aren’t affected by metadata mut… was briefly discussed in composability; it has the potential to break users who are relying on accurate autograd metadata or tensor identity on meta[‘val’], but it turns out this doesn’t work anyway.
-
[HACK] don’t graph break on math.sqrt: we noticed that BERT_pytorch was repeatedly recompiling on every iteration. Voz tracked this down to the id of an NN module not staying stable over iterations; but it turned out that this was because we were graph breaking too much, due to an overly aggressive unsupported() call in Dynamo. Reducing the graph breaking fixed the problem!
- Disable explain for now: we wanted to collect graph break stats to make it easier to tell if BERT_pytorch situation was happening elsewhere. Dynamo’s preexisting explain feature makes this possible. However, when we turned it on many huggingface models started failing with “AttributeError: ‘str’ object has no attribute ‘size’”. This has not been investigated yet.
- Patch out suspicious get_item_dyn logic for jx_nest_base is an example of very long code that is also wrong. Long code for a special case is a smell, there’s probably a conceptually simpler way to do it!
- Dynamo was affected by a lot of “missing symbol s0” assertion errors. This assertion error tests if Dynamo knows how to compute a given SymInt from the set of tensors being guarded on. These missing symbols came from a variety of places: Dynamo’s special handling of NN parameters/buffers [dynamo] [dynamic shapes] Fix register_buffer (and all module associa…, as well as dependence on _base tensors due to functionalization Record TensorReference for ._base tensors, but don’t install guards o…. The latter is fixed in a tricky way: we don’t actually ever guard on base of tensor, because according to voz this caused a lot of recompilation. More investigation necessary…
- Remove bad string parsing assert is an example of why it is bad to do string matching on strings that represent structured data types (like language expressions).
- Hide ConvParams struct from ConvUtils.h and the other PRs in the stack are a nice example of how a large involved fix was split into a series of manageable refactors, and then a short fix in the ending. I attempted to fix it in one go at first, but then realized there was a simpler way to do it (template the ConvParams struct.)
What’s new on the branch this week?
Meta/decomp support
- Meta implementation for _thnn_fused_lstm_cell_backward_impl ezyang
- meta support for allgather_ anjali411, but was reverted in disable allgather_ meta
- fix *_scatter meta impls for weird strides bdhirsh
- Fix cat striding in PrimTorch ezyang
SymInt support
- SymIntify upsample_nearest2d again after composite-ification ezyang
- Symintify view_as_complex and view_as_real ezyang
- Simplify cudnn rnn support greatly ezyang
- Fix assert error
- Symintify obeys_layout_contract ezyang
Infrastructure
- Set functionalization storage to size 0 for sparse tensor wrapping ezyang
- Fix bug in debug interpreter where it doesn’t actually check strides ezyang
- first draft of input mutation handling for aot autograd bdhirsh
- Towards unifying symbolic and non symbolic fake tensor ezyang
- Don’t trace when we track_tensor_tree ezyang
- Detach fake tensors into val, so they aren’t affected by metadata mut… ezyang
- Reenable cache ezyang
QOL
- Greatly improve node printing with detailed kwarg ezyang
- Restore nice error mesage for non functional graph ezyang
- Relax debug interpreter check ezyang
- Turn on explain by default for runall [QOL] make explain a little easier to parse ezyang, but disabled in Double timeout, disable explain for now
- Setup proxy in the test scripts ezyang
- Make log_extract.py able to deal with NotImplementedError ezyang
Dynamo
- Patch out suspicious get_item_dyn logic for jx_nest_base ezyang; and then Attempt to make get_item_dyn way simpler voz
- Restore minifier int support ezyang
- Turn on debug interpreter by default ezyang
- [dynamo] [dynamic shapes] Fix register_buffer (and all module associa… voz
- Remove bad string parsing assert voz
- dynamo: graph break on torch.foo(symint, symint) bdhirsh
- Pass dynamo’s fake_mode down to aot_autograd, remove duplicate fake t… voz
- Record TensorReference for ._base tensors, but don’t install guards o… voz
- Use nn module source to make buffers on modules static shape fake ten… voz
- Get non-dynamic shapes codepath working again ezyang
- [HACK] don’t graph break on math.sqrt ezyang
- Add guard_source for RandomValueSource ezyang
Inductor
Merge to master retrospective
- Reland 2 “Towards unifying symbolic and non symbolic fake tensor (#89038) (#89143)” - One of the reverts was just because of an xfail land race. However, the other revert was a very scary Executorch segfault that didn’t reproduce locally. Full discuss at Redirecting... (FB only) but the way I debugged it in the end was getting it to repro in Sandcastle, and then ablating pieces of the patch until I got something that worked. (What I ablated in the end was the registration of view ops to meta key.)
- Reland “SymIntify convolution backend calculation (#89069)”". This was originally reverted because I changed a function into a template, and some internal call sites no longer successfully matched against the template. Fixed via Maintain overload selection BC for expand_param_if_needed
- Symintify numel(), infer_size, prims.elementwise_meta was reverted because it broke torch.numel; previously it returned an int, now it returned a Tensor! wrap() is probably doing something naughty for SymInt inputs, which still needs to be fixed.
- reland “Do not use unsafe restriding for subclasses (#87610)” was reverted because it broke an internal test. It’s not entirely clear what we didn’t different next time around.
What’s made it to master this week?
- ezyang
- Also include MKL_THREAD_LIB in link libraries for caffe2::mkl and Set INTERFACE_LINK_DIRECTORIES on caffe2::mkl because I got annoyed that torchaudio wouldn’t build by default when I link against MKL lol
- Reland 2 “Towards unifying symbolic and non symbolic fake tensor (#89038) (#89143)”
- Fix cat striding in PrimTorch
- Add support for dynamic kwarg to torch._dynamo.optimize
- Reland “SymIntify convolution backend calculation (#89069)”"
- Detach fake tensors into val, so they aren’t affected by metadata mutation
- Don’t trace when we track_tensor_tree
- Symintify obeys_layout_contract
- SymIntArrayRef type caster
- Add int64_t, SymInt overloads for all binary operators in C++
- Move ConvParams methods directly on struct
- Hide ConvParams struct from ConvUtils.h
- Fix some naughty uses of reshape/flatten
- SherlockNoMad
- bdhirsh
- anjali411
- nkaretnikov
- Krovatkin (helping us out with the XLA integration!)
- voz
What’s coming next?
By Person:
- Chillee: proper integration of inductor with dynamic shapes (he’ll actually work on it this week!!)
- voz: merge aot autograd plumbing to master, with a lot of refactor
- ezyang: maybe improved implementation of reshape CoW warnings (but only if I get some spare time)
- bdhirsh: continue fixing input mutation for AOTAutograd (last mile two bugs)
- jbschlosser: maybe one of the assert failures on aot_eager
Our north star:
- All benchmark models are passing aot_eager and inductor training on branch
- Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
- All OpInfo tests passing
- Dynamic shapes on by default for developers / users