So what is so exciting about the new JAX framework?
Understanding JAX for Machine Learning
JAX is python's
numpy with automatic differentiation and optimized to run on GPU. The seamless translation between writing
numpy and writing in JAX has made JAX popular with machine learning practitioners.
JAX offers four main function transformations that make it efficient to use when executing deep learning workloads.
JAX Four Function Transformations
grad - automatically differentiates a function for backpropagation. You can take
grad to any derivative order.
from jax import grad import jax.numpy as jnp def tanh(x): # Define a function y = jnp.exp(-2.0 * x) return (1.0 - y) / (1.0 + y) grad_tanh = grad(tanh) # Obtain its gradient function print(grad_tanh(1.0)) # Evaluate it at x = 1.0 # prints 0.4199743
jit - auto-optimizes your functions to run their operations efficiently. Can also be used as a function decorator.
import jax.numpy as jnp from jax import jit def slow_f(x): # Element-wise ops see a large benefit from fusion return x * x + x * 2.0 x = jnp.ones((5000, 5000)) fast_f = jit(slow_f) %timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X %timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
vmap - maps a function across dimensions. Means that you don't have to keep track of dimensions as carefully when passing a batch through, for example.
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
pmap - maps processes across multiple processors, like multi-GPU
from jax import random, pmap import jax.numpy as jnp # Create 8 random 5000 x 6000 matrices, one per GPU keys = random.split(random.PRNGKey(0), 8) mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys) # Run a local matmul on each device in parallel (no data transfer) result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000) # Compute the mean on each device in parallel and print the result print(pmap(jnp.mean)(result)) # prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
JAX vs PyTorch for Machine Learning
The nearest machine learning framework to JAX is PyTorch. That is because they share their roots in striving to be as "numpy-esque" as possible.
JAX's functionality with lower level function definitions makes it preferable for certain research tasks.
That said, PyTorch offers a much further breadth of libraries and utilities, pre-trained and pre-written network definitions, a data loader, and portability to deployment destinations.
JAX vs TensorFlow for Machine Learning
JAX and TensorFlow were both written by Google. From my initial experimentation, JAX seems much easier to develop in and is more intuitive.
That said, JAX lacks the extensive infrastructure that TensorFlow has built over the years - be it open source projects, pre-trained models, tutorials, higher level abstractions (via Keras), and portability to deployment destinations.
What Does JAX Lack?
- A Data Loader - you'll need to implement your own or hop over to
PyTorchto borrow one,.
- Higher level model abstractions
- Deployment portability
When Should I Use JAX?
JAX is a new machine learning framework that has been gaining popularity in machine learning research.
If you're operating in the research realm, JAX is a good option for your project.
If you're actively developing an application, PyTorch and TensorFlow frameworks will move your initiative along with greater velocity. And of course, in computer vision there is always a tradeoff to weigh in building vs buying computer vision tooling.
Thanks for reading our writeup on JAX. Happy training, and of course, happy inferencing.