Skip to content

Commit

Permalink
Consolidate jitting in post_processing to make_outputs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 726429835
  • Loading branch information
jcitrin authored and Torax team committed Feb 13, 2025
1 parent 13d10de commit 6890048
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
7 changes: 6 additions & 1 deletion torax/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,12 @@ def _save_geometry(

# Get the variables from dataclass fields.
for field_name, data in dataclasses.asdict(self.geometry).items():
if "hires" in field_name or not isinstance(data, jax.Array):
if (
"hires" in field_name
or field_name == "geometry_type"
or field_name == "Ip_from_parameters"
or not isinstance(data, jax.Array)
):
continue
data_array = self._pack_into_data_array(
field_name,
Expand Down
7 changes: 1 addition & 6 deletions torax/post_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
}


@jax_utils.jit
def _compute_pressure(
core_profiles: state.CoreProfiles,
) -> tuple[array_typing.ArrayFloat, ...]:
Expand Down Expand Up @@ -88,7 +87,6 @@ def _compute_pressure(
)


@jax_utils.jit
def _compute_pprime(
core_profiles: state.CoreProfiles,
) -> array_typing.ArrayFloat:
Expand Down Expand Up @@ -137,7 +135,6 @@ def _compute_pprime(


# pylint: disable=invalid-name
@jax_utils.jit
def _compute_FFprime(
core_profiles: state.CoreProfiles,
geo: geometry.Geometry,
Expand Down Expand Up @@ -178,7 +175,6 @@ def _compute_FFprime(
# pylint: enable=invalid-name


@jax_utils.jit
def _compute_stored_thermal_energy(
p_el: array_typing.ArrayFloat,
p_ion: array_typing.ArrayFloat,
Expand All @@ -205,7 +201,6 @@ def _compute_stored_thermal_energy(
return wth_el, wth_ion, wth_tot


@jax_utils.jit
def _calculate_integrated_sources(
geo: geometry.Geometry,
core_profiles: state.CoreProfiles,
Expand Down Expand Up @@ -301,7 +296,6 @@ def _calculate_integrated_sources(
return integrated


@jax_utils.jit
def _calculate_q95(
psi_norm_face: array_typing.ArrayFloat,
core_profiles: state.CoreProfiles,
Expand All @@ -320,6 +314,7 @@ def _calculate_q95(
return q95


@jax_utils.jit
def make_outputs(
sim_state: state.ToraxSimState,
geo: geometry.Geometry,
Expand Down

0 comments on commit 6890048

Please sign in to comment.