Weight sharing on cuda

Hi,
I wanted to understand how cuda manages to have weight sharing enabled across a module.to(), For example -

  class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.a = torch.nn.Parameter(torch.ones([1]))
        self.b = torch.nn.Parameter(torch.ones([1]))
    def forward(self, input):
        c = self.a*input + self.b*input
        return c
  mod = MyModule()
  mod.a = mod.b
  mod.to(device) 

The weight sharing (mod.a = mod.b) is retained only when device is cuda above, after the model.to().

On backends like hpu, this doesn’t work. Similarly, XLA also documents this as a limitation in TPU training (Advanced) — PyTorch Lightning 1.7.0dev documentation
"PyTorch XLA requires these weights to be tied/shared after moving the model to the XLA device. "
It would be very helpful to know how this works on cuda alone and what can be done to enable this on other backends similarly.

1 Like

First, mod.a == mod.b is not a correct check for weight-sharing. The correct check is mod.a.data_ptr() == mod.b.data_ptr().

Secondly, to answer your question, it’s described here:

If you actually run the following code first:

torch.__future__.set_overwrite_module_params_on_conversion(True)

then you see that the weight sharing wont be true anymore.
This is not an intended behavior, it just happens as a side-effect of some old decisions that weren’t really meaning to have this effect.

to give some more details, the weight sharing is preserved for CUDA because we used to have a concept called Variable that wraps a Tensor. Tensor didn’t have a concept of .grad or grad_fn, only Variable did.
This “boxing” of a Tensor by a box called Variable had the following side-effect:

a = Variable(randn(10))
b = a

a.data.to(device='cuda')
# a and b are still the same object. A member of a was changed, so it's reflected in b

Once people started using and relied on this boxing – and often using .data long after we merged Variable and Tensor into a single object, we had the unfortunate job of preserving this behavior to keep backward-compatibility.

Thank you for your response @smth . If I understand it, this is a legacy behavior with Variable updating the weight tensors in-place by overwriting the contents from another tensor with a shallow copy. However, the comment in torch._has_compatible_shallow_copy_type seems to indicate that the future behavior would be to overwrite the existing tensor. Does this mean, in future, even on CUDA the weight sharing will eventually need to be done explicitly again after moving the model to the target device?
The reason for my question is that there are existing user models and infrastructure that rely on the legacy Variable based behavior and customers expect this to work on other backends too. For a backend other than CUDA, what should be the direction - should we work on enabling the weight sharing similar to CUDA, or can we point the customers to some RFC where this legacy Variable based behavior will be removed for CUDA as well? Let me know if opening an RFC to discuss this makes sense.

Yes. Though we don’t have a timeline on when we’ll switch this over.

once @albanD is back from his vacation, I’d like to hear his thoughts on if we should make the boxing be official behavior across backends (instead of their current “implicit” form).
I believe these have deep implications in the autodiff engine.

It would be good to know if the official behavior would be changed for all backends. We face issues when the existing models rely on this behavior that the weight sharing would remain across a model.to to target device. The sharing doesn’t remain on backends other than CUDA and this happens silently without any error/warning, causing wrong results.

Hey!

Does this mean, in future, even on CUDA the weight sharing will eventually need to be done explicitly again after moving the model to the target device?

This was the plan 3+ years ago when this future concept was added and we wanted to remove the use of .data = as much as possible.
Unfortunately, as you mentioned, many (if not most) models rely on this behavior and so we never executed forward on the removal.

For the HPU case, if you’re using a plain TensorImpl in c++ (not a subclass), then you could enable the same behavior as CUDA (and XPU, MPS and HIP btw). You just need to update the c++ function has_compatible_shallow_copy_type that says which backends can be shallow copied into each other (all the ones that use plain TensorImpl).

For a more general approach (that would also include Tensor subclasses in python), we are actually considering making nn.Parameter() a “true” subclass that can wrap any Tensor. So this would re-create a pair of object similar to the Variable/Tensor pair and so we will be able to swap any Tensor there. But as you might imagine, such a major change in Parameter has a lot of side effects and is hard to do as well. cc @ezyang

Hope this helps clarify things!

1 Like

Imma just drop this here: subclass_zoo/flat_view_tensor.py at main · albanD/subclass_zoo · GitHub

We tried to enable this on HPU @albanD but unfortunately a direct shallow copy via the set_data approach doesn’t work as we have a TensorImpl subclass for the HPU lazy tensors.

Given that the general approach for modifying nn.Parameter() is a difficult change, will it be then taken up at some point to disable to weight sharing for all backends?

If this is not viable, is there a way for enabling this weight sharing for HPU? We were evaluating if changing the TensorImpl in an existing TensorBase is a possible solution, let us know if you see any possible concerns with this.

I don’t think there is any plan to do that no. It would be “massively” BC-breaking with no benefit for a large set of our users.

If this is not viable, is there a way for enabling this weight sharing for HPU? We were evaluating if changing the TensorImpl in an existing TensorBase is a possible solution, let us know if you see any possible concerns with this.

Moving to use a plain TensorImpl will for sure allow you to enable this .data = pattern to work.
If that is not possible, the code shared above by Ed is a python-side way to do this by introducing the indirection at the python level. Your users would then simply to opt-in to this new style of Parameter and the sharing will be preserved.

@albanD Is there a way to get weight tying to work with a Parameter subclass without having to perform the tying after nn.Module.to() or using torch.__future__.set_overwrite_module_params_on_conversion()?

import torch
from torch import nn

class MyParameter(nn.Parameter):
    ...

class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.param = nn.Parameter(torch.randn(1)).requires_grad_(True)
        self.tied_param: nn.Parameter
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.param * x + self.tied_param

my_module = MyModule()
opt = torch.optim.Adam(my_module.parameters())

device = "cuda:0"
my_module.tied_param = MyParameter(my_module.param)  # if use you swap this line and the next, the weight tying works as expected
my_module = my_module.to(device)

opt.zero_grad()
my_module(torch.randn(10, device=device)).sum().backward()
opt.step()

print(my_module.param)
print(my_module.tied_param)

What makes this work with Parameter but not a Parameter subclass?

That’s because the subclass doesn’t pass this check? https://github.com/pytorch/pytorch/blob/150088a9cd638faa8b222e727166cd5c3932998d/torch/nn/modules/module.py#L812