I am currently working on a custom pass in PyTorch FX, where I need to create an empty tensor whose shape matches the output shape of a specific node. However, I am encountering an issue because the shape of the node in question is symbolic.
Here’s an outline of the situation:
1. Problem Description:
In my pass, I intend to create an empty tensor using the dimensions derived from a node. After this operation, the original node is removed from the graph. However, since the dimension sizes are represented as symbolic integers (SymInt), I face an error when trying to create the empty tensor.
2. Error Encountered:
When I attempt to create the tensor, I receive the following error:
File “torch/_inductor/ir.py”, line 3098, in init
assert all(isinstance(s, (Expr, int)) for s in size)
Lowering Exception: AssertionError:
target: aten.empty.memory_format
args[0]: (s0, s1)
3. Initial Solution Attempt:
I tried using guard_int to convert SymInt into an Expr, but this approach replaced the symbolic shapes with actual runtime values, which leads to recompilation every time the shape changes. This is not ideal as it defeats the purpose of dynamic shape handling in the graph.
Questions:
- Is there a way to create an empty tensor with symbolic dimensions without losing the symbolic nature of those dimensions or causing recompilation?
- Are there recommended strategies for working with symbolic shapes in custom FX passes?
Below is the reproducer for the issue -
import torch
import torch.fx
import torch.nn as nn
from torch._dynamo import register_backend
from torch._inductor.compile_fx import compile_fx, compile_fx_inner
from typing import List, Optional
from torch.fx.experimental.symbolic_shapes import guard_int
# --- Step 1: Model ---
class AddModel(nn.Module):
def forward(self, a, b):
return torch.add(a, b)
# --- Step 2: FX Compile Pass ---
def insert_add_out_pass(gm: torch.fx.GraphModule):
graph = gm.graph
for node in list(graph.nodes):
if node.target == torch.ops.aten.add.Tensor:
with graph.inserting_after(node):
# Get shape, dtype, device from meta
shape = node.meta["val"].shape
dtype = node.meta["val"].dtype
device = node.meta["val"].device
print(f"Found the shape to be: ", shape)
shape_exprs = [guard_int(s) for s in node.meta["val"].shape]
# Create empty tensor node
with graph.inserting_after(node):
empty_node = graph.call_function(
torch.ops.aten.empty.memory_format,
args=(shape_exprs,),
kwargs=dict(
dtype=dtype,
device=device,
layout=torch.strided,
pin_memory=False,
memory_format=torch.contiguous_format,
),
)
with graph.inserting_after(empty_node):
add_node = graph.call_function(torch.ops.aten.add, args=(node.args[0], node.args[1]))
with graph.inserting_after(add_node):
copy_node = graph.call_function(torch.ops.aten.copy_, args=(empty_node, add_node))
node.replace_all_uses_with(copy_node)
graph.erase_node(node)
graph.lint()
gm.recompile()
print(graph)
return gm
def custom_compile(gm: torch.fx.GraphModule, example_inputs,
cudagraphs=None, static_input_idxs: Optional[List[int]] = None,
num_fixed: int = 0,
is_backward: bool = False,
graph_id: Optional[int] = None,
cpp_wrapper: bool = False,
aot_mode: bool = False,
is_inference: bool = False,
boxed_forward_device_index=None,
user_visible_outputs=frozenset(),
layout_opt: Optional[bool] = None,):
print("Running custom compile")
gm = insert_add_out_pass(gm)
return compile_fx_inner(gm, example_inputs)
@register_backend
def custom_backend(gm: torch.fx.GraphModule, example_inputs):
print("Running custom backend")
return compile_fx(gm, example_inputs, inner_compile=custom_compile)
# --- Step 4: Compile & Run ---
model = AddModel()
compiled_model = torch.compile(model, backend=custom_backend, dynamic=True)
# Run with different shapes
a1 = torch.randn(2, 3)
b1 = torch.randn(2, 3)
with torch.no_grad():
out1 = compiled_model(a1, b1)
print("Out1 shape:", out1.shape)
a2 = torch.randn(4, 5)
b2 = torch.randn(4, 5)
with torch.no_grad():
out2 = compiled_model(a2, b2)
print("Out2 shape:", out2.shape)