Skip to content

Commit

Permalink
Merge pull request #1090 from pmgbergen/value-of-ad-functions
Browse files Browse the repository at this point in the history
BUG: Allow .value calls for ad functions
  • Loading branch information
IvarStefansson authored Jan 4, 2024
2 parents a476249 + 482ff19 commit 5cd06cc
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 52 deletions.
76 changes: 38 additions & 38 deletions src/porepy/numerics/ad/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,19 @@
"""
from __future__ import annotations

from typing import Callable
from typing import Callable, TypeVar

import numpy as np
import scipy.sparse as sps

import porepy as pp
from porepy.numerics.ad.forward_mode import AdArray

FloatType = TypeVar("FloatType", AdArray, np.ndarray, float)

__all__ = [
"exp",
"log",
"sign",
"abs",
"l2_norm",
"sin",
Expand All @@ -54,8 +55,8 @@
]


# %% Exponential and logarithmic functions
def exp(var):
# Exponential and logarithmic functions
def exp(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.exp(var.val)
der = var._diagvec_mul_jac(np.exp(var.val))
Expand All @@ -64,7 +65,7 @@ def exp(var):
return np.exp(var)


def log(var):
def log(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.log(var.val)
der = var._diagvec_mul_jac(1 / var.val)
Expand All @@ -73,18 +74,13 @@ def log(var):
return np.log(var)


# %% Sign and absolute value functions and l2_norm
def sign(var):
if not isinstance(var, AdArray):
return np.sign(var)
else:
return np.sign(var.val)
# Absolute value and l2_norm


def abs(var):
def abs(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.abs(var.val)
jac = var._diagvec_mul_jac(sign(var))
jac = var._diagvec_mul_jac(np.sign(var.val))
return AdArray(val, jac)
else:
return np.abs(var)
Expand All @@ -103,13 +99,15 @@ def l2_norm(dim: int, var: pp.ad.AdArray) -> pp.ad.AdArray:
Parameters:
dim: Dimension, i.e. number of vector components.
var: Ad operator (variable or expression) which is argument of the norm function.
var: Ad operator which is argument of the norm function.
Returns:
The norm of var with appropriate val and jac attributes.
"""

if not isinstance(var, AdArray):
resh = np.reshape(var, (dim, -1), order="F")
return np.linalg.norm(resh, axis=0)
if dim == 1:
# For scalar variables, the cell-wise L2 norm is equivalent to
# taking the absolute value.
Expand All @@ -124,8 +122,8 @@ def l2_norm(dim: int, var: pp.ad.AdArray) -> pp.ad.AdArray:
# Prepare for left multiplication with var.jac to yield
# norm(var).jac = var/norm(var) * var.jac
dim_size = var.val.size
# Check that size of var is compatible with the given dimension, e.g. all 'cells' have
# the same number of values assigned
# Check that size of var is compatible with the given dimension, e.g. all 'cells'
# have the same number of values assigned
assert dim_size % dim == 0
size = int(dim_size / dim)
local_inds_t = np.arange(dim_size)
Expand All @@ -141,8 +139,8 @@ def l2_norm(dim: int, var: pp.ad.AdArray) -> pp.ad.AdArray:
return pp.ad.AdArray(vals, jac)


