Reversible Networks in PyTorch

There has been some discussion around adding a reversible network API to PyTorch core in issues/23756. This note will cover some of the learnings that came out of looking into RevNets.

Contents:

  1. Background and some examples of reversible and invertible networks
  2. Examples of reversible and invertible networks libraries on top of PyTorch that already exist in the ecosystem
  3. Propose changes to core that would potentially improve support for reversible networks and discuss whether they should exist in core

This note is titled “reversible” networks, but we will also talk a little about “invertible” networks. The two terms are sometimes conflated to both mean “being able to compute the inputs from the outputs”. In this note we’ll make the distinction between the two and define invertibility to mean the stronger notion of bijectivity, and “reversibility” to mean “you are able to compute the inputs from the outputs”. This means that all invertible networks are reversible, but not all reversible networks are invertible.

(Note that this investigation was done early this half, so there may have been new progress made in the space that is not covered here.)

Thanks to Alban and others who gave valuable feedback and insights on earlier drafts.

Why care about reversible networks?

The main benefit of having a reversible network is reduced memory usage during training. The key is that if you can recompute activations from next layer activations, you can just compute those on the fly during the backward pass, and you can avoid saving those tensors for backward. See Alban’s Colab for a demonstration of this that is implemented simply using custom autograd Functions.

Examples:

  • RevNet is a variant of ResNet that borrows the coupling block idea from NICE. Coupling blocks wrap an arbitrary function to add a residual connection and enable inputs to be recomputed from the outputs. This enables fewer activations to be saved during training, reducing memory usage. Full reversibility is not a requirement for getting memory savings. This is the one of the earliest example of reversible networks applied to memory reduction, and it popularized the use of coupling layers for memory reduction purposes. The following two examples build off the idea from RevNet.
  • Momentum Residual Networks provides a drop-in way to reduce memory consumption. Unlike RevNet you are able to reuse the same weights. MomentumNet has also been shown to be more expressive than RevNet.
  • Reformer: Reversible architecture applied to the transformer

As mentioned above, the more strict version of “reversibility” is “invertibility”. Invertibility is one of the requirements toward training normalizing flows (NF) which are used for density estimation and generative modeling. NFs require (in addition to requiring that transforms be bijective) that 1) the inverse is differentiable 2) the log-absolute-determinant-Jacobian (LADJ) is tractable and differentiable. For most functions the LADJ is is not easy to compute, but coupling (see appendix) and other ways to restrict one’s function can make it easier to compute this value. Some models like i-Resnet choose not to restrict the function and rely on approximating this value instead.

Examples:

  • NICE: this inspired RevNet’s coupling layers
  • Real-NVP: compared to NICE, we want our transformations to scale the space instead of having unit determinant
  • i-ResNet: show the viability of non-coupling invertible nets

Survey of reversible network libraries

We give a couple examples of invertible networks libraries, discuss which use-case they cater to, and discuss some patterns that emerge. See appendix for a 2-minute introduction to normalizing flows.

RevLib:

  • Very coupling-oriented API
  • For memory reduction

FlowTorch

  • Library for learning and sampling from complex distributions that was originally part of Pyro
  • Does not provide constructions for memory reduction
  • It extends the torch.distributions module from PyTorch which already has some pieces (e.g. invertible transformations) that would seem useful for reversible nets. Similar to how a normalizing flow is trained, torch.distribution also provides an extension of the Distribution class that allows a sequence of transformations to be applied to a base distribution.
  • Also provides helpers for coupling, but coupling is baked into how the entire library is used like RevLib.

MemCNN

  • For memory reduction
  • Coupling logic is a little more factored compared to RevLib

Discussion:

  • On coupling function support (see the Appendix for a quick intro to coupling): Almost all libraries have converged to have some support of this technique since many recent models (e.g. RevNet, Reformer) have taken such an approach. It is not clear how relevant this approach will be in the future as more recent papers like i-Resnet have shown the viability of non-coupling-based approaches that are just as expressive.
  • On the implementations: The Revlib and MemCNN don’t do anything hacky or unexpected to achieve memory reduction. It is implemented very similarly to how checkpointing is done in core.

Some potential next steps

