State of symbolic shapes: Jul 29, 2023 edition
Previous update: State of symbolic shapes branch - #63 by ezyang
Executive summary
- Data dependent shape support in Inductor. I got an end to end PoC of a pointwise and then reduction with hacks working in Inductor: gist:1293a41299604c44310341b7540eabcb · GitHub The main gaps: (1) optional optimizations failing to retrieve hints (Triton size hints (pick 8192 to prevent the block size from shrinking), multiple of 16 hints (pick something not multiple of 16), 32-bit indexing), (3) buffer reuse (key’ing on the str rep is fine, use sympy_str), (4) updating wrapper codegen to create bindings to i0 variables. In general, it seems it’s pretty useful to have accurate maximum size information, for which ValueRanges is an incomplete fix because we don’t support symbols (s0) in bounds. Another trick we plan to implement is a special irrefutable guard, where if we guard on an unbacked symint, we instead just assume it is True and add a runtime assertion. One question is whether or not we always can get dynamic shapes working no matter what. It seems that in Inductor, we usually can just turn off optimizations to avoid guards. So it seems we just need to get host-side torch.cond working to handle everything else. Some fixes for these are in: If we can’t statically prove 32-bit indexing OK, only add guard if hint exists, Provide a refined upper bound for nonzero when original numel is static
- An initial plan for KeyedJaggedTensor. After studying some of the models that use KJT and trying to get export working on them, here are some of the initial findings:
- You can remove the list of integers from KJT before tracing a model, which will cause the model to perform a data-dependent access to populate these integers as unbacked integers. However, when we try to use these integers to do a tensor_split, we immediately hit guards we cannot prove. The guards should be provable via sum(lengths) == values.shape[0] but our symbolic reasoning is not strong enough. These guards are for errors, so they should be bypassable by irrefutable guards (guards which, if they fail, imply you would have errored anyway. In this case you can convert the guard into a runtime test.) This is worth pursuing further. In any case, you expect to have 500 unbacked symints, symbolic reasoning must be fast enough to deal with it.
- If you don’t remove the list of integers, you need some way to prevent them from 0/1 specializing. In export, you can simply require every sparse feature be populated to size 2 and hope it generalizes to 0/1. In eager, we probably will just have to specialize KJT to treat these integers specially. Big benefit to this strategy is you’re not hard-blocked on guards on unbacked SymInts, since there’s always a hint; don’t need any sum(lengths) reasoning since guards are discharged by checking the underlying values. Cannot actually do this in export because export does not support SymInt inputs–I plan to fix this.
- Export with KJTs doesn’t work because KJTs are not a supported input. Direct fix Add pytree support to KeyedJaggedTensor by ezyang · Pull Request #1287 · pytorch/torchrec · GitHub; indirect fix is rewriting the export calling convention from pytree specs to a dictionary of “FQN” (Source.name()) really to Tensor. In that case, to pass a KJT named
id_list_features
, you would actually pass three tensors,id_list_features._values
, etc. - More details at Meta-only doc (sorry, non-public due to details about Meta prod models).
- Translation validation bisection. We had a case of hint disagreeing with sympy simplification in internal; we’ve also had instances of this in open source, see [dynamo: fix the issue of aten.expand when the source and expaned size are all symbolic size by XiaobingSuper · Pull Request #101173 · pytorch/pytorch · GitHub](integer and real equality). Yukio is thinking of implementing a bisection mechanism for translation validation, so we can find the first guard that actually caused a correctness problem.
- Export for QAT. QAT wants to do whole-graph transformations on a pre-autograd FX graph. Export sort of supports this with
pre_dispatch
export. What is likely going to happen is this turns into the IR format that export is going to use. Pre-autograd functionalization is unlikely to happen; you only get some mild normalization. Still unresolved how to work this into the overall QAT workflow API, since export isn’t really keen on exposing this mid-point IR (which is kind of incoherent.) - Notable bug fixes.
- Change _dynamo.export to be export(f)(*args, **kwargs) and Change _dynamo.explain to be explain(f)(*args, **kwargs) helps avoid ambiguity between user kwargs and export/explain kwargs. It is technically BC-breaking, when you exported a module with no arguments (quite rare!)
- Turn on capture_dynamic_output_shape_ops/capture_scalar_outputs by default for export. Not sure why we hadn’t done this before…
- Make _CURRENT_TRACING_CONTEXT thread local. This occasionally caused a race that typically looked like “fake tensor mode mismatch.”
- Improve FakeTensor to work with mixed meta-cpu embedding bag arguments. This is for you reco system peeps using meta embedding tables with CPU inputs for testing.
- Tweak dynamic=False behavior is in.
- Add missing evaluate_expr for slice_scatter, slight refactor; fixes slice_scatter with SymInt start/end
- Support dynamic shapes in TritonTemplates by ipiszy · Pull Request #105295 · pytorch/pytorch · GitHub - responsible for decent TIMM improvement
- Notable new bugs.
- [dynamo.export] symbolic_shapes.GuardOnDataDependentSymNode - lively discussion about irrefutable guards
- llama model failed for dynamic shape path - this is cpu backend specifically
- Tensors always get 0/1 specialization guards, even if they’re not used - discovered by Animesh
- Bug when dealing with fallbacks on CPU · Issue #105853 · pytorch/pytorch · GitHub - not really sure what’s going on with this one
CI skips. -3, -1, -1, -2 (no change).
Training dashboard (as of 1da4115702). This week on HUD
Metric | Torchbench | Huggingface | TIMM models | Dynamic |
---|---|---|---|---|
Passrate | 92%, 59/64 | 96%, 44/46 | 98%, 59/60 | 100%, 8/8 |
Speedup | 1.54x → 1.56x | 1.69x | 1.28x → 1.35x | 1.97x → 2.04x |
Comptime | 81s | 107s → 108s | 142s | 38s → 39s |
Memory | 0.79x | 0.96x | 1.01x | 0.69x |
- TIMM models had the most change. Some of this is from cait_m36_384 which had its batch size changed. Some others are from Support dynamic shapes in TritonTemplates by ipiszy · Pull Request #105295 · pytorch/pytorch · GitHub (ghostnet_100). Some are across the board improvements (e.g., resmlp_12_224)
Inference dashboard (as of 1da4115702). This week on HUD
Metric | Torchbench | Huggingface | TIMM models | Dynamic |
---|---|---|---|---|
Passrate | 88%, 65/74 | 98%, 45/46 | 100%, 60/60 | 58%, 7/12 |
Speedup | 1.55x → 1.54x | 1.78x → 1.77x | 1.79x → 1.80x | 3.03x → 3.08x |
Comptime | 35s → 36s | 44s → 45s | 36s | 72s → 75s |
Memory | 0.68x | 1.11x | 0.84x → 0.85x | 0.87x |
Looks all within noise.
What’s next
- Rewriting export input/output spec flattening
- Irrefutable guards
- Generally more pushing on KJT stuff