State of symbolic shapes branch

State of symbolic shapes: Jan 11 edition

Previous update: State of symbolic shapes branch - #22 by ezyang

Commit ID at time of writing: f6c7cf1bf579cc42ea7e21bd557168625648a3e9

Executive summary

Welcome back from holiday and PSC folks! The holidays were sleepy, but not that sleepy: @jbschlosser landed some nice coverage wins (+2 working models on master) and @tugsbayasgalan has been working on a number of dynamic shape fixes that show up in PyTorch Edge export use cases. Since it’s the new year, you might be wondering what you ought to be doing in the new year. This status post is to help you figure out.

High priority things that need to be done (excluding export):

Things to be done sourced from PyTorch Edge export workstream (Meta only):

Things to be done sourced by generic export workstream (@SherlockNoMad)

Status report:

  • Model training status on master. See also Symbolic shapes work items tracker - Google Sheets
    • aot_eager: -15 (+2 MoM). These newly working models are thanks to work by @jbschlosser skip list
    • inductor: still really broken on master :rotating_light::rotating_light::rotating_light:
  • OpInfo tests on symbolic shapes.
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 516 passed (+3 MoM), 522 skipped (no change), 224 xfailed (-3 MoM)
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 291 passed (+5 MoM), 143 skipped (+1 MoM), 197 xfailed (-5 MoM)

Guard simplification

As you can see from the chatter before this post, Nuno is working on guard simplification. Nuno looked at resnet, and found that range analysis and a quadratic solver would solve most of the nontrivial constraints we were generating, greatly simplifying the guards we generate. However, there were also suggestions that we could simplify the guards at generation time.

Example constraint: Eq((s2 - 1)//16 + 1, 1). This is produced by empty_strided calling compute_contiguous on the passed in strides. Constraint is specifically the size_d != 1 test when testing contiguity. Should avoid guarding here ala subclass_zoo/dynamic_strides.ipynb at main · albanD/subclass_zoo · GitHub cell 18 (but need SymBool for this!)

 T z = 1;
  // NB: make sure we do signed arithmetic
  for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
    const auto& size_d = sizes[d];
    if (size_d != 1) {
      if (strides[d] == z) {
        z *= size_d;
      } else {
        is_contiguous = false;
        break;
      }
    }
  }

Example constraint: (s2 - 1)//2 + 1 < (s2 - 1)//2**2 + 2*(s2 - 1)//2 + 1. This comes from:

File "/Users/ezyang/Dev/pytorch-cpu/torch/_prims/__init__.py", line 348, in _elementwise_meta
  strides = utils.compute_elementwise_output_strides(*args_)
 File "/Users/ezyang/Dev/pytorch-cpu/torch/_prims_common/__init__.py", line 407, in compute_elementwise_output_strides
  comparison = should_swap(perm[dim0], perm[dim1])
 File "/Users/ezyang/Dev/pytorch-cpu/torch/_prims_common/__init__.py", line 387, in should_swap
  if stride_a < stride_b:

The easiest fix is probably to make sure we don’t run the sorting algorithm when we have contiguous inputs. But even better would be to introduce a contiguity guard (which tests that a tensor is contiguous), which should be able to eliminate these guards entirely.

You can reproduce these experiments with the following code:

model = torchvision.models.resnet18(weights="DEFAULT")
model = torch._dynamo.optimize("eager")(model)
model.eval()
model(torch.rand([64, 3, 7, 7]))

Something that is not so easy right now is finding out what produced guard expressions; e.g., I see x ** 2 // y blah blah, where did it come from? More detailed logging at the Dynamo per-op level would help.

What’s made it to master since last time?

ezyang

voz

jbschlosser

tugsbayasgalan

nkaretnikov

What’s coming next?

  • ezyang: catching up on code review
  • voz: changing dynamo backend api to take aot autograd directly
  • bdhirsh: on vacation until next week
  • Chillee: inductor integration

Our north star:

  • Dynamic shapes enabled by default for PT2 release
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators :rotating_light::rotating_light::rotating_light:
  • All OpInfo tests passing
  • Dynamic shapes on by default for developers / users
1 Like