Skip to content

Commit

Permalink
documentation improvements (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
lockwo authored May 1, 2024
1 parent d9bc792 commit be48de1
Show file tree
Hide file tree
Showing 20 changed files with 128 additions and 108 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/deploy-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
python -m pip install --upgrade pip
python -m pip install .
python -m pip install -r docs/requirements.txt
# https://github.com/mhausenblas/mkdocs-deploy-gh-pages/blob/master/action.sh
- name: Build docs
run: |
mkdocs build
Expand Down
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
<h1 align='center'>distreqx</h1>
<h2 align='center'>Distrax + Equinox = distreqx. Easy Pytree probability distributions and bijectors.</h2>

distreqx is a [JAX](https://github.com/google/jax)-based library providing implementations of a subset of [TensorFlow Probability (TFP)](https://github.com/tensorflow/probability), with some new features and emphasis on jax compatibility.
distreqx (pronounced "dist-rex") is a [JAX](https://github.com/google/jax)-based library providing implementations of distributions, bijectors, and tools for statistical and probabilistic machine learning with all benefits of jax (native GPU/TPU acceleration, differentiability, vectorization, distributing workloads, XLA compilation, etc.).

This is a largely as reimplementation of [distrax](https://github.com/google-deepmind/distrax) using [equinox](https://github.com/patrick-kidger/equinox), much of the code/comments/documentation/tests are directly taken or adapted from distrax so all credit to the DeepMind team.
The origin of this repo is a reimplementation of [distrax](https://github.com/google-deepmind/distrax), (which is asubset of [TensorFlow Probability (TFP)](https://github.com/tensorflow/probability), with some new features and emphasis on jax compatibility) using [equinox](https://github.com/patrick-kidger/equinox). As a result, much of the original code/comments/documentation/tests are directly taken or adapted from distrax (original distrax copyright available at end of README.)

Features include:
Current features include:

- Probability distributions
- Bijectors
Expand All @@ -14,14 +14,16 @@ Features include:
## Installation

```
pip install distreqx
git clone https://github.com/lockwo/distreqx.git
cd distreqx
pip install -e .
```

Requires Python 3.9+, JAX 0.4.13+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.0+.
Requires Python 3.9+, JAX 0.4.11+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.11.0+.

## Documentation

Available at .
Available at https://lockwo.github.io/distreqx/.

## Quick example

Expand All @@ -31,9 +33,9 @@ from distreqx import

## Differences with Distrax

- No support for TFP
- Broader pytree support
- No official support/interoperability with TFP
- The concept of a batch dimension is dropped. If you want to operate on a batch, use `vmap` (note, this can be used in construction as well, e.g. [vmaping the construction](https://docs.kidger.site/equinox/tricks/#ensembling) of a `ScalarAffine`)
- Broader pytree enablement

## Citation

Expand All @@ -46,8 +48,7 @@ If you found this library useful in academic research, please cite:

## See also: other libraries in the JAX ecosystem

[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.
[GPJax](https://github.com/JaxGaussianProcesses/GPJax): Gaussian processes in JAX.
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
[Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
Expand Down
12 changes: 6 additions & 6 deletions distreqx/bijectors/_bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,34 +63,34 @@ def __init__(
self._is_constant_log_det = is_constant_log_det

def forward(self, x: PyTree) -> PyTree:
"""Computes y = f(x)."""
R"""Computes $y = f(x)$."""
y, _ = self.forward_and_log_det(x)
return y

def inverse(self, y: PyTree) -> PyTree:
"""Computes x = f^{-1}(y)."""
r"""Computes $x = f^{-1}(y)$."""
x, _ = self.inverse_and_log_det(y)
return x

def forward_log_det_jacobian(self, x: PyTree) -> PyTree:
"""Computes log|det J(f)(x)|."""
r"""Computes $\log|\det J(f)(x)|$."""
_, logdet = self.forward_and_log_det(x)
return logdet

def inverse_log_det_jacobian(self, y: PyTree) -> PyTree:
"""Computes log|det J(f^{-1})(y)|."""
r"""Computes $\log|\det J(f^{-1})(y)|$."""
_, logdet = self.inverse_and_log_det(y)
return logdet

@abstractmethod
def forward_and_log_det(self, x: PyTree) -> Tuple[PyTree, PyTree]:
"""Computes y = f(x) and log|det J(f)(x)|."""
r"""Computes $y = f(x)$ and $\log|\det J(f)(x)|$."""
raise NotImplementedError(
f"Bijector {self.name} does not implement `forward_and_log_det`."
)

def inverse_and_log_det(self, y: Array) -> Tuple[PyTree, PyTree]:
"""Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
r"""Computes $x = f^{-1}(y)$ and $\log|\det J(f^{-1})(y)|$."""
raise NotImplementedError(
f"Bijector {self.name} does not implement `inverse_and_log_det`."
)
Expand Down
1 change: 1 addition & 0 deletions distreqx/bijectors/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(self, bijector: AbstractBijector, ndims: int):
"""Initializes a Block.
**Arguments:**
- `bijector`: the bijector to be promoted to a block bijector. It can be a
distreqx bijector or a callable to be wrapped by `Lambda`.
- `ndims`: number of dimensions to promote to event dimensions.
Expand Down
1 change: 1 addition & 0 deletions distreqx/bijectors/chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, bijectors: Sequence[AbstractBijector]):
"""Initializes a Chain bijector.
**Arguments:**
- `bijectors`: a sequence of bijectors to be composed into one. Each bijector
can be a distreqx bijector or a callable to be wrapped
by `Lambda`. The sequence must contain at least one bijector.
Expand Down
5 changes: 3 additions & 2 deletions distreqx/bijectors/shift.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ class Shift(ScalarAffine):
def __init__(self, shift: Array):
"""Initializes a `Shift` bijector.
Args:
shift: the bijector's shift parameter. Can also be batched.
**Arguments:**
- `shift`: the bijector's shift parameter.
"""
super().__init__(shift=shift)

Expand Down
3 changes: 2 additions & 1 deletion distreqx/bijectors/tanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class Tanh(AbstractBijector):
instead of `sample` followed by `log_prob`.
"""

def __init__(self):
def __init__(self) -> None:
"""Initialize the TanH bijector."""
super().__init__()

def forward_log_det_jacobian(self, x: Array) -> Array:
Expand Down
16 changes: 8 additions & 8 deletions distreqx/distributions/_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def sample_and_log_prob(
**Returns:**
A tuple of a sample and their log probs.
- A tuple of a sample and their log probs.
"""
samples = self.sample(key)
log_prob = self.log_prob(samples)
Expand All @@ -48,7 +48,7 @@ def log_prob(self, value: PyTree[Array]) -> PyTree[Array]:
**Returns:**
The log probability log P(value).
- The log probability log P(value).
"""
raise NotImplementedError

Expand Down Expand Up @@ -78,7 +78,7 @@ def prob(self, value: PyTree[Array]) -> PyTree[Array]:
**Returns:**
The probability P(value).
- The probability P(value).
"""
return jnp.exp(self.log_prob(value))

Expand Down Expand Up @@ -109,7 +109,7 @@ def cdf(self, value: PyTree[Array]) -> PyTree[Array]:
**Returns:**
The CDF evaluated at value, i.e. P[X <= value].
- The CDF evaluated at value, i.e. P[X <= value].
"""
return jnp.exp(self.log_cdf(value))

Expand All @@ -127,7 +127,7 @@ def survival_function(self, value: PyTree[Array]) -> PyTree[Array]:
**Returns:**
The survival function evaluated at `value`, i.e. P[X > value]
- The survival function evaluated at `value`, i.e. P[X > value]
"""
return 1.0 - self.cdf(value)

Expand All @@ -145,7 +145,7 @@ def log_survival_function(self, value: PyTree[Array]) -> PyTree[Array]:
**Returns:**
The log of the survival function evaluated at `value`, i.e.
- The log of the survival function evaluated at `value`, i.e.
log P[X > value]
"""
return jnp.log1p(-self.cdf(value))
Expand Down Expand Up @@ -188,7 +188,7 @@ def kl_divergence(self, other_dist, **kwargs) -> PyTree[Array]:
**Returns:**
The KL divergence `KL(self || other_dist)`.
- The KL divergence `KL(self || other_dist)`.
"""
raise NotImplementedError(
f"Distribution `{self.name}` does not implement `kl_divergence`."
Expand All @@ -204,6 +204,6 @@ def cross_entropy(self, other_dist, **kwargs) -> Array:
**Returns:**
The cross entropy `H(self || other_dist)`.
- The cross entropy `H(self || other_dist)`.
"""
return self.kl_divergence(other_dist, **kwargs) + self.entropy()
66 changes: 27 additions & 39 deletions distreqx/distributions/_transformed.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def sample_and_log_prob(self, key: PRNGKeyArray) -> Tuple[Array, Array]:
**Returns:**
A tuple of a sample and its log probs.
- A tuple of a sample and its log probs.
"""
x, lp_x = self.distribution.sample_and_log_prob(key)
y, fldj = self.bijector.forward_and_log_det(x)
Expand Down Expand Up @@ -153,7 +153,7 @@ def entropy(self, input_hint: Optional[Array] = None) -> Array:
**Returns:**
The entropy of the distribution.
- The entropy of the distribution.
**Raises:**
Expand All @@ -175,16 +175,37 @@ def entropy(self, input_hint: Optional[Array] = None) -> Array:
)

def kl_divergence(self, other_dist, **kwargs) -> Array:
"""Calculates the KL divergence to another distribution.
"""Obtains the KL divergence between two Transformed distributions.
This computes the KL divergence between two Transformed distributions with the
same bijector. If the two Transformed distributions do not have the same
bijector, an error is raised. To determine if the bijectors are equal, this
method proceeds as follows:
- If both bijectors are the same instance of a distreqx bijector, then they are
declared equal.
- If not the same instance, we check if they are equal according to their
`same_as` predicate.
- Otherwise, the string representation of the Jaxpr of the `forward` method
of each bijector is compared. If both string representations are equal, the
bijectors are declared equal.
- Otherwise, the bijectors cannot be guaranteed to be equal and an error is
raised.
**Arguments:**
- `other_dist`: A compatible disteqx distribution.
- `kwargs`: Additional kwargs, can accept an `input_hint`.
- `other_dist`: A Transformed distribution.
- `input_hint`: keyword argument, an example sample from the base distribution,
used to trace the `forward` method. If not specified, it is computed using
a zero array of the shape and dtype of a sample from the base distribution.
**Returns:**
The KL divergence `KL(self || other_dist)`.
- `KL(dist1 || dist2)`.
**Raises:**
- `NotImplementedError`: If bijectors are not known to be equal.
- `ValueError`: If the base distributions do not have the same `event_shape`.
"""
return _kl_divergence_transformed_transformed(self, other_dist, **kwargs)

Expand All @@ -196,39 +217,6 @@ def _kl_divergence_transformed_transformed(
input_hint: Optional[Array] = None,
**unused_kwargs,
) -> Array:
"""Obtains the KL divergence between two Transformed distributions.
This computes the KL divergence between two Transformed distributions with the
same bijector. If the two Transformed distributions do not have the same
bijector, an error is raised. To determine if the bijectors are equal, this
method proceeds as follows:
- If both bijectors are the same instance of a distreqx bijector, then they are
declared equal.
- If not the same instance, we check if they are equal according to their
`same_as` predicate.
- Otherwise, the string representation of the Jaxpr of the `forward` method
of each bijector is compared. If both string representations are equal, the
bijectors are declared equal.
- Otherwise, the bijectors cannot be guaranteed to be equal and an error is
raised.
**Arguments:**
- `dist1`: A Transformed distribution.
- `dist2`: A Transformed distribution.
- `input_hint`: an example sample from the base distribution, used to trace the
`forward` method. If not specified, it is computed using a zero array of
the shape and dtype of a sample from the base distribution.
**Returns:**
`KL(dist1 || dist2)`.
**Raises:**
- `NotImplementedError`: If bijectors are not known to be equal.
- `ValueError`: If the base distributions do not have the same `event_shape`.
"""
if dist1.distribution.event_shape != dist2.distribution.event_shape:
raise ValueError(
f"The two base distributions do not have the same event shape: "
Expand Down
10 changes: 6 additions & 4 deletions distreqx/distributions/mvn_from_bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,13 @@ def _kl_divergence_mvn_mvn(
"""Divergence KL(dist1 || dist2) between multivariate normal distributions.
**Arguments:**
dist1: A multivariate normal distribution.
dist2: A multivariate normal distribution.
Returns:
Batchwise `KL(dist1 || dist2)`.
- `dist1`: A multivariate normal distribution.
- `dist2`: A multivariate normal distribution.
**Returns:**
- `KL(dist1 || dist2)`.
"""
num_dims = dist1.event_shape[-1]

Expand Down
Loading

0 comments on commit be48de1

Please sign in to comment.