Adapting Models to use TorchScript and Getting them to Produce Fusions

I have been going through the exercise of taking a widely used model and converting it to TorchScript. In particular, I have been emphasizing attempting to convert the entire model to enable fusions across software blocks and making it easier for a user to simply wrap the model in torch.jit.script() instead of having to limit themselves to specific parts.

One common idiom I have discovered in the HuggingFace Bert model that has caused me trouble is the usage of Optional parameters in a Module’s forward method. These parameters are used to describe the configuration of the Module via a runtime parameter even though the usage of the parameter is constant for that Module instance. Here is an example where the parameter mask is either a Tensor or None to convey the application of a mask on the input Tensor.

class ExampleModule(torch.nn.Module):
    def __init__(self, hidden_size, scale_factor):
        super().__init__()
        self.scale = int(hidden_size / scale_factor)

    def forward(
        self,
        inputs,
        mask=None
    ):
        tmp = inputs / self.scale
        if mask is not None :
           tmp = tmp + mask
        outputs = torch.nn.functional.softmax(tmp, dim=-1)

        return outputs

The expectation is that this Module should fuse. Out-of-the-box, it does not because the forward method’s mask parameter has a default value of None. This confuses TorchScript because the mask can have a type of torch.Tensor or the type None which gives the following error.

RuntimeError: 
Expected a default value of type Tensor (inferred) on parameter "mask".Because "mask" was not annotated with an explicit type it is assumed to be type 'Tensor'.:

This requires the first modification to the module which is to add Type hints given that all parameters are not only of type torch.Tensor.

class ExampleModule(torch.nn.Module):
    def __init__(self, hidden_size, scale_factor):
        super().__init__()
        self.scale = int(hidden_size / scale_factor)

    def forward(
        self,
        inputs : torch.Tensor,
        mask : Optional[torch.Tensor]=None
    ) -> torch.Tensor :
        tmp = inputs / self.scale
        if mask is not None :
           tmp = tmp + mask
        outputs = torch.nn.functional.softmax(tmp, dim=-1)

        return outputs

While TorchScript now accepts the Module without error, it does not fuse the operators in the forward method. One way you can tell a fusion was not made is to print out the IR Graph that lacks a Fusion Group. You see a prim::If node that blocks fusion, instead.

graph(%self : __torch__.ExampleModule,
      %inputs.1 : Tensor,
      %mask.1 : Tensor?):
  %5 : None = prim::Constant() # example.py:15:23
  %4 : int = prim::Constant[value=1]()
  %3 : int = prim::Constant[value=-1]() # example.py:17:55
  %6 : int = prim::GetAttr[name="scale"](%self)
  %tmp.2 : Tensor = aten::div(%inputs.1, %6) # example.py:14:14
  %9 : bool = aten::__isnot__(%mask.1, %5) # example.py:15:11
  %tmp : Tensor = prim::If(%9) # example.py:15:8
    block0():
      %mask.4 : Tensor = prim::unchecked_cast(%mask.1)
      %tmp.3 : Tensor = aten::add_(%tmp.2, %mask.4, %4) # example.py:16:12
      -> (%tmp.3)
    block1():
      -> (%tmp.2)
  %ret.2 : Tensor = aten::softmax(%tmp, %3, %5) # 
  return (%ret.2)

The second code modification to enable fusion is to replace the check for None of the runtime parameter mask with a conditional class attribute that is marked as constant as suggested in the documentation.

class ExampleModule(torch.nn.Module):
    use_mask : Final[bool]
    def __init__(self, hidden_size, scale_factor, use_mask):
        super().__init__()
        self.scale = int(hidden_size / scale_factor)
        self.use_mask = use_mask

    def forward(
        self,
        inputs : torch.Tensor,
        mask : Optional[torch.Tensor]=None
    ):
        tmp = inputs / self.scale
        if self.use_mask :
            tmp = tmp + mask
        outputs = torch.nn.functional.softmax(tmp, dim=-1)

        return outputs

This modification fails to fuse as well and trips an error in TorchScript as the mask parameter is seen as type Optional[torch.Tensor] instead of torch.Tensor.

