{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PrithviWxC\n", "\n", "This notebook will walk you through how to construct the model,\n", "load the weights, build the dataset, and use the model for inference." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import random\n", "from pathlib import Path\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "from huggingface_hub import hf_hub_download, snapshot_download" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now configure the backends and torch states, including setting the seeds for the RNGs." ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "torch.jit.enable_onednn_fusion(True)\n", "if torch.cuda.is_available():\n", " print(f\"Using device: {torch.cuda.get_device_name()}\")\n", " torch.backends.cudnn.benchmark = True\n", " torch.backends.cudnn.deterministic = True\n", "\n", "random.seed(42)\n", "if torch.cuda.is_available():\n", " torch.cuda.manual_seed(42)\n", "torch.manual_seed(42)\n", "np.random.seed(42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The model has approximately 2.3 billion parameters, so it\n", "requires reasonable computational resources, but it is possible\n", "to run it on a CPU." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "if torch.cuda.is_available():\n", " device = torch.device(\"cuda\")\n", "else:\n", " device = torch.device(\"cpu\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataloader\n", "### Variables and times\n", "\n", "With the environment ready to go, we now need to set up the task.\n", "The core model expects a fixed set of variables from the MERRA-2\n", "dataset, which are prescribed below. The variables are comprised\n", "of surface variables, surface static variables, and variables at\n", "various vertical levels within the atmosphere. More details on the\n", "MERRA-2 dataset can be found\n", "[here](https://gmao.gsfc.nasa.gov/reanalysis/MERRA-2/).\n", "\n", "The MERRA-2 dataset includes data at longitudes of $-180^\\circ$\n", "and $+180^\\circ$. This represents duplicate data, so we set a\n", "padding variable to remove it.\n", "\n", "The input to the core model consists of these variables at two\n", "different times. The time difference in hours between these samples\n", "is passed to the model and set in the input_time variable.\n", "\n", "The model's task is to predict the fixed set of variables at a\n", "target time, given the input data.\n", "\n", "For example, if the input times are 0900 and 1200, resulting in\n", "an input_time of -3, then a lead_time of 6 would result in a\n", "target time of 1800." ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "surface_vars = [\n", " \"EFLUX\",\n", " \"GWETROOT\",\n", " \"HFLUX\",\n", " \"LAI\",\n", " \"LWGAB\",\n", " \"LWGEM\",\n", " \"LWTUP\",\n", " \"PS\",\n", " \"QV2M\",\n", " \"SLP\",\n", " \"SWGNT\",\n", " \"SWTNT\",\n", " \"T2M\",\n", " \"TQI\",\n", " \"TQL\",\n", " \"TQV\",\n", " \"TS\",\n", " \"U10M\",\n", " \"V10M\",\n", " \"Z0M\",\n", "]\n", "static_surface_vars = [\"FRACI\", \"FRLAND\", \"FROCEAN\", \"PHIS\"]\n", "vertical_vars = [\"CLOUD\", \"H\", \"OMEGA\", \"PL\", \"QI\", \"QL\", \"QV\", \"T\", \"U\", \"V\"]\n", "levels = [\n", " 34.0,\n", " 39.0,\n", " 41.0,\n", " 43.0,\n", " 44.0,\n", " 45.0,\n", " 48.0,\n", " 51.0,\n", " 53.0,\n", " 56.0,\n", " 63.0,\n", " 68.0,\n", " 71.0,\n", " 72.0,\n", "]\n", "padding = {\"level\": [0, 0], \"lat\": [0, -1], \"lon\": [0, 0]}\n", "\n", "lead_times = [12] # This varibale can be change to change the task\n", "input_times = [-6] # This varibale can be change to change the task" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data file\n", "MERRA-2 data is available from 1980 to the present day,\n", "at 3-hour temporal resolution. The dataloader we have provided\n", "expects the surface data and vertical data to be saved in\n", "separate files, and when provided with the directories, will\n", "search for the relevant data that falls within the provided time range.\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "45d1a1486bdc4dff82597d5cf87095f0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Fetching 1 files: 0%| | 0/1 [00:00 dict[str, Tensor]:\n", " \"\"\"Prepressing function for MERRA2 Dataset\n", "\n", " Args:\n", " batch (dict): List of training samples, each sample should be a\n", " dictionary with the following keys::\n", "\n", " 'sur_static': Numpy array of shape (3, lat, lon). For each pixel (lat, lon), the first dimension indexes sin(lat), cos(lon), sin(lon).\n", " 'sur_vals': Torch tensor of shape (parameter, time, lat, lon).\n", " 'sur_tars': Torch tensor of shape (parameter, time, lat, lon).\n", " 'ulv_vals': Torch tensor of shape (parameter, level, time, lat, lon).\n", " 'ulv_tars': Torch tensor of shape (parameter, level, time, lat, lon).\n", " 'sur_climate': Torch tensor of shape (parameter, lat, lon)\n", " 'ulv_climate': Torch tensor of shape (parameter, level, lat, lon)\n", " 'lead_time': Integer.\n", " 'input_time': Integer.\n", "\n", " padding: Dictionary with keys 'level', 'lat', 'lon', each of dim 2.\n", "\n", " Returns:\n", " Dictionary with the following keys::\n", "\n", " 'x': [batch, time, parameter, lat, lon]\n", " 'y': [batch, parameter, lat, lon]\n", " 'static': [batch, parameter, lat, lon]\n", " 'lead_time': [batch]\n", " 'input_time': [batch]\n", " 'climate (Optional)': [batch, parameter, lat, lon]\n", "\n", " Note:\n", " Here, for x and y, 'parameter' is [surface parameter, upper level,\n", " parameter x level]. Similarly for the static information we have\n", " [sin(lat), cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod),\n", " ...].\n", " \"\"\" # noqa: E501\n", " b0 = batch[0]\n", " nbatch = len(batch)\n", " data_keys = set(b0.keys())\n", "\n", " essential_keys = {\n", " \"sur_static\",\n", " \"sur_vals\",\n", " \"sur_tars\",\n", " \"ulv_vals\",\n", " \"ulv_tars\",\n", " \"input_time\",\n", " \"lead_time\",\n", " }\n", "\n", " climate_keys = {\n", " \"sur_climate\",\n", " \"ulv_climate\",\n", " }\n", "\n", " all_keys = essential_keys | climate_keys\n", "\n", " if not essential_keys.issubset(data_keys):\n", " raise ValueError(\"Missing essential keys.\")\n", "\n", " if not data_keys.issubset(all_keys):\n", " raise ValueError(\"Unexpected keys in batch.\")\n", "\n", " # Bring all tensors from the batch into a single tensor\n", " upl_x = torch.empty((nbatch, *b0[\"ulv_vals\"].shape))\n", " upl_y = torch.empty((nbatch, *b0[\"ulv_tars\"].shape))\n", "\n", " sur_x = torch.empty((nbatch, *b0[\"sur_vals\"].shape))\n", " sur_y = torch.empty((nbatch, *b0[\"sur_tars\"].shape))\n", "\n", " sur_sta = torch.empty((nbatch, *b0[\"sur_static\"].shape))\n", "\n", " lead_time = torch.empty((nbatch,), dtype=torch.float32)\n", " input_time = torch.empty((nbatch,), dtype=torch.float32)\n", "\n", " for i, rec in enumerate(batch):\n", " sur_x[i] = rec[\"sur_vals\"]\n", " sur_y[i] = rec[\"sur_tars\"]\n", "\n", " upl_x[i] = rec[\"ulv_vals\"]\n", " upl_y[i] = rec[\"ulv_tars\"]\n", "\n", " sur_sta[i] = rec[\"sur_static\"]\n", "\n", " lead_time[i] = rec[\"lead_time\"]\n", " input_time[i] = rec[\"input_time\"]\n", "\n", " return_value = {\n", " \"lead_time\": lead_time,\n", " \"input_time\": input_time,\n", " }\n", "\n", " # Reshape (batch, parameter, level, time, lat, lon) ->\n", " # (batch, time, parameter, level, lat, lon)\n", " upl_x = upl_x.permute((0, 3, 1, 2, 4, 5))\n", " upl_y = upl_y.permute((0, 3, 1, 2, 4, 5))\n", " # Reshape (batch, parameter, time, lat, lon) ->\n", " # (batch, time, parameter, lat, lon)\n", " sur_x = sur_x.permute((0, 2, 1, 3, 4))\n", " sur_y = sur_y.permute((0, 2, 1, 3, 4))\n", "\n", " # Pad\n", " padding_2d = (*padding[\"lon\"], *padding[\"lat\"])\n", "\n", " def pad2d(x):\n", " return torch.nn.functional.pad(x, padding_2d, mode=\"constant\", value=0)\n", "\n", " padding_3d = (*padding[\"lon\"], *padding[\"lat\"], *padding[\"level\"])\n", "\n", " def pad3d(x):\n", " return torch.nn.functional.pad(x, padding_3d, mode=\"constant\", value=0)\n", "\n", " sur_x = pad2d(sur_x).contiguous()\n", " upl_x = pad3d(upl_x).contiguous()\n", " sur_y = pad2d(sur_y).contiguous()\n", " upl_y = pad3d(upl_y).contiguous()\n", " return_value[\"static\"] = pad2d(sur_sta).contiguous()\n", "\n", " # Remove time for targets\n", " upl_y = torch.squeeze(upl_y, 1)\n", " sur_y = torch.squeeze(sur_y, 1)\n", "\n", " # We stack along the combined parameter x level dimension\n", " return_value[\"x\"] = torch.cat(\n", " (sur_x, upl_x.view(*upl_x.shape[:2], -1, *upl_x.shape[4:])), dim=2\n", " )\n", " return_value[\"y\"] = torch.cat(\n", " (sur_y, upl_y.view(upl_y.shape[0], -1, *upl_y.shape[3:])), dim=1\n", " )\n", "\n", " if climate_keys.issubset(data_keys):\n", " sur_climate = torch.empty((nbatch, *b0[\"sur_climate\"].shape))\n", " ulv_climate = torch.empty((nbatch, *b0[\"ulv_climate\"].shape))\n", " for i, rec in enumerate(batch):\n", " sur_climate[i] = rec[\"sur_climate\"]\n", " ulv_climate[i] = rec[\"ulv_climate\"]\n", " sur_climate = pad2d(sur_climate)\n", " ulv_climate = pad3d(ulv_climate)\n", "\n", " return_value[\"climate\"] = torch.cat(\n", " (\n", " sur_climate,\n", " ulv_climate.view(nbatch, -1, *ulv_climate.shape[3:]),\n", " ),\n", " dim=1,\n", " )\n", "\n", " return return_value\n", "\n", "\n", "def input_scalers(\n", " surf_vars: list[str],\n", " vert_vars: list[str],\n", " levels: list[float],\n", " surf_path: str | Path,\n", " vert_path: str | Path,\n", ") -> tuple[Tensor, Tensor]:\n", " \"\"\"Reads the input scalers\n", "\n", " Args:\n", " surf_vars: surface variables to be used.\n", " vert_vars: vertical variables to be used.\n", " levels: MERRA2 levels to use.\n", " surf_path: path to surface scalers file.\n", " vert_path: path to vertical level scalers file.\n", "\n", " Returns:\n", " mu (Tensor): mean values\n", " var (Tensor): varience values\n", " \"\"\"\n", " with h5py.File(Path(surf_path), \"r\", libver=\"latest\") as surf_file:\n", " stats = [x.decode().lower() for x in surf_file[\"statistic\"][()]]\n", " mu_idx = stats.index(\"mu\")\n", " sig_idx = stats.index(\"sigma\")\n", "\n", " s_mu = torch.tensor([surf_file[k][()][mu_idx] for k in surf_vars])\n", " s_sig = torch.tensor([surf_file[k][()][sig_idx] for k in surf_vars])\n", "\n", " with h5py.File(Path(vert_path), \"r\", libver=\"latest\") as vert_file:\n", " stats = [x.decode().lower() for x in vert_file[\"statistic\"][()]]\n", " mu_idx = stats.index(\"mu\")\n", " sig_idx = stats.index(\"sigma\")\n", "\n", " lvl = vert_file[\"lev\"][()]\n", " l_idx = [np.where(lvl == v)[0].item() for v in levels]\n", "\n", " v_mu = np.array([vert_file[k][()][mu_idx, l_idx] for k in vert_vars])\n", " v_sig = np.array([vert_file[k][()][sig_idx, l_idx] for k in vert_vars])\n", "\n", " v_mu = torch.from_numpy(v_mu).view(-1)\n", " v_sig = torch.from_numpy(v_sig).view(-1)\n", "\n", " mu = torch.cat((s_mu, v_mu), dim=0).to(torch.float32)\n", " sig = torch.cat((s_sig, v_sig), dim=0).to(torch.float32).clamp(1e-4, 1e4)\n", " return mu, sig\n", "\n", "\n", "def static_input_scalers(\n", " scalar_path: str | Path, stat_vars: list[str], unscaled_params: int = 7\n", ") -> tuple[Tensor, Tensor]:\n", " scalar_path = Path(scalar_path)\n", "\n", " with h5py.File(scalar_path, \"r\", libver=\"latest\") as scaler_file:\n", " stats = [x.decode().lower() for x in scaler_file[\"statistic\"][()]]\n", " mu_idx = stats.index(\"mu\")\n", " sig_idx = stats.index(\"sigma\")\n", "\n", " mu = torch.tensor([scaler_file[k][()][mu_idx] for k in stat_vars])\n", " sig = torch.tensor([scaler_file[k][()][sig_idx] for k in stat_vars])\n", "\n", " z = torch.zeros(unscaled_params, dtype=mu.dtype, device=mu.device)\n", " o = torch.ones(unscaled_params, dtype=sig.dtype, device=sig.device)\n", " mu = torch.cat((z, mu), dim=0).to(torch.float32)\n", " sig = torch.cat((o, sig), dim=0).to(torch.float32)\n", "\n", " return mu, sig.clamp(1e-4, 1e4)\n", "\n", "\n", "def output_scalers(\n", " surf_vars: list[str],\n", " vert_vars: list[str],\n", " levels: list[float],\n", " surf_path: str | Path,\n", " vert_path: str | Path,\n", ") -> Tensor:\n", " surf_path = Path(surf_path)\n", " vert_path = Path(vert_path)\n", "\n", " with h5py.File(surf_path, \"r\", libver=\"latest\") as surf_file:\n", " svars = torch.tensor([surf_file[k][()] for k in surf_vars])\n", "\n", " with h5py.File(vert_path, \"r\", libver=\"latest\") as vert_file:\n", " lvl = vert_file[\"lev\"][()]\n", " l_idx = [np.where(lvl == v)[0].item() for v in levels]\n", " vvars = np.array([vert_file[k][()][l_idx] for k in vert_vars])\n", " vvars = torch.from_numpy(vvars).view(-1)\n", "\n", " var = torch.cat((svars, vvars), dim=0).to(torch.float32).clamp(1e-7, 1e7)\n", "\n", " return var\n", "\n", "\n", "class SampleSpec:\n", " \"\"\"\n", " A data class to collect the information used to define a sample.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " inputs: tuple[pd.Timestamp, pd.Timestamp],\n", " lead_time: int,\n", " target: pd.Timestamp | list[pd.Timestamp],\n", " ):\n", " \"\"\"\n", " Args:\n", " inputs: Tuple of timestamps. In ascending order.\n", " lead_time: Lead time. In hours.\n", " target: Timestamp of the target. Can be before or after the inputs.\n", " \"\"\"\n", " if not inputs[0] < inputs[1]:\n", " raise ValueError(\n", " \"Timestamps in `inputs` should be in strictly ascending order.\"\n", " )\n", "\n", " self.inputs = inputs\n", " self.input_time = (inputs[1] - inputs[0]).total_seconds() / 3600\n", " self.lead_time = lead_time\n", " self.target = target\n", "\n", " self.times = [*inputs, target]\n", " self.stat_times = [inputs[-1]]\n", "\n", " @property\n", " def climatology_info(self) -> tuple[int, int]:\n", " \"\"\"Get the required climatology info.\n", "\n", " :return: information required to obtain climatology data. Essentially\n", " this is the day of the year and hour of the day of the target\n", " timestamp, with the former restricted to the interval [1, 365].\n", " :rtype: tuple\n", " \"\"\"\n", " return (min(self.target.dayofyear, 365), self.target.hour)\n", "\n", " @property\n", " def year(self) -> int:\n", " return self.inputs[1].year\n", "\n", " @property\n", " def dayofyear(self) -> int:\n", " return self.inputs[1].dayofyear\n", "\n", " @property\n", " def hourofday(self) -> int:\n", " return self.inputs[1].hour\n", "\n", " def _info_str(self) -> str:\n", " iso_8601 = \"%Y-%m-%dT%H:%M:%S\"\n", "\n", " return (\n", " f\"Issue time: {self.inputs[1].strftime(iso_8601)}\\n\"\n", " f\"Lead time: {self.lead_time} hours ahead\\n\"\n", " f\"Input delta: {self.input_time} hours\\n\"\n", " f\"Target time: {self.target.strftime(iso_8601)}\"\n", " )\n", "\n", " @classmethod\n", " def get(cls, timestamp: pd.Timestamp, dt: int, lead_time: int):\n", " \"\"\"Given a timestamp and lead time, generates a SampleSpec object\n", " describing the sample further.\n", "\n", " Args:\n", " timestamp: Timstamp of the sample, Ie this is the larger of the two\n", " input timstamps.\n", " dt: Time between input samples, in hours.\n", " lead_time: Lead time. In hours.\n", "\n", " Returns:\n", " SampleSpec\n", " \"\"\" # noqa: E501\n", " assert dt > 0, \"dt should be possitive\"\n", " lt = pd.to_timedelta(lead_time, unit=\"h\")\n", " dt = pd.to_timedelta(dt, unit=\"h\")\n", "\n", " if lead_time >= 0:\n", " timestamp_target = timestamp + lt\n", " else:\n", " timestamp_target = timestamp - dt + lt\n", "\n", " spec = cls(\n", " inputs=(timestamp - dt, timestamp),\n", " lead_time=lead_time,\n", " target=timestamp_target,\n", " )\n", "\n", " return spec\n", "\n", " def __repr__(self) -> str:\n", " return self._info_str()\n", "\n", " def __str__(self) -> str:\n", " return self._info_str()\n", "\n", "\n", "class Merra2Dataset(Dataset):\n", " \"\"\"MERRA2 dataset. The dataset unifies surface and vertical data as well as\n", " optional climatology.\n", "\n", " Samples come in the form of a dictionary. Not all keys support all\n", " variables, yet the general ordering of dimensions is\n", " parameter, level, time, lat, lon\n", "\n", " Note:\n", " Data is assumed to be in NetCDF files containing daily data at 3-hourly\n", " intervals. These follow the naming patterns\n", " MERRA2_sfc_YYYYMMHH.nc and MERRA_pres_YYYYMMHH.nc and can be located in\n", " two different locations. Optional climatology data comes from files\n", " climate_surface_doyDOY_hourHOD.nc and\n", " climate_vertical_doyDOY_hourHOD.nc.\n", "\n", "\n", " Note:\n", " `_get_valid_timestamps` assembles a set of all timestamps for which\n", " there is data (with hourly resolutions). The result is stored in\n", " `_valid_timestamps`. `_get_valid_climate_timestamps` does the same with\n", " climatology data and stores it in `_valid_climate_timestamps`.\n", "\n", " Based on this information, `samples` generates a list of valid samples,\n", " stored in `samples`. Here the format is::\n", "\n", " [\n", " [\n", " (timestamp 1, lead time A),\n", " (timestamp 1, lead time B),\n", " (timestamp 1, lead time C),\n", " ],\n", " [\n", " (timestamp 2, lead time D),\n", " (timestamp 2, lead time E),\n", " ]\n", " ]\n", "\n", " That is, the outer list iterates over timestamps (init times), the\n", " inner over lead times. Only valid entries are stored.\n", " \"\"\"\n", "\n", " valid_vertical_vars = [\n", " \"CLOUD\",\n", " \"H\",\n", " \"OMEGA\",\n", " \"PL\",\n", " \"QI\",\n", " \"QL\",\n", " \"QV\",\n", " \"T\",\n", " \"U\",\n", " \"V\",\n", " ]\n", " valid_surface_vars = [\n", " \"EFLUX\",\n", " \"GWETROOT\",\n", " \"HFLUX\",\n", " \"LAI\",\n", " \"LWGAB\",\n", " \"LWGEM\",\n", " \"LWTUP\",\n", " \"PRECTOT\",\n", " \"PS\",\n", " \"QV2M\",\n", " \"SLP\",\n", " \"SWGNT\",\n", " \"SWTNT\",\n", " \"T2M\",\n", " \"TQI\",\n", " \"TQL\",\n", " \"TQV\",\n", " \"TS\",\n", " \"U10M\",\n", " \"V10M\",\n", " \"Z0M\",\n", " ]\n", " valid_static_surface_vars = [\"FRACI\", \"FRLAND\", \"FROCEAN\", \"PHIS\"]\n", "\n", " valid_levels = [\n", " 34.0,\n", " 39.0,\n", " 41.0,\n", " 43.0,\n", " 44.0,\n", " 45.0,\n", " 48.0,\n", " 51.0,\n", " 53.0,\n", " 56.0,\n", " 63.0,\n", " 68.0,\n", " 71.0,\n", " 72.0,\n", " ]\n", "\n", " timedelta_input = pd.to_timedelta(3, unit=\"h\")\n", "\n", " def __init__(\n", " self,\n", " time_range: tuple[str | pd.Timestamp, str | pd.Timestamp],\n", " lead_times: list[int],\n", " input_times: list[int],\n", " data_path_surface: str | Path,\n", " data_path_vertical: str | Path,\n", " climatology_path_surface: str | Path | None = None,\n", " climatology_path_vertical: str | Path | None = None,\n", " surface_vars: list[str] | None = None,\n", " static_surface_vars: list[str] | None = None,\n", " vertical_vars: list[str] | None = None,\n", " levels: list[float] | None = None,\n", " roll_longitudes: int = 0,\n", " positional_encoding: str = \"absolute\",\n", " rtype: type = np.float32,\n", " dtype: torch.dtype = torch.float32,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " data_path_surface: Location of surface data.\n", " data_path_vertical: Location of vertical data.\n", " climatology_path_surface: Location of (optional) surface\n", " climatology.\n", " climatology_path_vertical: Location of (optional) vertical\n", " climatology.\n", " surface_vars: Surface variables.\n", " static_surface_vars: Static surface variables.\n", " vertical_vars: Vertical variables.\n", " levels: Levels.\n", " time_range: Used to subset data.\n", " lead_times: Lead times for generalized forecasting.\n", " roll_longitudes: Set to non-zero value to data by random amount\n", " along longitude dimension.\n", " position_encoding: possible values are\n", " ['absolute' (default), 'fourier'].\n", " 'absolute' returns lat lon encoded in 3 dimensions using sine\n", " and cosine\n", " 'fourier' returns lat/lon to be encoded by model\n", " returns lat/lon to be encoded by model\n", " rtype: numpy data type used during read\n", " dtype: torch data type of data output\n", " \"\"\"\n", "\n", " self.time_range = (\n", " pd.to_datetime(time_range[0]),\n", " pd.to_datetime(time_range[1]),\n", " )\n", " self.lead_times = lead_times\n", " self.input_times = input_times\n", " self._roll_longitudes = list(range(roll_longitudes + 1))\n", "\n", " self._uvars = vertical_vars or self.valid_vertical_vars\n", " self._level = levels or self.valid_levels\n", " self._svars = surface_vars or self.valid_surface_vars\n", " self._sstat = static_surface_vars or self.valid_static_surface_vars\n", " self._nuvars = len(self._uvars)\n", " self._nlevel = len(self._level)\n", " self._nsvars = len(self._svars)\n", " self._nsstat = len(self._sstat)\n", "\n", " self.rtype = rtype\n", " self.dtype = dtype\n", "\n", " self.positional_encoding = positional_encoding\n", "\n", " self._data_path_surface = Path(data_path_surface)\n", " self._data_path_vertical = Path(data_path_vertical)\n", "\n", " self.dir_exists(self._data_path_surface)\n", " self.dir_exists(self._data_path_vertical)\n", "\n", " self._get_coordinates()\n", "\n", " self._climatology_path_surface = Path(climatology_path_surface) or None\n", " self._climatology_path_vertical = (\n", " Path(climatology_path_vertical) or None\n", " )\n", " self._require_clim = (\n", " self._climatology_path_surface is not None\n", " and self._climatology_path_vertical is not None\n", " )\n", "\n", " if self._require_clim:\n", " self.dir_exists(self._climatology_path_surface)\n", " self.dir_exists(self._climatology_path_vertical)\n", " elif (\n", " climatology_path_surface is None\n", " and climatology_path_vertical is None\n", " ):\n", " self._climatology_path_surface = None\n", " self._climatology_path_vertical = None\n", " else:\n", " raise ValueError(\n", " \"Either both or neither of\"\n", " \"`climatology_path_surface` and\"\n", " \"`climatology_path_vertical` should be None.\"\n", " )\n", "\n", " if not set(self._svars).issubset(set(self.valid_surface_vars)):\n", " raise ValueError(\"Invalid surface variable.\")\n", "\n", " if not set(self._sstat).issubset(set(self.valid_static_surface_vars)):\n", " raise ValueError(\"Invalid static surface variable.\")\n", "\n", " if not set(self._uvars).issubset(set(self.valid_vertical_vars)):\n", " raise ValueError(\"Inalid vertical variable.\")\n", "\n", " if not set(self._level).issubset(set(self.valid_levels)):\n", " raise ValueError(\"Invalid level.\")\n", "\n", " @staticmethod\n", " def dir_exists(path: Path) -> None:\n", " if not path.is_dir():\n", " raise ValueError(f\"Directory {path} does not exist.\")\n", "\n", " @property\n", " def upper_shape(self) -> tuple:\n", " \"\"\"Returns the vertical variables shape\n", " Returns:\n", " tuple: vertical variable shape in the following order::\n", "\n", " [VAR, LEV, TIME, LAT, LON]\n", " \"\"\"\n", " return self._nuvars, self._nlevel, 2, 361, 576\n", "\n", " @property\n", " def surface_shape(self) -> tuple:\n", " \"\"\"Returns the surface variables shape\n", "\n", " Returns:\n", " tuple: surafce shape in the following order::\n", "\n", " [VAR, LEV, TIME, LAT, LON]\n", " \"\"\"\n", " return self._nsvars, 2, 361, 576\n", "\n", " def data_file_surface(self, timestamp: pd.Timestamp) -> Path:\n", " \"\"\"Build the surfcae data file name based on timestamp\n", "\n", " Args:\n", " timestamp: a timestamp\n", "\n", " Returns:\n", " Path: constructed path\n", " \"\"\"\n", " pattern = \"MERRA2_sfc_%Y%m%d.nc\"\n", " data_file = self._data_path_surface / timestamp.strftime(pattern)\n", " return data_file\n", "\n", " def data_file_vertical(self, timestamp: pd.Timestamp) -> Path:\n", " \"\"\"Build the vertical data file name based on timestamp\n", "\n", " Args:\n", " timestamp: a timestamp\n", "\n", " Returns:\n", " Path: constructed path\n", " \"\"\"\n", " pattern = \"MERRA_pres_%Y%m%d.nc\"\n", " data_file = self._data_path_vertical / timestamp.strftime(pattern)\n", " return data_file\n", "\n", " def data_file_surface_climate(\n", " self,\n", " timestamp: pd.Timestamp | None = None,\n", " dayofyear: int | None = None,\n", " hourofday: int | None = None,\n", " ) -> Path:\n", " \"\"\"\n", " Returns the path to a climatology file based either on a timestamp or\n", " the dayofyear / hourofday combination.\n", " Args:\n", " timestamp: A timestamp.\n", " dayofyear: Day of the year. 1 to 366.\n", " hourofday: Hour of the day. 0 to 23.\n", " Returns:\n", " Path: Path to climatology file.\n", " \"\"\"\n", " if timestamp is not None and (\n", " (dayofyear is not None) or (hourofday is not None)\n", " ):\n", " raise ValueError(\n", " \"Provide either timestamp or both dayofyear and hourofday.\"\n", " )\n", "\n", " if timestamp is not None:\n", " dayofyear = min(timestamp.dayofyear, 365)\n", " hourofday = timestamp.hour\n", "\n", " file_name = f\"climate_surface_doy{dayofyear:03}_hour{hourofday:02}.nc\"\n", " data_file = self._climatology_path_surface / file_name\n", " return data_file\n", "\n", " def data_file_vertical_climate(\n", " self,\n", " timestamp: pd.Timestamp | None = None,\n", " dayofyear: int | None = None,\n", " hourofday: int | None = None,\n", " ) -> Path:\n", " \"\"\"Returns the path to a climatology file based either on a timestamp\n", " or the dayofyear / hourofday combination.\n", "\n", " Args:\n", " timestamp: A timestamp. dayofyear: Day of the year. 1 to 366.\n", " hourofday: Hour of the day. 0 to 23.\n", " Returns:\n", " Path: Path to climatology file.\n", " \"\"\"\n", " if timestamp is not None and (\n", " (dayofyear is not None) or (hourofday is not None)\n", " ):\n", " raise ValueError(\n", " \"Provide either timestamp or both dayofyear and hourofday.\"\n", " )\n", "\n", " if timestamp is not None:\n", " dayofyear = min(timestamp.dayofyear, 365)\n", " hourofday = timestamp.hour\n", "\n", " file_name = f\"climate_vertical_doy{dayofyear:03}_hour{hourofday:02}.nc\"\n", " data_file = self._climatology_path_vertical / file_name\n", " return data_file\n", "\n", " def _get_coordinates(self) -> None:\n", " \"\"\"\n", " Obtains the coordiantes (latitudes and longitudes) from a single data\n", " file.\n", " \"\"\"\n", " timestamp = next(iter(self.valid_timestamps))\n", "\n", " file = self.data_file_surface(timestamp)\n", " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", " self.lats = lats = handle[\"lat\"][()].astype(self.rtype)\n", " self.lons = lons = handle[\"lon\"][()].astype(self.rtype)\n", "\n", " deg_to_rad = np.pi / 180\n", " self._embed_lat = np.sin(lats * deg_to_rad).reshape(-1, 1)\n", "\n", " self._embed_lon = np.empty((2, 1, len(lons)), dtype=self.rtype)\n", " self._embed_lon[0, 0] = np.cos(lons * deg_to_rad)\n", " self._embed_lon[1, 0] = np.sin(lons * deg_to_rad)\n", "\n", " @ft.cached_property\n", " def lats(self) -> np.ndarray:\n", " timestamp = next(iter(self.valid_timestamps))\n", "\n", " file = self.data_file_surface(timestamp)\n", " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", " return handle[\"lat\"][()].astype(self.rtype)\n", "\n", " @ft.cached_property\n", " def lons(self) -> np.ndarray:\n", " timestamp = next(iter(self.valid_timestamps))\n", "\n", " file = self.data_file_surface(timestamp)\n", " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", " return handle[\"lon\"][()].astype(self.rtype)\n", "\n", " @ft.cached_property\n", " def position_signal(self) -> np.ndarray:\n", " \"\"\"Generates the \"position signal\" that is part of the static\n", " features.\n", "\n", " Returns:\n", " Tensor: Torch tensor of dimension (parameter, lat, lon) containing\n", " sin(lat), cos(lon), sin(lon).\n", " \"\"\"\n", "\n", " latitudes, longitudes = np.meshgrid(\n", " self.lats, self.lons, indexing=\"ij\"\n", " )\n", "\n", " if self.positional_encoding == \"absolute\":\n", " latitudes = latitudes / 360 * 2.0 * np.pi\n", " longitudes = longitudes / 360 * 2.0 * np.pi\n", " sur_static = np.stack(\n", " [np.sin(latitudes), np.cos(longitudes), np.sin(longitudes)],\n", " axis=0,\n", " )\n", " else:\n", " sur_static = np.stack([latitudes, longitudes], axis=0)\n", "\n", " sur_static = sur_static.astype(self.rtype)\n", "\n", " return sur_static\n", "\n", " @ft.cached_property\n", " def valid_timestamps(self) -> set[pd.Timestamp]:\n", " \"\"\"Generates list of valid timestamps based on available files. Only\n", " timestamps for which both surface and vertical information is available\n", " are considered valid.\n", " Returns:\n", " list: list of timestamps\n", " \"\"\"\n", "\n", " s_glob = self._data_path_surface.glob(\"MERRA2_sfc_????????.nc\")\n", " s_files = [os.path.basename(f) for f in s_glob]\n", " v_glob = self._data_path_surface.glob(\"MERRA_pres_????????.nc\")\n", " v_files = [os.path.basename(f) for f in v_glob]\n", "\n", " s_re = re.compile(r\"MERRA2_sfc_(\\d{8}).nc\\Z\")\n", " v_re = re.compile(r\"MERRA_pres_(\\d{8}).nc\\Z\")\n", " fmt = \"%Y%m%d\"\n", "\n", " s_times = {\n", " (datetime.strptime(m[1], fmt))\n", " for f in s_files\n", " if (m := s_re.match(f))\n", " }\n", " v_times = {\n", " (datetime.strptime(m[1], fmt))\n", " for f in v_files\n", " if (m := v_re.match(f))\n", " }\n", "\n", " times = s_times.intersection(v_times)\n", "\n", " # Each file contains a day at 3 hour intervals\n", " times = {\n", " t + timedelta(hours=i) for i in range(0, 24, 3) for t in times\n", " }\n", "\n", " start_time, end_time = self.time_range\n", " times = {pd.Timestamp(t) for t in times if start_time <= t <= end_time}\n", "\n", " return times\n", "\n", " @ft.cached_property\n", " def valid_climate_timestamps(self) -> set[tuple[int, int]]:\n", " \"\"\"Generates list of \"timestamps\" (dayofyear, hourofday) for which\n", " climatology data is present. Only instances for which surface and\n", " vertical data is available are considered valid.\n", " Returns:\n", " list: List of tuples describing valid climatology instances.\n", " \"\"\"\n", " if not self._require_clim:\n", " return set()\n", "\n", " s_glob = self._climatology_path_surface.glob(\n", " \"climate_surface_doy???_hour??.nc\"\n", " )\n", " s_files = [os.path.basename(f) for f in s_glob]\n", "\n", " v_glob = self._climatology_path_vertical.glob(\n", " \"climate_vertical_doy???_hour??.nc\"\n", " )\n", " v_files = [os.path.basename(f) for f in v_glob]\n", "\n", " s_re = re.compile(r\"climate_surface_doy(\\d{3})_hour(\\d{2}).nc\\Z\")\n", " v_re = re.compile(r\"climate_vertical_doy(\\d{3})_hour(\\d{2}).nc\\Z\")\n", "\n", " s_times = {\n", " (int(m[1]), int(m[2])) for f in s_files if (m := s_re.match(f))\n", " }\n", " v_times = {\n", " (int(m[1]), int(m[2])) for f in v_files if (m := v_re.match(f))\n", " }\n", "\n", " times = s_times.intersection(v_times)\n", "\n", " return times\n", "\n", " def _data_available(self, spec: SampleSpec) -> bool:\n", " \"\"\"\n", " Checks whether data is available for a given SampleSpec object. Does so\n", " using the internal sets with available data previously constructed. Not\n", " by checking the file system.\n", " Args:\n", " spec: SampleSpec object as returned by SampleSpec.get\n", " Returns:\n", " bool: if data is availability.\n", " \"\"\"\n", " valid = set(spec.times).issubset(self.valid_timestamps)\n", "\n", " if self._require_clim:\n", " sci = spec.climatology_info\n", " ci = set(sci) if isinstance(sci, list) else set([sci]) # noqa: C405\n", " valid &= ci.issubset(self.valid_climate_timestamps)\n", "\n", " return valid\n", "\n", " @ft.cached_property\n", " def samples(self) -> list[tuple[pd.Timestamp, int, int]]:\n", " \"\"\"\n", " Generates list of all valid samlpes.\n", " Returns:\n", " list: List of tuples (timestamp, input time, lead time).\n", " \"\"\"\n", " valid_samples = []\n", " dts = [(it, lt) for it in self.input_times for lt in self.lead_times]\n", "\n", " for timestamp in sorted(self.valid_timestamps):\n", " timestamp_samples = []\n", " for it, lt in dts:\n", " spec = SampleSpec.get(timestamp, -it, lt)\n", "\n", " if self._data_available(spec):\n", " timestamp_samples.append((timestamp, it, lt))\n", "\n", " if timestamp_samples:\n", " valid_samples.append(timestamp_samples)\n", "\n", " return valid_samples\n", "\n", " def _to_torch(\n", " self,\n", " data: dict[str, Tensor | list[Tensor]],\n", " dtype: torch.dtype = torch.float32,\n", " ) -> dict[str, Tensor | list[Tensor]]:\n", " out = {}\n", " for k, v in data.items():\n", " if isinstance(v, list):\n", " out[k] = [torch.from_numpy(x).to(dtype) for x in v]\n", " else:\n", " out[k] = torch.from_numpy(v).to(dtype)\n", "\n", " return out\n", "\n", " def _lat_roll(\n", " self, data: dict[str, Tensor | list[Tensor]], n: int\n", " ) -> dict[str, Tensor | list[Tensor]]:\n", " out = {}\n", " for k, v in data.items():\n", " if isinstance(v, list):\n", " out[k] = [torch.roll(x, shifts=n, dims=-1) for x in v]\n", " else:\n", " out[k] = torch.roll(v, shifts=n, dims=-1)\n", "\n", " return out\n", "\n", " def _read_static_data(\n", " self, file: str | Path, doy: int, hod: int\n", " ) -> np.ndarray:\n", " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", " lats_surf = handle[\"lat\"]\n", " lons_surf = handle[\"lon\"]\n", "\n", " nll = (len(lats_surf), len(lons_surf))\n", "\n", " npos = len(self.position_signal)\n", " ntime = 4\n", "\n", " nstat = npos + ntime + self._nsstat\n", " data = np.empty((nstat, *nll), dtype=self.rtype)\n", "\n", " for i, key in enumerate(self._sstat, start=npos + ntime):\n", " data[i] = handle[key][()].astype(dtype=self.rtype)\n", "\n", " # [possition signal], cos(doy), sin(doy), cos(hod), sin(hod)\n", " data[0:npos] = self.position_signal\n", " data[npos + 0] = np.cos(2 * np.pi * doy / 366)\n", " data[npos + 1] = np.sin(2 * np.pi * doy / 366)\n", " data[npos + 2] = np.cos(2 * np.pi * hod / 24)\n", " data[npos + 3] = np.sin(2 * np.pi * hod / 24)\n", "\n", " return data\n", "\n", " def _read_surface(\n", " self, tidx: int, nll: tuple[int, int], handle: h5py.File\n", " ) -> np.ndarray:\n", " data = np.empty((self._nsvars, *nll), dtype=self.rtype)\n", "\n", " for i, key in enumerate(self._svars):\n", " data[i] = handle[key][tidx][()].astype(dtype=self.rtype)\n", "\n", " return data\n", "\n", " def _read_levels(\n", " self, tidx: int, nll: tuple[int, int], handle: h5py.File\n", " ) -> np.ndarray:\n", " lvls = handle[\"lev\"][()]\n", " lidx = self._level_idxs(lvls)\n", "\n", " data = np.empty((self._nuvars, self._nlevel, *nll), dtype=self.rtype)\n", "\n", " for i, key in enumerate(self._uvars):\n", " data[i] = handle[key][tidx, lidx][()].astype(dtype=self.rtype)\n", "\n", " return np.ascontiguousarray(np.flip(data, axis=1))\n", "\n", " def _level_idxs(self, lvls):\n", " lidx = [np.argwhere(lvls == int(lvl)).item() for lvl in self._level]\n", " return sorted(lidx)\n", "\n", " @staticmethod\n", " def _date_to_tidx(date: datetime | pd.Timestamp, handle: h5py.File) -> int:\n", " if isinstance(date, pd.Timestamp):\n", " date = date.to_pydatetime()\n", "\n", " time = handle[\"time\"]\n", "\n", " t0 = time.attrs[\"begin_time\"][()].item()\n", " d0 = f\"{time.attrs['begin_date'][()].item()}\"\n", "\n", " offset = datetime.strptime(d0, \"%Y%m%d\")\n", "\n", " times = [offset + timedelta(minutes=int(t + t0)) for t in time[()]]\n", " return times.index(date)\n", "\n", " def _read_data(\n", " self, file_pair: tuple[str, str], date: datetime\n", " ) -> dict[str, np.ndarray]:\n", " s_file, v_file = file_pair\n", "\n", " with h5py.File(s_file, \"r\", libver=\"latest\") as shandle:\n", " lats_surf = shandle[\"lat\"]\n", " lons_surf = shandle[\"lon\"]\n", "\n", " nll = (len(lats_surf), len(lons_surf))\n", "\n", " tidx = self._date_to_tidx(date, shandle)\n", "\n", " sdata = self._read_surface(tidx, nll, shandle)\n", "\n", " with h5py.File(v_file, \"r\", libver=\"latest\") as vhandle:\n", " lats_vert = vhandle[\"lat\"]\n", " lons_vert = vhandle[\"lon\"]\n", "\n", " nll = (len(lats_vert), len(lons_vert))\n", "\n", " tidx = self._date_to_tidx(date, vhandle)\n", "\n", " vdata = self._read_levels(tidx, nll, vhandle)\n", "\n", " data = {\"vert\": vdata, \"surf\": sdata}\n", "\n", " return data\n", "\n", " def _read_climate(\n", " self, file_pair: tuple[str, str]\n", " ) -> dict[str, np.ndarray]:\n", " s_file, v_file = file_pair\n", "\n", " with h5py.File(s_file, \"r\", libver=\"latest\") as shandle:\n", " lats_surf = shandle[\"lat\"]\n", " lons_surf = shandle[\"lon\"]\n", "\n", " nll = (len(lats_surf), len(lons_surf))\n", "\n", " sdata = np.empty((self._nsvars, *nll), dtype=self.rtype)\n", "\n", " for i, key in enumerate(self._svars):\n", " sdata[i] = shandle[key][()].astype(dtype=self.rtype)\n", "\n", " with h5py.File(v_file, \"r\", libver=\"latest\") as vhandle:\n", " lats_vert = vhandle[\"lat\"]\n", " lons_vert = vhandle[\"lon\"]\n", "\n", " nll = (len(lats_vert), len(lons_vert))\n", "\n", " lvls = vhandle[\"lev\"][()]\n", " lidx = self._level_idxs(lvls)\n", "\n", " vdata = np.empty(\n", " (self._nuvars, self._nlevel, *nll), dtype=self.rtype\n", " )\n", "\n", " for i, key in enumerate(self._uvars):\n", " vdata[i] = vhandle[key][lidx][()].astype(dtype=self.rtype)\n", "\n", " data = {\n", " \"vert\": np.ascontiguousarray(np.flip(vdata, axis=1)),\n", " \"surf\": sdata,\n", " }\n", "\n", " return data\n", "\n", " def get_data_from_sample_spec(\n", " self, spec: SampleSpec\n", " ) -> dict[str, Tensor | int | float]:\n", " \"\"\"Loads and assembles sample data given a SampleSpec object.\n", "\n", " Args:\n", " spec (SampleSpec): Full details regarding the data to be loaded\n", " Returns:\n", " dict: Dictionary with the following keys::\n", "\n", " 'sur_static': Torch tensor of shape [parameter, lat, lon]. For\n", " each pixel (lat, lon), the first 7 dimensions index sin(lat),\n", " cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod).\n", " Where doy is the day of the year [1, 366] and hod the hour of\n", " the day [0, 23].\n", " 'sur_vals': Torch tensor of shape [parameter, time, lat, lon].\n", " 'sur_tars': Torch tensor of shape [parameter, time, lat, lon].\n", " 'ulv_vals': Torch tensor of shape [parameter, level, time, lat, lon].\n", " 'ulv_tars': Torch tensor of shape [parameter, level, time, lat, lon].\n", " 'sur_climate': Torch tensor of shape [parameter, lat, lon].\n", " 'ulv_climate': Torch tensor of shape [paramter, level, lat, lon].\n", " 'lead_time': Float.\n", " 'input_time': Float.\n", "\n", " \"\"\" # noqa: E501\n", "\n", " # We assemble the unique timestamps for which we need data.\n", " vals_required = {*spec.times}\n", " stat_required = {*spec.stat_times}\n", "\n", " # We assemble the unique data files from which we need value data\n", " vals_file_map = defaultdict(list)\n", " for t in vals_required:\n", " data_files = (\n", " self.data_file_surface(t),\n", " self.data_file_vertical(t),\n", " )\n", " vals_file_map[data_files].append(t)\n", "\n", " # We assemble the unique data files from which we need static data\n", " stat_file_map = defaultdict(list)\n", " for t in stat_required:\n", " data_files = (\n", " self.data_file_surface(t),\n", " self.data_file_vertical(t),\n", " )\n", " stat_file_map[data_files].append(t)\n", "\n", " # Load the value data\n", " data = {}\n", " for data_files, times in vals_file_map.items():\n", " for time in times:\n", " data[time] = self._read_data(data_files, time)\n", "\n", " # Combine times\n", " sample_data = {}\n", "\n", " input_upl = np.stack([data[t][\"vert\"] for t in spec.inputs], axis=2)\n", " sample_data[\"ulv_vals\"] = input_upl\n", "\n", " target_upl = data[spec.target][\"vert\"]\n", " sample_data[\"ulv_tars\"] = target_upl[:, :, None]\n", "\n", " input_sur = np.stack([data[t][\"surf\"] for t in spec.inputs], axis=1)\n", " sample_data[\"sur_vals\"] = input_sur\n", "\n", " target_sur = data[spec.target][\"surf\"]\n", " sample_data[\"sur_tars\"] = target_sur[:, None]\n", "\n", " # Load the static data\n", " data_files, times = stat_file_map.popitem()\n", " time = times[0].dayofyear, times[0].hour\n", " sample_data[\"sur_static\"] = self._read_static_data(\n", " data_files[0], *time\n", " )\n", "\n", " # If required load the surface data\n", " if self._require_clim:\n", " ci_year, ci_hour = spec.climatology_info\n", "\n", " surf_file = self.data_file_surface_climate(\n", " dayofyear=ci_year,\n", " hourofday=ci_hour,\n", " )\n", "\n", " vert_file = self.data_file_vertical_climate(\n", " dayofyear=ci_year,\n", " hourofday=ci_hour,\n", " )\n", "\n", " clim_data = self._read_climate((surf_file, vert_file))\n", "\n", " sample_data[\"sur_climate\"] = clim_data[\"surf\"]\n", " sample_data[\"ulv_climate\"] = clim_data[\"vert\"]\n", "\n", " # Move the data from numpy to torch\n", " sample_data = self._to_torch(sample_data, dtype=self.dtype)\n", "\n", " # Optionally roll\n", " if len(self._roll_longitudes) > 0:\n", " roll_by = random.choice(self._roll_longitudes)\n", " sample_data = self._lat_roll(sample_data, roll_by)\n", "\n", " # Now that we have rolled, we can add the static data\n", " sample_data[\"lead_time\"] = spec.lead_time\n", " sample_data[\"input_time\"] = spec.input_time\n", "\n", " return sample_data\n", "\n", " def get_data(\n", " self, timestamp: pd.Timestamp, input_time: int, lead_time: int\n", " ) -> dict[str, Tensor | int]:\n", " \"\"\"\n", " Loads data based on timestamp and lead time.\n", " Args:\n", " timestamp: Timestamp.\n", " input_time: time between input samples.\n", " lead_time: lead time.\n", " Returns:\n", " Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',\n", " 'lead_time'.\n", " \"\"\"\n", " spec = SampleSpec.get(timestamp, -input_time, lead_time)\n", " sample_data = self.get_data_from_sample_spec(spec)\n", " return sample_data\n", "\n", " def __getitem__(self, idx: int) -> dict[str, Tensor | int]:\n", " \"\"\"\n", " Loads data based on sample index and random choice of sample.\n", " Args:\n", " idx: Sample index.\n", " Returns:\n", " Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',\n", " 'lead_time', 'input_time'.\n", " \"\"\"\n", " sample_set = self.samples[idx]\n", " timestamp, input_time, lead_time, *nsteps = random.choice(sample_set)\n", " sample_data = self.get_data(timestamp, input_time, lead_time)\n", " return sample_data\n", "\n", " def __len__(self):\n", " return len(self.samples)\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "# from PrithviWxC.dataloaders.merra2 import Merra2Dataset\n", "\n", "dataset = Merra2Dataset(\n", " time_range=time_range,\n", " lead_times=lead_times,\n", " input_times=input_times,\n", " data_path_surface=surf_dir,\n", " data_path_vertical=vert_dir,\n", " climatology_path_surface=surf_clim_dir,\n", " climatology_path_vertical=vert_clim_dir,\n", " surface_vars=surface_vars,\n", " static_surface_vars=static_surface_vars,\n", " vertical_vars=vertical_vars,\n", " levels=levels,\n", " positional_encoding=positional_encoding,\n", ")\n", "assert len(dataset) > 0, \"There doesn't seem to be any valid data.\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The model\n", "We are now ready to build the mdoel.\n", "### Scalers\n", "Additionally, the model takes as static parameters the mean\n", "and variance values of the input variables and the variance\n", "values of the target difference, i.e., the variance between\n", "climatology and instantaneous variables. We have provided\n", "data files containing these values, and here we load this data." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# from PrithviWxC.dataloaders.merra2 import (\n", "# input_scalers,\n", "# output_scalers,\n", "# static_input_scalers,\n", "# )\n", "\n", "surf_in_scal_path = Path(\"./climatology/musigma_surface.nc\")\n", "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", " filename=f\"climatology/{surf_in_scal_path.name}\",\n", " local_dir=\".\",\n", ")\n", "\n", "vert_in_scal_path = Path(\"./climatology/musigma_vertical.nc\")\n", "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", " filename=f\"climatology/{vert_in_scal_path.name}\",\n", " local_dir=\".\",\n", ")\n", "\n", "surf_out_scal_path = Path(\"./climatology/anomaly_variance_surface.nc\")\n", "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", " filename=f\"climatology/{surf_out_scal_path.name}\",\n", " local_dir=\".\",\n", ")\n", "\n", "vert_out_scal_path = Path(\"./climatology/anomaly_variance_vertical.nc\")\n", "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", " filename=f\"climatology/{vert_out_scal_path.name}\",\n", " local_dir=\".\",\n", ")\n", "\n", "in_mu, in_sig = input_scalers(\n", " surface_vars,\n", " vertical_vars,\n", " levels,\n", " surf_in_scal_path,\n", " vert_in_scal_path,\n", ")\n", "\n", "output_sig = output_scalers(\n", " surface_vars,\n", " vertical_vars,\n", " levels,\n", " surf_out_scal_path,\n", " vert_out_scal_path,\n", ")\n", "\n", "static_mu, static_sig = static_input_scalers(\n", " surf_in_scal_path,\n", " static_surface_vars,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Task and additional configs\n", "As previously mentioned, the PrithviWxC model's pretext task\n", "involved predicting the desired variable at a specific lead\n", "time. This was achieved by calculating the difference (delta)\n", "compared to the climatological average at that time. This\n", "operational mode is activated using the residual flag. Although\n", "the model includes additional residual options, the core model\n", "weights were not trained using these modes.\n", "\n", "Additionally, for training and evaluation, it is possible to\n", "mask tokens in the model. The masking occurs after tokenization,\n", "prior to the encoder layers. The model utilizes multi-axis\n", "attention, with data broken down into a hierarchy of local and\n", "global patches. Consequently, masking can be configured to mask\n", "either small local patches or larger global patches. This\n", "configuration is achieved via the `masking_mode` flag. It is\n", "possible to set `masking_mode=both`. This does not mix the modes\n", "but rather allows both modes to be used and swapped between,\n", "primarily for training purposes. For this demonstration, we will\n", "adjust the masking ratio to showcase the reconstruction\n", "capabilities of the model.\n", "\n", "Finally, we can set up shifting. Primarily utilized in the\n", "decoder, this enables alternate shifting of the attention\n", "windows, similar to the SWIN model. This option necessitates\n", "an even number of decoder blocks and is incompatible with the\n", "encoder when masking is also employed." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "residual = \"climate\"\n", "masking_mode = \"local\"\n", "decoder_shifting = True\n", "masking_ratio = 0.99" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model init\n", "We now have all the pieces to build the model. If you are\n", "using the pretrained weights, a number of the model\n", "hyperparameters are predetermined and included below. With\n", "this configuration, the model will have approximately 2.3\n", "billion parameters. Therefore, if you want to train the fully\n", "unfrozen model, you will likely need to use a model distribution\n", "approach, such as fully shared data parallelism (FSDP). To\n", "further reduce the memory usage of the model when gradients are\n", "required, there are two variables — `checkpoint_encoder` and\n", "`checkpoint_decoder` — which enable activation checkpointing of\n", "desired transformer layers." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "from functools import cached_property\n", "from importlib.metadata import version\n", "\n", "from torch import Tensor\n", "from torch.utils.checkpoint import checkpoint\n", "\n", "if version(\"torch\") > \"2.3.0\":\n", " from torch.nn.attention import SDPBackend, sdpa_kernel\n", "import numpy as np\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "\n", "# DropPath code is straight from timm\n", "# (https://huggingface.co/spaces/Roll20/pet_score/blame/main/lib/timm/models/layers/drop.py)\n", "def drop_path(\n", " x: Tensor,\n", " drop_prob: float = 0.0,\n", " training: bool = False,\n", " scale_by_keep: bool = True,\n", ") -> Tensor:\n", " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of\n", " residual blocks). Taken form timm.\n", "\n", " Args:\n", " x (Tensor): Input tensor.\n", " drop_prob (float): Probability of dropping `x`, defaults to 0.\n", " training (bool): Whether model is in in traingin of eval mode,\n", " defaults to False.\n", " scale_by_keep (bool): Whether the output should scaled by\n", " (`1 - drop_prob`), defaults to True.\n", " Returns:\n", " Tensor: Tensor that may have randomly dropped with proability\n", " `drop_path`\n", " \"\"\"\n", " if drop_prob == 0.0 or not training:\n", " return x\n", " keep_prob = 1 - drop_prob\n", " shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n", " random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n", " if keep_prob > 0.0 and scale_by_keep:\n", " random_tensor.div_(keep_prob)\n", " return x * random_tensor\n", "\n", "\n", "class DropPath(nn.Module):\n", " \"\"\"\n", " Drop paths (Stochastic Depth) per sample (when applied in main path of\n", " residual blocks).\n", " \"\"\"\n", "\n", " def __init__(\n", " self, drop_prob: float | None = None, scale_by_keep: bool = True\n", " ) -> None:\n", " super(DropPath, self).__init__()\n", " self.drop_prob = drop_prob\n", " self.scale_by_keep = scale_by_keep\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " \"\"\"Runs drop path on input tensor\n", "\n", " Args:\n", " x: input\n", "\n", " Returns:\n", " tensor: output after drop_path\n", " \"\"\"\n", " return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)\n", "\n", "\n", "class Mlp(nn.Module):\n", " \"\"\"\n", " Multi layer perceptron.\n", " \"\"\"\n", "\n", " def __init__(\n", " self, features: int, hidden_features: int, dropout: float = 0.0\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " features: Input/output dimension.\n", " hidden_features: Hidden dimension.\n", " dropout: Dropout.\n", " \"\"\"\n", " super().__init__()\n", " self.net = nn.Sequential(\n", " nn.Linear(features, hidden_features),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(hidden_features, features),\n", " nn.Dropout(dropout),\n", " )\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Args:\n", " x (Tesnor): Tensor of shape [..., channel]\n", " Returns:\n", " Tenosr: Tensor of same shape as x.\n", " \"\"\"\n", " return self.net(x)\n", "\n", "\n", "class LayerNormPassThrough(nn.LayerNorm):\n", " \"\"\"Normalising layer that allows the attention mask to be passed through\"\"\"\n", "\n", " def __init__(self, *args, **kwargs):\n", " super().__init__(*args, **kwargs)\n", "\n", " def forward(\n", " self, d: tuple[Tensor, Tensor | None]\n", " ) -> tuple[Tensor, Tensor | None]:\n", " \"\"\"Forwards function\n", "\n", " Args:\n", " d (tuple): tuple of the data tensor and the attention mask\n", " Returns:\n", " output (Tensor): normalised output data\n", " attn_mask (Tensor): the attention mask that was passed in\n", " \"\"\"\n", " input, attn_mask = d\n", " output = F.layer_norm(\n", " input, self.normalized_shape, self.weight, self.bias, self.eps\n", " )\n", " return output, attn_mask\n", "\n", "\n", "class MultiheadAttention(nn.Module):\n", " \"\"\"Multihead attention layer for inputs of shape\n", " [..., sequence, features].\n", " \"\"\"\n", "\n", " def __init__(self, features: int, n_heads: int, dropout: float) -> None:\n", " \"\"\"\n", " Args:\n", " features: Number of features for inputs to the layer.\n", " n_heads: Number of attention heads. Should be a factor of features.\n", " (I.e. the layer uses features // n_heads.)\n", " dropout: Dropout.\n", " \"\"\" # noqa: E501\n", " super().__init__()\n", "\n", " if (features % n_heads) != 0:\n", " raise ValueError(\n", " f\"Features '{features}' is not divisible by heads '{n_heads}'.\"\n", " )\n", "\n", " self.features = features\n", " self.n_heads = n_heads\n", " self.dropout = dropout\n", "\n", " self.qkv_layer = torch.nn.Linear(features, features * 3, bias=False)\n", " self.w_layer = torch.nn.Linear(features, features, bias=False)\n", "\n", " def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:\n", " \"\"\"\n", " Args:\n", " d (tuple): tuple containing Tensor of shape [..., sequence, features] and the attention mask\n", " Returns:\n", " Tensor: Tensor of shape [..., sequence, features]\n", " \"\"\" # noqa: E501\n", " x, attn_mask = d\n", "\n", " if not x.shape[-1] == self.features:\n", " raise ValueError(\n", " f\"Expecting tensor with last dimension size {self.features}.\"\n", " )\n", "\n", " passenger_dims = x.shape[:-2]\n", " B = passenger_dims.numel()\n", " S = x.shape[-2]\n", " C = x.shape[-1]\n", " x = x.reshape(B, S, C)\n", "\n", " # x [B, S, C]\n", " # q, k, v [B, H, S, C/H]\n", " q, k, v = (\n", " self.qkv_layer(x)\n", " .view(B, S, self.n_heads, 3 * (C // self.n_heads))\n", " .transpose(1, 2)\n", " .chunk(chunks=3, dim=3)\n", " )\n", "\n", " # Let us enforce either flash (A100+) or memory efficient attention.\n", " if version(\"torch\") > \"2.3.0\":\n", " with sdpa_kernel(\n", " [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]\n", " ):\n", " # x [B, H, S, C//H]\n", " x = F.scaled_dot_product_attention(\n", " q, k, v, attn_mask=attn_mask, dropout_p=self.dropout\n", " )\n", " else:\n", " with torch.backends.cuda.sdp_kernel(\n", " enable_flash=True, enable_math=False, enable_mem_efficient=True\n", " ):\n", " # x [B, H, S, C//H]\n", " x = F.scaled_dot_product_attention(\n", " q, k, v, dropout_p=self.dropout\n", " )\n", "\n", " # x [B, S, C]\n", " x = x.transpose(1, 2).view(B, S, C)\n", "\n", " # x [B, S, C]\n", " x = self.w_layer(x)\n", "\n", " # Back to input shape\n", " x = x.view(*passenger_dims, S, self.features)\n", " return x\n", "\n", "\n", "class Transformer(nn.Module):\n", " \"\"\"\n", " Transformer for inputs of shape [..., S, features].\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " features: int,\n", " mlp_multiplier: int,\n", " n_heads: int,\n", " dropout: float,\n", " drop_path: float,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " features: Number of features for inputs to the layer.\n", " mlp_multiplier: Model uses features*mlp_multiplier hidden units.\n", " n_heads: Number of attention heads. Should be a factor of features.\n", " (I.e. the layer uses features // n_heads.) dropout: Dropout.\n", " drop_path: DropPath.\n", " \"\"\"\n", " super().__init__()\n", "\n", " self.features = features\n", " self.mlp_multiplier = mlp_multiplier\n", " self.n_heads = n_heads\n", " self.dropout = dropout\n", " self.drop_path = (\n", " DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n", " )\n", "\n", " self.attention = nn.Sequential(\n", " LayerNormPassThrough(features),\n", " MultiheadAttention(features, n_heads, dropout),\n", " )\n", "\n", " self.ff = nn.Sequential(\n", " nn.LayerNorm(features),\n", " Mlp(\n", " features=features,\n", " hidden_features=features * mlp_multiplier,\n", " dropout=dropout,\n", " ),\n", " )\n", "\n", " def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:\n", " \"\"\"\n", " Args:\n", " x: Tensor of shape [..., sequence, features]\n", " Returns:\n", " Tensor: Tensor of shape [..., sequence, features]\n", " \"\"\"\n", " x, attn_mask = d\n", " if not x.shape[-1] == self.features:\n", " raise ValueError(\n", " f\"Expecting tensor with last dimension size {self.features}.\"\n", " )\n", "\n", " attention_x = self.attention(d)\n", "\n", " x = x + self.drop_path(attention_x)\n", " x = x + self.drop_path(self.ff(x))\n", "\n", " return x\n", "\n", "\n", "class _Shift(nn.Module):\n", " \"\"\"Private base class for the shifter. This allows some behaviour to be\n", " easily handled when the shifter isn't used.\n", " \"\"\"\n", "\n", " def __init__(self):\n", " super().__init__()\n", "\n", " self._shifted = False\n", "\n", " @torch.no_grad()\n", " def reset(self) -> None:\n", " \"\"\"\n", " Resets the bool tracking whether the data is shifted\n", " \"\"\"\n", " self._shifted: bool = False\n", "\n", " def forward(self, data: Tensor) -> tuple[Tensor, dict[bool, None]]:\n", " return data, {True: None, False: None}\n", "\n", "\n", "class SWINShift(_Shift):\n", " \"\"\"\n", " Handles the shifting of patches similar to how SWIN works. However if we\n", " shift the latitudes then the poles will wrap and potentially that might be\n", " problematic. The possition tokens should handle it but masking is safer.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " mu_shape: tuple[int, int],\n", " global_shape: tuple[int, int],\n", " local_shape: tuple[int, int],\n", " patch_shape: tuple[int, int],\n", " n_context_tokens: int = 2,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " mu_shape: the shape to the masking units\n", " global_shape: number of global patches in lat and lon\n", " local_shape: size of the local patches\n", " patch_shape: patch size\n", " n_context_token: number of additional context tokens at start of\n", " _each_ local sequence\n", " \"\"\"\n", " super().__init__()\n", "\n", " self._mu_shape = ms = mu_shape\n", " self._g_shape = gs = global_shape\n", " self._l_shape = ls = local_shape\n", " self._p_shape = ps = patch_shape\n", " self._lat_patch = (gs[0], ls[0], gs[1], ls[1])\n", " self._n_context_tokens = n_context_tokens\n", "\n", " self._g_shift_to = tuple(\n", " int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)\n", " )\n", " self._g_shift_from = tuple(\n", " -int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)\n", " )\n", "\n", " # Define the attention masks for the shifted MaxViT.\n", " nglobal = global_shape[0] * global_shape[1]\n", " nlocal = (\n", " local_shape[0] * local_shape[1] + self._n_context_tokens\n", " ) # \"+ 1\" for leadtime\n", "\n", " lm = torch.ones((nglobal, 1, nlocal, nlocal), dtype=bool)\n", " mwidth = int(0.5 * local_shape[1]) * local_shape[0]\n", " lm[\n", " : gs[1],\n", " :,\n", " self._n_context_tokens : mwidth + self._n_context_tokens,\n", " self._n_context_tokens : mwidth + self._n_context_tokens,\n", " ] = False\n", " self.register_buffer(\"local_mask\", lm)\n", "\n", " gm = torch.ones((nlocal, 1, nglobal, nglobal), dtype=bool)\n", " gm[: int(0.5 * ls[1]) * ls[0], :, : gs[1], : gs[1]] = False\n", " self.register_buffer(\"global_mask\", gm)\n", "\n", " def _to_grid_global(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Shuffle and reshape the data from the global/local setting back to the\n", " lat/lon grid setting\n", " Args:\n", " x: the data tensor to be shuffled.\n", " Returns:\n", " x: data in the global/local setting\n", " \"\"\"\n", " nbatch, *other = x.shape\n", "\n", " y1 = x.view(nbatch, *self._g_shape, *self._l_shape, -1)\n", " y2 = y1.permute(0, 5, 1, 3, 2, 4).contiguous()\n", "\n", " s = y2.shape\n", " return y2.view((nbatch, -1, s[2] * s[3], s[4] * s[5]))\n", "\n", " def _to_grid_local(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Shuffle and reshape the data from the local/global setting to the\n", " lat/lon grid setting\n", " Args:\n", " x: the data tensor to be shuffled.\n", " Returns:\n", " x: data in the lat/lon setting.\n", " \"\"\"\n", " x = x.transpose(2, 1).contiguous()\n", " return self._to_grid_global(x)\n", "\n", " def _from_grid_global(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Shuffle and reshape the data from the lat/lon grid to the global/local\n", " setting\n", " Args:\n", " x: the data tensor to be shuffled.\n", " Returns:\n", " x: data in the global/local setting\n", " \"\"\"\n", " nbatch, *other = x.shape\n", "\n", " z1 = x.view(nbatch, -1, *self._lat_patch)\n", " z2 = z1.permute(0, 2, 4, 3, 5, 1).contiguous()\n", "\n", " s = z2.shape\n", " return z2.view(nbatch, s[1] * s[2], s[3] * s[4], -1)\n", "\n", " def _from_grid_local(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Shuffle and reshape the data from the lat/lon grid to the local/global\n", " setting\n", " Args:\n", " x: the data tensor to be shuffled.\n", " Returns:\n", " x: data in the local/global setting\n", " \"\"\"\n", " x = self._from_grid_global(x)\n", " return x.transpose(2, 1).contiguous()\n", "\n", " def _shift(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Shifts data in the gridded lat/lon setting by half the mask unit shape\n", " Args:\n", " x: data to be shifted\n", " Returns:\n", " x: either the hsifted or unshifted data\n", " \"\"\"\n", " shift = self._g_shift_from if self._shifted else self._g_shift_to\n", " x_shifted = torch.roll(x, shift, (-2, -1))\n", "\n", " self._shifted = not self._shifted\n", " return x_shifted\n", "\n", " def _sep_lt(self, x: Tensor) -> tuple[Tensor, Tensor]:\n", " \"\"\"\n", " Seperate off the leadtime from the local patches\n", " Args:\n", " x: data to have leadtime removed from\n", " Returns:\n", " lt: leadtime\n", " x: data without the lead time in the local patch\n", " \"\"\"\n", " lt_it = x[:, : self._n_context_tokens, :, :]\n", " x_stripped = x[:, self._n_context_tokens :, :, :]\n", "\n", " return lt_it, x_stripped\n", "\n", " def forward(self, data: Tensor) -> tuple[Tensor, Tensor]:\n", " \"\"\"Shift or unshift the the data depending on whether the data is\n", " already shifted, as defined by self._shifte.\n", "\n", " Args:\n", " data: data to be shifted\n", " Returns:\n", " Tensor: shifted data Tensor\n", " \"\"\"\n", " lt, x = self._sep_lt(data)\n", "\n", " x_grid = self._to_grid_local(x)\n", " x_shifted = self._shift(x_grid)\n", " x_patched = self._from_grid_local(x_shifted)\n", "\n", " # Mask has to be repeated based on batch size\n", " n_batch = x_grid.shape[0]\n", " local_rep = [n_batch] + [1] * (self.local_mask.ndim - 1)\n", " global_rep = [n_batch] + [1] * (self.global_mask.ndim - 1)\n", "\n", " if self._shifted:\n", " attn_mask = {\n", " True: self.local_mask.repeat(local_rep),\n", " False: self.global_mask.repeat(global_rep),\n", " }\n", " else:\n", " attn_mask = {True: None, False: None}\n", "\n", " return torch.cat((lt, x_patched), axis=1), attn_mask\n", "\n", "\n", "class LocalGlobalLocalBlock(nn.Module):\n", " \"\"\"\n", " Applies alternating block and grid attention. Given a parameter n_blocks,\n", " the entire module contains 2*n_blocks+1 transformer blocks. The first,\n", " third, ..., last apply local (block) attention. The second, fourth, ...\n", " global (grid) attention.\n", "\n", " This is heavily inspired by\n", " Tu et al. \"MaxViT: Multi-Axis Vision Transformer\"\n", " (https://arxiv.org/abs/2204.01697).\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " features: int,\n", " mlp_multiplier: int,\n", " n_heads: int,\n", " dropout: float,\n", " n_blocks: int,\n", " drop_path: float,\n", " shifter: nn.Module | None = None,\n", " checkpoint: list[int] | None = None,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " features: Number of features for inputs to the layer.\n", " mlp_multiplier: Model uses features*mlp_multiplier hidden units.\n", " n_heads: Number of attention heads. Should be a factor of features.\n", " (I.e. the layer uses features // n_heads.)\n", " dropout: Dropout.\n", " drop_path: DropPath.\n", " n_blocks: Number of local-global transformer pairs.\n", " \"\"\"\n", " super().__init__()\n", "\n", " self.features = features\n", " self.mlp_multiplier = mlp_multiplier\n", " self.n_heads = n_heads\n", " self.dropout = dropout\n", " self.drop_path = drop_path\n", " self.n_blocks = n_blocks\n", " self._checkpoint = checkpoint or []\n", "\n", " if not all(0 <= c < 2 * n_blocks + 1 for c in self._checkpoint):\n", " raise ValueError(\n", " \"Checkpoints should be 0 <= i < 2*n_blocks+1. \"\n", " f\"{self._checkpoint=}.\"\n", " )\n", "\n", " self.transformers = nn.ModuleList(\n", " [\n", " Transformer(\n", " features=features,\n", " mlp_multiplier=mlp_multiplier,\n", " n_heads=n_heads,\n", " dropout=dropout,\n", " drop_path=drop_path,\n", " )\n", " for _ in range(2 * n_blocks + 1)\n", " ]\n", " )\n", "\n", " self.evaluator = [\n", " self._checkpoint_wrapper\n", " if i in self._checkpoint\n", " else lambda m, x: m(x)\n", " for i, _ in enumerate(self.transformers)\n", " ]\n", "\n", " self.shifter = shifter or _Shift()\n", "\n", " @staticmethod\n", " def _checkpoint_wrapper(\n", " model: nn.Module, data: tuple[Tensor, Tensor | None]\n", " ) -> Tensor:\n", " return checkpoint(model, data, use_reentrant=False)\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Args:\n", " x: Tensor of shape::\n", "\n", " [batch, global_sequence, local_sequence, features]\n", "\n", " Returns:\n", " Tensor: Tensor of shape::\n", "\n", " [batch, global_sequence, local_sequence, features]\n", " \"\"\"\n", " if x.shape[-1] != self.features:\n", " raise ValueError(\n", " f\"Expecting tensor with last dimension size {self.features}.\"\n", " )\n", " if x.ndim != 4:\n", " raise ValueError(\n", " f\"Expecting tensor with exactly four dimensions. {x.shape=}.\"\n", " )\n", "\n", " self.shifter.reset()\n", " local: bool = True\n", " attn_mask = {True: None, False: None}\n", "\n", " transformer_iter = zip(self.evaluator, self.transformers, strict=False)\n", "\n", " # First local block\n", " evaluator, transformer = next(transformer_iter)\n", " x = evaluator(transformer, (x, attn_mask[local]))\n", "\n", " for evaluator, transformer in transformer_iter:\n", " local = not local\n", " # We are making exactly 2*n_blocks transposes.\n", " # So the output has the same shape as input.\n", " x = x.transpose(1, 2)\n", "\n", " x = evaluator(transformer, (x, attn_mask[local]))\n", "\n", " if not local:\n", " x, attn_mask = self.shifter(x)\n", "\n", " return x\n", "\n", "\n", "class PatchEmbed(nn.Module):\n", " \"\"\"\n", " Patch embedding via 2D convolution.\n", " \"\"\"\n", "\n", " def __init__(\n", " self, patch_size: int | tuple[int, ...], channels: int, embed_dim: int\n", " ):\n", " super().__init__()\n", "\n", " self.patch_size = patch_size\n", " self.channels = channels\n", " self.embed_dim = embed_dim\n", "\n", " self.proj = nn.Conv2d(\n", " channels,\n", " embed_dim,\n", " kernel_size=patch_size,\n", " stride=patch_size,\n", " bias=True,\n", " )\n", "\n", " def forward(self, x: Tensor) -> Tensor:\n", " \"\"\"\n", " Args:\n", " x: Tensor of shape [batch, channels, lat, lon].\n", " Returns:\n", " Tensor: Tensor with shape\n", " [batch, embed_dim, lat//patch_size, lon//patch_size]\n", " \"\"\"\n", "\n", " H, W = x.shape[-2:]\n", "\n", " if W % self.patch_size[1] != 0:\n", " raise ValueError(\n", " f\"Cannot do patch embedding for tensor of shape {x.size()}\"\n", " \" with patch size {self.patch_size}. (Dimensions are BSCHW.)\"\n", " )\n", " if H % self.patch_size[0] != 0:\n", " raise ValueError(\n", " f\"Cannot do patch embedding for tensor of shape {x.size()}\"\n", " f\" with patch size {self.patch_size}. (Dimensions are BSCHW.)\"\n", " )\n", "\n", " x = self.proj(x)\n", "\n", " return x\n", "\n", "\n", "class PrithviWxCEncoderDecoder(nn.Module):\n", " \"\"\"\n", " Hiera-MaxViT encoder/decoder code.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " embed_dim: int,\n", " n_blocks: int,\n", " mlp_multiplier: float,\n", " n_heads: int,\n", " dropout: float,\n", " drop_path: float,\n", " shifter: nn.Module | None = None,\n", " transformer_cp: list[int] | None = None,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " embed_dim: Embedding dimension\n", " n_blocks: Number of local-global transformer pairs.\n", " mlp_multiplier: MLP multiplier for hidden features in feed forward\n", " networks.\n", " n_heads: Number of attention heads.\n", " dropout: Dropout.\n", " drop_path: DropPath.\n", " \"\"\"\n", " super().__init__()\n", "\n", " self.embed_dim = embed_dim\n", " self.n_blocks = n_blocks\n", " self.mlp_multiplier = mlp_multiplier\n", " self.n_heads = n_heads\n", " self.dropout = dropout\n", " self._transformer_cp = transformer_cp\n", "\n", " self.lgl_block = LocalGlobalLocalBlock(\n", " features=embed_dim,\n", " mlp_multiplier=mlp_multiplier,\n", " n_heads=n_heads,\n", " dropout=dropout,\n", " drop_path=drop_path,\n", " n_blocks=n_blocks,\n", " shifter=shifter,\n", " checkpoint=transformer_cp,\n", " )\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " Args:\n", " x: Tensor of shape\n", " [batch, global sequence, local sequence, embed_dim]\n", " Returns:\n", " Tensor of shape\n", " [batch, mask_unit_sequence, local_sequence, embed_dim].\n", " Identical in shape to the input x.\n", " \"\"\"\n", "\n", " x = self.lgl_block(x)\n", "\n", " return x\n", "\n", "\n", "class PrithviWxC(nn.Module):\n", " \"\"\"Encoder-decoder fusing Hiera with MaxViT. See\n", " - Ryali et al. \"Hiera: A Hierarchical Vision Transformer without the\n", " Bells-and-Whistles\" (https://arxiv.org/abs/2306.00989)\n", " - Tu et al. \"MaxViT: Multi-Axis Vision Transformer\"\n", " (https://arxiv.org/abs/2204.01697)\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " in_channels: int,\n", " input_size_time: int,\n", " in_channels_static: int,\n", " input_scalers_mu: Tensor,\n", " input_scalers_sigma: Tensor,\n", " input_scalers_epsilon: float,\n", " static_input_scalers_mu: Tensor,\n", " static_input_scalers_sigma: Tensor,\n", " static_input_scalers_epsilon: float,\n", " output_scalers: Tensor,\n", " n_lats_px: int,\n", " n_lons_px: int,\n", " patch_size_px: tuple[int],\n", " mask_unit_size_px: tuple[int],\n", " mask_ratio_inputs: float,\n", " embed_dim: int,\n", " n_blocks_encoder: int,\n", " n_blocks_decoder: int,\n", " mlp_multiplier: float,\n", " n_heads: int,\n", " dropout: float,\n", " drop_path: float,\n", " parameter_dropout: float,\n", " residual: str,\n", " masking_mode: str,\n", " positional_encoding: str,\n", " decoder_shifting: bool = False,\n", " checkpoint_encoder: list[int] | None = None,\n", " checkpoint_decoder: list[int] | None = None,\n", " ) -> None:\n", " \"\"\"\n", " Args:\n", " in_channels: Number of input channels.\n", " input_size_time: Number of timestamps in input.\n", " in_channels_static: Number of input channels for static data.\n", " input_scalers_mu: Tensor of size (in_channels,). Used to rescale\n", " input.\n", " input_scalers_sigma: Tensor of size (in_channels,). Used to rescale\n", " input.\n", " input_scalers_epsilon: Float. Used to rescale input.\n", " static_input_scalers_mu: Tensor of size (in_channels_static). Used\n", " to rescale static inputs.\n", " static_input_scalers_sigma: Tensor of size (in_channels_static).\n", " Used to rescale static inputs.\n", " static_input_scalers_epsilon: Float. Used to rescale static inputs.\n", " output_scalers: Tensor of shape (in_channels,). Used to rescale\n", " output.\n", " n_lats_px: Total latitudes in data. In pixels.\n", " n_lons_px: Total longitudes in data. In pixels.\n", " patch_size_px: Patch size for tokenization. In pixels lat/lon.\n", " mask_unit_size_px: Size of each mask unit. In pixels lat/lon.\n", " mask_ratio_inputs: Masking ratio for inputs. 0 to 1.\n", " embed_dim: Embedding dimension\n", " n_blocks_encoder: Number of local-global transformer pairs in\n", " encoder.\n", " n_blocks_decoder: Number of local-global transformer pairs in\n", " decoder.\n", " mlp_multiplier: MLP multiplier for hidden features in feed forward\n", " networks.\n", " n_heads: Number of attention heads.\n", " dropout: Dropout.\n", " drop_path: DropPath.\n", " parameter_dropout: Dropout applied to parameters.\n", " residual: Indicates whether and how model should work as residual\n", " model. Accepted values are 'climate', 'temporal' and 'none'\n", " positional_encoding: possible values are\n", " ['absolute' (default), 'fourier'].\n", " 'absolute' lat lon encoded in 3 dimensions using sine and\n", " cosine\n", " 'fourier' lat/lon to be encoded using various frequencies\n", " masking_mode: String ['local', 'global', 'both'] that controls the\n", " type of masking used.\n", " checkpoint_encoder: List of integers controlling if gradient\n", " checkpointing is used on encoder.\n", " Format: [] for no gradient checkpointing. [3, 7] for\n", " checkpointing after 4th and 8th layer etc.\n", " checkpoint_decoder: List of integers controlling if gradient\n", " checkpointing is used on decoder.\n", " Format: See `checkpoint_encoder`.\n", " masking_mode: The type of masking to use\n", " {'global', 'local', 'both'}\n", " decoder_shifting: Whether to use swin shifting in the decoder.\n", " \"\"\"\n", " super().__init__()\n", "\n", " self.in_channels = in_channels\n", " self.input_size_time = input_size_time\n", " self.in_channels_static = in_channels_static\n", " self.n_lats_px = n_lats_px\n", " self.n_lons_px = n_lons_px\n", " self.patch_size_px = patch_size_px\n", " self.mask_unit_size_px = mask_unit_size_px\n", " self.mask_ratio_inputs = mask_ratio_inputs\n", " self.embed_dim = embed_dim\n", " self.n_blocks_encoder = n_blocks_encoder\n", " self.n_blocks_decoder = n_blocks_decoder\n", " self.mlp_multiplier = mlp_multiplier\n", " self.n_heads = n_heads\n", " self.dropout = dropout\n", " self.drop_path = drop_path\n", " self.residual = residual\n", " self._decoder_shift = decoder_shifting\n", " self.positional_encoding = positional_encoding\n", " self._checkpoint_encoder = checkpoint_encoder\n", " self._checkpoint_decoder = checkpoint_decoder\n", "\n", " assert self.n_lats_px % self.mask_unit_size_px[0] == 0\n", " assert self.n_lons_px % self.mask_unit_size_px[1] == 0\n", " assert self.mask_unit_size_px[0] % self.patch_size_px[0] == 0\n", " assert self.mask_unit_size_px[1] % self.patch_size_px[1] == 0\n", "\n", " if self.patch_size_px[0] != self.patch_size_px[1]:\n", " raise NotImplementedError(\n", " \"Current pixel shuffle symmetric patches.\"\n", " )\n", "\n", " self.local_shape_mu = (\n", " self.mask_unit_size_px[0] // self.patch_size_px[0],\n", " self.mask_unit_size_px[1] // self.patch_size_px[1],\n", " )\n", " self.global_shape_mu = (\n", " self.n_lats_px // self.mask_unit_size_px[0],\n", " self.n_lons_px // self.mask_unit_size_px[1],\n", " )\n", "\n", " assert input_scalers_mu.shape == (in_channels,)\n", " assert input_scalers_sigma.shape == (in_channels,)\n", " assert output_scalers.shape == (in_channels,)\n", "\n", " if self.positional_encoding != \"fourier\":\n", " assert static_input_scalers_mu.shape == (in_channels_static,)\n", " assert static_input_scalers_sigma.shape == (in_channels_static,)\n", "\n", " # Input shape [batch, time, parameter, lat, lon]\n", " self.input_scalers_epsilon = input_scalers_epsilon\n", " self.register_buffer(\n", " \"input_scalers_mu\", input_scalers_mu.reshape(1, 1, -1, 1, 1)\n", " )\n", " self.register_buffer(\n", " \"input_scalers_sigma\", input_scalers_sigma.reshape(1, 1, -1, 1, 1)\n", " )\n", "\n", " # Static inputs shape [batch, parameter, lat, lon]\n", " self.static_input_scalers_epsilon = static_input_scalers_epsilon\n", " self.register_buffer(\n", " \"static_input_scalers_mu\",\n", " static_input_scalers_mu.reshape(1, -1, 1, 1),\n", " )\n", " self.register_buffer(\n", " \"static_input_scalers_sigma\",\n", " static_input_scalers_sigma.reshape(1, -1, 1, 1),\n", " )\n", "\n", " # Output shape [batch, parameter, lat, lon]\n", " self.register_buffer(\n", " \"output_scalers\", output_scalers.reshape(1, -1, 1, 1)\n", " )\n", "\n", " self.parameter_dropout = nn.Dropout2d(p=parameter_dropout)\n", "\n", " self.patch_embedding = PatchEmbed(\n", " patch_size=patch_size_px,\n", " channels=in_channels * input_size_time,\n", " embed_dim=embed_dim,\n", " )\n", "\n", " if self.residual == \"climate\":\n", " self.patch_embedding_static = PatchEmbed(\n", " patch_size=patch_size_px,\n", " channels=in_channels + in_channels_static,\n", " embed_dim=embed_dim,\n", " )\n", " else:\n", " self.patch_embedding_static = PatchEmbed(\n", " patch_size=patch_size_px,\n", " channels=in_channels_static,\n", " embed_dim=embed_dim,\n", " )\n", "\n", " self.input_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)\n", " self.lead_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)\n", "\n", " self.mask_token = nn.Parameter(torch.randn(1, 1, 1, self.embed_dim))\n", " self._nglobal_mu = np.prod(self.global_shape_mu)\n", " self._global_idx = torch.arange(self._nglobal_mu)\n", "\n", " self._nlocal_mu = np.prod(self.local_shape_mu)\n", " self._local_idx = torch.arange(self._nlocal_mu)\n", "\n", " self.encoder = PrithviWxCEncoderDecoder(\n", " embed_dim=embed_dim,\n", " n_blocks=n_blocks_encoder,\n", " mlp_multiplier=mlp_multiplier,\n", " n_heads=n_heads,\n", " dropout=dropout,\n", " drop_path=drop_path,\n", " transformer_cp=checkpoint_encoder,\n", " )\n", "\n", " if n_blocks_decoder != 0:\n", " if self._decoder_shift:\n", " self.decoder_shifter = d_shifter = SWINShift(\n", " self.mask_unit_size_px,\n", " self.global_shape_mu,\n", " self.local_shape_mu,\n", " self.patch_size_px,\n", " n_context_tokens=0,\n", " )\n", " else:\n", " self.decoder_shifter = d_shifter = None\n", "\n", " self.decoder = PrithviWxCEncoderDecoder(\n", " embed_dim=embed_dim,\n", " n_blocks=n_blocks_decoder,\n", " mlp_multiplier=mlp_multiplier,\n", " n_heads=n_heads,\n", " dropout=dropout,\n", " drop_path=0.0,\n", " shifter=d_shifter,\n", " transformer_cp=checkpoint_decoder,\n", " )\n", "\n", " self.unembed = nn.Linear(\n", " self.embed_dim,\n", " self.in_channels\n", " * self.patch_size_px[0]\n", " * self.patch_size_px[1],\n", " bias=True,\n", " )\n", "\n", " self.masking_mode = masking_mode.lower()\n", " match self.masking_mode:\n", " case \"local\":\n", " self.generate_mask = self._gen_mask_local\n", " case \"global\":\n", " self.generate_mask = self._gen_mask_global\n", " case \"both\":\n", " self._mask_both_local: bool = True\n", " self.generate_mask = self._gen_mask_both\n", " case _:\n", " raise ValueError(\n", " f\"Masking mode '{masking_mode}' not supported\"\n", " )\n", "\n", " def swap_masking(self) -> None:\n", " self._mask_both_local = not self._mask_both_local\n", "\n", " @cached_property\n", " def n_masked_global(self):\n", " return int(self.mask_ratio_inputs * np.prod(self.global_shape_mu))\n", "\n", " @cached_property\n", " def n_masked_local(self):\n", " return int(self.mask_ratio_inputs * np.prod(self.local_shape_mu))\n", "\n", " @staticmethod\n", " def _shuffle_along_axis(a, axis):\n", " idx = torch.argsort(input=torch.rand(*a.shape), dim=axis)\n", " return torch.gather(a, dim=axis, index=idx)\n", "\n", " def _gen_mask_local(self, sizes: tuple[int]) -> tuple[Tensor]:\n", " \"\"\"\n", " Args:\n", " batch_size: Number of elements in batch\n", " Returns:\n", " Tuple of torch tensors. [indices masked, indices unmasked].\n", " Each of these is a tensor of shape (batch, global sequene)\n", " \"\"\"\n", " # Identify which indices (values) should be masked\n", "\n", " maskable_indices = self._local_idx.view(1, -1).expand(*sizes[:2], -1)\n", "\n", " maskable_indices = self._shuffle_along_axis(maskable_indices, 2)\n", "\n", " indices_masked = maskable_indices[:, :, : self.n_masked_local]\n", " indices_unmasked = maskable_indices[:, :, self.n_masked_local :]\n", "\n", " return indices_masked, indices_unmasked\n", "\n", " def _gen_mask_global(self, sizes: tuple[int]) -> tuple[Tensor]:\n", " \"\"\"\n", " Args:\n", " batch_size: Number of elements in batch\n", " Returns:\n", " Tuple of torch tensors. [indices masked, indices unmasked].\n", " Each of these is a tensor of shape (batch, global sequene)\n", " \"\"\"\n", " # Identify which indices (values) should be masked\n", "\n", " maskable_indices = self._global_idx.view(1, -1).expand(*sizes[:1], -1)\n", "\n", " maskable_indices = self._shuffle_along_axis(maskable_indices, 1)\n", "\n", " indices_masked = maskable_indices[:, : self.n_masked_global]\n", " indices_unmasked = maskable_indices[:, self.n_masked_global :]\n", "\n", " return indices_masked, indices_unmasked\n", "\n", " def _gen_mask_both(self, sizes: tuple[int]) -> tuple[Tensor]:\n", " if self._mask_both_local:\n", " return self._gen_mask_local(sizes)\n", " else:\n", " return self._gen_mask_global(sizes)\n", "\n", " @staticmethod\n", " def reconstruct_batch(\n", " idx_masked: Tensor,\n", " idx_unmasked: Tensor,\n", " data_masked: Tensor,\n", " data_unmasked: Tensor,\n", " ) -> Tensor:\n", " \"\"\"Reconstructs a tensor along the mask unit dimension. Batched\n", " version.\n", "\n", " Args:\n", " idx_masked: Tensor of shape `batch, mask unit sequence`.\n", " idx_unmasked: Tensor of shape `batch, mask unit sequence`.\n", " data_masked: Tensor of shape `batch, mask unit sequence, ...`.\n", " Should have same size along mask unit sequence dimension as\n", " idx_masked. Dimensions beyond the first two, marked here as ...\n", " will typically be `local_sequence, channel` or\n", " `channel, lat, lon`. These dimensions should agree with\n", " data_unmasked.\n", " data_unmasked: Tensor of shape `batch, mask unit sequence, ...`.\n", " Should have same size along mask unit sequence dimension as\n", " idx_unmasked. Dimensions beyond the first two, marked here as\n", " ... will typically be `local_sequence, channel` or `channel,\n", " lat, lon`. These dimensions should agree with data_masked.\n", " Returns:\n", " Tensor: Tensor of same shape as inputs data_masked and\n", " data_unmasked. I.e. `batch, mask unit sequence, ...`. Index for\n", " the total data composed of the masked and the unmasked part.\n", " \"\"\"\n", " dim: int = idx_masked.ndim\n", "\n", " idx_total = torch.argsort(\n", " torch.cat([idx_masked, idx_unmasked], dim=-1), dim=-1\n", " )\n", " idx_total = idx_total.view(\n", " *idx_total.shape, *[1] * (data_unmasked.ndim - dim)\n", " )\n", " idx_total = idx_total.expand(\n", " *idx_total.shape[:dim], *data_unmasked.shape[dim:]\n", " )\n", "\n", " data = torch.cat([data_masked, data_unmasked], dim=dim - 1)\n", " data = torch.gather(data, dim=dim - 1, index=idx_total)\n", "\n", " return data, idx_total\n", "\n", " def fourier_pos_encoding(self, x_static: Tensor) -> Tensor:\n", " \"\"\"\n", " Args\n", " x_static: B x C x H x W. first two channels are lat, and lon\n", " Returns\n", " Tensor: Tensor of shape B x E x H x W where E is the embedding\n", " dimension.\n", " \"\"\"\n", "\n", " # B x C x H x W -> B x 1 x H/P x W/P\n", " latitudes_patch = F.avg_pool2d(\n", " x_static[:, [0]],\n", " kernel_size=self.patch_size_px,\n", " stride=self.patch_size_px,\n", " )\n", " longitudes_patch = F.avg_pool2d(\n", " x_static[:, [1]],\n", " kernel_size=self.patch_size_px,\n", " stride=self.patch_size_px,\n", " )\n", "\n", " modes = (\n", " torch.arange(self.embed_dim // 4, device=x_static.device).view(\n", " 1, -1, 1, 1\n", " )\n", " + 1.0\n", " )\n", " pos_encoding = torch.cat(\n", " (\n", " torch.sin(latitudes_patch * modes),\n", " torch.sin(longitudes_patch * modes),\n", " torch.cos(latitudes_patch * modes),\n", " torch.cos(longitudes_patch * modes),\n", " ),\n", " axis=1,\n", " )\n", "\n", " return pos_encoding # B x E x H/P x W/P\n", "\n", " def time_encoding(self, input_time, lead_time):\n", " \"\"\"\n", " Args:\n", " input_time: Tensor of shape [batch].\n", " lead_time: Tensor of shape [batch].\n", " Returns:\n", " Tensor: Tensor of shape [batch, embed_dim, 1, 1]\n", " \"\"\"\n", " input_time = self.input_time_embedding(input_time.view(-1, 1, 1, 1))\n", " lead_time = self.lead_time_embedding(lead_time.view(-1, 1, 1, 1))\n", "\n", " time_encoding = torch.cat(\n", " (\n", " torch.cos(input_time),\n", " torch.cos(lead_time),\n", " torch.sin(input_time),\n", " torch.sin(lead_time),\n", " ),\n", " axis=3,\n", " )\n", " return time_encoding\n", "\n", " def to_patching(self, x: Tensor) -> Tensor:\n", " \"\"\"Transform data from lat/lon space to two axis patching\n", "\n", " Args: ->\n", " x: Tesnor in lat/lon space (N, C, Nlat//P_0, Nlon//P_1)\n", "\n", " Returns:\n", " Tensor in patch space (N, G, L, C)\n", " \"\"\"\n", " n_batch = x.shape[0]\n", "\n", " x = x.view(\n", " n_batch,\n", " -1,\n", " self.global_shape_mu[0],\n", " self.local_shape_mu[0],\n", " self.global_shape_mu[1],\n", " self.local_shape_mu[1],\n", " )\n", " x = x.permute(0, 2, 4, 3, 5, 1).contiguous()\n", "\n", " s = x.shape\n", " return x.view(n_batch, s[1] * s[2], s[3] * s[4], -1)\n", "\n", " def from_patching(self, x: Tensor) -> Tensor:\n", " \"\"\"Transform data from two axis patching to lat/lon space\n", "\n", " Args:\n", " x: Tensor in patch space with shape (N, G, L, C*P_0*P_1)\n", "\n", " Returns:\n", " Tensor: Tensor in lat/lon space\n", " (N, C*P_0*P_1, Nlat//P_0, Nlon // P_1)\n", " \"\"\"\n", " n_batch = x.shape[0]\n", "\n", " x = x.view(\n", " n_batch,\n", " self.global_shape_mu[0],\n", " self.global_shape_mu[1],\n", " self.local_shape_mu[0],\n", " self.local_shape_mu[1],\n", " -1,\n", " )\n", " x = x.permute(0, 5, 1, 3, 2, 4).contiguous()\n", "\n", " s = x.shape\n", " return x.view(n_batch, -1, s[2] * s[3], s[4] * s[5])\n", "\n", " def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:\n", " \"\"\"\n", " Args:\n", " batch: Dictionary the following keys::\n", "\n", " 'x': Tensor of shape [batch, time, parameter, lat, lon]\n", " 'y': Tensor of shape [batch, parameter, lat, lon]\n", " 'static': Tensor of shape [batch, channel_static, lat, lon]\n", " 'climate': Optional tensor of shape [batch, parameter, lat, lon]\n", " 'input_time': Tensor of shape [batch]. Or none.\n", " 'lead_time': Tensor of shape [batch]. Or none.\n", "\n", " Returns:\n", " Tensor: Tensor of shape [batch, parameter, lat, lon].\n", " \"\"\" # noqa: E501\n", " x_rescaled = (batch[\"x\"] - self.input_scalers_mu) / (\n", " self.input_scalers_sigma + self.input_scalers_epsilon\n", " )\n", " batch_size = x_rescaled.shape[0]\n", "\n", " if self.positional_encoding == \"fourier\":\n", " x_static_pos = self.fourier_pos_encoding(batch[\"static\"])\n", " x_static = (\n", " batch[\"static\"][:, 2:] - self.static_input_scalers_mu[:, 3:]\n", " ) / (\n", " self.static_input_scalers_sigma[:, 3:]\n", " + self.static_input_scalers_epsilon\n", " )\n", " else:\n", " x_static = (batch[\"static\"] - self.static_input_scalers_mu) / (\n", " self.static_input_scalers_sigma\n", " + self.static_input_scalers_epsilon\n", " )\n", "\n", " if self.residual == \"temporal\":\n", " # We create a residual of same shape as y\n", " index = torch.where(\n", " batch[\"lead_time\"] > 0, batch[\"x\"].shape[1] - 1, 0\n", " )\n", " index = index.view(-1, 1, 1, 1, 1)\n", " index = index.expand(batch_size, 1, *batch[\"x\"].shape[2:])\n", " x_hat = torch.gather(batch[\"x\"], dim=1, index=index)\n", " x_hat = x_hat.squeeze(1)\n", " elif self.residual == \"climate\":\n", " climate_scaled = (\n", " batch[\"climate\"] - self.input_scalers_mu.view(1, -1, 1, 1)\n", " ) / (\n", " self.input_scalers_sigma.view(1, -1, 1, 1)\n", " + self.input_scalers_epsilon\n", " )\n", "\n", " # [batch, time, parameter, lat, lon]\n", " # -> [batch, time x parameter, lat, lon]\n", " x_rescaled = x_rescaled.flatten(1, 2)\n", " # Parameter dropout\n", " x_rescaled = self.parameter_dropout(x_rescaled)\n", "\n", " x_embedded = self.patch_embedding(x_rescaled)\n", "\n", " if self.residual == \"climate\":\n", " static_embedded = self.patch_embedding_static(\n", " torch.cat((x_static, climate_scaled), dim=1)\n", " )\n", " else:\n", " static_embedded = self.patch_embedding_static(x_static)\n", "\n", " if self.positional_encoding == \"fourier\":\n", " static_embedded += x_static_pos\n", "\n", " x_embedded = self.to_patching(x_embedded)\n", " static_embedded = self.to_patching(static_embedded)\n", "\n", " time_encoding = self.time_encoding(\n", " batch[\"input_time\"], batch[\"lead_time\"]\n", " )\n", "\n", " tokens = x_embedded + static_embedded + time_encoding\n", "\n", " # Now we generate masks based on masking_mode\n", " indices_masked, indices_unmasked = self.generate_mask(\n", " (batch_size, self._nglobal_mu)\n", " )\n", " indices_masked = indices_masked.to(device=tokens.device)\n", " indices_unmasked = indices_unmasked.to(device=tokens.device)\n", " maskdim: int = indices_masked.ndim\n", "\n", " # Unmasking\n", " unmask_view = (*indices_unmasked.shape, *[1] * (tokens.ndim - maskdim))\n", " unmasked = torch.gather(\n", " tokens,\n", " dim=maskdim - 1,\n", " index=indices_unmasked.view(*unmask_view).expand(\n", " *indices_unmasked.shape, *tokens.shape[maskdim:]\n", " ),\n", " )\n", "\n", " # Encoder\n", " x_encoded = self.encoder(unmasked)\n", "\n", " # Generate and position encode the mask tokens\n", " # [1, 1, 1, embed_dim]\n", " # -> [batch, global_seq_masked, local seq, embed_dim]\n", " mask_view = (*indices_masked.shape, *[1] * (tokens.ndim - maskdim))\n", " masking = self.mask_token.repeat(*static_embedded.shape[:3], 1)\n", " masked = masking + static_embedded\n", " masked = torch.gather(\n", " masked,\n", " dim=maskdim - 1,\n", " index=indices_masked.view(*mask_view).expand(\n", " *indices_masked.shape, *tokens.shape[maskdim:]\n", " ),\n", " )\n", "\n", " recon, _ = self.reconstruct_batch(\n", " indices_masked, indices_unmasked, masked, x_encoded\n", " )\n", "\n", " x_decoded = self.decoder(recon)\n", "\n", " # Output: [batch, global sequence, local sequence,\n", " # in_channels * patch_size[0] * patch_size[1]]\n", " x_unembed = self.unembed(x_decoded)\n", "\n", " # Reshape to [batch, global_lat, global_lon, local_lat, local_lon,\n", " # in_channels * patch_size[0] * patch_size[1]]\n", " x_out = self.from_patching(x_unembed)\n", "\n", " # Pixel shuffle to [batch, in_channels, lat, lon]\n", " x_out = F.pixel_shuffle(x_out, self.patch_size_px[0])\n", "\n", " if self.residual == \"temporal\":\n", " x_out = self.output_scalers * x_out + x_hat\n", " elif self.residual == \"climate\":\n", " x_out = self.output_scalers * x_out + batch[\"climate\"]\n", " elif self.residual == \"none\":\n", " x_out = (\n", " self.output_scalers * x_out\n", " + self.input_scalers_mu.reshape(1, -1, 1, 1)\n", " )\n", "\n", " return x_out\n" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "import yaml\n", "\n", "# from PrithviWxC.model import PrithviWxC\n", "\n", "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", " filename=\"config.yaml\",\n", " local_dir=\".\",\n", ")\n", "\n", "with open(\"./config.yaml\", \"r\") as f:\n", " config = yaml.safe_load(f)\n", "\n", "model = PrithviWxC(\n", " in_channels=config[\"params\"][\"in_channels\"],\n", " input_size_time=config[\"params\"][\"input_size_time\"],\n", " in_channels_static=config[\"params\"][\"in_channels_static\"],\n", " input_scalers_mu=in_mu,\n", " input_scalers_sigma=in_sig,\n", " input_scalers_epsilon=config[\"params\"][\"input_scalers_epsilon\"],\n", " static_input_scalers_mu=static_mu,\n", " static_input_scalers_sigma=static_sig,\n", " static_input_scalers_epsilon=config[\"params\"][\n", " \"static_input_scalers_epsilon\"\n", " ],\n", " output_scalers=output_sig**0.5,\n", " n_lats_px=config[\"params\"][\"n_lats_px\"],\n", " n_lons_px=config[\"params\"][\"n_lons_px\"],\n", " patch_size_px=config[\"params\"][\"patch_size_px\"],\n", " mask_unit_size_px=config[\"params\"][\"mask_unit_size_px\"],\n", " mask_ratio_inputs=masking_ratio,\n", " embed_dim=config[\"params\"][\"embed_dim\"],\n", " n_blocks_encoder=config[\"params\"][\"n_blocks_encoder\"],\n", " n_blocks_decoder=config[\"params\"][\"n_blocks_decoder\"],\n", " mlp_multiplier=config[\"params\"][\"mlp_multiplier\"],\n", " n_heads=config[\"params\"][\"n_heads\"],\n", " dropout=config[\"params\"][\"dropout\"],\n", " drop_path=config[\"params\"][\"drop_path\"],\n", " parameter_dropout=config[\"params\"][\"parameter_dropout\"],\n", " residual=residual,\n", " masking_mode=masking_mode,\n", " decoder_shifting=decoder_shifting,\n", " positional_encoding=positional_encoding,\n", " checkpoint_encoder=[],\n", " checkpoint_decoder=[],\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load weights\n", "We have provided unshared pretrained weights for the model,\n", "which can now be loaded. The model can then be transferred\n", "to the requested device." ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "weights_path = Path(\"./weights/prithvi.wxc.2300m.v1.pt\")\n", "hf_hub_download(\n", " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", " filename=weights_path.name,\n", " local_dir=\"./weights\",\n", ")\n", "\n", "state_dict = torch.load(weights_path, weights_only=False)\n", "if \"model_state\" in state_dict:\n", " state_dict = state_dict[\"model_state\"]\n", "model.load_state_dict(state_dict, strict=True)\n", "\n", "if (hasattr(model, \"device\") and model.device != device) or not hasattr(\n", " model, \"device\"\n", "):\n", " model = model.to(device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inference\n", "We are now ready to perform inference on the model. The data\n", "returned from the dataset class requires additional\n", "preprocessing; therefore, after polling the dataset, we process\n", "the data using the `preproc` function. This processed data is\n", "then transferred to the device. To recover the masking, we can\n", "save the torch RNG state and use it later. Finally, we run the\n", "model in evaluation mode without generating the gradient graph." ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [], "source": [ "# from PrithviWxC.dataloaders.merra2 import preproc\n", "\n", "data = next(iter(dataset))\n", "batch = preproc([data], padding)\n", "\n", "for k, v in batch.items():\n", " if isinstance(v, torch.Tensor):\n", " batch[k] = v.to(device)\n", "\n", "rng_state_1 = torch.get_rng_state()\n", "with torch.no_grad():\n", " model.eval()\n", " out = model(batch)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plotting" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "t2m = out[0, 12].cpu().numpy()\n", "\n", "lat = np.linspace(-90, 90, out.shape[-2])\n", "lon = np.linspace(-180, 180, out.shape[-1])\n", "X, Y = np.meshgrid(lon, lat)\n", "\n", "plt.contourf(X, Y, t2m, 100)\n", "plt.gca().set_aspect(\"equal\")\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.7" } }, "nbformat": 4, "nbformat_minor": 4 }