JAX arrays¶
JAX arrays are similar to NumPy arrays in many ways and different in some crucial ways. Since JAX focuses on functional programming, hence JAX arrays are immutable. This may require significant amount of changes while transitioning from NumPy code to JAX code.
In this chapter, we will go through a number of examples explaining the similarities and differences.
Importing¶
JAX includes several libraries. The one that closely resembles NumPy is jax.numpy
. Just like we use the shorthand np
for numpy
during the import numpy as np
as a convention, the convention for importing jax.numpy
is:
import jax.numpy as jnp
1-D vectors¶
z = jnp.zeros(4)
print(z)
[0. 0. 0. 0.]
print(z.dtype, z.shape)
float32 (4,)
jnp.ones(4)
DeviceArray([1., 1., 1., 1.], dtype=float32)
jnp.empty(4)
DeviceArray([0., 0., 0., 0.], dtype=float32)
jnp.ones(4, dtype=int)
DeviceArray([1, 1, 1, 1], dtype=int32)
A range of integers¶
a = jnp.arange(5)
print(a)
[0 1 2 3 4]
print(a.dtype, a.shape)
int32 (5,)
# start and stop
jnp.arange(2,8)
DeviceArray([2, 3, 4, 5, 6, 7], dtype=int32)
# start, stop and step size
jnp.arange(2,8, 2)
DeviceArray([2, 4, 6], dtype=int32)
Linearly spaced values¶
jnp.linspace(0, 1, num=5)
DeviceArray([0. , 0.25, 0.5 , 0.75, 1. ], dtype=float32)
# excluding the endpoint.
jnp.linspace(0, 1, num=5, endpoint=False)
DeviceArray([0. , 0.2, 0.4, 0.6, 0.8], dtype=float32)
Boolean vectors¶
jnp.ones(4, dtype=bool)
DeviceArray([ True, True, True, True], dtype=bool)
jnp.zeros(4, dtype=bool)
DeviceArray([False, False, False, False], dtype=bool)
2-D Matrices¶
jnp.zeros((4,4))
DeviceArray([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]], dtype=float32)
jnp.ones((4,4))
DeviceArray([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], dtype=float32)
64 Bit Support¶
By default, JAX works with 32-bit integers and floating point numbers. All calculations are in 32-bit. If you need 64-bit integers and floats, in your calculations, you need to explicitly enable the support.
It is recommended that you enable 64-bit support at the beginning of your program. You shouldn’t switch this parameter in between.
# enabling 64-bit support
from jax.config import config
config.update("jax_enable_x64", True)
jnp.ones(4, dtype=jnp.int64)
DeviceArray([1, 1, 1, 1], dtype=int64)