State of symbolic shapes branch

Which version is expected to receive symbolic shapes?

The current plan is that we will have a convincing alpha in the next release (we have a lot of stuff working in the branch, but it needs to be merged to master and we’re not bug free enough to turn it on by default), and hopefully in two releases it will be on by default.

1 Like

As promised: On the architecture of torchdynamo - Google Docs

State of symbolic shapes branch: Nov 5 edition

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ) is a long running branch containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch - #9 by ezyang

Commit ID at time of writing: 1f5fac1d10df2e4a054740abc92bcf9d6a6553eb

Executive summary

This week was relatively light on both commits to the branch and merges to master; the bulk of our time was spent on late-breaking infrastructure problems (including addressing problems that affect not-dynamic shapes) and onboarding onto Dynamo. On the plus side, we have solutions to a number of longstanding problems to the overall compilation stack, and we managed to claw back a bit of aot_eager TDS model coverage (+19 passing models) by fixing Dynamo bugs.

  • Milestone: the Edge team imported the symbolic-shapes branch into fbcode, and used it to successfully export one of their models which makes use of symbolic shapes. This is without any specific work to support their use case, and is a nice external validation of the work we’re doing. (Also, @suo would like to know when we will finish merging everything to master thankyouverymuch.)
  • New design: AOTAutograd 2.0 - Google Docs - This design doc describes how to make AOTAutograd work with graphs that have input mutation. This is a release blocking issue as right now instance norm and batch norm are miscompiled by functionalization as running state updates are being removed ([PrimTorch] Functionalization pass removes Instance Norm / Batch Norm running stats transformations · Issue #88375 · pytorch/pytorch · GitHub) and some models are failing to get optimized because they have input mutation. @bdhirsh is working on the input mutation piece of the story.
  • New design: Stride agnostic IR (reshape preservation) - Google Docs - This design doc describes how to make our compiler stack less sensitive to stride propagation bugs, by observing the fact that accurate strides are specifically needed only by view/reshape in most user models, and we can remove this dependence by introducing infallible view() and input/output freezing reshape.
  • The team spent time onboarding onto Dynamo this week. There was a nice session lead by Voz at Redirecting... (FB only), and a public writeup about what we learned from Voz and @jansel at On the architecture of torchdynamo - Google Docs @bdhirsh and anjali411 were able to make their first contributions to Dynamo this week, yay!
  • Preserve reshapes in AOTAutograd by ezyang · Pull Request #88361 · pytorch/pytorch · GitHub is the first example of overriding a pre-autograd decomposition with a user defined decomposition. You may find the pattern useful in other contexts.
  • We have more automation for running sweeps and updating the spreadsheet, so the spreadsheets are now up-to-date and tracking both models and tests.
  • Model training status on symbolic-shapes. (@ezyang) See also Symbolic shapes work items tracker - Google Sheets (up to date!)
    • aot_eager, with TDS: 137 out of 163 (+19 WoW) logs
    • inductor, with TDS: 8 out of 163 (-37??? WoW; I looked at the old logs and I’m not sure how I got 45 passes last week; the same models that were passing last week are passing this week, so maybe there was a setup problem) logs
    • Lowlight: we have two new accuracy failures on our models (cait_m36_384, xcit_large_24_p8_224). This is bad, because it means we do not have enough asserts to catch when we are doing incorrect things. These models should be investigated.
    • Lowlight: a number of models are failing due to sympy timeouts. We need to figure out a strategy for solving this once and for all. @Chillee has suggested that we may want to try rewriting sympy.solve for our use case.
  • OpInfo tests on symbolic-shapes. (@ezyang)
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 350 passed (+3 WoW), 370 failed (-4 WoW), 499 skipped (+2 WoW) logs
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 240 passed (+2 WoW), 228 failed (-18 WoW), 127 skipped (+2 WoW) logs

Previous branch diff: 110 files changed, 2788 insertions(+), 2114 deletions(-)
Current branch diff: 76 files changed, 2341 insertions(+), 533 deletions(-)

Notable bugs

What’s new on the branch this week?

SymInt support

Dynamo

Functionalization

Infrastructure

Quality of life

Merge to master retrospective

  • Do not use unsafe restriding for subclasses was reverted because it broke some internal fbcode tests.
  • Revert “Revert “Put Python Dispatcher cache in dict, clear it on new registrations. (#88329)”” was reverted for making test times balloon by 2hrs! The root cause of the problem was a refactor that switched a cache to making use of a variable that was reassigned (key = resolve_key(key); cache(key)), causing the cache to never get hit and massively slowing down test runtime. @ezyang figured out the problem by guessing it was a cache problem and then reproducing it.
  • Reland 2 Many symintifications (#87604) was successfully landed to fbcode, but it turns out it actually broke static runtime. This is because tensor_split only had one overload ported, and the IntArrayRef was actually accepting int64_t arguments, causing call sites that intended to go to the other overload go to the wrong overload. This was point fixed by just porting the other overload to have an explicit signature. It’s not clear there’s a structural fix for this problem; please be aware when adding BC native:: signatures for operators with multiple overloads.

What’s made it to master this week?

What’s coming next?

  • Brian, you still have a lot of open PRs lol, please keep landing them
  • Voz, we need a plan for how we are going to land Dynamo fixes to master. Need to discuss with Alban if merge captain should also be doing Dynamo fixes; the main difficulty is they need proper review. (@jansel spot checked some of the changes on the branches and found a number of suggestions, so we will need to be rigorous about this.)
  • Fix input mutations for AOTAutograd (bdhirsh)
  • Stride agnostic IR (ezyang)
  • E2E training on master with inductor.
  • All benchmark models are passing aot_eager training on branch; tracked at Operators that need symbolic support - Google Sheets
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing

State of symbolic shapes branch: Nov 12 edition

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ) is a long running branch containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch - #16 by ezyang

Commit ID at time of writing: 807a62fc61bea26707c3dc09a12bad204e375a95

Executive summary

This was a chaotic week. Meta had layoffs for the first time in its history (this workstream was not directly affected.) We still made progress on this workstream (merge to master, inductor support, operator coverage), but we also discovered more work to do (more dynamo bugs, dynamic shape guard problems, more accuracy failures). Some big work items (dynamo merge to master, input mutation, copy-on-write tensors) have progressed, but are still short of actually landing to master (or even to the branch, as the case may be). Merge to master is also slow going as we often have to first add OpInfo support before we can merge our changes.

  • Staffing. Nikita Karetnikov from Quansight is pivoting from general PrimTorch work to working on decompositions / meta implementations that the dynamic shapes workstream can specifically benefit from. Welcome Nikita! In other news, Edward is on jury duty next week.
  • Design. We have shifted the “stride agnostic IR” concept to a more expansive “stride agnostic PyTorch” concept, where we make eager mode PyTorch as whole less sensitive to stride changes. This includes a new design for Copy-on-write tensors for eager PyTorch - Google Docs which aims to eventually make the BC-breaking change to reshape()/contiguous()/etc to have these functions always return contiguous tensors. A PoC PR for the entire design exists Copy on write reshape by ezyang · Pull Request #88774 · pytorch/pytorch · GitHub and fully passes non-trunk CI, but there are some unresolved questions, such as whether or not to more deeply integrate data pointer reference counting into Storage to reduce the overall level of indirection, and whether or not the proposed warning strategy is too loud or not. This pair of proposals was discussed in the most recent Composability meeting; there were no major objections but also a desire to better understand the implications of the change.
  • Make silent errors noisy. A big reason why our aot_eager pass rate regressed this rate is we turned on more stringent error checking in the branch, to try to transform potential bugs into assertion failures. This week, we: (1) assert sizes/strides of intermediate tensors are consistent between fake and real tensors, (2) assert functional-only graph after lowering (this turns the batch norm problem we observed last week into a hard error; to bypass some of these errors, we disabled AOTAutograd from running on subgraphs with BatchNorm), (3) assert all guards correspond to tensors dynamo knows about (this flushed out a problem with symbolic shapes guards, where dynamo was not tracking enough guards; we fixed one of the problems, but that didn’t hit all of the issues, so we also have a way of suppressing these failures). Unfortunately, while these changes did nail some accuracy problems, we still have new accuracy failures on the latest model runs.
  • Inductor has problems. The branch now has some quick and dirty hacks which substantially improved the inductor pass rate (+16 working models), but there are still are many bugs that are causing lots of models to fail for similar reasons. The conclusion is that there are some fundamental pieces in the dynamic shapes inductor integration that don’t exist yet (though @Chillee still assures me they’re easy to do.) On the bright side, the uniformity of inductor errors means there probably aren’t that many distinct bugs to fix.
  • Model training status on symbolic-shapes. (@ezyang) See also Symbolic shapes work items tracker - Google Sheets
    • aot_eager, with TDS: 135 out of 163 (-2 WoW) logs csv Note that this run is skipping all subgraphs with batchnorm and with dynamo guard asserts suppressed (in the spreadsheet, this is noted as BN+IA)
    • inductor, with TDS: 24 out of 163 (+16 WoW) (logs too long) csv
    • Lowlight: jx_nest_base and twins_pcpvt_base are failing with accuracy errors. Interestingly, this pair of models previously was failing with accuracy errors without dynamic shapes; both were fixed by a single PR https://github.com/pytorch/pytorch/pull/85417 . jx_nest_base is minifiable, although the minifier failed with a reduction error on int when I tried running it (I did not try very hard). twins_pcpvt_base was passing on 10/28 and regressed into a timeout in 11/2 (this is after voz’s major dynamo change hit master); jx_nest_base has never passed.
    • Highlight: cait_m36_384 and mobilenet_v2_quantized_qat accuracy failures turned into non accuracy failures after we added DebugInterpreter to aot_eager. cait_m36_384 is now passing; no one has had a chance to investigate mobilenet_v2_quantized_qat
  • OpInfo tests on symbolic-shapes. (@ezyang)
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 388 passed (+33 WoW), 334 failed (-36 WoW), 501 skipped (+2 WoW) logs csv
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 255 passed (+15 WoW), 213 failed (-15 WoW), 129 skipped (+2 WoW) logs csv

Previous branch diff: 76 files changed, 2341 insertions(+), 533 deletions(-)
Current branch diff: 82 files changed, 2236 insertions(+), 747 deletions(-)

Notable bugs

  • Fix buggy unsqueeze_ implementation. I found this error because the unsqueeze OpInfo test was failing (hooray unit tests). I didn’t actually directly debug the bug; I just replaced the code wholesale with a straight port of the C++ code (I think the bug was in how dimension wrapping was implemented, though!)
  • Fix stride on softmax_backward_data, fixes cait_m36_384 was found via the DebugInterpreter. The interesting thing about this fix was that I had actually made a repro last week, but no one had picked it up and fixed it, and when I went to look at it again the repro no longer actually failed. Fortunately, DebugInterpreter confirmed that it was indeed a problem with softmax_backward_data.
  • More correct fix for upsample_bilinear2d decompositions is a fixup of a previous PR, where I attempted to register a composite implementation for upsample_bilinear2d.default in the Python dispatcher. I initially tried CompositeImplicitAutograd; this did not work, because this operator has an explicit autograd formula (Autograd key registration), and Autograd is higher precedence than CompositeImplicitAutograd. This is easy to work around if you know the correct semantics, but you might expect Python registrations to “override” their C++ variants; we should look into potential API changes to remove this footgun.
  • Suppress guards when constructing fake tensors was discovered by chasing down an assert failure from Dynamo when it needed to construct a shape guard involving a symbolic variable that it didn’t know the source of. The problem is that we do non-trivial tensor operations to make, e.g., view fake tensors actually views, but this means the base tensor gets its own set of fresh symbolic variables that dynamo doesn’t currently track. Fixing this in Dynamo is a fairly involved refactor, especially because we’d like some design that makes it hard for Dynamo to forget to track tensors (e.g., tie it to fake tensor conversion). In the meantime, we added a config driven via TORCHDYNAMO_IGNORE_ASSERT to allow dynamo to suppress these errors for now.
  • call call_method instead of _call_operator_builtin - operator calls in Dynamo don’t properly handle NotImplemented. Brian/Anjali tried to patch over this, but the bugfix was buggy, so Voz yanked it out from the branch again. This problem needs to be solved again properly.
  • Some infra/QoL commits were about making it easier to debug errors if things fail. For example, if exec fails to compile code, print the code that failed to compile. There are some design choices which produce unreadable debug output; for example Change how sympy guard formatting works modifies symbolic shape guards to print with newlines, instead of being mashed into a giant unholy expression.
  • Fix call_range and remove weird idiom of fallthorugh for dynamic in f…. This failed in a very strange way, and even after Edward identified what the problem was, he couldn’t figure out how to fix it (and Voz fixed it later.) It would be good to have a better understanding of what the Right™ way to fix these issues in Dynamo are.

What’s new on the branch this week?

Meta/decomp support

SymInt support

Infrastructure

Quality of life

Dynamo

Inductor

Merge to master retrospective

  • Meta registrations, Nov 7 edition, part 1 had to be split into small constituent PRs, because CI on the mega PR was failing on a fixed set of seemingly unrelated tests, even after rebase. However, the problem evaporated when everything was landed individually, so it’s not really clear what the moral of the story here is.
  • reland “fix as_strided_scatter_backward (#87646)” was reverted because it broke some trunk jobs. It looks like this was resolved by adding more xfails/skips: reland "fix as_strided_scatter_backward (#87646)" by bdhirsh · Pull Request #88342 · pytorch/pytorch · GitHub
  • Some meta function merges are slow because OpInfo coverage is insufficient. There are two common themes: first, sometimes there is an OpInfo for a given operation, but the OpInfo covers a lot of different overloads/dim specialization of the function, and we only implemented one overload in a PR. To easily test, we have to extract out one particular overload from the OpInfo, or go ahead and implement all of the overloads so the mega OpInfo works. (Arguably, this is a bad decision in OpInfo design.) Second, many missing OpInfos relate to backward functions. Hypothetically, these can be tested indirectly via the forward-backward call that test_aotdispatch.py performs, but in practice it seems to be easier to just add the backward OpInfo.

What’s made it to master this week?

What’s coming next

By person:

  • bdhirsh: fix input mutation for AOTAutograd (same as last week, progress at first draft of input mutation handling for aot autograd by bdhirsh · Pull Request #88817 · pytorch/pytorch · GitHub )
  • voz: merging the first round of dynamo patches to master; afterwards, aot autograd cache, restore builtins support in dynamo, fix symbolic shape guard construction in dynamo, refactor dynamo storage of dynamic shapes
  • ezyang: run some more experiments on copy-on-write
  • Chillee: fix sympy infinite loops, get inductor training working with dynamic shapes
  • anjali411: land more branch changes to master, run sweeps this week
  • nkaretnikov: continue working on test_proxy_tensor coverage

Our north star:

  • All benchmark models are passing aot_eager and inductor training on branch
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing
  • Dynamic shapes on by default for developers / users

State of symbolic shapes branch: Nov 20 edition

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ) is a long running branch containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch - #18 by ezyang

Commit ID at time of writing: 41c314473272b2622e20c261f19549b4bd3f1d8f

Executive summary

This week, we made two major merges into the branch: (1) @bdhirsh’s AOTAutograd input mutations PR and (2) voz’s fake tensor plumbing to AOTAutograd PR. While not all of the bugs from these two PRs have been totally resolved (and in particular, bugs in input mutations are suppressing the pass rate on the branch), the features work well enough that we only suffered a minor regression in aot_eager pass rate. Additionally, we have BERT_pytorch working end-to-end for BOTH training and inductor, meaning that on branch (but not master) we have achieved our goal for inductor. :tada::tada::tada:

  • PSA: We now have int64_t, SymInt overloads for all binary operators in C++, so you no longer have to rewrite 2 + symint into symint + 2; both work now.
  • PSA: DebugInterpreter is now actually verifying stride equality; at time of writing, this is revealing seven models which have incorrect sizes/strides. These bugs are ripe for picking!
  • We have more clarity about what is missing to properly integrate inductor with dynamic shapes (instead of all of the hacks that are currently on branch.) A big question mark is whether or not ShapeEnv should be shared between AOTAutograd’s forward and backward; when the ShapeEnv is shared, Inductor cannot necessarily infer all of the shape variables (because a shape variable may only occur in a more complex shape expression from input, e.g., s0 * 2). ref @Chillee is pivoting back to working on this, after having spent time working on channels last in the general workstream. There’s also some discussion about forward-backwards staging confusion at [aot_eager] [hf_Longformer] Cannot view a tensor with shape · Issue #1888 · pytorch/torchdynamo · GitHub
  • Some progress was made designing a more precise warning mechanism for Copy-on-write tensors for eager PyTorch - Google Docs see bottom of doc (though no implementation progress)
  • Model training status on symbolic-shapes. See also Symbolic shapes work items tracker - Google Sheets
    • aot_eager: 128 out of 163 (-7 WoW) logs csv. Regression is primarily due to AOTAutograd input mutation support hitting our branch, but having two bugs (accidentally using SymInt sizes to perform manipulations on real tensors, and needing to regenerate views of intermediates rather than returning them directly).
    • inductor: 36 out of 163 (+12 WoW) logs csv; notably, BERT_pytorch is passing with inductor now, and a spot check of generated Triton code suggests the output is dynamic: BERT_pytorch dynamic Triton · GitHub
    • End-to-end BERT_pytorch working with dynamic shapes and inductor by ezyang · Pull Request #89313 · pytorch/pytorch · GitHub demonstrates the minimal set of changes from our branch necessary to get BERT_pytorch working end-to-end on master. The PR doesn’t pass CI as is; our current thinking is to do some necessary refactors first which will simplify the final form of this PR.
  • OpInfo tests on symbolic shapes.
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 494 passed (+106 WoW), 229 failed (-105 WoW), 512 skipped (+11 WoW) logs csv. This improvement is partially due to Towards unifying symbolic and non symbolic fake tensor, which allows us to attempt C++ meta functions even if they’re not known to be SymInt-aware; it turns out many of them still work correctly anyway. This is at the cost of worse error messages when SymInt is not supported.
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 249 passed (-6 WoW), 221 failed (+8 WoW), 133 skipped (+4 WoW) logs csv. The regression here is from outstanding bugs on AOTAutograd input mutation changes on the branch; the numbers should improve once that regression is fixed.

Previous branch diff: 82 files changed, 2236 insertions(+), 747 deletions(-)
Current branch diff: 68 files changed, 2612 insertions(+), 554 deletions(-)

Notable bugs

  • DebugInterpreter actually works now: Fix cat striding in PrimTorch was identified by looking at a model failing a stride quality assert in debug interpreter. I figured out that DebugInterpreter was not correctly testing strides for correctness when I fixed a separate stride problem that was causing an accuracy failure, and then attempted to diagnose why DebugInterpreter hadn’t caught it earlier; turns out the stride matching function returns a tuple, and tuples are always True :rage:
  • Testing is important: unrelated refactoring on master broke our previously symint-ified upsample_nearest2d, and had to be fixed again in SymIntify upsample_nearest2d again after composite-ification. If we had appropriate testing on master, it probably could have been caught at regression time.
  • Simplify cudnn rnn support greatly untangles a complicated situation with cudnn SymInt support. The root cause of the problem is that cudnn_rnn is a composite function that calls a black box cudnn function to figure out what the output size of the workspace should be. This cannot be SymInt’ified, but in the original attempt to make this all work, our poor intrepid engineer tried SymInt’ifying all of the Descriptor code responsible for driving calls to the cudnn API. This patch undos all of that in favor of an earlier guard. However, this is not enough to fix tts_angular, as this eventually calls _cudnn_rnn which requires a meta implementation, but this meta implementation requires query cudnn APIs. Handling this black box correctly may require something similar to how we plan to handle missing meta functions for custom functions (guess relations between inputs and outputs, and verify the black box function acts consistently for concrete values we haven’t seen before.) Or we can just guard (but guarding requires us to be able to call into the cudnn API from our Python implementation.)
  • Set functionalization storage to size 0 for sparse tensor wrapping is a lucky hack that works around a very delicate situation, which is that functionalization really doesn’t know how to handle sparse tensors. Fortunately, we don’t actually need to functionalize operations on sparse tensors, so hacking an empty storage “works”, but it is likely to fail if someone stress tests us on sparse support.
  • first draft of input mutation handling for aot autograd is affected by two bugs right now.
    • “RuntimeError: Expected !is_symbolic() to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)” The problem here is that the patch accidentally directly tries to use fake tensor sizes (with SymInts) to reshape real tensors, accidentally propagating SymInts into real tensors (a big no no! And also missing asserts; we should add asserts.) At time of writing, this is close to fixed but not quite, using the idea of having the forward return sizes/strides/storage_offset directly as symint outputs of the forward. Most of the tests pass on the branch, but while hf_Longformer and tacotron2 no longer have the error, they’re both OOM’ing even when set a batch size of 1… this warrants further investigation. A suggestion is to use zdevito’s memory profiling tools to diagnose: Debugging PyTorch memory use with snapshots | Zach’s Blog
    • “RuntimeError: Output 0 of AsStridedBackward0 is a view of a view which was created in no_grad mode and is being modified inplace with grad mode enabled.” This is the thing where our graph should return a graph intermediate, and manually regenerate the view(s) of the intermediate in the epilogue. @bdhirsh hasn’t started working on fixing it yet.
  • Detach fake tensors into val, so they aren’t affected by metadata mut… was briefly discussed in composability; it has the potential to break users who are relying on accurate autograd metadata or tensor identity on meta[‘val’], but it turns out this doesn’t work anyway.
  • [HACK] don’t graph break on math.sqrt: we noticed that BERT_pytorch was repeatedly recompiling on every iteration. Voz tracked this down to the id of an NN module not staying stable over iterations; but it turned out that this was because we were graph breaking too much, due to an overly aggressive unsupported() call in Dynamo. Reducing the graph breaking fixed the problem!
    • Disable explain for now: we wanted to collect graph break stats to make it easier to tell if BERT_pytorch situation was happening elsewhere. Dynamo’s preexisting explain feature makes this possible. However, when we turned it on many huggingface models started failing with “AttributeError: ‘str’ object has no attribute ‘size’”. This has not been investigated yet.
  • Patch out suspicious get_item_dyn logic for jx_nest_base is an example of very long code that is also wrong. Long code for a special case is a smell, there’s probably a conceptually simpler way to do it!
  • Dynamo was affected by a lot of “missing symbol s0” assertion errors. This assertion error tests if Dynamo knows how to compute a given SymInt from the set of tensors being guarded on. These missing symbols came from a variety of places: Dynamo’s special handling of NN parameters/buffers [dynamo] [dynamic shapes] Fix register_buffer (and all module associa…, as well as dependence on _base tensors due to functionalization Record TensorReference for ._base tensors, but don’t install guards o…. The latter is fixed in a tricky way: we don’t actually ever guard on base of tensor, because according to voz this caused a lot of recompilation. More investigation necessary…
  • Remove bad string parsing assert is an example of why it is bad to do string matching on strings that represent structured data types (like language expressions).
  • Hide ConvParams struct from ConvUtils.h and the other PRs in the stack are a nice example of how a large involved fix was split into a series of manageable refactors, and then a short fix in the ending. I attempted to fix it in one go at first, but then realized there was a simpler way to do it (template the ConvParams struct.)

What’s new on the branch this week?

Meta/decomp support

SymInt support

Infrastructure

QOL

Dynamo

Inductor

Merge to master retrospective

What’s made it to master this week?

What’s coming next?

By Person:

  • Chillee: proper integration of inductor with dynamic shapes (he’ll actually work on it this week!!)
  • voz: merge aot autograd plumbing to master, with a lot of refactor
  • ezyang: maybe improved implementation of reshape CoW warnings (but only if I get some spare time)
  • bdhirsh: continue fixing input mutation for AOTAutograd (last mile two bugs)
  • jbschlosser: maybe one of the assert failures on aot_eager

Our north star:

  • All benchmark models are passing aot_eager and inductor training on branch
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing
  • Dynamic shapes on by default for developers / users
2 Likes

State of symbolic shapes branch: Dec 1 edition (even of PyTorch Conference)

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ) is a long running branch containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch - #18 by ezyang

Commit ID at time of writing: a05b7b1c73247ff562a82aac0edca79bbaebc2bd

Executive summary

It is the eve of the PyTorch Conference and we have been busy getting things ready for some big announcements. :wink: Before and after Thanksgiving, many folks involved with dynamic shapes were deputized to help fix some major release blockers in the general compiler workstream; Brian and Jane landed all of the pieces needed to properly update batch norm running stats, and Alban and Edward found and fixed some more fairly major AOTAutograd bugs. On the dynamic shapes front, Voz has been steadily working on getting all of the Dynamo changes passing CI on master; half of the preparatory changes have been landed so far, and the branch has been resync’ed after those merges. There is some regression in the aot_eager pass rate as we remove hacks and redo fixes properly.

Previous branch diff: 68 files changed, 2612 insertions(+), 554 deletions(-)
Current branch diff: 68 files changed, 1440 insertions(+), 290 deletions(-)

What’s new on the branch these two weeks?

Metas/decompositions

Infrastructure

Debug interpreter

Dynamo

Inductor

QOL

Merge to master retrospective

  • Reland “Add single process version of dynamo distributed hf_Bert tests (#89721)” - this got bounced because not enough tests ran on PR. We added more files to automatically trigger inductor tests.
  • Refactor how AOTAutograd backends are defined - this is just an example of a few cases where folks ran inductor CI, got accuracy failure on a model, and then spent a bunch of time trying to debug what had happened; when in fact, the failure was a preexisting master failure. It is not easy to identify these because ciflow/inductor does not run on every master commit.
  • Change aot_module_simplified to take take arguments directly - this broke a timm model, and lead us on a pretty big chase that eventually revealed that example inputs being passed to backends did not have correct requires grad because they were being cloned. This was fixed by refactoring the AOTAutograd-Dynamo integration to not clone example inputs.
  • Remove fake_tensor_propagation - this nearly got bounced because it broke some internal users who didn’t have fake tensor support for some operations. Averted because (1) their tests weren’t in CI and (2) it turned out to be pretty easy to add meta tensor support.
  • Don’t unsafely clone autograd meta - this couldn’t be landed because it broke an inductor model, causing it to raise an error where previously it passed. This lead to a very long debugging session by Alban until we finally nailed the problem.

What’s made it to master this week?

ezyang

bdhirsh

anjali411

nkaretnikov

voz

albanD

What’s coming next?

  • Land fake tensor propagation from Dynamo to AOTAutograd (voz)
  • ShapeEnv revamp to get guards for duck sizing (ezyang)
  • GuardEnv for non-shape related extra guards produced by AOTAutograd (voz)
  • Address CI comments for AOTAutograd input mutation, factoring it to be more modular (bdhirsh)
  • Proper inductor integration (Chillee didn’t end up working on it, unallocated; mildly blocked on ShapeEnv revamp)

Our north star:

  • All benchmark models are passing aot_eager and inductor training on branch
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing
  • Dynamic shapes on by default for developers / users
3 Likes

State of symbolic shapes branch: Dec 12 edition

Previous update: State of symbolic shapes branch - #19 by ezyang

Commit ID at time of writing: bcb284d77fe865373b2f1617867320fb32ea68af

Executive summary

We master now peeps! We are officially retiring the symbolic-shapes branch. There are still some changes that need to be merged to master (Symbolic shapes work items tracker - Google Sheets “Merge to master” sheet), but the major fixes for shape guard creation have landed to master, so all that remains on the branch are some QOL fixes (in particular the debug interpreter is no longer on by default), a little bit of op coverage, and some experimental code (esp for inductor integration) that needs to be rewritten anyway.

Previous branch diff: 68 files changed, 2612 insertions(+), 554 deletions(-)
Current branch diff: 0 files changed, 0 insertions(+), 0 deletions(-)

Notable bugs

  • It turns out checkpointing doesn’t operate on ShapeEnv, but it should. Dynamo uses checkpoints to roll back its internal state after executing an instruction (or many instructions, in the case of an inlined function call) fails, so that it can pretend those instructions never executed. However, because ShapeEnv isn’t checkpointed, shape guards that occur during the rolled back instructions still end up getting installed. This can result in a hard error if the guards refer to variables that we don’t know about from the outer context (we think hf_Reformer and swin_base_patch4_window7_224 are affected by this). Checkpointing the ShapeEnv performantly is nontrivial, as we refine the context with equalities and use that to drive sympy simplification, all of which would need to be undone. This bug is still unfixed.
  • Preserve original GraphArgs for shape guard codegen and Rewrite dynamo cond() handling to not recursively call export are both fixes for pretty interesting bugs, if I don’t say so myself. Go checkout their PR descriptions for more details.

What’s made it to master this week?

ezyang

nkaretnikov

voz

SherlockNoMad

What’s coming next?

By Person:

  • voz: Guard refactor in dynamo
  • ezyang: burn down symbolic shapes, fix bugs, work on exporting all shape expressions to form a benchmark, aot autograd default api maybe?
  • bdhirsh: continue fixing AOTAutograd v2 follow up bugs
  • jbschlosser: merge to master tasks, burn down symbolic shapes
  • unallocated: inductor integration

Our north star:

  • All benchmark models are passing aot_eager and inductor training on branch
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing
  • Dynamic shapes on by default for developers / users
4 Likes

State of symbolic shapes: Dec 19 edition

Previous update: State of symbolic shapes branch - #20 by ezyang

Commit ID at time of writing: 212873c615dd3455a24d390605335aeeebd76236

Executive summary

This week, we turned on dynamic shapes with aot_eager on CI in the inductor job. Compared with static shapes aot_eager, we only have a 17 failures difference on master! Inductor remains in bad shape in master, as we are still waiting on @Chillee to submit his PR with fixes.

In other news, @ezyang has released a benchmark for reasoning on shape computation: GitHub - ezyang/SMT-LIB-benchmarks-pytorch-shapes: SMT-LIB benchmarks for shape computations from deep learning models in PyTorch If you work on SMT solvers or like symbolic reasoning systems, check it out! It offers an easy way to test out new ideas about how to symbolically reason over shape compute. We still have a number of infinite loops in Sympy, although this week we are now just suppressing all stack overflows induced by Sympy.

  • Model training status on master. See also Symbolic shapes work items tracker - Google Sheets
  • OpInfo tests on symbolic shapes.
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 513 passed (+5 WoW), 522 skipped (no change), 227 xfailed (-3 WoW)
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 286 passed (+5 WoW), 142 skipped (+1 WoW), 203 xfailed (-5 WoW)

Notable bugs

  • Despite overhauling ShapeEnv guard production in Dynamo two weeks ago, there were still more stragglers that had to be addressed this week. The main source of problems was a mismatch between when we added a tensor to GraphArgs (as it is an FX graph input) and when we allocated dynamic shapes for a tensor (so we may need to determine the source of its symbolic shape). This lead to more refactoring in Dynamo so that we could guarantee that whenever a tensor had symbolic shapes allocated for it, we also tracked it for the purposes of guard creation. This fixed all bugs, except one(!), which @ezyang has an open PR set for (involving more refactoring.)
  • Assert for functional graph is FINALLY in master, and it caught more bugs in inductor lowerings when it landed. Hooray for more stringent asserts.

What’s made it to master this week?

ezyang

jbschlosser

voz

bdhirsh

What’s coming next?

By Person:

  • voz: vacation
  • ezyang: vacation
  • bdhirsh: continue AOTAutograd v2 follow up
  • jbschlosser: merge to master and burn down
  • Chillee: inductor integration (apparently, Horace “has a few fixes” but they’re still not posted yet)

Our north star:

  • All benchmark models are passing aot_eager and inductor training on branch
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing
  • Dynamic shapes on by default for developers / users

(SMT/verification geek here)

Can you give a bit more context about this?
Are these SMT files created by Sympy or is this an alternative code path in PyTorch that bypasses Sympy entirely?

The goal of the symbolic shapes code, I think, is two fold:

  • Do type/shape checking. This is a simple satisfiability check to ensure the program is type-safe
  • Compute symbolic shapes for each operator, with a simplified expression (whatever that means). This is not easily solvable with SMT solvers.

Also, could you give an example of a complex shape computation that requires such heavy machinery?

Happy to help with this stuff if needed, but I need a bit more context.
I’ve started to play with symbolic stuff today, with partial success. I still see some crashes, especially when mixing with torch.compile. I’m probably doing something that is not supported…

Thanks,
Nuno

Happy to! The SMT files here bypass sympy entirely; they contain exactly the shape computations that the original user program / operators had. The asserts correspond to places where people (usually operators) did control flow

The goal of the reasoning code is a little difficult to express exactly. The smt2 instances test one particular thing that is easy to formulate for SMT solvers: given a sequence of shape computations and asserts, find a satisfying assignment of input shapes that satisfies all the asserts. This could be used to, for example, take a model and regenerate a small version of it (with small parameters and inputs) for easy testing. But this is actually not really the most pressing problem for dynamic shapes in PyTorch as is, IMO. The PyTorch reasoning code tends to do two other things: find simple canonical forms for size expressions, so that we can eventually simplify the loop bound / indexing computation we generate in inductor, and determine if guards are redundant or not (since we spend a lot of time testing conditionals, which naively must be retested before we can reuse compiled code.) And we would like to do this all in a simple Python implementation, which precludes more typical approaches like throwing LLVM’s optimizer at the problem. “Simple” canonical forms is not so easy to express in smt-lib2 nomenclature; redundant guards is technically expressible but awkward (and better suited as a compiler benchmark.)

We currently have a baseline sympy+unification implementation (which I have provided a crappy hookup for) which also takes some extra simplifications I haven’t encoded in the problem instance yet (duck sizing, 0/1 specialization) and while, if you treat sympy as a black box it’s pretty simple, empirically our biggest problem is that it is slow. I’ll post some measurements I took but in some pathological cases we can spend minutes just crunching sympy simplifications. And as a team we have a bit of a disagreement about how to proceed. I kind of want to chuck sympy entirely and roll our own domain specific algebra system, but @Chillee thinks this is too much and we should be able to patch over the specific sympy badness more quickly than having to rewrite everything from scratch. We also kind of need an entailment system, which an SMT solver would provide, but we don’t have a good sense for how to do this integration. Should Z3 really be a mandatory dependency for dynamic shapes?

Re crashes, post up what you are doing. Inductor still doesn’t work on master but you should be able to play around with aot_eager

These are the timings I took on the baseline implementation, showing which models have unusually horrible perf in sympy in symbolic shapes today.

Z3 has a decent Python API. So, in theory this can be integrated easily.
Z3 is a dependency for a lot of things these days, including clang’s static analyzer (optional, but a lot of distros ship it enabled).

Z3 can also simplify expressions. It has a bunch of tactics that can be applied to expressions to simplify them. It can also do quantifier elimination, which can be useful for expression simplification.

I understand the guards & bounds checks simplification problems. Those can be mapped into entailment checks as you say.

What I don’t understand yet is why are these shape inference expressions so complicated. I was just looking at Torchy’s code that does shape+stride inference, and doesn’t look that complicated (there are min, max, equalities, mul, add, reduce). I think the worst part is that in a lot of cases you need to know the number of dimensions. If that’s also symbolic, things become very non-trivial. Do you need to give an upper bound for the dimensions? Or do you generate predicates that are valid in the unbounded case?

I’m happy to help with this stuff (designing, optimizing, whatever) if needed. I’ve experience with symbolic reasoning.

Regarding crashes:
this one issues a warning:

@torch.compile
def fn(a):
    b = a * a
    return b

# [WARNING] Unsupported: meta converter nyi with fake tensor propagation
print(fn(torch.tensor([1.0, 2.0], device="meta")))

Assert fails:

@torch.compile
def fn(a):
    b = a * a
    return b

with FakeTensorMode() as mode:
    print(fn(torch.empty([2,3])))

#   File "_dynamo/variables/builder.py", line 611, in wrap_tensor
#    assert type(value) in (torch.Tensor, torch.nn.Parameter)

Assert fails; could give a nice error message:

# without allow_meta=True
with FakeTensorMode() as mode:
    print(fn(torch.empty([2,3], device="meta")))

# File "_subclasses/fake_tensor.py", line 587, in __init__
#    assert device.type != "meta"

Another crash:

torch._dynamo.config.dynamic_shapes = True
torch._functorch.config.use_dynamic_shapes = True

@torch.compile
def fn(a):
    b = a * a
    return b

print(fn(torch.tensor([1.0, 2.0])))

#   File "sympy/core/cache.py", line 70, in wrapper
#    retval = cfunc(*args, **kwargs)
# TypeError: unhashable type: 'SymInt'
# RuntimeError: Trying to extract a concrete int out of a symbolic int

All I was trying was to print some complicated symbolic shape expressions so I could understand the problem a bit better. But then I hit all these crashes, so I must be doing something very wrong…

I definitely think Z3 as an optional dependency is not too hard a sell. The turning point would be if we could make our design substantially simpler if we could assume Z3 was always available; then the calculation would turn to whether or not PyTorch by default has Z3 as a dep, which, enhhhhh. It’s easier for a distro to do it probably.

Nope, everything is fixed dimensionality. So honestly the computations in the smt2 files are not that complicated, but they can be very repetitive because we are just tracing the real shape compute PyTorch does which was not necessarily written in a nice way for symbolic analysis.

A big gap I see from the operators you quote here is floor division (shows up a bit in pooling operations) and modulus (not sure where these come from actually.)

Definitely looking for collabs!

These are crashing for silly reasons haha. In order:

  1. Meta tensors intentionally don’t work with fake tensor (which is what PT2 will do.) In principle they actually should work fine but real world user code doesn’t actually need to optimize code computing on meta tensors, and when we were working on fake tensor it was usually a bug to try to fakeify a meta tensor, soooo yeah. @eellison we probably should fix this eventually
  2. You don’t need to explicitly fakeify before calling torch.compile; it does that for you. So what happened here is you said hey PT2 compile this with a tensor subclass input and this doesn’t work. Shouldn’t be an assert though; we should file a bug on this.
  3. This is (1) and yeah let’s beef up the medsage
  4. This one I think is because inductor and dynamic shapes is busted. So you need to make sure you don’t use inductor. Easiest is to use torch._dynamo.optimize(“aot_eager”) instead of compile

For playing around with simple examples, your best bet is to look at the tests with Symbolic in their class name in test/test_proxy_tensor.py. In particular, most of these call make_fx with tracing mode symbolic. This will give you the smallest slice of the system that is doing something interesting with dynamic shapes.

1 Like

Thank you!

Just a couple more crashes (I know, I’m a magnet for them…):

def fn(a):
    b = a * a
    if b.sum():
        return a
    return b

print(fn(torch.tensor([1.0, 2.0])))
traced_f = make_fx(fn, tracing_mode="symbolic")(torch.tensor([1.0, 2.0]))
print(traced_f)

# RuntimeError: tried to get Double out of SymFloat

If using this instead:

def fn(a):
    b = a * a
    if b.sum() >= 1:
        return a
    return b

# NotImplementedError: local_scalar_dense/item NYI for torch.bool

I’ve attempted to fix it by patching fake tensor’s local_scalar_dense:

    elif is_integer_dtype(arg.dtype) or is_boolean_dtype(arg.dtype):
        return fake_mode.shape_env.create_unbacked_symint()

But no luck; still crashes (tried to get Long out of SymInt).

Re the latest crashes, these are all due to attempting to do control flow on data dependent values (in this case, the float in the tensor). This is expected but we can make the error messages better.

State of symbolic shapes: Jan 11 edition

Previous update: State of symbolic shapes branch - #22 by ezyang

Commit ID at time of writing: f6c7cf1bf579cc42ea7e21bd557168625648a3e9

Executive summary

Welcome back from holiday and PSC folks! The holidays were sleepy, but not that sleepy: @jbschlosser landed some nice coverage wins (+2 working models on master) and @tugsbayasgalan has been working on a number of dynamic shape fixes that show up in PyTorch Edge export use cases. Since it’s the new year, you might be wondering what you ought to be doing in the new year. This status post is to help you figure out.

High priority things that need to be done (excluding export):

Things to be done sourced from PyTorch Edge export workstream (Meta only):

Things to be done sourced by generic export workstream (@SherlockNoMad)

Status report:

  • Model training status on master. See also Symbolic shapes work items tracker - Google Sheets
    • aot_eager: -15 (+2 MoM). These newly working models are thanks to work by @jbschlosser skip list
    • inductor: still really broken on master :rotating_light::rotating_light::rotating_light:
  • OpInfo tests on symbolic shapes.
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 516 passed (+3 MoM), 522 skipped (no change), 224 xfailed (-3 MoM)
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 291 passed (+5 MoM), 143 skipped (+1 MoM), 197 xfailed (-5 MoM)

Guard simplification

As you can see from the chatter before this post, Nuno is working on guard simplification. Nuno looked at resnet, and found that range analysis and a quadratic solver would solve most of the nontrivial constraints we were generating, greatly simplifying the guards we generate. However, there were also suggestions that we could simplify the guards at generation time.

Example constraint: Eq((s2 - 1)//16 + 1, 1). This is produced by empty_strided calling compute_contiguous on the passed in strides. Constraint is specifically the size_d != 1 test when testing contiguity. Should avoid guarding here ala subclass_zoo/dynamic_strides.ipynb at main · albanD/subclass_zoo · GitHub cell 18 (but need SymBool for this!)

 T z = 1;
  // NB: make sure we do signed arithmetic
  for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
    const auto& size_d = sizes[d];
    if (size_d != 1) {
      if (strides[d] == z) {
        z *= size_d;
      } else {
        is_contiguous = false;
        break;
      }
    }
  }

Example constraint: (s2 - 1)//2 + 1 < (s2 - 1)//2**2 + 2*(s2 - 1)//2 + 1. This comes from:

File "/Users/ezyang/Dev/pytorch-cpu/torch/_prims/__init__.py", line 348, in _elementwise_meta
  strides = utils.compute_elementwise_output_strides(*args_)
 File "/Users/ezyang/Dev/pytorch-cpu/torch/_prims_common/__init__.py", line 407, in compute_elementwise_output_strides
  comparison = should_swap(perm[dim0], perm[dim1])
 File "/Users/ezyang/Dev/pytorch-cpu/torch/_prims_common/__init__.py", line 387, in should_swap
  if stride_a < stride_b:

The easiest fix is probably to make sure we don’t run the sorting algorithm when we have contiguous inputs. But even better would be to introduce a contiguity guard (which tests that a tensor is contiguous), which should be able to eliminate these guards entirely.

You can reproduce these experiments with the following code:

model = torchvision.models.resnet18(weights="DEFAULT")
model = torch._dynamo.optimize("eager")(model)
model.eval()
model(torch.rand([64, 3, 7, 7]))

Something that is not so easy right now is finding out what produced guard expressions; e.g., I see x ** 2 // y blah blah, where did it come from? More detailed logging at the Dynamo per-op level would help.

What’s made it to master since last time?

ezyang

voz

jbschlosser

tugsbayasgalan

nkaretnikov

What’s coming next?

  • ezyang: catching up on code review
  • voz: changing dynamo backend api to take aot autograd directly
  • bdhirsh: on vacation until next week
  • Chillee: inductor integration

Our north star:

  • Dynamic shapes enabled by default for PT2 release
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators :rotating_light::rotating_light::rotating_light:
  • All OpInfo tests passing
  • Dynamic shapes on by default for developers / users
1 Like

Mini-update: Status of unbacked SymInts on Jan 16

I recently presented our progress on unbacked SymInts, our strategy for data-dependent output sizes, in the most recent composability meeting (meeting notes: Composability meeting notes - Google Docs , Meta only recording: Redirecting...). This status post will recap what I described in the meeting, and also explain what you should expect on unbacked symints in the near future.

tl;dr I (@ezyang) will be deprioritizing proper support for unbacked SymInts, because it looks like there are fundamental infrastructure improvements in Sympy reasoning and tracing performance we need work on first. Also, unbacked SymInts are not launch blocking for dynamic shapes in PT2. Fortunately, we have identified a few short term unblocking steps that can help immediate users of unbacked SymInts make progress, albeit at the cost of some slight unsoundness (which we don’t expect to matter in practice.)

Background

In PyTorch’s tracing model, we ordinarily try to treat the shapes of input tensors symbolically. However, if we need to perform control flow on an expression involving one of these symbolic sizes, we peek at the true values to resolve the condition to true/false, and install a guard saying that are trace is only valid if the condition evaluates equivalently in the future. This is motivated by the fact that in a tracing framework, we cannot easily trace both sides of the conditional (you could use something like thermometer continuations to run the trace as if the condition evaluated true, and then rewind and rerun the trace as if the condition evaluated false, but you still have the problem of a combinatorial explosion of possible paths you could go down.)

Guarding works well for statically known sizes, but if you call an operator like torch.nonzero, it will produce a size that is only known at runtime. Our idea for how to handle this case is simple: we produce a symbolic size that has no underlying value (aka is “unbacked”), and instead error if we attempt to guard on this symbolic size.

Some further reading that you may find helpful: subclass_zoo/dynamic_strides.ipynb at main · albanD/subclass_zoo · GitHub (about dynamic shapes and strides) and subclass_zoo/dynamic_shapes.ipynb at main · albanD/subclass_zoo · GitHub (about dynamic shapes in general)

Current state

Here is the state of unbacked SymInts on master:

Where exactly are these guards coming from? Here is a mostly complete accounting:

    case MemoryFormat::Contiguous: {
      // dim_ is a virtual call, don't repeat it
      const auto dim_ = dim();
      extra_meta_->strides_.resize(dim_);
      if (dim_ > 0) {
        const auto last_idx = dim_ - 1;
        extra_meta_->strides_[last_idx] = c10::SymInt(1);
        for (auto i = last_idx - 1; i >= 0; --i) {
          extra_meta_->strides_[last_idx] =
              extra_meta_->strides_[i + 1] * extra_meta_->sizes_[i + 1].max(1);
        }
      }
  • When we construct a TensorImpl, we have to compute whether or not it is contiguous. It turns out that we don’t short circuit this computation when torch.empty (whose output is guaranteed to be contiguous) is called. This computation branches on size and stride:
  for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
    const auto& size_d = sizes[d];
    if (size_d != 1) {
      if (strides[d] == z) {
        z *= size_d;
      } else {
        is_contiguous = false;
        break;
      }
    }
  }
  • Even if we were to shortcut contiguity call, we still need to compute whether or not the tensor is channels last contiguous. This is because some contiguous tensors are ALSO channels last contiguous, e.g., (1, 1, 1, 1). This computation branches on size and stride in a similar way:
      T expected = 1;
      for (auto& d : {1, 3, 2, 0}) {
        const auto& size_d = sizes[d];
        if (size_d != 1) {
          if (strides[d] != expected) {
            return bool_is_channels_last_contiguous(false);
          }
          expected *= size_d;
        }
      }
      return bool_is_channels_last_contiguous(true);
  • We also compute whether or not a tensor is non-overlapping and dense. The computation here is quite involved: we have to do a sort on the strides (sorting network, anyone?) But fortunately, it also can be short circuited in the common case (since a tensor is definitely non-overlapping and dense if it is contiguous.) However, it cannot always be short-circuited; for example, if you allocate a tensor with an unbacked SymInt size and then take a view on it, the view may not be contiguous, and so you have to do the full computation in that case. This is exactly what happens in the case of boolean indexing (in the meta implementation of index.Tensor). One extra problem is the call to nonzero here is in library code, so if you want to add asserts on the result of index, you can’t easily do this.
            if index.dtype in [torch.int8, torch.bool]:
                nonzero = index.nonzero()
                k = len(result)
                check(
                    k + index.ndim <= self.ndim,
                    lambda: f"too many indices for tensor of dimension {self.ndim}",
                    IndexError,
                )
                for j in range(index.ndim):
                    check(
                        index.shape[j] == self.shape[k + j],
                        lambda: f"The shape of the mask {index.shape} at index {i} "
                        f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
                        IndexError,
                    )
                    result.append(nonzero.select(1, j))

To support “proper” unbacked SymInts, we must modify our framework code to avoid guarding in all of these cases. We also need a mechanism by which users can insert extra runtime assertions to ensure that if a guard is unavoidable (e.g., size >= 0) but will always resolve one direction, we can manually guide tracing down one branch of the guard and have a runtime test to verify that we would have gone down that path at runtime as well.

Progress towards unbacked SymInts

The diff stack at Switch is_contiguous to use guardless implementation by ezyang · Pull Request #92234 · pytorch/pytorch · GitHub was my attempt to directly remove all of these guards. The infrastructure pieces (first three PRs) were not too bad (and I intend to land them), however, they do not actually remove the guards. The actual guard removal has run into at least two problems:

  • When I remove the stride computation guard (max between size and 1), it makes test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 take a long time to run. This suggests that one of the guards from stride computation was essential for simplifying our size variables and prevented Sympy from doing a ton of unnecessary work.

  • When I remove the compute contiguous / non-overlapping guards, along with causing problems with max_pool2d, it also makes test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 take a long time to run.

While I could believe that the ultimate fixes for these two problems could be quite short in the end, the way we thread the needle with our Sympy processing is quite opaque to me, and I don’t feel comfortable in being able to discover the fixes for these problems in that amount of time. Combined with our deadline of getting dynamic shapes turned on by default for the PT2 release (), as well as the fact that we need to make investments to speeding up Sympy simplification (e.g., Nuno’s suggestions), I think it makes sense for me to deprioritize making full unbacked SymInts work in the short term. If a brave soul wants to step up to try to fix these problems, I can help give advice.

The short term unblock

The Executorch folks are attempting to export models with symbolic shapes in them. If we don’t have true unbacked SymInts, what can you do? A lot, it turns out. Here are the workarounds I suggest trying, from least sophisticated (aka most unsound), to most sophisticated (but requiring more work).

  1. If you need to implement a meta function for an operator with dynamic output size, an extremely simple first step is to simply make the function return a fixed size (zero, or the maximum possible value, are good choices) instead of an unbacked SymInt. This is unsound for the following reasons: (1) the fixed size may get burned into the graph (this is more likely to happen with zero; the max size is likely itself a symbolic size, so the burn-in in that case is less severe), (2) conditionals will be burned in according to this size, which means you may miss noticing a branch that actually needs to be replaced with a cond() operator to allow both branches to be exported. However, this is really good for “getting past” a missing output meta and seeing what needs to be done next.

  2. The next step up from (1) is to return an unbacked SymInt, but allow guards to be resolved with respect to an example size in certain contexts. The basic idea is to return an unbacked SymInt, but then “mock up” size_hint so that when we try to guard on the unbacked SymInt, we make use of an example size instead. This prevents some burn-in (as we are still passing around a fresh, unbacked SymInt for tracing purposes); we only burn-in conditional paths. A user could then use logging / manual inspection to see if any conditional paths need to be treated specially. You can also only mock in certain fragments of code (e.g., while executing a tensor factory) where the example value is known to generalize for all possible values. I can help implement this, though I’m not exactly sure when to do it.

We can also in parallel implement the mechanism for adding runtime asserts, although without better guard simplification framework (e.g., Nuno’s range analysis) it is difficult to use these asserts to actually resolve guards on unbacked SymInts.

The long term plan

To be honest, I have not committed to a particular course of action for how to solve these problems. I can think of two plausible trajectories.

The conservative route. We land my PR stack as is, and try to hotfix the particular simplification problems case-by-case. Hopefully, they aren’t too hard to resolve and there aren’t too many of them.

The blow-it-up route. We rewrite ShapeEnv’s symbolic reasoning logic from scratch without Sympy or using Sympy in a much more limited fashion, so that algorithmically it is obvious that we always get good performance. This would also help resolve performance issues in tracing from spending too much time in Sympy (according to @voz , in Bert we spend 30% of our time in Sympy simplification.)

Open to comments from the peanut gallery. Honestly, this is going to depend a lot on who ends up doing this work.

2 Likes

How about learn from JAX and add an optional static-sized variant for nonzero, so at least some powerful users could get past this problem? AFAIK Tensor indexing still uses nonzero so we could unblock those use cases if users are willing to change their code.

I’ve raised this 18 months ago: [JIT] Support JAX-style statically shaped nonzero to avoid host-device synchronization · Issue #62320 · pytorch/pytorch · GitHub. Seems still relevant to me.