State of symbolic shapes branch

State of symbolic shapes branch: Dec 12 edition

Previous update: State of symbolic shapes branch - #19 by ezyang

Commit ID at time of writing: bcb284d77fe865373b2f1617867320fb32ea68af

Executive summary

We master now peeps! We are officially retiring the symbolic-shapes branch. There are still some changes that need to be merged to master (Symbolic shapes work items tracker - Google Sheets “Merge to master” sheet), but the major fixes for shape guard creation have landed to master, so all that remains on the branch are some QOL fixes (in particular the debug interpreter is no longer on by default), a little bit of op coverage, and some experimental code (esp for inductor integration) that needs to be rewritten anyway.

Previous branch diff: 68 files changed, 2612 insertions(+), 554 deletions(-)
Current branch diff: 0 files changed, 0 insertions(+), 0 deletions(-)

Notable bugs

  • It turns out checkpointing doesn’t operate on ShapeEnv, but it should. Dynamo uses checkpoints to roll back its internal state after executing an instruction (or many instructions, in the case of an inlined function call) fails, so that it can pretend those instructions never executed. However, because ShapeEnv isn’t checkpointed, shape guards that occur during the rolled back instructions still end up getting installed. This can result in a hard error if the guards refer to variables that we don’t know about from the outer context (we think hf_Reformer and swin_base_patch4_window7_224 are affected by this). Checkpointing the ShapeEnv performantly is nontrivial, as we refine the context with equalities and use that to drive sympy simplification, all of which would need to be undone. This bug is still unfixed.
  • Preserve original GraphArgs for shape guard codegen and Rewrite dynamo cond() handling to not recursively call export are both fixes for pretty interesting bugs, if I don’t say so myself. Go checkout their PR descriptions for more details.

What’s made it to master this week?

ezyang

nkaretnikov

voz

SherlockNoMad

What’s coming next?

By Person:

  • voz: Guard refactor in dynamo
  • ezyang: burn down symbolic shapes, fix bugs, work on exporting all shape expressions to form a benchmark, aot autograd default api maybe?
  • bdhirsh: continue fixing AOTAutograd v2 follow up bugs
  • jbschlosser: merge to master tasks, burn down symbolic shapes
  • unallocated: inductor integration

Our north star:

  • All benchmark models are passing aot_eager and inductor training on branch
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing
  • Dynamic shapes on by default for developers / users
4 Likes

State of symbolic shapes: Dec 19 edition

Previous update: State of symbolic shapes branch - #20 by ezyang

Commit ID at time of writing: 212873c615dd3455a24d390605335aeeebd76236

Executive summary

This week, we turned on dynamic shapes with aot_eager on CI in the inductor job. Compared with static shapes aot_eager, we only have a 17 failures difference on master! Inductor remains in bad shape in master, as we are still waiting on @Chillee to submit his PR with fixes.

In other news, @ezyang has released a benchmark for reasoning on shape computation: GitHub - ezyang/SMT-LIB-benchmarks-pytorch-shapes: SMT-LIB benchmarks for shape computations from deep learning models in PyTorch If you work on SMT solvers or like symbolic reasoning systems, check it out! It offers an easy way to test out new ideas about how to symbolically reason over shape compute. We still have a number of infinite loops in Sympy, although this week we are now just suppressing all stack overflows induced by Sympy.

  • Model training status on master. See also Symbolic shapes work items tracker - Google Sheets
  • OpInfo tests on symbolic shapes.
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 513 passed (+5 WoW), 522 skipped (no change), 227 xfailed (-3 WoW)
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 286 passed (+5 WoW), 142 skipped (+1 WoW), 203 xfailed (-5 WoW)

Notable bugs

  • Despite overhauling ShapeEnv guard production in Dynamo two weeks ago, there were still more stragglers that had to be addressed this week. The main source of problems was a mismatch between when we added a tensor to GraphArgs (as it is an FX graph input) and when we allocated dynamic shapes for a tensor (so we may need to determine the source of its symbolic shape). This lead to more refactoring in Dynamo so that we could guarantee that whenever a tensor had symbolic shapes allocated for it, we also tracked it for the purposes of guard creation. This fixed all bugs, except one(!), which @ezyang has an open PR set for (involving more refactoring.)
  • Assert for functional graph is FINALLY in master, and it caught more bugs in inductor lowerings when it landed. Hooray for more stringent asserts.

What’s made it to master this week?

