Hi,
I’ve been playing a bit with Dynamo+FX and I wanted to prototype a transformation that increases the number of outputs of a graph.
Example:
import operator
from typing import List
import torch
import torch._dynamo
def mycompiler(gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
print(gm.code)
new_output = ()
for node in gm.graph.nodes:
if node.op == 'call_function' and node.target == operator.iadd:
assert(len(node.args) == 2)
with gm.graph.inserting_before(node):
for arg in node.args:
n = gm.graph.call_method('clone', (arg,))
new_output += (n,)
elif node.op == 'output':
node.args = (node.args[0] + new_output,)
gm.graph.lint()
gm.recompile()
print(gm.graph)
print(gm.code)
return torch.jit.trace(gm, inputs)
def fn(x, y):
x += y
x += y
return x.add(y), y.sub(x)
model = torch._dynamo.optimize(mycompiler, nopython=True)(fn)
print(model(torch.rand(1), torch.rand(1)))
The original code is:
def forward(self, x : torch.Tensor, y : torch.Tensor):
x += y; iadd = x; x = None
iadd += y; iadd_1 = iadd; iadd = None
add = iadd_1.add(y)
sub = y.sub(iadd_1); y = iadd_1 = None
return (add, sub)
The transformed code looks as expected:
def forward(self, x : torch.Tensor, y : torch.Tensor):
clone = x.clone()
clone_1 = y.clone()
x += y; iadd = x; x = None
clone_2 = iadd.clone()
clone_3 = y.clone()
iadd += y; iadd_1 = iadd; iadd = None
add = iadd_1.add(y)
sub = y.sub(iadd_1); y = iadd_1 = None
return (add, sub, clone, clone_1, clone_2, clone_3)
However, the function returns only the original 2 outputs when executed, not the 6 as expected.
Is this a bug or a feature that prevents users from changing the number of outputs of a function?
Thank you!
Nuno