Spaces:
Build error
Build error
import functools as ft | |
import os | |
import random | |
import re | |
from collections import defaultdict | |
from datetime import datetime, timedelta | |
from pathlib import Path | |
import h5py | |
import numpy as np | |
import pandas as pd | |
import torch | |
from torch import Tensor | |
from torch.utils.data import Dataset | |
def preproc(batch: list[dict], padding: dict[tuple[int]]) -> dict[str, Tensor]: | |
"""Prepressing function for MERRA2 Dataset | |
Args: | |
batch (dict): List of training samples, each sample should be a | |
dictionary with the following keys:: | |
'sur_static': Numpy array of shape (3, lat, lon). For each pixel (lat, lon), the first dimension indexes sin(lat), cos(lon), sin(lon). | |
'sur_vals': Torch tensor of shape (parameter, time, lat, lon). | |
'sur_tars': Torch tensor of shape (parameter, time, lat, lon). | |
'ulv_vals': Torch tensor of shape (parameter, level, time, lat, lon). | |
'ulv_tars': Torch tensor of shape (parameter, level, time, lat, lon). | |
'sur_climate': Torch tensor of shape (parameter, lat, lon) | |
'ulv_climate': Torch tensor of shape (parameter, level, lat, lon) | |
'lead_time': Integer. | |
'input_time': Integer. | |
padding: Dictionary with keys 'level', 'lat', 'lon', each of dim 2. | |
Returns: | |
Dictionary with the following keys:: | |
'x': [batch, time, parameter, lat, lon] | |
'y': [batch, parameter, lat, lon] | |
'static': [batch, parameter, lat, lon] | |
'lead_time': [batch] | |
'input_time': [batch] | |
'climate (Optional)': [batch, parameter, lat, lon] | |
Note: | |
Here, for x and y, 'parameter' is [surface parameter, upper level, | |
parameter x level]. Similarly for the static information we have | |
[sin(lat), cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod), | |
...]. | |
""" # noqa: E501 | |
b0 = batch[0] | |
nbatch = len(batch) | |
data_keys = set(b0.keys()) | |
essential_keys = { | |
"sur_static", | |
"sur_vals", | |
"sur_tars", | |
"ulv_vals", | |
"ulv_tars", | |
"input_time", | |
"lead_time", | |
} | |
climate_keys = { | |
"sur_climate", | |
"ulv_climate", | |
} | |
all_keys = essential_keys | climate_keys | |
if not essential_keys.issubset(data_keys): | |
raise ValueError("Missing essential keys.") | |
if not data_keys.issubset(all_keys): | |
raise ValueError("Unexpected keys in batch.") | |
# Bring all tensors from the batch into a single tensor | |
upl_x = torch.empty((nbatch, *b0["ulv_vals"].shape)) | |
upl_y = torch.empty((nbatch, *b0["ulv_tars"].shape)) | |
sur_x = torch.empty((nbatch, *b0["sur_vals"].shape)) | |
sur_y = torch.empty((nbatch, *b0["sur_tars"].shape)) | |
sur_sta = torch.empty((nbatch, *b0["sur_static"].shape)) | |
lead_time = torch.empty((nbatch,), dtype=torch.float32) | |
input_time = torch.empty((nbatch,), dtype=torch.float32) | |
for i, rec in enumerate(batch): | |
sur_x[i] = rec["sur_vals"] | |
sur_y[i] = rec["sur_tars"] | |
upl_x[i] = rec["ulv_vals"] | |
upl_y[i] = rec["ulv_tars"] | |
sur_sta[i] = rec["sur_static"] | |
lead_time[i] = rec["lead_time"] | |
input_time[i] = rec["input_time"] | |
return_value = { | |
"lead_time": lead_time, | |
"input_time": input_time, | |
} | |
# Reshape (batch, parameter, level, time, lat, lon) -> | |
# (batch, time, parameter, level, lat, lon) | |
upl_x = upl_x.permute((0, 3, 1, 2, 4, 5)) | |
upl_y = upl_y.permute((0, 3, 1, 2, 4, 5)) | |
# Reshape (batch, parameter, time, lat, lon) -> | |
# (batch, time, parameter, lat, lon) | |
sur_x = sur_x.permute((0, 2, 1, 3, 4)) | |
sur_y = sur_y.permute((0, 2, 1, 3, 4)) | |
# Pad | |
padding_2d = (*padding["lon"], *padding["lat"]) | |
def pad2d(x): | |
return torch.nn.functional.pad(x, padding_2d, mode="constant", value=0) | |
padding_3d = (*padding["lon"], *padding["lat"], *padding["level"]) | |
def pad3d(x): | |
return torch.nn.functional.pad(x, padding_3d, mode="constant", value=0) | |
sur_x = pad2d(sur_x).contiguous() | |
upl_x = pad3d(upl_x).contiguous() | |
sur_y = pad2d(sur_y).contiguous() | |
upl_y = pad3d(upl_y).contiguous() | |
return_value["static"] = pad2d(sur_sta).contiguous() | |
# Remove time for targets | |
upl_y = torch.squeeze(upl_y, 1) | |
sur_y = torch.squeeze(sur_y, 1) | |
# We stack along the combined parameter x level dimension | |
return_value["x"] = torch.cat( | |
(sur_x, upl_x.view(*upl_x.shape[:2], -1, *upl_x.shape[4:])), dim=2 | |
) | |
return_value["y"] = torch.cat( | |
(sur_y, upl_y.view(upl_y.shape[0], -1, *upl_y.shape[3:])), dim=1 | |
) | |
if climate_keys.issubset(data_keys): | |
sur_climate = torch.empty((nbatch, *b0["sur_climate"].shape)) | |
ulv_climate = torch.empty((nbatch, *b0["ulv_climate"].shape)) | |
for i, rec in enumerate(batch): | |
sur_climate[i] = rec["sur_climate"] | |
ulv_climate[i] = rec["ulv_climate"] | |
sur_climate = pad2d(sur_climate) | |
ulv_climate = pad3d(ulv_climate) | |
return_value["climate"] = torch.cat( | |
( | |
sur_climate, | |
ulv_climate.view(nbatch, -1, *ulv_climate.shape[3:]), | |
), | |
dim=1, | |
) | |
return return_value | |
def input_scalers( | |
surf_vars: list[str], | |
vert_vars: list[str], | |
levels: list[float], | |
surf_path: str | Path, | |
vert_path: str | Path, | |
) -> tuple[Tensor, Tensor]: | |
"""Reads the input scalers | |
Args: | |
surf_vars: surface variables to be used. | |
vert_vars: vertical variables to be used. | |
levels: MERRA2 levels to use. | |
surf_path: path to surface scalers file. | |
vert_path: path to vertical level scalers file. | |
Returns: | |
mu (Tensor): mean values | |
var (Tensor): varience values | |
""" | |
with h5py.File(Path(surf_path), "r", libver="latest") as surf_file: | |
stats = [x.decode().lower() for x in surf_file["statistic"][()]] | |
mu_idx = stats.index("mu") | |
sig_idx = stats.index("sigma") | |
s_mu = torch.tensor([surf_file[k][()][mu_idx] for k in surf_vars]) | |
s_sig = torch.tensor([surf_file[k][()][sig_idx] for k in surf_vars]) | |
with h5py.File(Path(vert_path), "r", libver="latest") as vert_file: | |
stats = [x.decode().lower() for x in vert_file["statistic"][()]] | |
mu_idx = stats.index("mu") | |
sig_idx = stats.index("sigma") | |
lvl = vert_file["lev"][()] | |
l_idx = [np.where(lvl == v)[0].item() for v in levels] | |
v_mu = np.array([vert_file[k][()][mu_idx, l_idx] for k in vert_vars]) | |
v_sig = np.array([vert_file[k][()][sig_idx, l_idx] for k in vert_vars]) | |
v_mu = torch.from_numpy(v_mu).view(-1) | |
v_sig = torch.from_numpy(v_sig).view(-1) | |
mu = torch.cat((s_mu, v_mu), dim=0).to(torch.float32) | |
sig = torch.cat((s_sig, v_sig), dim=0).to(torch.float32).clamp(1e-4, 1e4) | |
return mu, sig | |
def static_input_scalers( | |
scalar_path: str | Path, stat_vars: list[str], unscaled_params: int = 7 | |
) -> tuple[Tensor, Tensor]: | |
scalar_path = Path(scalar_path) | |
with h5py.File(scalar_path, "r", libver="latest") as scaler_file: | |
stats = [x.decode().lower() for x in scaler_file["statistic"][()]] | |
mu_idx = stats.index("mu") | |
sig_idx = stats.index("sigma") | |
mu = torch.tensor([scaler_file[k][()][mu_idx] for k in stat_vars]) | |
sig = torch.tensor([scaler_file[k][()][sig_idx] for k in stat_vars]) | |
z = torch.zeros(unscaled_params, dtype=mu.dtype, device=mu.device) | |
o = torch.ones(unscaled_params, dtype=sig.dtype, device=sig.device) | |
mu = torch.cat((z, mu), dim=0).to(torch.float32) | |
sig = torch.cat((o, sig), dim=0).to(torch.float32) | |
return mu, sig.clamp(1e-4, 1e4) | |
def output_scalers( | |
surf_vars: list[str], | |
vert_vars: list[str], | |
levels: list[float], | |
surf_path: str | Path, | |
vert_path: str | Path, | |
) -> Tensor: | |
surf_path = Path(surf_path) | |
vert_path = Path(vert_path) | |
with h5py.File(surf_path, "r", libver="latest") as surf_file: | |
svars = torch.tensor([surf_file[k][()] for k in surf_vars]) | |
with h5py.File(vert_path, "r", libver="latest") as vert_file: | |
lvl = vert_file["lev"][()] | |
l_idx = [np.where(lvl == v)[0].item() for v in levels] | |
vvars = np.array([vert_file[k][()][l_idx] for k in vert_vars]) | |
vvars = torch.from_numpy(vvars).view(-1) | |
var = torch.cat((svars, vvars), dim=0).to(torch.float32).clamp(1e-7, 1e7) | |
return var | |
class SampleSpec: | |
""" | |
A data class to collect the information used to define a sample. | |
""" | |
def __init__( | |
self, | |
inputs: tuple[pd.Timestamp, pd.Timestamp], | |
lead_time: int, | |
target: pd.Timestamp | list[pd.Timestamp], | |
): | |
""" | |
Args: | |
inputs: Tuple of timestamps. In ascending order. | |
lead_time: Lead time. In hours. | |
target: Timestamp of the target. Can be before or after the inputs. | |
""" | |
if not inputs[0] < inputs[1]: | |
raise ValueError( | |
"Timestamps in `inputs` should be in strictly ascending order." | |
) | |
self.inputs = inputs | |
self.input_time = (inputs[1] - inputs[0]).total_seconds() / 3600 | |
self.lead_time = lead_time | |
self.target = target | |
self.times = [*inputs, target] | |
self.stat_times = [inputs[-1]] | |
def climatology_info(self) -> tuple[int, int]: | |
"""Get the required climatology info. | |
:return: information required to obtain climatology data. Essentially | |
this is the day of the year and hour of the day of the target | |
timestamp, with the former restricted to the interval [1, 365]. | |
:rtype: tuple | |
""" | |
return (min(self.target.dayofyear, 365), self.target.hour) | |
def year(self) -> int: | |
return self.inputs[1].year | |
def dayofyear(self) -> int: | |
return self.inputs[1].dayofyear | |
def hourofday(self) -> int: | |
return self.inputs[1].hour | |
def _info_str(self) -> str: | |
iso_8601 = "%Y-%m-%dT%H:%M:%S" | |
return ( | |
f"Issue time: {self.inputs[1].strftime(iso_8601)}\n" | |
f"Lead time: {self.lead_time} hours ahead\n" | |
f"Input delta: {self.input_time} hours\n" | |
f"Target time: {self.target.strftime(iso_8601)}" | |
) | |
def get(cls, timestamp: pd.Timestamp, dt: int, lead_time: int): | |
"""Given a timestamp and lead time, generates a SampleSpec object | |
describing the sample further. | |
Args: | |
timestamp: Timstamp of the sample, Ie this is the larger of the two | |
input timstamps. | |
dt: Time between input samples, in hours. | |
lead_time: Lead time. In hours. | |
Returns: | |
SampleSpec | |
""" # noqa: E501 | |
assert dt > 0, "dt should be possitive" | |
lt = pd.to_timedelta(lead_time, unit="h") | |
dt = pd.to_timedelta(dt, unit="h") | |
if lead_time >= 0: | |
timestamp_target = timestamp + lt | |
else: | |
timestamp_target = timestamp - dt + lt | |
spec = cls( | |
inputs=(timestamp - dt, timestamp), | |
lead_time=lead_time, | |
target=timestamp_target, | |
) | |
return spec | |
def __repr__(self) -> str: | |
return self._info_str() | |
def __str__(self) -> str: | |
return self._info_str() | |
class Merra2Dataset(Dataset): | |
"""MERRA2 dataset. The dataset unifies surface and vertical data as well as | |
optional climatology. | |
Samples come in the form of a dictionary. Not all keys support all | |
variables, yet the general ordering of dimensions is | |
parameter, level, time, lat, lon | |
Note: | |
Data is assumed to be in NetCDF files containing daily data at 3-hourly | |
intervals. These follow the naming patterns | |
MERRA2_sfc_YYYYMMHH.nc and MERRA_pres_YYYYMMHH.nc and can be located in | |
two different locations. Optional climatology data comes from files | |
climate_surface_doyDOY_hourHOD.nc and | |
climate_vertical_doyDOY_hourHOD.nc. | |
Note: | |
`_get_valid_timestamps` assembles a set of all timestamps for which | |
there is data (with hourly resolutions). The result is stored in | |
`_valid_timestamps`. `_get_valid_climate_timestamps` does the same with | |
climatology data and stores it in `_valid_climate_timestamps`. | |
Based on this information, `samples` generates a list of valid samples, | |
stored in `samples`. Here the format is:: | |
[ | |
[ | |
(timestamp 1, lead time A), | |
(timestamp 1, lead time B), | |
(timestamp 1, lead time C), | |
], | |
[ | |
(timestamp 2, lead time D), | |
(timestamp 2, lead time E), | |
] | |
] | |
That is, the outer list iterates over timestamps (init times), the | |
inner over lead times. Only valid entries are stored. | |
""" | |
valid_vertical_vars = [ | |
"CLOUD", | |
"H", | |
"OMEGA", | |
"PL", | |
"QI", | |
"QL", | |
"QV", | |
"T", | |
"U", | |
"V", | |
] | |
valid_surface_vars = [ | |
"EFLUX", | |
"GWETROOT", | |
"HFLUX", | |
"LAI", | |
"LWGAB", | |
"LWGEM", | |
"LWTUP", | |
"PRECTOT", | |
"PS", | |
"QV2M", | |
"SLP", | |
"SWGNT", | |
"SWTNT", | |
"T2M", | |
"TQI", | |
"TQL", | |
"TQV", | |
"TS", | |
"U10M", | |
"V10M", | |
"Z0M", | |
] | |
valid_static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"] | |
valid_levels = [ | |
34.0, | |
39.0, | |
41.0, | |
43.0, | |
44.0, | |
45.0, | |
48.0, | |
51.0, | |
53.0, | |
56.0, | |
63.0, | |
68.0, | |
71.0, | |
72.0, | |
] | |
timedelta_input = pd.to_timedelta(3, unit="h") | |
def __init__( | |
self, | |
time_range: tuple[str | pd.Timestamp, str | pd.Timestamp], | |
lead_times: list[int], | |
input_times: list[int], | |
data_path_surface: str | Path, | |
data_path_vertical: str | Path, | |
climatology_path_surface: str | Path | None = None, | |
climatology_path_vertical: str | Path | None = None, | |
surface_vars: list[str] | None = None, | |
static_surface_vars: list[str] | None = None, | |
vertical_vars: list[str] | None = None, | |
levels: list[float] | None = None, | |
roll_longitudes: int = 0, | |
positional_encoding: str = "absolute", | |
rtype: type = np.float32, | |
dtype: torch.dtype = torch.float32, | |
) -> None: | |
""" | |
Args: | |
data_path_surface: Location of surface data. | |
data_path_vertical: Location of vertical data. | |
climatology_path_surface: Location of (optional) surface | |
climatology. | |
climatology_path_vertical: Location of (optional) vertical | |
climatology. | |
surface_vars: Surface variables. | |
static_surface_vars: Static surface variables. | |
vertical_vars: Vertical variables. | |
levels: Levels. | |
time_range: Used to subset data. | |
lead_times: Lead times for generalized forecasting. | |
roll_longitudes: Set to non-zero value to data by random amount | |
along longitude dimension. | |
position_encoding: possible values are | |
['absolute' (default), 'fourier']. | |
'absolute' returns lat lon encoded in 3 dimensions using sine | |
and cosine | |
'fourier' returns lat/lon to be encoded by model | |
<any other key> returns lat/lon to be encoded by model | |
rtype: numpy data type used during read | |
dtype: torch data type of data output | |
""" | |
self.time_range = ( | |
pd.to_datetime(time_range[0]), | |
pd.to_datetime(time_range[1]), | |
) | |
self.lead_times = lead_times | |
self.input_times = input_times | |
self._roll_longitudes = list(range(roll_longitudes + 1)) | |
self._uvars = vertical_vars or self.valid_vertical_vars | |
self._level = levels or self.valid_levels | |
self._svars = surface_vars or self.valid_surface_vars | |
self._sstat = static_surface_vars or self.valid_static_surface_vars | |
self._nuvars = len(self._uvars) | |
self._nlevel = len(self._level) | |
self._nsvars = len(self._svars) | |
self._nsstat = len(self._sstat) | |
self.rtype = rtype | |
self.dtype = dtype | |
self.positional_encoding = positional_encoding | |
self._data_path_surface = Path(data_path_surface) | |
self._data_path_vertical = Path(data_path_vertical) | |
self.dir_exists(self._data_path_surface) | |
self.dir_exists(self._data_path_vertical) | |
self._get_coordinates() | |
self._climatology_path_surface = Path(climatology_path_surface) or None | |
self._climatology_path_vertical = ( | |
Path(climatology_path_vertical) or None | |
) | |
self._require_clim = ( | |
self._climatology_path_surface is not None | |
and self._climatology_path_vertical is not None | |
) | |
if self._require_clim: | |
self.dir_exists(self._climatology_path_surface) | |
self.dir_exists(self._climatology_path_vertical) | |
elif ( | |
climatology_path_surface is None | |
and climatology_path_vertical is None | |
): | |
self._climatology_path_surface = None | |
self._climatology_path_vertical = None | |
else: | |
raise ValueError( | |
"Either both or neither of" | |
"`climatology_path_surface` and" | |
"`climatology_path_vertical` should be None." | |
) | |
if not set(self._svars).issubset(set(self.valid_surface_vars)): | |
raise ValueError("Invalid surface variable.") | |
if not set(self._sstat).issubset(set(self.valid_static_surface_vars)): | |
raise ValueError("Invalid static surface variable.") | |
if not set(self._uvars).issubset(set(self.valid_vertical_vars)): | |
raise ValueError("Inalid vertical variable.") | |
if not set(self._level).issubset(set(self.valid_levels)): | |
raise ValueError("Invalid level.") | |
def dir_exists(path: Path) -> None: | |
if not path.is_dir(): | |
raise ValueError(f"Directory {path} does not exist.") | |
def upper_shape(self) -> tuple: | |
"""Returns the vertical variables shape | |
Returns: | |
tuple: vertical variable shape in the following order:: | |
[VAR, LEV, TIME, LAT, LON] | |
""" | |
return self._nuvars, self._nlevel, 2, 361, 576 | |
def surface_shape(self) -> tuple: | |
"""Returns the surface variables shape | |
Returns: | |
tuple: surafce shape in the following order:: | |
[VAR, LEV, TIME, LAT, LON] | |
""" | |
return self._nsvars, 2, 361, 576 | |
def data_file_surface(self, timestamp: pd.Timestamp) -> Path: | |
"""Build the surfcae data file name based on timestamp | |
Args: | |
timestamp: a timestamp | |
Returns: | |
Path: constructed path | |
""" | |
pattern = "MERRA2_sfc_%Y%m%d.nc" | |
data_file = self._data_path_surface / timestamp.strftime(pattern) | |
return data_file | |
def data_file_vertical(self, timestamp: pd.Timestamp) -> Path: | |
"""Build the vertical data file name based on timestamp | |
Args: | |
timestamp: a timestamp | |
Returns: | |
Path: constructed path | |
""" | |
pattern = "MERRA_pres_%Y%m%d.nc" | |
data_file = self._data_path_vertical / timestamp.strftime(pattern) | |
return data_file | |
def data_file_surface_climate( | |
self, | |
timestamp: pd.Timestamp | None = None, | |
dayofyear: int | None = None, | |
hourofday: int | None = None, | |
) -> Path: | |
""" | |
Returns the path to a climatology file based either on a timestamp or | |
the dayofyear / hourofday combination. | |
Args: | |
timestamp: A timestamp. | |
dayofyear: Day of the year. 1 to 366. | |
hourofday: Hour of the day. 0 to 23. | |
Returns: | |
Path: Path to climatology file. | |
""" | |
if timestamp is not None and ( | |
(dayofyear is not None) or (hourofday is not None) | |
): | |
raise ValueError( | |
"Provide either timestamp or both dayofyear and hourofday." | |
) | |
if timestamp is not None: | |
dayofyear = min(timestamp.dayofyear, 365) | |
hourofday = timestamp.hour | |
file_name = f"climate_surface_doy{dayofyear:03}_hour{hourofday:02}.nc" | |
data_file = self._climatology_path_surface / file_name | |
return data_file | |
def data_file_vertical_climate( | |
self, | |
timestamp: pd.Timestamp | None = None, | |
dayofyear: int | None = None, | |
hourofday: int | None = None, | |
) -> Path: | |
"""Returns the path to a climatology file based either on a timestamp | |
or the dayofyear / hourofday combination. | |
Args: | |
timestamp: A timestamp. dayofyear: Day of the year. 1 to 366. | |
hourofday: Hour of the day. 0 to 23. | |
Returns: | |
Path: Path to climatology file. | |
""" | |
if timestamp is not None and ( | |
(dayofyear is not None) or (hourofday is not None) | |
): | |
raise ValueError( | |
"Provide either timestamp or both dayofyear and hourofday." | |
) | |
if timestamp is not None: | |
dayofyear = min(timestamp.dayofyear, 365) | |
hourofday = timestamp.hour | |
file_name = f"climate_vertical_doy{dayofyear:03}_hour{hourofday:02}.nc" | |
data_file = self._climatology_path_vertical / file_name | |
return data_file | |
def _get_coordinates(self) -> None: | |
""" | |
Obtains the coordiantes (latitudes and longitudes) from a single data | |
file. | |
""" | |
timestamp = next(iter(self.valid_timestamps)) | |
file = self.data_file_surface(timestamp) | |
with h5py.File(file, "r", libver="latest") as handle: | |
self.lats = lats = handle["lat"][()].astype(self.rtype) | |
self.lons = lons = handle["lon"][()].astype(self.rtype) | |
deg_to_rad = np.pi / 180 | |
self._embed_lat = np.sin(lats * deg_to_rad).reshape(-1, 1) | |
self._embed_lon = np.empty((2, 1, len(lons)), dtype=self.rtype) | |
self._embed_lon[0, 0] = np.cos(lons * deg_to_rad) | |
self._embed_lon[1, 0] = np.sin(lons * deg_to_rad) | |
def lats(self) -> np.ndarray: | |
timestamp = next(iter(self.valid_timestamps)) | |
file = self.data_file_surface(timestamp) | |
with h5py.File(file, "r", libver="latest") as handle: | |
return handle["lat"][()].astype(self.rtype) | |
def lons(self) -> np.ndarray: | |
timestamp = next(iter(self.valid_timestamps)) | |
file = self.data_file_surface(timestamp) | |
with h5py.File(file, "r", libver="latest") as handle: | |
return handle["lon"][()].astype(self.rtype) | |
def position_signal(self) -> np.ndarray: | |
"""Generates the "position signal" that is part of the static | |
features. | |
Returns: | |
Tensor: Torch tensor of dimension (parameter, lat, lon) containing | |
sin(lat), cos(lon), sin(lon). | |
""" | |
latitudes, longitudes = np.meshgrid( | |
self.lats, self.lons, indexing="ij" | |
) | |
if self.positional_encoding == "absolute": | |
latitudes = latitudes / 360 * 2.0 * np.pi | |
longitudes = longitudes / 360 * 2.0 * np.pi | |
sur_static = np.stack( | |
[np.sin(latitudes), np.cos(longitudes), np.sin(longitudes)], | |
axis=0, | |
) | |
else: | |
sur_static = np.stack([latitudes, longitudes], axis=0) | |
sur_static = sur_static.astype(self.rtype) | |
return sur_static | |
def valid_timestamps(self) -> set[pd.Timestamp]: | |
"""Generates list of valid timestamps based on available files. Only | |
timestamps for which both surface and vertical information is available | |
are considered valid. | |
Returns: | |
list: list of timestamps | |
""" | |
s_glob = self._data_path_surface.glob("MERRA2_sfc_????????.nc") | |
s_files = [os.path.basename(f) for f in s_glob] | |
v_glob = self._data_path_surface.glob("MERRA_pres_????????.nc") | |
v_files = [os.path.basename(f) for f in v_glob] | |
s_re = re.compile(r"MERRA2_sfc_(\d{8}).nc\Z") | |
v_re = re.compile(r"MERRA_pres_(\d{8}).nc\Z") | |
fmt = "%Y%m%d" | |
s_times = { | |
(datetime.strptime(m[1], fmt)) | |
for f in s_files | |
if (m := s_re.match(f)) | |
} | |
v_times = { | |
(datetime.strptime(m[1], fmt)) | |
for f in v_files | |
if (m := v_re.match(f)) | |
} | |
times = s_times.intersection(v_times) | |
# Each file contains a day at 3 hour intervals | |
times = { | |
t + timedelta(hours=i) for i in range(0, 24, 3) for t in times | |
} | |
start_time, end_time = self.time_range | |
times = {pd.Timestamp(t) for t in times if start_time <= t <= end_time} | |
return times | |
def valid_climate_timestamps(self) -> set[tuple[int, int]]: | |
"""Generates list of "timestamps" (dayofyear, hourofday) for which | |
climatology data is present. Only instances for which surface and | |
vertical data is available are considered valid. | |
Returns: | |
list: List of tuples describing valid climatology instances. | |
""" | |
if not self._require_clim: | |
return set() | |
s_glob = self._climatology_path_surface.glob( | |
"climate_surface_doy???_hour??.nc" | |
) | |
s_files = [os.path.basename(f) for f in s_glob] | |
v_glob = self._climatology_path_vertical.glob( | |
"climate_vertical_doy???_hour??.nc" | |
) | |
v_files = [os.path.basename(f) for f in v_glob] | |
s_re = re.compile(r"climate_surface_doy(\d{3})_hour(\d{2}).nc\Z") | |
v_re = re.compile(r"climate_vertical_doy(\d{3})_hour(\d{2}).nc\Z") | |
s_times = { | |
(int(m[1]), int(m[2])) for f in s_files if (m := s_re.match(f)) | |
} | |
v_times = { | |
(int(m[1]), int(m[2])) for f in v_files if (m := v_re.match(f)) | |
} | |
times = s_times.intersection(v_times) | |
return times | |
def _data_available(self, spec: SampleSpec) -> bool: | |
""" | |
Checks whether data is available for a given SampleSpec object. Does so | |
using the internal sets with available data previously constructed. Not | |
by checking the file system. | |
Args: | |
spec: SampleSpec object as returned by SampleSpec.get | |
Returns: | |
bool: if data is availability. | |
""" | |
valid = set(spec.times).issubset(self.valid_timestamps) | |
if self._require_clim: | |
sci = spec.climatology_info | |
ci = set(sci) if isinstance(sci, list) else set([sci]) # noqa: C405 | |
valid &= ci.issubset(self.valid_climate_timestamps) | |
return valid | |
def samples(self) -> list[tuple[pd.Timestamp, int, int]]: | |
""" | |
Generates list of all valid samlpes. | |
Returns: | |
list: List of tuples (timestamp, input time, lead time). | |
""" | |
valid_samples = [] | |
dts = [(it, lt) for it in self.input_times for lt in self.lead_times] | |
for timestamp in sorted(self.valid_timestamps): | |
timestamp_samples = [] | |
for it, lt in dts: | |
spec = SampleSpec.get(timestamp, -it, lt) | |
if self._data_available(spec): | |
timestamp_samples.append((timestamp, it, lt)) | |
if timestamp_samples: | |
valid_samples.append(timestamp_samples) | |
return valid_samples | |
def _to_torch( | |
self, | |
data: dict[str, Tensor | list[Tensor]], | |
dtype: torch.dtype = torch.float32, | |
) -> dict[str, Tensor | list[Tensor]]: | |
out = {} | |
for k, v in data.items(): | |
if isinstance(v, list): | |
out[k] = [torch.from_numpy(x).to(dtype) for x in v] | |
else: | |
out[k] = torch.from_numpy(v).to(dtype) | |
return out | |
def _lat_roll( | |
self, data: dict[str, Tensor | list[Tensor]], n: int | |
) -> dict[str, Tensor | list[Tensor]]: | |
out = {} | |
for k, v in data.items(): | |
if isinstance(v, list): | |
out[k] = [torch.roll(x, shifts=n, dims=-1) for x in v] | |
else: | |
out[k] = torch.roll(v, shifts=n, dims=-1) | |
return out | |
def _read_static_data( | |
self, file: str | Path, doy: int, hod: int | |
) -> np.ndarray: | |
with h5py.File(file, "r", libver="latest") as handle: | |
lats_surf = handle["lat"] | |
lons_surf = handle["lon"] | |
nll = (len(lats_surf), len(lons_surf)) | |
npos = len(self.position_signal) | |
ntime = 4 | |
nstat = npos + ntime + self._nsstat | |
data = np.empty((nstat, *nll), dtype=self.rtype) | |
for i, key in enumerate(self._sstat, start=npos + ntime): | |
data[i] = handle[key][()].astype(dtype=self.rtype) | |
# [possition signal], cos(doy), sin(doy), cos(hod), sin(hod) | |
data[0:npos] = self.position_signal | |
data[npos + 0] = np.cos(2 * np.pi * doy / 366) | |
data[npos + 1] = np.sin(2 * np.pi * doy / 366) | |
data[npos + 2] = np.cos(2 * np.pi * hod / 24) | |
data[npos + 3] = np.sin(2 * np.pi * hod / 24) | |
return data | |
def _read_surface( | |
self, tidx: int, nll: tuple[int, int], handle: h5py.File | |
) -> np.ndarray: | |
data = np.empty((self._nsvars, *nll), dtype=self.rtype) | |
for i, key in enumerate(self._svars): | |
data[i] = handle[key][tidx][()].astype(dtype=self.rtype) | |
return data | |
def _read_levels( | |
self, tidx: int, nll: tuple[int, int], handle: h5py.File | |
) -> np.ndarray: | |
lvls = handle["lev"][()] | |
lidx = self._level_idxs(lvls) | |
data = np.empty((self._nuvars, self._nlevel, *nll), dtype=self.rtype) | |
for i, key in enumerate(self._uvars): | |
data[i] = handle[key][tidx, lidx][()].astype(dtype=self.rtype) | |
return np.ascontiguousarray(np.flip(data, axis=1)) | |
def _level_idxs(self, lvls): | |
lidx = [np.argwhere(lvls == int(lvl)).item() for lvl in self._level] | |
return sorted(lidx) | |
def _date_to_tidx(date: datetime | pd.Timestamp, handle: h5py.File) -> int: | |
if isinstance(date, pd.Timestamp): | |
date = date.to_pydatetime() | |
time = handle["time"] | |
t0 = time.attrs["begin_time"][()].item() | |
d0 = f"{time.attrs['begin_date'][()].item()}" | |
offset = datetime.strptime(d0, "%Y%m%d") | |
times = [offset + timedelta(minutes=int(t + t0)) for t in time[()]] | |
return times.index(date) | |
def _read_data( | |
self, file_pair: tuple[str, str], date: datetime | |
) -> dict[str, np.ndarray]: | |
s_file, v_file = file_pair | |
with h5py.File(s_file, "r", libver="latest") as shandle: | |
lats_surf = shandle["lat"] | |
lons_surf = shandle["lon"] | |
nll = (len(lats_surf), len(lons_surf)) | |
tidx = self._date_to_tidx(date, shandle) | |
sdata = self._read_surface(tidx, nll, shandle) | |
with h5py.File(v_file, "r", libver="latest") as vhandle: | |
lats_vert = vhandle["lat"] | |
lons_vert = vhandle["lon"] | |
nll = (len(lats_vert), len(lons_vert)) | |
tidx = self._date_to_tidx(date, vhandle) | |
vdata = self._read_levels(tidx, nll, vhandle) | |
data = {"vert": vdata, "surf": sdata} | |
return data | |
def _read_climate( | |
self, file_pair: tuple[str, str] | |
) -> dict[str, np.ndarray]: | |
s_file, v_file = file_pair | |
with h5py.File(s_file, "r", libver="latest") as shandle: | |
lats_surf = shandle["lat"] | |
lons_surf = shandle["lon"] | |
nll = (len(lats_surf), len(lons_surf)) | |
sdata = np.empty((self._nsvars, *nll), dtype=self.rtype) | |
for i, key in enumerate(self._svars): | |
sdata[i] = shandle[key][()].astype(dtype=self.rtype) | |
with h5py.File(v_file, "r", libver="latest") as vhandle: | |
lats_vert = vhandle["lat"] | |
lons_vert = vhandle["lon"] | |
nll = (len(lats_vert), len(lons_vert)) | |
lvls = vhandle["lev"][()] | |
lidx = self._level_idxs(lvls) | |
vdata = np.empty( | |
(self._nuvars, self._nlevel, *nll), dtype=self.rtype | |
) | |
for i, key in enumerate(self._uvars): | |
vdata[i] = vhandle[key][lidx][()].astype(dtype=self.rtype) | |
data = { | |
"vert": np.ascontiguousarray(np.flip(vdata, axis=1)), | |
"surf": sdata, | |
} | |
return data | |
def get_data_from_sample_spec( | |
self, spec: SampleSpec | |
) -> dict[str, Tensor | int | float]: | |
"""Loads and assembles sample data given a SampleSpec object. | |
Args: | |
spec (SampleSpec): Full details regarding the data to be loaded | |
Returns: | |
dict: Dictionary with the following keys:: | |
'sur_static': Torch tensor of shape [parameter, lat, lon]. For | |
each pixel (lat, lon), the first 7 dimensions index sin(lat), | |
cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod). | |
Where doy is the day of the year [1, 366] and hod the hour of | |
the day [0, 23]. | |
'sur_vals': Torch tensor of shape [parameter, time, lat, lon]. | |
'sur_tars': Torch tensor of shape [parameter, time, lat, lon]. | |
'ulv_vals': Torch tensor of shape [parameter, level, time, lat, lon]. | |
'ulv_tars': Torch tensor of shape [parameter, level, time, lat, lon]. | |
'sur_climate': Torch tensor of shape [parameter, lat, lon]. | |
'ulv_climate': Torch tensor of shape [paramter, level, lat, lon]. | |
'lead_time': Float. | |
'input_time': Float. | |
""" # noqa: E501 | |
# We assemble the unique timestamps for which we need data. | |
vals_required = {*spec.times} | |
stat_required = {*spec.stat_times} | |
# We assemble the unique data files from which we need value data | |
vals_file_map = defaultdict(list) | |
for t in vals_required: | |
data_files = ( | |
self.data_file_surface(t), | |
self.data_file_vertical(t), | |
) | |
vals_file_map[data_files].append(t) | |
# We assemble the unique data files from which we need static data | |
stat_file_map = defaultdict(list) | |
for t in stat_required: | |
data_files = ( | |
self.data_file_surface(t), | |
self.data_file_vertical(t), | |
) | |
stat_file_map[data_files].append(t) | |
# Load the value data | |
data = {} | |
for data_files, times in vals_file_map.items(): | |
for time in times: | |
data[time] = self._read_data(data_files, time) | |
# Combine times | |
sample_data = {} | |
input_upl = np.stack([data[t]["vert"] for t in spec.inputs], axis=2) | |
sample_data["ulv_vals"] = input_upl | |
target_upl = data[spec.target]["vert"] | |
sample_data["ulv_tars"] = target_upl[:, :, None] | |
input_sur = np.stack([data[t]["surf"] for t in spec.inputs], axis=1) | |
sample_data["sur_vals"] = input_sur | |
target_sur = data[spec.target]["surf"] | |
sample_data["sur_tars"] = target_sur[:, None] | |
# Load the static data | |
data_files, times = stat_file_map.popitem() | |
time = times[0].dayofyear, times[0].hour | |
sample_data["sur_static"] = self._read_static_data( | |
data_files[0], *time | |
) | |
# If required load the surface data | |
if self._require_clim: | |
ci_year, ci_hour = spec.climatology_info | |
surf_file = self.data_file_surface_climate( | |
dayofyear=ci_year, | |
hourofday=ci_hour, | |
) | |
vert_file = self.data_file_vertical_climate( | |
dayofyear=ci_year, | |
hourofday=ci_hour, | |
) | |
clim_data = self._read_climate((surf_file, vert_file)) | |
sample_data["sur_climate"] = clim_data["surf"] | |
sample_data["ulv_climate"] = clim_data["vert"] | |
# Move the data from numpy to torch | |
sample_data = self._to_torch(sample_data, dtype=self.dtype) | |
# Optionally roll | |
if len(self._roll_longitudes) > 0: | |
roll_by = random.choice(self._roll_longitudes) | |
sample_data = self._lat_roll(sample_data, roll_by) | |
# Now that we have rolled, we can add the static data | |
sample_data["lead_time"] = spec.lead_time | |
sample_data["input_time"] = spec.input_time | |
return sample_data | |
def get_data( | |
self, timestamp: pd.Timestamp, input_time: int, lead_time: int | |
) -> dict[str, Tensor | int]: | |
""" | |
Loads data based on timestamp and lead time. | |
Args: | |
timestamp: Timestamp. | |
input_time: time between input samples. | |
lead_time: lead time. | |
Returns: | |
Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars', | |
'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate', | |
'lead_time'. | |
""" | |
spec = SampleSpec.get(timestamp, -input_time, lead_time) | |
sample_data = self.get_data_from_sample_spec(spec) | |
return sample_data | |
def __getitem__(self, idx: int) -> dict[str, Tensor | int]: | |
""" | |
Loads data based on sample index and random choice of sample. | |
Args: | |
idx: Sample index. | |
Returns: | |
Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars', | |
'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate', | |
'lead_time', 'input_time'. | |
""" | |
sample_set = self.samples[idx] | |
timestamp, input_time, lead_time, *nsteps = random.choice(sample_set) | |
sample_data = self.get_data(timestamp, input_time, lead_time) | |
return sample_data | |
def __len__(self): | |
return len(self.samples) | |