ezyang

jbschlosser

voz

bdhirsh

What’s coming next?

By Person:

  • voz: vacation
  • ezyang: vacation
  • bdhirsh: continue AOTAutograd v2 follow up
  • jbschlosser: merge to master and burn down
  • Chillee: inductor integration (apparently, Horace “has a few fixes” but they’re still not posted yet)

Our north star:

  • All benchmark models are passing aot_eager and inductor training on branch
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators
  • All OpInfo tests passing
  • Dynamic shapes on by default for developers / users

(SMT/verification geek here)

Can you give a bit more context about this?
Are these SMT files created by Sympy or is this an alternative code path in PyTorch that bypasses Sympy entirely?

The goal of the symbolic shapes code, I think, is two fold:

  • Do type/shape checking. This is a simple satisfiability check to ensure the program is type-safe
  • Compute symbolic shapes for each operator, with a simplified expression (whatever that means). This is not easily solvable with SMT solvers.

Also, could you give an example of a complex shape computation that requires such heavy machinery?

Happy to help with this stuff if needed, but I need a bit more context.
I’ve started to play with symbolic stuff today, with partial success. I still see some crashes, especially when mixing with torch.compile. I’m probably doing something that is not supported…

Thanks,
Nuno

Happy to! The SMT files here bypass sympy entirely; they contain exactly the shape computations that the original user program / operators had. The asserts correspond to places where people (usually operators) did control flow

The goal of the reasoning code is a little difficult to express exactly. The smt2 instances test one particular thing that is easy to formulate for SMT solvers: given a sequence of shape computations and asserts, find a satisfying assignment of input shapes that satisfies all the asserts. This could be used to, for example, take a model and regenerate a small version of it (with small parameters and inputs) for easy testing. But this is actually not really the most pressing problem for dynamic shapes in PyTorch as is, IMO. The PyTorch reasoning code tends to do two other things: find simple canonical forms for size expressions, so that we can eventually simplify the loop bound / indexing computation we generate in inductor, and determine if guards are redundant or not (since we spend a lot of time testing conditionals, which naively must be retested before we can reuse compiled code.) And we would like to do this all in a simple Python implementation, which precludes more typical approaches like throwing LLVM’s optimizer at the problem. “Simple” canonical forms is not so easy to express in smt-lib2 nomenclature; redundant guards is technically expressible but awkward (and better suited as a compiler benchmark.)

We currently have a baseline sympy+unification implementation (which I have provided a crappy hookup for) which also takes some extra simplifications I haven’t encoded in the problem instance yet (duck sizing, 0/1 specialization) and while, if you treat sympy as a black box it’s pretty simple, empirically our biggest problem is that it is slow. I’ll post some measurements I took but in some pathological cases we can spend minutes just crunching sympy simplifications. And as a team we have a bit of a disagreement about how to proceed. I kind of want to chuck sympy entirely and roll our own domain specific algebra system, but @Chillee thinks this is too much and we should be able to patch over the specific sympy badness more quickly than having to rewrite everything from scratch. We also kind of need an entailment system, which an SMT solver would provide, but we don’t have a good sense for how to do this integration. Should Z3 really be a mandatory dependency for dynamic shapes?

Re crashes, post up what you are doing. Inductor still doesn’t work on master but you should be able to play around with aot_eager

These are the timings I took on the baseline implementation, showing which models have unusually horrible perf in sympy in symbolic shapes today.

Z3 has a decent Python API. So, in theory this can be integrated easily.
Z3 is a dependency for a lot of things these days, including clang’s static analyzer (optional, but a lot of distros ship it enabled).

Z3 can also simplify expressions. It has a bunch of tactics that can be applied to expressions to simplify them. It can also do quantifier elimination, which can be useful for expression simplification.

I understand the guards & bounds checks simplification problems. Those can be mapped into entailment checks as you say.

What I don’t understand yet is why are these shape inference expressions so complicated. I was just looking at Torchy’s code that does shape+stride inference, and doesn’t look that complicated (there are min, max, equalities, mul, add, reduce). I think the worst part is that in a lot of cases you need to know the number of dimensions. If that’s also symbolic, things become very non-trivial. Do you need to give an upper bound for the dimensions? Or do you generate predicates that are valid in the unbounded case?

I’m happy to help with this stuff (designing, optimizing, whatever) if needed. I’ve experience with symbolic reasoning.

