State of symbolic shapes branch

State of symbolic shapes: Jul 29, 2023 edition

Previous update: State of symbolic shapes branch - #63 by ezyang

Executive summary

  • Data dependent shape support in Inductor. I got an end to end PoC of a pointwise and then reduction with hacks working in Inductor: gist:1293a41299604c44310341b7540eabcb · GitHub The main gaps: (1) optional optimizations failing to retrieve hints (Triton size hints (pick 8192 to prevent the block size from shrinking), multiple of 16 hints (pick something not multiple of 16), 32-bit indexing), (3) buffer reuse (key’ing on the str rep is fine, use sympy_str), (4) updating wrapper codegen to create bindings to i0 variables. In general, it seems it’s pretty useful to have accurate maximum size information, for which ValueRanges is an incomplete fix because we don’t support symbols (s0) in bounds. Another trick we plan to implement is a special irrefutable guard, where if we guard on an unbacked symint, we instead just assume it is True and add a runtime assertion. One question is whether or not we always can get dynamic shapes working no matter what. It seems that in Inductor, we usually can just turn off optimizations to avoid guards. So it seems we just need to get host-side torch.cond working to handle everything else. Some fixes for these are in: If we can’t statically prove 32-bit indexing OK, only add guard if hint exists, Provide a refined upper bound for nonzero when original numel is static
  • An initial plan for KeyedJaggedTensor. After studying some of the models that use KJT and trying to get export working on them, here are some of the initial findings:
    • You can remove the list of integers from KJT before tracing a model, which will cause the model to perform a data-dependent access to populate these integers as unbacked integers. However, when we try to use these integers to do a tensor_split, we immediately hit guards we cannot prove. The guards should be provable via sum(lengths) == values.shape[0] but our symbolic reasoning is not strong enough. These guards are for errors, so they should be bypassable by irrefutable guards (guards which, if they fail, imply you would have errored anyway. In this case you can convert the guard into a runtime test.) This is worth pursuing further. In any case, you expect to have 500 unbacked symints, symbolic reasoning must be fast enough to deal with it.
    • If you don’t remove the list of integers, you need some way to prevent them from 0/1 specializing. In export, you can simply require every sparse feature be populated to size 2 and hope it generalizes to 0/1. In eager, we probably will just have to specialize KJT to treat these integers specially. Big benefit to this strategy is you’re not hard-blocked on guards on unbacked SymInts, since there’s always a hint; don’t need any sum(lengths) reasoning since guards are discharged by checking the underlying values. Cannot actually do this in export because export does not support SymInt inputs–I plan to fix this.
    • Export with KJTs doesn’t work because KJTs are not a supported input. Direct fix Add pytree support to KeyedJaggedTensor by ezyang · Pull Request #1287 · meta-pytorch/torchrec · GitHub; indirect fix is rewriting the export calling convention from pytree specs to a dictionary of “FQN” (Source.name()) really to Tensor. In that case, to pass a KJT named id_list_features, you would actually pass three tensors, id_list_features._values, etc.
    • More details at Meta-only doc (sorry, non-public due to details about Meta prod models).
  • Translation validation bisection. We had a case of hint disagreeing with sympy simplification in internal; we’ve also had instances of this in open source, see [https://github.com/pytorch/pytorch/pull/101173](integer and real equality). Yukio is thinking of implementing a bisection mechanism for translation validation, so we can find the first guard that actually caused a correctness problem.
  • Export for QAT. QAT wants to do whole-graph transformations on a pre-autograd FX graph. Export sort of supports this with pre_dispatch export. What is likely going to happen is this turns into the IR format that export is going to use. Pre-autograd functionalization is unlikely to happen; you only get some mild normalization. Still unresolved how to work this into the overall QAT workflow API, since export isn’t really keen on exposing this mid-point IR (which is kind of incoherent.)
  • Notable bug fixes.
  • Notable new bugs.

CI skips. -3, -1, -1, -2 (no change).

Training dashboard (as of 1da4115702). This week on HUD

Metric Torchbench Huggingface TIMM models Dynamic
Passrate 92%, 59/64 96%, 44/46 98%, 59/60 100%, 8/8
Speedup 1.54x → 1.56x 1.69x 1.28x → 1.35x 1.97x → 2.04x
Comptime 81s 107s → 108s 142s 38s → 39s
Memory 0.79x 0.96x 1.01x 0.69x

Inference dashboard (as of 1da4115702). This week on HUD

Metric Torchbench Huggingface TIMM models Dynamic
Passrate 88%, 65/74 98%, 45/46 100%, 60/60 58%, 7/12
Speedup 1.55x → 1.54x 1.78x → 1.77x 1.79x → 1.80x 3.03x → 3.08x
Comptime 35s → 36s 44s → 45s 36s 72s → 75s
Memory 0.68x 1.11x 0.84x → 0.85x 0.87x

Looks all within noise.

What’s next

  • Rewriting export input/output spec flattening
  • Irrefutable guards
  • Generally more pushing on KJT stuff
1 Like