What is the relation between PrimTorch and Decompostions?

I read some discussions and source codes about PrimTorch and Decompostions but still confused about the following parts:

  1. Decompostions are defined across multiple directories - _refs, _decomp, _inductor. Why? Is there a historical reason?

  2. Some decompositions target prim ops but other targets torch aten ops, such as logaddexp:

def logaddexp(a: TensorLikeType, b: TensorLikeType) -> TensorLikeType:
    # Nb. this implementation does not distribute the gradients evenly when a == b
    mask = torch.real(a) >= torch.real(b)
    max_ = torch.where(mask, a, b)
    min_ = torch.where(mask, b, a)
    inf_mask = torch.logical_and(
        torch.logical_not(torch.isfinite(torch.real(a))), torch.real(a) == torch.real(b)
    )
...

Should all decompostions target prim ops? If not, what is the relation between PrimTorch and Decompostions?

I will very appreciate if someone enlighten me.

2 Likes

I’ve also been confused on the second question for a long time, in my current understanding, one should have been able to choose a set of decompositions, and there should be a set of decompositions made up of all-core-aten/all-prims ops. While from this repo, it seems that the API of choosing the decompose-set is not currently avilible and the core_aten_decompositions() is not a pure core-aten decompose-set.

About the relation between PrimTorch and Decompostions: PrimTorch = Core Aten + Prims (APIs of Core Aten and Prims can be found here), so it is fair to say “aten ops will be decomposed to core-aten and/or prims” or “aten ops will be decomposed to PrimTorch”.

2 Likes

So, there’s a couple general principles with our decompositions. One of them is that each of them is generally “minimal”. So, for example, layer_norm will be decomposed into aten.var_mean, which will then be decomposed into aten.var and aten.mean and so on. Then, at the end, they will often get decomposed into prims. The idea here is that the decompositions are kinda like a bus journey, and we want to provide as many “bus stops” as possible along the way.

There are a couple of places where we decompose into some useful prim ops early on. For example, prims.convert_element_type or prims.device_put.

  1. Decompostions are defined across multiple directories - _refs, _decomp, _inductor. Why? Is there a historical reason?

Concretely, decompositions defined in _refs target both “torch” operators as well as “aten” operators. Decompositions defined in _inductor are Inductor specific (and might be say, an optimization implemented as a decomposition). For example, we

  1. Should all decompostions target prim ops?

As mentioned above, all operators should eventually be decomposable into prim ops. But they may decompose into many other operators along the way.

1 Like

Hi Chillee, thanks for you reply.

I thought decomposition can only apply once for a aten op but it seems it can be decomposed multiple times until it eventually ends up in a small set of operators(maybe core aten and prims?).

Concretely, decompositions defined in _refs target both “torch” operators as well as “aten” operators. Decompositions defined in _inductor are Inductor specific (and might be say, an optimization implemented as a decomposition). For example, we

How about the decompositions defined in _decomp? Are the same as those in _refs?

it can be decomposed multiple times until it eventually ends up in a small set of operators

Yes, that’s the idea. And the recursive decomposition also allows backends to easily “fall back” to a higher level representation if they can’t handle the lower level one.

How about the decompositions defined in _decomp? Are the same as those in _refs?

In principle, yes, they’re pretty similar. Decompositions defined in _refs have the additional property that you can also use them to decompose PyTorch’s python API (i.e. torch.foo operators), while decompositions are used for decomposing PyTorch’s C++ API (i.e. aten.foo operators).

Luckily, there’s a lot of overlap, so a lot of “_refs” have also been registered as decompositions.

1 Like

Decompositions defined in _refs have the additional property that you can also use them to decompose PyTorch’s python API (i.e. torch.foo operators)

Thanks!!! That’s very helpful.

In my view, decomposition can transform aten ops to smaller and simpler ops, and thus reduce the complexity of AI compliers.
It seems that python ref also serves the same purpose. The only difference is python ref is for python API. Is there other differences or example usages?

That’s 100% correct.