PyTorch 2.0 User Empathy Day Recap

The rules and the outputs

We hosted the first PyTorch 2.0 User Empathy Day at Meta, where people picked up trending OSS models, applied torch.compile as a USER and gave it as many fixes as possible without torching PT2 internals. The goal was to better understand the gaps and pain points from the user’s perspective and gather signals for us to better prioritize our feature development.

By the end of the whole day’s hacking, we emerged with:

  • 5 models successfully compiled and had decent speedups out of box.
  • 3 models eventually got torch.compile working.
  • 7 models never got torch.compile working.

The approach of candidate model selection was based on whether the model is popular and commonly used by the community, although some of the models don’t necessarily work well in eager mode. We are not targeting to move metric from this user empathy day, the metric above is to give an holistic view of how torch.compile works on these models.

But we think the gaps and pain points are the most important thing we learnt from the user empathy day, which we will discuss in the following two sessions.

Common themes that people had opinions on

Logging:

  • Too verbose for users to figure out the actual problem:
    • The existing TORCH_LOGS=”+dynamo” is too spammy for users, and when errors are thrown, stack traces are too long across dynamo/aotautograd/inductor.
    • Need a better structure and abstraction for user view logging (separate from developer logging).
    • Ed’s structured traced logs probably can improve this in the future.
  • Too many (10+) artifacts in TORCH_LOGS, users don’t know which should be used.
  • Unsure how to take action on the error message:
    • E.g, “CPU tensor but the Triton is accessing CUDA” which showed up a couple of times, but the log info is not actionable.
  • Unable to tell if torch.compile is taking a while or if something is hanging. Would be good to have a progress bar.
  • Dynamo resume function is hard to understand and track stack trace. E.g, the generated resume function has different input arguments, users don’t know where they are from if they don’t have a basic mental model about how resume execution works.

Documentation:

  • It’s better to move existing Google docs (e.g, TORCH_LOGS documentation) to the actual website – makes it google-search-able and user friendly.

  • Some documentation is ambiguous (ex. unsure if max_autotune is a backend or a mode)

  • Profiler is an important tool for identifying performance issues, we need docs or tutorials for how to profile compiled model performance.

Dependency and environment setup:

  • Many people had trouble with setting up environments (versioning between different pip packages, conflicts between conda packages, xformers, flash-attention, etc).
    • This is not a PT2 specific issue, users suffer from this on PT as well.
  • More and more libraries use Python 3.12 by default, but this isn’t supported by Dynamo yet.

Bug reproducing and reporting:

  • When an error or graph break happens, it’s hard to make repros to file bugs to PyTorch.
  • Some of the graph breaks are easy to fix by looking at the error stacks and figuring out the unsupported features, e.g, missing support for itertools.zip_longest.
  • But it needs more info to exactly reproduce some other issues, e.g, normal dict is incorrectly wrapped as UserDefinedObjectVariable(dict) which should be ConstDictVariable.

Going deeper

Model accuray:

  • We didn’t observe accuracy issues from all models that we could compare for inference.
  • We did see training loss decrease for training.
  • We use a very limited input dataset, it’s possible there are some accuracy issues if we do a scaling benchmark, but at least there is no silly accuracy issue so far.

Graph breaks:

  • einops/einx are used widely in these trending models. Dynamo makes einops allowed in graphs to avoid graph breaks, but we still observe many issues around this.
    • A possible enhancement is to pre-dispatch trace the allowed einops functions.
  • Unsupported features in torch.distributed.
    • Ring attention is initializing distributed states and calls inside of autograd.Function, where a few functions are not torch.compile friendly.
  • Python coverage issues, like zip_longest, frozen dataclasses, etc.

Re-compilation:

  • It’s not clear why a function was re-compiled.
    • This echoes the point of better user view logging that we should deliver actionable messages to users that “because of X, function Y is re-compiled”, and this log should not be buried in the floody (developer) logs.
  • We should provide an API to help users workaround re-compilations, e.g, we can make

torch._dynamo.disable a public API.

Compilation time

  • Compilation time is a critical metric that impacts the user’s experience. E.g, using Whisper model to transcribe an entire movie takes 2 mins in eager mode but torch.compile takes 2 mins to compile.
    • This is not solid evidence of slow compilation, just to emphasize how much impact does compilation time have on user experience.
  • We observe FX graph cache and fake tensor cache reducing compile time by 2x from the StableCascade model whose FX graph has 22k nodes.

A more comprehensive benchmark

We realized the limitation of the existing benchmarks and had some ideas of improving benchmark at the debrief session:

  • The existing TB/HF/TIMM model benchmarks use aggregated performance numbers and compare against eager mode, this could make small regression slip through the cracks.
  • For the most important or popular models (e.g, llama, mixtral), we should benchmark against SOTA numbers.
  • We should build microbenchmarks to make the whole model’s performance predictable and capture even small regression.

Next steps

  • Please open github issues for all the problems you hit and mark them with tag “empathy-day”. We should prioritize these issues since they are from trending models.
    • Kudos to the folks who already filed issues!
  • The issues listed in logging, documentation, graph break and re-compilation are pretty actionable, we will find out resources on these items. Feel free to let us know if you are interested in contributing.
  • We don’t have a clear idea on how to improve the issues listed in dependency and environment setup, bug reproducing and reporting. We should brainstorm some ideas under these topics and welcome any suggestions.
  • Building more comprehensive benchmark is important and signed up by Yanbo.

We would like to thank everyone who participated in the user empathy day, thanks for sharing your valuable insights and empathizing with each other’s challenges and successes in using PyTorch. We are making PyTorch better!

2 Likes

Are you thinking about releasing some pre-built Docker images, or are there already some available?

Are you thinking about releasing some pre-built Docker images, or are there already some available?

We have docker image at GitHub - pytorch/pytorch: Tensors and Dynamic neural networks in Python with strong GPU acceleration, but the problem is users have their custom requirements of the dependent python libraries when working with different AI models.

1 Like