diff --git a/torax/constants.py b/torax/constants.py index f12ff117..27b64a66 100644 --- a/torax/constants.py +++ b/torax/constants.py @@ -51,6 +51,7 @@ class Constants: epsilon0: chex.Numeric mu0: chex.Numeric eps: chex.Numeric + c: chex.Numeric CONSTANTS: Final[Constants] = Constants( diff --git a/torax/transport_model/tests/tglf_based_transport_model.py b/torax/transport_model/tests/tglf_based_transport_model.py new file mode 100644 index 00000000..016aa827 --- /dev/null +++ b/torax/transport_model/tests/tglf_based_transport_model.py @@ -0,0 +1,207 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for torax.transport_model.tglf_based_transport_model.""" +from absl.testing import absltest +from absl.testing import parameterized +import chex +import jax.numpy as jnp +from torax import core_profile_setters +from torax import state +from torax.config import runtime_params as general_runtime_params +from torax.config import runtime_params_slice +from torax.geometry import geometry +from torax.pedestal_model import pedestal_model as pedestal_model_lib +from torax.pedestal_model import set_tped_nped +from torax.sources import source_models as source_models_lib +from torax.transport_model import tglf_based_transport_model +from torax.transport_model import quasilinear_transport_model +from torax.transport_model import runtime_params as runtime_params_lib + + +def _get_model_inputs(transport: tglf_based_transport_model.RuntimeParams): + """Returns the model inputs for testing.""" + runtime_params = general_runtime_params.GeneralRuntimeParams() + geo = geometry.build_circular_geometry() + source_models_builder = source_models_lib.SourceModelsBuilder() + source_models = source_models_builder() + pedestal_model_builder = ( + set_tped_nped.SetTemperatureDensityPedestalModelBuilder() + ) + dynamic_runtime_params_slice = ( + runtime_params_slice.DynamicRuntimeParamsSliceProvider( + runtime_params=runtime_params, + transport=transport, + sources=source_models_builder.runtime_params, + pedestal=pedestal_model_builder.runtime_params, + torax_mesh=geo.torax_mesh, + )( + t=runtime_params.numerics.t_initial, + ) + ) + static_slice = runtime_params_slice.build_static_runtime_params_slice( + runtime_params=runtime_params, + source_runtime_params=source_models_builder.runtime_params, + torax_mesh=geo.torax_mesh, + ) + core_profiles = core_profile_setters.initial_core_profiles( + dynamic_runtime_params_slice=dynamic_runtime_params_slice, + static_runtime_params_slice=static_slice, + geo=geo, + source_models=source_models, + ) + return dynamic_runtime_params_slice, geo, core_profiles + + +class TGLFBasedTransportModelTest(parameterized.TestCase): + """Unit tests for the `torax.transport_model.tglf_based_transport_model` module.""" + + def test_tglf_based_transport_model_output_shapes(self): + """Tests that the core transport output has the right shapes.""" + transport = tglf_based_transport_model.RuntimeParams( + **runtime_params_lib.RuntimeParams() + ) + transport_model = FakeTGLFBasedTransportModel() + dynamic_runtime_params_slice, geo, core_profiles = _get_model_inputs( + transport + ) + pedestal_model = set_tped_nped.SetTemperatureDensityPedestalModel() + pedestal_model_outputs = pedestal_model( + dynamic_runtime_params_slice, geo, core_profiles + ) + + core_transport = transport_model( + dynamic_runtime_params_slice, geo, core_profiles, pedestal_model_outputs + ) + expected_shape = geo.rho_face_norm.shape + self.assertEqual(core_transport.chi_face_ion.shape, expected_shape) + self.assertEqual(core_transport.chi_face_el.shape, expected_shape) + self.assertEqual(core_transport.d_face_el.shape, expected_shape) + self.assertEqual(core_transport.v_face_el.shape, expected_shape) + + def test_tglf_based_transport_model_prepare_tglf_inputs_shapes(self): + """Tests that the tglf inputs have the expected shapes.""" + transport = tglf_based_transport_model.RuntimeParams( + **runtime_params_lib.RuntimeParams() + ) + dynamic_runtime_params_slice, geo, core_profiles = _get_model_inputs( + transport + ) + transport_model = FakeTGLFBasedTransportModel() + tglf_inputs = transport_model._prepare_tglf_inputs( + Zeff_face=dynamic_runtime_params_slice.plasma_composition.Zeff_face, + q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor, + geo=geo, + core_profiles=core_profiles, + ) + + # Inputs that are 1D + vector_keys = [ + 'chiGB', + 'lref_over_lti', + 'lref_over_lte', + 'lref_over_lne', + 'lref_over_lni0', + 'lref_over_lni1', + 'Ti_over_Te', + 'drmaj', + 'q', + 's_hat', + 'nu_ee', + 'kappa', + 'kappa_shear', + 'delta', + 'delta_shear', + 'beta_e', + 'Zeff', + ] + # Inputs that are 0D + scalar_keys = ['Rmaj', 'Rmin'] + + expected_vector_length = geo.rho_face_norm.shape[0] + for key in vector_keys: + try: + self.assertEqual( + getattr(tglf_inputs, key).shape, (expected_vector_length,) + ) + except Exception as e: + print(key, getattr(tglf_inputs, key)) + raise e + for key in scalar_keys: + self.assertEqual(getattr(tglf_inputs, key).shape, ()) + + +class FakeTGLFBasedTransportModel( + tglf_based_transport_model.TGLFBasedTransportModel +): + """Fake TGLFBasedTransportModel for testing purposes.""" + + def __init__(self): + super().__init__() + self._frozen = True + + # pylint: disable=invalid-name + def prepare_tglf_inputs( + self, + Zeff_face: chex.Array, + q_correction_factor: chex.Numeric, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + ) -> tglf_based_transport_model.TGLFInputs: + """Exposing prepare_tglf_inputs for testing.""" + return self._prepare_tglf_inputs( + Zeff_face=Zeff_face, + q_correction_factor=q_correction_factor, + geo=geo, + core_profiles=core_profiles, + ) + + # pylint: enable=invalid-name + + def _call_implementation( + self, + dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice, + geo: geometry.Geometry, + core_profiles: state.CoreProfiles, + pedestal_model_output: pedestal_model_lib.PedestalModelOutput, + ) -> state.CoreTransport: + tglf_inputs = self._prepare_tglf_inputs( + Zeff_face=dynamic_runtime_params_slice.plasma_composition.Zeff_face, + q_correction_factor=dynamic_runtime_params_slice.numerics.q_correction_factor, + geo=geo, + core_profiles=core_profiles, + ) + + transport = dynamic_runtime_params_slice.transport + # Assert required for pytype. + assert isinstance( + transport, + tglf_based_transport_model.DynamicRuntimeParams, + ) + + return self._make_core_transport( + qi=jnp.ones(geo.rho_face_norm.shape) * 0.4, + qe=jnp.ones(geo.rho_face_norm.shape) * 0.5, + pfe=jnp.ones(geo.rho_face_norm.shape) * 1.6, + quasilinear_inputs=tglf_inputs, + transport=transport, + geo=geo, + core_profiles=core_profiles, + gradient_reference_length=geo.Rmaj, # TODO + gyrobohm_flux_reference_length=geo.Rmin, # TODO + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/torax/transport_model/tglf_based_transport_model.py b/torax/transport_model/tglf_based_transport_model.py index f3de687d..acb17f73 100644 --- a/torax/transport_model/tglf_based_transport_model.py +++ b/torax/transport_model/tglf_based_transport_model.py @@ -16,7 +16,7 @@ import chex from jax import numpy as jnp -from torax import geometry +from torax.geometry import geometry from torax import physics from torax import state from torax.constants import CONSTANTS @@ -60,8 +60,8 @@ class TGLFInputs(quasilinear_transport_model.QuasilinearInputs): # Ti/Te Ti_over_Te: chex.Array - # dRmaj/dr - dRmaj: chex.Array + # drmaj/dr (flux surface centroid major radius gradient) + drmaj: chex.Array # q q: chex.Array # r/q dq/dr @@ -88,17 +88,18 @@ class TGLFBasedTransportModel( """Base class for TGLF-based transport models.""" def _prepare_tglf_inputs( + self, Zeff_face: chex.Array, q_correction_factor: chex.Numeric, geo: geometry.Geometry, core_profiles: state.CoreProfiles, ) -> TGLFInputs: - ## Shorthand 'standard' variables + # Shorthand 'standard' variables Te_keV = core_profiles.temp_el.face_value() Te_eV = Te_keV * 1e3 Te_J = Te_keV * CONSTANTS.keV2J Ti_keV = core_profiles.temp_ion.face_value() - ne = core_profiles.ne * core_profiles.nref + ne = core_profiles.ne.face_value() * core_profiles.nref # q must be recalculated since in the nonlinear solver psi has intermediate # states in the iterative solve q, _ = physics.calc_q_from_psi( @@ -107,29 +108,33 @@ def _prepare_tglf_inputs( q_correction_factor=q_correction_factor, ) - ## Reference values used for TGLF-specific normalisation - # https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization - # https://gafusion.github.io/doc/geometry.html#effective-field - # B_unit = 1/r d(psi_tor)/dr = q/r dpsi/dr - # Note: TGLF uses geo.rmid = (Rmax - Rmin)/2 as the radial coordinate - # This means all gradients are calculated w.r.t. rmid - m_D_amu = 2.014 # Mass of deuterium + # Reference values used for TGLF-specific normalisation + # - 'a' in TGLF means the minor radius at the LCFS + # - 'r' in TGLF means the flux surface centroid minor radius. Gradients are + # taken w.r.t. r + # https://gafusion.github.io/doc/tglf/tglf_list.html#rmin-loc + # - B_unit = 1/r d(psi_tor)/dr = q/r dpsi/dr + # https://gafusion.github.io/doc/geometry.html#effective-field + # - c_s (ion sound speed) + # https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization + m_D_amu = 2.014 # Mass of deuterium - TODO: load from lookup table m_D = m_D_amu * CONSTANTS.mp # Mass of deuterium c_s = (Te_J / m_D) ** 0.5 - a = geo.Rmin[-1] # Minor radius at LCFS - B_unit = q / (geo.rmid) * jnp.gradient(core_profiles.psi, geo.rmid) + a = geo.Rmin # Device minor radius at LCFS + r = geo.rmid_face # Flux surface centroid minor radius + B_unit = q / r * jnp.gradient(core_profiles.psi.face_value(), r) - ## Dimensionless gradients, eg lref_over_lti where lref=amin, lti = -ti / (dti/dr) + # Dimensionless gradients normalized_log_gradients = quasilinear_transport_model.NormalizedLogarithmicGradients.from_profiles( core_profiles=core_profiles, - radial_coordinate=geo.rmid, + radial_coordinate=geo.rmid, # TODO: Why does this have to be a variable on the cell grid? reference_length=a, ) - ## Dimensionless temperature ratio + # Dimensionless temperature ratio Ti_over_Te = Ti_keV / Te_keV - ## Dimensionless electron-electron collision frequency = nu_ee / (c_s/a) + # Dimensionless electron-electron collision frequency = nu_ee / (c_s/a) # https://gafusion.github.io/doc/tglf/tglf_list.html#xnue # https://gafusion.github.io/doc/cgyro/cgyro_list.html#cgyro-nu-ee # Note: In the TGLF docs, XNUE is mislabelled as electron-ion collision frequency. @@ -143,35 +148,39 @@ def _prepare_tglf_inputs( ) nu_ee = normalised_nu_ee / (c_s / a) - ## Safety factor, q + # Safety factor, q # https://gafusion.github.io/doc/tglf/tglf_list.html#q-sa # defined before - ## Safety factor shear, s_hat = r/q dq/dr + # Safety factor shear, s_hat = r/q dq/dr # https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-shat-sa # Note: calc_s_from_psi_rmid gives rq dq/dr, so we divide by q**2 + # r_mid = r s_hat = physics.calc_s_from_psi_rmid(geo, core_profiles.psi) / q**2 - ## Electron beta + # Electron beta # https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-betae # Note: Te in eV beta_e = 8 * jnp.pi * ne * Te_eV / B_unit**2 - ## Major radius shear = dRmaj/dr + # Major radius shear = drmaj/drmin, where 'rmaj' is the flux surface centroid + # major radius and 'rmin' the flux surface centroid minor radius # https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-drmajdx-loc - dRmaj = jnp.gradient(geo.Rmaj, geo.rmid) + rmaj = ( + geo.Rin_face + geo.Rout_face + ) / 2 # Flux surface centroid maj radius + drmaj = jnp.gradient(rmaj, r) - ## Elongation shear = r/kappa dkappa/dr + # Elongation shear = r/kappa dkappa/dr # https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-s-kappa-loc kappa = geo.elongation_face - kappa_shear = geo.rmid_face / kappa * jnp.gradient(kappa, geo.rmid_face) + kappa_shear = geo.rmid_face / kappa * jnp.gradient(kappa, r) - ## Triangularity shear = r ddelta/dr + # Triangularity shear = r ddelta/dr # https://gafusion.github.io/doc/tglf/tglf_list.html#tglf-s-delta-loc - delta = geo.delta_face - delta_shear = geo.rmid_face * jnp.gradient(delta, geo.rmid_face) + delta_shear = r * jnp.gradient(geo.delta_face, r) - ## Gyrobohm diffusivity + # Gyrobohm diffusivity # https://gafusion.github.io/doc/tglf/tglf_table.html#id7 # https://gafusion.github.io/doc/cgyro/outputs.html#output-normalization # Note: TGLF uses the same normalisation as CGYRO @@ -198,13 +207,13 @@ def _prepare_tglf_inputs( lref_over_lni1=normalized_log_gradients.lref_over_lni1, # From TGLFInputs Ti_over_Te=Ti_over_Te, - dRmaj=dRmaj, + drmaj=drmaj, q=q, s_hat=s_hat, nu_ee=nu_ee, kappa=kappa, kappa_shear=kappa_shear, - delta=delta, + delta=geo.delta_face, delta_shear=delta_shear, beta_e=beta_e, Zeff=Zeff_face,