Hey!
Thank you very much for your interest and reaching out.
I would have a couple of high level questions for you here (sorry if that’s not really directly linked to what you’re asking for):
- To what extend does this formulation extend to other quantities that can be computed alongside backprop (like gradient statistics or hessian approximations as done in [1912.10985] BackPACK: Packing more into backprop for example, there are other projects doing this but this one is the first that come to mind) ?
- Similarly, another extension that I thought about (but never explored) for your idea is to include “invertible network”. To get there, I think of backprop as being parametrized by 3 things: What does a chain rule op does (one step of vector jacobian product for backprop, compute the inverse of the forward for invertible network), how you aggregate gradients from different paths (via sum for backprop, any of the path’s value for invertible, the next step should make sure only one such path is used anyways) and (optionally) a path-finding algorithm (find all path for backprop, find the shortest (in flops?) path for invertible network). I wonder now if this could be expressed as a ring? Or the path selection breaks the analogy?
In terms of implementation, I think there are 3 main pieces: the ring space, the + and the x.
Let’s go from simpler to harder:
The + is indeed relatively simple to update. It happens here for the ones happening inside the graph and here for the ones “outside” the graph when working with the .grad
field.
I’m sure we could make these configurable without too much trouble.
The x is simple to change but a lot of work. The way our backprop formulas are defined in derivatives.yaml is they encode one step of the backprop algorithm: meaning that the x operator is builtin.
So the simple solution is to rewrite all the formulas to use another x operator and swap them out.
This is technically feasible by swapping out the formulas being used but that will require re-write all of the formulas which is… a lot of work.
The ring space change (needed for entropy for sure but in a simpler way the other ones as well) is an interesting one.
We actually have a concept of “Tensor subclass” that you can implement and define whatever semantic you want on it. For example a complex Tensor in that world can be implemented like this (of course this is missing the actual implementation for most ops). We have some more detailed doc about subclass in here (it is not the best but it is a tricky subject to document).
Using this, you can actually implement a Tensor Subclass for which all computation will behave according to the ring you want. And then you can backprop that Tensor to change the “usual computations” to the ones you want.
Of course you want to be careful that the new semantic is only used to compute the chain rule and not the jacobian.
A minimal example of this is:
import torch
from torch import autograd
from torch.utils._pytree import tree_map
class MaxProdTensor(torch.Tensor):
def __new__(cls, t):
# We would need to define autograd on this ring, inception!
assert t.requires_grad == False
res = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
size=t.size(),
strides=t.stride(),
storage_offset=0,
dtype=t.dtype,
layout=t.layout,
device=t.device,
requires_grad=False,
)
return res
def __init__(self, t):
self.elem = t
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def unwrap(t):
if isinstance(t, cls):
return t.elem
else:
return t
def wrap(t):
if isinstance(t, torch.Tensor) and not isinstance(t, cls):
return cls(t)
else:
return t
def run_with_usual_semantic():
# Unpack as plain Tensor, run the normal op and repack
return tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)))
func_with_same_semantic = [
"mul",
"expand",
]
if func.overloadpacket.__name__ in func_with_same_semantic:
return run_with_usual_semantic()
elif func is torch.ops.aten.add.Tensor:
# Add is the same as Max for the usual semantic
func = torch.ops.aten.maximum
return run_with_usual_semantic()
else:
raise NotImplementedError(f"todo {func}")
def __repr__(self):
return f"MaxProdTensor({self.elem})"
t1 = MaxProdTensor(torch.rand(2))
t2 = MaxProdTensor(torch.rand(2))
print("My Tensors")
print(t1)
print(t2)
print("t1 * t2")
print(t1 * t2)
print("t1 + t2")
print(t1 + t2)
def get_loss(inp):
out = inp ** 2 + inp
return out.sum()
def run(inp):
print("")
print(f"Inp == {inp}")
print("Regular autograd")
print("Gradient is 2 * inp + 1")
# The gradient flowing is on the regular real ring
loss = get_loss(inp)
grad_out = torch.ones_like(loss)
print(autograd.grad(loss, inp, grad_out))
# The gradient flowing is on the max prod ring
loss = get_loss(inp)
grad_out = MaxProdTensor(torch.ones_like(loss))
print("Max prod grad is max(2 * inp, 1)")
print(autograd.grad(loss, inp, grad_out))
inp = torch.ones(2, requires_grad=True)
run(inp)
inp = torch.ones(2, requires_grad=True) * 0.1
run(inp)
Which gives:
My Tensors
MaxProdTensor(tensor([0.2503, 0.2757]))
MaxProdTensor(tensor([0.9741, 0.5299]))
t1 * t2
MaxProdTensor(tensor([0.2438, 0.1461]))
t1 + t2
MaxProdTensor(tensor([0.9741, 0.5299]))
Inp == tensor([1., 1.], requires_grad=True)
Regular autograd
Gradient is 2 * inp + 1
(tensor([3., 3.]),)
Max prod grad is max(2 * inp, 1)
(MaxProdTensor(tensor([2., 2.])),)
Inp == tensor([0.1000, 0.1000], grad_fn=<MulBackward0>)
Regular autograd
Gradient is 2 * inp + 1
(tensor([1.2000, 1.2000]),)
Max prod grad is max(2 * inp, 1)
(MaxProdTensor(tensor([1., 1.])),)
As you can see the custom backward is computed the max prod quantity instead of sum prod