Skip to content

Commit

Permalink
Rework docs about gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
Jollywatt committed Feb 5, 2025
1 parent f7d6ba0 commit 83be391
Showing 1 changed file with 94 additions and 20 deletions.
114 changes: 94 additions & 20 deletions docs/src/understanding_mooncake/algorithmic_differentiation.md
Original file line number Diff line number Diff line change
Expand Up @@ -410,36 +410,110 @@ This "vector-Jacobian product" expression is commonly used to explain AD, and is

# Directional Derivatives and Gradients

Now we turn to using reverse-mode AD to compute the gradient of a function.
In short, given a function ``g : \mathcal{X} \to \RR`` with derivative ``D g [x]`` at ``x``, its gradient is equal to ``D g [x]^\ast (1)``.
We explain why in this section.
Now we turn to using forwards- and reverse-mode AD to compute the gradient of a function.

The derivative discussed here can be used to compute directional derivatives.
Consider a function ``f : \mathcal{X} \to \RR`` with Frechet derivative ``D f [x] : \mathcal{X} \to \RR`` at ``x \in \mathcal{X}``.
Then ``D f[x](\dot{x})`` returns the directional derivative in direction ``\dot{x}``.
Recall that if ``D f[x] : \mathcal{X} \to \mathbb{R}`` is the Frechet derivative discussed here then ``D f[x](\dot{x})`` is the _directional derivative_ in the ``\dot{x}`` direction.

Gradients are closely related to the adjoint of the derivative.
Recall that the gradient of ``f`` at ``x`` is defined to be the vector ``\nabla f (x) \in \mathcal{X}`` such that ``\langle \nabla f (x), \dot{x} \rangle`` gives the directional derivative of ``f`` at ``x`` in direction ``\dot{x}``.
Having noted that ``D f[x](\dot{x})`` is exactly this directional derivative, we can equivalently say that
The _gradient_ of ``f : \mathcal{X} \to \mathbb{R}`` at ``x`` is defined to be the vector ``\nabla f (x) \in \mathcal{X}`` such that
```math
D f[x](\dot{x}) = \langle \nabla f (x), \dot{x} \rangle .
\langle \nabla f (x), \dot{x} \rangle = D f[x](\dot{x})
```
for any direction ``\dot{x}``.
In other words, the vector ``\nabla f`` encodes all the information about the directional derivatives of ``f``, and we use the inner product to retrieve each one.

The role of the adjoint is revealed when we consider ``f := \mathcal{l} \circ g``, where ``g : \mathcal{X} \to \mathcal{Y}``, ``\mathcal{l}(y) := \langle \bar{y}, y \rangle``, and ``\bar{y} \in \mathcal{Y}`` is some fixed vector.
Noting that ``D \mathcal{l} [y](\dot{y}) = \langle \bar{y}, \dot{y} \rangle``, we apply the chain rule to obtain
An alternative characterisation is that ``\nabla f(x)`` is the vector pointing in the direction of steepest ascent on ``f`` at ``x``, with magnitude equal to the directional derivative in that steepest direction.

_**Aside: The choice of inner product**_

Notice that the value of the gradient depends on how the inner product on ``\mathcal{X}`` is defined.
Indeed, different choices of inner product result in different values of ``\nabla f``.
Adjoints such as ``D f[x]^*`` are also inner product dependent.
However, the actual derivative ``D f[x]`` is of course invariant -- it makes no reference to the inner product.

In practice, Mooncake uses the Euclidean inner product, extended in the "obvious way" to other composite data types (that is, as if everything is flattened and embedded in ``\mathbb{R}^N``).
But we endeavour to keep the discussion general in order to make the role of the inner product explicit.



#### Computing the gradient from forwards-mode

To compute the gradient in forwards-mode, we need to evaluate the forwards pass ``\dim \mathcal{X}`` times.
We also need to refer to a basis ``\{\mathbf{e}_i\}`` of ``\mathcal{X}`` and its reciprocal basis ``\{\mathbf{e}^i\}`` defined by ``\langle \mathbf{e}_i, \mathbf{e}^j \rangle = \delta_i^j``.
(For any basis there exists such a reciprocal basis, and they are the same if the basis is orthonormal.)

