Skip to main content
Back to top
Ctrl
+
K
JAX AI Stack
Getting started
Installing the stack
Getting started with JAX for ML
Part 1: JAX neural net basics
Part 2: Debug a variational autoencoder (VAE)
Part 3: Train a diffusion model for image generation
The stack
Stack overview
JAX: array computing
Flax NNX: neural nets
Optax: optimizers
Orbax: checkpointing
Orbax: model export
Grain: data loading
Chex: test utilities
Tutorials
Visualize JAX model metrics with TensorBoard
Introduction to Data Loaders
Introduction to Data Loaders on CPU with JAX
Introduction to Data Loaders on GPU with JAX
From PyTorch to JAX
JAX for PyTorch users
Porting a PyTorch model to JAX
Example applications
Train a miniGPT language model with JAX
Basic text classification with 1D CNN
Text classification with a transformer language model using JAX
Machine Translation with encoder-decoder transformer model
Image segmentation with UNETR model
Image Captioning with Vision Transformer (ViT) model
Train a Vision Transformer (ViT) for image classification with JAX
Time series classification with CNN
Developer resources
Contribute to documentation
.md
.pdf
Grain: data loading
Grain: data loading
#