I have been going through the exercise of taking a widely used model and converting it to TorchScript. In particular, I have been emphasizing attempting to convert the entire model to enable fusions across software blocks and making it easier for a user to simply wrap the model in `torch.jit.script()`

instead of having to limit themselves to specific parts.

One common idiom I have discovered in the HuggingFace Bert model that has caused me trouble is the usage of `Optional`

parameters in a Module’s `forward`

method. These parameters are used to describe the configuration of the Module via a runtime parameter even though the usage of the parameter is constant for that Module instance. Here is an example where the parameter `mask`

is either a Tensor or `None`

to convey the application of a mask on the input Tensor.

```
class ExampleModule(torch.nn.Module):
def __init__(self, hidden_size, scale_factor):
super().__init__()
self.scale = int(hidden_size / scale_factor)
def forward(
self,
inputs,
mask=None
):
tmp = inputs / self.scale
if mask is not None :
tmp = tmp + mask
outputs = torch.nn.functional.softmax(tmp, dim=-1)
return outputs
```

The expectation is that this Module should fuse. Out-of-the-box, it does not because the `forward`

method’s `mask`

parameter has a default value of `None`

. This confuses TorchScript because the mask can have a type of `torch.Tensor`

or the type `None`

which gives the following error.

```
RuntimeError:
Expected a default value of type Tensor (inferred) on parameter "mask".Because "mask" was not annotated with an explicit type it is assumed to be type 'Tensor'.:
```

This requires the first modification to the module which is to add Type hints given that all parameters are not only of type `torch.Tensor`

.

```
class ExampleModule(torch.nn.Module):
def __init__(self, hidden_size, scale_factor):
super().__init__()
self.scale = int(hidden_size / scale_factor)
def forward(
self,
inputs : torch.Tensor,
mask : Optional[torch.Tensor]=None
) -> torch.Tensor :
tmp = inputs / self.scale
if mask is not None :
tmp = tmp + mask
outputs = torch.nn.functional.softmax(tmp, dim=-1)
return outputs
```

While TorchScript now accepts the Module without error, it does not fuse the operators in the forward method. One way you can tell a fusion was not made is to print out the IR Graph that lacks a Fusion Group. You see a prim::If node that blocks fusion, instead.

```
graph(%self : __torch__.ExampleModule,
%inputs.1 : Tensor,
%mask.1 : Tensor?):
%5 : None = prim::Constant() # example.py:15:23
%4 : int = prim::Constant[value=1]()
%3 : int = prim::Constant[value=-1]() # example.py:17:55
%6 : int = prim::GetAttr[name="scale"](%self)
%tmp.2 : Tensor = aten::div(%inputs.1, %6) # example.py:14:14
%9 : bool = aten::__isnot__(%mask.1, %5) # example.py:15:11
%tmp : Tensor = prim::If(%9) # example.py:15:8
block0():
%mask.4 : Tensor = prim::unchecked_cast(%mask.1)
%tmp.3 : Tensor = aten::add_(%tmp.2, %mask.4, %4) # example.py:16:12
-> (%tmp.3)
block1():
-> (%tmp.2)
%ret.2 : Tensor = aten::softmax(%tmp, %3, %5) #
return (%ret.2)
```

The second code modification to enable fusion is to replace the check for None of the runtime parameter mask with a conditional class attribute that is marked as constant as suggested in the documentation.

```
class ExampleModule(torch.nn.Module):
use_mask : Final[bool]
def __init__(self, hidden_size, scale_factor, use_mask):
super().__init__()
self.scale = int(hidden_size / scale_factor)
self.use_mask = use_mask
def forward(
self,
inputs : torch.Tensor,
mask : Optional[torch.Tensor]=None
):
tmp = inputs / self.scale
if self.use_mask :
tmp = tmp + mask
outputs = torch.nn.functional.softmax(tmp, dim=-1)
return outputs
```

This modification fails to fuse as well and trips an error in TorchScript as the mask parameter is seen as type `Optional[torch.Tensor]`

instead of `torch.Tensor`

.

