I understand that can_fuse and scores_fusion are responsible for fusion of buffers at the Inductor IR level. But I also see some fusions in the fx_passes before the lowering to Inductor IR goes through. Can someone clarify the entire flow from a fusion perspective?
Another question would be is it possible to access the Dict on scores_fusion or output of self.get_possible_fusions_with_highest_priority, basically a list of all possible fusions with their scores
The pattern matching in fx_passes is different from fusions and matters much less for performance. A lot of the stuff in pattern matching are graph rewrites (such as removing redundant ops), not fusing things together.
Inlining happens during lowering, which is more similar to fusions. This is controlled by buffer.realize(), where any access to a buffer that hasn’t been realized will be inlined into the consumer (possibly recomputing it in multiple consumers). An unrealized buffer is never allocated or computed on its own.
The list of all possible fusions (ordered by score_fusion) is computed in this function:
You could call that function from a debug environment or print it out.
The fusions in fuse_fx() are disabled by default. It only does something if you set permute_fusion=True (for GPU-only, three patterns) or freezing=True (for CPU-only, only one pattern). The permute_fusion stuff actually isn’t even fusions, it is rewrite rules to simplify cases of matmul().permute(). The freezing one is more like constant folding.
Likely a distraction to focus on that, it is inconsequential for perf unless your model has some pretty uncommon patterns in it.
Do the pointwise ops get fused even before inductor lowering happens? Because I noticed in the buffers generated by Triton codegen that my manual RMSNorm implementation was already a single buffer even in the Inductor IRs that were generated pre_fusion
Sure the decompositions part makes sense, but then what about these ComputedBuffers which aren’t just one pointwise op or reduction op, but a combination of multiple nodes