Questions about handling non-scalar kwargs in pattern matcher and kwargs capture in replacements

I’m working with the pattern matcher in PyTorch, specifically using ‘register_replacement’ (not ‘gen_register_replacement’). I have two related questions about handling kwargs:

  1. How can we pass non-scalar kwargs to pattern matcher candidates?

Here’s a simplified example of what I’m trying to achieve:

# Pattern with a non-scalar kwarg (e.g., torch.dtype or bool)
def _my_pattern(arg_0, arg_1, kwarg_0):
    mm_0 = torch.ops.aten.mm(arg_0, arg_1)
    res = my_custom_op(mm_0, kw=kwarg_0)
    return res

# Replacement
def _my_replacement(arg_0, arg_1, kwarg_0):
    res = my_fused_op(arg_0, arg_1, kw=kwarg_0)
    return res

# Setting up candidates
arg_1 = functools.partial(torch.empty, (256, 32), device="cpu", requires_grad=True, dtype=torch.float)
arg_2 = functools.partial(torch.empty, (32, 512), device="cpu", requires_grad=True, dtype=torch.float)

# How do we handle non-scalar kwargs here?
# The scalar_workaround in register_replacement only works for values like alpha=1.3

  1. Alternative approach: Is it possible to capture kwargs from the pattern in the replacement?

If we don’t specify the kwargs in the pattern, can we capture them from the FX graph? For example:

# Pattern without explicit kwargs
def _my_pattern(arg_0, arg_1):
    mm_0 = torch.ops.aten.mm(arg_0, arg_1)
    res = my_custom_op(mm_0)  # has kwargs in the actual graph
    return res

# Replacement - can we capture kwargs from pattern?
def _my_replacement(arg_0, arg_1):
    # How to get kwargs from my_custom_op in the pattern?
    res = my_fused_op(arg_0, arg_1)
    return res

I’ve been looking at torch/_inductor/fx_passes/fuse_attention.py for implementation guidance but haven’t found a clear solution for either approach. Any suggestions would be appreciated.

In the second case, you can capture the kwargs inside the extra_check function (callback that does the final check to validate a match). That callback gets the match object passed as a parameter, and if it returns True, the replacement occurs. You could either save the kwarg, let the replacement happen, and post process, or you could return False from the check and manually process the match and execute the replacement either later or inside the extra_check.