How to set wrap function using TorchDynamo graph capture?

I am trying to capture both forward and backward graphs using:

m, _ = aot_export_module(fwd, [inp], trace_joint=True, output_loss_index=0, decompositions=None)
fwd, bwd = default_partition(m, [inp], num_fwd_outputs=1)

And the functions will be decomposed and the nodes in these graphs will be at level of aten.op.

But I actually need some functions remain in the graphs. So I wonder if there are any methods to set some functions as black boxes.

In previous way of tracing a graph: fx.Tracer, there is a parameter autowrap_functions to avoid function decomposition. And we could also register a function using @torch.fx.wrap to avoid it. (torch.fx — PyTorch 2.3 documentation )

Hi, you could try allow_in_graph, this might help.

Will try, thank you.

Update: One solution to do this is using custom ops definition and wrap it in autograd.Function.

Pytorch 2.4.0 is providing some new APIs to define custom ops more conveniently, while older version API still works.

I am also providing some runnable codes as a demo. It works well to me (Pytorch 2.3.0)

import numpy as np
import torch
from torch._functorch.aot_autograd import aot_export_module
from torch._functorch.partitioners import default_partition

# Define forward op
torch.library.define("mylib::bar", "(Tensor x) -> Tensor")

@torch.library.impl("mylib::bar", "default")
def bar_impl(x):
    x_np = x.detach().numpy()
    z_np = np.multiply(x_np, x_np)
    return torch.tensor(z_np)

@torch.library.impl_abstract("mylib::bar")
def bar_abstract(x):
    return torch.empty_like(x)


# Define backward op
torch.library.define("mylib::bar_backward", "(Tensor grad, Tensor x) -> Tensor")

@torch.library.impl("mylib::bar_backward", "default")
def bar_backward(grad, x):
    grad_x = torch.ops.mylib.bar.default(grad + x)
    return grad_x

@torch.library.impl_abstract("mylib::bar_backward")
def bar_backward_abstract(grad, x):
    return torch.empty_like(x)


# Create an autograd.Function with the forward and backward
class CustomFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return torch.ops.mylib.bar.default(x)

    @staticmethod
    def backward(ctx, grad):
        x = ctx.saved_tensors[0]
        return torch.ops.mylib.bar_backward.default(grad, x)

def custom_func(x):
    return CustomFunc.apply(x)


class CustomModel(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.w1 = torch.nn.Parameter(torch.empty(hidden_size, hidden_size))

    def forward(self, x):
        x = custom_func(x)
        x = torch.mm(x, self.w1)
        x = custom_func(x)
        x = x.sum()
        return (x,)


if __name__ == "__main__":
    hidden_size = 1024
    model = CustomModel(hidden_size)
    inp = torch.zeros(2, hidden_size, requires_grad=False)
    m, _ = aot_export_module(model, [inp], trace_joint=True, output_loss_index=0, decompositions=None)
    fwd, bwd = default_partition(m, [inp], num_fwd_outputs=1)
    
    fwd.graph.print_tabular()
    bwd.graph.print_tabular()

And it will print two graphs:

opcode         name    target             args                        kwargs
-------------  ------  -----------------  --------------------------  --------
placeholder    arg0_1  arg0_1             ()                          {}
placeholder    arg1_1  arg1_1             ()                          {}
call_function  bar     mylib.bar.default  (arg1_1,)                   {}
call_function  mm      aten.mm.default    (bar, arg0_1)               {}
call_function  bar_1   mylib.bar.default  (mm,)                       {}
call_function  sum_1   aten.sum.default   (bar_1,)                    {}
output         output  output             ([sum_1, bar, mm, sum_1],)  {}
opcode         name          target                      args                    kwargs
-------------  ------------  --------------------------  ----------------------  -------------------------------------------------------------
placeholder    bar           bar                         ()                      {}
placeholder    mm            mm                          ()                      {}
placeholder    sum_1         sum_1                       ()                      {}
call_function  ones_like     aten.ones_like.default      (sum_1,)                {'pin_memory': False, 'memory_format': torch.preserve_format}
call_function  expand        aten.expand.default         (ones_like, [2, 1024])  {}
call_function  bar_backward  mylib.bar_backward.default  (expand, mm)            {}
call_function  t             aten.t.default              (bar,)                  {}
call_function  mm_1          aten.mm.default             (t, bar_backward)       {}
output         output        output                      ([mm_1],)               {}

And custom ops are wrapped correctly.

I think it works for forward graph, but not for backward graph, while AOTAutograd will decompose the function even setting it allow_in_graph. Thanks anyway!

ref: