Skip to content

lockwo/awesome-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

7 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Awesome JAX AwesomeJAX Logo

JAX brings automatic differentiation and the XLA compiler together through a NumPy-like API for high performance machine learning research on accelerators like GPUs and TPUs.

This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!

Be sure to check out our (experimental) interactive web version: https://lockwo.github.io/awesome-jax/.

Why do we need another "awesome-jax" list? Existing ones are inactive, and this is directly based on the no longer active Awesome JAX repos https://github.com/n2cholas/awesome-jax/ and https://github.com/mhlr/awesome-jax.

Contents

Libraries

  • Neural Network Libraries

    • Flax - Flax is a neural network library for JAX that is designed for flexibility.
    • Equinox - Elegant easy-to-use neural networks + scientific computing in JAX.
  • Reinforcement Learning Libraries

    • Algorithms
      • cleanrl - High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG).
      • rlax - a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning agents.
      • purejaxrl - Really Fast End-to-End Jax RL Implementations.
      • Mava - 🦁 A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX.
    • Environments
      • pgx - Vectorized RL game environments in JAX.
      • jumanji - πŸ•ΉοΈ A diverse suite of scalable reinforcement learning environments in JAX.
      • gymnax - RL Environments in JAX 🌍.
      • brax - Massively parallel rigidbody physics simulation on accelerator hardware.
  • Natural Language Processing Libraries

    • levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax.
    • maxtext - A simple, performant and scalable Jax LLM!
    • EasyLM - Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.
  • JAX Utilities Libraries

    • jaxtyping - Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays.
    • chex - a library of utilities for helping to write reliable JAX code.
    • mpi4jax - Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚑.
    • jax-tqdm - Add a tqdm progress bar to your JAX scans and loops.
    • JAX-Toolbox - JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs.
    • penzai - A JAX research toolkit for building, editing, and visualizing neural networks.
    • orbax - Orbax provides common checkpointing and persistence utilities for JAX users.
  • Computer Vision Libraries

    • Scenic - Scenic: A Jax Library for Computer Vision Research and Beyond.
    • dm_pix - PIX is an image processing library in JAX, for JAX.
  • Distributions, Sampling, and Probabilistic Libraries

    • distreqx - Distrax, but in equinox. Lightweight JAX library of probability distributions and bijectors.
    • distrax - a lightweight library of probability distributions and bijectors.
    • flowjax - Distributions, bijections and normalizing flows using Equinox and JAX.
    • blackjax - BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.
    • bayex - Minimal Implementation of Bayesian Optimization in JAX.
    • efax - Exponential families for JAX.
  • GPJax - Gaussian processes in JAX.

  • tinygp - The tiniest of Gaussian Process libraries.

  • Diffrax - Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

  • jax-md - Differentiable, Hardware Accelerated, Molecular Dynamics.

  • lineax - Linear solvers in JAX and Equinox.

  • optimistix - Nonlinear optimisation (root-finding, least squares, etc.) in JAX+Equinox.

  • sympy2jax - Turn SymPy expressions into trainable JAX expressions.

  • quax - Multiple dispatch over abstract array types in JAX.

  • interpax - Interpolation and function approximation with JAX.

  • quadax - Numerical quadrature with JAX.

  • optax - Optax is a gradient processing and optimization library for JAX.

  • dynamax - State Space Models library in JAX.

  • dynamiqs - High-performance quantum systems simulation with JAX (GPU-accelerated & differentiable solvers).

  • scico - Scientific Computational Imaging COde.

  • exojax - 🐈 Automatic differentiable spectrum modeling of exoplanets/brown dwarfs using JAX, compatible with NumPyro and Optax/JAXopt.

  • PGMax - Loopy belief propagation for factor graphs on discrete variables in JAX.

  • evosax - Evolution Strategies in JAX 🦎.

  • evojax - EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JAX library, this toolkit enables neuroevolution algorithms to work with neural networks running in parallel across multiple TPU/GPUs.

  • mctx - Monte Carlo tree search in JAX.

  • kfac-jax - Second Order Optimization and Curvature Estimation with K-FAC in JAX.

  • jwave - A JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs.

  • jax_cosmo - A differentiable cosmology library in JAX.

  • jaxlie - Rigid transforms + Lie groups in JAX.

  • ott - Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.

  • XLB - XLB: Accelerated Lattice Boltzmann (XLB) for Physics-based ML.

  • EasyDeL - Accelerate, Optimize performance with streamlined training and serving options with JAX.

  • QDax - Accelerated Quality-Diversity.

  • paxml - Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.

  • econpizza - Solve nonlinear heterogeneous agent models.

  • fedjax - FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research.

  • neural-tangents - Fast and Easy Infinite Neural Networks in Python.

  • jax-fem - Differentiable Finite Element Method with JAX.

  • veros - The versatile ocean simulator, in pure Python, powered by JAX.

  • JAXFLUIDS - Differentiable Fluid Dynamics Package.

Up and Coming Libraries

  • traceax - Stochastic trace estimation using JAX.
  • graphax - Cross-Country Elimination in JAX.
  • cd_dynamax - Extension of dynamax repo to cases with continuous-time dynamics with measurements sampled at possibly irregular discrete times. Allows generic inference of dynamical systems parameters from partial noisy observations via auto-differentiable filtering, SGD, and HMC.

Inactive Libraries

  • Haiku - JAX-based neural network library.
  • jraph - A Graph Neural Network Library in Jax.
  • SymJAX - symbolic CPU/GPU/TPU programming.
  • coax - Modular framework for Reinforcement Learning in python.
  • eqxvision - A Python package of computer vision models for the Equinox ecosystem.
  • jaxfit - GPU/TPU accelerated nonlinear least-squares curve fitting using JAX.
  • safejax - Serialize JAX, Flax, Haiku, or Objax model params with πŸ€—safetensors.
  • kernex - Stencil computations in JAX.
  • lorax - LoRA for arbitrary JAX models and functions.
  • mcx - Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
  • einshape - DSL-based reshaping library for JAX and other frameworks.
  • jax-flows - Normalizing Flows in JAX 🌊.
  • sklearn-jax-kernels - Composable kernels for scikit-learn implemented in JAX.
  • deltapv - A photovoltaic simulator with automatic differentiation.
  • cr-sparse - Functional models and algorithms for sparse signal processing.
  • flaxvision - A selection of neural network models ported from torchvision for JAX & Flax.
  • imax - Image augmentation library for Jax.
  • jax-unirep - Reimplementation of the UniRep protein featurization model.
  • parallax - Immutable Torch Modules for JAX.
  • jax-resnet - Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).
  • elegy - A High Level API for Deep Learning in JAX.
  • objax - Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base.
  • jaxrl - JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with continuous action spaces.

Models and Projects

  • whisper-jax - JAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU.
  • esm2quinox - An implementation of ESM2 in Equinox+JAX.

Tutorials and Blog Posts

Videos

Community