-
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* testing files * complete tests * add to contrib * typing
- Loading branch information
Showing
24 changed files
with
1,105 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Contributing | ||
|
||
Contributions (pull requests) are very welcome! Here's how to get started. | ||
|
||
--- | ||
|
||
**Getting started** | ||
|
||
First fork the library on GitHub. | ||
|
||
Then clone and install the library in development mode: | ||
|
||
```bash | ||
git clone https://github.com/your-username-here/distreqx.git | ||
cd distreqx | ||
pip install -e . | ||
``` | ||
|
||
Then install the pre-commit hook: | ||
|
||
```bash | ||
pip install pre-commit | ||
pre-commit install | ||
``` | ||
|
||
These hooks use Black and isort to format the code, and flake8 to lint it. | ||
|
||
--- | ||
|
||
**If you're making changes to the code:** | ||
|
||
Now make your changes. Make sure to include additional tests if necessary. | ||
|
||
If you include a new features, there are 3 required classes of tests: | ||
- Correctness: tests the are against analytic or known solutions that ensure the computation is correct | ||
- Compatibility: tests that check for `jit`, `vmap`, and `grad`-ability of the feature to make sure they behave as expected | ||
- Edge cases: tests that make sure edge cases (e.g. large/small numerics, unexpected dtypes) are either dealt with or fail in an expected manner | ||
|
||
Next verify the tests all pass: | ||
|
||
```bash | ||
pip install pytest | ||
pytest tests | ||
``` | ||
|
||
Then push your changes back to your fork of the repository: | ||
|
||
```bash | ||
git push | ||
``` | ||
|
||
Finally, open a pull request on GitHub! | ||
|
||
--- | ||
|
||
**If you're making changes to the documentation:** | ||
|
||
Make your changes. You can then build the documentation by doing | ||
|
||
```bash | ||
pip install -r docs/requirements.txt | ||
mkdocs serve | ||
``` | ||
Then doing `Control-C`, and running: | ||
``` | ||
mkdocs serve | ||
``` | ||
(So you run `mkdocs serve` twice.) | ||
|
||
You can then see your local copy of the documentation by navigating to `localhost:8000` in a web browser. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
"""Sigmoid bijector.""" | ||
|
||
from typing import Tuple | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jaxtyping import Array | ||
|
||
from ._bijector import AbstractBijector | ||
|
||
|
||
class Sigmoid(AbstractBijector): | ||
"""A bijector that computes the logistic sigmoid. | ||
The log-determinant implementation in this bijector is more numerically stable | ||
than relying on the automatic differentiation approach used by Lambda, so this | ||
bijector should be preferred over Lambda(jax.nn.sigmoid) where possible. | ||
Note that the underlying implementation of `jax.nn.sigmoid` used by the | ||
`forward` function of this bijector does not support inputs of integer type. | ||
To invoke the forward function of this bijector on an argument of integer | ||
type, it should first be cast explicitly to a floating point type. | ||
When the absolute value of the input is large, `Sigmoid` becomes close to a | ||
constant, so that it is not possible to recover the input `x` from the output | ||
`y` within machine precision. In cases where it is needed to compute both the | ||
forward mapping and the backward mapping one after the other to recover the | ||
original input `x`, it is the user's responsibility to simplify the operation | ||
to avoid numerical issues. One example of such case is to use the bijector | ||
within a `Transformed` distribution and to obtain the log-probability of | ||
samples obtained from the distribution's `sample` method. For values of the | ||
samples for which it is not possible to apply the inverse bijector accurately, | ||
`log_prob` returns NaN. This can be avoided by using `sample_and_log_prob` | ||
instead of `sample` followed by `log_prob`. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def forward_log_det_jacobian(self, x: Array) -> Array: | ||
"""Computes log|det J(f)(x)|.""" | ||
return -_more_stable_softplus(-x) - _more_stable_softplus(x) | ||
|
||
def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: | ||
"""Computes y = f(x) and log|det J(f)(x)|.""" | ||
return _more_stable_sigmoid(x), self.forward_log_det_jacobian(x) | ||
|
||
def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: | ||
"""Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" | ||
x = jnp.log(y) - jnp.log1p(-y) | ||
return x, -self.forward_log_det_jacobian(x) | ||
|
||
def same_as(self, other: AbstractBijector) -> bool: | ||
"""Returns True if this bijector is guaranteed to be the same as `other`.""" | ||
return type(other) is Sigmoid | ||
|
||
|
||
def _more_stable_sigmoid(x: Array) -> Array: | ||
"""Where extremely negatively saturated, approximate sigmoid with exp(x).""" | ||
ret = jnp.where(x < -9, jnp.exp(x), jax.nn.sigmoid(x)) | ||
if not isinstance(ret, Array): | ||
raise TypeError("ret is not an Array") | ||
return ret | ||
|
||
|
||
def _more_stable_softplus(x: Array) -> Array: | ||
"""Where extremely saturated, approximate softplus with log1p(exp(x)).""" | ||
ret = jnp.where(x < -9, jnp.log1p(jnp.exp(x)), jax.nn.softplus(x)) | ||
if not isinstance(ret, Array): | ||
raise TypeError("ret is not an Array") | ||
return ret |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,11 @@ | ||
from ._distribution import ( | ||
AbstractDistribution as AbstractDistribution, | ||
) | ||
from ._transformed import Transformed as Transformed | ||
from .bernoulli import Bernoulli as Bernoulli | ||
from .independent import Independent as Independent | ||
from .mvn_diag import MultivariateNormalDiag as MultivariateNormalDiag | ||
from .mvn_from_bijector import ( | ||
MultivariateNormalFromBijector as MultivariateNormalFromBijector, | ||
) | ||
from .normal import Normal as Normal | ||
from .transformed import Transformed as Transformed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Sigmoid Bijector | ||
|
||
::: distreqx.bijectors.sigmoid.Sigmoid | ||
selection: | ||
members: None | ||
--- |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,5 @@ | |
|
||
::: distreqx.bijectors.tanh.Tanh | ||
selection: | ||
members: | ||
- __init__ | ||
members: None | ||
--- |
2 changes: 1 addition & 1 deletion
2
docs/api/distributions/_transformed.md → docs/api/distributions/transformed.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
"""Tests for `linear.py`.""" | ||
|
||
from unittest import TestCase | ||
|
||
from parameterized import parameterized # type: ignore | ||
|
||
from distreqx.bijectors import AbstractLinearBijector | ||
|
||
|
||
class MockLinear(AbstractLinearBijector): | ||
def __init__(self, dims): | ||
super().__init__(dims) | ||
|
||
def forward_and_log_det(self, x): | ||
raise Exception | ||
|
||
|
||
class LinearTest(TestCase): | ||
@parameterized.expand( | ||
[ | ||
("1", 1), | ||
("10", 10), | ||
] | ||
) | ||
def test_properties(self, name, event_dims): | ||
bij = MockLinear(event_dims) | ||
self.assertTrue(bij.is_constant_jacobian) | ||
self.assertTrue(bij.is_constant_log_det) | ||
self.assertEqual(bij.event_dims, event_dims) | ||
with self.assertRaises(NotImplementedError): | ||
bij.matrix |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
"""Tests for `block.py`.""" | ||
|
||
from unittest import TestCase | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
from parameterized import parameterized # type: ignore | ||
|
||
from distreqx.bijectors import AbstractBijector, Block, ScalarAffine, Tanh | ||
|
||
|
||
RTOL = 1e-6 | ||
seed = jax.random.PRNGKey(1234) | ||
|
||
|
||
class BlockTest(TestCase): | ||
def test_properties(self): | ||
bijct = Tanh() | ||
block = Block(bijct, 1) | ||
self.assertEqual(block.ndims, 1) | ||
self.assertIsInstance(block.bijector, AbstractBijector) | ||
|
||
def test_invalid_properties(self): | ||
bijct = Tanh() | ||
with self.assertRaises(ValueError): | ||
Block(bijct, -1) | ||
|
||
@parameterized.expand( | ||
[ | ||
("dx_tanh_0", Tanh, 0), | ||
("dx_tanh_1", Tanh, 1), | ||
("dx_tanh_2", Tanh, 2), | ||
] | ||
) | ||
def test_forward_inverse_work_as_expected(self, name, bijector_fn, ndims): | ||
bijct = bijector_fn() | ||
x = jax.random.normal(seed, [2, 3]) | ||
block = Block(bijct, ndims) | ||
np.testing.assert_array_equal(bijct.forward(x), block.forward(x)) | ||
np.testing.assert_array_equal(bijct.inverse(x), block.inverse(x)) | ||
np.testing.assert_allclose( | ||
bijct.forward_and_log_det(x)[0], block.forward_and_log_det(x)[0], atol=2e-7 | ||
) | ||
np.testing.assert_array_equal( | ||
bijct.inverse_and_log_det(x)[0], block.inverse_and_log_det(x)[0] | ||
) | ||
|
||
@parameterized.expand( | ||
[ | ||
("dx_tanh_0", Tanh, 0), | ||
("dx_tanh_1", Tanh, 1), | ||
("dx_tanh_2", Tanh, 2), | ||
] | ||
) | ||
def test_log_det_jacobian_works_as_expected(self, name, bijector_fn, ndims): | ||
bijct = bijector_fn() | ||
x = jax.random.normal(seed, [2, 3]) | ||
block = Block(bijct, ndims) | ||
axes = tuple(range(-ndims, 0)) | ||
np.testing.assert_allclose( | ||
bijct.forward_log_det_jacobian(x).sum(axes), | ||
block.forward_log_det_jacobian(x), | ||
rtol=RTOL, | ||
) | ||
np.testing.assert_allclose( | ||
bijct.inverse_log_det_jacobian(x).sum(axes), | ||
block.inverse_log_det_jacobian(x), | ||
rtol=RTOL, | ||
) | ||
np.testing.assert_allclose( | ||
bijct.forward_and_log_det(x)[1].sum(axes), | ||
block.forward_and_log_det(x)[1], | ||
rtol=RTOL, | ||
) | ||
np.testing.assert_allclose( | ||
bijct.inverse_and_log_det(x)[1].sum(axes), | ||
block.inverse_and_log_det(x)[1], | ||
rtol=RTOL, | ||
) | ||
|
||
def test_jittable(self): | ||
@jax.jit | ||
def f(x, b): | ||
return b.forward(x) | ||
|
||
bijector = Block(ScalarAffine(jnp.array(0)), 1) | ||
x = jnp.zeros((2, 3)) | ||
y = f(x, bijector) | ||
self.assertIsInstance(y, jax.Array) |
Oops, something went wrong.