Skip to content

Commit

Permalink
BUG: Update Ad operator evaluation
Browse files Browse the repository at this point in the history
Should have been included in previous commits
  • Loading branch information
keileg committed Feb 7, 2025
1 parent f931949 commit a4123a6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
10 changes: 5 additions & 5 deletions src/porepy/numerics/nonlinear/line_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,14 +341,14 @@ def compute_constraint_weights(
f_1_vals = cast(
np.ndarray,
model.equation_system.evaluate(
constraint_function, x_0 + solution_update * b
constraint_function, state=x_0 + solution_update * b
),
)
f_1 = f_1_vals[crossing_inds]

def f(x):
return model.equation_system.evaluate(
constraint_function, x_0 + solution_update * x
constraint_function, state=x_0 + solution_update * x
)[crossing_inds]

alpha, a, b = self.recursive_spline_interpolation(
Expand Down Expand Up @@ -658,12 +658,12 @@ def constraint_weights(
f_1 = cast(
np.ndarray,
model.equation_system.evaluate(
constraint_function, x_0 + max_weight * solution_update
constraint_function, state=x_0 + max_weight * solution_update
),
)
weight = max_weight
weights = max_weight * np.ones(f_1.shape)
f_0 = model.equation_system.evaluate(constraint_function, x_0)
f_0 = model.equation_system.evaluate(constraint_function, state=x_0)
active_inds = np.ones(f_1.shape, dtype=bool)
for i in range(10):
# Only consider dofs where the constraint indicator has changed sign.
Expand Down Expand Up @@ -711,7 +711,7 @@ def constraint_weights(
f_1 = cast(
np.ndarray,
model.equation_system.evaluate(
constraint_function, x_0 + weight * solution_update
constraint_function, state=x_0 + weight * solution_update
),
)
active_inds = np.logical_and(
Expand Down
20 changes: 13 additions & 7 deletions tests/numerics/ad/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,15 +574,17 @@ def _compare_ad_objects(a, b):
inds_var = np.hstack(
[eq_system.dofs_of(eq_system.get_variables([var], [g])) for g in subdomains]
)
assert np.allclose(true_iterate[inds_var], eq_system.evaluate(var_ad, true_iterate))
assert np.allclose(
true_iterate[inds_var], eq_system.evaluate(var_ad, state=true_iterate)
)

# Check evaluation when no state is passed to the parser, and information must
# instead be glued together from the MixedDimensionalGrid
assert np.allclose(true_iterate[inds_var], eq_system.evaluate(var_ad))

# Evaluate the equation using the double iterate
assert np.allclose(
2 * true_iterate[inds_var], eq_system.evaluate(var_ad, double_iterate)
2 * true_iterate[inds_var], eq_system.evaluate(var_ad, state=double_iterate)
)

# Represent the variable on the previous time step. This should be a numpy array
Expand All @@ -593,7 +595,9 @@ def _compare_ad_objects(a, b):

# Also check that state values given to the ad parser are ignored for previous
# values
assert np.allclose(prev_evaluated, eq_system.evaluate(prev_var_ad, double_iterate))
assert np.allclose(
prev_evaluated, eq_system.evaluate(prev_var_ad, state=double_iterate)
)

## Next, test edge variables. This should be much the same as the grid variables,
# so the testing is less thorough.
Expand All @@ -607,7 +611,7 @@ def _compare_ad_objects(a, b):
[eq_system.dofs_of([var]) for var in variable_interfaces]
)
interface_values = np.hstack(
[eq_system.evaluate(var, true_iterate) for var in variable_interfaces]
[eq_system.evaluate(var, state=true_iterate) for var in variable_interfaces]
)
assert np.allclose(
true_iterate[interface_inds],
Expand All @@ -622,11 +626,13 @@ def _compare_ad_objects(a, b):
ind1 = eq_system.dofs_of(eq_system.get_variables([var], [g]))
ind2 = eq_system.dofs_of(eq_system.get_variables([var2], [g]))

assert np.allclose(true_iterate[ind1], eq_system.evaluate(v1, true_iterate))
assert np.allclose(true_iterate[ind2], eq_system.evaluate(v2, true_iterate))
assert np.allclose(true_iterate[ind1], eq_system.evaluate(v1, state=true_iterate))
assert np.allclose(true_iterate[ind2], eq_system.evaluate(v2, state=true_iterate))

v1_prev = v1.previous_timestep()
assert np.allclose(true_state[ind1], eq_system.evaluate(v1_prev, true_iterate))
assert np.allclose(
true_state[ind1], eq_system.evaluate(v1_prev, state=true_iterate)
)


@pytest.mark.parametrize("prev_time", [True, False])
Expand Down

0 comments on commit a4123a6

Please sign in to comment.