|
import streamlit as st |
|
import torch |
|
import random |
|
import numpy as np |
|
import yaml |
|
from pathlib import Path |
|
from io import BytesIO |
|
import random |
|
from pathlib import Path |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
import tempfile |
|
import traceback |
|
import functools as ft |
|
import os |
|
import random |
|
import re |
|
from collections import defaultdict |
|
from datetime import datetime, timedelta |
|
from pathlib import Path |
|
import h5py |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from torch import Tensor |
|
from torch.utils.data import Dataset |
|
import logging |
|
from Prithvi import * |
|
|
|
def preproc(batch: list[dict], padding: dict[tuple[int]]) -> dict[str, Tensor]: |
|
"""Prepressing function for MERRA2 Dataset |
|
|
|
Args: |
|
batch (dict): List of training samples, each sample should be a |
|
dictionary with the following keys:: |
|
|
|
'sur_static': Numpy array of shape (3, lat, lon). For each pixel (lat, lon), the first dimension indexes sin(lat), cos(lon), sin(lon). |
|
'sur_vals': Torch tensor of shape (parameter, time, lat, lon). |
|
'sur_tars': Torch tensor of shape (parameter, time, lat, lon). |
|
'ulv_vals': Torch tensor of shape (parameter, level, time, lat, lon). |
|
'ulv_tars': Torch tensor of shape (parameter, level, time, lat, lon). |
|
'sur_climate': Torch tensor of shape (parameter, lat, lon) |
|
'ulv_climate': Torch tensor of shape (parameter, level, lat, lon) |
|
'lead_time': Integer. |
|
'input_time': Integer. |
|
|
|
padding: Dictionary with keys 'level', 'lat', 'lon', each of dim 2. |
|
|
|
Returns: |
|
Dictionary with the following keys:: |
|
|
|
'x': [batch, time, parameter, lat, lon] |
|
'y': [batch, parameter, lat, lon] |
|
'static': [batch, parameter, lat, lon] |
|
'lead_time': [batch] |
|
'input_time': [batch] |
|
'climate (Optional)': [batch, parameter, lat, lon] |
|
|
|
Note: |
|
Here, for x and y, 'parameter' is [surface parameter, upper level, |
|
parameter x level]. Similarly for the static information we have |
|
[sin(lat), cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod), |
|
...]. |
|
""" |
|
b0 = batch[0] |
|
nbatch = len(batch) |
|
data_keys = set(b0.keys()) |
|
|
|
essential_keys = { |
|
"sur_static", |
|
"sur_vals", |
|
"sur_tars", |
|
"ulv_vals", |
|
"ulv_tars", |
|
"input_time", |
|
"lead_time", |
|
} |
|
|
|
climate_keys = { |
|
"sur_climate", |
|
"ulv_climate", |
|
} |
|
|
|
all_keys = essential_keys | climate_keys |
|
|
|
if not essential_keys.issubset(data_keys): |
|
raise ValueError("Missing essential keys.") |
|
|
|
if not data_keys.issubset(all_keys): |
|
raise ValueError("Unexpected keys in batch.") |
|
|
|
|
|
upl_x = torch.empty((nbatch, *b0["ulv_vals"].shape)) |
|
upl_y = torch.empty((nbatch, *b0["ulv_tars"].shape)) |
|
|
|
sur_x = torch.empty((nbatch, *b0["sur_vals"].shape)) |
|
sur_y = torch.empty((nbatch, *b0["sur_tars"].shape)) |
|
|
|
sur_sta = torch.empty((nbatch, *b0["sur_static"].shape)) |
|
|
|
lead_time = torch.empty((nbatch,), dtype=torch.float32) |
|
input_time = torch.empty((nbatch,), dtype=torch.float32) |
|
|
|
for i, rec in enumerate(batch): |
|
sur_x[i] = rec["sur_vals"] |
|
sur_y[i] = rec["sur_tars"] |
|
|
|
upl_x[i] = rec["ulv_vals"] |
|
upl_y[i] = rec["ulv_tars"] |
|
|
|
sur_sta[i] = rec["sur_static"] |
|
|
|
lead_time[i] = rec["lead_time"] |
|
input_time[i] = rec["input_time"] |
|
|
|
return_value = { |
|
"lead_time": lead_time, |
|
"input_time": input_time, |
|
} |
|
|
|
|
|
|
|
upl_x = upl_x.permute((0, 3, 1, 2, 4, 5)) |
|
upl_y = upl_y.permute((0, 3, 1, 2, 4, 5)) |
|
|
|
|
|
sur_x = sur_x.permute((0, 2, 1, 3, 4)) |
|
sur_y = sur_y.permute((0, 2, 1, 3, 4)) |
|
|
|
|
|
padding_2d = (*padding["lon"], *padding["lat"]) |
|
|
|
def pad2d(x): |
|
return torch.nn.functional.pad(x, padding_2d, mode="constant", value=0) |
|
|
|
padding_3d = (*padding["lon"], *padding["lat"], *padding["level"]) |
|
|
|
def pad3d(x): |
|
return torch.nn.functional.pad(x, padding_3d, mode="constant", value=0) |
|
|
|
sur_x = pad2d(sur_x).contiguous() |
|
upl_x = pad3d(upl_x).contiguous() |
|
sur_y = pad2d(sur_y).contiguous() |
|
upl_y = pad3d(upl_y).contiguous() |
|
return_value["static"] = pad2d(sur_sta).contiguous() |
|
|
|
|
|
upl_y = torch.squeeze(upl_y, 1) |
|
sur_y = torch.squeeze(sur_y, 1) |
|
|
|
|
|
return_value["x"] = torch.cat( |
|
(sur_x, upl_x.view(*upl_x.shape[:2], -1, *upl_x.shape[4:])), dim=2 |
|
) |
|
return_value["y"] = torch.cat( |
|
(sur_y, upl_y.view(upl_y.shape[0], -1, *upl_y.shape[3:])), dim=1 |
|
) |
|
|
|
if climate_keys.issubset(data_keys): |
|
sur_climate = torch.empty((nbatch, *b0["sur_climate"].shape)) |
|
ulv_climate = torch.empty((nbatch, *b0["ulv_climate"].shape)) |
|
for i, rec in enumerate(batch): |
|
sur_climate[i] = rec["sur_climate"] |
|
ulv_climate[i] = rec["ulv_climate"] |
|
sur_climate = pad2d(sur_climate) |
|
ulv_climate = pad3d(ulv_climate) |
|
|
|
return_value["climate"] = torch.cat( |
|
( |
|
sur_climate, |
|
ulv_climate.view(nbatch, -1, *ulv_climate.shape[3:]), |
|
), |
|
dim=1, |
|
) |
|
|
|
return return_value |
|
|
|
|
|
def input_scalers( |
|
surf_vars: list[str], |
|
vert_vars: list[str], |
|
levels: list[float], |
|
surf_path: str | Path, |
|
vert_path: str | Path, |
|
) -> tuple[Tensor, Tensor]: |
|
"""Reads the input scalers |
|
|
|
Args: |
|
surf_vars: surface variables to be used. |
|
vert_vars: vertical variables to be used. |
|
levels: MERRA2 levels to use. |
|
surf_path: path to surface scalers file. |
|
vert_path: path to vertical level scalers file. |
|
|
|
Returns: |
|
mu (Tensor): mean values |
|
var (Tensor): varience values |
|
""" |
|
with h5py.File(Path(surf_path), "r", libver="latest") as surf_file: |
|
stats = [x.decode().lower() for x in surf_file["statistic"][()]] |
|
mu_idx = stats.index("mu") |
|
sig_idx = stats.index("sigma") |
|
|
|
s_mu = torch.tensor([surf_file[k][()][mu_idx] for k in surf_vars]) |
|
s_sig = torch.tensor([surf_file[k][()][sig_idx] for k in surf_vars]) |
|
|
|
with h5py.File(Path(vert_path), "r", libver="latest") as vert_file: |
|
stats = [x.decode().lower() for x in vert_file["statistic"][()]] |
|
mu_idx = stats.index("mu") |
|
sig_idx = stats.index("sigma") |
|
|
|
lvl = vert_file["lev"][()] |
|
l_idx = [np.where(lvl == v)[0].item() for v in levels] |
|
|
|
v_mu = np.array([vert_file[k][()][mu_idx, l_idx] for k in vert_vars]) |
|
v_sig = np.array([vert_file[k][()][sig_idx, l_idx] for k in vert_vars]) |
|
|
|
v_mu = torch.from_numpy(v_mu).view(-1) |
|
v_sig = torch.from_numpy(v_sig).view(-1) |
|
|
|
mu = torch.cat((s_mu, v_mu), dim=0).to(torch.float32) |
|
sig = torch.cat((s_sig, v_sig), dim=0).to(torch.float32).clamp(1e-4, 1e4) |
|
return mu, sig |
|
|
|
|
|
def static_input_scalers( |
|
scalar_path: str | Path, stat_vars: list[str], unscaled_params: int = 7 |
|
) -> tuple[Tensor, Tensor]: |
|
scalar_path = Path(scalar_path) |
|
|
|
with h5py.File(scalar_path, "r", libver="latest") as scaler_file: |
|
stats = [x.decode().lower() for x in scaler_file["statistic"][()]] |
|
mu_idx = stats.index("mu") |
|
sig_idx = stats.index("sigma") |
|
|
|
mu = torch.tensor([scaler_file[k][()][mu_idx] for k in stat_vars]) |
|
sig = torch.tensor([scaler_file[k][()][sig_idx] for k in stat_vars]) |
|
|
|
z = torch.zeros(unscaled_params, dtype=mu.dtype, device=mu.device) |
|
o = torch.ones(unscaled_params, dtype=sig.dtype, device=sig.device) |
|
mu = torch.cat((z, mu), dim=0).to(torch.float32) |
|
sig = torch.cat((o, sig), dim=0).to(torch.float32) |
|
|
|
return mu, sig.clamp(1e-4, 1e4) |
|
|
|
|
|
def output_scalers( |
|
surf_vars: list[str], |
|
vert_vars: list[str], |
|
levels: list[float], |
|
surf_path: str | Path, |
|
vert_path: str | Path, |
|
) -> Tensor: |
|
surf_path = Path(surf_path) |
|
vert_path = Path(vert_path) |
|
|
|
with h5py.File(surf_path, "r", libver="latest") as surf_file: |
|
svars = torch.tensor([surf_file[k][()] for k in surf_vars]) |
|
|
|
with h5py.File(vert_path, "r", libver="latest") as vert_file: |
|
lvl = vert_file["lev"][()] |
|
l_idx = [np.where(lvl == v)[0].item() for v in levels] |
|
vvars = np.array([vert_file[k][()][l_idx] for k in vert_vars]) |
|
vvars = torch.from_numpy(vvars).view(-1) |
|
|
|
var = torch.cat((svars, vvars), dim=0).to(torch.float32).clamp(1e-7, 1e7) |
|
|
|
return var |
|
|
|
|
|
class SampleSpec: |
|
""" |
|
A data class to collect the information used to define a sample. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
inputs: tuple[pd.Timestamp, pd.Timestamp], |
|
lead_time: int, |
|
target: pd.Timestamp | list[pd.Timestamp], |
|
): |
|
""" |
|
Args: |
|
inputs: Tuple of timestamps. In ascending order. |
|
lead_time: Lead time. In hours. |
|
target: Timestamp of the target. Can be before or after the inputs. |
|
""" |
|
if not inputs[0] < inputs[1]: |
|
raise ValueError( |
|
"Timestamps in `inputs` should be in strictly ascending order." |
|
) |
|
|
|
self.inputs = inputs |
|
self.input_time = (inputs[1] - inputs[0]).total_seconds() / 3600 |
|
self.lead_time = lead_time |
|
self.target = target |
|
|
|
self.times = [*inputs, target] |
|
self.stat_times = [inputs[-1]] |
|
|
|
@property |
|
def climatology_info(self) -> tuple[int, int]: |
|
"""Get the required climatology info. |
|
|
|
:return: information required to obtain climatology data. Essentially |
|
this is the day of the year and hour of the day of the target |
|
timestamp, with the former restricted to the interval [1, 365]. |
|
:rtype: tuple |
|
""" |
|
return (min(self.target.dayofyear, 365), self.target.hour) |
|
|
|
@property |
|
def year(self) -> int: |
|
return self.inputs[1].year |
|
|
|
@property |
|
def dayofyear(self) -> int: |
|
return self.inputs[1].dayofyear |
|
|
|
@property |
|
def hourofday(self) -> int: |
|
return self.inputs[1].hour |
|
|
|
def _info_str(self) -> str: |
|
iso_8601 = "%Y-%m-%dT%H:%M:%S" |
|
|
|
return ( |
|
f"Issue time: {self.inputs[1].strftime(iso_8601)}\n" |
|
f"Lead time: {self.lead_time} hours ahead\n" |
|
f"Input delta: {self.input_time} hours\n" |
|
f"Target time: {self.target.strftime(iso_8601)}" |
|
) |
|
|
|
@classmethod |
|
def get(cls, timestamp: pd.Timestamp, dt: int, lead_time: int): |
|
"""Given a timestamp and lead time, generates a SampleSpec object |
|
describing the sample further. |
|
|
|
Args: |
|
timestamp: Timstamp of the sample, Ie this is the larger of the two |
|
input timstamps. |
|
dt: Time between input samples, in hours. |
|
lead_time: Lead time. In hours. |
|
|
|
Returns: |
|
SampleSpec |
|
""" |
|
assert dt > 0, "dt should be possitive" |
|
lt = pd.to_timedelta(lead_time, unit="h") |
|
dt = pd.to_timedelta(dt, unit="h") |
|
|
|
if lead_time >= 0: |
|
timestamp_target = timestamp + lt |
|
else: |
|
timestamp_target = timestamp - dt + lt |
|
|
|
spec = cls( |
|
inputs=(timestamp - dt, timestamp), |
|
lead_time=lead_time, |
|
target=timestamp_target, |
|
) |
|
|
|
return spec |
|
|
|
def __repr__(self) -> str: |
|
return self._info_str() |
|
|
|
def __str__(self) -> str: |
|
return self._info_str() |
|
|
|
|
|
class Merra2Dataset(Dataset): |
|
"""MERRA2 dataset. The dataset unifies surface and vertical data as well as |
|
optional climatology. |
|
|
|
Samples come in the form of a dictionary. Not all keys support all |
|
variables, yet the general ordering of dimensions is |
|
parameter, level, time, lat, lon |
|
|
|
Note: |
|
Data is assumed to be in NetCDF files containing daily data at 3-hourly |
|
intervals. These follow the naming patterns |
|
MERRA2_sfc_YYYYMMHH.nc and MERRA_pres_YYYYMMHH.nc and can be located in |
|
two different locations. Optional climatology data comes from files |
|
climate_surface_doyDOY_hourHOD.nc and |
|
climate_vertical_doyDOY_hourHOD.nc. |
|
|
|
|
|
Note: |
|
`_get_valid_timestamps` assembles a set of all timestamps for which |
|
there is data (with hourly resolutions). The result is stored in |
|
`_valid_timestamps`. `_get_valid_climate_timestamps` does the same with |
|
climatology data and stores it in `_valid_climate_timestamps`. |
|
|
|
Based on this information, `samples` generates a list of valid samples, |
|
stored in `samples`. Here the format is:: |
|
|
|
[ |
|
[ |
|
(timestamp 1, lead time A), |
|
(timestamp 1, lead time B), |
|
(timestamp 1, lead time C), |
|
], |
|
[ |
|
(timestamp 2, lead time D), |
|
(timestamp 2, lead time E), |
|
] |
|
] |
|
|
|
That is, the outer list iterates over timestamps (init times), the |
|
inner over lead times. Only valid entries are stored. |
|
""" |
|
|
|
valid_vertical_vars = [ |
|
"CLOUD", |
|
"H", |
|
"OMEGA", |
|
"PL", |
|
"QI", |
|
"QL", |
|
"QV", |
|
"T", |
|
"U", |
|
"V", |
|
] |
|
valid_surface_vars = [ |
|
"EFLUX", |
|
"GWETROOT", |
|
"HFLUX", |
|
"LAI", |
|
"LWGAB", |
|
"LWGEM", |
|
"LWTUP", |
|
"PRECTOT", |
|
"PS", |
|
"QV2M", |
|
"SLP", |
|
"SWGNT", |
|
"SWTNT", |
|
"T2M", |
|
"TQI", |
|
"TQL", |
|
"TQV", |
|
"TS", |
|
"U10M", |
|
"V10M", |
|
"Z0M", |
|
] |
|
valid_static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"] |
|
|
|
valid_levels = [ |
|
34.0, |
|
39.0, |
|
41.0, |
|
43.0, |
|
44.0, |
|
45.0, |
|
48.0, |
|
51.0, |
|
53.0, |
|
56.0, |
|
63.0, |
|
68.0, |
|
71.0, |
|
72.0, |
|
] |
|
|
|
timedelta_input = pd.to_timedelta(3, unit="h") |
|
|
|
def __init__( |
|
self, |
|
time_range: tuple[str | pd.Timestamp, str | pd.Timestamp], |
|
lead_times: list[int], |
|
input_times: list[int], |
|
data_path_surface: str | Path, |
|
data_path_vertical: str | Path, |
|
climatology_path_surface: str | Path | None = None, |
|
climatology_path_vertical: str | Path | None = None, |
|
surface_vars: list[str] | None = None, |
|
static_surface_vars: list[str] | None = None, |
|
vertical_vars: list[str] | None = None, |
|
levels: list[float] | None = None, |
|
roll_longitudes: int = 0, |
|
positional_encoding: str = "absolute", |
|
rtype: type = np.float32, |
|
dtype: torch.dtype = torch.float32, |
|
) -> None: |
|
""" |
|
Args: |
|
data_path_surface: Location of surface data. |
|
data_path_vertical: Location of vertical data. |
|
climatology_path_surface: Location of (optional) surface |
|
climatology. |
|
climatology_path_vertical: Location of (optional) vertical |
|
climatology. |
|
surface_vars: Surface variables. |
|
static_surface_vars: Static surface variables. |
|
vertical_vars: Vertical variables. |
|
levels: Levels. |
|
time_range: Used to subset data. |
|
lead_times: Lead times for generalized forecasting. |
|
roll_longitudes: Set to non-zero value to data by random amount |
|
along longitude dimension. |
|
position_encoding: possible values are |
|
['absolute' (default), 'fourier']. |
|
'absolute' returns lat lon encoded in 3 dimensions using sine |
|
and cosine |
|
'fourier' returns lat/lon to be encoded by model |
|
<any other key> returns lat/lon to be encoded by model |
|
rtype: numpy data type used during read |
|
dtype: torch data type of data output |
|
""" |
|
|
|
self.time_range = ( |
|
pd.to_datetime(time_range[0]), |
|
pd.to_datetime(time_range[1]), |
|
) |
|
self.lead_times = lead_times |
|
self.input_times = input_times |
|
self._roll_longitudes = list(range(roll_longitudes + 1)) |
|
|
|
self._uvars = vertical_vars or self.valid_vertical_vars |
|
self._level = levels or self.valid_levels |
|
self._svars = surface_vars or self.valid_surface_vars |
|
self._sstat = static_surface_vars or self.valid_static_surface_vars |
|
self._nuvars = len(self._uvars) |
|
self._nlevel = len(self._level) |
|
self._nsvars = len(self._svars) |
|
self._nsstat = len(self._sstat) |
|
|
|
self.rtype = rtype |
|
self.dtype = dtype |
|
|
|
self.positional_encoding = positional_encoding |
|
|
|
self._data_path_surface = Path(data_path_surface) |
|
self._data_path_vertical = Path(data_path_vertical) |
|
|
|
self.dir_exists(self._data_path_surface) |
|
self.dir_exists(self._data_path_vertical) |
|
|
|
self._get_coordinates() |
|
|
|
self._climatology_path_surface = Path(climatology_path_surface) or None |
|
self._climatology_path_vertical = ( |
|
Path(climatology_path_vertical) or None |
|
) |
|
self._require_clim = ( |
|
self._climatology_path_surface is not None |
|
and self._climatology_path_vertical is not None |
|
) |
|
|
|
if self._require_clim: |
|
self.dir_exists(self._climatology_path_surface) |
|
self.dir_exists(self._climatology_path_vertical) |
|
elif ( |
|
climatology_path_surface is None |
|
and climatology_path_vertical is None |
|
): |
|
self._climatology_path_surface = None |
|
self._climatology_path_vertical = None |
|
else: |
|
raise ValueError( |
|
"Either both or neither of" |
|
"`climatology_path_surface` and" |
|
"`climatology_path_vertical` should be None." |
|
) |
|
|
|
if not set(self._svars).issubset(set(self.valid_surface_vars)): |
|
raise ValueError("Invalid surface variable.") |
|
|
|
if not set(self._sstat).issubset(set(self.valid_static_surface_vars)): |
|
raise ValueError("Invalid static surface variable.") |
|
|
|
if not set(self._uvars).issubset(set(self.valid_vertical_vars)): |
|
raise ValueError("Inalid vertical variable.") |
|
|
|
if not set(self._level).issubset(set(self.valid_levels)): |
|
raise ValueError("Invalid level.") |
|
|
|
@staticmethod |
|
def dir_exists(path: Path) -> None: |
|
if not path.is_dir(): |
|
raise ValueError(f"Directory {path} does not exist.") |
|
|
|
@property |
|
def upper_shape(self) -> tuple: |
|
"""Returns the vertical variables shape |
|
Returns: |
|
tuple: vertical variable shape in the following order:: |
|
|
|
[VAR, LEV, TIME, LAT, LON] |
|
""" |
|
return self._nuvars, self._nlevel, 2, 361, 576 |
|
|
|
@property |
|
def surface_shape(self) -> tuple: |
|
"""Returns the surface variables shape |
|
|
|
Returns: |
|
tuple: surafce shape in the following order:: |
|
|
|
[VAR, LEV, TIME, LAT, LON] |
|
""" |
|
return self._nsvars, 2, 361, 576 |
|
|
|
def data_file_surface(self, timestamp: pd.Timestamp) -> Path: |
|
"""Build the surfcae data file name based on timestamp |
|
|
|
Args: |
|
timestamp: a timestamp |
|
|
|
Returns: |
|
Path: constructed path |
|
""" |
|
pattern = "MERRA2_sfc_%Y%m%d.nc" |
|
data_file = self._data_path_surface / timestamp.strftime(pattern) |
|
return data_file |
|
|
|
def data_file_vertical(self, timestamp: pd.Timestamp) -> Path: |
|
"""Build the vertical data file name based on timestamp |
|
|
|
Args: |
|
timestamp: a timestamp |
|
|
|
Returns: |
|
Path: constructed path |
|
""" |
|
pattern = "MERRA_pres_%Y%m%d.nc" |
|
data_file = self._data_path_vertical / timestamp.strftime(pattern) |
|
return data_file |
|
|
|
def data_file_surface_climate( |
|
self, |
|
timestamp: pd.Timestamp | None = None, |
|
dayofyear: int | None = None, |
|
hourofday: int | None = None, |
|
) -> Path: |
|
""" |
|
Returns the path to a climatology file based either on a timestamp or |
|
the dayofyear / hourofday combination. |
|
Args: |
|
timestamp: A timestamp. |
|
dayofyear: Day of the year. 1 to 366. |
|
hourofday: Hour of the day. 0 to 23. |
|
Returns: |
|
Path: Path to climatology file. |
|
""" |
|
if timestamp is not None and ( |
|
(dayofyear is not None) or (hourofday is not None) |
|
): |
|
raise ValueError( |
|
"Provide either timestamp or both dayofyear and hourofday." |
|
) |
|
|
|
if timestamp is not None: |
|
dayofyear = min(timestamp.dayofyear, 365) |
|
hourofday = timestamp.hour |
|
|
|
file_name = f"climate_surface_doy{dayofyear:03}_hour{hourofday:02}.nc" |
|
data_file = self._climatology_path_surface / file_name |
|
return data_file |
|
|
|
def data_file_vertical_climate( |
|
self, |
|
timestamp: pd.Timestamp | None = None, |
|
dayofyear: int | None = None, |
|
hourofday: int | None = None, |
|
) -> Path: |
|
"""Returns the path to a climatology file based either on a timestamp |
|
or the dayofyear / hourofday combination. |
|
|
|
Args: |
|
timestamp: A timestamp. dayofyear: Day of the year. 1 to 366. |
|
hourofday: Hour of the day. 0 to 23. |
|
Returns: |
|
Path: Path to climatology file. |
|
""" |
|
if timestamp is not None and ( |
|
(dayofyear is not None) or (hourofday is not None) |
|
): |
|
raise ValueError( |
|
"Provide either timestamp or both dayofyear and hourofday." |
|
) |
|
|
|
if timestamp is not None: |
|
dayofyear = min(timestamp.dayofyear, 365) |
|
hourofday = timestamp.hour |
|
|
|
file_name = f"climate_vertical_doy{dayofyear:03}_hour{hourofday:02}.nc" |
|
data_file = self._climatology_path_vertical / file_name |
|
return data_file |
|
|
|
def _get_coordinates(self) -> None: |
|
""" |
|
Obtains the coordiantes (latitudes and longitudes) from a single data |
|
file. |
|
""" |
|
timestamp = next(iter(self.valid_timestamps)) |
|
|
|
file = self.data_file_surface(timestamp) |
|
with h5py.File(file, "r", libver="latest") as handle: |
|
self.lats = lats = handle["lat"][()].astype(self.rtype) |
|
self.lons = lons = handle["lon"][()].astype(self.rtype) |
|
|
|
deg_to_rad = np.pi / 180 |
|
self._embed_lat = np.sin(lats * deg_to_rad).reshape(-1, 1) |
|
|
|
self._embed_lon = np.empty((2, 1, len(lons)), dtype=self.rtype) |
|
self._embed_lon[0, 0] = np.cos(lons * deg_to_rad) |
|
self._embed_lon[1, 0] = np.sin(lons * deg_to_rad) |
|
|
|
@ft.cached_property |
|
def lats(self) -> np.ndarray: |
|
timestamp = next(iter(self.valid_timestamps)) |
|
|
|
file = self.data_file_surface(timestamp) |
|
with h5py.File(file, "r", libver="latest") as handle: |
|
return handle["lat"][()].astype(self.rtype) |
|
|
|
@ft.cached_property |
|
def lons(self) -> np.ndarray: |
|
timestamp = next(iter(self.valid_timestamps)) |
|
|
|
file = self.data_file_surface(timestamp) |
|
with h5py.File(file, "r", libver="latest") as handle: |
|
return handle["lon"][()].astype(self.rtype) |
|
|
|
@ft.cached_property |
|
def position_signal(self) -> np.ndarray: |
|
"""Generates the "position signal" that is part of the static |
|
features. |
|
|
|
Returns: |
|
Tensor: Torch tensor of dimension (parameter, lat, lon) containing |
|
sin(lat), cos(lon), sin(lon). |
|
""" |
|
|
|
latitudes, longitudes = np.meshgrid( |
|
self.lats, self.lons, indexing="ij" |
|
) |
|
|
|
if self.positional_encoding == "absolute": |
|
latitudes = latitudes / 360 * 2.0 * np.pi |
|
longitudes = longitudes / 360 * 2.0 * np.pi |
|
sur_static = np.stack( |
|
[np.sin(latitudes), np.cos(longitudes), np.sin(longitudes)], |
|
axis=0, |
|
) |
|
else: |
|
sur_static = np.stack([latitudes, longitudes], axis=0) |
|
|
|
sur_static = sur_static.astype(self.rtype) |
|
|
|
return sur_static |
|
|
|
@ft.cached_property |
|
def valid_timestamps(self) -> set[pd.Timestamp]: |
|
"""Generates list of valid timestamps based on available files. Only |
|
timestamps for which both surface and vertical information is available |
|
are considered valid. |
|
Returns: |
|
list: list of timestamps |
|
""" |
|
|
|
s_glob = self._data_path_surface.glob("MERRA2_sfc_????????.nc") |
|
s_files = [os.path.basename(f) for f in s_glob] |
|
v_glob = self._data_path_surface.glob("MERRA_pres_????????.nc") |
|
v_files = [os.path.basename(f) for f in v_glob] |
|
|
|
s_re = re.compile(r"MERRA2_sfc_(\d{8}).nc\Z") |
|
v_re = re.compile(r"MERRA_pres_(\d{8}).nc\Z") |
|
fmt = "%Y%m%d" |
|
|
|
s_times = { |
|
(datetime.strptime(m[1], fmt)) |
|
for f in s_files |
|
if (m := s_re.match(f)) |
|
} |
|
v_times = { |
|
(datetime.strptime(m[1], fmt)) |
|
for f in v_files |
|
if (m := v_re.match(f)) |
|
} |
|
|
|
times = s_times.intersection(v_times) |
|
|
|
|
|
times = { |
|
t + timedelta(hours=i) for i in range(0, 24, 3) for t in times |
|
} |
|
|
|
start_time, end_time = self.time_range |
|
times = {pd.Timestamp(t) for t in times if start_time <= t <= end_time} |
|
|
|
return times |
|
|
|
@ft.cached_property |
|
def valid_climate_timestamps(self) -> set[tuple[int, int]]: |
|
"""Generates list of "timestamps" (dayofyear, hourofday) for which |
|
climatology data is present. Only instances for which surface and |
|
vertical data is available are considered valid. |
|
Returns: |
|
list: List of tuples describing valid climatology instances. |
|
""" |
|
if not self._require_clim: |
|
return set() |
|
|
|
s_glob = self._climatology_path_surface.glob( |
|
"climate_surface_doy???_hour??.nc" |
|
) |
|
s_files = [os.path.basename(f) for f in s_glob] |
|
|
|
v_glob = self._climatology_path_vertical.glob( |
|
"climate_vertical_doy???_hour??.nc" |
|
) |
|
v_files = [os.path.basename(f) for f in v_glob] |
|
|
|
s_re = re.compile(r"climate_surface_doy(\d{3})_hour(\d{2}).nc\Z") |
|
v_re = re.compile(r"climate_vertical_doy(\d{3})_hour(\d{2}).nc\Z") |
|
|
|
s_times = { |
|
(int(m[1]), int(m[2])) for f in s_files if (m := s_re.match(f)) |
|
} |
|
v_times = { |
|
(int(m[1]), int(m[2])) for f in v_files if (m := v_re.match(f)) |
|
} |
|
|
|
times = s_times.intersection(v_times) |
|
|
|
return times |
|
|
|
def _data_available(self, spec: SampleSpec) -> bool: |
|
""" |
|
Checks whether data is available for a given SampleSpec object. Does so |
|
using the internal sets with available data previously constructed. Not |
|
by checking the file system. |
|
Args: |
|
spec: SampleSpec object as returned by SampleSpec.get |
|
Returns: |
|
bool: if data is availability. |
|
""" |
|
valid = set(spec.times).issubset(self.valid_timestamps) |
|
|
|
if self._require_clim: |
|
sci = spec.climatology_info |
|
ci = set(sci) if isinstance(sci, list) else set([sci]) |
|
valid &= ci.issubset(self.valid_climate_timestamps) |
|
|
|
return valid |
|
|
|
@ft.cached_property |
|
def samples(self) -> list[tuple[pd.Timestamp, int, int]]: |
|
""" |
|
Generates list of all valid samlpes. |
|
Returns: |
|
list: List of tuples (timestamp, input time, lead time). |
|
""" |
|
valid_samples = [] |
|
dts = [(it, lt) for it in self.input_times for lt in self.lead_times] |
|
|
|
for timestamp in sorted(self.valid_timestamps): |
|
timestamp_samples = [] |
|
for it, lt in dts: |
|
spec = SampleSpec.get(timestamp, -it, lt) |
|
|
|
if self._data_available(spec): |
|
timestamp_samples.append((timestamp, it, lt)) |
|
|
|
if timestamp_samples: |
|
valid_samples.append(timestamp_samples) |
|
|
|
return valid_samples |
|
|
|
def _to_torch( |
|
self, |
|
data: dict[str, Tensor | list[Tensor]], |
|
dtype: torch.dtype = torch.float32, |
|
) -> dict[str, Tensor | list[Tensor]]: |
|
out = {} |
|
for k, v in data.items(): |
|
if isinstance(v, list): |
|
out[k] = [torch.from_numpy(x).to(dtype) for x in v] |
|
else: |
|
out[k] = torch.from_numpy(v).to(dtype) |
|
|
|
return out |
|
|
|
def _lat_roll( |
|
self, data: dict[str, Tensor | list[Tensor]], n: int |
|
) -> dict[str, Tensor | list[Tensor]]: |
|
out = {} |
|
for k, v in data.items(): |
|
if isinstance(v, list): |
|
out[k] = [torch.roll(x, shifts=n, dims=-1) for x in v] |
|
else: |
|
out[k] = torch.roll(v, shifts=n, dims=-1) |
|
|
|
return out |
|
|
|
def _read_static_data( |
|
self, file: str | Path, doy: int, hod: int |
|
) -> np.ndarray: |
|
with h5py.File(file, "r", libver="latest") as handle: |
|
lats_surf = handle["lat"] |
|
lons_surf = handle["lon"] |
|
|
|
nll = (len(lats_surf), len(lons_surf)) |
|
|
|
npos = len(self.position_signal) |
|
ntime = 4 |
|
|
|
nstat = npos + ntime + self._nsstat |
|
data = np.empty((nstat, *nll), dtype=self.rtype) |
|
|
|
for i, key in enumerate(self._sstat, start=npos + ntime): |
|
data[i] = handle[key][()].astype(dtype=self.rtype) |
|
|
|
|
|
data[0:npos] = self.position_signal |
|
data[npos + 0] = np.cos(2 * np.pi * doy / 366) |
|
data[npos + 1] = np.sin(2 * np.pi * doy / 366) |
|
data[npos + 2] = np.cos(2 * np.pi * hod / 24) |
|
data[npos + 3] = np.sin(2 * np.pi * hod / 24) |
|
|
|
return data |
|
|
|
def _read_surface( |
|
self, tidx: int, nll: tuple[int, int], handle: h5py.File |
|
) -> np.ndarray: |
|
data = np.empty((self._nsvars, *nll), dtype=self.rtype) |
|
|
|
for i, key in enumerate(self._svars): |
|
data[i] = handle[key][tidx][()].astype(dtype=self.rtype) |
|
|
|
return data |
|
|
|
def _read_levels( |
|
self, tidx: int, nll: tuple[int, int], handle: h5py.File |
|
) -> np.ndarray: |
|
lvls = handle["lev"][()] |
|
lidx = self._level_idxs(lvls) |
|
|
|
data = np.empty((self._nuvars, self._nlevel, *nll), dtype=self.rtype) |
|
|
|
for i, key in enumerate(self._uvars): |
|
data[i] = handle[key][tidx, lidx][()].astype(dtype=self.rtype) |
|
|
|
return np.ascontiguousarray(np.flip(data, axis=1)) |
|
|
|
def _level_idxs(self, lvls): |
|
lidx = [np.argwhere(lvls == int(lvl)).item() for lvl in self._level] |
|
return sorted(lidx) |
|
|
|
@staticmethod |
|
def _date_to_tidx(date: datetime | pd.Timestamp, handle: h5py.File) -> int: |
|
if isinstance(date, pd.Timestamp): |
|
date = date.to_pydatetime() |
|
|
|
time = handle["time"] |
|
|
|
t0 = time.attrs["begin_time"][()].item() |
|
d0 = f"{time.attrs['begin_date'][()].item()}" |
|
|
|
offset = datetime.strptime(d0, "%Y%m%d") |
|
|
|
times = [offset + timedelta(minutes=int(t + t0)) for t in time[()]] |
|
return times.index(date) |
|
|
|
def _read_data( |
|
self, file_pair: tuple[str, str], date: datetime |
|
) -> dict[str, np.ndarray]: |
|
s_file, v_file = file_pair |
|
|
|
with h5py.File(s_file, "r", libver="latest") as shandle: |
|
lats_surf = shandle["lat"] |
|
lons_surf = shandle["lon"] |
|
|
|
nll = (len(lats_surf), len(lons_surf)) |
|
|
|
tidx = self._date_to_tidx(date, shandle) |
|
|
|
sdata = self._read_surface(tidx, nll, shandle) |
|
|
|
with h5py.File(v_file, "r", libver="latest") as vhandle: |
|
lats_vert = vhandle["lat"] |
|
lons_vert = vhandle["lon"] |
|
|
|
nll = (len(lats_vert), len(lons_vert)) |
|
|
|
tidx = self._date_to_tidx(date, vhandle) |
|
|
|
vdata = self._read_levels(tidx, nll, vhandle) |
|
|
|
data = {"vert": vdata, "surf": sdata} |
|
|
|
return data |
|
|
|
def _read_climate( |
|
self, file_pair: tuple[str, str] |
|
) -> dict[str, np.ndarray]: |
|
s_file, v_file = file_pair |
|
|
|
with h5py.File(s_file, "r", libver="latest") as shandle: |
|
lats_surf = shandle["lat"] |
|
lons_surf = shandle["lon"] |
|
|
|
nll = (len(lats_surf), len(lons_surf)) |
|
|
|
sdata = np.empty((self._nsvars, *nll), dtype=self.rtype) |
|
|
|
for i, key in enumerate(self._svars): |
|
sdata[i] = shandle[key][()].astype(dtype=self.rtype) |
|
|
|
with h5py.File(v_file, "r", libver="latest") as vhandle: |
|
lats_vert = vhandle["lat"] |
|
lons_vert = vhandle["lon"] |
|
|
|
nll = (len(lats_vert), len(lons_vert)) |
|
|
|
lvls = vhandle["lev"][()] |
|
lidx = self._level_idxs(lvls) |
|
|
|
vdata = np.empty( |
|
(self._nuvars, self._nlevel, *nll), dtype=self.rtype |
|
) |
|
|
|
for i, key in enumerate(self._uvars): |
|
vdata[i] = vhandle[key][lidx][()].astype(dtype=self.rtype) |
|
|
|
data = { |
|
"vert": np.ascontiguousarray(np.flip(vdata, axis=1)), |
|
"surf": sdata, |
|
} |
|
|
|
return data |
|
|
|
def get_data_from_sample_spec( |
|
self, spec: SampleSpec |
|
) -> dict[str, Tensor | int | float]: |
|
"""Loads and assembles sample data given a SampleSpec object. |
|
|
|
Args: |
|
spec (SampleSpec): Full details regarding the data to be loaded |
|
Returns: |
|
dict: Dictionary with the following keys:: |
|
|
|
'sur_static': Torch tensor of shape [parameter, lat, lon]. For |
|
each pixel (lat, lon), the first 7 dimensions index sin(lat), |
|
cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod). |
|
Where doy is the day of the year [1, 366] and hod the hour of |
|
the day [0, 23]. |
|
'sur_vals': Torch tensor of shape [parameter, time, lat, lon]. |
|
'sur_tars': Torch tensor of shape [parameter, time, lat, lon]. |
|
'ulv_vals': Torch tensor of shape [parameter, level, time, lat, lon]. |
|
'ulv_tars': Torch tensor of shape [parameter, level, time, lat, lon]. |
|
'sur_climate': Torch tensor of shape [parameter, lat, lon]. |
|
'ulv_climate': Torch tensor of shape [paramter, level, lat, lon]. |
|
'lead_time': Float. |
|
'input_time': Float. |
|
|
|
""" |
|
|
|
|
|
vals_required = {*spec.times} |
|
stat_required = {*spec.stat_times} |
|
|
|
|
|
vals_file_map = defaultdict(list) |
|
for t in vals_required: |
|
data_files = ( |
|
self.data_file_surface(t), |
|
self.data_file_vertical(t), |
|
) |
|
vals_file_map[data_files].append(t) |
|
|
|
|
|
stat_file_map = defaultdict(list) |
|
for t in stat_required: |
|
data_files = ( |
|
self.data_file_surface(t), |
|
self.data_file_vertical(t), |
|
) |
|
stat_file_map[data_files].append(t) |
|
|
|
|
|
data = {} |
|
for data_files, times in vals_file_map.items(): |
|
for time in times: |
|
data[time] = self._read_data(data_files, time) |
|
|
|
|
|
sample_data = {} |
|
|
|
input_upl = np.stack([data[t]["vert"] for t in spec.inputs], axis=2) |
|
sample_data["ulv_vals"] = input_upl |
|
|
|
target_upl = data[spec.target]["vert"] |
|
sample_data["ulv_tars"] = target_upl[:, :, None] |
|
|
|
input_sur = np.stack([data[t]["surf"] for t in spec.inputs], axis=1) |
|
sample_data["sur_vals"] = input_sur |
|
|
|
target_sur = data[spec.target]["surf"] |
|
sample_data["sur_tars"] = target_sur[:, None] |
|
|
|
|
|
data_files, times = stat_file_map.popitem() |
|
time = times[0].dayofyear, times[0].hour |
|
sample_data["sur_static"] = self._read_static_data( |
|
data_files[0], *time |
|
) |
|
|
|
|
|
if self._require_clim: |
|
ci_year, ci_hour = spec.climatology_info |
|
|
|
surf_file = self.data_file_surface_climate( |
|
dayofyear=ci_year, |
|
hourofday=ci_hour, |
|
) |
|
|
|
vert_file = self.data_file_vertical_climate( |
|
dayofyear=ci_year, |
|
hourofday=ci_hour, |
|
) |
|
|
|
clim_data = self._read_climate((surf_file, vert_file)) |
|
|
|
sample_data["sur_climate"] = clim_data["surf"] |
|
sample_data["ulv_climate"] = clim_data["vert"] |
|
|
|
|
|
sample_data = self._to_torch(sample_data, dtype=self.dtype) |
|
|
|
|
|
if len(self._roll_longitudes) > 0: |
|
roll_by = random.choice(self._roll_longitudes) |
|
sample_data = self._lat_roll(sample_data, roll_by) |
|
|
|
|
|
sample_data["lead_time"] = spec.lead_time |
|
sample_data["input_time"] = spec.input_time |
|
|
|
return sample_data |
|
|
|
def get_data( |
|
self, timestamp: pd.Timestamp, input_time: int, lead_time: int |
|
) -> dict[str, Tensor | int]: |
|
""" |
|
Loads data based on timestamp and lead time. |
|
Args: |
|
timestamp: Timestamp. |
|
input_time: time between input samples. |
|
lead_time: lead time. |
|
Returns: |
|
Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars', |
|
'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate', |
|
'lead_time'. |
|
""" |
|
spec = SampleSpec.get(timestamp, -input_time, lead_time) |
|
sample_data = self.get_data_from_sample_spec(spec) |
|
return sample_data |
|
|
|
def __getitem__(self, idx: int) -> dict[str, Tensor | int]: |
|
""" |
|
Loads data based on sample index and random choice of sample. |
|
Args: |
|
idx: Sample index. |
|
Returns: |
|
Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars', |
|
'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate', |
|
'lead_time', 'input_time'. |
|
""" |
|
sample_set = self.samples[idx] |
|
timestamp, input_time, lead_time, *nsteps = random.choice(sample_set) |
|
sample_data = self.get_data(timestamp, input_time, lead_time) |
|
return sample_data |
|
|
|
def __len__(self): |
|
return len(self.samples) |
|
|
|
from functools import cached_property |
|
from importlib.metadata import version |
|
|
|
from torch import Tensor |
|
from torch.utils.checkpoint import checkpoint |
|
|
|
if version("torch") > "2.3.0": |
|
from torch.nn.attention import SDPBackend, sdpa_kernel |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
|
def drop_path( |
|
x: Tensor, |
|
drop_prob: float = 0.0, |
|
training: bool = False, |
|
scale_by_keep: bool = True, |
|
) -> Tensor: |
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of |
|
residual blocks). Taken form timm. |
|
|
|
Args: |
|
x (Tensor): Input tensor. |
|
drop_prob (float): Probability of dropping `x`, defaults to 0. |
|
training (bool): Whether model is in in traingin of eval mode, |
|
defaults to False. |
|
scale_by_keep (bool): Whether the output should scaled by |
|
(`1 - drop_prob`), defaults to True. |
|
Returns: |
|
Tensor: Tensor that may have randomly dropped with proability |
|
`drop_path` |
|
""" |
|
if drop_prob == 0.0 or not training: |
|
return x |
|
keep_prob = 1 - drop_prob |
|
shape = (x.shape[0],) + (1,) * (x.ndim - 1) |
|
random_tensor = x.new_empty(shape).bernoulli_(keep_prob) |
|
if keep_prob > 0.0 and scale_by_keep: |
|
random_tensor.div_(keep_prob) |
|
return x * random_tensor |
|
|
|
|
|
class DropPath(nn.Module): |
|
""" |
|
Drop paths (Stochastic Depth) per sample (when applied in main path of |
|
residual blocks). |
|
""" |
|
|
|
def __init__( |
|
self, drop_prob: float | None = None, scale_by_keep: bool = True |
|
) -> None: |
|
super(DropPath, self).__init__() |
|
self.drop_prob = drop_prob |
|
self.scale_by_keep = scale_by_keep |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
"""Runs drop path on input tensor |
|
|
|
Args: |
|
x: input |
|
|
|
Returns: |
|
tensor: output after drop_path |
|
""" |
|
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) |
|
|
|
|
|
class Mlp(nn.Module): |
|
""" |
|
Multi layer perceptron. |
|
""" |
|
|
|
def __init__( |
|
self, features: int, hidden_features: int, dropout: float = 0.0 |
|
) -> None: |
|
""" |
|
Args: |
|
features: Input/output dimension. |
|
hidden_features: Hidden dimension. |
|
dropout: Dropout. |
|
""" |
|
super().__init__() |
|
self.net = nn.Sequential( |
|
nn.Linear(features, hidden_features), |
|
nn.GELU(), |
|
nn.Dropout(dropout), |
|
nn.Linear(hidden_features, features), |
|
nn.Dropout(dropout), |
|
) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
""" |
|
Args: |
|
x (Tesnor): Tensor of shape [..., channel] |
|
Returns: |
|
Tenosr: Tensor of same shape as x. |
|
""" |
|
return self.net(x) |
|
|
|
|
|
class LayerNormPassThrough(nn.LayerNorm): |
|
"""Normalising layer that allows the attention mask to be passed through""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward( |
|
self, d: tuple[Tensor, Tensor | None] |
|
) -> tuple[Tensor, Tensor | None]: |
|
"""Forwards function |
|
|
|
Args: |
|
d (tuple): tuple of the data tensor and the attention mask |
|
Returns: |
|
output (Tensor): normalised output data |
|
attn_mask (Tensor): the attention mask that was passed in |
|
""" |
|
input, attn_mask = d |
|
output = F.layer_norm( |
|
input, self.normalized_shape, self.weight, self.bias, self.eps |
|
) |
|
return output, attn_mask |
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
"""Multihead attention layer for inputs of shape |
|
[..., sequence, features]. |
|
""" |
|
|
|
def __init__(self, features: int, n_heads: int, dropout: float) -> None: |
|
""" |
|
Args: |
|
features: Number of features for inputs to the layer. |
|
n_heads: Number of attention heads. Should be a factor of features. |
|
(I.e. the layer uses features // n_heads.) |
|
dropout: Dropout. |
|
""" |
|
super().__init__() |
|
|
|
if (features % n_heads) != 0: |
|
raise ValueError( |
|
f"Features '{features}' is not divisible by heads '{n_heads}'." |
|
) |
|
|
|
self.features = features |
|
self.n_heads = n_heads |
|
self.dropout = dropout |
|
|
|
self.qkv_layer = torch.nn.Linear(features, features * 3, bias=False) |
|
self.w_layer = torch.nn.Linear(features, features, bias=False) |
|
|
|
def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor: |
|
""" |
|
Args: |
|
d (tuple): tuple containing Tensor of shape [..., sequence, features] and the attention mask |
|
Returns: |
|
Tensor: Tensor of shape [..., sequence, features] |
|
""" |
|
x, attn_mask = d |
|
|
|
if not x.shape[-1] == self.features: |
|
raise ValueError( |
|
f"Expecting tensor with last dimension size {self.features}." |
|
) |
|
|
|
passenger_dims = x.shape[:-2] |
|
B = passenger_dims.numel() |
|
S = x.shape[-2] |
|
C = x.shape[-1] |
|
x = x.reshape(B, S, C) |
|
|
|
|
|
|
|
q, k, v = ( |
|
self.qkv_layer(x) |
|
.view(B, S, self.n_heads, 3 * (C // self.n_heads)) |
|
.transpose(1, 2) |
|
.chunk(chunks=3, dim=3) |
|
) |
|
|
|
|
|
if version("torch") > "2.3.0": |
|
with sdpa_kernel( |
|
[SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION] |
|
): |
|
|
|
x = F.scaled_dot_product_attention( |
|
q, k, v, attn_mask=attn_mask, dropout_p=self.dropout |
|
) |
|
else: |
|
with torch.backends.cuda.sdp_kernel( |
|
enable_flash=True, enable_math=False, enable_mem_efficient=True |
|
): |
|
|
|
x = F.scaled_dot_product_attention( |
|
q, k, v, dropout_p=self.dropout |
|
) |
|
|
|
|
|
x = x.transpose(1, 2).view(B, S, C) |
|
|
|
|
|
x = self.w_layer(x) |
|
|
|
|
|
x = x.view(*passenger_dims, S, self.features) |
|
return x |
|
|
|
|
|
class Transformer(nn.Module): |
|
""" |
|
Transformer for inputs of shape [..., S, features]. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
features: int, |
|
mlp_multiplier: int, |
|
n_heads: int, |
|
dropout: float, |
|
drop_path: float, |
|
) -> None: |
|
""" |
|
Args: |
|
features: Number of features for inputs to the layer. |
|
mlp_multiplier: Model uses features*mlp_multiplier hidden units. |
|
n_heads: Number of attention heads. Should be a factor of features. |
|
(I.e. the layer uses features // n_heads.) dropout: Dropout. |
|
drop_path: DropPath. |
|
""" |
|
super().__init__() |
|
|
|
self.features = features |
|
self.mlp_multiplier = mlp_multiplier |
|
self.n_heads = n_heads |
|
self.dropout = dropout |
|
self.drop_path = ( |
|
DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
) |
|
|
|
self.attention = nn.Sequential( |
|
LayerNormPassThrough(features), |
|
MultiheadAttention(features, n_heads, dropout), |
|
) |
|
|
|
self.ff = nn.Sequential( |
|
nn.LayerNorm(features), |
|
Mlp( |
|
features=features, |
|
hidden_features=features * mlp_multiplier, |
|
dropout=dropout, |
|
), |
|
) |
|
|
|
def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor: |
|
""" |
|
Args: |
|
x: Tensor of shape [..., sequence, features] |
|
Returns: |
|
Tensor: Tensor of shape [..., sequence, features] |
|
""" |
|
x, attn_mask = d |
|
if not x.shape[-1] == self.features: |
|
raise ValueError( |
|
f"Expecting tensor with last dimension size {self.features}." |
|
) |
|
|
|
attention_x = self.attention(d) |
|
|
|
x = x + self.drop_path(attention_x) |
|
x = x + self.drop_path(self.ff(x)) |
|
|
|
return x |
|
|
|
|
|
class _Shift(nn.Module): |
|
"""Private base class for the shifter. This allows some behaviour to be |
|
easily handled when the shifter isn't used. |
|
""" |
|
|
|
def __init__(self): |
|
super().__init__() |
|
|
|
self._shifted = False |
|
|
|
@torch.no_grad() |
|
def reset(self) -> None: |
|
""" |
|
Resets the bool tracking whether the data is shifted |
|
""" |
|
self._shifted: bool = False |
|
|
|
def forward(self, data: Tensor) -> tuple[Tensor, dict[bool, None]]: |
|
return data, {True: None, False: None} |
|
|
|
|
|
class SWINShift(_Shift): |
|
""" |
|
Handles the shifting of patches similar to how SWIN works. However if we |
|
shift the latitudes then the poles will wrap and potentially that might be |
|
problematic. The possition tokens should handle it but masking is safer. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
mu_shape: tuple[int, int], |
|
global_shape: tuple[int, int], |
|
local_shape: tuple[int, int], |
|
patch_shape: tuple[int, int], |
|
n_context_tokens: int = 2, |
|
) -> None: |
|
""" |
|
Args: |
|
mu_shape: the shape to the masking units |
|
global_shape: number of global patches in lat and lon |
|
local_shape: size of the local patches |
|
patch_shape: patch size |
|
n_context_token: number of additional context tokens at start of |
|
_each_ local sequence |
|
""" |
|
super().__init__() |
|
|
|
self._mu_shape = ms = mu_shape |
|
self._g_shape = gs = global_shape |
|
self._l_shape = ls = local_shape |
|
self._p_shape = ps = patch_shape |
|
self._lat_patch = (gs[0], ls[0], gs[1], ls[1]) |
|
self._n_context_tokens = n_context_tokens |
|
|
|
self._g_shift_to = tuple( |
|
int(0.5 * x / p) for x, p in zip(ms, ps, strict=False) |
|
) |
|
self._g_shift_from = tuple( |
|
-int(0.5 * x / p) for x, p in zip(ms, ps, strict=False) |
|
) |
|
|
|
|
|
nglobal = global_shape[0] * global_shape[1] |
|
nlocal = ( |
|
local_shape[0] * local_shape[1] + self._n_context_tokens |
|
) |
|
|
|
lm = torch.ones((nglobal, 1, nlocal, nlocal), dtype=bool) |
|
mwidth = int(0.5 * local_shape[1]) * local_shape[0] |
|
lm[ |
|
: gs[1], |
|
:, |
|
self._n_context_tokens : mwidth + self._n_context_tokens, |
|
self._n_context_tokens : mwidth + self._n_context_tokens, |
|
] = False |
|
self.register_buffer("local_mask", lm) |
|
|
|
gm = torch.ones((nlocal, 1, nglobal, nglobal), dtype=bool) |
|
gm[: int(0.5 * ls[1]) * ls[0], :, : gs[1], : gs[1]] = False |
|
self.register_buffer("global_mask", gm) |
|
|
|
def _to_grid_global(self, x: Tensor) -> Tensor: |
|
""" |
|
Shuffle and reshape the data from the global/local setting back to the |
|
lat/lon grid setting |
|
Args: |
|
x: the data tensor to be shuffled. |
|
Returns: |
|
x: data in the global/local setting |
|
""" |
|
nbatch, *other = x.shape |
|
|
|
y1 = x.view(nbatch, *self._g_shape, *self._l_shape, -1) |
|
y2 = y1.permute(0, 5, 1, 3, 2, 4).contiguous() |
|
|
|
s = y2.shape |
|
return y2.view((nbatch, -1, s[2] * s[3], s[4] * s[5])) |
|
|
|
def _to_grid_local(self, x: Tensor) -> Tensor: |
|
""" |
|
Shuffle and reshape the data from the local/global setting to the |
|
lat/lon grid setting |
|
Args: |
|
x: the data tensor to be shuffled. |
|
Returns: |
|
x: data in the lat/lon setting. |
|
""" |
|
x = x.transpose(2, 1).contiguous() |
|
return self._to_grid_global(x) |
|
|
|
def _from_grid_global(self, x: Tensor) -> Tensor: |
|
""" |
|
Shuffle and reshape the data from the lat/lon grid to the global/local |
|
setting |
|
Args: |
|
x: the data tensor to be shuffled. |
|
Returns: |
|
x: data in the global/local setting |
|
""" |
|
nbatch, *other = x.shape |
|
|
|
z1 = x.view(nbatch, -1, *self._lat_patch) |
|
z2 = z1.permute(0, 2, 4, 3, 5, 1).contiguous() |
|
|
|
s = z2.shape |
|
return z2.view(nbatch, s[1] * s[2], s[3] * s[4], -1) |
|
|
|
def _from_grid_local(self, x: Tensor) -> Tensor: |
|
""" |
|
Shuffle and reshape the data from the lat/lon grid to the local/global |
|
setting |
|
Args: |
|
x: the data tensor to be shuffled. |
|
Returns: |
|
x: data in the local/global setting |
|
""" |
|
x = self._from_grid_global(x) |
|
return x.transpose(2, 1).contiguous() |
|
|
|
def _shift(self, x: Tensor) -> Tensor: |
|
""" |
|
Shifts data in the gridded lat/lon setting by half the mask unit shape |
|
Args: |
|
x: data to be shifted |
|
Returns: |
|
x: either the hsifted or unshifted data |
|
""" |
|
shift = self._g_shift_from if self._shifted else self._g_shift_to |
|
x_shifted = torch.roll(x, shift, (-2, -1)) |
|
|
|
self._shifted = not self._shifted |
|
return x_shifted |
|
|
|
def _sep_lt(self, x: Tensor) -> tuple[Tensor, Tensor]: |
|
""" |
|
Seperate off the leadtime from the local patches |
|
Args: |
|
x: data to have leadtime removed from |
|
Returns: |
|
lt: leadtime |
|
x: data without the lead time in the local patch |
|
""" |
|
lt_it = x[:, : self._n_context_tokens, :, :] |
|
x_stripped = x[:, self._n_context_tokens :, :, :] |
|
|
|
return lt_it, x_stripped |
|
|
|
def forward(self, data: Tensor) -> tuple[Tensor, Tensor]: |
|
"""Shift or unshift the the data depending on whether the data is |
|
already shifted, as defined by self._shifte. |
|
|
|
Args: |
|
data: data to be shifted |
|
Returns: |
|
Tensor: shifted data Tensor |
|
""" |
|
lt, x = self._sep_lt(data) |
|
|
|
x_grid = self._to_grid_local(x) |
|
x_shifted = self._shift(x_grid) |
|
x_patched = self._from_grid_local(x_shifted) |
|
|
|
|
|
n_batch = x_grid.shape[0] |
|
local_rep = [n_batch] + [1] * (self.local_mask.ndim - 1) |
|
global_rep = [n_batch] + [1] * (self.global_mask.ndim - 1) |
|
|
|
if self._shifted: |
|
attn_mask = { |
|
True: self.local_mask.repeat(local_rep), |
|
False: self.global_mask.repeat(global_rep), |
|
} |
|
else: |
|
attn_mask = {True: None, False: None} |
|
|
|
return torch.cat((lt, x_patched), axis=1), attn_mask |
|
|
|
|
|
class LocalGlobalLocalBlock(nn.Module): |
|
""" |
|
Applies alternating block and grid attention. Given a parameter n_blocks, |
|
the entire module contains 2*n_blocks+1 transformer blocks. The first, |
|
third, ..., last apply local (block) attention. The second, fourth, ... |
|
global (grid) attention. |
|
|
|
This is heavily inspired by |
|
Tu et al. "MaxViT: Multi-Axis Vision Transformer" |
|
(https://arxiv.org/abs/2204.01697). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
features: int, |
|
mlp_multiplier: int, |
|
n_heads: int, |
|
dropout: float, |
|
n_blocks: int, |
|
drop_path: float, |
|
shifter: nn.Module | None = None, |
|
checkpoint: list[int] | None = None, |
|
) -> None: |
|
""" |
|
Args: |
|
features: Number of features for inputs to the layer. |
|
mlp_multiplier: Model uses features*mlp_multiplier hidden units. |
|
n_heads: Number of attention heads. Should be a factor of features. |
|
(I.e. the layer uses features // n_heads.) |
|
dropout: Dropout. |
|
drop_path: DropPath. |
|
n_blocks: Number of local-global transformer pairs. |
|
""" |
|
super().__init__() |
|
|
|
self.features = features |
|
self.mlp_multiplier = mlp_multiplier |
|
self.n_heads = n_heads |
|
self.dropout = dropout |
|
self.drop_path = drop_path |
|
self.n_blocks = n_blocks |
|
self._checkpoint = checkpoint or [] |
|
|
|
if not all(0 <= c < 2 * n_blocks + 1 for c in self._checkpoint): |
|
raise ValueError( |
|
"Checkpoints should be 0 <= i < 2*n_blocks+1. " |
|
f"{self._checkpoint=}." |
|
) |
|
|
|
self.transformers = nn.ModuleList( |
|
[ |
|
Transformer( |
|
features=features, |
|
mlp_multiplier=mlp_multiplier, |
|
n_heads=n_heads, |
|
dropout=dropout, |
|
drop_path=drop_path, |
|
) |
|
for _ in range(2 * n_blocks + 1) |
|
] |
|
) |
|
|
|
self.evaluator = [ |
|
self._checkpoint_wrapper |
|
if i in self._checkpoint |
|
else lambda m, x: m(x) |
|
for i, _ in enumerate(self.transformers) |
|
] |
|
|
|
self.shifter = shifter or _Shift() |
|
|
|
@staticmethod |
|
def _checkpoint_wrapper( |
|
model: nn.Module, data: tuple[Tensor, Tensor | None] |
|
) -> Tensor: |
|
return checkpoint(model, data, use_reentrant=False) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
""" |
|
Args: |
|
x: Tensor of shape:: |
|
|
|
[batch, global_sequence, local_sequence, features] |
|
|
|
Returns: |
|
Tensor: Tensor of shape:: |
|
|
|
[batch, global_sequence, local_sequence, features] |
|
""" |
|
if x.shape[-1] != self.features: |
|
raise ValueError( |
|
f"Expecting tensor with last dimension size {self.features}." |
|
) |
|
if x.ndim != 4: |
|
raise ValueError( |
|
f"Expecting tensor with exactly four dimensions. {x.shape=}." |
|
) |
|
|
|
self.shifter.reset() |
|
local: bool = True |
|
attn_mask = {True: None, False: None} |
|
|
|
transformer_iter = zip(self.evaluator, self.transformers, strict=False) |
|
|
|
|
|
evaluator, transformer = next(transformer_iter) |
|
x = evaluator(transformer, (x, attn_mask[local])) |
|
|
|
for evaluator, transformer in transformer_iter: |
|
local = not local |
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
|
x = evaluator(transformer, (x, attn_mask[local])) |
|
|
|
if not local: |
|
x, attn_mask = self.shifter(x) |
|
|
|
return x |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
""" |
|
Patch embedding via 2D convolution. |
|
""" |
|
|
|
def __init__( |
|
self, patch_size: int | tuple[int, ...], channels: int, embed_dim: int |
|
): |
|
super().__init__() |
|
|
|
self.patch_size = patch_size |
|
self.channels = channels |
|
self.embed_dim = embed_dim |
|
|
|
self.proj = nn.Conv2d( |
|
channels, |
|
embed_dim, |
|
kernel_size=patch_size, |
|
stride=patch_size, |
|
bias=True, |
|
) |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
""" |
|
Args: |
|
x: Tensor of shape [batch, channels, lat, lon]. |
|
Returns: |
|
Tensor: Tensor with shape |
|
[batch, embed_dim, lat//patch_size, lon//patch_size] |
|
""" |
|
|
|
H, W = x.shape[-2:] |
|
|
|
if W % self.patch_size[1] != 0: |
|
raise ValueError( |
|
f"Cannot do patch embedding for tensor of shape {x.size()}" |
|
" with patch size {self.patch_size}. (Dimensions are BSCHW.)" |
|
) |
|
if H % self.patch_size[0] != 0: |
|
raise ValueError( |
|
f"Cannot do patch embedding for tensor of shape {x.size()}" |
|
f" with patch size {self.patch_size}. (Dimensions are BSCHW.)" |
|
) |
|
|
|
x = self.proj(x) |
|
|
|
return x |
|
|
|
|
|
class PrithviWxCEncoderDecoder(nn.Module): |
|
""" |
|
Hiera-MaxViT encoder/decoder code. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
embed_dim: int, |
|
n_blocks: int, |
|
mlp_multiplier: float, |
|
n_heads: int, |
|
dropout: float, |
|
drop_path: float, |
|
shifter: nn.Module | None = None, |
|
transformer_cp: list[int] | None = None, |
|
) -> None: |
|
""" |
|
Args: |
|
embed_dim: Embedding dimension |
|
n_blocks: Number of local-global transformer pairs. |
|
mlp_multiplier: MLP multiplier for hidden features in feed forward |
|
networks. |
|
n_heads: Number of attention heads. |
|
dropout: Dropout. |
|
drop_path: DropPath. |
|
""" |
|
super().__init__() |
|
|
|
self.embed_dim = embed_dim |
|
self.n_blocks = n_blocks |
|
self.mlp_multiplier = mlp_multiplier |
|
self.n_heads = n_heads |
|
self.dropout = dropout |
|
self._transformer_cp = transformer_cp |
|
|
|
self.lgl_block = LocalGlobalLocalBlock( |
|
features=embed_dim, |
|
mlp_multiplier=mlp_multiplier, |
|
n_heads=n_heads, |
|
dropout=dropout, |
|
drop_path=drop_path, |
|
n_blocks=n_blocks, |
|
shifter=shifter, |
|
checkpoint=transformer_cp, |
|
) |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Args: |
|
x: Tensor of shape |
|
[batch, global sequence, local sequence, embed_dim] |
|
Returns: |
|
Tensor of shape |
|
[batch, mask_unit_sequence, local_sequence, embed_dim]. |
|
Identical in shape to the input x. |
|
""" |
|
|
|
x = self.lgl_block(x) |
|
|
|
return x |
|
|
|
|
|
class PrithviWxC(nn.Module): |
|
"""Encoder-decoder fusing Hiera with MaxViT. See |
|
- Ryali et al. "Hiera: A Hierarchical Vision Transformer without the |
|
Bells-and-Whistles" (https://arxiv.org/abs/2306.00989) |
|
- Tu et al. "MaxViT: Multi-Axis Vision Transformer" |
|
(https://arxiv.org/abs/2204.01697) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_channels: int, |
|
input_size_time: int, |
|
in_channels_static: int, |
|
input_scalers_mu: Tensor, |
|
input_scalers_sigma: Tensor, |
|
input_scalers_epsilon: float, |
|
static_input_scalers_mu: Tensor, |
|
static_input_scalers_sigma: Tensor, |
|
static_input_scalers_epsilon: float, |
|
output_scalers: Tensor, |
|
n_lats_px: int, |
|
n_lons_px: int, |
|
patch_size_px: tuple[int], |
|
mask_unit_size_px: tuple[int], |
|
mask_ratio_inputs: float, |
|
embed_dim: int, |
|
n_blocks_encoder: int, |
|
n_blocks_decoder: int, |
|
mlp_multiplier: float, |
|
n_heads: int, |
|
dropout: float, |
|
drop_path: float, |
|
parameter_dropout: float, |
|
residual: str, |
|
masking_mode: str, |
|
positional_encoding: str, |
|
decoder_shifting: bool = False, |
|
checkpoint_encoder: list[int] | None = None, |
|
checkpoint_decoder: list[int] | None = None, |
|
) -> None: |
|
""" |
|
Args: |
|
in_channels: Number of input channels. |
|
input_size_time: Number of timestamps in input. |
|
in_channels_static: Number of input channels for static data. |
|
input_scalers_mu: Tensor of size (in_channels,). Used to rescale |
|
input. |
|
input_scalers_sigma: Tensor of size (in_channels,). Used to rescale |
|
input. |
|
input_scalers_epsilon: Float. Used to rescale input. |
|
static_input_scalers_mu: Tensor of size (in_channels_static). Used |
|
to rescale static inputs. |
|
static_input_scalers_sigma: Tensor of size (in_channels_static). |
|
Used to rescale static inputs. |
|
static_input_scalers_epsilon: Float. Used to rescale static inputs. |
|
output_scalers: Tensor of shape (in_channels,). Used to rescale |
|
output. |
|
n_lats_px: Total latitudes in data. In pixels. |
|
n_lons_px: Total longitudes in data. In pixels. |
|
patch_size_px: Patch size for tokenization. In pixels lat/lon. |
|
mask_unit_size_px: Size of each mask unit. In pixels lat/lon. |
|
mask_ratio_inputs: Masking ratio for inputs. 0 to 1. |
|
embed_dim: Embedding dimension |
|
n_blocks_encoder: Number of local-global transformer pairs in |
|
encoder. |
|
n_blocks_decoder: Number of local-global transformer pairs in |
|
decoder. |
|
mlp_multiplier: MLP multiplier for hidden features in feed forward |
|
networks. |
|
n_heads: Number of attention heads. |
|
dropout: Dropout. |
|
drop_path: DropPath. |
|
parameter_dropout: Dropout applied to parameters. |
|
residual: Indicates whether and how model should work as residual |
|
model. Accepted values are 'climate', 'temporal' and 'none' |
|
positional_encoding: possible values are |
|
['absolute' (default), 'fourier']. |
|
'absolute' lat lon encoded in 3 dimensions using sine and |
|
cosine |
|
'fourier' lat/lon to be encoded using various frequencies |
|
masking_mode: String ['local', 'global', 'both'] that controls the |
|
type of masking used. |
|
checkpoint_encoder: List of integers controlling if gradient |
|
checkpointing is used on encoder. |
|
Format: [] for no gradient checkpointing. [3, 7] for |
|
checkpointing after 4th and 8th layer etc. |
|
checkpoint_decoder: List of integers controlling if gradient |
|
checkpointing is used on decoder. |
|
Format: See `checkpoint_encoder`. |
|
masking_mode: The type of masking to use |
|
{'global', 'local', 'both'} |
|
decoder_shifting: Whether to use swin shifting in the decoder. |
|
""" |
|
super().__init__() |
|
|
|
self.in_channels = in_channels |
|
self.input_size_time = input_size_time |
|
self.in_channels_static = in_channels_static |
|
self.n_lats_px = n_lats_px |
|
self.n_lons_px = n_lons_px |
|
self.patch_size_px = patch_size_px |
|
self.mask_unit_size_px = mask_unit_size_px |
|
self.mask_ratio_inputs = mask_ratio_inputs |
|
self.embed_dim = embed_dim |
|
self.n_blocks_encoder = n_blocks_encoder |
|
self.n_blocks_decoder = n_blocks_decoder |
|
self.mlp_multiplier = mlp_multiplier |
|
self.n_heads = n_heads |
|
self.dropout = dropout |
|
self.drop_path = drop_path |
|
self.residual = residual |
|
self._decoder_shift = decoder_shifting |
|
self.positional_encoding = positional_encoding |
|
self._checkpoint_encoder = checkpoint_encoder |
|
self._checkpoint_decoder = checkpoint_decoder |
|
|
|
assert self.n_lats_px % self.mask_unit_size_px[0] == 0 |
|
assert self.n_lons_px % self.mask_unit_size_px[1] == 0 |
|
assert self.mask_unit_size_px[0] % self.patch_size_px[0] == 0 |
|
assert self.mask_unit_size_px[1] % self.patch_size_px[1] == 0 |
|
|
|
if self.patch_size_px[0] != self.patch_size_px[1]: |
|
raise NotImplementedError( |
|
"Current pixel shuffle symmetric patches." |
|
) |
|
|
|
self.local_shape_mu = ( |
|
self.mask_unit_size_px[0] // self.patch_size_px[0], |
|
self.mask_unit_size_px[1] // self.patch_size_px[1], |
|
) |
|
self.global_shape_mu = ( |
|
self.n_lats_px // self.mask_unit_size_px[0], |
|
self.n_lons_px // self.mask_unit_size_px[1], |
|
) |
|
|
|
assert input_scalers_mu.shape == (in_channels,) |
|
assert input_scalers_sigma.shape == (in_channels,) |
|
assert output_scalers.shape == (in_channels,) |
|
|
|
if self.positional_encoding != "fourier": |
|
assert static_input_scalers_mu.shape == (in_channels_static,) |
|
assert static_input_scalers_sigma.shape == (in_channels_static,) |
|
|
|
|
|
self.input_scalers_epsilon = input_scalers_epsilon |
|
self.register_buffer( |
|
"input_scalers_mu", input_scalers_mu.reshape(1, 1, -1, 1, 1) |
|
) |
|
self.register_buffer( |
|
"input_scalers_sigma", input_scalers_sigma.reshape(1, 1, -1, 1, 1) |
|
) |
|
|
|
|
|
self.static_input_scalers_epsilon = static_input_scalers_epsilon |
|
self.register_buffer( |
|
"static_input_scalers_mu", |
|
static_input_scalers_mu.reshape(1, -1, 1, 1), |
|
) |
|
self.register_buffer( |
|
"static_input_scalers_sigma", |
|
static_input_scalers_sigma.reshape(1, -1, 1, 1), |
|
) |
|
|
|
|
|
self.register_buffer( |
|
"output_scalers", output_scalers.reshape(1, -1, 1, 1) |
|
) |
|
|
|
self.parameter_dropout = nn.Dropout2d(p=parameter_dropout) |
|
|
|
self.patch_embedding = PatchEmbed( |
|
patch_size=patch_size_px, |
|
channels=in_channels * input_size_time, |
|
embed_dim=embed_dim, |
|
) |
|
|
|
if self.residual == "climate": |
|
self.patch_embedding_static = PatchEmbed( |
|
patch_size=patch_size_px, |
|
channels=in_channels + in_channels_static, |
|
embed_dim=embed_dim, |
|
) |
|
else: |
|
self.patch_embedding_static = PatchEmbed( |
|
patch_size=patch_size_px, |
|
channels=in_channels_static, |
|
embed_dim=embed_dim, |
|
) |
|
|
|
self.input_time_embedding = nn.Linear(1, embed_dim // 4, bias=True) |
|
self.lead_time_embedding = nn.Linear(1, embed_dim // 4, bias=True) |
|
|
|
self.mask_token = nn.Parameter(torch.randn(1, 1, 1, self.embed_dim)) |
|
self._nglobal_mu = np.prod(self.global_shape_mu) |
|
self._global_idx = torch.arange(self._nglobal_mu) |
|
|
|
self._nlocal_mu = np.prod(self.local_shape_mu) |
|
self._local_idx = torch.arange(self._nlocal_mu) |
|
|
|
self.encoder = PrithviWxCEncoderDecoder( |
|
embed_dim=embed_dim, |
|
n_blocks=n_blocks_encoder, |
|
mlp_multiplier=mlp_multiplier, |
|
n_heads=n_heads, |
|
dropout=dropout, |
|
drop_path=drop_path, |
|
transformer_cp=checkpoint_encoder, |
|
) |
|
|
|
if n_blocks_decoder != 0: |
|
if self._decoder_shift: |
|
self.decoder_shifter = d_shifter = SWINShift( |
|
self.mask_unit_size_px, |
|
self.global_shape_mu, |
|
self.local_shape_mu, |
|
self.patch_size_px, |
|
n_context_tokens=0, |
|
) |
|
else: |
|
self.decoder_shifter = d_shifter = None |
|
|
|
self.decoder = PrithviWxCEncoderDecoder( |
|
embed_dim=embed_dim, |
|
n_blocks=n_blocks_decoder, |
|
mlp_multiplier=mlp_multiplier, |
|
n_heads=n_heads, |
|
dropout=dropout, |
|
drop_path=0.0, |
|
shifter=d_shifter, |
|
transformer_cp=checkpoint_decoder, |
|
) |
|
|
|
self.unembed = nn.Linear( |
|
self.embed_dim, |
|
self.in_channels |
|
* self.patch_size_px[0] |
|
* self.patch_size_px[1], |
|
bias=True, |
|
) |
|
|
|
self.masking_mode = masking_mode.lower() |
|
match self.masking_mode: |
|
case "local": |
|
self.generate_mask = self._gen_mask_local |
|
case "global": |
|
self.generate_mask = self._gen_mask_global |
|
case "both": |
|
self._mask_both_local: bool = True |
|
self.generate_mask = self._gen_mask_both |
|
case _: |
|
raise ValueError( |
|
f"Masking mode '{masking_mode}' not supported" |
|
) |
|
|
|
def swap_masking(self) -> None: |
|
self._mask_both_local = not self._mask_both_local |
|
|
|
@cached_property |
|
def n_masked_global(self): |
|
return int(self.mask_ratio_inputs * np.prod(self.global_shape_mu)) |
|
|
|
@cached_property |
|
def n_masked_local(self): |
|
return int(self.mask_ratio_inputs * np.prod(self.local_shape_mu)) |
|
|
|
@staticmethod |
|
def _shuffle_along_axis(a, axis): |
|
idx = torch.argsort(input=torch.rand(*a.shape), dim=axis) |
|
return torch.gather(a, dim=axis, index=idx) |
|
|
|
def _gen_mask_local(self, sizes: tuple[int]) -> tuple[Tensor]: |
|
""" |
|
Args: |
|
batch_size: Number of elements in batch |
|
Returns: |
|
Tuple of torch tensors. [indices masked, indices unmasked]. |
|
Each of these is a tensor of shape (batch, global sequene) |
|
""" |
|
|
|
|
|
maskable_indices = self._local_idx.view(1, -1).expand(*sizes[:2], -1) |
|
|
|
maskable_indices = self._shuffle_along_axis(maskable_indices, 2) |
|
|
|
indices_masked = maskable_indices[:, :, : self.n_masked_local] |
|
indices_unmasked = maskable_indices[:, :, self.n_masked_local :] |
|
|
|
return indices_masked, indices_unmasked |
|
|
|
def _gen_mask_global(self, sizes: tuple[int]) -> tuple[Tensor]: |
|
""" |
|
Args: |
|
batch_size: Number of elements in batch |
|
Returns: |
|
Tuple of torch tensors. [indices masked, indices unmasked]. |
|
Each of these is a tensor of shape (batch, global sequene) |
|
""" |
|
|
|
|
|
maskable_indices = self._global_idx.view(1, -1).expand(*sizes[:1], -1) |
|
|
|
maskable_indices = self._shuffle_along_axis(maskable_indices, 1) |
|
|
|
indices_masked = maskable_indices[:, : self.n_masked_global] |
|
indices_unmasked = maskable_indices[:, self.n_masked_global :] |
|
|
|
return indices_masked, indices_unmasked |
|
|
|
def _gen_mask_both(self, sizes: tuple[int]) -> tuple[Tensor]: |
|
if self._mask_both_local: |
|
return self._gen_mask_local(sizes) |
|
else: |
|
return self._gen_mask_global(sizes) |
|
|
|
@staticmethod |
|
def reconstruct_batch( |
|
idx_masked: Tensor, |
|
idx_unmasked: Tensor, |
|
data_masked: Tensor, |
|
data_unmasked: Tensor, |
|
) -> Tensor: |
|
"""Reconstructs a tensor along the mask unit dimension. Batched |
|
version. |
|
|
|
Args: |
|
idx_masked: Tensor of shape `batch, mask unit sequence`. |
|
idx_unmasked: Tensor of shape `batch, mask unit sequence`. |
|
data_masked: Tensor of shape `batch, mask unit sequence, ...`. |
|
Should have same size along mask unit sequence dimension as |
|
idx_masked. Dimensions beyond the first two, marked here as ... |
|
will typically be `local_sequence, channel` or |
|
`channel, lat, lon`. These dimensions should agree with |
|
data_unmasked. |
|
data_unmasked: Tensor of shape `batch, mask unit sequence, ...`. |
|
Should have same size along mask unit sequence dimension as |
|
idx_unmasked. Dimensions beyond the first two, marked here as |
|
... will typically be `local_sequence, channel` or `channel, |
|
lat, lon`. These dimensions should agree with data_masked. |
|
Returns: |
|
Tensor: Tensor of same shape as inputs data_masked and |
|
data_unmasked. I.e. `batch, mask unit sequence, ...`. Index for |
|
the total data composed of the masked and the unmasked part. |
|
""" |
|
dim: int = idx_masked.ndim |
|
|
|
idx_total = torch.argsort( |
|
torch.cat([idx_masked, idx_unmasked], dim=-1), dim=-1 |
|
) |
|
idx_total = idx_total.view( |
|
*idx_total.shape, *[1] * (data_unmasked.ndim - dim) |
|
) |
|
idx_total = idx_total.expand( |
|
*idx_total.shape[:dim], *data_unmasked.shape[dim:] |
|
) |
|
|
|
data = torch.cat([data_masked, data_unmasked], dim=dim - 1) |
|
data = torch.gather(data, dim=dim - 1, index=idx_total) |
|
|
|
return data, idx_total |
|
|
|
def fourier_pos_encoding(self, x_static: Tensor) -> Tensor: |
|
""" |
|
Args |
|
x_static: B x C x H x W. first two channels are lat, and lon |
|
Returns |
|
Tensor: Tensor of shape B x E x H x W where E is the embedding |
|
dimension. |
|
""" |
|
|
|
|
|
latitudes_patch = F.avg_pool2d( |
|
x_static[:, [0]], |
|
kernel_size=self.patch_size_px, |
|
stride=self.patch_size_px, |
|
) |
|
longitudes_patch = F.avg_pool2d( |
|
x_static[:, [1]], |
|
kernel_size=self.patch_size_px, |
|
stride=self.patch_size_px, |
|
) |
|
|
|
modes = ( |
|
torch.arange(self.embed_dim // 4, device=x_static.device).view( |
|
1, -1, 1, 1 |
|
) |
|
+ 1.0 |
|
) |
|
pos_encoding = torch.cat( |
|
( |
|
torch.sin(latitudes_patch * modes), |
|
torch.sin(longitudes_patch * modes), |
|
torch.cos(latitudes_patch * modes), |
|
torch.cos(longitudes_patch * modes), |
|
), |
|
axis=1, |
|
) |
|
|
|
return pos_encoding |
|
|
|
def time_encoding(self, input_time, lead_time): |
|
""" |
|
Args: |
|
input_time: Tensor of shape [batch]. |
|
lead_time: Tensor of shape [batch]. |
|
Returns: |
|
Tensor: Tensor of shape [batch, embed_dim, 1, 1] |
|
""" |
|
input_time = self.input_time_embedding(input_time.view(-1, 1, 1, 1)) |
|
lead_time = self.lead_time_embedding(lead_time.view(-1, 1, 1, 1)) |
|
|
|
time_encoding = torch.cat( |
|
( |
|
torch.cos(input_time), |
|
torch.cos(lead_time), |
|
torch.sin(input_time), |
|
torch.sin(lead_time), |
|
), |
|
axis=3, |
|
) |
|
return time_encoding |
|
|
|
def to_patching(self, x: Tensor) -> Tensor: |
|
"""Transform data from lat/lon space to two axis patching |
|
|
|
Args: -> |
|
x: Tesnor in lat/lon space (N, C, Nlat//P_0, Nlon//P_1) |
|
|
|
Returns: |
|
Tensor in patch space (N, G, L, C) |
|
""" |
|
n_batch = x.shape[0] |
|
|
|
x = x.view( |
|
n_batch, |
|
-1, |
|
self.global_shape_mu[0], |
|
self.local_shape_mu[0], |
|
self.global_shape_mu[1], |
|
self.local_shape_mu[1], |
|
) |
|
x = x.permute(0, 2, 4, 3, 5, 1).contiguous() |
|
|
|
s = x.shape |
|
return x.view(n_batch, s[1] * s[2], s[3] * s[4], -1) |
|
|
|
def from_patching(self, x: Tensor) -> Tensor: |
|
"""Transform data from two axis patching to lat/lon space |
|
|
|
Args: |
|
x: Tensor in patch space with shape (N, G, L, C*P_0*P_1) |
|
|
|
Returns: |
|
Tensor: Tensor in lat/lon space |
|
(N, C*P_0*P_1, Nlat//P_0, Nlon // P_1) |
|
""" |
|
n_batch = x.shape[0] |
|
|
|
x = x.view( |
|
n_batch, |
|
self.global_shape_mu[0], |
|
self.global_shape_mu[1], |
|
self.local_shape_mu[0], |
|
self.local_shape_mu[1], |
|
-1, |
|
) |
|
x = x.permute(0, 5, 1, 3, 2, 4).contiguous() |
|
|
|
s = x.shape |
|
return x.view(n_batch, -1, s[2] * s[3], s[4] * s[5]) |
|
|
|
def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor: |
|
""" |
|
Args: |
|
batch: Dictionary the following keys:: |
|
|
|
'x': Tensor of shape [batch, time, parameter, lat, lon] |
|
'y': Tensor of shape [batch, parameter, lat, lon] |
|
'static': Tensor of shape [batch, channel_static, lat, lon] |
|
'climate': Optional tensor of shape [batch, parameter, lat, lon] |
|
'input_time': Tensor of shape [batch]. Or none. |
|
'lead_time': Tensor of shape [batch]. Or none. |
|
|
|
Returns: |
|
Tensor: Tensor of shape [batch, parameter, lat, lon]. |
|
""" |
|
x_rescaled = (batch["x"] - self.input_scalers_mu) / ( |
|
self.input_scalers_sigma + self.input_scalers_epsilon |
|
) |
|
batch_size = x_rescaled.shape[0] |
|
|
|
if self.positional_encoding == "fourier": |
|
x_static_pos = self.fourier_pos_encoding(batch["static"]) |
|
x_static = ( |
|
batch["static"][:, 2:] - self.static_input_scalers_mu[:, 3:] |
|
) / ( |
|
self.static_input_scalers_sigma[:, 3:] |
|
+ self.static_input_scalers_epsilon |
|
) |
|
else: |
|
x_static = (batch["static"] - self.static_input_scalers_mu) / ( |
|
self.static_input_scalers_sigma |
|
+ self.static_input_scalers_epsilon |
|
) |
|
|
|
if self.residual == "temporal": |
|
|
|
index = torch.where( |
|
batch["lead_time"] > 0, batch["x"].shape[1] - 1, 0 |
|
) |
|
index = index.view(-1, 1, 1, 1, 1) |
|
index = index.expand(batch_size, 1, *batch["x"].shape[2:]) |
|
x_hat = torch.gather(batch["x"], dim=1, index=index) |
|
x_hat = x_hat.squeeze(1) |
|
elif self.residual == "climate": |
|
climate_scaled = ( |
|
batch["climate"] - self.input_scalers_mu.view(1, -1, 1, 1) |
|
) / ( |
|
self.input_scalers_sigma.view(1, -1, 1, 1) |
|
+ self.input_scalers_epsilon |
|
) |
|
|
|
|
|
|
|
x_rescaled = x_rescaled.flatten(1, 2) |
|
|
|
x_rescaled = self.parameter_dropout(x_rescaled) |
|
|
|
x_embedded = self.patch_embedding(x_rescaled) |
|
|
|
if self.residual == "climate": |
|
static_embedded = self.patch_embedding_static( |
|
torch.cat((x_static, climate_scaled), dim=1) |
|
) |
|
else: |
|
static_embedded = self.patch_embedding_static(x_static) |
|
|
|
if self.positional_encoding == "fourier": |
|
static_embedded += x_static_pos |
|
|
|
x_embedded = self.to_patching(x_embedded) |
|
static_embedded = self.to_patching(static_embedded) |
|
|
|
time_encoding = self.time_encoding( |
|
batch["input_time"], batch["lead_time"] |
|
) |
|
|
|
tokens = x_embedded + static_embedded + time_encoding |
|
|
|
|
|
indices_masked, indices_unmasked = self.generate_mask( |
|
(batch_size, self._nglobal_mu) |
|
) |
|
indices_masked = indices_masked.to(device=tokens.device) |
|
indices_unmasked = indices_unmasked.to(device=tokens.device) |
|
maskdim: int = indices_masked.ndim |
|
|
|
|
|
unmask_view = (*indices_unmasked.shape, *[1] * (tokens.ndim - maskdim)) |
|
unmasked = torch.gather( |
|
tokens, |
|
dim=maskdim - 1, |
|
index=indices_unmasked.view(*unmask_view).expand( |
|
*indices_unmasked.shape, *tokens.shape[maskdim:] |
|
), |
|
) |
|
|
|
|
|
x_encoded = self.encoder(unmasked) |
|
|
|
|
|
|
|
|
|
mask_view = (*indices_masked.shape, *[1] * (tokens.ndim - maskdim)) |
|
masking = self.mask_token.repeat(*static_embedded.shape[:3], 1) |
|
masked = masking + static_embedded |
|
masked = torch.gather( |
|
masked, |
|
dim=maskdim - 1, |
|
index=indices_masked.view(*mask_view).expand( |
|
*indices_masked.shape, *tokens.shape[maskdim:] |
|
), |
|
) |
|
|
|
recon, _ = self.reconstruct_batch( |
|
indices_masked, indices_unmasked, masked, x_encoded |
|
) |
|
|
|
x_decoded = self.decoder(recon) |
|
|
|
|
|
|
|
x_unembed = self.unembed(x_decoded) |
|
|
|
|
|
|
|
x_out = self.from_patching(x_unembed) |
|
|
|
|
|
x_out = F.pixel_shuffle(x_out, self.patch_size_px[0]) |
|
|
|
if self.residual == "temporal": |
|
x_out = self.output_scalers * x_out + x_hat |
|
elif self.residual == "climate": |
|
x_out = self.output_scalers * x_out + batch["climate"] |
|
elif self.residual == "none": |
|
x_out = ( |
|
self.output_scalers * x_out |
|
+ self.input_scalers_mu.reshape(1, -1, 1, 1) |
|
) |
|
|
|
return x_out |
|
|