JAX 001 - A closer look at its background

I recently wanted to get into parallel programming so that I can optimize one of the optics simulation projects I was working on (it was built without much focus on distributed training). So with all the buzz going around with JAX and some cool applications I found (listed below), I thought to give it a try.

Why JAX?

With all the established libraries such as PyTorch and Tensorflow, why is there a requirement for this new library in the first place? Let’s find out.

We all know that increasing FLOPS (floating point operations per second) is a huge deal in machine learning to train models efficiently. JAX aims to help this goal by enabling researchers to write python programs which are automatically compiled and scaled to utilize accelerators (GPUs/TPUs). Often it ishard to write optimized code in python to leverage the potential of hardware accelerators. JAX aims to keep a balance between research-friendly programming experience vs hardware acceleration.

To do so, JAX aims to accelerate pure-and-statically-composed (PSC) Subroutines. It is done through a just-in-time (JIT) compiler which traces PSC routines.to learn more about “pure-and-statically-composed (PSC) Subroutines” please refer toNeurIPS 2020: JAX Ecosystem Meetup video .For this compilation process, the execution of the code needs to be monitored once. Therefore, JAX stands for J ust A fter e X ecution “.  (Source :https://cs.stanford.edu/~rfrostig/pubs/jax-mlsys2018.pdf[JAX Paper] )

Key features of JAX

JAX addresses several limitations in numpy. Therefore, it presents,

(autodiff, JIT compilation, vectorization, parallelization)

> JAX is Autograd and XLA, brought together for high-performance numerical computing and machine learning research. It provides composable transformations of Python+NumPy programs: differentiate, vectorize, parallelize, Just-In-Time compile to GPU/TPU, and more.

Figure 1: Credit: {mattjj, frostig, leary, dougalm, phawkins, skyewm, jekbradbury, necula} @google.com. Full set of slides can be found here.

Getting familiar with related concepts and history

There were a few loosly defined terms in my mind when I first read through the documentation. So I thought to look into those and have an idea about them.

Autograd - Autograd is an example of an automatic differentiation libraryreleased in 2014 . It is a lightweight tool to automatically differentiate native python and numpy code. But it doesn’t focus on hardware accelerators such GPU/TPUs. Therefore, since 2018, the main developers of Autograd are focusing on JAX which has more functionalities. Dougal, one of the main authors of autograd,calls JAX as the second generation of the project which started with Autograd.

“XLA - (Accelerated Linear Algebra) is a compiler-based linear algebra execution engine. It is the backend that powers machine learning frameworks such as TensorFlow and JAX at Google, on a variety of devices including CPUs, GPUs, and TPUs.” - (toread more )

JIT - (Just in time) compilation: This enables compilation of a JAX Python function so it can be executed efficiently in XLA.

Digging into these concepts made me realize that while the libraries we use makes it really easy to build end-to-end differentiable models, under the hood, these tools have been years-long effort by amazing teams with cool design choices.

In the next writeup I’m going to share my experience with device parallelization (pmap) and automatic vectorization (vmap) aspects.

Resources :

  1. https://jax.readthedocs.io/en/latest/index.html
  2. Compiling machine learning programs via high-level tracing [pdf]
  3. Matthew Johnson - JAX: accelerated ML research via composable function transformations (video ) [slides ]
  4. JAX at DeepMind (video ) [slides ]
  5. Lecture 6: Automatic Differentiation [pdf ]
  6. 🌟 Accelerate your cutting-edge machine learning research with free Cloud TPUs. (applyhere ).