State of symbolic shapes branch

State of symbolic-shapes branch: Sep 17 edition

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ; torchdynamo: [WIP branch] symbolic shape hacking by Chillee · Pull Request #1180 · pytorch/torchdynamo · GitHub) are long running branches in PyTorch/torchdynamo containing a large number of features and bugfixes related to dynamic shapes support in PyTorch.

Commit IDs at time of writing: pytorch e508e5ce3adaa3464f210e26e738e53d4ec4718c; torchdynamo 3ddb46e873c2bdd1c59217a128b9b2b7af8696fe

Executive summary

We started this branch three weeks ago, to move more quickly on adding dynamic shapes support to PyTorch, as getting past master CI was a bottleneck for our work. We made a lot of progress: this branch successfully runs pytorch_BERT forward/backward with the no-op AOTAutograd backend, and the forward mode is compileable by Inductor, producing a kernel that we have verified works with varying batch sizes without inducing recompilation.

From this work, we discovered tracing with dynamic shapes is quite slow. Over the last week, we made a lot of progress optimizing tracing time overhead (on devfair040, we went from 116.69s to 56.68s for E2E pytorch_BERT forwards-backwards on aot-nop), which stemmed from generating too many FX nodes for symbolic ints.

Based on our progress and discussions with Jason Ansel, we are now pivoting to enabling dynamic shapes by default for all our benchmark models, so that we can start exercising inductor on dynamic shapes. This involves merging the symbolic-shapes branch to master and blitzing operator support for torchbench.

Current branch diff: 74 files changed, 2880 insertions(+), 506 deletions(-)

What should I expect to work?

We are currently using the following command to test end-to-end:

TORCHDYNAMO_DYNAMIC_SHAPES=1 AOT_DYNAMIC_SHAPES=1 python benchmarks/torchbench.py --only BERT_pytorch --accuracy-aot-nop --training

What’s new in the branch this week?

NB: some changes were added to branch and then merged to master; those are listed in merged to master section only.

What’s made it to master this week?

What’s coming next?

High priority:

  • Merge into master. Items to merge are annotated with comments on the symbolic-shapes PR. This includes some subtasks:
    • Get min-cut partitioner working with symbolic shapes
    • Figure out what tests are failing on the branch
  • Get fake tensors and AOTAutograd working / Integrate inductor/dynamo dynamic shape analysis “properly” (related to sharing fake tensors) https://github.com/pytorch/pytorch/pull/85233
  • Full operator coverage for all benchmark models
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators (can be done later)

