Multi-GPU management extension

Hello, I am working on a small extension to allow quick scalability testing with multiple gpus without making much changes to the original code. Specifically, I am facing issue with autograd backward call.

For data management, the tensors are transferred between GPUs. The forward pass works properly but during the backward pass, there is a mismatch between device that is “expected” by autograd check. In reality, the module is taking care that everything ends up on a single device just that it may be on a different device for some tensors.

To do this, I am creating a tensor subclass. When I digged through the autograd code, I came across this github issue - Fix autograd engine checks · Issue #65016 · pytorch/pytorch (github.com)

Even though I am subclassing the tensor, for some reason it checks the metadata device to the actual grad device.

 if (grad.device() != metadata.device()) {
      // quick hack for: https://github.com/pytorch/pytorch/issues/65016 but
      // should be eventually removed
      if (!(metadata.is_tensor_subclass() ||
            grad.unsafeGetTensorImpl()->is_python_dispatch())) {
        if (grad.dim() == 0) {
          grad = grad.to(metadata.device());
        } else {
          std::stringstream ss;
          ss << "invalid gradient at index " << i << " - expected device ";
          ss << metadata.device() << " but got " << grad.device();
          AT_ERROR(format_error(ss.str()));
        }
      }
    }

I am unsure how I can disable the device check.

Your help would be much appreciated.

Hi,

I’m afraid there is no way to disable this device test.
But if you have a subclass, you should be able to make the subclass “pretend” to be on the right device even you store the data on a different device under the hood?

Hi,

I tried doing that but failed. I will give a minimalistic example of what the exact problem is.

l1 = wrapped_module(nn.Linear(5,5))
l2 = wrapped_module(nn.Linear(5,1))
in1 = torch.rand(5).as_subclass(myclass)
in2 = torch.rand(5).cuda().as_subclass(myclass)
loss1 = l2(l1(in1)) # Happens on cpu
loss2 = l2(l1(in2)) # l1 l2 transferred to gpu
loss = loss1 +loss2 # loss1 data transferred to gpu
loss.backward() # error comes from validate_output in engine.cpp

Here, wrapped module converts all the tensors in the module to myclass tensors. Now during the forward pass, l1 and l2 have the weights on cpu when working with in1. However, when it works with in2, they are shifted to a cuda device.

The forward pass works perfectly. In the backward pass, I ensure that any saved tensors are put on the cuda device when processing the backward pass. However, when the graph is constructed, the node function notes that the device of the tensor is cpu. The grad is constructed correctly on cuda device but node function check gives error because it expects the grad to be of type cpu.

In the condition check, I was not able to understand how exactly metadata.is_tensor_subclass() and is_python_disptach() gets set. Thought that if somehow I switch those flags, I can escape the check.

P.S.: I thought a possible way might be using torch_dispatch but could not figure out a way.

Thanks for the details!
How do you do Here, wrapped module converts all the tensors in the module to myclass tensors. here exactly? What most likely happens is that you change the Tensors inplace and thus the Tensors on CPU that are needed just don’t exist anymore.

You could try torch.__future__.set_overwrite_module_params_on_conversion(True) if you use the Module.to() method to change the device.

Thanks! What you suggested worked on a conceptual level. First to answer your question - wrapped_module changes the __class__ property of the parameter tensors. Its not a good way to do it but any other tricks I tried caused errors, either in the functional dispatch (for e.g. F.linear) or did not change the class at all.

With the option you mentioned turned on, it was not possible to directly change the class. There are two things that I am not sure about yet:

  1. Is there any way to use it for tensors rather than module parameters?
  2. I am changing the tensor device inplace but I have a record of the device it came from for the particular autograd node. (Pythonic implementation to store some metadata similar to autograd). You mentioned that the Tensors on CPU are required and I am not sure why. When the option is turned on and we change the module device, does it maintain a reference to CPU tensors and do a back and forth for gradient accumulation?

Some more detail for 2nd: So, I checked the autograd graph and it shows a <CopyBackward0> node and it makes sense if it’s keeping a copy till backward is called. Let me add that I further experimented on inplace tensor shifting and in cases where I can properly circumvent the device check I mentioned, there was no need of keeping the original tensor. The only issue comes when the parents of a node are on different devices (unless if its doing a ref_inc that I missed). The example I provided, this happens at the summation of loss1 and loss2. Here the node <AddBackward0> has two parents; one on gpu and the other on cpu. Now in theory I can use node hooks and prehooks to correct the device if required. However, the check of output (for <AddBackward0>) happens before hook is executed.

Please do let me know if there is a better way to wrap modules.

P.S.: If you feel that this is a minor issue and can be addressed, a simple fix is to execute node hooks before doing a device check (or any check for that matter).