Regarding crashes:
this one issues a warning:

@torch.compile
def fn(a):
    b = a * a
    return b

# [WARNING] Unsupported: meta converter nyi with fake tensor propagation
print(fn(torch.tensor([1.0, 2.0], device="meta")))

Assert fails:

@torch.compile
def fn(a):
    b = a * a
    return b

with FakeTensorMode() as mode:
    print(fn(torch.empty([2,3])))

#   File "_dynamo/variables/builder.py", line 611, in wrap_tensor
#    assert type(value) in (torch.Tensor, torch.nn.Parameter)

Assert fails; could give a nice error message:

# without allow_meta=True
with FakeTensorMode() as mode:
    print(fn(torch.empty([2,3], device="meta")))

# File "_subclasses/fake_tensor.py", line 587, in __init__
#    assert device.type != "meta"

Another crash:

torch._dynamo.config.dynamic_shapes = True
torch._functorch.config.use_dynamic_shapes = True

@torch.compile
def fn(a):
    b = a * a
    return b

print(fn(torch.tensor([1.0, 2.0])))

#   File "sympy/core/cache.py", line 70, in wrapper
#    retval = cfunc(*args, **kwargs)
# TypeError: unhashable type: 'SymInt'
# RuntimeError: Trying to extract a concrete int out of a symbolic int

All I was trying was to print some complicated symbolic shape expressions so I could understand the problem a bit better. But then I hit all these crashes, so I must be doing something very wrong…

I definitely think Z3 as an optional dependency is not too hard a sell. The turning point would be if we could make our design substantially simpler if we could assume Z3 was always available; then the calculation would turn to whether or not PyTorch by default has Z3 as a dep, which, enhhhhh. It’s easier for a distro to do it probably.

Nope, everything is fixed dimensionality. So honestly the computations in the smt2 files are not that complicated, but they can be very repetitive because we are just tracing the real shape compute PyTorch does which was not necessarily written in a nice way for symbolic analysis.

A big gap I see from the operators you quote here is floor division (shows up a bit in pooling operations) and modulus (not sure where these come from actually.)

Definitely looking for collabs!

These are crashing for silly reasons haha. In order:

  1. Meta tensors intentionally don’t work with fake tensor (which is what PT2 will do.) In principle they actually should work fine but real world user code doesn’t actually need to optimize code computing on meta tensors, and when we were working on fake tensor it was usually a bug to try to fakeify a meta tensor, soooo yeah. @eellison we probably should fix this eventually
  2. You don’t need to explicitly fakeify before calling torch.compile; it does that for you. So what happened here is you said hey PT2 compile this with a tensor subclass input and this doesn’t work. Shouldn’t be an assert though; we should file a bug on this.
  3. This is (1) and yeah let’s beef up the medsage
  4. This one I think is because inductor and dynamic shapes is busted. So you need to make sure you don’t use inductor. Easiest is to use torch._dynamo.optimize(“aot_eager”) instead of compile

For playing around with simple examples, your best bet is to look at the tests with Symbolic in their class name in test/test_proxy_tensor.py. In particular, most of these call make_fx with tracing mode symbolic. This will give you the smallest slice of the system that is doing something interesting with dynamic shapes.

1 Like

Thank you!

Just a couple more crashes (I know, I’m a magnet for them…):

def fn(a):
    b = a * a
    if b.sum():
        return a
    return b

print(fn(torch.tensor([1.0, 2.0])))
traced_f = make_fx(fn, tracing_mode="symbolic")(torch.tensor([1.0, 2.0]))
print(traced_f)

# RuntimeError: tried to get Double out of SymFloat

If using this instead:

def fn(a):
    b = a * a
    if b.sum() >= 1:
        return a
    return b

# NotImplementedError: local_scalar_dense/item NYI for torch.bool

I’ve attempted to fix it by patching fake tensor’s local_scalar_dense:

    elif is_integer_dtype(arg.dtype) or is_boolean_dtype(arg.dtype):
        return fake_mode.shape_env.create_unbacked_symint()

But no luck; still crashes (tried to get Long out of SymInt).

Re the latest crashes, these are all due to attempting to do control flow on data dependent values (in this case, the float in the tensor). This is expected but we can make the error messages better.

State of symbolic shapes: Jan 11 edition

Previous update: State of symbolic shapes branch - #22 by ezyang

Commit ID at time of writing: f6c7cf1bf579cc42ea7e21bd557168625648a3e9

