Hi,
Is there any plan to support _set with other mutations in graph?
For below reproducer there’s:
AssertionError: Encountered a set_ on a graph input, but the input has other mutations that we cannot keep in the graph. This is not supported today. Current state:
keep_input_mutations=True
mutates_data=False
mutates_metadata=True
mutations_hidden_from_autograd=False
mutations_under_no_grad_or_inference_mode=False
mutation_inductor_storage_resize=False
requires_grad=True
import torch
import torch.nn as nn
class SomeModule(nn.Module):
def __init__(self):
super().__init__()
self.param = nn.Parameter(torch.randn(5, 5))
def forward(self, x):
r1 = torch.empty(self.param.shape, dtype=self.param.dtype, device=self.param.device)
r1.copy_(x)
self.param.data = r1.data
return self.param
model = SomeModule()
compiled_model = torch.compile(model)
input_tensor = torch.randn(5,5)
output = compiled_model(input_tensor)