I previously posted this topic on PyTorch Forums but I think this is a topic best suited for the dev forum.
I’m trying to understand the scope of torch.fx.traceback.preserve_node_meta()
. Can someone please explain the scope of torch.fx.Node
’s meta preservation in the context of the following example?
# A simple torch module
class SimpleAdd(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, y):
return x + y
class SimpleTransform(torch.fx.Transformer):
'''
A simple FX transformer
'''
def __init__(self, module):
super().__init__(module)
def run_node(self, n: torch.fx.Node):
'''
Traverse FX graph and add a custom meta field for each node
'''
n.meta["node_op"] = n.op
return super().run_node(n)
model = SimpleAdd()
with torch.fx.traceback.preserve_node_meta():
fxg = torch.fx.symbolic_trace(model)
# node meta prior to transform
for node in fxg.graph.nodes:
"""
Current node meta
x {'seq_nr': -1}
y {'seq_nr': -1}
add {'seq_nr': -1}
output {'seq_nr': -1}
"""
print(node, node.meta)
# node traversal to add new meta
SimpleTransform(fxg).run()
for node in fxg.graph.nodes:
"""
Updated meta. Previous meta is preserved.
x {'seq_nr': -1, 'node_op': 'placeholder'}
y {'seq_nr': -1, 'node_op': 'placeholder'}
add {'seq_nr': -1, 'node_op': 'call_function'}
output {'seq_nr': -1, 'node_op': 'output'}
"""
print(node, node.meta)
# No real transformation, call super().transform()
fxgn = SimpleTransform(fxg).transform()
# node meta after transform
for node in fxgn.graph.nodes:
"""
Node meta after calling transform().
x {}
y {}
add {'from_node': [('add', <built-in function add>)], 'seq_nr': -1}
output {}
"""
print(node, node.meta)
Question: Is this the correct behavior of preserve_node_meta()
? Given that transform()
returns a new graph, should it still preserve node meta from the old graph?
The example code was run against Pytorch commit eff01bc .
Thanks!