Skip to content

Commit

Permalink
mvn tri distribution (#11)
Browse files Browse the repository at this point in the history
* mvn tri

* mvn tri v2
  • Loading branch information
mayalenE authored Jun 10, 2024
1 parent c2189c6 commit 7bb389b
Show file tree
Hide file tree
Showing 9 changed files with 427 additions and 0 deletions.
1 change: 1 addition & 0 deletions distreqx/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .shift import Shift as Shift
from .sigmoid import Sigmoid as Sigmoid
from .tanh import Tanh as Tanh
from .triangular_linear import TriangularLinear as TriangularLinear
104 changes: 104 additions & 0 deletions distreqx/bijectors/triangular_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Triangular linear bijector."""

import jax
import jax.numpy as jnp
from jaxtyping import Array

from ._bijector import AbstractBijector
from ._linear import AbstractLinearBijector


def _triangular_logdet(matrix: Array) -> Array:
"""Computes the log absolute determinant of a triangular matrix."""
return jnp.sum(jnp.log(jnp.abs(jnp.diag(matrix))))


class TriangularLinear(AbstractLinearBijector):
"""A linear bijector whose weight matrix is triangular.
The bijector is defined as `f(x) = Ax` where `A` is a DxD triangular matrix.
The Jacobian determinant can be computed in O(D) as follows:
log|det J(x)| = log|det A| = sum(log|diag(A)|)
The inverse is computed in O(D^2) by solving the triangular system `Ax = y`.
The bijector is invertible if and only if all diagonal elements of `A` are
non-zero. It is the responsibility of the user to make sure that this is the
case; the class will make no attempt to verify that the bijector is
invertible.
"""

_matrix: Array
_is_lower: bool

def __init__(self, matrix: Array, is_lower: bool = True):
"""Initializes a `TriangularLinear` bijector.
**Arguments:**
- `matrix`: a square matrix whose triangular part defines `A`. Can also be a
batch of matrices. Whether `A` is the lower or upper triangular part of
`matrix` is determined by `is_lower`.
- `is_lower`: if True, `A` is set to the lower triangular part of `matrix`. If
False, `A` is set to the upper triangular part of `matrix`.
"""
if matrix.ndim < 2:
raise ValueError(
f"`matrix` must have at least 2 dimensions, got {matrix.ndim}."
)
if matrix.shape[-2] != matrix.shape[-1]:
raise ValueError(
f"`matrix` must be square; instead, it has shape {matrix.shape[-2:]}."
)
super().__init__(event_dims=matrix.shape[-1])

self._matrix = jnp.tril(matrix) if is_lower else jnp.triu(matrix)
self._is_lower = is_lower

@property
def matrix(self) -> Array:
"""The triangular matrix `A` of the transformation."""
return self._matrix

@property
def is_lower(self) -> bool:
"""True if `A` is lower triangular, False if upper triangular."""
return self._is_lower

def forward(self, x: Array) -> Array:
"""Computes y = f(x)."""
return self._matrix @ x

def forward_log_det_jacobian(self, x: Array) -> Array:
"""Computes log|det J(f)(x)|."""
triangular_logdet = jnp.vectorize(_triangular_logdet, signature="(m,m)->()")
return triangular_logdet(self._matrix)

def forward_and_log_det(self, x: Array) -> tuple[Array, Array]:
"""Computes y = f(x) and log|det J(f)(x)|."""
return self.forward(x), self.forward_log_det_jacobian(x)

def inverse(self, y: Array) -> Array:
"""Computes x = f^{-1}(y)."""
return jax.scipy.linalg.solve_triangular(self._matrix, y, lower=self.is_lower)

def inverse_log_det_jacobian(self, y: Array) -> Array:
"""Computes log|det J(f^{-1})(y)|."""
return -self.forward_log_det_jacobian(y)

def inverse_and_log_det(self, y: Array) -> tuple[Array, Array]:
"""Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
return self.inverse(y), self.inverse_log_det_jacobian(y)

def same_as(self, other: AbstractBijector) -> bool:
"""Returns True if this bijector is guaranteed to be the same as `other`."""
if type(other) is TriangularLinear: # pylint: disable=unidiomatic-typecheck
return all(
(
self.matrix is other.matrix,
self.is_lower is other.is_lower,
)
)
return False
1 change: 1 addition & 0 deletions distreqx/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
from .mvn_from_bijector import (
MultivariateNormalFromBijector as MultivariateNormalFromBijector,
)
from .mvn_tri import MultivariateNormalTri as MultivariateNormalTri
from .normal import Normal as Normal
from .transformed import Transformed as Transformed
107 changes: 107 additions & 0 deletions distreqx/distributions/mvn_tri.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""MultivariateNormalTri distribution."""

from typing import Optional

import jax.numpy as jnp
from jaxtyping import Array

from ..bijectors import DiagLinear, TriangularLinear
from .mvn_from_bijector import MultivariateNormalFromBijector


def _check_parameters(loc: Optional[Array], scale_tri: Optional[Array]) -> None:
"""Checks that the inputs are correct."""
if loc is None and scale_tri is None:
raise ValueError("At least one of `loc` and `scale_tri` must be specified.")

if loc is not None and loc.ndim < 1:
raise ValueError("The parameter `loc` must have at least one dimension.")

if scale_tri is not None and scale_tri.ndim < 2:
raise ValueError(
f"The parameter `scale_tri` must have at least two dimensions, but "
f"`scale_tri.shape = {scale_tri.shape}`."
)

if scale_tri is not None and scale_tri.shape[-1] != scale_tri.shape[-2]:
raise ValueError(
f"The parameter `scale_tri` must be a square matrix, but "
f"`scale_tri.shape = {scale_tri.shape}`."
)

if loc is not None:
num_dims = loc.shape[-1]
if scale_tri is not None and scale_tri.shape[-1] != num_dims:
raise ValueError(
f"Shapes are not compatible: `loc.shape = {loc.shape}` and "
f"`scale_tri.shape = {scale_tri.shape}`."
)


class MultivariateNormalTri(MultivariateNormalFromBijector):
"""Multivariate normal distribution on `R^k`.
The `MultivariateNormalTri` distribution is parameterized by a `k`-length
location (mean) vector `b` and a (lower or upper) triangular scale matrix `S`
of size `k x k`. The covariance matrix is `C = S @ S.T`.
"""

_scale_tri: Array
_is_lower: bool

def __init__(
self,
loc: Optional[Array] = None,
scale_tri: Optional[Array] = None,
is_lower: bool = True,
):
"""Initializes a MultivariateNormalTri distribution.
**Arguments:**
- `loc`: Mean vector of the distribution of shape `k`.
If not specified, it defaults to zeros.
- `scale_tri`: The scale matrix `S`. It must be a `k x k` triangular matrix.
If `scale_tri` is not triangular, the entries above or below the main
diagonal will be ignored. The parameter `is_lower` specifies if `scale_tri`
is lower or upper triangular. It is the responsibility of the user to make
sure that `scale_tri` only contains non-zero elements in its diagonal;
this class makes no attempt to verify that. If `scale_tri` is not specified,
it defaults to the identity.
- `is_lower`: Indicates if `scale_tri` is lower (if True) or upper (if False)
triangular.
"""
_check_parameters(loc, scale_tri)

if loc is not None:
num_dims = loc.shape[-1]
elif scale_tri is not None:
num_dims = scale_tri.shape[-1]
else:
raise ValueError

dtype = jnp.result_type(*[x for x in [loc, scale_tri] if x is not None])

if loc is None:
loc = jnp.zeros((num_dims,), dtype=dtype)

if scale_tri is None:
self._scale_tri = jnp.eye(num_dims, dtype=dtype)
scale = DiagLinear(diag=jnp.ones(loc.shape[-1:], dtype=dtype))
else:
tri_fn = jnp.tril if is_lower else jnp.triu
self._scale_tri = tri_fn(scale_tri)
scale = TriangularLinear(matrix=self._scale_tri, is_lower=is_lower)
self._is_lower = is_lower

super().__init__(loc=loc, scale=scale)

@property
def scale_tri(self) -> Array:
"""Triangular scale matrix `S`."""
return jnp.broadcast_to(self._scale_tri, self.event_shape + self.event_shape)

@property
def is_lower(self) -> bool:
"""Whether the `scale_tri` matrix is lower triangular."""
return self._is_lower
7 changes: 7 additions & 0 deletions docs/api/bijectors/triangular_linear.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Triangular Linear Bijector

::: distreqx.bijectors.triangular_linear.TriangularLinear
selection:
members:
- __init__
---
7 changes: 7 additions & 0 deletions docs/api/distributions/mvn_tri.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Triangular Multivariate Normal

::: distreqx.distributions.mvn_tri.MultivariateNormalTri
selection:
members:
- __init__
---
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ nav:
- 'api/distributions/normal.md'
- 'api/distributions/mvn_diag.md'
- 'api/distributions/mvn_from_bijector.md'
- 'api/distributions/mvn_tri.md'
- Bijectors:
- 'api/bijectors/_bijector.md'
- 'api/bijectors/_linear.md'
Expand All @@ -112,6 +113,7 @@ nav:
- 'api/bijectors/shift.md'
- 'api/bijectors/sigmoid.md'
- 'api/bijectors/tanh.md'
- 'api/bijectors/triangular_linear.md'
- Utilities:
- 'api/utils/math.md'
- Further Details:
Expand Down
66 changes: 66 additions & 0 deletions tests/mvn_tri_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from unittest import TestCase

import jax


jax.config.update("jax_enable_x64", True)

import equinox as eqx
import jax
import jax.numpy as jnp
import numpy as np
from parameterized import parameterized # type: ignore

from distreqx.distributions import MultivariateNormalTri


class MultivariateNormalTriTest(TestCase):
def setUp(self):
self.key = jax.random.PRNGKey(0)

def _test_raises_error(self, dist_kwargs):
with self.assertRaises(ValueError):
dist = MultivariateNormalTri(**dist_kwargs)
dist.sample(key=self.key)

def test_invalid_parameters(self):
self._test_raises_error(dist_kwargs={"loc": None, "scale_tri": None})
self._test_raises_error(dist_kwargs={"loc": jnp.array(1.0), "scale_tri": None})
self._test_raises_error(dist_kwargs={"loc": None, "scale_tri": jnp.array(1.0)})
self._test_raises_error(
dist_kwargs={"loc": None, "scale_tri": jnp.array([1.0])}
)
self._test_raises_error(
dist_kwargs={
"loc": None,
"scale_tri": jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]),
}
)
self._test_raises_error(
dist_kwargs={"loc": jnp.zeros((5,)), "scale_tri": jnp.ones((4, 4))}
)

@parameterized.expand([("float32", jnp.float32), ("float64", jnp.float64)])
def test_sample_dtype(self, name, dtype):
dist_params = {
"loc": jnp.array([0.0, 0.0], dtype),
"scale_tri": jnp.array([[1.0, 0.0], [0.0, 1.0]], dtype),
"is_lower": True,
}
dist = MultivariateNormalTri(**dist_params)
samples = dist.sample(key=self.key)
self.assertEqual(samples.dtype, dist.dtype)
self.assertEqual(samples.dtype, dtype)

def test_jittable(self):
@eqx.filter_jit
def f(dist):
return dist.sample(key=jax.random.PRNGKey(0))

dist_params = {"loc": jnp.zeros(2), "scale_tri": jnp.eye(2), "is_lower": True}
dist = MultivariateNormalTri(**dist_params)
y = f(dist)
self.assertIsInstance(y, jax.Array)

def assertion_fn(self, rtol):
return lambda x, y: np.testing.assert_allclose(x, y, rtol=rtol)
Loading

0 comments on commit 7bb389b

Please sign in to comment.