Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AssertionError: useidx !== nothing #319

Open
willtebbutt opened this issue Oct 29, 2024 · 2 comments
Open

AssertionError: useidx !== nothing #319

willtebbutt opened this issue Oct 29, 2024 · 2 comments

Comments

@willtebbutt
Copy link
Member

willtebbutt commented Oct 29, 2024

There is a known issue with Core.Compiler.populate_use_def_map!, as discussed here: JuliaLang/julia#56193

This can result in problems like that encountered by @penelopeysm in this PR: TuringLang/Bijectors.jl#338

For posterity, here is the kind of error that you see if you encounter it:

  Got exception outside of a @test
  AssertionError: useidx !== nothing
  Stacktrace:
    [1] kill_def_use!(tpdum::Core.Compiler.TwoPhaseDefUseMap, def::Int64, use::Int64)
      @ Core.Compiler ./compiler/inferencestate.jl:147
    [2] kill_def_use!
      @ ./compiler/inferencestate.jl:158 [inlined]
    [3] reprocess_instruction!(interp::Core.Compiler.NativeInterpreter, inst::Core.Compiler.Instruction, idx::Int64, bb::Nothing, irsv::Core.Compiler.IRInterpretationState)
      @ Core.Compiler ./compiler/ssair/irinterp.jl:119
    [4] _ir_abstract_constant_propagation(interp::Core.Compiler.NativeInterpreter, irsv::Core.Compiler.IRInterpretationState; externally_refined::Nothing)
      @ Core.Compiler ./compiler/ssair/irinterp.jl:391
    [5] _ir_abstract_constant_propagation
      @ ./compiler/ssair/irinterp.jl:280 [inlined]
    [6] __infer_ir!
      @ ~/.julia/packages/Mooncake/LKJK9/src/interpreter/ir_utils.jl:117 [inlined]
    [7] optimise_ir!(ir::Core.Compiler.IRCode; show_ir::Bool, do_inline::Bool)
      @ Mooncake ~/.julia/packages/Mooncake/LKJK9/src/interpreter/ir_utils.jl:150
    [8] optimise_ir!
      @ ~/.julia/packages/Mooncake/LKJK9/src/interpreter/ir_utils.jl:139 [inlined]
    [9] build_rrule(interp::Mooncake.MooncakeInterpreter{Mooncake.DefaultCtx}, sig_or_mi::Type; debug_mode::Bool, silence_debug_messages::Bool)
      @ Mooncake ~/.julia/packages/Mooncake/LKJK9/src/interpreter/s2s_reverse_mode_ad.jl:925
   [10] build_rrule
      @ ~/.julia/packages/Mooncake/LKJK9/src/interpreter/s2s_reverse_mode_ad.jl:861 [inlined]
   [11] #build_rrule#210
      @ ~/.julia/packages/Mooncake/LKJK9/src/interpreter/s2s_reverse_mode_ad.jl:841 [inlined]
   [12] build_rrule
      @ ~/.julia/packages/Mooncake/LKJK9/src/interpreter/s2s_reverse_mode_ad.jl:839 [inlined]

As you can see from the stack trace, the problem occurs in _ir_abstract_constant_propagation, which is the function which performs type inference, performs constant propagation, and a few other things (I think).

A minimal reproducer on Julia 1.11.1 is:

using Pkg
Pkg.activate(; temp=true)
Pkg.@pkg_str"add [email protected] [email protected] [email protected] [email protected]"
using Bijectors, Distributions, LinearAlgebra, Mooncake

function f(θ)
    layer = PlanarLayer(θ[1:2], θ[3:4], θ[5:5])
    flow = transformed(MvNormal(zeros(2), I), layer)
    x = θ[6:7]
    return logpdf(flow.dist, x) - logabsdetjac(flow.transform, x)
end

rule = Mooncake.build_rrule(Tuple{typeof(f), Vector{Float64}})

We could do with a general strategy for working around this whenever it is encountered, or fixing it by patching the compiler internals until it is fixed more widely. Unfortunately, just fixing CC.populate_use_def_map! does not appear to resolve the problem, so I'm assuming that this is the surface of the problem. In slack, Keno mentioned that the code path on which this occurs is likely under-tested :(

edit: I've managed to find a more minimal reproducer:

using Pkg
Pkg.activate(; temp=true)
Pkg.@pkg_str"add [email protected]"
using Mooncake

function h(θ)
    d = [0.0, 0.0]
    x = θ[1:2]
    return d
end
Mooncake.build_rrule(Tuple{typeof(h), Vector{Float64}})
@willtebbutt
Copy link
Member Author

willtebbutt commented Feb 6, 2025

