JAX

JAX is an open-source numerical computing library from Google, first released in 2018. Its starting point is deliberately familiar: a NumPy-compatible API, so that array code written against NumPy can often run under JAX with little change. What JAX adds is a set of function transformations that can be composed with each other, turning ordinary Python functions into differentiated, compiled, and vectorized versions of themselves.

Three transforms define the library. grad performs reverse-mode automatic differentiation, returning a new function that computes the gradient of the original and able to differentiate through loops, branches, recursion, and closures. jit compiles a function end to end through XLA, Google’s Accelerated Linear Algebra compiler, fusing operations and producing fast code that runs unchanged on CPU, GPU, or TPU. vmap adds a batch dimension automatically by pushing the loop down into the function’s primitive operations, so code written for a single example can run over a batch without being rewritten.

The power of the design is that these transforms compose. Because each one takes a function and returns a function, they can be stacked in any order, as in jit(vmap(grad(loss))), which produces a compiled, batched, per-example gradient. This composability is the conceptual core that distinguishes JAX from frameworks where differentiation, batching, and compilation are separate special-purpose machinery.

To make transforms behave predictably, JAX leans toward a functional programming style. It works best with pure functions and immutable arrays, and it handles randomness through explicit key values rather than hidden global state. That discipline is what lets the same function be traced, differentiated, and compiled cleanly, and it gives JAX programs a different flavor from the imperative, stateful style of a typical PyTorch model.

JAX is intentionally a low-level substrate rather than a full deep-learning framework, which is why an ecosystem of higher-level libraries grew on top of it for building neural networks, optimizers, and training loops. The official repository describes JAX as a library for accelerator-oriented array computation and program transformation, and its combination of a NumPy front end with composable, XLA-backed transforms made it a favored tool for research that needs both flexibility and high performance on TPUs and GPUs.

Sources

Last verified June 8, 2026