Skip to content

Commit

Permalink
Allow for more control over member number
Browse files Browse the repository at this point in the history
- fcst mode
- Control rng seed
  • Loading branch information
HCookie committed Dec 10, 2024
1 parent daacf27 commit 18e7458
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 15 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,15 @@ For the slower CPU usage:
```
pip install -r requirements.txt
```

## Specifying ensemble numbers

There are three ways to control the ensemble members and behaviour of the `GenCast` `ai-model`.

| Description | Args | Result |
| ----------- | ---- | ------ |
| `type=fc`, single member | `--num-ensemble-members 0` | Will create a `grib` file of `type=fc` |
| N members per process with ID = `range(num-ensemble-members)` | `--num-ensemble-members $N>1` | N ensemble members created all in same process, with id from the range|
| N members per process with controlled ID | `--num-ensemble-members $N>1` `--member-number 1,2...N` | N ensemble members created all in same process, with id controlled from `member-number` |

With these approaches it is possible to create either a single forecast, many ensembles in a single process, or many ensembles over many processes.
35 changes: 33 additions & 2 deletions src/ai_models_gencast/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,20 @@ def __init__(self, **kwargs):
f"{param}{level}" for param in self.param_level_pl[0] for level in self.param_level_pl[1]
]

if isinstance(self.member_number, str):
self.member_number = list(map(int, self.member_number.split(",")))
elif isinstance(self.member_number, int):
self.member_number = [int(self.member_number)]
elif self.member_number is None:
self.member_number = list(range(1, self.num_ensemble_members + 1))
else:
raise TypeError(f"`member_number` must be a string or int, not {type(self.member_number)}")

if not len(self.member_number) == self.num_ensemble_members:
raise ValueError(
f"Number of ensemble members must match `member_number`,\nNot {self.num_ensemble_members=} and {self.member_number=}"
)

# Jax doesn't seem to like passing configs as args through the jit. Passing it
# in via partial (instead of capture by closure) forces jax to invalidate the
# jit cache if you change configs.
Expand Down Expand Up @@ -207,14 +221,26 @@ def run_forward(

def run(self):

oper_fcst: bool = False
if self.num_ensemble_members == 0:
oper_fcst = True
# Set the number of ensemble members to 1, and id to 0.
self.num_ensemble_members = 1
self.member_number = [0]
self.grib_extra_metadata = {"type": "fc", "stream": "oper"}

if not (self.num_ensemble_members % len(jax.local_devices())) == 0:
raise ValueError(
f"Number of ensemble members must be divisible by number of devices, not {self.num_ensemble_members} and {len(jax.local_devices())}"
)

# Remove extra metadata to save input fields
_metadata = self.grib_extra_metadata
self.grib_extra_metadata = {}
# We ignore 'tp' so that we make sure that step 0 is a field of zero values
self.write_input_fields(self.fields_sfc, ignore=["tp"], accumulations=["tp"])
self.write_input_fields(self.fields_pl)
self.grib_extra_metadata = _metadata

with self.timer("Building model"):
self.load_model()
Expand Down Expand Up @@ -260,7 +286,7 @@ def run(self):
# We fold-in the ensemble member, this way the first N members should always
# match across different runs which use take the same inputs
# regardless of total ensemble size.
rngs = np.stack([jax.random.fold_in(rng, i) for i in range(self.num_ensemble_members)], axis=0)
rngs = np.stack([jax.random.fold_in(rng, i) for i in self.member_number], axis=0)

chunks = []
for chunk in rollout.chunked_prediction_generator_multiple_runs(
Expand Down Expand Up @@ -290,6 +316,8 @@ def run(self):
hour_steps=self.hour_steps,
num_ensemble_members=self.num_ensemble_members,
lagged=self.lagged,
oper_fcst=oper_fcst,
member_number=self.member_number,
)

def patch_retrieve_request(self, r):
Expand Down Expand Up @@ -320,7 +348,10 @@ def parse_model_args(self, args):
import argparse

parser = argparse.ArgumentParser("ai-models gencast")
parser.add_argument("--num_ensemble_members", type=int, help="Number of ensemble members to run", default=1)
parser.add_argument("--num-ensemble-members", type=int, help="Number of ensemble members to run", default=1)
parser.add_argument(
"--member-number", help="Member Number, can only be used if `num_ensemble_members` == 1", default=None
)
parser.add_argument("--use-an", action="store_true")
parser.add_argument("--override-constants")
return parser.parse_args(args)
Expand Down
27 changes: 14 additions & 13 deletions src/ai_models_gencast/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def save_output_xarray(
hour_steps,
num_ensemble_members,
lagged,
oper_fcst,
member_number,
):
LOG.info("Converting output xarray to GRIB and saving")

Expand All @@ -41,37 +43,36 @@ def save_output_xarray(
for time in range(lead_time // hour_steps):
for fs in all_fields[: len(all_fields) // len(lagged)]:
param, level = fs.metadata("shortName"), fs.metadata("levelist", default=None)
for ensemble_member in range(num_ensemble_members):
for i in range(num_ensemble_members):
ensemble_member = member_number[i]

if level is not None:
param = GRIB_TO_XARRAY_PL.get(param, param)
if param not in target_variables:
continue
values = output.isel(time=time).sel(level=level).sel(sample=ensemble_member).data_vars[param].values
values = output.isel(time=time).sel(level=level).sel(sample=i).data_vars[param].values
else:
param = GRIB_TO_CF.get(param, param)
param = GRIB_TO_XARRAY_SFC.get(param, param)
if param not in target_variables:
continue
values = output.isel(time=time).sel(sample=ensemble_member).data_vars[param].values
values = output.isel(time=time).sel(sample=i).data_vars[param].values

# We want to field north=>south

values = np.flipud(values.reshape(fs.shape))

if oper_fcst:
extra_write_kwargs = {}
else:
extra_write_kwargs = dict(number=ensemble_member)

if param == "total_precipitation_12hr":
write(
values,
template=fs,
startStep=0,
# Offset to align with GRIB numbering, e.g. 0 is control, 1+ is ensemble
number=ensemble_member + 1,
endStep=(time + 1) * hour_steps,
)
write(values, template=fs, startStep=0, endStep=(time + 1) * hour_steps, **extra_write_kwargs)
else:
write(
values,
template=fs,
step=(time + 1) * hour_steps,
# Offset to align with GRIB numbering, e.g. 0 is control, 1+ is ensemble
number=ensemble_member + 1,
**extra_write_kwargs,
)

0 comments on commit 18e7458

Please sign in to comment.