I have been going through the exercise of taking a widely used model and converting it to TorchScript. In particular, I have been emphasizing attempting to convert the entire model to enable fusions across software blocks and making it easier for a user to simply wrap the model in torch.jit.script()
instead of having to limit themselves to specific parts.
One common idiom I have discovered in the HuggingFace Bert model that has caused me trouble is the usage of Optional
parameters in a Module’s forward
method. These parameters are used to describe the configuration of the Module via a runtime parameter even though the usage of the parameter is constant for that Module instance. Here is an example where the parameter mask
is either a Tensor or None
to convey the application of a mask on the input Tensor.
class ExampleModule(torch.nn.Module):
def __init__(self, hidden_size, scale_factor):
super().__init__()
self.scale = int(hidden_size / scale_factor)
def forward(
self,
inputs,
mask=None
):
tmp = inputs / self.scale
if mask is not None :
tmp = tmp + mask
outputs = torch.nn.functional.softmax(tmp, dim=-1)
return outputs
The expectation is that this Module should fuse. Out-of-the-box, it does not because the forward
method’s mask
parameter has a default value of None
. This confuses TorchScript because the mask can have a type of torch.Tensor
or the type None
which gives the following error.
RuntimeError:
Expected a default value of type Tensor (inferred) on parameter "mask".Because "mask" was not annotated with an explicit type it is assumed to be type 'Tensor'.:
This requires the first modification to the module which is to add Type hints given that all parameters are not only of type torch.Tensor
.
class ExampleModule(torch.nn.Module):
def __init__(self, hidden_size, scale_factor):
super().__init__()
self.scale = int(hidden_size / scale_factor)
def forward(
self,
inputs : torch.Tensor,
mask : Optional[torch.Tensor]=None
) -> torch.Tensor :
tmp = inputs / self.scale
if mask is not None :
tmp = tmp + mask
outputs = torch.nn.functional.softmax(tmp, dim=-1)
return outputs
While TorchScript now accepts the Module without error, it does not fuse the operators in the forward method. One way you can tell a fusion was not made is to print out the IR Graph that lacks a Fusion Group. You see a prim::If node that blocks fusion, instead.
graph(%self : __torch__.ExampleModule,
%inputs.1 : Tensor,
%mask.1 : Tensor?):
%5 : None = prim::Constant() # example.py:15:23
%4 : int = prim::Constant[value=1]()
%3 : int = prim::Constant[value=-1]() # example.py:17:55
%6 : int = prim::GetAttr[name="scale"](%self)
%tmp.2 : Tensor = aten::div(%inputs.1, %6) # example.py:14:14
%9 : bool = aten::__isnot__(%mask.1, %5) # example.py:15:11
%tmp : Tensor = prim::If(%9) # example.py:15:8
block0():
%mask.4 : Tensor = prim::unchecked_cast(%mask.1)
%tmp.3 : Tensor = aten::add_(%tmp.2, %mask.4, %4) # example.py:16:12
-> (%tmp.3)
block1():
-> (%tmp.2)
%ret.2 : Tensor = aten::softmax(%tmp, %3, %5) #
return (%ret.2)
The second code modification to enable fusion is to replace the check for None of the runtime parameter mask with a conditional class attribute that is marked as constant as suggested in the documentation.
class ExampleModule(torch.nn.Module):
use_mask : Final[bool]
def __init__(self, hidden_size, scale_factor, use_mask):
super().__init__()
self.scale = int(hidden_size / scale_factor)
self.use_mask = use_mask
def forward(
self,
inputs : torch.Tensor,
mask : Optional[torch.Tensor]=None
):
tmp = inputs / self.scale
if self.use_mask :
tmp = tmp + mask
outputs = torch.nn.functional.softmax(tmp, dim=-1)
return outputs
This modification fails to fuse as well and trips an error in TorchScript as the mask parameter is seen as type Optional[torch.Tensor]
instead of torch.Tensor
.
RuntimeError:
Arguments for call are not valid.
The following variants are available:
aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> (Tensor):
Expected a value of type 'Tensor' for argument 'other' but instead found type 'Optional[Tensor]'.
aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor):
Expected a value of type 'number' for argument 'other' but instead found type 'Optional[Tensor]'.
aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> (Tensor(a!)):
Expected a value of type 'Tensor' for argument 'other' but instead found type 'Optional[Tensor]'.
aten::add.t(t[] a, t[] b) -> (t[]):
Could not match type Tensor to List[t] in argument 'a': Cannot match List[t] to Tensor.
aten::add.str(str a, str b) -> (str):
Expected a value of type 'str' for argument 'a' but instead found type 'Tensor'.
aten::add.int(int a, int b) -> (int):
Expected a value of type 'int' for argument 'b' but instead found type 'Optional[Tensor]'.
aten::add.float(float a, float b) -> (float):
Expected a value of type 'float' for argument 'b' but instead found type 'Optional[Tensor]'.
aten::add.int_float(int a, float b) -> (float):
Expected a value of type 'float' for argument 'b' but instead found type 'Optional[Tensor]'.
aten::add.float_int(float a, int b) -> (float):
Expected a value of type 'int' for argument 'b' but instead found type 'Optional[Tensor]'.
aten::add(Scalar a, Scalar b) -> (Scalar):
Expected a value of type 'number' for argument 'b' but instead found type 'Optional[Tensor]'.
add(float a, Tensor b) -> (Tensor):
Expected a value of type 'Tensor' for argument 'b' but instead found type 'Optional[Tensor]'.
add(int a, Tensor b) -> (Tensor):
Expected a value of type 'Tensor' for argument 'b' but instead found type 'Optional[Tensor]'.
The original call is:
File "example.py", line 18
tmp = inputs / self.scale
if self.use_mask :
tmp = tmp + mask
~~~~~~~~~~ <--- HERE
outputs = torch.nn.functional.softmax(tmp, dim=-1)
One option to resolve this problem is to apply an assert
on the mask
if it is of type None
as the assert
strips out the inner type. However, the result is identical to a conditional on None
where a fusion is blocked. Another option is to make the parameter default a zero dimension Tensor. This works, but is it the best solution?
class ExampleModule(torch.nn.Module):
use_mask : Final[bool]
def __init__(self, hidden_size, scale_factor, use_mask):
super().__init__()
self.scale = int(hidden_size / scale_factor)
self.use_mask = use_mask
def forward(
self,
inputs : torch.Tensor,
mask : torch.Tensor=torch.Tensor()
):
tmp = inputs / self.scale
if self.use_mask :
tmp = tmp + mask
outputs = torch.nn.functional.softmax(tmp, dim=-1)
return outputs
If I have a parameter that is only conditionally used, it seems more natural to leave it as an Optional
typed argument but it seems like there is no way to do this that will allow a fusion to occur. One idea was to cast
the Optional
to a Tensor if the conditional is True but that doesn’t work. Perhaps, having TorchScript support cast would be better? Are there better suggestions in how to handle Optional parameters such that they produce a fusion?
class ExampleModule(torch.nn.Module):
use_mask : Final[bool]
def __init__(self, hidden_size, scale_factor, use_mask):
super().__init__()
self.scale = int(hidden_size / scale_factor)
self.use_mask = use_mask
def forward(
self,
inputs : torch.Tensor,
mask : Optional[torch.Tensor]=None
):
tmp = inputs / self.scale
if self.use_mask :
tmp = tmp + cast(torch.Tensor, mask)
outputs = torch.nn.functional.softmax(tmp, dim=-1)
return outputs
Resulting error from using cast
:
RuntimeError:
builtin cannot be used as a value:
File "example.py", line 18
tmp = inputs / self.scale
if self.use_mask :
tmp = tmp + cast(torch.Tensor, mask)
~~~~~~~~~~~~ <--- HERE
outputs = torch.nn.functional.softmax(tmp, dim=-1)
Full working code example that produces a fusion:
import torch
from typing import Optional,Final,cast
class ExampleModule(torch.nn.Module):
use_mask : Final[bool]
def __init__(self, hidden_size, scale_factor, use_mask):
super().__init__()
self.scale = int(hidden_size / scale_factor)
self.use_mask = use_mask
def forward(
self,
inputs : torch.Tensor,
mask : torch.Tensor=torch.Tensor()
):
tmp = inputs / self.scale
if self.use_mask :
tmp = tmp + mask
outputs = torch.nn.functional.softmax(tmp, dim=-1)
return outputs
model = ExampleModule(1024, 16, True).to('cuda')
jit_model = torch.jit.script(model)
inputs = torch.randn([128, 16, 64, 64], device='cuda', requires_grad=True)
mask = torch.randn([128, 1, 1, 64], device='cuda', requires_grad=False)
with torch.jit.fuser("fuser2"):
for idx in range(2) :
if idx == 1 :
print(jit_model.forward.graph_for(inputs, mask))
out = jit_model.forward(inputs, mask)