diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..713d4a7 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,70 @@ +# Contributing + +Contributions (pull requests) are very welcome! Here's how to get started. + +--- + +**Getting started** + +First fork the library on GitHub. + +Then clone and install the library in development mode: + +```bash +git clone https://github.com/your-username-here/distreqx.git +cd distreqx +pip install -e . +``` + +Then install the pre-commit hook: + +```bash +pip install pre-commit +pre-commit install +``` + +These hooks use Black and isort to format the code, and flake8 to lint it. + +--- + +**If you're making changes to the code:** + +Now make your changes. Make sure to include additional tests if necessary. + +If you include a new features, there are 3 required classes of tests: +- Correctness: tests the are against analytic or known solutions that ensure the computation is correct +- Compatibility: tests that check for `jit`, `vmap`, and `grad`-ability of the feature to make sure they behave as expected +- Edge cases: tests that make sure edge cases (e.g. large/small numerics, unexpected dtypes) are either dealt with or fail in an expected manner + +Next verify the tests all pass: + +```bash +pip install pytest +pytest tests +``` + +Then push your changes back to your fork of the repository: + +```bash +git push +``` + +Finally, open a pull request on GitHub! + +--- + +**If you're making changes to the documentation:** + +Make your changes. You can then build the documentation by doing + +```bash +pip install -r docs/requirements.txt +mkdocs serve +``` +Then doing `Control-C`, and running: +``` +mkdocs serve +``` +(So you run `mkdocs serve` twice.) + +You can then see your local copy of the documentation by navigating to `localhost:8000` in a web browser. \ No newline at end of file diff --git a/README.md b/README.md index 5e03149..4b657ce 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,6 @@ If you found this library useful in academic research, please cite: [Lineax](https://github.com/patrick-kidger/lineax): linear solvers. [sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent. [diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. - -**Awesome JAX** [Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects. ## Original distrax copyright diff --git a/distreqx/bijectors/__init__.py b/distreqx/bijectors/__init__.py index 158b67b..f31850c 100644 --- a/distreqx/bijectors/__init__.py +++ b/distreqx/bijectors/__init__.py @@ -5,4 +5,5 @@ from .diag_linear import DiagLinear as DiagLinear from .scalar_affine import ScalarAffine as ScalarAffine from .shift import Shift as Shift +from .sigmoid import Sigmoid as Sigmoid from .tanh import Tanh as Tanh diff --git a/distreqx/bijectors/sigmoid.py b/distreqx/bijectors/sigmoid.py new file mode 100644 index 0000000..554c5aa --- /dev/null +++ b/distreqx/bijectors/sigmoid.py @@ -0,0 +1,71 @@ +"""Sigmoid bijector.""" + +from typing import Tuple + +import jax +import jax.numpy as jnp +from jaxtyping import Array + +from ._bijector import AbstractBijector + + +class Sigmoid(AbstractBijector): + """A bijector that computes the logistic sigmoid. + + The log-determinant implementation in this bijector is more numerically stable + than relying on the automatic differentiation approach used by Lambda, so this + bijector should be preferred over Lambda(jax.nn.sigmoid) where possible. + + Note that the underlying implementation of `jax.nn.sigmoid` used by the + `forward` function of this bijector does not support inputs of integer type. + To invoke the forward function of this bijector on an argument of integer + type, it should first be cast explicitly to a floating point type. + + When the absolute value of the input is large, `Sigmoid` becomes close to a + constant, so that it is not possible to recover the input `x` from the output + `y` within machine precision. In cases where it is needed to compute both the + forward mapping and the backward mapping one after the other to recover the + original input `x`, it is the user's responsibility to simplify the operation + to avoid numerical issues. One example of such case is to use the bijector + within a `Transformed` distribution and to obtain the log-probability of + samples obtained from the distribution's `sample` method. For values of the + samples for which it is not possible to apply the inverse bijector accurately, + `log_prob` returns NaN. This can be avoided by using `sample_and_log_prob` + instead of `sample` followed by `log_prob`. + """ + + def __init__(self) -> None: + super().__init__() + + def forward_log_det_jacobian(self, x: Array) -> Array: + """Computes log|det J(f)(x)|.""" + return -_more_stable_softplus(-x) - _more_stable_softplus(x) + + def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: + """Computes y = f(x) and log|det J(f)(x)|.""" + return _more_stable_sigmoid(x), self.forward_log_det_jacobian(x) + + def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: + """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" + x = jnp.log(y) - jnp.log1p(-y) + return x, -self.forward_log_det_jacobian(x) + + def same_as(self, other: AbstractBijector) -> bool: + """Returns True if this bijector is guaranteed to be the same as `other`.""" + return type(other) is Sigmoid + + +def _more_stable_sigmoid(x: Array) -> Array: + """Where extremely negatively saturated, approximate sigmoid with exp(x).""" + ret = jnp.where(x < -9, jnp.exp(x), jax.nn.sigmoid(x)) + if not isinstance(ret, Array): + raise TypeError("ret is not an Array") + return ret + + +def _more_stable_softplus(x: Array) -> Array: + """Where extremely saturated, approximate softplus with log1p(exp(x)).""" + ret = jnp.where(x < -9, jnp.log1p(jnp.exp(x)), jax.nn.softplus(x)) + if not isinstance(ret, Array): + raise TypeError("ret is not an Array") + return ret diff --git a/distreqx/bijectors/tanh.py b/distreqx/bijectors/tanh.py index a225380..4292e70 100644 --- a/distreqx/bijectors/tanh.py +++ b/distreqx/bijectors/tanh.py @@ -29,8 +29,7 @@ class Tanh(AbstractBijector): instead of `sample` followed by `log_prob`. """ - def __init__(self) -> None: - """Initialize the TanH bijector.""" + def __init__(self): super().__init__() def forward_log_det_jacobian(self, x: Array) -> Array: diff --git a/distreqx/distributions/__init__.py b/distreqx/distributions/__init__.py index 0eec17d..875ccb7 100644 --- a/distreqx/distributions/__init__.py +++ b/distreqx/distributions/__init__.py @@ -1,7 +1,6 @@ from ._distribution import ( AbstractDistribution as AbstractDistribution, ) -from ._transformed import Transformed as Transformed from .bernoulli import Bernoulli as Bernoulli from .independent import Independent as Independent from .mvn_diag import MultivariateNormalDiag as MultivariateNormalDiag @@ -9,3 +8,4 @@ MultivariateNormalFromBijector as MultivariateNormalFromBijector, ) from .normal import Normal as Normal +from .transformed import Transformed as Transformed diff --git a/distreqx/distributions/mvn_diag.py b/distreqx/distributions/mvn_diag.py index 84c5b3c..2291094 100644 --- a/distreqx/distributions/mvn_diag.py +++ b/distreqx/distributions/mvn_diag.py @@ -54,14 +54,12 @@ def __init__(self, loc: Optional[Array] = None, scale_diag: Optional[Array] = No elif loc is None and scale_diag is not None: loc = jnp.zeros(scale_diag.shape[-1], scale_diag.dtype) - assert loc is not None - assert scale_diag is not None - - broadcasted_shapes = jnp.broadcast_shapes(loc.shape, scale_diag.shape) - loc = jnp.expand_dims(loc, axis=list(range(len(broadcasted_shapes) - loc.ndim))) - scale_diag = jnp.expand_dims( - scale_diag, axis=list(range(len(broadcasted_shapes) - scale_diag.ndim)) - ) + if loc is None: + raise ValueError("loc is None") + if scale_diag is None: + raise ValueError("scale_diag is None") + if scale_diag.ndim != 1: + raise ValueError("scale_diag must be a vector!") scale = DiagLinear(scale_diag) super().__init__(loc=loc, scale=scale) diff --git a/distreqx/distributions/mvn_from_bijector.py b/distreqx/distributions/mvn_from_bijector.py index dd613a7..233ff52 100644 --- a/distreqx/distributions/mvn_from_bijector.py +++ b/distreqx/distributions/mvn_from_bijector.py @@ -8,9 +8,9 @@ from jaxtyping import Array from ..bijectors import AbstractLinearBijector, Block, Chain, DiagLinear, Shift -from ._transformed import Transformed from .independent import Independent from .normal import Normal +from .transformed import Transformed def _check_input_parameters_are_valid( diff --git a/distreqx/distributions/_transformed.py b/distreqx/distributions/transformed.py similarity index 100% rename from distreqx/distributions/_transformed.py rename to distreqx/distributions/transformed.py diff --git a/docs/api/bijectors/sigmoid.md b/docs/api/bijectors/sigmoid.md new file mode 100644 index 0000000..ba8ae98 --- /dev/null +++ b/docs/api/bijectors/sigmoid.md @@ -0,0 +1,6 @@ +# Sigmoid Bijector + +::: distreqx.bijectors.sigmoid.Sigmoid + selection: + members: None +--- \ No newline at end of file diff --git a/docs/api/bijectors/tanh.md b/docs/api/bijectors/tanh.md index b24c8b7..0e0c6b9 100644 --- a/docs/api/bijectors/tanh.md +++ b/docs/api/bijectors/tanh.md @@ -2,6 +2,5 @@ ::: distreqx.bijectors.tanh.Tanh selection: - members: - - __init__ + members: None --- \ No newline at end of file diff --git a/docs/api/distributions/_transformed.md b/docs/api/distributions/transformed.md similarity index 75% rename from docs/api/distributions/_transformed.md rename to docs/api/distributions/transformed.md index 36ac378..20a5d0c 100644 --- a/docs/api/distributions/_transformed.md +++ b/docs/api/distributions/transformed.md @@ -1,6 +1,6 @@ # Base Transformed -::: distreqx.distributions._transformed.Transformed +::: distreqx.distributions.transformed.Transformed selection: members: - __init__ diff --git a/mkdocs.yml b/mkdocs.yml index b07f47f..f6ad60d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -95,9 +95,9 @@ nav: - API: - Distributions: - 'api/distributions/_distribution.md' - - 'api/distributions/_transformed.md' - 'api/distributions/independent.md' - 'api/distributions/bernoulli.md' + - 'api/distributions/transformed.md' - Gaussians: - 'api/distributions/normal.md' - 'api/distributions/mvn_diag.md' @@ -110,6 +110,7 @@ nav: - 'api/bijectors/diag_linear.md' - 'api/bijectors/scalar_affine.md' - 'api/bijectors/shift.md' + - 'api/bijectors/sigmoid.md' - 'api/bijectors/tanh.md' - Utilities: - 'api/utils/math.md' diff --git a/tests/abstractlinear_test.py b/tests/abstractlinear_test.py new file mode 100644 index 0000000..6c8e1f9 --- /dev/null +++ b/tests/abstractlinear_test.py @@ -0,0 +1,31 @@ +"""Tests for `linear.py`.""" + +from unittest import TestCase + +from parameterized import parameterized # type: ignore + +from distreqx.bijectors import AbstractLinearBijector + + +class MockLinear(AbstractLinearBijector): + def __init__(self, dims): + super().__init__(dims) + + def forward_and_log_det(self, x): + raise Exception + + +class LinearTest(TestCase): + @parameterized.expand( + [ + ("1", 1), + ("10", 10), + ] + ) + def test_properties(self, name, event_dims): + bij = MockLinear(event_dims) + self.assertTrue(bij.is_constant_jacobian) + self.assertTrue(bij.is_constant_log_det) + self.assertEqual(bij.event_dims, event_dims) + with self.assertRaises(NotImplementedError): + bij.matrix diff --git a/tests/block_test.py b/tests/block_test.py new file mode 100644 index 0000000..291fbe9 --- /dev/null +++ b/tests/block_test.py @@ -0,0 +1,90 @@ +"""Tests for `block.py`.""" + +from unittest import TestCase + +import jax +import jax.numpy as jnp +import numpy as np +from parameterized import parameterized # type: ignore + +from distreqx.bijectors import AbstractBijector, Block, ScalarAffine, Tanh + + +RTOL = 1e-6 +seed = jax.random.PRNGKey(1234) + + +class BlockTest(TestCase): + def test_properties(self): + bijct = Tanh() + block = Block(bijct, 1) + self.assertEqual(block.ndims, 1) + self.assertIsInstance(block.bijector, AbstractBijector) + + def test_invalid_properties(self): + bijct = Tanh() + with self.assertRaises(ValueError): + Block(bijct, -1) + + @parameterized.expand( + [ + ("dx_tanh_0", Tanh, 0), + ("dx_tanh_1", Tanh, 1), + ("dx_tanh_2", Tanh, 2), + ] + ) + def test_forward_inverse_work_as_expected(self, name, bijector_fn, ndims): + bijct = bijector_fn() + x = jax.random.normal(seed, [2, 3]) + block = Block(bijct, ndims) + np.testing.assert_array_equal(bijct.forward(x), block.forward(x)) + np.testing.assert_array_equal(bijct.inverse(x), block.inverse(x)) + np.testing.assert_allclose( + bijct.forward_and_log_det(x)[0], block.forward_and_log_det(x)[0], atol=2e-7 + ) + np.testing.assert_array_equal( + bijct.inverse_and_log_det(x)[0], block.inverse_and_log_det(x)[0] + ) + + @parameterized.expand( + [ + ("dx_tanh_0", Tanh, 0), + ("dx_tanh_1", Tanh, 1), + ("dx_tanh_2", Tanh, 2), + ] + ) + def test_log_det_jacobian_works_as_expected(self, name, bijector_fn, ndims): + bijct = bijector_fn() + x = jax.random.normal(seed, [2, 3]) + block = Block(bijct, ndims) + axes = tuple(range(-ndims, 0)) + np.testing.assert_allclose( + bijct.forward_log_det_jacobian(x).sum(axes), + block.forward_log_det_jacobian(x), + rtol=RTOL, + ) + np.testing.assert_allclose( + bijct.inverse_log_det_jacobian(x).sum(axes), + block.inverse_log_det_jacobian(x), + rtol=RTOL, + ) + np.testing.assert_allclose( + bijct.forward_and_log_det(x)[1].sum(axes), + block.forward_and_log_det(x)[1], + rtol=RTOL, + ) + np.testing.assert_allclose( + bijct.inverse_and_log_det(x)[1].sum(axes), + block.inverse_and_log_det(x)[1], + rtol=RTOL, + ) + + def test_jittable(self): + @jax.jit + def f(x, b): + return b.forward(x) + + bijector = Block(ScalarAffine(jnp.array(0)), 1) + x = jnp.zeros((2, 3)) + y = f(x, bijector) + self.assertIsInstance(y, jax.Array) diff --git a/tests/chain_test.py b/tests/chain_test.py new file mode 100644 index 0000000..fe67e22 --- /dev/null +++ b/tests/chain_test.py @@ -0,0 +1,36 @@ +"""Tests for `chain.py`.""" + +from unittest import TestCase + +import jax +import jax.numpy as jnp +import numpy as np + +from distreqx.bijectors import AbstractBijector, Chain, ScalarAffine, Tanh + + +RTOL = 1e-2 + + +class ChainTest(TestCase): + def setUp(self): + self.seed = jax.random.PRNGKey(1234) + + def test_properties(self): + bijector = Chain([Tanh()]) + for bij in bijector.bijectors: + assert isinstance(bij, AbstractBijector) + + def test_raises_on_empty_list(self): + with self.assertRaises(ValueError): + Chain([]) + + def test_jittable(self): + @jax.jit + def f(x, b): + return b.forward(x) + + bijector = Chain([ScalarAffine(jnp.array(0.0), jnp.array(1.0))]) + x = np.zeros(()) + y = f(x, bijector) + self.assertIsInstance(y, jax.Array) diff --git a/tests/diag_linear_test.py b/tests/diag_linear_test.py new file mode 100644 index 0000000..4e6cf5f --- /dev/null +++ b/tests/diag_linear_test.py @@ -0,0 +1,120 @@ +"""Tests for `diag_linear.py`.""" + +from unittest import TestCase + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np + +from distreqx.bijectors import DiagLinear, Tanh + + +class DiagLinearTest(TestCase): + def test_static_properties(self): + bij = DiagLinear(diag=jnp.ones((4,))) + self.assertTrue(bij.is_constant_jacobian) + self.assertTrue(bij.is_constant_log_det) + self.assertEqual(bij.event_dims, 4) + + def test_properties(self): + bij = DiagLinear(diag=jnp.ones((4,))) + self.assertEqual(bij.event_dims, 4) + self.assertEqual(bij.diag.shape, (4,)) + self.assertEqual(bij.matrix.shape, (4, 4)) + np.testing.assert_allclose(bij.diag, 1.0, atol=1e-6) + np.testing.assert_allclose(bij.matrix, np.eye(4), atol=1e-6) + + def test_raises_with_invalid_parameters(self): + with self.assertRaises(ValueError): + DiagLinear(diag=jnp.ones(())) + + def test_parameters(self): + prng = jax.random.PRNGKey(42) + prng = jax.random.split(prng, 2) + diag = jax.random.uniform(prng[0], (4,)) + 0.5 + bij = DiagLinear(diag) + + x = jax.random.normal(prng[1], (4,)) + y, logdet_fwd = bij.forward_and_log_det(x) + z, logdet_inv = bij.inverse_and_log_det(x) + + self.assertEqual(y.shape, (4,)) + self.assertEqual(z.shape, (4,)) + self.assertEqual(logdet_fwd.shape, ()) + self.assertEqual(logdet_inv.shape, ()) + + def test_identity_initialization(self): + bij = DiagLinear(diag=jnp.ones((4,))) + prng = jax.random.PRNGKey(42) + x = jax.random.normal(prng, (4,)) + + # Forward methods. + y, logdet = bij.forward_and_log_det(x) + np.testing.assert_array_equal(y, x) + np.testing.assert_array_equal(logdet, jnp.zeros(1)) + + # Inverse methods. + x_rec, logdet = bij.inverse_and_log_det(y) + np.testing.assert_array_equal(x_rec, y) + np.testing.assert_array_equal(logdet, jnp.zeros(1)) + + def test_inverse_methods(self): + prng = jax.random.PRNGKey(42) + prng = jax.random.split(prng, 2) + diag = jax.random.uniform(prng[0], (4,)) + 0.5 + bij = DiagLinear(diag) + x = jax.random.normal(prng[1], (4,)) + y, logdet_fwd = bij.forward_and_log_det(x) + x_rec, logdet_inv = bij.inverse_and_log_det(y) + np.testing.assert_allclose(x_rec, x, atol=1e-6) + np.testing.assert_allclose(logdet_fwd, -logdet_inv, atol=1e-6) + + def test_forward_jacobian_det(self): + prng = jax.random.PRNGKey(42) + prng = jax.random.split(prng, 3) + diag = jax.random.uniform(prng[0], (4,)) + 0.5 + bij = DiagLinear(diag) + + batched_x = jax.random.normal(prng[1], (10, 4)) + single_x = jax.random.normal(prng[2], (4,)) + batched_logdet = eqx.filter_vmap(bij.forward_log_det_jacobian)(batched_x) + + jacobian_fn = jax.jacfwd(bij.forward) + logdet_numerical = jnp.linalg.slogdet(jacobian_fn(single_x))[1] + for logdet in batched_logdet: + np.testing.assert_allclose(logdet, logdet_numerical, atol=5e-4) + + def test_inverse_jacobian_det(self): + prng = jax.random.PRNGKey(42) + prng = jax.random.split(prng, 3) + diag = jax.random.uniform(prng[0], (4,)) + 0.5 + bij = DiagLinear(diag) + + batched_y = jax.random.normal(prng[1], (10, 4)) + single_y = jax.random.normal(prng[2], (4,)) + batched_logdet = eqx.filter_vmap(bij.inverse_log_det_jacobian)(batched_y) + + jacobian_fn = jax.jacfwd(bij.inverse) + logdet_numerical = jnp.linalg.slogdet(jacobian_fn(single_y))[1] + for logdet in batched_logdet: + np.testing.assert_allclose(logdet, logdet_numerical, atol=5e-4) + + def test_jittable(self): + @eqx.filter_jit + def f(x, b): + return b.forward(x) + + bij = DiagLinear(diag=jnp.ones((4,))) + x = jnp.zeros((4,)) + f(x, bij) + + def test_same_as_itself(self): + bij = DiagLinear(diag=jnp.ones((4,))) + self.assertTrue(bij.same_as(bij)) + + def test_not_same_as_others(self): + bij = DiagLinear(diag=jnp.ones((4,))) + other = DiagLinear(diag=2.0 * jnp.ones((4,))) + self.assertFalse(bij.same_as(other)) + self.assertFalse(bij.same_as(Tanh())) diff --git a/tests/independent_test.py b/tests/independent_test.py new file mode 100644 index 0000000..bc33274 --- /dev/null +++ b/tests/independent_test.py @@ -0,0 +1,27 @@ +"""Tests for `independent.py`.""" + +from unittest import TestCase + +import equinox as eqx +import jax.numpy as jnp +import numpy as np + +from distreqx.distributions import Independent, Normal + + +class IndependentTest(TestCase): + """Class to test miscellaneous methods of the `Independent` distribution.""" + + def setUp(self): + self.loc = jnp.array(np.random.randn(2, 3, 4)) + self.scale = jnp.array(np.abs(np.random.randn(2, 3, 4))) + self.base = Normal(loc=self.loc, scale=self.scale) + self.dist = Independent(self.base) + + def assertion_fn(self, rtol): + return lambda x, y: np.testing.assert_allclose(x, y, rtol=rtol) + + def test_constructor_is_jittable_given_ndims(self): + constructor = lambda d: Independent(d) + model = eqx.filter_jit(constructor)(self.base) + self.assertIsInstance(model, Independent) diff --git a/tests/mvn_diag_test.py b/tests/mvn_diag_test.py new file mode 100644 index 0000000..69a411b --- /dev/null +++ b/tests/mvn_diag_test.py @@ -0,0 +1,64 @@ +from unittest import TestCase + +import jax + + +jax.config.update("jax_enable_x64", True) + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np +from parameterized import parameterized # type: ignore + +from distreqx.distributions import MultivariateNormalDiag + + +class MultivariateNormalDiagTest(TestCase): + def setUp(self): + self.key = jax.random.PRNGKey(0) + + def _test_raises_error(self, dist_kwargs): + with self.assertRaises(ValueError): + dist = MultivariateNormalDiag(**dist_kwargs) + dist.sample(key=self.key) + + def test_invalid_parameters(self): + self._test_raises_error(dist_kwargs={"loc": None, "scale_diag": None}) + self._test_raises_error(dist_kwargs={"loc": None, "scale_diag": jnp.array(1.0)}) + self._test_raises_error(dist_kwargs={"loc": jnp.array(1.0), "scale_diag": None}) + self._test_raises_error( + dist_kwargs={"loc": jnp.zeros((3, 5)), "scale_diag": jnp.ones((3, 4))} + ) + + @parameterized.expand([("float32", jnp.float32), ("float64", jnp.float64)]) + def test_sample_dtype(self, name, dtype): + dist_params = { + "loc": jnp.array([0.0, 0.0], dtype), + "scale_diag": jnp.array([1.0, 1.0], dtype), + } + dist = MultivariateNormalDiag(**dist_params) + samples = dist.sample(key=self.key) + self.assertEqual(samples.dtype, dist.dtype) + self.assertEqual(samples.dtype, dtype) + + def test_median(self): + dist_params = { + "loc": jnp.array([0.3, -0.1, 0.0]), + "scale_diag": jnp.array([0.1, 1.4, 0.5]), + } + dist = MultivariateNormalDiag(**dist_params) + np.testing.assert_allclose(dist.median(), dist.mean(), rtol=1e-3) + + def test_jittable(self): + @eqx.filter_jit + def f(dist): + return dist.sample(key=jax.random.PRNGKey(0)) + + dist_params = {"loc": jnp.zeros(2), "scale_diag": jnp.ones(2)} + dist = MultivariateNormalDiag(**dist_params) + y = f(dist) + self.assertIsInstance(y, jax.Array) + + def assertion_fn(self, rtol): + return lambda x, y: np.testing.assert_allclose(x, y, rtol=rtol) diff --git a/tests/mvn_from_bijector_test.py b/tests/mvn_from_bijector_test.py new file mode 100644 index 0000000..6a1cc9d --- /dev/null +++ b/tests/mvn_from_bijector_test.py @@ -0,0 +1,128 @@ +"""Tests for `mvn_from_bijector.py`.""" + +from unittest import TestCase + +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np +from parameterized import parameterized # type: ignore + +from distreqx.bijectors import AbstractLinearBijector, DiagLinear +from distreqx.distributions import MultivariateNormalFromBijector + + +class MockLinear(AbstractLinearBijector): + """A mock linear bijector.""" + + def __init__(self, event_dims: int): + super().__init__(event_dims) + + def forward_and_log_det(self, x): + """Computes y = f(x) and log|det J(f)(x)|.""" + return x, jnp.zeros_like(x)[:-1] + + +class MultivariateNormalFromBijectorTest(TestCase): + @parameterized.expand( + [ + ("loc is 0d", 4, jnp.zeros(shape=())), + ("loc and scale dims not compatible", 3, jnp.zeros((4,))), + ] + ) + def test_raises_on_wrong_inputs(self, name, event_dims, loc): + bij = MockLinear(event_dims) + with self.assertRaises(ValueError): + MultivariateNormalFromBijector(loc, bij) + + @parameterized.expand([("no broadcast", jnp.ones((4,)), jnp.zeros((4,)), (4,))]) + def test_loc_scale_and_shapes(self, name, diag, loc, expected_shape): + scale = DiagLinear(diag) + dist = MultivariateNormalFromBijector(loc, scale) + np.testing.assert_allclose(dist.loc, np.zeros(expected_shape)) + self.assertTrue(scale.same_as(dist.scale)) + self.assertEqual(dist.event_shape, (4,)) + + def test_sample(self): + prng = jax.random.PRNGKey(42) + keys = jax.random.split(prng, 2) + diag = 0.5 + jax.random.uniform(keys[0], (4,)) + loc = jax.random.normal(keys[1], (4,)) + scale = DiagLinear(diag) + dist = MultivariateNormalFromBijector(loc, scale) + num_samples = 100_000 + sample_fn = lambda seed: dist.sample(key=seed) + samples = eqx.filter_vmap(sample_fn)(jax.random.split(prng, num_samples)) + self.assertEqual(samples.shape, (num_samples, 4)) + np.testing.assert_allclose(jnp.mean(samples, axis=0), loc, rtol=0.1) + np.testing.assert_allclose(jnp.std(samples, axis=0), diag, rtol=0.1) + + @parameterized.expand( + [ + ("no broadcast", (4,), (4,)), + ] + ) + def test_mean_median_mode(self, name, diag_shape, loc_shape): + prng = jax.random.PRNGKey(42) + diag = jax.random.normal(prng, diag_shape) + loc = jax.random.normal(prng, loc_shape) + scale = DiagLinear(diag) + batch_shape = jnp.broadcast_shapes(diag_shape, loc_shape)[:-1] + dist = MultivariateNormalFromBijector(loc, scale) + for method in ["mean", "median", "mode"]: + with self.subTest(method=method): + fn = getattr(dist, method) + np.testing.assert_allclose( + fn(), jnp.broadcast_to(loc, batch_shape + loc.shape[-1:]) + ) + + @parameterized.expand( + [ + ("kl distreqx_to_distreqx", "kl_divergence"), + ("cross-ent distreqx_to_distreqx", "cross_entropy"), + ] + ) + def test_with_two_distributions(self, name, function_string): + rng = np.random.default_rng(42) + rng1 = np.random.default_rng(42) + rng2 = np.random.default_rng(42) + + dist1_kwargs = { + "loc": jnp.array(rng.normal(size=(5,)).astype(np.float32)), + "scale": DiagLinear( + 0.1 + jnp.array(rng1.uniform(size=(5,)).astype(np.float32)) + ), + } + dist2_kwargs = { + "loc": jnp.asarray([-2.4, -1.0, 0.0, 1.2, 6.5]).astype(np.float32), + "scale": DiagLinear( + 0.1 + jnp.array(rng2.uniform(size=(5,)).astype(np.float32)) + ), + } + + dist1 = MultivariateNormalFromBijector(**dist1_kwargs) + dist2 = MultivariateNormalFromBijector(**dist2_kwargs) + + if function_string == "kl_divergence": + result1 = dist1.kl_divergence(dist2) + result2 = dist2.kl_divergence(dist1) + elif function_string == "cross_entropy": + result1 = dist1.cross_entropy(dist2) + result2 = dist2.cross_entropy(dist1) + else: + raise ValueError(f"Unsupported function string: {function_string}") + np.testing.assert_allclose(result1, result2, rtol=1e-3) + + def test_kl_divergence_raises_on_incompatible_distributions(self): + dim = 4 + dist1 = MultivariateNormalFromBijector( + loc=jnp.zeros((dim,)), + scale=DiagLinear(diag=jnp.ones((dim,))), + ) + dim = 5 + dist2 = MultivariateNormalFromBijector( + loc=jnp.zeros((dim,)), + scale=DiagLinear(diag=jnp.ones((dim,))), + ) + with self.assertRaises(TypeError): + dist1.kl_divergence(dist2) diff --git a/tests/scalar_affine_test.py b/tests/scalar_affine_test.py new file mode 100644 index 0000000..b3b6c0d --- /dev/null +++ b/tests/scalar_affine_test.py @@ -0,0 +1,102 @@ +"""Tests for `scalar_affine.py`.""" + +from unittest import TestCase + +import jax +import jax.numpy as jnp +import numpy as np + +from distreqx.bijectors import ScalarAffine + + +class ScalarAffineTest(TestCase): + def test_properties(self): + bij = ScalarAffine(shift=jnp.array(0.0), scale=jnp.array(1.0)) + self.assertTrue(bij.is_constant_jacobian) + self.assertTrue(bij.is_constant_log_det) + np.testing.assert_allclose(bij.shift, 0.0) + np.testing.assert_allclose(bij.scale, 1.0) + np.testing.assert_allclose(bij.log_scale, 0.0) + + def test_raises_if_both_scale_and_log_scale_are_specified(self): + with self.assertRaises(ValueError): + ScalarAffine( + shift=jnp.array(0.0), scale=jnp.array(1.0), log_scale=jnp.array(0.0) + ) + + def test_shapes_are_correct(self): + k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4) + x = jax.random.normal(k1, (3, 4, 5)) + shift = jax.random.normal(k2, (3, 4, 5)) + scale = jax.random.uniform(k3, (3, 4, 5)) + 0.1 + log_scale = jax.random.normal(k4, (3, 4, 5)) + bij_no_scale = ScalarAffine(shift) + bij_with_scale = ScalarAffine(shift, scale=scale) + bij_with_log_scale = ScalarAffine(shift, log_scale=log_scale) + for bij in [bij_no_scale, bij_with_scale, bij_with_log_scale]: + # Forward methods. + y, logdet = bij.forward_and_log_det(x) + self.assertEqual(y.shape, (3, 4, 5)) + self.assertEqual(logdet.shape, (3, 4, 5)) + # Inverse methods. + x, logdet = bij.inverse_and_log_det(y) + self.assertEqual(x.shape, (3, 4, 5)) + self.assertEqual(logdet.shape, (3, 4, 5)) + + def test_forward_methods_are_correct(self): + key = jax.random.PRNGKey(42) + x = jax.random.normal(key, (2, 3, 4, 5)) + bij_no_scale = ScalarAffine(shift=jnp.array(3.0)) + bij_with_scale = ScalarAffine(shift=jnp.array(3.0), scale=jnp.array(1.0)) + bij_with_log_scale = ScalarAffine( + shift=jnp.array(3.0), log_scale=jnp.array(0.0) + ) + for bij in [bij_no_scale, bij_with_scale, bij_with_log_scale]: + y, logdet = bij.forward_and_log_det(x) + np.testing.assert_allclose(y, x + 3.0, atol=1e-8) + np.testing.assert_allclose(logdet, 0.0, atol=1e-8) + + def test_inverse_methods_are_correct(self): + k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4) + x = jax.random.normal(k1, (2, 3, 4, 5)) + shift = jax.random.normal(k2, (4, 5)) + scale = jax.random.uniform(k3, (3, 4, 5)) + 0.1 + log_scale = jax.random.normal(k4, (3, 4, 5)) + bij_no_scale = ScalarAffine(shift) + bij_with_scale = ScalarAffine(shift, scale=scale) + bij_with_log_scale = ScalarAffine(shift, log_scale=log_scale) + for bij in [bij_no_scale, bij_with_scale, bij_with_log_scale]: + y, logdet_fwd = bij.forward_and_log_det(x) + x_rec, logdet_inv = bij.inverse_and_log_det(y) + np.testing.assert_allclose(x_rec, x, atol=1e-5) + np.testing.assert_allclose(logdet_fwd, -logdet_inv, atol=3e-6) + + def test_composite_methods_are_consistent(self): + k1, k2, k3, k4 = jax.random.split(jax.random.PRNGKey(42), 4) + bij = ScalarAffine( + shift=jax.random.normal(k1, (4, 5)), log_scale=jax.random.normal(k2, (4, 5)) + ) + # Forward methods. + x = jax.random.normal(k3, (2, 3, 4, 5)) + y1 = bij.forward(x) + logdet1 = bij.forward_log_det_jacobian(x) + y2, logdet2 = bij.forward_and_log_det(x) + np.testing.assert_allclose(y1, y2, atol=1e-12) + np.testing.assert_allclose(logdet1, logdet2, atol=1e-12) + # Inverse methods. + y = jax.random.normal(k4, (2, 3, 4, 5)) + x1 = bij.inverse(y) + logdet1 = bij.inverse_log_det_jacobian(y) + x2, logdet2 = bij.inverse_and_log_det(y) + np.testing.assert_allclose(x1, x2, atol=1e-12) + np.testing.assert_allclose(logdet1, logdet2, atol=1e-12) + + def test_jittable(self): + @jax.jit + def f(x, b): + return b.forward(x) + + bijector = ScalarAffine(jnp.array(0.0), jnp.array(1.0)) + x = jnp.zeros(()) + y = f(x, bijector) + self.assertIsInstance(y, jax.Array) diff --git a/tests/sigmoid_test.py b/tests/sigmoid_test.py new file mode 100644 index 0000000..4171460 --- /dev/null +++ b/tests/sigmoid_test.py @@ -0,0 +1,119 @@ +"""Tests for `sigmoid.py`.""" + +from unittest import TestCase + +import jax +import jax.numpy as jnp +import numpy as np +from parameterized import parameterized # type: ignore + +from distreqx.bijectors import Sigmoid, Tanh + + +RTOL = 1e-5 + + +class SigmoidTest(TestCase): + def setUp(self): + self.seed = jax.random.PRNGKey(1234) + + def test_properties(self): + bijector = Sigmoid() + self.assertFalse(bijector.is_constant_jacobian) + self.assertFalse(bijector.is_constant_log_det) + + @parameterized.expand( + [("x_shape", (2,)), ("x_shape", (2, 3)), ("x_shape", (2, 3, 4))] + ) + def test_forward_shapes(self, name, x_shape): + x = jnp.zeros(shape=x_shape) + bijector = Sigmoid() + y1 = bijector.forward(x) + logdet1 = bijector.forward_log_det_jacobian(x) + y2, logdet2 = bijector.forward_and_log_det(x) + self.assertEqual(y1.shape, x_shape) + self.assertEqual(y2.shape, x_shape) + self.assertEqual(logdet1.shape, x_shape) + self.assertEqual(logdet2.shape, x_shape) + + @parameterized.expand( + [("y_shape", (2,)), ("y_shape", (2, 3)), ("y_shape", (2, 3, 4))] + ) + def test_inverse_shapes(self, name, y_shape): + y = jnp.zeros(shape=y_shape) + bijector = Sigmoid() + x1 = bijector.inverse(y) + logdet1 = bijector.inverse_log_det_jacobian(y) + x2, logdet2 = bijector.inverse_and_log_det(y) + self.assertEqual(x1.shape, y_shape) + self.assertEqual(x2.shape, y_shape) + self.assertEqual(logdet1.shape, y_shape) + self.assertEqual(logdet2.shape, y_shape) + + def test_forward(self): + prng = jax.random.PRNGKey(42) + x = jax.random.normal(prng, (100,)) + bijector = Sigmoid() + y = bijector.forward(x) + np.testing.assert_allclose(y, jax.nn.sigmoid(x), rtol=RTOL) + + def test_forward_log_det_jacobian(self): + prng = jax.random.PRNGKey(42) + x = jax.random.normal(prng, (100,)) + bijector = Sigmoid() + fwd_logdet = bijector.forward_log_det_jacobian(x) + actual = jnp.log(jax.vmap(jax.grad(bijector.forward))(x)) + np.testing.assert_allclose(fwd_logdet, actual, rtol=1e-3) + + def test_forward_and_log_det(self): + prng = jax.random.PRNGKey(42) + x = jax.random.normal(prng, (100,)) + bijector = Sigmoid() + y1 = bijector.forward(x) + logdet1 = bijector.forward_log_det_jacobian(x) + y2, logdet2 = bijector.forward_and_log_det(x) + np.testing.assert_allclose(y1, y2, rtol=RTOL) + np.testing.assert_allclose(logdet1, logdet2, rtol=RTOL) + + def test_inverse(self): + prng = jax.random.PRNGKey(42) + x = jax.random.normal(prng, (100,)) + bijector = Sigmoid() + y = bijector.forward(x) + x_rec = bijector.inverse(y) + np.testing.assert_allclose(x_rec, x, rtol=1e-3) + + def test_inverse_log_det_jacobian(self): + prng = jax.random.PRNGKey(42) + x = jax.random.normal(prng, (100,)) + bijector = Sigmoid() + y = bijector.forward(x) + fwd_logdet = bijector.forward_log_det_jacobian(x) + inv_logdet = bijector.inverse_log_det_jacobian(y) + np.testing.assert_allclose(inv_logdet, -fwd_logdet, rtol=1e-4) + + def test_inverse_and_log_det(self): + prng = jax.random.PRNGKey(42) + y = jax.random.normal(prng, (100,)) + bijector = Sigmoid() + x1 = bijector.inverse(y) + logdet1 = bijector.inverse_log_det_jacobian(y) + x2, logdet2 = bijector.inverse_and_log_det(y) + np.testing.assert_allclose(x1, x2, rtol=RTOL) + np.testing.assert_allclose(logdet1, logdet2, rtol=RTOL) + + def test_jittable(self): + @jax.jit + def f(x, b): + return b.forward(x) + + bijector = Sigmoid() + x = jnp.zeros(()) + y = f(x, bijector) + self.assertIsInstance(y, jax.Array) + + def test_same_as(self): + bijector = Sigmoid() + self.assertTrue(bijector.same_as(bijector)) + self.assertTrue(bijector.same_as(Sigmoid())) + self.assertFalse(bijector.same_as(Tanh())) diff --git a/tests/tanh_test.py b/tests/tanh_test.py new file mode 100644 index 0000000..136d750 --- /dev/null +++ b/tests/tanh_test.py @@ -0,0 +1,129 @@ +"""Tests for `tanh.py`.""" + +from unittest import TestCase + +import jax +import jax.numpy as jnp +import numpy as np +from parameterized import parameterized # type: ignore + +from distreqx.bijectors import Sigmoid, Tanh + + +RTOL = 1e-5 + + +class TanhTest(TestCase): + def setUp(self): + self.seed = jax.random.PRNGKey(1234) + + def test_properties(self): + bijector = Tanh() + self.assertFalse(bijector.is_constant_jacobian) + self.assertFalse(bijector.is_constant_log_det) + + @parameterized.expand( + [("x_shape", (2,)), ("x_shape", (2, 3)), ("x_shape", (2, 3, 4))] + ) + def test_forward_shapes(self, name, x_shape): + x = jnp.zeros(shape=x_shape) + bijector = Tanh() + y1 = bijector.forward(x) + logdet1 = bijector.forward_log_det_jacobian(x) + y2, logdet2 = bijector.forward_and_log_det(x) + self.assertEqual(y1.shape, x_shape) + self.assertEqual(y2.shape, x_shape) + self.assertEqual(logdet1.shape, x_shape) + self.assertEqual(logdet2.shape, x_shape) + + @parameterized.expand( + [("y_shape", (2,)), ("y_shape", (2, 3)), ("y_shape", (2, 3, 4))] + ) + def test_inverse_shapes(self, name, y_shape): + y = jnp.zeros(shape=y_shape) + bijector = Tanh() + x1 = bijector.inverse(y) + logdet1 = bijector.inverse_log_det_jacobian(y) + x2, logdet2 = bijector.inverse_and_log_det(y) + self.assertEqual(x1.shape, y_shape) + self.assertEqual(x2.shape, y_shape) + self.assertEqual(logdet1.shape, y_shape) + self.assertEqual(logdet2.shape, y_shape) + + def test_forward(self): + x = jax.random.normal(self.seed, (100,)) + bijector = Tanh() + y = bijector.forward(x) + np.testing.assert_allclose(y, jnp.tanh(x), rtol=RTOL) + + def test_forward_log_det_jacobian(self): + x = jax.random.normal(self.seed, (100,)) + bijector = Tanh() + fwd_logdet = bijector.forward_log_det_jacobian(x) + actual = jnp.log(jax.vmap(jax.grad(bijector.forward))(x)) + np.testing.assert_allclose(fwd_logdet, actual, rtol=1e-2) + + def test_forward_and_log_det(self): + x = jax.random.normal(self.seed, (100,)) + bijector = Tanh() + y1 = bijector.forward(x) + logdet1 = bijector.forward_log_det_jacobian(x) + y2, logdet2 = bijector.forward_and_log_det(x) + np.testing.assert_allclose(y1, y2, rtol=RTOL) + np.testing.assert_allclose(logdet1, logdet2, rtol=RTOL) + + def test_inverse(self): + x = jax.random.normal(self.seed, (100,)) + bijector = Tanh() + y = bijector.forward(x) + x_rec = bijector.inverse(y) + np.testing.assert_allclose(x_rec, x, rtol=1e-3) + + def test_inverse_log_det_jacobian(self): + x = jax.random.normal(self.seed, (100,)) + bijector = Tanh() + y = bijector.forward(x) + fwd_logdet = bijector.forward_log_det_jacobian(x) + inv_logdet = bijector.inverse_log_det_jacobian(y) + np.testing.assert_allclose(inv_logdet, -fwd_logdet, rtol=1e-3) + + def test_inverse_and_log_det(self): + y = jax.random.normal(self.seed, (100,)) + bijector = Tanh() + x1 = bijector.inverse(y) + logdet1 = bijector.inverse_log_det_jacobian(y) + x2, logdet2 = bijector.inverse_and_log_det(y) + np.testing.assert_allclose(x1, x2, rtol=RTOL) + np.testing.assert_allclose(logdet1, logdet2, rtol=RTOL) + + @parameterized.expand( + [ + ("int16", jnp.array([0, 0], dtype=jnp.int16)), + ("int32", jnp.array([0, 0], dtype=jnp.int32)), + ] + ) + def test_integer_inputs(self, name, inputs): + bijector = Tanh() + output, log_det = bijector.forward_and_log_det(inputs) + + expected_out = jnp.tanh(inputs).astype(jnp.float32) + expected_log_det = jnp.zeros_like(inputs, dtype=jnp.float32) + + np.testing.assert_array_equal(output, expected_out) + np.testing.assert_array_equal(log_det, expected_log_det) + + def test_jittable(self): + @jax.jit + def f(x, b): + return b.forward(x) + + bijector = Tanh() + x = jnp.zeros(()) + y = f(x, bijector) + self.assertIsInstance(y, jax.Array) + + def test_same_as(self): + bijector = Tanh() + self.assertTrue(bijector.same_as(bijector)) + self.assertTrue(bijector.same_as(Tanh())) + self.assertFalse(bijector.same_as(Sigmoid())) diff --git a/tests/transformed_test.py b/tests/transformed_test.py new file mode 100644 index 0000000..ddff006 --- /dev/null +++ b/tests/transformed_test.py @@ -0,0 +1,98 @@ +"""Tests for `transformed.py`.""" + +from unittest import TestCase + +import jax +import jax.numpy as jnp +import numpy as np +from parameterized import parameterized # type: ignore + +from distreqx.bijectors import ScalarAffine, Sigmoid +from distreqx.distributions import Normal, Transformed + + +class TransformedTest(TestCase): + def setUp(self): + self.seed = jax.random.PRNGKey(1234) + + @parameterized.expand( + [ + ("int16", jnp.array([0, 0], dtype=np.int16), Normal), + ("int32", jnp.array([0, 0], dtype=np.int32), Normal), + ("int64", jnp.array([0, 0], dtype=np.int64), Normal), + ] + ) + def test_integer_inputs(self, name, inputs, base_dist): + base = base_dist( + jnp.zeros_like(inputs, dtype=jnp.float32), + jnp.ones_like(inputs, dtype=jnp.float32), + ) + bijector = ScalarAffine(shift=jnp.array(0.0)) + dist = Transformed(base, bijector) + + log_prob = dist.log_prob(inputs) + + standard_normal_log_prob_of_zero = jnp.array(-0.9189385) + expected_log_prob = jnp.full_like( + inputs, standard_normal_log_prob_of_zero, dtype=jnp.float32 + ) + + np.testing.assert_array_equal(log_prob, expected_log_prob) + + @parameterized.expand( + [ + ("kl distreqx_to_distreqx", "distreqx_to_distreqx"), + ] + ) + def test_kl_divergence(self, name, mode_string): + base_dist1 = Normal( + loc=jnp.array([0.1, 0.5, 0.9]), scale=jnp.array([0.1, 1.1, 2.5]) + ) + base_dist2 = Normal( + loc=jnp.array([-0.1, -0.5, 0.9]), scale=jnp.array([0.1, -1.1, 2.5]) + ) + bij_distreqx1 = ScalarAffine(shift=jnp.array(0.0)) + bij_distreqx2 = ScalarAffine(shift=jnp.array(0.0)) + distreqx_dist1 = Transformed(base_dist1, bij_distreqx1) + distreqx_dist2 = Transformed(base_dist2, bij_distreqx2) + + expected_result_fwd = base_dist1.kl_divergence(base_dist2) + expected_result_inv = base_dist2.kl_divergence(base_dist1) + + if mode_string == "distreqx_to_distreqx": + result_fwd = distreqx_dist1.kl_divergence(distreqx_dist2) + result_inv = distreqx_dist2.kl_divergence(distreqx_dist1) + else: + raise ValueError(f"Unsupported mode string: {mode_string}") + + np.testing.assert_allclose(result_fwd, expected_result_fwd, rtol=1e-2) + np.testing.assert_allclose(result_inv, expected_result_inv, rtol=1e-2) + + def test_kl_divergence_on_same_instance_of_distreqx_bijector(self): + base_dist1 = Normal( + loc=jnp.array([0.1, 0.5, 0.9]), scale=jnp.array([0.1, 1.1, 2.5]) + ) + base_dist2 = Normal( + loc=jnp.array([-0.1, -0.5, 0.9]), scale=jnp.array([0.1, -1.1, 2.5]) + ) + bij_distreqx = Sigmoid() + distreqx_dist1 = Transformed(base_dist1, bij_distreqx) + distreqx_dist2 = Transformed(base_dist2, bij_distreqx) + expected_result_fwd = base_dist1.kl_divergence(base_dist2) + expected_result_inv = base_dist2.kl_divergence(base_dist1) + result_fwd = distreqx_dist1.kl_divergence(distreqx_dist2) + result_inv = distreqx_dist2.kl_divergence(distreqx_dist1) + np.testing.assert_allclose(result_fwd, expected_result_fwd, rtol=1e-2) + np.testing.assert_allclose(result_inv, expected_result_inv, rtol=1e-2) + + def test_jittable(self): + @jax.jit + def f(x, d): + return d.log_prob(x) + + base = Normal(jnp.array(0.0), jnp.array(1.0)) + bijector = ScalarAffine(jnp.array(0.0), jnp.array(1.0)) + dist = Transformed(base, bijector) + x = jnp.zeros(()) + y = f(x, dist) + self.assertIsInstance(y, jax.Array)