Executive summary

Welcome back from holiday and PSC folks! The holidays were sleepy, but not that sleepy: @jbschlosser landed some nice coverage wins (+2 working models on master) and @tugsbayasgalan has been working on a number of dynamic shape fixes that show up in PyTorch Edge export use cases. Since it’s the new year, you might be wondering what you ought to be doing in the new year. This status post is to help you figure out.

High priority things that need to be done (excluding export):

Things to be done sourced from PyTorch Edge export workstream (Meta only):

Things to be done sourced by generic export workstream (@SherlockNoMad)

Status report:

  • Model training status on master. See also Symbolic shapes work items tracker - Google Sheets
    • aot_eager: -15 (+2 MoM). These newly working models are thanks to work by @jbschlosser skip list
    • inductor: still really broken on master :rotating_light::rotating_light::rotating_light:
  • OpInfo tests on symbolic shapes.
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 516 passed (+3 MoM), 522 skipped (no change), 224 xfailed (-3 MoM)
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 291 passed (+5 MoM), 143 skipped (+1 MoM), 197 xfailed (-5 MoM)

Guard simplification

As you can see from the chatter before this post, Nuno is working on guard simplification. Nuno looked at resnet, and found that range analysis and a quadratic solver would solve most of the nontrivial constraints we were generating, greatly simplifying the guards we generate. However, there were also suggestions that we could simplify the guards at generation time.

Example constraint: Eq((s2 - 1)//16 + 1, 1). This is produced by empty_strided calling compute_contiguous on the passed in strides. Constraint is specifically the size_d != 1 test when testing contiguity. Should avoid guarding here ala subclass_zoo/dynamic_strides.ipynb at main · albanD/subclass_zoo · GitHub cell 18 (but need SymBool for this!)

 T z = 1;
  // NB: make sure we do signed arithmetic
  for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
    const auto& size_d = sizes[d];
    if (size_d != 1) {
      if (strides[d] == z) {
        z *= size_d;
      } else {
        is_contiguous = false;
        break;
      }
    }
  }

Example constraint: (s2 - 1)//2 + 1 < (s2 - 1)//2**2 + 2*(s2 - 1)//2 + 1. This comes from:

File "/Users/ezyang/Dev/pytorch-cpu/torch/_prims/__init__.py", line 348, in _elementwise_meta
  strides = utils.compute_elementwise_output_strides(*args_)
 File "/Users/ezyang/Dev/pytorch-cpu/torch/_prims_common/__init__.py", line 407, in compute_elementwise_output_strides
  comparison = should_swap(perm[dim0], perm[dim1])
 File "/Users/ezyang/Dev/pytorch-cpu/torch/_prims_common/__init__.py", line 387, in should_swap
  if stride_a < stride_b:

The easiest fix is probably to make sure we don’t run the sorting algorithm when we have contiguous inputs. But even better would be to introduce a contiguity guard (which tests that a tensor is contiguous), which should be able to eliminate these guards entirely.

You can reproduce these experiments with the following code:

model = torchvision.models.resnet18(weights="DEFAULT")
model = torch._dynamo.optimize("eager")(model)
model.eval()
model(torch.rand([64, 3, 7, 7]))

Something that is not so easy right now is finding out what produced guard expressions; e.g., I see x ** 2 // y blah blah, where did it come from? More detailed logging at the Dynamo per-op level would help.

What’s made it to master since last time?

ezyang

voz

jbschlosser

tugsbayasgalan

nkaretnikov

What’s coming next?

  • ezyang: catching up on code review
  • voz: changing dynamo backend api to take aot autograd directly
  • bdhirsh: on vacation until next week
  • Chillee: inductor integration

Our north star:

  • Dynamic shapes enabled by default for PT2 release
  • Fallback implementation for custom operators without symbolic shape propagation, inferred by running fallback on real operators :rotating_light::rotating_light::rotating_light:
  • All OpInfo tests passing
  • Dynamic shapes on by default for developers / users
1 Like

Mini-update: Status of unbacked SymInts on Jan 16

I recently presented our progress on unbacked SymInts, our strategy for data-dependent output sizes, in the most recent composability meeting (meeting notes: Composability meeting notes - Google Docs , Meta only recording: Redirecting...). This status post will recap what I described in the meeting, and also explain what you should expect on unbacked symints in the near future.

