Skip to content

Commit

Permalink
AbstractInterpreter: refactor the lifetimes of OptimizationState
Browse files Browse the repository at this point in the history
…and `IRCode`

This commit limits the lifetimes of `OptimizationState` and `IRCode`
for a more dataflow clarity. It also avoids duplicated calls of `ir_to_codeinf!`.

Note that external `AbstractInterpreter`s can still extend their
lifetimes to cache additional information, as described by this
newly added documentation of `finish!`:

>     finish!(interp::AbstractInterpreter,
>         opt::OptimizationState, ir::IRCode, caller::InferenceResult)
>
> Runs post-Julia-level optimization process and caches information for later uses:
> - computes "purity" (i.e. side-effect-freeness) of the optimized frame
> - computes inlining cost and cache the inlineability in `opt.src.inlineable`
> - stores the result of optimization in `caller.src`
> * by default, `caller.src` will be an optimized `CodeInfo` object transformed from `ir`
> * in a case when this frame has been proven pure, `ConstAPI` object wrapping the constant
> value will be kept in `caller.src` instead, so that the runtime system will use
> the constant calling convention
>
> !!! note
>     The lifetimes of `opt` and `ir` end by the end of this process.
>     Still external `AbstractInterpreter` can override this method as necessary to cache them.
>     Note that `transform_result_for_cache` should be overloaded also in such cases,
>     otherwise the default `transform_result_for_cache` implmentation will discard any information
>     other than `CodeInfo`, `Vector{UInt8}` or `ConstAPI`.

This commit also adds a new overload `infresult_iterator` so that external
interpreters can tweak the behavior of post processings of `_typeinf`.
Especially, this change is motivated by the need for JET, whose post-optimization
processing needs references of `InferenceState`.
  • Loading branch information
aviatesk committed Jan 31, 2022
1 parent a7beb93 commit 7569d68
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 64 deletions.
69 changes: 43 additions & 26 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ end

include("compiler/ssair/driver.jl")

