From be48de1a45754a1ac8822ad9752898778838a782 Mon Sep 17 00:00:00 2001
From: Owen Lockwood <42878312+lockwo@users.noreply.github.com>
Date: Wed, 1 May 2024 10:37:15 -0600
Subject: [PATCH] documentation improvements (#3)
---
.github/workflows/deploy-docs.yml | 2 +-
README.md | 21 +++----
distreqx/bijectors/_bijector.py | 12 ++--
distreqx/bijectors/block.py | 1 +
distreqx/bijectors/chain.py | 1 +
distreqx/bijectors/shift.py | 5 +-
distreqx/bijectors/tanh.py | 3 +-
distreqx/distributions/_distribution.py | 16 ++---
distreqx/distributions/_transformed.py | 66 +++++++++------------
distreqx/distributions/mvn_from_bijector.py | 10 ++--
distreqx/utils/math.py | 60 +++++++++----------
docs/api/bijectors/_bijector.md | 9 ++-
docs/api/bijectors/_linear.md | 2 +-
docs/api/bijectors/tanh.md | 3 +-
docs/api/distributions/_distribution.md | 3 +-
docs/api/distributions/_transformed.md | 1 +
docs/api/distributions/mvn_from_bijector.md | 1 +
docs/api/distributions/normal.md | 1 +
docs/misc/faq.md | 13 ++++
mkdocs.yml | 6 +-
20 files changed, 128 insertions(+), 108 deletions(-)
create mode 100644 docs/misc/faq.md
diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml
index 90ab62f..eb898f7 100644
--- a/.github/workflows/deploy-docs.yml
+++ b/.github/workflows/deploy-docs.yml
@@ -26,7 +26,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install .
python -m pip install -r docs/requirements.txt
-
+ # https://github.com/mhausenblas/mkdocs-deploy-gh-pages/blob/master/action.sh
- name: Build docs
run: |
mkdocs build
diff --git a/README.md b/README.md
index 2a25647..5e03149 100644
--- a/README.md
+++ b/README.md
@@ -1,11 +1,11 @@
distreqx
Distrax + Equinox = distreqx. Easy Pytree probability distributions and bijectors.
-distreqx is a [JAX](https://github.com/google/jax)-based library providing implementations of a subset of [TensorFlow Probability (TFP)](https://github.com/tensorflow/probability), with some new features and emphasis on jax compatibility.
+distreqx (pronounced "dist-rex") is a [JAX](https://github.com/google/jax)-based library providing implementations of distributions, bijectors, and tools for statistical and probabilistic machine learning with all benefits of jax (native GPU/TPU acceleration, differentiability, vectorization, distributing workloads, XLA compilation, etc.).
-This is a largely as reimplementation of [distrax](https://github.com/google-deepmind/distrax) using [equinox](https://github.com/patrick-kidger/equinox), much of the code/comments/documentation/tests are directly taken or adapted from distrax so all credit to the DeepMind team.
+The origin of this repo is a reimplementation of [distrax](https://github.com/google-deepmind/distrax), (which is asubset of [TensorFlow Probability (TFP)](https://github.com/tensorflow/probability), with some new features and emphasis on jax compatibility) using [equinox](https://github.com/patrick-kidger/equinox). As a result, much of the original code/comments/documentation/tests are directly taken or adapted from distrax (original distrax copyright available at end of README.)
-Features include:
+Current features include:
- Probability distributions
- Bijectors
@@ -14,14 +14,16 @@ Features include:
## Installation
```
-pip install distreqx
+git clone https://github.com/lockwo/distreqx.git
+cd distreqx
+pip install -e .
```
-Requires Python 3.9+, JAX 0.4.13+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.0+.
+Requires Python 3.9+, JAX 0.4.11+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.0+.
## Documentation
-Available at .
+Available at https://lockwo.github.io/distreqx/.
## Quick example
@@ -31,9 +33,9 @@ from distreqx import
## Differences with Distrax
-- No support for TFP
-- Broader pytree support
+- No official support/interoperability with TFP
- The concept of a batch dimension is dropped. If you want to operate on a batch, use `vmap` (note, this can be used in construction as well, e.g. [vmaping the construction](https://docs.kidger.site/equinox/tricks/#ensembling) of a `ScalarAffine`)
+- Broader pytree enablement
## Citation
@@ -46,8 +48,7 @@ If you found this library useful in academic research, please cite:
## See also: other libraries in the JAX ecosystem
-[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
-[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.
+[GPJax](https://github.com/JaxGaussianProcesses/GPJax): Gaussian processes in JAX.
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
[Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
diff --git a/distreqx/bijectors/_bijector.py b/distreqx/bijectors/_bijector.py
index cfb392b..3cbdf26 100644
--- a/distreqx/bijectors/_bijector.py
+++ b/distreqx/bijectors/_bijector.py
@@ -63,34 +63,34 @@ def __init__(
self._is_constant_log_det = is_constant_log_det
def forward(self, x: PyTree) -> PyTree:
- """Computes y = f(x)."""
+ R"""Computes $y = f(x)$."""
y, _ = self.forward_and_log_det(x)
return y
def inverse(self, y: PyTree) -> PyTree:
- """Computes x = f^{-1}(y)."""
+ r"""Computes $x = f^{-1}(y)$."""
x, _ = self.inverse_and_log_det(y)
return x
def forward_log_det_jacobian(self, x: PyTree) -> PyTree:
- """Computes log|det J(f)(x)|."""
+ r"""Computes $\log|\det J(f)(x)|$."""
_, logdet = self.forward_and_log_det(x)
return logdet
def inverse_log_det_jacobian(self, y: PyTree) -> PyTree:
- """Computes log|det J(f^{-1})(y)|."""
+ r"""Computes $\log|\det J(f^{-1})(y)|$."""
_, logdet = self.inverse_and_log_det(y)
return logdet
@abstractmethod
def forward_and_log_det(self, x: PyTree) -> Tuple[PyTree, PyTree]:
- """Computes y = f(x) and log|det J(f)(x)|."""
+ r"""Computes $y = f(x)$ and $\log|\det J(f)(x)|$."""
raise NotImplementedError(
f"Bijector {self.name} does not implement `forward_and_log_det`."
)
def inverse_and_log_det(self, y: Array) -> Tuple[PyTree, PyTree]:
- """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
+ r"""Computes $x = f^{-1}(y)$ and $\log|\det J(f^{-1})(y)|$."""
raise NotImplementedError(
f"Bijector {self.name} does not implement `inverse_and_log_det`."
)
diff --git a/distreqx/bijectors/block.py b/distreqx/bijectors/block.py
index 62ddf39..bbc4749 100644
--- a/distreqx/bijectors/block.py
+++ b/distreqx/bijectors/block.py
@@ -38,6 +38,7 @@ def __init__(self, bijector: AbstractBijector, ndims: int):
"""Initializes a Block.
**Arguments:**
+
- `bijector`: the bijector to be promoted to a block bijector. It can be a
distreqx bijector or a callable to be wrapped by `Lambda`.
- `ndims`: number of dimensions to promote to event dimensions.
diff --git a/distreqx/bijectors/chain.py b/distreqx/bijectors/chain.py
index 0ed7eea..3d2a98f 100644
--- a/distreqx/bijectors/chain.py
+++ b/distreqx/bijectors/chain.py
@@ -36,6 +36,7 @@ def __init__(self, bijectors: Sequence[AbstractBijector]):
"""Initializes a Chain bijector.
**Arguments:**
+
- `bijectors`: a sequence of bijectors to be composed into one. Each bijector
can be a distreqx bijector or a callable to be wrapped
by `Lambda`. The sequence must contain at least one bijector.
diff --git a/distreqx/bijectors/shift.py b/distreqx/bijectors/shift.py
index 68d9533..d0d68d9 100644
--- a/distreqx/bijectors/shift.py
+++ b/distreqx/bijectors/shift.py
@@ -22,8 +22,9 @@ class Shift(ScalarAffine):
def __init__(self, shift: Array):
"""Initializes a `Shift` bijector.
- Args:
- shift: the bijector's shift parameter. Can also be batched.
+ **Arguments:**
+
+ - `shift`: the bijector's shift parameter.
"""
super().__init__(shift=shift)
diff --git a/distreqx/bijectors/tanh.py b/distreqx/bijectors/tanh.py
index 4292e70..a225380 100644
--- a/distreqx/bijectors/tanh.py
+++ b/distreqx/bijectors/tanh.py
@@ -29,7 +29,8 @@ class Tanh(AbstractBijector):
instead of `sample` followed by `log_prob`.
"""
- def __init__(self):
+ def __init__(self) -> None:
+ """Initialize the TanH bijector."""
super().__init__()
def forward_log_det_jacobian(self, x: Array) -> Array:
diff --git a/distreqx/distributions/_distribution.py b/distreqx/distributions/_distribution.py
index 29b36e0..9041c62 100644
--- a/distreqx/distributions/_distribution.py
+++ b/distreqx/distributions/_distribution.py
@@ -32,7 +32,7 @@ def sample_and_log_prob(
**Returns:**
- A tuple of a sample and their log probs.
+ - A tuple of a sample and their log probs.
"""
samples = self.sample(key)
log_prob = self.log_prob(samples)
@@ -48,7 +48,7 @@ def log_prob(self, value: PyTree[Array]) -> PyTree[Array]:
**Returns:**
- The log probability log P(value).
+ - The log probability log P(value).
"""
raise NotImplementedError
@@ -78,7 +78,7 @@ def prob(self, value: PyTree[Array]) -> PyTree[Array]:
**Returns:**
- The probability P(value).
+ - The probability P(value).
"""
return jnp.exp(self.log_prob(value))
@@ -109,7 +109,7 @@ def cdf(self, value: PyTree[Array]) -> PyTree[Array]:
**Returns:**
- The CDF evaluated at value, i.e. P[X <= value].
+ - The CDF evaluated at value, i.e. P[X <= value].
"""
return jnp.exp(self.log_cdf(value))
@@ -127,7 +127,7 @@ def survival_function(self, value: PyTree[Array]) -> PyTree[Array]:
**Returns:**
- The survival function evaluated at `value`, i.e. P[X > value]
+ - The survival function evaluated at `value`, i.e. P[X > value]
"""
return 1.0 - self.cdf(value)
@@ -145,7 +145,7 @@ def log_survival_function(self, value: PyTree[Array]) -> PyTree[Array]:
**Returns:**
- The log of the survival function evaluated at `value`, i.e.
+ - The log of the survival function evaluated at `value`, i.e.
log P[X > value]
"""
return jnp.log1p(-self.cdf(value))
@@ -188,7 +188,7 @@ def kl_divergence(self, other_dist, **kwargs) -> PyTree[Array]:
**Returns:**
- The KL divergence `KL(self || other_dist)`.
+ - The KL divergence `KL(self || other_dist)`.
"""
raise NotImplementedError(
f"Distribution `{self.name}` does not implement `kl_divergence`."
@@ -204,6 +204,6 @@ def cross_entropy(self, other_dist, **kwargs) -> Array:
**Returns:**
- The cross entropy `H(self || other_dist)`.
+ - The cross entropy `H(self || other_dist)`.
"""
return self.kl_divergence(other_dist, **kwargs) + self.entropy()
diff --git a/distreqx/distributions/_transformed.py b/distreqx/distributions/_transformed.py
index e80d98c..a63bc14 100644
--- a/distreqx/distributions/_transformed.py
+++ b/distreqx/distributions/_transformed.py
@@ -111,7 +111,7 @@ def sample_and_log_prob(self, key: PRNGKeyArray) -> Tuple[Array, Array]:
**Returns:**
- A tuple of a sample and its log probs.
+ - A tuple of a sample and its log probs.
"""
x, lp_x = self.distribution.sample_and_log_prob(key)
y, fldj = self.bijector.forward_and_log_det(x)
@@ -153,7 +153,7 @@ def entropy(self, input_hint: Optional[Array] = None) -> Array:
**Returns:**
- The entropy of the distribution.
+ - The entropy of the distribution.
**Raises:**
@@ -175,16 +175,37 @@ def entropy(self, input_hint: Optional[Array] = None) -> Array:
)
def kl_divergence(self, other_dist, **kwargs) -> Array:
- """Calculates the KL divergence to another distribution.
+ """Obtains the KL divergence between two Transformed distributions.
+
+ This computes the KL divergence between two Transformed distributions with the
+ same bijector. If the two Transformed distributions do not have the same
+ bijector, an error is raised. To determine if the bijectors are equal, this
+ method proceeds as follows:
+ - If both bijectors are the same instance of a distreqx bijector, then they are
+ declared equal.
+ - If not the same instance, we check if they are equal according to their
+ `same_as` predicate.
+ - Otherwise, the string representation of the Jaxpr of the `forward` method
+ of each bijector is compared. If both string representations are equal, the
+ bijectors are declared equal.
+ - Otherwise, the bijectors cannot be guaranteed to be equal and an error is
+ raised.
**Arguments:**
- - `other_dist`: A compatible disteqx distribution.
- - `kwargs`: Additional kwargs, can accept an `input_hint`.
+ - `other_dist`: A Transformed distribution.
+ - `input_hint`: keyword argument, an example sample from the base distribution,
+ used to trace the `forward` method. If not specified, it is computed using
+ a zero array of the shape and dtype of a sample from the base distribution.
**Returns:**
- The KL divergence `KL(self || other_dist)`.
+ - `KL(dist1 || dist2)`.
+
+ **Raises:**
+
+ - `NotImplementedError`: If bijectors are not known to be equal.
+ - `ValueError`: If the base distributions do not have the same `event_shape`.
"""
return _kl_divergence_transformed_transformed(self, other_dist, **kwargs)
@@ -196,39 +217,6 @@ def _kl_divergence_transformed_transformed(
input_hint: Optional[Array] = None,
**unused_kwargs,
) -> Array:
- """Obtains the KL divergence between two Transformed distributions.
-
- This computes the KL divergence between two Transformed distributions with the
- same bijector. If the two Transformed distributions do not have the same
- bijector, an error is raised. To determine if the bijectors are equal, this
- method proceeds as follows:
- - If both bijectors are the same instance of a distreqx bijector, then they are
- declared equal.
- - If not the same instance, we check if they are equal according to their
- `same_as` predicate.
- - Otherwise, the string representation of the Jaxpr of the `forward` method
- of each bijector is compared. If both string representations are equal, the
- bijectors are declared equal.
- - Otherwise, the bijectors cannot be guaranteed to be equal and an error is
- raised.
-
- **Arguments:**
-
- - `dist1`: A Transformed distribution.
- - `dist2`: A Transformed distribution.
- - `input_hint`: an example sample from the base distribution, used to trace the
- `forward` method. If not specified, it is computed using a zero array of
- the shape and dtype of a sample from the base distribution.
-
- **Returns:**
-
- `KL(dist1 || dist2)`.
-
- **Raises:**
-
- - `NotImplementedError`: If bijectors are not known to be equal.
- - `ValueError`: If the base distributions do not have the same `event_shape`.
- """
if dist1.distribution.event_shape != dist2.distribution.event_shape:
raise ValueError(
f"The two base distributions do not have the same event shape: "
diff --git a/distreqx/distributions/mvn_from_bijector.py b/distreqx/distributions/mvn_from_bijector.py
index aecd74e..dd613a7 100644
--- a/distreqx/distributions/mvn_from_bijector.py
+++ b/distreqx/distributions/mvn_from_bijector.py
@@ -177,11 +177,13 @@ def _kl_divergence_mvn_mvn(
"""Divergence KL(dist1 || dist2) between multivariate normal distributions.
**Arguments:**
- dist1: A multivariate normal distribution.
- dist2: A multivariate normal distribution.
- Returns:
- Batchwise `KL(dist1 || dist2)`.
+ - `dist1`: A multivariate normal distribution.
+ - `dist2`: A multivariate normal distribution.
+
+ **Returns:**
+
+ - `KL(dist1 || dist2)`.
"""
num_dims = dist1.event_shape[-1]
diff --git a/distreqx/utils/math.py b/distreqx/utils/math.py
index a0d47ae..5b752ad 100644
--- a/distreqx/utils/math.py
+++ b/distreqx/utils/math.py
@@ -14,16 +14,16 @@ def multiply_no_nan(x: Array, y: Array) -> Array:
**Arguments:**
- - `x`: First input.
- - `y`: Second input.
+ - `x`: First input.
+ - `y`: Second input.
**Returns:**
- The product of `x` and `y`.
+ - The product of `x` and `y`.
**Raises:**
- ValueError if the shapes of `x` and `y` do not match.
+ - ValueError if the shapes of `x` and `y` do not match.
"""
dtype = jnp.result_type(x, y)
return jnp.where(y == 0, jnp.zeros((), dtype=dtype), x * y)
@@ -37,12 +37,12 @@ def multiply_no_nan_jvp(
**Arguments:**
- - `primals`: A tuple containing the primal values of `x` and `y`.
- - `tangents`: A tuple containing the tangent values of `x` and `y`.
+ - `primals`: A tuple containing the primal values of `x` and `y`.
+ - `tangents`: A tuple containing the tangent values of `x` and `y`.
**Returns:**
- A tuple containing the output of the primal and tangent operations.
+ - A tuple containing the output of the primal and tangent operations.
"""
x, y = primals
x_dot, y_dot = tangents
@@ -58,12 +58,12 @@ def power_no_nan(x: Array, y: Array) -> Array:
**Arguments:**
- - `x`: First input.
- - `y`: Second input.
+ - `x`: First input.
+ - `y`: Second input.
**Returns:**
- The power `x ** y`.
+ - The power `x ** y`.
"""
dtype = jnp.result_type(x, y)
return jnp.where(y == 0, jnp.ones((), dtype=dtype), jnp.power(x, y))
@@ -77,12 +77,12 @@ def power_no_nan_jvp(
**Arguments:**
- - `primals`: A tuple containing the primal values of `x` and `y`.
- - `tangents`: A tuple containing the tangent values of `x` and `y`.
+ - `primals`: A tuple containing the primal values of `x` and `y`.
+ - `tangents`: A tuple containing the tangent values of `x` and `y`.
**Returns:**
- A tuple containing the output of the primal and tangent operations.
+ - A tuple containing the output of the primal and tangent operations.
"""
x, y = primals
x_dot, y_dot = tangents
@@ -96,12 +96,12 @@ def mul_exp(x: Array, logp: Array) -> Array:
**Arguments:**
- - `x`: An array.
- - `logp`: An array representing logarithms.
+ - `x`: An array.
+ - `logp`: An array representing logarithms.
**Returns:**
- The result of `x * exp(logp)`.
+ - The result of `x * exp(logp)`.
"""
p = jnp.exp(logp)
x = jnp.where(p == 0, 0.0, x)
@@ -115,12 +115,12 @@ def normalize(
**Arguments:**
- - `probs`: Probability values.
- - `logits`: Logit values.
+ - `probs`: Probability values.
+ - `logits`: Logit values.
**Returns:**
- Normalized probabilities or logits.
+ - Normalized probabilities or logits.
"""
if logits is None:
if probs is None:
@@ -139,12 +139,12 @@ def sum_last(x: Array, ndims: int) -> Array:
**Arguments:**
- - `x`: An array.
- - `ndims`: The number of last dimensions to sum.
+ - `x`: An array.
+ - `ndims`: The number of last dimensions to sum.
**Returns:**
- The sum of the last `ndims` dimensions of `x`.
+ - The sum of the last `ndims` dimensions of `x`.
"""
axes_to_sum = tuple(range(-ndims, 0))
return jnp.sum(x, axis=axes_to_sum)
@@ -155,12 +155,12 @@ def log_expbig_minus_expsmall(big: Array, small: Array) -> Array:
**Arguments:**
- - `big`: First input.
- - `small`: Second input. It must be `small <= big`.
+ - `big`: First input.
+ - `small`: Second input. It must be `small <= big`.
**Returns:**
- The resulting `log(exp(big) - exp(small))`.
+ - The resulting `log(exp(big) - exp(small))`.
"""
return big + jnp.log1p(-jnp.exp(small - big))
@@ -170,12 +170,12 @@ def log_beta(a: Array, b: Array) -> Array:
**Arguments:**
- - `a`: First input. It must be positive.
- - `b`: Second input. It must be positive.
+ - `a`: First input. It must be positive.
+ - `b`: Second input. It must be positive.
**Returns:**
- The value `log B(a, b) = log Gamma(a) + log Gamma(b) - log Gamma(a + b)`,
+ - The value `log B(a, b) = log Gamma(a) + log Gamma(b) - log Gamma(a + b)`,
where `Gamma` is the Gamma function, obtained through stable
computation of `log Gamma`.
"""
@@ -187,11 +187,11 @@ def log_beta_multivariate(a: Array) -> Array:
**Arguments:**
- - `a`: An array of length `K` containing positive values.
+ - `a`: An array of length `K` containing positive values.
**Returns:**
- The value
+ - The value
`log B(a) = sum_{k=1}^{K} log Gamma(a_k) - log Gamma(sum_{k=1}^{K} a_k)`,
where `Gamma` is the Gamma function, obtained through stable
computation of `log Gamma`.
diff --git a/docs/api/bijectors/_bijector.md b/docs/api/bijectors/_bijector.md
index 382875a..eb34a7b 100644
--- a/docs/api/bijectors/_bijector.md
+++ b/docs/api/bijectors/_bijector.md
@@ -1,7 +1,14 @@
-# Base Bijector Class
+# Base Bijector
::: distreqx.bijectors._bijector.AbstractBijector
selection:
members:
- __init__
+ - forward
+ - inverse
+ - forward_log_det_jacobian
+ - inverse_log_det_jacobian
+ - forward_and_log_det
+ - inverse_and_log_det
+ - same_as
---
\ No newline at end of file
diff --git a/docs/api/bijectors/_linear.md b/docs/api/bijectors/_linear.md
index efd722b..9a42931 100644
--- a/docs/api/bijectors/_linear.md
+++ b/docs/api/bijectors/_linear.md
@@ -1,4 +1,4 @@
-# Base Linear Bijector Class
+# Base Linear Bijector
::: distreqx.bijectors._linear.AbstractLinearBijector
selection:
diff --git a/docs/api/bijectors/tanh.md b/docs/api/bijectors/tanh.md
index 0e0c6b9..b24c8b7 100644
--- a/docs/api/bijectors/tanh.md
+++ b/docs/api/bijectors/tanh.md
@@ -2,5 +2,6 @@
::: distreqx.bijectors.tanh.Tanh
selection:
- members: None
+ members:
+ - __init__
---
\ No newline at end of file
diff --git a/docs/api/distributions/_distribution.md b/docs/api/distributions/_distribution.md
index 0b6f169..06c5dc9 100644
--- a/docs/api/distributions/_distribution.md
+++ b/docs/api/distributions/_distribution.md
@@ -1,4 +1,4 @@
-# Base Distribution Class
+# Base Distribution
::: distreqx.distributions._distribution.AbstractDistribution
selection:
@@ -9,5 +9,4 @@
- cdf
- survival_function
- log_survival_function
- - stddev
---
\ No newline at end of file
diff --git a/docs/api/distributions/_transformed.md b/docs/api/distributions/_transformed.md
index 6d9887f..36ac378 100644
--- a/docs/api/distributions/_transformed.md
+++ b/docs/api/distributions/_transformed.md
@@ -6,4 +6,5 @@
- __init__
- sample_and_log_prob
- entropy
+ - kl_divergence
---
\ No newline at end of file
diff --git a/docs/api/distributions/mvn_from_bijector.md b/docs/api/distributions/mvn_from_bijector.md
index 61b62f2..bb7985d 100644
--- a/docs/api/distributions/mvn_from_bijector.md
+++ b/docs/api/distributions/mvn_from_bijector.md
@@ -4,4 +4,5 @@
selection:
members:
- __init__
+ - covariance
---
\ No newline at end of file
diff --git a/docs/api/distributions/normal.md b/docs/api/distributions/normal.md
index 657c67f..02ac189 100644
--- a/docs/api/distributions/normal.md
+++ b/docs/api/distributions/normal.md
@@ -4,4 +4,5 @@
selection:
members:
- __init__
+ - entropy
---
\ No newline at end of file
diff --git a/docs/misc/faq.md b/docs/misc/faq.md
new file mode 100644
index 0000000..d511aa0
--- /dev/null
+++ b/docs/misc/faq.md
@@ -0,0 +1,13 @@
+# FAQ
+
+## Why not just use distrax?
+
+The simple answer to that question is "I tried". Distrax is a the product of a lot of great work, especially helpful for working with TFP, but in the current era of jax packages lacks important elements:
+
+- It's only semi-maintained (there have been no responses to any issues in the last >6 months)
+- It doesn't always play nice with other jax packages and can be slow (see: [#193](https://github.com/google-deepmind/distrax/issues/193), [#383](https://github.com/patrick-kidger/diffrax/issues/383), [#252](https://github.com/patrick-kidger/equinox/issues/252), [#269](https://github.com/patrick-kidger/equinox/issues/269), [#16](https://github.com/JaxGaussianProcesses/JaxUtils/issues/16), [#16170](https://github.com/google/jax/issues/16170))
+- You need Tensorflow to use it
+
+## Why use equinox?
+
+The `Jittable` class is basically an equinox module (if you squint) and while we could reimplement a custom Module class (like GPJax does), why reinvent the wheel? Equinox is actively being developed and should it become inactive is still possible to maintain.
diff --git a/mkdocs.yml b/mkdocs.yml
index 9babf77..b07f47f 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -90,6 +90,8 @@ plugins:
nav:
- 'index.md'
+ - Examples:
+ - Binary MNIST VAE: 'examples/01_vae.ipynb'
- API:
- Distributions:
- 'api/distributions/_distribution.md'
@@ -111,5 +113,5 @@ nav:
- 'api/bijectors/tanh.md'
- Utilities:
- 'api/utils/math.md'
- - Examples:
- - Binary MNIST VAE: 'examples/01_vae.ipynb'
\ No newline at end of file
+ - Further Details:
+ - 'misc/faq.md'
\ No newline at end of file