That might require finding some way to combine the clean interface of PyTorch with the performance of JAX, e.g. taking full advantage of metaprogramming and multiple dispatch to generate more efficient code.
You may be interested in TorchDynamo TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation