Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Activity Analysis #455

Open
3 tasks
willtebbutt opened this issue Feb 3, 2025 · 4 comments
Open
3 tasks

Activity Analysis #455

willtebbutt opened this issue Feb 3, 2025 · 4 comments
Labels
enhancement (performance) Would reduce the time it takes to run some bit of the code enhancement New feature or request

Comments

@willtebbutt
Copy link
Member

#452 highlighted a situation in which our lack of activity analysis can cause substantial slow-downs in performance. This purpose of this issue is to sketch out how activity analysis might be implemented in Mooncake.

Todo:

  • explain what activity analysis lets you do
  • explain what you need to consider when implementing it in a mutation-friendly AD system
  • explain how it might be implemented in Mooncake
@willtebbutt willtebbutt added enhancement New feature or request enhancement (performance) Would reduce the time it takes to run some bit of the code labels Feb 3, 2025
@gdalle
Copy link
Collaborator

gdalle commented Feb 3, 2025

Based on our previous discussion in #452, I feel like supporting inactive arguments might be slightly simpler than we anticipated. We could decide that every argument which is not a (Co)Dual is inactive by default. For read-only objects this is straightfoward, and for writable objects we just need to ensure that we don't accidentally write active data into an inactive object. But if the analogy with ForwardDiff is valid, such a write attempt will just error?

@willtebbutt
Copy link
Member Author

But if the analogy with ForwardDiff is valid, such a write attempt will just error?

Yeah, I think you're right that we can ensure that the behaviour is to error. For example, the rule for setindex! (or whatever the low-level primitive is called) would error if the array being written to is inactive, but the array whose value you are writing to it is active.

In terms of design, I suspect that we would be better off having a wrapper type that explicitly states that a value is inactive, otherwise we risk issues inside nested AD (I think). This is a detail we can figure out though.

Note that adding this kind of functionality would make it possible to resolve #412 . The various SpecialFunctions.jl rules that we cannot currently support all involve read-only data (always floats I believe) which is typically inactive (this is the case that Zygote can support). It would be very nice to have this resolved.

@yebai
Copy link
Contributor

yebai commented Feb 3, 2025

One slightly broader question about inactive wrapper types: can we use it to detect and track loop-invariants (e.g. array refs and induction variables, when they don't contribute to gradients, which are common) discussed in #156 and then skip pushing these inactive variables to block stacks?

@willtebbutt
Copy link
Member Author

I'm not sure that we can do this in general, but it will definitely have an effect on the what we need to shove in adjoints. This will be operation-dependent, so lets consider a couple of examples.

The rule for *(::Float64, ::Float64) is something like

function rrule!!(::CoDual{typeof(*)}, _x::CoDual{Float64}, _y::CoDual{Float64})
    x = primal(_x)
    y = primal(_y)
    z = x * y
    mul_adjoint(dz::Float64) = NoRData(), dz * y, dz * x
    return CoDual(z, NoFData(), mul_adjoint
end

Importantly, both x and y are required on the reverse-pass, meaning that mul_adjoint contains 16B of data.

Suppose that y is inactive. This means that we do not need to compute its gradient on the reverse-pass. In this case, the rule would be something like

function rrule!!(::CoDual{typeof(*)}, _x::CoDual{Float64}, _y::Inactive{Float64})
    x = primal(_x)
    y = primal(_y)
    z = x * y
    mul_adjoint(dz::Float64) = NoRData(), dz * y, GradNotRequired()
    return CoDual(z, NoFData(), mul_adjoint
end

The size of the adjoint is now only 8B, so the total amount of memory used is halved. This is probably quite valuable if this operation is called inside a large loop. Of course, if both x and y are inactive, then mul_adjoint will be both a singleton and a no-op.

Similarly, the adjoint returned by the rule for getindex(::Array, ::Int) (or whatever the low-level primitive is called that we actually implement rules for) will become a no-op and a singleton if the Array argument is inactive, meaning that it will not appear on the stack.

Is this the kind of thing that you had in mind @yebai ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement (performance) Would reduce the time it takes to run some bit of the code enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants