I guess in principle it could be rewritten as t = torch.where(mask, t * 2, t)
.
See Alternative to array-based boolean indexing for jax.jit · Issue #2765 · google/jax · GitHub
I guess in principle it could be rewritten as t = torch.where(mask, t * 2, t)
.
See Alternative to array-based boolean indexing for jax.jit · Issue #2765 · google/jax · GitHub