I have seen increasing usage of JAX/FLAX code in github. However, as a person with PyTorch brain, I find it difficult to understand JAX code at the first glance, not to mention porting the JAX code into PyTorch.
After spending sometime learning JAX/FLAX, I find that they are actually very similar to PyTorch, and we can build simple relationship between PyTorch code & JAX/FLAX code.
To be clear, JAX is the computation engine https://jax.readthedocs.io/ , and FLAX https://flax.readthedocs.io/ is the neural network library built upon JAX.
My first observation: JAX is stateless PyTorch.
JAX favors functional programming, so it hates stateful functions. By squeezing states (mainly params) out of PyTorch, we can get JAX API:
The middle column, labeled “core function,” represents a set of high-level functions or operations that are typical in a deep learning training loop. The PyTorch code and the JAX code are being mapped to these core functions to highlight the similarities and differences between them.
-
Initialization of Models and Parameters:
- PyTorch: The model is initialized with
MyMod(arg_model)
and its parameters are accessible viamodel.params
. - JAX: The model parameters are initialized directly with
init_params(key)
.
- PyTorch: The model is initialized with
-
Initialization of Optimizers:
- PyTorch: The optimizer is set up with the model’s parameters via
MyOpt(model.params, arg_opt)
. - JAX: The optimizer is initialized with arguments and parameters via
MyOpt(arg_opt, params)
and its state withopt.init(params)
.
- PyTorch: The optimizer is set up with the model’s parameters via
-
Forward Pass:
- PyTorch: Forward pass is directly done via
model(x)
. - JAX: The forward pass is executed using a function,
model.apply(params, x)
.
- PyTorch: Forward pass is directly done via
-
Loss Calculation:
- Both PyTorch and JAX use
loss_f(y, target)
to calculate the loss.
- Both PyTorch and JAX use
-
Gradient Computation:
- PyTorch: Gradients are implicitly computed with
loss.backward()
. - JAX: Gradients are explicitly computed using
jax.value_and_grad(loss_func)(params, x, target)
.
- PyTorch: Gradients are implicitly computed with
-
Parameter Updates:
- PyTorch: The optimizer’s
step()
method is used to update the model’s parameters. - JAX: The update is performed via
opt.step(grads, opt_state, params)
, producing new optimizer states and parameters.
- PyTorch: The optimizer’s
From this mapping, we can draw the conclusion that “JAX is stateless PyTorch” because:
- Explicitness: JAX is more explicit in its operations, such as gradient computation and parameter updates.
- Statelessness: Instead of modifying state in-place (like PyTorch), JAX employs a functional approach where you explicitly pass states (like
params
andopt_state
) and get back new states.
This mapping to core functions emphasizes the functional nature of JAX versus the more object-oriented nature of PyTorch.
Disclaimer: the code is conceptual only. Detailed function signatures/names are simplified.
My second observation: PyTorch works by creating new nn.Module instances, while FLAX works by inheriting new nn.Module instances.
PyTorch mainly works by the super().__init__()
function call, which is explicit and easy to understand. FLAX, however, mainly works by class MyMod(flax.linen.Module)
. Yes, it is true, there are lots of things that can happen during inherenting, but just most Python users don’t use this feature.
FLAX heavily uses __init_subclass__
to transform all of its sub classes for flax.linen.Module
. That’s why people feels FLAX is black magic.
This fugure shows the difference between pytorch model definition and flax model definition:
Certainly! Here are the five differences between the PyTorch and Flax models based on the provided image:
- In the Flax model, there is no
__init__
function unlike the PyTorch model (it relies on type annotation). - The Flax model uses a single function,
init_params
, to unify initialization and retrieval of parameters. In contrast, the PyTorch model sets parameters directly within the__init__
method. (note that this function name is conceptual only, and can be omitted if no params in this module) - In the Flax model, the
__call__
function combines initialization/retrieval of parameters and the forward pass into one function. The PyTorch model separates these into the__init__
andforward
methods. (FLAX can also usesetup
to create submodules, but it does not init parameters, neither.) - For initialization in Flax, random initialization requires an explicit seed (key) and input, whereas in PyTorch, this is not needed.
- The forward pass in the Flax model requires explicit parameters to be passed, while in PyTorch, the parameters are bound to the instance and do not need to be explicitly passed during the forward pass.
Hope this blog can help you if you want to port jax/flax code into pytorch, or vice versa port pytorch code to jax/flax.