Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gradient of jnp.linalg.norm does not respect absolute homogeneity. #26248

Open
tillahoffmann opened this issue Jan 31, 2025 · 1 comment
Open
Assignees
Labels
bug Something isn't working

Comments

@tillahoffmann
Copy link
Contributor

tillahoffmann commented Jan 31, 2025

Description

The norm should satisfy $d= \left\Vert a \mathbf{b}\right\Vert = \left\vert a\right\vert \left\Vert\mathbf{b}\right\Vert$ for scalar $a$ and vector $\mathbf{b}$. Consequently $\frac{\partial d}{\partial a} = \mathrm{sign}\left(a\right) \left\Vert\mathbf{b}\right\Vert$. However, the gradient of the standard implementation of the $p$-norm $\left(\sum_{i} b_i ^ p\right)^{1/p}$ is not defined at $\mathbf{b}=\mathbf{0}$ because applying the chain rule includes terms involving negative powers of zero.

>>> import jax
>>> from jax import numpy as jnp

>>> def func1(a, x):
...     return jnp.linalg.norm(a * x)


>>> def func2(a, x):
...     return jnp.abs(a) * jnp.linalg.norm(x)


>>> funcs = [func1, func2]
>>> a = 1.3
>>> x = jnp.zeros(2)

>>> for func in funcs:
...     print(f"{func.__name__}({a}, {x}) = {func(a, x)}")
...     print(f"jax.grad({func.__name__})({a}, {x}) = {jax.grad(func)(a, x)}")

func1(1.3, [0. 0.]) = 0.0
jax.grad(func1)(1.3, [0. 0.]) = nan
func2(1.3, [0. 0.]) = 0.0
jax.grad(func2)(1.3, [0. 0.]) = 0.0

This particular situation arises, for example, in the evaluation of the diagonal of the Matern covariance kernel for Gaussian processes in more than one dimension. Specifically, for heterogeneous length_scales, the rescaled distance between two points x and y is jnp.linalg.norm((x - y) / length_scales), and the derivative fails for x == y1. This is not an issue for homogeneous length_scales because we can rewrite as jnp.linalg.norm(x - y) / length_scales.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.0
jaxlib: 0.5.0
numpy:  2.2.1
python: 3.11.5 (main, Dec  8 2023, 17:04:09) [Clang 15.0.0 (clang-1500.0.40.1)]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Tills-MacBook-Pro-3.local', release='24.2.0', version='Darwin Kernel Version 24.2.0: Fri Dec  6 18:40:14 PST 2024; root:xnu-11215.61.5~2/RELEASE_ARM64_T8103', machine='arm64')

Footnotes

  1. If we use a Mahalanobis-style distance rather than just considering the product of kernels in different dimensions.

@tillahoffmann tillahoffmann added the bug Something isn't working label Jan 31, 2025
@jakevdp
Copy link
Collaborator

jakevdp commented Jan 31, 2025

Ping @mattjj because he daydreams about autodiff corner cases

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants