diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..9874ce4f89360f1a97e8f47808511b8f1d892cea 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,21 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_surface_doy001_hour00.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_surface_doy001_hour03.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_surface_doy001_hour06.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_surface_doy001_hour09.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_surface_doy001_hour12.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_surface_doy001_hour15.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_surface_doy001_hour18.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_surface_doy001_hour21.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_vertical_doy001_hour00.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_vertical_doy001_hour03.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_vertical_doy001_hour06.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_vertical_doy001_hour09.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_vertical_doy001_hour12.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_vertical_doy001_hour15.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_vertical_doy001_hour18.nc filter=lfs diff=lfs merge=lfs -text +examples/climatology/climate_vertical_doy001_hour21.nc filter=lfs diff=lfs merge=lfs -text +examples/merra-2/MERRA_pres_20200101.nc filter=lfs diff=lfs merge=lfs -text +examples/merra-2/MERRA2_sfc_20200101.nc filter=lfs diff=lfs merge=lfs -text diff --git a/examples/.cache/huggingface/.gitignore b/examples/.cache/huggingface/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f59ec20aabf5842d237244ece8c81ab184faeac1 --- /dev/null +++ b/examples/.cache/huggingface/.gitignore @@ -0,0 +1 @@ +* \ No newline at end of file diff --git a/examples/.cache/huggingface/download/climatology/anomaly_variance_surface.nc.metadata b/examples/.cache/huggingface/download/climatology/anomaly_variance_surface.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..51c64a690244bc524e3d99153b7c31bc6f4ec51e --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/anomaly_variance_surface.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +6c21b953f7f60f3bb80a1007e49227c5c018a26e3eb7d1a080a8d7bcf3ab7dfc +1731287317.7842605 diff --git a/examples/.cache/huggingface/download/climatology/anomaly_variance_vertical.nc.metadata b/examples/.cache/huggingface/download/climatology/anomaly_variance_vertical.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..a396fa2331c626fb1d8e2256b0b0c9bf1d5e3c08 --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/anomaly_variance_vertical.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +4317f4c8c89b1e604bca44841bf085603c91b8c70f92b37bee536f99c83222bb +1731287317.8256514 diff --git a/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour00.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour00.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..bf4bfea1a0806bc29fb5c6d167541f14bb224a0f --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour00.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +6df405a6178c3222c4abd98eea6573ced38aa2ccbe0966647a68ac18103a4d1d +1731111987.3265555 diff --git a/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour03.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour03.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..b266e296d327a228cf42ceb3374099e67d306597 --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour03.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +b67bcff5c6df5306677a65de3f7c08485a8ba1c3ecdc60b7f346cf0642caa7f1 +1731111985.428959 diff --git a/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour06.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour06.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..fd072a0e3d6e2496aaff128d94ae6b04d0e06a47 --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour06.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +080cb1f7e8dad508fc78e40a5bbd30d03564b7a48bdd2916f296dc3dfed30c60 +1731111986.869001 diff --git a/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour09.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour09.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..18a3cd3525ac468f20973a819c050b5656ec7a04 --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour09.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +134e194a8f38828ec067f98e8c5c7dc4aed0131046b6ed839f75bcf6b98b5492 +1731111985.1479104 diff --git a/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour12.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour12.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..b7dc1361be1b2ab56a7b7dc131a9c52b6da67059 --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour12.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +974a79a5e68096c356b42f15cd9c955f0a1306cf6a942c682e09d6504742a6d1 +1731111980.916947 diff --git a/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour15.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour15.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..fea603150a8f1cb49944b13043775d38e302e02f --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour15.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +dea669f097cf1b9d775956a01e2bae52ae1856dfcc7abae5c6ee394949275c5e +1731111986.7854862 diff --git a/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour18.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour18.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..c0a42ba8ead2027a59a65fa47a130b3186dcca47 --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour18.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +3b06aabcd1f14cc3a4b3c5dd355a225646fe4b9e3b5b676ee2fb6e5d05fe31c9 +1731111986.2445679 diff --git a/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour21.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour21.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..f1a1e036292d01c198f7d81d7a584f7ed8ce0260 --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_surface_doy001_hour21.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +ef90b12b58b73481f96c0caf5277d24cc05681003df0f6e83f5e85dc0b4d47b3 +1731111981.8565679 diff --git a/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour00.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour00.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..0e2aa5c1d1cad5cc2f6870174bf2f1e110cc5f07 --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour00.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +b5f75cc1a55b6f1beecf2271d3a8fe9914cb20fb87e471761b3b686a990ec0ef +1731112085.7071128 diff --git a/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour03.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour03.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..169dc54cc22423b00ef7d812ecd9b998bface555 --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour03.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +80a04c77e614e01ae26fdad1fc3c5a15fd6869d627ce4251447d64c3bb934916 +1731112080.982848 diff --git a/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour06.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour06.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..14f983905e36ff587b53572aa22112a0f0665506 --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour06.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +0f6fc10bbd055e1e7af63f70cfa24abd60794cd89822a95b9fb9194b07c7822f +1731112071.24669 diff --git a/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour09.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour09.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..54fd477dae6f4ef447075b99014129b87e59fa51 --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour09.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +70cb7c3e30902548b24330b61ecf493fdf1949bb434f45e9258434a80d91ee6d +1731112053.3633041 diff --git a/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour12.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour12.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..8d4df47781acbd5b97343a9277a4f597afd7e16f --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour12.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +a597136ae5b374e55b6f1113e218faeb3090a3d474f95bd6c043b4923238a32c +1731112082.503257 diff --git a/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour15.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour15.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..1fbca83a54a5dda08664a5e58d53eab0fdb6332a --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour15.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +6e1765b3fcf01aed2338b9df110a76a9e439703bb87c366b0d3868473ccceb46 +1731112055.9265618 diff --git a/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour18.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour18.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..3db49ef8c035200c54ecbed000d73b0d65db486f --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour18.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +66b27837717a6cd37d72d55c8de3cc24b2122e25173348be37b3a76d98c8d580 +1731112076.9637496 diff --git a/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour21.nc.metadata b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour21.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..a3a2b4f49a9a7b6f7cfe65013c473d6386935792 --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/climate_vertical_doy001_hour21.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +65084e38b67342d03135e98092d893e6b8ad677a1b690ea39b3fbc4b91df6fdc +1731112084.0759785 diff --git a/examples/.cache/huggingface/download/climatology/musigma_surface.nc.metadata b/examples/.cache/huggingface/download/climatology/musigma_surface.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..e392145e502badb4253b803d953f08c34c536007 --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/musigma_surface.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +62f7145aa62d2d632ef9a32b3492a932b128006cffe61674a53ecb1f163ce0b6 +1731287317.6933951 diff --git a/examples/.cache/huggingface/download/climatology/musigma_vertical.nc.metadata b/examples/.cache/huggingface/download/climatology/musigma_vertical.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..2a77b7d68ae46a71320be590c55aa54668bcb55f --- /dev/null +++ b/examples/.cache/huggingface/download/climatology/musigma_vertical.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +b83b40a81a18a1e9dc1633f71f7186fbeda9b17a3c1ecc536afd1575e635a1a3 +1731287317.7353563 diff --git a/examples/.cache/huggingface/download/config.yaml.metadata b/examples/.cache/huggingface/download/config.yaml.metadata new file mode 100644 index 0000000000000000000000000000000000000000..72f725600e68a9a05a47c289343ba9b9125b322b --- /dev/null +++ b/examples/.cache/huggingface/download/config.yaml.metadata @@ -0,0 +1,3 @@ +514c3d061ad45e3338495da7c16b13aa20fa75b1 +4435167a11fd412e2dfb565e135eed07a33c7663 +1731287317.8794053 diff --git a/examples/.cache/huggingface/download/merra-2/MERRA2_sfc_20200101.nc.metadata b/examples/.cache/huggingface/download/merra-2/MERRA2_sfc_20200101.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..4ab1d4433a1dee0a0d71b521636ce657b5c21fda --- /dev/null +++ b/examples/.cache/huggingface/download/merra-2/MERRA2_sfc_20200101.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +1de1638ca1f1b44ca95a7c6908573414c44e633f1bb2e212b8a46afc866c741e +1731111936.8741355 diff --git a/examples/.cache/huggingface/download/merra-2/MERRA_pres_20200101.nc.metadata b/examples/.cache/huggingface/download/merra-2/MERRA_pres_20200101.nc.metadata new file mode 100644 index 0000000000000000000000000000000000000000..441c56d0c4592ffbc82d90987b17818f08b35a96 --- /dev/null +++ b/examples/.cache/huggingface/download/merra-2/MERRA_pres_20200101.nc.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +2cd9b405aa1d388fc0c6cbe6d71104bea770b48e07ba4364f067dc425d025af8 +1731111969.7144666 diff --git a/examples/PrithviWxC_inference.ipynb b/examples/PrithviWxC_inference.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..bc3f80fdd52ee77574defbeabae0a8cc5d2da8ed --- /dev/null +++ b/examples/PrithviWxC_inference.ipynb @@ -0,0 +1,3186 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PrithviWxC\n", + "\n", + "This notebook will walk you through how to construct the model,\n", + "load the weights, build the dataset, and use the model for inference." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from huggingface_hub import hf_hub_download, snapshot_download" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We now configure the backends and torch states, including setting the seeds for the RNGs." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "torch.jit.enable_onednn_fusion(True)\n", + "if torch.cuda.is_available():\n", + " print(f\"Using device: {torch.cuda.get_device_name()}\")\n", + " torch.backends.cudnn.benchmark = True\n", + " torch.backends.cudnn.deterministic = True\n", + "\n", + "random.seed(42)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed(42)\n", + "torch.manual_seed(42)\n", + "np.random.seed(42)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model has approximately 2.3 billion parameters, so it\n", + "requires reasonable computational resources, but it is possible\n", + "to run it on a CPU." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "else:\n", + " device = torch.device(\"cpu\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Dataloader\n", + "### Variables and times\n", + "\n", + "With the environment ready to go, we now need to set up the task.\n", + "The core model expects a fixed set of variables from the MERRA-2\n", + "dataset, which are prescribed below. The variables are comprised\n", + "of surface variables, surface static variables, and variables at\n", + "various vertical levels within the atmosphere. More details on the\n", + "MERRA-2 dataset can be found\n", + "[here](https://gmao.gsfc.nasa.gov/reanalysis/MERRA-2/).\n", + "\n", + "The MERRA-2 dataset includes data at longitudes of $-180^\\circ$\n", + "and $+180^\\circ$. This represents duplicate data, so we set a\n", + "padding variable to remove it.\n", + "\n", + "The input to the core model consists of these variables at two\n", + "different times. The time difference in hours between these samples\n", + "is passed to the model and set in the input_time variable.\n", + "\n", + "The model's task is to predict the fixed set of variables at a\n", + "target time, given the input data.\n", + "\n", + "For example, if the input times are 0900 and 1200, resulting in\n", + "an input_time of -3, then a lead_time of 6 would result in a\n", + "target time of 1800." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "surface_vars = [\n", + " \"EFLUX\",\n", + " \"GWETROOT\",\n", + " \"HFLUX\",\n", + " \"LAI\",\n", + " \"LWGAB\",\n", + " \"LWGEM\",\n", + " \"LWTUP\",\n", + " \"PS\",\n", + " \"QV2M\",\n", + " \"SLP\",\n", + " \"SWGNT\",\n", + " \"SWTNT\",\n", + " \"T2M\",\n", + " \"TQI\",\n", + " \"TQL\",\n", + " \"TQV\",\n", + " \"TS\",\n", + " \"U10M\",\n", + " \"V10M\",\n", + " \"Z0M\",\n", + "]\n", + "static_surface_vars = [\"FRACI\", \"FRLAND\", \"FROCEAN\", \"PHIS\"]\n", + "vertical_vars = [\"CLOUD\", \"H\", \"OMEGA\", \"PL\", \"QI\", \"QL\", \"QV\", \"T\", \"U\", \"V\"]\n", + "levels = [\n", + " 34.0,\n", + " 39.0,\n", + " 41.0,\n", + " 43.0,\n", + " 44.0,\n", + " 45.0,\n", + " 48.0,\n", + " 51.0,\n", + " 53.0,\n", + " 56.0,\n", + " 63.0,\n", + " 68.0,\n", + " 71.0,\n", + " 72.0,\n", + "]\n", + "padding = {\"level\": [0, 0], \"lat\": [0, -1], \"lon\": [0, 0]}\n", + "\n", + "lead_times = [12] # This varibale can be change to change the task\n", + "input_times = [-6] # This varibale can be change to change the task" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data file\n", + "MERRA-2 data is available from 1980 to the present day,\n", + "at 3-hour temporal resolution. The dataloader we have provided\n", + "expects the surface data and vertical data to be saved in\n", + "separate files, and when provided with the directories, will\n", + "search for the relevant data that falls within the provided time range.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "45d1a1486bdc4dff82597d5cf87095f0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 1 files: 0%| | 0/1 [00:00 dict[str, Tensor]:\n", + " \"\"\"Prepressing function for MERRA2 Dataset\n", + "\n", + " Args:\n", + " batch (dict): List of training samples, each sample should be a\n", + " dictionary with the following keys::\n", + "\n", + " 'sur_static': Numpy array of shape (3, lat, lon). For each pixel (lat, lon), the first dimension indexes sin(lat), cos(lon), sin(lon).\n", + " 'sur_vals': Torch tensor of shape (parameter, time, lat, lon).\n", + " 'sur_tars': Torch tensor of shape (parameter, time, lat, lon).\n", + " 'ulv_vals': Torch tensor of shape (parameter, level, time, lat, lon).\n", + " 'ulv_tars': Torch tensor of shape (parameter, level, time, lat, lon).\n", + " 'sur_climate': Torch tensor of shape (parameter, lat, lon)\n", + " 'ulv_climate': Torch tensor of shape (parameter, level, lat, lon)\n", + " 'lead_time': Integer.\n", + " 'input_time': Integer.\n", + "\n", + " padding: Dictionary with keys 'level', 'lat', 'lon', each of dim 2.\n", + "\n", + " Returns:\n", + " Dictionary with the following keys::\n", + "\n", + " 'x': [batch, time, parameter, lat, lon]\n", + " 'y': [batch, parameter, lat, lon]\n", + " 'static': [batch, parameter, lat, lon]\n", + " 'lead_time': [batch]\n", + " 'input_time': [batch]\n", + " 'climate (Optional)': [batch, parameter, lat, lon]\n", + "\n", + " Note:\n", + " Here, for x and y, 'parameter' is [surface parameter, upper level,\n", + " parameter x level]. Similarly for the static information we have\n", + " [sin(lat), cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod),\n", + " ...].\n", + " \"\"\" # noqa: E501\n", + " b0 = batch[0]\n", + " nbatch = len(batch)\n", + " data_keys = set(b0.keys())\n", + "\n", + " essential_keys = {\n", + " \"sur_static\",\n", + " \"sur_vals\",\n", + " \"sur_tars\",\n", + " \"ulv_vals\",\n", + " \"ulv_tars\",\n", + " \"input_time\",\n", + " \"lead_time\",\n", + " }\n", + "\n", + " climate_keys = {\n", + " \"sur_climate\",\n", + " \"ulv_climate\",\n", + " }\n", + "\n", + " all_keys = essential_keys | climate_keys\n", + "\n", + " if not essential_keys.issubset(data_keys):\n", + " raise ValueError(\"Missing essential keys.\")\n", + "\n", + " if not data_keys.issubset(all_keys):\n", + " raise ValueError(\"Unexpected keys in batch.\")\n", + "\n", + " # Bring all tensors from the batch into a single tensor\n", + " upl_x = torch.empty((nbatch, *b0[\"ulv_vals\"].shape))\n", + " upl_y = torch.empty((nbatch, *b0[\"ulv_tars\"].shape))\n", + "\n", + " sur_x = torch.empty((nbatch, *b0[\"sur_vals\"].shape))\n", + " sur_y = torch.empty((nbatch, *b0[\"sur_tars\"].shape))\n", + "\n", + " sur_sta = torch.empty((nbatch, *b0[\"sur_static\"].shape))\n", + "\n", + " lead_time = torch.empty((nbatch,), dtype=torch.float32)\n", + " input_time = torch.empty((nbatch,), dtype=torch.float32)\n", + "\n", + " for i, rec in enumerate(batch):\n", + " sur_x[i] = rec[\"sur_vals\"]\n", + " sur_y[i] = rec[\"sur_tars\"]\n", + "\n", + " upl_x[i] = rec[\"ulv_vals\"]\n", + " upl_y[i] = rec[\"ulv_tars\"]\n", + "\n", + " sur_sta[i] = rec[\"sur_static\"]\n", + "\n", + " lead_time[i] = rec[\"lead_time\"]\n", + " input_time[i] = rec[\"input_time\"]\n", + "\n", + " return_value = {\n", + " \"lead_time\": lead_time,\n", + " \"input_time\": input_time,\n", + " }\n", + "\n", + " # Reshape (batch, parameter, level, time, lat, lon) ->\n", + " # (batch, time, parameter, level, lat, lon)\n", + " upl_x = upl_x.permute((0, 3, 1, 2, 4, 5))\n", + " upl_y = upl_y.permute((0, 3, 1, 2, 4, 5))\n", + " # Reshape (batch, parameter, time, lat, lon) ->\n", + " # (batch, time, parameter, lat, lon)\n", + " sur_x = sur_x.permute((0, 2, 1, 3, 4))\n", + " sur_y = sur_y.permute((0, 2, 1, 3, 4))\n", + "\n", + " # Pad\n", + " padding_2d = (*padding[\"lon\"], *padding[\"lat\"])\n", + "\n", + " def pad2d(x):\n", + " return torch.nn.functional.pad(x, padding_2d, mode=\"constant\", value=0)\n", + "\n", + " padding_3d = (*padding[\"lon\"], *padding[\"lat\"], *padding[\"level\"])\n", + "\n", + " def pad3d(x):\n", + " return torch.nn.functional.pad(x, padding_3d, mode=\"constant\", value=0)\n", + "\n", + " sur_x = pad2d(sur_x).contiguous()\n", + " upl_x = pad3d(upl_x).contiguous()\n", + " sur_y = pad2d(sur_y).contiguous()\n", + " upl_y = pad3d(upl_y).contiguous()\n", + " return_value[\"static\"] = pad2d(sur_sta).contiguous()\n", + "\n", + " # Remove time for targets\n", + " upl_y = torch.squeeze(upl_y, 1)\n", + " sur_y = torch.squeeze(sur_y, 1)\n", + "\n", + " # We stack along the combined parameter x level dimension\n", + " return_value[\"x\"] = torch.cat(\n", + " (sur_x, upl_x.view(*upl_x.shape[:2], -1, *upl_x.shape[4:])), dim=2\n", + " )\n", + " return_value[\"y\"] = torch.cat(\n", + " (sur_y, upl_y.view(upl_y.shape[0], -1, *upl_y.shape[3:])), dim=1\n", + " )\n", + "\n", + " if climate_keys.issubset(data_keys):\n", + " sur_climate = torch.empty((nbatch, *b0[\"sur_climate\"].shape))\n", + " ulv_climate = torch.empty((nbatch, *b0[\"ulv_climate\"].shape))\n", + " for i, rec in enumerate(batch):\n", + " sur_climate[i] = rec[\"sur_climate\"]\n", + " ulv_climate[i] = rec[\"ulv_climate\"]\n", + " sur_climate = pad2d(sur_climate)\n", + " ulv_climate = pad3d(ulv_climate)\n", + "\n", + " return_value[\"climate\"] = torch.cat(\n", + " (\n", + " sur_climate,\n", + " ulv_climate.view(nbatch, -1, *ulv_climate.shape[3:]),\n", + " ),\n", + " dim=1,\n", + " )\n", + "\n", + " return return_value\n", + "\n", + "\n", + "def input_scalers(\n", + " surf_vars: list[str],\n", + " vert_vars: list[str],\n", + " levels: list[float],\n", + " surf_path: str | Path,\n", + " vert_path: str | Path,\n", + ") -> tuple[Tensor, Tensor]:\n", + " \"\"\"Reads the input scalers\n", + "\n", + " Args:\n", + " surf_vars: surface variables to be used.\n", + " vert_vars: vertical variables to be used.\n", + " levels: MERRA2 levels to use.\n", + " surf_path: path to surface scalers file.\n", + " vert_path: path to vertical level scalers file.\n", + "\n", + " Returns:\n", + " mu (Tensor): mean values\n", + " var (Tensor): varience values\n", + " \"\"\"\n", + " with h5py.File(Path(surf_path), \"r\", libver=\"latest\") as surf_file:\n", + " stats = [x.decode().lower() for x in surf_file[\"statistic\"][()]]\n", + " mu_idx = stats.index(\"mu\")\n", + " sig_idx = stats.index(\"sigma\")\n", + "\n", + " s_mu = torch.tensor([surf_file[k][()][mu_idx] for k in surf_vars])\n", + " s_sig = torch.tensor([surf_file[k][()][sig_idx] for k in surf_vars])\n", + "\n", + " with h5py.File(Path(vert_path), \"r\", libver=\"latest\") as vert_file:\n", + " stats = [x.decode().lower() for x in vert_file[\"statistic\"][()]]\n", + " mu_idx = stats.index(\"mu\")\n", + " sig_idx = stats.index(\"sigma\")\n", + "\n", + " lvl = vert_file[\"lev\"][()]\n", + " l_idx = [np.where(lvl == v)[0].item() for v in levels]\n", + "\n", + " v_mu = np.array([vert_file[k][()][mu_idx, l_idx] for k in vert_vars])\n", + " v_sig = np.array([vert_file[k][()][sig_idx, l_idx] for k in vert_vars])\n", + "\n", + " v_mu = torch.from_numpy(v_mu).view(-1)\n", + " v_sig = torch.from_numpy(v_sig).view(-1)\n", + "\n", + " mu = torch.cat((s_mu, v_mu), dim=0).to(torch.float32)\n", + " sig = torch.cat((s_sig, v_sig), dim=0).to(torch.float32).clamp(1e-4, 1e4)\n", + " return mu, sig\n", + "\n", + "\n", + "def static_input_scalers(\n", + " scalar_path: str | Path, stat_vars: list[str], unscaled_params: int = 7\n", + ") -> tuple[Tensor, Tensor]:\n", + " scalar_path = Path(scalar_path)\n", + "\n", + " with h5py.File(scalar_path, \"r\", libver=\"latest\") as scaler_file:\n", + " stats = [x.decode().lower() for x in scaler_file[\"statistic\"][()]]\n", + " mu_idx = stats.index(\"mu\")\n", + " sig_idx = stats.index(\"sigma\")\n", + "\n", + " mu = torch.tensor([scaler_file[k][()][mu_idx] for k in stat_vars])\n", + " sig = torch.tensor([scaler_file[k][()][sig_idx] for k in stat_vars])\n", + "\n", + " z = torch.zeros(unscaled_params, dtype=mu.dtype, device=mu.device)\n", + " o = torch.ones(unscaled_params, dtype=sig.dtype, device=sig.device)\n", + " mu = torch.cat((z, mu), dim=0).to(torch.float32)\n", + " sig = torch.cat((o, sig), dim=0).to(torch.float32)\n", + "\n", + " return mu, sig.clamp(1e-4, 1e4)\n", + "\n", + "\n", + "def output_scalers(\n", + " surf_vars: list[str],\n", + " vert_vars: list[str],\n", + " levels: list[float],\n", + " surf_path: str | Path,\n", + " vert_path: str | Path,\n", + ") -> Tensor:\n", + " surf_path = Path(surf_path)\n", + " vert_path = Path(vert_path)\n", + "\n", + " with h5py.File(surf_path, \"r\", libver=\"latest\") as surf_file:\n", + " svars = torch.tensor([surf_file[k][()] for k in surf_vars])\n", + "\n", + " with h5py.File(vert_path, \"r\", libver=\"latest\") as vert_file:\n", + " lvl = vert_file[\"lev\"][()]\n", + " l_idx = [np.where(lvl == v)[0].item() for v in levels]\n", + " vvars = np.array([vert_file[k][()][l_idx] for k in vert_vars])\n", + " vvars = torch.from_numpy(vvars).view(-1)\n", + "\n", + " var = torch.cat((svars, vvars), dim=0).to(torch.float32).clamp(1e-7, 1e7)\n", + "\n", + " return var\n", + "\n", + "\n", + "class SampleSpec:\n", + " \"\"\"\n", + " A data class to collect the information used to define a sample.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " inputs: tuple[pd.Timestamp, pd.Timestamp],\n", + " lead_time: int,\n", + " target: pd.Timestamp | list[pd.Timestamp],\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " inputs: Tuple of timestamps. In ascending order.\n", + " lead_time: Lead time. In hours.\n", + " target: Timestamp of the target. Can be before or after the inputs.\n", + " \"\"\"\n", + " if not inputs[0] < inputs[1]:\n", + " raise ValueError(\n", + " \"Timestamps in `inputs` should be in strictly ascending order.\"\n", + " )\n", + "\n", + " self.inputs = inputs\n", + " self.input_time = (inputs[1] - inputs[0]).total_seconds() / 3600\n", + " self.lead_time = lead_time\n", + " self.target = target\n", + "\n", + " self.times = [*inputs, target]\n", + " self.stat_times = [inputs[-1]]\n", + "\n", + " @property\n", + " def climatology_info(self) -> tuple[int, int]:\n", + " \"\"\"Get the required climatology info.\n", + "\n", + " :return: information required to obtain climatology data. Essentially\n", + " this is the day of the year and hour of the day of the target\n", + " timestamp, with the former restricted to the interval [1, 365].\n", + " :rtype: tuple\n", + " \"\"\"\n", + " return (min(self.target.dayofyear, 365), self.target.hour)\n", + "\n", + " @property\n", + " def year(self) -> int:\n", + " return self.inputs[1].year\n", + "\n", + " @property\n", + " def dayofyear(self) -> int:\n", + " return self.inputs[1].dayofyear\n", + "\n", + " @property\n", + " def hourofday(self) -> int:\n", + " return self.inputs[1].hour\n", + "\n", + " def _info_str(self) -> str:\n", + " iso_8601 = \"%Y-%m-%dT%H:%M:%S\"\n", + "\n", + " return (\n", + " f\"Issue time: {self.inputs[1].strftime(iso_8601)}\\n\"\n", + " f\"Lead time: {self.lead_time} hours ahead\\n\"\n", + " f\"Input delta: {self.input_time} hours\\n\"\n", + " f\"Target time: {self.target.strftime(iso_8601)}\"\n", + " )\n", + "\n", + " @classmethod\n", + " def get(cls, timestamp: pd.Timestamp, dt: int, lead_time: int):\n", + " \"\"\"Given a timestamp and lead time, generates a SampleSpec object\n", + " describing the sample further.\n", + "\n", + " Args:\n", + " timestamp: Timstamp of the sample, Ie this is the larger of the two\n", + " input timstamps.\n", + " dt: Time between input samples, in hours.\n", + " lead_time: Lead time. In hours.\n", + "\n", + " Returns:\n", + " SampleSpec\n", + " \"\"\" # noqa: E501\n", + " assert dt > 0, \"dt should be possitive\"\n", + " lt = pd.to_timedelta(lead_time, unit=\"h\")\n", + " dt = pd.to_timedelta(dt, unit=\"h\")\n", + "\n", + " if lead_time >= 0:\n", + " timestamp_target = timestamp + lt\n", + " else:\n", + " timestamp_target = timestamp - dt + lt\n", + "\n", + " spec = cls(\n", + " inputs=(timestamp - dt, timestamp),\n", + " lead_time=lead_time,\n", + " target=timestamp_target,\n", + " )\n", + "\n", + " return spec\n", + "\n", + " def __repr__(self) -> str:\n", + " return self._info_str()\n", + "\n", + " def __str__(self) -> str:\n", + " return self._info_str()\n", + "\n", + "\n", + "class Merra2Dataset(Dataset):\n", + " \"\"\"MERRA2 dataset. The dataset unifies surface and vertical data as well as\n", + " optional climatology.\n", + "\n", + " Samples come in the form of a dictionary. Not all keys support all\n", + " variables, yet the general ordering of dimensions is\n", + " parameter, level, time, lat, lon\n", + "\n", + " Note:\n", + " Data is assumed to be in NetCDF files containing daily data at 3-hourly\n", + " intervals. These follow the naming patterns\n", + " MERRA2_sfc_YYYYMMHH.nc and MERRA_pres_YYYYMMHH.nc and can be located in\n", + " two different locations. Optional climatology data comes from files\n", + " climate_surface_doyDOY_hourHOD.nc and\n", + " climate_vertical_doyDOY_hourHOD.nc.\n", + "\n", + "\n", + " Note:\n", + " `_get_valid_timestamps` assembles a set of all timestamps for which\n", + " there is data (with hourly resolutions). The result is stored in\n", + " `_valid_timestamps`. `_get_valid_climate_timestamps` does the same with\n", + " climatology data and stores it in `_valid_climate_timestamps`.\n", + "\n", + " Based on this information, `samples` generates a list of valid samples,\n", + " stored in `samples`. Here the format is::\n", + "\n", + " [\n", + " [\n", + " (timestamp 1, lead time A),\n", + " (timestamp 1, lead time B),\n", + " (timestamp 1, lead time C),\n", + " ],\n", + " [\n", + " (timestamp 2, lead time D),\n", + " (timestamp 2, lead time E),\n", + " ]\n", + " ]\n", + "\n", + " That is, the outer list iterates over timestamps (init times), the\n", + " inner over lead times. Only valid entries are stored.\n", + " \"\"\"\n", + "\n", + " valid_vertical_vars = [\n", + " \"CLOUD\",\n", + " \"H\",\n", + " \"OMEGA\",\n", + " \"PL\",\n", + " \"QI\",\n", + " \"QL\",\n", + " \"QV\",\n", + " \"T\",\n", + " \"U\",\n", + " \"V\",\n", + " ]\n", + " valid_surface_vars = [\n", + " \"EFLUX\",\n", + " \"GWETROOT\",\n", + " \"HFLUX\",\n", + " \"LAI\",\n", + " \"LWGAB\",\n", + " \"LWGEM\",\n", + " \"LWTUP\",\n", + " \"PRECTOT\",\n", + " \"PS\",\n", + " \"QV2M\",\n", + " \"SLP\",\n", + " \"SWGNT\",\n", + " \"SWTNT\",\n", + " \"T2M\",\n", + " \"TQI\",\n", + " \"TQL\",\n", + " \"TQV\",\n", + " \"TS\",\n", + " \"U10M\",\n", + " \"V10M\",\n", + " \"Z0M\",\n", + " ]\n", + " valid_static_surface_vars = [\"FRACI\", \"FRLAND\", \"FROCEAN\", \"PHIS\"]\n", + "\n", + " valid_levels = [\n", + " 34.0,\n", + " 39.0,\n", + " 41.0,\n", + " 43.0,\n", + " 44.0,\n", + " 45.0,\n", + " 48.0,\n", + " 51.0,\n", + " 53.0,\n", + " 56.0,\n", + " 63.0,\n", + " 68.0,\n", + " 71.0,\n", + " 72.0,\n", + " ]\n", + "\n", + " timedelta_input = pd.to_timedelta(3, unit=\"h\")\n", + "\n", + " def __init__(\n", + " self,\n", + " time_range: tuple[str | pd.Timestamp, str | pd.Timestamp],\n", + " lead_times: list[int],\n", + " input_times: list[int],\n", + " data_path_surface: str | Path,\n", + " data_path_vertical: str | Path,\n", + " climatology_path_surface: str | Path | None = None,\n", + " climatology_path_vertical: str | Path | None = None,\n", + " surface_vars: list[str] | None = None,\n", + " static_surface_vars: list[str] | None = None,\n", + " vertical_vars: list[str] | None = None,\n", + " levels: list[float] | None = None,\n", + " roll_longitudes: int = 0,\n", + " positional_encoding: str = \"absolute\",\n", + " rtype: type = np.float32,\n", + " dtype: torch.dtype = torch.float32,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " data_path_surface: Location of surface data.\n", + " data_path_vertical: Location of vertical data.\n", + " climatology_path_surface: Location of (optional) surface\n", + " climatology.\n", + " climatology_path_vertical: Location of (optional) vertical\n", + " climatology.\n", + " surface_vars: Surface variables.\n", + " static_surface_vars: Static surface variables.\n", + " vertical_vars: Vertical variables.\n", + " levels: Levels.\n", + " time_range: Used to subset data.\n", + " lead_times: Lead times for generalized forecasting.\n", + " roll_longitudes: Set to non-zero value to data by random amount\n", + " along longitude dimension.\n", + " position_encoding: possible values are\n", + " ['absolute' (default), 'fourier'].\n", + " 'absolute' returns lat lon encoded in 3 dimensions using sine\n", + " and cosine\n", + " 'fourier' returns lat/lon to be encoded by model\n", + " returns lat/lon to be encoded by model\n", + " rtype: numpy data type used during read\n", + " dtype: torch data type of data output\n", + " \"\"\"\n", + "\n", + " self.time_range = (\n", + " pd.to_datetime(time_range[0]),\n", + " pd.to_datetime(time_range[1]),\n", + " )\n", + " self.lead_times = lead_times\n", + " self.input_times = input_times\n", + " self._roll_longitudes = list(range(roll_longitudes + 1))\n", + "\n", + " self._uvars = vertical_vars or self.valid_vertical_vars\n", + " self._level = levels or self.valid_levels\n", + " self._svars = surface_vars or self.valid_surface_vars\n", + " self._sstat = static_surface_vars or self.valid_static_surface_vars\n", + " self._nuvars = len(self._uvars)\n", + " self._nlevel = len(self._level)\n", + " self._nsvars = len(self._svars)\n", + " self._nsstat = len(self._sstat)\n", + "\n", + " self.rtype = rtype\n", + " self.dtype = dtype\n", + "\n", + " self.positional_encoding = positional_encoding\n", + "\n", + " self._data_path_surface = Path(data_path_surface)\n", + " self._data_path_vertical = Path(data_path_vertical)\n", + "\n", + " self.dir_exists(self._data_path_surface)\n", + " self.dir_exists(self._data_path_vertical)\n", + "\n", + " self._get_coordinates()\n", + "\n", + " self._climatology_path_surface = Path(climatology_path_surface) or None\n", + " self._climatology_path_vertical = (\n", + " Path(climatology_path_vertical) or None\n", + " )\n", + " self._require_clim = (\n", + " self._climatology_path_surface is not None\n", + " and self._climatology_path_vertical is not None\n", + " )\n", + "\n", + " if self._require_clim:\n", + " self.dir_exists(self._climatology_path_surface)\n", + " self.dir_exists(self._climatology_path_vertical)\n", + " elif (\n", + " climatology_path_surface is None\n", + " and climatology_path_vertical is None\n", + " ):\n", + " self._climatology_path_surface = None\n", + " self._climatology_path_vertical = None\n", + " else:\n", + " raise ValueError(\n", + " \"Either both or neither of\"\n", + " \"`climatology_path_surface` and\"\n", + " \"`climatology_path_vertical` should be None.\"\n", + " )\n", + "\n", + " if not set(self._svars).issubset(set(self.valid_surface_vars)):\n", + " raise ValueError(\"Invalid surface variable.\")\n", + "\n", + " if not set(self._sstat).issubset(set(self.valid_static_surface_vars)):\n", + " raise ValueError(\"Invalid static surface variable.\")\n", + "\n", + " if not set(self._uvars).issubset(set(self.valid_vertical_vars)):\n", + " raise ValueError(\"Inalid vertical variable.\")\n", + "\n", + " if not set(self._level).issubset(set(self.valid_levels)):\n", + " raise ValueError(\"Invalid level.\")\n", + "\n", + " @staticmethod\n", + " def dir_exists(path: Path) -> None:\n", + " if not path.is_dir():\n", + " raise ValueError(f\"Directory {path} does not exist.\")\n", + "\n", + " @property\n", + " def upper_shape(self) -> tuple:\n", + " \"\"\"Returns the vertical variables shape\n", + " Returns:\n", + " tuple: vertical variable shape in the following order::\n", + "\n", + " [VAR, LEV, TIME, LAT, LON]\n", + " \"\"\"\n", + " return self._nuvars, self._nlevel, 2, 361, 576\n", + "\n", + " @property\n", + " def surface_shape(self) -> tuple:\n", + " \"\"\"Returns the surface variables shape\n", + "\n", + " Returns:\n", + " tuple: surafce shape in the following order::\n", + "\n", + " [VAR, LEV, TIME, LAT, LON]\n", + " \"\"\"\n", + " return self._nsvars, 2, 361, 576\n", + "\n", + " def data_file_surface(self, timestamp: pd.Timestamp) -> Path:\n", + " \"\"\"Build the surfcae data file name based on timestamp\n", + "\n", + " Args:\n", + " timestamp: a timestamp\n", + "\n", + " Returns:\n", + " Path: constructed path\n", + " \"\"\"\n", + " pattern = \"MERRA2_sfc_%Y%m%d.nc\"\n", + " data_file = self._data_path_surface / timestamp.strftime(pattern)\n", + " return data_file\n", + "\n", + " def data_file_vertical(self, timestamp: pd.Timestamp) -> Path:\n", + " \"\"\"Build the vertical data file name based on timestamp\n", + "\n", + " Args:\n", + " timestamp: a timestamp\n", + "\n", + " Returns:\n", + " Path: constructed path\n", + " \"\"\"\n", + " pattern = \"MERRA_pres_%Y%m%d.nc\"\n", + " data_file = self._data_path_vertical / timestamp.strftime(pattern)\n", + " return data_file\n", + "\n", + " def data_file_surface_climate(\n", + " self,\n", + " timestamp: pd.Timestamp | None = None,\n", + " dayofyear: int | None = None,\n", + " hourofday: int | None = None,\n", + " ) -> Path:\n", + " \"\"\"\n", + " Returns the path to a climatology file based either on a timestamp or\n", + " the dayofyear / hourofday combination.\n", + " Args:\n", + " timestamp: A timestamp.\n", + " dayofyear: Day of the year. 1 to 366.\n", + " hourofday: Hour of the day. 0 to 23.\n", + " Returns:\n", + " Path: Path to climatology file.\n", + " \"\"\"\n", + " if timestamp is not None and (\n", + " (dayofyear is not None) or (hourofday is not None)\n", + " ):\n", + " raise ValueError(\n", + " \"Provide either timestamp or both dayofyear and hourofday.\"\n", + " )\n", + "\n", + " if timestamp is not None:\n", + " dayofyear = min(timestamp.dayofyear, 365)\n", + " hourofday = timestamp.hour\n", + "\n", + " file_name = f\"climate_surface_doy{dayofyear:03}_hour{hourofday:02}.nc\"\n", + " data_file = self._climatology_path_surface / file_name\n", + " return data_file\n", + "\n", + " def data_file_vertical_climate(\n", + " self,\n", + " timestamp: pd.Timestamp | None = None,\n", + " dayofyear: int | None = None,\n", + " hourofday: int | None = None,\n", + " ) -> Path:\n", + " \"\"\"Returns the path to a climatology file based either on a timestamp\n", + " or the dayofyear / hourofday combination.\n", + "\n", + " Args:\n", + " timestamp: A timestamp. dayofyear: Day of the year. 1 to 366.\n", + " hourofday: Hour of the day. 0 to 23.\n", + " Returns:\n", + " Path: Path to climatology file.\n", + " \"\"\"\n", + " if timestamp is not None and (\n", + " (dayofyear is not None) or (hourofday is not None)\n", + " ):\n", + " raise ValueError(\n", + " \"Provide either timestamp or both dayofyear and hourofday.\"\n", + " )\n", + "\n", + " if timestamp is not None:\n", + " dayofyear = min(timestamp.dayofyear, 365)\n", + " hourofday = timestamp.hour\n", + "\n", + " file_name = f\"climate_vertical_doy{dayofyear:03}_hour{hourofday:02}.nc\"\n", + " data_file = self._climatology_path_vertical / file_name\n", + " return data_file\n", + "\n", + " def _get_coordinates(self) -> None:\n", + " \"\"\"\n", + " Obtains the coordiantes (latitudes and longitudes) from a single data\n", + " file.\n", + " \"\"\"\n", + " timestamp = next(iter(self.valid_timestamps))\n", + "\n", + " file = self.data_file_surface(timestamp)\n", + " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", + " self.lats = lats = handle[\"lat\"][()].astype(self.rtype)\n", + " self.lons = lons = handle[\"lon\"][()].astype(self.rtype)\n", + "\n", + " deg_to_rad = np.pi / 180\n", + " self._embed_lat = np.sin(lats * deg_to_rad).reshape(-1, 1)\n", + "\n", + " self._embed_lon = np.empty((2, 1, len(lons)), dtype=self.rtype)\n", + " self._embed_lon[0, 0] = np.cos(lons * deg_to_rad)\n", + " self._embed_lon[1, 0] = np.sin(lons * deg_to_rad)\n", + "\n", + " @ft.cached_property\n", + " def lats(self) -> np.ndarray:\n", + " timestamp = next(iter(self.valid_timestamps))\n", + "\n", + " file = self.data_file_surface(timestamp)\n", + " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", + " return handle[\"lat\"][()].astype(self.rtype)\n", + "\n", + " @ft.cached_property\n", + " def lons(self) -> np.ndarray:\n", + " timestamp = next(iter(self.valid_timestamps))\n", + "\n", + " file = self.data_file_surface(timestamp)\n", + " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", + " return handle[\"lon\"][()].astype(self.rtype)\n", + "\n", + " @ft.cached_property\n", + " def position_signal(self) -> np.ndarray:\n", + " \"\"\"Generates the \"position signal\" that is part of the static\n", + " features.\n", + "\n", + " Returns:\n", + " Tensor: Torch tensor of dimension (parameter, lat, lon) containing\n", + " sin(lat), cos(lon), sin(lon).\n", + " \"\"\"\n", + "\n", + " latitudes, longitudes = np.meshgrid(\n", + " self.lats, self.lons, indexing=\"ij\"\n", + " )\n", + "\n", + " if self.positional_encoding == \"absolute\":\n", + " latitudes = latitudes / 360 * 2.0 * np.pi\n", + " longitudes = longitudes / 360 * 2.0 * np.pi\n", + " sur_static = np.stack(\n", + " [np.sin(latitudes), np.cos(longitudes), np.sin(longitudes)],\n", + " axis=0,\n", + " )\n", + " else:\n", + " sur_static = np.stack([latitudes, longitudes], axis=0)\n", + "\n", + " sur_static = sur_static.astype(self.rtype)\n", + "\n", + " return sur_static\n", + "\n", + " @ft.cached_property\n", + " def valid_timestamps(self) -> set[pd.Timestamp]:\n", + " \"\"\"Generates list of valid timestamps based on available files. Only\n", + " timestamps for which both surface and vertical information is available\n", + " are considered valid.\n", + " Returns:\n", + " list: list of timestamps\n", + " \"\"\"\n", + "\n", + " s_glob = self._data_path_surface.glob(\"MERRA2_sfc_????????.nc\")\n", + " s_files = [os.path.basename(f) for f in s_glob]\n", + " v_glob = self._data_path_surface.glob(\"MERRA_pres_????????.nc\")\n", + " v_files = [os.path.basename(f) for f in v_glob]\n", + "\n", + " s_re = re.compile(r\"MERRA2_sfc_(\\d{8}).nc\\Z\")\n", + " v_re = re.compile(r\"MERRA_pres_(\\d{8}).nc\\Z\")\n", + " fmt = \"%Y%m%d\"\n", + "\n", + " s_times = {\n", + " (datetime.strptime(m[1], fmt))\n", + " for f in s_files\n", + " if (m := s_re.match(f))\n", + " }\n", + " v_times = {\n", + " (datetime.strptime(m[1], fmt))\n", + " for f in v_files\n", + " if (m := v_re.match(f))\n", + " }\n", + "\n", + " times = s_times.intersection(v_times)\n", + "\n", + " # Each file contains a day at 3 hour intervals\n", + " times = {\n", + " t + timedelta(hours=i) for i in range(0, 24, 3) for t in times\n", + " }\n", + "\n", + " start_time, end_time = self.time_range\n", + " times = {pd.Timestamp(t) for t in times if start_time <= t <= end_time}\n", + "\n", + " return times\n", + "\n", + " @ft.cached_property\n", + " def valid_climate_timestamps(self) -> set[tuple[int, int]]:\n", + " \"\"\"Generates list of \"timestamps\" (dayofyear, hourofday) for which\n", + " climatology data is present. Only instances for which surface and\n", + " vertical data is available are considered valid.\n", + " Returns:\n", + " list: List of tuples describing valid climatology instances.\n", + " \"\"\"\n", + " if not self._require_clim:\n", + " return set()\n", + "\n", + " s_glob = self._climatology_path_surface.glob(\n", + " \"climate_surface_doy???_hour??.nc\"\n", + " )\n", + " s_files = [os.path.basename(f) for f in s_glob]\n", + "\n", + " v_glob = self._climatology_path_vertical.glob(\n", + " \"climate_vertical_doy???_hour??.nc\"\n", + " )\n", + " v_files = [os.path.basename(f) for f in v_glob]\n", + "\n", + " s_re = re.compile(r\"climate_surface_doy(\\d{3})_hour(\\d{2}).nc\\Z\")\n", + " v_re = re.compile(r\"climate_vertical_doy(\\d{3})_hour(\\d{2}).nc\\Z\")\n", + "\n", + " s_times = {\n", + " (int(m[1]), int(m[2])) for f in s_files if (m := s_re.match(f))\n", + " }\n", + " v_times = {\n", + " (int(m[1]), int(m[2])) for f in v_files if (m := v_re.match(f))\n", + " }\n", + "\n", + " times = s_times.intersection(v_times)\n", + "\n", + " return times\n", + "\n", + " def _data_available(self, spec: SampleSpec) -> bool:\n", + " \"\"\"\n", + " Checks whether data is available for a given SampleSpec object. Does so\n", + " using the internal sets with available data previously constructed. Not\n", + " by checking the file system.\n", + " Args:\n", + " spec: SampleSpec object as returned by SampleSpec.get\n", + " Returns:\n", + " bool: if data is availability.\n", + " \"\"\"\n", + " valid = set(spec.times).issubset(self.valid_timestamps)\n", + "\n", + " if self._require_clim:\n", + " sci = spec.climatology_info\n", + " ci = set(sci) if isinstance(sci, list) else set([sci]) # noqa: C405\n", + " valid &= ci.issubset(self.valid_climate_timestamps)\n", + "\n", + " return valid\n", + "\n", + " @ft.cached_property\n", + " def samples(self) -> list[tuple[pd.Timestamp, int, int]]:\n", + " \"\"\"\n", + " Generates list of all valid samlpes.\n", + " Returns:\n", + " list: List of tuples (timestamp, input time, lead time).\n", + " \"\"\"\n", + " valid_samples = []\n", + " dts = [(it, lt) for it in self.input_times for lt in self.lead_times]\n", + "\n", + " for timestamp in sorted(self.valid_timestamps):\n", + " timestamp_samples = []\n", + " for it, lt in dts:\n", + " spec = SampleSpec.get(timestamp, -it, lt)\n", + "\n", + " if self._data_available(spec):\n", + " timestamp_samples.append((timestamp, it, lt))\n", + "\n", + " if timestamp_samples:\n", + " valid_samples.append(timestamp_samples)\n", + "\n", + " return valid_samples\n", + "\n", + " def _to_torch(\n", + " self,\n", + " data: dict[str, Tensor | list[Tensor]],\n", + " dtype: torch.dtype = torch.float32,\n", + " ) -> dict[str, Tensor | list[Tensor]]:\n", + " out = {}\n", + " for k, v in data.items():\n", + " if isinstance(v, list):\n", + " out[k] = [torch.from_numpy(x).to(dtype) for x in v]\n", + " else:\n", + " out[k] = torch.from_numpy(v).to(dtype)\n", + "\n", + " return out\n", + "\n", + " def _lat_roll(\n", + " self, data: dict[str, Tensor | list[Tensor]], n: int\n", + " ) -> dict[str, Tensor | list[Tensor]]:\n", + " out = {}\n", + " for k, v in data.items():\n", + " if isinstance(v, list):\n", + " out[k] = [torch.roll(x, shifts=n, dims=-1) for x in v]\n", + " else:\n", + " out[k] = torch.roll(v, shifts=n, dims=-1)\n", + "\n", + " return out\n", + "\n", + " def _read_static_data(\n", + " self, file: str | Path, doy: int, hod: int\n", + " ) -> np.ndarray:\n", + " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", + " lats_surf = handle[\"lat\"]\n", + " lons_surf = handle[\"lon\"]\n", + "\n", + " nll = (len(lats_surf), len(lons_surf))\n", + "\n", + " npos = len(self.position_signal)\n", + " ntime = 4\n", + "\n", + " nstat = npos + ntime + self._nsstat\n", + " data = np.empty((nstat, *nll), dtype=self.rtype)\n", + "\n", + " for i, key in enumerate(self._sstat, start=npos + ntime):\n", + " data[i] = handle[key][()].astype(dtype=self.rtype)\n", + "\n", + " # [possition signal], cos(doy), sin(doy), cos(hod), sin(hod)\n", + " data[0:npos] = self.position_signal\n", + " data[npos + 0] = np.cos(2 * np.pi * doy / 366)\n", + " data[npos + 1] = np.sin(2 * np.pi * doy / 366)\n", + " data[npos + 2] = np.cos(2 * np.pi * hod / 24)\n", + " data[npos + 3] = np.sin(2 * np.pi * hod / 24)\n", + "\n", + " return data\n", + "\n", + " def _read_surface(\n", + " self, tidx: int, nll: tuple[int, int], handle: h5py.File\n", + " ) -> np.ndarray:\n", + " data = np.empty((self._nsvars, *nll), dtype=self.rtype)\n", + "\n", + " for i, key in enumerate(self._svars):\n", + " data[i] = handle[key][tidx][()].astype(dtype=self.rtype)\n", + "\n", + " return data\n", + "\n", + " def _read_levels(\n", + " self, tidx: int, nll: tuple[int, int], handle: h5py.File\n", + " ) -> np.ndarray:\n", + " lvls = handle[\"lev\"][()]\n", + " lidx = self._level_idxs(lvls)\n", + "\n", + " data = np.empty((self._nuvars, self._nlevel, *nll), dtype=self.rtype)\n", + "\n", + " for i, key in enumerate(self._uvars):\n", + " data[i] = handle[key][tidx, lidx][()].astype(dtype=self.rtype)\n", + "\n", + " return np.ascontiguousarray(np.flip(data, axis=1))\n", + "\n", + " def _level_idxs(self, lvls):\n", + " lidx = [np.argwhere(lvls == int(lvl)).item() for lvl in self._level]\n", + " return sorted(lidx)\n", + "\n", + " @staticmethod\n", + " def _date_to_tidx(date: datetime | pd.Timestamp, handle: h5py.File) -> int:\n", + " if isinstance(date, pd.Timestamp):\n", + " date = date.to_pydatetime()\n", + "\n", + " time = handle[\"time\"]\n", + "\n", + " t0 = time.attrs[\"begin_time\"][()].item()\n", + " d0 = f\"{time.attrs['begin_date'][()].item()}\"\n", + "\n", + " offset = datetime.strptime(d0, \"%Y%m%d\")\n", + "\n", + " times = [offset + timedelta(minutes=int(t + t0)) for t in time[()]]\n", + " return times.index(date)\n", + "\n", + " def _read_data(\n", + " self, file_pair: tuple[str, str], date: datetime\n", + " ) -> dict[str, np.ndarray]:\n", + " s_file, v_file = file_pair\n", + "\n", + " with h5py.File(s_file, \"r\", libver=\"latest\") as shandle:\n", + " lats_surf = shandle[\"lat\"]\n", + " lons_surf = shandle[\"lon\"]\n", + "\n", + " nll = (len(lats_surf), len(lons_surf))\n", + "\n", + " tidx = self._date_to_tidx(date, shandle)\n", + "\n", + " sdata = self._read_surface(tidx, nll, shandle)\n", + "\n", + " with h5py.File(v_file, \"r\", libver=\"latest\") as vhandle:\n", + " lats_vert = vhandle[\"lat\"]\n", + " lons_vert = vhandle[\"lon\"]\n", + "\n", + " nll = (len(lats_vert), len(lons_vert))\n", + "\n", + " tidx = self._date_to_tidx(date, vhandle)\n", + "\n", + " vdata = self._read_levels(tidx, nll, vhandle)\n", + "\n", + " data = {\"vert\": vdata, \"surf\": sdata}\n", + "\n", + " return data\n", + "\n", + " def _read_climate(\n", + " self, file_pair: tuple[str, str]\n", + " ) -> dict[str, np.ndarray]:\n", + " s_file, v_file = file_pair\n", + "\n", + " with h5py.File(s_file, \"r\", libver=\"latest\") as shandle:\n", + " lats_surf = shandle[\"lat\"]\n", + " lons_surf = shandle[\"lon\"]\n", + "\n", + " nll = (len(lats_surf), len(lons_surf))\n", + "\n", + " sdata = np.empty((self._nsvars, *nll), dtype=self.rtype)\n", + "\n", + " for i, key in enumerate(self._svars):\n", + " sdata[i] = shandle[key][()].astype(dtype=self.rtype)\n", + "\n", + " with h5py.File(v_file, \"r\", libver=\"latest\") as vhandle:\n", + " lats_vert = vhandle[\"lat\"]\n", + " lons_vert = vhandle[\"lon\"]\n", + "\n", + " nll = (len(lats_vert), len(lons_vert))\n", + "\n", + " lvls = vhandle[\"lev\"][()]\n", + " lidx = self._level_idxs(lvls)\n", + "\n", + " vdata = np.empty(\n", + " (self._nuvars, self._nlevel, *nll), dtype=self.rtype\n", + " )\n", + "\n", + " for i, key in enumerate(self._uvars):\n", + " vdata[i] = vhandle[key][lidx][()].astype(dtype=self.rtype)\n", + "\n", + " data = {\n", + " \"vert\": np.ascontiguousarray(np.flip(vdata, axis=1)),\n", + " \"surf\": sdata,\n", + " }\n", + "\n", + " return data\n", + "\n", + " def get_data_from_sample_spec(\n", + " self, spec: SampleSpec\n", + " ) -> dict[str, Tensor | int | float]:\n", + " \"\"\"Loads and assembles sample data given a SampleSpec object.\n", + "\n", + " Args:\n", + " spec (SampleSpec): Full details regarding the data to be loaded\n", + " Returns:\n", + " dict: Dictionary with the following keys::\n", + "\n", + " 'sur_static': Torch tensor of shape [parameter, lat, lon]. For\n", + " each pixel (lat, lon), the first 7 dimensions index sin(lat),\n", + " cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod).\n", + " Where doy is the day of the year [1, 366] and hod the hour of\n", + " the day [0, 23].\n", + " 'sur_vals': Torch tensor of shape [parameter, time, lat, lon].\n", + " 'sur_tars': Torch tensor of shape [parameter, time, lat, lon].\n", + " 'ulv_vals': Torch tensor of shape [parameter, level, time, lat, lon].\n", + " 'ulv_tars': Torch tensor of shape [parameter, level, time, lat, lon].\n", + " 'sur_climate': Torch tensor of shape [parameter, lat, lon].\n", + " 'ulv_climate': Torch tensor of shape [paramter, level, lat, lon].\n", + " 'lead_time': Float.\n", + " 'input_time': Float.\n", + "\n", + " \"\"\" # noqa: E501\n", + "\n", + " # We assemble the unique timestamps for which we need data.\n", + " vals_required = {*spec.times}\n", + " stat_required = {*spec.stat_times}\n", + "\n", + " # We assemble the unique data files from which we need value data\n", + " vals_file_map = defaultdict(list)\n", + " for t in vals_required:\n", + " data_files = (\n", + " self.data_file_surface(t),\n", + " self.data_file_vertical(t),\n", + " )\n", + " vals_file_map[data_files].append(t)\n", + "\n", + " # We assemble the unique data files from which we need static data\n", + " stat_file_map = defaultdict(list)\n", + " for t in stat_required:\n", + " data_files = (\n", + " self.data_file_surface(t),\n", + " self.data_file_vertical(t),\n", + " )\n", + " stat_file_map[data_files].append(t)\n", + "\n", + " # Load the value data\n", + " data = {}\n", + " for data_files, times in vals_file_map.items():\n", + " for time in times:\n", + " data[time] = self._read_data(data_files, time)\n", + "\n", + " # Combine times\n", + " sample_data = {}\n", + "\n", + " input_upl = np.stack([data[t][\"vert\"] for t in spec.inputs], axis=2)\n", + " sample_data[\"ulv_vals\"] = input_upl\n", + "\n", + " target_upl = data[spec.target][\"vert\"]\n", + " sample_data[\"ulv_tars\"] = target_upl[:, :, None]\n", + "\n", + " input_sur = np.stack([data[t][\"surf\"] for t in spec.inputs], axis=1)\n", + " sample_data[\"sur_vals\"] = input_sur\n", + "\n", + " target_sur = data[spec.target][\"surf\"]\n", + " sample_data[\"sur_tars\"] = target_sur[:, None]\n", + "\n", + " # Load the static data\n", + " data_files, times = stat_file_map.popitem()\n", + " time = times[0].dayofyear, times[0].hour\n", + " sample_data[\"sur_static\"] = self._read_static_data(\n", + " data_files[0], *time\n", + " )\n", + "\n", + " # If required load the surface data\n", + " if self._require_clim:\n", + " ci_year, ci_hour = spec.climatology_info\n", + "\n", + " surf_file = self.data_file_surface_climate(\n", + " dayofyear=ci_year,\n", + " hourofday=ci_hour,\n", + " )\n", + "\n", + " vert_file = self.data_file_vertical_climate(\n", + " dayofyear=ci_year,\n", + " hourofday=ci_hour,\n", + " )\n", + "\n", + " clim_data = self._read_climate((surf_file, vert_file))\n", + "\n", + " sample_data[\"sur_climate\"] = clim_data[\"surf\"]\n", + " sample_data[\"ulv_climate\"] = clim_data[\"vert\"]\n", + "\n", + " # Move the data from numpy to torch\n", + " sample_data = self._to_torch(sample_data, dtype=self.dtype)\n", + "\n", + " # Optionally roll\n", + " if len(self._roll_longitudes) > 0:\n", + " roll_by = random.choice(self._roll_longitudes)\n", + " sample_data = self._lat_roll(sample_data, roll_by)\n", + "\n", + " # Now that we have rolled, we can add the static data\n", + " sample_data[\"lead_time\"] = spec.lead_time\n", + " sample_data[\"input_time\"] = spec.input_time\n", + "\n", + " return sample_data\n", + "\n", + " def get_data(\n", + " self, timestamp: pd.Timestamp, input_time: int, lead_time: int\n", + " ) -> dict[str, Tensor | int]:\n", + " \"\"\"\n", + " Loads data based on timestamp and lead time.\n", + " Args:\n", + " timestamp: Timestamp.\n", + " input_time: time between input samples.\n", + " lead_time: lead time.\n", + " Returns:\n", + " Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", + " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',\n", + " 'lead_time'.\n", + " \"\"\"\n", + " spec = SampleSpec.get(timestamp, -input_time, lead_time)\n", + " sample_data = self.get_data_from_sample_spec(spec)\n", + " return sample_data\n", + "\n", + " def __getitem__(self, idx: int) -> dict[str, Tensor | int]:\n", + " \"\"\"\n", + " Loads data based on sample index and random choice of sample.\n", + " Args:\n", + " idx: Sample index.\n", + " Returns:\n", + " Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", + " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',\n", + " 'lead_time', 'input_time'.\n", + " \"\"\"\n", + " sample_set = self.samples[idx]\n", + " timestamp, input_time, lead_time, *nsteps = random.choice(sample_set)\n", + " sample_data = self.get_data(timestamp, input_time, lead_time)\n", + " return sample_data\n", + "\n", + " def __len__(self):\n", + " return len(self.samples)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "# from PrithviWxC.dataloaders.merra2 import Merra2Dataset\n", + "\n", + "dataset = Merra2Dataset(\n", + " time_range=time_range,\n", + " lead_times=lead_times,\n", + " input_times=input_times,\n", + " data_path_surface=surf_dir,\n", + " data_path_vertical=vert_dir,\n", + " climatology_path_surface=surf_clim_dir,\n", + " climatology_path_vertical=vert_clim_dir,\n", + " surface_vars=surface_vars,\n", + " static_surface_vars=static_surface_vars,\n", + " vertical_vars=vertical_vars,\n", + " levels=levels,\n", + " positional_encoding=positional_encoding,\n", + ")\n", + "assert len(dataset) > 0, \"There doesn't seem to be any valid data.\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## The model\n", + "We are now ready to build the mdoel.\n", + "### Scalers\n", + "Additionally, the model takes as static parameters the mean\n", + "and variance values of the input variables and the variance\n", + "values of the target difference, i.e., the variance between\n", + "climatology and instantaneous variables. We have provided\n", + "data files containing these values, and here we load this data." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "# from PrithviWxC.dataloaders.merra2 import (\n", + "# input_scalers,\n", + "# output_scalers,\n", + "# static_input_scalers,\n", + "# )\n", + "\n", + "surf_in_scal_path = Path(\"./climatology/musigma_surface.nc\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=f\"climatology/{surf_in_scal_path.name}\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "vert_in_scal_path = Path(\"./climatology/musigma_vertical.nc\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=f\"climatology/{vert_in_scal_path.name}\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "surf_out_scal_path = Path(\"./climatology/anomaly_variance_surface.nc\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=f\"climatology/{surf_out_scal_path.name}\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "vert_out_scal_path = Path(\"./climatology/anomaly_variance_vertical.nc\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=f\"climatology/{vert_out_scal_path.name}\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "in_mu, in_sig = input_scalers(\n", + " surface_vars,\n", + " vertical_vars,\n", + " levels,\n", + " surf_in_scal_path,\n", + " vert_in_scal_path,\n", + ")\n", + "\n", + "output_sig = output_scalers(\n", + " surface_vars,\n", + " vertical_vars,\n", + " levels,\n", + " surf_out_scal_path,\n", + " vert_out_scal_path,\n", + ")\n", + "\n", + "static_mu, static_sig = static_input_scalers(\n", + " surf_in_scal_path,\n", + " static_surface_vars,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Task and additional configs\n", + "As previously mentioned, the PrithviWxC model's pretext task\n", + "involved predicting the desired variable at a specific lead\n", + "time. This was achieved by calculating the difference (delta)\n", + "compared to the climatological average at that time. This\n", + "operational mode is activated using the residual flag. Although\n", + "the model includes additional residual options, the core model\n", + "weights were not trained using these modes.\n", + "\n", + "Additionally, for training and evaluation, it is possible to\n", + "mask tokens in the model. The masking occurs after tokenization,\n", + "prior to the encoder layers. The model utilizes multi-axis\n", + "attention, with data broken down into a hierarchy of local and\n", + "global patches. Consequently, masking can be configured to mask\n", + "either small local patches or larger global patches. This\n", + "configuration is achieved via the `masking_mode` flag. It is\n", + "possible to set `masking_mode=both`. This does not mix the modes\n", + "but rather allows both modes to be used and swapped between,\n", + "primarily for training purposes. For this demonstration, we will\n", + "adjust the masking ratio to showcase the reconstruction\n", + "capabilities of the model.\n", + "\n", + "Finally, we can set up shifting. Primarily utilized in the\n", + "decoder, this enables alternate shifting of the attention\n", + "windows, similar to the SWIN model. This option necessitates\n", + "an even number of decoder blocks and is incompatible with the\n", + "encoder when masking is also employed." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "residual = \"climate\"\n", + "masking_mode = \"local\"\n", + "decoder_shifting = True\n", + "masking_ratio = 0.99" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model init\n", + "We now have all the pieces to build the model. If you are\n", + "using the pretrained weights, a number of the model\n", + "hyperparameters are predetermined and included below. With\n", + "this configuration, the model will have approximately 2.3\n", + "billion parameters. Therefore, if you want to train the fully\n", + "unfrozen model, you will likely need to use a model distribution\n", + "approach, such as fully shared data parallelism (FSDP). To\n", + "further reduce the memory usage of the model when gradients are\n", + "required, there are two variables — `checkpoint_encoder` and\n", + "`checkpoint_decoder` — which enable activation checkpointing of\n", + "desired transformer layers." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import cached_property\n", + "from importlib.metadata import version\n", + "\n", + "from torch import Tensor\n", + "from torch.utils.checkpoint import checkpoint\n", + "\n", + "if version(\"torch\") > \"2.3.0\":\n", + " from torch.nn.attention import SDPBackend, sdpa_kernel\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "# DropPath code is straight from timm\n", + "# (https://huggingface.co/spaces/Roll20/pet_score/blame/main/lib/timm/models/layers/drop.py)\n", + "def drop_path(\n", + " x: Tensor,\n", + " drop_prob: float = 0.0,\n", + " training: bool = False,\n", + " scale_by_keep: bool = True,\n", + ") -> Tensor:\n", + " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of\n", + " residual blocks). Taken form timm.\n", + "\n", + " Args:\n", + " x (Tensor): Input tensor.\n", + " drop_prob (float): Probability of dropping `x`, defaults to 0.\n", + " training (bool): Whether model is in in traingin of eval mode,\n", + " defaults to False.\n", + " scale_by_keep (bool): Whether the output should scaled by\n", + " (`1 - drop_prob`), defaults to True.\n", + " Returns:\n", + " Tensor: Tensor that may have randomly dropped with proability\n", + " `drop_path`\n", + " \"\"\"\n", + " if drop_prob == 0.0 or not training:\n", + " return x\n", + " keep_prob = 1 - drop_prob\n", + " shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n", + " random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n", + " if keep_prob > 0.0 and scale_by_keep:\n", + " random_tensor.div_(keep_prob)\n", + " return x * random_tensor\n", + "\n", + "\n", + "class DropPath(nn.Module):\n", + " \"\"\"\n", + " Drop paths (Stochastic Depth) per sample (when applied in main path of\n", + " residual blocks).\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self, drop_prob: float | None = None, scale_by_keep: bool = True\n", + " ) -> None:\n", + " super(DropPath, self).__init__()\n", + " self.drop_prob = drop_prob\n", + " self.scale_by_keep = scale_by_keep\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"Runs drop path on input tensor\n", + "\n", + " Args:\n", + " x: input\n", + "\n", + " Returns:\n", + " tensor: output after drop_path\n", + " \"\"\"\n", + " return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)\n", + "\n", + "\n", + "class Mlp(nn.Module):\n", + " \"\"\"\n", + " Multi layer perceptron.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self, features: int, hidden_features: int, dropout: float = 0.0\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " features: Input/output dimension.\n", + " hidden_features: Hidden dimension.\n", + " dropout: Dropout.\n", + " \"\"\"\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(features, hidden_features),\n", + " nn.GELU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(hidden_features, features),\n", + " nn.Dropout(dropout),\n", + " )\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x (Tesnor): Tensor of shape [..., channel]\n", + " Returns:\n", + " Tenosr: Tensor of same shape as x.\n", + " \"\"\"\n", + " return self.net(x)\n", + "\n", + "\n", + "class LayerNormPassThrough(nn.LayerNorm):\n", + " \"\"\"Normalising layer that allows the attention mask to be passed through\"\"\"\n", + "\n", + " def __init__(self, *args, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + "\n", + " def forward(\n", + " self, d: tuple[Tensor, Tensor | None]\n", + " ) -> tuple[Tensor, Tensor | None]:\n", + " \"\"\"Forwards function\n", + "\n", + " Args:\n", + " d (tuple): tuple of the data tensor and the attention mask\n", + " Returns:\n", + " output (Tensor): normalised output data\n", + " attn_mask (Tensor): the attention mask that was passed in\n", + " \"\"\"\n", + " input, attn_mask = d\n", + " output = F.layer_norm(\n", + " input, self.normalized_shape, self.weight, self.bias, self.eps\n", + " )\n", + " return output, attn_mask\n", + "\n", + "\n", + "class MultiheadAttention(nn.Module):\n", + " \"\"\"Multihead attention layer for inputs of shape\n", + " [..., sequence, features].\n", + " \"\"\"\n", + "\n", + " def __init__(self, features: int, n_heads: int, dropout: float) -> None:\n", + " \"\"\"\n", + " Args:\n", + " features: Number of features for inputs to the layer.\n", + " n_heads: Number of attention heads. Should be a factor of features.\n", + " (I.e. the layer uses features // n_heads.)\n", + " dropout: Dropout.\n", + " \"\"\" # noqa: E501\n", + " super().__init__()\n", + "\n", + " if (features % n_heads) != 0:\n", + " raise ValueError(\n", + " f\"Features '{features}' is not divisible by heads '{n_heads}'.\"\n", + " )\n", + "\n", + " self.features = features\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + "\n", + " self.qkv_layer = torch.nn.Linear(features, features * 3, bias=False)\n", + " self.w_layer = torch.nn.Linear(features, features, bias=False)\n", + "\n", + " def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " d (tuple): tuple containing Tensor of shape [..., sequence, features] and the attention mask\n", + " Returns:\n", + " Tensor: Tensor of shape [..., sequence, features]\n", + " \"\"\" # noqa: E501\n", + " x, attn_mask = d\n", + "\n", + " if not x.shape[-1] == self.features:\n", + " raise ValueError(\n", + " f\"Expecting tensor with last dimension size {self.features}.\"\n", + " )\n", + "\n", + " passenger_dims = x.shape[:-2]\n", + " B = passenger_dims.numel()\n", + " S = x.shape[-2]\n", + " C = x.shape[-1]\n", + " x = x.reshape(B, S, C)\n", + "\n", + " # x [B, S, C]\n", + " # q, k, v [B, H, S, C/H]\n", + " q, k, v = (\n", + " self.qkv_layer(x)\n", + " .view(B, S, self.n_heads, 3 * (C // self.n_heads))\n", + " .transpose(1, 2)\n", + " .chunk(chunks=3, dim=3)\n", + " )\n", + "\n", + " # Let us enforce either flash (A100+) or memory efficient attention.\n", + " if version(\"torch\") > \"2.3.0\":\n", + " with sdpa_kernel(\n", + " [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]\n", + " ):\n", + " # x [B, H, S, C//H]\n", + " x = F.scaled_dot_product_attention(\n", + " q, k, v, attn_mask=attn_mask, dropout_p=self.dropout\n", + " )\n", + " else:\n", + " with torch.backends.cuda.sdp_kernel(\n", + " enable_flash=True, enable_math=False, enable_mem_efficient=True\n", + " ):\n", + " # x [B, H, S, C//H]\n", + " x = F.scaled_dot_product_attention(\n", + " q, k, v, dropout_p=self.dropout\n", + " )\n", + "\n", + " # x [B, S, C]\n", + " x = x.transpose(1, 2).view(B, S, C)\n", + "\n", + " # x [B, S, C]\n", + " x = self.w_layer(x)\n", + "\n", + " # Back to input shape\n", + " x = x.view(*passenger_dims, S, self.features)\n", + " return x\n", + "\n", + "\n", + "class Transformer(nn.Module):\n", + " \"\"\"\n", + " Transformer for inputs of shape [..., S, features].\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " features: int,\n", + " mlp_multiplier: int,\n", + " n_heads: int,\n", + " dropout: float,\n", + " drop_path: float,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " features: Number of features for inputs to the layer.\n", + " mlp_multiplier: Model uses features*mlp_multiplier hidden units.\n", + " n_heads: Number of attention heads. Should be a factor of features.\n", + " (I.e. the layer uses features // n_heads.) dropout: Dropout.\n", + " drop_path: DropPath.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.features = features\n", + " self.mlp_multiplier = mlp_multiplier\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + " self.drop_path = (\n", + " DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n", + " )\n", + "\n", + " self.attention = nn.Sequential(\n", + " LayerNormPassThrough(features),\n", + " MultiheadAttention(features, n_heads, dropout),\n", + " )\n", + "\n", + " self.ff = nn.Sequential(\n", + " nn.LayerNorm(features),\n", + " Mlp(\n", + " features=features,\n", + " hidden_features=features * mlp_multiplier,\n", + " dropout=dropout,\n", + " ),\n", + " )\n", + "\n", + " def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x: Tensor of shape [..., sequence, features]\n", + " Returns:\n", + " Tensor: Tensor of shape [..., sequence, features]\n", + " \"\"\"\n", + " x, attn_mask = d\n", + " if not x.shape[-1] == self.features:\n", + " raise ValueError(\n", + " f\"Expecting tensor with last dimension size {self.features}.\"\n", + " )\n", + "\n", + " attention_x = self.attention(d)\n", + "\n", + " x = x + self.drop_path(attention_x)\n", + " x = x + self.drop_path(self.ff(x))\n", + "\n", + " return x\n", + "\n", + "\n", + "class _Shift(nn.Module):\n", + " \"\"\"Private base class for the shifter. This allows some behaviour to be\n", + " easily handled when the shifter isn't used.\n", + " \"\"\"\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " self._shifted = False\n", + "\n", + " @torch.no_grad()\n", + " def reset(self) -> None:\n", + " \"\"\"\n", + " Resets the bool tracking whether the data is shifted\n", + " \"\"\"\n", + " self._shifted: bool = False\n", + "\n", + " def forward(self, data: Tensor) -> tuple[Tensor, dict[bool, None]]:\n", + " return data, {True: None, False: None}\n", + "\n", + "\n", + "class SWINShift(_Shift):\n", + " \"\"\"\n", + " Handles the shifting of patches similar to how SWIN works. However if we\n", + " shift the latitudes then the poles will wrap and potentially that might be\n", + " problematic. The possition tokens should handle it but masking is safer.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " mu_shape: tuple[int, int],\n", + " global_shape: tuple[int, int],\n", + " local_shape: tuple[int, int],\n", + " patch_shape: tuple[int, int],\n", + " n_context_tokens: int = 2,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " mu_shape: the shape to the masking units\n", + " global_shape: number of global patches in lat and lon\n", + " local_shape: size of the local patches\n", + " patch_shape: patch size\n", + " n_context_token: number of additional context tokens at start of\n", + " _each_ local sequence\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self._mu_shape = ms = mu_shape\n", + " self._g_shape = gs = global_shape\n", + " self._l_shape = ls = local_shape\n", + " self._p_shape = ps = patch_shape\n", + " self._lat_patch = (gs[0], ls[0], gs[1], ls[1])\n", + " self._n_context_tokens = n_context_tokens\n", + "\n", + " self._g_shift_to = tuple(\n", + " int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)\n", + " )\n", + " self._g_shift_from = tuple(\n", + " -int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)\n", + " )\n", + "\n", + " # Define the attention masks for the shifted MaxViT.\n", + " nglobal = global_shape[0] * global_shape[1]\n", + " nlocal = (\n", + " local_shape[0] * local_shape[1] + self._n_context_tokens\n", + " ) # \"+ 1\" for leadtime\n", + "\n", + " lm = torch.ones((nglobal, 1, nlocal, nlocal), dtype=bool)\n", + " mwidth = int(0.5 * local_shape[1]) * local_shape[0]\n", + " lm[\n", + " : gs[1],\n", + " :,\n", + " self._n_context_tokens : mwidth + self._n_context_tokens,\n", + " self._n_context_tokens : mwidth + self._n_context_tokens,\n", + " ] = False\n", + " self.register_buffer(\"local_mask\", lm)\n", + "\n", + " gm = torch.ones((nlocal, 1, nglobal, nglobal), dtype=bool)\n", + " gm[: int(0.5 * ls[1]) * ls[0], :, : gs[1], : gs[1]] = False\n", + " self.register_buffer(\"global_mask\", gm)\n", + "\n", + " def _to_grid_global(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shuffle and reshape the data from the global/local setting back to the\n", + " lat/lon grid setting\n", + " Args:\n", + " x: the data tensor to be shuffled.\n", + " Returns:\n", + " x: data in the global/local setting\n", + " \"\"\"\n", + " nbatch, *other = x.shape\n", + "\n", + " y1 = x.view(nbatch, *self._g_shape, *self._l_shape, -1)\n", + " y2 = y1.permute(0, 5, 1, 3, 2, 4).contiguous()\n", + "\n", + " s = y2.shape\n", + " return y2.view((nbatch, -1, s[2] * s[3], s[4] * s[5]))\n", + "\n", + " def _to_grid_local(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shuffle and reshape the data from the local/global setting to the\n", + " lat/lon grid setting\n", + " Args:\n", + " x: the data tensor to be shuffled.\n", + " Returns:\n", + " x: data in the lat/lon setting.\n", + " \"\"\"\n", + " x = x.transpose(2, 1).contiguous()\n", + " return self._to_grid_global(x)\n", + "\n", + " def _from_grid_global(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shuffle and reshape the data from the lat/lon grid to the global/local\n", + " setting\n", + " Args:\n", + " x: the data tensor to be shuffled.\n", + " Returns:\n", + " x: data in the global/local setting\n", + " \"\"\"\n", + " nbatch, *other = x.shape\n", + "\n", + " z1 = x.view(nbatch, -1, *self._lat_patch)\n", + " z2 = z1.permute(0, 2, 4, 3, 5, 1).contiguous()\n", + "\n", + " s = z2.shape\n", + " return z2.view(nbatch, s[1] * s[2], s[3] * s[4], -1)\n", + "\n", + " def _from_grid_local(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shuffle and reshape the data from the lat/lon grid to the local/global\n", + " setting\n", + " Args:\n", + " x: the data tensor to be shuffled.\n", + " Returns:\n", + " x: data in the local/global setting\n", + " \"\"\"\n", + " x = self._from_grid_global(x)\n", + " return x.transpose(2, 1).contiguous()\n", + "\n", + " def _shift(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shifts data in the gridded lat/lon setting by half the mask unit shape\n", + " Args:\n", + " x: data to be shifted\n", + " Returns:\n", + " x: either the hsifted or unshifted data\n", + " \"\"\"\n", + " shift = self._g_shift_from if self._shifted else self._g_shift_to\n", + " x_shifted = torch.roll(x, shift, (-2, -1))\n", + "\n", + " self._shifted = not self._shifted\n", + " return x_shifted\n", + "\n", + " def _sep_lt(self, x: Tensor) -> tuple[Tensor, Tensor]:\n", + " \"\"\"\n", + " Seperate off the leadtime from the local patches\n", + " Args:\n", + " x: data to have leadtime removed from\n", + " Returns:\n", + " lt: leadtime\n", + " x: data without the lead time in the local patch\n", + " \"\"\"\n", + " lt_it = x[:, : self._n_context_tokens, :, :]\n", + " x_stripped = x[:, self._n_context_tokens :, :, :]\n", + "\n", + " return lt_it, x_stripped\n", + "\n", + " def forward(self, data: Tensor) -> tuple[Tensor, Tensor]:\n", + " \"\"\"Shift or unshift the the data depending on whether the data is\n", + " already shifted, as defined by self._shifte.\n", + "\n", + " Args:\n", + " data: data to be shifted\n", + " Returns:\n", + " Tensor: shifted data Tensor\n", + " \"\"\"\n", + " lt, x = self._sep_lt(data)\n", + "\n", + " x_grid = self._to_grid_local(x)\n", + " x_shifted = self._shift(x_grid)\n", + " x_patched = self._from_grid_local(x_shifted)\n", + "\n", + " # Mask has to be repeated based on batch size\n", + " n_batch = x_grid.shape[0]\n", + " local_rep = [n_batch] + [1] * (self.local_mask.ndim - 1)\n", + " global_rep = [n_batch] + [1] * (self.global_mask.ndim - 1)\n", + "\n", + " if self._shifted:\n", + " attn_mask = {\n", + " True: self.local_mask.repeat(local_rep),\n", + " False: self.global_mask.repeat(global_rep),\n", + " }\n", + " else:\n", + " attn_mask = {True: None, False: None}\n", + "\n", + " return torch.cat((lt, x_patched), axis=1), attn_mask\n", + "\n", + "\n", + "class LocalGlobalLocalBlock(nn.Module):\n", + " \"\"\"\n", + " Applies alternating block and grid attention. Given a parameter n_blocks,\n", + " the entire module contains 2*n_blocks+1 transformer blocks. The first,\n", + " third, ..., last apply local (block) attention. The second, fourth, ...\n", + " global (grid) attention.\n", + "\n", + " This is heavily inspired by\n", + " Tu et al. \"MaxViT: Multi-Axis Vision Transformer\"\n", + " (https://arxiv.org/abs/2204.01697).\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " features: int,\n", + " mlp_multiplier: int,\n", + " n_heads: int,\n", + " dropout: float,\n", + " n_blocks: int,\n", + " drop_path: float,\n", + " shifter: nn.Module | None = None,\n", + " checkpoint: list[int] | None = None,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " features: Number of features for inputs to the layer.\n", + " mlp_multiplier: Model uses features*mlp_multiplier hidden units.\n", + " n_heads: Number of attention heads. Should be a factor of features.\n", + " (I.e. the layer uses features // n_heads.)\n", + " dropout: Dropout.\n", + " drop_path: DropPath.\n", + " n_blocks: Number of local-global transformer pairs.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.features = features\n", + " self.mlp_multiplier = mlp_multiplier\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + " self.drop_path = drop_path\n", + " self.n_blocks = n_blocks\n", + " self._checkpoint = checkpoint or []\n", + "\n", + " if not all(0 <= c < 2 * n_blocks + 1 for c in self._checkpoint):\n", + " raise ValueError(\n", + " \"Checkpoints should be 0 <= i < 2*n_blocks+1. \"\n", + " f\"{self._checkpoint=}.\"\n", + " )\n", + "\n", + " self.transformers = nn.ModuleList(\n", + " [\n", + " Transformer(\n", + " features=features,\n", + " mlp_multiplier=mlp_multiplier,\n", + " n_heads=n_heads,\n", + " dropout=dropout,\n", + " drop_path=drop_path,\n", + " )\n", + " for _ in range(2 * n_blocks + 1)\n", + " ]\n", + " )\n", + "\n", + " self.evaluator = [\n", + " self._checkpoint_wrapper\n", + " if i in self._checkpoint\n", + " else lambda m, x: m(x)\n", + " for i, _ in enumerate(self.transformers)\n", + " ]\n", + "\n", + " self.shifter = shifter or _Shift()\n", + "\n", + " @staticmethod\n", + " def _checkpoint_wrapper(\n", + " model: nn.Module, data: tuple[Tensor, Tensor | None]\n", + " ) -> Tensor:\n", + " return checkpoint(model, data, use_reentrant=False)\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x: Tensor of shape::\n", + "\n", + " [batch, global_sequence, local_sequence, features]\n", + "\n", + " Returns:\n", + " Tensor: Tensor of shape::\n", + "\n", + " [batch, global_sequence, local_sequence, features]\n", + " \"\"\"\n", + " if x.shape[-1] != self.features:\n", + " raise ValueError(\n", + " f\"Expecting tensor with last dimension size {self.features}.\"\n", + " )\n", + " if x.ndim != 4:\n", + " raise ValueError(\n", + " f\"Expecting tensor with exactly four dimensions. {x.shape=}.\"\n", + " )\n", + "\n", + " self.shifter.reset()\n", + " local: bool = True\n", + " attn_mask = {True: None, False: None}\n", + "\n", + " transformer_iter = zip(self.evaluator, self.transformers, strict=False)\n", + "\n", + " # First local block\n", + " evaluator, transformer = next(transformer_iter)\n", + " x = evaluator(transformer, (x, attn_mask[local]))\n", + "\n", + " for evaluator, transformer in transformer_iter:\n", + " local = not local\n", + " # We are making exactly 2*n_blocks transposes.\n", + " # So the output has the same shape as input.\n", + " x = x.transpose(1, 2)\n", + "\n", + " x = evaluator(transformer, (x, attn_mask[local]))\n", + "\n", + " if not local:\n", + " x, attn_mask = self.shifter(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "class PatchEmbed(nn.Module):\n", + " \"\"\"\n", + " Patch embedding via 2D convolution.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self, patch_size: int | tuple[int, ...], channels: int, embed_dim: int\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.patch_size = patch_size\n", + " self.channels = channels\n", + " self.embed_dim = embed_dim\n", + "\n", + " self.proj = nn.Conv2d(\n", + " channels,\n", + " embed_dim,\n", + " kernel_size=patch_size,\n", + " stride=patch_size,\n", + " bias=True,\n", + " )\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x: Tensor of shape [batch, channels, lat, lon].\n", + " Returns:\n", + " Tensor: Tensor with shape\n", + " [batch, embed_dim, lat//patch_size, lon//patch_size]\n", + " \"\"\"\n", + "\n", + " H, W = x.shape[-2:]\n", + "\n", + " if W % self.patch_size[1] != 0:\n", + " raise ValueError(\n", + " f\"Cannot do patch embedding for tensor of shape {x.size()}\"\n", + " \" with patch size {self.patch_size}. (Dimensions are BSCHW.)\"\n", + " )\n", + " if H % self.patch_size[0] != 0:\n", + " raise ValueError(\n", + " f\"Cannot do patch embedding for tensor of shape {x.size()}\"\n", + " f\" with patch size {self.patch_size}. (Dimensions are BSCHW.)\"\n", + " )\n", + "\n", + " x = self.proj(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "class PrithviWxCEncoderDecoder(nn.Module):\n", + " \"\"\"\n", + " Hiera-MaxViT encoder/decoder code.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " embed_dim: int,\n", + " n_blocks: int,\n", + " mlp_multiplier: float,\n", + " n_heads: int,\n", + " dropout: float,\n", + " drop_path: float,\n", + " shifter: nn.Module | None = None,\n", + " transformer_cp: list[int] | None = None,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " embed_dim: Embedding dimension\n", + " n_blocks: Number of local-global transformer pairs.\n", + " mlp_multiplier: MLP multiplier for hidden features in feed forward\n", + " networks.\n", + " n_heads: Number of attention heads.\n", + " dropout: Dropout.\n", + " drop_path: DropPath.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.embed_dim = embed_dim\n", + " self.n_blocks = n_blocks\n", + " self.mlp_multiplier = mlp_multiplier\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + " self._transformer_cp = transformer_cp\n", + "\n", + " self.lgl_block = LocalGlobalLocalBlock(\n", + " features=embed_dim,\n", + " mlp_multiplier=mlp_multiplier,\n", + " n_heads=n_heads,\n", + " dropout=dropout,\n", + " drop_path=drop_path,\n", + " n_blocks=n_blocks,\n", + " shifter=shifter,\n", + " checkpoint=transformer_cp,\n", + " )\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x: Tensor of shape\n", + " [batch, global sequence, local sequence, embed_dim]\n", + " Returns:\n", + " Tensor of shape\n", + " [batch, mask_unit_sequence, local_sequence, embed_dim].\n", + " Identical in shape to the input x.\n", + " \"\"\"\n", + "\n", + " x = self.lgl_block(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "class PrithviWxC(nn.Module):\n", + " \"\"\"Encoder-decoder fusing Hiera with MaxViT. See\n", + " - Ryali et al. \"Hiera: A Hierarchical Vision Transformer without the\n", + " Bells-and-Whistles\" (https://arxiv.org/abs/2306.00989)\n", + " - Tu et al. \"MaxViT: Multi-Axis Vision Transformer\"\n", + " (https://arxiv.org/abs/2204.01697)\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " in_channels: int,\n", + " input_size_time: int,\n", + " in_channels_static: int,\n", + " input_scalers_mu: Tensor,\n", + " input_scalers_sigma: Tensor,\n", + " input_scalers_epsilon: float,\n", + " static_input_scalers_mu: Tensor,\n", + " static_input_scalers_sigma: Tensor,\n", + " static_input_scalers_epsilon: float,\n", + " output_scalers: Tensor,\n", + " n_lats_px: int,\n", + " n_lons_px: int,\n", + " patch_size_px: tuple[int],\n", + " mask_unit_size_px: tuple[int],\n", + " mask_ratio_inputs: float,\n", + " embed_dim: int,\n", + " n_blocks_encoder: int,\n", + " n_blocks_decoder: int,\n", + " mlp_multiplier: float,\n", + " n_heads: int,\n", + " dropout: float,\n", + " drop_path: float,\n", + " parameter_dropout: float,\n", + " residual: str,\n", + " masking_mode: str,\n", + " positional_encoding: str,\n", + " decoder_shifting: bool = False,\n", + " checkpoint_encoder: list[int] | None = None,\n", + " checkpoint_decoder: list[int] | None = None,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " in_channels: Number of input channels.\n", + " input_size_time: Number of timestamps in input.\n", + " in_channels_static: Number of input channels for static data.\n", + " input_scalers_mu: Tensor of size (in_channels,). Used to rescale\n", + " input.\n", + " input_scalers_sigma: Tensor of size (in_channels,). Used to rescale\n", + " input.\n", + " input_scalers_epsilon: Float. Used to rescale input.\n", + " static_input_scalers_mu: Tensor of size (in_channels_static). Used\n", + " to rescale static inputs.\n", + " static_input_scalers_sigma: Tensor of size (in_channels_static).\n", + " Used to rescale static inputs.\n", + " static_input_scalers_epsilon: Float. Used to rescale static inputs.\n", + " output_scalers: Tensor of shape (in_channels,). Used to rescale\n", + " output.\n", + " n_lats_px: Total latitudes in data. In pixels.\n", + " n_lons_px: Total longitudes in data. In pixels.\n", + " patch_size_px: Patch size for tokenization. In pixels lat/lon.\n", + " mask_unit_size_px: Size of each mask unit. In pixels lat/lon.\n", + " mask_ratio_inputs: Masking ratio for inputs. 0 to 1.\n", + " embed_dim: Embedding dimension\n", + " n_blocks_encoder: Number of local-global transformer pairs in\n", + " encoder.\n", + " n_blocks_decoder: Number of local-global transformer pairs in\n", + " decoder.\n", + " mlp_multiplier: MLP multiplier for hidden features in feed forward\n", + " networks.\n", + " n_heads: Number of attention heads.\n", + " dropout: Dropout.\n", + " drop_path: DropPath.\n", + " parameter_dropout: Dropout applied to parameters.\n", + " residual: Indicates whether and how model should work as residual\n", + " model. Accepted values are 'climate', 'temporal' and 'none'\n", + " positional_encoding: possible values are\n", + " ['absolute' (default), 'fourier'].\n", + " 'absolute' lat lon encoded in 3 dimensions using sine and\n", + " cosine\n", + " 'fourier' lat/lon to be encoded using various frequencies\n", + " masking_mode: String ['local', 'global', 'both'] that controls the\n", + " type of masking used.\n", + " checkpoint_encoder: List of integers controlling if gradient\n", + " checkpointing is used on encoder.\n", + " Format: [] for no gradient checkpointing. [3, 7] for\n", + " checkpointing after 4th and 8th layer etc.\n", + " checkpoint_decoder: List of integers controlling if gradient\n", + " checkpointing is used on decoder.\n", + " Format: See `checkpoint_encoder`.\n", + " masking_mode: The type of masking to use\n", + " {'global', 'local', 'both'}\n", + " decoder_shifting: Whether to use swin shifting in the decoder.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.in_channels = in_channels\n", + " self.input_size_time = input_size_time\n", + " self.in_channels_static = in_channels_static\n", + " self.n_lats_px = n_lats_px\n", + " self.n_lons_px = n_lons_px\n", + " self.patch_size_px = patch_size_px\n", + " self.mask_unit_size_px = mask_unit_size_px\n", + " self.mask_ratio_inputs = mask_ratio_inputs\n", + " self.embed_dim = embed_dim\n", + " self.n_blocks_encoder = n_blocks_encoder\n", + " self.n_blocks_decoder = n_blocks_decoder\n", + " self.mlp_multiplier = mlp_multiplier\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + " self.drop_path = drop_path\n", + " self.residual = residual\n", + " self._decoder_shift = decoder_shifting\n", + " self.positional_encoding = positional_encoding\n", + " self._checkpoint_encoder = checkpoint_encoder\n", + " self._checkpoint_decoder = checkpoint_decoder\n", + "\n", + " assert self.n_lats_px % self.mask_unit_size_px[0] == 0\n", + " assert self.n_lons_px % self.mask_unit_size_px[1] == 0\n", + " assert self.mask_unit_size_px[0] % self.patch_size_px[0] == 0\n", + " assert self.mask_unit_size_px[1] % self.patch_size_px[1] == 0\n", + "\n", + " if self.patch_size_px[0] != self.patch_size_px[1]:\n", + " raise NotImplementedError(\n", + " \"Current pixel shuffle symmetric patches.\"\n", + " )\n", + "\n", + " self.local_shape_mu = (\n", + " self.mask_unit_size_px[0] // self.patch_size_px[0],\n", + " self.mask_unit_size_px[1] // self.patch_size_px[1],\n", + " )\n", + " self.global_shape_mu = (\n", + " self.n_lats_px // self.mask_unit_size_px[0],\n", + " self.n_lons_px // self.mask_unit_size_px[1],\n", + " )\n", + "\n", + " assert input_scalers_mu.shape == (in_channels,)\n", + " assert input_scalers_sigma.shape == (in_channels,)\n", + " assert output_scalers.shape == (in_channels,)\n", + "\n", + " if self.positional_encoding != \"fourier\":\n", + " assert static_input_scalers_mu.shape == (in_channels_static,)\n", + " assert static_input_scalers_sigma.shape == (in_channels_static,)\n", + "\n", + " # Input shape [batch, time, parameter, lat, lon]\n", + " self.input_scalers_epsilon = input_scalers_epsilon\n", + " self.register_buffer(\n", + " \"input_scalers_mu\", input_scalers_mu.reshape(1, 1, -1, 1, 1)\n", + " )\n", + " self.register_buffer(\n", + " \"input_scalers_sigma\", input_scalers_sigma.reshape(1, 1, -1, 1, 1)\n", + " )\n", + "\n", + " # Static inputs shape [batch, parameter, lat, lon]\n", + " self.static_input_scalers_epsilon = static_input_scalers_epsilon\n", + " self.register_buffer(\n", + " \"static_input_scalers_mu\",\n", + " static_input_scalers_mu.reshape(1, -1, 1, 1),\n", + " )\n", + " self.register_buffer(\n", + " \"static_input_scalers_sigma\",\n", + " static_input_scalers_sigma.reshape(1, -1, 1, 1),\n", + " )\n", + "\n", + " # Output shape [batch, parameter, lat, lon]\n", + " self.register_buffer(\n", + " \"output_scalers\", output_scalers.reshape(1, -1, 1, 1)\n", + " )\n", + "\n", + " self.parameter_dropout = nn.Dropout2d(p=parameter_dropout)\n", + "\n", + " self.patch_embedding = PatchEmbed(\n", + " patch_size=patch_size_px,\n", + " channels=in_channels * input_size_time,\n", + " embed_dim=embed_dim,\n", + " )\n", + "\n", + " if self.residual == \"climate\":\n", + " self.patch_embedding_static = PatchEmbed(\n", + " patch_size=patch_size_px,\n", + " channels=in_channels + in_channels_static,\n", + " embed_dim=embed_dim,\n", + " )\n", + " else:\n", + " self.patch_embedding_static = PatchEmbed(\n", + " patch_size=patch_size_px,\n", + " channels=in_channels_static,\n", + " embed_dim=embed_dim,\n", + " )\n", + "\n", + " self.input_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)\n", + " self.lead_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)\n", + "\n", + " self.mask_token = nn.Parameter(torch.randn(1, 1, 1, self.embed_dim))\n", + " self._nglobal_mu = np.prod(self.global_shape_mu)\n", + " self._global_idx = torch.arange(self._nglobal_mu)\n", + "\n", + " self._nlocal_mu = np.prod(self.local_shape_mu)\n", + " self._local_idx = torch.arange(self._nlocal_mu)\n", + "\n", + " self.encoder = PrithviWxCEncoderDecoder(\n", + " embed_dim=embed_dim,\n", + " n_blocks=n_blocks_encoder,\n", + " mlp_multiplier=mlp_multiplier,\n", + " n_heads=n_heads,\n", + " dropout=dropout,\n", + " drop_path=drop_path,\n", + " transformer_cp=checkpoint_encoder,\n", + " )\n", + "\n", + " if n_blocks_decoder != 0:\n", + " if self._decoder_shift:\n", + " self.decoder_shifter = d_shifter = SWINShift(\n", + " self.mask_unit_size_px,\n", + " self.global_shape_mu,\n", + " self.local_shape_mu,\n", + " self.patch_size_px,\n", + " n_context_tokens=0,\n", + " )\n", + " else:\n", + " self.decoder_shifter = d_shifter = None\n", + "\n", + " self.decoder = PrithviWxCEncoderDecoder(\n", + " embed_dim=embed_dim,\n", + " n_blocks=n_blocks_decoder,\n", + " mlp_multiplier=mlp_multiplier,\n", + " n_heads=n_heads,\n", + " dropout=dropout,\n", + " drop_path=0.0,\n", + " shifter=d_shifter,\n", + " transformer_cp=checkpoint_decoder,\n", + " )\n", + "\n", + " self.unembed = nn.Linear(\n", + " self.embed_dim,\n", + " self.in_channels\n", + " * self.patch_size_px[0]\n", + " * self.patch_size_px[1],\n", + " bias=True,\n", + " )\n", + "\n", + " self.masking_mode = masking_mode.lower()\n", + " match self.masking_mode:\n", + " case \"local\":\n", + " self.generate_mask = self._gen_mask_local\n", + " case \"global\":\n", + " self.generate_mask = self._gen_mask_global\n", + " case \"both\":\n", + " self._mask_both_local: bool = True\n", + " self.generate_mask = self._gen_mask_both\n", + " case _:\n", + " raise ValueError(\n", + " f\"Masking mode '{masking_mode}' not supported\"\n", + " )\n", + "\n", + " def swap_masking(self) -> None:\n", + " self._mask_both_local = not self._mask_both_local\n", + "\n", + " @cached_property\n", + " def n_masked_global(self):\n", + " return int(self.mask_ratio_inputs * np.prod(self.global_shape_mu))\n", + "\n", + " @cached_property\n", + " def n_masked_local(self):\n", + " return int(self.mask_ratio_inputs * np.prod(self.local_shape_mu))\n", + "\n", + " @staticmethod\n", + " def _shuffle_along_axis(a, axis):\n", + " idx = torch.argsort(input=torch.rand(*a.shape), dim=axis)\n", + " return torch.gather(a, dim=axis, index=idx)\n", + "\n", + " def _gen_mask_local(self, sizes: tuple[int]) -> tuple[Tensor]:\n", + " \"\"\"\n", + " Args:\n", + " batch_size: Number of elements in batch\n", + " Returns:\n", + " Tuple of torch tensors. [indices masked, indices unmasked].\n", + " Each of these is a tensor of shape (batch, global sequene)\n", + " \"\"\"\n", + " # Identify which indices (values) should be masked\n", + "\n", + " maskable_indices = self._local_idx.view(1, -1).expand(*sizes[:2], -1)\n", + "\n", + " maskable_indices = self._shuffle_along_axis(maskable_indices, 2)\n", + "\n", + " indices_masked = maskable_indices[:, :, : self.n_masked_local]\n", + " indices_unmasked = maskable_indices[:, :, self.n_masked_local :]\n", + "\n", + " return indices_masked, indices_unmasked\n", + "\n", + " def _gen_mask_global(self, sizes: tuple[int]) -> tuple[Tensor]:\n", + " \"\"\"\n", + " Args:\n", + " batch_size: Number of elements in batch\n", + " Returns:\n", + " Tuple of torch tensors. [indices masked, indices unmasked].\n", + " Each of these is a tensor of shape (batch, global sequene)\n", + " \"\"\"\n", + " # Identify which indices (values) should be masked\n", + "\n", + " maskable_indices = self._global_idx.view(1, -1).expand(*sizes[:1], -1)\n", + "\n", + " maskable_indices = self._shuffle_along_axis(maskable_indices, 1)\n", + "\n", + " indices_masked = maskable_indices[:, : self.n_masked_global]\n", + " indices_unmasked = maskable_indices[:, self.n_masked_global :]\n", + "\n", + " return indices_masked, indices_unmasked\n", + "\n", + " def _gen_mask_both(self, sizes: tuple[int]) -> tuple[Tensor]:\n", + " if self._mask_both_local:\n", + " return self._gen_mask_local(sizes)\n", + " else:\n", + " return self._gen_mask_global(sizes)\n", + "\n", + " @staticmethod\n", + " def reconstruct_batch(\n", + " idx_masked: Tensor,\n", + " idx_unmasked: Tensor,\n", + " data_masked: Tensor,\n", + " data_unmasked: Tensor,\n", + " ) -> Tensor:\n", + " \"\"\"Reconstructs a tensor along the mask unit dimension. Batched\n", + " version.\n", + "\n", + " Args:\n", + " idx_masked: Tensor of shape `batch, mask unit sequence`.\n", + " idx_unmasked: Tensor of shape `batch, mask unit sequence`.\n", + " data_masked: Tensor of shape `batch, mask unit sequence, ...`.\n", + " Should have same size along mask unit sequence dimension as\n", + " idx_masked. Dimensions beyond the first two, marked here as ...\n", + " will typically be `local_sequence, channel` or\n", + " `channel, lat, lon`. These dimensions should agree with\n", + " data_unmasked.\n", + " data_unmasked: Tensor of shape `batch, mask unit sequence, ...`.\n", + " Should have same size along mask unit sequence dimension as\n", + " idx_unmasked. Dimensions beyond the first two, marked here as\n", + " ... will typically be `local_sequence, channel` or `channel,\n", + " lat, lon`. These dimensions should agree with data_masked.\n", + " Returns:\n", + " Tensor: Tensor of same shape as inputs data_masked and\n", + " data_unmasked. I.e. `batch, mask unit sequence, ...`. Index for\n", + " the total data composed of the masked and the unmasked part.\n", + " \"\"\"\n", + " dim: int = idx_masked.ndim\n", + "\n", + " idx_total = torch.argsort(\n", + " torch.cat([idx_masked, idx_unmasked], dim=-1), dim=-1\n", + " )\n", + " idx_total = idx_total.view(\n", + " *idx_total.shape, *[1] * (data_unmasked.ndim - dim)\n", + " )\n", + " idx_total = idx_total.expand(\n", + " *idx_total.shape[:dim], *data_unmasked.shape[dim:]\n", + " )\n", + "\n", + " data = torch.cat([data_masked, data_unmasked], dim=dim - 1)\n", + " data = torch.gather(data, dim=dim - 1, index=idx_total)\n", + "\n", + " return data, idx_total\n", + "\n", + " def fourier_pos_encoding(self, x_static: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Args\n", + " x_static: B x C x H x W. first two channels are lat, and lon\n", + " Returns\n", + " Tensor: Tensor of shape B x E x H x W where E is the embedding\n", + " dimension.\n", + " \"\"\"\n", + "\n", + " # B x C x H x W -> B x 1 x H/P x W/P\n", + " latitudes_patch = F.avg_pool2d(\n", + " x_static[:, [0]],\n", + " kernel_size=self.patch_size_px,\n", + " stride=self.patch_size_px,\n", + " )\n", + " longitudes_patch = F.avg_pool2d(\n", + " x_static[:, [1]],\n", + " kernel_size=self.patch_size_px,\n", + " stride=self.patch_size_px,\n", + " )\n", + "\n", + " modes = (\n", + " torch.arange(self.embed_dim // 4, device=x_static.device).view(\n", + " 1, -1, 1, 1\n", + " )\n", + " + 1.0\n", + " )\n", + " pos_encoding = torch.cat(\n", + " (\n", + " torch.sin(latitudes_patch * modes),\n", + " torch.sin(longitudes_patch * modes),\n", + " torch.cos(latitudes_patch * modes),\n", + " torch.cos(longitudes_patch * modes),\n", + " ),\n", + " axis=1,\n", + " )\n", + "\n", + " return pos_encoding # B x E x H/P x W/P\n", + "\n", + " def time_encoding(self, input_time, lead_time):\n", + " \"\"\"\n", + " Args:\n", + " input_time: Tensor of shape [batch].\n", + " lead_time: Tensor of shape [batch].\n", + " Returns:\n", + " Tensor: Tensor of shape [batch, embed_dim, 1, 1]\n", + " \"\"\"\n", + " input_time = self.input_time_embedding(input_time.view(-1, 1, 1, 1))\n", + " lead_time = self.lead_time_embedding(lead_time.view(-1, 1, 1, 1))\n", + "\n", + " time_encoding = torch.cat(\n", + " (\n", + " torch.cos(input_time),\n", + " torch.cos(lead_time),\n", + " torch.sin(input_time),\n", + " torch.sin(lead_time),\n", + " ),\n", + " axis=3,\n", + " )\n", + " return time_encoding\n", + "\n", + " def to_patching(self, x: Tensor) -> Tensor:\n", + " \"\"\"Transform data from lat/lon space to two axis patching\n", + "\n", + " Args: ->\n", + " x: Tesnor in lat/lon space (N, C, Nlat//P_0, Nlon//P_1)\n", + "\n", + " Returns:\n", + " Tensor in patch space (N, G, L, C)\n", + " \"\"\"\n", + " n_batch = x.shape[0]\n", + "\n", + " x = x.view(\n", + " n_batch,\n", + " -1,\n", + " self.global_shape_mu[0],\n", + " self.local_shape_mu[0],\n", + " self.global_shape_mu[1],\n", + " self.local_shape_mu[1],\n", + " )\n", + " x = x.permute(0, 2, 4, 3, 5, 1).contiguous()\n", + "\n", + " s = x.shape\n", + " return x.view(n_batch, s[1] * s[2], s[3] * s[4], -1)\n", + "\n", + " def from_patching(self, x: Tensor) -> Tensor:\n", + " \"\"\"Transform data from two axis patching to lat/lon space\n", + "\n", + " Args:\n", + " x: Tensor in patch space with shape (N, G, L, C*P_0*P_1)\n", + "\n", + " Returns:\n", + " Tensor: Tensor in lat/lon space\n", + " (N, C*P_0*P_1, Nlat//P_0, Nlon // P_1)\n", + " \"\"\"\n", + " n_batch = x.shape[0]\n", + "\n", + " x = x.view(\n", + " n_batch,\n", + " self.global_shape_mu[0],\n", + " self.global_shape_mu[1],\n", + " self.local_shape_mu[0],\n", + " self.local_shape_mu[1],\n", + " -1,\n", + " )\n", + " x = x.permute(0, 5, 1, 3, 2, 4).contiguous()\n", + "\n", + " s = x.shape\n", + " return x.view(n_batch, -1, s[2] * s[3], s[4] * s[5])\n", + "\n", + " def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " batch: Dictionary the following keys::\n", + "\n", + " 'x': Tensor of shape [batch, time, parameter, lat, lon]\n", + " 'y': Tensor of shape [batch, parameter, lat, lon]\n", + " 'static': Tensor of shape [batch, channel_static, lat, lon]\n", + " 'climate': Optional tensor of shape [batch, parameter, lat, lon]\n", + " 'input_time': Tensor of shape [batch]. Or none.\n", + " 'lead_time': Tensor of shape [batch]. Or none.\n", + "\n", + " Returns:\n", + " Tensor: Tensor of shape [batch, parameter, lat, lon].\n", + " \"\"\" # noqa: E501\n", + " x_rescaled = (batch[\"x\"] - self.input_scalers_mu) / (\n", + " self.input_scalers_sigma + self.input_scalers_epsilon\n", + " )\n", + " batch_size = x_rescaled.shape[0]\n", + "\n", + " if self.positional_encoding == \"fourier\":\n", + " x_static_pos = self.fourier_pos_encoding(batch[\"static\"])\n", + " x_static = (\n", + " batch[\"static\"][:, 2:] - self.static_input_scalers_mu[:, 3:]\n", + " ) / (\n", + " self.static_input_scalers_sigma[:, 3:]\n", + " + self.static_input_scalers_epsilon\n", + " )\n", + " else:\n", + " x_static = (batch[\"static\"] - self.static_input_scalers_mu) / (\n", + " self.static_input_scalers_sigma\n", + " + self.static_input_scalers_epsilon\n", + " )\n", + "\n", + " if self.residual == \"temporal\":\n", + " # We create a residual of same shape as y\n", + " index = torch.where(\n", + " batch[\"lead_time\"] > 0, batch[\"x\"].shape[1] - 1, 0\n", + " )\n", + " index = index.view(-1, 1, 1, 1, 1)\n", + " index = index.expand(batch_size, 1, *batch[\"x\"].shape[2:])\n", + " x_hat = torch.gather(batch[\"x\"], dim=1, index=index)\n", + " x_hat = x_hat.squeeze(1)\n", + " elif self.residual == \"climate\":\n", + " climate_scaled = (\n", + " batch[\"climate\"] - self.input_scalers_mu.view(1, -1, 1, 1)\n", + " ) / (\n", + " self.input_scalers_sigma.view(1, -1, 1, 1)\n", + " + self.input_scalers_epsilon\n", + " )\n", + "\n", + " # [batch, time, parameter, lat, lon]\n", + " # -> [batch, time x parameter, lat, lon]\n", + " x_rescaled = x_rescaled.flatten(1, 2)\n", + " # Parameter dropout\n", + " x_rescaled = self.parameter_dropout(x_rescaled)\n", + "\n", + " x_embedded = self.patch_embedding(x_rescaled)\n", + "\n", + " if self.residual == \"climate\":\n", + " static_embedded = self.patch_embedding_static(\n", + " torch.cat((x_static, climate_scaled), dim=1)\n", + " )\n", + " else:\n", + " static_embedded = self.patch_embedding_static(x_static)\n", + "\n", + " if self.positional_encoding == \"fourier\":\n", + " static_embedded += x_static_pos\n", + "\n", + " x_embedded = self.to_patching(x_embedded)\n", + " static_embedded = self.to_patching(static_embedded)\n", + "\n", + " time_encoding = self.time_encoding(\n", + " batch[\"input_time\"], batch[\"lead_time\"]\n", + " )\n", + "\n", + " tokens = x_embedded + static_embedded + time_encoding\n", + "\n", + " # Now we generate masks based on masking_mode\n", + " indices_masked, indices_unmasked = self.generate_mask(\n", + " (batch_size, self._nglobal_mu)\n", + " )\n", + " indices_masked = indices_masked.to(device=tokens.device)\n", + " indices_unmasked = indices_unmasked.to(device=tokens.device)\n", + " maskdim: int = indices_masked.ndim\n", + "\n", + " # Unmasking\n", + " unmask_view = (*indices_unmasked.shape, *[1] * (tokens.ndim - maskdim))\n", + " unmasked = torch.gather(\n", + " tokens,\n", + " dim=maskdim - 1,\n", + " index=indices_unmasked.view(*unmask_view).expand(\n", + " *indices_unmasked.shape, *tokens.shape[maskdim:]\n", + " ),\n", + " )\n", + "\n", + " # Encoder\n", + " x_encoded = self.encoder(unmasked)\n", + "\n", + " # Generate and position encode the mask tokens\n", + " # [1, 1, 1, embed_dim]\n", + " # -> [batch, global_seq_masked, local seq, embed_dim]\n", + " mask_view = (*indices_masked.shape, *[1] * (tokens.ndim - maskdim))\n", + " masking = self.mask_token.repeat(*static_embedded.shape[:3], 1)\n", + " masked = masking + static_embedded\n", + " masked = torch.gather(\n", + " masked,\n", + " dim=maskdim - 1,\n", + " index=indices_masked.view(*mask_view).expand(\n", + " *indices_masked.shape, *tokens.shape[maskdim:]\n", + " ),\n", + " )\n", + "\n", + " recon, _ = self.reconstruct_batch(\n", + " indices_masked, indices_unmasked, masked, x_encoded\n", + " )\n", + "\n", + " x_decoded = self.decoder(recon)\n", + "\n", + " # Output: [batch, global sequence, local sequence,\n", + " # in_channels * patch_size[0] * patch_size[1]]\n", + " x_unembed = self.unembed(x_decoded)\n", + "\n", + " # Reshape to [batch, global_lat, global_lon, local_lat, local_lon,\n", + " # in_channels * patch_size[0] * patch_size[1]]\n", + " x_out = self.from_patching(x_unembed)\n", + "\n", + " # Pixel shuffle to [batch, in_channels, lat, lon]\n", + " x_out = F.pixel_shuffle(x_out, self.patch_size_px[0])\n", + "\n", + " if self.residual == \"temporal\":\n", + " x_out = self.output_scalers * x_out + x_hat\n", + " elif self.residual == \"climate\":\n", + " x_out = self.output_scalers * x_out + batch[\"climate\"]\n", + " elif self.residual == \"none\":\n", + " x_out = (\n", + " self.output_scalers * x_out\n", + " + self.input_scalers_mu.reshape(1, -1, 1, 1)\n", + " )\n", + "\n", + " return x_out\n" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "import yaml\n", + "\n", + "# from PrithviWxC.model import PrithviWxC\n", + "\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=\"config.yaml\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "with open(\"./config.yaml\", \"r\") as f:\n", + " config = yaml.safe_load(f)\n", + "\n", + "model = PrithviWxC(\n", + " in_channels=config[\"params\"][\"in_channels\"],\n", + " input_size_time=config[\"params\"][\"input_size_time\"],\n", + " in_channels_static=config[\"params\"][\"in_channels_static\"],\n", + " input_scalers_mu=in_mu,\n", + " input_scalers_sigma=in_sig,\n", + " input_scalers_epsilon=config[\"params\"][\"input_scalers_epsilon\"],\n", + " static_input_scalers_mu=static_mu,\n", + " static_input_scalers_sigma=static_sig,\n", + " static_input_scalers_epsilon=config[\"params\"][\n", + " \"static_input_scalers_epsilon\"\n", + " ],\n", + " output_scalers=output_sig**0.5,\n", + " n_lats_px=config[\"params\"][\"n_lats_px\"],\n", + " n_lons_px=config[\"params\"][\"n_lons_px\"],\n", + " patch_size_px=config[\"params\"][\"patch_size_px\"],\n", + " mask_unit_size_px=config[\"params\"][\"mask_unit_size_px\"],\n", + " mask_ratio_inputs=masking_ratio,\n", + " embed_dim=config[\"params\"][\"embed_dim\"],\n", + " n_blocks_encoder=config[\"params\"][\"n_blocks_encoder\"],\n", + " n_blocks_decoder=config[\"params\"][\"n_blocks_decoder\"],\n", + " mlp_multiplier=config[\"params\"][\"mlp_multiplier\"],\n", + " n_heads=config[\"params\"][\"n_heads\"],\n", + " dropout=config[\"params\"][\"dropout\"],\n", + " drop_path=config[\"params\"][\"drop_path\"],\n", + " parameter_dropout=config[\"params\"][\"parameter_dropout\"],\n", + " residual=residual,\n", + " masking_mode=masking_mode,\n", + " decoder_shifting=decoder_shifting,\n", + " positional_encoding=positional_encoding,\n", + " checkpoint_encoder=[],\n", + " checkpoint_decoder=[],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Load weights\n", + "We have provided unshared pretrained weights for the model,\n", + "which can now be loaded. The model can then be transferred\n", + "to the requested device." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "weights_path = Path(\"./weights/prithvi.wxc.2300m.v1.pt\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=weights_path.name,\n", + " local_dir=\"./weights\",\n", + ")\n", + "\n", + "state_dict = torch.load(weights_path, weights_only=False)\n", + "if \"model_state\" in state_dict:\n", + " state_dict = state_dict[\"model_state\"]\n", + "model.load_state_dict(state_dict, strict=True)\n", + "\n", + "if (hasattr(model, \"device\") and model.device != device) or not hasattr(\n", + " model, \"device\"\n", + "):\n", + " model = model.to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inference\n", + "We are now ready to perform inference on the model. The data\n", + "returned from the dataset class requires additional\n", + "preprocessing; therefore, after polling the dataset, we process\n", + "the data using the `preproc` function. This processed data is\n", + "then transferred to the device. To recover the masking, we can\n", + "save the torch RNG state and use it later. Finally, we run the\n", + "model in evaluation mode without generating the gradient graph." + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "# from PrithviWxC.dataloaders.merra2 import preproc\n", + "\n", + "data = next(iter(dataset))\n", + "batch = preproc([data], padding)\n", + "\n", + "for k, v in batch.items():\n", + " if isinstance(v, torch.Tensor):\n", + " batch[k] = v.to(device)\n", + "\n", + "rng_state_1 = torch.get_rng_state()\n", + "with torch.no_grad():\n", + " model.eval()\n", + " out = model(batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plotting" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "t2m = out[0, 12].cpu().numpy()\n", + "\n", + "lat = np.linspace(-90, 90, out.shape[-2])\n", + "lon = np.linspace(-180, 180, out.shape[-1])\n", + "X, Y = np.meshgrid(lon, lat)\n", + "\n", + "plt.contourf(X, Y, t2m, 100)\n", + "plt.gca().set_aspect(\"equal\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/PrithviWxC_rollout.ipynb b/examples/PrithviWxC_rollout.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ba116a7a03039a94a63245fd2458bd93ca692bdb --- /dev/null +++ b/examples/PrithviWxC_rollout.ipynb @@ -0,0 +1,3670 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# PrithviWxC Rollout Inference\n", + "If you haven't already, take a look at the exmaple for the PrithviWxC core\n", + "model, as we will pass over the points covered there.\n", + "\n", + "Here we will introduce the PrithviWxC model that was trained furhter for\n", + "autoregressive rollout, a common strategy to increase accuracy and stability of\n", + "models when applied to forecasting-type tasks. " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "from pathlib import Path\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import torch\n", + "from huggingface_hub import hf_hub_download, snapshot_download\n", + "\n", + "# Set backend etc.\n", + "torch.jit.enable_onednn_fusion(True)\n", + "if torch.cuda.is_available():\n", + " torch.backends.cudnn.benchmark = True\n", + " torch.backends.cudnn.deterministic = True\n", + "\n", + "# Set seeds\n", + "random.seed(42)\n", + "if torch.cuda.is_available():\n", + " torch.cuda.manual_seed(42)\n", + "torch.manual_seed(42)\n", + "np.random.seed(42)\n", + "\n", + "# Set device\n", + "if torch.cuda.is_available():\n", + " device = torch.device(\"cuda\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + "\n", + "# Set variables\n", + "surface_vars = [\n", + " \"EFLUX\",\n", + " \"GWETROOT\",\n", + " \"HFLUX\",\n", + " \"LAI\",\n", + " \"LWGAB\",\n", + " \"LWGEM\",\n", + " \"LWTUP\",\n", + " \"PS\",\n", + " \"QV2M\",\n", + " \"SLP\",\n", + " \"SWGNT\",\n", + " \"SWTNT\",\n", + " \"T2M\",\n", + " \"TQI\",\n", + " \"TQL\",\n", + " \"TQV\",\n", + " \"TS\",\n", + " \"U10M\",\n", + " \"V10M\",\n", + " \"Z0M\",\n", + "]\n", + "static_surface_vars = [\"FRACI\", \"FRLAND\", \"FROCEAN\", \"PHIS\"]\n", + "vertical_vars = [\"CLOUD\", \"H\", \"OMEGA\", \"PL\", \"QI\", \"QL\", \"QV\", \"T\", \"U\", \"V\"]\n", + "levels = [\n", + " 34.0,\n", + " 39.0,\n", + " 41.0,\n", + " 43.0,\n", + " 44.0,\n", + " 45.0,\n", + " 48.0,\n", + " 51.0,\n", + " 53.0,\n", + " 56.0,\n", + " 63.0,\n", + " 68.0,\n", + " 71.0,\n", + " 72.0,\n", + "]\n", + "padding = {\"level\": [0, 0], \"lat\": [0, -1], \"lon\": [0, 0]}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Lead time\n", + "When performing auto-regressive rollout, the intermediate steps require the\n", + "static data at those times and---if using `residual=climate`---the intermediate\n", + "climatology. We provide a dataloader that extends the MERRA2 loader of the\n", + "core model, adding in these additional terms. Further, it return target data for\n", + "the intermediate steps if those are required for loss terms. \n", + "\n", + "The `lead_time` flag still lets the target time for the model, however now it\n", + "only a single value and must be a positive integer multiple of the `-input_time`. " + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "lead_time = 3 # This variable can be change to change the task\n", + "input_time = -3 # This variable can be change to change the task" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data file\n", + "MERRA-2 data is available from 1980 to the present day,\n", + "at 3-hour temporal resolution. The dataloader we have provided\n", + "expects the surface data and vertical data to be saved in\n", + "separate files, and when provided with the directories, will\n", + "search for the relevant data that falls within the provided time range.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "159bec6eee1846d680fe284324094487", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Fetching 1 files: 0%| | 0/1 [00:00 dict[str, Tensor]:\n", + " \"\"\"Prepressing function for MERRA2 Dataset\n", + "\n", + " Args:\n", + " batch (dict): List of training samples, each sample should be a\n", + " dictionary with the following keys::\n", + "\n", + " 'sur_static': Numpy array of shape (3, lat, lon). For each pixel (lat, lon), the first dimension indexes sin(lat), cos(lon), sin(lon).\n", + " 'sur_vals': Torch tensor of shape (parameter, time, lat, lon).\n", + " 'sur_tars': Torch tensor of shape (parameter, time, lat, lon).\n", + " 'ulv_vals': Torch tensor of shape (parameter, level, time, lat, lon).\n", + " 'ulv_tars': Torch tensor of shape (parameter, level, time, lat, lon).\n", + " 'sur_climate': Torch tensor of shape (parameter, lat, lon)\n", + " 'ulv_climate': Torch tensor of shape (parameter, level, lat, lon)\n", + " 'lead_time': Integer.\n", + " 'input_time': Integer.\n", + "\n", + " padding: Dictionary with keys 'level', 'lat', 'lon', each of dim 2.\n", + "\n", + " Returns:\n", + " Dictionary with the following keys::\n", + "\n", + " 'x': [batch, time, parameter, lat, lon]\n", + " 'y': [batch, parameter, lat, lon]\n", + " 'static': [batch, parameter, lat, lon]\n", + " 'lead_time': [batch]\n", + " 'input_time': [batch]\n", + " 'climate (Optional)': [batch, parameter, lat, lon]\n", + "\n", + " Note:\n", + " Here, for x and y, 'parameter' is [surface parameter, upper level,\n", + " parameter x level]. Similarly for the static information we have\n", + " [sin(lat), cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod),\n", + " ...].\n", + " \"\"\" # noqa: E501\n", + " b0 = batch[0]\n", + " nbatch = len(batch)\n", + " data_keys = set(b0.keys())\n", + "\n", + " essential_keys = {\n", + " \"sur_static\",\n", + " \"sur_vals\",\n", + " \"sur_tars\",\n", + " \"ulv_vals\",\n", + " \"ulv_tars\",\n", + " \"input_time\",\n", + " \"lead_time\",\n", + " }\n", + "\n", + " climate_keys = {\n", + " \"sur_climate\",\n", + " \"ulv_climate\",\n", + " }\n", + "\n", + " all_keys = essential_keys | climate_keys\n", + "\n", + " if not essential_keys.issubset(data_keys):\n", + " raise ValueError(\"Missing essential keys.\")\n", + "\n", + " if not data_keys.issubset(all_keys):\n", + " raise ValueError(\"Unexpected keys in batch.\")\n", + "\n", + " # Bring all tensors from the batch into a single tensor\n", + " upl_x = torch.empty((nbatch, *b0[\"ulv_vals\"].shape))\n", + " upl_y = torch.empty((nbatch, *b0[\"ulv_tars\"].shape))\n", + "\n", + " sur_x = torch.empty((nbatch, *b0[\"sur_vals\"].shape))\n", + " sur_y = torch.empty((nbatch, *b0[\"sur_tars\"].shape))\n", + "\n", + " sur_sta = torch.empty((nbatch, *b0[\"sur_static\"].shape))\n", + "\n", + " lead_time = torch.empty((nbatch,), dtype=torch.float32)\n", + " input_time = torch.empty((nbatch,), dtype=torch.float32)\n", + "\n", + " for i, rec in enumerate(batch):\n", + " sur_x[i] = rec[\"sur_vals\"]\n", + " sur_y[i] = rec[\"sur_tars\"]\n", + "\n", + " upl_x[i] = rec[\"ulv_vals\"]\n", + " upl_y[i] = rec[\"ulv_tars\"]\n", + "\n", + " sur_sta[i] = rec[\"sur_static\"]\n", + "\n", + " lead_time[i] = rec[\"lead_time\"]\n", + " input_time[i] = rec[\"input_time\"]\n", + "\n", + " return_value = {\n", + " \"lead_time\": lead_time,\n", + " \"input_time\": input_time,\n", + " }\n", + "\n", + " # Reshape (batch, parameter, level, time, lat, lon) ->\n", + " # (batch, time, parameter, level, lat, lon)\n", + " upl_x = upl_x.permute((0, 3, 1, 2, 4, 5))\n", + " upl_y = upl_y.permute((0, 3, 1, 2, 4, 5))\n", + " # Reshape (batch, parameter, time, lat, lon) ->\n", + " # (batch, time, parameter, lat, lon)\n", + " sur_x = sur_x.permute((0, 2, 1, 3, 4))\n", + " sur_y = sur_y.permute((0, 2, 1, 3, 4))\n", + "\n", + " # Pad\n", + " padding_2d = (*padding[\"lon\"], *padding[\"lat\"])\n", + "\n", + " def pad2d(x):\n", + " return torch.nn.functional.pad(x, padding_2d, mode=\"constant\", value=0)\n", + "\n", + " padding_3d = (*padding[\"lon\"], *padding[\"lat\"], *padding[\"level\"])\n", + "\n", + " def pad3d(x):\n", + " return torch.nn.functional.pad(x, padding_3d, mode=\"constant\", value=0)\n", + "\n", + " sur_x = pad2d(sur_x).contiguous()\n", + " upl_x = pad3d(upl_x).contiguous()\n", + " sur_y = pad2d(sur_y).contiguous()\n", + " upl_y = pad3d(upl_y).contiguous()\n", + " return_value[\"static\"] = pad2d(sur_sta).contiguous()\n", + "\n", + " # Remove time for targets\n", + " upl_y = torch.squeeze(upl_y, 1)\n", + " sur_y = torch.squeeze(sur_y, 1)\n", + "\n", + " # We stack along the combined parameter x level dimension\n", + " return_value[\"x\"] = torch.cat(\n", + " (sur_x, upl_x.view(*upl_x.shape[:2], -1, *upl_x.shape[4:])), dim=2\n", + " )\n", + " return_value[\"y\"] = torch.cat(\n", + " (sur_y, upl_y.view(upl_y.shape[0], -1, *upl_y.shape[3:])), dim=1\n", + " )\n", + "\n", + " if climate_keys.issubset(data_keys):\n", + " sur_climate = torch.empty((nbatch, *b0[\"sur_climate\"].shape))\n", + " ulv_climate = torch.empty((nbatch, *b0[\"ulv_climate\"].shape))\n", + " for i, rec in enumerate(batch):\n", + " sur_climate[i] = rec[\"sur_climate\"]\n", + " ulv_climate[i] = rec[\"ulv_climate\"]\n", + " sur_climate = pad2d(sur_climate)\n", + " ulv_climate = pad3d(ulv_climate)\n", + "\n", + " return_value[\"climate\"] = torch.cat(\n", + " (\n", + " sur_climate,\n", + " ulv_climate.view(nbatch, -1, *ulv_climate.shape[3:]),\n", + " ),\n", + " dim=1,\n", + " )\n", + "\n", + " return return_value\n", + "\n", + "\n", + "def input_scalers(\n", + " surf_vars: list[str],\n", + " vert_vars: list[str],\n", + " levels: list[float],\n", + " surf_path: str | Path,\n", + " vert_path: str | Path,\n", + ") -> tuple[Tensor, Tensor]:\n", + " \"\"\"Reads the input scalers\n", + "\n", + " Args:\n", + " surf_vars: surface variables to be used.\n", + " vert_vars: vertical variables to be used.\n", + " levels: MERRA2 levels to use.\n", + " surf_path: path to surface scalers file.\n", + " vert_path: path to vertical level scalers file.\n", + "\n", + " Returns:\n", + " mu (Tensor): mean values\n", + " var (Tensor): varience values\n", + " \"\"\"\n", + " with h5py.File(Path(surf_path), \"r\", libver=\"latest\") as surf_file:\n", + " stats = [x.decode().lower() for x in surf_file[\"statistic\"][()]]\n", + " mu_idx = stats.index(\"mu\")\n", + " sig_idx = stats.index(\"sigma\")\n", + "\n", + " s_mu = torch.tensor([surf_file[k][()][mu_idx] for k in surf_vars])\n", + " s_sig = torch.tensor([surf_file[k][()][sig_idx] for k in surf_vars])\n", + "\n", + " with h5py.File(Path(vert_path), \"r\", libver=\"latest\") as vert_file:\n", + " stats = [x.decode().lower() for x in vert_file[\"statistic\"][()]]\n", + " mu_idx = stats.index(\"mu\")\n", + " sig_idx = stats.index(\"sigma\")\n", + "\n", + " lvl = vert_file[\"lev\"][()]\n", + " l_idx = [np.where(lvl == v)[0].item() for v in levels]\n", + "\n", + " v_mu = np.array([vert_file[k][()][mu_idx, l_idx] for k in vert_vars])\n", + " v_sig = np.array([vert_file[k][()][sig_idx, l_idx] for k in vert_vars])\n", + "\n", + " v_mu = torch.from_numpy(v_mu).view(-1)\n", + " v_sig = torch.from_numpy(v_sig).view(-1)\n", + "\n", + " mu = torch.cat((s_mu, v_mu), dim=0).to(torch.float32)\n", + " sig = torch.cat((s_sig, v_sig), dim=0).to(torch.float32).clamp(1e-4, 1e4)\n", + " return mu, sig\n", + "\n", + "\n", + "def static_input_scalers(\n", + " scalar_path: str | Path, stat_vars: list[str], unscaled_params: int = 7\n", + ") -> tuple[Tensor, Tensor]:\n", + " scalar_path = Path(scalar_path)\n", + "\n", + " with h5py.File(scalar_path, \"r\", libver=\"latest\") as scaler_file:\n", + " stats = [x.decode().lower() for x in scaler_file[\"statistic\"][()]]\n", + " mu_idx = stats.index(\"mu\")\n", + " sig_idx = stats.index(\"sigma\")\n", + "\n", + " mu = torch.tensor([scaler_file[k][()][mu_idx] for k in stat_vars])\n", + " sig = torch.tensor([scaler_file[k][()][sig_idx] for k in stat_vars])\n", + "\n", + " z = torch.zeros(unscaled_params, dtype=mu.dtype, device=mu.device)\n", + " o = torch.ones(unscaled_params, dtype=sig.dtype, device=sig.device)\n", + " mu = torch.cat((z, mu), dim=0).to(torch.float32)\n", + " sig = torch.cat((o, sig), dim=0).to(torch.float32)\n", + "\n", + " return mu, sig.clamp(1e-4, 1e4)\n", + "\n", + "\n", + "def output_scalers(\n", + " surf_vars: list[str],\n", + " vert_vars: list[str],\n", + " levels: list[float],\n", + " surf_path: str | Path,\n", + " vert_path: str | Path,\n", + ") -> Tensor:\n", + " surf_path = Path(surf_path)\n", + " vert_path = Path(vert_path)\n", + "\n", + " with h5py.File(surf_path, \"r\", libver=\"latest\") as surf_file:\n", + " svars = torch.tensor([surf_file[k][()] for k in surf_vars])\n", + "\n", + " with h5py.File(vert_path, \"r\", libver=\"latest\") as vert_file:\n", + " lvl = vert_file[\"lev\"][()]\n", + " l_idx = [np.where(lvl == v)[0].item() for v in levels]\n", + " vvars = np.array([vert_file[k][()][l_idx] for k in vert_vars])\n", + " vvars = torch.from_numpy(vvars).view(-1)\n", + "\n", + " var = torch.cat((svars, vvars), dim=0).to(torch.float32).clamp(1e-7, 1e7)\n", + "\n", + " return var\n", + "\n", + "\n", + "class SampleSpec:\n", + " \"\"\"\n", + " A data class to collect the information used to define a sample.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " inputs: tuple[pd.Timestamp, pd.Timestamp],\n", + " lead_time: int,\n", + " target: pd.Timestamp | list[pd.Timestamp],\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " inputs: Tuple of timestamps. In ascending order.\n", + " lead_time: Lead time. In hours.\n", + " target: Timestamp of the target. Can be before or after the inputs.\n", + " \"\"\"\n", + " if not inputs[0] < inputs[1]:\n", + " raise ValueError(\n", + " \"Timestamps in `inputs` should be in strictly ascending order.\"\n", + " )\n", + "\n", + " self.inputs = inputs\n", + " self.input_time = (inputs[1] - inputs[0]).total_seconds() / 3600\n", + " self.lead_time = lead_time\n", + " self.target = target\n", + "\n", + " self.times = [*inputs, target]\n", + " self.stat_times = [inputs[-1]]\n", + "\n", + " @property\n", + " def climatology_info(self) -> tuple[int, int]:\n", + " \"\"\"Get the required climatology info.\n", + "\n", + " :return: information required to obtain climatology data. Essentially\n", + " this is the day of the year and hour of the day of the target\n", + " timestamp, with the former restricted to the interval [1, 365].\n", + " :rtype: tuple\n", + " \"\"\"\n", + " return (min(self.target.dayofyear, 365), self.target.hour)\n", + "\n", + " @property\n", + " def year(self) -> int:\n", + " return self.inputs[1].year\n", + "\n", + " @property\n", + " def dayofyear(self) -> int:\n", + " return self.inputs[1].dayofyear\n", + "\n", + " @property\n", + " def hourofday(self) -> int:\n", + " return self.inputs[1].hour\n", + "\n", + " def _info_str(self) -> str:\n", + " iso_8601 = \"%Y-%m-%dT%H:%M:%S\"\n", + "\n", + " return (\n", + " f\"Issue time: {self.inputs[1].strftime(iso_8601)}\\n\"\n", + " f\"Lead time: {self.lead_time} hours ahead\\n\"\n", + " f\"Input delta: {self.input_time} hours\\n\"\n", + " f\"Target time: {self.target.strftime(iso_8601)}\"\n", + " )\n", + "\n", + " @classmethod\n", + " def get(cls, timestamp: pd.Timestamp, dt: int, lead_time: int):\n", + " \"\"\"Given a timestamp and lead time, generates a SampleSpec object\n", + " describing the sample further.\n", + "\n", + " Args:\n", + " timestamp: Timstamp of the sample, Ie this is the larger of the two\n", + " input timstamps.\n", + " dt: Time between input samples, in hours.\n", + " lead_time: Lead time. In hours.\n", + "\n", + " Returns:\n", + " SampleSpec\n", + " \"\"\" # noqa: E501\n", + " assert dt > 0, \"dt should be possitive\"\n", + " lt = pd.to_timedelta(lead_time, unit=\"h\")\n", + " dt = pd.to_timedelta(dt, unit=\"h\")\n", + "\n", + " if lead_time >= 0:\n", + " timestamp_target = timestamp + lt\n", + " else:\n", + " timestamp_target = timestamp - dt + lt\n", + "\n", + " spec = cls(\n", + " inputs=(timestamp - dt, timestamp),\n", + " lead_time=lead_time,\n", + " target=timestamp_target,\n", + " )\n", + "\n", + " return spec\n", + "\n", + " def __repr__(self) -> str:\n", + " return self._info_str()\n", + "\n", + " def __str__(self) -> str:\n", + " return self._info_str()\n", + "\n", + "\n", + "class Merra2Dataset(Dataset):\n", + " \"\"\"MERRA2 dataset. The dataset unifies surface and vertical data as well as\n", + " optional climatology.\n", + "\n", + " Samples come in the form of a dictionary. Not all keys support all\n", + " variables, yet the general ordering of dimensions is\n", + " parameter, level, time, lat, lon\n", + "\n", + " Note:\n", + " Data is assumed to be in NetCDF files containing daily data at 3-hourly\n", + " intervals. These follow the naming patterns\n", + " MERRA2_sfc_YYYYMMHH.nc and MERRA_pres_YYYYMMHH.nc and can be located in\n", + " two different locations. Optional climatology data comes from files\n", + " climate_surface_doyDOY_hourHOD.nc and\n", + " climate_vertical_doyDOY_hourHOD.nc.\n", + "\n", + "\n", + " Note:\n", + " `_get_valid_timestamps` assembles a set of all timestamps for which\n", + " there is data (with hourly resolutions). The result is stored in\n", + " `_valid_timestamps`. `_get_valid_climate_timestamps` does the same with\n", + " climatology data and stores it in `_valid_climate_timestamps`.\n", + "\n", + " Based on this information, `samples` generates a list of valid samples,\n", + " stored in `samples`. Here the format is::\n", + "\n", + " [\n", + " [\n", + " (timestamp 1, lead time A),\n", + " (timestamp 1, lead time B),\n", + " (timestamp 1, lead time C),\n", + " ],\n", + " [\n", + " (timestamp 2, lead time D),\n", + " (timestamp 2, lead time E),\n", + " ]\n", + " ]\n", + "\n", + " That is, the outer list iterates over timestamps (init times), the\n", + " inner over lead times. Only valid entries are stored.\n", + " \"\"\"\n", + "\n", + " valid_vertical_vars = [\n", + " \"CLOUD\",\n", + " \"H\",\n", + " \"OMEGA\",\n", + " \"PL\",\n", + " \"QI\",\n", + " \"QL\",\n", + " \"QV\",\n", + " \"T\",\n", + " \"U\",\n", + " \"V\",\n", + " ]\n", + " valid_surface_vars = [\n", + " \"EFLUX\",\n", + " \"GWETROOT\",\n", + " \"HFLUX\",\n", + " \"LAI\",\n", + " \"LWGAB\",\n", + " \"LWGEM\",\n", + " \"LWTUP\",\n", + " \"PRECTOT\",\n", + " \"PS\",\n", + " \"QV2M\",\n", + " \"SLP\",\n", + " \"SWGNT\",\n", + " \"SWTNT\",\n", + " \"T2M\",\n", + " \"TQI\",\n", + " \"TQL\",\n", + " \"TQV\",\n", + " \"TS\",\n", + " \"U10M\",\n", + " \"V10M\",\n", + " \"Z0M\",\n", + " ]\n", + " valid_static_surface_vars = [\"FRACI\", \"FRLAND\", \"FROCEAN\", \"PHIS\"]\n", + "\n", + " valid_levels = [\n", + " 34.0,\n", + " 39.0,\n", + " 41.0,\n", + " 43.0,\n", + " 44.0,\n", + " 45.0,\n", + " 48.0,\n", + " 51.0,\n", + " 53.0,\n", + " 56.0,\n", + " 63.0,\n", + " 68.0,\n", + " 71.0,\n", + " 72.0,\n", + " ]\n", + "\n", + " timedelta_input = pd.to_timedelta(3, unit=\"h\")\n", + "\n", + " def __init__(\n", + " self,\n", + " time_range: tuple[str | pd.Timestamp, str | pd.Timestamp],\n", + " lead_times: list[int],\n", + " input_times: list[int],\n", + " data_path_surface: str | Path,\n", + " data_path_vertical: str | Path,\n", + " climatology_path_surface: str | Path | None = None,\n", + " climatology_path_vertical: str | Path | None = None,\n", + " surface_vars: list[str] | None = None,\n", + " static_surface_vars: list[str] | None = None,\n", + " vertical_vars: list[str] | None = None,\n", + " levels: list[float] | None = None,\n", + " roll_longitudes: int = 0,\n", + " positional_encoding: str = \"absolute\",\n", + " rtype: type = np.float32,\n", + " dtype: torch.dtype = torch.float32,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " data_path_surface: Location of surface data.\n", + " data_path_vertical: Location of vertical data.\n", + " climatology_path_surface: Location of (optional) surface\n", + " climatology.\n", + " climatology_path_vertical: Location of (optional) vertical\n", + " climatology.\n", + " surface_vars: Surface variables.\n", + " static_surface_vars: Static surface variables.\n", + " vertical_vars: Vertical variables.\n", + " levels: Levels.\n", + " time_range: Used to subset data.\n", + " lead_times: Lead times for generalized forecasting.\n", + " roll_longitudes: Set to non-zero value to data by random amount\n", + " along longitude dimension.\n", + " position_encoding: possible values are\n", + " ['absolute' (default), 'fourier'].\n", + " 'absolute' returns lat lon encoded in 3 dimensions using sine\n", + " and cosine\n", + " 'fourier' returns lat/lon to be encoded by model\n", + " returns lat/lon to be encoded by model\n", + " rtype: numpy data type used during read\n", + " dtype: torch data type of data output\n", + " \"\"\"\n", + "\n", + " self.time_range = (\n", + " pd.to_datetime(time_range[0]),\n", + " pd.to_datetime(time_range[1]),\n", + " )\n", + " self.lead_times = lead_times\n", + " self.input_times = input_times\n", + " self._roll_longitudes = list(range(roll_longitudes + 1))\n", + "\n", + " self._uvars = vertical_vars or self.valid_vertical_vars\n", + " self._level = levels or self.valid_levels\n", + " self._svars = surface_vars or self.valid_surface_vars\n", + " self._sstat = static_surface_vars or self.valid_static_surface_vars\n", + " self._nuvars = len(self._uvars)\n", + " self._nlevel = len(self._level)\n", + " self._nsvars = len(self._svars)\n", + " self._nsstat = len(self._sstat)\n", + "\n", + " self.rtype = rtype\n", + " self.dtype = dtype\n", + "\n", + " self.positional_encoding = positional_encoding\n", + "\n", + " self._data_path_surface = Path(data_path_surface)\n", + " self._data_path_vertical = Path(data_path_vertical)\n", + "\n", + " self.dir_exists(self._data_path_surface)\n", + " self.dir_exists(self._data_path_vertical)\n", + "\n", + " self._get_coordinates()\n", + "\n", + " self._climatology_path_surface = Path(climatology_path_surface) or None\n", + " self._climatology_path_vertical = (\n", + " Path(climatology_path_vertical) or None\n", + " )\n", + " self._require_clim = (\n", + " self._climatology_path_surface is not None\n", + " and self._climatology_path_vertical is not None\n", + " )\n", + "\n", + " if self._require_clim:\n", + " self.dir_exists(self._climatology_path_surface)\n", + " self.dir_exists(self._climatology_path_vertical)\n", + " elif (\n", + " climatology_path_surface is None\n", + " and climatology_path_vertical is None\n", + " ):\n", + " self._climatology_path_surface = None\n", + " self._climatology_path_vertical = None\n", + " else:\n", + " raise ValueError(\n", + " \"Either both or neither of\"\n", + " \"`climatology_path_surface` and\"\n", + " \"`climatology_path_vertical` should be None.\"\n", + " )\n", + "\n", + " if not set(self._svars).issubset(set(self.valid_surface_vars)):\n", + " raise ValueError(\"Invalid surface variable.\")\n", + "\n", + " if not set(self._sstat).issubset(set(self.valid_static_surface_vars)):\n", + " raise ValueError(\"Invalid static surface variable.\")\n", + "\n", + " if not set(self._uvars).issubset(set(self.valid_vertical_vars)):\n", + " raise ValueError(\"Inalid vertical variable.\")\n", + "\n", + " if not set(self._level).issubset(set(self.valid_levels)):\n", + " raise ValueError(\"Invalid level.\")\n", + "\n", + " @staticmethod\n", + " def dir_exists(path: Path) -> None:\n", + " if not path.is_dir():\n", + " raise ValueError(f\"Directory {path} does not exist.\")\n", + "\n", + " @property\n", + " def upper_shape(self) -> tuple:\n", + " \"\"\"Returns the vertical variables shape\n", + " Returns:\n", + " tuple: vertical variable shape in the following order::\n", + "\n", + " [VAR, LEV, TIME, LAT, LON]\n", + " \"\"\"\n", + " return self._nuvars, self._nlevel, 2, 361, 576\n", + "\n", + " @property\n", + " def surface_shape(self) -> tuple:\n", + " \"\"\"Returns the surface variables shape\n", + "\n", + " Returns:\n", + " tuple: surafce shape in the following order::\n", + "\n", + " [VAR, LEV, TIME, LAT, LON]\n", + " \"\"\"\n", + " return self._nsvars, 2, 361, 576\n", + "\n", + " def data_file_surface(self, timestamp: pd.Timestamp) -> Path:\n", + " \"\"\"Build the surfcae data file name based on timestamp\n", + "\n", + " Args:\n", + " timestamp: a timestamp\n", + "\n", + " Returns:\n", + " Path: constructed path\n", + " \"\"\"\n", + " pattern = \"MERRA2_sfc_%Y%m%d.nc\"\n", + " data_file = self._data_path_surface / timestamp.strftime(pattern)\n", + " return data_file\n", + "\n", + " def data_file_vertical(self, timestamp: pd.Timestamp) -> Path:\n", + " \"\"\"Build the vertical data file name based on timestamp\n", + "\n", + " Args:\n", + " timestamp: a timestamp\n", + "\n", + " Returns:\n", + " Path: constructed path\n", + " \"\"\"\n", + " pattern = \"MERRA_pres_%Y%m%d.nc\"\n", + " data_file = self._data_path_vertical / timestamp.strftime(pattern)\n", + " return data_file\n", + "\n", + " def data_file_surface_climate(\n", + " self,\n", + " timestamp: pd.Timestamp | None = None,\n", + " dayofyear: int | None = None,\n", + " hourofday: int | None = None,\n", + " ) -> Path:\n", + " \"\"\"\n", + " Returns the path to a climatology file based either on a timestamp or\n", + " the dayofyear / hourofday combination.\n", + " Args:\n", + " timestamp: A timestamp.\n", + " dayofyear: Day of the year. 1 to 366.\n", + " hourofday: Hour of the day. 0 to 23.\n", + " Returns:\n", + " Path: Path to climatology file.\n", + " \"\"\"\n", + " if timestamp is not None and (\n", + " (dayofyear is not None) or (hourofday is not None)\n", + " ):\n", + " raise ValueError(\n", + " \"Provide either timestamp or both dayofyear and hourofday.\"\n", + " )\n", + "\n", + " if timestamp is not None:\n", + " dayofyear = min(timestamp.dayofyear, 365)\n", + " hourofday = timestamp.hour\n", + "\n", + " file_name = f\"climate_surface_doy{dayofyear:03}_hour{hourofday:02}.nc\"\n", + " data_file = self._climatology_path_surface / file_name\n", + " return data_file\n", + "\n", + " def data_file_vertical_climate(\n", + " self,\n", + " timestamp: pd.Timestamp | None = None,\n", + " dayofyear: int | None = None,\n", + " hourofday: int | None = None,\n", + " ) -> Path:\n", + " \"\"\"Returns the path to a climatology file based either on a timestamp\n", + " or the dayofyear / hourofday combination.\n", + "\n", + " Args:\n", + " timestamp: A timestamp. dayofyear: Day of the year. 1 to 366.\n", + " hourofday: Hour of the day. 0 to 23.\n", + " Returns:\n", + " Path: Path to climatology file.\n", + " \"\"\"\n", + " if timestamp is not None and (\n", + " (dayofyear is not None) or (hourofday is not None)\n", + " ):\n", + " raise ValueError(\n", + " \"Provide either timestamp or both dayofyear and hourofday.\"\n", + " )\n", + "\n", + " if timestamp is not None:\n", + " dayofyear = min(timestamp.dayofyear, 365)\n", + " hourofday = timestamp.hour\n", + "\n", + " file_name = f\"climate_vertical_doy{dayofyear:03}_hour{hourofday:02}.nc\"\n", + " data_file = self._climatology_path_vertical / file_name\n", + " return data_file\n", + "\n", + " def _get_coordinates(self) -> None:\n", + " \"\"\"\n", + " Obtains the coordiantes (latitudes and longitudes) from a single data\n", + " file.\n", + " \"\"\"\n", + " timestamp = next(iter(self.valid_timestamps))\n", + "\n", + " file = self.data_file_surface(timestamp)\n", + " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", + " self.lats = lats = handle[\"lat\"][()].astype(self.rtype)\n", + " self.lons = lons = handle[\"lon\"][()].astype(self.rtype)\n", + "\n", + " deg_to_rad = np.pi / 180\n", + " self._embed_lat = np.sin(lats * deg_to_rad).reshape(-1, 1)\n", + "\n", + " self._embed_lon = np.empty((2, 1, len(lons)), dtype=self.rtype)\n", + " self._embed_lon[0, 0] = np.cos(lons * deg_to_rad)\n", + " self._embed_lon[1, 0] = np.sin(lons * deg_to_rad)\n", + "\n", + " @ft.cached_property\n", + " def lats(self) -> np.ndarray:\n", + " timestamp = next(iter(self.valid_timestamps))\n", + "\n", + " file = self.data_file_surface(timestamp)\n", + " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", + " return handle[\"lat\"][()].astype(self.rtype)\n", + "\n", + " @ft.cached_property\n", + " def lons(self) -> np.ndarray:\n", + " timestamp = next(iter(self.valid_timestamps))\n", + "\n", + " file = self.data_file_surface(timestamp)\n", + " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", + " return handle[\"lon\"][()].astype(self.rtype)\n", + "\n", + " @ft.cached_property\n", + " def position_signal(self) -> np.ndarray:\n", + " \"\"\"Generates the \"position signal\" that is part of the static\n", + " features.\n", + "\n", + " Returns:\n", + " Tensor: Torch tensor of dimension (parameter, lat, lon) containing\n", + " sin(lat), cos(lon), sin(lon).\n", + " \"\"\"\n", + "\n", + " latitudes, longitudes = np.meshgrid(\n", + " self.lats, self.lons, indexing=\"ij\"\n", + " )\n", + "\n", + " if self.positional_encoding == \"absolute\":\n", + " latitudes = latitudes / 360 * 2.0 * np.pi\n", + " longitudes = longitudes / 360 * 2.0 * np.pi\n", + " sur_static = np.stack(\n", + " [np.sin(latitudes), np.cos(longitudes), np.sin(longitudes)],\n", + " axis=0,\n", + " )\n", + " else:\n", + " sur_static = np.stack([latitudes, longitudes], axis=0)\n", + "\n", + " sur_static = sur_static.astype(self.rtype)\n", + "\n", + " return sur_static\n", + "\n", + " @ft.cached_property\n", + " def valid_timestamps(self) -> set[pd.Timestamp]:\n", + " \"\"\"Generates list of valid timestamps based on available files. Only\n", + " timestamps for which both surface and vertical information is available\n", + " are considered valid.\n", + " Returns:\n", + " list: list of timestamps\n", + " \"\"\"\n", + "\n", + " s_glob = self._data_path_surface.glob(\"MERRA2_sfc_????????.nc\")\n", + " s_files = [os.path.basename(f) for f in s_glob]\n", + " v_glob = self._data_path_surface.glob(\"MERRA_pres_????????.nc\")\n", + " v_files = [os.path.basename(f) for f in v_glob]\n", + "\n", + " s_re = re.compile(r\"MERRA2_sfc_(\\d{8}).nc\\Z\")\n", + " v_re = re.compile(r\"MERRA_pres_(\\d{8}).nc\\Z\")\n", + " fmt = \"%Y%m%d\"\n", + "\n", + " s_times = {\n", + " (datetime.strptime(m[1], fmt))\n", + " for f in s_files\n", + " if (m := s_re.match(f))\n", + " }\n", + " v_times = {\n", + " (datetime.strptime(m[1], fmt))\n", + " for f in v_files\n", + " if (m := v_re.match(f))\n", + " }\n", + "\n", + " times = s_times.intersection(v_times)\n", + "\n", + " # Each file contains a day at 3 hour intervals\n", + " times = {\n", + " t + timedelta(hours=i) for i in range(0, 24, 3) for t in times\n", + " }\n", + "\n", + " start_time, end_time = self.time_range\n", + " times = {pd.Timestamp(t) for t in times if start_time <= t <= end_time}\n", + "\n", + " return times\n", + "\n", + " @ft.cached_property\n", + " def valid_climate_timestamps(self) -> set[tuple[int, int]]:\n", + " \"\"\"Generates list of \"timestamps\" (dayofyear, hourofday) for which\n", + " climatology data is present. Only instances for which surface and\n", + " vertical data is available are considered valid.\n", + " Returns:\n", + " list: List of tuples describing valid climatology instances.\n", + " \"\"\"\n", + " if not self._require_clim:\n", + " return set()\n", + "\n", + " s_glob = self._climatology_path_surface.glob(\n", + " \"climate_surface_doy???_hour??.nc\"\n", + " )\n", + " s_files = [os.path.basename(f) for f in s_glob]\n", + "\n", + " v_glob = self._climatology_path_vertical.glob(\n", + " \"climate_vertical_doy???_hour??.nc\"\n", + " )\n", + " v_files = [os.path.basename(f) for f in v_glob]\n", + "\n", + " s_re = re.compile(r\"climate_surface_doy(\\d{3})_hour(\\d{2}).nc\\Z\")\n", + " v_re = re.compile(r\"climate_vertical_doy(\\d{3})_hour(\\d{2}).nc\\Z\")\n", + "\n", + " s_times = {\n", + " (int(m[1]), int(m[2])) for f in s_files if (m := s_re.match(f))\n", + " }\n", + " v_times = {\n", + " (int(m[1]), int(m[2])) for f in v_files if (m := v_re.match(f))\n", + " }\n", + "\n", + " times = s_times.intersection(v_times)\n", + "\n", + " return times\n", + "\n", + " def _data_available(self, spec: SampleSpec) -> bool:\n", + " \"\"\"\n", + " Checks whether data is available for a given SampleSpec object. Does so\n", + " using the internal sets with available data previously constructed. Not\n", + " by checking the file system.\n", + " Args:\n", + " spec: SampleSpec object as returned by SampleSpec.get\n", + " Returns:\n", + " bool: if data is availability.\n", + " \"\"\"\n", + " valid = set(spec.times).issubset(self.valid_timestamps)\n", + "\n", + " if self._require_clim:\n", + " sci = spec.climatology_info\n", + " ci = set(sci) if isinstance(sci, list) else set([sci]) # noqa: C405\n", + " valid &= ci.issubset(self.valid_climate_timestamps)\n", + "\n", + " return valid\n", + "\n", + " @ft.cached_property\n", + " def samples(self) -> list[tuple[pd.Timestamp, int, int]]:\n", + " \"\"\"\n", + " Generates list of all valid samlpes.\n", + " Returns:\n", + " list: List of tuples (timestamp, input time, lead time).\n", + " \"\"\"\n", + " valid_samples = []\n", + " dts = [(it, lt) for it in self.input_times for lt in self.lead_times]\n", + "\n", + " for timestamp in sorted(self.valid_timestamps):\n", + " timestamp_samples = []\n", + " for it, lt in dts:\n", + " spec = SampleSpec.get(timestamp, -it, lt)\n", + "\n", + " if self._data_available(spec):\n", + " timestamp_samples.append((timestamp, it, lt))\n", + "\n", + " if timestamp_samples:\n", + " valid_samples.append(timestamp_samples)\n", + "\n", + " return valid_samples\n", + "\n", + " def _to_torch(\n", + " self,\n", + " data: dict[str, Tensor | list[Tensor]],\n", + " dtype: torch.dtype = torch.float32,\n", + " ) -> dict[str, Tensor | list[Tensor]]:\n", + " out = {}\n", + " for k, v in data.items():\n", + " if isinstance(v, list):\n", + " out[k] = [torch.from_numpy(x).to(dtype) for x in v]\n", + " else:\n", + " out[k] = torch.from_numpy(v).to(dtype)\n", + "\n", + " return out\n", + "\n", + " def _lat_roll(\n", + " self, data: dict[str, Tensor | list[Tensor]], n: int\n", + " ) -> dict[str, Tensor | list[Tensor]]:\n", + " out = {}\n", + " for k, v in data.items():\n", + " if isinstance(v, list):\n", + " out[k] = [torch.roll(x, shifts=n, dims=-1) for x in v]\n", + " else:\n", + " out[k] = torch.roll(v, shifts=n, dims=-1)\n", + "\n", + " return out\n", + "\n", + " def _read_static_data(\n", + " self, file: str | Path, doy: int, hod: int\n", + " ) -> np.ndarray:\n", + " with h5py.File(file, \"r\", libver=\"latest\") as handle:\n", + " lats_surf = handle[\"lat\"]\n", + " lons_surf = handle[\"lon\"]\n", + "\n", + " nll = (len(lats_surf), len(lons_surf))\n", + "\n", + " npos = len(self.position_signal)\n", + " ntime = 4\n", + "\n", + " nstat = npos + ntime + self._nsstat\n", + " data = np.empty((nstat, *nll), dtype=self.rtype)\n", + "\n", + " for i, key in enumerate(self._sstat, start=npos + ntime):\n", + " data[i] = handle[key][()].astype(dtype=self.rtype)\n", + "\n", + " # [possition signal], cos(doy), sin(doy), cos(hod), sin(hod)\n", + " data[0:npos] = self.position_signal\n", + " data[npos + 0] = np.cos(2 * np.pi * doy / 366)\n", + " data[npos + 1] = np.sin(2 * np.pi * doy / 366)\n", + " data[npos + 2] = np.cos(2 * np.pi * hod / 24)\n", + " data[npos + 3] = np.sin(2 * np.pi * hod / 24)\n", + "\n", + " return data\n", + "\n", + " def _read_surface(\n", + " self, tidx: int, nll: tuple[int, int], handle: h5py.File\n", + " ) -> np.ndarray:\n", + " data = np.empty((self._nsvars, *nll), dtype=self.rtype)\n", + "\n", + " for i, key in enumerate(self._svars):\n", + " data[i] = handle[key][tidx][()].astype(dtype=self.rtype)\n", + "\n", + " return data\n", + "\n", + " def _read_levels(\n", + " self, tidx: int, nll: tuple[int, int], handle: h5py.File\n", + " ) -> np.ndarray:\n", + " lvls = handle[\"lev\"][()]\n", + " lidx = self._level_idxs(lvls)\n", + "\n", + " data = np.empty((self._nuvars, self._nlevel, *nll), dtype=self.rtype)\n", + "\n", + " for i, key in enumerate(self._uvars):\n", + " data[i] = handle[key][tidx, lidx][()].astype(dtype=self.rtype)\n", + "\n", + " return np.ascontiguousarray(np.flip(data, axis=1))\n", + "\n", + " def _level_idxs(self, lvls):\n", + " lidx = [np.argwhere(lvls == int(lvl)).item() for lvl in self._level]\n", + " return sorted(lidx)\n", + "\n", + " @staticmethod\n", + " def _date_to_tidx(date: datetime | pd.Timestamp, handle: h5py.File) -> int:\n", + " if isinstance(date, pd.Timestamp):\n", + " date = date.to_pydatetime()\n", + "\n", + " time = handle[\"time\"]\n", + "\n", + " t0 = time.attrs[\"begin_time\"][()].item()\n", + " d0 = f\"{time.attrs['begin_date'][()].item()}\"\n", + "\n", + " offset = datetime.strptime(d0, \"%Y%m%d\")\n", + "\n", + " times = [offset + timedelta(minutes=int(t + t0)) for t in time[()]]\n", + " return times.index(date)\n", + "\n", + " def _read_data(\n", + " self, file_pair: tuple[str, str], date: datetime\n", + " ) -> dict[str, np.ndarray]:\n", + " s_file, v_file = file_pair\n", + "\n", + " with h5py.File(s_file, \"r\", libver=\"latest\") as shandle:\n", + " lats_surf = shandle[\"lat\"]\n", + " lons_surf = shandle[\"lon\"]\n", + "\n", + " nll = (len(lats_surf), len(lons_surf))\n", + "\n", + " tidx = self._date_to_tidx(date, shandle)\n", + "\n", + " sdata = self._read_surface(tidx, nll, shandle)\n", + "\n", + " with h5py.File(v_file, \"r\", libver=\"latest\") as vhandle:\n", + " lats_vert = vhandle[\"lat\"]\n", + " lons_vert = vhandle[\"lon\"]\n", + "\n", + " nll = (len(lats_vert), len(lons_vert))\n", + "\n", + " tidx = self._date_to_tidx(date, vhandle)\n", + "\n", + " vdata = self._read_levels(tidx, nll, vhandle)\n", + "\n", + " data = {\"vert\": vdata, \"surf\": sdata}\n", + "\n", + " return data\n", + "\n", + " def _read_climate(\n", + " self, file_pair: tuple[str, str]\n", + " ) -> dict[str, np.ndarray]:\n", + " s_file, v_file = file_pair\n", + "\n", + " with h5py.File(s_file, \"r\", libver=\"latest\") as shandle:\n", + " lats_surf = shandle[\"lat\"]\n", + " lons_surf = shandle[\"lon\"]\n", + "\n", + " nll = (len(lats_surf), len(lons_surf))\n", + "\n", + " sdata = np.empty((self._nsvars, *nll), dtype=self.rtype)\n", + "\n", + " for i, key in enumerate(self._svars):\n", + " sdata[i] = shandle[key][()].astype(dtype=self.rtype)\n", + "\n", + " with h5py.File(v_file, \"r\", libver=\"latest\") as vhandle:\n", + " lats_vert = vhandle[\"lat\"]\n", + " lons_vert = vhandle[\"lon\"]\n", + "\n", + " nll = (len(lats_vert), len(lons_vert))\n", + "\n", + " lvls = vhandle[\"lev\"][()]\n", + " lidx = self._level_idxs(lvls)\n", + "\n", + " vdata = np.empty(\n", + " (self._nuvars, self._nlevel, *nll), dtype=self.rtype\n", + " )\n", + "\n", + " for i, key in enumerate(self._uvars):\n", + " vdata[i] = vhandle[key][lidx][()].astype(dtype=self.rtype)\n", + "\n", + " data = {\n", + " \"vert\": np.ascontiguousarray(np.flip(vdata, axis=1)),\n", + " \"surf\": sdata,\n", + " }\n", + "\n", + " return data\n", + "\n", + " def get_data_from_sample_spec(\n", + " self, spec: SampleSpec\n", + " ) -> dict[str, Tensor | int | float]:\n", + " \"\"\"Loads and assembles sample data given a SampleSpec object.\n", + "\n", + " Args:\n", + " spec (SampleSpec): Full details regarding the data to be loaded\n", + " Returns:\n", + " dict: Dictionary with the following keys::\n", + "\n", + " 'sur_static': Torch tensor of shape [parameter, lat, lon]. For\n", + " each pixel (lat, lon), the first 7 dimensions index sin(lat),\n", + " cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod).\n", + " Where doy is the day of the year [1, 366] and hod the hour of\n", + " the day [0, 23].\n", + " 'sur_vals': Torch tensor of shape [parameter, time, lat, lon].\n", + " 'sur_tars': Torch tensor of shape [parameter, time, lat, lon].\n", + " 'ulv_vals': Torch tensor of shape [parameter, level, time, lat, lon].\n", + " 'ulv_tars': Torch tensor of shape [parameter, level, time, lat, lon].\n", + " 'sur_climate': Torch tensor of shape [parameter, lat, lon].\n", + " 'ulv_climate': Torch tensor of shape [paramter, level, lat, lon].\n", + " 'lead_time': Float.\n", + " 'input_time': Float.\n", + "\n", + " \"\"\" # noqa: E501\n", + "\n", + " # We assemble the unique timestamps for which we need data.\n", + " vals_required = {*spec.times}\n", + " stat_required = {*spec.stat_times}\n", + "\n", + " # We assemble the unique data files from which we need value data\n", + " vals_file_map = defaultdict(list)\n", + " for t in vals_required:\n", + " data_files = (\n", + " self.data_file_surface(t),\n", + " self.data_file_vertical(t),\n", + " )\n", + " vals_file_map[data_files].append(t)\n", + "\n", + " # We assemble the unique data files from which we need static data\n", + " stat_file_map = defaultdict(list)\n", + " for t in stat_required:\n", + " data_files = (\n", + " self.data_file_surface(t),\n", + " self.data_file_vertical(t),\n", + " )\n", + " stat_file_map[data_files].append(t)\n", + "\n", + " # Load the value data\n", + " data = {}\n", + " for data_files, times in vals_file_map.items():\n", + " for time in times:\n", + " data[time] = self._read_data(data_files, time)\n", + "\n", + " # Combine times\n", + " sample_data = {}\n", + "\n", + " input_upl = np.stack([data[t][\"vert\"] for t in spec.inputs], axis=2)\n", + " sample_data[\"ulv_vals\"] = input_upl\n", + "\n", + " target_upl = data[spec.target][\"vert\"]\n", + " sample_data[\"ulv_tars\"] = target_upl[:, :, None]\n", + "\n", + " input_sur = np.stack([data[t][\"surf\"] for t in spec.inputs], axis=1)\n", + " sample_data[\"sur_vals\"] = input_sur\n", + "\n", + " target_sur = data[spec.target][\"surf\"]\n", + " sample_data[\"sur_tars\"] = target_sur[:, None]\n", + "\n", + " # Load the static data\n", + " data_files, times = stat_file_map.popitem()\n", + " time = times[0].dayofyear, times[0].hour\n", + " sample_data[\"sur_static\"] = self._read_static_data(\n", + " data_files[0], *time\n", + " )\n", + "\n", + " # If required load the surface data\n", + " if self._require_clim:\n", + " ci_year, ci_hour = spec.climatology_info\n", + "\n", + " surf_file = self.data_file_surface_climate(\n", + " dayofyear=ci_year,\n", + " hourofday=ci_hour,\n", + " )\n", + "\n", + " vert_file = self.data_file_vertical_climate(\n", + " dayofyear=ci_year,\n", + " hourofday=ci_hour,\n", + " )\n", + "\n", + " clim_data = self._read_climate((surf_file, vert_file))\n", + "\n", + " sample_data[\"sur_climate\"] = clim_data[\"surf\"]\n", + " sample_data[\"ulv_climate\"] = clim_data[\"vert\"]\n", + "\n", + " # Move the data from numpy to torch\n", + " sample_data = self._to_torch(sample_data, dtype=self.dtype)\n", + "\n", + " # Optionally roll\n", + " if len(self._roll_longitudes) > 0:\n", + " roll_by = random.choice(self._roll_longitudes)\n", + " sample_data = self._lat_roll(sample_data, roll_by)\n", + "\n", + " # Now that we have rolled, we can add the static data\n", + " sample_data[\"lead_time\"] = spec.lead_time\n", + " sample_data[\"input_time\"] = spec.input_time\n", + "\n", + " return sample_data\n", + "\n", + " def get_data(\n", + " self, timestamp: pd.Timestamp, input_time: int, lead_time: int\n", + " ) -> dict[str, Tensor | int]:\n", + " \"\"\"\n", + " Loads data based on timestamp and lead time.\n", + " Args:\n", + " timestamp: Timestamp.\n", + " input_time: time between input samples.\n", + " lead_time: lead time.\n", + " Returns:\n", + " Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", + " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',\n", + " 'lead_time'.\n", + " \"\"\"\n", + " spec = SampleSpec.get(timestamp, -input_time, lead_time)\n", + " sample_data = self.get_data_from_sample_spec(spec)\n", + " return sample_data\n", + "\n", + " def __getitem__(self, idx: int) -> dict[str, Tensor | int]:\n", + " \"\"\"\n", + " Loads data based on sample index and random choice of sample.\n", + " Args:\n", + " idx: Sample index.\n", + " Returns:\n", + " Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", + " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',\n", + " 'lead_time', 'input_time'.\n", + " \"\"\"\n", + " sample_set = self.samples[idx]\n", + " timestamp, input_time, lead_time, *nsteps = random.choice(sample_set)\n", + " sample_data = self.get_data(timestamp, input_time, lead_time)\n", + " return sample_data\n", + "\n", + " def __len__(self):\n", + " return len(self.samples)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import functools as ft\n", + "import random\n", + "from collections import defaultdict\n", + "from copy import deepcopy\n", + "from pathlib import Path\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "from torch import Tensor\n", + "\n", + "# from PrithviWxC.dataloaders.merra2 import Merra2Dataset, SampleSpec\n", + "\n", + "\n", + "def preproc(\n", + " batch: list[dict[str, int | float | Tensor]], padding: dict[tuple[int]]\n", + ") -> dict[str, Tensor]:\n", + " \"\"\"Prepressing function for MERRA2 Dataset\n", + "\n", + " Args:\n", + " batch (dict): List of training samples, each sample should be a\n", + " dictionary with the following keys::\n", + "\n", + " 'sur_static': Numpy array of shape (3, lat, lon). For each pixel (lat, lon), the first dimension indexes sin(lat), cos(lon), sin(lon).\n", + " 'sur_vals': Torch tensor of shape (parameter, time, lat, lon).\n", + " 'sur_tars': Torch tensor of shape (parameter, time, lat, lon).\n", + " 'ulv_vals': Torch tensor of shape (parameter, level, time, lat, lon).\n", + " 'ulv_tars': Torch tensor of shape (parameter, level, time, lat, lon).\n", + " 'sur_climate': Torch tensor of shape (nstep, parameter, lat, lon)\n", + " 'ulv_climate': Torch tensor of shape (nstep parameter, level, lat, lon)\n", + " 'lead_time': Integer.\n", + " 'input_time': Interger\n", + "\n", + " padding: Dictionary with keys 'level', 'lat', 'lon', each of dim 2.\n", + "\n", + " Returns:\n", + " Dictionary with the following keys::\n", + "\n", + " 'x': [batch, time, parameter, lat, lon]\n", + " 'ys': [batch, nsteps, parameter, lat, lon]\n", + " 'static': [batch, nstep, parameter, lat, lon]\n", + " 'lead_time': [batch]\n", + " 'input_time': [batch]\n", + " 'climate (Optional)': [batch, nsteps, parameter, lat, lon]\n", + "\n", + " Note:\n", + " Here, for x and ys, 'parameter' is [surface parameter, upper level,\n", + " parameter x level]. Similarly for the static information we have\n", + " [sin(lat), cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod),\n", + " ...].\n", + " \"\"\" # noqa: E501\n", + "\n", + " b0 = batch[0]\n", + " nbatch = len(batch)\n", + " data_keys = set(b0.keys())\n", + "\n", + " essential_keys = {\n", + " \"sur_static\",\n", + " \"sur_vals\",\n", + " \"sur_tars\",\n", + " \"ulv_vals\",\n", + " \"ulv_tars\",\n", + " \"input_time\",\n", + " \"lead_time\",\n", + " }\n", + "\n", + " climate_keys = {\n", + " \"sur_climate\",\n", + " \"ulv_climate\",\n", + " }\n", + "\n", + " all_keys = essential_keys | climate_keys\n", + "\n", + " if not essential_keys.issubset(data_keys):\n", + " raise ValueError(\"Missing essential keys.\")\n", + "\n", + " if not data_keys.issubset(all_keys):\n", + " raise ValueError(\"Unexpected keys in batch.\")\n", + "\n", + " # Bring all tensors from the batch into a single tensor\n", + " upl_x = torch.empty((nbatch, *b0[\"ulv_vals\"].shape))\n", + " upl_y = torch.empty((nbatch, *b0[\"ulv_tars\"].shape))\n", + "\n", + " sur_x = torch.empty((nbatch, *b0[\"sur_vals\"].shape))\n", + " sur_y = torch.empty((nbatch, *b0[\"sur_tars\"].shape))\n", + "\n", + " sur_sta = torch.empty((nbatch, *b0[\"sur_static\"].shape))\n", + "\n", + " lead_time = torch.empty(\n", + " (nbatch, *b0[\"lead_time\"].shape),\n", + " dtype=torch.float32,\n", + " )\n", + " input_time = torch.empty((nbatch,), dtype=torch.float32)\n", + "\n", + " for i, rec in enumerate(batch):\n", + " sur_x[i] = torch.Tensor(rec[\"sur_vals\"])\n", + " sur_y[i] = torch.Tensor(rec[\"sur_tars\"])\n", + "\n", + " upl_x[i] = torch.Tensor(rec[\"ulv_vals\"])\n", + " upl_y[i] = torch.Tensor(rec[\"ulv_tars\"])\n", + "\n", + " sur_sta[i] = torch.Tensor(rec[\"sur_static\"])\n", + "\n", + " lead_time[i] = rec[\"lead_time\"]\n", + " input_time[i] = rec[\"input_time\"]\n", + "\n", + " return_value = {\n", + " \"lead_time\": lead_time,\n", + " \"input_time\": input_time,\n", + " \"target_time\": torch.sum(lead_time).reshape(-1),\n", + " }\n", + "\n", + " # Reshape (batch, parameter, level, time, lat, lon)\n", + " # -> (batch, time, parameter, level, lat, lon)\n", + " upl_x = upl_x.permute((0, 3, 1, 2, 4, 5))\n", + " upl_y = upl_y.permute((0, 3, 1, 2, 4, 5))\n", + "\n", + " # Reshape (batch, parameter, time, lat, lon)\n", + " # -> (batch, time, parameter, lat, lon)\n", + " sur_x = sur_x.permute((0, 2, 1, 3, 4))\n", + " sur_y = sur_y.permute((0, 2, 1, 3, 4))\n", + "\n", + " # Pad\n", + " padding_2d = (*padding[\"lon\"], *padding[\"lat\"])\n", + "\n", + " def pad2d(x):\n", + " return torch.nn.functional.pad(x, padding_2d, mode=\"constant\", value=0)\n", + "\n", + " padding_3d = (*padding[\"lon\"], *padding[\"lat\"], *padding[\"level\"])\n", + "\n", + " def pad3d(x):\n", + " return torch.nn.functional.pad(x, padding_3d, mode=\"constant\", value=0)\n", + "\n", + " sur_x = pad2d(sur_x).contiguous()\n", + " upl_x = pad3d(upl_x).contiguous()\n", + " sur_y = pad2d(sur_y).contiguous()\n", + " upl_y = pad3d(upl_y).contiguous()\n", + " return_value[\"statics\"] = pad2d(sur_sta).contiguous()\n", + "\n", + " # We stack along the combined parameter level dimension\n", + " return_value[\"x\"] = torch.cat(\n", + " (sur_x, upl_x.view(*upl_x.shape[:2], -1, *upl_x.shape[4:])), dim=2\n", + " )\n", + " return_value[\"ys\"] = torch.cat(\n", + " (sur_y, upl_y.view(*upl_y.shape[:2], -1, *upl_y.shape[4:])), dim=2\n", + " )\n", + "\n", + " if climate_keys.issubset(data_keys):\n", + " sur_climate = torch.empty((nbatch, *b0[\"sur_climate\"].shape))\n", + " ulv_climate = torch.empty((nbatch, *b0[\"ulv_climate\"].shape))\n", + " for i, rec in enumerate(batch):\n", + " sur_climate[i] = rec[\"sur_climate\"]\n", + " ulv_climate[i] = rec[\"ulv_climate\"]\n", + " sur_climate = pad2d(sur_climate)\n", + " ulv_climate = pad3d(ulv_climate)\n", + "\n", + " ulv_climate = ulv_climate.view(\n", + " *ulv_climate.shape[:2], -1, *ulv_climate.shape[4:]\n", + " )\n", + " return_value[\"climates\"] = torch.cat((sur_climate, ulv_climate), dim=2)\n", + "\n", + " return return_value\n", + "\n", + "\n", + "class RolloutSpec(SampleSpec):\n", + " \"\"\"\n", + " A data class to collect the information used to define a rollout sample.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " inputs: tuple[pd.Timestamp, pd.Timestamp],\n", + " lead_time: int,\n", + " target: pd.Timestamp,\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " inputs: Tuple of timestamps. In ascending order.\n", + " lead_time: Lead time. In hours.\n", + " target: Timestamp of the target. Can be before or after the inputs.\n", + " \"\"\"\n", + " super().__init__(inputs, lead_time, target)\n", + "\n", + " self.dt = dt = pd.Timedelta(lead_time, unit=\"h\")\n", + " self.inters = list(pd.date_range(inputs[-1], target, freq=dt))\n", + "\n", + " self._ctimes = deepcopy(self.inters)\n", + " self.stat_times = deepcopy(self.inters)\n", + "\n", + " self.stat_times.pop(-1)\n", + " self._ctimes.pop(0)\n", + " self.inters.pop(0)\n", + " self.inters.pop(-1)\n", + "\n", + " self.times = [*inputs, *self.inters, target]\n", + " self.targets = self.times[2:]\n", + " self.nsteps = len(self.times) - 2\n", + "\n", + " @property\n", + " def climatology_info(self) -> dict[pd.Timestamp, tuple[int, int]]:\n", + " \"\"\"Returns information required to obtain climatology data.\n", + " Returns:\n", + " list: list containing required climatology info.\n", + " \"\"\"\n", + " return [(min(t.dayofyear, 365), t.hour) for t in self._ctimes]\n", + "\n", + " def _info_str(self) -> str:\n", + " iso_8601 = \"%Y-%m-%dT%H:%M:%S\"\n", + "\n", + " inter_str = \"\\n\".join(t.strftime(iso_8601) for t in self.inters)\n", + "\n", + " return (\n", + " f\"Issue time: {self.inputs[1].strftime(iso_8601)}\\n\"\n", + " f\"Lead time: {self.lead_time} hours ahead\\n\"\n", + " f\"Target time: {self.target.strftime(iso_8601)}\\n\"\n", + " f\"Intermediate times: {inter_str}\"\n", + " )\n", + "\n", + " @classmethod\n", + " def get(cls, timestamp: pd.Timestamp, lead_time: int, nsteps: int):\n", + " \"\"\"Given a timestamp and lead time, generates a RolloutSpec object\n", + " describing the sample further.\n", + "\n", + " Args:\n", + " timestamp: Timstamp (issue time) of the sample.\n", + " lead_time: Lead time. In hours.\n", + "\n", + " Returns:\n", + " SampleSpec object.\n", + " \"\"\"\n", + " if lead_time > 0:\n", + " dt = pd.to_timedelta(lead_time, unit=\"h\")\n", + " timestamp_target = timestamp + nsteps * dt\n", + " else:\n", + " raise ValueError(\"Rollout is only forwards\")\n", + "\n", + " spec = cls(\n", + " inputs=(timestamp - dt, timestamp),\n", + " lead_time=lead_time,\n", + " target=timestamp_target,\n", + " )\n", + "\n", + " return spec\n", + "\n", + " def __repr__(self) -> str:\n", + " return self._info_str()\n", + "\n", + " def __str__(self) -> str:\n", + " return self._info_str()\n", + "\n", + "\n", + "class Merra2RolloutDataset(Merra2Dataset):\n", + " \"\"\"Dataset class that read MERRA2 data for performing rollout.\n", + "\n", + " Implementation details::\n", + "\n", + " Samples stores the list of valid samples. This takes the form\n", + " ```\n", + " [\n", + " [(timestamp 1, -input_time, n_steps)],\n", + " [(timestamp 2, -input_time, n_steps)],\n", + " ]\n", + " ```\n", + " The nested list is for compatibility reasons with Merra2Dataset. Note\n", + " that input time and n_steps are always the same value. For some reason\n", + " the sign of input_time is the opposite to that in Merra2Dataset\n", + " \"\"\"\n", + "\n", + " input_time_len = 2\n", + "\n", + " def __init__(\n", + " self,\n", + " time_range: tuple[str | pd.Timestamp, str | pd.Timestamp],\n", + " input_time: int | float | pd.Timedelta,\n", + " lead_time: int | float,\n", + " data_path_surface: str | Path,\n", + " data_path_vertical: str | Path,\n", + " climatology_path_surface: str | Path | None,\n", + " climatology_path_vertical: str | Path | None,\n", + " surface_vars: list[str],\n", + " static_surface_vars: list[str],\n", + " vertical_vars: list[str],\n", + " levels: list[float],\n", + " roll_longitudes: int = 0,\n", + " positional_encoding: str = \"absolute\",\n", + " ):\n", + " \"\"\"\n", + " Args:\n", + " time_range: time range to consider when building dataset\n", + " input_time: requested time between inputs\n", + " lead_time: requested time to predict\n", + " data_path_surface: path of surface data directory\n", + " data_path_vertical: path of vertical data directory\n", + " climatology_path_surface: path of surface climatology data\n", + " directory\n", + " climatology_path_vertical: path of vertical climatology data\n", + " directory\n", + " surface_vars: surface variables to return\n", + " static_surface_vars: static surface variables to return\n", + " vertical_vars: vertical variables to return\n", + " levels: MERA2 vertical levels to consider\n", + " roll_longitudes: Whether and now uch to randomly roll latitudes by.\n", + " Defaults to 0.\n", + " positional_encoding: The type of possitional encodeing to use.\n", + " Defaults to \"absolute\".\n", + "\n", + " Raises:\n", + " ValueError: If lead time is not integer multiple of input time\n", + " \"\"\"\n", + "\n", + " self._target_lead = lead_time\n", + "\n", + " if isinstance(input_time, int) or isinstance(input_time, float):\n", + " self.timedelta_input = pd.to_timedelta(-input_time, unit=\"h\")\n", + " else:\n", + " self.timedelta_input = -input_time\n", + "\n", + " lead_times = [self.timedelta_input / pd.to_timedelta(1, unit=\"h\")]\n", + "\n", + " super().__init__(\n", + " time_range,\n", + " lead_times,\n", + " [input_time],\n", + " data_path_surface,\n", + " data_path_vertical,\n", + " climatology_path_surface,\n", + " climatology_path_vertical,\n", + " surface_vars,\n", + " static_surface_vars,\n", + " vertical_vars,\n", + " levels,\n", + " roll_longitudes,\n", + " positional_encoding,\n", + " )\n", + "\n", + " nstep_float = (\n", + " pd.to_timedelta(self._target_lead, unit=\"h\") / self.timedelta_input\n", + " )\n", + "\n", + " if abs(nstep_float % 1) > 1e-5:\n", + " raise ValueError(\"Leadtime not multiple of input time\")\n", + "\n", + " self.nsteps = round(nstep_float)\n", + "\n", + " @ft.cached_property\n", + " def samples(self) -> list[tuple[pd.Timestamp, int, int]]:\n", + " \"\"\"Generates list of all valid samlpes.\n", + "\n", + " Returns:\n", + " List of tuples (timestamp, input time, lead time).\n", + " \"\"\"\n", + " valid_samples = []\n", + "\n", + " for timestamp in sorted(self.valid_timestamps):\n", + " timestamp_samples = []\n", + " for lt in self.lead_times:\n", + " spec = RolloutSpec.get(timestamp, lt, self.nsteps)\n", + "\n", + " if self._data_available(spec):\n", + " timestamp_samples.append(\n", + " (timestamp, self.input_times[0], lt, self.nsteps)\n", + " )\n", + "\n", + " if timestamp_samples:\n", + " valid_samples.append(timestamp_samples)\n", + "\n", + " return valid_samples\n", + "\n", + " def get_data_from_rollout_spec(\n", + " self, spec: RolloutSpec\n", + " ) -> dict[str, Tensor | int | float]:\n", + " \"\"\"Loads and assembles sample data given a RolloutSpec object.\n", + "\n", + " Args:\n", + " spec (RolloutSpec): Full details regarding the data to be loaded\n", + " Returns:\n", + " dict: Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", + " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',c'lead_time',\n", + " 'input_time'. For each, the value is as follows::\n", + "\n", + " {\n", + " 'sur_static': Torch tensor of shape [parameter, lat, lon]. For\n", + " each pixel (lat, lon), the first 7 dimensions index sin(lat),\n", + " cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod).\n", + " Where doy is the day of the year [1, 366] and hod the hour of\n", + " the day [0, 23].\n", + " 'sur_vals': Torch tensor of shape [parameter, time, lat, lon].\n", + " 'sur_tars': Torch tensor of shape [parameter, time, lat, lon].\n", + " 'ulv_vals': Torch tensor of shape\n", + " [parameter, level, time, lat, lon].\n", + " 'ulv_tars': Torch tensor of shape\n", + " [nsteps, parameter, level, time, lat, lon].\n", + " 'sur_climate': Torch tensor of shape\n", + " [nsteps, parameter, lat, lon].\n", + " 'ulv_climate': Torch tensor of shape\n", + " [nsteps, paramter, level, lat, lon].\n", + " 'lead_time': Float.\n", + " 'input_time': Float.\n", + " }\n", + "\n", + " \"\"\"\n", + "\n", + " # We assemble the unique timestamps for which we need data.\n", + " vals_required = {*spec.times}\n", + " stat_required = {*spec.stat_times}\n", + "\n", + " # We assemble the unique data files from which we need value data\n", + " vals_file_map = defaultdict(list)\n", + " for t in vals_required:\n", + " data_files = (\n", + " self.data_file_surface(t),\n", + " self.data_file_vertical(t),\n", + " )\n", + " vals_file_map[data_files].append(t)\n", + "\n", + " # We assemble the unique data files from which we need static data\n", + " stat_file_map = defaultdict(list)\n", + " for t in stat_required:\n", + " data_files = (\n", + " self.data_file_surface(t),\n", + " self.data_file_vertical(t),\n", + " )\n", + " stat_file_map[data_files].append(t)\n", + "\n", + " # Load the value data\n", + " data = {}\n", + " for data_files, times in vals_file_map.items():\n", + " for time in times:\n", + " data[time] = self._read_data(data_files, time)\n", + "\n", + " # Load the static data\n", + " stat = {}\n", + " for data_files, times in stat_file_map.items():\n", + " for time in times:\n", + " hod, doy = time.hour, time.dayofyear\n", + " stat[time] = self._read_static_data(data_files[0], hod, doy)\n", + "\n", + " # Combine times\n", + " sample_data = {}\n", + "\n", + " input_upl = np.stack([data[t][\"vert\"] for t in spec.inputs], axis=2)\n", + " sample_data[\"ulv_vals\"] = input_upl\n", + "\n", + " target_upl = np.stack([data[t][\"vert\"] for t in spec.targets], axis=2)\n", + " sample_data[\"ulv_tars\"] = target_upl\n", + "\n", + " input_sur = np.stack([data[t][\"surf\"] for t in spec.inputs], axis=1)\n", + " sample_data[\"sur_vals\"] = input_sur\n", + "\n", + " target_sur = np.stack([data[t][\"surf\"] for t in spec.targets], axis=1)\n", + " sample_data[\"sur_tars\"] = target_sur\n", + "\n", + " # Load the static data\n", + " static = np.stack([stat[t] for t in spec.stat_times], axis=0)\n", + " sample_data[\"sur_static\"] = static\n", + "\n", + " # If required load the climate data\n", + " if self._require_clim:\n", + " clim_data = {}\n", + " for ci in spec.climatology_info:\n", + " ci_year, ci_hour = ci\n", + "\n", + " surf_file = self.data_file_surface_climate(\n", + " dayofyear=ci_year,\n", + " hourofday=ci_hour,\n", + " )\n", + "\n", + " vert_file = self.data_file_vertical_climate(\n", + " dayofyear=ci_year,\n", + " hourofday=ci_hour,\n", + " )\n", + "\n", + " clim_data[ci] = self._read_climate((surf_file, vert_file))\n", + "\n", + " clim_surf = [clim_data[ci][\"surf\"] for ci in spec.climatology_info]\n", + " sample_data[\"sur_climate\"] = np.stack(clim_surf, axis=0)\n", + "\n", + " clim_surf = [clim_data[ci][\"vert\"] for ci in spec.climatology_info]\n", + " sample_data[\"ulv_climate\"] = np.stack(clim_surf, axis=0)\n", + "\n", + " # Move the data from numpy to torch\n", + " sample_data = self._to_torch(sample_data, dtype=self.dtype)\n", + "\n", + " # Optionally roll\n", + " if len(self._roll_longitudes) > 0:\n", + " roll_by = random.choice(self._roll_longitudes)\n", + " sample_data = self._lat_roll(sample_data, roll_by)\n", + "\n", + " # Now that we have rolled, we can add the static data\n", + " lt = torch.tensor([spec.lead_time] * self.nsteps).to(self.dtype)\n", + " sample_data[\"lead_time\"] = lt\n", + " sample_data[\"input_time\"] = spec.input_time\n", + "\n", + " return sample_data\n", + "\n", + " def get_data(\n", + " self, timestamp: pd.Timestamp, *args, **kwargs\n", + " ) -> dict[Tensor | int]:\n", + " \"\"\"Loads data based on timestamp and lead time.\n", + "\n", + " Args:\n", + " timestamp: Timestamp.\n", + " Returns:\n", + " Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',\n", + " 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',\n", + " 'lead_time', 'input_time'\n", + " \"\"\"\n", + " rollout_spec = RolloutSpec.get(\n", + " timestamp, self.lead_times[0], self.nsteps\n", + " )\n", + " sample_data = self.get_data_from_rollout_spec(rollout_spec)\n", + " return sample_data\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "# from PrithviWxC.dataloaders.merra2_rollout import Merra2RolloutDataset\n", + "\n", + "dataset = Merra2RolloutDataset(\n", + " time_range=time_range,\n", + " lead_time=lead_time,\n", + " input_time=input_time,\n", + " data_path_surface=surf_dir,\n", + " data_path_vertical=vert_dir,\n", + " climatology_path_surface=surf_clim_dir,\n", + " climatology_path_vertical=vert_clim_dir,\n", + " surface_vars=surface_vars,\n", + " static_surface_vars=static_surface_vars,\n", + " vertical_vars=vertical_vars,\n", + " levels=levels,\n", + " positional_encoding=positional_encoding,\n", + ")\n", + "assert len(dataset) > 0, \"There doesn't seem to be any valid data.\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model\n", + "### Scalers and other hyperparameters\n", + "Again, this setup is similar as before." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "# from PrithviWxC.dataloaders.merra2 import (\n", + "# input_scalers,\n", + "# output_scalers,\n", + "# static_input_scalers,\n", + "# )\n", + "\n", + "surf_in_scal_path = Path(\"./climatology/musigma_surface.nc\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=f\"climatology/{surf_in_scal_path.name}\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "vert_in_scal_path = Path(\"./climatology/musigma_vertical.nc\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=f\"climatology/{vert_in_scal_path.name}\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "surf_out_scal_path = Path(\"./climatology/anomaly_variance_surface.nc\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=f\"climatology/{surf_out_scal_path.name}\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "vert_out_scal_path = Path(\"./climatology/anomaly_variance_vertical.nc\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.2300m.v1\",\n", + " filename=f\"climatology/{vert_out_scal_path.name}\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.rollout.2300m.v1\",\n", + " filename=\"config.yaml\",\n", + " local_dir=\".\",\n", + ")\n", + "\n", + "in_mu, in_sig = input_scalers(\n", + " surface_vars,\n", + " vertical_vars,\n", + " levels,\n", + " surf_in_scal_path,\n", + " vert_in_scal_path,\n", + ")\n", + "\n", + "output_sig = output_scalers(\n", + " surface_vars,\n", + " vertical_vars,\n", + " levels,\n", + " surf_out_scal_path,\n", + " vert_out_scal_path,\n", + ")\n", + "\n", + "static_mu, static_sig = static_input_scalers(\n", + " surf_in_scal_path,\n", + " static_surface_vars,\n", + ")\n", + "\n", + "residual = \"none\"\n", + "masking_mode = \"local\"\n", + "decoder_shifting = True\n", + "masking_ratio = 0.99" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model init\n", + "We can now build and load the pretrained weights, note that you should use the\n", + "rollout version of the weights." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'weights\\\\prithvi.wxc.rollout.2300m.v1.pt'" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "weights_path = Path(\"./weights/prithvi.wxc.rollout.2300m.v1.pt\")\n", + "hf_hub_download(\n", + " repo_id=\"Prithvi-WxC/prithvi.wxc.rollout.2300m.v1\",\n", + " filename=weights_path.name,\n", + " local_dir=\"./weights\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from functools import cached_property\n", + "from importlib.metadata import version\n", + "\n", + "from torch import Tensor\n", + "from torch.utils.checkpoint import checkpoint\n", + "\n", + "if version(\"torch\") > \"2.3.0\":\n", + " from torch.nn.attention import SDPBackend, sdpa_kernel\n", + "import numpy as np\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "# DropPath code is straight from timm\n", + "# (https://huggingface.co/spaces/Roll20/pet_score/blame/main/lib/timm/models/layers/drop.py)\n", + "def drop_path(\n", + " x: Tensor,\n", + " drop_prob: float = 0.0,\n", + " training: bool = False,\n", + " scale_by_keep: bool = True,\n", + ") -> Tensor:\n", + " \"\"\"Drop paths (Stochastic Depth) per sample (when applied in main path of\n", + " residual blocks). Taken form timm.\n", + "\n", + " Args:\n", + " x (Tensor): Input tensor.\n", + " drop_prob (float): Probability of dropping `x`, defaults to 0.\n", + " training (bool): Whether model is in in traingin of eval mode,\n", + " defaults to False.\n", + " scale_by_keep (bool): Whether the output should scaled by\n", + " (`1 - drop_prob`), defaults to True.\n", + " Returns:\n", + " Tensor: Tensor that may have randomly dropped with proability\n", + " `drop_path`\n", + " \"\"\"\n", + " if drop_prob == 0.0 or not training:\n", + " return x\n", + " keep_prob = 1 - drop_prob\n", + " shape = (x.shape[0],) + (1,) * (x.ndim - 1)\n", + " random_tensor = x.new_empty(shape).bernoulli_(keep_prob)\n", + " if keep_prob > 0.0 and scale_by_keep:\n", + " random_tensor.div_(keep_prob)\n", + " return x * random_tensor\n", + "\n", + "\n", + "class DropPath(nn.Module):\n", + " \"\"\"\n", + " Drop paths (Stochastic Depth) per sample (when applied in main path of\n", + " residual blocks).\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self, drop_prob: float | None = None, scale_by_keep: bool = True\n", + " ) -> None:\n", + " super(DropPath, self).__init__()\n", + " self.drop_prob = drop_prob\n", + " self.scale_by_keep = scale_by_keep\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"Runs drop path on input tensor\n", + "\n", + " Args:\n", + " x: input\n", + "\n", + " Returns:\n", + " tensor: output after drop_path\n", + " \"\"\"\n", + " return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)\n", + "\n", + "\n", + "class Mlp(nn.Module):\n", + " \"\"\"\n", + " Multi layer perceptron.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self, features: int, hidden_features: int, dropout: float = 0.0\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " features: Input/output dimension.\n", + " hidden_features: Hidden dimension.\n", + " dropout: Dropout.\n", + " \"\"\"\n", + " super().__init__()\n", + " self.net = nn.Sequential(\n", + " nn.Linear(features, hidden_features),\n", + " nn.GELU(),\n", + " nn.Dropout(dropout),\n", + " nn.Linear(hidden_features, features),\n", + " nn.Dropout(dropout),\n", + " )\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x (Tesnor): Tensor of shape [..., channel]\n", + " Returns:\n", + " Tenosr: Tensor of same shape as x.\n", + " \"\"\"\n", + " return self.net(x)\n", + "\n", + "\n", + "class LayerNormPassThrough(nn.LayerNorm):\n", + " \"\"\"Normalising layer that allows the attention mask to be passed through\"\"\"\n", + "\n", + " def __init__(self, *args, **kwargs):\n", + " super().__init__(*args, **kwargs)\n", + "\n", + " def forward(\n", + " self, d: tuple[Tensor, Tensor | None]\n", + " ) -> tuple[Tensor, Tensor | None]:\n", + " \"\"\"Forwards function\n", + "\n", + " Args:\n", + " d (tuple): tuple of the data tensor and the attention mask\n", + " Returns:\n", + " output (Tensor): normalised output data\n", + " attn_mask (Tensor): the attention mask that was passed in\n", + " \"\"\"\n", + " input, attn_mask = d\n", + " output = F.layer_norm(\n", + " input, self.normalized_shape, self.weight, self.bias, self.eps\n", + " )\n", + " return output, attn_mask\n", + "\n", + "\n", + "class MultiheadAttention(nn.Module):\n", + " \"\"\"Multihead attention layer for inputs of shape\n", + " [..., sequence, features].\n", + " \"\"\"\n", + "\n", + " def __init__(self, features: int, n_heads: int, dropout: float) -> None:\n", + " \"\"\"\n", + " Args:\n", + " features: Number of features for inputs to the layer.\n", + " n_heads: Number of attention heads. Should be a factor of features.\n", + " (I.e. the layer uses features // n_heads.)\n", + " dropout: Dropout.\n", + " \"\"\" # noqa: E501\n", + " super().__init__()\n", + "\n", + " if (features % n_heads) != 0:\n", + " raise ValueError(\n", + " f\"Features '{features}' is not divisible by heads '{n_heads}'.\"\n", + " )\n", + "\n", + " self.features = features\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + "\n", + " self.qkv_layer = torch.nn.Linear(features, features * 3, bias=False)\n", + " self.w_layer = torch.nn.Linear(features, features, bias=False)\n", + "\n", + " def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " d (tuple): tuple containing Tensor of shape [..., sequence, features] and the attention mask\n", + " Returns:\n", + " Tensor: Tensor of shape [..., sequence, features]\n", + " \"\"\" # noqa: E501\n", + " x, attn_mask = d\n", + "\n", + " if not x.shape[-1] == self.features:\n", + " raise ValueError(\n", + " f\"Expecting tensor with last dimension size {self.features}.\"\n", + " )\n", + "\n", + " passenger_dims = x.shape[:-2]\n", + " B = passenger_dims.numel()\n", + " S = x.shape[-2]\n", + " C = x.shape[-1]\n", + " x = x.reshape(B, S, C)\n", + "\n", + " # x [B, S, C]\n", + " # q, k, v [B, H, S, C/H]\n", + " q, k, v = (\n", + " self.qkv_layer(x)\n", + " .view(B, S, self.n_heads, 3 * (C // self.n_heads))\n", + " .transpose(1, 2)\n", + " .chunk(chunks=3, dim=3)\n", + " )\n", + "\n", + " # Let us enforce either flash (A100+) or memory efficient attention.\n", + " if version(\"torch\") > \"2.3.0\":\n", + " with sdpa_kernel(\n", + " [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]\n", + " ):\n", + " # x [B, H, S, C//H]\n", + " x = F.scaled_dot_product_attention(\n", + " q, k, v, attn_mask=attn_mask, dropout_p=self.dropout\n", + " )\n", + " else:\n", + " with torch.backends.cuda.sdp_kernel(\n", + " enable_flash=True, enable_math=False, enable_mem_efficient=True\n", + " ):\n", + " # x [B, H, S, C//H]\n", + " x = F.scaled_dot_product_attention(\n", + " q, k, v, dropout_p=self.dropout\n", + " )\n", + "\n", + " # x [B, S, C]\n", + " x = x.transpose(1, 2).view(B, S, C)\n", + "\n", + " # x [B, S, C]\n", + " x = self.w_layer(x)\n", + "\n", + " # Back to input shape\n", + " x = x.view(*passenger_dims, S, self.features)\n", + " return x\n", + "\n", + "\n", + "class Transformer(nn.Module):\n", + " \"\"\"\n", + " Transformer for inputs of shape [..., S, features].\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " features: int,\n", + " mlp_multiplier: int,\n", + " n_heads: int,\n", + " dropout: float,\n", + " drop_path: float,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " features: Number of features for inputs to the layer.\n", + " mlp_multiplier: Model uses features*mlp_multiplier hidden units.\n", + " n_heads: Number of attention heads. Should be a factor of features.\n", + " (I.e. the layer uses features // n_heads.) dropout: Dropout.\n", + " drop_path: DropPath.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.features = features\n", + " self.mlp_multiplier = mlp_multiplier\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + " self.drop_path = (\n", + " DropPath(drop_path) if drop_path > 0.0 else nn.Identity()\n", + " )\n", + "\n", + " self.attention = nn.Sequential(\n", + " LayerNormPassThrough(features),\n", + " MultiheadAttention(features, n_heads, dropout),\n", + " )\n", + "\n", + " self.ff = nn.Sequential(\n", + " nn.LayerNorm(features),\n", + " Mlp(\n", + " features=features,\n", + " hidden_features=features * mlp_multiplier,\n", + " dropout=dropout,\n", + " ),\n", + " )\n", + "\n", + " def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x: Tensor of shape [..., sequence, features]\n", + " Returns:\n", + " Tensor: Tensor of shape [..., sequence, features]\n", + " \"\"\"\n", + " x, attn_mask = d\n", + " if not x.shape[-1] == self.features:\n", + " raise ValueError(\n", + " f\"Expecting tensor with last dimension size {self.features}.\"\n", + " )\n", + "\n", + " attention_x = self.attention(d)\n", + "\n", + " x = x + self.drop_path(attention_x)\n", + " x = x + self.drop_path(self.ff(x))\n", + "\n", + " return x\n", + "\n", + "\n", + "class _Shift(nn.Module):\n", + " \"\"\"Private base class for the shifter. This allows some behaviour to be\n", + " easily handled when the shifter isn't used.\n", + " \"\"\"\n", + "\n", + " def __init__(self):\n", + " super().__init__()\n", + "\n", + " self._shifted = False\n", + "\n", + " @torch.no_grad()\n", + " def reset(self) -> None:\n", + " \"\"\"\n", + " Resets the bool tracking whether the data is shifted\n", + " \"\"\"\n", + " self._shifted: bool = False\n", + "\n", + " def forward(self, data: Tensor) -> tuple[Tensor, dict[bool, None]]:\n", + " return data, {True: None, False: None}\n", + "\n", + "\n", + "class SWINShift(_Shift):\n", + " \"\"\"\n", + " Handles the shifting of patches similar to how SWIN works. However if we\n", + " shift the latitudes then the poles will wrap and potentially that might be\n", + " problematic. The possition tokens should handle it but masking is safer.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " mu_shape: tuple[int, int],\n", + " global_shape: tuple[int, int],\n", + " local_shape: tuple[int, int],\n", + " patch_shape: tuple[int, int],\n", + " n_context_tokens: int = 2,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " mu_shape: the shape to the masking units\n", + " global_shape: number of global patches in lat and lon\n", + " local_shape: size of the local patches\n", + " patch_shape: patch size\n", + " n_context_token: number of additional context tokens at start of\n", + " _each_ local sequence\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self._mu_shape = ms = mu_shape\n", + " self._g_shape = gs = global_shape\n", + " self._l_shape = ls = local_shape\n", + " self._p_shape = ps = patch_shape\n", + " self._lat_patch = (gs[0], ls[0], gs[1], ls[1])\n", + " self._n_context_tokens = n_context_tokens\n", + "\n", + " self._g_shift_to = tuple(\n", + " int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)\n", + " )\n", + " self._g_shift_from = tuple(\n", + " -int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)\n", + " )\n", + "\n", + " # Define the attention masks for the shifted MaxViT.\n", + " nglobal = global_shape[0] * global_shape[1]\n", + " nlocal = (\n", + " local_shape[0] * local_shape[1] + self._n_context_tokens\n", + " ) # \"+ 1\" for leadtime\n", + "\n", + " lm = torch.ones((nglobal, 1, nlocal, nlocal), dtype=bool)\n", + " mwidth = int(0.5 * local_shape[1]) * local_shape[0]\n", + " lm[\n", + " : gs[1],\n", + " :,\n", + " self._n_context_tokens : mwidth + self._n_context_tokens,\n", + " self._n_context_tokens : mwidth + self._n_context_tokens,\n", + " ] = False\n", + " self.register_buffer(\"local_mask\", lm)\n", + "\n", + " gm = torch.ones((nlocal, 1, nglobal, nglobal), dtype=bool)\n", + " gm[: int(0.5 * ls[1]) * ls[0], :, : gs[1], : gs[1]] = False\n", + " self.register_buffer(\"global_mask\", gm)\n", + "\n", + " def _to_grid_global(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shuffle and reshape the data from the global/local setting back to the\n", + " lat/lon grid setting\n", + " Args:\n", + " x: the data tensor to be shuffled.\n", + " Returns:\n", + " x: data in the global/local setting\n", + " \"\"\"\n", + " nbatch, *other = x.shape\n", + "\n", + " y1 = x.view(nbatch, *self._g_shape, *self._l_shape, -1)\n", + " y2 = y1.permute(0, 5, 1, 3, 2, 4).contiguous()\n", + "\n", + " s = y2.shape\n", + " return y2.view((nbatch, -1, s[2] * s[3], s[4] * s[5]))\n", + "\n", + " def _to_grid_local(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shuffle and reshape the data from the local/global setting to the\n", + " lat/lon grid setting\n", + " Args:\n", + " x: the data tensor to be shuffled.\n", + " Returns:\n", + " x: data in the lat/lon setting.\n", + " \"\"\"\n", + " x = x.transpose(2, 1).contiguous()\n", + " return self._to_grid_global(x)\n", + "\n", + " def _from_grid_global(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shuffle and reshape the data from the lat/lon grid to the global/local\n", + " setting\n", + " Args:\n", + " x: the data tensor to be shuffled.\n", + " Returns:\n", + " x: data in the global/local setting\n", + " \"\"\"\n", + " nbatch, *other = x.shape\n", + "\n", + " z1 = x.view(nbatch, -1, *self._lat_patch)\n", + " z2 = z1.permute(0, 2, 4, 3, 5, 1).contiguous()\n", + "\n", + " s = z2.shape\n", + " return z2.view(nbatch, s[1] * s[2], s[3] * s[4], -1)\n", + "\n", + " def _from_grid_local(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shuffle and reshape the data from the lat/lon grid to the local/global\n", + " setting\n", + " Args:\n", + " x: the data tensor to be shuffled.\n", + " Returns:\n", + " x: data in the local/global setting\n", + " \"\"\"\n", + " x = self._from_grid_global(x)\n", + " return x.transpose(2, 1).contiguous()\n", + "\n", + " def _shift(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Shifts data in the gridded lat/lon setting by half the mask unit shape\n", + " Args:\n", + " x: data to be shifted\n", + " Returns:\n", + " x: either the hsifted or unshifted data\n", + " \"\"\"\n", + " shift = self._g_shift_from if self._shifted else self._g_shift_to\n", + " x_shifted = torch.roll(x, shift, (-2, -1))\n", + "\n", + " self._shifted = not self._shifted\n", + " return x_shifted\n", + "\n", + " def _sep_lt(self, x: Tensor) -> tuple[Tensor, Tensor]:\n", + " \"\"\"\n", + " Seperate off the leadtime from the local patches\n", + " Args:\n", + " x: data to have leadtime removed from\n", + " Returns:\n", + " lt: leadtime\n", + " x: data without the lead time in the local patch\n", + " \"\"\"\n", + " lt_it = x[:, : self._n_context_tokens, :, :]\n", + " x_stripped = x[:, self._n_context_tokens :, :, :]\n", + "\n", + " return lt_it, x_stripped\n", + "\n", + " def forward(self, data: Tensor) -> tuple[Tensor, Tensor]:\n", + " \"\"\"Shift or unshift the the data depending on whether the data is\n", + " already shifted, as defined by self._shifte.\n", + "\n", + " Args:\n", + " data: data to be shifted\n", + " Returns:\n", + " Tensor: shifted data Tensor\n", + " \"\"\"\n", + " lt, x = self._sep_lt(data)\n", + "\n", + " x_grid = self._to_grid_local(x)\n", + " x_shifted = self._shift(x_grid)\n", + " x_patched = self._from_grid_local(x_shifted)\n", + "\n", + " # Mask has to be repeated based on batch size\n", + " n_batch = x_grid.shape[0]\n", + " local_rep = [n_batch] + [1] * (self.local_mask.ndim - 1)\n", + " global_rep = [n_batch] + [1] * (self.global_mask.ndim - 1)\n", + "\n", + " if self._shifted:\n", + " attn_mask = {\n", + " True: self.local_mask.repeat(local_rep),\n", + " False: self.global_mask.repeat(global_rep),\n", + " }\n", + " else:\n", + " attn_mask = {True: None, False: None}\n", + "\n", + " return torch.cat((lt, x_patched), axis=1), attn_mask\n", + "\n", + "\n", + "class LocalGlobalLocalBlock(nn.Module):\n", + " \"\"\"\n", + " Applies alternating block and grid attention. Given a parameter n_blocks,\n", + " the entire module contains 2*n_blocks+1 transformer blocks. The first,\n", + " third, ..., last apply local (block) attention. The second, fourth, ...\n", + " global (grid) attention.\n", + "\n", + " This is heavily inspired by\n", + " Tu et al. \"MaxViT: Multi-Axis Vision Transformer\"\n", + " (https://arxiv.org/abs/2204.01697).\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " features: int,\n", + " mlp_multiplier: int,\n", + " n_heads: int,\n", + " dropout: float,\n", + " n_blocks: int,\n", + " drop_path: float,\n", + " shifter: nn.Module | None = None,\n", + " checkpoint: list[int] | None = None,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " features: Number of features for inputs to the layer.\n", + " mlp_multiplier: Model uses features*mlp_multiplier hidden units.\n", + " n_heads: Number of attention heads. Should be a factor of features.\n", + " (I.e. the layer uses features // n_heads.)\n", + " dropout: Dropout.\n", + " drop_path: DropPath.\n", + " n_blocks: Number of local-global transformer pairs.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.features = features\n", + " self.mlp_multiplier = mlp_multiplier\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + " self.drop_path = drop_path\n", + " self.n_blocks = n_blocks\n", + " self._checkpoint = checkpoint or []\n", + "\n", + " if not all(0 <= c < 2 * n_blocks + 1 for c in self._checkpoint):\n", + " raise ValueError(\n", + " \"Checkpoints should be 0 <= i < 2*n_blocks+1. \"\n", + " f\"{self._checkpoint=}.\"\n", + " )\n", + "\n", + " self.transformers = nn.ModuleList(\n", + " [\n", + " Transformer(\n", + " features=features,\n", + " mlp_multiplier=mlp_multiplier,\n", + " n_heads=n_heads,\n", + " dropout=dropout,\n", + " drop_path=drop_path,\n", + " )\n", + " for _ in range(2 * n_blocks + 1)\n", + " ]\n", + " )\n", + "\n", + " self.evaluator = [\n", + " self._checkpoint_wrapper\n", + " if i in self._checkpoint\n", + " else lambda m, x: m(x)\n", + " for i, _ in enumerate(self.transformers)\n", + " ]\n", + "\n", + " self.shifter = shifter or _Shift()\n", + "\n", + " @staticmethod\n", + " def _checkpoint_wrapper(\n", + " model: nn.Module, data: tuple[Tensor, Tensor | None]\n", + " ) -> Tensor:\n", + " return checkpoint(model, data, use_reentrant=False)\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x: Tensor of shape::\n", + "\n", + " [batch, global_sequence, local_sequence, features]\n", + "\n", + " Returns:\n", + " Tensor: Tensor of shape::\n", + "\n", + " [batch, global_sequence, local_sequence, features]\n", + " \"\"\"\n", + " if x.shape[-1] != self.features:\n", + " raise ValueError(\n", + " f\"Expecting tensor with last dimension size {self.features}.\"\n", + " )\n", + " if x.ndim != 4:\n", + " raise ValueError(\n", + " f\"Expecting tensor with exactly four dimensions. {x.shape=}.\"\n", + " )\n", + "\n", + " self.shifter.reset()\n", + " local: bool = True\n", + " attn_mask = {True: None, False: None}\n", + "\n", + " transformer_iter = zip(self.evaluator, self.transformers, strict=False)\n", + "\n", + " # First local block\n", + " evaluator, transformer = next(transformer_iter)\n", + " x = evaluator(transformer, (x, attn_mask[local]))\n", + "\n", + " for evaluator, transformer in transformer_iter:\n", + " local = not local\n", + " # We are making exactly 2*n_blocks transposes.\n", + " # So the output has the same shape as input.\n", + " x = x.transpose(1, 2)\n", + "\n", + " x = evaluator(transformer, (x, attn_mask[local]))\n", + "\n", + " if not local:\n", + " x, attn_mask = self.shifter(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "class PatchEmbed(nn.Module):\n", + " \"\"\"\n", + " Patch embedding via 2D convolution.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self, patch_size: int | tuple[int, ...], channels: int, embed_dim: int\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.patch_size = patch_size\n", + " self.channels = channels\n", + " self.embed_dim = embed_dim\n", + "\n", + " self.proj = nn.Conv2d(\n", + " channels,\n", + " embed_dim,\n", + " kernel_size=patch_size,\n", + " stride=patch_size,\n", + " bias=True,\n", + " )\n", + "\n", + " def forward(self, x: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x: Tensor of shape [batch, channels, lat, lon].\n", + " Returns:\n", + " Tensor: Tensor with shape\n", + " [batch, embed_dim, lat//patch_size, lon//patch_size]\n", + " \"\"\"\n", + "\n", + " H, W = x.shape[-2:]\n", + "\n", + " if W % self.patch_size[1] != 0:\n", + " raise ValueError(\n", + " f\"Cannot do patch embedding for tensor of shape {x.size()}\"\n", + " \" with patch size {self.patch_size}. (Dimensions are BSCHW.)\"\n", + " )\n", + " if H % self.patch_size[0] != 0:\n", + " raise ValueError(\n", + " f\"Cannot do patch embedding for tensor of shape {x.size()}\"\n", + " f\" with patch size {self.patch_size}. (Dimensions are BSCHW.)\"\n", + " )\n", + "\n", + " x = self.proj(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "class PrithviWxCEncoderDecoder(nn.Module):\n", + " \"\"\"\n", + " Hiera-MaxViT encoder/decoder code.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " embed_dim: int,\n", + " n_blocks: int,\n", + " mlp_multiplier: float,\n", + " n_heads: int,\n", + " dropout: float,\n", + " drop_path: float,\n", + " shifter: nn.Module | None = None,\n", + " transformer_cp: list[int] | None = None,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " embed_dim: Embedding dimension\n", + " n_blocks: Number of local-global transformer pairs.\n", + " mlp_multiplier: MLP multiplier for hidden features in feed forward\n", + " networks.\n", + " n_heads: Number of attention heads.\n", + " dropout: Dropout.\n", + " drop_path: DropPath.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.embed_dim = embed_dim\n", + " self.n_blocks = n_blocks\n", + " self.mlp_multiplier = mlp_multiplier\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + " self._transformer_cp = transformer_cp\n", + "\n", + " self.lgl_block = LocalGlobalLocalBlock(\n", + " features=embed_dim,\n", + " mlp_multiplier=mlp_multiplier,\n", + " n_heads=n_heads,\n", + " dropout=dropout,\n", + " drop_path=drop_path,\n", + " n_blocks=n_blocks,\n", + " shifter=shifter,\n", + " checkpoint=transformer_cp,\n", + " )\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " x: Tensor of shape\n", + " [batch, global sequence, local sequence, embed_dim]\n", + " Returns:\n", + " Tensor of shape\n", + " [batch, mask_unit_sequence, local_sequence, embed_dim].\n", + " Identical in shape to the input x.\n", + " \"\"\"\n", + "\n", + " x = self.lgl_block(x)\n", + "\n", + " return x\n", + "\n", + "\n", + "class PrithviWxC(nn.Module):\n", + " \"\"\"Encoder-decoder fusing Hiera with MaxViT. See\n", + " - Ryali et al. \"Hiera: A Hierarchical Vision Transformer without the\n", + " Bells-and-Whistles\" (https://arxiv.org/abs/2306.00989)\n", + " - Tu et al. \"MaxViT: Multi-Axis Vision Transformer\"\n", + " (https://arxiv.org/abs/2204.01697)\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " in_channels: int,\n", + " input_size_time: int,\n", + " in_channels_static: int,\n", + " input_scalers_mu: Tensor,\n", + " input_scalers_sigma: Tensor,\n", + " input_scalers_epsilon: float,\n", + " static_input_scalers_mu: Tensor,\n", + " static_input_scalers_sigma: Tensor,\n", + " static_input_scalers_epsilon: float,\n", + " output_scalers: Tensor,\n", + " n_lats_px: int,\n", + " n_lons_px: int,\n", + " patch_size_px: tuple[int],\n", + " mask_unit_size_px: tuple[int],\n", + " mask_ratio_inputs: float,\n", + " embed_dim: int,\n", + " n_blocks_encoder: int,\n", + " n_blocks_decoder: int,\n", + " mlp_multiplier: float,\n", + " n_heads: int,\n", + " dropout: float,\n", + " drop_path: float,\n", + " parameter_dropout: float,\n", + " residual: str,\n", + " masking_mode: str,\n", + " positional_encoding: str,\n", + " decoder_shifting: bool = False,\n", + " checkpoint_encoder: list[int] | None = None,\n", + " checkpoint_decoder: list[int] | None = None,\n", + " ) -> None:\n", + " \"\"\"\n", + " Args:\n", + " in_channels: Number of input channels.\n", + " input_size_time: Number of timestamps in input.\n", + " in_channels_static: Number of input channels for static data.\n", + " input_scalers_mu: Tensor of size (in_channels,). Used to rescale\n", + " input.\n", + " input_scalers_sigma: Tensor of size (in_channels,). Used to rescale\n", + " input.\n", + " input_scalers_epsilon: Float. Used to rescale input.\n", + " static_input_scalers_mu: Tensor of size (in_channels_static). Used\n", + " to rescale static inputs.\n", + " static_input_scalers_sigma: Tensor of size (in_channels_static).\n", + " Used to rescale static inputs.\n", + " static_input_scalers_epsilon: Float. Used to rescale static inputs.\n", + " output_scalers: Tensor of shape (in_channels,). Used to rescale\n", + " output.\n", + " n_lats_px: Total latitudes in data. In pixels.\n", + " n_lons_px: Total longitudes in data. In pixels.\n", + " patch_size_px: Patch size for tokenization. In pixels lat/lon.\n", + " mask_unit_size_px: Size of each mask unit. In pixels lat/lon.\n", + " mask_ratio_inputs: Masking ratio for inputs. 0 to 1.\n", + " embed_dim: Embedding dimension\n", + " n_blocks_encoder: Number of local-global transformer pairs in\n", + " encoder.\n", + " n_blocks_decoder: Number of local-global transformer pairs in\n", + " decoder.\n", + " mlp_multiplier: MLP multiplier for hidden features in feed forward\n", + " networks.\n", + " n_heads: Number of attention heads.\n", + " dropout: Dropout.\n", + " drop_path: DropPath.\n", + " parameter_dropout: Dropout applied to parameters.\n", + " residual: Indicates whether and how model should work as residual\n", + " model. Accepted values are 'climate', 'temporal' and 'none'\n", + " positional_encoding: possible values are\n", + " ['absolute' (default), 'fourier'].\n", + " 'absolute' lat lon encoded in 3 dimensions using sine and\n", + " cosine\n", + " 'fourier' lat/lon to be encoded using various frequencies\n", + " masking_mode: String ['local', 'global', 'both'] that controls the\n", + " type of masking used.\n", + " checkpoint_encoder: List of integers controlling if gradient\n", + " checkpointing is used on encoder.\n", + " Format: [] for no gradient checkpointing. [3, 7] for\n", + " checkpointing after 4th and 8th layer etc.\n", + " checkpoint_decoder: List of integers controlling if gradient\n", + " checkpointing is used on decoder.\n", + " Format: See `checkpoint_encoder`.\n", + " masking_mode: The type of masking to use\n", + " {'global', 'local', 'both'}\n", + " decoder_shifting: Whether to use swin shifting in the decoder.\n", + " \"\"\"\n", + " super().__init__()\n", + "\n", + " self.in_channels = in_channels\n", + " self.input_size_time = input_size_time\n", + " self.in_channels_static = in_channels_static\n", + " self.n_lats_px = n_lats_px\n", + " self.n_lons_px = n_lons_px\n", + " self.patch_size_px = patch_size_px\n", + " self.mask_unit_size_px = mask_unit_size_px\n", + " self.mask_ratio_inputs = mask_ratio_inputs\n", + " self.embed_dim = embed_dim\n", + " self.n_blocks_encoder = n_blocks_encoder\n", + " self.n_blocks_decoder = n_blocks_decoder\n", + " self.mlp_multiplier = mlp_multiplier\n", + " self.n_heads = n_heads\n", + " self.dropout = dropout\n", + " self.drop_path = drop_path\n", + " self.residual = residual\n", + " self._decoder_shift = decoder_shifting\n", + " self.positional_encoding = positional_encoding\n", + " self._checkpoint_encoder = checkpoint_encoder\n", + " self._checkpoint_decoder = checkpoint_decoder\n", + "\n", + " assert self.n_lats_px % self.mask_unit_size_px[0] == 0\n", + " assert self.n_lons_px % self.mask_unit_size_px[1] == 0\n", + " assert self.mask_unit_size_px[0] % self.patch_size_px[0] == 0\n", + " assert self.mask_unit_size_px[1] % self.patch_size_px[1] == 0\n", + "\n", + " if self.patch_size_px[0] != self.patch_size_px[1]:\n", + " raise NotImplementedError(\n", + " \"Current pixel shuffle symmetric patches.\"\n", + " )\n", + "\n", + " self.local_shape_mu = (\n", + " self.mask_unit_size_px[0] // self.patch_size_px[0],\n", + " self.mask_unit_size_px[1] // self.patch_size_px[1],\n", + " )\n", + " self.global_shape_mu = (\n", + " self.n_lats_px // self.mask_unit_size_px[0],\n", + " self.n_lons_px // self.mask_unit_size_px[1],\n", + " )\n", + "\n", + " assert input_scalers_mu.shape == (in_channels,)\n", + " assert input_scalers_sigma.shape == (in_channels,)\n", + " assert output_scalers.shape == (in_channels,)\n", + "\n", + " if self.positional_encoding != \"fourier\":\n", + " assert static_input_scalers_mu.shape == (in_channels_static,)\n", + " assert static_input_scalers_sigma.shape == (in_channels_static,)\n", + "\n", + " # Input shape [batch, time, parameter, lat, lon]\n", + " self.input_scalers_epsilon = input_scalers_epsilon\n", + " self.register_buffer(\n", + " \"input_scalers_mu\", input_scalers_mu.reshape(1, 1, -1, 1, 1)\n", + " )\n", + " self.register_buffer(\n", + " \"input_scalers_sigma\", input_scalers_sigma.reshape(1, 1, -1, 1, 1)\n", + " )\n", + "\n", + " # Static inputs shape [batch, parameter, lat, lon]\n", + " self.static_input_scalers_epsilon = static_input_scalers_epsilon\n", + " self.register_buffer(\n", + " \"static_input_scalers_mu\",\n", + " static_input_scalers_mu.reshape(1, -1, 1, 1),\n", + " )\n", + " self.register_buffer(\n", + " \"static_input_scalers_sigma\",\n", + " static_input_scalers_sigma.reshape(1, -1, 1, 1),\n", + " )\n", + "\n", + " # Output shape [batch, parameter, lat, lon]\n", + " self.register_buffer(\n", + " \"output_scalers\", output_scalers.reshape(1, -1, 1, 1)\n", + " )\n", + "\n", + " self.parameter_dropout = nn.Dropout2d(p=parameter_dropout)\n", + "\n", + " self.patch_embedding = PatchEmbed(\n", + " patch_size=patch_size_px,\n", + " channels=in_channels * input_size_time,\n", + " embed_dim=embed_dim,\n", + " )\n", + "\n", + " if self.residual == \"climate\":\n", + " self.patch_embedding_static = PatchEmbed(\n", + " patch_size=patch_size_px,\n", + " channels=in_channels + in_channels_static,\n", + " embed_dim=embed_dim,\n", + " )\n", + " else:\n", + " self.patch_embedding_static = PatchEmbed(\n", + " patch_size=patch_size_px,\n", + " channels=in_channels_static,\n", + " embed_dim=embed_dim,\n", + " )\n", + "\n", + " self.input_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)\n", + " self.lead_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)\n", + "\n", + " self.mask_token = nn.Parameter(torch.randn(1, 1, 1, self.embed_dim))\n", + " self._nglobal_mu = np.prod(self.global_shape_mu)\n", + " self._global_idx = torch.arange(self._nglobal_mu)\n", + "\n", + " self._nlocal_mu = np.prod(self.local_shape_mu)\n", + " self._local_idx = torch.arange(self._nlocal_mu)\n", + "\n", + " self.encoder = PrithviWxCEncoderDecoder(\n", + " embed_dim=embed_dim,\n", + " n_blocks=n_blocks_encoder,\n", + " mlp_multiplier=mlp_multiplier,\n", + " n_heads=n_heads,\n", + " dropout=dropout,\n", + " drop_path=drop_path,\n", + " transformer_cp=checkpoint_encoder,\n", + " )\n", + "\n", + " if n_blocks_decoder != 0:\n", + " if self._decoder_shift:\n", + " self.decoder_shifter = d_shifter = SWINShift(\n", + " self.mask_unit_size_px,\n", + " self.global_shape_mu,\n", + " self.local_shape_mu,\n", + " self.patch_size_px,\n", + " n_context_tokens=0,\n", + " )\n", + " else:\n", + " self.decoder_shifter = d_shifter = None\n", + "\n", + " self.decoder = PrithviWxCEncoderDecoder(\n", + " embed_dim=embed_dim,\n", + " n_blocks=n_blocks_decoder,\n", + " mlp_multiplier=mlp_multiplier,\n", + " n_heads=n_heads,\n", + " dropout=dropout,\n", + " drop_path=0.0,\n", + " shifter=d_shifter,\n", + " transformer_cp=checkpoint_decoder,\n", + " )\n", + "\n", + " self.unembed = nn.Linear(\n", + " self.embed_dim,\n", + " self.in_channels\n", + " * self.patch_size_px[0]\n", + " * self.patch_size_px[1],\n", + " bias=True,\n", + " )\n", + "\n", + " self.masking_mode = masking_mode.lower()\n", + " match self.masking_mode:\n", + " case \"local\":\n", + " self.generate_mask = self._gen_mask_local\n", + " case \"global\":\n", + " self.generate_mask = self._gen_mask_global\n", + " case \"both\":\n", + " self._mask_both_local: bool = True\n", + " self.generate_mask = self._gen_mask_both\n", + " case _:\n", + " raise ValueError(\n", + " f\"Masking mode '{masking_mode}' not supported\"\n", + " )\n", + "\n", + " def swap_masking(self) -> None:\n", + " self._mask_both_local = not self._mask_both_local\n", + "\n", + " @cached_property\n", + " def n_masked_global(self):\n", + " return int(self.mask_ratio_inputs * np.prod(self.global_shape_mu))\n", + "\n", + " @cached_property\n", + " def n_masked_local(self):\n", + " return int(self.mask_ratio_inputs * np.prod(self.local_shape_mu))\n", + "\n", + " @staticmethod\n", + " def _shuffle_along_axis(a, axis):\n", + " idx = torch.argsort(input=torch.rand(*a.shape), dim=axis)\n", + " return torch.gather(a, dim=axis, index=idx)\n", + "\n", + " def _gen_mask_local(self, sizes: tuple[int]) -> tuple[Tensor]:\n", + " \"\"\"\n", + " Args:\n", + " batch_size: Number of elements in batch\n", + " Returns:\n", + " Tuple of torch tensors. [indices masked, indices unmasked].\n", + " Each of these is a tensor of shape (batch, global sequene)\n", + " \"\"\"\n", + " # Identify which indices (values) should be masked\n", + "\n", + " maskable_indices = self._local_idx.view(1, -1).expand(*sizes[:2], -1)\n", + "\n", + " maskable_indices = self._shuffle_along_axis(maskable_indices, 2)\n", + "\n", + " indices_masked = maskable_indices[:, :, : self.n_masked_local]\n", + " indices_unmasked = maskable_indices[:, :, self.n_masked_local :]\n", + "\n", + " return indices_masked, indices_unmasked\n", + "\n", + " def _gen_mask_global(self, sizes: tuple[int]) -> tuple[Tensor]:\n", + " \"\"\"\n", + " Args:\n", + " batch_size: Number of elements in batch\n", + " Returns:\n", + " Tuple of torch tensors. [indices masked, indices unmasked].\n", + " Each of these is a tensor of shape (batch, global sequene)\n", + " \"\"\"\n", + " # Identify which indices (values) should be masked\n", + "\n", + " maskable_indices = self._global_idx.view(1, -1).expand(*sizes[:1], -1)\n", + "\n", + " maskable_indices = self._shuffle_along_axis(maskable_indices, 1)\n", + "\n", + " indices_masked = maskable_indices[:, : self.n_masked_global]\n", + " indices_unmasked = maskable_indices[:, self.n_masked_global :]\n", + "\n", + " return indices_masked, indices_unmasked\n", + "\n", + " def _gen_mask_both(self, sizes: tuple[int]) -> tuple[Tensor]:\n", + " if self._mask_both_local:\n", + " return self._gen_mask_local(sizes)\n", + " else:\n", + " return self._gen_mask_global(sizes)\n", + "\n", + " @staticmethod\n", + " def reconstruct_batch(\n", + " idx_masked: Tensor,\n", + " idx_unmasked: Tensor,\n", + " data_masked: Tensor,\n", + " data_unmasked: Tensor,\n", + " ) -> Tensor:\n", + " \"\"\"Reconstructs a tensor along the mask unit dimension. Batched\n", + " version.\n", + "\n", + " Args:\n", + " idx_masked: Tensor of shape `batch, mask unit sequence`.\n", + " idx_unmasked: Tensor of shape `batch, mask unit sequence`.\n", + " data_masked: Tensor of shape `batch, mask unit sequence, ...`.\n", + " Should have same size along mask unit sequence dimension as\n", + " idx_masked. Dimensions beyond the first two, marked here as ...\n", + " will typically be `local_sequence, channel` or\n", + " `channel, lat, lon`. These dimensions should agree with\n", + " data_unmasked.\n", + " data_unmasked: Tensor of shape `batch, mask unit sequence, ...`.\n", + " Should have same size along mask unit sequence dimension as\n", + " idx_unmasked. Dimensions beyond the first two, marked here as\n", + " ... will typically be `local_sequence, channel` or `channel,\n", + " lat, lon`. These dimensions should agree with data_masked.\n", + " Returns:\n", + " Tensor: Tensor of same shape as inputs data_masked and\n", + " data_unmasked. I.e. `batch, mask unit sequence, ...`. Index for\n", + " the total data composed of the masked and the unmasked part.\n", + " \"\"\"\n", + " dim: int = idx_masked.ndim\n", + "\n", + " idx_total = torch.argsort(\n", + " torch.cat([idx_masked, idx_unmasked], dim=-1), dim=-1\n", + " )\n", + " idx_total = idx_total.view(\n", + " *idx_total.shape, *[1] * (data_unmasked.ndim - dim)\n", + " )\n", + " idx_total = idx_total.expand(\n", + " *idx_total.shape[:dim], *data_unmasked.shape[dim:]\n", + " )\n", + "\n", + " data = torch.cat([data_masked, data_unmasked], dim=dim - 1)\n", + " data = torch.gather(data, dim=dim - 1, index=idx_total)\n", + "\n", + " return data, idx_total\n", + "\n", + " def fourier_pos_encoding(self, x_static: Tensor) -> Tensor:\n", + " \"\"\"\n", + " Args\n", + " x_static: B x C x H x W. first two channels are lat, and lon\n", + " Returns\n", + " Tensor: Tensor of shape B x E x H x W where E is the embedding\n", + " dimension.\n", + " \"\"\"\n", + "\n", + " # B x C x H x W -> B x 1 x H/P x W/P\n", + " latitudes_patch = F.avg_pool2d(\n", + " x_static[:, [0]],\n", + " kernel_size=self.patch_size_px,\n", + " stride=self.patch_size_px,\n", + " )\n", + " longitudes_patch = F.avg_pool2d(\n", + " x_static[:, [1]],\n", + " kernel_size=self.patch_size_px,\n", + " stride=self.patch_size_px,\n", + " )\n", + "\n", + " modes = (\n", + " torch.arange(self.embed_dim // 4, device=x_static.device).view(\n", + " 1, -1, 1, 1\n", + " )\n", + " + 1.0\n", + " )\n", + " pos_encoding = torch.cat(\n", + " (\n", + " torch.sin(latitudes_patch * modes),\n", + " torch.sin(longitudes_patch * modes),\n", + " torch.cos(latitudes_patch * modes),\n", + " torch.cos(longitudes_patch * modes),\n", + " ),\n", + " axis=1,\n", + " )\n", + "\n", + " return pos_encoding # B x E x H/P x W/P\n", + "\n", + " def time_encoding(self, input_time, lead_time):\n", + " \"\"\"\n", + " Args:\n", + " input_time: Tensor of shape [batch].\n", + " lead_time: Tensor of shape [batch].\n", + " Returns:\n", + " Tensor: Tensor of shape [batch, embed_dim, 1, 1]\n", + " \"\"\"\n", + " input_time = self.input_time_embedding(input_time.view(-1, 1, 1, 1))\n", + " lead_time = self.lead_time_embedding(lead_time.view(-1, 1, 1, 1))\n", + "\n", + " time_encoding = torch.cat(\n", + " (\n", + " torch.cos(input_time),\n", + " torch.cos(lead_time),\n", + " torch.sin(input_time),\n", + " torch.sin(lead_time),\n", + " ),\n", + " axis=3,\n", + " )\n", + " return time_encoding\n", + "\n", + " def to_patching(self, x: Tensor) -> Tensor:\n", + " \"\"\"Transform data from lat/lon space to two axis patching\n", + "\n", + " Args: ->\n", + " x: Tesnor in lat/lon space (N, C, Nlat//P_0, Nlon//P_1)\n", + "\n", + " Returns:\n", + " Tensor in patch space (N, G, L, C)\n", + " \"\"\"\n", + " n_batch = x.shape[0]\n", + "\n", + " x = x.view(\n", + " n_batch,\n", + " -1,\n", + " self.global_shape_mu[0],\n", + " self.local_shape_mu[0],\n", + " self.global_shape_mu[1],\n", + " self.local_shape_mu[1],\n", + " )\n", + " x = x.permute(0, 2, 4, 3, 5, 1).contiguous()\n", + "\n", + " s = x.shape\n", + " return x.view(n_batch, s[1] * s[2], s[3] * s[4], -1)\n", + "\n", + " def from_patching(self, x: Tensor) -> Tensor:\n", + " \"\"\"Transform data from two axis patching to lat/lon space\n", + "\n", + " Args:\n", + " x: Tensor in patch space with shape (N, G, L, C*P_0*P_1)\n", + "\n", + " Returns:\n", + " Tensor: Tensor in lat/lon space\n", + " (N, C*P_0*P_1, Nlat//P_0, Nlon // P_1)\n", + " \"\"\"\n", + " n_batch = x.shape[0]\n", + "\n", + " x = x.view(\n", + " n_batch,\n", + " self.global_shape_mu[0],\n", + " self.global_shape_mu[1],\n", + " self.local_shape_mu[0],\n", + " self.local_shape_mu[1],\n", + " -1,\n", + " )\n", + " x = x.permute(0, 5, 1, 3, 2, 4).contiguous()\n", + "\n", + " s = x.shape\n", + " return x.view(n_batch, -1, s[2] * s[3], s[4] * s[5])\n", + "\n", + " def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:\n", + " \"\"\"\n", + " Args:\n", + " batch: Dictionary the following keys::\n", + "\n", + " 'x': Tensor of shape [batch, time, parameter, lat, lon]\n", + " 'y': Tensor of shape [batch, parameter, lat, lon]\n", + " 'static': Tensor of shape [batch, channel_static, lat, lon]\n", + " 'climate': Optional tensor of shape [batch, parameter, lat, lon]\n", + " 'input_time': Tensor of shape [batch]. Or none.\n", + " 'lead_time': Tensor of shape [batch]. Or none.\n", + "\n", + " Returns:\n", + " Tensor: Tensor of shape [batch, parameter, lat, lon].\n", + " \"\"\" # noqa: E501\n", + " x_rescaled = (batch[\"x\"] - self.input_scalers_mu) / (\n", + " self.input_scalers_sigma + self.input_scalers_epsilon\n", + " )\n", + " batch_size = x_rescaled.shape[0]\n", + "\n", + " if self.positional_encoding == \"fourier\":\n", + " x_static_pos = self.fourier_pos_encoding(batch[\"static\"])\n", + " x_static = (\n", + " batch[\"static\"][:, 2:] - self.static_input_scalers_mu[:, 3:]\n", + " ) / (\n", + " self.static_input_scalers_sigma[:, 3:]\n", + " + self.static_input_scalers_epsilon\n", + " )\n", + " else:\n", + " x_static = (batch[\"static\"] - self.static_input_scalers_mu) / (\n", + " self.static_input_scalers_sigma\n", + " + self.static_input_scalers_epsilon\n", + " )\n", + "\n", + " if self.residual == \"temporal\":\n", + " # We create a residual of same shape as y\n", + " index = torch.where(\n", + " batch[\"lead_time\"] > 0, batch[\"x\"].shape[1] - 1, 0\n", + " )\n", + " index = index.view(-1, 1, 1, 1, 1)\n", + " index = index.expand(batch_size, 1, *batch[\"x\"].shape[2:])\n", + " x_hat = torch.gather(batch[\"x\"], dim=1, index=index)\n", + " x_hat = x_hat.squeeze(1)\n", + " elif self.residual == \"climate\":\n", + " climate_scaled = (\n", + " batch[\"climate\"] - self.input_scalers_mu.view(1, -1, 1, 1)\n", + " ) / (\n", + " self.input_scalers_sigma.view(1, -1, 1, 1)\n", + " + self.input_scalers_epsilon\n", + " )\n", + "\n", + " # [batch, time, parameter, lat, lon]\n", + " # -> [batch, time x parameter, lat, lon]\n", + " x_rescaled = x_rescaled.flatten(1, 2)\n", + " # Parameter dropout\n", + " x_rescaled = self.parameter_dropout(x_rescaled)\n", + "\n", + " x_embedded = self.patch_embedding(x_rescaled)\n", + "\n", + " if self.residual == \"climate\":\n", + " static_embedded = self.patch_embedding_static(\n", + " torch.cat((x_static, climate_scaled), dim=1)\n", + " )\n", + " else:\n", + " static_embedded = self.patch_embedding_static(x_static)\n", + "\n", + " if self.positional_encoding == \"fourier\":\n", + " static_embedded += x_static_pos\n", + "\n", + " x_embedded = self.to_patching(x_embedded)\n", + " static_embedded = self.to_patching(static_embedded)\n", + "\n", + " time_encoding = self.time_encoding(\n", + " batch[\"input_time\"], batch[\"lead_time\"]\n", + " )\n", + "\n", + " tokens = x_embedded + static_embedded + time_encoding\n", + "\n", + " # Now we generate masks based on masking_mode\n", + " indices_masked, indices_unmasked = self.generate_mask(\n", + " (batch_size, self._nglobal_mu)\n", + " )\n", + " indices_masked = indices_masked.to(device=tokens.device)\n", + " indices_unmasked = indices_unmasked.to(device=tokens.device)\n", + " maskdim: int = indices_masked.ndim\n", + "\n", + " # Unmasking\n", + " unmask_view = (*indices_unmasked.shape, *[1] * (tokens.ndim - maskdim))\n", + " unmasked = torch.gather(\n", + " tokens,\n", + " dim=maskdim - 1,\n", + " index=indices_unmasked.view(*unmask_view).expand(\n", + " *indices_unmasked.shape, *tokens.shape[maskdim:]\n", + " ),\n", + " )\n", + "\n", + " # Encoder\n", + " x_encoded = self.encoder(unmasked)\n", + "\n", + " # Generate and position encode the mask tokens\n", + " # [1, 1, 1, embed_dim]\n", + " # -> [batch, global_seq_masked, local seq, embed_dim]\n", + " mask_view = (*indices_masked.shape, *[1] * (tokens.ndim - maskdim))\n", + " masking = self.mask_token.repeat(*static_embedded.shape[:3], 1)\n", + " masked = masking + static_embedded\n", + " masked = torch.gather(\n", + " masked,\n", + " dim=maskdim - 1,\n", + " index=indices_masked.view(*mask_view).expand(\n", + " *indices_masked.shape, *tokens.shape[maskdim:]\n", + " ),\n", + " )\n", + "\n", + " recon, _ = self.reconstruct_batch(\n", + " indices_masked, indices_unmasked, masked, x_encoded\n", + " )\n", + "\n", + " x_decoded = self.decoder(recon)\n", + "\n", + " # Output: [batch, global sequence, local sequence,\n", + " # in_channels * patch_size[0] * patch_size[1]]\n", + " x_unembed = self.unembed(x_decoded)\n", + "\n", + " # Reshape to [batch, global_lat, global_lon, local_lat, local_lon,\n", + " # in_channels * patch_size[0] * patch_size[1]]\n", + " x_out = self.from_patching(x_unembed)\n", + "\n", + " # Pixel shuffle to [batch, in_channels, lat, lon]\n", + " x_out = F.pixel_shuffle(x_out, self.patch_size_px[0])\n", + "\n", + " if self.residual == \"temporal\":\n", + " x_out = self.output_scalers * x_out + x_hat\n", + " elif self.residual == \"climate\":\n", + " x_out = self.output_scalers * x_out + batch[\"climate\"]\n", + " elif self.residual == \"none\":\n", + " x_out = (\n", + " self.output_scalers * x_out\n", + " + self.input_scalers_mu.reshape(1, -1, 1, 1)\n", + " )\n", + "\n", + " return x_out\n" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", + "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", + "\u001b[1;31mClick here for more info. \n", + "\u001b[1;31mView Jupyter log for further details." + ] + } + ], + "source": [ + "import yaml\n", + "\n", + "# from PrithviWxC.model import PrithviWxC\n", + "\n", + "with open(\"./config.yaml\", \"r\") as f:\n", + " config = yaml.safe_load(f)\n", + "\n", + "model = PrithviWxC(\n", + " in_channels=config[\"params\"][\"in_channels\"],\n", + " input_size_time=config[\"params\"][\"input_size_time\"],\n", + " in_channels_static=config[\"params\"][\"in_channels_static\"],\n", + " input_scalers_mu=in_mu,\n", + " input_scalers_sigma=in_sig,\n", + " input_scalers_epsilon=config[\"params\"][\"input_scalers_epsilon\"],\n", + " static_input_scalers_mu=static_mu,\n", + " static_input_scalers_sigma=static_sig,\n", + " static_input_scalers_epsilon=config[\"params\"][\n", + " \"static_input_scalers_epsilon\"\n", + " ],\n", + " output_scalers=output_sig**0.5,\n", + " n_lats_px=config[\"params\"][\"n_lats_px\"],\n", + " n_lons_px=config[\"params\"][\"n_lons_px\"],\n", + " patch_size_px=config[\"params\"][\"patch_size_px\"],\n", + " mask_unit_size_px=config[\"params\"][\"mask_unit_size_px\"],\n", + " mask_ratio_inputs=masking_ratio,\n", + " embed_dim=config[\"params\"][\"embed_dim\"],\n", + " n_blocks_encoder=config[\"params\"][\"n_blocks_encoder\"],\n", + " n_blocks_decoder=config[\"params\"][\"n_blocks_decoder\"],\n", + " mlp_multiplier=config[\"params\"][\"mlp_multiplier\"],\n", + " n_heads=config[\"params\"][\"n_heads\"],\n", + " dropout=config[\"params\"][\"dropout\"],\n", + " drop_path=config[\"params\"][\"drop_path\"],\n", + " parameter_dropout=config[\"params\"][\"parameter_dropout\"],\n", + " residual=residual,\n", + " masking_mode=masking_mode,\n", + " decoder_shifting=decoder_shifting,\n", + " positional_encoding=positional_encoding,\n", + " checkpoint_encoder=[],\n", + " checkpoint_decoder=[],\n", + ")\n", + "\n", + "\n", + "state_dict = torch.load(weights_path, weights_only=False)\n", + "if \"model_state\" in state_dict:\n", + " state_dict = state_dict[\"model_state\"]\n", + "model.load_state_dict(state_dict, strict=True)\n", + "\n", + "if (hasattr(model, \"device\") and model.device != device) or not hasattr(\n", + " model, \"device\"\n", + "):\n", + " model = model.to(device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Rollout\n", + "We are now ready to perform the rollout. Agin the data has to be run through a\n", + "preprocessor. However this time we use a preprocessor that can handle the\n", + "additional intermediate data. Also, rather than calling the model directly, we\n", + "have a conveient wrapper function that performs the interation. This also\n", + "simplifies the model loading when using a sharded cahckpoint. If you attempt to\n", + "perform training steps upton this function, we should use an aggressive number\n", + "of activation checkpoints as the memory consumption becomes quite high." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "from torch import Tensor, nn\n", + "\n", + "\n", + "def rollout_iter(\n", + " nsteps: int,\n", + " model: nn.Module,\n", + " batch: dict[str, Tensor | int | float],\n", + ") -> Tensor:\n", + " \"\"\"A helper function for performing autoregressive rollout.\n", + "\n", + " Args:\n", + " nsteps (int): The number of rollout steps to take\n", + " model (nn.Module): A model.\n", + " batch (dict): A data dictionary common to the Prithvi models.\n", + "\n", + " Raises:\n", + " ValueError: If the number of steps isn't positive.\n", + "\n", + " Returns:\n", + " Tensor: the output of the model after nsteps autoregressive iterations.\n", + " \"\"\"\n", + " if nsteps < 1:\n", + " raise ValueError(\"'nsteps' shouold be a positive int.\")\n", + "\n", + " xlast = batch[\"x\"][:, 1]\n", + " batch[\"lead_time\"] = batch[\"lead_time\"][..., 0]\n", + "\n", + " # Save the masking ratio to be restored later\n", + " mask_ratio_tmp = model.mask_ratio_inputs\n", + "\n", + " for step in range(nsteps):\n", + " # After first step, turn off masking\n", + " if step > 0:\n", + " model.mask_ratio_inputs = 0.0\n", + "\n", + " batch[\"static\"] = batch[\"statics\"][:, step]\n", + " batch[\"climate\"] = batch[\"climates\"][:, step]\n", + " batch[\"y\"] = batch[\"ys\"][:, step]\n", + "\n", + " out = model(batch)\n", + "\n", + " batch[\"x\"] = torch.cat((xlast[:, None], out[:, None]), dim=1)\n", + " xlast = out\n", + "\n", + " # Restore the masking ratio\n", + " model.mask_ratio_inputs = mask_ratio_tmp\n", + "\n", + " return xlast\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# from PrithviWxC.dataloaders.merra2_rollout import preproc\n", + "# from PrithviWxC.rollout import rollout_iter\n", + "\n", + "data = next(iter(dataset))\n", + "batch = preproc([data], padding)\n", + "\n", + "for k, v in batch.items():\n", + " if isinstance(v, torch.Tensor):\n", + " batch[k] = v.to(device)\n", + "\n", + "rng_state_1 = torch.get_rng_state()\n", + "with torch.no_grad():\n", + " model.eval()\n", + " out = rollout_iter(dataset.nsteps, model, batch)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Plotting" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "t2m = out[0, 12].cpu().numpy()\n", + "\n", + "lat = np.linspace(-90, 90, out.shape[-2])\n", + "lon = np.linspace(-180, 180, out.shape[-1])\n", + "X, Y = np.meshgrid(lon, lat)\n", + "\n", + "plt.contourf(X, Y, t2m, 100)\n", + "plt.gca().set_aspect(\"equal\")\n", + "plt.show()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/climatology/anomaly_variance_surface.nc b/examples/climatology/anomaly_variance_surface.nc new file mode 100644 index 0000000000000000000000000000000000000000..8d15de473e699d4fd89769dcd971119a03b300fc Binary files /dev/null and b/examples/climatology/anomaly_variance_surface.nc differ diff --git a/examples/climatology/anomaly_variance_vertical.nc b/examples/climatology/anomaly_variance_vertical.nc new file mode 100644 index 0000000000000000000000000000000000000000..61f1c52132380b79375d2f5f227dd027e64b1a37 Binary files /dev/null and b/examples/climatology/anomaly_variance_vertical.nc differ diff --git a/examples/climatology/climate_surface_doy001_hour00.nc b/examples/climatology/climate_surface_doy001_hour00.nc new file mode 100644 index 0000000000000000000000000000000000000000..4eee99dd0577ca1208dc01c26555d3f6e7eb52ad --- /dev/null +++ b/examples/climatology/climate_surface_doy001_hour00.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6df405a6178c3222c4abd98eea6573ced38aa2ccbe0966647a68ac18103a4d1d +size 20830158 diff --git a/examples/climatology/climate_surface_doy001_hour03.nc b/examples/climatology/climate_surface_doy001_hour03.nc new file mode 100644 index 0000000000000000000000000000000000000000..78d57e7a4d962164d547ab013ef825a80604ab3e --- /dev/null +++ b/examples/climatology/climate_surface_doy001_hour03.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b67bcff5c6df5306677a65de3f7c08485a8ba1c3ecdc60b7f346cf0642caa7f1 +size 20830158 diff --git a/examples/climatology/climate_surface_doy001_hour06.nc b/examples/climatology/climate_surface_doy001_hour06.nc new file mode 100644 index 0000000000000000000000000000000000000000..27ef0dcf515665782a8031aebb8e7aca8ff355c6 --- /dev/null +++ b/examples/climatology/climate_surface_doy001_hour06.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:080cb1f7e8dad508fc78e40a5bbd30d03564b7a48bdd2916f296dc3dfed30c60 +size 20830158 diff --git a/examples/climatology/climate_surface_doy001_hour09.nc b/examples/climatology/climate_surface_doy001_hour09.nc new file mode 100644 index 0000000000000000000000000000000000000000..efa0bf3ccc1d3d0c7e174d1fa57f1ff5683484bb --- /dev/null +++ b/examples/climatology/climate_surface_doy001_hour09.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:134e194a8f38828ec067f98e8c5c7dc4aed0131046b6ed839f75bcf6b98b5492 +size 20830158 diff --git a/examples/climatology/climate_surface_doy001_hour12.nc b/examples/climatology/climate_surface_doy001_hour12.nc new file mode 100644 index 0000000000000000000000000000000000000000..fa221976960847e2b5d17fa127d3cfb4c53a87fa --- /dev/null +++ b/examples/climatology/climate_surface_doy001_hour12.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:974a79a5e68096c356b42f15cd9c955f0a1306cf6a942c682e09d6504742a6d1 +size 20830158 diff --git a/examples/climatology/climate_surface_doy001_hour15.nc b/examples/climatology/climate_surface_doy001_hour15.nc new file mode 100644 index 0000000000000000000000000000000000000000..f8a630fcb5be85bed8bd10124cbc81ffbc5a2364 --- /dev/null +++ b/examples/climatology/climate_surface_doy001_hour15.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:dea669f097cf1b9d775956a01e2bae52ae1856dfcc7abae5c6ee394949275c5e +size 20830158 diff --git a/examples/climatology/climate_surface_doy001_hour18.nc b/examples/climatology/climate_surface_doy001_hour18.nc new file mode 100644 index 0000000000000000000000000000000000000000..2a5ceb6e6c2b3e23175085085a0ef578d691f203 --- /dev/null +++ b/examples/climatology/climate_surface_doy001_hour18.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3b06aabcd1f14cc3a4b3c5dd355a225646fe4b9e3b5b676ee2fb6e5d05fe31c9 +size 20830158 diff --git a/examples/climatology/climate_surface_doy001_hour21.nc b/examples/climatology/climate_surface_doy001_hour21.nc new file mode 100644 index 0000000000000000000000000000000000000000..13f366c3802e63beff6bbae03c5a5fdadb56ab85 --- /dev/null +++ b/examples/climatology/climate_surface_doy001_hour21.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef90b12b58b73481f96c0caf5277d24cc05681003df0f6e83f5e85dc0b4d47b3 +size 20830158 diff --git a/examples/climatology/climate_vertical_doy001_hour00.nc b/examples/climatology/climate_vertical_doy001_hour00.nc new file mode 100644 index 0000000000000000000000000000000000000000..9a51931d9b43d2f69d646a812db0109cd816eab0 --- /dev/null +++ b/examples/climatology/climate_vertical_doy001_hour00.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b5f75cc1a55b6f1beecf2271d3a8fe9914cb20fb87e471761b3b686a990ec0ef +size 116475398 diff --git a/examples/climatology/climate_vertical_doy001_hour03.nc b/examples/climatology/climate_vertical_doy001_hour03.nc new file mode 100644 index 0000000000000000000000000000000000000000..5fb988aff029f32fb5f36e83c86e8ca140ac3109 --- /dev/null +++ b/examples/climatology/climate_vertical_doy001_hour03.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:80a04c77e614e01ae26fdad1fc3c5a15fd6869d627ce4251447d64c3bb934916 +size 116475398 diff --git a/examples/climatology/climate_vertical_doy001_hour06.nc b/examples/climatology/climate_vertical_doy001_hour06.nc new file mode 100644 index 0000000000000000000000000000000000000000..d236e82f4f26f26c0f3d4ae15b53dbf5ec244253 --- /dev/null +++ b/examples/climatology/climate_vertical_doy001_hour06.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0f6fc10bbd055e1e7af63f70cfa24abd60794cd89822a95b9fb9194b07c7822f +size 116475398 diff --git a/examples/climatology/climate_vertical_doy001_hour09.nc b/examples/climatology/climate_vertical_doy001_hour09.nc new file mode 100644 index 0000000000000000000000000000000000000000..aa225cbc173a9eca35c393609730577e4000d164 --- /dev/null +++ b/examples/climatology/climate_vertical_doy001_hour09.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70cb7c3e30902548b24330b61ecf493fdf1949bb434f45e9258434a80d91ee6d +size 116475398 diff --git a/examples/climatology/climate_vertical_doy001_hour12.nc b/examples/climatology/climate_vertical_doy001_hour12.nc new file mode 100644 index 0000000000000000000000000000000000000000..f4b0b384e6ce8f2890536396f8edd9df0e061c8a --- /dev/null +++ b/examples/climatology/climate_vertical_doy001_hour12.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a597136ae5b374e55b6f1113e218faeb3090a3d474f95bd6c043b4923238a32c +size 116475398 diff --git a/examples/climatology/climate_vertical_doy001_hour15.nc b/examples/climatology/climate_vertical_doy001_hour15.nc new file mode 100644 index 0000000000000000000000000000000000000000..4551686abb82484bf2515a6f6dc3bf0d4c69afb0 --- /dev/null +++ b/examples/climatology/climate_vertical_doy001_hour15.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6e1765b3fcf01aed2338b9df110a76a9e439703bb87c366b0d3868473ccceb46 +size 116475398 diff --git a/examples/climatology/climate_vertical_doy001_hour18.nc b/examples/climatology/climate_vertical_doy001_hour18.nc new file mode 100644 index 0000000000000000000000000000000000000000..3b4342cea1f59679496ba50c3bf88b33474f251c --- /dev/null +++ b/examples/climatology/climate_vertical_doy001_hour18.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:66b27837717a6cd37d72d55c8de3cc24b2122e25173348be37b3a76d98c8d580 +size 116475398 diff --git a/examples/climatology/climate_vertical_doy001_hour21.nc b/examples/climatology/climate_vertical_doy001_hour21.nc new file mode 100644 index 0000000000000000000000000000000000000000..78aa2a001f9de1cf095e13d59ac5b97daa0776d2 --- /dev/null +++ b/examples/climatology/climate_vertical_doy001_hour21.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:65084e38b67342d03135e98092d893e6b8ad677a1b690ea39b3fbc4b91df6fdc +size 116475398 diff --git a/examples/climatology/musigma_surface.nc b/examples/climatology/musigma_surface.nc new file mode 100644 index 0000000000000000000000000000000000000000..4ba7c8bb12f55491531e06fe7ea3b0a184128bc0 Binary files /dev/null and b/examples/climatology/musigma_surface.nc differ diff --git a/examples/climatology/musigma_vertical.nc b/examples/climatology/musigma_vertical.nc new file mode 100644 index 0000000000000000000000000000000000000000..f6415aec93712677beefb17bced1faede757d54f Binary files /dev/null and b/examples/climatology/musigma_vertical.nc differ diff --git a/examples/config.yaml b/examples/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4435167a11fd412e2dfb565e135eed07a33c7663 --- /dev/null +++ b/examples/config.yaml @@ -0,0 +1,20 @@ +params: + in_channels: 160 + input_size_time: 2 + in_channels_static: 8 + input_scalers_epsilon: 0.0 + static_input_scalers_epsilon: 0.0 + n_lats_px: 360 + n_lons_px: 576 + patch_size_px: [2, 2] + mask_unit_size_px: [30, 32] + embed_dim: 2560 + n_blocks_encoder: 12 + n_blocks_decoder: 2 + mlp_multiplier: 4 + n_heads: 16 + dropout: 0.0 + drop_path: 0.0 + parameter_dropout: 0.0 + checkpoint_encoder: [] + checkpoint_decoder: [] \ No newline at end of file diff --git a/examples/merra-2/MERRA2_sfc_20200101.nc b/examples/merra-2/MERRA2_sfc_20200101.nc new file mode 100644 index 0000000000000000000000000000000000000000..e9a61489fc05397e0c8153f74832505a6ecc833d --- /dev/null +++ b/examples/merra-2/MERRA2_sfc_20200101.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1de1638ca1f1b44ca95a7c6908573414c44e633f1bb2e212b8a46afc866c741e +size 101525285 diff --git a/examples/merra-2/MERRA_pres_20200101.nc b/examples/merra-2/MERRA_pres_20200101.nc new file mode 100644 index 0000000000000000000000000000000000000000..9d58023e68ef715d31fd5284747de9ce86883cbd --- /dev/null +++ b/examples/merra-2/MERRA_pres_20200101.nc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2cd9b405aa1d388fc0c6cbe6d71104bea770b48e07ba4364f067dc425d025af8 +size 337127950 diff --git a/examples/weights/.cache/huggingface/.gitignore b/examples/weights/.cache/huggingface/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..f59ec20aabf5842d237244ece8c81ab184faeac1 --- /dev/null +++ b/examples/weights/.cache/huggingface/.gitignore @@ -0,0 +1 @@ +* \ No newline at end of file diff --git a/examples/weights/.cache/huggingface/download/prithvi.wxc.2300m.v1.pt.metadata b/examples/weights/.cache/huggingface/download/prithvi.wxc.2300m.v1.pt.metadata new file mode 100644 index 0000000000000000000000000000000000000000..9f2c1ddb8374e74ed5b105dd0797e168f662a711 --- /dev/null +++ b/examples/weights/.cache/huggingface/download/prithvi.wxc.2300m.v1.pt.metadata @@ -0,0 +1,3 @@ +e9fdd56d7011c98ae591166c56e8b624d4f39e4a +9b3617e91f164833e4c155dc5035178591266731ef7ce471b12d1caa34d1b8d8 +1731282990.9262006 diff --git a/examples/weights/.cache/huggingface/download/prithvi.wxc.rollout.2300m.v1.pt.metadata b/examples/weights/.cache/huggingface/download/prithvi.wxc.rollout.2300m.v1.pt.metadata new file mode 100644 index 0000000000000000000000000000000000000000..8a3f7eed11b0128f939411c60c13ce451d3fd869 --- /dev/null +++ b/examples/weights/.cache/huggingface/download/prithvi.wxc.rollout.2300m.v1.pt.metadata @@ -0,0 +1,3 @@ +514c3d061ad45e3338495da7c16b13aa20fa75b1 +e66ef85d4e404465a5359b729f115d4bd7a5c5e8016c7b86b5f773d72cef8efa +1731287319.8335612 diff --git a/examples/weights/prithvi.wxc.2300m.v1.pt b/examples/weights/prithvi.wxc.2300m.v1.pt new file mode 100644 index 0000000000000000000000000000000000000000..7d620ba813e14c24e5ec691ebe2b6b38ad0ed06e --- /dev/null +++ b/examples/weights/prithvi.wxc.2300m.v1.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9b3617e91f164833e4c155dc5035178591266731ef7ce471b12d1caa34d1b8d8 +size 28447290466 diff --git a/examples/weights/prithvi.wxc.rollout.2300m.v1.pt b/examples/weights/prithvi.wxc.rollout.2300m.v1.pt new file mode 100644 index 0000000000000000000000000000000000000000..ee0fa8cc489d6e787a2ca1ce98e0dabc7e8789ce --- /dev/null +++ b/examples/weights/prithvi.wxc.rollout.2300m.v1.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e66ef85d4e404465a5359b729f115d4bd7a5c5e8016c7b86b5f773d72cef8efa +size 28447289145