tl;dr I (@ezyang) will be deprioritizing proper support for unbacked SymInts, because it looks like there are fundamental infrastructure improvements in Sympy reasoning and tracing performance we need work on first. Also, unbacked SymInts are not launch blocking for dynamic shapes in PT2. Fortunately, we have identified a few short term unblocking steps that can help immediate users of unbacked SymInts make progress, albeit at the cost of some slight unsoundness (which we don’t expect to matter in practice.)

Background

In PyTorch’s tracing model, we ordinarily try to treat the shapes of input tensors symbolically. However, if we need to perform control flow on an expression involving one of these symbolic sizes, we peek at the true values to resolve the condition to true/false, and install a guard saying that are trace is only valid if the condition evaluates equivalently in the future. This is motivated by the fact that in a tracing framework, we cannot easily trace both sides of the conditional (you could use something like thermometer continuations to run the trace as if the condition evaluated true, and then rewind and rerun the trace as if the condition evaluated false, but you still have the problem of a combinatorial explosion of possible paths you could go down.)

Guarding works well for statically known sizes, but if you call an operator like torch.nonzero, it will produce a size that is only known at runtime. Our idea for how to handle this case is simple: we produce a symbolic size that has no underlying value (aka is “unbacked”), and instead error if we attempt to guard on this symbolic size.

Some further reading that you may find helpful: subclass_zoo/dynamic_strides.ipynb at main · albanD/subclass_zoo · GitHub (about dynamic shapes and strides) and subclass_zoo/dynamic_shapes.ipynb at main · albanD/subclass_zoo · GitHub (about dynamic shapes in general)

Current state

Here is the state of unbacked SymInts on master:

