Skip to content

Commit

Permalink
Fix recompliation of model
Browse files Browse the repository at this point in the history
  • Loading branch information
HCookie committed Dec 10, 2024
1 parent 18e7458 commit 104096d
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/ai_models_gencast/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def run_forward(
LOG.info("Model license: %s", self.ckpt.license)

jax.jit(self._with_configs(run_forward.init))
self.model = self._drop_state(self._with_params(jax.jit(self._with_configs(run_forward.apply))))
self.model = xarray_jax.pmap(
self._drop_state(self._with_params(jax.jit(self._with_configs(run_forward.apply)))), dim="sample"
)

def run(self):

Expand Down Expand Up @@ -290,7 +292,7 @@ def run(self):

chunks = []
for chunk in rollout.chunked_prediction_generator_multiple_runs(
xarray_jax.pmap(self.model, dim="sample"),
self.model,
rngs=rngs,
inputs=input_xr,
targets_template=template * np.nan,
Expand Down

0 comments on commit 104096d

Please sign in to comment.