-
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* mvn tri * mvn tri v2
- Loading branch information
Showing
9 changed files
with
427 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
"""Triangular linear bijector.""" | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jaxtyping import Array | ||
|
||
from ._bijector import AbstractBijector | ||
from ._linear import AbstractLinearBijector | ||
|
||
|
||
def _triangular_logdet(matrix: Array) -> Array: | ||
"""Computes the log absolute determinant of a triangular matrix.""" | ||
return jnp.sum(jnp.log(jnp.abs(jnp.diag(matrix)))) | ||
|
||
|
||
class TriangularLinear(AbstractLinearBijector): | ||
"""A linear bijector whose weight matrix is triangular. | ||
The bijector is defined as `f(x) = Ax` where `A` is a DxD triangular matrix. | ||
The Jacobian determinant can be computed in O(D) as follows: | ||
log|det J(x)| = log|det A| = sum(log|diag(A)|) | ||
The inverse is computed in O(D^2) by solving the triangular system `Ax = y`. | ||
The bijector is invertible if and only if all diagonal elements of `A` are | ||
non-zero. It is the responsibility of the user to make sure that this is the | ||
case; the class will make no attempt to verify that the bijector is | ||
invertible. | ||
""" | ||
|
||
_matrix: Array | ||
_is_lower: bool | ||
|
||
def __init__(self, matrix: Array, is_lower: bool = True): | ||
"""Initializes a `TriangularLinear` bijector. | ||
**Arguments:** | ||
- `matrix`: a square matrix whose triangular part defines `A`. Can also be a | ||
batch of matrices. Whether `A` is the lower or upper triangular part of | ||
`matrix` is determined by `is_lower`. | ||
- `is_lower`: if True, `A` is set to the lower triangular part of `matrix`. If | ||
False, `A` is set to the upper triangular part of `matrix`. | ||
""" | ||
if matrix.ndim < 2: | ||
raise ValueError( | ||
f"`matrix` must have at least 2 dimensions, got {matrix.ndim}." | ||
) | ||
if matrix.shape[-2] != matrix.shape[-1]: | ||
raise ValueError( | ||
f"`matrix` must be square; instead, it has shape {matrix.shape[-2:]}." | ||
) | ||
super().__init__(event_dims=matrix.shape[-1]) | ||
|
||
self._matrix = jnp.tril(matrix) if is_lower else jnp.triu(matrix) | ||
self._is_lower = is_lower | ||
|
||
@property | ||
def matrix(self) -> Array: | ||
"""The triangular matrix `A` of the transformation.""" | ||
return self._matrix | ||
|
||
@property | ||
def is_lower(self) -> bool: | ||
"""True if `A` is lower triangular, False if upper triangular.""" | ||
return self._is_lower | ||
|
||
def forward(self, x: Array) -> Array: | ||
"""Computes y = f(x).""" | ||
return self._matrix @ x | ||
|
||
def forward_log_det_jacobian(self, x: Array) -> Array: | ||
"""Computes log|det J(f)(x)|.""" | ||
triangular_logdet = jnp.vectorize(_triangular_logdet, signature="(m,m)->()") | ||
return triangular_logdet(self._matrix) | ||
|
||
def forward_and_log_det(self, x: Array) -> tuple[Array, Array]: | ||
"""Computes y = f(x) and log|det J(f)(x)|.""" | ||
return self.forward(x), self.forward_log_det_jacobian(x) | ||
|
||
def inverse(self, y: Array) -> Array: | ||
"""Computes x = f^{-1}(y).""" | ||
return jax.scipy.linalg.solve_triangular(self._matrix, y, lower=self.is_lower) | ||
|
||
def inverse_log_det_jacobian(self, y: Array) -> Array: | ||
"""Computes log|det J(f^{-1})(y)|.""" | ||
return -self.forward_log_det_jacobian(y) | ||
|
||
def inverse_and_log_det(self, y: Array) -> tuple[Array, Array]: | ||
"""Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" | ||
return self.inverse(y), self.inverse_log_det_jacobian(y) | ||
|
||
def same_as(self, other: AbstractBijector) -> bool: | ||
"""Returns True if this bijector is guaranteed to be the same as `other`.""" | ||
if type(other) is TriangularLinear: # pylint: disable=unidiomatic-typecheck | ||
return all( | ||
( | ||
self.matrix is other.matrix, | ||
self.is_lower is other.is_lower, | ||
) | ||
) | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
"""MultivariateNormalTri distribution.""" | ||
|
||
from typing import Optional | ||
|
||
import jax.numpy as jnp | ||
from jaxtyping import Array | ||
|
||
from ..bijectors import DiagLinear, TriangularLinear | ||
from .mvn_from_bijector import MultivariateNormalFromBijector | ||
|
||
|
||
def _check_parameters(loc: Optional[Array], scale_tri: Optional[Array]) -> None: | ||
"""Checks that the inputs are correct.""" | ||
if loc is None and scale_tri is None: | ||
raise ValueError("At least one of `loc` and `scale_tri` must be specified.") | ||
|
||
if loc is not None and loc.ndim < 1: | ||
raise ValueError("The parameter `loc` must have at least one dimension.") | ||
|
||
if scale_tri is not None and scale_tri.ndim < 2: | ||
raise ValueError( | ||
f"The parameter `scale_tri` must have at least two dimensions, but " | ||
f"`scale_tri.shape = {scale_tri.shape}`." | ||
) | ||
|
||
if scale_tri is not None and scale_tri.shape[-1] != scale_tri.shape[-2]: | ||
raise ValueError( | ||
f"The parameter `scale_tri` must be a square matrix, but " | ||
f"`scale_tri.shape = {scale_tri.shape}`." | ||
) | ||
|
||
if loc is not None: | ||
num_dims = loc.shape[-1] | ||
if scale_tri is not None and scale_tri.shape[-1] != num_dims: | ||
raise ValueError( | ||
f"Shapes are not compatible: `loc.shape = {loc.shape}` and " | ||
f"`scale_tri.shape = {scale_tri.shape}`." | ||
) | ||
|
||
|
||
class MultivariateNormalTri(MultivariateNormalFromBijector): | ||
"""Multivariate normal distribution on `R^k`. | ||
The `MultivariateNormalTri` distribution is parameterized by a `k`-length | ||
location (mean) vector `b` and a (lower or upper) triangular scale matrix `S` | ||
of size `k x k`. The covariance matrix is `C = S @ S.T`. | ||
""" | ||
|
||
_scale_tri: Array | ||
_is_lower: bool | ||
|
||
def __init__( | ||
self, | ||
loc: Optional[Array] = None, | ||
scale_tri: Optional[Array] = None, | ||
is_lower: bool = True, | ||
): | ||
"""Initializes a MultivariateNormalTri distribution. | ||
**Arguments:** | ||
- `loc`: Mean vector of the distribution of shape `k`. | ||
If not specified, it defaults to zeros. | ||
- `scale_tri`: The scale matrix `S`. It must be a `k x k` triangular matrix. | ||
If `scale_tri` is not triangular, the entries above or below the main | ||
diagonal will be ignored. The parameter `is_lower` specifies if `scale_tri` | ||
is lower or upper triangular. It is the responsibility of the user to make | ||
sure that `scale_tri` only contains non-zero elements in its diagonal; | ||
this class makes no attempt to verify that. If `scale_tri` is not specified, | ||
it defaults to the identity. | ||
- `is_lower`: Indicates if `scale_tri` is lower (if True) or upper (if False) | ||
triangular. | ||
""" | ||
_check_parameters(loc, scale_tri) | ||
|
||
if loc is not None: | ||
num_dims = loc.shape[-1] | ||
elif scale_tri is not None: | ||
num_dims = scale_tri.shape[-1] | ||
else: | ||
raise ValueError | ||
|
||
dtype = jnp.result_type(*[x for x in [loc, scale_tri] if x is not None]) | ||
|
||
if loc is None: | ||
loc = jnp.zeros((num_dims,), dtype=dtype) | ||
|
||
if scale_tri is None: | ||
self._scale_tri = jnp.eye(num_dims, dtype=dtype) | ||
scale = DiagLinear(diag=jnp.ones(loc.shape[-1:], dtype=dtype)) | ||
else: | ||
tri_fn = jnp.tril if is_lower else jnp.triu | ||
self._scale_tri = tri_fn(scale_tri) | ||
scale = TriangularLinear(matrix=self._scale_tri, is_lower=is_lower) | ||
self._is_lower = is_lower | ||
|
||
super().__init__(loc=loc, scale=scale) | ||
|
||
@property | ||
def scale_tri(self) -> Array: | ||
"""Triangular scale matrix `S`.""" | ||
return jnp.broadcast_to(self._scale_tri, self.event_shape + self.event_shape) | ||
|
||
@property | ||
def is_lower(self) -> bool: | ||
"""Whether the `scale_tri` matrix is lower triangular.""" | ||
return self._is_lower |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Triangular Linear Bijector | ||
|
||
::: distreqx.bijectors.triangular_linear.TriangularLinear | ||
selection: | ||
members: | ||
- __init__ | ||
--- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Triangular Multivariate Normal | ||
|
||
::: distreqx.distributions.mvn_tri.MultivariateNormalTri | ||
selection: | ||
members: | ||
- __init__ | ||
--- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from unittest import TestCase | ||
|
||
import jax | ||
|
||
|
||
jax.config.update("jax_enable_x64", True) | ||
|
||
import equinox as eqx | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from parameterized import parameterized # type: ignore | ||
|
||
from distreqx.distributions import MultivariateNormalTri | ||
|
||
|
||
class MultivariateNormalTriTest(TestCase): | ||
def setUp(self): | ||
self.key = jax.random.PRNGKey(0) | ||
|
||
def _test_raises_error(self, dist_kwargs): | ||
with self.assertRaises(ValueError): | ||
dist = MultivariateNormalTri(**dist_kwargs) | ||
dist.sample(key=self.key) | ||
|
||
def test_invalid_parameters(self): | ||
self._test_raises_error(dist_kwargs={"loc": None, "scale_tri": None}) | ||
self._test_raises_error(dist_kwargs={"loc": jnp.array(1.0), "scale_tri": None}) | ||
self._test_raises_error(dist_kwargs={"loc": None, "scale_tri": jnp.array(1.0)}) | ||
self._test_raises_error( | ||
dist_kwargs={"loc": None, "scale_tri": jnp.array([1.0])} | ||
) | ||
self._test_raises_error( | ||
dist_kwargs={ | ||
"loc": None, | ||
"scale_tri": jnp.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), | ||
} | ||
) | ||
self._test_raises_error( | ||
dist_kwargs={"loc": jnp.zeros((5,)), "scale_tri": jnp.ones((4, 4))} | ||
) | ||
|
||
@parameterized.expand([("float32", jnp.float32), ("float64", jnp.float64)]) | ||
def test_sample_dtype(self, name, dtype): | ||
dist_params = { | ||
"loc": jnp.array([0.0, 0.0], dtype), | ||
"scale_tri": jnp.array([[1.0, 0.0], [0.0, 1.0]], dtype), | ||
"is_lower": True, | ||
} | ||
dist = MultivariateNormalTri(**dist_params) | ||
samples = dist.sample(key=self.key) | ||
self.assertEqual(samples.dtype, dist.dtype) | ||
self.assertEqual(samples.dtype, dtype) | ||
|
||
def test_jittable(self): | ||
@eqx.filter_jit | ||
def f(dist): | ||
return dist.sample(key=jax.random.PRNGKey(0)) | ||
|
||
dist_params = {"loc": jnp.zeros(2), "scale_tri": jnp.eye(2), "is_lower": True} | ||
dist = MultivariateNormalTri(**dist_params) | ||
y = f(dist) | ||
self.assertIsInstance(y, jax.Array) | ||
|
||
def assertion_fn(self, rtol): | ||
return lambda x, y: np.testing.assert_allclose(x, y, rtol=rtol) |
Oops, something went wrong.