An Efficient Way to Compute Per Sample Gradients

tl;dr We’ve added an API for computing efficient per-sample (or per-example) gradients, called ExpandedWeights, which looks like call_for_per_sample_grad(module)(input). Unlike other systems, this is implemented in core PyTorch, meaning that you don’t have to install any new packages or rewrite your existing model. We’re actively working with the Opacus team to incorporate ExpandedWeights as the per-sample-gradient mechanism for Opacus

Background

Per-sample-gradients are used in differential privacy (1), optimization research (2) and new research use cases. Because of this, we’ve gotten many requests to support easy and efficient computation of per sample gradients of PyTorch modules. Until now, you could either use the Opacus library, which focused on extending PyTorch for differential privacy, or functorch.

ExpandedWeights is designed to work out-of-the-box with existing PyTorch models, unlike functorch and Opacs, which both require some extra user intervention. The Opacus library is focused on a user interested in differential privacy, making it difficult to extract the per sample gradients instead of training a model using the full Opacus system. With functorch, a user will need to rewrite their model to not include a batch dimension and make sure that they aren’t using any imperative constructs, like Python data structure mutation.

Introduction

This new prototype system, called ExpandedWeights, offers a lightweight way that requires minimal changes to user code and is focused on only computing per sample gradients for PyTorch supported nn.Modules. It’s currently available in pytorch’s nightly build or fbcod master.

In practice, we’ve seen that this is at least 2x faster than using Opacus’ current backwards hooks while using their same algorithms by using autograd.Function, which avoids unnecessarily computing batched gradients for the weights and lets us avoid holding unnecessary inputs. We’re actively working with Opacus to have them use the ExpandedWeights system.

Below are two benchmarks we took based on the resnet benchmark example and IMDB example both in the Opacus repo. For each benchmark we used toy data and only computed and threw away the per sample gradients since Opacus will still handle the differentially private training

Note that all of these graphs only show batch sizes up to 128. In both examples when we tried with a batch size of 256, Opacus hit OOM errors while ExpandedWeights was able to run.

Usage

The usage mimics normal model calls in PyTorch. For a basic training call, a user will do

from nn.utils._per_sample_grad import call_for_per_sample_grads
res = call_for_per_sample_grads(model)(input)
loss(res, labels).backward()
for param in model.parameters():
  param.grad_sample # instead of the grad, per sample gradients are in the grad_sample field

This currently works for Linear, Conv{1, 2, 3}D, Embedding, LayerNorm, InstanceNorm{1, 2, 3}D, and GroupNorm layers by upstreaming Opacus’ algorithms to PyTorch and switching them to use autograd.Function instead of backwards hooks. Additionally, this works with any modules without weights, like activation functions or pooling layers. This work covers all but one of the examples in the Opacus repo.

If ExpandedWeights tries to get batched gradients for an operation that is unsupported, it will currently error and ask the user to file an issue for support. Specifically, this will currently error with RNNs, which will require some more extensive changes to PyTorch’s RNN function. These changes are our next focus for the ExpandedWeights system. If there’s other layers we should prioritize, please share any requests.

When should I use this instead of functorch or Opacus?

Essentially, the focus of ExpandedWeights is to add native support for per sample gradients of PyTorch modules so that users can easily build on this for any possible use case. The other options of functorch or Opacus options will either require more changes to user code or have more specific users.

Functorch is much more powerful, allowing a user to arbitrarily compose transforms and use any PyTorch function instead of our limited set of modules. However, it requires a user to write their model to not use a batch dimension, make sure their models are functional, and then run their model without side effects, like arbitrary mutation of python data structures. ExpandedWeights is much lighter weight, only requiring a couple lines of changed code from the original model, and works with arbitrary mutation unsupported by functorch.

Opacus is a much more focused library, enabling researchers to quickly experiment with incorporating differential privacy in their workflows. ExpandedWeights is a more generic solution than this, focusing on only getting per sample gradients and not on the DP-specific parts of the library. Because of the speedups and clear error messaging ExpandedWeights provide, we are actively working with the Opacus team to use ExpandedWeights for Opacus’ per sample gradient computations.

Future Work

  • RNN support for Expanded Weights is a top priority that our Opacus collaborators have already asked about. To do this we need RNN to be able to be decomposed into a couple of linear calls, which we’ll describe in an upcoming RFC. With this, we’ll also add the ability for Expanded Weights to have inputs where the batch dimension is the second dimension, which occurs in RNNs.

  • We also plan to introduce stronger verification on the module passed. Opacus has already pointed out a couple of places where our verification lets through unsupported cross batch interaction.

  • Finally, we are looking longer term at automatically detecting allowable cross batch interaction using another tensor subclass. Currently Opacus has flags that ask a user if the first dimension is the batch and what type of loss reduction is used. With the current Expanded Weights design, we would need the same user intervention to get correct per sample gradients. We’re exploring adding another tensor subclass for the inputs that would give let us detect these behaviors automatically

3 Likes