RuntimeError: 
Arguments for call are not valid.
The following variants are available:
  
  aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor):
  Expected a value of type 'Tensor' for argument 'other' but instead found type 'Optional[Tensor]'.
  
  aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor):
  Expected a value of type 'number' for argument 'other' but instead found type 'Optional[Tensor]'.
  
  aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> (Tensor(a!)):
  Expected a value of type 'Tensor' for argument 'other' but instead found type 'Optional[Tensor]'.
  
  aten::add.t(t[] a, t[] b) -> (t[]):
  Could not match type Tensor to List[t] in argument 'a': Cannot match List[t] to Tensor.
  
  aten::add.str(str a, str b) -> (str):
  Expected a value of type 'str' for argument 'a' but instead found type 'Tensor'.
  
  aten::add.int(int a, int b) -> (int):
  Expected a value of type 'int' for argument 'b' but instead found type 'Optional[Tensor]'.
  
  aten::add.float(float a, float b) -> (float):
  Expected a value of type 'float' for argument 'b' but instead found type 'Optional[Tensor]'.
  
  aten::add.int_float(int a, float b) -> (float):
  Expected a value of type 'float' for argument 'b' but instead found type 'Optional[Tensor]'.
  
  aten::add.float_int(float a, int b) -> (float):
  Expected a value of type 'int' for argument 'b' but instead found type 'Optional[Tensor]'.
  
  aten::add(Scalar a, Scalar b) -> (Scalar):
  Expected a value of type 'number' for argument 'b' but instead found type 'Optional[Tensor]'.
  
  add(float a, Tensor b) -> (Tensor):
  Expected a value of type 'Tensor' for argument 'b' but instead found type 'Optional[Tensor]'.
  
  add(int a, Tensor b) -> (Tensor):
  Expected a value of type 'Tensor' for argument 'b' but instead found type 'Optional[Tensor]'.

The original call is:
  File "example.py", line 18
        tmp = inputs / self.scale
        if self.use_mask :
            tmp = tmp + mask
                  ~~~~~~~~~~ <--- HERE
        outputs = torch.nn.functional.softmax(tmp, dim=-1)

One option to resolve this problem is to apply an assert on the mask if it is of type None as the assert strips out the inner type. However, the result is identical to a conditional on None where a fusion is blocked. Another option is to make the parameter default a zero dimension Tensor. This works, but is it the best solution?

class ExampleModule(torch.nn.Module):
    use_mask : Final[bool]
    def __init__(self, hidden_size, scale_factor, use_mask):
        super().__init__()
        self.scale = int(hidden_size / scale_factor)
        self.use_mask = use_mask

    def forward(
        self,
        inputs : torch.Tensor,
        mask : torch.Tensor=torch.Tensor()
    ):
        tmp = inputs / self.scale
        if self.use_mask :
            tmp = tmp + mask
        outputs = torch.nn.functional.softmax(tmp, dim=-1)

        return outputs

If I have a parameter that is only conditionally used, it seems more natural to leave it as an Optional typed argument but it seems like there is no way to do this that will allow a fusion to occur. One idea was to cast the Optional to a Tensor if the conditional is True but that doesn’t work. Perhaps, having TorchScript support cast would be better? Are there better suggestions in how to handle Optional parameters such that they produce a fusion?

class ExampleModule(torch.nn.Module):
    use_mask : Final[bool]
    def __init__(self, hidden_size, scale_factor, use_mask):
        super().__init__()
        self.scale = int(hidden_size / scale_factor)
        self.use_mask = use_mask

    def forward(
        self,
        inputs : torch.Tensor,
        mask : Optional[torch.Tensor]=None
    ):
        tmp = inputs / self.scale
        if self.use_mask :
            tmp = tmp + cast(torch.Tensor, mask)
        outputs = torch.nn.functional.softmax(tmp, dim=-1)

        return outputs

Resulting error from using cast:

RuntimeError: 
builtin cannot be used as a value:
  File "example.py", line 18
        tmp = inputs / self.scale
        if self.use_mask :
            tmp = tmp + cast(torch.Tensor, mask)
                             ~~~~~~~~~~~~ <--- HERE
        outputs = torch.nn.functional.softmax(tmp, dim=-1)

Full working code example that produces a fusion:

