Skip to content

Commit

Permalink
Add DIRK-IMEX schemes (#106)
Browse files Browse the repository at this point in the history
* First elements of DIRK IMEX

* Cleanup

* Remove thingies

* Cleanups

* Fix BCs / indent error

* Adding first test

* Fix IMEX-Euler syntax and test it

* Add convergence test

* heat -> convection-diffusion

* Add monodomain demo

* Fix typos in demos

* Lighter-weight mass solver

* Rename property

* Less cryptic comment

* Better name

* Add feedback on failed assertion

* Reorganize loops to prep for cases

* Add general finalize method

* Special case when last explicit stage is not needed

* Update demo

* Adding stiffly accurate finalize method

* Introducing factory code

* Typos in demo

* Tweak docstring

* Sign convention
  • Loading branch information
ScottMacLachlan authored Jan 14, 2025
1 parent cbf6af4 commit bfd9468
Show file tree
Hide file tree
Showing 8 changed files with 664 additions and 11 deletions.
10 changes: 5 additions & 5 deletions demos/monodomain/demo_monodomain_FHN.py.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ The basic form of the equation is:
\chi \left( C_m u_t + I_{ion}(u) \right) = \nabla \cdot \sigma \nabla u
where :math:`u` is the membrane potential, :math:`\sigma` is the conductivity tensor, :math:`C_m` is the specific capacitance of the cell membrane, and :math:`\chi` is the surface area to volume ration. The term :math:`I_{ion}` is current due to ionic flows through channels in the cell membranes, and may couple to a complicated reaction network. In our case, we take the relatively simple model due to Fitzhugh and Nagumo. Here, we have a separate concentration variable :math:`c` satisfying the reaction equation:
where :math:`u` is the membrane potential, :math:`\sigma` is the conductivity tensor, :math:`C_m` is the specific capacitance of the cell membrane, and :math:`\chi` is the surface area to volume ratio. The term :math:`I_{ion}` is current due to ionic flows through channels in the cell membranes, and may couple to a complicated reaction network. In our case, we take the relatively simple model due to Fitzhugh and Nagumo. Here, we have a separate concentration variable :math:`c` satisfying the reaction equation:

.. math::
Expand Down Expand Up @@ -58,15 +58,15 @@ Specify the physical constants and initial conditions::
sigma = as_matrix([[sigma1, 0.0], [0.0, sigma2]])

InitialPotential = conditional(x < 3.5, Constant(2.0), Constant(-1.28791))
InitialCell = conditional(And(And(31 <= x, x < 39), And(0 <= y, y < 35)),
initial_potential = conditional(x < 3.5, Constant(2.0), Constant(-1.28791))
initial_cell = conditional(And(And(31 <= x, x < 39), And(0 <= y, y < 35)),
Constant(2.0), Constant(-0.5758))


uu = Function(Z)
vu, vc = TestFunctions(Z)
uu.sub(0).interpolate(InitialPotential)
uu.sub(1).interpolate(InitialCell)
uu.sub(0).interpolate(initial_potential)
uu.sub(1).interpolate(initial_cell)

(u, c) = split(uu)
Expand Down
157 changes: 157 additions & 0 deletions demos/monodomain/demo_monodomain_FHN_dirkimex.py.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
Solving monodomain equations with Fitzhugh-Nagumo reaction and a DIRK-IMEX method
=================================================================================

We're solving monodomain (reaction-diffusion) with a particular reaction term.
The basic form of the equation is:

.. math::
\chi \left( C_m u_t + I_{ion}(u) \right) = \nabla \cdot \sigma \nabla u
where :math:`u` is the membrane potential, :math:`\sigma` is the conductivity tensor, :math:`C_m` is the specific capacitance of the cell membrane, and :math:`\chi` is the surface area to volume ratio. The term :math:`I_{ion}` is current due to ionic flows through channels in the cell membranes, and may couple to a complicated reaction network. In our case, we take the relatively simple model due to Fitzhugh and Nagumo. Here, we have a separate concentration variable :math:`c` satisfying the reaction equation:

.. math::
c_t = \epsilon( u + \beta - \gamma c)
for certain positive parameters :math:`\beta` and :math:`\gamma`, and the current takes the form of:

.. math::
I_{ion}(u, c) = \tfrac{1}{\epsilon} \left( u - \tfrac{u^3}{3} - c \right)
so that we have an overall system of two equations. One of them is linear but stiff/diffusive, and the other is nonstiff but nonlinear. This combination makes the system a good candidate for IMEX-type methods.


We start with standard Firedrake/Irksome imports::

import copy

from firedrake import (And, Constant, File, Function, FunctionSpace,
RectangleMesh, SpatialCoordinate, TestFunctions,
as_matrix, conditional, dx, grad, inner, split)
from irksome import Dt, MeshConstant, DIRK_IMEX, TimeStepper

And we set up the mesh and function space.::
mesh = RectangleMesh(20, 20, 70, 70, quadrilateral=True)
polyOrder = 2
V = FunctionSpace(mesh, "CG", 2)
Z = V * V

x, y = SpatialCoordinate(mesh)
MC = MeshConstant(mesh)
dt = MC.Constant(0.05)
t = MC.Constant(0.0)

Specify the physical constants and initial conditions::

eps = Constant(0.1)
beta = Constant(1.0)
gamma = Constant(0.5)

chi = Constant(1.0)
capacitance = Constant(1.0)

sigma1 = sigma2 = 1.0
sigma = as_matrix([[sigma1, 0.0], [0.0, sigma2]])

initial_potential = conditional(x < 3.5, Constant(2.0), Constant(-1.28791))
initial_cell = conditional(And(And(31 <= x, x < 39), And(0 <= y, y < 35)),
Constant(2.0), Constant(-0.5758))


uu = Function(Z)
vu, vc = TestFunctions(Z)
uu.sub(0).interpolate(initial_potential)
uu.sub(1).interpolate(initial_cell)

(u, c) = split(uu)

This sets up the Butcher tableau. Here, we use the DIRK-IMEX methods proposed
by Ascher, Ruuth, and Spiteri in their 1997 Applied Numerical Mathematics paper.
For this case, We use a four-stage method.::
butcher_tableau = DIRK_IMEX(4, 4, 3)
ns = butcher_tableau.num_stages

To access an IMEX method, we need to separately specify the implicit and explicit parts of the operator.
The part to be handled implicitly is taken to contain the time derivatives as well::
F1 = (inner(chi * capacitance * Dt(u), vu)*dx
+ inner(grad(u), sigma * grad(vu))*dx
+ inner(Dt(c), vc)*dx - inner(eps * u, vc)*dx
- inner(beta * eps, vc)*dx + inner(gamma * eps * c, vc)*dx)

This is the part to be handled explicitly.::
F2 = inner((chi/eps) * (-u + (u**3 / 3) + c), vu)*dx

If we wanted to use a fully implicit method, we would just take
F = F1 + F2.

Now, set up solver parameters. Since we're using a DIRK-IMEX scheme, we can
specify only parameters for each stage. We use an additive Schwarz (fieldsplit) method that applies AMG to the potential block and incomplete Cholesky to the cell block independently for each stage::
params = {"snes_type": "ksponly",
"ksp_monitor": None,
"mat_type": "aij",
"ksp_type": "fgmres",
"pc_type": "fieldsplit",
"pc_fieldsplit_type": "additive",
"fieldsplit_0": {
"ksp_type": "preonly",
"pc_type": "gamg",
},
"fieldsplit_1": {
"ksp_type": "preonly",
"pc_type": "icc",
}}


The DIRK-IMEX schemes also require a mass-matrix solver. Here, we just use an incomplete Cholesky preconditioner for CG on the coupled system, which works fine.::

mass_params = {"snes_type": "ksponly",
"ksp_rtol": 1.e-8,
"ksp_monitor": None,
"mat_type": "aij",
"ksp_type": "cg",
"pc_type": "icc",
}

Now, we access the IMEX method via the `TimeStepper` as with other methods. Note that we specify somewhat different kwargs, needing to specify the implicit and explicit parts separately as well as separate solver options for the implicit and mass solvers.::
stepper = TimeStepper(F1, butcher_tableau, t, dt, uu,
stage_type="dirkimex",
solver_parameters=params,
mass_parameters=mass_params,
Fexp=F2)

uFinal, cFinal = uu.split()
outfile1 = File("FHN_results/FHN_2d_u.pvd")
outfile2 = File("FHN_results/FHN_2d_c.pvd")
outfile1.write(uFinal, time=0)
outfile2.write(cFinal, time=0)

for j in range(12):
print(f"{float(t)}")
stepper.advance()
t.assign(float(t) + float(dt))

if (j % 5 == 0):
outfile1.write(uFinal, time=j * float(dt))
outfile2.write(cFinal, time=j * float(dt))


We can print out some solver statistics here. We expect one implicit solve per stage per timestep, and that's what we see with the four-stage method. For this Butcher Tableau, we can avoid computing the final explicit stage (since it's coefficient in the next stage reconstruction is zero), so we see the same number of mass solves.::

nsteps, n_nonlin, n_lin, n_nonlin_mass, n_lin_mass = stepper.solver_stats()
print(f"Time steps taken: {nsteps}")
print(f" {n_nonlin} nonlinear steps in implicit stage solves (should be {nsteps*ns})")
print(f" {n_lin} linear steps in implicit stage solves")
print(f" {n_nonlin_mass} nonlinear steps in mass solves (should be {nsteps*ns})")
print(f" {n_lin_mass} linear steps in mass solves")

3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,13 @@ and for adaptive IRK methods:
demos/demo_heat_adapt.py


Or check out an IMEX-type method for the monodomain equations:
Or check out two IMEX-type methods for the monodomain equations:

.. toctree::
:maxdepth: 1

demos/demo_monodomain_FHN.py
demos/demo_monodomain_FHN_dirkimex.py

Advanced demos
--------------
Expand Down
2 changes: 2 additions & 0 deletions irksome/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from .ButcherTableaux import RadauIIA # noqa: F401
from .pep_explicit_rk import PEPRK # noqa: F401
from .deriv import Dt # noqa: F401
from .dirk_imex_tableaux import DIRK_IMEX # noqa: F401
from .dirk_stepper import DIRKTimeStepper # noqa: F401
from .getForm import getForm # noqa: F401
from .imex import RadauIIAIMEXMethod # noqa: F401
from .imex import DIRKIMEXMethod # noqa: F401
from .pc import RanaBase, RanaDU, RanaLD # noqa: F401
from .pc import IRKAuxiliaryOperatorPC # noqa: F401
from .stage import StageValueTimeStepper # noqa: F401
Expand Down
71 changes: 71 additions & 0 deletions irksome/dirk_imex_tableaux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from .ButcherTableaux import ButcherTableau
import numpy as np

# For the implicit scheme, the full Butcher Table is given as A, b, c.

# For the explicit scheme, the full b_hat and c_hat are given, but (to
# avoid a lot of offset-by-ones in the code we store only the
# lower-left ns x ns block of A_hat

# IMEX Butcher tableau for 1 stage
imex111A = np.array([[1.0]])
imex111A_hat = np.array([[1.0]])
imex111b = np.array([1.0])
imex111b_hat = np.array([1.0, 0.0])
imex111c = np.array([1.0])
imex111c_hat = np.array([0.0, 1.0])


# IMEX Butcher tableau for s = 2
gamma = (2 - np.sqrt(2)) / 2
delta = -2 * np.sqrt(2) / 3
imex232A = np.array([[gamma, 0], [1 - gamma, gamma]])
imex232A_hat = np.array([[gamma, 0], [delta, 1 - delta]])
imex232b = np.array([1 - gamma, gamma])
imex232b_hat = np.array([0, 1 - gamma, gamma])
imex232c = np.array([gamma, 1.0])
imex232c_hat = np.array([0, gamma, 1.0])

# IMEX Butcher tableau for 3 stages
imex343A = np.array([[0.4358665215, 0, 0], [0.2820667392, 0.4358665215, 0], [1.208496649, -0.644363171, 0.4358665215]])
imex343A_hat = np.array([[0.4358665215, 0, 0], [0.3212788860, 0.3966543747, 0], [-0.105858296, 0.5529291479, 0.5529291479]])
imex343b = np.array([1.208496649, -0.644363171, 0.4358665215])
imex343b_hat = np.array([0, 1.208496649, -0.644363171, 0.4358665215])
imex343c = np.array([0.4358665215, 0.7179332608, 1])
imex343c_hat = np.array([0, 0.4358665215, 0.7179332608, 1.0])


# IMEX Butcher tableau for 4 stages
imex443A = np.array([[1/2, 0, 0, 0],
[1/6, 1/2, 0, 0],
[-1/2, 1/2, 1/2, 0],
[3/2, -3/2, 1/2, 1/2]])
imex443A_hat = np.array([[1/2, 0, 0, 0],
[11/18, 1/18, 0, 0],
[5/6, -5/6, 1/2, 0],
[1/4, 7/4, 3/4, -7/4]])
imex443b = np.array([3/2, -3/2, 1/2, 1/2])
imex443b_hat = np.array([1/4, 7/4, 3/4, -7/4, 0])
imex443c = np.array([1/2, 2/3, 1/2, 1])
imex443c_hat = np.array([0, 1/2, 2/3, 1/2, 1])

dirk_imex_dict = {
(1, 1, 1): (imex111A, imex111b, imex111c, imex111A_hat, imex111b_hat, imex111c_hat),
(2, 3, 2): (imex232A, imex232b, imex232c, imex232A_hat, imex232b_hat, imex232c_hat),
(3, 4, 3): (imex343A, imex343b, imex343c, imex343A_hat, imex343b_hat, imex343c_hat),
(4, 4, 3): (imex443A, imex443b, imex443c, imex443A_hat, imex443b_hat, imex443c_hat)
}


class DIRK_IMEX(ButcherTableau):
def __init__(self, ns_imp, ns_exp, order):
try:
A, b, c, A_hat, b_hat, c_hat = dirk_imex_dict[ns_imp, ns_exp, order]
except KeyError:
raise NotImplementedError("No DIRK-IMEX method for that combination of implicit and explicit stages and order")
self.order = order
super(DIRK_IMEX, self).__init__(A, b, None, c, order, None, None)
self.A_hat = A_hat
self.b_hat = b_hat
self.c_hat = c_hat
self.is_dirk_imex = True # Mark this as a DIRK-IMEX scheme
Loading

0 comments on commit bfd9468

Please sign in to comment.