Where we are headed and why it looks a lot like Julia (but not exactly like Julia)

When trying to predict how PyTorch would itself get disrupted, we used to joke a bit about the next version of PyTorch being written in Julia. This was not very serious: a huge factor in moving PyTorch from Lua to Python was to tap into Python’s immense ecosystem (an ecosystem that shows no signs of going away) and even today it is still hard to imagine how a new language can overcome the network effects of Python.

However, recently, I have been thinking about various projects we have going on in PyTorch, including:

  • functorch - write transformations like vmap/grad directly in Python, previously only possible to do as C++ extensions to the dispatcher
  • FX for graph transformations, previously only possible to do as C++ TorchScript passes
  • Python autograd implementation for doing experimental changes to our autograd implementation, previously only possible in C++

What do all of these projects have in common? There’s some functionality that previously people could only do in C++, and the project in question makes it possible to do it in Python, increasing the hackability and ease of development. It’s important to remember that PyTorch used to be written in mostly Python, and we moved everything to C++ to make it run faster. So we are increasingly in a situation where we want to have our cake (hackability) and eat it too (performance).

This is the same story that Julia has been telling for nearly a decade now. Julia says:

  • A language must compile to efficient code, and we will add restrictions to the language (type stability) to make sure this is possible.
  • A language must allow post facto extensibility (multiple dispatch), and we will organize the ecosystem around JIT compilation to make this possible.
  • The combination of these two features gives you a system that has dynamic language level flexibility (because you have extensibility) but static language level performance (because you have efficient code)

We’ve already derived a lot of inspiration from Julia (for example, Zachary DeVito credits the original emphasis on multiple dispatch in our dispatcher to Julia), and I think in general Julia can serve as a very powerful vision of what could be possible, and also what we have to be careful about (e.g., time to first plot). There’s also opportunity to improve on Julia for our domain; e.g., Julia often advertises the fact that you can directly write loops with mathematical operations and have these compile into efficient code–we don’t need to try to pursue this because the cores of our kernels are quite complex and best implemented at a low level in any case.

Why not use Julia directly? We want the Julia vision, but we want it in Python (it’s the ecosystem!) There is tremendous potential in this direction, but also a lot of work and many unresolved design questions. I’m pretty excited about where we are headed next.

Credits to Gregory Chanan who has said many similar things in the past, including in his PTDC talk.

22 Likes

i wonder how technically feasible is to to have the core in Julia instead of C++ but stil have an interface in Python.

7 Likes

I currently use Julia with the Python ecosystem. It is my preferred environment. Julia calls Python transparently and can use any library. For example, I am implementing fast numerical Julia code on top of Huggingface models. There is no need to choose Julia or the Python ecosystem - use both

5 Likes

This would be very nice.

See also Where do the 2000+ PyTorch operators come from?: More than you wanted to know. This kind of composability issue is where Julia shines.

1 Like

I did not know that PyTorch uses multiple dispatch - where can I read more?

It makes sense not wanting to leave the Python ecosystem - as of today, of course. As a total newbie, I feel like PyTorch needs to push more in the production ecosystem and Python is the strongest language to be doing that (given the popularity of MLOps, data engineering and such and libraries like Prefect). Still, I would not exclude a tighter integration with Julia, even just as a standard to look at to hack performance.

1 Like

I’d recommend looking at Let’s talk about the PyTorch dispatcher : ezyang’s blog

and What (and Why) is __torch_dispatch__?.

4 Likes

That’s a real advantage of Python, and I think that’s a perfectly fine response. But it’s worth noticing that’s exactly how disruption happens. Everyone talks about how it would be too much effort to switch to the next big thing because of the sunk costs; then someone puts the effort into switching to the next big thing, and suddenly everyone else is left scrambling to catch up.

Arguably PyTorch has already been disrupted by JAX, and in particular, the huge performance benefits of JAX’s JIT compiler. The real question is where that disruption leaves PyTorch, and whether PyTorch’s better interface and design can let it survive, or if it needs to change at a fundamental level to keep up. That might require finding some way to combine the clean interface of PyTorch with the performance of JAX, e.g. taking full advantage of metaprogramming and multiple dispatch to generate more efficient code.

1 Like

That might require finding some way to combine the clean interface of PyTorch with the performance of JAX, e.g. taking full advantage of metaprogramming and multiple dispatch to generate more efficient code.

You may be interested in TorchDynamo :slight_smile: TorchDynamo: An Experiment in Dynamic Python Bytecode Transformation

1 Like

I find it pretty interesting in theory, but I’m skeptical this will really end up working, simply because compiling bytecode seems much harder than just compiling a language.

At this point, so much time and effort has been put into optimizing Python that I think it would’ve just been faster to rewrite much of the Python ecosystem in another language, or adopt tools like PyCall for most use cases. As things stand, there’s just so much C++ code that has to be maintained, rewritten, or overhauled for performance reasons every time a new project like JAX or TorchDynamo comes around. Things would be much easier if we didn’t have to fight Python’s design every step of the way by taking crazy approaches like recompiling bytecode after it’s already been compiled.

1 Like