These are some potential changes to core, e.g. that would upstream a useful construct that many libraries have converged on, or simplify the implementation of constructs that libraries need but are currently implemented in a hacky way?

  • Approaches that improve normalizing flows support:
    • (1) Currently, torch.distributions have the notion of transforms, which conform to an API providing .inverse(), .log_det_abs_jacobian() etc., They are composable, too but not trainable. One idea is to upstream trainable transforms from flowtorch (formerly pyro). This would basically mean upstreaming a flowtorch since we wouldn’t want to introduce trainable transformations without providing some of them ourselves.
    • (2) To take implementation 1 a step further, we could potentially unify torch.distributions transforms and aten operations to deduplicate code. Is the payoff of deduplicating logic is worth it for the amount of effort?
    • Discussion:
      • Possibly why this shouldn’t be in core is that the flowtorch library and the space of normalizing flows research is still evolving quickly, we don’t want to prematurely bake things into core that may change in 1-2 years.
  • Approach that improves support for the memory reduction use case:
    • Introduce reversible sequential container (see Alban’s Colab). This canonicalizes the .inverse() API at the module level. It will be trivial for any network that have modules with the .inverse() implementation to achieve some reduction in memory us during training.
    • Discussion:
      • The argument to bringing this to core is that this is analogous to checkpointing. checkpointing can be implemented using existing PyTorch APIs in Python (custom autograd Function OR saved tensor hooks), but it belongs in core because there are trickinesses to deal with to compose with existing PyTorch features. This line of reasoning would also to apply to a reversible sequential container, which is implemented in a similar fashion.
      • Bringing a reversible sequential container to core would commit us to adopting a .inverse() API at the module level . Alternatively it could be at the aten level (see open questions below).
  • An approach can potentially improve support both use cases (Alban):
    • “PyTorch should support an API that generalizes the concept of computing quantities like forward grad and LADJ alongside the forward pass and quantities like inverse alongside the backward pass (similar to Backpack).
    • My notes on the topic are that:
      • Flowtorch currently uses tensor subclassing (torch_function) to compute LADJ along with the forward. Perhaps that is a clean enough solution?
      • There is a complexity which is that the different quantities we want to compute may be dependent on each other, i.e. if we want to compute both the gradients and inverse, there vjp function would need access to the output of the inverse. There would also need to be a way to disable saving tensors for backward in the first place (via saved tensor hooks?).
    • This path seems promising but there are still many questions. A small prototype would probably be a good next step.

Related (open) questions:

  • If we’d like to bring a .inverse() API to core, should it work at the nn.functional level, module level, or aten level? A .inverse() API at the module level seems to require the least amount of changes to core.
  • Since reversibility is a weaker notion than invertibility/bijectivity, can we build a construct like the reversible sequential container such that it also brings benefit for the normalizing flows use case? Or maybe it is unproductive to talk about both reversible and invertible networks together in the same context at all?

Conclusion:

This note gives some background on reversible networks and discusses how PyTorch can better support their implementation. We conclude that (1) though there may be tweaks to the core API that can improve reversible networks support they come with their own trade offs and may require further investigation, (2) existing libraries seem to be mostly well supported by existing PyTorch APIs. We currently do not have plans to make API changes to core directly in support of simplifying the implementation of reversible networks in PyTorch in the near term.

Appendix: more background

Normalizing flows (NF):

  • TLDR: use NFs to flexibly build complex distributions that you can train and sample from
  • What problem does it solve:
    • Normalizing flows is a technique initially applied to density estimation. (More recently normalizing flows has been applied to generative modeling, e.g. image generation.)
    • In density estimation, one parametrizes a known distribution by maximizing likelihood of observed data. It is often unclear what family of distributions should be used to fit the data however.
    • For normalizing flows one to begin with a simple distribution, but can freely compose a series of transformations to build more complex distributions using the change of variables rule.
  • More concretely, how do I train it (i.e., compute/maximize log probability)? or sample from it?
    • To sample from a NF, begin with some simple distribution P (usually gaussian) with pdf (f), then apply a series of smooth bijective maps x1, x2, … to it.
    • To compute the log probability of NFs we apply the change of variables theorem for distributions. If we let x be the inverse of the composition of these transformations, the probability density function (pdf) of the “data” distribution g is given by

  • One can also define the inverses the other way around, depending on which way is easier to compute.
  • A new problem:
    • Now the issue is finding such smooth bijective maps that have tractable det Jacobian and are invertible but are also expressive enough to model our data

Coupling layer:

  • TLDR: Coupling layer solves the problem of engineering layers that can 1) be very expressive (f can be anything!) AND 2) also have tractable inverse and Jacobian determinant. (for example for use with NFs above)
  • What:
    • Given arbitrary function f (call it the coupling function) and some special function g which we call the coupling law, computes g(x, f(x)).
    • there are restrictions on the coupling law g: 1) invertible and has tractable det Jacobian (e.g., its Jacobian is triangular)
  • How?
    • Observation: det of Jacobian is the product of its diagonal when it is triangular
    • The Jacobian is triangular as long as the coupling law is triangular. If the coupling law is simple addition (i.e., we have additive coupling), then the determinant of the Jacobian is simply 1.
  • Problems:
    • Forces you to have a residual connection
    • Coupling makes design choices that are not well understood yet, e.g. how to choose dimensions to partition. There is some discussion of this in the i-ResNet paper.
  • Further reading: NICE/Real-NVP papers
3 Likes

How come parametrisations are not considered in “potential next steps”?
Note that we already designed them with an inverse.

Adding to this, I think that torch.distributions could be improved by using parametrisations. This would make them trainable, and would make a stronger point for the first of the proposalsa in “potential next steps”.

Ahh that is a great point that I missed. Using reparametrizations as building blocks seems like a great way to support normalizing flows models indeed.