Implementing Generalized Backpropagation

Hello, I’m a student at ETH Zürich. Kevin Du from published a paper last year proposing Generalizing Backpropagation for Gradient-Based Interpretability. Someone else and me are currently researching how we would implement this into PyTorch. There’s also a YouTube Video explaining the paper.

I read the contribution guide and we know that the first step would be to submit an RFC but since this is the first time we would be contributing to PyTorch and because this would be a bigger change that touches a lot different parts of PyTorch, we wanted to open a discussion even before we submit the RFC.


In backpropagation, we basically have a summation and a multiplication of gradients. The summation can be found in torch/csrc/autograd/functions/accumulate grad.h::accumulateGrad() while the multiplication can be found in the generated autograd functions apply().

A user would new have the ability to create arbritary semirings for both of those operations and provide them. The summation operation would be passed to accumulateGrad() and the multiplication operation would be passed to the autograd function, replacing the currently implemented summation and multiplication.

I’d like to get input to the following questions:

  1. Can we somehow implement this without actually changing the implementation of backprop? I looked at hooks but unless I misunderstood something, I can’t use them to achieve this goal. Please correct me if I am wrong.
  2. Assuming we have to touch the backpropagation implementation, what would be a good way of letting the user define such semiring operations? I know that PyTorch uses precompiled functions but I’m not familiar with it enough at the moment to really judge if we can reuse that. One of the reasons why we would like to implement it into PyTorch is performance. The current implementation is based on a very basic DL framework that is designed for teaching purposes.
1 Like


Thank you very much for your interest and reaching out.

I would have a couple of high level questions for you here (sorry if that’s not really directly linked to what you’re asking for):

  • To what extend does this formulation extend to other quantities that can be computed alongside backprop (like gradient statistics or hessian approximations as done in [1912.10985] BackPACK: Packing more into backprop for example, there are other projects doing this but this one is the first that come to mind) ?
  • Similarly, another extension that I thought about (but never explored) for your idea is to include “invertible network”. To get there, I think of backprop as being parametrized by 3 things: What does a chain rule op does (one step of vector jacobian product for backprop, compute the inverse of the forward for invertible network), how you aggregate gradients from different paths (via sum for backprop, any of the path’s value for invertible, the next step should make sure only one such path is used anyways) and (optionally) a path-finding algorithm (find all path for backprop, find the shortest (in flops?) path for invertible network). I wonder now if this could be expressed as a ring? Or the path selection breaks the analogy?

In terms of implementation, I think there are 3 main pieces: the ring space, the + and the x.
Let’s go from simpler to harder:

The + is indeed relatively simple to update. It happens here for the ones happening inside the graph and here for the ones “outside” the graph when working with the .grad field.
I’m sure we could make these configurable without too much trouble.

The x is simple to change but a lot of work. The way our backprop formulas are defined in derivatives.yaml is they encode one step of the backprop algorithm: meaning that the x operator is builtin.
So the simple solution is to rewrite all the formulas to use another x operator and swap them out.
This is technically feasible by swapping out the formulas being used but that will require re-write all of the formulas which is… a lot of work.

The ring space change (needed for entropy for sure but in a simpler way the other ones as well) is an interesting one.
We actually have a concept of “Tensor subclass” that you can implement and define whatever semantic you want on it. For example a complex Tensor in that world can be implemented like this (of course this is missing the actual implementation for most ops). We have some more detailed doc about subclass in here (it is not the best but it is a tricky subject to document).
Using this, you can actually implement a Tensor Subclass for which all computation will behave according to the ring you want. And then you can backprop that Tensor to change the “usual computations” to the ones you want.
Of course you want to be careful that the new semantic is only used to compute the chain rule and not the jacobian.

A minimal example of this is:

import torch
from torch import autograd
from torch.utils._pytree import tree_map

