Skip to content

Commit

Permalink
Switch geometry to using IntEnums.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726053299
  • Loading branch information
sbodenstein authored and Torax team committed Feb 12, 2025
1 parent ca3a492 commit aca365c
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 25 deletions.
2 changes: 1 addition & 1 deletion torax/geometry/circular_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def build_circular_geometry(

return CircularAnalyticalGeometry(
# Set the standard geometry params.
geometry_type=geometry.GeometryType.CIRCULAR.value,
geometry_type=geometry.GeometryType.CIRCULAR,
torax_mesh=mesh,
Phi=Phi,
Phi_face=Phi_face,
Expand Down
34 changes: 13 additions & 21 deletions torax/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
import jax
import jax.numpy as jnp
import numpy as np
from torax import jax_utils
import pydantic
from torax.torax_pydantic import torax_pydantic


@chex.dataclass(frozen=True)
class Grid1D:
class Grid1D(torax_pydantic.BaseModelFrozen):
"""Data structure defining a 1-D grid of cells with faces.
Construct via `construct` classmethod.
Expand All @@ -41,16 +41,10 @@ class Grid1D:
cell_centers: Coordinates of cell centers.
"""

nx: int
dx: float
face_centers: chex.Array
cell_centers: chex.Array

def __post_init__(self):
jax_utils.assert_rank(self.nx, 0)
jax_utils.assert_rank(self.dx, 0)
jax_utils.assert_rank(self.face_centers, 1)
jax_utils.assert_rank(self.cell_centers, 1)
nx: pydantic.PositiveInt
dx: pydantic.PositiveFloat
face_centers: torax_pydantic.NumpyArray1D
cell_centers: torax_pydantic.NumpyArray1D

def __eq__(self, other: Grid1D) -> bool:
return (
Expand Down Expand Up @@ -102,7 +96,7 @@ def face_to_cell(face: chex.Array) -> chex.Array:


@enum.unique
class GeometryType(enum.Enum):
class GeometryType(enum.IntEnum):
"""Integer enum for geometry type.
This type can be used within JAX expressions to access the geometry type
Expand Down Expand Up @@ -130,7 +124,7 @@ class Geometry:
"""

# TODO(b/356356966): extend documentation to define what each attribute is.
geometry_type: int
geometry_type: GeometryType
torax_mesh: Grid1D
Phi: chex.Array
Phi_face: chex.Array
Expand Down Expand Up @@ -248,7 +242,7 @@ def g1_over_vpr_face(self) -> jax.Array:

@property
def g1_over_vpr2_face(self) -> jax.Array:
bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:]**2
bulk = self.g1_face[..., 1:] / self.vpr_face[..., 1:] ** 2
# Correct value on-axis is 1/rho_b**2
first_element = jnp.ones_like(self.rho_b) / self.rho_b**2
return jnp.concatenate(
Expand All @@ -260,17 +254,15 @@ def z_magnetic_axis(self) -> chex.Numeric:
if z_magnetic_axis is not None:
return z_magnetic_axis
else:
raise ValueError(
'Geometry does not have a z magnetic axis.'
)
raise ValueError('Geometry does not have a z magnetic axis.')


def stack_geometries(geometries: Sequence[Geometry]) -> Geometry:
"""Batch together a sequence of geometries.
Args:
geometries: A sequence of geometries to stack. The geometries must have
the same mesh, geometry type.
geometries: A sequence of geometries to stack. The geometries must have the
same mesh, geometry type.
Returns:
A Geometry object, where each array attribute has an additional
Expand Down
2 changes: 1 addition & 1 deletion torax/geometry/geometry_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def torax_mesh(self) -> geometry.Grid1D:
class TimeDependentGeometryProvider:
"""A geometry provider which holds values to interpolate based on time."""

geometry_type: int
geometry_type: geometry.GeometryType
torax_mesh: geometry.Grid1D
drho_norm: interpolated_param.InterpolatedVarSingleAxis
Phi: interpolated_param.InterpolatedVarSingleAxis
Expand Down
2 changes: 1 addition & 1 deletion torax/geometry/standard_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1007,7 +1007,7 @@ def build_standard_geometry(
area = rhon_interpolation_func(rho_norm, area_intermediate)

return StandardGeometry(
geometry_type=intermediate.geometry_type.value,
geometry_type=intermediate.geometry_type,
torax_mesh=mesh,
Phi=Phi,
Phi_face=Phi_face,
Expand Down
4 changes: 3 additions & 1 deletion torax/geometry/tests/geometry_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ def test_stack_geometries_error_handling_different_mesh_sizes(self):
def test_stack_geometries_error_handling_different_geometry_types(self):
"""Test different geometry type error handling for stack_geometries."""
geo0 = circular_geometry.build_circular_geometry(Rmaj=1.0, B0=2.0, n_rho=10)
geo_diff_geometry_type = dataclasses.replace(geo0, geometry_type=3)
geo_diff_geometry_type = dataclasses.replace(
geo0, geometry_type=geometry.GeometryType(3)
)
with self.assertRaisesRegex(
ValueError, 'All geometries must have the same geometry type'
):
Expand Down

0 comments on commit aca365c

Please sign in to comment.