import torch
from typing import Optional,Final,cast

class ExampleModule(torch.nn.Module):
    use_mask : Final[bool]
    def __init__(self, hidden_size, scale_factor, use_mask):
        super().__init__()
        self.scale = int(hidden_size / scale_factor)
        self.use_mask = use_mask

    def forward(
        self,
        inputs : torch.Tensor,
        mask : torch.Tensor=torch.Tensor()
    ):
        tmp = inputs / self.scale
        if self.use_mask :
            tmp = tmp + mask
        outputs = torch.nn.functional.softmax(tmp, dim=-1)

        return outputs

model = ExampleModule(1024, 16, True).to('cuda')
jit_model = torch.jit.script(model)

inputs = torch.randn([128, 16, 64, 64], device='cuda', requires_grad=True)
mask = torch.randn([128, 1, 1, 64], device='cuda', requires_grad=False)

with torch.jit.fuser("fuser2"):
    for idx in range(2) :
        if idx == 1 :
            print(jit_model.forward.graph_for(inputs, mask))
        out = jit_model.forward(inputs, mask)
4 Likes

Hi Kevin,

great observation!

One thing to note is that we used to fuse this idiom with the old executor because we did specialize graphs on the presence of optimizers. For all the goodness it brings, this is one of the things that have been lost when we moved to the new profiling executor.
The old code did so in the argument_spec.

As this would seem to be a common case, also for functions scripted with PyTorch, I wonder if it might make sense to actually continue to specialize functions on presence of optional parameters.
@eellison ?

Best regards

Thomas

2 Likes

Great post!

For the second example module - if self.use_mask - I think there are a few possible solutions here. We have a flag on master with our Freezing Optimization so that that buffers, parameters, train variable, and any attributes that gets written to all get preserved as attributes, and everything else gets inlined. This is basically an inversion of __constants__ - attributes which are mutated are inferred to be preserved, and we could add a __preserved__ attribute to the module for user-specified attributes. We would probably have to be a little more careful about escape analysis / constant propagation to turn this on by default. There haven’t been any tests yet but if we could get some data that this speeds up training we could work on promoting it more & put more work into it.

This pretty much works on master with the the not-yet-public preservedParameters flag (should maybe be named preserveForTraining) in torch._C._freeze_module, just requires the following small change-set

+++ b/torch/csrc/jit/passes/freeze_module.cpp
@@ -438,6 +438,9 @@ class AttributePropagator {
             if (iter2 != iter->second.end())
               paramConst = iter2->second;
           }
