Enhance type annotation on native_functions.yaml

While we explore a possible ONNX exporter based on lowering a FX graph leveraging ONNX Script as a high level ONNX API, we hit a conundrum.

Currently, aten/src/ATen/native/native_functions.yaml already provides type annotations for both input and output, some of the types, such as Tensor, do not specify the actual data type dtype they support. E.g., Tensor[torch.float32], Tensor[torch.float64] or even Tensor[torch.bool] if that tensor belongs to a boolean mask.

Although for PyTorch the Tensor type generalization does not pose a problem as it has dtype and data attributes to help users handling/representing any underlying data type, ONNX Script leverages ONNX Function in a way the exact Tensor.dtype is needed to map between a FX symbol, say aten::scaled_dot_product_attention to either def _aten_scaled_dot_product_attention(query: TFloat, key: TFloat, value: TFloat, attn_mask: BOOL, ...) or def aten_scaled_dot_product_attention(query: TFloat, key: TFloat, value: TFloat, attn_mask: TFloat, ...). The former ONNX function expects attn_mask to be a Tensor.float32 while the latter expects a Torch.bool.

One could workaround such limitation by mapping several ONNX functions to a single ATen operator and then implement extra logic to dispatch to the appropriate dtype variant. That costs time and add extra complexity. and we wonder if we could do better by expanding current type annotation on native_functions.yaml to include such dtype information.

Taking aten::scaled_dot_product_attention as an example, the original declaration might changed from

- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
  python_module: nn
  variants: function
  autogen: scaled_dot_product_attention.out


- func: scaled_dot_product_attention(Tensor[torch.float32] query, Tensor[torch.float32] key, Tensor value[torch.float32], Tensor[torch.float32,torch.bool]? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor[torch.fooat32]
  python_module: nn
  variants: function
  autogen: scaled_dot_product_attention.out

The notation above is quiet verbose as it needs to be repeated over over across all inputs and output arguments and also across all func.

Another attempt would be:

- func: scaled_dot_product_attention(T query, T key, T value, U? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> T
  python_module: nn
  variants: function
  autogen: scaled_dot_product_attention.out
  type_constraints: {T: [torch.float32,torch.float64], U: [torch.float32, torch.bool]}

In this approach, type_constraints is declared once per func and defines all dtypes variants for each argument, which then can be reused amongst all arguments within the func. That does not solve the repetion across funcs, thuogh

Yet another approach would be extending the previous one with global definitions of types that all funcs could use, saving type_constrains for very specific scenarios.

Again, the main idea is to parse the yaml file and learn all type variations each aten func might have so that we can provision ONNX implementation for each. This would be needed because ONNX/ONNX Script have static typing, which prevents us from having a single ONNX implementation for all type variation for a given input.

However, this approach could be expanded to maybe implement the cast policy for Autocast. Instead of manually registering the cast policy for each operator, we could annotate the func entries with such information and have autocast code automatically generated based on that
, as suggested by pytorch/autocast_mode.cpp

Another application might be enforcing runtime type check at the dispatch level to make sure kernels are not being called with unexpected input.

Any thoughts? @ezyang @albanD @jansel

1 Like

Why not just run the graph with a FakeTensor to annotate the types/devices? That will also give you symbolic shapes and strides.

This would be useful (and we are leveraging that) on a further stage, when we are actually lowering a FX node to ONNX.

However, the problem above arises before the actual lowering. In order to implement the ONNX operator registry, we need to know what are all possible dtype variations for each aten/prim operator so that we can implement and register them.

We could do this manually or creating yet another yaml file to annotate all variations, but a single yaml as source of true would be easier to maintain in the long run.

Even if we did all this annotation (which would be a massive amount of work), I think there is a high likelihood the annotations would be wrong, since the ultimate source of truth would be eager mode not the annotation.

I’d suggest utilizing OpInfo testing to ensure ONNX captures the behavior of every operator. This would also capture many more subtle behaviors that go beyond just dtypes.

I do agree it might be a lot of work to annotate all of them. At first, we could default all func to AllTypes and eventually we start plumbing with the specifics, as needed.

Regarding the mismatch between eager and annotation, we have two options. 1) eager mode always use type annotation to ensure users are not messing up or, if that brings a perf issue, 2) we could use the build debug flag to make eager mode do the type validation based on annotations.

