Stack overview#
The JAX AI Stack is comprised of the following packages:
JAX: high-performance array computing
Flax NNX: object-oriented neural nets
Optax: optimizers
Orbax: checkpointing and model export
Grain: JAX-native data loading
Chex: JAX test utilities
The jax-ai-stack
metapackage installs compatible versions of all of these
libraries, as well as shared compatible versions of shared dependencies.
In addition, there is an optional jax-ai-stack[tfds]
installation that
includes TensorFlow
Datasets,
for those who wish to use TFDS for data loading. This includes a compatible
version of TensorFlow as well.