Gradient of jnp.linalg.norm
does not respect absolute homogeneity.
#26248
Labels
bug
Something isn't working
jnp.linalg.norm
does not respect absolute homogeneity.
#26248
Description
The norm should satisfy$d= \left\Vert a \mathbf{b}\right\Vert = \left\vert a\right\vert \left\Vert\mathbf{b}\right\Vert$ for scalar $a$ and vector $\mathbf{b}$ . Consequently $\frac{\partial d}{\partial a} = \mathrm{sign}\left(a\right) \left\Vert\mathbf{b}\right\Vert$ . However, the gradient of the standard implementation of the $p$ -norm $\left(\sum_{i} b_i ^ p\right)^{1/p}$ is not defined at $\mathbf{b}=\mathbf{0}$ because applying the chain rule includes terms involving negative powers of zero.
This particular situation arises, for example, in the evaluation of the diagonal of the Matern covariance kernel for Gaussian processes in more than one dimension. Specifically, for heterogeneous
length_scales
, the rescaled distance between two pointsx
andy
isjnp.linalg.norm((x - y) / length_scales)
, and the derivative fails forx == y
1. This is not an issue for homogeneouslength_scales
because we can rewrite asjnp.linalg.norm(x - y) / length_scales
.System info (python version, jaxlib version, accelerator, etc.)
Footnotes
If we use a Mahalanobis-style distance rather than just considering the product of kernels in different dimensions. ↩
The text was updated successfully, but these errors were encountered: