State of symbolic shapes branch

Thank you!

Just a couple more crashes (I know, I’m a magnet for them…):

def fn(a):
    b = a * a
    if b.sum():
        return a
    return b

print(fn(torch.tensor([1.0, 2.0])))
traced_f = make_fx(fn, tracing_mode="symbolic")(torch.tensor([1.0, 2.0]))
print(traced_f)

# RuntimeError: tried to get Double out of SymFloat

If using this instead:

def fn(a):
    b = a * a
    if b.sum() >= 1:
        return a
    return b

# NotImplementedError: local_scalar_dense/item NYI for torch.bool

I’ve attempted to fix it by patching fake tensor’s local_scalar_dense:

    elif is_integer_dtype(arg.dtype) or is_boolean_dtype(arg.dtype):
        return fake_mode.shape_env.create_unbacked_symint()

But no luck; still crashes (tried to get Long out of SymInt).