This way, the eager mode would have a stricter type documentation and users will learn which inputs could cause undefined behavior. Today undefined behavior due to input type might be hard to pin point.

OpInfo is great and we intend to use, but still does not address the problem. I will try to rephrase the problem statement as follows:

“Given a FX func f, how to determine all input and output types combinations in terms of their data types?”

Given a single function_a(Tensor inp0, Tensor inp1) -> Tensor and assuming Tensor is limited to dtype: Union[torch.float32, torch.float64] for simplicity, the input/output dtype combination could be:

function_a(Tensor[torch.float32], Tensor[torch.float32]) -> Tensor[torch.float32]
function_a(Tensor[torch.float32], Tensor[torch.float32]) -> Tensor[torch.float64]
function_a(Tensor[torch.float32], Tensor[torch.float64]) -> Tensor[torch.float32]
function_a(Tensor[torch.float64], Tensor[torch.float32]) -> Tensor[torch.float32]
function_a(Tensor[torch.float32], Tensor[torch.float64]) -> Tensor[torch.float64]
function_a(Tensor[torch.float64], Tensor[torch.float32]) -> Tensor[torch.float64]
function_a(Tensor[torch.float64], Tensor[torch.float64]) -> Tensor[torch.float32]
function_a(Tensor[torch.float64], Tensor[torch.float64]) -> Tensor[torch.float64]

As can be seen, a single func entry at native_functions.yaml could lead to 8 different ONNX Function schemas that we need to generate. This is a simplisitc example, as Tensor can include more dtypes. More over, depending on the function spec, some of these combinations might be invalid, but without additional metadata, we can’t tell. e.g., function_a(Tensor[torch.float64], Tensor[torch.float64]) -> Tensor[torch.float32] might be invalid outside the autocast context manager.

I hope this helps clarifying the issue we are trying to tackle.

Without extra annotations, we have to manually read native_functions.yaml, specify and implement all the dtype combinations for each func on a 1:n map. Even using a script to auto gen it, a manual step is needed to delete invalid combinations. We would also have the burden of making sure that our registry is up to date with funcs on native_functions.yaml.

However we solve the problem above, the rest could be like you mentioned. Once the set of ONNX schemas for each Aten schema is discovered and implemented, we are ready to do proper FX → ONNX lowering. Each of the ONNX function can be tested using OpInfo to assert their correct behavior.

This doesn’t really fit with the data model of PyTorch, in theory it would be perfectly legal for me to register an op that did something like:

def my_custom_op(x):
   if x.size(1) % 2:
      return x.to(torch.float32)
      return x.to(torch.float16)

That is a contrived example, but there are ops that change dtypes based on non-tensor args and whose behavior changes based on global amp modes. The source of truth is the implementation, and I wouldn’t want to change the source of truth to some external file we needed to keep in sync with the implementations. That sounds like a maintenance burden.

We solved this same problem in a more general with meta functions, fake tensors, and symints. I could write a meta function for the above example that queried x.size(1) % 2 (guarding in the case of symint) and provided accurate metadata.

I don’t think ONNX should try to manually build such a yaml file on the side either. You should just automatically generate it based on OpInfo testing. OpInfo tests contain example inputs for each operator, including all the dtypes those ops support. You could just write a script to scrape the data you need out of that. If there are some coverage gaps in the OpInfo inputs, you could also write a script to just exhaustively try every dtype on every op – using OpInfo inputs as a starting place.

I think @jansel is completely right here. OpInfos tackle this exact problem, and in fact, there is a test within them that makes sure that the dtypes that the OpInfos declare as correct are in fact the correct ones.

When looking at the OpInfos, you can see that there are too many exceptions for a typing system like that to work. First of all, you can see that the set of supported dtypes may change with the backend (CPU/CUDA/ROCM). Even more, these supported dtypes may not work for all the inputs. As such, as Jason mentioned, you need to peek at the metadata of the tensor in a non-trivial way to be able to figure out which dtypes the inputs may accept.

As Jason mentioned, I’d recommend to implement this by using the inputs from OpInfos and tracing through them to get inputs that work for the ATen operators. If you want to do this manually, the function opinfo.utils.get_supported_dtypes may be useful.

Thanks, will look into OpInfo and check how it can be leveraged to auto generate aten ops signature variants