Is it possible to disable inlining of custom module for torch.compile?

Hi, all. I want to substitute a custom module or op with my own implementation. And I want to achieve this through custom backend.

However, according to my experiment and the doc here: Introduction to torch.compile — PyTorch Tutorials 2.4.0+cu121 documentation. It seems that the torch.fx.GraphModule the torch given to the custom backend would inline the custom module.

Module Code

class CustomModule(torch.nn.Module):

    def __init__(self, scale):
        super().__init__()
        self.scale = scale

    def forward(self, q, k, v):
        thing =  torch.softmax(q @ k * self.scale, dim=-1) @ v
        return thing

class CustomModel(torch.nn.Module):

    def __init__(self, head_num=8, scale=0.125):
        super().__init__()
        self.proj = torch.nn.Linear(32, 32 * 3)
        self.head_num = head_num
        self.dim = 32 // head_num
        self.attn = CustomModule(scale)

    def forward(self, x: torch.Tensor):
        B = x.shape[0]
        L = x.shape[1]
        x = self.proj(x)
        q, k, v = x.chunk(3, -1)

        def transpose_for_score(x: torch.Tensor):
            x = x.view(B, L, self.head_num, -1).permute(0, 2, 1, 3)
            return x

        q = transpose_for_score(q)
        k = transpose_for_score(k)
        v = transpose_for_score(v)

        k = k.permute(0, 1, 3, 2)

        res = self.attn(q, k, v)

        return res.permute(0, 2, 1, 3).reshape(B, L, -1)

Output:

opcode         name       target                                                      args                     kwargs
-------------  ---------  ----------------------------------------------------------  -----------------------  -----------
placeholder    l_x_       L_x_                                                        ()                       {}
call_module    x          L__self___proj                                              (l_x_,)                  {}
call_method    chunk      chunk                                                       (x, 3, -1)               {}
call_function  q          <built-in function getitem>                                 (chunk, 0)               {}
call_function  k          <built-in function getitem>                                 (chunk, 1)               {}
call_function  v          <built-in function getitem>                                 (chunk, 2)               {}
call_method    view       view                                                        (q, 16, 32, 8, -1)       {}
call_method    x_1        permute                                                     (view, 0, 2, 1, 3)       {}
call_method    view_1     view                                                        (k, 16, 32, 8, -1)       {}
call_method    x_2        permute                                                     (view_1, 0, 2, 1, 3)     {}
call_method    view_2     view                                                        (v, 16, 32, 8, -1)       {}
call_method    x_3        permute                                                     (view_2, 0, 2, 1, 3)     {}
call_method    k_1        permute                                                     (x_2, 0, 1, 3, 2)        {}
call_function  matmul     <built-in function matmul>                                  (x_1, k_1)               {}
call_function  mul        <built-in function mul>                                     (matmul, 0.125)          {}
call_function  softmax    <built-in method softmax of type object at 0x7fe48783a500>  (mul,)                   {'dim': -1}
call_function  res        <built-in function matmul>                                  (softmax, x_3)           {}
call_method    permute_4  permute                                                     (res, 0, 2, 1, 3)        {}
call_method    reshape    reshape                                                     (permute_4, 16, 32, -1)  {}
output         output     output                                                      ((reshape,),)            {}

Is it possible to let torch.dynamo to treat the custom module as built-in module to disable the inlining? (like the behavior of self.proj, it is not inlined.)

I think you might be looking for custom ops:
https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html

There is also allow_in_graph:
https://pytorch.org/docs/stable/generated/torch.compiler.allow_in_graph.html

1 Like