State of symbolic shapes branch

I guess in principle it could be rewritten as t = torch.where(mask, t * 2, t).

See Alternative to array-based boolean indexing for jax.jit · Issue #2765 · google/jax · GitHub