diff --git a/README.md b/README.md index e18b7f1..332878a 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ distreqx (pronounced "dist-rex") is a [JAX](https://github.com/google/jax)-based library providing implementations of distributions, bijectors, and tools for statistical and probabilistic machine learning with all benefits of jax (native GPU/TPU acceleration, differentiability, vectorization, distributing workloads, XLA compilation, etc.). -The origin of this repo is a reimplementation of [distrax](https://github.com/google-deepmind/distrax), (which is a subset of [TensorFlow Probability (TFP)](https://github.com/tensorflow/probability), with some new features and emphasis on jax compatibility) using [equinox](https://github.com/patrick-kidger/equinox). As a result, much of the original code/comments/documentation/tests are directly taken or adapted from distrax (original distrax copyright available at end of README.) +The origin of this package is a reimplementation of [distrax](https://github.com/google-deepmind/distrax), (which is a subset of [TensorFlow Probability (TFP)](https://github.com/tensorflow/probability), with some new features and emphasis on jax compatibility) using [equinox](https://github.com/patrick-kidger/equinox). As a result, much of the original code/comments/documentation/tests are directly taken or adapted from distrax (original distrax copyright available at end of README.) Current features include: @@ -13,6 +13,12 @@ Current features include: ## Installation +``` +pip install distreqx +``` + +or + ``` git clone https://github.com/lockwo/distreqx.git cd distreqx @@ -28,7 +34,17 @@ Available at https://lockwo.github.io/distreqx/. ## Quick example ```python -from distreqx import +from distreqx import distributions + +key = jax.random.PRNGKey(1234) +mu = jnp.array([-1., 0., 1.]) +sigma = jnp.array([0.1, 0.2, 0.3]) + +dist = distributions.MultivariateNormalDiag(mu, sigma) + +samples = dist_distrax.sample(key) + +print(dist_distrax.log_prob(samples)) ``` ## Differences with Distrax @@ -43,6 +59,12 @@ from distreqx import If you found this library useful in academic research, please cite: ```bibtex +@software{lockwood2024distreqx, + title = {distreqx: Distributions and Bijectors in Jax}, + author = {Owen Lockwood}, + url = {https://github.com/lockwo/distreqx}, + doi = {[tbd]}, +} ``` (Also consider starring the project on GitHub.) @@ -51,6 +73,8 @@ If you found this library useful in academic research, please cite: [GPJax](https://github.com/JaxGaussianProcesses/GPJax): Gaussian processes in JAX. +[flowjax](https://github.com/danielward27/flowjax): Normalizing flows in JAX. + [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares. [Lineax](https://github.com/patrick-kidger/lineax): linear solvers. diff --git a/docs/index.md b/docs/index.md index c3ccd2f..ca3b547 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,3 +1,48 @@ # distreqx -Distrax + equinox +distreqx (pronounced "dist-rex") is a [JAX](https://github.com/google/jax)-based library providing implementations of distributions, bijectors, and tools for statistical and probabilistic machine learning with all benefits of jax (native GPU/TPU acceleration, differentiability, vectorization, distributing workloads, XLA compilation, etc.). + +The origin of this package is a reimplementation of [distrax](https://github.com/google-deepmind/distrax), (which is a subset of [TensorFlow Probability (TFP)](https://github.com/tensorflow/probability), with some new features and emphasis on jax compatibility) using [equinox](https://github.com/patrick-kidger/equinox). As a result, much of the original code/comments/documentation/tests are directly taken or adapted from distrax (original distrax copyright available at end of README.) + +Current features include: + +- Probability distributions +- Bijectors + + +## Installation + +``` +git clone https://github.com/lockwo/distreqx.git +cd distreqx +pip install -e . +``` + +Requires Python 3.9+, JAX 0.4.11+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.0+. + +## Documentation + +Available at https://lockwo.github.io/distreqx/. + +## Quick example + +```python +from distreqx import distributions + +key = jax.random.PRNGKey(1234) +mu = jnp.array([-1., 0., 1.]) +sigma = jnp.array([0.1, 0.2, 0.3]) + +dist = distributions.MultivariateNormalDiag(mu, sigma) + +samples = dist_distrax.sample(key) + +print(dist_distrax.log_prob(samples)) +``` + +## Differences with Distrax + +- No official support/interoperability with TFP +- The concept of a batch dimension is dropped. If you want to operate on a batch, use `vmap` (note, this can be used in construction as well, e.g. [vmaping the construction](https://docs.kidger.site/equinox/tricks/#ensembling) of a `ScalarAffine`) +- Broader pytree enablement +- Strict [abstract/final](https://docs.kidger.site/equinox/pattern/) design pattern \ No newline at end of file diff --git a/docs/misc/faq.md b/docs/misc/faq.md index d511aa0..b4dba99 100644 --- a/docs/misc/faq.md +++ b/docs/misc/faq.md @@ -11,3 +11,7 @@ The simple answer to that question is "I tried". Distrax is a the product of a l ## Why use equinox? The `Jittable` class is basically an equinox module (if you squint) and while we could reimplement a custom Module class (like GPJax does), why reinvent the wheel? Equinox is actively being developed and should it become inactive is still possible to maintain. + +## What about flowjax? + +When I started this project, I was unaware of [flowjax](https://github.com/danielward27/flowjax). Although flowjax does provide a lot of advanced tooling for NFs and bijections, there are notable differences. `distreqx` is less specialized and provides a broader baseline set of tools (e.g. distributions). flowjax has more advanced NF tools. `distreqx` also adheres to an abstract/final design pattern from the development side. flowjax also approaches the concept of "transformed" distributions in a different manner. diff --git a/examples/01_vae.ipynb b/examples/01_vae.ipynb index 5d694ea..df2afe1 100644 --- a/examples/01_vae.ipynb +++ b/examples/01_vae.ipynb @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "fd0b1b93", + "metadata": {}, + "source": [ + "# Variational Autoencoder\n", + "\n", + "In this example we will be implementing a [variational autoencoder](https://arxiv.org/abs/1312.6114) using distreqx." + ] + }, { "cell_type": "code", "execution_count": 81, @@ -18,6 +28,14 @@ "from distreqx import distributions" ] }, + { + "cell_type": "markdown", + "id": "85dba433", + "metadata": {}, + "source": [ + "First, we need to create a standard small encoder and decoder module. The shapes are hard coded for the MNIST dataset we will be using." + ] + }, { "cell_type": "code", "execution_count": 82, @@ -63,6 +81,14 @@ " return logits" ] }, + { + "cell_type": "markdown", + "id": "5342d497", + "metadata": {}, + "source": [ + "Next we can construct the VAE object. It consists of an encoder and decoder, the encoder provides the mean and variance of the multivariate Gaussian prior. The output of the decoder represents the logits of a bernoulli distribution over the pixel space. Note that the `Independent` here is a bit of a legacy artifact. In general, `distreqx` encourages `vmap` based approaches to distributions and offloads any batching to the user. However, it is often possible to implicitly batch computations for certain disributions (sometimes even correctly). `Independent` is merely a helper that sums over dimensions, so even though we don't `vmap` over the bernoulli (like we often should), we can still sum over batch dimensions (since the event shape of a bernoulli is ())." + ] + }, { "cell_type": "code", "execution_count": 83, @@ -115,6 +141,14 @@ " return VAEOutput(variational_distrib, likelihood_distrib, image)" ] }, + { + "cell_type": "markdown", + "id": "162706d5", + "metadata": {}, + "source": [ + "Now we can train our model with the standard ELBO. Keep in mind, here we require some `vmap`ing over the distribution, since we now have an additional batch dimension (that we do not want to have `Independent` sum over)." + ] + }, { "cell_type": "code", "execution_count": 84, @@ -179,7 +213,8 @@ " prior_z = distributions.MultivariateNormalDiag(\n", " loc=jnp.zeros(latent_size), scale_diag=jnp.ones(latent_size)\n", " )\n", - " # we need to make surve to vmap over the distribution itself! (https://docs.kidger.site/equinox/tricks/#ensembling)\n", + " # we need to make surve to vmap over the distribution itself!\n", + " # see also: https://docs.kidger.site/equinox/tricks/#ensembling\n", " log_likelihood = eqx.filter_vmap(lambda x, y: x.log_prob(y))(\n", " outputs.likelihood_distrib, batch\n", " )\n", @@ -270,6 +305,14 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "b5abd45f", + "metadata": {}, + "source": [ + "For such a small latent space, we can visualize a nice representation of the output. " + ] + }, { "cell_type": "code", "execution_count": 91, diff --git a/examples/02_mixture_models.ipynb b/examples/02_mixture_models.ipynb index be11064..1c52036 100644 --- a/examples/02_mixture_models.ipynb +++ b/examples/02_mixture_models.ipynb @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "13775a10", + "metadata": {}, + "source": [ + "# Mixture Models\n", + "\n", + "In this tutorial, we will look at Gaussian and Bernoulli mixture models. These mixture models are defined by mixing distribution (categorical) which is responsible for defining the distribution over the component distributions. We will train these models to generatively model MNIST data sets. We can train them with both expectation maximization (EM) or gradient descent based approaches." + ] + }, { "cell_type": "code", "execution_count": 1, @@ -142,6 +152,14 @@ " plt.show()" ] }, + { + "cell_type": "markdown", + "id": "6e57e4a6", + "metadata": {}, + "source": [ + "Here we have some manual updates, this is in general not necessary, but can be helpful to have more direct control over the parameters (especially given the nature of modules to be very deep)." + ] + }, { "cell_type": "code", "execution_count": 5, @@ -237,6 +255,14 @@ " model.plot(5, 4)" ] }, + { + "cell_type": "markdown", + "id": "234c9de5", + "metadata": {}, + "source": [ + "Now lets plot the EM loss and the final generated images for the GMM." + ] + }, { "cell_type": "code", "execution_count": 6, @@ -275,6 +301,14 @@ "model.plot(4, 5)" ] }, + { + "cell_type": "markdown", + "id": "878b606d", + "metadata": {}, + "source": [ + "Now we can repeat the process, but with SGD this time instead of EM. Notice the different failure mode? Difficulties in training mixtures models with SGD are well known (and there are many variants of EM to help overcome these failure modes)." + ] + }, { "cell_type": "code", "execution_count": 7, @@ -434,7 +468,9 @@ "id": "21921308", "metadata": {}, "source": [ - "## Bernoulli Mixture Models" + "## Bernoulli Mixture Models\n", + "\n", + "Now we can repeat the exact process as before, but with the component distribution being a Bernoulli. Once you've completed this tutorial, try it with a different distribution and see how it works!" ] }, { diff --git a/examples/03_normalizing_flow.ipynb b/examples/03_normalizing_flow.ipynb index 6d99c32..ab28092 100644 --- a/examples/03_normalizing_flow.ipynb +++ b/examples/03_normalizing_flow.ipynb @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "9397982b", + "metadata": {}, + "source": [ + "# Normalizing Flows\n", + "\n", + "In this tutorial, adapted from: https://gebob19.github.io/normalizing-flows/, we will implement a simple [RealNVP](https://arxiv.org/abs/1605.08803) normalizing flow. Normalizing flows are a class of generative models which are advantageous due to their explicit representation of densities and likelihoods, but come at a cost of requiring computable jacobian determinants and invertible layers. For an introduction to normalizing flows, see https://arxiv.org/abs/1912.02762." + ] + }, { "cell_type": "code", "execution_count": 69, @@ -23,7 +33,7 @@ "id": "267074ae", "metadata": {}, "source": [ - "adapted from: https://gebob19.github.io/normalizing-flows/" + "Let's define our simple dataset." ] }, { @@ -64,6 +74,14 @@ "plt.ylim(ylim)" ] }, + { + "cell_type": "markdown", + "id": "ce5c2cdc", + "metadata": {}, + "source": [ + "Now we can program our custom bijector." + ] + }, { "cell_type": "code", "execution_count": 71, @@ -152,6 +170,14 @@ " return type(other) is RNVP" ] }, + { + "cell_type": "markdown", + "id": "9b5c4c64", + "metadata": {}, + "source": [ + "Since we want to stack these together, we can use a chain bijector to accomplish this." + ] + }, { "cell_type": "code", "execution_count": 72, @@ -165,6 +191,14 @@ "bijector_chain = bijectors.Chain([RNVP(2, 1, i % 2, keys[i], 600) for i in range(n)])" ] }, + { + "cell_type": "markdown", + "id": "4a31df9b", + "metadata": {}, + "source": [ + "Flows map p(x) -> p(z) via a function F (samples are generated via F^-1(z)). In general, p(z) is chosen to have some tractable form for sampling and calculating log probabilities. A common choice is Gaussian, which we go with here." + ] + }, { "cell_type": "code", "execution_count": 73, @@ -177,6 +211,14 @@ "base_distribution_log_prob = eqx.filter_vmap(base_distribution.log_prob)" ] }, + { + "cell_type": "markdown", + "id": "5e65a110", + "metadata": {}, + "source": [ + "Here we plot the initial, untrained, samples." + ] + }, { "cell_type": "code", "execution_count": 74, @@ -296,6 +338,14 @@ "plt.show()" ] }, + { + "cell_type": "markdown", + "id": "0111f495", + "metadata": {}, + "source": [ + "After training we can plot both F(x) (to see where the true data ends up in our sampled space) and F^-1(z) to generate new samples." + ] + }, { "cell_type": "code", "execution_count": 77, diff --git a/pyproject.toml b/pyproject.toml index 76a5bdc..5120a7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,8 +28,8 @@ urls = {repository = "https://github.com/lockwo/distreqx"} dependencies = ["jax>=0.4.11", "jaxtyping>=0.2.20", "equinox>=0.11.0"] [build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" +requires = ["setuptools >= 61.0"] +build-backend = "setuptools.build_meta" [tool.hatch.build] include = ["distreqx/*"]