You've probably heard of TensorFlow and PyTorch, and maybe you've even heard of MXNet - but there is a new kid on the block of machine learning frameworks - Google's JAX.

Over the last two years, JAX has been taking deep learning research by storm, facilitating the implementation of Google's Vision Transformer (ViT) and powering research at DeepMind.

So what is so exciting about the new JAX framework?

JAX at Large

Boiled down, JAX is python's numpy with automatic differentiation and optimized to run on GPU. The seamless translation between writing numpy and writing in JAX has made JAX popular with machine learning practitioners.

JAX offers four main function transformations that make it efficient to use when executing deep learning workloads.

JAX Four Function Transformations

grad - automatically differentiates a function for backpropagation. You can take grad to any derivative order.

from jax import grad
import jax.numpy as jnp

def tanh(x):  # Define a function
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient function
print(grad_tanh(1.0))   # Evaluate it at x = 1.0
# prints 0.4199743

jit - auto-optimizes your functions to run their operations efficiently. Can also be used as a function decorator.

import jax.numpy as jnp
from jax import jit

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x)  # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x)  # ~ 14.5 ms / loop (also on GPU via JAX)

vmap - maps a function across dimensions. Means that you don't have to keep track of dimensions as carefully when passing a batch through, for example.

predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)

pmap - maps processes across multiple processors, like multi-GPU

from jax import random, pmap
import jax.numpy as jnp

# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]

JAX vs PyTorch

The nearest machine learning framework to JAX is PyTorch. That is because they share their roots in striving to be as "numpy-esque" as possible.

JAX's functionality with lower level function definitions makes it preferrable for certain research tasks.

That said, PyTorch offers a much further breadth of libraries and utilities, pre-trained and pre-written network definitions, a data loader, and portability to deployment destinations.

JAX vs TensorFlow

JAX and TensorFlow were both written by Google. From my initial experimentation, JAX seems much easier to develop in and is more intuitive.

That said, JAX lacks the extensive infrastructure that TensorFlow has built over the years - be it open source projects, pre-trained models, tutorials, higher level abstractions (via Keras), and portability to deployment destinations.

What JAX lacks?

  • A Data Loader - you'll need to implement your own or hop over to TensorFlow or PyTorch to borrow one,.
  • Higher level model abstractions
  • Deployment portability

When should I use JAX?

JAX is a new machine learning framework that has been gaining popularity in machine learning research.

If you're operating in the research realm, JAX is a good option for your project.

If you're actively developing an application, PyTorch and TensorFlow frameworks will move your initiative along with greater velocity. And of course, in computer vision there is always a tradeoff to weigh in building vs buying computer vision tooling.

Thanks for reading our writeup on JAX! Happy training, and of course, happy inferencing!