Skip to content

Commit

Permalink
AbstractInterpreter: refactor the lifetimes of OptimizationState
Browse files Browse the repository at this point in the history
…and `IRCode`

This commit limits the lifetimes of `OptimizationState` and `IRCode`
for a more dataflow clarity. It also avoids duplicated calls of `ir_to_codeinf!`.

Note that external `AbstractInterpreter`s can still extend their
lifetimes to cache additional information, as described by this
newly added documentation of `finish!`:

>     finish!(interp::AbstractInterpreter,
>         opt::OptimizationState, ir::IRCode, caller::InferenceResult)
>
> Runs post-Julia-level optimization process and caches information for later uses:
> - computes "purity" (i.e. side-effect-freeness) of the optimized frame
> - computes inlining cost and cache the inlineability in `opt.src.inlineable`
> - stores the result of optimization in `caller.src`
> * by default, `caller.src` will be an optimized `CodeInfo` object transformed from `ir`
> * in a case when this frame has been proven pure, `ConstAPI` object wrapping the constant
> value will be kept in `caller.src` instead, so that the runtime system will use
> the constant calling convention
>
> !!! note
>     The lifetimes of `opt` and `ir` end by the end of this process.
>     Still external `AbstractInterpreter` can override this method as necessary to cache them.
>     Note that `transform_result_for_cache` should be overloaded also in such cases,
>     otherwise the default `transform_result_for_cache` implmentation will discard any information
>     other than `CodeInfo`, `Vector{UInt8}` or `ConstAPI`.

This commit also adds a new overload `infresult_iterator` so that external
interpreters can tweak the behavior of post processings of `_typeinf`.
Especially, this change is motivated by the need for JET, whose post-optimization
processing needs references of `InferenceState`.
  • Loading branch information
aviatesk committed Mar 26, 2022
1 parent 5dc6155 commit 8e88697
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 147 deletions.
229 changes: 125 additions & 104 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,9 @@ end

include("compiler/ssair/driver.jl")

