Skip to content

Commit

Permalink
Add support for power transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Dec 5, 2022
1 parent 473c1e6 commit d87ef92
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 8 deletions.
42 changes: 34 additions & 8 deletions aeppl/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
from aesara.tensor.math import add, exp, log, mul, reciprocal, sub, true_div
from aesara.tensor.math import add, exp, log, mul, pow, reciprocal, sub, true_div
from aesara.tensor.rewriting.basic import (
register_specialize,
register_stabilize,
Expand Down Expand Up @@ -422,8 +422,20 @@ def transform(measurable_input, *other_inputs):
def measurable_reciprocal(fgraph, node):
"""Rewrite a `reciprocal` node to a `MeasurableVariable`."""

def transform(measurable_input, *other_inputs):
return ReciprocalTransform(), (measurable_input,)
new_node = at.power(node.inputs[0], at.as_tensor(-1)).owner
return measurable_pow.transform(fgraph, new_node)


@register_measurable_ir
@node_rewriter([pow])
def measurable_pow(fgraph, node):
"""Rewrite a `power` node to a `MeasurableVariable`."""

def transform(measurable_input, *args):
return PowerTransform(transform_args_fn=lambda *inputs: inputs[-1]), (
measurable_input,
*args,
)

return construct_elemwise_transform(fgraph, node, transform)

Expand Down Expand Up @@ -579,17 +591,31 @@ def log_jac_det(self, value, *inputs):
return -at.log(value)


class ReciprocalTransform(RVTransform):
name = "reciprocal"
class PowerTransform(RVTransform):
name = "power"

def __init__(self, transform_args_fn):
self.transform_args_fn = transform_args_fn

def forward(self, value, *inputs):
return at.reciprocal(value)
power = self.transform_args_fn(*inputs)
return at.power(value, power)

def backward(self, value, *inputs):
return at.reciprocal(value)
power = self.transform_args_fn(*inputs)

inv_power = at.reciprocal(power)
return at.switch(
at.eq(at.mod(power, 2), 0),
at.power(value, inv_power),
at.sgn(value) * at.power(at.abs(value), inv_power),
)

def log_jac_det(self, value, *inputs):
return -2 * at.log(value)
from aeppl.logprob import xlogy0

power = self.transform_args_fn(*inputs)
return at.log(at.abs(power)) + xlogy0((power - 1), at.abs(value))


class IntervalTransform(RVTransform):
Expand Down
19 changes: 19 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,3 +763,22 @@ def test_transform_measurable_sub():

with pytest.raises(RuntimeError, match="The logprob terms"):
joint_logprob(Z_rv, X_rv)


@pytest.mark.parametrize(
"pow_fn, exp_val_fn",
[
(lambda x: x**2, lambda z: sp.stats.chi2(df=1).logpdf(z))
# TODO: Add more cases.
],
)
def test_transform_measurable_pow(pow_fn, exp_val_fn):
X_rv = at.random.normal(0, 1, name="X")
Z_rv = pow_fn(X_rv)
Z_rv.name = "Z"

z_logp, (z_vv,) = conditional_logprob(Z_rv)
z_logp_fn = aesara.function([z_vv], z_logp[Z_rv])

z_val = 0.5
assert np.allclose(z_logp_fn(z_val), exp_val_fn(z_val))

0 comments on commit d87ef92

Please sign in to comment.