Torch.nn H2 2021 Lookback and H1 2022 Lookahead

Hey everyone! I wanted to post some quick highlights from the torch.nn work during H2 2021 and the upcoming projects we are looking into for H1 2022. There were some big accomplishments this past half, and we have quite a bit planned for the new half. Suggestions, comments, and feedback are always welcome!

H2 2021 Highlights

  • Convolution consolidation. As expected, this project reducing the number of convolution-related ops within PyTorch was a bear, affecting countless areas of both OSS and internal code in addition to being a very important and sensitive area of the code. Following are some of the notable successes:
    • Starting op count: 85. Final op count for the half: 47 .
    • A single forward / backward op pair encompasses routing to the various supported backends. For projects that need to support convolution (e.g. functorch, Lazy Tensor Core, forward mode AD formulas in autograd), this essentially means only a single pair of ops (convolution / convolution_backward) needs to be targeted to get convolution support. In fact, this has already been done for the aforementioned projects.
    • The backend routing logic is now privately exposed to Python via torch._C._select_conv_backend with a brand-new, robust set of tests ensuring that the correct backend is selected for a given set if inputs. The function takes the same inputs as the convolution op and returns an enum indicated the selected backend; see the example in the test logic (cc Taylor Robie who mentioned this may be useful for profiling).
    • The work uncovered several bugs. Some examples:
      • Strided _convolution_double_backward was broken for 3D input / weight; fixed in #67283.
      • #68034: 1D convolution was broken for MKL-DNN tensors; fixed in #68166.
      • #68755: MKL-DNN convolution does not currently support transposed convolution, but there is work from Intel adding this in #58348.
      • functorch’s convolution batching rule needed to change to avoid hitting the batched fallback; see here for more context.
    • Rewrite of the bias gradient computation logic led to a modest speedup in our naive convolution algorithms; I estimate an amortized ~10% decrease in time for a forward / backward pass when using our naive, non-accelerated algorithms (non-cuDNN, MKL-DNN, etc.).
    • While a big portion of the impact has been achieved, there are still some things to do:
      • Drop the rest of the backend-specific ops.
      • Make convolution / convolution_backward structured.
      • Split CPU / CUDA routing logic across dispatch key entries.
      • Remove the _convolution op that TorchScript serializes, move XLA from overriding convolution_overrideable / convolution_backward_overrideable to implementing convolution / convolution_backward instead and drop the former pair.
    • #thanks Richard Zou, Alban Desmaison, Natalia Gimelshein, Brian Hirsh, Will Constable, Bin Bao, Horace He, Edward Yang, Gregory Chanan, and Peter Bell (Quansight) for reviews, comments, debugging, and collaboration toward this effort!
  • Support for inputs without batch dimensions. This support was added throughout the torch.nn modules for composability with vmap. Nearly every module now supports this, with a few BC-breaking cases remaining to be knocked out this half.
    • #thanks Richard Zou for the design discussion, George Qi for adding support to several modules, and Quansight developers Thomas Fan and Kshiteej Kalambarkar for their extensive contributions!
  • Label Smoothing for CrossEntropyLoss. This was called out in a previous post, but I wanted to highlight it again here nonetheless: support for both label smoothing and class probability distribution targets for CrossEntropyLoss is available as of PyTorch 1.10.0 (despite the fact that a Google search for “pytorch label smoothing” still doesn’t make this clear; who do I ping to get #3428 fixed)?
  • Transformer API flexibility improvements. There is vital work going on in parallel by several compiler and performance oriented people to improve core’s Transformer implementation. In the meantime, on the frontend side, a couple API-related changes have landed to MHA flexibility and making it more usable for research in the meantime.
  • Testing improvements.
    • ModuleInfo-based testing. As of the end of the half, 90% of the old-style test logic has been ported to a new ModuleInfo-based approach, moving towards a full replacement of the confusing test infrastructure.
      • The design was described in more detail in a previous post, but check out the tests here to see how to easily write tests that run across modules using the database of ModuleInfo entries.
      • We still need full ModuleInfo coverage for all torch.nn modules! Any help in writing these is much appreciated :slight_smile: Check out the issue for the gaps if you’d like to contribute.
    • Test parametrization mechanism. This was described in detail in a previous post; shamelessly plugging it here again. The TL;DR is that there is now a test decorator that allows you to parametrize tests, making it easy to achieve thorough test coverage with subtests. Check out the docs here and some real-world examples here, here, here, and here.
    • #thanks Mike Ruberry (reviews / design), Richard Zou (reviews, design, usage feedback), Mikayla Gawarecki (adding module testing across memory_formats), and Quansight developers Thomas Fan and Kshiteej Kalambarkar for the extensive help in this important area!
  • Documentation improvements. We landed quite a few documentation improvements during the last half. To make it easy to find and use information within the torch.nn docs, I put together an opinionated module documentation style guide that is now followed by the majority of torch.nn modules. Work toward 100% compliance is continuing as we speak - if there is anything else you’d think would make our docs more usable, please reach out!

H1 2022 Projects

A key theme for H1 within torch.nn is accelerating research through both unblocking and streamlining UX for research-focused use cases. In particular, we want to open up the custom tensor extension point by moving towards first-class support within torch.nn, provide composable building blocks for meta learning (a key area where several claim JAX has an advantage at the moment), and explore a new API for invertible ops with the potential of unlocking significant memory savings for RevNet-type models.

  • Broadened support for custom tensors. Support for custom tensors / tensor subclasses defined within Python are a big focus of H1 2022, motivated by their expanding use within functorch, MaskedTensor, etc.
    • Support for custom tensors as parameters. As mentioned in a previous post, custom tensors are currently not usable as module parameters or buffers, mostly due to internal implementation details. Developer bandwidth limits prevented this from being addressed last half, but it is a high-priority focus for the beginning of H1 2022. This is a popular request that should open up the custom tensor extension point to further use cases.
    • Composability of tensor subclasses. In H2 2021, Alban Desmaison put together a PoC exploring a mechanism for composing multiple tensor subclasses, ensuring reasonable semantics. Our plan for the half is to tie up the loose ends and provide this mechanism officially within core.
  • Improved meta-learning support. The higher library was created to facilitate meta-learning research. It provides (in a non-official way) building blocks for the functionality needed to implement meta-learning algorithms. Our plan is to instead provide these building blocks as official core APIs, streamlining the supported UX for meta-learning within PyTorch and freeing the higher project to focus on the implementation of particular meta-learning algorithms instead. The hope is that this makes meta-learning research more easily achievable within PyTorch.
    • “Stateless” / “functional” API for modules. Higher-order gradient calculation requires operating with multiple sets of parameters across “optimization timesteps”. Thus, it is useful to have a mechanism that decouples module logic from the set of parameters, allowing module logic to be run on a given set of parameters instead of operating mutably on its single, internal set (referred to here as a “stateless” or “functional” API for modules). Note that higher and functorch have historically had to provide their own versions of this feature. We have been working through the design of a core-provided API throughout the last half, and there is currently a beta version we plan to move to a stable feature this half.
    • Differentiable optimizers. This is another core requirement for computing higher-order gradients that higher currently provides. We’d like to provide our own set of differentiable optimizers within core as part of an official meta-learning UX. The plan is to produce a working PoC tutorial implementing a SotA meta-learning algorithm with all core-provided building blocks.
  • Support for invertible ops. Ops that are properly invertible don’t need to save input tensors between the forward and backward passes, as they can be recomputed from the outputs. Support for even larger models can be unlocked through such a tradeoff between computation time and memory usage. Our goal this half is to explore an official API for specifying computation of inverses in such a way that unlocks the memory usage savings for invertible models. We plan to produce an initially module-centric, out-of-core PoC demonstrating a clear memory savings advantage on a RevNet implementation.
  • Further documentation and testing improvements. Documentation and testing play a huge role in ensuring PyTorch provides a world-class user experience, so we plan to continue with a strong focus here in H1 2022.
    • Ideally, documentation provides a streamlined way for users to frame their ML use cases and research questions in terms of PyTorch framework concepts to quickly produce functioning code. How well does this process work now? What are some of the big obstacles to this process going as planned? We are already starting to explore the answers to these questions, with the goal of ensuring that up-to-date, high-quality information is readily available to users. Thoughts on the highest-impact improvements that can be made in this area, whether restricted to torch.nn or more broadly, are most welcome!
    • We will continue the process of clearing out old-style module tests, replacing them with the new-and-improved ModuleInfo-based tests. One key thing to address here is finishing out the database of ModuleInfo entries (tracked here) to fully cover the ~150 modules provided by torch.nn. Once this is done, the old-style dicts can be dropped entirely, cleaning up a huge amount of difficult-to-understand legacy test code.
    • Another obstacle to test understandability and maintenance is the monstrosity that is . It’s difficult to operate on due to its massive size, and arguably quite a bit of the coverage of the ad-hoc tests there should be replaced by ModuleInfo-generated tests. I don’t think we can fully get rid of ad-hoc tests while maintaining the same level of coverage, but they should at least be split into maintainable and well-organized chunks.

Look out for more in-depth treatments of the above topics in the near future as we begin to explore them in more detail! Note that we will happily accept contributions towards these ends, both in the form of ideas and code, if you’d like to help out :slight_smile: Thanks everyone for the collaboration and support throughout 2021, and I’m excited to continue the work in 2022!