Hi,
I’m interested in PyTorch’s operator fusion. I have been reading the Scheduler’s source code and noticed a strange case where it seems to consider the fusability of a pair of nodes node1, node2
in the normal order: self.can_fuse(node1, node2)
but also the reverse order: self.can_fuse(node2, node1)
. Specifically, I’m referring to this code here:
def check_all_pairs(nodes):
for node1_index, node1 in enumerate(nodes):
for node2 in nodes[node1_index + 1 :]:
key = (node1, node2)
if key in seen:
continue
seen.add(key)
if self.can_fuse(node1, node2):
possible_fusions.append(key)
elif (node2.is_template() or node2.is_foreach()) and self.can_fuse(
node2, node1
):
# foreach fusions and epilogue fusions are order dependent
possible_fusions.append((node2, node1))
I inspected this function on some trivial functions via torch.compile()
and it seems that nodes
is typically a set containing a graph node n
and all the nodes it outputs to m_1,...m_k
(it’s probably not that simple, so feel free to correct me there). So we have a set (n
, m_1
, … m_k
). So in the self.can_fuse(node1, node2)
case we have:
- Check whether we can fuse
n
with somem_i+1
(propbably vertical fusion) - Check whether we can fuse
m_i
withm_i+1
(probably horizontal fusion)
for somei
in1...k
.
Now my question is, when is self.can_fuse(node2, node1)
meaningful? Is this specifically for cases where the order of operations doesn’t affect the correctness? Although in that case self.can_fuse(node1, node2)
should also be sufficient.
Would appreciate someone’s insight here. Many thanks!
FYI: this is a re-post of this question on the pytorch forum. I’m hoping that’s ok given it didn’t see any response there.