Skip to content

Commit

Permalink
ai-models-gencast
Browse files Browse the repository at this point in the history
  • Loading branch information
HCookie committed Dec 9, 2024
1 parent 7c866fa commit fd5125a
Show file tree
Hide file tree
Showing 11 changed files with 720 additions and 6 deletions.
13 changes: 8 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# ai-models-gencast

`ai-models-gencast` is an [ai-models](https://github.com/ecmwf-lab/ai-models) plugin to run Google Deepmind's [Gencast](https://github.com/deepmind/graphcast).
`ai-models-gencast` is an [ai-models](https://github.com/ecmwf-lab/ai-models) plugin to run Google Deepmind's [GenCast](https://github.com/deepmind/graphcast).

GenCast: Diffusion-based ensemble forecasting for medium-range weather, arXiv preprint: 2312.15796, 2023. https://arxiv.org/abs/2312.15796
GenCast: Diffusion-based ensemble forecasting for medium-range weather, arXiv preprint: 2312.15796, 2024. https://arxiv.org/abs/2312.15796

Gencast was created by Ilan Price, Alvaro Sanchez-Gonzalez, Ferran Alet, Tom R. Andersson, Andrew El-Kadi, Dominic Masters, Timo Ewalds, Jacklynn Stott, Shakir Mohamed, Peter Battaglia, Remi Lam, Matthew Willson
GenCast was created by Ilan Price, Alvaro Sanchez-Gonzalez, Ferran Alet, Tom R. Andersson, Andrew El-Kadi, Dominic Masters, Timo Ewalds, Jacklynn Stott, Shakir Mohamed, Peter Battaglia, Remi Lam, Matthew Willson

The model weights are made available for use under the terms of the Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0). You may obtain a copy of the License at: https://creativecommons.org/licenses/by-nc-sa/4.0/.

Expand All @@ -20,11 +20,14 @@ This will install the package and most of its dependencies.

Then to install gencast dependencies (and Jax on GPU):

> [!CAUTION]
> GenCast requires significant GPU & Memory Resources.
> See [here](https://github.com/google-deepmind/graphcast/blob/main/docs/cloud_vm_setup.md#gencast-memory-requirements)

### Gencast and Jax
### GenCast and Jax

Gencast depends on Jax, which needs special installation instructions for your specific hardware.
GenCast depends on Jax, which needs special installation instructions for your specific hardware.

Please see the [installation guide](https://github.com/google/jax#installation) to follow the correct instructions.

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ dynamic = [
# JAX requirements are in requirements.txt

dependencies = [
"ai-models>=0.4.0",
"ai-models>=0.7.4",
"dm-tree",
"dm-haiku",
]
Expand Down
2 changes: 2 additions & 0 deletions requirements-gpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
jax[cuda11_pip]==0.4.36
git+https://github.com/deepmind/graphcast.git
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
jax==0.4.36
git+https://github.com/deepmind/graphcast.git
8 changes: 8 additions & 0 deletions src/ai_models_gencast/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

__version__ = "0.0.7"
42 changes: 42 additions & 0 deletions src/ai_models_gencast/convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

GRIB_TO_XARRAY_SFC = {
"t2m": "2m_temperature",
"sst": "sea_surface_temperature",
"msl": "mean_sea_level_pressure",
"u10": "10m_u_component_of_wind",
"v10": "10m_v_component_of_wind",
"tp": "total_precipitation_12hr",
"z": "geopotential_at_surface",
"lsm": "land_sea_mask",
"latitude": "lat",
"longitude": "lon",
# "step": "batch",
"valid_time": "datetime",
}

GRIB_TO_XARRAY_PL = {
"t": "temperature",
"z": "geopotential",
"u": "u_component_of_wind",
"v": "v_component_of_wind",
"w": "vertical_velocity",
"q": "specific_humidity",
"isobaricInhPa": "level",
"latitude": "lat",
"longitude": "lon",
# "step": "batch",
"valid_time": "datetime",
}


GRIB_TO_CF = {
"2t": "t2m",
"10u": "u10",
"10v": "v10",
}
194 changes: 194 additions & 0 deletions src/ai_models_gencast/input.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import datetime
import logging
from collections import defaultdict

import earthkit.data as ekd
import numpy as np
import xarray as xr

LOG = logging.getLogger(__name__)

CF_NAME_SFC = {
"10u": "10m_u_component_of_wind",
"10v": "10m_v_component_of_wind",
"2t": "2m_temperature",
"sst": "sea_surface_temperature",
"lsm": "land_sea_mask",
"msl": "mean_sea_level_pressure",
"tp": "total_precipitation_12hr",
"z": "geopotential_at_surface",
}

CF_NAME_PL = {
"q": "specific_humidity",
"t": "temperature",
"u": "u_component_of_wind",
"v": "v_component_of_wind",
"w": "vertical_velocity",
"z": "geopotential",
}


def forcing_variables_numpy(sample, forcing_variables, dates):
"""Generate variables from earthkit-data
Args:
date (datetime): Datetime of current time step in forecast
params (List[str]): Parameters to calculate as constants
Returns:
torch.Tensor: Tensor with constants
"""
ds = ekd.from_source(
"forcings",
sample,
date=dates,
param=forcing_variables,
)

return ds.order_by(param=forcing_variables, valid_datetime="ascending").to_numpy(dtype=np.float32)


def create_training_xarray(
*,
fields_sfc,
fields_pl,
lagged,
start_date,
hour_steps,
lead_time,
forcing_variables,
constants,
timer,
context,
):
time_deltas = [
datetime.timedelta(hours=h)
for h in lagged + [hour for hour in range(hour_steps, lead_time + hour_steps, hour_steps)]
]

all_datetimes = [start_date + time_delta for time_delta in time_deltas]

with timer("Creating forcing variables"):
forcing_numpy = forcing_variables_numpy(fields_sfc, forcing_variables, all_datetimes)

with timer("Converting GRIB to xarray"):
# Create Input dataset

lat = fields_sfc[0].metadata("distinctLatitudes")
lon = fields_sfc[0].metadata("distinctLongitudes")

forcing_numpy = forcing_numpy.reshape(len(forcing_variables), len(all_datetimes), len(lat), len(lon))

# SURFACE FIELDS

fields_sfc = fields_sfc.order_by("param", "valid_datetime")
sfc = defaultdict(list)
given_datetimes = set()
for field in fields_sfc:
given_datetimes.add(field.metadata("valid_datetime"))
sfc[field.metadata("param")].append(field)

# PRESSURE LEVEL FIELDS

fields_pl = fields_pl.order_by("param", "valid_datetime", "level")
pl = defaultdict(list)
levels = set()
given_datetimes = set()
for field in fields_pl:
given_datetimes.add(field.metadata("valid_datetime"))
pl[field.metadata("param")].append(field)
levels.add(field.metadata("level"))

data_vars = {}

for param, fields in sfc.items():
if param in ("z", "lsm"):
data_vars[CF_NAME_SFC[param]] = (["lat", "lon"], fields[0].to_numpy())
continue

data = np.stack([field.to_numpy(dtype=np.float32) for field in fields]).reshape(
1,
len(given_datetimes),
len(lat),
len(lon),
)

data = np.pad(
data,
(
(0, 0),
(0, len(all_datetimes) - len(given_datetimes)),
(0, 0),
(0, 0),
),
constant_values=(np.nan,),
)

data_vars[CF_NAME_SFC[param]] = (["batch", "time", "lat", "lon"], data)

for param, fields in pl.items():
data = np.stack([field.to_numpy(dtype=np.float32) for field in fields]).reshape(
1,
len(given_datetimes),
len(levels),
len(lat),
len(lon),
)
data = np.pad(
data,
(
(0, 0),
(0, len(all_datetimes) - len(given_datetimes)),
(0, 0),
(0, 0),
(0, 0),
),
constant_values=(np.nan,),
)

data_vars[CF_NAME_PL[param]] = (
["batch", "time", "level", "lat", "lon"],
data,
)

data_vars["toa_incident_solar_radiation"] = (
["batch", "time", "lat", "lon"],
forcing_numpy[0:1, :, :, :],
)

training_xarray = xr.Dataset(
data_vars=data_vars,
coords=dict(
lon=lon,
lat=lat,
time=time_deltas,
datetime=(
("batch", "time"),
[all_datetimes],
),
level=sorted(levels),
),
)

with timer("Reindexing"):
# And we want the grid south to north
training_xarray = training_xarray.reindex(lat=sorted(training_xarray.lat.values), copy=False)

if constants:
# Add geopotential_at_surface and land_sea_mask back in
x = xr.load_dataset(constants)

for patch in ("geopotential_at_surface", "land_sea_mask"):
LOG.info("PATCHING %s", patch)
training_xarray[patch] = x[patch]

return training_xarray, time_deltas
Loading

0 comments on commit fd5125a

Please sign in to comment.