How is pattern matching in inductor/fx implemented?

Inductor and FX heavily use pattern matching to match and replace subgraph patterns for optimization. However, subgraph matching (finding a subgraph that is isomorphic to a given graph) is a well-known NP-hard problem. What is the mental model for inductor/FX to deal with the problem? Do they focus on some restricted set of patterns?

To be honest, reading code directly in https://github.com/pytorch/pytorch/blob/main/torch/_inductor/pattern_matcher.py is quite a pain.

1 Like

Hello Kaichao,

Indeed, you’re correct. The pattern matcher here serves as a tool to revise the torch.fx graph module, transforming it according to specific patterns outlined here. For instance, take the well-known flash attention optimization, or fuse attention. If the original fx graph module is detected, the torch inductor will replace it with a new operator named aten.scaled_dot_product_attention.

There are numerous patterns yet to be explored, particularly for GEMM-like operators. If multiple operators fall within the scope of pointwise or reduction types, the inductor scheduler will seamlessly fuse them into a single operator. Following this, it will invoke the codegen backend, such as OpenMP for CPU or Triton for GPU, to execute the operation. Exploring how to identify and create efficient patterns, as well as designing a general scheduler and IR suitable for a wide range of ML workloads, remains an intriguing and complex challenge.

I hope this information is helpful to you. Looking forward to further discussion.

2 Likes

@fishmingyu Thanks for the information. The general workflow is clear. My question is how “If the original fx graph module is detected” is implemented. I mean, subgraph detection is in general very hard.

Take the fuse attention as an example:

def _sfdp_pattern_1(query, key, value, inv_scale):
    return (
        torch.matmul(query, key.transpose(-2, -1))
        .div(inv_scale)
        .softmax(dim=-1)
        .matmul(value)
    )

This subgraph is quite complex. How do we match such a graph? Do we start by looking at a transpose operator, followed by matmul / div / softmax / matmul sequence?

And what is the level of graph we are matching? Do we match aten ops, or torch.xxx API or torch.nn modules? I suppose subgraph matching in a graph with diverse node types would be very difficult.

Put it this way:

If I have a computation graph g, and a pattern search_fn represented by another computation graph s, does inductor promise to find all matching graph (up to isomorphism) s in g?

If the answer is yes, are there any constraints on the structure of s, either by limiting the number of nodes or limiting the number of edges?

If the answer is no, are there any suggestions for writing search_fn so that inductor works better?

Inductor and FX heavily use pattern matching to match and replace subgraph patterns for optimization.

I’d disagree with the “heavily” here. You can entirely disable pattern matching and the performance impact is small overall and zero on many programs. The main usage is for creating calls to aten.scaled_dot_product_attention.

The key part of the algorithm to do matching is a recursive algorithm that walks the two graphs and compares them, starting from an “anchor” node which is the output of the pattern:

While in the general case graph matching is NP-hard, it is relatively easy for the types of graphs that make sense as machine learning programs. What we have here is closer to the tree inclusion problem which can be done in polynomial time.

4 Likes

Thanks, Jason, the information is super useful!

By tree inclusion problem, do you mean search_fn always represents a rooted tree?

Indeed I can imagine finding a rooted tree in DAG (computation graph) should be much easier than general graph matching. And it gets even better because children nodes are ordered in computation graph (i.e. the order is the order of arguments to a function).

The pseudo code I can come up with:

match_node(tree_node, graph_node):
    if graph_node matches tree_node (in terms of op type and args):
        for tree_child, graph_child in zip(tree_node.children, graph_node.children):
            if not match_node(tree_child, graph_node):
                return False
        return True
    return False

find_tree(tree, graph):
    for node in graph:
        if match_node(tree, node):
            return node
    return None

Yeah that is roughly it

1 Like

Hi,