# %% Trigonometric functions
def sin(var):
# Trigonometric functions
def sin(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.sin(var.val)
jac = var._diagvec_mul_jac(np.cos(var.val))
Expand All @@ -151,7 +149,7 @@ def sin(var):
return np.sin(var)


def cos(var):
def cos(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.cos(var.val)
jac = var._diagvec_mul_jac(-np.sin(var.val))
Expand All @@ -160,7 +158,7 @@ def cos(var):
return np.cos(var)


def tan(var):
def tan(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.tan(var.val)
jac = var._diagvec_mul_jac((np.cos(var.val) ** 2) ** (-1))
Expand All @@ -169,7 +167,7 @@ def tan(var):
return np.tan(var)


def arcsin(var):
def arcsin(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.arcsin(var.val)
jac = var._diagvec_mul_jac((1 - var.val**2) ** (-0.5))
Expand All @@ -178,7 +176,7 @@ def arcsin(var):
return np.arcsin(var)


def arccos(var):
def arccos(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.arccos(var.val)
jac = var._diagvec_mul_jac(-((1 - var.val**2) ** (-0.5)))
Expand All @@ -187,7 +185,7 @@ def arccos(var):
return np.arccos(var)


def arctan(var):
def arctan(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.arctan(var.val)
jac = var._diagvec_mul_jac((var.val**2 + 1) ** (-1))
Expand All @@ -196,8 +194,8 @@ def arctan(var):
return np.arctan(var)


# %% Hyperbolic functions
def sinh(var):
# Hyperbolic functions
def sinh(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.sinh(var.val)
jac = var._diagvec_mul_jac(np.cosh(var.val))
Expand All @@ -206,7 +204,7 @@ def sinh(var):
return np.sinh(var)


def cosh(var):
def cosh(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.cosh(var.val)
jac = var._diagvec_mul_jac(np.sinh(var.val))
Expand All @@ -215,7 +213,7 @@ def cosh(var):
return np.cosh(var)


def tanh(var):
def tanh(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.tanh(var.val)
jac = var._diagvec_mul_jac(np.cosh(var.val) ** (-2))
Expand All @@ -224,7 +222,7 @@ def tanh(var):
return np.tanh(var)


def arcsinh(var):
def arcsinh(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.arcsinh(var.val)
jac = var._diagvec_mul_jac((var.val**2 + 1) ** (-0.5))
Expand All @@ -233,7 +231,7 @@ def arcsinh(var):
return np.arcsinh(var)


def arccosh(var):
def arccosh(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.arccosh(var.val)
den1 = (var.val - 1) ** (-0.5)
Expand All @@ -244,7 +242,7 @@ def arccosh(var):
return np.arccosh(var)


def arctanh(var):
def arctanh(var: FloatType) -> FloatType:
if isinstance(var, AdArray):
val = np.arctanh(var.val)
jac = var._diagvec_mul_jac((1 - var.val**2) ** (-1))
Expand All @@ -253,7 +251,7 @@ def arctanh(var):
return np.arctanh(var)


# %% Step and Heaviside functions
# Step and Heaviside functions
def heaviside(var, zerovalue: float = 0.5):
if isinstance(var, AdArray):
return np.heaviside(var.val, zerovalue)
Expand Down Expand Up @@ -306,7 +304,7 @@ def __call__(self, var, zerovalue: float = 0.5):
return np.heaviside(var) # type: ignore


def maximum(var_0: pp.ad.AdArray, var_1: pp.ad.AdArray | np.ndarray) -> pp.ad.AdArray:
def maximum(var_0: FloatType, var_1: FloatType) -> FloatType:
"""Ad maximum function represented as an AdArray.
The arguments can be either AdArrays or ndarrays, this duality is needed to allow
Expand Down Expand Up @@ -387,7 +385,7 @@ def maximum(var_0: pp.ad.AdArray, var_1: pp.ad.AdArray | np.ndarray) -> pp.ad.Ad
if isinstance(jacs[0], (float, int)):
assert np.isclose(jacs[0], 0)
assert np.isclose(jacs[1], 0)
return pp.ad.AdArray(max_val, 0)
return AdArray(max_val, 0)

# Start from var_0, then change entries corresponding to inds.
max_jac = jacs[0].copy()
Expand All @@ -401,10 +399,10 @@ def maximum(var_0: pp.ad.AdArray, var_1: pp.ad.AdArray | np.ndarray) -> pp.ad.Ad
else:
max_jac[inds] = jacs[1][inds]

return pp.ad.AdArray(max_val, max_jac)
return AdArray(max_val, max_jac)


def characteristic_function(tol: float, var: pp.ad.AdArray):
def characteristic_function(tol: float, var: FloatType) -> FloatType:
"""Characteristic function of an ad variable.
Returns 1 if ``var.val`` is within absolute tolerance = ``tol`` of zero.
Expand All @@ -422,8 +420,10 @@ def characteristic_function(tol: float, var: pp.ad.AdArray):
The characteristic function of var with appropriate val and jac attributes.
"""
if not isinstance(var, AdArray):
return np.isclose(var, 0, atol=tol).astype(float)
vals = np.zeros(var.val.size)
zero_inds = np.isclose(var.val, 0, atol=tol)
vals[zero_inds] = 1
vals[zero_inds] = 1.0
jac = sps.csr_matrix(var.jac.shape)
return pp.ad.AdArray(vals, jac)
return AdArray(vals, jac)
14 changes: 0 additions & 14 deletions tests/numerics/ad/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,6 @@ def test_log_scalar_times_ad_var():
assert np.all(a.val == [1, 2, 3]) and np.all(a.jac.A == J.A)


# Function: sign
def test_sign_no_advar():
a = np.array([1, -10, 3, -np.pi])
sign = af.sign(a)
assert np.all(sign == [1, -1, 1, -1])


def test_sign_advar():
a = AdArray(np.array([1, -10, 3, -np.pi]), np.eye(4))
sign = af.sign(a)
assert np.all(sign == [1, -1, 1, -1])
assert np.allclose(a.val, [1, -10, 3, -np.pi]) and np.allclose(a.jac, np.eye(4))


# Function: abs
def test_abs_no_advar():
a = np.array([1, -10, 3, -np.pi])
Expand Down

0 comments on commit 5cd06cc

Please sign in to comment.