Skip to content

Commit

Permalink
Unit tests for TGLFInputs
Browse files Browse the repository at this point in the history
  • Loading branch information
theo-brown committed Feb 13, 2025
1 parent 6b62f36 commit e63a27f
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 31 deletions.
1 change: 1 addition & 0 deletions torax/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class Constants:
epsilon0: chex.Numeric
mu0: chex.Numeric
eps: chex.Numeric
c: chex.Numeric


CONSTANTS: Final[Constants] = Constants(
Expand Down
207 changes: 207 additions & 0 deletions torax/transport_model/tests/tglf_based_transport_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for torax.transport_model.tglf_based_transport_model."""
from absl.testing import absltest
from absl.testing import parameterized
import chex
import jax.numpy as jnp
from torax import core_profile_setters
from torax import state
from torax.config import runtime_params as general_runtime_params
from torax.config import runtime_params_slice
from torax.geometry import geometry
from torax.pedestal_model import pedestal_model as pedestal_model_lib
from torax.pedestal_model import set_tped_nped
from torax.sources import source_models as source_models_lib
from torax.transport_model import tglf_based_transport_model
from torax.transport_model import quasilinear_transport_model
from torax.transport_model import runtime_params as runtime_params_lib


def _get_model_inputs(transport: tglf_based_transport_model.RuntimeParams):
"""Returns the model inputs for testing."""
runtime_params = general_runtime_params.GeneralRuntimeParams()
geo = geometry.build_circular_geometry()
source_models_builder = source_models_lib.SourceModelsBuilder()
source_models = source_models_builder()
pedestal_model_builder = (
set_tped_nped.SetTemperatureDensityPedestalModelBuilder()
)
dynamic_runtime_params_slice = (
runtime_params_slice.DynamicRuntimeParamsSliceProvider(
runtime_params=runtime_params,
transport=transport,
sources=source_models_builder.runtime_params,
pedestal=pedestal_model_builder.runtime_params,
torax_mesh=geo.torax_mesh,
)(
t=runtime_params.numerics.t_initial,
)
)
static_slice = runtime_params_slice.build_static_runtime_params_slice(
runtime_params=runtime_params,
source_runtime_params=source_models_builder.runtime_params,
torax_mesh=geo.torax_mesh,
)
core_profiles = core_profile_setters.initial_core_profiles(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_slice,
geo=geo,
source_models=source_models,
)
return dynamic_runtime_params_slice, geo, core_profiles


class TGLFBasedTransportModelTest(parameterized.TestCase):
"""Unit tests for the `torax.transport_model.tglf_based_transport_model` module."""

def test_tglf_based_transport_model_output_shapes(self):
"""Tests that the core transport output has the right shapes."""
transport = tglf_based_transport_model.RuntimeParams(
**runtime_params_lib.RuntimeParams()
)
transport_model = FakeTGLFBasedTransportModel()
dynamic_runtime_params_slice, geo, core_profiles = _get_model_inputs(
transport
)
pedestal_model = set_tped_nped.SetTemperatureDensityPedestalModel()
pedestal_model_outputs = pedestal_model(
dynamic_runtime_params_slice, geo, core_profiles
)

core_transport = transport_model(
dynamic_runtime_params_slice, geo, core_profiles, pedestal_model_outputs
)
expected_shape = geo.rho_face_norm.shape
self.assertEqual(core_transport.chi_face_ion.shape, expected_shape)
self.assertEqual(core_transport.chi_face_el.shape, expected_shape)
self.assertEqual(core_transport.d_face_el.shape, expected_shape)
self.assertEqual(core_transport.v_face_el.shape, expected_shape)

def test_tglf_based_transport_model_prepare_tglf_inputs_shapes(self):
"""Tests that the tglf inputs have the expected shapes."""
transport = tglf_based_transport_model.RuntimeParams(
**runtime_params_lib.RuntimeParams()
)
dynamic_runtime_params_slice, geo, core_profiles = _get_model_inputs(
transport
)
transport_model = FakeTGLFBasedTransportModel()
tglf_inputs = transport_model._prepare_tglf_inputs(
Zeff_face=dynamic_runtime_params_slice.plasma_composition.Zeff_face,
q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor,
geo=geo,
core_profiles=core_profiles,
)

# Inputs that are 1D
vector_keys = [
'chiGB',
'lref_over_lti',
'lref_over_lte',
'lref_over_lne',
'lref_over_lni0',
'lref_over_lni1',
'Ti_over_Te',
'drmaj',
'q',
's_hat',
'nu_ee',
'kappa',
'kappa_shear',
'delta',
'delta_shear',
'beta_e',
'Zeff',
]
# Inputs that are 0D
scalar_keys = ['Rmaj', 'Rmin']

expected_vector_length = geo.rho_face_norm.shape[0]
for key in vector_keys:
try:
self.assertEqual(
getattr(tglf_inputs, key).shape, (expected_vector_length,)
)
except Exception as e:
print(key, getattr(tglf_inputs, key))
raise e
for key in scalar_keys:
self.assertEqual(getattr(tglf_inputs, key).shape, ())


class FakeTGLFBasedTransportModel(
tglf_based_transport_model.TGLFBasedTransportModel
):
"""Fake TGLFBasedTransportModel for testing purposes."""

def __init__(self):
super().__init__()
self._frozen = True

# pylint: disable=invalid-name
def prepare_tglf_inputs(
self,
Zeff_face: chex.Array,
q_correction_factor: chex.Numeric,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
) -> tglf_based_transport_model.TGLFInputs:
"""Exposing prepare_tglf_inputs for testing."""
return self._prepare_tglf_inputs(
Zeff_face=Zeff_face,
q_correction_factor=q_correction_factor,
geo=geo,
core_profiles=core_profiles,
)

# pylint: enable=invalid-name

def _call_implementation(
self,
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
pedestal_model_output: pedestal_model_lib.PedestalModelOutput,
) -> state.CoreTransport:
tglf_inputs = self._prepare_tglf_inputs(
Zeff_face=dynamic_runtime_params_slice.plasma_composition.Zeff_face,
q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor,
geo=geo,
core_profiles=core_profiles,
)

transport = dynamic_runtime_params_slice.transport
# Assert required for pytype.
assert isinstance(
transport,
tglf_based_transport_model.DynamicRuntimeParams,
)

return self._make_core_transport(
qi=jnp.ones(geo.rho_face_norm.shape) * 0.4,
qe=jnp.ones(geo.rho_face_norm.shape) * 0.5,
pfe=jnp.ones(geo.rho_face_norm.shape) * 1.6,
quasilinear_inputs=tglf_inputs,
transport=transport,
geo=geo,
core_profiles=core_profiles,
gradient_reference_length=geo.Rmaj, # TODO
gyrobohm_flux_reference_length=geo.Rmin, # TODO
)


if __name__ == '__main__':
absltest.main()
71 changes: 40 additions & 31 deletions torax/transport_model/tglf_based_transport_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import chex
from jax import numpy as jnp

from torax import geometry
from torax.geometry import geometry
from torax import physics
from torax import state
from torax.constants import CONSTANTS
Expand Down Expand Up @@ -60,8 +60,8 @@ class TGLFInputs(quasilinear_transport_model.QuasilinearInputs):

# Ti/Te
Ti_over_Te: chex.Array
# dRmaj/dr
dRmaj: chex.Array
# drmaj/dr (flux surface centroid major radius gradient)
drmaj: chex.Array
# q
q: chex.Array
# r/q dq/dr
Expand All @@ -88,17 +88,18 @@ class TGLFBasedTransportModel(
"""Base class for TGLF-based transport models."""

def _prepare_tglf_inputs(
self,
Zeff_face: chex.Array,
q_correction_factor: chex.Numeric,
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
) -> TGLFInputs:
## Shorthand 'standard' variables
# Shorthand 'standard' variables
Te_keV = core_profiles.temp_el.face_value()
Te_eV = Te_keV * 1e3
Te_J = Te_keV * CONSTANTS.keV2J
Ti_keV = core_profiles.temp_ion.face_value()
ne = core_profiles.ne * core_profiles.nref
ne = core_profiles.ne.face_value() * core_profiles.nref
# q must be recalculated since in the nonlinear solver psi has intermediate
# states in the iterative solve
q, _ = physics.calc_q_from_psi(
Expand All @@ -107,29 +108,33 @@ def _prepare_tglf_inputs(
q_correction_factor=q_correction_factor,
)

## Reference values used for TGLF-specific normalisation
# https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization
# https://gafusion.github.io/doc/geometry.html#effective-field
# B_unit = 1/r d(psi_tor)/dr = q/r dpsi/dr
# Note: TGLF uses geo.rmid = (Rmax - Rmin)/2 as the radial coordinate
# This means all gradients are calculated w.r.t. rmid
m_D_amu = 2.014 # Mass of deuterium
# Reference values used for TGLF-specific normalisation
# - 'a' in TGLF means the minor radius at the LCFS
# - 'r' in TGLF means the flux surface centroid minor radius. Gradients are
# taken w.r.t. r
# https://gafusion.github.io/doc/tglf/tglf_list.html#rmin-loc
# - B_unit = 1/r d(psi_tor)/dr = q/r dpsi/dr
# https://gafusion.github.io/doc/geometry.html#effective-field
# - c_s (ion sound speed)
# https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization
m_D_amu = 2.014 # Mass of deuterium - TODO: load from lookup table
m_D = m_D_amu * CONSTANTS.mp # Mass of deuterium
c_s = (Te_J / m_D) ** 0.5
a = geo.Rmin[-1] # Minor radius at LCFS
B_unit = q / (geo.rmid) * jnp.gradient(core_profiles.psi, geo.rmid)
a = geo.Rmin # Device minor radius at LCFS
r = geo.rmid_face # Flux surface centroid minor radius
B_unit = q / r * jnp.gradient(core_profiles.psi.face_value(), r)

## Dimensionless gradients, eg lref_over_lti where lref=amin, lti = -ti / (dti/dr)
# Dimensionless gradients
normalized_log_gradients = quasilinear_transport_model.NormalizedLogarithmicGradients.from_profiles(
core_profiles=core_profiles,
radial_coordinate=geo.rmid,
radial_coordinate=geo.rmid, # TODO: Why does this have to be a variable on the cell grid?
reference_length=a,
)

## Dimensionless temperature ratio
# Dimensionless temperature ratio
Ti_over_Te = Ti_keV / Te_keV

## Dimensionless electron-electron collision frequency = nu_ee / (c_s/a)
# Dimensionless electron-electron collision frequency = nu_ee / (c_s/a)
# https://gafusion.github.io/doc/tglf/tglf_list.html#xnue
# https://gafusion.github.io/doc/cgyro/cgyro_list.html#cgyro-nu-ee
# Note: In the TGLF docs, XNUE is mislabelled as electron-ion collision frequency.
Expand All @@ -143,35 +148,39 @@ def _prepare_tglf_inputs(
)
nu_ee = normalised_nu_ee / (c_s / a)

## Safety factor, q
# Safety factor, q
# https://gafusion.github.io/doc/tglf/tglf_list.html#q-sa
# defined before

## Safety factor shear, s_hat = r/q dq/dr
# Safety factor shear, s_hat = r/q dq/dr
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-shat-sa
# Note: calc_s_from_psi_rmid gives rq dq/dr, so we divide by q**2
# r_mid = r
s_hat = physics.calc_s_from_psi_rmid(geo, core_profiles.psi) / q**2

## Electron beta
# Electron beta
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-betae
# Note: Te in eV
beta_e = 8 * jnp.pi * ne * Te_eV / B_unit**2

## Major radius shear = dRmaj/dr
# Major radius shear = drmaj/drmin, where 'rmaj' is the flux surface centroid
# major radius and 'rmin' the flux surface centroid minor radius
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-drmajdx-loc
dRmaj = jnp.gradient(geo.Rmaj, geo.rmid)
rmaj = (
geo.Rin_face + geo.Rout_face
) / 2 # Flux surface centroid maj radius
drmaj = jnp.gradient(rmaj, r)

## Elongation shear = r/kappa dkappa/dr
# Elongation shear = r/kappa dkappa/dr
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-s-kappa-loc
kappa = geo.elongation_face
kappa_shear = geo.rmid_face / kappa * jnp.gradient(kappa, geo.rmid_face)
kappa_shear = geo.rmid_face / kappa * jnp.gradient(kappa, r)

## Triangularity shear = r ddelta/dr
# Triangularity shear = r ddelta/dr
# https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-s-delta-loc
delta = geo.delta_face
delta_shear = geo.rmid_face * jnp.gradient(delta, geo.rmid_face)
delta_shear = r * jnp.gradient(geo.delta_face, r)

## Gyrobohm diffusivity
# Gyrobohm diffusivity
# https://gafusion.github.io/doc/tglf/tglf_table.html#id7
# https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization
# Note: TGLF uses the same normalisation as CGYRO
Expand All @@ -198,13 +207,13 @@ def _prepare_tglf_inputs(
lref_over_lni1=normalized_log_gradients.lref_over_lni1,
# From TGLFInputs
Ti_over_Te=Ti_over_Te,
dRmaj=dRmaj,
drmaj=drmaj,
q=q,
s_hat=s_hat,
nu_ee=nu_ee,
kappa=kappa,
kappa_shear=kappa_shear,
delta=delta,
delta=geo.delta_face,
delta_shear=delta_shear,
beta_e=beta_e,
Zeff=Zeff_face,
Expand Down

0 comments on commit e63a27f

Please sign in to comment.