State of symbolic shapes: Jan 11 edition
Previous update: State of symbolic shapes branch - #22 by ezyang
Commit ID at time of writing: f6c7cf1bf579cc42ea7e21bd557168625648a3e9
Executive summary
Welcome back from holiday and PSC folks! The holidays were sleepy, but not that sleepy: @jbschlosser landed some nice coverage wins (+2 working models on master) and @tugsbayasgalan has been working on a number of dynamic shape fixes that show up in PyTorch Edge export use cases. Since it’s the new year, you might be wondering what you ought to be doing in the new year. This status post is to help you figure out.
High priority things that need to be done (excluding export):
- https://github.com/pytorch/pytorch/pull/89532 is still not yet merged as Brian Hirsh is on vacation until next week. If someone can help get this over the finish line in the next two days that would be helpful.
- We should work on an updated manual for how to make things work for dynamic shapes, as our process has stabilized. The old docs are at Converting an op to accept SymInt arguments - Google Docs
- Tracing is pretty slow for some models. We have isolated some of the slowness to be specifically due to our sympy compute in ShapeEnv; see the samples in Time to do sympy compute per model - Google Sheets If you are interested in working on optimizing our solver these are good edge cases to look at. We also have some evidence that just fake tensors are very slow; an internal model from @suo is very large and traces very slowly even with fake tensors. Repro with
buck run @mode/opt-split-dwarf @mode/inplace scripts/suo/frontend:main
onhg up a194c70b83ca
(Meta only) - We need to track how many graph breaks are in our sweep, and how this compares to eager. Explain produces most of this information but it’s not durably recorded anywhere. This should be done by enhancing the benchmark script. @bertmaher was asking about these stats sans dynamic shapes; this is a generally useful thing beyond dynamic shapes.
- We need to implement SymBool (instead of smuggling bools as integers) as work on guard simplification needs to be able to symbolically represent logical operations
- Nuno reported a bunch of bad error messages when doing things that symbolic shapes do not support. An easy task is to make these error messages better.
- GitHub issue request for pack_padded_sequence support _pack_padded_sequence fails in dynamo due to requiring a non-fake 2nd argument · Issue #2024 · pytorch/torchdynamo · GitHub
- There are some intimations that dynamic shapes may be important for jagged tensor for some internal models. Will Feng and Mengchi Zhang to follow up with.
- Add torch.tensor replacement and int_tensor prim by anjali411 · Pull Request #88221 · pytorch/pytorch · GitHub needs a new owner to get it past the finish line. This supports torch.tensor([symint]) style use cases
Things to be done sourced from PyTorch Edge export workstream (Meta only):
- @suo reported that when custom ops are missing meta implementations, you don’t get a nice error message saying “this op needs a meta implementation”. Instead you get P590681504. Custom op was implemented at: Internal Login
- item in dynamo doesn’t work all the time, e.g., Internal Login . But it works with capture_scalar_outputs; can we turn on capture_scalar_outputs by default? More follow up at [dynamo] Fix bug in tensor.item fake tensor propogation by tugsbayasgalan · Pull Request #91668 · pytorch/pytorch · GitHub . In general the item() implementation in Dynamo is suspicious.
- Unbacked SymInt support will be needed soon for boolean masking. Example code that people write at POC: Allow guarding on unbacked symints by ezyang · Pull Request #90985 · pytorch/pytorch · GitHub
Things to be done sourced by generic export workstream (@SherlockNoMad)
- SymInt operations show up in the graph. They ought to have schema, but currently they don’t (e.g., because we put operator.mul directly in the graph). Can we put them in native_functions.yaml? See Add sym_size/stride/numel/storage_offset to native_function.yaml by SherlockNoMad · Pull Request #91919 · pytorch/pytorch · GitHub
Status report:
-
Model training status on master. See also Symbolic shapes work items tracker - Google Sheets
- aot_eager: -15 (+2 MoM). These newly working models are thanks to work by @jbschlosser skip list
- inductor: still really broken on master
-
OpInfo tests on symbolic shapes.
-
pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive
- 516 passed (+3 MoM), 522 skipped (no change), 224 xfailed (-3 MoM) -
pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive
- 291 passed (+5 MoM), 143 skipped (+1 MoM), 197 xfailed (-5 MoM)
-
Guard simplification
As you can see from the chatter before this post, Nuno is working on guard simplification. Nuno looked at resnet, and found that range analysis and a quadratic solver would solve most of the nontrivial constraints we were generating, greatly simplifying the guards we generate. However, there were also suggestions that we could simplify the guards at generation time.
Example constraint: Eq((s2 - 1)//16 + 1, 1)
. This is produced by empty_strided calling compute_contiguous on the passed in strides. Constraint is specifically the size_d != 1 test when testing contiguity. Should avoid guarding here ala subclass_zoo/dynamic_strides.ipynb at main · albanD/subclass_zoo · GitHub cell 18 (but need SymBool for this!)
T z = 1;
// NB: make sure we do signed arithmetic
for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
const auto& size_d = sizes[d];
if (size_d != 1) {
if (strides[d] == z) {
z *= size_d;
} else {
is_contiguous = false;
break;
}
}
}
Example constraint: (s2 - 1)//2 + 1 < (s2 - 1)//2**2 + 2*(s2 - 1)//2 + 1
. This comes from:
File "/Users/ezyang/Dev/pytorch-cpu/torch/_prims/__init__.py", line 348, in _elementwise_meta
strides = utils.compute_elementwise_output_strides(*args_)
File "/Users/ezyang/Dev/pytorch-cpu/torch/_prims_common/__init__.py", line 407, in compute_elementwise_output_strides
comparison = should_swap(perm[dim0], perm[dim1])
File "/Users/ezyang/Dev/pytorch-cpu/torch/_prims_common/__init__.py", line 387, in should_swap
if stride_a < stride_b:
The easiest fix is probably to make sure we don’t run the sorting algorithm when we have contiguous inputs. But even better would be to introduce a contiguity guard (which tests that a tensor is contiguous), which should be able to eliminate these guards entirely.
You can reproduce these experiments with the following code:
model = torchvision.models.resnet18(weights="DEFAULT")
model = torch._dynamo.optimize("eager")(model)
model.eval()
model(torch.rand([64, 3, 7, 7]))
Something that is not so easy right now is finding out what produced guard expressions; e.g., I see x ** 2 // y blah blah, where did it come from? More detailed logging at the Dynamo per-op level would help.
What’s made it to master since last time?
ezyang
- hrnet_w18, tts_angular works with dynamic shapes
- Delete dead intermediary_symbols
- Properly resolve source_ref when constructing shape guards
- Store source, not sname, in Symbol
- Restructure ShapeEnv so it uses GuardBuilder.SHAPE_ENV directly
- Propagate guard failures to userland
voz
jbschlosser
- SymIntify F.interpolate() with recompute_scale_factor=True
- Move sym_int and sym_float alongside SymInt / SymFloat in base torch package
- Decomps and meta registrations for upsample_nearest 1D / 2D / 3D
- Fix for RNN/LSTM/GRU modules to work with stateless.functional_call()
tugsbayasgalan
- Automatically convert real tensors to fake in dynamo export
- Make torch.split take symint as arg
- [dynamo] Fix bug in tensor.item fake tensor propogation
- [dynamo] Support dynamic slicing
- Symintify pytorch slicing logic
nkaretnikov
What’s coming next?
- ezyang: catching up on code review
- voz: changing dynamo backend api to take aot autograd directly
- bdhirsh: on vacation until next week
- Chillee: inductor integration
Our north star:
- Dynamic shapes enabled by default for PT2 release
- Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
- All OpInfo tests passing
- Dynamic shapes on by default for developers / users