Equipped with such a pair of bases, we can always decompose a vector ``x = \sum_i x^i \mathbf{e}_i`` into its components ``x^i = \langle x, \mathbf{e}^i \rangle``.
Therefore, the gradient is given by
```math
\begin{align}
D f [x] (\dot{x}) &= [(D \mathcal{l} [g(x)]) \circ (D g [x])](\dot{x}) \nonumber \\
&= \langle \bar{y}, D g [x] (\dot{x}) \rangle \nonumber \\
&= \langle D g [x]^\ast (\bar{y}), \dot{x} \rangle, \nonumber
\end{align}
\nabla f(x)
= \sum_i \langle \nabla f(x), \mathbf{e}^i \rangle \mathbf{e}_i
= \sum_i D f[x](\mathbf{e}^i) \, \mathbf{e}_i
```
from which we conclude that ``D g [x]^\ast (\bar{y})`` is the gradient of the composition ``l \circ g`` at ``x``.
where the second equality follows from the gradient's implicit definition.

The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.
If the inner product is Euclidean, then ``\mathbf{e}^i = \mathbf{e}_i`` and we can interpret the ``i``th component of ``\nabla f`` as the directional derivative when moving in the ``i``th direction.

The above shows that if ``\mathcal{Y} = \RR`` and ``g`` is the function we wish to compute the gradient of, we can simply set ``\bar{y} = 1`` and compute ``D g [x]^\ast (\bar{y})`` to obtain the gradient of ``g`` at ``x``.
_**Example**_

Consider again the Julia `function`
```julia
f(x::Float64, y::Tuple{Float64, Float64}) = x + y[1] * y[2]
```
corresponding to ``f(x, y) = x + y_1 y_2``.
An orthonormal basis for the function's domain ``\mathbb{R} \times \mathbb{R}^2`` is
```math
\mathbf{e}_1 = \mathbf{e}^1 = (1, (0, 0)), \quad
\mathbf{e}_2 = \mathbf{e}^2 = (0, (1, 0)), \quad
\mathbf{e}_3 = \mathbf{e}^3 = (0, (0, 1)), \quad
```
so the gradient is
```math
\begin{align*}
\nabla f(x, y)
&= \sum_i D f[x, y](\mathbf{e}^i) \mathbf{e}_i \\
&= \Big(D f[x, y](1, (0, 0)), \big(D f[x, y](0, (1, 0)), D f[x, y](0, (0, 1))\big)\Big) \\
&= (1, (y_2, y_1))
\end{align*}
```
referring [above](#AD-of-a-Julia-function:-a-slightly-less-trivial-example) for the form of ``D f[x, y]``.

#### Computing the gradient from reverse-mode
If we perform a single reverse-pass on a function ``f : \mathcal{X} \to \RR`` to obtain ``D f[x]^\ast``, then the gradient is simply
```math
\nabla f (x) = D f[x]^\ast (1) .
```

To show this, note that ``D f [x] (\dot{x}) = \langle 1, D f[x] (\dot{x}) \rangle = \langle D f[x]^\ast (1), \dot{x} \rangle`` using the definition of the adjoint.
Then, the definition of the gradient gives
```math
\langle \nabla f (x), \dot{x} \rangle = \langle D f[x]^\ast (1), \dot{x} \rangle
```
which implies ``\nabla f (x) = D f[x]^\ast (1)`` since ``\dot{x}`` is arbitrary.

_**Example**_

The adjoint derivative of ``f(x, y) = x + y_1 y_2`` (see [above](#AD-of-a-Julia-function:-a-slightly-less-trivial-example)) immediately gives
```math
\nabla f(x, y) = D f[x, y]^\ast (1) = (1, (y_2, y_1)) .
```

_**Aside: Adjoint Derivatives as Gradients**_

It is interesting to note that value of ``D f[x]^\ast (\bar{y})`` returned by performing reverse-mode on a function ``f : \mathcal{X} \to \mathcal{Y}`` can always be viewed as the gradient of another function ``F : \mathcal{X} \to \mathbb{R}``.

Let ``F \coloneqq h_{\bar{y}} \circ f`` where ``h_{\bar{y}}(y) = \langle y, \bar{y}\rangle``.
One can show ``D h_{\bar{y}}[y]^\ast (1) = \bar{y}``.
Then, since
```math
\begin{align*}
\langle \nabla F(x), \dot{x} \rangle
&= \langle D F[x]^\ast (1), \dot{x} \rangle \\
&= \langle D f[x]^\ast (D h_{\bar{y}}[f(x)]^\ast (1)), \dot{x} \rangle \\
&= \langle D f[x]^\ast (\bar{y}), \dot{x} \rangle \\
\end{align*}
```
we have that ``\nabla F(x) = D f[x]^\ast (\bar{y})``.

The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.



Expand Down

0 comments on commit 83be391

Please sign in to comment.