Skip to content

Commit

Permalink
Expand docs (#30)
Browse files Browse the repository at this point in the history
* add docs

* setuptools

* readme
  • Loading branch information
lockwo authored Sep 15, 2024
1 parent 3a0d20f commit 4e4b106
Show file tree
Hide file tree
Showing 7 changed files with 210 additions and 8 deletions.
28 changes: 26 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

distreqx (pronounced "dist-rex") is a [JAX](https://github.com/google/jax)-based library providing implementations of distributions, bijectors, and tools for statistical and probabilistic machine learning with all benefits of jax (native GPU/TPU acceleration, differentiability, vectorization, distributing workloads, XLA compilation, etc.).

The origin of this repo is a reimplementation of [distrax](https://github.com/google-deepmind/distrax), (which is a subset of [TensorFlow Probability (TFP)](https://github.com/tensorflow/probability), with some new features and emphasis on jax compatibility) using [equinox](https://github.com/patrick-kidger/equinox). As a result, much of the original code/comments/documentation/tests are directly taken or adapted from distrax (original distrax copyright available at end of README.)
The origin of this package is a reimplementation of [distrax](https://github.com/google-deepmind/distrax), (which is a subset of [TensorFlow Probability (TFP)](https://github.com/tensorflow/probability), with some new features and emphasis on jax compatibility) using [equinox](https://github.com/patrick-kidger/equinox). As a result, much of the original code/comments/documentation/tests are directly taken or adapted from distrax (original distrax copyright available at end of README.)

Current features include:

Expand All @@ -13,6 +13,12 @@ Current features include:

## Installation

```
pip install distreqx
```

or

```
git clone https://github.com/lockwo/distreqx.git
cd distreqx
Expand All @@ -28,7 +34,17 @@ Available at https://lockwo.github.io/distreqx/.
## Quick example

```python
from distreqx import
from distreqx import distributions

key = jax.random.PRNGKey(1234)
mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.1, 0.2, 0.3])

dist = distributions.MultivariateNormalDiag(mu, sigma)

samples = dist_distrax.sample(key)

print(dist_distrax.log_prob(samples))
```

## Differences with Distrax
Expand All @@ -43,6 +59,12 @@ from distreqx import
If you found this library useful in academic research, please cite:

```bibtex
@software{lockwood2024distreqx,
title = {distreqx: Distributions and Bijectors in Jax},
author = {Owen Lockwood},
url = {https://github.com/lockwo/distreqx},
doi = {[tbd]},
}
```

(Also consider starring the project on GitHub.)
Expand All @@ -51,6 +73,8 @@ If you found this library useful in academic research, please cite:

[GPJax](https://github.com/JaxGaussianProcesses/GPJax): Gaussian processes in JAX.

[flowjax](https://github.com/danielward27/flowjax): Normalizing flows in JAX.

[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.

[Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
Expand Down
47 changes: 46 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,48 @@
# distreqx

Distrax + equinox
distreqx (pronounced "dist-rex") is a [JAX](https://github.com/google/jax)-based library providing implementations of distributions, bijectors, and tools for statistical and probabilistic machine learning with all benefits of jax (native GPU/TPU acceleration, differentiability, vectorization, distributing workloads, XLA compilation, etc.).

The origin of this package is a reimplementation of [distrax](https://github.com/google-deepmind/distrax), (which is a subset of [TensorFlow Probability (TFP)](https://github.com/tensorflow/probability), with some new features and emphasis on jax compatibility) using [equinox](https://github.com/patrick-kidger/equinox). As a result, much of the original code/comments/documentation/tests are directly taken or adapted from distrax (original distrax copyright available at end of README.)

Current features include:

- Probability distributions
- Bijectors


## Installation

```
git clone https://github.com/lockwo/distreqx.git
cd distreqx
pip install -e .
```

Requires Python 3.9+, JAX 0.4.11+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.0+.

## Documentation

Available at https://lockwo.github.io/distreqx/.

## Quick example

```python
from distreqx import distributions

key = jax.random.PRNGKey(1234)
mu = jnp.array([-1., 0., 1.])
sigma = jnp.array([0.1, 0.2, 0.3])

dist = distributions.MultivariateNormalDiag(mu, sigma)

samples = dist_distrax.sample(key)

print(dist_distrax.log_prob(samples))
```

## Differences with Distrax

- No official support/interoperability with TFP
- The concept of a batch dimension is dropped. If you want to operate on a batch, use `vmap` (note, this can be used in construction as well, e.g. [vmaping the construction](https://docs.kidger.site/equinox/tricks/#ensembling) of a `ScalarAffine`)
- Broader pytree enablement
- Strict [abstract/final](https://docs.kidger.site/equinox/pattern/) design pattern
4 changes: 4 additions & 0 deletions docs/misc/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ The simple answer to that question is "I tried". Distrax is a the product of a l
## Why use equinox?

The `Jittable` class is basically an equinox module (if you squint) and while we could reimplement a custom Module class (like GPJax does), why reinvent the wheel? Equinox is actively being developed and should it become inactive is still possible to maintain.

## What about flowjax?

When I started this project, I was unaware of [flowjax](https://github.com/danielward27/flowjax). Although flowjax does provide a lot of advanced tooling for NFs and bijections, there are notable differences. `distreqx` is less specialized and provides a broader baseline set of tools (e.g. distributions). flowjax has more advanced NF tools. `distreqx` also adheres to an abstract/final design pattern from the development side. flowjax also approaches the concept of "transformed" distributions in a different manner.
45 changes: 44 additions & 1 deletion examples/01_vae.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "fd0b1b93",
"metadata": {},
"source": [
"# Variational Autoencoder\n",
"\n",
"In this example we will be implementing a [variational autoencoder](https://arxiv.org/abs/1312.6114) using distreqx."
]
},
{
"cell_type": "code",
"execution_count": 81,
Expand All @@ -18,6 +28,14 @@
"from distreqx import distributions"
]
},
{
"cell_type": "markdown",
"id": "85dba433",
"metadata": {},
"source": [
"First, we need to create a standard small encoder and decoder module. The shapes are hard coded for the MNIST dataset we will be using."
]
},
{
"cell_type": "code",
"execution_count": 82,
Expand Down Expand Up @@ -63,6 +81,14 @@
" return logits"
]
},
{
"cell_type": "markdown",
"id": "5342d497",
"metadata": {},
"source": [
"Next we can construct the VAE object. It consists of an encoder and decoder, the encoder provides the mean and variance of the multivariate Gaussian prior. The output of the decoder represents the logits of a bernoulli distribution over the pixel space. Note that the `Independent` here is a bit of a legacy artifact. In general, `distreqx` encourages `vmap` based approaches to distributions and offloads any batching to the user. However, it is often possible to implicitly batch computations for certain disributions (sometimes even correctly). `Independent` is merely a helper that sums over dimensions, so even though we don't `vmap` over the bernoulli (like we often should), we can still sum over batch dimensions (since the event shape of a bernoulli is ())."
]
},
{
"cell_type": "code",
"execution_count": 83,
Expand Down Expand Up @@ -115,6 +141,14 @@
" return VAEOutput(variational_distrib, likelihood_distrib, image)"
]
},
{
"cell_type": "markdown",
"id": "162706d5",
"metadata": {},
"source": [
"Now we can train our model with the standard ELBO. Keep in mind, here we require some `vmap`ing over the distribution, since we now have an additional batch dimension (that we do not want to have `Independent` sum over)."
]
},
{
"cell_type": "code",
"execution_count": 84,
Expand Down Expand Up @@ -179,7 +213,8 @@
" prior_z = distributions.MultivariateNormalDiag(\n",
" loc=jnp.zeros(latent_size), scale_diag=jnp.ones(latent_size)\n",
" )\n",
" # we need to make surve to vmap over the distribution itself! (https://docs.kidger.site/equinox/tricks/#ensembling)\n",
" # we need to make surve to vmap over the distribution itself!\n",
" # see also: https://docs.kidger.site/equinox/tricks/#ensembling\n",
" log_likelihood = eqx.filter_vmap(lambda x, y: x.log_prob(y))(\n",
" outputs.likelihood_distrib, batch\n",
" )\n",
Expand Down Expand Up @@ -270,6 +305,14 @@
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "b5abd45f",
"metadata": {},
"source": [
"For such a small latent space, we can visualize a nice representation of the output. "
]
},
{
"cell_type": "code",
"execution_count": 91,
Expand Down
38 changes: 37 additions & 1 deletion examples/02_mixture_models.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "13775a10",
"metadata": {},
"source": [
"# Mixture Models\n",
"\n",
"In this tutorial, we will look at Gaussian and Bernoulli mixture models. These mixture models are defined by mixing distribution (categorical) which is responsible for defining the distribution over the component distributions. We will train these models to generatively model MNIST data sets. We can train them with both expectation maximization (EM) or gradient descent based approaches."
]
},
{
"cell_type": "code",
"execution_count": 1,
Expand Down Expand Up @@ -142,6 +152,14 @@
" plt.show()"
]
},
{
"cell_type": "markdown",
"id": "6e57e4a6",
"metadata": {},
"source": [
"Here we have some manual updates, this is in general not necessary, but can be helpful to have more direct control over the parameters (especially given the nature of modules to be very deep)."
]
},
{
"cell_type": "code",
"execution_count": 5,
Expand Down Expand Up @@ -237,6 +255,14 @@
" model.plot(5, 4)"
]
},
{
"cell_type": "markdown",
"id": "234c9de5",
"metadata": {},
"source": [
"Now lets plot the EM loss and the final generated images for the GMM."
]
},
{
"cell_type": "code",
"execution_count": 6,
Expand Down Expand Up @@ -275,6 +301,14 @@
"model.plot(4, 5)"
]
},
{
"cell_type": "markdown",
"id": "878b606d",
"metadata": {},
"source": [
"Now we can repeat the process, but with SGD this time instead of EM. Notice the different failure mode? Difficulties in training mixtures models with SGD are well known (and there are many variants of EM to help overcome these failure modes)."
]
},
{
"cell_type": "code",
"execution_count": 7,
Expand Down Expand Up @@ -434,7 +468,9 @@
"id": "21921308",
"metadata": {},
"source": [
"## Bernoulli Mixture Models"
"## Bernoulli Mixture Models\n",
"\n",
"Now we can repeat the exact process as before, but with the component distribution being a Bernoulli. Once you've completed this tutorial, try it with a different distribution and see how it works!"
]
},
{
Expand Down
52 changes: 51 additions & 1 deletion examples/03_normalizing_flow.ipynb
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "9397982b",
"metadata": {},
"source": [
"# Normalizing Flows\n",
"\n",
"In this tutorial, adapted from: https://gebob19.github.io/normalizing-flows/, we will implement a simple [RealNVP](https://arxiv.org/abs/1605.08803) normalizing flow. Normalizing flows are a class of generative models which are advantageous due to their explicit representation of densities and likelihoods, but come at a cost of requiring computable jacobian determinants and invertible layers. For an introduction to normalizing flows, see https://arxiv.org/abs/1912.02762."
]
},
{
"cell_type": "code",
"execution_count": 69,
Expand All @@ -23,7 +33,7 @@
"id": "267074ae",
"metadata": {},
"source": [
"adapted from: https://gebob19.github.io/normalizing-flows/"
"Let's define our simple dataset."
]
},
{
Expand Down Expand Up @@ -64,6 +74,14 @@
"plt.ylim(ylim)"
]
},
{
"cell_type": "markdown",
"id": "ce5c2cdc",
"metadata": {},
"source": [
"Now we can program our custom bijector."
]
},
{
"cell_type": "code",
"execution_count": 71,
Expand Down Expand Up @@ -152,6 +170,14 @@
" return type(other) is RNVP"
]
},
{
"cell_type": "markdown",
"id": "9b5c4c64",
"metadata": {},
"source": [
"Since we want to stack these together, we can use a chain bijector to accomplish this."
]
},
{
"cell_type": "code",
"execution_count": 72,
Expand All @@ -165,6 +191,14 @@
"bijector_chain = bijectors.Chain([RNVP(2, 1, i % 2, keys[i], 600) for i in range(n)])"
]
},
{
"cell_type": "markdown",
"id": "4a31df9b",
"metadata": {},
"source": [
"Flows map p(x) -> p(z) via a function F (samples are generated via F^-1(z)). In general, p(z) is chosen to have some tractable form for sampling and calculating log probabilities. A common choice is Gaussian, which we go with here."
]
},
{
"cell_type": "code",
"execution_count": 73,
Expand All @@ -177,6 +211,14 @@
"base_distribution_log_prob = eqx.filter_vmap(base_distribution.log_prob)"
]
},
{
"cell_type": "markdown",
"id": "5e65a110",
"metadata": {},
"source": [
"Here we plot the initial, untrained, samples."
]
},
{
"cell_type": "code",
"execution_count": 74,
Expand Down Expand Up @@ -296,6 +338,14 @@
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "0111f495",
"metadata": {},
"source": [
"After training we can plot both F(x) (to see where the true data ends up in our sampled space) and F^-1(z) to generate new samples."
]
},
{
"cell_type": "code",
"execution_count": 77,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ urls = {repository = "https://github.com/lockwo/distreqx"}
dependencies = ["jax>=0.4.11", "jaxtyping>=0.2.20", "equinox>=0.11.0"]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
requires = ["setuptools >= 61.0"]
build-backend = "setuptools.build_meta"

[tool.hatch.build]
include = ["distreqx/*"]
Expand Down

0 comments on commit 4e4b106

Please sign in to comment.