Skip to content

Commit

Permalink
Add a specify_shape to the split values in `logprob_join_constant_sha…
Browse files Browse the repository at this point in the history
…pes`

This avoids aesara-devs/aeppl#191
  • Loading branch information
ricardoV94 committed Oct 8, 2022
1 parent ec6ea14 commit 5133cc5
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from aeppl.tensor import MeasurableJoin
from aeppl.transforms import TransformValuesRewrite
from aesara import tensor as at
from aesara.graph.basic import graph_inputs, io_toposort
from aesara.graph.basic import Variable, graph_inputs, io_toposort
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
Expand Down Expand Up @@ -328,23 +328,35 @@ def ignore_logprob(rv: TensorVariable) -> TensorVariable:
def logprob_join_constant_shapes(op, values, axis, *base_vars, **kwargs):
"""Compute the log-likelihood graph for a `Join`.
This overrides the implementation in Aeppl, to constant fold the shapes
of the base vars so that RandomVariables do not show up in the logp graph,
This overrides the implementation in Aeppl, to constant fold the shapes of
the split value var so that RandomVariables do not show up in the logp graph,
which is a requirement enforced by `pymc.distributions.logprob.joint_logp`
"""
(value,) = values

base_var_shapes = [base_var.shape[axis] for base_var in base_vars]

# We flatten the base var shapes, so that we can constant_fold with more granularity
flat_base_var_shapes = [s for base_var in base_vars for s in base_var.shape]
# We don't need the graph to be constant, just to have RandomVariables removed
base_var_shapes = constant_fold(base_var_shapes, raise_not_constant=False)

split_values = at.split(
flat_base_var_shapes = constant_fold(flat_base_var_shapes, raise_not_constant=False)
# Unflatten base var shapes
base_var_shapes = []
i = 0
for base_var in base_vars:
ndim = base_var.type.ndim
base_var_shapes.append(flat_base_var_shapes[i : i + ndim])
i += ndim

split_values_raw = at.split(
value,
splits_size=[base_var_shape for base_var_shape in base_var_shapes],
splits_size=[at.stack(base_var_shape)[axis] for base_var_shape in base_var_shapes],
n_splits=len(base_vars),
axis=axis,
)
# We need to do add specify_shape due https://github.com/aesara-devs/aeppl/issues/191
split_values = []
for split_value, shape in zip(split_values_raw, base_var_shapes):
split_value_shape = [None if isinstance(s, Variable) else s for s in shape]
split_values.append(at.specify_shape(split_value, split_value_shape))

logps = [
logprob(base_var, split_value) for base_var, split_value in zip(base_vars, split_values)
Expand Down

0 comments on commit 5133cc5

Please sign in to comment.