Hey folks, I’m struggling to produce an OutputKind.BUFFER_MUTATION in the signature of a torch.export’ed program. I swear this used to work.
I took the exact example from the docs:
If I run this script:
import torch
class MyMod(torch.nn.Module):
def __init__(self) -> None:
super(MyMod, self).__init__()
# Define a parameter
self.my_parameter = torch.nn.Parameter(torch.tensor(2.0))
# Define two buffers
self.register_buffer('my_buffer1', torch.tensor(3.0))
self.register_buffer('my_buffer2', torch.tensor(4.0))
def forward(self, x1, x2):
# Use the parameter, buffers, and both inputs in the forward method
output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2
# Mutate one of the buffers (e.g., increment it by 1)
self.my_buffer2.add_(1.0) # In-place addition
return output
exported_program = torch.export.export(MyMod(), (torch.zeros([1]), torch.zeros([1])))
print(exported_program.graph_signature)
The documentation there (and my expectation was) that I would get:
output_specs=[
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'),
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)
]
But instead I get
output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None)]
With no BUFFER_MUTATION.
The reason seems to be that the graph isn’t functionalized – there is still an aten.add_ in the graph.
graph():
%p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter]
%b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1]
%b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2]
%x1 : [num_users=1] = placeholder[target=x1]
%x2 : [num_users=1] = placeholder[target=x2]
%add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {})
%mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {})
%mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {})
%add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {})
%add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {})
return (add_1,)
This is with torch==2.7.1
. So I guess my question is what changed (and why) and is there a way of getting the previous functionalized graph?