Thank you!
Just a couple more crashes (I know, I’m a magnet for them…):
def fn(a):
b = a * a
if b.sum():
return a
return b
print(fn(torch.tensor([1.0, 2.0])))
traced_f = make_fx(fn, tracing_mode="symbolic")(torch.tensor([1.0, 2.0]))
print(traced_f)
# RuntimeError: tried to get Double out of SymFloat
If using this instead:
def fn(a):
b = a * a
if b.sum() >= 1:
return a
return b
# NotImplementedError: local_scalar_dense/item NYI for torch.bool
I’ve attempted to fix it by patching fake tensor’s local_scalar_dense
:
elif is_integer_dtype(arg.dtype) or is_boolean_dtype(arg.dtype):
return fake_mode.shape_env.create_unbacked_symint()
But no luck; still crashes (tried to get Long out of SymInt
).