class MaxProdTensor(torch.Tensor):
    def __new__(cls, t):
        # We would need to define autograd on this ring, inception!
        assert t.requires_grad == False
        res = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
        return res

    def __init__(self, t):
        self.elem = t

    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
        def unwrap(t):
            if isinstance(t, cls):
                return t.elem
                return t

        def wrap(t):
            if isinstance(t, torch.Tensor) and not isinstance(t, cls):
                return cls(t)
                return t

        def run_with_usual_semantic():
            # Unpack as plain Tensor, run the normal op and repack
            return tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))

        func_with_same_semantic = [

        if func.overloadpacket.__name__ in func_with_same_semantic:
            return run_with_usual_semantic()
        elif func is torch.ops.aten.add.Tensor:
            # Add is the same as Max for the usual semantic
            func = torch.ops.aten.maximum
            return run_with_usual_semantic()
            raise NotImplementedError(f"todo {func}")

    def __repr__(self):
        return f"MaxProdTensor({self.elem})"

t1 = MaxProdTensor(torch.rand(2))
t2 = MaxProdTensor(torch.rand(2))

print("My Tensors")

print("t1 * t2")
print(t1 * t2)
print("t1 + t2")
print(t1 + t2)

def get_loss(inp):
    out = inp ** 2 + inp
    return out.sum()

def run(inp):
    print(f"Inp == {inp}")
    print("Regular autograd")
    print("Gradient is 2 * inp + 1")
    # The gradient flowing is on the regular real ring
    loss = get_loss(inp)
    grad_out = torch.ones_like(loss)
    print(autograd.grad(loss, inp, grad_out))

    # The gradient flowing is on the max prod ring
    loss = get_loss(inp)
    grad_out = MaxProdTensor(torch.ones_like(loss))
    print("Max prod grad is max(2 * inp, 1)")
    print(autograd.grad(loss, inp, grad_out))

inp = torch.ones(2, requires_grad=True)

inp = torch.ones(2, requires_grad=True) * 0.1

Which gives:

My Tensors
MaxProdTensor(tensor([0.2503, 0.2757]))
MaxProdTensor(tensor([0.9741, 0.5299]))
t1 * t2
MaxProdTensor(tensor([0.2438, 0.1461]))
t1 + t2
MaxProdTensor(tensor([0.9741, 0.5299]))

Inp == tensor([1., 1.], requires_grad=True)
Regular autograd
Gradient is 2 * inp + 1
(tensor([3., 3.]),)
Max prod grad is max(2 * inp, 1)
(MaxProdTensor(tensor([2., 2.])),)

Inp == tensor([0.1000, 0.1000], grad_fn=<MulBackward0>)
Regular autograd
Gradient is 2 * inp + 1
(tensor([1.2000, 1.2000]),)
Max prod grad is max(2 * inp, 1)
(MaxProdTensor(tensor([1., 1.])),)

As you can see the custom backward is computed the max prod quantity instead of sum prod :slight_smile:


Hello Alban,

Thanks for the response and please excuse the late reply! We wanted to dig through your response deeply to reply thoughtfully, which took some more time than expected.

To answer your questions:

  • First, in this framework, we would be interested in computing hessians as well because it would be useful to be able to compute the gradient of values computed via backprop (e.g., entropy). However, we weren’t sure what you meant by “gradient statistics”, could you elaborate a bit here?
  • Second, we aren’t sure what you meant by invertible networks here/how they apply to parameterizing backprop. Would you be able to share some references/elaborate more on what you meant? Generally, we agree that two of the main parameters of backprop are the +, x; however, we’re not sure we’re following re: the path finding bit. In our understanding, you could find the shortest path in backprop with a modification of the (max, x) semiring that also tracks the parent node, but that involves substituting the value of the + rather than introducing a new parameter.

Thanks a lot for your tips on the implementation! This helps a great deal. We’ll first start implementing the MaxProd semiring following your example and try to reproduce some of the results, and after that think about the more complicated cases (e.g., entropy).

We have a few followup questions here:

  • We see you’re using autograd.grad() and pass it preallocated memory and set the type of grad_out to the MaxProdTensor. Is this to ensure that the operator overloading only affects the gradient graph computations (as opposed to the forward pass, etc)? We were thinking that another alternative to ensure that we’re only using MaxProdTensor on the gradient graph is to use hooks and replace the tensor. Unfortunately, none of us can really judge which implementation strategy offers the best trade-off in terms of “ease of implementation” vs. performance. Do you have any opinions on which alternative might be better? :slight_smile: (Since we are planning on doing this for large language models like Llama 2, we consider performance to be a priority.)

  • We know that standard backprop is highly optimized on CUDA kernels; do you have a sense of how much of a performance hit we should expect if we implement max-prod backprop without any device-specific optimizations?

Thanks a lot in advance!

However, we weren’t sure what you meant by “gradient statistics”, could you elaborate a bit here?

The Backpack paper mentioned above considers the following quantities as being interesting.

Would you be able to share some references/elaborate more on what you meant?

This issue has a lot of discussion and link to relevant papers: [feature request] Core API for invertible/inplace and flow-like ops + memory-saving (hookless?) reversible sequential container for RevNets to allow for much larger batch-sizes in academic setting · Issue #23756 · pytorch/pytorch · GitHub

however, we’re not sure we’re following re: the path finding bit.

I’m not sure how to explain this so let me use a small example:

a = torch.rand(2, requires_grad=True)
b = a * 2
c = a * 3
loss = my_op(b, c)

In the usual definition of backprop, the branch for b and c each contribute some value and the value is “aggregated” using the + operation as: grad_a = grad_a_from_b + grad_a_from_c.
In the context of computing the inverse, I what you want to do is new_a = new_a_from_b or new_a_from_c because you expect new_a_from_b == new_a_from_c. But it is a bit of an abuse of notation so this is weird indeed. The issue linked above has more details on this idea.

We see you’re using autograd.grad()

Using this or .backward() doesn’t change anything.

Is this to ensure that the operator overloading only affects the gradient graph computations

Yes. Because all the other ops that are happening should be happening on regular Tensor and only the computation involving the “backward message” should be done in the special ring.

use hooks and replace the tensor.

The problem with hooks is that you cannot prevent the regular backward formula from running. So it won’t be good performance wise as you might be doing a lot of extra work.
Note that Backpack which is implemented with hooks does not have this issue because they explicitly expect the “real gradient” to be computed alongside their quantities. So they need the regular backward formula to run.

how much of a performance hit we should expect if we implement max-prod backprop without any device-specific optimizations?

The implementation I shared above actually calls back into more torch op (just change torch.add into torch.maximum), so it will use the optimized kernel for that op.
In general, you should try very hard to re-use existing ops (even if you do a bit of extra work) and define a custom function with manual device code only if there is no other choices.

I think some + can also be hidden in backprop formulas, for example in a broadcast addition, z (with shape[3, 2]) = x (with shape[2]) + y (with shape [3, 2]), grad_x should be grad_z.max(dim=0) not grad_z.sum(dim=0), and this can’t changed by accumulateGrad or accumulate.

This is a pretty tricky yes.
I think I am getting confused from the notation in

Where I think of the left most sum as the sum being done in the engine and the product being the matrix multiply being done in the backward of a matrix multiply for example.
I do think that it is trickier where, in this matmul of the backward, the products in this matmul are the product side and the reduction in the matmul corresponds to the sum part.

In that case, indeed, the pointers I shared are the ones where cross op re-use lead to summation but there are cases where intra-op reuse (like a matmul) also lead to these summation.
So yeah, I do think the subclass based approach is the only way to override all the other + if you don’t want to rewrite every single formula in the autograd engine.

Yes, it’s much more complicated than the origin backprop, and if my understanding is right, the reason why the notation is confusing is that the value of generalized backprop may not be an invariant when the computation graph changed (or say that for origin backprop the differential\sum\product parts in RHS may be exchanged when the computation graph changed based on how you define an operation), for example, consider that,

  1. y = 2x
  2. y = x + x (or “a = x, b = x, y = a + b” to make a diamond shape graph)

they are same in math, but for a (max, ×)-gradient, they will give different answers,

  1. grad_x = 2grad_y
  2. grad_x = max(grad_a, grad_b) = grad_y

it’s kind of weird.