JAX is a Python library, developed by researchers at Google, for accelerator-oriented array computation and program transformation. Its README describes it as composable transformations of Python and NumPy programs. The core idea is that you write ordinary numerical code using a NumPy-compatible API, and then apply transformations that rewrite that code automatically.
Three transformations define JAX. jax.grad performs automatic differentiation, computing exact gradients of native Python functions, which is what training a neural network needs. jax.jit just-in-time compiles functions through XLA, Google’s compiler, so the same code runs fast on CPUs, GPUs, and TPUs. jax.vmap automatically vectorizes a function so a routine written for a single example runs efficiently over a whole batch without manual rewriting. Because these transformations compose, you can differentiate, compile, and vectorize the same function together.
JAX occupies a distinct niche from PyTorch and TensorFlow. Rather than a full framework with built-in layers, it is a lower-level numerical foundation favored in research and high-performance settings; libraries like Flax and Haiku build neural-network abstractions on top of it. Several frontier models, including some from Google DeepMind, were trained on JAX.
Why business readers should care: JAX is the substrate behind a meaningful share of cutting-edge model training, and its compile-and-scale design is part of how labs squeeze performance out of expensive accelerator clusters.