Help PyTorch brain understand JAX/FLAX code

I have seen increasing usage of JAX/FLAX code in github. However, as a person with PyTorch brain, I find it difficult to understand JAX code at the first glance, not to mention porting the JAX code into PyTorch.

After spending sometime learning JAX/FLAX, I find that they are actually very similar to PyTorch, and we can build simple relationship between PyTorch code & JAX/FLAX code.

To be clear, JAX is the computation engine https://jax.readthedocs.io/ , and FLAX https://flax.readthedocs.io/ is the neural network library built upon JAX.

My first observation: JAX is stateless PyTorch.

JAX favors functional programming, so it hates stateful functions. By squeezing states (mainly params) out of PyTorch, we can get JAX API:

The middle column, labeled “core function,” represents a set of high-level functions or operations that are typical in a deep learning training loop. The PyTorch code and the JAX code are being mapped to these core functions to highlight the similarities and differences between them.

  1. Initialization of Models and Parameters:

    • PyTorch: The model is initialized with MyMod(arg_model) and its parameters are accessible via model.params.
    • JAX: The model parameters are initialized directly with init_params(key).
  2. Initialization of Optimizers:

    • PyTorch: The optimizer is set up with the model’s parameters via MyOpt(model.params, arg_opt).
    • JAX: The optimizer is initialized with arguments and parameters via MyOpt(arg_opt, params) and its state with opt.init(params).
  3. Forward Pass:

    • PyTorch: Forward pass is directly done via model(x).
    • JAX: The forward pass is executed using a function, model.apply(params, x).
  4. Loss Calculation:

    • Both PyTorch and JAX use loss_f(y, target) to calculate the loss.
  5. Gradient Computation:

    • PyTorch: Gradients are implicitly computed with loss.backward().
    • JAX: Gradients are explicitly computed using jax.value_and_grad(loss_func)(params, x, target).
  6. Parameter Updates:

    • PyTorch: The optimizer’s step() method is used to update the model’s parameters.
    • JAX: The update is performed via opt.step(grads, opt_state, params), producing new optimizer states and parameters.

From this mapping, we can draw the conclusion that “JAX is stateless PyTorch” because:

  • Explicitness: JAX is more explicit in its operations, such as gradient computation and parameter updates.
  • Statelessness: Instead of modifying state in-place (like PyTorch), JAX employs a functional approach where you explicitly pass states (like params and opt_state) and get back new states.

This mapping to core functions emphasizes the functional nature of JAX versus the more object-oriented nature of PyTorch.

Disclaimer: the code is conceptual only. Detailed function signatures/names are simplified.

My second observation: PyTorch works by creating new nn.Module instances, while FLAX works by inheriting new nn.Module instances.

PyTorch mainly works by the super().__init__() function call, which is explicit and easy to understand. FLAX, however, mainly works by class MyMod(flax.linen.Module). Yes, it is true, there are lots of things that can happen during inherenting, but just most Python users don’t use this feature.

FLAX heavily uses __init_subclass__ to transform all of its sub classes for flax.linen.Module. That’s why people feels FLAX is black magic.

This fugure shows the difference between pytorch model definition and flax model definition:

Certainly! Here are the five differences between the PyTorch and Flax models based on the provided image:

  1. In the Flax model, there is no __init__ function unlike the PyTorch model (it relies on type annotation).
  2. The Flax model uses a single function, init_params, to unify initialization and retrieval of parameters. In contrast, the PyTorch model sets parameters directly within the __init__ method. (note that this function name is conceptual only, and can be omitted if no params in this module)
  3. In the Flax model, the __call__ function combines initialization/retrieval of parameters and the forward pass into one function. The PyTorch model separates these into the __init__ and forward methods. (FLAX can also use setup to create submodules, but it does not init parameters, neither.)
  4. For initialization in Flax, random initialization requires an explicit seed (key) and input, whereas in PyTorch, this is not needed.
  5. The forward pass in the Flax model requires explicit parameters to be passed, while in PyTorch, the parameters are bound to the instance and do not need to be explicitly passed during the forward pass.

Hope this blog can help you if you want to port jax/flax code into pytorch, or vice versa port pytorch code to jax/flax.

5 Likes