Update: One solution to do this is using custom ops definition and wrap it in autograd.Function.
Pytorch 2.4.0 is providing some new APIs to define custom ops more conveniently, while older version API still works.
I am also providing some runnable codes as a demo. It works well to me (Pytorch 2.3.0)
import numpy as np
import torch
from torch._functorch.aot_autograd import aot_export_module
from torch._functorch.partitioners import default_partition
# Define forward op
torch.library.define("mylib::bar", "(Tensor x) -> Tensor")
@torch.library.impl("mylib::bar", "default")
def bar_impl(x):
x_np = x.detach().numpy()
z_np = np.multiply(x_np, x_np)
return torch.tensor(z_np)
@torch.library.impl_abstract("mylib::bar")
def bar_abstract(x):
return torch.empty_like(x)
# Define backward op
torch.library.define("mylib::bar_backward", "(Tensor grad, Tensor x) -> Tensor")
@torch.library.impl("mylib::bar_backward", "default")
def bar_backward(grad, x):
grad_x = torch.ops.mylib.bar.default(grad + x)
return grad_x
@torch.library.impl_abstract("mylib::bar_backward")
def bar_backward_abstract(grad, x):
return torch.empty_like(x)
# Create an autograd.Function with the forward and backward
class CustomFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return torch.ops.mylib.bar.default(x)
@staticmethod
def backward(ctx, grad):
x = ctx.saved_tensors[0]
return torch.ops.mylib.bar_backward.default(grad, x)
def custom_func(x):
return CustomFunc.apply(x)
class CustomModel(torch.nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.w1 = torch.nn.Parameter(torch.empty(hidden_size, hidden_size))
def forward(self, x):
x = custom_func(x)
x = torch.mm(x, self.w1)
x = custom_func(x)
x = x.sum()
return (x,)
if __name__ == "__main__":
hidden_size = 1024
model = CustomModel(hidden_size)
inp = torch.zeros(2, hidden_size, requires_grad=False)
m, _ = aot_export_module(model, [inp], trace_joint=True, output_loss_index=0, decompositions=None)
fwd, bwd = default_partition(m, [inp], num_fwd_outputs=1)
fwd.graph.print_tabular()
bwd.graph.print_tabular()
And it will print two graphs:
opcode name target args kwargs
------------- ------ ----------------- -------------------------- --------
placeholder arg0_1 arg0_1 () {}
placeholder arg1_1 arg1_1 () {}
call_function bar mylib.bar.default (arg1_1,) {}
call_function mm aten.mm.default (bar, arg0_1) {}
call_function bar_1 mylib.bar.default (mm,) {}
call_function sum_1 aten.sum.default (bar_1,) {}
output output output ([sum_1, bar, mm, sum_1],) {}
opcode name target args kwargs
------------- ------------ -------------------------- ---------------------- -------------------------------------------------------------
placeholder bar bar () {}
placeholder mm mm () {}
placeholder sum_1 sum_1 () {}
call_function ones_like aten.ones_like.default (sum_1,) {'pin_memory': False, 'memory_format': torch.preserve_format}
call_function expand aten.expand.default (ones_like, [2, 1024]) {}
call_function bar_backward mylib.bar_backward.default (expand, mm) {}
call_function t aten.t.default (bar,) {}
call_function mm_1 aten.mm.default (t, bar_backward) {}
output output output ([mm_1],) {}
And custom ops are wrapped correctly.