State of symbolic shapes branch

State of symbolic shapes: Mar 5 edition

Previous update: State of symbolic shapes branch - #43 by ezyang

Executive summary

The tl;dr:

  • Training support has landed in master, BUT you need to set torch._functorch.config.use_dynamic_shapes = True to use it (will be fixed soon)
  • specialize_int_float is renamed to specialize_int, and it now actually works. It is temporarily defaulted to True (but this will change soon)
  • Only int inputs will be allowed in inductor; floats must be passed as 0d tensors. More at Handling int/float inputs/intermediates in Inductor - Google Docs

The details:

  • Training support has hit master… sort of. Horace’s patch to add training support passed CI and is landed, but it turns out our testing was insufficient and you need to manually turn on torch._functorch.config.use_dynamic_shapes = True for it to actually work. This is fixed in Fix training enablement in AOTAutograd by ezyang · Pull Request #95975 · pytorch/pytorch · GitHub which isn’t landed yet. A large number of models pass, but there are also many models which fail or have accuracy failures, we plan on burning these issues down shortly.
  • Handling int/float inputs/intermediates in Inductor. One of the major points of contention are how exactly non-Tensor ints/floats supposed to be passed to Inductor. Should they be passed as SymInt or 0d tensors? We’ve finally resolved the question at Handling int/float inputs/intermediates in Inductor - Google Docs The overall idea is that sizevar computation should be expressed as non-Tensor computation (or, more specifically, with sympy expressions), but everything else should just be Tensor compute, for ease of lowering. In practice, this means ints are passed as ints (as we must track their sympy expressions in case they are used in a sizevar compute), but we have made a policy decision that input floats can NEVER be used sizevar compute, and thus they can always be passed as 0d tensors.
  • Int unspecialization actually works now. Previously, there was a knob specialize_int_float, which, hypothetically, if set to False, allowed you to generate Dynamo graphs which didn’t specialize on int/float inputs. In practice, this knob didn’t actually do anything, as every integer between 0-17 was specialized anyway. In practice, this matters; for example, overspecialization in torch._dynamo.exc.Unsupported: dynamic shapes: arange · Issue #93468 · pytorch/pytorch · GitHub was due to Dynamo deciding that a batch size of 10 was small enough to specialize on. Make int unspecialization actually work fixes that problem. However, this in the process uncovered a pile of bugs in Inductor regarding unspecialized ints. Right now, int unspecialization is not turned on by default but we intend to shift to it soon, allowing for regressions in CI.
  • We now allow implicit specialization via int conversion. Previously, if you ran int(symint), we would raise a warning, saying that this would cause a specialization, and if you really wanted it, you should explicitly guard_int. We have now relaxed this restriction: we will implicitly convert SymInts to ints and introduce guards as necessary. This switch is profitable because there are a number of APIs which cannot, even in principle, support dynamic shapes, and so allowing these implicit conversions make these APIs work (as opposed to fail as we do today).
  • Guards depending on unbacked SymInts. @tugsbayasgalan took the new nonzero export support in master, and he noticed one major gap: in some cases, we would generate guards that depended on unbacked SymInts, which is a big no-no, because guards must be executed at the very beginning of a model, but an unbacked SymInt may only become known midway through execution. The fix for this Don't generate guards that refer to unbacked SymInts by ezyang · Pull Request #95732 · pytorch/pytorch · GitHub revealed that there a number of guards with the odd property: (1) if you replace the backed shape variables (s0, s1, …) with their size hints, you can statically determine what the guard should evaluate to given the example inputs, but… (2) without this replacement, it’s not clear if the guard is true or not. For example, Ne(36*i0, 36) is trivially True when i0 > 1, but the real expression in context is Ne(i0*((s2//2 - 1)//8)**2 + 2*i0*((s2//2 - 1)//8) + i0, ((s2//2 - 1)//8)**2 + 2*((s2//2 - 1)//8) + 1) (which Tugsuu also thinks is True, but sympy can’t figure it out.) Another example is Eq(i0*s3**2, 9*i0), where this should result in a guard that s3 = 3 but sympy once again cannot figure it out. Our current hypothesis is that many of these shape variables are actually static at the end, but at the time the guard we don’t know what they are; so either deferring the checks till later or encouraging users to assume_static_by_default = True should help. Tugsuu will validate this next week.
  • PSA: size_hint vs evaluate_expr. We found out this week that some internal teams are making use of ShapeEnv, and were misusing evaluate_expr. Contrary to what its name suggests, this not only evaluates an expression, but it ALSO installs a guard. If you want to peek at what the value of the expression is without guarding, you should use size_hint.
  • State of real world model enablement.
    • Mark Saroufim has tentatively volunteered to add LLaMa and InstructPix2Pix to torchbench, which will help us track whether they work or not with dynamic shapes.
    • OpenNMT’s arange minimal repo no longer overspecializes, but it fails in Inductor now with assert isinstance(numer, int) and isinstance( at
      torch/_inductor/utils.py:83. This failure also affects fastNLP_Bert, speech_transformer and yolov3 inductor inference
    • No updates: MaskRCNN, Detectron2, wav2vec2, fused boolean mask update

The numbers:

  • Model status on master.
    • aot_eager inference: 0 (+1 WoW), Joel’s PR made it in.
    • aot_eager training: 0 (unchanged). No regressions!
    • inductor inference: -4 (+1 WoW). swin fix made it in.
    • inductor inference unspec: -10 (NEW!). This number is tracking inductor inference with specialize_int = False now that unspecialization actually does something. We plan to subsume the old inductor inference number with this one, as unspecialization is important for avoiding overspecialization in graph breaks in practice. We should probably also switch the aot_eager stats to this as well.
    • inductor training: -42ish (NEW!). We don’t have official CI numbers, also an important bug fix hasn’t made it yet (Fix training enablement in AOTAutograd by ezyang · Pull Request #95975 · pytorch/pytorch · GitHub), the number here is based off of a sweep with this PR patched in.
  • Opinfo tests on symbolic shapes.
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 562 passed (unchanged), 523 skipped (unchanged), 195 xfailed (unchanged)
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 320 passed (+2 WoW), 147 skipped (unchanged), 173 xfailed (-2 WoW)
  • Graph breaks on master. 0ish (+3 WoW). A sweep on 2/23 revealed a extra graph breaks on hf_Longformer and AllenaiLongformerBase but Voz manually confirmed that the static model also graph breaks. @wconstab added ability for CI to record graph breaks Add dynamo graph break stats to CI by wconstab · Pull Request #95635 · pytorch/pytorch · GitHub so hopefully we can just start testing the number of graph breaks in CI and ensure we avoid regressions this way.
    Tracing cost of enabling dynamic shapes (aot_eager). Mean: 20s (-1s WoW), Max: 240s (-14s WoW). This looks within noise.

Known problems

  • Inductor guards are silently discarded; this could cause accuracy problems
  • CI is not testing if we actually successfully generate dynamic code, so we could silently regress this (right now, we validate the generated code by manual inspection)
  • We are not testing performance of dynamic shapes code; it could be heinously slow (fixing this is blocked on GCP dashboard runs)
  • Minifier does not work with dynamic shapes
  • Split reductions in Inductor do not work with dynamic shapes
  • Python profiling gives misleading results for what is causing slowdowns

What’s coming next?

  • ezyang: probably a mix of unspec int / training burndown, and discovering more issues from our E2E models. The assert isinstance(numer, int) and isinstance( is a particular blocker for OpenNMT.
  • Chillee: returning to dynamic shapes, probably something symbolic reasoning related, also some training burndown
  • msaroufim: I asked Mark to add LLaMa and InstructPix2Pix to torchbench, we’ll see if he gets to it or not
  • bdhirsh: still aotautograd refactor for training export, as well as per-dispatch key mode stacks and then torchquant/DTensor PT2 support (was working on reference cycle issue this week)
  • voz: inline shape refinement with torch.constrain, make dynamo.export not run the real model, use original arg names on export Use original arg names if possible by voznesenskym · Pull Request #95898 · pytorch/pytorch · GitHub