mutable struct OptimizationState
struct OptimizationState
linfo::MethodInstance
src::CodeInfo
ir::Union{Nothing, IRCode}
stmt_info::Vector{Any}
mod::Module
sptypes::Vector{Any} # static parameters
Expand All @@ -99,8 +98,7 @@ mutable struct OptimizationState
EdgeTracker(s_edges, frame.valid_worlds),
WorldView(code_cache(interp), frame.world),
interp)
return new(frame.linfo,
frame.src, nothing, frame.stmt_info, frame.mod,
return new(frame.linfo, frame.src, frame.stmt_info, frame.mod,
frame.sptypes, frame.slottypes, inlining)
end
function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::OptimizationParams, interp::AbstractInterpreter)
Expand All @@ -127,8 +125,7 @@ mutable struct OptimizationState
nothing,
WorldView(code_cache(interp), get_world_counter()),
interp)
return new(linfo,
src, nothing, stmt_info, mod,
return new(linfo, src, stmt_info, mod,
sptypes_from_meth_instance(linfo), slottypes, inlining)
end
end
Expand All @@ -139,11 +136,10 @@ function OptimizationState(linfo::MethodInstance, params::OptimizationParams, in
return OptimizationState(linfo, src, params, interp)
end

function ir_to_codeinf!(opt::OptimizationState)
function ir_to_codeinf!(opt::OptimizationState, ir::IRCode)
(; linfo, src) = opt
optdef = linfo.def
replace_code_newstyle!(src, opt.ir::IRCode, isa(optdef, Method) ? Int(optdef.nargs) : 0)
opt.ir = nothing
replace_code_newstyle!(src, ir, isa(optdef, Method) ? Int(optdef.nargs) : 0)
widen_all_consts!(src)
src.inferred = true
# finish updating the result struct
Expand Down Expand Up @@ -380,130 +376,155 @@ struct ConstAPI
end

"""
finish(interp::AbstractInterpreter, opt::OptimizationState,
params::OptimizationParams, ir::IRCode, caller::InferenceResult) -> analyzed::Union{Nothing,ConstAPI}
Post process information derived by Julia-level optimizations for later uses:
- computes "purity", i.e. side-effect-freeness
- computes inlining cost
In a case when the purity is proven, `finish` can return `ConstAPI` object wrapping the constant
value so that the runtime system will use the constant calling convention for the method calls.
finish!(interp::AbstractInterpreter,
opt::OptimizationState, ir::IRCode, caller::InferenceResult)
Runs post-Julia-level optimization process and caches information for later uses:
- computes "purity" (i.e. side-effect-freeness) of the optimized frame
- computes inlining cost and cache the inlineability in `opt.src.inlineable`
- stores the result of optimization in `caller.src`
* by default, `caller.src` will be an optimized `CodeInfo` object transformed from `ir`
* in a case when this frame has been proven pure, `ConstAPI` object wrapping the constant
value will be kept in `caller.src` instead, so that the runtime system will use
the constant calling convention
!!! note
The lifetimes of `opt` and `ir` end by the end of this process.
Still external `AbstractInterpreter` can override `transform_optresult_for_cache`
as necessary to cache them. Note that `transform_result_for_cache` should be overloaded
also in such cases, otherwise the default implmentation of `transform_result_for_cache`
will discard any information other than `CodeInfo`, `Vector{UInt8}` or `ConstAPI`.
"""
function finish(interp::AbstractInterpreter, opt::OptimizationState,
params::OptimizationParams, ir::IRCode, caller::InferenceResult)
(; src, linfo) = opt
(; def, specTypes) = linfo

analyzed = nothing # `ConstAPI` if this call can use constant calling convention
force_noinline = _any(@nospecialize(x) -> isexpr(x, :meta) && x.args[1] === :noinline, ir.meta)
function finish!(interp::AbstractInterpreter,
opt::OptimizationState, ir::IRCode, caller::InferenceResult)
src = opt.src

# compute inlining and other related optimizations
result = caller.result
@assert !(result isa LimitedAccuracy)
result = isa(result, InterConditional) ? widenconditional(result) : result
if (isa(result, Const) || isconstType(result))
proven_pure = false
# must be proven pure to use constant calling convention;
# otherwise we might skip throwing errors (issue #20704)
# TODO: Improve this analysis; if a function is marked @pure we should really
# only care about certain errors (e.g. method errors and type errors).
if length(ir.stmts) < 15
proven_pure = true
for i in 1:length(ir.stmts)
node = ir.stmts[i]
stmt = node[:inst]
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, node[:type], ir)
proven_pure = false
break
end
end
if proven_pure
for fl in src.slotflags
if (fl & SLOT_USEDUNDEF) != 0
proven_pure = false
break
end
end
end
end

if proven_pure
# use constant calling convention
# Do not emit `jl_fptr_const_return` if coverage is enabled
# so that we don't need to add coverage support
# to the `jl_call_method_internal` fast path
# Still set pure flag to make sure `inference` tests pass
# and to possibly enable more optimization in the future
src.pure = true
newresult = nothing # ConstAPI if this call can use constant calling convention
if isa(result, Const) || isconstType(result)
# computes "purity" (i.e. side-effect-freeness)
if compute_purity(ir, src)
src.inlineable = src.pure = true

# must be proven pure to use constant calling convention;
# otherwise we might skip throwing errors (issue #20704)
if isa(result, Const)
val = result.val
if is_inlineable_constant(val)
analyzed = ConstAPI(val)
newresult = ConstAPI(val)
end
else
@assert isconstType(result)
analyzed = ConstAPI(result.parameters[1])
newresult = ConstAPI(result.parameters[1])
end
force_noinline || (src.inlineable = true)
end
end

opt.ir = ir

# determine and cache inlineability
union_penalties = false
if !force_noinline
sig = unwrap_unionall(specTypes)
if isa(sig, DataType) && sig.name === Tuple.name
for P in sig.parameters
P = unwrap_unionall(P)
if isa(P, Union)
union_penalties = true
break
end
src.inlineable = compute_inlineability(ir, opt, result, src.inlineable)

caller.valid_worlds = (opt.inlining.et::EdgeTracker).valid_worlds[]

caller.src = transform_optresult_for_cache(interp, opt, ir, newresult)

return nothing
end

function compute_purity(ir::IRCode, src::CodeInfo)
# TODO: Improve this analysis; if a function is marked @pure we should really
# only care about certain errors (e.g. method errors and type errors).
if length(ir.stmts) < 15
for i in 1:length(ir.stmts)
node = ir.stmts[i]
stmt = node[:inst]
if stmt_affects_purity(stmt, ir) && !stmt_effect_free(stmt, node[:type], ir)
return false
end
else
force_noinline = true
end
if !src.inlineable && result === Bottom
force_noinline = true
for flag in src.slotflags
if (flag & SLOT_USEDUNDEF) != 0
return false
end
end
return true
end
if force_noinline
src.inlineable = false
elseif isa(def, Method)
if src.inlineable && isdispatchtuple(specTypes)
# obey @inline declaration if a dispatch barrier would not help
else
# compute the cost (size) of inlining this code
cost_threshold = default = params.inline_cost_threshold
if result Tuple && !isconcretetype(widenconst(result))
cost_threshold += params.inline_tupleret_bonus
end
# if the method is declared as `@inline`, increase the cost threshold 20x
if src.inlineable
cost_threshold += 19*default
end
# a few functions get special treatment
if def.module === _topmod(def.module)
name = def.name
if name === :iterate || name === :unsafe_convert || name === :cconvert
cost_threshold += 4*default
end
return false
end

function compute_inlineability(ir::IRCode, opt::OptimizationState, @nospecialize(result),
declared_inlineability::Bool)
(; def, specTypes) = opt.linfo
force_noinline = _any(@nospecialize(x) -> isexpr(x, :meta) && x.args[1] === :noinline, ir.meta)
force_noinline && return false
union_penalties = false
sig = unwrap_unionall(specTypes)
if isa(sig, DataType) && sig.name === Tuple.name
for P in sig.parameters
P = unwrap_unionall(P)
if isa(P, Union)
union_penalties = true
break
end
src.inlineable = inline_worthy(ir, params, union_penalties, cost_threshold)
end
else
return false
end
if !declared_inlineability && result === Bottom
return false
end
isa(def, Method) || return declared_inlineability
if declared_inlineability && isdispatchtuple(specTypes)
# obey @inline declaration if a dispatch barrier would not help
return true
end
# compute the cost (size) of inlining this code
params = opt.inlining.params
cost_threshold = default = params.inline_cost_threshold
if result Tuple && !isconcretetype(widenconst(result))
cost_threshold += params.inline_tupleret_bonus
end
# if the method is declared as `@inline`, increase the cost threshold 20x
if declared_inlineability
cost_threshold += 19*default
end
# a few functions get special treatment
if def.module === _topmod(def.module)
name = def.name
if name === :iterate || name === :unsafe_convert || name === :cconvert
cost_threshold += 4*default
end
end
return inline_worthy(ir, params, union_penalties, cost_threshold)
end

return analyzed
function transform_optresult_for_cache(::AbstractInterpreter,
opt::OptimizationState, ir::IRCode, @nospecialize(newresult))
if isa(newresult, ConstAPI)
# use constant calling convention
# Do not emit `jl_fptr_const_return` if coverage is enabled
# so that we don't need to add coverage support
# to the `jl_call_method_internal` fast path
# Still set pure flag to make sure `inference` tests pass
# and to possibly enable more optimization in the future

# XXX: The work in ir_to_codeinf! is essentially wasted. The only reason
# we're doing it is so that code_llvm can return the code
# for the `return ...::Const` (which never runs anyway). We should do this
# as a post processing step instead.
ir_to_codeinf!(opt, ir)
return newresult
end
return ir_to_codeinf!(opt, ir)
end

# run the optimization work
function optimize(interp::AbstractInterpreter, opt::OptimizationState,
params::OptimizationParams, caller::InferenceResult)
function optimize!(interp::AbstractInterpreter,
opt::OptimizationState, caller::InferenceResult)
@timeit "optimizer" ir = run_passes(opt.src, opt, caller)
return finish(interp, opt, params, ir, caller)
@timeit "finish!" finish!(interp, opt, ir, caller)
end

using .EscapeAnalysis
Expand Down
Loading

0 comments on commit 8e88697

Please sign in to comment.