Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pydantic types for time-varying-scalars and time-varying-arrays. #684

Merged
merged 1 commit into from
Feb 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 88 additions & 0 deletions torax/torax_pydantic/interpolated_param_1d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes and functions for defining interpolated parameters."""

import functools
from typing import Any

import pydantic
from torax import interpolated_param
from torax.torax_pydantic import interpolated_param_common
from torax.torax_pydantic import model_base


class TimeVaryingScalar(interpolated_param_common.TimeVaryingBase):
"""Base class for time interpolated scalar types.

The Pydantic `.model_validate` constructor can accept a variety of input types
defined by the `TimeInterpolatedInput` type. See
https://torax.readthedocs.io/en/latest/configuration.html#time-varying-scalars
for more details.

Attributes:
time: A 1-dimensional NumPy array of times.
value: A NumPy array specifying the values to interpolate.
is_bool_param: If True, the input value is assumed to be a bool and is
converted to a float.
interpolation_mode: An InterpolationMode enum specifying the interpolation
mode to use.
"""

time: model_base.NumpyArray1D
value: model_base.NumpyArray
is_bool_param: bool = False
interpolation_mode: interpolated_param.InterpolationMode = (
interpolated_param.InterpolationMode.PIECEWISE_LINEAR
)

@pydantic.model_validator(mode='before')
@classmethod
def _conform_data(
cls, data: interpolated_param.TimeInterpolatedInput | dict[str, Any]
) -> dict[str, Any]:

if isinstance(data, dict):
# A workaround for https://github.com/pydantic/pydantic/issues/10477.
data.pop('_get_cached_interpolated_param', None)

# This is the standard constructor input. No conforming required.
if set(data.keys()).issubset(cls.model_fields.keys()):
return data # pytype: disable=bad-return-type

time, value, interpolation_mode, is_bool_param = (
interpolated_param.convert_input_to_xs_ys(data)
)
return dict(
time=time,
value=value,
interpolation_mode=interpolation_mode,
is_bool_param=is_bool_param,
)

@functools.cached_property
def _get_cached_interpolated_param(
self,
) -> interpolated_param.InterpolatedVarSingleAxis:
"""Interpolates the input param at time t.

Returns:
A constructed interpolated var.
"""

return interpolated_param.InterpolatedVarSingleAxis(
value=(self.time, self.value),
interpolation_mode=self.interpolation_mode,
is_bool_param=self.is_bool_param,
)
225 changes: 225 additions & 0 deletions torax/torax_pydantic/interpolated_param_2d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Classes and functions for defining interpolated parameters."""

from collections.abc import Mapping
import functools
from typing import Any
import chex
import pydantic
from torax import interpolated_param
from torax.torax_pydantic import interpolated_param_common
from torax.torax_pydantic import model_base
import xarray as xr


class TimeVaryingArray(interpolated_param_common.TimeVaryingBase):
"""Base class for time interpolated array types.

The Pydantic `.model_validate` constructor can accept a variety of input types
defined by the `TimeRhoInterpolatedInput` type. See
https://torax.readthedocs.io/en/latest/configuration.html#time-varying-arrays
for more details.

Attributes:
value: A mapping of the form `{time: (rho_norm, values), ...}`, where
`rho_norm` and `values` are 1D NumPy arrays of equal length.
rho_interpolation_mode: The interpolation mode to use for the rho axis.
time_interpolation_mode: The interpolation mode to use for the time axis.
rho_norm_grid: The rho norm grid to use for the interpolation.
"""

value: Mapping[float, tuple[model_base.NumpyArray1D, model_base.NumpyArray1D]]
rho_interpolation_mode: interpolated_param.InterpolationMode = (
interpolated_param.InterpolationMode.PIECEWISE_LINEAR
)
time_interpolation_mode: interpolated_param.InterpolationMode = (
interpolated_param.InterpolationMode.PIECEWISE_LINEAR
)
rho_norm_grid: model_base.NumpyArray | None = None

@pydantic.model_validator(mode='before')
@classmethod
def _conform_data(
cls, data: interpolated_param.TimeRhoInterpolatedInput | dict[str, Any]
) -> dict[str, Any]:

if isinstance(data, dict):
# A workaround for https://github.com/pydantic/pydantic/issues/10477.
data.pop('_get_cached_interpolated_param', None)

# This is the standard constructor input. No conforming required.
if set(data.keys()).issubset(cls.model_fields.keys()):
return data

# Potentially parse the interpolation modes from the input.
time_interpolation_mode = (
interpolated_param.InterpolationMode.PIECEWISE_LINEAR
)
rho_interpolation_mode = (
interpolated_param.InterpolationMode.PIECEWISE_LINEAR
)

if isinstance(data, tuple):
if len(data) == 2 and isinstance(data[1], dict):
time_interpolation_mode = interpolated_param.InterpolationMode[
data[1]['time_interpolation_mode'].upper()
]
rho_interpolation_mode = interpolated_param.InterpolationMode[
data[1]['rho_interpolation_mode'].upper()
]
# First element in tuple assumed to be the input.
data = data[0]

if isinstance(data, xr.DataArray):
value = _load_from_xr_array(data)
elif isinstance(data, tuple) and all(
isinstance(v, chex.Array) for v in data
):
value = _load_from_arrays(
data,
)
elif isinstance(data, Mapping) or isinstance(data, (float, int)):
value = _load_from_primitives(data)
else:
raise ValueError('Input to TimeVaryingArray unsupported.')

return dict(
value=value,
time_interpolation_mode=time_interpolation_mode,
rho_interpolation_mode=rho_interpolation_mode,
)

@functools.cached_property
def _get_cached_interpolated_param(
self,
) -> interpolated_param.InterpolatedVarTimeRho:
if self.rho_norm_grid is None:
raise ValueError('grid must be set.')
return interpolated_param.InterpolatedVarTimeRho(
self.value,
rho_norm=self.rho_norm_grid,
time_interpolation_mode=self.time_interpolation_mode,
rho_interpolation_mode=self.rho_interpolation_mode,
)


def _load_from_primitives(
primitive_values: (
Mapping[float, interpolated_param.InterpolatedVarSingleAxisInput]
| float
),
) -> Mapping[float, tuple[chex.Array, chex.Array]]:
"""Loads the data from primitives.

Three cases are supported:
1. A float is passed in, describes constant initial condition profile.
2. A non-nested dict is passed in, it will describe the radial profile for
the initial condition.
3. A nested dict is passed in, it will describe a time-dependent radial
profile providing both initial condition and prescribed values at times beyond

Args:
primitive_values: The python primitive values to load.

Returns:
A mapping from time to (rho_norm, values) where rho_norm and values are both
arrays of equal length.
"""
# Float case.
if isinstance(primitive_values, (float, int)):
primitive_values = {0.0: {0.0: primitive_values}}
# Non-nested dict.
if isinstance(primitive_values, Mapping) and all(
isinstance(v, float) for v in primitive_values.values()
):
primitive_values = {0.0: primitive_values}

if len(set(primitive_values.keys())) != len(primitive_values):
raise ValueError('Indicies in values mapping must be unique.')
if not primitive_values:
raise ValueError('Values mapping must not be empty.')

loaded_values = {}
for t, v in primitive_values.items():
x, y, _, _ = interpolated_param.convert_input_to_xs_ys(v)
loaded_values[t] = (x, y)

return loaded_values


def _load_from_xr_array(
xr_array: xr.DataArray,
) -> Mapping[float, tuple[chex.Array, chex.Array]]:
"""Loads the data from an xr.DataArray."""
if 'time' not in xr_array.coords:
raise ValueError('"time" must be a coordinate in given dataset.')
if interpolated_param.RHO_NORM not in xr_array.coords:
raise ValueError(
f'"{interpolated_param.RHO_NORM}" must be a coordinate in given'
' dataset.'
)
values = {
t: (
xr_array.rho_norm.data,
xr_array.sel(time=t).values,
)
for t in xr_array.time.data
}
return values


def _load_from_arrays(
arrays: tuple[chex.Array, ...],
) -> Mapping[float, tuple[chex.Array, chex.Array]]:
"""Loads the data from numpy arrays.

Args:
arrays: A tuple of (times, rho_norm, values) or (rho_norm, values). - In the
former case times and rho_norm are assumed to be 1D arrays of equal
length, values is a 2D array with shape (len(times), len(rho_norm)). - In
the latter case rho_norm and values are assumed to be 1D arrays of equal
length (shortcut for initial condition profile).

Returns:
A mapping from time to (rho_norm, values)
"""
if len(arrays) == 2:
# Shortcut for initial condition profile.
rho_norm, values = arrays
if len(rho_norm.shape) != 1:
raise ValueError(f'rho_norm must be a 1D array. Given: {rho_norm.shape}.')
if len(values.shape) != 1:
raise ValueError(f'values must be a 1D array. Given: {values.shape}.')
if rho_norm.shape != values.shape:
raise ValueError(
'rho_norm and values must be of the same shape. Given: '
f'{rho_norm.shape} and {values.shape}.'
)
return {0.0: (rho_norm, values)}
if len(arrays) == 3:
times, rho_norm, values = arrays
if len(times.shape) != 1:
raise ValueError(f'times must be a 1D array. Given: {times.shape}.')
if len(rho_norm.shape) != 1:
raise ValueError(f'rho_norm must be a 1D array. Given: {rho_norm.shape}.')
if values.shape != (len(times), len(rho_norm)):
raise ValueError(
'values must be of shape (len(times), len(rho_norm)). Given: '
f'{values.shape}.'
)
return {t: (rho_norm, values[i, :]) for i, t in enumerate(times)}
else:
raise ValueError(f'arrays must be length 2 or 3. Given: {len(arrays)}.')
68 changes: 68 additions & 0 deletions torax/torax_pydantic/interpolated_param_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common functions and classes for interpolated parameters."""

import abc
import functools
import chex
import pydantic
from torax import interpolated_param
from torax.torax_pydantic import model_base
from typing_extensions import Self


class TimeVaryingBase(model_base.BaseModelMutable):
"""Base class for time varying interpolated parameters."""

def get_value(self, x: chex.Numeric) -> chex.Array:
"""Returns the value of this parameter interpolated at x=time.

Requires self.grid to be set.

Args:
x: An array of times to interpolate at.

Returns:
An array of interpolated values.
"""
return self._get_cached_interpolated_param.get_value(x)

def __eq__(self, other):
"""Custom equality check."""

try:
chex.assert_trees_all_equal(vars(self), vars(other))
return True
except AssertionError:
return False

@functools.cached_property
@abc.abstractmethod
def _get_cached_interpolated_param(
self,
) -> (
interpolated_param.InterpolatedVarSingleAxis
| interpolated_param.InterpolatedVarTimeRho
):
"""Returns the value of this parameter interpolated at x=time."""
...

@pydantic.model_validator(mode='after')
def clear_cached_property(self) -> Self:
try:
del self._get_cached_interpolated_param
except AttributeError:
pass
return self
Loading