I’m working with the pattern matcher in PyTorch, specifically using ‘register_replacement’ (not ‘gen_register_replacement’). I have two related questions about handling kwargs:
- How can we pass non-scalar kwargs to pattern matcher candidates?
Here’s a simplified example of what I’m trying to achieve:
# Pattern with a non-scalar kwarg (e.g., torch.dtype or bool)
def _my_pattern(arg_0, arg_1, kwarg_0):
mm_0 = torch.ops.aten.mm(arg_0, arg_1)
res = my_custom_op(mm_0, kw=kwarg_0)
return res
# Replacement
def _my_replacement(arg_0, arg_1, kwarg_0):
res = my_fused_op(arg_0, arg_1, kw=kwarg_0)
return res
# Setting up candidates
arg_1 = functools.partial(torch.empty, (256, 32), device="cpu", requires_grad=True, dtype=torch.float)
arg_2 = functools.partial(torch.empty, (32, 512), device="cpu", requires_grad=True, dtype=torch.float)
# How do we handle non-scalar kwargs here?
# The scalar_workaround in register_replacement only works for values like alpha=1.3
- Alternative approach: Is it possible to capture kwargs from the pattern in the replacement?
If we don’t specify the kwargs in the pattern, can we capture them from the FX graph? For example:
# Pattern without explicit kwargs
def _my_pattern(arg_0, arg_1):
mm_0 = torch.ops.aten.mm(arg_0, arg_1)
res = my_custom_op(mm_0) # has kwargs in the actual graph
return res
# Replacement - can we capture kwargs from pattern?
def _my_replacement(arg_0, arg_1):
# How to get kwargs from my_custom_op in the pattern?
res = my_fused_op(arg_0, arg_1)
return res
I’ve been looking at torch/_inductor/fx_passes/fuse_attention.py
for implementation guidance but haven’t found a clear solution for either approach. Any suggestions would be appreciated.