Skip to content

Commit

Permalink
Rework jax compilation order
Browse files Browse the repository at this point in the history
  • Loading branch information
HCookie committed Dec 11, 2024
1 parent 91d43c6 commit 82580b1
Showing 1 changed file with 2 additions and 7 deletions.
9 changes: 2 additions & 7 deletions src/ai_models_gencast/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,9 @@ def run_forward(
LOG.info("Model license: %s", self.ckpt.license)

jax.jit(self._with_configs(run_forward.init))
run_forward_jitted = jax.jit(
lambda rng, inputs, targets_template, forcings: run_forward.apply(
rng, inputs, targets_template, forcings
)
)
# We also produce a pmapped version for running in parallel.

self.model = xarray_jax.pmap(
self._with_params(self._with_configs(self._drop_state(run_forward_jitted))), dim="sample"
jax.jit(self._with_params(self._with_configs(self._drop_state(run_forward.apply)))), dim="sample"
)

def run(self):
Expand Down

0 comments on commit 82580b1

Please sign in to comment.