What is the relationship among nn.module API / nn.functional API / aten API?

With respect to PyTorch frontend APIs, I can come up with three sets of APIs:

  • nn.module API, such as torch.nn.Dropout
  • nn.functional API, such as torch.nn.functional.dropout
  • aten ops, such as torch.ops.aten.dropout.default

What is the relationship among these APIs?

Here is what I think, and please correct me if I’m wrong:

  • aten ops are the implementation of kernels
  • nn.functional APIs are functional, they wrap around aten ops, add some sanity checks against arguments
  • nn.module APIs contain state, and pass the state to stateless functional APIs

And when we step into the PyTorch compiler, the contrary is eager mode. Then which are considered as “eager mode”?

When we switch from eager mode to the PyTorch compiler with torch.compile(backend="eager"), which level of APIs does it trace? Does it trace those wrapper logics in nn.module and nn.functional?

These questions arise when I want to fix a bug confrimed by @Chillee , but the PR caused some dispute among @jansel @peterbell10 @bdhirsh .

Furthermore, I find that even the meaning of aten ops is not clear.

I can get all the aten ops and overloads by:

aten_ops = []
aten_ops_with_overloads = []

for k in dir(torch.ops.aten):
    possible_op = getattr(torch.ops.aten, k)
    if hasattr(possible_op, "overloads"):
        aten_ops.append(possible_op)
        for ovl in possible_op.overloads():
            aten_ops_with_overloads.append(getattr(possible_op, ovl))

And I see some overloads like:

<OpOverload(op='aten.__and__', overload='Scalar')>,
 <OpOverload(op='aten.__and__', overload='Tensor')>,
 <OpOverload(op='aten.__and__', overload='bool')>,
 <OpOverload(op='aten.__and__', overload='int')>,
 <OpOverload(op='aten.__iand__', overload='Tensor')>,
 <OpOverload(op='aten.__iand__', overload='Scalar')>,

Clearly there will be another layer of logic for dispatching an operator to one of its overload. Does Dynamo also deal with these dispatch logic?

I was expecting that, as we go from eager backend to aot_eager and further to inductor, Dynamo will introduce more and more guards, but I find that guards of eager are the same of aot_eager. That might mean that dispatching logic is not pushed to Dynamo guards.

Hey!

I think you’re actually missing a 4th layer: torch.dropout in your example above.

In order of lowest level to higher level:

  • torch.ops.* is a raw binding to our c++ dispatcher-based API. It contains all the “native” ops including the ones dynamically registered via torch library (both from python and c++). All features working at the dispatcher level will see these ops (or a particular subset of them): torch_dispatch classes and modes, aotautograd, inductor, etc
  • torch.* is our main python API, it is either written in python or a direct binding to some c++ ops. This is where torch_function happens, dynamo tracing and a lot of the fx tracing.
  • torch.nn.functional.* is a functional API for nn, it is usually logic wrapping a cal to torch.*.
  • torch.nn.Module is the stateful Module, it is holding onto state (params, buffers) and calls into functional or torch.* functions.
1 Like

Thanks, Alban!

It’s strange that torch.dropout has no doc string, but torch.nn.functional.dropout does, which makes torch.nn.functional.dropout more formal and user-facing. That’s why I ignored torch.dropout :cry:

With respect to Dynamo tracing, it seems it just records what it sees:

@torch.compile(backend="eager")
def f(x):
    return torch.nn.functional.dropout(x, p=0.5)
@torch.compile(backend="eager")
def f(x, mod):
    return mod(x)
f(torch.randn(5), torch.nn.ReLU())
@torch.compile(backend="eager")
def f(x):
    return torch.ops.aten.relu(x)
@torch.compile(backend="eager")
def f(x):
    return torch.add(x, 1)

With the above functions, I can see nn modules / nn functional / torch.xxx API / torch.ops.aten.xxx in Dynamo captured graphs.

Nothing special, most likely an oversight when dropout was added as a native op before we added the CI that ensures that all functions are properly documented :frowning: