If a module mutates it’s buffers in it’s forward, on torch.export.export; one would expect it to functionalize the graph and returns a ExportedProgram with functional graph along with extra return values (as specified in https://github.com/pytorch/pytorch/blob/010064159b347369fd0212e69ff4618266b2ac9e/torch/export/__init__.py#L165).
However, this only works with explicit calls to inplace ops such as
add_, but not with assignments:
For example: the below raises
AssertionError: Mutating module attribute a during export.
class M(torch.nn.Module): def __init__(self): super().__init__() self.a = torch.randn((100, 100)) def forward(self, b): self.a = self.a + b return self.a exported = export(M(), (torch.randn(100, 100), )) print(exported.graph_module.code)
One can argue that,
self.a = self.a + b is not the same as
self.a.add_(b) as the former rebinds the reference of
self.a. However, if replacing that line with
self.a[:, 1] = torch.ones((100, )) it also fails with the same error, even though the reference (
id(self.a)) did not change in this case.
The latter pattern is very useful in LLMs, for example, Llama2 uses kv cache (https://github.com/facebookresearch/llama/blob/main/llama/model.py#L164) and currently is not exportable as-is (even after replacing all the fairscale layers with vanilla torch.nn.* equivalent).
Ideally, we should capture both rebind and in-place modification in functionalization pass (and produce extra returns). However, dynamo runs before aot_autograd in export, and dynamo tracing seems to treat the assignment operator as changing of reference, and doesn’t like that.
My questions are:
Whether supporting assignment is desired? Based on these previous issues: (torchdynamo.export error message is not clear. · Issue #1475 · pytorch/torchdynamo · GitHub ; Disallow module attribute mutation by mergennachin · Pull Request #88354 · pytorch/pytorch · GitHub) seems like the answer is “no”?
If previous answer is hard no can we solve the case of in-place assignment like the case of llama2?