Dynamo Graph Capture can't get `get_attr` node?

I’m exploring the differences in graph capturing behavior between using torchdynamo and torch.fx symbolic_trace. Specifically, I’ve noticed that when tracing models using Dynamo, the get_attr nodes often get automatically converted to placeholder nodes.

Here’s a simple runnable demo that illustrates what I’m observing:

import torch
from torch.fx import GraphModule, Tracer, Graph
from torch.export import export

class CustomModel(torch.nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.w1 = torch.nn.Parameter(torch.empty(hidden_size, hidden_size))

    def forward(self, x):
        x = torch.mm(x, self.w1)
        # x = gelu(x)  # Uncomment for non-linear operations
        x = x.sum()
        return (x,)

if __name__ == "__main__":
    torch.set_default_dtype(torch.bfloat16)
    with torch.device("meta"):
        hidden_size = 1024
        model = CustomModel(hidden_size)
        inp = torch.zeros(2, hidden_size, requires_grad=True)
        
        tracer = Tracer()
        graph = tracer.trace(model)
        graph.print_tabular()
        
        exported_program: torch.export.ExportedProgram = export(model, args=(inp,))
        gm = exported_program.graph_module
        gm.graph.print_tabular()

The output using torch.fx directly vs. using export (which I presume uses Dynamo internally) shows different behaviors. In the FX trace, parameters are retained as get_attr, whereas in the Dynamo-based trace, they are converted to placeholder.

Output using FX:

opcode         name    target                                                 args         kwargs
-------------  ------  -----------------------------------------------------  -----------  --------
placeholder    x       x                                                      ()           {}
get_attr       w1      w1                                                     ()           {}
call_function  mm      <built-in method mm of type object at 0x7f661c873500>  (x, w1)      {}
call_method    sum_1   sum                                                    (mm,)        {}
output         output  output                                                 ((sum_1,),)  {}

Output using Dynamo:

opcode         name    target            args         kwargs
-------------  ------  ----------------  -----------  --------
placeholder    p_w1    p_w1              ()           {}
placeholder    x       x                 ()           {}
call_function  mm      aten.mm.default   (x, p_w1)    {}
call_function  sum_1   aten.sum.default  (mm,)        {}
output         output  output            ((sum_1,),)  {}

I am curious why this happens, and if there’s a way to control or prevent this behavior when using Dynamo. Any insights or recommendations on how to handle this discrepancy would be greatly appreciated.

Dynamo is “promoting” all weights into inputs, to generate a “functional” graph. This is why you loose all the get_attr nodes. All the tensors are inputs now.

Thank you for your answer!

Indeed, I currently have a requirement to differentiate between model inputs and model parameters in the generated graph. (Further, maybe I need to distinguish between parameters from different modules, such as those from attention mechanisms and MLPs.)

Dynamo along with AOTAutograd names all placeholders in the format of arg0_1, making it challenging to determine their specific origins. Although I can retrieve tensor metadata for these placeholders using node.meta, it’s still difficult to distinguish between model inputs and parameters when they share identical shapes. Do you know any method to achieve this distinction?

I would check torch.export.unflatten(exported), it will return the get_attr nodes, and may fit your use case

It helps a lot. Thanks

Update:

Dynamo will automatically functionalize the graph, meaning that all input parameters and buffers are treated as graph inputs, and the entire graph is seen as a large forward function.

If you use export to get an ExportedProgram type, you can then call torch.export.unflatten(exported) to transform the graph back into an UnflattenedModule. In this case, the graph structure will have different hierarchical levels, such as getattr and call_module.

However, if you want to capture both the forward and backward (joint) graphs, you would use aot_export_module, which returns a torch.fx.GraphModule and a GraphSignature. This approach uses Dynamo and AOTAutoGrad to capture the graph at the lower level.

At this point, if we want to get an ExportedProgram and use unflatten, it is not feasible because the TreeSpec information in the GraphSignature is actually empty.

Nonetheless, the GraphSignature contains information like inputs_to_parameters, so we can still manually obtain the source of the placeholders, but we cannot rebuild the submodule structure.

1 Like