JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.
This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!
Be sure to check out our (experimental) interactive web version: https://lockwo.github.io/awesome-jax/.
Why do we need another "awesome-jax" list? Existing ones are inactive, and this is directly based on the no longer active Awesome JAX repos https://github.com/n2cholas/awesome-jax/ and https://github.com/mhlr/awesome-jax.
-
Neural Network Libraries
-
Reinforcement Learning Libraries
- Algorithms
- cleanrl - High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG).
- rlax - a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning agents.
- purejaxrl - Really Fast End-to-End Jax RL Implementations.
- Mava - π¦ A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX.
- cleanrl - High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG).
- Environments
- Algorithms
-
Natural Language Processing Libraries
-
JAX Utilities Libraries
- jaxtyping - Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays.
- chex - a library of utilities for helping to write reliable JAX code.
- mpi4jax - Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python β‘.
- jax-tqdm - Add a tqdm progress bar to your JAX scans and loops.
- JAX-Toolbox - JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs.
- penzai - A JAX research toolkit for building, editing, and visualizing neural networks.
- orbax - Orbax provides common checkpointing and persistence utilities for JAX users.
- jaxtyping - Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays.
-
Computer Vision Libraries
-
Distributions, Sampling, and Probabilistic Libraries
- distreqx - Distrax, but in equinox. Lightweight JAX library of probability distributions and bijectors.
- distrax - a lightweight library of probability distributions and bijectors.
- flowjax - Distributions, bijections and normalizing flows using Equinox and JAX.
- blackjax - BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
- bayex - Minimal Implementation of Bayesian Optimization in JAX.
- efax - Exponential families for JAX.
- distreqx - Distrax, but in equinox. Lightweight JAX library of probability distributions and bijectors.
-
GPJax - Gaussian processes in JAX.
-
tinygp - The tiniest of Gaussian Process libraries.
-
Diffrax - Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.
-
jax-md - Differentiable, Hardware Accelerated, Molecular Dynamics.
-
lineax - Linear solvers in JAX and Equinox.
-
optimistix - Nonlinear optimisation (root-finding, least squares, etc.) in JAX+Equinox.
-
sympy2jax - Turn SymPy expressions into trainable JAX expressions.
-
quax - Multiple dispatch over abstract array types in JAX.
-
interpax - Interpolation and function approximation with JAX.
-
quadax - Numerical quadrature with JAX.
-
optax - Optax is a gradient processing and optimization library for JAX.
-
dynamax - State Space Models library in JAX.
-
dynamiqs - High-performance quantum systems simulation with JAX (GPU-accelerated & differentiable solvers).
-
scico - Scientific Computational Imaging COde.
-
exojax - π Automatic differentiable spectrum modeling of exoplanets/brown dwarfs using JAX, compatible with NumPyro and Optax/JAXopt.
-
PGMax - Loopy belief propagation for factor graphs on discrete variables in JAX.
-
evosax - Evolution Strategies in JAX π¦.
-
evojax - EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JAX library, this toolkit enables neuroevolution algorithms to work with neural networks running in parallel across multiple TPU/GPUs.
-
mctx - Monte Carlo tree search in JAX.
-
kfac-jax - Second Order Optimization and Curvature Estimation with K-FAC in JAX.
-
jwave - A JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs.
-
jax_cosmo - A differentiable cosmology library in JAX.
-
jaxlie - Rigid transforms + Lie groups in JAX.
-
ott - Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.
-
XLB - XLB: Accelerated Lattice Boltzmann (XLB) for Physics-based ML.
-
EasyDeL - Accelerate, Optimize performance with streamlined training and serving options with JAX.
-
QDax - Accelerated Quality-Diversity.
-
paxml - Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
-
econpizza - Solve nonlinear heterogeneous agent models.
-
fedjax - FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research.
-
neural-tangents - Fast and Easy Infinite Neural Networks in Python.
-
jax-fem - Differentiable Finite Element Method with JAX.
-
veros - The versatile ocean simulator, in pure Python, powered by JAX.
-
JAXFLUIDS - Differentiable Fluid Dynamics Package.
- traceax - Stochastic trace estimation using JAX.
- graphax - Cross-Country Elimination in JAX.
- cd_dynamax - Extension of dynamax repo to cases with continuous-time dynamics with measurements sampled at possibly irregular discrete times. Allows generic inference of dynamical systems parameters from partial noisy observations via auto-differentiable filtering, SGD, and HMC.
- Haiku - JAX-based neural network library.
- jraph - A Graph Neural Network Library in Jax.
- SymJAX - symbolic CPU/GPU/TPU programming.
- coax - Modular framework for Reinforcement Learning in python.
- eqxvision - A Python package of computer vision models for the Equinox ecosystem.
- jaxfit - GPU/TPU accelerated nonlinear least-squares curve fitting using JAX.
- safejax - Serialize JAX, Flax, Haiku, or Objax model params with π€
safetensors
. - kernex - Stencil computations in JAX.
- lorax - LoRA for arbitrary JAX models and functions.
- mcx - Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
- einshape - DSL-based reshaping library for JAX and other frameworks.
- jax-flows - Normalizing Flows in JAX π.
- sklearn-jax-kernels - Composable kernels for scikit-learn implemented in JAX.
- deltapv - A photovoltaic simulator with automatic differentiation.
- cr-sparse - Functional models and algorithms for sparse signal processing.
- flaxvision - A selection of neural network models ported from torchvision for JAX & Flax.
- imax - Image augmentation library for Jax.
- jax-unirep - Reimplementation of the UniRep protein featurization model.
- parallax - Immutable Torch Modules for JAX.
- jax-resnet - Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
- elegy - A High Level API for Deep Learning in JAX.
- objax - Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base.
- jaxrl - JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with continuous action spaces.
- whisper-jax - JAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU.
- esm2quinox - An implementation of ESM2 in Equinox+JAX.
- Learning JAX as a PyTorch developer
- Massively parallel MCMC with JAX
- Achieving Over 4000x Speedups and Meta-Evolving Discoveries with PureJaxRL
- How to add a progress bar to JAX scans and loops
- MCMC in JAX with benchmarks: 3 ways to write a sampler
- Deterministic ADVI in JAX
- Exploring hyperparameter meta-loss landscapes with Jax
- Evolving Neural Networks in JAX
- Meta-Learning in 50 Lines of JAX
- Implementing NeRF in JAX
- Normalizing Flows in 100 Lines of JAX
- JAX vs Julia (vs PyTorch)
- From PyTorch to JAX: towards neural net frameworks that purify stateful code
- out of distribution detection using focal loss
- Differentiable Path Tracing on the GPU/TPU
- Getting started with JAX (MLPs, CNNs & RNNs)