If a module mutates it’s buffers in it’s forward, on torch.export.export; one would expect it to functionalize the graph and returns a ExportedProgram with functional graph along with extra return values (as specified in https://github.com/pytorch/pytorch/blob/010064159b347369fd0212e69ff4618266b2ac9e/torch/export/__init__.py#L165).
However, this only works with explicit calls to inplace ops such as add_
, but not with assignments:
For example: the below raises AssertionError: Mutating module attribute a during export.
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.randn((100, 100))
def forward(self, b):
self.a = self.a + b
return self.a
exported = export(M(), (torch.randn(100, 100), ))
print(exported.graph_module.code)
One can argue that, self.a = self.a + b
is not the same as self.a.add_(b)
as the former rebinds the reference of self.a
. However, if replacing that line with self.a[:, 1] = torch.ones((100, ))
it also fails with the same error, even though the reference (id(self.a)
) did not change in this case.
The latter pattern is very useful in LLMs, for example, Llama2 uses kv cache (https://github.com/facebookresearch/llama/blob/main/llama/model.py#L164) and currently is not exportable as-is (even after replacing all the fairscale layers with vanilla torch.nn.* equivalent).
Ideally, we should capture both rebind and in-place modification in functionalization pass (and produce extra returns). However, dynamo runs before aot_autograd in export, and dynamo tracing seems to treat the assignment operator as changing of reference, and doesn’t like that.
My questions are:
-
Whether supporting assignment is desired? Based on these previous issues: (torchdynamo.export error message is not clear. · Issue #1475 · pytorch/torchdynamo · GitHub ; Disallow module attribute mutation by mergennachin · Pull Request #88354 · pytorch/pytorch · GitHub) seems like the answer is “no”?
-
If previous answer is hard no can we solve the case of in-place assignment like the case of llama2?
Thanks!