I tried adding a custom SDPA pattern in torch/_inductor/fx_passes/fuse_attention.py based on BERT_pytorch (which is different from hf_Bert).
The pattern’s matching fails with itself (if a UT is added, or if a simple python script based on the pattern is run).
Could you please give me some inputs on how I should approach debugging this issue (i.e. some approach on identifying what mismatched)? I’ve added a test script (which needs serialized patterns to be rebuilt because a new pattern is added) corresponding to a test branch which adds a new pattern.

Thank you! :slight_smile:

EDIT: this issue has a workaround.

I am attempting to write a pattern matcher to match fused operators, and currently encounter an issue: neither search_fn nor search_fn_pattern provides a way to perform fuzzy matching for optional parameters. For example, optional parameters like offset and bias in the following operator prototype appear differently in the FX graph when they are passed or omitted:

Operator prototype as follows:

npu_quant_matmul(x1, x2, scale, *, offset=None, pertoken_scale=None, bias=None) -> Tensor

Case 1 (with optional parameters):
torch.ops.npu.npu_quant_matmul.default(arg2_1, arg3_1, arg4_1, offset=arg5_1, bias=arg6_1);
Case 2 (without optional parameters):
torch.ops.npu.npu_quant_matmul.default(arg2_1, arg3_1, arg4_1);
Attempt 1: Using search_fn
The following search_fn only matches Case 1 and fails to match Case 2 (where optional parameters are omitted). To cover Case 2, I have to write a separate search_fn:

# Only matches Case 1
def search_fn(x1, x2, scale, offset, bias):
    trans = x2.transpose(0, 1)
    return torch.ops.npu.npu_quant_matmul.default(x1, trans, scale, offset=offset, bias=bias)

# Only matches Case 2
def search_fn(x1, x2, scale):
    trans = x2.transpose(0, 1)
    return torch.ops.npu.npu_quant_matmul.default(x1, trans, scale)

Attempt 2: Using search_fn_pattern (CallFunction)
I also tried using search_fn_pattern (from torch._inductor.pattern_matcher import register_replacement as register) and wrote a CallFunction as shown below. However, using Ignored() results in a placeholder *, which causes a failure during the matching process when verifying len(self.args)!=len(node.args). It seems impossible to match both cases with a single CallFunction:

def _build_search_pattern():
        npu_transpose_func = CallFunction(
            torch.ops.aten.transpose.int, 
            KeywordArg("x2"), 
            Ignored(), 
            Ignored(), 
            _users=1
        )

        output = CallFunction(
            torch.ops.npu.npu_quant_matmul.default, 
            KeywordArg("x1"),
            npu_transpose_func,
            KeywordArg('scale'),
            offset=KeywordArg('offset'), #Alternatively, writing Ignore() here will cause a failure during the matching process when verifying that len(self.args)!=len(node.args)
            bias=KeywordArg('bias'), 
            _users=1
        )

        return output

Impact:
Some operators may have more optional parameters (e.g., n optional parameters). In such cases, I would need to write 2ⁿ search_fn or search_fn_pattern to cover all scenarios, which is extremely inefficient and unscalable.
Questions:

  1. Is there any syntax to perform fuzzy matching for this scenario, so that both cases (with/without optional parameters passed) can be matched with a single pattern?
  2. What is the difference between the two writing methods in the matching process:
    1.CallFunction(offset=KeywordArg(‘offset’))
    2.CallFunction(KeywordArg(‘offset’))?

Within inductor the calling format for all the operators should be normalized by AOTAutograd at the start, so you shouldn’t need to match all combinations of args/kwargs. Sounds like you are doing some custom stuff, so I’d recommend running an arg normalization pass on the graph before pattern matching.

CallFunction(offset=KeywordArg(‘offset’)) matches only the kwarg variant, while CallFunction(KeywordArg(‘offset’)) matches only the positional variant. The KeywordArg(‘offset’) controls how the arg is passed to the handler function, KeywordArg maches exactly the same thing as Arg.

I think we also have a wildcard match that matches anything, and you can subclass to create your own custom pattern.