I've managed to figure out what's going on here, and have discussed it in here .

The upshot is that you can patch this with the following code:

# Apply fix from Jules Merck.
@eval CC function populate_def_use_map!(tpdum::TwoPhaseDefUseMap, scanner::BBScanner)
    scan!(scanner, false) do inst::Instruction, lstmt::Int, bb::Int
        for ur in userefs(inst[:stmt]) # replace inst with inst[:stmt]
            val = ur[]
            if isa(val, SSAValue)
                push!(tpdum[val.id], inst.idx)
            end
        end
        return true
    end
end

@eval CC function ((; sv)::ScanStmt)(inst::Instruction, lstmt::Int, bb::Int)
    stmt = inst[:stmt]

    if isa(stmt, EnterNode)
        # try/catch not yet modeled
        give_up_refinements!(sv)
        return true # don't bail out early -- replaces `nothing` with `true` 
    end

    scan_non_dataflow_flags!(inst, sv)

    stmt_inconsistent = scan_inconsistency!(inst, sv)

    if stmt_inconsistent
        if !has_flag(inst[:flag], IR_FLAG_NOTHROW)
            # Taint :consistent if this statement may raise since :consistent requires
            # consistent termination. TODO: Separate :consistent_return and :consistent_termination from :consistent.
            sv.all_retpaths_consistent = false
        end
        if inst.idx == lstmt
            if isa(stmt, ReturnNode) && isdefined(stmt, :val)
                sv.all_retpaths_consistent = false
            elseif isa(stmt, GotoIfNot)
                # Conditional Branch with inconsistent condition.
                # If we do not know this function terminates, taint consistency, now,
                # :consistent requires consistent termination. TODO: Just look at the
                # inconsistent region.
                if !sv.result.ipo_effects.terminates
                    sv.all_retpaths_consistent = false
                elseif visit_conditional_successors(sv.lazypostdomtree, sv.ir, bb) do succ::Int
                        return any_stmt_may_throw(sv.ir, succ)
                    end
                    # check if this `GotoIfNot` leads to conditional throws, which taints consistency
                    sv.all_retpaths_consistent = false
                else
                    (; cfg, domtree) = get!(sv.lazyagdomtree)
                    for succ in iterated_dominance_frontier(cfg, BlockLiveness(sv.ir.cfg.blocks[bb].succs, nothing), domtree)
                        if succ == length(cfg.blocks)
                            # Phi node in the virtual exit -> We have a conditional
                            # return. TODO: Check if all the retvals are egal.
                            sv.all_retpaths_consistent = false
                        else
                            visit_bb_phis!(sv.ir, succ) do phiidx::Int
                                push!(sv.inconsistent, phiidx)
                            end
                        end
                    end
                end
            end
        end
    end

    # Do not bail out early, as this can cause tpdum counts to be off.
    # # bail out early if there are no possibilities to refine the effects
    # if !any_refinable(sv)
    #     return nothing
    # end

    return true
end

@eval CC function scan_inconsistency!(inst::Instruction, sv::PostOptAnalysisState)
    flag = inst[:flag]
    stmt_inconsistent = !has_flag(flag, IR_FLAG_CONSISTENT)
    stmt = inst[:stmt]
    # Special case: For `getfield` and memory operations, we allow inconsistency of the :boundscheck argument
    (; inconsistent, tpdum) = sv
    # Main.@show stmt
    if iscall_with_boundscheck(stmt, sv)
        for i = 1:length(stmt.args) # explore all args -- don't assume boundscheck is not an SSA
            val = stmt.args[i]
            if isa(val, SSAValue)
                stmt_inconsistent |= val.id in inconsistent
                count!(tpdum, val)
            end
        end
    else
        for ur in userefs(stmt)
            val = ur[]
            if isa(val, SSAValue)
                stmt_inconsistent |= val.id in inconsistent
                count!(tpdum, val)
            end
        end
    end
    stmt_inconsistent && push!(inconsistent, inst.idx)
    return stmt_inconsistent
end

I've verified locally that this seems to fix the problem (there are two examples commented out that are tagging with this issue's number, and they are both fixed with the above patch).

edit2: patches updated to cover another couple of holes I found while trying to get this to work on TemporalGPs.jl. See comment on julia lined above for more discussion.

@willtebbutt
Copy link
Member Author

willtebbutt commented Feb 6, 2025

Since this problem occurs in the compiler, the best option for users at the minute is literally to copy + paste the above patch as part of their session (if they encounter this problem). Now that I think I know what's going on, I'm going to push to get the fix in and released in the next patch versions of 1.10 and 1.11. This issue needs to remain open until this happens, so it will open for at least a couple more months.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant