TorchDynamo Update 3: GPU Inference Edition

Since September 2021, we have working on an experimental project called TorchDynamo. TorchDynamo is a Python-level JIT compiler designed to make unmodified PyTorch programs faster. TorchDynamo hooks into the frame evaluation API in CPython to dynamically modify Python bytecode right before it is executed. It rewrites Python bytecode in order to extract sequences of PyTorch operations into an FX Graph which is then just-in-time compiled with an ensemble of different backends and autotuning. It creates this FX Graph through bytecode analysis, not tracing, and is designed to generating smaller graph fragments that can be mixed with Python execution.

Our first post, TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation, introduced the concept and the approach.

Our second post, TorchDynamo Update: 1.48x geomean speedup on TorchBench CPU Inference, shared the first performance results and introduced our ensemble-based proof of concept backend.

This post is primarily a performance update, as we have continued to develop on the strategy described in the earlier posts. Notable changes since then:

  • Added GPU support.
  • Increased support of Python bytecodes.
  • Added new backends, including: nvfuser, cudagraphs, onnxruntime-gpu, tensorrt (fx2trt/torch2trt/onnx2trt), and tensorflow/xla (via onnx).
  • Imported new benchmarks added to TorchBenchmark, including 2 that TorchDynamo fails on, which should be fixed soon.
  • Switch to measuring on different machines, which made GPU possible and also changed the CPU results due to a jump from 12 to 96 threads and adding AVX512 hardware.

Performance Results

With that, on to the numbers! Attached you will find updated performance for both GPU and CPU inference. This includes the following baselines for comparison:

Each number is the median of 100 measurements and is normalized to speedup over eager mode.

The first thing that still jumps out at me is the difference is model coverage between TorchScript based backends and TorchDynamo/Eager. Except for eager mode (100%) and TorchDynamo (96%), no backend works on more than half of models. This reflects a massive usability gap we have between eager mode and existing graph mode backends.

TorchDynamo provides larger average speedups than the other backends shown. On GPU TorchDynamo provides a 1.29x geometric mean speedup and on CPU TorchDynamo provides a 1.71x geometric mean speedup. These results show TorchDynamo is faster on average while maintaining high model coverage.

For a bit more raw data to figure out what TorchDynamo is doing, the following are the counts of how often TorchDynamo used each GPU backend:

('eager', 323)
('cudagraphs', 161)
('nvfuser', 62)
('ofi', 58)
('nnc', 50)
('tensorrt', 36)
('onnx2tf', 2)
('onnxrt_cuda', 1)

And here are the same counts for CPU backends:

('eager', 314)
('ts', 157)
('ofi', 149)
('onnxrt_cpu', 24)
('onnx2tf', 24)

One should take these numbers with a grain of salt. The size of these graphs vary dramatically and a small subset of the graphs matter much more than others. Eager is the most commonly selected backend, because 1) other backends don’t support many graphs; and 2) eager often outperforms graph based backends. The biggest area I see for further performance improvement is to break graphs at unsupported ops in order to increase backend choice.

Conclusions

It remains early for TorchDynamo, and it is not ready for production usage, however these early results continue to be extremely promising. It shows that the best of both worlds is possible, where we can support the full dynamism and user experience of PyTorch and Python, but still get performance similar to or better than more restrictive graph mode runtimes.

7 Likes

What are you counting with the backend counts for the GPU backends?

The number of subgraphs each backend was used on across all benchmarks.

Hey, I’m looking at some the regression reported on nvfuser + OFI.

Not totally sure if I’m running the right things. running the model with this

./torchbench.py --no-skip -d cuda -n 200 --speedup-ts --nvfuser -k "pyhpc_isoneutral"
106  checking models:  pyhpc_isoneutral_mixing                                                                              
107 cuda pyhpc_isoneutral...  /opt/conda/lib/python3.8/site-packages/librosa/cache.py:49: DeprecationWarning: The 'cachedir' attribute has been deprecated in version 0.12 and will be removed in version 0.14.    
108 Use os.path.join(memory.location, 'joblib') attribute instead.