mutable struct OptimizationState
struct OptimizationState
linfo::MethodInstance
src::CodeInfo
ir::Union{Nothing, IRCode}
stmt_info::Vector{Any}
mod::Module
sptypes::Vector{Any} # static parameters
Expand All @@ -69,8 +68,7 @@ mutable struct OptimizationState
EdgeTracker(s_edges, frame.valid_worlds),
WorldView(code_cache(interp), frame.world),
interp)
return new(frame.linfo,
frame.src, nothing, frame.stmt_info, frame.mod,
return new(frame.linfo, frame.src, frame.stmt_info, frame.mod,
frame.sptypes, frame.slottypes, inlining)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
Expand All @@ -97,8 +95,7 @@ mutable struct OptimizationState
nothing,
WorldView(code_cache(interp), get_world_counter()),
interp)
return new(linfo,
src, nothing, stmt_info, mod,
return new(linfo, src, stmt_info, mod,
sptypes_from_meth_instance(linfo), slottypes, inlining)
end
end
Expand All @@ -109,11 +106,10 @@ function OptimizationState(linfo::MethodInstance, params::OptimizationParams, in
return OptimizationState(linfo, src, params, interp)
end

function ir_to_codeinf!(opt::OptimizationState)
function ir_to_codeinf!(opt::OptimizationState, ir::IRCode)
(; linfo, src) = opt
optdef = linfo.def
replace_code_newstyle!(src, opt.ir::IRCode, isa(optdef, Method) ? Int(optdef.nargs) : 0)
opt.ir = nothing
replace_code_newstyle!(src, ir, isa(optdef, Method) ? Int(optdef.nargs) : 0)
widen_all_consts!(src)
src.inferred = true
# finish updating the result struct
Expand Down Expand Up @@ -383,18 +379,27 @@ struct ConstAPI
end

"""
finish(interp::AbstractInterpreter, opt::OptimizationState,
params::OptimizationParams, ir::IRCode, caller::InferenceResult) -> analyzed::Union{Nothing,ConstAPI}
Post process information derived by Julia-level optimizations for later uses:
- computes "purity", i.e. side-effect-freeness
- computes inlining cost
In a case when the purity is proven, `finish` can return `ConstAPI` object wrapping the constant
value so that the runtime system will use the constant calling convention for the method calls.
finish!(interp::AbstractInterpreter,
opt::OptimizationState, ir::IRCode, caller::InferenceResult)
Runs post-Julia-level optimization process and caches information for later uses:
- computes "purity" (i.e. side-effect-freeness) of the optimized frame
- computes inlining cost and cache the inlineability in `opt.src.inlineable`
- stores the result of optimization in `caller.src`
* by default, `caller.src` will be an optimized `CodeInfo` object transformed from `ir`
* in a case when this frame has been proven pure, `ConstAPI` object wrapping the constant
value will be kept in `caller.src` instead, so that the runtime system will use
the constant calling convention
!!! note
The lifetimes of `opt` and `ir` end by the end of this process.
Still external `AbstractInterpreter` can override this method as necessary to cache them.
Note that `transform_result_for_cache` should be overloaded also in such cases,
otherwise the default `transform_result_for_cache` implmentation will discard any information
other than `CodeInfo`, `Vector{UInt8}` or `ConstAPI`.
"""
function finish(interp::AbstractInterpreter, opt::OptimizationState,
params::OptimizationParams, ir::IRCode, caller::InferenceResult)
function finish!(interp::AbstractInterpreter,
opt::OptimizationState, ir::IRCode, caller::InferenceResult)
(; src, linfo) = opt
(; def, specTypes) = linfo

Expand Down Expand Up @@ -452,8 +457,6 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
end
end

opt.ir = ir

# determine and cache inlineability
union_penalties = false
if !force_noinline
Expand All @@ -480,6 +483,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
# obey @inline declaration if a dispatch barrier would not help
else
# compute the cost (size) of inlining this code
params = opt.inlining.params
cost_threshold = default = params.inline_cost_threshold
if result Tuple && !isconcretetype(widenconst(result))
cost_threshold += params.inline_tupleret_bonus
Expand All @@ -499,14 +503,27 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
end
end

return analyzed
caller.valid_worlds = (opt.inlining.et::EdgeTracker).valid_worlds[]

if isa(analyzed, ConstAPI)
# XXX: The work in ir_to_codeinf! is essentially wasted. The only reason
# we're doing it is so that code_llvm can return the code
# for the `return ...::Const` (which never runs anyway). We should do this
# as a post processing step instead.
ir_to_codeinf!(opt, ir)
caller.src = analyzed
else
caller.src = ir_to_codeinf!(opt, ir)
end

return nothing
end

# run the optimization work
function optimize(interp::AbstractInterpreter, opt::OptimizationState,
params::OptimizationParams, caller::InferenceResult)
function optimize!(interp::AbstractInterpreter,
opt::OptimizationState, caller::InferenceResult)
@timeit "optimizer" ir = run_passes(opt.src, opt)
return finish(interp, opt, params, ir, caller)
@timeit "finish!" finish!(interp, opt, ir, caller)
end

function run_passes(ci::CodeInfo, sv::OptimizationState)
Expand Down
72 changes: 34 additions & 38 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,23 +210,20 @@ function typeinf(interp::AbstractInterpreter, frame::InferenceState)
end
end

function finish!(interp::AbstractInterpreter, caller::InferenceResult)
# If we didn't transform the src for caching, we may have to transform
# it anyway for users like typeinf_ext. Do that here.
opt = caller.src
if opt isa OptimizationState # implies `may_optimize(interp) === true`
if opt.ir !== nothing
caller.src = ir_to_codeinf!(opt)
end
end
return caller.src
end

function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
typeinf_nocycle(interp, frame) || return false # frame is now part of a higher cycle
# with no active ip's, frame is done
frames = frame.callers_in_cycle
isempty(frames) && push!(frames, frame)
finish_infstates!(interp, frames)
# collect results for the new expanded frame
results = infresult_iterator(interp, frames)
optimize!(interp, results)
cache_results!(interp, results)
return true
end

function finish_infstates!(interp::AbstractInterpreter, frames::Vector{InferenceState})
valid_worlds = WorldRange()
for caller in frames
@assert !(caller.dont_work_on_me)
Expand All @@ -240,29 +237,35 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
# finalize and record the linfo result
caller.inferred = true
end
# collect results for the new expanded frame
results = Tuple{InferenceResult, Vector{Any}, Bool}[
( frames[i].result,
frames[i].stmt_edges[1]::Vector{Any},
frames[i].cached )
for i in 1:length(frames) ]
empty!(frames)
for (caller, _, _) in results
end

struct InfResultInfo
caller::InferenceResult
edges::Vector{Any}
cached::Bool
end

# returns iterator on which `optimize!` and `postopt_process!` work on
function infresult_iterator(_::AbstractInterpreter, frames::Vector{InferenceState})
results = InfResultInfo[ InfResultInfo(
frames[i].result,
frames[i].stmt_edges[1]::Vector{Any},
frames[i].cached ) for i in 1:length(frames) ]
empty!(frames) # discard `InferenceState` now
return results
end

function optimize!(interp::AbstractInterpreter, results::Vector{InfResultInfo})
for (; caller) in results
opt = caller.src
if opt isa OptimizationState # implies `may_optimize(interp) === true`
analyzed = optimize(interp, opt, OptimizationParams(interp), caller)
if isa(analyzed, ConstAPI)
# XXX: The work in ir_to_codeinf! is essentially wasted. The only reason
# we're doing it is so that code_llvm can return the code
# for the `return ...::Const` (which never runs anyway). We should do this
# as a post processing step instead.
ir_to_codeinf!(opt)
caller.src = analyzed
end
caller.valid_worlds = (opt.inlining.et::EdgeTracker).valid_worlds[]
optimize!(interp, opt, caller)
end
end
for (caller, edges, cached) in results
end

function cache_results!(interp::AbstractInterpreter, results::Vector{InfResultInfo})
for (; caller, edges, cached) in results
valid_worlds = caller.valid_worlds
if last(valid_worlds) >= get_world_counter()
# if we aren't cached, we don't need this edge
Expand All @@ -272,9 +275,7 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
if cached
cache_result!(interp, caller)
end
finish!(interp, caller)
end
return true
end

function CodeInstance(result::InferenceResult, @nospecialize(inferred_result),
Expand Down Expand Up @@ -349,11 +350,6 @@ end

function transform_result_for_cache(interp::AbstractInterpreter, linfo::MethodInstance,
valid_worlds::WorldRange, @nospecialize(inferred_result))
# If we decided not to optimize, drop the OptimizationState now.
# External interpreters can override as necessary to cache additional information
if inferred_result isa OptimizationState
inferred_result = ir_to_codeinf!(inferred_result)
end
if inferred_result isa CodeInfo
inferred_result.min_world = first(valid_worlds)
inferred_result.max_world = last(valid_worlds)
Expand Down

0 comments on commit 7569d68

Please sign in to comment.