Torch.nn September Update

torch.nn September Update

Hey everyone! Here’s an update with some torch.nn-related highlights:

Long-standing issues

  • Notably, support for label smoothing in nn.CrossEntropyLoss has finally landed! This was a popular request that was over 3 years old, and I believe the two-prong approach we went with gives a good mix of performance, flexibility, and convenience. To summarize:
    • nn.CrossEntropyLoss now supports class probability targets (AKA “soft” targets) in addition to class index targets. This allows for full flexibility in the the target distribution, supporting both manually-applied label smoothing as well as alternate soft label techniques like mixup / cutmix within the core loss function.
    • nn.CrossEntropyLoss now supports a new label_smoothing kwarg that applies canonical label smoothing as defined in the original paper. It works with both class index targets and class probability targets.
      • If you just want canonical label smoothing, use the label_smoothing kwarg with class index targets. This gives better performance over label smoothing on class probability targets by avoiding materialization of the full target distribution tensor.
      • If you need more flexibility in defining your target distribution, but would still like the labels to be smoothed, use class probability targets with the label_smoothing kwarg. This allows for easily applying mixup + canonical label smoothing, for example.
    • The new label smoothing functionality is already in use by torchvision (see here) for their “batteries included” initiative!

Convolution consolidation

  • As mentioned in the previous post, I’m looking this half into consolidating the 85 convolution ops into a more manageable number to make things easier for vmap and other functionality that targets all ops. If you’re interested, I’ve put together a doc classifying / mapping out the various convolution ops as well as the plan for consolidation: [WIP] Convolution Consolidation. The first steps towards this are in process now, but it’s never too late to leave feedback if you have questions or concerns.
  • TL;DR: The goal I’m targeting is to have a single general convolution op with a single backward implementation that subsumes the current backend-specific convolution ops with their individual backward implementations. Most of the high-level API op structure will remain the same, but I believe that only needing to deal with a single general forward / backward implementation provides maximum impact for projects like vmap.

