TorchDynamo Update 8: TorchDynamo passed correctness check on 7k+ github models

Recently we successfully ran TorchDynamo on 1K+ GitHub projects (a total of 7k+ models/test cases) collected using a crawling script. It is an important milestone as it demonstrated TorchDynamo as the most reliable OOB graph capture for PyTorch to date.

This post offers more details on this work, including the qualities of the graphs captured and the kind of problems fixed along the way.

TorchDynamo

If you are new to TorchDynamo, the links below will allow you to catch up on the new exploration. TorchDynamo generates FX graph from Python bytecode and various backends are integrated with TorchDynamo to complete inference/training of the model. In the future, with the help of a cost model, TorchDynamo could automate the selection of the best backend for each subgraph to achieve optimal performance.

How did we set up Dynamo’s 1K+ GitHub project evaluation?

Model selection criteria

  • Any Github project w/ 100+ stars and including “Pytorch” as a keyword.

Testing goal

  • No exceptions thrown out
  • Getting correct results

Testing data

Testing tool

Running TorchDynamo in default mode – Test w/ Graph Break

We first ran the models using the default mode of TorchDynamo. Under the default mode, TorchDynamo graph capture may fall back to Eager for any Python constructs not supported by the compiler backend, potentially causing graph breaks. This mode has the best UX (completely OOB) but at the expense of sometimes capturing partial graphs instead of whole graphs.

Starting point – the 1st run on May 1st

The following table shows our first evaluation conducted on May 1st. It showed that TorchDynamo already achieved a pretty high success rate.

total passing success rate
projects 1111 1035 93.2%
tests 7549 7399 98.0%

As we dug into the errors, we identified 7 distinct bugs that accounted for the 141 runtime errors and 4 distinct bugs that accounted for 9 correctness errors. The following list gave examples of the kind of bugs/issues we discovered.

End-point on June 10th

On June 10th, after fixing all the bugs, we hit the 100% goal!

total passing success rate
projects 1112 1112 100.0%
tests 7560 7560 100.0%

Graph and Graph Break Characteristics

So how are the qualities of the graphs captured by TorchDynamo? The following stats shed some light:

  • Average unique graphs in each model: 1.5
  • The largest model by PyTorch operators: 4516
  • Average PyTorch operators ran inside of TorchDynamo for each model: 33

We did observe that some models generated a lot of graph breaks. The following listed the top 10 models with the largest number of graphs captured. These models will be the focus of our future work to improve the full graph mode of Dynamo

Running Dynamo in the full graph mode (aka nopython=True)

We next evaluated Dynamo’s ability to capture full graphs using the flag nopython=True. In this mode, instead of breaking the graph and falling back to Eager when encountering an unsupported Python construct, Dynamo deliberately aborts and provides hints to users to fix the graph break. This mode is especially important for providing a smooth UX transition from the partial-graph capture (default, OOB, eager) to the full-graph capture (human-in-the-loop, export) using the same toolchain.

The 3rd round run - aborting on graph breaks

This table shows the coverage using the full-graph mode. As expected, the success rate dropped from 100%.

total passing success rate
projects 1112 704 63.3%
tests 7561 6383 84.4%

For all the models passed without Python fallback (aka unique graphs per model is 1):

Future Work

Look at the graph break reasons; these are some of the top ones:

  • Non-const NNModule method
  • Call function in skip files, e.g., collections
  • Data dependency and control flow
  • Usage of non-Pytorch libraries, e.g., NumPy.

Some of them are required to respect, and others need to add support, so we need to categorize them and treat them differently:

  • Wrap these exceptions(e.g., unimplemented) with more readable exceptions to provide a better user experience.
  • Prioritize the top k graph break reasons and implement these features to avoid the graph break.
  • Open issue to track the graph break reason if it needs time to implement.
9 Likes

This is amazing work! I look forward to using this in Torch-MLIR!

4 Likes

I’m coming from Julia where we have fast custom datastructures, some of which can be stack allocated or inlined into memory.

To what extent is this going to be possible with torch dynamo?

I’m looking for one or more of of mutable or immutable “struct like” things with dynamically dispatched methods, statically dispatched methods, inlined method.

Use cases range from fast graphs, writing to some state in an inner loop, tree structures etc

the backend compilers aren’t really optimized to have custom data structures (forget fast lol).
They’re more like ML compilers, still very opinionated about what a Tensor is, how the memory layouts are, etc.

I was wondering how we generated valid input shapes for all these models.

Apparently @jansel wrote a fairly nice deducer with a funny search that reads the error messages and tries to prune the space down.

Had a lot of fun reading through the code.

2 Likes

In many case TorchDynamo will flatten python data structures to present a simple list of input tensors to the backend, then restore any output data structures and side effects after the backend graph runs.

This could make things faster. For example, if you construct lists/tuples/namedtuples/dicts and those don’t escape the bytecode being compiled, TorchDynamo will optimize them away. If you do many updates to a data structure, those will be collapsed into a single one.

Ah, interesting.

Does this include custom classes, particularly if I don’t add fields at runtime (or if I want to preclude this from happening with some decorator or slots).

Would be cool if it could get memory layout help with type hints maybe with semantic enforcement. Also Jax’s PyTrees are really useful. And Chex’s dataclasses. There’s a lot in the space with Jax.

Yes, there is support for custom classes and even limited support for mutating attributes of custom classes. TorchDynamo will try to extract all the graph inputs from the custom classes, then queue up all the mutation/side effects and apply them after the graph executes. It definitely doesn’t handle everything, but for simple cases it should work.

PyTrees are awesome, we use them in functorch/AOT Autograd.