```
RuntimeError:
Arguments for call are not valid.
The following variants are available:
aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor):
Expected a value of type 'Tensor' for argument 'other' but instead found type 'Optional[Tensor]'.
aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor):
Expected a value of type 'number' for argument 'other' but instead found type 'Optional[Tensor]'.
aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> (Tensor(a!)):
Expected a value of type 'Tensor' for argument 'other' but instead found type 'Optional[Tensor]'.
aten::add.t(t[] a, t[] b) -> (t[]):
Could not match type Tensor to List[t] in argument 'a': Cannot match List[t] to Tensor.
aten::add.str(str a, str b) -> (str):
Expected a value of type 'str' for argument 'a' but instead found type 'Tensor'.
aten::add.int(int a, int b) -> (int):
Expected a value of type 'int' for argument 'b' but instead found type 'Optional[Tensor]'.
aten::add.float(float a, float b) -> (float):
Expected a value of type 'float' for argument 'b' but instead found type 'Optional[Tensor]'.
aten::add.int_float(int a, float b) -> (float):
Expected a value of type 'float' for argument 'b' but instead found type 'Optional[Tensor]'.
aten::add.float_int(float a, int b) -> (float):
Expected a value of type 'int' for argument 'b' but instead found type 'Optional[Tensor]'.
aten::add(Scalar a, Scalar b) -> (Scalar):
Expected a value of type 'number' for argument 'b' but instead found type 'Optional[Tensor]'.
add(float a, Tensor b) -> (Tensor):
Expected a value of type 'Tensor' for argument 'b' but instead found type 'Optional[Tensor]'.
add(int a, Tensor b) -> (Tensor):
Expected a value of type 'Tensor' for argument 'b' but instead found type 'Optional[Tensor]'.
The original call is:
File "example.py", line 18
tmp = inputs / self.scale
if self.use_mask :
tmp = tmp + mask
~~~~~~~~~~ <--- HERE
outputs = torch.nn.functional.softmax(tmp, dim=-1)
```

One option to resolve this problem is to apply an `assert`

on the `mask`

if it is of type `None`

as the `assert`

strips out the inner type. However, the result is identical to a conditional on `None`

where a fusion is blocked. Another option is to make the parameter default a zero dimension Tensor. **This works, but is it the best solution?**

```
class ExampleModule(torch.nn.Module):
use_mask : Final[bool]
def __init__(self, hidden_size, scale_factor, use_mask):
super().__init__()
self.scale = int(hidden_size / scale_factor)
self.use_mask = use_mask
def forward(
self,
inputs : torch.Tensor,
mask : torch.Tensor=torch.Tensor()
):
tmp = inputs / self.scale
if self.use_mask :
tmp = tmp + mask
outputs = torch.nn.functional.softmax(tmp, dim=-1)
return outputs
```

If I have a parameter that is only conditionally used, it seems more natural to leave it as an `Optional`

typed argument but it seems like there is no way to do this that will allow a fusion to occur. One idea was to `cast`

the `Optional`

to a Tensor if the conditional is True but that doesn’t work. Perhaps, having TorchScript support cast would be better? **Are there better suggestions in how to handle Optional parameters such that they produce a fusion?**

```
class ExampleModule(torch.nn.Module):
use_mask : Final[bool]
def __init__(self, hidden_size, scale_factor, use_mask):
super().__init__()
self.scale = int(hidden_size / scale_factor)
self.use_mask = use_mask
def forward(
self,
inputs : torch.Tensor,
mask : Optional[torch.Tensor]=None
):
tmp = inputs / self.scale
if self.use_mask :
tmp = tmp + cast(torch.Tensor, mask)
outputs = torch.nn.functional.softmax(tmp, dim=-1)
return outputs
```

Resulting error from using `cast`

:

```
RuntimeError:
builtin cannot be used as a value:
File "example.py", line 18
tmp = inputs / self.scale
if self.use_mask :
tmp = tmp + cast(torch.Tensor, mask)
~~~~~~~~~~~~ <--- HERE
outputs = torch.nn.functional.softmax(tmp, dim=-1)
```

**Full working code example that produces a fusion:**

```
import torch
from typing import Optional,Final,cast
class ExampleModule(torch.nn.Module):
use_mask : Final[bool]
def __init__(self, hidden_size, scale_factor, use_mask):
super().__init__()
self.scale = int(hidden_size / scale_factor)
self.use_mask = use_mask
def forward(
self,
inputs : torch.Tensor,
mask : torch.Tensor=torch.Tensor()
):
tmp = inputs / self.scale
if self.use_mask :
tmp = tmp + mask
outputs = torch.nn.functional.softmax(tmp, dim=-1)
return outputs
model = ExampleModule(1024, 16, True).to('cuda')
jit_model = torch.jit.script(model)
inputs = torch.randn([128, 16, 64, 64], device='cuda', requires_grad=True)
mask = torch.randn([128, 1, 1, 64], device='cuda', requires_grad=False)
with torch.jit.fuser("fuser2"):
for idx in range(2) :
if idx == 1 :
print(jit_model.forward.graph_for(inputs, mask))
out = jit_model.forward(inputs, mask)
```