From be48de1a45754a1ac8822ad9752898778838a782 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Wed, 1 May 2024 10:37:15 -0600 Subject: [PATCH] documentation improvements (#3) --- .github/workflows/deploy-docs.yml | 2 +- README.md | 21 +++---- distreqx/bijectors/_bijector.py | 12 ++-- distreqx/bijectors/block.py | 1 + distreqx/bijectors/chain.py | 1 + distreqx/bijectors/shift.py | 5 +- distreqx/bijectors/tanh.py | 3 +- distreqx/distributions/_distribution.py | 16 ++--- distreqx/distributions/_transformed.py | 66 +++++++++------------ distreqx/distributions/mvn_from_bijector.py | 10 ++-- distreqx/utils/math.py | 60 +++++++++---------- docs/api/bijectors/_bijector.md | 9 ++- docs/api/bijectors/_linear.md | 2 +- docs/api/bijectors/tanh.md | 3 +- docs/api/distributions/_distribution.md | 3 +- docs/api/distributions/_transformed.md | 1 + docs/api/distributions/mvn_from_bijector.md | 1 + docs/api/distributions/normal.md | 1 + docs/misc/faq.md | 13 ++++ mkdocs.yml | 6 +- 20 files changed, 128 insertions(+), 108 deletions(-) create mode 100644 docs/misc/faq.md diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml index 90ab62f..eb898f7 100644 --- a/.github/workflows/deploy-docs.yml +++ b/.github/workflows/deploy-docs.yml @@ -26,7 +26,7 @@ jobs: python -m pip install --upgrade pip python -m pip install . python -m pip install -r docs/requirements.txt - + # https://github.com/mhausenblas/mkdocs-deploy-gh-pages/blob/master/action.sh - name: Build docs run: | mkdocs build diff --git a/README.md b/README.md index 2a25647..5e03149 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,11 @@

distreqx

Distrax + Equinox = distreqx. Easy Pytree probability distributions and bijectors.

-distreqx is a [JAX](https://github.com/google/jax)-based library providing implementations of a subset of [TensorFlow Probability (TFP)](https://github.com/tensorflow/probability), with some new features and emphasis on jax compatibility. +distreqx (pronounced "dist-rex") is a [JAX](https://github.com/google/jax)-based library providing implementations of distributions, bijectors, and tools for statistical and probabilistic machine learning with all benefits of jax (native GPU/TPU acceleration, differentiability, vectorization, distributing workloads, XLA compilation, etc.). -This is a largely as reimplementation of [distrax](https://github.com/google-deepmind/distrax) using [equinox](https://github.com/patrick-kidger/equinox), much of the code/comments/documentation/tests are directly taken or adapted from distrax so all credit to the DeepMind team. +The origin of this repo is a reimplementation of [distrax](https://github.com/google-deepmind/distrax), (which is asubset of [TensorFlow Probability (TFP)](https://github.com/tensorflow/probability), with some new features and emphasis on jax compatibility) using [equinox](https://github.com/patrick-kidger/equinox). As a result, much of the original code/comments/documentation/tests are directly taken or adapted from distrax (original distrax copyright available at end of README.) -Features include: +Current features include: - Probability distributions - Bijectors @@ -14,14 +14,16 @@ Features include: ## Installation ``` -pip install distreqx +git clone https://github.com/lockwo/distreqx.git +cd distreqx +pip install -e . ``` -Requires Python 3.9+, JAX 0.4.13+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.0+. +Requires Python 3.9+, JAX 0.4.11+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.0+. ## Documentation -Available at . +Available at https://lockwo.github.io/distreqx/. ## Quick example @@ -31,9 +33,9 @@ from distreqx import ## Differences with Distrax -- No support for TFP -- Broader pytree support +- No official support/interoperability with TFP - The concept of a batch dimension is dropped. If you want to operate on a batch, use `vmap` (note, this can be used in construction as well, e.g. [vmaping the construction](https://docs.kidger.site/equinox/tricks/#ensembling) of a `ScalarAffine`) +- Broader pytree enablement ## Citation @@ -46,8 +48,7 @@ If you found this library useful in academic research, please cite: ## See also: other libraries in the JAX ecosystem -[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX! -[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays. +[GPJax](https://github.com/JaxGaussianProcesses/GPJax): Gaussian processes in JAX. [Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares. [Lineax](https://github.com/patrick-kidger/lineax): linear solvers. [sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent. diff --git a/distreqx/bijectors/_bijector.py b/distreqx/bijectors/_bijector.py index cfb392b..3cbdf26 100644 --- a/distreqx/bijectors/_bijector.py +++ b/distreqx/bijectors/_bijector.py @@ -63,34 +63,34 @@ def __init__( self._is_constant_log_det = is_constant_log_det def forward(self, x: PyTree) -> PyTree: - """Computes y = f(x).""" + R"""Computes $y = f(x)$.""" y, _ = self.forward_and_log_det(x) return y def inverse(self, y: PyTree) -> PyTree: - """Computes x = f^{-1}(y).""" + r"""Computes $x = f^{-1}(y)$.""" x, _ = self.inverse_and_log_det(y) return x def forward_log_det_jacobian(self, x: PyTree) -> PyTree: - """Computes log|det J(f)(x)|.""" + r"""Computes $\log|\det J(f)(x)|$.""" _, logdet = self.forward_and_log_det(x) return logdet def inverse_log_det_jacobian(self, y: PyTree) -> PyTree: - """Computes log|det J(f^{-1})(y)|.""" + r"""Computes $\log|\det J(f^{-1})(y)|$.""" _, logdet = self.inverse_and_log_det(y) return logdet @abstractmethod def forward_and_log_det(self, x: PyTree) -> Tuple[PyTree, PyTree]: - """Computes y = f(x) and log|det J(f)(x)|.""" + r"""Computes $y = f(x)$ and $\log|\det J(f)(x)|$.""" raise NotImplementedError( f"Bijector {self.name} does not implement `forward_and_log_det`." ) def inverse_and_log_det(self, y: Array) -> Tuple[PyTree, PyTree]: - """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" + r"""Computes $x = f^{-1}(y)$ and $\log|\det J(f^{-1})(y)|$.""" raise NotImplementedError( f"Bijector {self.name} does not implement `inverse_and_log_det`." ) diff --git a/distreqx/bijectors/block.py b/distreqx/bijectors/block.py index 62ddf39..bbc4749 100644 --- a/distreqx/bijectors/block.py +++ b/distreqx/bijectors/block.py @@ -38,6 +38,7 @@ def __init__(self, bijector: AbstractBijector, ndims: int): """Initializes a Block. **Arguments:** + - `bijector`: the bijector to be promoted to a block bijector. It can be a distreqx bijector or a callable to be wrapped by `Lambda`. - `ndims`: number of dimensions to promote to event dimensions. diff --git a/distreqx/bijectors/chain.py b/distreqx/bijectors/chain.py index 0ed7eea..3d2a98f 100644 --- a/distreqx/bijectors/chain.py +++ b/distreqx/bijectors/chain.py @@ -36,6 +36,7 @@ def __init__(self, bijectors: Sequence[AbstractBijector]): """Initializes a Chain bijector. **Arguments:** + - `bijectors`: a sequence of bijectors to be composed into one. Each bijector can be a distreqx bijector or a callable to be wrapped by `Lambda`. The sequence must contain at least one bijector. diff --git a/distreqx/bijectors/shift.py b/distreqx/bijectors/shift.py index 68d9533..d0d68d9 100644 --- a/distreqx/bijectors/shift.py +++ b/distreqx/bijectors/shift.py @@ -22,8 +22,9 @@ class Shift(ScalarAffine): def __init__(self, shift: Array): """Initializes a `Shift` bijector. - Args: - shift: the bijector's shift parameter. Can also be batched. + **Arguments:** + + - `shift`: the bijector's shift parameter. """ super().__init__(shift=shift) diff --git a/distreqx/bijectors/tanh.py b/distreqx/bijectors/tanh.py index 4292e70..a225380 100644 --- a/distreqx/bijectors/tanh.py +++ b/distreqx/bijectors/tanh.py @@ -29,7 +29,8 @@ class Tanh(AbstractBijector): instead of `sample` followed by `log_prob`. """ - def __init__(self): + def __init__(self) -> None: + """Initialize the TanH bijector.""" super().__init__() def forward_log_det_jacobian(self, x: Array) -> Array: diff --git a/distreqx/distributions/_distribution.py b/distreqx/distributions/_distribution.py index 29b36e0..9041c62 100644 --- a/distreqx/distributions/_distribution.py +++ b/distreqx/distributions/_distribution.py @@ -32,7 +32,7 @@ def sample_and_log_prob( **Returns:** - A tuple of a sample and their log probs. + - A tuple of a sample and their log probs. """ samples = self.sample(key) log_prob = self.log_prob(samples) @@ -48,7 +48,7 @@ def log_prob(self, value: PyTree[Array]) -> PyTree[Array]: **Returns:** - The log probability log P(value). + - The log probability log P(value). """ raise NotImplementedError @@ -78,7 +78,7 @@ def prob(self, value: PyTree[Array]) -> PyTree[Array]: **Returns:** - The probability P(value). + - The probability P(value). """ return jnp.exp(self.log_prob(value)) @@ -109,7 +109,7 @@ def cdf(self, value: PyTree[Array]) -> PyTree[Array]: **Returns:** - The CDF evaluated at value, i.e. P[X <= value]. + - The CDF evaluated at value, i.e. P[X <= value]. """ return jnp.exp(self.log_cdf(value)) @@ -127,7 +127,7 @@ def survival_function(self, value: PyTree[Array]) -> PyTree[Array]: **Returns:** - The survival function evaluated at `value`, i.e. P[X > value] + - The survival function evaluated at `value`, i.e. P[X > value] """ return 1.0 - self.cdf(value) @@ -145,7 +145,7 @@ def log_survival_function(self, value: PyTree[Array]) -> PyTree[Array]: **Returns:** - The log of the survival function evaluated at `value`, i.e. + - The log of the survival function evaluated at `value`, i.e. log P[X > value] """ return jnp.log1p(-self.cdf(value)) @@ -188,7 +188,7 @@ def kl_divergence(self, other_dist, **kwargs) -> PyTree[Array]: **Returns:** - The KL divergence `KL(self || other_dist)`. + - The KL divergence `KL(self || other_dist)`. """ raise NotImplementedError( f"Distribution `{self.name}` does not implement `kl_divergence`." @@ -204,6 +204,6 @@ def cross_entropy(self, other_dist, **kwargs) -> Array: **Returns:** - The cross entropy `H(self || other_dist)`. + - The cross entropy `H(self || other_dist)`. """ return self.kl_divergence(other_dist, **kwargs) + self.entropy() diff --git a/distreqx/distributions/_transformed.py b/distreqx/distributions/_transformed.py index e80d98c..a63bc14 100644 --- a/distreqx/distributions/_transformed.py +++ b/distreqx/distributions/_transformed.py @@ -111,7 +111,7 @@ def sample_and_log_prob(self, key: PRNGKeyArray) -> Tuple[Array, Array]: **Returns:** - A tuple of a sample and its log probs. + - A tuple of a sample and its log probs. """ x, lp_x = self.distribution.sample_and_log_prob(key) y, fldj = self.bijector.forward_and_log_det(x) @@ -153,7 +153,7 @@ def entropy(self, input_hint: Optional[Array] = None) -> Array: **Returns:** - The entropy of the distribution. + - The entropy of the distribution. **Raises:** @@ -175,16 +175,37 @@ def entropy(self, input_hint: Optional[Array] = None) -> Array: ) def kl_divergence(self, other_dist, **kwargs) -> Array: - """Calculates the KL divergence to another distribution. + """Obtains the KL divergence between two Transformed distributions. + + This computes the KL divergence between two Transformed distributions with the + same bijector. If the two Transformed distributions do not have the same + bijector, an error is raised. To determine if the bijectors are equal, this + method proceeds as follows: + - If both bijectors are the same instance of a distreqx bijector, then they are + declared equal. + - If not the same instance, we check if they are equal according to their + `same_as` predicate. + - Otherwise, the string representation of the Jaxpr of the `forward` method + of each bijector is compared. If both string representations are equal, the + bijectors are declared equal. + - Otherwise, the bijectors cannot be guaranteed to be equal and an error is + raised. **Arguments:** - - `other_dist`: A compatible disteqx distribution. - - `kwargs`: Additional kwargs, can accept an `input_hint`. + - `other_dist`: A Transformed distribution. + - `input_hint`: keyword argument, an example sample from the base distribution, + used to trace the `forward` method. If not specified, it is computed using + a zero array of the shape and dtype of a sample from the base distribution. **Returns:** - The KL divergence `KL(self || other_dist)`. + - `KL(dist1 || dist2)`. + + **Raises:** + + - `NotImplementedError`: If bijectors are not known to be equal. + - `ValueError`: If the base distributions do not have the same `event_shape`. """ return _kl_divergence_transformed_transformed(self, other_dist, **kwargs) @@ -196,39 +217,6 @@ def _kl_divergence_transformed_transformed( input_hint: Optional[Array] = None, **unused_kwargs, ) -> Array: - """Obtains the KL divergence between two Transformed distributions. - - This computes the KL divergence between two Transformed distributions with the - same bijector. If the two Transformed distributions do not have the same - bijector, an error is raised. To determine if the bijectors are equal, this - method proceeds as follows: - - If both bijectors are the same instance of a distreqx bijector, then they are - declared equal. - - If not the same instance, we check if they are equal according to their - `same_as` predicate. - - Otherwise, the string representation of the Jaxpr of the `forward` method - of each bijector is compared. If both string representations are equal, the - bijectors are declared equal. - - Otherwise, the bijectors cannot be guaranteed to be equal and an error is - raised. - - **Arguments:** - - - `dist1`: A Transformed distribution. - - `dist2`: A Transformed distribution. - - `input_hint`: an example sample from the base distribution, used to trace the - `forward` method. If not specified, it is computed using a zero array of - the shape and dtype of a sample from the base distribution. - - **Returns:** - - `KL(dist1 || dist2)`. - - **Raises:** - - - `NotImplementedError`: If bijectors are not known to be equal. - - `ValueError`: If the base distributions do not have the same `event_shape`. - """ if dist1.distribution.event_shape != dist2.distribution.event_shape: raise ValueError( f"The two base distributions do not have the same event shape: " diff --git a/distreqx/distributions/mvn_from_bijector.py b/distreqx/distributions/mvn_from_bijector.py index aecd74e..dd613a7 100644 --- a/distreqx/distributions/mvn_from_bijector.py +++ b/distreqx/distributions/mvn_from_bijector.py @@ -177,11 +177,13 @@ def _kl_divergence_mvn_mvn( """Divergence KL(dist1 || dist2) between multivariate normal distributions. **Arguments:** - dist1: A multivariate normal distribution. - dist2: A multivariate normal distribution. - Returns: - Batchwise `KL(dist1 || dist2)`. + - `dist1`: A multivariate normal distribution. + - `dist2`: A multivariate normal distribution. + + **Returns:** + + - `KL(dist1 || dist2)`. """ num_dims = dist1.event_shape[-1] diff --git a/distreqx/utils/math.py b/distreqx/utils/math.py index a0d47ae..5b752ad 100644 --- a/distreqx/utils/math.py +++ b/distreqx/utils/math.py @@ -14,16 +14,16 @@ def multiply_no_nan(x: Array, y: Array) -> Array: **Arguments:** - - `x`: First input. - - `y`: Second input. + - `x`: First input. + - `y`: Second input. **Returns:** - The product of `x` and `y`. + - The product of `x` and `y`. **Raises:** - ValueError if the shapes of `x` and `y` do not match. + - ValueError if the shapes of `x` and `y` do not match. """ dtype = jnp.result_type(x, y) return jnp.where(y == 0, jnp.zeros((), dtype=dtype), x * y) @@ -37,12 +37,12 @@ def multiply_no_nan_jvp( **Arguments:** - - `primals`: A tuple containing the primal values of `x` and `y`. - - `tangents`: A tuple containing the tangent values of `x` and `y`. + - `primals`: A tuple containing the primal values of `x` and `y`. + - `tangents`: A tuple containing the tangent values of `x` and `y`. **Returns:** - A tuple containing the output of the primal and tangent operations. + - A tuple containing the output of the primal and tangent operations. """ x, y = primals x_dot, y_dot = tangents @@ -58,12 +58,12 @@ def power_no_nan(x: Array, y: Array) -> Array: **Arguments:** - - `x`: First input. - - `y`: Second input. + - `x`: First input. + - `y`: Second input. **Returns:** - The power `x ** y`. + - The power `x ** y`. """ dtype = jnp.result_type(x, y) return jnp.where(y == 0, jnp.ones((), dtype=dtype), jnp.power(x, y)) @@ -77,12 +77,12 @@ def power_no_nan_jvp( **Arguments:** - - `primals`: A tuple containing the primal values of `x` and `y`. - - `tangents`: A tuple containing the tangent values of `x` and `y`. + - `primals`: A tuple containing the primal values of `x` and `y`. + - `tangents`: A tuple containing the tangent values of `x` and `y`. **Returns:** - A tuple containing the output of the primal and tangent operations. + - A tuple containing the output of the primal and tangent operations. """ x, y = primals x_dot, y_dot = tangents @@ -96,12 +96,12 @@ def mul_exp(x: Array, logp: Array) -> Array: **Arguments:** - - `x`: An array. - - `logp`: An array representing logarithms. + - `x`: An array. + - `logp`: An array representing logarithms. **Returns:** - The result of `x * exp(logp)`. + - The result of `x * exp(logp)`. """ p = jnp.exp(logp) x = jnp.where(p == 0, 0.0, x) @@ -115,12 +115,12 @@ def normalize( **Arguments:** - - `probs`: Probability values. - - `logits`: Logit values. + - `probs`: Probability values. + - `logits`: Logit values. **Returns:** - Normalized probabilities or logits. + - Normalized probabilities or logits. """ if logits is None: if probs is None: @@ -139,12 +139,12 @@ def sum_last(x: Array, ndims: int) -> Array: **Arguments:** - - `x`: An array. - - `ndims`: The number of last dimensions to sum. + - `x`: An array. + - `ndims`: The number of last dimensions to sum. **Returns:** - The sum of the last `ndims` dimensions of `x`. + - The sum of the last `ndims` dimensions of `x`. """ axes_to_sum = tuple(range(-ndims, 0)) return jnp.sum(x, axis=axes_to_sum) @@ -155,12 +155,12 @@ def log_expbig_minus_expsmall(big: Array, small: Array) -> Array: **Arguments:** - - `big`: First input. - - `small`: Second input. It must be `small <= big`. + - `big`: First input. + - `small`: Second input. It must be `small <= big`. **Returns:** - The resulting `log(exp(big) - exp(small))`. + - The resulting `log(exp(big) - exp(small))`. """ return big + jnp.log1p(-jnp.exp(small - big)) @@ -170,12 +170,12 @@ def log_beta(a: Array, b: Array) -> Array: **Arguments:** - - `a`: First input. It must be positive. - - `b`: Second input. It must be positive. + - `a`: First input. It must be positive. + - `b`: Second input. It must be positive. **Returns:** - The value `log B(a, b) = log Gamma(a) + log Gamma(b) - log Gamma(a + b)`, + - The value `log B(a, b) = log Gamma(a) + log Gamma(b) - log Gamma(a + b)`, where `Gamma` is the Gamma function, obtained through stable computation of `log Gamma`. """ @@ -187,11 +187,11 @@ def log_beta_multivariate(a: Array) -> Array: **Arguments:** - - `a`: An array of length `K` containing positive values. + - `a`: An array of length `K` containing positive values. **Returns:** - The value + - The value `log B(a) = sum_{k=1}^{K} log Gamma(a_k) - log Gamma(sum_{k=1}^{K} a_k)`, where `Gamma` is the Gamma function, obtained through stable computation of `log Gamma`. diff --git a/docs/api/bijectors/_bijector.md b/docs/api/bijectors/_bijector.md index 382875a..eb34a7b 100644 --- a/docs/api/bijectors/_bijector.md +++ b/docs/api/bijectors/_bijector.md @@ -1,7 +1,14 @@ -# Base Bijector Class +# Base Bijector ::: distreqx.bijectors._bijector.AbstractBijector selection: members: - __init__ + - forward + - inverse + - forward_log_det_jacobian + - inverse_log_det_jacobian + - forward_and_log_det + - inverse_and_log_det + - same_as --- \ No newline at end of file diff --git a/docs/api/bijectors/_linear.md b/docs/api/bijectors/_linear.md index efd722b..9a42931 100644 --- a/docs/api/bijectors/_linear.md +++ b/docs/api/bijectors/_linear.md @@ -1,4 +1,4 @@ -# Base Linear Bijector Class +# Base Linear Bijector ::: distreqx.bijectors._linear.AbstractLinearBijector selection: diff --git a/docs/api/bijectors/tanh.md b/docs/api/bijectors/tanh.md index 0e0c6b9..b24c8b7 100644 --- a/docs/api/bijectors/tanh.md +++ b/docs/api/bijectors/tanh.md @@ -2,5 +2,6 @@ ::: distreqx.bijectors.tanh.Tanh selection: - members: None + members: + - __init__ --- \ No newline at end of file diff --git a/docs/api/distributions/_distribution.md b/docs/api/distributions/_distribution.md index 0b6f169..06c5dc9 100644 --- a/docs/api/distributions/_distribution.md +++ b/docs/api/distributions/_distribution.md @@ -1,4 +1,4 @@ -# Base Distribution Class +# Base Distribution ::: distreqx.distributions._distribution.AbstractDistribution selection: @@ -9,5 +9,4 @@ - cdf - survival_function - log_survival_function - - stddev --- \ No newline at end of file diff --git a/docs/api/distributions/_transformed.md b/docs/api/distributions/_transformed.md index 6d9887f..36ac378 100644 --- a/docs/api/distributions/_transformed.md +++ b/docs/api/distributions/_transformed.md @@ -6,4 +6,5 @@ - __init__ - sample_and_log_prob - entropy + - kl_divergence --- \ No newline at end of file diff --git a/docs/api/distributions/mvn_from_bijector.md b/docs/api/distributions/mvn_from_bijector.md index 61b62f2..bb7985d 100644 --- a/docs/api/distributions/mvn_from_bijector.md +++ b/docs/api/distributions/mvn_from_bijector.md @@ -4,4 +4,5 @@ selection: members: - __init__ + - covariance --- \ No newline at end of file diff --git a/docs/api/distributions/normal.md b/docs/api/distributions/normal.md index 657c67f..02ac189 100644 --- a/docs/api/distributions/normal.md +++ b/docs/api/distributions/normal.md @@ -4,4 +4,5 @@ selection: members: - __init__ + - entropy --- \ No newline at end of file diff --git a/docs/misc/faq.md b/docs/misc/faq.md new file mode 100644 index 0000000..d511aa0 --- /dev/null +++ b/docs/misc/faq.md @@ -0,0 +1,13 @@ +# FAQ + +## Why not just use distrax? + +The simple answer to that question is "I tried". Distrax is a the product of a lot of great work, especially helpful for working with TFP, but in the current era of jax packages lacks important elements: + +- It's only semi-maintained (there have been no responses to any issues in the last >6 months) +- It doesn't always play nice with other jax packages and can be slow (see: [#193](https://github.com/google-deepmind/distrax/issues/193), [#383](https://github.com/patrick-kidger/diffrax/issues/383), [#252](https://github.com/patrick-kidger/equinox/issues/252), [#269](https://github.com/patrick-kidger/equinox/issues/269), [#16](https://github.com/JaxGaussianProcesses/JaxUtils/issues/16), [#16170](https://github.com/google/jax/issues/16170)) +- You need Tensorflow to use it + +## Why use equinox? + +The `Jittable` class is basically an equinox module (if you squint) and while we could reimplement a custom Module class (like GPJax does), why reinvent the wheel? Equinox is actively being developed and should it become inactive is still possible to maintain. diff --git a/mkdocs.yml b/mkdocs.yml index 9babf77..b07f47f 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -90,6 +90,8 @@ plugins: nav: - 'index.md' + - Examples: + - Binary MNIST VAE: 'examples/01_vae.ipynb' - API: - Distributions: - 'api/distributions/_distribution.md' @@ -111,5 +113,5 @@ nav: - 'api/bijectors/tanh.md' - Utilities: - 'api/utils/math.md' - - Examples: - - Binary MNIST VAE: 'examples/01_vae.ipynb' \ No newline at end of file + - Further Details: + - 'misc/faq.md' \ No newline at end of file