Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds @returned_quantities macro #696

Open
wants to merge 27 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5c746c4
Added `@returned_quantities` macro
torfjelde Oct 23, 2024
0b081b7
Added `@returned_quantities` to the docs
torfjelde Oct 23, 2024
dc699a5
Fixed names of doctests for `@returned_quantities`
torfjelde Oct 23, 2024
7067695
Update src/submodel_macro.jl
torfjelde Oct 24, 2024
8cb0796
Added `@prefix` macro which calls `prefix` with a `Val` argument to
torfjelde Oct 29, 2024
2d887c9
Convert the result of `prefix_expr` in `@prefix` into a `Sybmol`
torfjelde Oct 29, 2024
692cfff
Export `prefix` and `@prefix`
torfjelde Oct 29, 2024
32fd6b9
Updated docstring for `@returned_quantities`
torfjelde Oct 29, 2024
5478fb3
Fixed bug in `rand` for `Model` where it would duplicate the non-leaf
torfjelde Oct 29, 2024
5fe65b3
Merge remote-tracking branch 'origin/torfjelde/returned-quantities-ma…
torfjelde Oct 29, 2024
9e0730f
Update src/contexts.jl
torfjelde Oct 29, 2024
cc3af46
Added `prefix` and `@prefix` to docs
torfjelde Oct 29, 2024
720053a
removed the prefix=... syntax for `@returned_quantities`
torfjelde Oct 31, 2024
fe0403f
added deprecation.jl + deprecated `generated_quantities` in favour of…
torfjelde Oct 31, 2024
55b95a1
removed export of `prefix` and `generated_quantities` (the latter is
torfjelde Oct 31, 2024
34fb6bd
updated `DynamicPPLMCMCChainsExt` to define `returned_quantities`
torfjelde Oct 31, 2024
9a7e18f
updated docs
torfjelde Oct 31, 2024
7aef65b
Update docs/src/api.md
torfjelde Nov 1, 2024
5ee727b
improved docstring for `prefix` and `@prefix`
torfjelde Nov 6, 2024
d92141c
added `@returned_quantities` macro taking two arguments + removed
torfjelde Nov 6, 2024
64b519d
updated docs to reflect the new two-argument `@returned_quantities`
torfjelde Nov 6, 2024
1b48f65
added depwarn to `@submodel` macro
torfjelde Nov 6, 2024
db2102c
fixed reference
torfjelde Nov 6, 2024
da95aba
fixed reference to `@prefix` in `@returned_quantities` macro
torfjelde Nov 6, 2024
c8d567f
actually fixed doc references
torfjelde Nov 6, 2024
d477137
updated doctests for `@submodel` to include the depwarn + added
torfjelde Nov 8, 2024
4896793
Merge branch 'master' into torfjelde/returned-quantities-macro
torfjelde Nov 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,18 @@ These statements are rewritten by `@model` as calls of [internal functions](@ref
@model
```

One can nest models and call another model inside the model function with [`@submodel`](@ref).
One can nest models and call another model inside the model function with [`@submodel`](@ref) and [`@returned_quantities(model)`](@ref).

```@docs
@submodel
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the plan to keep @submodel indefinitely, even though @returned_quantities does the same job?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would think we'd remove @submodel at some point. @yebai ?

Copy link
Member

@yebai yebai Oct 31, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, let's depreciate-then-delete @submodel in favour of @returned_quantities

@returned_quantities(model)
```

In the context of nesting models, it's also useful to prefix the variables in sub-models to avoid variable names clashing:

```@docs
@prefix
DynamicPPL.prefix
```

### Type
Expand Down Expand Up @@ -118,10 +126,11 @@ It is possible to manually increase (or decrease) the accumulated log density fr
@addlogprob!
```

Return values of the model function for a collection of samples can be obtained with [`generated_quantities`](@ref).
Return values of the model function for a collection of samples can be obtained with [`@returned_quantities`](@ref).

```@docs
generated_quantities
@returned_quantities(model, input)
DynamicPPL.returned_quantities
```

For a chain of samples, one can compute the pointwise log-likelihoods of each observed random variable with [`pointwise_loglikelihoods`](@ref). Similarly, the log-densities of the priors using
Expand Down
10 changes: 5 additions & 5 deletions ext/DynamicPPLMCMCChainsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ function DynamicPPL.varnames(c::MCMCChains.Chains)
end

"""
generated_quantities(model::Model, chain::MCMCChains.Chains)
returned_quantities(model::Model, chain::MCMCChains.Chains)

Execute `model` for each of the samples in `chain` and return an array of the values
returned by the `model` for each sample.
Expand All @@ -63,12 +63,12 @@ m = demo(data)
chain = sample(m, alg, n)
# To inspect the `interesting_quantity(θ, x)` where `θ` is replaced by samples
# from the posterior/`chain`:
generated_quantities(m, chain) # <= results in a `Vector` of returned values
returned_quantities(m, chain) # <= results in a `Vector` of returned values
# from `interesting_quantity(θ, x)`
```
## Concrete (and simple)
```julia
julia> using DynamicPPL, Turing
julia> using Turing

julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
Expand All @@ -87,7 +87,7 @@ julia> model = demo(randn(10));

julia> chain = sample(model, MH(), 10);

julia> generated_quantities(model, chain)
julia> DynamicPPL.returned_quantities(model, chain)
10×1 Array{Tuple{Float64},2}:
(2.1964758025119338,)
(2.1964758025119338,)
Expand All @@ -101,7 +101,7 @@ julia> generated_quantities(model, chain)
(-0.16489786710222099,)
```
"""
function DynamicPPL.generated_quantities(
function DynamicPPL.returned_quantities(
model::DynamicPPL.Model, chain_full::MCMCChains.Chains
)
chain = MCMCChains.get_sections(chain_full, :parameters)
Expand Down
5 changes: 4 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ export AbstractVarInfo,
Model,
getmissings,
getargnames,
generated_quantities,
extract_priors,
values_as_in_model,
# Samplers
Expand Down Expand Up @@ -125,6 +124,8 @@ export AbstractVarInfo,
# Convenience macros
@addlogprob!,
@submodel,
@returned_quantities,
@prefix,
value_iterator_from_chain,
check_model,
check_model_and_trace,
Expand Down Expand Up @@ -196,6 +197,8 @@ include("values_as_in_model.jl")
include("debug_utils.jl")
using .DebugUtils

include("deprecated.jl")

if !isdefined(Base, :get_extension)
using Requires
end
Expand Down
54 changes: 54 additions & 0 deletions src/contexts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,60 @@ function prefix(::PrefixContext{Prefix}, vn::VarName{Sym}) where {Prefix,Sym}
end
end

"""
prefix(model::Model, x)

Return `model` but with all random variables prefixed by `x`.

If `x` is known at compile-time, use `Val{x}()` to avoid runtime overheads for prefixing.

# Examples

```jldoctest
julia> using DynamicPPL: prefix

julia> @model demo() = x ~ Dirac(1)
demo (generic function with 2 methods)

julia> rand(prefix(demo(), :my_prefix))
(var"my_prefix.x" = 1,)

julia> # One can also use `Val` to avoid runtime overheads.
rand(prefix(demo(), Val(:my_prefix)))
(var"my_prefix.x" = 1,)
```
"""
prefix(model::Model, x) = contextualize(model, PrefixContext{Symbol(x)}(model.context))
function prefix(model::Model, ::Val{x}) where {x}
return contextualize(model, PrefixContext{Symbol(x)}(model.context))
end

"""
@prefix(model, prefix_expr)

Return `model` but with all random variables prefixed by `prefix_expr`.

The result of `prefix_expr` must will be converted to a `Symbol` and used as the prefix.

!!! note
This is effectively just a convenience macro for the method [`DynamicPPL.prefix(::Model, x)`](@ref),
which automatically converts the result of `prefix_expr` into a `Val` to avoid runtime overheads
for static prefixes. For more control over the prefixing, use the method directly.

# Examples

```jldoctest
julia> @model demo() = x ~ Dirac(1)
demo (generic function with 2 methods)

julia> rand(@prefix(demo(), :my_prefix))
(var"my_prefix.x" = 1,)
```
"""
macro prefix(model, prefix_expr)
return :($prefix($(esc(model)), $Val{$Symbol($(esc(prefix_expr)))}()))
end

struct ConditionContext{Values,Ctx<:AbstractContext} <: AbstractContext
values::Values
context::Ctx
Expand Down
1 change: 1 addition & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
@deprecate generated_quantities returned_quantities
34 changes: 25 additions & 9 deletions src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,9 @@ function Base.rand(rng::Random.AbstractRNG, ::Type{T}, model::Model) where {T}
evaluate!!(
model,
SimpleVarInfo{Float64}(OrderedDict()),
SamplingContext(rng, SampleFromPrior(), model.context),
# NOTE: Use `leafcontext` here so we a) avoid overriding the leaf context of `model`,
# and b) avoid double-stacking the parent contexts.
SamplingContext(rng, SampleFromPrior(), leafcontext(model.context)),
),
)
return values_as(x, T)
Expand Down Expand Up @@ -1204,9 +1206,9 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
end

"""
generated_quantities(model::Model, parameters::NamedTuple)
generated_quantities(model::Model, values, keys)
generated_quantities(model::Model, values, keys)
returned_quantities(model::Model, parameters::NamedTuple)
returned_quantities(model::Model, values, keys)
returned_quantities(model::Model, values, keys)

Execute `model` with variables `keys` set to `values` and return the values returned by the `model`.

Expand All @@ -1216,6 +1218,8 @@ If a `NamedTuple` is given, `keys=keys(parameters)` and `values=values(parameter
```jldoctest
julia> using DynamicPPL, Distributions

julia> using DynamicPPL: returned_quantities

julia> @model function demo(xs)
s ~ InverseGamma(2, 3)
m_shifted ~ Normal(10, √s)
Expand All @@ -1231,18 +1235,30 @@ julia> model = demo(randn(10));

julia> parameters = (; s = 1.0, m_shifted=10.0);

julia> generated_quantities(model, parameters)
julia> returned_quantities(model, parameters)
(0.0,)

julia> generated_quantities(model, values(parameters), keys(parameters))
julia> returned_quantities(model, values(parameters), keys(parameters))
(0.0,)
```
"""
function generated_quantities(model::Model, parameters::NamedTuple)
function returned_quantities(model::Model, parameters::NamedTuple)
fixed_model = fix(model, parameters)
return fixed_model()
end

function generated_quantities(model::Model, values, keys)
return generated_quantities(model, NamedTuple{keys}(values))
function returned_quantities(model::Model, values, keys)
return returned_quantities(model, NamedTuple{keys}(values))
end

"""
@returned_quantities(model, input)

Execute `model` and extract the return-values of `model` for `input`.

!!! note
This macro is in fact a simple wrapper around the method [`DynamicPPL.returned_quantities`](@ref).
"""
macro returned_quantities(model_expr, input_expr)
return :($returned_quantities($(esc(model_expr)), $(esc(input_expr))))
end
Loading
Loading