Skip to content

Commit

Permalink
Add tests (#7)
Browse files Browse the repository at this point in the history
* testing files

* complete tests

* add to contrib

* typing
  • Loading branch information
lockwo authored May 31, 2024
1 parent be48de1 commit c2189c6
Show file tree
Hide file tree
Showing 24 changed files with 1,105 additions and 18 deletions.
70 changes: 70 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Contributing

Contributions (pull requests) are very welcome! Here's how to get started.

---

**Getting started**

First fork the library on GitHub.

Then clone and install the library in development mode:

```bash
git clone https://github.com/your-username-here/distreqx.git
cd distreqx
pip install -e .
```

Then install the pre-commit hook:

```bash
pip install pre-commit
pre-commit install
```

These hooks use Black and isort to format the code, and flake8 to lint it.

---

**If you're making changes to the code:**

Now make your changes. Make sure to include additional tests if necessary.

If you include a new features, there are 3 required classes of tests:
- Correctness: tests the are against analytic or known solutions that ensure the computation is correct
- Compatibility: tests that check for `jit`, `vmap`, and `grad`-ability of the feature to make sure they behave as expected
- Edge cases: tests that make sure edge cases (e.g. large/small numerics, unexpected dtypes) are either dealt with or fail in an expected manner

Next verify the tests all pass:

```bash
pip install pytest
pytest tests
```

Then push your changes back to your fork of the repository:

```bash
git push
```

Finally, open a pull request on GitHub!

---

**If you're making changes to the documentation:**

Make your changes. You can then build the documentation by doing

```bash
pip install -r docs/requirements.txt
mkdocs serve
```
Then doing `Control-C`, and running:
```
mkdocs serve
```
(So you run `mkdocs serve` twice.)

You can then see your local copy of the documentation by navigating to `localhost:8000` in a web browser.
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,6 @@ If you found this library useful in academic research, please cite:
[Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
[diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.

**Awesome JAX**
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.

## Original distrax copyright
Expand Down
1 change: 1 addition & 0 deletions distreqx/bijectors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from .diag_linear import DiagLinear as DiagLinear
from .scalar_affine import ScalarAffine as ScalarAffine
from .shift import Shift as Shift
from .sigmoid import Sigmoid as Sigmoid
from .tanh import Tanh as Tanh
71 changes: 71 additions & 0 deletions distreqx/bijectors/sigmoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""Sigmoid bijector."""

from typing import Tuple

import jax
import jax.numpy as jnp
from jaxtyping import Array

from ._bijector import AbstractBijector


class Sigmoid(AbstractBijector):
"""A bijector that computes the logistic sigmoid.
The log-determinant implementation in this bijector is more numerically stable
than relying on the automatic differentiation approach used by Lambda, so this
bijector should be preferred over Lambda(jax.nn.sigmoid) where possible.
Note that the underlying implementation of `jax.nn.sigmoid` used by the
`forward` function of this bijector does not support inputs of integer type.
To invoke the forward function of this bijector on an argument of integer
type, it should first be cast explicitly to a floating point type.
When the absolute value of the input is large, `Sigmoid` becomes close to a
constant, so that it is not possible to recover the input `x` from the output
`y` within machine precision. In cases where it is needed to compute both the
forward mapping and the backward mapping one after the other to recover the
original input `x`, it is the user's responsibility to simplify the operation
to avoid numerical issues. One example of such case is to use the bijector
within a `Transformed` distribution and to obtain the log-probability of
samples obtained from the distribution's `sample` method. For values of the
samples for which it is not possible to apply the inverse bijector accurately,
`log_prob` returns NaN. This can be avoided by using `sample_and_log_prob`
instead of `sample` followed by `log_prob`.
"""

def __init__(self) -> None:
super().__init__()

def forward_log_det_jacobian(self, x: Array) -> Array:
"""Computes log|det J(f)(x)|."""
return -_more_stable_softplus(-x) - _more_stable_softplus(x)

def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
"""Computes y = f(x) and log|det J(f)(x)|."""
return _more_stable_sigmoid(x), self.forward_log_det_jacobian(x)

def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
"""Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
x = jnp.log(y) - jnp.log1p(-y)
return x, -self.forward_log_det_jacobian(x)

def same_as(self, other: AbstractBijector) -> bool:
"""Returns True if this bijector is guaranteed to be the same as `other`."""
return type(other) is Sigmoid


def _more_stable_sigmoid(x: Array) -> Array:
"""Where extremely negatively saturated, approximate sigmoid with exp(x)."""
ret = jnp.where(x < -9, jnp.exp(x), jax.nn.sigmoid(x))
if not isinstance(ret, Array):
raise TypeError("ret is not an Array")
return ret


def _more_stable_softplus(x: Array) -> Array:
"""Where extremely saturated, approximate softplus with log1p(exp(x))."""
ret = jnp.where(x < -9, jnp.log1p(jnp.exp(x)), jax.nn.softplus(x))
if not isinstance(ret, Array):
raise TypeError("ret is not an Array")
return ret
3 changes: 1 addition & 2 deletions distreqx/bijectors/tanh.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class Tanh(AbstractBijector):
instead of `sample` followed by `log_prob`.
"""

def __init__(self) -> None:
"""Initialize the TanH bijector."""
def __init__(self):
super().__init__()

def forward_log_det_jacobian(self, x: Array) -> Array:
Expand Down
2 changes: 1 addition & 1 deletion distreqx/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from ._distribution import (
AbstractDistribution as AbstractDistribution,
)
from ._transformed import Transformed as Transformed
from .bernoulli import Bernoulli as Bernoulli
from .independent import Independent as Independent
from .mvn_diag import MultivariateNormalDiag as MultivariateNormalDiag
from .mvn_from_bijector import (
MultivariateNormalFromBijector as MultivariateNormalFromBijector,
)
from .normal import Normal as Normal
from .transformed import Transformed as Transformed
14 changes: 6 additions & 8 deletions distreqx/distributions/mvn_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,12 @@ def __init__(self, loc: Optional[Array] = None, scale_diag: Optional[Array] = No
elif loc is None and scale_diag is not None:
loc = jnp.zeros(scale_diag.shape[-1], scale_diag.dtype)

assert loc is not None
assert scale_diag is not None

broadcasted_shapes = jnp.broadcast_shapes(loc.shape, scale_diag.shape)
loc = jnp.expand_dims(loc, axis=list(range(len(broadcasted_shapes) - loc.ndim)))
scale_diag = jnp.expand_dims(
scale_diag, axis=list(range(len(broadcasted_shapes) - scale_diag.ndim))
)
if loc is None:
raise ValueError("loc is None")
if scale_diag is None:
raise ValueError("scale_diag is None")
if scale_diag.ndim != 1:
raise ValueError("scale_diag must be a vector!")

scale = DiagLinear(scale_diag)
super().__init__(loc=loc, scale=scale)
Expand Down
2 changes: 1 addition & 1 deletion distreqx/distributions/mvn_from_bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from jaxtyping import Array

from ..bijectors import AbstractLinearBijector, Block, Chain, DiagLinear, Shift
from ._transformed import Transformed
from .independent import Independent
from .normal import Normal
from .transformed import Transformed


def _check_input_parameters_are_valid(
Expand Down
File renamed without changes.
6 changes: 6 additions & 0 deletions docs/api/bijectors/sigmoid.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Sigmoid Bijector

::: distreqx.bijectors.sigmoid.Sigmoid
selection:
members: None
---
3 changes: 1 addition & 2 deletions docs/api/bijectors/tanh.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,5 @@

::: distreqx.bijectors.tanh.Tanh
selection:
members:
- __init__
members: None
---
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Base Transformed

::: distreqx.distributions._transformed.Transformed
::: distreqx.distributions.transformed.Transformed
selection:
members:
- __init__
Expand Down
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,9 @@ nav:
- API:
- Distributions:
- 'api/distributions/_distribution.md'
- 'api/distributions/_transformed.md'
- 'api/distributions/independent.md'
- 'api/distributions/bernoulli.md'
- 'api/distributions/transformed.md'
- Gaussians:
- 'api/distributions/normal.md'
- 'api/distributions/mvn_diag.md'
Expand All @@ -110,6 +110,7 @@ nav:
- 'api/bijectors/diag_linear.md'
- 'api/bijectors/scalar_affine.md'
- 'api/bijectors/shift.md'
- 'api/bijectors/sigmoid.md'
- 'api/bijectors/tanh.md'
- Utilities:
- 'api/utils/math.md'
Expand Down
31 changes: 31 additions & 0 deletions tests/abstractlinear_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""Tests for `linear.py`."""

from unittest import TestCase

from parameterized import parameterized # type: ignore

from distreqx.bijectors import AbstractLinearBijector


class MockLinear(AbstractLinearBijector):
def __init__(self, dims):
super().__init__(dims)

def forward_and_log_det(self, x):
raise Exception


class LinearTest(TestCase):
@parameterized.expand(
[
("1", 1),
("10", 10),
]
)
def test_properties(self, name, event_dims):
bij = MockLinear(event_dims)
self.assertTrue(bij.is_constant_jacobian)
self.assertTrue(bij.is_constant_log_det)
self.assertEqual(bij.event_dims, event_dims)
with self.assertRaises(NotImplementedError):
bij.matrix
90 changes: 90 additions & 0 deletions tests/block_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Tests for `block.py`."""

from unittest import TestCase

import jax
import jax.numpy as jnp
import numpy as np
from parameterized import parameterized # type: ignore

from distreqx.bijectors import AbstractBijector, Block, ScalarAffine, Tanh


RTOL = 1e-6
seed = jax.random.PRNGKey(1234)


class BlockTest(TestCase):
def test_properties(self):
bijct = Tanh()
block = Block(bijct, 1)
self.assertEqual(block.ndims, 1)
self.assertIsInstance(block.bijector, AbstractBijector)

def test_invalid_properties(self):
bijct = Tanh()
with self.assertRaises(ValueError):
Block(bijct, -1)

@parameterized.expand(
[
("dx_tanh_0", Tanh, 0),
("dx_tanh_1", Tanh, 1),
("dx_tanh_2", Tanh, 2),
]
)
def test_forward_inverse_work_as_expected(self, name, bijector_fn, ndims):
bijct = bijector_fn()
x = jax.random.normal(seed, [2, 3])
block = Block(bijct, ndims)
np.testing.assert_array_equal(bijct.forward(x), block.forward(x))
np.testing.assert_array_equal(bijct.inverse(x), block.inverse(x))
np.testing.assert_allclose(
bijct.forward_and_log_det(x)[0], block.forward_and_log_det(x)[0], atol=2e-7
)
np.testing.assert_array_equal(
bijct.inverse_and_log_det(x)[0], block.inverse_and_log_det(x)[0]
)

@parameterized.expand(
[
("dx_tanh_0", Tanh, 0),
("dx_tanh_1", Tanh, 1),
("dx_tanh_2", Tanh, 2),
]
)
def test_log_det_jacobian_works_as_expected(self, name, bijector_fn, ndims):
bijct = bijector_fn()
x = jax.random.normal(seed, [2, 3])
block = Block(bijct, ndims)
axes = tuple(range(-ndims, 0))
np.testing.assert_allclose(
bijct.forward_log_det_jacobian(x).sum(axes),
block.forward_log_det_jacobian(x),
rtol=RTOL,
)
np.testing.assert_allclose(
bijct.inverse_log_det_jacobian(x).sum(axes),
block.inverse_log_det_jacobian(x),
rtol=RTOL,
)
np.testing.assert_allclose(
bijct.forward_and_log_det(x)[1].sum(axes),
block.forward_and_log_det(x)[1],
rtol=RTOL,
)
np.testing.assert_allclose(
bijct.inverse_and_log_det(x)[1].sum(axes),
block.inverse_and_log_det(x)[1],
rtol=RTOL,
)

def test_jittable(self):
@jax.jit
def f(x, b):
return b.forward(x)

bijector = Block(ScalarAffine(jnp.array(0)), 1)
x = jnp.zeros((2, 3))
y = f(x, bijector)
self.assertIsInstance(y, jax.Array)
Loading

0 comments on commit c2189c6

Please sign in to comment.