Stack overview

Stack overview#

The JAX AI Stack is comprised of the following packages:

  • JAX: high-performance array computing

  • Flax NNX: object-oriented neural nets

  • Optax: optimizers

  • Orbax: checkpointing and model export

  • Grain: JAX-native data loading

  • Chex: JAX test utilities

The jax-ai-stack metapackage installs compatible versions of all of these libraries, as well as shared compatible versions of shared dependencies.

In addition, there is an optional jax-ai-stack[tfds] installation that includes TensorFlow Datasets, for those who wish to use TFDS for data loading. This includes a compatible version of TensorFlow as well.