Hi, all. I want to substitute a custom module or op with my own implementation. And I want to achieve this through custom backend.
However, according to my experiment and the doc here: Introduction to torch.compile — PyTorch Tutorials 2.4.0+cu121 documentation. It seems that the torch.fx.GraphModule the torch given to the custom backend would inline the custom module.
Module Code
class CustomModule(torch.nn.Module):
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, q, k, v):
thing = torch.softmax(q @ k * self.scale, dim=-1) @ v
return thing
class CustomModel(torch.nn.Module):
def __init__(self, head_num=8, scale=0.125):
super().__init__()
self.proj = torch.nn.Linear(32, 32 * 3)
self.head_num = head_num
self.dim = 32 // head_num
self.attn = CustomModule(scale)
def forward(self, x: torch.Tensor):
B = x.shape[0]
L = x.shape[1]
x = self.proj(x)
q, k, v = x.chunk(3, -1)
def transpose_for_score(x: torch.Tensor):
x = x.view(B, L, self.head_num, -1).permute(0, 2, 1, 3)
return x
q = transpose_for_score(q)
k = transpose_for_score(k)
v = transpose_for_score(v)
k = k.permute(0, 1, 3, 2)
res = self.attn(q, k, v)
return res.permute(0, 2, 1, 3).reshape(B, L, -1)
Output:
opcode name target args kwargs
------------- --------- ---------------------------------------------------------- ----------------------- -----------
placeholder l_x_ L_x_ () {}
call_module x L__self___proj (l_x_,) {}
call_method chunk chunk (x, 3, -1) {}
call_function q <built-in function getitem> (chunk, 0) {}
call_function k <built-in function getitem> (chunk, 1) {}
call_function v <built-in function getitem> (chunk, 2) {}
call_method view view (q, 16, 32, 8, -1) {}
call_method x_1 permute (view, 0, 2, 1, 3) {}
call_method view_1 view (k, 16, 32, 8, -1) {}
call_method x_2 permute (view_1, 0, 2, 1, 3) {}
call_method view_2 view (v, 16, 32, 8, -1) {}
call_method x_3 permute (view_2, 0, 2, 1, 3) {}
call_method k_1 permute (x_2, 0, 1, 3, 2) {}
call_function matmul <built-in function matmul> (x_1, k_1) {}
call_function mul <built-in function mul> (matmul, 0.125) {}
call_function softmax <built-in method softmax of type object at 0x7fe48783a500> (mul,) {'dim': -1}
call_function res <built-in function matmul> (softmax, x_3) {}
call_method permute_4 permute (res, 0, 2, 1, 3) {}
call_method reshape reshape (permute_4, 16, 32, -1) {}
output output output ((reshape,),) {}
Is it possible to let torch.dynamo to treat the custom module as built-in module to disable the inlining? (like the behavior of self.proj, it is not inlined.)