import torch
from torch._subclasses.fake_tensor import FakeTensorMode
def model(x):
return torch.where(torch.max(x) > 10, x+x, x.sum(dim=1))
with FakeTensorMode() as mode:
fake_input = torch.randn(64, 1024)
fake_output = model(fake_input)
print(f"Fake Input Shape: {fake_input.shape}")
print(f"Fake Output Shape: {fake_output.shape}")
print(f"Is it fake? {type(fake_output)}")
RuntimeError: Attempting to broadcast a dimension of length 64 at -1! Mismatching argument at index 2 had torch.Size([64]); but expected shape should be broadcastable to [64, 1024]