Reverse Fusion of Node Pairs in Scheduler

Hi,

I’m interested in PyTorch’s operator fusion. I have been reading the Scheduler’s source code and noticed a strange case where it seems to consider the fusability of a pair of nodes node1, node2 in the normal order: self.can_fuse(node1, node2) but also the reverse order: self.can_fuse(node2, node1) . Specifically, I’m referring to this code here:

def check_all_pairs(nodes):
    for node1_index, node1 in enumerate(nodes):
        for node2 in nodes[node1_index + 1 :]:
            key = (node1, node2)
            if key in seen:
                continue
            seen.add(key)

            if self.can_fuse(node1, node2):
                possible_fusions.append(key)
            elif (node2.is_template() or node2.is_foreach()) and self.can_fuse(
                node2, node1
            ):
                # foreach fusions and epilogue fusions are order dependent
                possible_fusions.append((node2, node1))

I inspected this function on some trivial functions via torch.compile() and it seems that nodes is typically a set containing a graph node n and all the nodes it outputs to m_1,...m_k (it’s probably not that simple, so feel free to correct me there). So we have a set (n, m_1, … m_k). So in the self.can_fuse(node1, node2) case we have:

  1. Check whether we can fuse n with some m_i+1 (propbably vertical fusion)
  2. Check whether we can fuse m_i with m_i+1 (probably horizontal fusion)
    for some i in 1...k.

Now my question is, when is self.can_fuse(node2, node1) meaningful? Is this specifically for cases where the order of operations doesn’t affect the correctness? Although in that case self.can_fuse(node1, node2) should also be sufficient.

Would appreciate someone’s insight here. Many thanks!

FYI: this is a re-post of this question on the pytorch forum. I’m hoping that’s ok given it didn’t see any response there.