Inductor and FX heavily use pattern matching to match and replace subgraph patterns for optimization. However, subgraph matching (finding a subgraph that is isomorphic to a given graph) is a well-known NP-hard problem. What is the mental model for inductor/FX to deal with the problem? Do they focus on some restricted set of patterns?
Indeed, you’re correct. The pattern matcher here serves as a tool to revise the torch.fx graph module, transforming it according to specific patterns outlined here. For instance, take the well-known flash attention optimization, or fuse attention. If the original fx graph module is detected, the torch inductor will replace it with a new operator named aten.scaled_dot_product_attention.
There are numerous patterns yet to be explored, particularly for GEMM-like operators. If multiple operators fall within the scope of pointwise or reduction types, the inductor scheduler will seamlessly fuse them into a single operator. Following this, it will invoke the codegen backend, such as OpenMP for CPU or Triton for GPU, to execute the operation. Exploring how to identify and create efficient patterns, as well as designing a general scheduler and IR suitable for a wide range of ML workloads, remains an intriguing and complex challenge.
I hope this information is helpful to you. Looking forward to further discussion.
@fishmingyu Thanks for the information. The general workflow is clear. My question is how “If the original fx graph module is detected” is implemented. I mean, subgraph detection is in general very hard.
This subgraph is quite complex. How do we match such a graph? Do we start by looking at a transpose operator, followed by matmul / div / softmax / matmul sequence?
And what is the level of graph we are matching? Do we match aten ops, or torch.xxx API or torch.nn modules? I suppose subgraph matching in a graph with diverse node types would be very difficult.
If I have a computation graph g, and a pattern search_fn represented by another computation graph s, does inductor promise to find all matching graph (up to isomorphism) s in g?
If the answer is yes, are there any constraints on the structure of s, either by limiting the number of nodes or limiting the number of edges?
If the answer is no, are there any suggestions for writing search_fn so that inductor works better?
Inductor and FX heavily use pattern matching to match and replace subgraph patterns for optimization.
I’d disagree with the “heavily” here. You can entirely disable pattern matching and the performance impact is small overall and zero on many programs. The main usage is for creating calls to aten.scaled_dot_product_attention.
The key part of the algorithm to do matching is a recursive algorithm that walks the two graphs and compares them, starting from an “anchor” node which is the output of the pattern:
While in the general case graph matching is NP-hard, it is relatively easy for the types of graphs that make sense as machine learning programs. What we have here is closer to the tree inclusion problem which can be done in polynomial time.
By tree inclusion problem, do you mean search_fn always represents a rooted tree?
Indeed I can imagine finding a rooted tree in DAG (computation graph) should be much easier than general graph matching. And it gets even better because children nodes are ordered in computation graph (i.e. the order is the order of arguments to a function).
The pseudo code I can come up with:
match_node(tree_node, graph_node):
if graph_node matches tree_node (in terms of op type and args):
for tree_child, graph_child in zip(tree_node.children, graph_node.children):
if not match_node(tree_child, graph_node):
return False
return True
return False
find_tree(tree, graph):
for node in graph:
if match_node(tree, node):
return node
return None
I tried adding a custom SDPA pattern in torch/_inductor/fx_passes/fuse_attention.py based on BERT_pytorch (which is different from hf_Bert).
The pattern’s matching fails with itself (if a UT is added, or if a simple python script based on the pattern is run). Could you please give me some inputs on how I should approach debugging this issue (i.e. some approach on identifying what mismatched)? I’ve added a test script (which needs serialized patterns to be rebuilt because a new pattern is added) corresponding to a test branch which adds a new pattern.