I am attempting to write a pattern matcher to match fused operators, and currently encounter an issue: neither search_fn nor search_fn_pattern provides a way to perform fuzzy matching for optional parameters. For example, optional parameters like offset and bias in the following operator prototype appear differently in the FX graph when they are passed or omitted:
Operator prototype as follows:
npu_quant_matmul(x1, x2, scale, *, offset=None, pertoken_scale=None, bias=None) -> Tensor
Case 1 (with optional parameters):
torch.ops.npu.npu_quant_matmul.default(arg2_1, arg3_1, arg4_1, offset=arg5_1, bias=arg6_1);
Case 2 (without optional parameters):
torch.ops.npu.npu_quant_matmul.default(arg2_1, arg3_1, arg4_1);
Attempt 1: Using search_fn
The following search_fn only matches Case 1 and fails to match Case 2 (where optional parameters are omitted). To cover Case 2, I have to write a separate search_fn:
# Only matches Case 1
def search_fn(x1, x2, scale, offset, bias):
trans = x2.transpose(0, 1)
return torch.ops.npu.npu_quant_matmul.default(x1, trans, scale, offset=offset, bias=bias)
# Only matches Case 2
def search_fn(x1, x2, scale):
trans = x2.transpose(0, 1)
return torch.ops.npu.npu_quant_matmul.default(x1, trans, scale)
Attempt 2: Using search_fn_pattern (CallFunction)
I also tried using search_fn_pattern (from torch._inductor.pattern_matcher import register_replacement as register) and wrote a CallFunction as shown below. However, using Ignored() results in a placeholder *, which causes a failure during the matching process when verifying len(self.args)!=len(node.args). It seems impossible to match both cases with a single CallFunction:
def _build_search_pattern():
npu_transpose_func = CallFunction(
torch.ops.aten.transpose.int,
KeywordArg("x2"),
Ignored(),
Ignored(),
_users=1
)
output = CallFunction(
torch.ops.npu.npu_quant_matmul.default,
KeywordArg("x1"),
npu_transpose_func,
KeywordArg('scale'),
offset=KeywordArg('offset'), #Alternatively, writing Ignore() here will cause a failure during the matching process when verifying that len(self.args)!=len(node.args)
bias=KeywordArg('bias'),
_users=1
)
return output
Impact:
Some operators may have more optional parameters (e.g., n optional parameters). In such cases, I would need to write 2âż search_fn or search_fn_pattern to cover all scenarios, which is extremely inefficient and unscalable.
Questions:
- Is there any syntax to perform fuzzy matching for this scenario, so that both cases (with/without optional parameters passed) can be matched with a single pattern?
- What is the difference between the two writing methods in the matching process:
1.CallFunction(offset=KeywordArg(âoffsetâ))
2.CallFunction(KeywordArg(âoffsetâ))?