What’s the rationale for disabling mutations to user inputs while tracing via torch.export
?
Is it a soundness issue? Not implemented? Something else?
Context: I’m trying to run a modified version of Llama through torch.export
where the Attention module needs to update the kv-cache, which is in turn given as an argument of forward
.
The referenced restriction precludes the AOT flow.
e.g.
class AdvancedIndexingModule(nn.Module):
def __init__(self, x):
super().__init__()
self.x = x
def forward(self, x, y):
x[:10, 5 : 10 + 5] = y
return self.x
module = AdvancedIndexingModule(torch.rand(20, 20))
inputs = (
torch.rand(20, 20),
torch.rand(10, 10),
)
ep = export.export(module, inputs)
print(ep)
produces
Traceback (most recent call last):
File "/home/anieto/Groq/Groq/Compiler/test/end_to_end/demos/llm_tools/test.py", line 33, in <module>
ep = export.export(module, inputs)
File "/home/anieto/.local/lib/python3.10/site-packages/torch/export/__init__.py", line 449, in export
return export__RC__(
File "/home/anieto/.local/lib/python3.10/site-packages/torch/_export/__init__.py", line 258, in export__RC__
return _export(
File "/home/anieto/.local/lib/python3.10/site-packages/torch/_export/__init__.py", line 567, in wrapper
return fn(*args, **kwargs)
File "/home/anieto/.local/lib/python3.10/site-packages/torch/_export/__init__.py", line 713, in _export
gm, graph_signature = aot_export_module(
File "/home/anieto/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 5157, in aot_export_module
fx_g, metadata, in_spec, out_spec = _aot_export_function(
File "/home/anieto/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 5339, in _aot_export_function
fx_g, meta = create_aot_dispatcher_function(
File "/home/anieto/.local/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
r = func(*args, **kwargs)
File "/home/anieto/.local/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 4569, in create_aot_dispatcher_function
raise RuntimeError(f"""
RuntimeError:
Found following user inputs located at [0] are mutated. This is currently banned in the aot_export workflow.
If you need this functionality, please file a github issue.