While we explore a possible ONNX exporter based on lowering a FX graph leveraging ONNX Script as a high level ONNX API, we hit a conundrum.
Currently, aten/src/ATen/native/native_functions.yaml
already provides type annotations for both input and output, some of the types, such as Tensor
, do not specify the actual data type dtype
they support. E.g., Tensor[torch.float32]
, Tensor[torch.float64]
or even Tensor[torch.bool]
if that tensor belongs to a boolean mask.
Although for PyTorch the Tensor
type generalization does not pose a problem as it has dtype
and data
attributes to help users handling/representing any underlying data type, ONNX Script leverages ONNX Function in a way the exact Tensor.dtype
is needed to map between a FX symbol, say aten::scaled_dot_product_attention
to either def _aten_scaled_dot_product_attention(query: TFloat, key: TFloat, value: TFloat, attn_mask: BOOL, ...)
or def aten_scaled_dot_product_attention(query: TFloat, key: TFloat, value: TFloat, attn_mask: TFloat, ...)
. The former ONNX function expects attn_mask
to be a Tensor.float32
while the latter expects a Torch.bool
.
One could workaround such limitation by mapping several ONNX functions to a single ATen operator and then implement extra logic to dispatch to the appropriate dtype variant. That costs time and add extra complexity. and we wonder if we could do better by expanding current type annotation on native_functions.yaml
to include such dtype
information.
Taking aten::scaled_dot_product_attention
as an example, the original declaration might changed from
- func: scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor
python_module: nn
variants: function
autogen: scaled_dot_product_attention.out
to
- func: scaled_dot_product_attention(Tensor[torch.float32] query, Tensor[torch.float32] key, Tensor value[torch.float32], Tensor[torch.float32,torch.bool]? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> Tensor[torch.fooat32]
python_module: nn
variants: function
autogen: scaled_dot_product_attention.out
The notation above is quiet verbose as it needs to be repeated over over across all inputs and output arguments and also across all func
.
Another attempt would be:
- func: scaled_dot_product_attention(T query, T key, T value, U? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> T
python_module: nn
variants: function
autogen: scaled_dot_product_attention.out
type_constraints: {T: [torch.float32,torch.float64], U: [torch.float32, torch.bool]}
In this approach, type_constraints
is declared once per func
and defines all dtypes variants for each argument, which then can be reused amongst all arguments within the func
. That does not solve the repetion across func
s, thuogh
Yet another approach would be extending the previous one with global definitions of types that all func
s could use, saving type_constrains
for very specific scenarios.
Again, the main idea is to parse the yaml file and learn all type variations each aten func might have so that we can provision ONNX implementation for each. This would be needed because ONNX/ONNX Script have static typing, which prevents us from having a single ONNX implementation for all type variation for a given input.
However, this approach could be expanded to maybe implement the cast policy for Autocast
. Instead of manually registering the cast policy for each operator, we could annotate the func entries with such information and have autocast code automatically generated based on that
, as suggested by pytorch/autocast_mode.cpp
Another application might be enforcing runtime type check at the dispatch level to make sure kernels are not being called with unexpected input.