Skip to content

Commit

Permalink
Add One hot categorical (#36)
Browse files Browse the repository at this point in the history
* one one_hot_categorical

* one_hot_categorical_test

* format tests

* format dist
  • Loading branch information
lockwo authored Dec 8, 2024
1 parent 1d3e2e6 commit 5759f4a
Show file tree
Hide file tree
Showing 6 changed files with 627 additions and 48 deletions.
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.7
rev: v0.2.2
hooks:
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.368
rev: v1.1.350
hooks:
- id: pyright
additional_dependencies: ["equinox", "pytest", "jax", "jaxtyping"]
1 change: 1 addition & 0 deletions distreqx/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
)
from .mvn_tri import MultivariateNormalTri as MultivariateNormalTri
from .normal import Normal as Normal
from .one_hot_categorical import OneHotCategorical as OneHotCategorical
from .transformed import (
AbstractTransformed as AbstractTransformed,
Transformed as Transformed,
Expand Down
69 changes: 25 additions & 44 deletions distreqx/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,58 +175,39 @@ def mean(self):
raise NotImplementedError

def kl_divergence(self, other_dist, **kwargs) -> Array:
"""Calculates the KL divergence to another distribution.
"""Obtains the KL divergence `KL(dist1 || dist2)` between two Categoricals.
The KL computation takes into account that `0 * log(0) = 0`; therefore,
`dist1` may have zeros in its probability vector.
**Arguments:**
- `other_dist`: A compatible disteqx distribution.
- `kwargs`: Additional kwargs.
- `other_dist`: A Categorical distribution.
**Returns:**
The KL divergence `KL(self || other_dist)`.
"""
return _kl_divergence_categorical_categorical(self, other_dist)


def _kl_divergence_categorical_categorical(
dist1: Categorical,
dist2: Categorical,
*unused_args,
**unused_kwargs,
) -> Array:
"""Obtains the KL divergence `KL(dist1 || dist2)` between two Categoricals.
The KL computation takes into account that `0 * log(0) = 0`; therefore,
`dist1` may have zeros in its probability vector.
**Arguments:**
`KL(dist1 || dist2)`.
- `dist1`: A Categorical distribution.
- `dist2`: A Categorical distribution.
**Raises:**
**Returns:**
`KL(dist1 || dist2)`.
ValueError if the two distributions have different number of categories.
"""
if not isinstance(other_dist, Categorical):
raise TypeError("Only valid KL for both categoricals.")
logits1 = self.logits
logits2 = other_dist.logits

**Raises:**
num_categories1 = logits1.shape[-1]
num_categories2 = logits2.shape[-1]

ValueError if the two distributions have different number of categories.
"""
logits1 = dist1.logits
logits2 = dist2.logits

num_categories1 = logits1.shape[-1]
num_categories2 = logits2.shape[-1]

if num_categories1 != num_categories2:
raise ValueError(
f"Cannot obtain the KL between two Categorical distributions "
f"with different number of categories: the first distribution has "
f"{num_categories1} categories, while the second distribution has "
f"{num_categories2} categories."
)
if num_categories1 != num_categories2:
raise ValueError(
f"Cannot obtain the KL between two Categorical distributions "
f"with different number of categories: the first distribution has "
f"{num_categories1} categories, while the second distribution has "
f"{num_categories2} categories."
)

log_probs1 = jax.nn.log_softmax(logits1, axis=-1)
log_probs2 = jax.nn.log_softmax(logits2, axis=-1)
return jnp.sum(mul_exp(log_probs1 - log_probs2, log_probs1), axis=-1)
log_probs1 = jax.nn.log_softmax(logits1, axis=-1)
log_probs2 = jax.nn.log_softmax(logits2, axis=-1)
return jnp.sum(mul_exp(log_probs1 - log_probs2, log_probs1), axis=-1)
179 changes: 179 additions & 0 deletions distreqx/distributions/one_hot_categorical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""One hot categorical distribution."""

from typing import Optional, Union

import jax
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray

from ..utils.math import mul_exp, multiply_no_nan, normalize
from ._distribution import (
AbstractSampleLogProbDistribution,
AbstractSTDDistribution,
AbstractSurivialDistribution,
)


class OneHotCategorical(
AbstractSTDDistribution,
AbstractSampleLogProbDistribution,
AbstractSurivialDistribution,
strict=True,
):
"""OneHotCategorical distribution."""

_logits: Union[Array, None]
_probs: Union[Array, None]

def __init__(self, logits: Optional[Array] = None, probs: Optional[Array] = None):
"""Initializes a OneHotCategorical distribution.
**Arguments:**
- `logits`: Logit transform of the probability of each category. Only one
of `logits` or `probs` can be specified.
- `probs`: Probability of each category. Only one of `logits` or `probs` can
be specified.
"""
if (logits is None) == (probs is None):
raise ValueError(
f"One and exactly one of `logits` and `probs` should be `None`, "
f"but `logits` is {logits} and `probs` is {probs}."
)
if (not isinstance(logits, jax.Array)) and (not isinstance(probs, jax.Array)):
raise ValueError("`logits` and `probs` are not jax arrays.")

self._probs = None if probs is None else normalize(probs=probs)
self._logits = None if logits is None else normalize(logits=logits)

@property
def event_shape(self) -> tuple:
"""Shape of event of distribution samples."""
return (self.num_categories,)

@property
def logits(self) -> Array:
"""The logits for each event."""
if self._logits is not None:
return self._logits
if self._probs is None:
raise ValueError(
"_probs and _logits are None!"
) # TODO: useless but needed for pyright
return jnp.log(self._probs)

@property
def probs(self) -> Array:
"""The probabilities for each event."""
if self._probs is not None:
return self._probs
if self._logits is None:
raise ValueError(
"_probs and _logits are None!"
) # TODO: useless but needed for pyright
return jax.nn.softmax(self._logits, axis=-1)

@property
def num_categories(self) -> int:
"""Number of categories."""
if self._probs is not None:
return self._probs.shape[-1]
if self._logits is None:
raise ValueError(
"_probs and _logits are None!"
) # TODO: useless but needed for pyright
return self._logits.shape[-1]

def sample(self, key: PRNGKeyArray) -> Array:
"""See `Distribution.sample`."""
is_valid = jnp.logical_and(
jnp.all(jnp.isfinite(self.probs), axis=-1, keepdims=True),
jnp.all(self.probs >= 0, axis=-1, keepdims=True),
)
draws = jax.random.categorical(key=key, logits=self.logits, axis=-1)
draws_one_hot = jax.nn.one_hot(draws, num_classes=self.num_categories)
return jnp.where(
is_valid, draws_one_hot, jnp.ones_like(draws_one_hot) * -1
).astype(jnp.int8)

def log_prob(self, value: Array) -> Array:
"""See `Distribution.log_prob`."""
return jnp.sum(multiply_no_nan(self.logits, value), axis=-1)

def prob(self, value: Array) -> Array:
"""See `Distribution.prob`."""
return jnp.sum(multiply_no_nan(self.probs, value), axis=-1)

def entropy(self) -> Array:
"""See `Distribution.entropy`."""
if self._logits is None:
if self._probs is None:
raise ValueError(
"_probs and _logits are None!"
) # TODO: useless but needed for pyright
log_probs = jnp.log(self._probs)
else:
log_probs = jax.nn.log_softmax(self._logits)
return -jnp.sum(mul_exp(log_probs, log_probs), axis=-1)

def mode(self) -> Array:
"""See `Distribution.mode`."""
preferences = self._probs if self._logits is None else self._logits
assert preferences is not None
greedy_index = jnp.argmax(preferences, axis=-1)
return jax.nn.one_hot(greedy_index, self.num_categories)

def cdf(self, value: Array) -> Array:
"""See `Distribution.cdf`."""
return jnp.sum(multiply_no_nan(jnp.cumsum(self.probs, axis=-1), value), axis=-1)

def log_cdf(self, value: Array) -> Array:
"""See `Distribution.log_cdf`."""
return jnp.log(self.cdf(value))

def median(self):
raise NotImplementedError

def variance(self):
raise NotImplementedError

def mean(self):
raise NotImplementedError

def kl_divergence(self, other_dist, **kwargs) -> Array:
"""Obtains the KL divergence `KL(dist1 || dist2)` between two Categoricals.
The KL computation takes into account that `0 * log(0) = 0`; therefore,
`dist1` may have zeros in its probability vector.
**Arguments:**
- `other_dist`: A Categorical distribution.
**Returns:**
`KL(dist1 || dist2)`.
**Raises:**
ValueError if the two distributions have different number of categories.
"""
if not isinstance(other_dist, OneHotCategorical):
raise TypeError("Only valid KL for both categoricals.")
logits1 = self.logits
logits2 = other_dist.logits

num_categories1 = logits1.shape[-1]
num_categories2 = logits2.shape[-1]

if num_categories1 != num_categories2:
raise ValueError(
f"Cannot obtain the KL between two Categorical distributions "
f"with different number of categories: the first distribution has "
f"{num_categories1} categories, while the second distribution has "
f"{num_categories2} categories."
)

log_probs1 = jax.nn.log_softmax(logits1, axis=-1)
log_probs2 = jax.nn.log_softmax(logits2, axis=-1)
return jnp.sum(mul_exp(log_probs1 - log_probs2, log_probs1), axis=-1)
6 changes: 6 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ Current features include:

## Installation

```
pip install distreqx
```

or

```
git clone https://github.com/lockwo/distreqx.git
cd distreqx
Expand Down
Loading

0 comments on commit 5759f4a

Please sign in to comment.