Dynamo allow in graph fail on tensor.view(dtype)

hi team, I am trying to run dynamo on python function which contains tensor.view(dtype) operation,
I know that this can’t be captured by dynamo, so I added foo into list of allow functions, e.g. dynamo.allow_in_graph(foo).
however, dynamo still fails to run. Is this a bug or expected behavior? would be great if someone could help to take a look.

cc @ezyang who might know the answer

import torch
import torch._dynamo as dynamo

def foo(a):
    b = a.view(torch.int8)
    return b

def bar(x):
    x = x + 1
    y = foo(x)
    z = y * 2
    return z

torch.manual_seed(42)
x = torch.randn(1,3)

dynamo.allow_in_graph(foo)
dynamo.explain(bar, x)

Full stack trace,

fake tensor raised TypeError
Traceback (most recent call last):
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 987, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1170, in dispatch
    r = func(*args, **kwargs)
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/_ops.py", line 287, in __call__
    return self._op(*args, **kwargs or {})
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/_refs/__init__.py", line 3988, in view
    return _reshape_view_helper(a, *shape, allow_copy=False)
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/_refs/__init__.py", line 3152, in _reshape_view_helper
    shape = utils.infer_size(shape, a.numel())
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/_prims_common/__init__.py", line 725, in infer_size
    elif d >= 0:
TypeError: '>=' not supported between instances of 'torch.dtype' and 'int'
Traceback (most recent call last):
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 1194, in run_node
    return node.target(*args, **kwargs)
  File "/home/thonle/dev/dynamo/repro_fp16.py", line 5, in foo
    b = a.view(torch.int8)
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 987, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/_subclasses/fake_tensor.py", line 1170, in dispatch
    r = func(*args, **kwargs)
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/_ops.py", line 287, in __call__
    return self._op(*args, **kwargs or {})
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/_refs/__init__.py", line 3988, in view
    return _reshape_view_helper(a, *shape, allow_copy=False)
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/_refs/__init__.py", line 3152, in _reshape_view_helper
    shape = utils.infer_size(shape, a.numel())
  File "/home/thonle/.local/lib/python3.8/site-packages/torch/_prims_common/__init__.py", line 725, in infer_size
    elif d >= 0:
TypeError: '>=' not supported between instances of 'torch.dtype' and 'int'

Are you using commit ahead of [Inductor] Fix x.view(dtype) decomp and make inductor support it by yanboliang · Pull Request #102920 · pytorch/pytorch · GitHub?

1 Like