Thinking in JAX¶
This is a beginner level book to get your grooves in Google JAX
.
What is JAX¶
We attempt to describe JAX by comparing with its predecessors and peers.
In short, JAX
is a new Python library for numerical computing.
It provides a
NumPy
like API for numerical computing. This helps in easy transition of existing users ofNumPy
toJAX
.At the same time, it is built from ground up using functional programming principles.
Thus, data structures like JAX arrays are immutable.
So, although its API resembles
NumPy
, it’s not quite the same. InNumPy
, arrays are mutable.A major focus area for
JAX
is deep learning. It has been regularly compared withTensorFlow
andPyTorch
.However,
JAX
actually provides the low level plumbing for building machine learning libraries. The libraries which provide actual ML building blocks aredm-haiku
,flax
,rlax
,trax
, etc.The
JAX
core can be used in any scientific computing problem.JAX
provides built-in support for automatic differentiation (AD). AD is key for successfully implementing large deep learning networks.JAX
provides a JUST-IN-TIME (jit
) compiler. It can compile a Python code written as per JAX conventions (functional programming etc.) to suitable machine code for a variety of hardware architectures thanks toXLA
.JAX
enables you to write code in Python which can run efficiently across CPU/GPU/TPU architectures.JAX
has built-in support for vectorizing a function over different dimensions of input data.