163   if self.cachedir is not None and self.level >= level:                                                                 
164   0/  0 frames (+ 0), 0.996x SAME 0.288x p=0.06                                                                                                                                                                
                                         
190 MEAN SPEEDUP [    0.99638     0.28841]                                                                                  
191 GEOMEAN SPEEDUP [    0.99638     0.28841]

That 0/0 frames sounds like nothing was executed.

@jjsjann123

The 0/0 frame thing is expected for --speedup-ts and you can ignore it. 0/0 frames is the number of frames TorchDynamo processed, but TorchDynamo is disabled for --speedup-ts (baseline) so it will always be 0/0.

0.996x SAME

is the speedup with TorchScript (and nvfuser). And t-test for statistical significance.

0.288x p=0.06

is the speedup (slowdown in this case) with TorchScript + OFI (and nvfuser). With a p-value of .06.

1 Like

@jjsjann123 also make sure you are using the latest TorchBench (or my branch if you want some minor fixes).

The bug fixed by this PR: Don't use jit when jit=False for pyhpc benchmarks by jansel · Pull Request #675 · pytorch/benchmark · GitHub
will likely fix the issue resulting in you seeing ~1x performance for TorchScript on the benchmark.

1 Like

Hi @jansel, I was trying to add a new backend for torchDynamo, namely the newly landed meta-schedule of TVM and there are some issues in the log that maybe you can take a look.

My repro looks like ./torchbench.py --fast --backend tvm_meta_schedule and in the logs, at the end of each model run, sometimes there is only
========== End debug info ========== 0.998x SAME

and sometimes there is
21.967x p=0.00
without the End debug info. Based on your explanation above, does that mean if there is only 0.998x SAME, the backend I added for torchDynamo actually failed or crashed so only torchscript ran as a fallback? And for the second case my backend actually successfully optimized the given model with a 21.967x speedup? I wonder why those numbers aren’t printed together as @jjsjann123 shown in their response as well. Thanks in advance!

The 1x numbers are likely exceptions coming from your compile_fn, and falling back. You can verify this by catching exceptions yourself and calling sys.exit(), or looking for error printouts.

The multiple speedups in a single line are specific to --speedup-ts (which is testing multiple things).

Thanks! So the 21.967x p=0.00 speedup is the result of comparing my new backend to torchscript. Is that the case here?

For torchdynamo, is there a re-try mechanism if one subgraph is not optimized successfully? My understanding is that torchDynamo will divide the model into subgraphs based on the control flow(?), in other words if there is no control flow within the model, the model itself will be passed into the backend as a whole subgraph(?). And if the compute_fn is not returned successfully it will fall back to torchscript(?) without any re-try. Would be great if you can give me some pointers here. Thanks in advance!

Everything is compared to eager, not torchscript. You can measure the performance of torchscript with --speedup-ts.

For most (80%) of models, torchdynamo just gets one whole program graph. You can run with --coverage to see graph counts. If the backend throws and exception, it just runs eager mode fallback.

Having issues seeing the speedups when trying torchdynamo backends out on GPU inference (I am working on a few OCR models)

Thanks for this awesome update!
What’s the blocker on Huggingface LLMs? As I’m sure you know, compile() is producing errors on hf_T5, and all the others you listed, such as:

  File "/home/kastanday/utils/mambaforge3/envs/torch2.0/lib/python3.10/site-packages/transformers/models/t5/modeling_t5.py", line 949, in <graph break in forward>
    raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
ValueError: You have to specify either input_ids or inputs_embeds

For accelerated inference on huggingface models, would export() work? Likely not since the model’s can’t be compiled (yet). Any advice here? Thanks.

# API Not Final
exported_model = torch._dynamo.export(model, input)
torch.save(exported_model, "foo.pt")

hf_T5 runs fine in our nightly benchmarks, so there must be something different about your setup. Feel free to submit an issue with more details.

Export should work fine, but you it gives you an FX graph which can’t be torch.save()ed – so you aren’t using it in the intended way. Export is not designed to be a solution for model serialization. It gives you a graph intended to be passed to other frameworks and compiled.

1 Like