Flax NNX: neural nets

Flax NNX: neural nets#

Flax NNX provides neural net functionality on top of JAX, such as a module abstraction and pre-defined layers, via a Pythonic object-oriented API. NNX allows you to write stateful model code that can still take advantage of JAX’s function transforms and other features.

NNX has native integration with Optax.

Main Flax NNX site: flax.readthedocs.io

If you’d like to learn more about NNX beyond what’s covered in the Getting started with JAX for ML guide, we recommend starting with Flax basics.

The Flax NNX docs cover many other useful topics including: