Supporting mutations in torch.export.export

If a module mutates it’s buffers in it’s forward, on torch.export.export; one would expect it to functionalize the graph and returns a ExportedProgram with functional graph along with extra return values (as specified in https://github.com/pytorch/pytorch/blob/010064159b347369fd0212e69ff4618266b2ac9e/torch/export/__init__.py#L165).

However, this only works with explicit calls to inplace ops such as add_, but not with assignments:

For example: the below raises AssertionError: Mutating module attribute a during export.

class M(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.a = torch.randn((100, 100))

    def forward(self, b):
        self.a = self.a + b
        return self.a

exported = export(M(), (torch.randn(100, 100), ))
print(exported.graph_module.code)

One can argue that, self.a = self.a + b is not the same as self.a.add_(b) as the former rebinds the reference of self.a. However, if replacing that line with self.a[:, 1] = torch.ones((100, )) it also fails with the same error, even though the reference (id(self.a)) did not change in this case.

The latter pattern is very useful in LLMs, for example, Llama2 uses kv cache (https://github.com/facebookresearch/llama/blob/main/llama/model.py#L164) and currently is not exportable as-is (even after replacing all the fairscale layers with vanilla torch.nn.* equivalent).

Ideally, we should capture both rebind and in-place modification in functionalization pass (and produce extra returns). However, dynamo runs before aot_autograd in export, and dynamo tracing seems to treat the assignment operator as changing of reference, and doesn’t like that.

My questions are:

Thanks!

It is actually desired to support it, should be ready in the next few weeks.

2 Likes

This seems to still be not supported as of Pytorch 2.2. Any idea when it might be supported?