Improved module testing

  • The initial version of ModuleInfo has landed and we are in the process of transitioning test coverage over from the old-style tests. Some highlights here for those interested in running tests over the set of modules:
    • Following in OpInfo's footsteps, the following now exist and can be used for your testing needs:
      • module_db - the database of ModuleInfo entries (analogous to op_db).
      • @modules decorator - used for specifying the set of modules a test should be run over (analogous to @ops).
      • New-style ModuleInfo tests in [test/]( (analogous to test/
      • ModuleInput - analogous to SampleInput, but encapsulates inputs to both the module constructor and its forward pass, as these are generally coupled.
    • Similar to OpInfo, A ModuleInfo entry in module_db defines metadata for a given module + a function for generating a set of samples to be used in generated tests. An example demonstrating usage of all of the above can be found here.
    • Note that we don’t yet have coverage over all modules within torch.nn, and we are focused on porting test coverage from the old-style testing before filing the module coverage gap. This helps keep test running time low for the time being, as currently tests are duplicated across old-style and ModuleInfo-style tests.
  • There has been some discussion around generating OpInfo entries for torch.nn.functional forms from their corresponding ModuleInfo entries. While this sounds great in theory, the functionality tested for modules is in practice fairly different than that tested by OpInfo. In particular, module testing tends to be more concerned with validating module state through the various mutations supported by modules (e.g. moving parameters across devices, changing parameter dtypes, serializing modules, etc.). Additionally, the relationship between modules and their functional forms is fairly ad-hoc / not consistent enough to easily do ModuleInfoOpInfo generation without a lot of edge case handling / weird UX. ModuleInfo would need to take on lots of extra stuff that would only be used within OpInfo generation, complicating its interface.
    • With this in mind, it seemed to make sense to simply write OpInfo entries directly for torch.nn.functional forms, as they are closer to ops than modules.
    • ModuleInfo entries should be written for the module forms, with the focus on module-centric validation outside of the scope of the test coverage provided by OpInfo entries for the functional forms.
  • A mechanism for test parametrization is now available for general usage within PyTorch core. If you’ve ever used pytest.mark.parametrize, and wanted something like it for PyTorch core, you may enjoy this new functionality. Essentially, it generates tests across a set of inputs, making it easier to identify exactly which case failed. Since the usage of pytest within PyTorch core was previously found to be incompatible with the existing device type test class infrastructure, I put together a mechanism that provides test parametrization with a similar interface to pytest’s and does work with our device-specific machinery. Some examples:
from torch.testing._internal.common_utils import TestCase, parametrize, \
class TestFoo(TestCase):
    # Generates a test per x value in range [0, 5).
    @parametrize("x", range(5))
    def test_bar(self, x):
    # Stacking parametrize decorators generates the product of inputs
    # e.g. ('a', False), ('a', True), ('b', False), ('b', True), ('c', False),
    # ('c', True).
    @parametrize("x", ['a', 'b', 'c'])
    @parametrize("y", [False, True])
    def test_baz(self, x, y):
    # Explicit tuples can be used when the full product is not desired.
    @parametrize("x,y", [(1, 'a'), (2, 'b'), (3, 'c')])
    def test_two(self, x, y):
# Required to generate the parametrized tests.
# Alternatively, for device / dtype specific tests, the usual
# instantiate_device_type_test() works with parametrization
# so there is no need to call both.
  • In addition to the basic usage above, there’s also functionality for skipping / expecting a specific input to fail, customizing the names of the generated tests, and more. Examples in the docs demonstrate some of these usages. More documentation for all the above is in the works and will be accessible from the dev wiki!
  • As anyone who has dealt with it before knows, can be a monstrosity to deal with. It’s large enough that things dealing with it regularly break, including GitHub’s code viewer and vim’s syntax highlighting. Further, it contains a huge, relatively unorganized mess of various testing for both specific modules as well as general torch.nn module behavior. Long term, it needs to split up, and there is some discussion about how to best do that here.

No-batch-dim support

  • Progress on supporting inputs without batch dimensions (e.g. for interoperability with vmap) is moving along steadily. There’s not much beyond that to report here, but if you’re interested, check out the current status in the issue. Much thanks to Thomas and kshitij12345 from Quansight for attacking this :slight_smile:

Other Work

In fielding internal and external requests, some limitations with torch.nn have come to light that are ripe to be addressed in the near future. Below, some of these are listed. If any of these affect you and you’d like to see them addressed, please indicate your support within the linked issues or in the comments. This helps us prioritize development efforts effectively. Thanks!

  • It’s been requested repeatedly to provide more flexibility in defining what is part of a Module’s state_dict and thus serialized for the module. It is now possible to include arbitrary state within a state_dict by overriding the pair of Module.get_extra_state and Module.set_extra_state for your custom module. Examples can be found here and below:
class MyModule(nn.Module):
    def __init__(self):
        super().__init__() = 5
    def get_extra_state(self):
        # Anything returned here is included as part of the module's
        # state_dict (i.e. it is serialized).
    def set_extra_state(self, extra_state):
        # When loading the module via load_state_dict(), if extra state is
        # encountered, this method will be called. = extra_state
  • Currently, custom Tensor-like objects cannot be fully considered first-class citizens as they are not fully interchangeable with standard Tensors. One limitation here that has been called out previously is that custom Tensors are not usable as parameters. As more potential use cases for custom Tensors arise, it seems to be important that we change this. Please add your thoughts to the issue here if you’d like to see this happen.
  • Stas Bekman from HuggingFace recently brought up a difficulty in dealing with large models: Module.load_state_dict() generally requires two sets of the model parameters to be in memory at once (one within the module instance and one within the state_dict being loaded from). This is an obstacle to loading models large enough that 2x the memory won’t fit. We are looking into solutions for this involving meta tensors to avoid the extra memory usage. Please add your support to the issue if this impacts you or you’d like to see this addressed.
  • A recent request for automated profiling for Module.forward brought to light the fact that our current set of hooks isn’t quite enough. Full pre-backward hooks would allow for viewing / modifying gradients before the backward pass is called, filling in a gap here and allowing for something like the profiling described above to be implemented.