Low priority:

  • Figure out why accuracy fails on E2E BERT
  • Get inductor working E2E with training on BERT
  • Get hf_BERT working (pytorch_BERT is different
    (Low priority atm) Get more models working
8 Likes

Awesome work!

I’m wondering if this project might be helpful to support dynamic shape for LTC, given that many ops have been migrated to support SymInts now.

There is a decent amount of LTC dynamic shapes logic that exists, but it was mostly added to appease the build/CI, we aren’t working on actually making LTC work end to end.

State of symbolic shapes branch: Sep 25 edition

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ; torchdynamo: [WIP branch] symbolic shape hacking by Chillee · Pull Request #1180 · pytorch/torchdynamo · GitHub) are long running branches in PyTorch/torchdynamo containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch

Commit IDs at time of writing: pytorch 538031e232635ce1cd8c8d3ec54f4da14142d4d8; torchdynamo 3ddb46e873c2bdd1c59217a128b9b2b7af8696fe (unchanged)

Executive summary

We spent a lot of time this week merging changes to master, reducing the number of inserted lines in the branch by 50%. This merging process was not entirely smooth; more details in the next section. We did not make much progress in increasing operator coverage; however, Alban Desmaison and Anjali Chourdia are onboarding onto the project, and Nick Korovaiko is pivoting to working on model enablement, so we expect the pace to pick up soon.

We are currently working on updating the runbooks for symbolic shapes. Here are the relevant docs, in various states of up-datedness:

Previous branch diff: 74 files changed, 2880 insertions(+), 506 deletions(-)
Current branch diff: 60 files changed, 1464 insertions(+), 389 deletions(-)

Low light this week: master currently has a 5x regression on trace time due to https://github.com/pytorch/pytorch/pull/85239 ; a portion of this change is currently reverted on the branch by Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub

Retrospective on merge to master

In this section, I want to call out unexpected complications that occurred while merging PRs to master.

  • OpInfo for Slice - adding a new OpInfo for a torch.ops OpOverload caused a lot of spurious failures in preexisitng OpInfo tests, which assumed that all operators existed in torch.X namespace. To land this PR, we modified many of the offending OpInfo tests to ignore OpInfos; contrary to our initial expectations, this was not that painful (using -k slice to filter tests to only run our new slice OpInfo helped a lot.)
  • More SymFloat support - this triggered public bindings test, but due to Test public bindings in CI gives weird output on error · Issue #83393 · pytorch/pytorch · GitHub the CI failure is difficult to understand. Furthermore, many of us compile with distributed disabled; but the test does not run when distributed is disabled. We root caused the cause of the public bindings failure (it’s because CI retries the test) and also fixed the test to run even when distributed is disabled.
  • Correctly handle duplicate arguments to AOTAutograd - the initial version of this patch was fine, but the suggested refactor to move it further up the stack failed due to unclear invariants about whether or not None tensor arguments are allowed in AOTAutograd. A careful refactor of AOTAutograd here should help. Note that this patch was not strictly needed for dynamic shapes, but we had to fix this underlying bug to land Setup fake tensor and symbolic shapes once at beginning of AOTAutograd
  • as_strided symbolic support - This patch uncovered two problems. First, our solution to not require XLA module updates was not actually working, because call_fallback_fn utilized at::_ops interface, whose type eagerly updates when you change them to SymInt. We changed this function to strip SymInt (use call_fallback_fn_symint to preserve symints), and we were successfully able to land this PR without an XLA submodule update. Second, this PR triggered a Valgrind error on old versions of clang; after a day of investigating, we decided that this was a clang9 bug and disabled Valgrind on this version of clang. However, it also indicates that we perhaps shouldn’t be passing c10::SymInt by value, which is what we currently do.
  • Ported matmul compositeimplicitautograd impl into core also turned on Python dispatcher by default, which triggered a latent bug in linalg_lstsq we had to fix in Fixed memory issues in linalg_lstsq. The memory corruption was difficult to diagnose and it was eventually solved by reading the code and noticing improper overwriting of Tensor&. It would be good to audit the codebase for further instances of assigning over mutable tensor reference.
  • Symintifying slice ops triggered vulkan failures as a Vulkan test was passing MIN_INT as an argument, which is not representable as SymInt, though for no good reason. This suggests that perhaps we should also allow extremely negative integer constants. An added confounder on the diff was the PR author attempted to fix the problem by changing at::slice into at::native::slice, which caused more problems as the Vulkan slice implementation was bypassed.

Overall, it looks like we were able to address the root cause for most of the problems we encountered when landing PRs, which suggests that future landing should be smoother.

What should I expect to work?

The end-to-end command for BERT_pytorch is unchanged:

TORCHDYNAMO_DYNAMIC_SHAPES=1 AOT_DYNAMIC_SHAPES=1 python benchmarks/torchbench.py --only BERT_pytorch --accuracy-aot-nop --training

Note that if you are on a more recent commit of dynamo (which is OK if you’re not using inductor), the command line flags have changed. You will instead have to run:

TORCHDYNAMO_DYNAMIC_SHAPES=1 AOT_DYNAMIC_SHAPES=1 python benchmarks/torchbench.py --only BERT_pytorch --accuracy --backend aot_eager --training

We recommend using TORCH_SHOW_CPP_STACKTRACES=1 and TORCH_SHOW_CPP_STACKTRACES_WITH_LINENO=1 for more informative C++ stack traces.

What’s new in the branch this week?

As always, merged to master PRs are only listed in the master section.

What’s made it to master this week?

NB: cpp pytree was removed from the branch, without merging to master

What’s coming next?

  • Add SymInt to Scalar Add SymInt to Scalar by eellison · Pull Request #84958 · pytorch/pytorch · GitHub
  • Get testing for backwards and symbolic shapes working (assigned to @Chillee )
  • More merge to master. Subtasks from last week still apply:
    • Get min-cut partitioner working with symbolic shapes
    • Figure out what tests are failing on the branch
  • Full operator coverage for all benchmark models
    • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators (can be done later)
2 Likes

State of symbolic shapes branch: Oct 2 edition

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ; torchdynamo: [WIP branch] symbolic shape hacking by Chillee · Pull Request #1180 · pytorch/torchdynamo · GitHub ) are long running branches in PyTorch/torchdynamo containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch

Commit IDs at time of writing: pytorch eea837d4fa2f287c6d8633387c5fa7e7ed8dc9e9; torchdynamo 3ddb46e873c2bdd1c59217a128b9b2b7af8696fe (unchanged)

Executive summary

We made a lot of progress:

Runbooks:

Previous branch diff: 60 files changed, 1464 insertions(+), 389 deletions(-)
Current branch diff: 41 files changed, 1044 insertions(+), 448 deletions(-)

Retrospective on merge to master

What were unexpected complications when merging PRs to master?

  • Enable convolution_backward with bias and symints required adding support for another SymInt type (SymInt[]?). Fortunately, this was relatively symmetric with the other type support we’ve added, but it took Nick a day or so to do the first pieces, and there were two more subtle bits that Nick didn’t manage on his first try. First, aten/src/ATen/core/jit_type.h needed to be implemented with care: whenever a SymInt-including type (or container type) is written for getTypePtr_, you must instead implement it on getMaybeFakeTypePtr_ and ensure you either recursively pass on the fake template argument, or resolve it to be an int (if fake is true) and a SymInt (if fake is false). This lets us maintain the illusion for TorchScript that SymInt C++ types turn into conventional int-only schema. Second, aten/src/ATen/core/boxing/impl/make_boxed_from_unboxed_functor.h needed another case to allow for OptionalArray to be provided by an int-list IValue (covariant typing). We hope that this is the very last C++ type we need to add for SymInts (we now have SymInt, SymInt[], SymInt? and SymInt[]? which covers all used permutations of optional and array types on SymInt; there are no uses of int?[] in native_functions.yaml)
  • Consistent compute numel/contiguous strategy with SymInts involved refactoring some code that introduced an unsigned underflow bug; previously we had an expression int64_t i = dim() - 1, where dim() returned an int64_t, which was refactored into int64_t i = sizes.size() - 1, where sizes.size() returned a size_t (unsigned type). When size is zero, this underflows into a large number when then triggers signed conversion UB (as the very large unsigned underflow is not representable as a signed integer). The UB appears to resolve to a negative number as desired on most platforms, but not on Android. It is not clear why UBSAN did not catch the error on conventional platforms. This suggests we need more unit testing of TensorImpl (not done yet.)
  • Revert “Revert “Symintifying slice ops (#85196)”” was reverted for breaking Executorch fbcode-only builds. This is because Executorch was translating SymInt JIT type into SymInt C++ type, but the Executorch runtime does not (and should not) support SymInt. This was fixed by refactoring the codegen to offer an “int-only codegen mode” (by passing symint=False argument to relevant functions), and then having Executorch opt out of SymInt codegen.
  • Revert “Revert “Symintified mmm/addmm derivative formulas (#85794)”” was reverted because it broke Mac AOTAutograd tests. It was resolved by marking linalg.lu_solve as flaky, but it’s not clear what the root cause was.
  • removed compile cache and static argnums didn’t cause any problems on the way in, but we got complaints from users that they were relying on the cache, and so removing it broke their code, e.g., Functorch memory_efficient_fusion gives wrong output batch size · Issue #86020 · pytorch/pytorch · GitHub

Though not directly caused by the work on symbolic shapes, Edward had to spend some time this week fixing an ios SEV caused by insufficient CI for Metal backend: Insufficient CI for metal backend · Issue #84172 · pytorch/pytorch · GitHub This is relevant to us because we caused a similar ios SEV by changing the signature of C++ functions; however, we subsequently made changes that don’t require us to update C++ signatures, reducing the likelihood of a repeat SEV along these lines. There is also now a PR up to test Metal in OSS CI.

What should I expect to work?

The end-to-end command for BERT_pytorch is unchanged:

TORCHDYNAMO_DYNAMIC_SHAPES=1 AOT_DYNAMIC_SHAPES=1 python benchmarks/torchbench.py --only BERT_pytorch --accuracy --backend aot_eager --training

We recommend using TORCH_SHOW_CPP_STACKTRACES=1 . The other models listed in Operators that need symbolic support - Google Sheets are expected to work, but we do not have CI at the moment, so some minor regressions may occur during development.

What’s new in the branch this week?

As always, merged to master PRs are only listed in the master section.

What’s made it to master this week?

PRs marked with X have discussion in “Retrospective on merge to master”

What’s coming next?

  • Improved testing
    • Make sure we are testing meta functions that aren’t registered to the dispatcher
    • Improve stride testing coverage
  • More merge to master.
  • Get min-cut partitioner working with symbolic shapes (we are in better shape thanks to @wconstab but there are still problems)
  • Full operator coverage for all benchmark models (as infra is nearly all landed on master, it is a good time to start also merging these operators into master)
  • Consolidated guard strategy across torchdynamo/AOTAutograd/torchinductor (@Chillee is probably going to work on this)
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators (can be done later)
2 Likes

State of symbolic shapes branch: Oct 7 edition

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ) is a long running branch containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch - #4 by ezyang

Commit IDs at time of writing: pytorch 06b089d271c4e9844d6aa31e93f0c4a07b044ec8

Today’s update is early because I’m going on PTO next week.

Executive summary

Technical accomplishments:

  • pytorch_BERT forward-only is working on master. We are in striking distance of pytorch_BERT training working on master, pending a consolidated guard story across dynamo-aotautograd-inductor (@Chillee) and partitioner fixes (@wconstab; see https://github.com/pytorch/pytorch/pull/86425).
  • We increased the number of torchbench models from 10 to 36 (out of 48) passing training with dynamic shapes enabled. Thank you @wconstab for performing nightly runs of the entire model suite to track our progress; you can see the run results in Operators that need symbolic support - Google Sheets . We have definitely reached critical mass of operator coverage @anjali411 has started looking at HF models.
  • OpInfo tests on the branch:
    • For test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive, we are at 297 passed, 272 failed, 35 skipped
    • For test_aotdispatch -k test_aot_autograd_symbolic_exhaustive, we are at 194 passed, 287 failed, 123 skipped (this test has more fails, as it exercises backwards as well)
  • @albanD removed the getitem monkey hack from the branch, and @wconstab merged the WIP partitioner change to the branch, meaning the the branch, in principle is fully correct (in practice, there may be bugs, as branch code is not CR’ed yet.)

Procedural news:

  • With a lot of new activity on the branch, we’ve introduced a Kanban-style board Operators that need symbolic support - Google Sheets to keep track of what people are actively working on.
  • Master merges are lagging behind activity on the branch, so our diff increased (mostly because people have been adding more code to the branch.
  • It is now possible to template over functions that tensor operations using at::symint::op (where T can be int64_t or SymInt). This makes it easier to define compatibility native::op definitions for BC.
  • We fixed OptionalSymIntArrayRef support for good, these ops should now work
  • If you want to get CC’ed on all dynamic shapes PR review, open a PR like this add myself for dynamic shapes PR review

Previous branch diff: 41 files changed, 1044 insertions(+), 448 deletions(-)
Current branch diff: 39 files changed, 1309 insertions(+), 233 deletions(-)

How to help

Runbooks:

Run a model with dynamic shapes with:

TORCHDYNAMO_DYNAMIC_SHAPES=1 AOT_DYNAMIC_SHAPES=1 python benchmarks/torchbench.py --only BERT_pytorch --accuracy --backend aot_eager --training

Many models are working so you do not have to only test BERT_pytorch. You can also swap torchbench.py for the other benchmark suites. If your model is failing in an unusual way, try disabling TORCHDYNAMO_DYNAMIC_SHAPES=1 first to see if that solves the issue.

Retrospective on merge to master

What were unexpected complications when merging PRs to master?

  • Symbolic shapes mega merge PR (Oct 3) (@ezyang, with contributions from @bdhirsh and @anjali411) - this land raced with a different enablement PR, and as a result master was broken with “unexpected successes” (some operator tests required both changes to start passing). This was fixed by quickly landing a forward fix
  • Ported reshape to symints and added a shim for BC (@Chillee) - the original version of this patch broke internal uses of native::reshape. We introduced a dedicated header for the old definition and duplicated the code in question to keep old sites working. With at::symint namespace you can now also template this code instead of duplicating it. This diff also internally broke Executorch, as Executorch defines custom out operators that must be synchronized with their internal variants. Thanks David Bort for forward fixing it.
  • A set of three PRs from different authors was unlanded, because the first PR renamed native::narrow breaking internal call sites (similar to native::reshape above), and then the other PRs had to be unlanded as they depended on the first PR to be landed. We had some discussion with the GH1 oncall @seemethere about what a good protocol for this situation should be. @albanD was appointed merge captain and will be solely responsible for merging enablement PRs.
  • A PR was incorrectly identified as breaking internal ARVR Windows builds (in fact, it was an unrelated and longstanding breakage.) @ezyang spent some time getting a Windows build with an alternate version of MSVC on this internal config (see FB only: Redirecting... ) and fixed the root cause error.

Overall, we didn’t have any really nasty merge to master problems (certainly far less than last week), and our pains are mostly from the rapid tempo at which we want to be merging things to master.

What’s new in the branch this week?

Unlike previously, I’ve opted to include all changes, even if they were merged into master, as PRs merging to master are now at coarser granularity so it is not obvious that there’s overlap.

What’s made it to master this week?

PRs marked with X have discussion in “Retrospective on merge to master”

What’s coming next

1 Like

State of symbolic shapes branch: Oct 16 edition

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ) is a long running branch containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch - #5 by ezyang

Commit ID at time of writing: 95cb550231fd36e0fb0e3283c033dee384e3397b

Executive summary

  • 15 out of 48 torchbench models are passing with training on master (data collected by @wconstab); compare with 36 out of 48 on branch from last week. This means, modulo inductor, we have hit :tada: :tada: :tada: our goal for dynamic shapes (the goal was “10+ e2e (including BERT) training demonstration on PyTorch master (no recompilation with variable batch size)”)
  • @Chillee got 9 torchbench models to run E2E in inference with inductor on Monday. The dynamic shape aware timings are comparable, although in some cases slower. These numbers don’t include compilation time.
  • OpInfo tests on branch:
    • For test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive, we are at 305 passed (+8 week over week (WoW)), 267 failed (-5 WoW), 35 (unchanged) skipped
    • For test_aotdispatch -k test_aot_autograd_symbolic_exhaustive , we are at 209 passed (+15 WoW), 271 failed (-16 WoW), 127 skipped (+4 WoW). The new skips are probably nn.functional.batch_norm (0 is not tracked with proxy) and some more operators identified as having data-dependent control flow.
  • Notable bug fixes:
  • Nick Korovaiko is transitioning off dynamic shapes and moving to help with inductor burndown.

Previous branch diff: 39 files changed, 1309 insertions(+), 233 deletions(-)
Current branch diff: 30 files changed, 1209 insertions(+), 225 deletions(-)

We briefly were at <900 insertions on Monday, before reverts and more pushes to the branch brought it up again.

Retrospective on merge to master

How to run models E2E

Dynamo has merged into PyTorch repo, so the benchmark instructions are simplified:

TORCHDYNAMO_DYNAMIC_SHAPES=1 AOT_DYNAMIC_SHAPES=1 python benchmarks/dynamo/torchbench.py --only BERT_pytorch --accuracy --backend aot_eager --training

What’s new on the branch this week?

Like last week, all changes are included even if they were merged into master

What’s made it to master this week?

Some PRs were merged by not their authors; the original authors are noted in parentheses

Currently open PRs

What’s coming next?

The prime directives (these haven’t really changed):

  • E2E training on master with inductor.
    • Plumb fake tensors up to torchdynamo
    • Plumb ShapeEnv guards up to torchdynamo guards
    • Resolve strategy for sharing ShapeEnv between forward and backwards (support passing in symint as input?)
  • Full operator coverage for all benchmark models on the branch
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing

Some miscellaneous tactical stuff:

  • Redundant guards involving FloorDiv are not simplifying away (seen in resnet18, discovered by @Chillee)
  • Fix PT with torchdeploy/multipy by making Python op registration work with multiple Python interpreters (@ezyang)
  • Get item() tracing working with symbolic floats for Executorch tracing (Michael Voznesensky)
2 Likes

State of symbolic shapes branch: Oct 22 edition

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ) is a long running branch containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch - #7 by ezyang

Commit ID at time of writing: 5f11aa560bdb406e6826e355edf19bda6174d63f

Executive summary

We focused on model enablement this week. We’re starting to hit the last mile on training, which means our pace is slowing as we spend time fixing more complicated bugs, though the pace on the branch is still faster than the rate we are merging code to master.

  • This major bug deserves a bullet on its own: we identified that torchdynamo was over-suppressing errors (even when dynamic shapes was off), artificially inflating the PASS rate of the dynamic shapes dashboard. This was fixed in [dynamo] Unify raise_on_* config to suppress_errors and raise by default by ezyang · Pull Request #87440 · pytorch/pytorch · GitHub . This update post will report training status with errors suppressed, but in subsequent update posts we will stop suppressing errors for a more accurate depiction of our status. (On the plus side; error suppression would not have affected inductor speedup numbers, so even if we were skipping blocks to compile, we are still getting good speedups.)
  • Model training status on symbolic-shapes. (run by @ezyang); see also Operators that need symbolic support - Google Sheets for hand audited results (not completely up to date atm, but actively used for task assignment). This run was done prior to fixing error suppression, so they suppress errors.
    • torchbench: 44 out of 55 (+8 WoW)
    • huggingface: 34 out of 44 (new)
    • timm: 38 out of 62 (new)
  • Model inference status on master (run by @ezyang); these runs are all WITHOUT suppressing errors. Pass rate is artificially depressed as PyTorch the run was done with was built without numpy support.
    • torchbench inductor: 27 out of 46 (+18 WoW)
    • torchbench aot_eager: 35 out of 46 (+20 WoW)
    • torchbench aot_eager, without dynamic shapes (baseline): 39 out of 46 (new)
  • OpInfo tests on symbolic-shapes (@ezyang)

Previous branch diff: 30 files changed, 1209 insertions(+), 225 deletions(-)
Current branch diff: 70 files changed, 1681 insertions(+), 430 deletions(-)

Notable bug fixes

  • [HACK] Don’t clone, instead don’t clobber proxies. This is a multi-week saga. The beginning of the story is that we noticed partitioning was failing, because occasionally size computation in the forward pass depended on size expressions on backward inputs. @wconstab identified that this was because we were overwriting proxies when a SymIntNode was reused on multiple tensors. To fix this, Will introduced a clone when we set SymInts into tensors (Clone symint on set_sizes_and_strides by wconstab · Pull Request #85878 · pytorch/pytorch · GitHub), so that when we assign proxies after a tensor returned by an operator, we would always have fresh SymIntNodes, preventing overwriting. However, after this fix landed, we noticed that occasionally SymIntNodes would show up that didn’t have any proxies at all! Brian fixed one particular instance of this (functionalization: skip meta reference compute for aot autograd), but there were others. We still don’t know how this situation arises (the minifier didn’t work), but @ezyang is testing an alternate fix on the branch, where instead of cloning SymInts, we simply avoid overwriting proxies if one already exists. This is a hack because it means we can’t faithfully report user programs (e.g., if you write r = torch.add(x, y); s = r.size(0), this might end up reporting as s = x.size(0)), but with the few days of testing it seems to have fixed the problem and not caused any other regressions on benchmark models.
  • Fixed FakeTensor not calling CompositeImplicitAutograd decomps sometimes. @anjali411 noticed that one of her models was stuck in an infinite loop involving _to_copy decomps. She and @ngimel unsuccessfully tried to debug it. @Chillee eventually pinned down the root cause by tracing through each decomposition the loop went through, and breaking the loop with a small change to fake tensor. This spawned a discussion about our use of decomposition tables being too complicated, which @SherlockNoMad has been working to refactor.
  • fix minifier and minor improvements to minifier. In many situations, the minifier would fail to minify errors involving dynamic shapes. These fixes make more programs minify successfully; many bugs fixed this week were fixed with help from the updated minifier.
  • Added some hacks to unblock hf_reformer. This bug manifested as another infinite loop in Sympy. @Chillee put in some hacks to fix it; I don’t really understand how it worked.
  • Support symbolic as_strided_, and make unsqueeze_ correct. The bug here is pretty simple: unsqueeze_ didn’t actually modify the tensor inplace. It was very difficult to diagnose without the minifier; after minification, it became clear that there was some unsqueeze_ shenanigans. @albanD helped by remembering that unsqueeze_ was implemented incorrectly on the branch. Going forward, I request people avoid committing known wrong logic to the branch.
  • Correct dtype calc when arange gets SymInt as argument. This was tricky to diagnose because arange passes fine (with an incorrect return dtype), there is a graph break, and then the program finally fails in another subgraph. The bug itself was also subtle, and we added a lint rule to catch future occurrences of it (Audit for error prone isinstance int/float and add lint)
  • properly return NotImplemented on binOp(symInt, tensor) calls. This bug was tricky to diagnose because understanding why it’s gone wrong requires knowing how reversible magic methods in Python work. First, Python attempts __mul__, and if it raises a TypeError or returns NotImplemented, it will silently swallow the error and try __rmul__. Previously this was accidentally working, but changes @bdhirsh made this no longer work.
  • Add support for torch.tensor(scalar_sym{int/float}). The short term fix involves just guarding on the contents of the SymInt, but @anjali411 and @ezyang discussed a more permanent solution, which involves introducing a new operator int_tensor(SymInt[] data) which can be used to directly propagate small amounts of tensor data through without guarding.
  • Bugfixes for metas. A lot of meta operations are incorrectly implemented because they use .to(memory_format=...), which doesn’t do what you think it does. An old issue `x.to(memory_format=torch.contiguous_format)` does not always return a contiguous tensor · Issue #62027 · pytorch/pytorch · GitHub has been revived and we are discussing what to do about this at the API level.
  • Improve argument printing. This is a dumb cosmetic problem that @ezyang eventually got fed up with and fixed. Sometimes, we would report an error like “Argument types did not match: expected tuple of ints but got tuple.” This is because argument parser did not consistently report what the inner type of a tuple/list was that caused the problem. This is now fixed.
  • Convert torch.Size() argument to sym size in test_proxy_tensor. Some bugs in factory functions were not caught because while we symintify input tensors, we don’t symintify inputs which could be tensors. @albanD worked to beef up our testing here, taking advantage of the fact that OpInfos explicitly denote Size arguments with the Size tuple subclass. However, we also need to still increase testing for int arguments that aren’t Size.
  • Add inplace function testing to test_proxy_tensor. We didn’t have any testing for inplace operators at all; in fact, OpInfo inplace testing is very poorly exercised. @albanD added more coverage here.
  • [discussion] fix for aot autograd outputs that dont require grad. This is a kind of embarrassing bug in AOTAutograd where we just didn’t think hard enough about what to do in various edge cases involving requires_grad. @soumith suggested that we should do a more careful audit of AOTAutograd for other edge cases along these lines.
  • symintify nll loss fns (#86915) by anjali411 · Pull Request #87095 · pytorch/pytorch · GitHub uncovered a bug in our default argument handling (specifically, we didn’t handle it at all for SymIntList). It was difficult to diagnose because we had uninitialized memory that would unpredictably fail asserts after using the default. @anjali411 has a follow up to make it explicitly error for any unrecognized argument types so that this doesn’t occur in the future.

Merge to master retrospective

We had very few reverts this week (hooray!)

What’s new on the branch this week?

What’s made it to master this week?

What’s coming next?

  • E2E training on master with inductor.
  • All benchmark models are passing aot_eager training on branch; tracked at Operators that need symbolic support - Google Sheets
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing

We also had some discussions about “unbacked” symbolic integers (e.g., as produced by item()). @Lezcano has a proposal for how to organize decompositions Proposal for a property-based tag system for prims, refs, and decompositions - Google Docs

2 Likes

State of symbolic shapes branch: Oct 30 edition

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ) is a long running branch containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch - #8 by ezyang

Commit ID at time of writing: 121e8ebcc2fc50e5ca28cfb3ad437596084424e1

Executive summary

Voz enabled propagation of symbolic shapes in dynamo (previously, symbolic shapes were only propagated in AOTAutograd), and this uncovered a large number of previously undiscovered bugs and coverage problems in TORCHDYNAMO_DYNAMIC_SHAPES=1 itself. The team plans to pivot to working on the torchdynamo codebase to help fix these problems.

  • @SherlockNoMad found in Meta OpInfo Test for stride correctness by SherlockNoMad · Pull Request #87849 · pytorch/pytorch · GitHub that there are ~70 aten ops have mismatched stride value between meta function and eager’s implementation. Usually, this is a result of incomplete python meta function, or decompositions (because our unit tests was not asserting on stride’s correctness, and we didn’t have enough test cases for strided inputs). Sherlock compiled all the failures into this tracker sheet Stride Mismatch Tracker - Google Sheets . Bugs on this sheet are open season for burn down.
  • Model training status on symbolic-shapes. (@ezyang) See also Operators that need symbolic support - Google Sheets (out of date).
    • aot_eager, no TDS: 149 out of 165 (new) - logs
    • aot_eager, with TDS: 118 out of 163 (+2 WoW; heavily depressed due to new bugs discovered in torchdynamo) - logs
    • inductor, with TDS: 45 out of 163 (new) - logs
  • Model inference status on symbolic-shapes. (@ezyang)
    • inductor, with TDS: 69 out of 177 (new) - logs
  • Model inference status on master. (@ezyang)
    • This week is really bad, as we haven’t gotten all the branch fixes after dynamo symbolic shapes fallout. With TORCHDYNAMO_DYNAMIC_SHAPES=1: 4 out of 177 (-23 WoW) - logs
  • OpInfo tests on symbolic-shapes. (@ezyang)
    • pytest -v test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 347 passed (+8 WoW), 374 failed (-7 WoW), 497 skipped (+1 WoW)
    • pytest -v test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 238 passed (+19 WoW), 246 failed (-18 WoW), 125 skipped (+0 WoW)

Previous branch diff: 70 files changed, 1681 insertions(+), 430 deletions(-)
Current branch diff: 110 files changed, 2788 insertions(+), 2114 deletions(-)

Notable bug fixes

  • Dynamo’s dynamic shapes support is seriously buggy, and it doesn’t seem like there is a clear enough conceptual framework for how things should be implemented to make bug fixes simple. For example, voz in Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub needed to write a page of code just to get dynamic size(i) method calls working. We intend to have a knowledge sharing session with Voz on Tuesday to get the team up to speed on how to approach dynamo bugs.
  • We’re still fixing incorrect stride bugs. Delete incorrect and duplicate meta_add_ in particular was a symbolic-shapes branch only howler. It is still quite difficult to diagnose these without a minifier; ezyang proposed we have a mode where aot_eager checks the real tensors have metadata consistent with the trace, but this still isn’t implemented yet. Sherlock has a workstream for fixing strides based on the test suite: Stride Mismatch Tracker - Google Sheets There is still a live problem with softmax: minimal repro gist:54f03e02fd36069bf9693ae2ab707d10 · GitHub
  • In parallel to Sherlock’s burn down of stride bugs, @ezyang is investigating whether or not we can make incorrect stride bugs less severe by preserving reshapes during tracing. This is tricky to do, because we still have to do autograd, and autograd doesn’t directly support reshape (as it sometimes returns a view and sometimes returns a fresh tensor). Our plan is to transform reshape into an always copying operation, and adding enough extra sanity checking in functionalization to detect if this is not semantics preserving. To do this, we need to run functionalization and autograd at the same time; thus functionalize and compute joint simultaneously. Implementing this was a doozy, as it triggered multiple functionalization bugs:
  • We made a LOT of quality of life improvements to the branch this week. A lot of it was simply dogfooding the software and making adjustments when we noticed things could be improved. These range from paper cuts (spammy warnings, overly verbose exceptions) to important debugging tools like (printing more program state on failure, e.g., the incomplete make_fx traced graph or strides of variables in a graph) to important unblocking features (adding timeout so that sympy infinite loops don’t block model sweeps).
  • Use functionalize in symbolic exhaustive tests, not sound otherwise is a howler; someone intentionally turned off functionalization on the test suite, so many tests were failing because AOTAutograd without functionalization isn’t actually sound. Fixing this helped a number of test cases pass. The moral of the story is, don’t give users configuration knobs that are known unsound, unless you make it really obvious (or at least, obvious enough that a core dev doesn’t turn that knob on by default in tests, and another core dev approves it in CR.)
  • Fix bernoulli functionalization was an annoying typofix in the bernoulli implementation for a longstanding problem (long enough that XLA had manually worked around it.) It had to be diagnosed manually (by looking at the graphs and noticing some operators were missing from DCE); it could have been immediately caught by testing for no mutating ops after functionalization, but the PR that implemented this is still not landed aot_autograd: add assert for functional-only graph by bdhirsh · Pull Request #85681 · pytorch/pytorch · GitHub (I must emphasize how important it is to actually land the PRs you write!)
  • A runtime error complaining about shape mismatch in backwards turned out to be a functionalization bug, where we were not copying enough metadata when doing a shallow copy and detach of FunctionalTensorWrapper (which happens when a variable is saved for backwards.) Both bdhirsh and ezyang root caused the problem at about the same time. Fixed in functionalization: fix detach() and tested by Saved tensor that is view test case functionalization
  • This is not a bug fix per se, but an infrastrucutre improvement basic bin_op(symint, symint) support was initially done in a brute force way by adding each overload for mixed operations one-by-one, but on subsequent code review we didn’t like it. Unify SymIntNode and SymFloatNode into SymNode instead removed the static types from C++, which meant you don’t have to add overloads. This produced a patch that was functionally equivalent, but a net decrease in total LOC. https://twitter.com/ezyang/status/1585373693852934144
  • Prevent use of as_strided in backwards was discovered when inspecting some backward graphs by hand and noticing they had a lot of as_strided calls in them. For reference, typical user programs don’t really ever use as_strided, and neither do most of our autograd derivative formulas, so it is weird that they were showing up. In fact, this is due to how our view mutation logic work, which by default converts all views into AsStrided operations so you don’t have to track the provenance of any single view. This is a nice performance optimization for eager, but it generates code that’s worse for compilers, and in fact, XLA already had a way of disabling this shortcut and maintaining the views by hand. Turning this on for PT2 eliminates these as strided calls. BTW, as strided calls are preferably avoided in IR as they are extremely input stride sensitive; if the incoming tensor has a different stride, as strided will silently corrupt your data. This would be less of a problem with a “compositional” variant of as strided that respects input strides (e.g., if a dim was already stride 2, restriding it by 2 would result in 2 * 2).
  • We have a lot of sympy infinite loops. Fix another sympy infinite loop whackamoles one of them, but there are still more. @Chillee we need to figure out a more sustainable strategy for these.
  • Disable aot autograd cache is a nice stopgap: some of our models are now failing because they do exercise dynamic shapes, but our guards are insufficient (because AOTAutograd guards aren’t propagated up to torchdynamo yet.) Disabling the AOTAutograd cache means we always recompile at AOTAutograd until this can be fixed.
  • Unify meta tensor and fake tensor converter conversion broke a number of inductor model runs on master (non-dynamic shapes configuration.) I was able to fix this by disabling meta tensor converter’s view preservation (it’s logic to remake a tensor into a view if it was originally a view). But in principle, it should be OK to do this. This is very perplexing. Some of the current investigation notes are at AOTAutograd has soundness bug when passing views as inputs · Issue #1815 · pytorch/torchdynamo · GitHub

What’s new on the branch this week?

This time, instead of doing commits chronologically, I’ve tried grouping them by theme.

Meta support

SymInt support

Quality of life

Dynamo

Functionalization

Infrastructure

Merge to master retrospective

  • Many symintifications was reverted because it broke internal Executorch build, due to an Executorch only YAML file defining an out variant of an operator (Redirecting...). The fbcode change ended up being trivial, so we relanded the PR by landing it GH1, and then ninja’ing the fbcode fix after it was imported in diff train. @albanD was initially concerned about merge conflicts between fbcode master and GH master because this was a large patch, but this strategy neatly sidestepped the problem (the biggest annoyance being ensuring the diff train was sufficiently imported to import this diff).
  • Fix bernoulli functionalization. required non-trivial XLA changes, but at time of writing @ezyang wasn’t able to compile XLA successfully. Fortunately, JackCaoG helped do the XLA side patch (which was relatively long, but not too difficult; just undoing XLA’s bernoulli hack.)
  • Unify meta tensor and fake tensor converter conversion was reverted on master because inductor tests were not run on PR. This will be fixed by enabling inductor CI on all fake tensor changes.

What’s made it to master this week?

What’s coming next?

  • Educate the team on how torchdynamo dynamic shapes is supposed to work, and spend a lot of time fixing issues here
  • E2E training on master with inductor.
  • Reshape unification
  • All benchmark models are passing aot_eager training on branch; tracked at Operators that need symbolic support - Google Sheets
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing
1 Like

“Operators that need symbolic support” doesn’t seem to have operators listed?

Yeah the name is wrong now lol. I’ll rename the sheet.

Will “Educate the team on how torchdynamo dynamic shapes is supposed to work” be a public cast? Would love to watch it, if possible

Unfortunately we recorded it internally, so we can’t share it. I’ll try to write a post that summarizes what we discussed during the session.

1 Like

Which version is expected to receive symbolic shapes?

The current plan is that we will have a convincing alpha in the next release (we have a lot of stuff working in the branch, but it needs to be merged to master and we’re not bug free enough to turn it on by default), and hopefully in two releases it will be on by default.

1 Like

As promised: On the architecture of torchdynamo - Google Docs

State of symbolic shapes branch: Nov 5 edition

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ) is a long running branch containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch - #9 by ezyang

Commit ID at time of writing: 1f5fac1d10df2e4a054740abc92bcf9d6a6553eb

Executive summary

This week was relatively light on both commits to the branch and merges to master; the bulk of our time was spent on late-breaking infrastructure problems (including addressing problems that affect not-dynamic shapes) and onboarding onto Dynamo. On the plus side, we have solutions to a number of longstanding problems to the overall compilation stack, and we managed to claw back a bit of aot_eager TDS model coverage (+19 passing models) by fixing Dynamo bugs.

  • Milestone: the Edge team imported the symbolic-shapes branch into fbcode, and used it to successfully export one of their models which makes use of symbolic shapes. This is without any specific work to support their use case, and is a nice external validation of the work we’re doing. (Also, @suo would like to know when we will finish merging everything to master thankyouverymuch.)
  • New design: AOTAutograd 2.0 - Google Docs - This design doc describes how to make AOTAutograd work with graphs that have input mutation. This is a release blocking issue as right now instance norm and batch norm are miscompiled by functionalization as running state updates are being removed ([PrimTorch] Functionalization pass removes Instance Norm / Batch Norm running stats transformations · Issue #88375 · pytorch/pytorch · GitHub) and some models are failing to get optimized because they have input mutation. @bdhirsh is working on the input mutation piece of the story.
  • New design: Stride agnostic IR (reshape preservation) - Google Docs - This design doc describes how to make our compiler stack less sensitive to stride propagation bugs, by observing the fact that accurate strides are specifically needed only by view/reshape in most user models, and we can remove this dependence by introducing infallible view() and input/output freezing reshape.
  • The team spent time onboarding onto Dynamo this week. There was a nice session lead by Voz at Redirecting... (FB only), and a public writeup about what we learned from Voz and @jansel at On the architecture of torchdynamo - Google Docs @bdhirsh and anjali411 were able to make their first contributions to Dynamo this week, yay!
  • Preserve reshapes in AOTAutograd by ezyang · Pull Request #88361 · pytorch/pytorch · GitHub is the first example of overriding a pre-autograd decomposition with a user defined decomposition. You may find the pattern useful in other contexts.
  • We have more automation for running sweeps and updating the spreadsheet, so the spreadsheets are now up-to-date and tracking both models and tests.
  • Model training status on symbolic-shapes. (@ezyang) See also Symbolic shapes work items tracker - Google Sheets (up to date!)
    • aot_eager, with TDS: 137 out of 163 (+19 WoW) logs
    • inductor, with TDS: 8 out of 163 (-37??? WoW; I looked at the old logs and I’m not sure how I got 45 passes last week; the same models that were passing last week are passing this week, so maybe there was a setup problem) logs
    • Lowlight: we have two new accuracy failures on our models (cait_m36_384, xcit_large_24_p8_224). This is bad, because it means we do not have enough asserts to catch when we are doing incorrect things. These models should be investigated.
    • Lowlight: a number of models are failing due to sympy timeouts. We need to figure out a strategy for solving this once and for all. @Chillee has suggested that we may want to try rewriting sympy.solve for our use case.
  • OpInfo tests on symbolic-shapes. (@ezyang)
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 350 passed (+3 WoW), 370 failed (-4 WoW), 499 skipped (+2 WoW) logs
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 240 passed (+2 WoW), 228 failed (-18 WoW), 127 skipped (+2 WoW) logs

Previous branch diff: 110 files changed, 2788 insertions(+), 2114 deletions(-)
Current branch diff: 76 files changed, 2341 insertions(+), 533 deletions(-)

Notable bugs

What’s new on the branch this week?

SymInt support

Dynamo

Functionalization

Infrastructure

Quality of life

Merge to master retrospective

  • Do not use unsafe restriding for subclasses was reverted because it broke some internal fbcode tests.
  • Revert “Revert “Put Python Dispatcher cache in dict, clear it on new registrations. (#88329)”” was reverted for making test times balloon by 2hrs! The root cause of the problem was a refactor that switched a cache to making use of a variable that was reassigned (key = resolve_key(key); cache(key)), causing the cache to never get hit and massively slowing down test runtime. @ezyang figured out the problem by guessing it was a cache problem and then reproducing it.
  • Reland 2 Many symintifications (#87604) was successfully landed to fbcode, but it turns out it actually broke static runtime. This is because tensor_split only had one overload ported, and the IntArrayRef was actually accepting int64_t arguments, causing call sites that intended to go to the other overload go to the wrong overload. This was point fixed by just porting the other overload to have an explicit signature. It’s not clear there’s a structural fix for this problem; please be aware when adding BC native:: signatures for operators with multiple overloads.

What’s made it to master this week?

What’s coming next?

  • Brian, you still have a lot of open PRs lol, please keep landing them
  • Voz, we need a plan for how we are going to land Dynamo fixes to master. Need to discuss with Alban if merge captain should also be doing Dynamo fixes; the main difficulty is they need proper review. (@jansel spot checked some of the changes on the branches and found a number of suggestions, so we will need to be rigorous about this.)
  • Fix input mutations for AOTAutograd (bdhirsh)
  • Stride agnostic IR (ezyang)
  • E2E training on master with inductor.
  • All benchmark models are passing aot_eager training on branch; tracked at Operators that need symbolic support - Google Sheets
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing

State of symbolic shapes branch: Nov 12 edition

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ) is a long running branch containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch - #16 by ezyang

Commit ID at time of writing: 807a62fc61bea26707c3dc09a12bad204e375a95

Executive summary

This was a chaotic week. Meta had layoffs for the first time in its history (this workstream was not directly affected.) We still made progress on this workstream (merge to master, inductor support, operator coverage), but we also discovered more work to do (more dynamo bugs, dynamic shape guard problems, more accuracy failures). Some big work items (dynamo merge to master, input mutation, copy-on-write tensors) have progressed, but are still short of actually landing to master (or even to the branch, as the case may be). Merge to master is also slow going as we often have to first add OpInfo support before we can merge our changes.

  • Staffing. Nikita Karetnikov from Quansight is pivoting from general PrimTorch work to working on decompositions / meta implementations that the dynamic shapes workstream can specifically benefit from. Welcome Nikita! In other news, Edward is on jury duty next week.
  • Design. We have shifted the “stride agnostic IR” concept to a more expansive “stride agnostic PyTorch” concept, where we make eager mode PyTorch as whole less sensitive to stride changes. This includes a new design for Copy-on-write tensors for eager PyTorch - Google Docs which aims to eventually make the BC-breaking change to reshape()/contiguous()/etc to have these functions always return contiguous tensors. A PoC PR for the entire design exists Copy on write reshape by ezyang · Pull Request #88774 · pytorch/pytorch · GitHub and fully passes non-trunk CI, but there are some unresolved questions, such as whether or not to more deeply integrate data pointer reference counting into Storage to reduce the overall level of indirection, and whether or not the proposed warning strategy is too loud or not. This pair of proposals was discussed in the most recent Composability meeting; there were no major objections but also a desire to better understand the implications of the change.
  • Make silent errors noisy. A big reason why our aot_eager pass rate regressed this rate is we turned on more stringent error checking in the branch, to try to transform potential bugs into assertion failures. This week, we: (1) assert sizes/strides of intermediate tensors are consistent between fake and real tensors, (2) assert functional-only graph after lowering (this turns the batch norm problem we observed last week into a hard error; to bypass some of these errors, we disabled AOTAutograd from running on subgraphs with BatchNorm), (3) assert all guards correspond to tensors dynamo knows about (this flushed out a problem with symbolic shapes guards, where dynamo was not tracking enough guards; we fixed one of the problems, but that didn’t hit all of the issues, so we also have a way of suppressing these failures). Unfortunately, while these changes did nail some accuracy problems, we still have new accuracy failures on the latest model runs.
  • Inductor has problems. The branch now has some quick and dirty hacks which substantially improved the inductor pass rate (+16 working models), but there are still are many bugs that are causing lots of models to fail for similar reasons. The conclusion is that there are some fundamental pieces in the dynamic shapes inductor integration that don’t exist yet (though @Chillee still assures me they’re easy to do.) On the bright side, the uniformity of inductor errors means there probably aren’t that many distinct bugs to fix.
  • Model training status on symbolic-shapes. (@ezyang) See also Symbolic shapes work items tracker - Google Sheets
    • aot_eager, with TDS: 135 out of 163 (-2 WoW) logs csv Note that this run is skipping all subgraphs with batchnorm and with dynamo guard asserts suppressed (in the spreadsheet, this is noted as BN+IA)
    • inductor, with TDS: 24 out of 163 (+16 WoW) (logs too long) csv
    • Lowlight: jx_nest_base and twins_pcpvt_base are failing with accuracy errors. Interestingly, this pair of models previously was failing with accuracy errors without dynamic shapes; both were fixed by a single PR https://github.com/pytorch/pytorch/pull/85417 . jx_nest_base is minifiable, although the minifier failed with a reduction error on int when I tried running it (I did not try very hard). twins_pcpvt_base was passing on 10/28 and regressed into a timeout in 11/2 (this is after voz’s major dynamo change hit master); jx_nest_base has never passed.
    • Highlight: cait_m36_384 and mobilenet_v2_quantized_qat accuracy failures turned into non accuracy failures after we added DebugInterpreter to aot_eager. cait_m36_384 is now passing; no one has had a chance to investigate mobilenet_v2_quantized_qat
  • OpInfo tests on symbolic-shapes. (@ezyang)
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 388 passed (+33 WoW), 334 failed (-36 WoW), 501 skipped (+2 WoW) logs csv
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 255 passed (+15 WoW), 213 failed (-15 WoW), 129 skipped (+2 WoW) logs csv

Previous branch diff: 76 files changed, 2341 insertions(+), 533 deletions(-)
Current branch diff: 82 files changed, 2236 insertions(+), 747 deletions(-)

Notable bugs

  • Fix buggy unsqueeze_ implementation. I found this error because the unsqueeze OpInfo test was failing (hooray unit tests). I didn’t actually directly debug the bug; I just replaced the code wholesale with a straight port of the C++ code (I think the bug was in how dimension wrapping was implemented, though!)
  • Fix stride on softmax_backward_data, fixes cait_m36_384 was found via the DebugInterpreter. The interesting thing about this fix was that I had actually made a repro last week, but no one had picked it up and fixed it, and when I went to look at it again the repro no longer actually failed. Fortunately, DebugInterpreter confirmed that it was indeed a problem with softmax_backward_data.
  • More correct fix for upsample_bilinear2d decompositions is a fixup of a previous PR, where I attempted to register a composite implementation for upsample_bilinear2d.default in the Python dispatcher. I initially tried CompositeImplicitAutograd; this did not work, because this operator has an explicit autograd formula (Autograd key registration), and Autograd is higher precedence than CompositeImplicitAutograd. This is easy to work around if you know the correct semantics, but you might expect Python registrations to “override” their C++ variants; we should look into potential API changes to remove this footgun.
  • Suppress guards when constructing fake tensors was discovered by chasing down an assert failure from Dynamo when it needed to construct a shape guard involving a symbolic variable that it didn’t know the source of. The problem is that we do non-trivial tensor operations to make, e.g., view fake tensors actually views, but this means the base tensor gets its own set of fresh symbolic variables that dynamo doesn’t currently track. Fixing this in Dynamo is a fairly involved refactor, especially because we’d like some design that makes it hard for Dynamo to forget to track tensors (e.g., tie it to fake tensor conversion). In the meantime, we added a config driven via TORCHDYNAMO_IGNORE_ASSERT to allow dynamo to suppress these errors for now.
  • call call_method instead of _call_operator_builtin - operator calls in Dynamo don’t properly handle NotImplemented. Brian/Anjali tried to patch over this, but the bugfix was buggy, so Voz yanked it out from the branch again. This problem needs to be solved again properly.
  • Some infra/QoL commits were about making it easier to debug errors if things fail. For example, if exec fails to compile code, print the code that failed to compile. There are some design choices which produce unreadable debug output; for example Change how sympy guard formatting works modifies symbolic shape guards to print with newlines, instead of being mashed into a giant unholy expression.
  • Fix call_range and remove weird idiom of fallthorugh for dynamic in f…. This failed in a very strange way, and even after Edward identified what the problem was, he couldn’t figure out how to fix it (and Voz fixed it later.) It would be good to have a better understanding of what the Right™ way to fix these issues in Dynamo are.

What’s new on the branch this week?

Meta/decomp support

SymInt support

Infrastructure

Quality of life

Dynamo

Inductor

Merge to master retrospective

  • Meta registrations, Nov 7 edition, part 1 had to be split into small constituent PRs, because CI on the mega PR was failing on a fixed set of seemingly unrelated tests, even after rebase. However, the problem evaporated when everything was landed individually, so it’s not really clear what the moral of the story here is.
  • reland “fix as_strided_scatter_backward (#87646)” was reverted because it broke some trunk jobs. It looks like this was resolved by adding more xfails/skips: reland "fix as_strided_scatter_backward (#87646)" by bdhirsh · Pull Request #88342 · pytorch/pytorch · GitHub
  • Some meta function merges are slow because OpInfo coverage is insufficient. There are two common themes: first, sometimes there is an OpInfo for a given operation, but the OpInfo covers a lot of different overloads/dim specialization of the function, and we only implemented one overload in a PR. To easily test, we have to extract out one particular overload from the OpInfo, or go ahead and implement all of the overloads so the mega OpInfo works. (Arguably, this is a bad decision in OpInfo design.) Second, many missing OpInfos relate to backward functions. Hypothetically, these can be tested indirectly via the forward-backward call that test_aotdispatch.py performs, but in practice it seems to be easier to just add the backward OpInfo.

What’s made it to master this week?

What’s coming next

By person:

  • bdhirsh: fix input mutation for AOTAutograd (same as last week, progress at first draft of input mutation handling for aot autograd by bdhirsh · Pull Request #88817 · pytorch/pytorch · GitHub )
  • voz: merging the first round of dynamo patches to master; afterwards, aot autograd cache, restore builtins support in dynamo, fix symbolic shape guard construction in dynamo, refactor dynamo storage of dynamic shapes
  • ezyang: run some more experiments on copy-on-write
  • Chillee: fix sympy infinite loops, get inductor training working with dynamic shapes
  • anjali411: land more branch changes to master, run sweeps this week
  • nkaretnikov: continue working on test_proxy_tensor coverage

Our north star:

  • All benchmark models are passing aot_eager and inductor training on branch
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing
  • Dynamic shapes on by default for developers / users
1 Like

State of symbolic shapes branch: Nov 20 edition

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ) is a long running branch containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch - #18 by ezyang

Commit ID at time of writing: 41c314473272b2622e20c261f19549b4bd3f1d8f

Executive summary

This week, we made two major merges into the branch: (1) @bdhirsh’s AOTAutograd input mutations PR and (2) voz’s fake tensor plumbing to AOTAutograd PR. While not all of the bugs from these two PRs have been totally resolved (and in particular, bugs in input mutations are suppressing the pass rate on the branch), the features work well enough that we only suffered a minor regression in aot_eager pass rate. Additionally, we have BERT_pytorch working end-to-end for BOTH training and inductor, meaning that on branch (but not master) we have achieved our goal for inductor. :tada::tada::tada:

  • PSA: We now have int64_t, SymInt overloads for all binary operators in C++, so you no longer have to rewrite 2 + symint into symint + 2; both work now.
  • PSA: DebugInterpreter is now actually verifying stride equality; at time of writing, this is revealing seven models which have incorrect sizes/strides. These bugs are ripe for picking!
  • We have more clarity about what is missing to properly integrate inductor with dynamic shapes (instead of all of the hacks that are currently on branch.) A big question mark is whether or not ShapeEnv should be shared between AOTAutograd’s forward and backward; when the ShapeEnv is shared, Inductor cannot necessarily infer all of the shape variables (because a shape variable may only occur in a more complex shape expression from input, e.g., s0 * 2). ref @Chillee is pivoting back to working on this, after having spent time working on channels last in the general workstream. There’s also some discussion about forward-backwards staging confusion at [aot_eager] [hf_Longformer] Cannot view a tensor with shape · Issue #1888 · pytorch/torchdynamo · GitHub
  • Some progress was made designing a more precise warning mechanism for Copy-on-write tensors for eager PyTorch - Google Docs see bottom of doc (though no implementation progress)
  • Model training status on symbolic-shapes. See also Symbolic shapes work items tracker - Google Sheets
    • aot_eager: 128 out of 163 (-7 WoW) logs csv. Regression is primarily due to AOTAutograd input mutation support hitting our branch, but having two bugs (accidentally using SymInt sizes to perform manipulations on real tensors, and needing to regenerate views of intermediates rather than returning them directly).
    • inductor: 36 out of 163 (+12 WoW) logs csv; notably, BERT_pytorch is passing with inductor now, and a spot check of generated Triton code suggests the output is dynamic: BERT_pytorch dynamic Triton · GitHub
    • End-to-end BERT_pytorch working with dynamic shapes and inductor by ezyang · Pull Request #89313 · pytorch/pytorch · GitHub demonstrates the minimal set of changes from our branch necessary to get BERT_pytorch working end-to-end on master. The PR doesn’t pass CI as is; our current thinking is to do some necessary refactors first which will simplify the final form of this PR.
  • OpInfo tests on symbolic shapes.
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 494 passed (+106 WoW), 229 failed (-105 WoW), 512 skipped (+11 WoW) logs csv. This improvement is partially due to Towards unifying symbolic and non symbolic fake tensor, which allows us to attempt C++ meta functions even if they’re not known to be SymInt-aware; it turns out many of them still work correctly anyway. This is at the cost of worse error messages when SymInt is not supported.
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 249 passed (-6 WoW), 221 failed (+8 WoW), 133 skipped (+4 WoW) logs csv. The regression here is from outstanding bugs on AOTAutograd input mutation changes on the branch; the numbers should improve once that regression is fixed.

Previous branch diff: 82 files changed, 2236 insertions(+), 747 deletions(-)
Current branch diff: 68 files changed, 2612 insertions(+), 554 deletions(-)

Notable bugs

  • DebugInterpreter actually works now: Fix cat striding in PrimTorch was identified by looking at a model failing a stride quality assert in debug interpreter. I figured out that DebugInterpreter was not correctly testing strides for correctness when I fixed a separate stride problem that was causing an accuracy failure, and then attempted to diagnose why DebugInterpreter hadn’t caught it earlier; turns out the stride matching function returns a tuple, and tuples are always True :rage:
  • Testing is important: unrelated refactoring on master broke our previously symint-ified upsample_nearest2d, and had to be fixed again in SymIntify upsample_nearest2d again after composite-ification. If we had appropriate testing on master, it probably could have been caught at regression time.
  • Simplify cudnn rnn support greatly untangles a complicated situation with cudnn SymInt support. The root cause of the problem is that cudnn_rnn is a composite function that calls a black box cudnn function to figure out what the output size of the workspace should be. This cannot be SymInt’ified, but in the original attempt to make this all work, our poor intrepid engineer tried SymInt’ifying all of the Descriptor code responsible for driving calls to the cudnn API. This patch undos all of that in favor of an earlier guard. However, this is not enough to fix tts_angular, as this eventually calls _cudnn_rnn which requires a meta implementation, but this meta implementation requires query cudnn APIs. Handling this black box correctly may require something similar to how we plan to handle missing meta functions for custom functions (guess relations between inputs and outputs, and verify the black box function acts consistently for concrete values we haven’t seen before.) Or we can just guard (but guarding requires us to be able to call into the cudnn API from our Python implementation.)
  • Set functionalization storage to size 0 for sparse tensor wrapping is a lucky hack that works around a very delicate situation, which is that functionalization really doesn’t know how to handle sparse tensors. Fortunately, we don’t actually need to functionalize operations on sparse tensors, so hacking an empty storage “works”, but it is likely to fail if someone stress tests us on sparse support.
  • first draft of input mutation handling for aot autograd is affected by two bugs right now.
    • “RuntimeError: Expected !is_symbolic() to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)” The problem here is that the patch accidentally directly tries to use fake tensor sizes (with SymInts) to reshape real tensors, accidentally propagating SymInts into real tensors (a big no no! And also missing asserts; we should add asserts.) At time of writing, this is close to fixed but not quite, using the idea of having the forward return sizes/strides/storage_offset directly as symint outputs of the forward. Most of the tests pass on the branch, but while hf_Longformer and tacotron2 no longer have the error, they’re both OOM’ing even when set a batch size of 1… this warrants further investigation. A suggestion is to use zdevito’s memory profiling tools to diagnose: Debugging PyTorch memory use with snapshots | Zach’s Blog
    • “RuntimeError: Output 0 of AsStridedBackward0 is a view of a view which was created in no_grad mode and is being modified inplace with grad mode enabled.” This is the thing where our graph should return a graph intermediate, and manually regenerate the view(s) of the intermediate in the epilogue. @bdhirsh hasn’t started working on fixing it yet.
  • Detach fake tensors into val, so they aren’t affected by metadata mut… was briefly discussed in composability; it has the potential to break users who are relying on accurate autograd metadata or tensor identity on meta[‘val’], but it turns out this doesn’t work anyway.
  • [HACK] don’t graph break on math.sqrt: we noticed that BERT_pytorch was repeatedly recompiling on every iteration. Voz tracked this down to the id of an NN module not staying stable over iterations; but it turned out that this was because we were graph breaking too much, due to an overly aggressive unsupported() call in Dynamo. Reducing the graph breaking fixed the problem!
    • Disable explain for now: we wanted to collect graph break stats to make it easier to tell if BERT_pytorch situation was happening elsewhere. Dynamo’s preexisting explain feature makes this possible. However, when we turned it on many huggingface models started failing with “AttributeError: ‘str’ object has no attribute ‘size’”. This has not been investigated yet.
  • Patch out suspicious get_item_dyn logic for jx_nest_base is an example of very long code that is also wrong. Long code for a special case is a smell, there’s probably a conceptually simpler way to do it!
  • Dynamo was affected by a lot of “missing symbol s0” assertion errors. This assertion error tests if Dynamo knows how to compute a given SymInt from the set of tensors being guarded on. These missing symbols came from a variety of places: Dynamo’s special handling of NN parameters/buffers [dynamo] [dynamic shapes] Fix register_buffer (and all module associa…, as well as dependence on _base tensors due to functionalization Record TensorReference for ._base tensors, but don’t install guards o…. The latter is fixed in a tricky way: we don’t actually ever guard on base of tensor, because according to voz this caused a lot of recompilation. More investigation necessary…
  • Remove bad string parsing assert is an example of why it is bad to do string matching on strings that represent structured data types (like language expressions).
  • Hide ConvParams struct from ConvUtils.h and the other PRs in the stack are a nice example of how a large involved fix was split into a series of manageable refactors, and then a short fix in the ending. I attempted to fix it in one go at first, but then realized there was a simpler way to do it (template the ConvParams struct.)

What’s new on the branch this week?

Meta/decomp support

SymInt support

Infrastructure

QOL

Dynamo

Inductor

Merge to master retrospective

What’s made it to master this week?

What’s coming next?

By Person:

  • Chillee: proper integration of inductor with dynamic shapes (he’ll actually work on it this week!!)
  • voz: merge aot autograd plumbing to master, with a lot of refactor
  • ezyang: maybe improved implementation of reshape CoW warnings (but only if I get some spare time)
  • bdhirsh: continue fixing input mutation for AOTAutograd (last mile two bugs)
  • jbschlosser: maybe one of the assert failures on aot_eager

Our north star:

  • All benchmark models are passing aot_eager and inductor training on branch
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing
  • Dynamic shapes on by default for developers / users
3 Likes

State of symbolic shapes branch: Dec 1 edition (even of PyTorch Conference)

The symbolic-shapes branch (PyTorch: Symbolic shapes by ezyang · Pull Request #84246 · pytorch/pytorch · GitHub ) is a long running branch containing a large number of features and bugfixes related to dynamic shapes support in PyTorch. Previous update: State of symbolic shapes branch - #18 by ezyang

Commit ID at time of writing: a05b7b1c73247ff562a82aac0edca79bbaebc2bd

Executive summary

It is the eve of the PyTorch Conference and we have been busy getting things ready for some big announcements. :wink: Before and after Thanksgiving, many folks involved with dynamic shapes were deputized to help fix some major release blockers in the general compiler workstream; Brian and Jane landed all of the pieces needed to properly update batch norm running stats, and Alban and Edward found and fixed some more fairly major AOTAutograd bugs. On the dynamic shapes front, Voz has been steadily working on getting all of the Dynamo changes passing CI on master; half of the preparatory changes have been landed so far, and the branch has been resync’ed after those merges. There is some regression in the aot_eager pass rate as we remove hacks and redo fixes properly.

Previous branch diff: 68 files changed, 2612 insertions(+), 554 deletions(-)
Current branch diff: 68 files changed, 1440 insertions(+), 290 deletions(-)

What’s new on the branch these two weeks?

Metas/decompositions

Infrastructure

Debug interpreter

Dynamo

Inductor

QOL

Merge to master retrospective

  • Reland “Add single process version of dynamo distributed hf_Bert tests (#89721)” - this got bounced because not enough tests ran on PR. We added more files to automatically trigger inductor tests.
  • Refactor how AOTAutograd backends are defined - this is just an example of a few cases where folks ran inductor CI, got accuracy failure on a model, and then spent a bunch of time trying to debug what had happened; when in fact, the failure was a preexisting master failure. It is not easy to identify these because ciflow/inductor does not run on every master commit.
  • Change aot_module_simplified to take take arguments directly - this broke a timm model, and lead us on a pretty big chase that eventually revealed that example inputs being passed to backends did not have correct requires grad because they were being cloned. This was fixed by refactoring the AOTAutograd-Dynamo integration to not clone example inputs.
  • Remove fake_tensor_propagation - this nearly got bounced because it broke some internal users who didn’t have fake tensor support for some operations. Averted because (1) their tests weren’t in CI and (2) it turned out to be pretty easy to add meta tensor support.
  • Don’t unsafely clone autograd meta - this couldn’t be landed because it broke an inductor model, causing it to raise an error where previously it passed. This lead to a very long debugging session by Alban until we finally nailed the problem.

What’s made it to master this week?

ezyang

bdhirsh

anjali411

nkaretnikov

voz

albanD

What’s coming next?

  • Land fake tensor propagation from Dynamo to AOTAutograd (voz)
  • ShapeEnv revamp to get guards for duck sizing (ezyang)
  • GuardEnv for non-shape related extra guards produced by AOTAutograd (voz)
  • Address CI comments for AOTAutograd input mutation, factoring it to be more modular (bdhirsh)
  • Proper inductor integration (Chillee didn’t end up working on it, unallocated; mildly blocked on ShapeEnv revamp)

Our north star:

  • All benchmark models are passing aot_eager and inductor training on branch
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing
  • Dynamic shapes on by default for developers / users
3 Likes