Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unused attributes from Geometry and remove Circular geometries. #693

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torax/config/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def _build_circular_geometry_provider(
geometries[time] = circular_geometry.build_circular_geometry(
n_rho=kwargs['n_rho'], **c
)
return circular_geometry.CircularAnalyticalGeometryProvider.create_provider(
return geometry_provider.TimeDependentGeometryProvider.create_provider(
geometries
)
return geometry_provider.ConstantGeometryProvider(
Expand Down
2 changes: 0 additions & 2 deletions torax/config/tests/build_sim.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def test_build_sim_with_full_config(self):
)
with self.subTest('geometry'):
geo = sim.geometry_provider(sim.initial_state.t)
self.assertIsInstance(geo, circular_geometry.CircularAnalyticalGeometry)
self.assertEqual(geo.torax_mesh.nx, 5)
with self.subTest('sources'):
self.assertEqual(
Expand Down Expand Up @@ -219,7 +218,6 @@ def test_build_circular_geometry(self):
)
geo = geo_provider(t=0)
np.testing.assert_array_equal(geo_provider.torax_mesh.nx, 5)
self.assertIsInstance(geo, circular_geometry.CircularAnalyticalGeometry)
np.testing.assert_array_equal(geo.B0, 5.3) # test a default.

def test_build_geometry_from_chease(self):
Expand Down
8 changes: 1 addition & 7 deletions torax/core_profile_setters.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from torax.config import profile_conditions
from torax.config import runtime_params_slice
from torax.fvm import cell_variable
from torax.geometry import circular_geometry
from torax.geometry import geometry
from torax.geometry import standard_geometry
from torax.sources import ohmic_heat_source
Expand Down Expand Up @@ -649,10 +648,7 @@ def _init_psi_and_current(
source_models=source_models,
)
# Calculating j according to nu formula and psi from j.
elif (
isinstance(geo, circular_geometry.CircularAnalyticalGeometry)
or dynamic_runtime_params_slice.profile_conditions.initial_psi_from_j
):
else:
currents = _prescribe_currents_no_bootstrap(
dynamic_runtime_params_slice=dynamic_runtime_params_slice,
static_runtime_params_slice=static_runtime_params_slice,
Expand Down Expand Up @@ -685,8 +681,6 @@ def _init_psi_and_current(
psi,
)
currents = dataclasses.replace(currents, Ip_profile_face=Ip_profile_face)
else:
raise ValueError('Cannot compute psi for given config.')

core_profiles = dataclasses.replace(core_profiles, psi=psi, currents=currents)

Expand Down
44 changes: 5 additions & 39 deletions torax/geometry/circular_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,55 +17,25 @@

from __future__ import annotations

import chex
import numpy as np
from torax import interpolated_param
from torax.geometry import geometry
from torax.geometry import geometry_provider


# Using invalid-name because we are using the same naming convention as the
# external physics implementations
# pylint: disable=invalid-name


@chex.dataclass(frozen=True)
class CircularAnalyticalGeometry(geometry.Geometry):
"""Circular geometry type used for testing only.

Most users should default to using the Geometry class.
"""

elongation_hires: chex.Array


@chex.dataclass(frozen=True)
class CircularAnalyticalGeometryProvider(
geometry_provider.TimeDependentGeometryProvider):
"""Circular geometry type used for testing only.

Most users should default to using the GeometryProvider class.
"""

elongation_hires: interpolated_param.InterpolatedVarSingleAxis

def __call__(self, t: chex.Numeric) -> geometry.Geometry:
"""Returns a Geometry instance at the given time."""
return self._get_geometry_base(t, CircularAnalyticalGeometry)


def build_circular_geometry(
n_rho: int = 25,
elongation_LCFS: float = 1.72,
Rmaj: float = 6.2,
Rmin: float = 2.0,
B0: float = 5.3,
hires_fac: int = 4,
) -> CircularAnalyticalGeometry:
"""Constructs a CircularAnalyticalGeometry.
) -> geometry.Geometry:
"""Constructs a circular Geometry.

This is the standard entrypoint for building a circular geometry, not
CircularAnalyticalGeometry.__init__(). chex.dataclasses do not allow
Geometry.__init__(). chex.dataclasses do not allow
overriding __init__ functions with different parameters than the attributes of
the dataclass, so this builder function lives outside the class.

Expand All @@ -81,7 +51,7 @@ def build_circular_geometry(
calculations.

Returns:
A CircularAnalyticalGeometry instance.
A Geometry instance.
"""
# circular geometry assumption of r/Rmin = rho_norm, the normalized
# toroidal flux coordinate.
Expand Down Expand Up @@ -217,7 +187,7 @@ def build_circular_geometry(
F_hires = np.ones(len(rho_hires)) * B0 * Rmaj
g2g3_over_rhon_hires = 4 * np.pi**2 * vpr_hires * g3_hires * B0 / F_hires

return CircularAnalyticalGeometry(
return geometry.Geometry(
# Set the standard geometry params.
geometry_type=geometry.GeometryType.CIRCULAR.value,
drho_norm=np.asarray(drho_norm),
Expand Down Expand Up @@ -257,13 +227,9 @@ def build_circular_geometry(
# Set the circular geometry-specific params.
elongation=elongation,
elongation_face=elongation_face,
volume_hires=volume_hires,
area_hires=area_hires,
spr_hires=spr_hires,
rho_hires_norm=rho_hires_norm,
rho_hires=rho_hires,
elongation_hires=elongation_hires,
vpr_hires=vpr_hires,
# always initialize Phibdot as zero. It will be replaced once both geo_t
# and geo_t_plus_dt are provided, and set to be the same for geo_t and
# geo_t_plus_dt for each given time interval.
Expand Down
3 changes: 0 additions & 3 deletions torax/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,9 @@ class Geometry:
Rin_face: chex.Array
Rout: chex.Array
Rout_face: chex.Array
volume_hires: chex.Array
area_hires: chex.Array
spr_hires: chex.Array
rho_hires_norm: chex.Array
rho_hires: chex.Array
vpr_hires: chex.Array
Phibdot: chex.Array
_z_magnetic_axis: chex.Array | None

Expand Down
3 changes: 0 additions & 3 deletions torax/geometry/geometry_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,12 +148,9 @@ class TimeDependentGeometryProvider:
Rin_face: interpolated_param.InterpolatedVarSingleAxis
Rout: interpolated_param.InterpolatedVarSingleAxis
Rout_face: interpolated_param.InterpolatedVarSingleAxis
volume_hires: interpolated_param.InterpolatedVarSingleAxis
area_hires: interpolated_param.InterpolatedVarSingleAxis
spr_hires: interpolated_param.InterpolatedVarSingleAxis
rho_hires_norm: interpolated_param.InterpolatedVarSingleAxis
rho_hires: interpolated_param.InterpolatedVarSingleAxis
vpr_hires: interpolated_param.InterpolatedVarSingleAxis
_z_magnetic_axis: interpolated_param.InterpolatedVarSingleAxis | None

@classmethod
Expand Down
6 changes: 0 additions & 6 deletions torax/geometry/standard_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,7 +931,6 @@ def build_standard_geometry(
# V' for volume integrations on face grid
vpr_face = rhon_interpolation_func(rho_face_norm, vpr)
# V' for volume integrations on cell grid
vpr_hires = rhon_interpolation_func(rho_hires_norm, vpr)
vpr = rhon_interpolation_func(rho_norm, vpr)

# S' for area integrals on face grid
Expand Down Expand Up @@ -999,11 +998,9 @@ def build_standard_geometry(
g2g3_over_rhon = rhon_interpolation_func(rho_norm, g2g3_over_rhon)

volume_face = rhon_interpolation_func(rho_face_norm, volume_intermediate)
volume_hires = rhon_interpolation_func(rho_hires_norm, volume_intermediate)
volume = rhon_interpolation_func(rho_norm, volume_intermediate)

area_face = rhon_interpolation_func(rho_face_norm, area_intermediate)
area_hires = rhon_interpolation_func(rho_hires_norm, area_intermediate)
area = rhon_interpolation_func(rho_norm, area_intermediate)

return StandardGeometry(
Expand Down Expand Up @@ -1052,12 +1049,9 @@ def build_standard_geometry(
delta_lower_face=delta_lower_face,
elongation=elongation,
elongation_face=elongation_face,
volume_hires=volume_hires,
area_hires=area_hires,
spr_hires=spr_hires,
rho_hires_norm=rho_hires_norm,
rho_hires=rho_hires,
vpr_hires=vpr_hires,
# always initialize Phibdot as zero. It will be replaced once both geo_t
# and geo_t_plus_dt are provided, and set to be the same for geo_t and
# geo_t_plus_dt for each given time interval.
Expand Down
3 changes: 2 additions & 1 deletion torax/geometry/tests/circular_geometry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
from torax.geometry import circular_geometry
from torax.geometry import geometry
from torax.geometry import geometry_provider


class CircularGeometryTest(absltest.TestCase):
Expand All @@ -40,7 +41,7 @@ def test_build_geometry_provider_from_circular(self):
hires_fac=4,
)
provider = (
circular_geometry.CircularAnalyticalGeometryProvider.create_provider(
geometry_provider.TimeDependentGeometryProvider.create_provider(
{0.0: geo_0, 10.0: geo_1}
)
)
Expand Down
6 changes: 3 additions & 3 deletions torax/sources/runtime_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@
class Mode(enum.Enum):
"""Defines how to compute the source terms for this source/sink."""

# Source is set to zero always. This is an explicit source by definition.
# Source is set to zero always.
ZERO = 0

# Source values come from a model in code. These terms can be implicit or
# explicit depending on the model implementation.
MODEL_BASED = 1

# Source values come from a pre-determined set of values, that may evolve in
# time. Values can be drawn from a file or an array. These sources are always
# explicit.
# time.
# Prescribed doesn't work for multi term sources.
PRESCRIBED = 2


Expand Down
11 changes: 3 additions & 8 deletions torax/tests/physics.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,10 @@ def test_update_psi_from_j(
geo,
source_models=source_models,
)

# pylint: disable=protected-access
if isinstance(geo, circular_geometry.CircularAnalyticalGeometry):
if isinstance(geo, standard_geometry.StandardGeometry):
psi = geo.psi_from_Ip
else:
currents = core_profile_setters._prescribe_currents_no_bootstrap(
static_slice,
dynamic_runtime_params_slice,
Expand All @@ -147,13 +148,7 @@ def test_update_psi_from_j(
psi = core_profile_setters._update_psi_from_j(
dynamic_runtime_params_slice, geo, currents.jtot_hires
).value
elif isinstance(geo, standard_geometry.StandardGeometry):
psi = geo.psi_from_Ip
else:
raise ValueError(f'Unknown geometry type: {geo.geometry_type}')
# pylint: enable=protected-access
print(psi)

np.testing.assert_allclose(psi, references.psi.value)

@parameterized.parameters([
Expand Down