Dynamo/FX: patching a function to add more outputs not working

Hi,

I’ve been playing a bit with Dynamo+FX and I wanted to prototype a transformation that increases the number of outputs of a graph.
Example:

import operator
from typing import List
import torch
import torch._dynamo

def mycompiler(gm: torch.fx.GraphModule, inputs: List[torch.Tensor]):
    print(gm.code)
    new_output = ()

    for node in gm.graph.nodes:
        if node.op == 'call_function' and node.target == operator.iadd:
            assert(len(node.args) == 2)
            with gm.graph.inserting_before(node):
                for arg in node.args:
                    n = gm.graph.call_method('clone', (arg,))
                    new_output += (n,)
        elif node.op == 'output':
            node.args = (node.args[0] + new_output,)

    gm.graph.lint()
    gm.recompile()
    print(gm.graph)
    print(gm.code)
    return torch.jit.trace(gm, inputs)

def fn(x, y):
    x += y
    x += y
    return x.add(y), y.sub(x)


model = torch._dynamo.optimize(mycompiler, nopython=True)(fn)
print(model(torch.rand(1), torch.rand(1)))

The original code is:

def forward(self, x : torch.Tensor, y : torch.Tensor):
    x += y;  iadd = x;  x = None
    iadd += y;  iadd_1 = iadd;  iadd = None
    add = iadd_1.add(y)
    sub = y.sub(iadd_1);  y = iadd_1 = None
    return (add, sub)

The transformed code looks as expected:

def forward(self, x : torch.Tensor, y : torch.Tensor):
    clone = x.clone()
    clone_1 = y.clone()
    x += y;  iadd = x;  x = None
    clone_2 = iadd.clone()
    clone_3 = y.clone()
    iadd += y;  iadd_1 = iadd;  iadd = None
    add = iadd_1.add(y)
    sub = y.sub(iadd_1);  y = iadd_1 = None
    return (add, sub, clone, clone_1, clone_2, clone_3)

However, the function returns only the original 2 outputs when executed, not the 6 as expected.

Is this a bug or a feature that prevents users from changing the number of outputs of a function?

Thank you!
Nuno

Without nopython this is by design (though we probably should error here), since the outputs are unpredictably assigned to various Python locals so we have no idea what to do with the extra outputs. In export mode I suppose this could work; try using export instead

Thanks!
Removing the nopython option and using export didn’t make a difference, though:

gm = torch._dynamo.export(model, torch.rand(1), torch.rand(1))[0]
gm.print_readable()

Prints:

class GraphModule(torch.nn.Module):
    def forward(self, arg0, arg1):
        # File: experiments.py:36, code: x += y
        arg0 += arg1;  iadd = arg0;  arg0 = None

        # File: experiments.py:37, code: x += y
        iadd += arg1;  iadd_1 = iadd;  iadd = None

        # File: experiments.py:38, code: return x.add(y), y.sub(x)
        add = iadd_1.add(arg1)
        sub = arg1.sub(iadd_1);  arg1 = iadd_1 = None
        return (add, sub)

It’s pretty weird, but… the instrumentation code is gone altogether. Any clues of what might be happening?
I wanted to use Dynamo/FX for some exploratory research, hence I’m exploring ways to automatically instrument it and log some data.