Where exactly are these guards coming from? Here is a mostly complete accounting:

    case MemoryFormat::Contiguous: {
      // dim_ is a virtual call, don't repeat it
      const auto dim_ = dim();
      extra_meta_->strides_.resize(dim_);
      if (dim_ > 0) {
        const auto last_idx = dim_ - 1;
        extra_meta_->strides_[last_idx] = c10::SymInt(1);
        for (auto i = last_idx - 1; i >= 0; --i) {
          extra_meta_->strides_[last_idx] =
              extra_meta_->strides_[i + 1] * extra_meta_->sizes_[i + 1].max(1);
        }
      }
  • When we construct a TensorImpl, we have to compute whether or not it is contiguous. It turns out that we don’t short circuit this computation when torch.empty (whose output is guaranteed to be contiguous) is called. This computation branches on size and stride:
  for (int64_t d = int64_t(sizes.size()) - 1; d >= 0; d--) {
    const auto& size_d = sizes[d];
    if (size_d != 1) {
      if (strides[d] == z) {
        z *= size_d;
      } else {
        is_contiguous = false;
        break;
      }
    }
  }
  • Even if we were to shortcut contiguity call, we still need to compute whether or not the tensor is channels last contiguous. This is because some contiguous tensors are ALSO channels last contiguous, e.g., (1, 1, 1, 1). This computation branches on size and stride in a similar way:
      T expected = 1;
      for (auto& d : {1, 3, 2, 0}) {
        const auto& size_d = sizes[d];
        if (size_d != 1) {
          if (strides[d] != expected) {
            return bool_is_channels_last_contiguous(false);
          }
          expected *= size_d;
        }
      }
      return bool_is_channels_last_contiguous(true);
  • We also compute whether or not a tensor is non-overlapping and dense. The computation here is quite involved: we have to do a sort on the strides (sorting network, anyone?) But fortunately, it also can be short circuited in the common case (since a tensor is definitely non-overlapping and dense if it is contiguous.) However, it cannot always be short-circuited; for example, if you allocate a tensor with an unbacked SymInt size and then take a view on it, the view may not be contiguous, and so you have to do the full computation in that case. This is exactly what happens in the case of boolean indexing (in the meta implementation of index.Tensor). One extra problem is the call to nonzero here is in library code, so if you want to add asserts on the result of index, you can’t easily do this.
            if index.dtype in [torch.int8, torch.bool]:
                nonzero = index.nonzero()
                k = len(result)
                check(
                    k + index.ndim <= self.ndim,
                    lambda: f"too many indices for tensor of dimension {self.ndim}",
                    IndexError,
                )
                for j in range(index.ndim):
                    check(
                        index.shape[j] == self.shape[k + j],
                        lambda: f"The shape of the mask {index.shape} at index {i} "
                        f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
                        IndexError,
                    )
                    result.append(nonzero.select(1, j))

To support “proper” unbacked SymInts, we must modify our framework code to avoid guarding in all of these cases. We also need a mechanism by which users can insert extra runtime assertions to ensure that if a guard is unavoidable (e.g., size >= 0) but will always resolve one direction, we can manually guide tracing down one branch of the guard and have a runtime test to verify that we would have gone down that path at runtime as well.

Progress towards unbacked SymInts

The diff stack at Switch is_contiguous to use guardless implementation by ezyang · Pull Request #92234 · pytorch/pytorch · GitHub was my attempt to directly remove all of these guards. The infrastructure pieces (first three PRs) were not too bad (and I intend to land them), however, they do not actually remove the guards. The actual guard removal has run into at least two problems:

  • When I remove the stride computation guard (max between size and 1), it makes test_aot_autograd_symbolic_exhaustive_nn_functional_max_pool2d_cpu_float32 take a long time to run. This suggests that one of the guards from stride computation was essential for simplifying our size variables and prevented Sympy from doing a ton of unnecessary work.

  • When I remove the compute contiguous / non-overlapping guards, along with causing problems with max_pool2d, it also makes test_aot_autograd_symbolic_exhaustive_nn_functional_unfold_cpu_float32 take a long time to run.

While I could believe that the ultimate fixes for these two problems could be quite short in the end, the way we thread the needle with our Sympy processing is quite opaque to me, and I don’t feel comfortable in being able to discover the fixes for these problems in that amount of time. Combined with our deadline of getting dynamic shapes turned on by default for the PT2 release (), as well as the fact that we need to make investments to speeding up Sympy simplification (e.g., Nuno’s suggestions), I think it makes sense for me to deprioritize making full unbacked SymInts work in the short term. If a brave soul wants to step up to try to fix these problems, I can help give advice.

The short term unblock

The Executorch folks are attempting to export models with symbolic shapes in them. If we don’t have true unbacked SymInts, what can you do? A lot, it turns out. Here are the workarounds I suggest trying, from least sophisticated (aka most unsound), to most sophisticated (but requiring more work).

  1. If you need to implement a meta function for an operator with dynamic output size, an extremely simple first step is to simply make the function return a fixed size (zero, or the maximum possible value, are good choices) instead of an unbacked SymInt. This is unsound for the following reasons: (1) the fixed size may get burned into the graph (this is more likely to happen with zero; the max size is likely itself a symbolic size, so the burn-in in that case is less severe), (2) conditionals will be burned in according to this size, which means you may miss noticing a branch that actually needs to be replaced with a cond() operator to allow both branches to be exported. However, this is really good for “getting past” a missing output meta and seeing what needs to be done next.

  2. The next step up from (1) is to return an unbacked SymInt, but allow guards to be resolved with respect to an example size in certain contexts. The basic idea is to return an unbacked SymInt, but then “mock up” size_hint so that when we try to guard on the unbacked SymInt, we make use of an example size instead. This prevents some burn-in (as we are still passing around a fresh, unbacked SymInt for tracing purposes); we only burn-in conditional paths. A user could then use logging / manual inspection to see if any conditional paths need to be treated specially. You can also only mock in certain fragments of code (e.g., while executing a tensor factory) where the example value is known to generalize for all possible values. I can help implement this, though I’m not exactly sure when to do it.

We can also in parallel implement the mechanism for adding runtime asserts, although without better guard simplification framework (e.g., Nuno’s range analysis) it is difficult to use these asserts to actually resolve guards on unbacked SymInts.

The long term plan

To be honest, I have not committed to a particular course of action for how to solve these problems. I can think of two plausible trajectories.

The conservative route. We land my PR stack as is, and try to hotfix the particular simplification problems case-by-case. Hopefully, they aren’t too hard to resolve and there aren’t too many of them.

The blow-it-up route. We rewrite ShapeEnv’s symbolic reasoning logic from scratch without Sympy or using Sympy in a much more limited fashion, so that algorithmically it is obvious that we always get good performance. This would also help resolve performance issues in tracing from spending too much time in Sympy (according to @voz , in Bert we spend 30% of our time in Sympy simplification.)

Open to comments from the peanut gallery. Honestly, this is going to depend a lot on who ends up doing this work.

2 Likes

How about learn from JAX and add an optional static-sized variant for nonzero, so at least some powerful users could get past this problem? AFAIK Tensor indexing still uses nonzero so we could unblock those use cases if users are willing to change their code.

I’ve raised this 18 months ago: [JIT] Support JAX-style statically shaped nonzero to avoid host-device synchronization · Issue #62320 · pytorch/pytorch · GitHub. Seems still relevant to me.

Hi!

I would be curious how indexing like t[mask] = t[mask] * 2 can be solved by adding this new size argument? You still need to find a way to get the number of non-zero values internally before being able to call nonzero?

I guess in principle it could be rewritten as t = torch.where(mask, t * 2, t).

See Alternative to array-based boolean indexing for jax.jit · Issue #2765 · google/jax · GitHub

FWIW, I do think we should have a static-sized variant for nonzero. But for example, in the boolean masking example, this wouldn’t actually help, as there is no nonzero call; instead, the nonzero call is implicit in the index by boolean mask. The torch.where works well for pointwise operation on top, but for the flip() example you need to do something more complicated.

1 Like

State of symbolic shapes: Jan 20 edition

Previous update: State of symbolic shapes branch - #31 by ezyang

Executive summary

Volume-wise, there wasn’t that much activity, but there were three landed PRs that had a disproportionate effect on our metrics. First, we landed Brian’s AOTAutograd fixes which fixed a large number of assert faiulres; second, Horace is finally back to dynamic shapes and landed a PR that fixes a few Inductor inference dynamic shapes problem (fixing inductor enough that we can start reporting master stats again); finally, I noticed an accounting problem for our stats, where many of the failures we were reporting actually had nothing to do with dynamic shapes. Overall, this pushed our delta for aot_eager to TWO :tada::tada::tada: (one coverage, one timeout). This is fantastic, and we are turning our attention to other areas of dynamic shapes support:

  • Brian is spearheading tracking the number of extra graph breaks caused by dynamic shapes (tracked on “extra graph breaks” sheet at Symbolic shapes work items tracker - Google Sheets ). For now, we are only looking at torchbench. We don’t have a consolidated statistic to track this week over week yet but we will soon.
  • Horace is grinding down inductor inference failures with dynamic shapes (tracked on “inductor eval” sheet at Symbolic shapes work items tracker - Google Sheets ; the horace sheet is with Horace’s WIP stack). We are in the progress of transitioning regular CI coverage from testing aot_eager training to testing inductor inference, which will allow us to give comparable metrics to aot_eager on master (this week we will have a one-off metric here).
  • Voz is working on improving our tracing time, which is called out by both OSS and internal users as a problem, and is a big problem for dynamic shapes, which is ostensibly about improving compilation times. We are also in the process of preparing a consolidated statistic to track week over week.

We also need to start working on inductor training support, which is has its own unique challenges. We’ve also been discussing nested tensor / jagged tensor compilation with inductor (e.g., PyTorch Composability Sync: Nested/Jagged Tensor compilation - YouTube ). We are deprioritizing work to characterize how dynamic/static our benchmark suite is, and instead indeed to evaluate this ad hoc on use cases where users come to us and say “hey, I need this to be dynamic.” One example is this script from Timothee Lacroix: Redirecting... (Meta only). There is some discussion about needing a more fine-grained way to turn on dynamic shapes (e.g., instead of turning it on for ALL local tensors, only turning it on for tensor dimensions that are known to be dynamic.)

Status report:

What’s made it to master since last time?

ezyang

voz (nothing dynamic shapes related)

Chillee

jbschlosser (nothing; just got back from PTO)

bdhirsh

nkaretnikov

What’s coming next?

  • ezyang: CI stuff, then probably trying to get inductor training going on master
  • Chillee: hosing down inductor inference errors
  • bdhirsh: working on dynamo graph breaks; also working on AOTDispatch enhancements for torchquant and nested tensor
  • jbschlosser: not sure yet
  • nkaretnikov: enabling dynamic shapes testing on inductor

Our north star: Dynamic shapes at feature parity with static shapes for PT2 release (but NOT turned on by default)

2 Likes

There is now a manual for all things dynamic shapes related, check it out here: The dynamic shapes manual - Google Docs

1 Like

State of symbolic shapes: Jan 29 edition

Previous update: State of symbolic shapes branch - #37 by ezyang

Executive summary

We are two weeks away from branch cut for PyTorch 2.0. Dynamic shapes has enough on master that we are non-blocking for the release: there is still a lot we want to get in before the release, but the most important stuff is landed. In particular, Horace landed more inference fixes and we also have enabled CI for Inductor inference on master. There is a PR in progress for training https://github.com/pytorch/pytorch/pull/93059 but our general thinking is that dynamic shapes is more important for inference (where you are more likely to want to vary sequence length) as opposed to training.

Horace’s order of operations is: (1) basic training support, (2) inference performance on autoregressive generation, (3) other stuff; Edward will just be working on general enablement here and there. Voz is still working on trace time performance (some improvements landed, and some very promising work on short circuiting meta computation at [WIP] [RFC] add shape_preserving notions to decomps for fake_tensor specific short circuiting by voznesenskym · Pull Request #93118 · pytorch/pytorch · GitHub could also lead to speed wins with static shapes too.) Brian and Joel have still been working on Dynamo graph breaks, although none of the PRs from this workstream have landed yet (still working out Dynamo code review.)

  • Models outside of the benchmark suite. We took some fun models out for a spin last week. wav2vec2 is successfully running inference under torch.compile with dynamic shapes. maskrcnn is not in as good a state, but a lot of its blockers are things we know about and have been working on.
  • Accuracy failures. Background_Matting and LearningToPaint are failing accuracy with inductor inference with dynamic shapes, but not without dynamic shapes. These are priority to fix.
  • Documentation. This got its own post, but in case you missed it: there is now a manual for dynamic shapes enablement: The dynamic shapes manual - Google Docs Let us know if there’s anything you’d like to see in it.
  • How dynamic is the benchmark suite? Edward ran an experiment where he printed out the number of unique symbolic variables after tracing. Interestingly, most models only have one unique symbolic variable (likely the batch dimension.)
  • Why is tracing so slow? Voz added a bunch of extra instrumentation to help better characterize what exactly we’re doing when tracing, and Horace ran some experiments. One of the more interesting results was that in hf_Bert inference, Dynamo produces a graph with 570 nodes, but after AOTAutograd this balloons to 1528 nodes. Making matters worse, fake tensor is invoked 47000 times (16k occurring before AOTAutograd, 31k after.) This is what pushed us in the direction of reducing fake tensor overhead with meta function short circuiting. Hacky experiments by Voz show we can get a 50-70% speedup this way. Also, pytree is slow, we are eagerly awaiting [WIP][POC][pytree] Use OpTree for PyTree manipulation by XuehaiPan · Pull Request #92679 · pytorch/pytorch · GitHub
  • Model training status on master. See also Symbolic shapes work items tracker - Google Sheets
    • aot_eager inference: -3 (NEW!). It turns out there are some models that are failing static shapes aot_eager training but not inference. These appear to be failing for straightforward coverage reasons and should be easily fixable.
    • aot_eager training: -1 (-1 WoW). The only remaining error is a timeout, which we hope will be resolved by trace time performance work.
    • inductor inference: -16 (NEW!). Doing a more direct comparison against Horace’s stack from last week, a manual sweep gives 143/160 (+43 WoW)
    • inductor training: with Horace’s patch, 49/129 (NEW!)
  • OpInfo tests on symbolic shapes.
    • pytest test/test_proxy_tensor.py -k test_make_fx_symbolic_exhaustive - 547 passed (+5 WoW), 523 skipped (no change), 196 xfailed (-5 WoW)
    • pytest test/functorch/test_aotdispatch.py -k test_aot_autograd_symbolic_exhaustive - 302 passed (+5 WoW), 143 skipped, 828 deselected, 188 xfailed (-5 WoW)

What’s made it to master since last time?

ezyang

Chillee

voz

jbschlosser

What’s coming next?

  • ezyang: inductor inference accuracy failures, popcorn enablement
  • Chillee: inductor training, autoregressive generation performance
  • bdhirsh: dynamo graph breaks, inference functionalization (this looks like we will still need to put copy_ in the graph)
  • jbschlosser: dynamo graph breaks
  • nkaretnikov: finally getting the floor div patch series in (it fixes real bugs!)

Our north star: Dynamic shapes at feature parity with static shapes (but NOT turned on by default.)

Mini-update: Dynamic shapes enablement for OpenNMT

vince62s has been interested in using using dynamic shapes with OpenNMT. I haven’t had a chance to try the full model, but a small extracted example is (mostly) successfully compiling in a size generic way when run with dynamic shapes. Check out my summary comment for more details!

1 Like