+          if (name == "training") {
+            continue;
+          }
           if (!paramConst) {
             auto attr = attrModule.attr(name);
             if (!isEval || preserveParameters_) {

If you run it with batchnorm, you’ll see that all the necessary fields are preserved running_mean, etc, and then when run with the if self.use_mask example module you see the correct fields get inlined.

mod = torch.jit.script(ExampleModule(2, 3, True))
out = torch.jit._script.RecursiveScriptModule(torch._C._freeze_module(mod._c, preserveParameters=True))
print(out.graph)
graph(%self : __torch__.___torch_mangle_0.ExampleModule,
      %inputs.1 : Tensor,
      %mask.1 : Tensor):
  %7 : int = prim::Constant[value=1]()
  %6 : int = prim::Constant[value=-1]() # test/example.py:18:55
  %4 : None = prim::Constant()
  %self.scale : int = prim::Constant[value=0]()
  %tmp.1 : Tensor = aten::div(%inputs.1, %self.scale) # test/example.py:15:14
  %tmp.3 : Tensor = aten::add(%tmp.1, %mask.1, %7) # test/example.py:17:18
  %ret.1 : Tensor = aten::softmax(%tmp.3, %6, %4) # /home/eellison/.conda/envs/work/lib/python3.7/site-packages/torch/nn/functional.py:1583:14
  return (%ret.1)

The other solution here would be profile ivalues that dont change and optimize the graph with bailouts. That sort of dovetails into the other module, with the optional input… One, about the annotation, there’s an issue out that would make it unnecessary, somehow it slipped through triaging. I just added it to be re-triaged, hopefully it will be fixed soon. Two - i’m not sure how much of an issue this is in practice, since typically modules are part of a larger graph where the is None check gets optimized out, but is a good thing to think about especially in relation to things like ScriptTorch / partial model scripting.

But, yes, it would be great to profile and optimize the graph more, and I think one way to do that is with Bailouts. This was how the profiling executor was originally designed but was simplified in an effort to make shipping it simpler. We could also consider profiling the graph and then specializing the inputs & adding a cache similar to our existing legacy executor. I do think once there is more evidence that non-fusion optimizations (including graph cleanup), but also memory reuse, pre-computed dispatch, etc, it would make sense to work more along these lines. Personally I think this is probably a ways off, and requires a rethinking & reimplementation of how we go about statically analyzing Tensor Types. The existing shape analysis pass isn’t tested, only works for one particular use case, bundles shape and device/dtype inference in a way that’s inseparable, and was a major pain to maintain. I think for us to get to Bailouts & graph optimization we probably need to redesign Shape & Dtype/Device inference. I have a concrete plan for a new Shape analysis i’ve hinted at to a few people and will be RFC’ing somewhat soon, that should go part of the way of clearing up shape control flow & enable more future optimization possibilities.

Another thing is, there may be simpler specializations we could enable along the way to more full fledged ones - such as specializing on the optionality of inputs as Tom suggested, or just the sizes but not dtypes/devices (the stuff that typically gets in the way of optimization). finding benchmarks where the legacy executor is faster than current executor is a good way to motivate these changes.

Curious to hear people’s thoughts, especially about the __constants____preserved__ inversion.

Hey Elias thanks for the response!

I am not following what the intention of __preserved__ is. In the case of __constants__, it makes sense to me to mark something constant to let the compiler know it is safe to optimize away. However, if something can mutate, that sounds like it is not safe to optimize away?

Also, were you suggesting that IValue profiling could make this could create a fusion where the conditional is asking if the parameter is None?

class ExampleModule(torch.nn.Module):
    def __init__(self, hidden_size, scale_factor):
        super().__init__()
        self.scale = int(hidden_size / scale_factor)

    def forward(
        self,
        inputs : torch.Tensor,
        mask : Optional[torch.Tensor]=None
    ) -> torch.Tensor :
        tmp = inputs / self.scale
        if mask is not None :
           tmp = tmp + mask
        outputs = torch.nn.functional.softmax(tmp, dim=-1)

        return outputs

If I were to grade code, I might prefer to minimally see the case where Module structure is communicated via class attributes, when a conditional does not vary during runtime, as it is a “better” way of describing the intention of the code, and, therefore would be good to see it cleanly supported. However, what is a good way, in python, of doing this when a parameter may not be needed and I don’t want to transfer a Tensor? Is this the a good way?

class ExampleModule(torch.nn.Module):
    use_mask : Final[bool]
    def __init__(self, hidden_size, scale_factor, use_mask):
        super().__init__()
        self.scale = int(hidden_size / scale_factor)
        self.use_mask = use_mask

    def forward(
        self,
        inputs : torch.Tensor,
        mask : Optional[torch.Tensor]=None
    ):
        tmp = inputs / self.scale
        if self.use_mask :
            tmp = tmp + cast(torch.Tensor, mask)
        outputs = torch.nn.functional.softmax(tmp, dim=-1)

        return outputs

I have some other issues I have run into.

The first involves view. The code writers used the following idiom a few times. The resulting error is that view was unable to infer the size of the list. That makes sense given the list is defined by relative indices into the Tensor which can’t be known until runtime.

class ExampleModule(torch.nn.Module):
    def __init__(self, hidden_size, scale_factor):
        super().__init__()
        self.hidden_size = hidden_size
        self.scale_factor = scale_factor

    def forward(
        self,
        inputs : torch.Tensor,
    ):
        shape = inputs.size()[:-1] + [self.scale_factor, int(self.hidden_size / self.scale_factor)]
        outputs = inputs.view(*shape)
        outputs = outputs.permute(0, 2 , 1 , 3)

        return outputs

The resulting error:

RuntimeError: 
cannot statically infer the expected size of a list in this context:
  File "example_view.py", line 15
    ):
        shape = inputs.size()[:-1] + [self.scale_factor, int(self.hidden_size / self.scale_factor)]
        outputs = inputs.view(*shape)
                              ~~~~~~ <--- HERE
        outputs = outputs.permute(0, 2 , 1 , 3)

One possible fix was to define a fixed size list. I am not sure why this not inferable. Is this perhaps a bug?

class ExampleModule(torch.nn.Module):
    def __init__(self, hidden_size, scale_factor):
        super().__init__()
        self.hidden_size = hidden_size
        self.scale_factor = scale_factor

    def forward(
        self,
        inputs : torch.Tensor,
    ):
        shape = [inputs.size(0), inputs.size(1), self.scale_factor, int(self.hidden_size / self.scale_factor)]
        outputs = inputs.view(*shape)
        outputs = outputs.permute(0, 2 , 1 , 3)

        return outputs

To compile this particular case, you can actually just pass the list to view (inputs.view(shape)), incidentally, this is just what the resulting graph does, anyway. In fact, when I last looked (a looong time ago), translate to list arg in C++ was the (only) this “vararg in python” pattern the codegen supported.
Now, we might like to support that pattern better.

So, freezing is the process of inlining attributes which are analyzed to not be mutated in the forward. Because we do the analysis, it enables it to work with modules like batchnorm which update attributes during training. We haven’t really got to it, but it’s a TODO to promote this API for training (right now it just is used for inference). The only difference with training is we would preserved parameters, buffers, and the “training” flag.

__preserved__ would be basically enabling freezing for training by default - we’re going to inline non-parameters and non-buffers as a constant unless it’s been mutated, or you explicitly tell us not to with __preserved__. Maybe that’s a step too far, and we should just promote the freezing for training api more. In the example you have of use_mask : Final[bool] , the user wouldn’t have to add the Final annotation (which is easy to forget to it) with the freezing for training api.

We can open an issue about the inputs.view(*shape) error. While that case may be difficult to support, it might not be too hard to add a better error msg & suggestion of Tom’s alternative.

Thanks for the clarification Elias! That sounds much easier for the user to more or less assume every conditional is constant for the run versus having to explicitly mark things which is extra work to enable fusion out-of-the-box. I think having advanced features for advanced users like the ability to mark something as mutable with __preserved__ is good if there is a benefit to giving you that information up front.

I had interpreted the view documentation as suggesting that it didn’t take a list directly, therefore, I didn’t try it. Thanks for pointing that out Tom! I had worked around it by providing the variable arguments to view directly.

1 Like

Another obstacle in getting a Bert model through TorchScript has been the handling of a ModuleList that is used to define a stack of layers for the Bert Model. For instance, the Bert Large model contains 24 identical layers.

One issue that exists is that when the ModuleList is partitioned into DifferentiableGraphs, TorchScript makes unique DifferentiableGraphs for each instance. Is this a bug or is there a way to reuse DifferentiableGraphs. Depending on the view of the responsibility of the backend fuser, it is possible to trigger unique fusions for each unique DifferentiableGraph. When TorchScript is not wrapped around the ModuleList and only specific idioms, better reuse happens.

Here is an example. I cut down the Multihead Attention Module into an Example Module.

class ExampleSubmodule(nn.Module):
    def __init__(self, hidden_size, head_size, dropout_prob):
        super().__init__()
        self.input_linear = nn.Linear(hidden_size, hidden_size)
        self.dropout1 = nn.Dropout(dropout_prob)
        self.dropout2 = nn.Dropout(dropout_prob)

        self.heads = int(hidden_size / head_size)
        self.head_size = head_size
        self.hidden_size = hidden_size

    def forward(self, inputs):
        output1 = self.input_linear(inputs)

        output1 = output1.view(output1.size(0), output1.size(1), self.heads, self.head_size)
        output1 = output1.permute(0, 2, 1, 3)

        # Fusion1 - Div+Dropout
        output2 = output1 / math.sqrt(self.heads)
        output3 = self.dropout1(output2)

        output3 = output3.permute(0, 2, 1, 3).contiguous()
        output3 = output3.view(output3.size(0), output3.size(1), self.hidden_size)

        # Fusion2 - Dropout+Add
        output4 = self.dropout2(output3)
        output5 = output4 + inputs

        return output5

class ExampleModule(nn.Module):
    def __init__(self, hidden_size, head_size, dropout_prob, layers):
        super().__init__()
        self.layers = nn.ModuleList([ExampleSubmodule(hidden_size, head_size, dropout_prob) for x in range(0,layers)])

    def forward(self, inputs):
        for layer in self.layers :
            inputs = layer(inputs)
        return inputs

When I execute 2 layers, of the submodule, above, I get 4 Differentiable Graphs. Two of those graphs are identical. Here is what the graphs look like:

Diff Group 0 and Diff Group 2 look identical.

with prim::DifferentiableGraph_0 = graph(%24 : bool,
      %29 : float,
      %38 : Tensor):
  %5 : int = prim::Constant[value=0]()
  %121 : float = prim::Constant[value=1.1111111111111112]()
  %34 : int[] = prim::Constant[value=[0, 2, 1, 3]]()
  %output1.9 : Tensor = aten::permute(%38, %34) # example_module_list.py:30:18
  %130 : bool = prim::CudaFusionGuard[types=[Float(128, 64, 128, 16, strides=[131072, 16, 1024, 1], requires_grad=0, device=cuda:0)]](%output1.9)
  %138 : bool = prim::Constant[value=1]()
  %139 : bool = prim::Constant[value=1]()
  %140 : bool = aten::__xor__(%24, %139)
  %141 : bool = aten::__xor__(%140, %138)
  %142 : bool = aten::__and__(%141, %130)
  %128 : Tensor, %129 : Tensor = prim::If(%142)
    block0():
      %res.2 : Tensor, %mask.4 : Tensor = prim::CudaFusionGroup_0[cache_id=1](%output1.9, %29)
      -> (%res.2, %mask.4)
    block1():
      %143 : Function = prim::Constant[name="fallback_function", fallback=1]()
      %144 : (Float(128, 64, 128, 16, strides=[131072, 16, 1024, 1], requires_grad=0, device=cuda:0), Bool(128, 64, 128, 16, strides=[131072, 16, 1024, 1], requires_grad=0, device=cuda:0)) = prim::CallFunction(%143, %output1.9, %24, %29)
      %145 : Float(128, 64, 128, 16, strides=[131072, 16, 1024, 1], requires_grad=0, device=cuda:0), %146 : Bool(128, 64, 128, 16, strides=[131072, 16, 1024, 1], requires_grad=0, device=cuda:0) = prim::TupleUnpack(%144)
      -> (%145, %146)
  %60 : Tensor = aten::permute(%128, %34) # <string>:226:19
  %51 : Tensor = aten::contiguous(%60, %5) # <string>:29:19
  return (%51, %51, %51, %121, %129)
with prim::DifferentiableGraph_2 = graph(%24 : bool,
      %29 : float,
      %38 : Tensor):
  %5 : int = prim::Constant[value=0]()
  %121 : float = prim::Constant[value=1.1111111111111112]()
  %34 : int[] = prim::Constant[value=[0, 2, 1, 3]]()
  %output1.9 : Tensor = aten::permute(%38, %34) # example_module_list.py:30:18
  %130 : bool = prim::CudaFusionGuard[types=[Float(128, 64, 128, 16, strides=[131072, 16, 1024, 1], requires_grad=0, device=cuda:0)]](%output1.9)
  %138 : bool = prim::Constant[value=1]()
  %139 : bool = prim::Constant[value=1]()
  %140 : bool = aten::__xor__(%24, %139)
  %141 : bool = aten::__xor__(%140, %138)
  %142 : bool = aten::__and__(%141, %130)
  %128 : Tensor, %129 : Tensor = prim::If(%142)
    block0():
      %res.2 : Tensor, %mask.4 : Tensor = prim::CudaFusionGroup_0[cache_id=1](%output1.9, %29)
      -> (%res.2, %mask.4)
    block1():
      %143 : Function = prim::Constant[name="fallback_function", fallback=1]()
      %144 : (Float(128, 64, 128, 16, strides=[131072, 16, 1024, 1], requires_grad=0, device=cuda:0), Bool(128, 64, 128, 16, strides=[131072, 16, 1024, 1], requires_grad=0, device=cuda:0)) = prim::CallFunction(%143, %output1.9, %24, %29)
      %145 : Float(128, 64, 128, 16, strides=[131072, 16, 1024, 1], requires_grad=0, device=cuda:0), %146 : Bool(128, 64, 128, 16, strides=[131072, 16, 1024, 1], requires_grad=0, device=cuda:0) = prim::TupleUnpack(%144)
      -> (%145, %146)
  %60 : Tensor = aten::permute(%128, %34) # <string>:226:19
  %51 : Tensor = aten::contiguous(%60, %5) # <string>:29:19
  return (%51, %51, %51, %121, %129)

Is there a reason that both are needed?

As a follow up to the example in the ModuleList post, is there a reason why view operators create a partition in the DifferentiableGraph?

The reason is that views are special in autograd in that the autograd node is linked to the base Tensor’s autograd node. This isn’t currently handled by the autograd node for differentiable graphs, the todo is here:

That said, last I looked, I was of the opinion that we could probably use the x.view_as(x)-trick to get around that - the view_as returns something which doesn’t share the backward node. This is done by autograd for custom functions returning inputs that are views:

I had tried my hand at this and I think I had this figured out, but I managed to shoot myself into the foot during testing it by producing illegal backwards. I was confused enough by my testing problems to not figure out whether this would be a solution. Maybe @albanD knows if that could work.

Best regards

Thomas

1 Like

I think the main reasons are:

  • view + inplace is not supported by the autodiff. So theses ops are excluded.
  • view alone are not supported either as you show in your comment. In particular, it would require marking the outputs proper autograd views which the autodiff doesn’t know how to do.

That said, last I looked, I was of the opinion that we could probably use the x.view_as(x) -trick to get around that

Not sure what you mean by that?
I think the main issue here is that:

  • it is not trivial to get all the potential aliasing
  • it is not trivial to handle taking views of tensors created outside (version counter sharing + view metadata)
  • it is not trivial to return the right thing when you return one or more Tensors that are the view of the same base. Also this base needs to exist which can be problematic if it was optimized away
  • combinations with inplace ops outside of the autodiff block would make the whole backward fail as rewriting the graph for such “macro Node” is impossible.

I am probably missing something but my understanding was that we have several scenarios here:
Outputs can be

  1. plain (non-views) created in the graph
  2. views of something created in the graph
  3. views of (view or non-view) inputs
  4. non-view or view inputs that are returned as is (?)

So it seems 1. is the unproblematic case. But given that we do our own differentiation and won’t be using autograd to get the backward, I thought that 2. probably could be treated just as 1.
My understanding was that 3. is the difficult case, and that identifying 3. is difficult. But I also thought that user-defined functions had a similar problem, and that x.view_as(x) made them “have separate backward” in the code line I linked above.
I’m not sure if 4. happens and if it would be problematic/required some treatment, too.

But as I said, I’m missing something.

The thing is that the custom Functions are properly tracking views (because views are still tracked during the forward) and properly forbidding inplace on the result if it is any kind of view (via CreationMeta).

  1. is not a problem as we can do a simples as_view as done in the custom function to go back to 3.

The problem with 2 is that the autograd needs to know if your two outputs are views of each other or not (and what their base is). Otherwise, when inplace happens, we don’t know how to handle it properly.

@tom I have read your blog post on integrating a C++ extension into TorchScript. I was wondering if there is a workaround for custom python functions where you define the forward and backward pass. I have a forward pass implementation of Multihead Attention from Transformer networks that defines an output shape that makes it not possible to use autograd and autodiff.

1 Like

Nowadays, when you call autograd.Functions from TorchScript, it’ll insert a Python fallback operator (I forget if that’s scripting, tracing or both).
Another thing that I experimented with was to construct the a DifferentiableGraph manually from scripted functions for forward and backward (but I haven’t yet written that up on the blog).
Of course, the long term goal of that would be to have some form of autograd.Function in TorchScript, but that’s a bit messy because you need to delay inlining until after autodiff or somesuch. This might be easier than last time I did a prototype for this (~2 years ago?).

Best regards

Thomas

2 Likes