qq1990 commited on
Commit
60840ab
·
verified ·
1 Parent(s): 3db5513

Upload 6 files

Browse files
PrithviWxC/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prithvi-WxC - Weather and climate foundational model."""
2
+
3
+ __version__ = "1.0.0"
4
+
5
+ from . import dataloaders, model
6
+
7
+ __all__ = [
8
+ "dataloaders",
9
+ "model",
10
+ ]
PrithviWxC/dataloaders/__init__.py ADDED
File without changes
PrithviWxC/dataloaders/merra2.py ADDED
@@ -0,0 +1,1168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools as ft
2
+ import os
3
+ import random
4
+ import re
5
+ from collections import defaultdict
6
+ from datetime import datetime, timedelta
7
+ from pathlib import Path
8
+
9
+ import h5py
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ from torch import Tensor
14
+ from torch.utils.data import Dataset
15
+
16
+
17
+ def preproc(batch: list[dict], padding: dict[tuple[int]]) -> dict[str, Tensor]:
18
+ """Prepressing function for MERRA2 Dataset
19
+
20
+ Args:
21
+ batch (dict): List of training samples, each sample should be a
22
+ dictionary with the following keys::
23
+
24
+ 'sur_static': Numpy array of shape (3, lat, lon). For each pixel (lat, lon), the first dimension indexes sin(lat), cos(lon), sin(lon).
25
+ 'sur_vals': Torch tensor of shape (parameter, time, lat, lon).
26
+ 'sur_tars': Torch tensor of shape (parameter, time, lat, lon).
27
+ 'ulv_vals': Torch tensor of shape (parameter, level, time, lat, lon).
28
+ 'ulv_tars': Torch tensor of shape (parameter, level, time, lat, lon).
29
+ 'sur_climate': Torch tensor of shape (parameter, lat, lon)
30
+ 'ulv_climate': Torch tensor of shape (parameter, level, lat, lon)
31
+ 'lead_time': Integer.
32
+ 'input_time': Integer.
33
+
34
+ padding: Dictionary with keys 'level', 'lat', 'lon', each of dim 2.
35
+
36
+ Returns:
37
+ Dictionary with the following keys::
38
+
39
+ 'x': [batch, time, parameter, lat, lon]
40
+ 'y': [batch, parameter, lat, lon]
41
+ 'static': [batch, parameter, lat, lon]
42
+ 'lead_time': [batch]
43
+ 'input_time': [batch]
44
+ 'climate (Optional)': [batch, parameter, lat, lon]
45
+
46
+ Note:
47
+ Here, for x and y, 'parameter' is [surface parameter, upper level,
48
+ parameter x level]. Similarly for the static information we have
49
+ [sin(lat), cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod),
50
+ ...].
51
+ """ # noqa: E501
52
+ b0 = batch[0]
53
+ nbatch = len(batch)
54
+ data_keys = set(b0.keys())
55
+
56
+ essential_keys = {
57
+ "sur_static",
58
+ "sur_vals",
59
+ "sur_tars",
60
+ "ulv_vals",
61
+ "ulv_tars",
62
+ "input_time",
63
+ "lead_time",
64
+ }
65
+
66
+ climate_keys = {
67
+ "sur_climate",
68
+ "ulv_climate",
69
+ }
70
+
71
+ all_keys = essential_keys | climate_keys
72
+
73
+ if not essential_keys.issubset(data_keys):
74
+ raise ValueError("Missing essential keys.")
75
+
76
+ if not data_keys.issubset(all_keys):
77
+ raise ValueError("Unexpected keys in batch.")
78
+
79
+ # Bring all tensors from the batch into a single tensor
80
+ upl_x = torch.empty((nbatch, *b0["ulv_vals"].shape))
81
+ upl_y = torch.empty((nbatch, *b0["ulv_tars"].shape))
82
+
83
+ sur_x = torch.empty((nbatch, *b0["sur_vals"].shape))
84
+ sur_y = torch.empty((nbatch, *b0["sur_tars"].shape))
85
+
86
+ sur_sta = torch.empty((nbatch, *b0["sur_static"].shape))
87
+
88
+ lead_time = torch.empty((nbatch,), dtype=torch.float32)
89
+ input_time = torch.empty((nbatch,), dtype=torch.float32)
90
+
91
+ for i, rec in enumerate(batch):
92
+ sur_x[i] = rec["sur_vals"]
93
+ sur_y[i] = rec["sur_tars"]
94
+
95
+ upl_x[i] = rec["ulv_vals"]
96
+ upl_y[i] = rec["ulv_tars"]
97
+
98
+ sur_sta[i] = rec["sur_static"]
99
+
100
+ lead_time[i] = rec["lead_time"]
101
+ input_time[i] = rec["input_time"]
102
+
103
+ return_value = {
104
+ "lead_time": lead_time,
105
+ "input_time": input_time,
106
+ }
107
+
108
+ # Reshape (batch, parameter, level, time, lat, lon) ->
109
+ # (batch, time, parameter, level, lat, lon)
110
+ upl_x = upl_x.permute((0, 3, 1, 2, 4, 5))
111
+ upl_y = upl_y.permute((0, 3, 1, 2, 4, 5))
112
+ # Reshape (batch, parameter, time, lat, lon) ->
113
+ # (batch, time, parameter, lat, lon)
114
+ sur_x = sur_x.permute((0, 2, 1, 3, 4))
115
+ sur_y = sur_y.permute((0, 2, 1, 3, 4))
116
+
117
+ # Pad
118
+ padding_2d = (*padding["lon"], *padding["lat"])
119
+
120
+ def pad2d(x):
121
+ return torch.nn.functional.pad(x, padding_2d, mode="constant", value=0)
122
+
123
+ padding_3d = (*padding["lon"], *padding["lat"], *padding["level"])
124
+
125
+ def pad3d(x):
126
+ return torch.nn.functional.pad(x, padding_3d, mode="constant", value=0)
127
+
128
+ sur_x = pad2d(sur_x).contiguous()
129
+ upl_x = pad3d(upl_x).contiguous()
130
+ sur_y = pad2d(sur_y).contiguous()
131
+ upl_y = pad3d(upl_y).contiguous()
132
+ return_value["static"] = pad2d(sur_sta).contiguous()
133
+
134
+ # Remove time for targets
135
+ upl_y = torch.squeeze(upl_y, 1)
136
+ sur_y = torch.squeeze(sur_y, 1)
137
+
138
+ # We stack along the combined parameter x level dimension
139
+ return_value["x"] = torch.cat(
140
+ (sur_x, upl_x.view(*upl_x.shape[:2], -1, *upl_x.shape[4:])), dim=2
141
+ )
142
+ return_value["y"] = torch.cat(
143
+ (sur_y, upl_y.view(upl_y.shape[0], -1, *upl_y.shape[3:])), dim=1
144
+ )
145
+
146
+ if climate_keys.issubset(data_keys):
147
+ sur_climate = torch.empty((nbatch, *b0["sur_climate"].shape))
148
+ ulv_climate = torch.empty((nbatch, *b0["ulv_climate"].shape))
149
+ for i, rec in enumerate(batch):
150
+ sur_climate[i] = rec["sur_climate"]
151
+ ulv_climate[i] = rec["ulv_climate"]
152
+ sur_climate = pad2d(sur_climate)
153
+ ulv_climate = pad3d(ulv_climate)
154
+
155
+ return_value["climate"] = torch.cat(
156
+ (
157
+ sur_climate,
158
+ ulv_climate.view(nbatch, -1, *ulv_climate.shape[3:]),
159
+ ),
160
+ dim=1,
161
+ )
162
+
163
+ return return_value
164
+
165
+
166
+ def input_scalers(
167
+ surf_vars: list[str],
168
+ vert_vars: list[str],
169
+ levels: list[float],
170
+ surf_path: str | Path,
171
+ vert_path: str | Path,
172
+ ) -> tuple[Tensor, Tensor]:
173
+ """Reads the input scalers
174
+
175
+ Args:
176
+ surf_vars: surface variables to be used.
177
+ vert_vars: vertical variables to be used.
178
+ levels: MERRA2 levels to use.
179
+ surf_path: path to surface scalers file.
180
+ vert_path: path to vertical level scalers file.
181
+
182
+ Returns:
183
+ mu (Tensor): mean values
184
+ var (Tensor): varience values
185
+ """
186
+ with h5py.File(Path(surf_path), "r", libver="latest") as surf_file:
187
+ stats = [x.decode().lower() for x in surf_file["statistic"][()]]
188
+ mu_idx = stats.index("mu")
189
+ sig_idx = stats.index("sigma")
190
+
191
+ s_mu = torch.tensor([surf_file[k][()][mu_idx] for k in surf_vars])
192
+ s_sig = torch.tensor([surf_file[k][()][sig_idx] for k in surf_vars])
193
+
194
+ with h5py.File(Path(vert_path), "r", libver="latest") as vert_file:
195
+ stats = [x.decode().lower() for x in vert_file["statistic"][()]]
196
+ mu_idx = stats.index("mu")
197
+ sig_idx = stats.index("sigma")
198
+
199
+ lvl = vert_file["lev"][()]
200
+ l_idx = [np.where(lvl == v)[0].item() for v in levels]
201
+
202
+ v_mu = np.array([vert_file[k][()][mu_idx, l_idx] for k in vert_vars])
203
+ v_sig = np.array([vert_file[k][()][sig_idx, l_idx] for k in vert_vars])
204
+
205
+ v_mu = torch.from_numpy(v_mu).view(-1)
206
+ v_sig = torch.from_numpy(v_sig).view(-1)
207
+
208
+ mu = torch.cat((s_mu, v_mu), dim=0).to(torch.float32)
209
+ sig = torch.cat((s_sig, v_sig), dim=0).to(torch.float32).clamp(1e-4, 1e4)
210
+ return mu, sig
211
+
212
+
213
+ def static_input_scalers(
214
+ scalar_path: str | Path, stat_vars: list[str], unscaled_params: int = 7
215
+ ) -> tuple[Tensor, Tensor]:
216
+ scalar_path = Path(scalar_path)
217
+
218
+ with h5py.File(scalar_path, "r", libver="latest") as scaler_file:
219
+ stats = [x.decode().lower() for x in scaler_file["statistic"][()]]
220
+ mu_idx = stats.index("mu")
221
+ sig_idx = stats.index("sigma")
222
+
223
+ mu = torch.tensor([scaler_file[k][()][mu_idx] for k in stat_vars])
224
+ sig = torch.tensor([scaler_file[k][()][sig_idx] for k in stat_vars])
225
+
226
+ z = torch.zeros(unscaled_params, dtype=mu.dtype, device=mu.device)
227
+ o = torch.ones(unscaled_params, dtype=sig.dtype, device=sig.device)
228
+ mu = torch.cat((z, mu), dim=0).to(torch.float32)
229
+ sig = torch.cat((o, sig), dim=0).to(torch.float32)
230
+
231
+ return mu, sig.clamp(1e-4, 1e4)
232
+
233
+
234
+ def output_scalers(
235
+ surf_vars: list[str],
236
+ vert_vars: list[str],
237
+ levels: list[float],
238
+ surf_path: str | Path,
239
+ vert_path: str | Path,
240
+ ) -> Tensor:
241
+ surf_path = Path(surf_path)
242
+ vert_path = Path(vert_path)
243
+
244
+ with h5py.File(surf_path, "r", libver="latest") as surf_file:
245
+ svars = torch.tensor([surf_file[k][()] for k in surf_vars])
246
+
247
+ with h5py.File(vert_path, "r", libver="latest") as vert_file:
248
+ lvl = vert_file["lev"][()]
249
+ l_idx = [np.where(lvl == v)[0].item() for v in levels]
250
+ vvars = np.array([vert_file[k][()][l_idx] for k in vert_vars])
251
+ vvars = torch.from_numpy(vvars).view(-1)
252
+
253
+ var = torch.cat((svars, vvars), dim=0).to(torch.float32).clamp(1e-7, 1e7)
254
+
255
+ return var
256
+
257
+
258
+ class SampleSpec:
259
+ """
260
+ A data class to collect the information used to define a sample.
261
+ """
262
+
263
+ def __init__(
264
+ self,
265
+ inputs: tuple[pd.Timestamp, pd.Timestamp],
266
+ lead_time: int,
267
+ target: pd.Timestamp | list[pd.Timestamp],
268
+ ):
269
+ """
270
+ Args:
271
+ inputs: Tuple of timestamps. In ascending order.
272
+ lead_time: Lead time. In hours.
273
+ target: Timestamp of the target. Can be before or after the inputs.
274
+ """
275
+ if not inputs[0] < inputs[1]:
276
+ raise ValueError(
277
+ "Timestamps in `inputs` should be in strictly ascending order."
278
+ )
279
+
280
+ self.inputs = inputs
281
+ self.input_time = (inputs[1] - inputs[0]).total_seconds() / 3600
282
+ self.lead_time = lead_time
283
+ self.target = target
284
+
285
+ self.times = [*inputs, target]
286
+ self.stat_times = [inputs[-1]]
287
+
288
+ @property
289
+ def climatology_info(self) -> tuple[int, int]:
290
+ """Get the required climatology info.
291
+
292
+ :return: information required to obtain climatology data. Essentially
293
+ this is the day of the year and hour of the day of the target
294
+ timestamp, with the former restricted to the interval [1, 365].
295
+ :rtype: tuple
296
+ """
297
+ return (min(self.target.dayofyear, 365), self.target.hour)
298
+
299
+ @property
300
+ def year(self) -> int:
301
+ return self.inputs[1].year
302
+
303
+ @property
304
+ def dayofyear(self) -> int:
305
+ return self.inputs[1].dayofyear
306
+
307
+ @property
308
+ def hourofday(self) -> int:
309
+ return self.inputs[1].hour
310
+
311
+ def _info_str(self) -> str:
312
+ iso_8601 = "%Y-%m-%dT%H:%M:%S"
313
+
314
+ return (
315
+ f"Issue time: {self.inputs[1].strftime(iso_8601)}\n"
316
+ f"Lead time: {self.lead_time} hours ahead\n"
317
+ f"Input delta: {self.input_time} hours\n"
318
+ f"Target time: {self.target.strftime(iso_8601)}"
319
+ )
320
+
321
+ @classmethod
322
+ def get(cls, timestamp: pd.Timestamp, dt: int, lead_time: int):
323
+ """Given a timestamp and lead time, generates a SampleSpec object
324
+ describing the sample further.
325
+
326
+ Args:
327
+ timestamp: Timstamp of the sample, Ie this is the larger of the two
328
+ input timstamps.
329
+ dt: Time between input samples, in hours.
330
+ lead_time: Lead time. In hours.
331
+
332
+ Returns:
333
+ SampleSpec
334
+ """ # noqa: E501
335
+ assert dt > 0, "dt should be possitive"
336
+ lt = pd.to_timedelta(lead_time, unit="h")
337
+ dt = pd.to_timedelta(dt, unit="h")
338
+
339
+ if lead_time >= 0:
340
+ timestamp_target = timestamp + lt
341
+ else:
342
+ timestamp_target = timestamp - dt + lt
343
+
344
+ spec = cls(
345
+ inputs=(timestamp - dt, timestamp),
346
+ lead_time=lead_time,
347
+ target=timestamp_target,
348
+ )
349
+
350
+ return spec
351
+
352
+ def __repr__(self) -> str:
353
+ return self._info_str()
354
+
355
+ def __str__(self) -> str:
356
+ return self._info_str()
357
+
358
+
359
+ class Merra2Dataset(Dataset):
360
+ """MERRA2 dataset. The dataset unifies surface and vertical data as well as
361
+ optional climatology.
362
+
363
+ Samples come in the form of a dictionary. Not all keys support all
364
+ variables, yet the general ordering of dimensions is
365
+ parameter, level, time, lat, lon
366
+
367
+ Note:
368
+ Data is assumed to be in NetCDF files containing daily data at 3-hourly
369
+ intervals. These follow the naming patterns
370
+ MERRA2_sfc_YYYYMMHH.nc and MERRA_pres_YYYYMMHH.nc and can be located in
371
+ two different locations. Optional climatology data comes from files
372
+ climate_surface_doyDOY_hourHOD.nc and
373
+ climate_vertical_doyDOY_hourHOD.nc.
374
+
375
+
376
+ Note:
377
+ `_get_valid_timestamps` assembles a set of all timestamps for which
378
+ there is data (with hourly resolutions). The result is stored in
379
+ `_valid_timestamps`. `_get_valid_climate_timestamps` does the same with
380
+ climatology data and stores it in `_valid_climate_timestamps`.
381
+
382
+ Based on this information, `samples` generates a list of valid samples,
383
+ stored in `samples`. Here the format is::
384
+
385
+ [
386
+ [
387
+ (timestamp 1, lead time A),
388
+ (timestamp 1, lead time B),
389
+ (timestamp 1, lead time C),
390
+ ],
391
+ [
392
+ (timestamp 2, lead time D),
393
+ (timestamp 2, lead time E),
394
+ ]
395
+ ]
396
+
397
+ That is, the outer list iterates over timestamps (init times), the
398
+ inner over lead times. Only valid entries are stored.
399
+ """
400
+
401
+ valid_vertical_vars = [
402
+ "CLOUD",
403
+ "H",
404
+ "OMEGA",
405
+ "PL",
406
+ "QI",
407
+ "QL",
408
+ "QV",
409
+ "T",
410
+ "U",
411
+ "V",
412
+ ]
413
+ valid_surface_vars = [
414
+ "EFLUX",
415
+ "GWETROOT",
416
+ "HFLUX",
417
+ "LAI",
418
+ "LWGAB",
419
+ "LWGEM",
420
+ "LWTUP",
421
+ "PRECTOT",
422
+ "PS",
423
+ "QV2M",
424
+ "SLP",
425
+ "SWGNT",
426
+ "SWTNT",
427
+ "T2M",
428
+ "TQI",
429
+ "TQL",
430
+ "TQV",
431
+ "TS",
432
+ "U10M",
433
+ "V10M",
434
+ "Z0M",
435
+ ]
436
+ valid_static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
437
+
438
+ valid_levels = [
439
+ 34.0,
440
+ 39.0,
441
+ 41.0,
442
+ 43.0,
443
+ 44.0,
444
+ 45.0,
445
+ 48.0,
446
+ 51.0,
447
+ 53.0,
448
+ 56.0,
449
+ 63.0,
450
+ 68.0,
451
+ 71.0,
452
+ 72.0,
453
+ ]
454
+
455
+ timedelta_input = pd.to_timedelta(3, unit="h")
456
+
457
+ def __init__(
458
+ self,
459
+ time_range: tuple[str | pd.Timestamp, str | pd.Timestamp],
460
+ lead_times: list[int],
461
+ input_times: list[int],
462
+ data_path_surface: str | Path,
463
+ data_path_vertical: str | Path,
464
+ climatology_path_surface: str | Path | None = None,
465
+ climatology_path_vertical: str | Path | None = None,
466
+ surface_vars: list[str] | None = None,
467
+ static_surface_vars: list[str] | None = None,
468
+ vertical_vars: list[str] | None = None,
469
+ levels: list[float] | None = None,
470
+ roll_longitudes: int = 0,
471
+ positional_encoding: str = "absolute",
472
+ rtype: type = np.float32,
473
+ dtype: torch.dtype = torch.float32,
474
+ ) -> None:
475
+ """
476
+ Args:
477
+ data_path_surface: Location of surface data.
478
+ data_path_vertical: Location of vertical data.
479
+ climatology_path_surface: Location of (optional) surface
480
+ climatology.
481
+ climatology_path_vertical: Location of (optional) vertical
482
+ climatology.
483
+ surface_vars: Surface variables.
484
+ static_surface_vars: Static surface variables.
485
+ vertical_vars: Vertical variables.
486
+ levels: Levels.
487
+ time_range: Used to subset data.
488
+ lead_times: Lead times for generalized forecasting.
489
+ roll_longitudes: Set to non-zero value to data by random amount
490
+ along longitude dimension.
491
+ position_encoding: possible values are
492
+ ['absolute' (default), 'fourier'].
493
+ 'absolute' returns lat lon encoded in 3 dimensions using sine
494
+ and cosine
495
+ 'fourier' returns lat/lon to be encoded by model
496
+ <any other key> returns lat/lon to be encoded by model
497
+ rtype: numpy data type used during read
498
+ dtype: torch data type of data output
499
+ """
500
+
501
+ self.time_range = (
502
+ pd.to_datetime(time_range[0]),
503
+ pd.to_datetime(time_range[1]),
504
+ )
505
+ self.lead_times = lead_times
506
+ self.input_times = input_times
507
+ self._roll_longitudes = list(range(roll_longitudes + 1))
508
+
509
+ self._uvars = vertical_vars or self.valid_vertical_vars
510
+ self._level = levels or self.valid_levels
511
+ self._svars = surface_vars or self.valid_surface_vars
512
+ self._sstat = static_surface_vars or self.valid_static_surface_vars
513
+ self._nuvars = len(self._uvars)
514
+ self._nlevel = len(self._level)
515
+ self._nsvars = len(self._svars)
516
+ self._nsstat = len(self._sstat)
517
+
518
+ self.rtype = rtype
519
+ self.dtype = dtype
520
+
521
+ self.positional_encoding = positional_encoding
522
+
523
+ self._data_path_surface = Path(data_path_surface)
524
+ self._data_path_vertical = Path(data_path_vertical)
525
+
526
+ self.dir_exists(self._data_path_surface)
527
+ self.dir_exists(self._data_path_vertical)
528
+
529
+ self._get_coordinates()
530
+
531
+ self._climatology_path_surface = Path(climatology_path_surface) or None
532
+ self._climatology_path_vertical = (
533
+ Path(climatology_path_vertical) or None
534
+ )
535
+ self._require_clim = (
536
+ self._climatology_path_surface is not None
537
+ and self._climatology_path_vertical is not None
538
+ )
539
+
540
+ if self._require_clim:
541
+ self.dir_exists(self._climatology_path_surface)
542
+ self.dir_exists(self._climatology_path_vertical)
543
+ elif (
544
+ climatology_path_surface is None
545
+ and climatology_path_vertical is None
546
+ ):
547
+ self._climatology_path_surface = None
548
+ self._climatology_path_vertical = None
549
+ else:
550
+ raise ValueError(
551
+ "Either both or neither of"
552
+ "`climatology_path_surface` and"
553
+ "`climatology_path_vertical` should be None."
554
+ )
555
+
556
+ if not set(self._svars).issubset(set(self.valid_surface_vars)):
557
+ raise ValueError("Invalid surface variable.")
558
+
559
+ if not set(self._sstat).issubset(set(self.valid_static_surface_vars)):
560
+ raise ValueError("Invalid static surface variable.")
561
+
562
+ if not set(self._uvars).issubset(set(self.valid_vertical_vars)):
563
+ raise ValueError("Inalid vertical variable.")
564
+
565
+ if not set(self._level).issubset(set(self.valid_levels)):
566
+ raise ValueError("Invalid level.")
567
+
568
+ @staticmethod
569
+ def dir_exists(path: Path) -> None:
570
+ if not path.is_dir():
571
+ raise ValueError(f"Directory {path} does not exist.")
572
+
573
+ @property
574
+ def upper_shape(self) -> tuple:
575
+ """Returns the vertical variables shape
576
+ Returns:
577
+ tuple: vertical variable shape in the following order::
578
+
579
+ [VAR, LEV, TIME, LAT, LON]
580
+ """
581
+ return self._nuvars, self._nlevel, 2, 361, 576
582
+
583
+ @property
584
+ def surface_shape(self) -> tuple:
585
+ """Returns the surface variables shape
586
+
587
+ Returns:
588
+ tuple: surafce shape in the following order::
589
+
590
+ [VAR, LEV, TIME, LAT, LON]
591
+ """
592
+ return self._nsvars, 2, 361, 576
593
+
594
+ def data_file_surface(self, timestamp: pd.Timestamp) -> Path:
595
+ """Build the surfcae data file name based on timestamp
596
+
597
+ Args:
598
+ timestamp: a timestamp
599
+
600
+ Returns:
601
+ Path: constructed path
602
+ """
603
+ pattern = "MERRA2_sfc_%Y%m%d.nc"
604
+ data_file = self._data_path_surface / timestamp.strftime(pattern)
605
+ return data_file
606
+
607
+ def data_file_vertical(self, timestamp: pd.Timestamp) -> Path:
608
+ """Build the vertical data file name based on timestamp
609
+
610
+ Args:
611
+ timestamp: a timestamp
612
+
613
+ Returns:
614
+ Path: constructed path
615
+ """
616
+ pattern = "MERRA_pres_%Y%m%d.nc"
617
+ data_file = self._data_path_vertical / timestamp.strftime(pattern)
618
+ return data_file
619
+
620
+ def data_file_surface_climate(
621
+ self,
622
+ timestamp: pd.Timestamp | None = None,
623
+ dayofyear: int | None = None,
624
+ hourofday: int | None = None,
625
+ ) -> Path:
626
+ """
627
+ Returns the path to a climatology file based either on a timestamp or
628
+ the dayofyear / hourofday combination.
629
+ Args:
630
+ timestamp: A timestamp.
631
+ dayofyear: Day of the year. 1 to 366.
632
+ hourofday: Hour of the day. 0 to 23.
633
+ Returns:
634
+ Path: Path to climatology file.
635
+ """
636
+ if timestamp is not None and (
637
+ (dayofyear is not None) or (hourofday is not None)
638
+ ):
639
+ raise ValueError(
640
+ "Provide either timestamp or both dayofyear and hourofday."
641
+ )
642
+
643
+ if timestamp is not None:
644
+ dayofyear = min(timestamp.dayofyear, 365)
645
+ hourofday = timestamp.hour
646
+
647
+ file_name = f"climate_surface_doy{dayofyear:03}_hour{hourofday:02}.nc"
648
+ data_file = self._climatology_path_surface / file_name
649
+ return data_file
650
+
651
+ def data_file_vertical_climate(
652
+ self,
653
+ timestamp: pd.Timestamp | None = None,
654
+ dayofyear: int | None = None,
655
+ hourofday: int | None = None,
656
+ ) -> Path:
657
+ """Returns the path to a climatology file based either on a timestamp
658
+ or the dayofyear / hourofday combination.
659
+
660
+ Args:
661
+ timestamp: A timestamp. dayofyear: Day of the year. 1 to 366.
662
+ hourofday: Hour of the day. 0 to 23.
663
+ Returns:
664
+ Path: Path to climatology file.
665
+ """
666
+ if timestamp is not None and (
667
+ (dayofyear is not None) or (hourofday is not None)
668
+ ):
669
+ raise ValueError(
670
+ "Provide either timestamp or both dayofyear and hourofday."
671
+ )
672
+
673
+ if timestamp is not None:
674
+ dayofyear = min(timestamp.dayofyear, 365)
675
+ hourofday = timestamp.hour
676
+
677
+ file_name = f"climate_vertical_doy{dayofyear:03}_hour{hourofday:02}.nc"
678
+ data_file = self._climatology_path_vertical / file_name
679
+ return data_file
680
+
681
+ def _get_coordinates(self) -> None:
682
+ """
683
+ Obtains the coordiantes (latitudes and longitudes) from a single data
684
+ file.
685
+ """
686
+ timestamp = next(iter(self.valid_timestamps))
687
+
688
+ file = self.data_file_surface(timestamp)
689
+ with h5py.File(file, "r", libver="latest") as handle:
690
+ self.lats = lats = handle["lat"][()].astype(self.rtype)
691
+ self.lons = lons = handle["lon"][()].astype(self.rtype)
692
+
693
+ deg_to_rad = np.pi / 180
694
+ self._embed_lat = np.sin(lats * deg_to_rad).reshape(-1, 1)
695
+
696
+ self._embed_lon = np.empty((2, 1, len(lons)), dtype=self.rtype)
697
+ self._embed_lon[0, 0] = np.cos(lons * deg_to_rad)
698
+ self._embed_lon[1, 0] = np.sin(lons * deg_to_rad)
699
+
700
+ @ft.cached_property
701
+ def lats(self) -> np.ndarray:
702
+ timestamp = next(iter(self.valid_timestamps))
703
+
704
+ file = self.data_file_surface(timestamp)
705
+ with h5py.File(file, "r", libver="latest") as handle:
706
+ return handle["lat"][()].astype(self.rtype)
707
+
708
+ @ft.cached_property
709
+ def lons(self) -> np.ndarray:
710
+ timestamp = next(iter(self.valid_timestamps))
711
+
712
+ file = self.data_file_surface(timestamp)
713
+ with h5py.File(file, "r", libver="latest") as handle:
714
+ return handle["lon"][()].astype(self.rtype)
715
+
716
+ @ft.cached_property
717
+ def position_signal(self) -> np.ndarray:
718
+ """Generates the "position signal" that is part of the static
719
+ features.
720
+
721
+ Returns:
722
+ Tensor: Torch tensor of dimension (parameter, lat, lon) containing
723
+ sin(lat), cos(lon), sin(lon).
724
+ """
725
+
726
+ latitudes, longitudes = np.meshgrid(
727
+ self.lats, self.lons, indexing="ij"
728
+ )
729
+
730
+ if self.positional_encoding == "absolute":
731
+ latitudes = latitudes / 360 * 2.0 * np.pi
732
+ longitudes = longitudes / 360 * 2.0 * np.pi
733
+ sur_static = np.stack(
734
+ [np.sin(latitudes), np.cos(longitudes), np.sin(longitudes)],
735
+ axis=0,
736
+ )
737
+ else:
738
+ sur_static = np.stack([latitudes, longitudes], axis=0)
739
+
740
+ sur_static = sur_static.astype(self.rtype)
741
+
742
+ return sur_static
743
+
744
+ @ft.cached_property
745
+ def valid_timestamps(self) -> set[pd.Timestamp]:
746
+ """Generates list of valid timestamps based on available files. Only
747
+ timestamps for which both surface and vertical information is available
748
+ are considered valid.
749
+ Returns:
750
+ list: list of timestamps
751
+ """
752
+
753
+ s_glob = self._data_path_surface.glob("MERRA2_sfc_????????.nc")
754
+ s_files = [os.path.basename(f) for f in s_glob]
755
+ v_glob = self._data_path_surface.glob("MERRA_pres_????????.nc")
756
+ v_files = [os.path.basename(f) for f in v_glob]
757
+
758
+ s_re = re.compile(r"MERRA2_sfc_(\d{8}).nc\Z")
759
+ v_re = re.compile(r"MERRA_pres_(\d{8}).nc\Z")
760
+ fmt = "%Y%m%d"
761
+
762
+ s_times = {
763
+ (datetime.strptime(m[1], fmt))
764
+ for f in s_files
765
+ if (m := s_re.match(f))
766
+ }
767
+ v_times = {
768
+ (datetime.strptime(m[1], fmt))
769
+ for f in v_files
770
+ if (m := v_re.match(f))
771
+ }
772
+
773
+ times = s_times.intersection(v_times)
774
+
775
+ # Each file contains a day at 3 hour intervals
776
+ times = {
777
+ t + timedelta(hours=i) for i in range(0, 24, 3) for t in times
778
+ }
779
+
780
+ start_time, end_time = self.time_range
781
+ times = {pd.Timestamp(t) for t in times if start_time <= t <= end_time}
782
+
783
+ return times
784
+
785
+ @ft.cached_property
786
+ def valid_climate_timestamps(self) -> set[tuple[int, int]]:
787
+ """Generates list of "timestamps" (dayofyear, hourofday) for which
788
+ climatology data is present. Only instances for which surface and
789
+ vertical data is available are considered valid.
790
+ Returns:
791
+ list: List of tuples describing valid climatology instances.
792
+ """
793
+ if not self._require_clim:
794
+ return set()
795
+
796
+ s_glob = self._climatology_path_surface.glob(
797
+ "climate_surface_doy???_hour??.nc"
798
+ )
799
+ s_files = [os.path.basename(f) for f in s_glob]
800
+
801
+ v_glob = self._climatology_path_vertical.glob(
802
+ "climate_vertical_doy???_hour??.nc"
803
+ )
804
+ v_files = [os.path.basename(f) for f in v_glob]
805
+
806
+ s_re = re.compile(r"climate_surface_doy(\d{3})_hour(\d{2}).nc\Z")
807
+ v_re = re.compile(r"climate_vertical_doy(\d{3})_hour(\d{2}).nc\Z")
808
+
809
+ s_times = {
810
+ (int(m[1]), int(m[2])) for f in s_files if (m := s_re.match(f))
811
+ }
812
+ v_times = {
813
+ (int(m[1]), int(m[2])) for f in v_files if (m := v_re.match(f))
814
+ }
815
+
816
+ times = s_times.intersection(v_times)
817
+
818
+ return times
819
+
820
+ def _data_available(self, spec: SampleSpec) -> bool:
821
+ """
822
+ Checks whether data is available for a given SampleSpec object. Does so
823
+ using the internal sets with available data previously constructed. Not
824
+ by checking the file system.
825
+ Args:
826
+ spec: SampleSpec object as returned by SampleSpec.get
827
+ Returns:
828
+ bool: if data is availability.
829
+ """
830
+ valid = set(spec.times).issubset(self.valid_timestamps)
831
+
832
+ if self._require_clim:
833
+ sci = spec.climatology_info
834
+ ci = set(sci) if isinstance(sci, list) else set([sci]) # noqa: C405
835
+ valid &= ci.issubset(self.valid_climate_timestamps)
836
+
837
+ return valid
838
+
839
+ @ft.cached_property
840
+ def samples(self) -> list[tuple[pd.Timestamp, int, int]]:
841
+ """
842
+ Generates list of all valid samlpes.
843
+ Returns:
844
+ list: List of tuples (timestamp, input time, lead time).
845
+ """
846
+ valid_samples = []
847
+ dts = [(it, lt) for it in self.input_times for lt in self.lead_times]
848
+
849
+ for timestamp in sorted(self.valid_timestamps):
850
+ timestamp_samples = []
851
+ for it, lt in dts:
852
+ spec = SampleSpec.get(timestamp, -it, lt)
853
+
854
+ if self._data_available(spec):
855
+ timestamp_samples.append((timestamp, it, lt))
856
+
857
+ if timestamp_samples:
858
+ valid_samples.append(timestamp_samples)
859
+
860
+ return valid_samples
861
+
862
+ def _to_torch(
863
+ self,
864
+ data: dict[str, Tensor | list[Tensor]],
865
+ dtype: torch.dtype = torch.float32,
866
+ ) -> dict[str, Tensor | list[Tensor]]:
867
+ out = {}
868
+ for k, v in data.items():
869
+ if isinstance(v, list):
870
+ out[k] = [torch.from_numpy(x).to(dtype) for x in v]
871
+ else:
872
+ out[k] = torch.from_numpy(v).to(dtype)
873
+
874
+ return out
875
+
876
+ def _lat_roll(
877
+ self, data: dict[str, Tensor | list[Tensor]], n: int
878
+ ) -> dict[str, Tensor | list[Tensor]]:
879
+ out = {}
880
+ for k, v in data.items():
881
+ if isinstance(v, list):
882
+ out[k] = [torch.roll(x, shifts=n, dims=-1) for x in v]
883
+ else:
884
+ out[k] = torch.roll(v, shifts=n, dims=-1)
885
+
886
+ return out
887
+
888
+ def _read_static_data(
889
+ self, file: str | Path, doy: int, hod: int
890
+ ) -> np.ndarray:
891
+ with h5py.File(file, "r", libver="latest") as handle:
892
+ lats_surf = handle["lat"]
893
+ lons_surf = handle["lon"]
894
+
895
+ nll = (len(lats_surf), len(lons_surf))
896
+
897
+ npos = len(self.position_signal)
898
+ ntime = 4
899
+
900
+ nstat = npos + ntime + self._nsstat
901
+ data = np.empty((nstat, *nll), dtype=self.rtype)
902
+
903
+ for i, key in enumerate(self._sstat, start=npos + ntime):
904
+ data[i] = handle[key][()].astype(dtype=self.rtype)
905
+
906
+ # [possition signal], cos(doy), sin(doy), cos(hod), sin(hod)
907
+ data[0:npos] = self.position_signal
908
+ data[npos + 0] = np.cos(2 * np.pi * doy / 366)
909
+ data[npos + 1] = np.sin(2 * np.pi * doy / 366)
910
+ data[npos + 2] = np.cos(2 * np.pi * hod / 24)
911
+ data[npos + 3] = np.sin(2 * np.pi * hod / 24)
912
+
913
+ return data
914
+
915
+ def _read_surface(
916
+ self, tidx: int, nll: tuple[int, int], handle: h5py.File
917
+ ) -> np.ndarray:
918
+ data = np.empty((self._nsvars, *nll), dtype=self.rtype)
919
+
920
+ for i, key in enumerate(self._svars):
921
+ data[i] = handle[key][tidx][()].astype(dtype=self.rtype)
922
+
923
+ return data
924
+
925
+ def _read_levels(
926
+ self, tidx: int, nll: tuple[int, int], handle: h5py.File
927
+ ) -> np.ndarray:
928
+ lvls = handle["lev"][()]
929
+ lidx = self._level_idxs(lvls)
930
+
931
+ data = np.empty((self._nuvars, self._nlevel, *nll), dtype=self.rtype)
932
+
933
+ for i, key in enumerate(self._uvars):
934
+ data[i] = handle[key][tidx, lidx][()].astype(dtype=self.rtype)
935
+
936
+ return np.ascontiguousarray(np.flip(data, axis=1))
937
+
938
+ def _level_idxs(self, lvls):
939
+ lidx = [np.argwhere(lvls == int(lvl)).item() for lvl in self._level]
940
+ return sorted(lidx)
941
+
942
+ @staticmethod
943
+ def _date_to_tidx(date: datetime | pd.Timestamp, handle: h5py.File) -> int:
944
+ if isinstance(date, pd.Timestamp):
945
+ date = date.to_pydatetime()
946
+
947
+ time = handle["time"]
948
+
949
+ t0 = time.attrs["begin_time"][()].item()
950
+ d0 = f"{time.attrs['begin_date'][()].item()}"
951
+
952
+ offset = datetime.strptime(d0, "%Y%m%d")
953
+
954
+ times = [offset + timedelta(minutes=int(t + t0)) for t in time[()]]
955
+ return times.index(date)
956
+
957
+ def _read_data(
958
+ self, file_pair: tuple[str, str], date: datetime
959
+ ) -> dict[str, np.ndarray]:
960
+ s_file, v_file = file_pair
961
+
962
+ with h5py.File(s_file, "r", libver="latest") as shandle:
963
+ lats_surf = shandle["lat"]
964
+ lons_surf = shandle["lon"]
965
+
966
+ nll = (len(lats_surf), len(lons_surf))
967
+
968
+ tidx = self._date_to_tidx(date, shandle)
969
+
970
+ sdata = self._read_surface(tidx, nll, shandle)
971
+
972
+ with h5py.File(v_file, "r", libver="latest") as vhandle:
973
+ lats_vert = vhandle["lat"]
974
+ lons_vert = vhandle["lon"]
975
+
976
+ nll = (len(lats_vert), len(lons_vert))
977
+
978
+ tidx = self._date_to_tidx(date, vhandle)
979
+
980
+ vdata = self._read_levels(tidx, nll, vhandle)
981
+
982
+ data = {"vert": vdata, "surf": sdata}
983
+
984
+ return data
985
+
986
+ def _read_climate(
987
+ self, file_pair: tuple[str, str]
988
+ ) -> dict[str, np.ndarray]:
989
+ s_file, v_file = file_pair
990
+
991
+ with h5py.File(s_file, "r", libver="latest") as shandle:
992
+ lats_surf = shandle["lat"]
993
+ lons_surf = shandle["lon"]
994
+
995
+ nll = (len(lats_surf), len(lons_surf))
996
+
997
+ sdata = np.empty((self._nsvars, *nll), dtype=self.rtype)
998
+
999
+ for i, key in enumerate(self._svars):
1000
+ sdata[i] = shandle[key][()].astype(dtype=self.rtype)
1001
+
1002
+ with h5py.File(v_file, "r", libver="latest") as vhandle:
1003
+ lats_vert = vhandle["lat"]
1004
+ lons_vert = vhandle["lon"]
1005
+
1006
+ nll = (len(lats_vert), len(lons_vert))
1007
+
1008
+ lvls = vhandle["lev"][()]
1009
+ lidx = self._level_idxs(lvls)
1010
+
1011
+ vdata = np.empty(
1012
+ (self._nuvars, self._nlevel, *nll), dtype=self.rtype
1013
+ )
1014
+
1015
+ for i, key in enumerate(self._uvars):
1016
+ vdata[i] = vhandle[key][lidx][()].astype(dtype=self.rtype)
1017
+
1018
+ data = {
1019
+ "vert": np.ascontiguousarray(np.flip(vdata, axis=1)),
1020
+ "surf": sdata,
1021
+ }
1022
+
1023
+ return data
1024
+
1025
+ def get_data_from_sample_spec(
1026
+ self, spec: SampleSpec
1027
+ ) -> dict[str, Tensor | int | float]:
1028
+ """Loads and assembles sample data given a SampleSpec object.
1029
+
1030
+ Args:
1031
+ spec (SampleSpec): Full details regarding the data to be loaded
1032
+ Returns:
1033
+ dict: Dictionary with the following keys::
1034
+
1035
+ 'sur_static': Torch tensor of shape [parameter, lat, lon]. For
1036
+ each pixel (lat, lon), the first 7 dimensions index sin(lat),
1037
+ cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod).
1038
+ Where doy is the day of the year [1, 366] and hod the hour of
1039
+ the day [0, 23].
1040
+ 'sur_vals': Torch tensor of shape [parameter, time, lat, lon].
1041
+ 'sur_tars': Torch tensor of shape [parameter, time, lat, lon].
1042
+ 'ulv_vals': Torch tensor of shape [parameter, level, time, lat, lon].
1043
+ 'ulv_tars': Torch tensor of shape [parameter, level, time, lat, lon].
1044
+ 'sur_climate': Torch tensor of shape [parameter, lat, lon].
1045
+ 'ulv_climate': Torch tensor of shape [paramter, level, lat, lon].
1046
+ 'lead_time': Float.
1047
+ 'input_time': Float.
1048
+
1049
+ """ # noqa: E501
1050
+
1051
+ # We assemble the unique timestamps for which we need data.
1052
+ vals_required = {*spec.times}
1053
+ stat_required = {*spec.stat_times}
1054
+
1055
+ # We assemble the unique data files from which we need value data
1056
+ vals_file_map = defaultdict(list)
1057
+ for t in vals_required:
1058
+ data_files = (
1059
+ self.data_file_surface(t),
1060
+ self.data_file_vertical(t),
1061
+ )
1062
+ vals_file_map[data_files].append(t)
1063
+
1064
+ # We assemble the unique data files from which we need static data
1065
+ stat_file_map = defaultdict(list)
1066
+ for t in stat_required:
1067
+ data_files = (
1068
+ self.data_file_surface(t),
1069
+ self.data_file_vertical(t),
1070
+ )
1071
+ stat_file_map[data_files].append(t)
1072
+
1073
+ # Load the value data
1074
+ data = {}
1075
+ for data_files, times in vals_file_map.items():
1076
+ for time in times:
1077
+ data[time] = self._read_data(data_files, time)
1078
+
1079
+ # Combine times
1080
+ sample_data = {}
1081
+
1082
+ input_upl = np.stack([data[t]["vert"] for t in spec.inputs], axis=2)
1083
+ sample_data["ulv_vals"] = input_upl
1084
+
1085
+ target_upl = data[spec.target]["vert"]
1086
+ sample_data["ulv_tars"] = target_upl[:, :, None]
1087
+
1088
+ input_sur = np.stack([data[t]["surf"] for t in spec.inputs], axis=1)
1089
+ sample_data["sur_vals"] = input_sur
1090
+
1091
+ target_sur = data[spec.target]["surf"]
1092
+ sample_data["sur_tars"] = target_sur[:, None]
1093
+
1094
+ # Load the static data
1095
+ data_files, times = stat_file_map.popitem()
1096
+ time = times[0].dayofyear, times[0].hour
1097
+ sample_data["sur_static"] = self._read_static_data(
1098
+ data_files[0], *time
1099
+ )
1100
+
1101
+ # If required load the surface data
1102
+ if self._require_clim:
1103
+ ci_year, ci_hour = spec.climatology_info
1104
+
1105
+ surf_file = self.data_file_surface_climate(
1106
+ dayofyear=ci_year,
1107
+ hourofday=ci_hour,
1108
+ )
1109
+
1110
+ vert_file = self.data_file_vertical_climate(
1111
+ dayofyear=ci_year,
1112
+ hourofday=ci_hour,
1113
+ )
1114
+
1115
+ clim_data = self._read_climate((surf_file, vert_file))
1116
+
1117
+ sample_data["sur_climate"] = clim_data["surf"]
1118
+ sample_data["ulv_climate"] = clim_data["vert"]
1119
+
1120
+ # Move the data from numpy to torch
1121
+ sample_data = self._to_torch(sample_data, dtype=self.dtype)
1122
+
1123
+ # Optionally roll
1124
+ if len(self._roll_longitudes) > 0:
1125
+ roll_by = random.choice(self._roll_longitudes)
1126
+ sample_data = self._lat_roll(sample_data, roll_by)
1127
+
1128
+ # Now that we have rolled, we can add the static data
1129
+ sample_data["lead_time"] = spec.lead_time
1130
+ sample_data["input_time"] = spec.input_time
1131
+
1132
+ return sample_data
1133
+
1134
+ def get_data(
1135
+ self, timestamp: pd.Timestamp, input_time: int, lead_time: int
1136
+ ) -> dict[str, Tensor | int]:
1137
+ """
1138
+ Loads data based on timestamp and lead time.
1139
+ Args:
1140
+ timestamp: Timestamp.
1141
+ input_time: time between input samples.
1142
+ lead_time: lead time.
1143
+ Returns:
1144
+ Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',
1145
+ 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',
1146
+ 'lead_time'.
1147
+ """
1148
+ spec = SampleSpec.get(timestamp, -input_time, lead_time)
1149
+ sample_data = self.get_data_from_sample_spec(spec)
1150
+ return sample_data
1151
+
1152
+ def __getitem__(self, idx: int) -> dict[str, Tensor | int]:
1153
+ """
1154
+ Loads data based on sample index and random choice of sample.
1155
+ Args:
1156
+ idx: Sample index.
1157
+ Returns:
1158
+ Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',
1159
+ 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',
1160
+ 'lead_time', 'input_time'.
1161
+ """
1162
+ sample_set = self.samples[idx]
1163
+ timestamp, input_time, lead_time, *nsteps = random.choice(sample_set)
1164
+ sample_data = self.get_data(timestamp, input_time, lead_time)
1165
+ return sample_data
1166
+
1167
+ def __len__(self):
1168
+ return len(self.samples)
PrithviWxC/dataloaders/merra2_rollout.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools as ft
2
+ import random
3
+ from collections import defaultdict
4
+ from copy import deepcopy
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ from PrithviWxC.dataloaders.merra2 import Merra2Dataset, SampleSpec
13
+
14
+
15
+ def preproc(
16
+ batch: list[dict[str, int | float | Tensor]], padding: dict[tuple[int]]
17
+ ) -> dict[str, Tensor]:
18
+ """Prepressing function for MERRA2 Dataset
19
+
20
+ Args:
21
+ batch (dict): List of training samples, each sample should be a
22
+ dictionary with the following keys::
23
+
24
+ 'sur_static': Numpy array of shape (3, lat, lon). For each pixel (lat, lon), the first dimension indexes sin(lat), cos(lon), sin(lon).
25
+ 'sur_vals': Torch tensor of shape (parameter, time, lat, lon).
26
+ 'sur_tars': Torch tensor of shape (parameter, time, lat, lon).
27
+ 'ulv_vals': Torch tensor of shape (parameter, level, time, lat, lon).
28
+ 'ulv_tars': Torch tensor of shape (parameter, level, time, lat, lon).
29
+ 'sur_climate': Torch tensor of shape (nstep, parameter, lat, lon)
30
+ 'ulv_climate': Torch tensor of shape (nstep parameter, level, lat, lon)
31
+ 'lead_time': Integer.
32
+ 'input_time': Interger
33
+
34
+ padding: Dictionary with keys 'level', 'lat', 'lon', each of dim 2.
35
+
36
+ Returns:
37
+ Dictionary with the following keys::
38
+
39
+ 'x': [batch, time, parameter, lat, lon]
40
+ 'ys': [batch, nsteps, parameter, lat, lon]
41
+ 'static': [batch, nstep, parameter, lat, lon]
42
+ 'lead_time': [batch]
43
+ 'input_time': [batch]
44
+ 'climate (Optional)': [batch, nsteps, parameter, lat, lon]
45
+
46
+ Note:
47
+ Here, for x and ys, 'parameter' is [surface parameter, upper level,
48
+ parameter x level]. Similarly for the static information we have
49
+ [sin(lat), cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod),
50
+ ...].
51
+ """ # noqa: E501
52
+
53
+ b0 = batch[0]
54
+ nbatch = len(batch)
55
+ data_keys = set(b0.keys())
56
+
57
+ essential_keys = {
58
+ "sur_static",
59
+ "sur_vals",
60
+ "sur_tars",
61
+ "ulv_vals",
62
+ "ulv_tars",
63
+ "input_time",
64
+ "lead_time",
65
+ }
66
+
67
+ climate_keys = {
68
+ "sur_climate",
69
+ "ulv_climate",
70
+ }
71
+
72
+ all_keys = essential_keys | climate_keys
73
+
74
+ if not essential_keys.issubset(data_keys):
75
+ raise ValueError("Missing essential keys.")
76
+
77
+ if not data_keys.issubset(all_keys):
78
+ raise ValueError("Unexpected keys in batch.")
79
+
80
+ # Bring all tensors from the batch into a single tensor
81
+ upl_x = torch.empty((nbatch, *b0["ulv_vals"].shape))
82
+ upl_y = torch.empty((nbatch, *b0["ulv_tars"].shape))
83
+
84
+ sur_x = torch.empty((nbatch, *b0["sur_vals"].shape))
85
+ sur_y = torch.empty((nbatch, *b0["sur_tars"].shape))
86
+
87
+ sur_sta = torch.empty((nbatch, *b0["sur_static"].shape))
88
+
89
+ lead_time = torch.empty(
90
+ (nbatch, *b0["lead_time"].shape),
91
+ dtype=torch.float32,
92
+ )
93
+ input_time = torch.empty((nbatch,), dtype=torch.float32)
94
+
95
+ for i, rec in enumerate(batch):
96
+ sur_x[i] = torch.Tensor(rec["sur_vals"])
97
+ sur_y[i] = torch.Tensor(rec["sur_tars"])
98
+
99
+ upl_x[i] = torch.Tensor(rec["ulv_vals"])
100
+ upl_y[i] = torch.Tensor(rec["ulv_tars"])
101
+
102
+ sur_sta[i] = torch.Tensor(rec["sur_static"])
103
+
104
+ lead_time[i] = rec["lead_time"]
105
+ input_time[i] = rec["input_time"]
106
+
107
+ return_value = {
108
+ "lead_time": lead_time,
109
+ "input_time": input_time,
110
+ "target_time": torch.sum(lead_time).reshape(-1),
111
+ }
112
+
113
+ # Reshape (batch, parameter, level, time, lat, lon)
114
+ # -> (batch, time, parameter, level, lat, lon)
115
+ upl_x = upl_x.permute((0, 3, 1, 2, 4, 5))
116
+ upl_y = upl_y.permute((0, 3, 1, 2, 4, 5))
117
+
118
+ # Reshape (batch, parameter, time, lat, lon)
119
+ # -> (batch, time, parameter, lat, lon)
120
+ sur_x = sur_x.permute((0, 2, 1, 3, 4))
121
+ sur_y = sur_y.permute((0, 2, 1, 3, 4))
122
+
123
+ # Pad
124
+ padding_2d = (*padding["lon"], *padding["lat"])
125
+
126
+ def pad2d(x):
127
+ return torch.nn.functional.pad(x, padding_2d, mode="constant", value=0)
128
+
129
+ padding_3d = (*padding["lon"], *padding["lat"], *padding["level"])
130
+
131
+ def pad3d(x):
132
+ return torch.nn.functional.pad(x, padding_3d, mode="constant", value=0)
133
+
134
+ sur_x = pad2d(sur_x).contiguous()
135
+ upl_x = pad3d(upl_x).contiguous()
136
+ sur_y = pad2d(sur_y).contiguous()
137
+ upl_y = pad3d(upl_y).contiguous()
138
+ return_value["statics"] = pad2d(sur_sta).contiguous()
139
+
140
+ # We stack along the combined parameter level dimension
141
+ return_value["x"] = torch.cat(
142
+ (sur_x, upl_x.view(*upl_x.shape[:2], -1, *upl_x.shape[4:])), dim=2
143
+ )
144
+ return_value["ys"] = torch.cat(
145
+ (sur_y, upl_y.view(*upl_y.shape[:2], -1, *upl_y.shape[4:])), dim=2
146
+ )
147
+
148
+ if climate_keys.issubset(data_keys):
149
+ sur_climate = torch.empty((nbatch, *b0["sur_climate"].shape))
150
+ ulv_climate = torch.empty((nbatch, *b0["ulv_climate"].shape))
151
+ for i, rec in enumerate(batch):
152
+ sur_climate[i] = rec["sur_climate"]
153
+ ulv_climate[i] = rec["ulv_climate"]
154
+ sur_climate = pad2d(sur_climate)
155
+ ulv_climate = pad3d(ulv_climate)
156
+
157
+ ulv_climate = ulv_climate.view(
158
+ *ulv_climate.shape[:2], -1, *ulv_climate.shape[4:]
159
+ )
160
+ return_value["climates"] = torch.cat((sur_climate, ulv_climate), dim=2)
161
+
162
+ return return_value
163
+
164
+
165
+ class RolloutSpec(SampleSpec):
166
+ """
167
+ A data class to collect the information used to define a rollout sample.
168
+ """
169
+
170
+ def __init__(
171
+ self,
172
+ inputs: tuple[pd.Timestamp, pd.Timestamp],
173
+ lead_time: int,
174
+ target: pd.Timestamp,
175
+ ):
176
+ """
177
+ Args:
178
+ inputs: Tuple of timestamps. In ascending order.
179
+ lead_time: Lead time. In hours.
180
+ target: Timestamp of the target. Can be before or after the inputs.
181
+ """
182
+ super().__init__(inputs, lead_time, target)
183
+
184
+ self.dt = dt = pd.Timedelta(lead_time, unit="h")
185
+ self.inters = list(pd.date_range(inputs[-1], target, freq=dt))
186
+
187
+ self._ctimes = deepcopy(self.inters)
188
+ self.stat_times = deepcopy(self.inters)
189
+
190
+ self.stat_times.pop(-1)
191
+ self._ctimes.pop(0)
192
+ self.inters.pop(0)
193
+ self.inters.pop(-1)
194
+
195
+ self.times = [*inputs, *self.inters, target]
196
+ self.targets = self.times[2:]
197
+ self.nsteps = len(self.times) - 2
198
+
199
+ @property
200
+ def climatology_info(self) -> dict[pd.Timestamp, tuple[int, int]]:
201
+ """Returns information required to obtain climatology data.
202
+ Returns:
203
+ list: list containing required climatology info.
204
+ """
205
+ return [(min(t.dayofyear, 365), t.hour) for t in self._ctimes]
206
+
207
+ def _info_str(self) -> str:
208
+ iso_8601 = "%Y-%m-%dT%H:%M:%S"
209
+
210
+ inter_str = "\n".join(t.strftime(iso_8601) for t in self.inters)
211
+
212
+ return (
213
+ f"Issue time: {self.inputs[1].strftime(iso_8601)}\n"
214
+ f"Lead time: {self.lead_time} hours ahead\n"
215
+ f"Target time: {self.target.strftime(iso_8601)}\n"
216
+ f"Intermediate times: {inter_str}"
217
+ )
218
+
219
+ @classmethod
220
+ def get(cls, timestamp: pd.Timestamp, lead_time: int, nsteps: int):
221
+ """Given a timestamp and lead time, generates a RolloutSpec object
222
+ describing the sample further.
223
+
224
+ Args:
225
+ timestamp: Timstamp (issue time) of the sample.
226
+ lead_time: Lead time. In hours.
227
+
228
+ Returns:
229
+ SampleSpec object.
230
+ """
231
+ if lead_time > 0:
232
+ dt = pd.to_timedelta(lead_time, unit="h")
233
+ timestamp_target = timestamp + nsteps * dt
234
+ else:
235
+ raise ValueError("Rollout is only forwards")
236
+
237
+ spec = cls(
238
+ inputs=(timestamp - dt, timestamp),
239
+ lead_time=lead_time,
240
+ target=timestamp_target,
241
+ )
242
+
243
+ return spec
244
+
245
+ def __repr__(self) -> str:
246
+ return self._info_str()
247
+
248
+ def __str__(self) -> str:
249
+ return self._info_str()
250
+
251
+
252
+ class Merra2RolloutDataset(Merra2Dataset):
253
+ """Dataset class that read MERRA2 data for performing rollout.
254
+
255
+ Implementation details::
256
+
257
+ Samples stores the list of valid samples. This takes the form
258
+ ```
259
+ [
260
+ [(timestamp 1, -input_time, n_steps)],
261
+ [(timestamp 2, -input_time, n_steps)],
262
+ ]
263
+ ```
264
+ The nested list is for compatibility reasons with Merra2Dataset. Note
265
+ that input time and n_steps are always the same value. For some reason
266
+ the sign of input_time is the opposite to that in Merra2Dataset
267
+ """
268
+
269
+ input_time_len = 2
270
+
271
+ def __init__(
272
+ self,
273
+ time_range: tuple[str | pd.Timestamp, str | pd.Timestamp],
274
+ input_time: int | float | pd.Timedelta,
275
+ lead_time: int | float,
276
+ data_path_surface: str | Path,
277
+ data_path_vertical: str | Path,
278
+ climatology_path_surface: str | Path | None,
279
+ climatology_path_vertical: str | Path | None,
280
+ surface_vars: list[str],
281
+ static_surface_vars: list[str],
282
+ vertical_vars: list[str],
283
+ levels: list[float],
284
+ roll_longitudes: int = 0,
285
+ positional_encoding: str = "absolute",
286
+ ):
287
+ """
288
+ Args:
289
+ time_range: time range to consider when building dataset
290
+ input_time: requested time between inputs
291
+ lead_time: requested time to predict
292
+ data_path_surface: path of surface data directory
293
+ data_path_vertical: path of vertical data directory
294
+ climatology_path_surface: path of surface climatology data
295
+ directory
296
+ climatology_path_vertical: path of vertical climatology data
297
+ directory
298
+ surface_vars: surface variables to return
299
+ static_surface_vars: static surface variables to return
300
+ vertical_vars: vertical variables to return
301
+ levels: MERA2 vertical levels to consider
302
+ roll_longitudes: Whether and now uch to randomly roll latitudes by.
303
+ Defaults to 0.
304
+ positional_encoding: The type of possitional encodeing to use.
305
+ Defaults to "absolute".
306
+
307
+ Raises:
308
+ ValueError: If lead time is not integer multiple of input time
309
+ """
310
+
311
+ self._target_lead = lead_time
312
+
313
+ if isinstance(input_time, int) or isinstance(input_time, float):
314
+ self.timedelta_input = pd.to_timedelta(-input_time, unit="h")
315
+ else:
316
+ self.timedelta_input = -input_time
317
+
318
+ lead_times = [self.timedelta_input / pd.to_timedelta(1, unit="h")]
319
+
320
+ super().__init__(
321
+ time_range,
322
+ lead_times,
323
+ [input_time],
324
+ data_path_surface,
325
+ data_path_vertical,
326
+ climatology_path_surface,
327
+ climatology_path_vertical,
328
+ surface_vars,
329
+ static_surface_vars,
330
+ vertical_vars,
331
+ levels,
332
+ roll_longitudes,
333
+ positional_encoding,
334
+ )
335
+
336
+ nstep_float = (
337
+ pd.to_timedelta(self._target_lead, unit="h") / self.timedelta_input
338
+ )
339
+
340
+ if abs(nstep_float % 1) > 1e-5:
341
+ raise ValueError("Leadtime not multiple of input time")
342
+
343
+ self.nsteps = round(nstep_float)
344
+
345
+ @ft.cached_property
346
+ def samples(self) -> list[tuple[pd.Timestamp, int, int]]:
347
+ """Generates list of all valid samlpes.
348
+
349
+ Returns:
350
+ List of tuples (timestamp, input time, lead time).
351
+ """
352
+ valid_samples = []
353
+
354
+ for timestamp in sorted(self.valid_timestamps):
355
+ timestamp_samples = []
356
+ for lt in self.lead_times:
357
+ spec = RolloutSpec.get(timestamp, lt, self.nsteps)
358
+
359
+ if self._data_available(spec):
360
+ timestamp_samples.append(
361
+ (timestamp, self.input_times[0], lt, self.nsteps)
362
+ )
363
+
364
+ if timestamp_samples:
365
+ valid_samples.append(timestamp_samples)
366
+
367
+ return valid_samples
368
+
369
+ def get_data_from_rollout_spec(
370
+ self, spec: RolloutSpec
371
+ ) -> dict[str, Tensor | int | float]:
372
+ """Loads and assembles sample data given a RolloutSpec object.
373
+
374
+ Args:
375
+ spec (RolloutSpec): Full details regarding the data to be loaded
376
+ Returns:
377
+ dict: Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',
378
+ 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',c'lead_time',
379
+ 'input_time'. For each, the value is as follows::
380
+
381
+ {
382
+ 'sur_static': Torch tensor of shape [parameter, lat, lon]. For
383
+ each pixel (lat, lon), the first 7 dimensions index sin(lat),
384
+ cos(lon), sin(lon), cos(doy), sin(doy), cos(hod), sin(hod).
385
+ Where doy is the day of the year [1, 366] and hod the hour of
386
+ the day [0, 23].
387
+ 'sur_vals': Torch tensor of shape [parameter, time, lat, lon].
388
+ 'sur_tars': Torch tensor of shape [parameter, time, lat, lon].
389
+ 'ulv_vals': Torch tensor of shape
390
+ [parameter, level, time, lat, lon].
391
+ 'ulv_tars': Torch tensor of shape
392
+ [nsteps, parameter, level, time, lat, lon].
393
+ 'sur_climate': Torch tensor of shape
394
+ [nsteps, parameter, lat, lon].
395
+ 'ulv_climate': Torch tensor of shape
396
+ [nsteps, paramter, level, lat, lon].
397
+ 'lead_time': Float.
398
+ 'input_time': Float.
399
+ }
400
+
401
+ """
402
+
403
+ # We assemble the unique timestamps for which we need data.
404
+ vals_required = {*spec.times}
405
+ stat_required = {*spec.stat_times}
406
+
407
+ # We assemble the unique data files from which we need value data
408
+ vals_file_map = defaultdict(list)
409
+ for t in vals_required:
410
+ data_files = (
411
+ self.data_file_surface(t),
412
+ self.data_file_vertical(t),
413
+ )
414
+ vals_file_map[data_files].append(t)
415
+
416
+ # We assemble the unique data files from which we need static data
417
+ stat_file_map = defaultdict(list)
418
+ for t in stat_required:
419
+ data_files = (
420
+ self.data_file_surface(t),
421
+ self.data_file_vertical(t),
422
+ )
423
+ stat_file_map[data_files].append(t)
424
+
425
+ # Load the value data
426
+ data = {}
427
+ for data_files, times in vals_file_map.items():
428
+ for time in times:
429
+ data[time] = self._read_data(data_files, time)
430
+
431
+ # Load the static data
432
+ stat = {}
433
+ for data_files, times in stat_file_map.items():
434
+ for time in times:
435
+ hod, doy = time.hour, time.dayofyear
436
+ stat[time] = self._read_static_data(data_files[0], hod, doy)
437
+
438
+ # Combine times
439
+ sample_data = {}
440
+
441
+ input_upl = np.stack([data[t]["vert"] for t in spec.inputs], axis=2)
442
+ sample_data["ulv_vals"] = input_upl
443
+
444
+ target_upl = np.stack([data[t]["vert"] for t in spec.targets], axis=2)
445
+ sample_data["ulv_tars"] = target_upl
446
+
447
+ input_sur = np.stack([data[t]["surf"] for t in spec.inputs], axis=1)
448
+ sample_data["sur_vals"] = input_sur
449
+
450
+ target_sur = np.stack([data[t]["surf"] for t in spec.targets], axis=1)
451
+ sample_data["sur_tars"] = target_sur
452
+
453
+ # Load the static data
454
+ static = np.stack([stat[t] for t in spec.stat_times], axis=0)
455
+ sample_data["sur_static"] = static
456
+
457
+ # If required load the climate data
458
+ if self._require_clim:
459
+ clim_data = {}
460
+ for ci in spec.climatology_info:
461
+ ci_year, ci_hour = ci
462
+
463
+ surf_file = self.data_file_surface_climate(
464
+ dayofyear=ci_year,
465
+ hourofday=ci_hour,
466
+ )
467
+
468
+ vert_file = self.data_file_vertical_climate(
469
+ dayofyear=ci_year,
470
+ hourofday=ci_hour,
471
+ )
472
+
473
+ clim_data[ci] = self._read_climate((surf_file, vert_file))
474
+
475
+ clim_surf = [clim_data[ci]["surf"] for ci in spec.climatology_info]
476
+ sample_data["sur_climate"] = np.stack(clim_surf, axis=0)
477
+
478
+ clim_surf = [clim_data[ci]["vert"] for ci in spec.climatology_info]
479
+ sample_data["ulv_climate"] = np.stack(clim_surf, axis=0)
480
+
481
+ # Move the data from numpy to torch
482
+ sample_data = self._to_torch(sample_data, dtype=self.dtype)
483
+
484
+ # Optionally roll
485
+ if len(self._roll_longitudes) > 0:
486
+ roll_by = random.choice(self._roll_longitudes)
487
+ sample_data = self._lat_roll(sample_data, roll_by)
488
+
489
+ # Now that we have rolled, we can add the static data
490
+ lt = torch.tensor([spec.lead_time] * self.nsteps).to(self.dtype)
491
+ sample_data["lead_time"] = lt
492
+ sample_data["input_time"] = spec.input_time
493
+
494
+ return sample_data
495
+
496
+ def get_data(
497
+ self, timestamp: pd.Timestamp, *args, **kwargs
498
+ ) -> dict[Tensor | int]:
499
+ """Loads data based on timestamp and lead time.
500
+
501
+ Args:
502
+ timestamp: Timestamp.
503
+ Returns:
504
+ Dictionary with keys 'sur_static', 'sur_vals', 'sur_tars',
505
+ 'ulv_vals', 'ulv_tars', 'sur_climate', 'ulv_climate',
506
+ 'lead_time', 'input_time'
507
+ """
508
+ rollout_spec = RolloutSpec.get(
509
+ timestamp, self.lead_times[0], self.nsteps
510
+ )
511
+ sample_data = self.get_data_from_rollout_spec(rollout_spec)
512
+ return sample_data
PrithviWxC/model.py ADDED
@@ -0,0 +1,1321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import cached_property
2
+ from importlib.metadata import version
3
+
4
+ from torch import Tensor
5
+ from torch.utils.checkpoint import checkpoint
6
+
7
+ if version("torch") > "2.3.0":
8
+ from torch.nn.attention import SDPBackend, sdpa_kernel
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+
15
+ # DropPath code is straight from timm
16
+ # (https://huggingface.co/spaces/Roll20/pet_score/blame/main/lib/timm/models/layers/drop.py)
17
+ def drop_path(
18
+ x: Tensor,
19
+ drop_prob: float = 0.0,
20
+ training: bool = False,
21
+ scale_by_keep: bool = True,
22
+ ) -> Tensor:
23
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
24
+ residual blocks). Taken form timm.
25
+
26
+ Args:
27
+ x (Tensor): Input tensor.
28
+ drop_prob (float): Probability of dropping `x`, defaults to 0.
29
+ training (bool): Whether model is in in traingin of eval mode,
30
+ defaults to False.
31
+ scale_by_keep (bool): Whether the output should scaled by
32
+ (`1 - drop_prob`), defaults to True.
33
+ Returns:
34
+ Tensor: Tensor that may have randomly dropped with proability
35
+ `drop_path`
36
+ """
37
+ if drop_prob == 0.0 or not training:
38
+ return x
39
+ keep_prob = 1 - drop_prob
40
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
41
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
42
+ if keep_prob > 0.0 and scale_by_keep:
43
+ random_tensor.div_(keep_prob)
44
+ return x * random_tensor
45
+
46
+
47
+ class DropPath(nn.Module):
48
+ """
49
+ Drop paths (Stochastic Depth) per sample (when applied in main path of
50
+ residual blocks).
51
+ """
52
+
53
+ def __init__(
54
+ self, drop_prob: float | None = None, scale_by_keep: bool = True
55
+ ) -> None:
56
+ super(DropPath, self).__init__()
57
+ self.drop_prob = drop_prob
58
+ self.scale_by_keep = scale_by_keep
59
+
60
+ def forward(self, x: Tensor) -> Tensor:
61
+ """Runs drop path on input tensor
62
+
63
+ Args:
64
+ x: input
65
+
66
+ Returns:
67
+ tensor: output after drop_path
68
+ """
69
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
70
+
71
+
72
+ class Mlp(nn.Module):
73
+ """
74
+ Multi layer perceptron.
75
+ """
76
+
77
+ def __init__(
78
+ self, features: int, hidden_features: int, dropout: float = 0.0
79
+ ) -> None:
80
+ """
81
+ Args:
82
+ features: Input/output dimension.
83
+ hidden_features: Hidden dimension.
84
+ dropout: Dropout.
85
+ """
86
+ super().__init__()
87
+ self.net = nn.Sequential(
88
+ nn.Linear(features, hidden_features),
89
+ nn.GELU(),
90
+ nn.Dropout(dropout),
91
+ nn.Linear(hidden_features, features),
92
+ nn.Dropout(dropout),
93
+ )
94
+
95
+ def forward(self, x: Tensor) -> Tensor:
96
+ """
97
+ Args:
98
+ x (Tesnor): Tensor of shape [..., channel]
99
+ Returns:
100
+ Tenosr: Tensor of same shape as x.
101
+ """
102
+ return self.net(x)
103
+
104
+
105
+ class LayerNormPassThrough(nn.LayerNorm):
106
+ """Normalising layer that allows the attention mask to be passed through"""
107
+
108
+ def __init__(self, *args, **kwargs):
109
+ super().__init__(*args, **kwargs)
110
+
111
+ def forward(
112
+ self, d: tuple[Tensor, Tensor | None]
113
+ ) -> tuple[Tensor, Tensor | None]:
114
+ """Forwards function
115
+
116
+ Args:
117
+ d (tuple): tuple of the data tensor and the attention mask
118
+ Returns:
119
+ output (Tensor): normalised output data
120
+ attn_mask (Tensor): the attention mask that was passed in
121
+ """
122
+ input, attn_mask = d
123
+ output = F.layer_norm(
124
+ input, self.normalized_shape, self.weight, self.bias, self.eps
125
+ )
126
+ return output, attn_mask
127
+
128
+
129
+ class MultiheadAttention(nn.Module):
130
+ """Multihead attention layer for inputs of shape
131
+ [..., sequence, features].
132
+ """
133
+
134
+ def __init__(self, features: int, n_heads: int, dropout: float) -> None:
135
+ """
136
+ Args:
137
+ features: Number of features for inputs to the layer.
138
+ n_heads: Number of attention heads. Should be a factor of features.
139
+ (I.e. the layer uses features // n_heads.)
140
+ dropout: Dropout.
141
+ """ # noqa: E501
142
+ super().__init__()
143
+
144
+ if (features % n_heads) != 0:
145
+ raise ValueError(
146
+ f"Features '{features}' is not divisible by heads '{n_heads}'."
147
+ )
148
+
149
+ self.features = features
150
+ self.n_heads = n_heads
151
+ self.dropout = dropout
152
+
153
+ self.qkv_layer = torch.nn.Linear(features, features * 3, bias=False)
154
+ self.w_layer = torch.nn.Linear(features, features, bias=False)
155
+
156
+ def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:
157
+ """
158
+ Args:
159
+ d (tuple): tuple containing Tensor of shape [..., sequence, features] and the attention mask
160
+ Returns:
161
+ Tensor: Tensor of shape [..., sequence, features]
162
+ """ # noqa: E501
163
+ x, attn_mask = d
164
+
165
+ if not x.shape[-1] == self.features:
166
+ raise ValueError(
167
+ f"Expecting tensor with last dimension size {self.features}."
168
+ )
169
+
170
+ passenger_dims = x.shape[:-2]
171
+ B = passenger_dims.numel()
172
+ S = x.shape[-2]
173
+ C = x.shape[-1]
174
+ x = x.reshape(B, S, C)
175
+
176
+ # x [B, S, C]
177
+ # q, k, v [B, H, S, C/H]
178
+ q, k, v = (
179
+ self.qkv_layer(x)
180
+ .view(B, S, self.n_heads, 3 * (C // self.n_heads))
181
+ .transpose(1, 2)
182
+ .chunk(chunks=3, dim=3)
183
+ )
184
+
185
+ # Let us enforce either flash (A100+) or memory efficient attention.
186
+ if version("torch") > "2.3.0":
187
+ with sdpa_kernel(
188
+ [SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION]
189
+ ):
190
+ # x [B, H, S, C//H]
191
+ x = F.scaled_dot_product_attention(
192
+ q, k, v, attn_mask=attn_mask, dropout_p=self.dropout
193
+ )
194
+ else:
195
+ with torch.backends.cuda.sdp_kernel(
196
+ enable_flash=True, enable_math=False, enable_mem_efficient=True
197
+ ):
198
+ # x [B, H, S, C//H]
199
+ x = F.scaled_dot_product_attention(
200
+ q, k, v, dropout_p=self.dropout
201
+ )
202
+
203
+ # x [B, S, C]
204
+ x = x.transpose(1, 2).view(B, S, C)
205
+
206
+ # x [B, S, C]
207
+ x = self.w_layer(x)
208
+
209
+ # Back to input shape
210
+ x = x.view(*passenger_dims, S, self.features)
211
+ return x
212
+
213
+
214
+ class Transformer(nn.Module):
215
+ """
216
+ Transformer for inputs of shape [..., S, features].
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ features: int,
222
+ mlp_multiplier: int,
223
+ n_heads: int,
224
+ dropout: float,
225
+ drop_path: float,
226
+ ) -> None:
227
+ """
228
+ Args:
229
+ features: Number of features for inputs to the layer.
230
+ mlp_multiplier: Model uses features*mlp_multiplier hidden units.
231
+ n_heads: Number of attention heads. Should be a factor of features.
232
+ (I.e. the layer uses features // n_heads.) dropout: Dropout.
233
+ drop_path: DropPath.
234
+ """
235
+ super().__init__()
236
+
237
+ self.features = features
238
+ self.mlp_multiplier = mlp_multiplier
239
+ self.n_heads = n_heads
240
+ self.dropout = dropout
241
+ self.drop_path = (
242
+ DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
243
+ )
244
+
245
+ self.attention = nn.Sequential(
246
+ LayerNormPassThrough(features),
247
+ MultiheadAttention(features, n_heads, dropout),
248
+ )
249
+
250
+ self.ff = nn.Sequential(
251
+ nn.LayerNorm(features),
252
+ Mlp(
253
+ features=features,
254
+ hidden_features=features * mlp_multiplier,
255
+ dropout=dropout,
256
+ ),
257
+ )
258
+
259
+ def forward(self, d: tuple[Tensor, Tensor | None]) -> Tensor:
260
+ """
261
+ Args:
262
+ x: Tensor of shape [..., sequence, features]
263
+ Returns:
264
+ Tensor: Tensor of shape [..., sequence, features]
265
+ """
266
+ x, attn_mask = d
267
+ if not x.shape[-1] == self.features:
268
+ raise ValueError(
269
+ f"Expecting tensor with last dimension size {self.features}."
270
+ )
271
+
272
+ attention_x = self.attention(d)
273
+
274
+ x = x + self.drop_path(attention_x)
275
+ x = x + self.drop_path(self.ff(x))
276
+
277
+ return x
278
+
279
+
280
+ class _Shift(nn.Module):
281
+ """Private base class for the shifter. This allows some behaviour to be
282
+ easily handled when the shifter isn't used.
283
+ """
284
+
285
+ def __init__(self):
286
+ super().__init__()
287
+
288
+ self._shifted = False
289
+
290
+ @torch.no_grad()
291
+ def reset(self) -> None:
292
+ """
293
+ Resets the bool tracking whether the data is shifted
294
+ """
295
+ self._shifted: bool = False
296
+
297
+ def forward(self, data: Tensor) -> tuple[Tensor, dict[bool, None]]:
298
+ return data, {True: None, False: None}
299
+
300
+
301
+ class SWINShift(_Shift):
302
+ """
303
+ Handles the shifting of patches similar to how SWIN works. However if we
304
+ shift the latitudes then the poles will wrap and potentially that might be
305
+ problematic. The possition tokens should handle it but masking is safer.
306
+ """
307
+
308
+ def __init__(
309
+ self,
310
+ mu_shape: tuple[int, int],
311
+ global_shape: tuple[int, int],
312
+ local_shape: tuple[int, int],
313
+ patch_shape: tuple[int, int],
314
+ n_context_tokens: int = 2,
315
+ ) -> None:
316
+ """
317
+ Args:
318
+ mu_shape: the shape to the masking units
319
+ global_shape: number of global patches in lat and lon
320
+ local_shape: size of the local patches
321
+ patch_shape: patch size
322
+ n_context_token: number of additional context tokens at start of
323
+ _each_ local sequence
324
+ """
325
+ super().__init__()
326
+
327
+ self._mu_shape = ms = mu_shape
328
+ self._g_shape = gs = global_shape
329
+ self._l_shape = ls = local_shape
330
+ self._p_shape = ps = patch_shape
331
+ self._lat_patch = (gs[0], ls[0], gs[1], ls[1])
332
+ self._n_context_tokens = n_context_tokens
333
+
334
+ self._g_shift_to = tuple(
335
+ int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)
336
+ )
337
+ self._g_shift_from = tuple(
338
+ -int(0.5 * x / p) for x, p in zip(ms, ps, strict=False)
339
+ )
340
+
341
+ # Define the attention masks for the shifted MaxViT.
342
+ nglobal = global_shape[0] * global_shape[1]
343
+ nlocal = (
344
+ local_shape[0] * local_shape[1] + self._n_context_tokens
345
+ ) # "+ 1" for leadtime
346
+
347
+ lm = torch.ones((nglobal, 1, nlocal, nlocal), dtype=bool)
348
+ mwidth = int(0.5 * local_shape[1]) * local_shape[0]
349
+ lm[
350
+ : gs[1],
351
+ :,
352
+ self._n_context_tokens : mwidth + self._n_context_tokens,
353
+ self._n_context_tokens : mwidth + self._n_context_tokens,
354
+ ] = False
355
+ self.register_buffer("local_mask", lm)
356
+
357
+ gm = torch.ones((nlocal, 1, nglobal, nglobal), dtype=bool)
358
+ gm[: int(0.5 * ls[1]) * ls[0], :, : gs[1], : gs[1]] = False
359
+ self.register_buffer("global_mask", gm)
360
+
361
+ def _to_grid_global(self, x: Tensor) -> Tensor:
362
+ """
363
+ Shuffle and reshape the data from the global/local setting back to the
364
+ lat/lon grid setting
365
+ Args:
366
+ x: the data tensor to be shuffled.
367
+ Returns:
368
+ x: data in the global/local setting
369
+ """
370
+ nbatch, *other = x.shape
371
+
372
+ y1 = x.view(nbatch, *self._g_shape, *self._l_shape, -1)
373
+ y2 = y1.permute(0, 5, 1, 3, 2, 4).contiguous()
374
+
375
+ s = y2.shape
376
+ return y2.view((nbatch, -1, s[2] * s[3], s[4] * s[5]))
377
+
378
+ def _to_grid_local(self, x: Tensor) -> Tensor:
379
+ """
380
+ Shuffle and reshape the data from the local/global setting to the
381
+ lat/lon grid setting
382
+ Args:
383
+ x: the data tensor to be shuffled.
384
+ Returns:
385
+ x: data in the lat/lon setting.
386
+ """
387
+ x = x.transpose(2, 1).contiguous()
388
+ return self._to_grid_global(x)
389
+
390
+ def _from_grid_global(self, x: Tensor) -> Tensor:
391
+ """
392
+ Shuffle and reshape the data from the lat/lon grid to the global/local
393
+ setting
394
+ Args:
395
+ x: the data tensor to be shuffled.
396
+ Returns:
397
+ x: data in the global/local setting
398
+ """
399
+ nbatch, *other = x.shape
400
+
401
+ z1 = x.view(nbatch, -1, *self._lat_patch)
402
+ z2 = z1.permute(0, 2, 4, 3, 5, 1).contiguous()
403
+
404
+ s = z2.shape
405
+ return z2.view(nbatch, s[1] * s[2], s[3] * s[4], -1)
406
+
407
+ def _from_grid_local(self, x: Tensor) -> Tensor:
408
+ """
409
+ Shuffle and reshape the data from the lat/lon grid to the local/global
410
+ setting
411
+ Args:
412
+ x: the data tensor to be shuffled.
413
+ Returns:
414
+ x: data in the local/global setting
415
+ """
416
+ x = self._from_grid_global(x)
417
+ return x.transpose(2, 1).contiguous()
418
+
419
+ def _shift(self, x: Tensor) -> Tensor:
420
+ """
421
+ Shifts data in the gridded lat/lon setting by half the mask unit shape
422
+ Args:
423
+ x: data to be shifted
424
+ Returns:
425
+ x: either the hsifted or unshifted data
426
+ """
427
+ shift = self._g_shift_from if self._shifted else self._g_shift_to
428
+ x_shifted = torch.roll(x, shift, (-2, -1))
429
+
430
+ self._shifted = not self._shifted
431
+ return x_shifted
432
+
433
+ def _sep_lt(self, x: Tensor) -> tuple[Tensor, Tensor]:
434
+ """
435
+ Seperate off the leadtime from the local patches
436
+ Args:
437
+ x: data to have leadtime removed from
438
+ Returns:
439
+ lt: leadtime
440
+ x: data without the lead time in the local patch
441
+ """
442
+ lt_it = x[:, : self._n_context_tokens, :, :]
443
+ x_stripped = x[:, self._n_context_tokens :, :, :]
444
+
445
+ return lt_it, x_stripped
446
+
447
+ def forward(self, data: Tensor) -> tuple[Tensor, Tensor]:
448
+ """Shift or unshift the the data depending on whether the data is
449
+ already shifted, as defined by self._shifte.
450
+
451
+ Args:
452
+ data: data to be shifted
453
+ Returns:
454
+ Tensor: shifted data Tensor
455
+ """
456
+ lt, x = self._sep_lt(data)
457
+
458
+ x_grid = self._to_grid_local(x)
459
+ x_shifted = self._shift(x_grid)
460
+ x_patched = self._from_grid_local(x_shifted)
461
+
462
+ # Mask has to be repeated based on batch size
463
+ n_batch = x_grid.shape[0]
464
+ local_rep = [n_batch] + [1] * (self.local_mask.ndim - 1)
465
+ global_rep = [n_batch] + [1] * (self.global_mask.ndim - 1)
466
+
467
+ if self._shifted:
468
+ attn_mask = {
469
+ True: self.local_mask.repeat(local_rep),
470
+ False: self.global_mask.repeat(global_rep),
471
+ }
472
+ else:
473
+ attn_mask = {True: None, False: None}
474
+
475
+ return torch.cat((lt, x_patched), axis=1), attn_mask
476
+
477
+
478
+ class LocalGlobalLocalBlock(nn.Module):
479
+ """
480
+ Applies alternating block and grid attention. Given a parameter n_blocks,
481
+ the entire module contains 2*n_blocks+1 transformer blocks. The first,
482
+ third, ..., last apply local (block) attention. The second, fourth, ...
483
+ global (grid) attention.
484
+
485
+ This is heavily inspired by
486
+ Tu et al. "MaxViT: Multi-Axis Vision Transformer"
487
+ (https://arxiv.org/abs/2204.01697).
488
+ """
489
+
490
+ def __init__(
491
+ self,
492
+ features: int,
493
+ mlp_multiplier: int,
494
+ n_heads: int,
495
+ dropout: float,
496
+ n_blocks: int,
497
+ drop_path: float,
498
+ shifter: nn.Module | None = None,
499
+ checkpoint: list[int] | None = None,
500
+ ) -> None:
501
+ """
502
+ Args:
503
+ features: Number of features for inputs to the layer.
504
+ mlp_multiplier: Model uses features*mlp_multiplier hidden units.
505
+ n_heads: Number of attention heads. Should be a factor of features.
506
+ (I.e. the layer uses features // n_heads.)
507
+ dropout: Dropout.
508
+ drop_path: DropPath.
509
+ n_blocks: Number of local-global transformer pairs.
510
+ """
511
+ super().__init__()
512
+
513
+ self.features = features
514
+ self.mlp_multiplier = mlp_multiplier
515
+ self.n_heads = n_heads
516
+ self.dropout = dropout
517
+ self.drop_path = drop_path
518
+ self.n_blocks = n_blocks
519
+ self._checkpoint = checkpoint or []
520
+
521
+ if not all(0 <= c < 2 * n_blocks + 1 for c in self._checkpoint):
522
+ raise ValueError(
523
+ "Checkpoints should be 0 <= i < 2*n_blocks+1. "
524
+ f"{self._checkpoint=}."
525
+ )
526
+
527
+ self.transformers = nn.ModuleList(
528
+ [
529
+ Transformer(
530
+ features=features,
531
+ mlp_multiplier=mlp_multiplier,
532
+ n_heads=n_heads,
533
+ dropout=dropout,
534
+ drop_path=drop_path,
535
+ )
536
+ for _ in range(2 * n_blocks + 1)
537
+ ]
538
+ )
539
+
540
+ self.evaluator = [
541
+ self._checkpoint_wrapper
542
+ if i in self._checkpoint
543
+ else lambda m, x: m(x)
544
+ for i, _ in enumerate(self.transformers)
545
+ ]
546
+
547
+ self.shifter = shifter or _Shift()
548
+
549
+ @staticmethod
550
+ def _checkpoint_wrapper(
551
+ model: nn.Module, data: tuple[Tensor, Tensor | None]
552
+ ) -> Tensor:
553
+ return checkpoint(model, data, use_reentrant=False)
554
+
555
+ def forward(self, x: Tensor) -> Tensor:
556
+ """
557
+ Args:
558
+ x: Tensor of shape::
559
+
560
+ [batch, global_sequence, local_sequence, features]
561
+
562
+ Returns:
563
+ Tensor: Tensor of shape::
564
+
565
+ [batch, global_sequence, local_sequence, features]
566
+ """
567
+ if x.shape[-1] != self.features:
568
+ raise ValueError(
569
+ f"Expecting tensor with last dimension size {self.features}."
570
+ )
571
+ if x.ndim != 4:
572
+ raise ValueError(
573
+ f"Expecting tensor with exactly four dimensions. {x.shape=}."
574
+ )
575
+
576
+ self.shifter.reset()
577
+ local: bool = True
578
+ attn_mask = {True: None, False: None}
579
+
580
+ transformer_iter = zip(self.evaluator, self.transformers, strict=False)
581
+
582
+ # First local block
583
+ evaluator, transformer = next(transformer_iter)
584
+ x = evaluator(transformer, (x, attn_mask[local]))
585
+
586
+ for evaluator, transformer in transformer_iter:
587
+ local = not local
588
+ # We are making exactly 2*n_blocks transposes.
589
+ # So the output has the same shape as input.
590
+ x = x.transpose(1, 2)
591
+
592
+ x = evaluator(transformer, (x, attn_mask[local]))
593
+
594
+ if not local:
595
+ x, attn_mask = self.shifter(x)
596
+
597
+ return x
598
+
599
+
600
+ class PatchEmbed(nn.Module):
601
+ """
602
+ Patch embedding via 2D convolution.
603
+ """
604
+
605
+ def __init__(
606
+ self, patch_size: int | tuple[int, ...], channels: int, embed_dim: int
607
+ ):
608
+ super().__init__()
609
+
610
+ self.patch_size = patch_size
611
+ self.channels = channels
612
+ self.embed_dim = embed_dim
613
+
614
+ self.proj = nn.Conv2d(
615
+ channels,
616
+ embed_dim,
617
+ kernel_size=patch_size,
618
+ stride=patch_size,
619
+ bias=True,
620
+ )
621
+
622
+ def forward(self, x: Tensor) -> Tensor:
623
+ """
624
+ Args:
625
+ x: Tensor of shape [batch, channels, lat, lon].
626
+ Returns:
627
+ Tensor: Tensor with shape
628
+ [batch, embed_dim, lat//patch_size, lon//patch_size]
629
+ """
630
+
631
+ H, W = x.shape[-2:]
632
+
633
+ if W % self.patch_size[1] != 0:
634
+ raise ValueError(
635
+ f"Cannot do patch embedding for tensor of shape {x.size()}"
636
+ " with patch size {self.patch_size}. (Dimensions are BSCHW.)"
637
+ )
638
+ if H % self.patch_size[0] != 0:
639
+ raise ValueError(
640
+ f"Cannot do patch embedding for tensor of shape {x.size()}"
641
+ f" with patch size {self.patch_size}. (Dimensions are BSCHW.)"
642
+ )
643
+
644
+ x = self.proj(x)
645
+
646
+ return x
647
+
648
+
649
+ class PrithviWxCEncoderDecoder(nn.Module):
650
+ """
651
+ Hiera-MaxViT encoder/decoder code.
652
+ """
653
+
654
+ def __init__(
655
+ self,
656
+ embed_dim: int,
657
+ n_blocks: int,
658
+ mlp_multiplier: float,
659
+ n_heads: int,
660
+ dropout: float,
661
+ drop_path: float,
662
+ shifter: nn.Module | None = None,
663
+ transformer_cp: list[int] | None = None,
664
+ ) -> None:
665
+ """
666
+ Args:
667
+ embed_dim: Embedding dimension
668
+ n_blocks: Number of local-global transformer pairs.
669
+ mlp_multiplier: MLP multiplier for hidden features in feed forward
670
+ networks.
671
+ n_heads: Number of attention heads.
672
+ dropout: Dropout.
673
+ drop_path: DropPath.
674
+ """
675
+ super().__init__()
676
+
677
+ self.embed_dim = embed_dim
678
+ self.n_blocks = n_blocks
679
+ self.mlp_multiplier = mlp_multiplier
680
+ self.n_heads = n_heads
681
+ self.dropout = dropout
682
+ self._transformer_cp = transformer_cp
683
+
684
+ self.lgl_block = LocalGlobalLocalBlock(
685
+ features=embed_dim,
686
+ mlp_multiplier=mlp_multiplier,
687
+ n_heads=n_heads,
688
+ dropout=dropout,
689
+ drop_path=drop_path,
690
+ n_blocks=n_blocks,
691
+ shifter=shifter,
692
+ checkpoint=transformer_cp,
693
+ )
694
+
695
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
696
+ """
697
+ Args:
698
+ x: Tensor of shape
699
+ [batch, global sequence, local sequence, embed_dim]
700
+ Returns:
701
+ Tensor of shape
702
+ [batch, mask_unit_sequence, local_sequence, embed_dim].
703
+ Identical in shape to the input x.
704
+ """
705
+
706
+ x = self.lgl_block(x)
707
+
708
+ return x
709
+
710
+
711
+ class PrithviWxC(nn.Module):
712
+ """Encoder-decoder fusing Hiera with MaxViT. See
713
+ - Ryali et al. "Hiera: A Hierarchical Vision Transformer without the
714
+ Bells-and-Whistles" (https://arxiv.org/abs/2306.00989)
715
+ - Tu et al. "MaxViT: Multi-Axis Vision Transformer"
716
+ (https://arxiv.org/abs/2204.01697)
717
+ """
718
+
719
+ def __init__(
720
+ self,
721
+ in_channels: int,
722
+ input_size_time: int,
723
+ in_channels_static: int,
724
+ input_scalers_mu: Tensor,
725
+ input_scalers_sigma: Tensor,
726
+ input_scalers_epsilon: float,
727
+ static_input_scalers_mu: Tensor,
728
+ static_input_scalers_sigma: Tensor,
729
+ static_input_scalers_epsilon: float,
730
+ output_scalers: Tensor,
731
+ n_lats_px: int,
732
+ n_lons_px: int,
733
+ patch_size_px: tuple[int],
734
+ mask_unit_size_px: tuple[int],
735
+ mask_ratio_inputs: float,
736
+ embed_dim: int,
737
+ n_blocks_encoder: int,
738
+ n_blocks_decoder: int,
739
+ mlp_multiplier: float,
740
+ n_heads: int,
741
+ dropout: float,
742
+ drop_path: float,
743
+ parameter_dropout: float,
744
+ residual: str,
745
+ masking_mode: str,
746
+ positional_encoding: str,
747
+ decoder_shifting: bool = False,
748
+ checkpoint_encoder: list[int] | None = None,
749
+ checkpoint_decoder: list[int] | None = None,
750
+ ) -> None:
751
+ """
752
+ Args:
753
+ in_channels: Number of input channels.
754
+ input_size_time: Number of timestamps in input.
755
+ in_channels_static: Number of input channels for static data.
756
+ input_scalers_mu: Tensor of size (in_channels,). Used to rescale
757
+ input.
758
+ input_scalers_sigma: Tensor of size (in_channels,). Used to rescale
759
+ input.
760
+ input_scalers_epsilon: Float. Used to rescale input.
761
+ static_input_scalers_mu: Tensor of size (in_channels_static). Used
762
+ to rescale static inputs.
763
+ static_input_scalers_sigma: Tensor of size (in_channels_static).
764
+ Used to rescale static inputs.
765
+ static_input_scalers_epsilon: Float. Used to rescale static inputs.
766
+ output_scalers: Tensor of shape (in_channels,). Used to rescale
767
+ output.
768
+ n_lats_px: Total latitudes in data. In pixels.
769
+ n_lons_px: Total longitudes in data. In pixels.
770
+ patch_size_px: Patch size for tokenization. In pixels lat/lon.
771
+ mask_unit_size_px: Size of each mask unit. In pixels lat/lon.
772
+ mask_ratio_inputs: Masking ratio for inputs. 0 to 1.
773
+ embed_dim: Embedding dimension
774
+ n_blocks_encoder: Number of local-global transformer pairs in
775
+ encoder.
776
+ n_blocks_decoder: Number of local-global transformer pairs in
777
+ decoder.
778
+ mlp_multiplier: MLP multiplier for hidden features in feed forward
779
+ networks.
780
+ n_heads: Number of attention heads.
781
+ dropout: Dropout.
782
+ drop_path: DropPath.
783
+ parameter_dropout: Dropout applied to parameters.
784
+ residual: Indicates whether and how model should work as residual
785
+ model. Accepted values are 'climate', 'temporal' and 'none'
786
+ positional_encoding: possible values are
787
+ ['absolute' (default), 'fourier'].
788
+ 'absolute' lat lon encoded in 3 dimensions using sine and
789
+ cosine
790
+ 'fourier' lat/lon to be encoded using various frequencies
791
+ masking_mode: String ['local', 'global', 'both'] that controls the
792
+ type of masking used.
793
+ checkpoint_encoder: List of integers controlling if gradient
794
+ checkpointing is used on encoder.
795
+ Format: [] for no gradient checkpointing. [3, 7] for
796
+ checkpointing after 4th and 8th layer etc.
797
+ checkpoint_decoder: List of integers controlling if gradient
798
+ checkpointing is used on decoder.
799
+ Format: See `checkpoint_encoder`.
800
+ masking_mode: The type of masking to use
801
+ {'global', 'local', 'both'}
802
+ decoder_shifting: Whether to use swin shifting in the decoder.
803
+ """
804
+ super().__init__()
805
+
806
+ self.in_channels = in_channels
807
+ self.input_size_time = input_size_time
808
+ self.in_channels_static = in_channels_static
809
+ self.n_lats_px = n_lats_px
810
+ self.n_lons_px = n_lons_px
811
+ self.patch_size_px = patch_size_px
812
+ self.mask_unit_size_px = mask_unit_size_px
813
+ self.mask_ratio_inputs = mask_ratio_inputs
814
+ self.embed_dim = embed_dim
815
+ self.n_blocks_encoder = n_blocks_encoder
816
+ self.n_blocks_decoder = n_blocks_decoder
817
+ self.mlp_multiplier = mlp_multiplier
818
+ self.n_heads = n_heads
819
+ self.dropout = dropout
820
+ self.drop_path = drop_path
821
+ self.residual = residual
822
+ self._decoder_shift = decoder_shifting
823
+ self.positional_encoding = positional_encoding
824
+ self._checkpoint_encoder = checkpoint_encoder
825
+ self._checkpoint_decoder = checkpoint_decoder
826
+
827
+ assert self.n_lats_px % self.mask_unit_size_px[0] == 0
828
+ assert self.n_lons_px % self.mask_unit_size_px[1] == 0
829
+ assert self.mask_unit_size_px[0] % self.patch_size_px[0] == 0
830
+ assert self.mask_unit_size_px[1] % self.patch_size_px[1] == 0
831
+
832
+ if self.patch_size_px[0] != self.patch_size_px[1]:
833
+ raise NotImplementedError(
834
+ "Current pixel shuffle symmetric patches."
835
+ )
836
+
837
+ self.local_shape_mu = (
838
+ self.mask_unit_size_px[0] // self.patch_size_px[0],
839
+ self.mask_unit_size_px[1] // self.patch_size_px[1],
840
+ )
841
+ self.global_shape_mu = (
842
+ self.n_lats_px // self.mask_unit_size_px[0],
843
+ self.n_lons_px // self.mask_unit_size_px[1],
844
+ )
845
+
846
+ assert input_scalers_mu.shape == (in_channels,)
847
+ assert input_scalers_sigma.shape == (in_channels,)
848
+ assert output_scalers.shape == (in_channels,)
849
+
850
+ if self.positional_encoding != "fourier":
851
+ assert static_input_scalers_mu.shape == (in_channels_static,)
852
+ assert static_input_scalers_sigma.shape == (in_channels_static,)
853
+
854
+ # Input shape [batch, time, parameter, lat, lon]
855
+ self.input_scalers_epsilon = input_scalers_epsilon
856
+ self.register_buffer(
857
+ "input_scalers_mu", input_scalers_mu.reshape(1, 1, -1, 1, 1)
858
+ )
859
+ self.register_buffer(
860
+ "input_scalers_sigma", input_scalers_sigma.reshape(1, 1, -1, 1, 1)
861
+ )
862
+
863
+ # Static inputs shape [batch, parameter, lat, lon]
864
+ self.static_input_scalers_epsilon = static_input_scalers_epsilon
865
+ self.register_buffer(
866
+ "static_input_scalers_mu",
867
+ static_input_scalers_mu.reshape(1, -1, 1, 1),
868
+ )
869
+ self.register_buffer(
870
+ "static_input_scalers_sigma",
871
+ static_input_scalers_sigma.reshape(1, -1, 1, 1),
872
+ )
873
+
874
+ # Output shape [batch, parameter, lat, lon]
875
+ self.register_buffer(
876
+ "output_scalers", output_scalers.reshape(1, -1, 1, 1)
877
+ )
878
+
879
+ self.parameter_dropout = nn.Dropout2d(p=parameter_dropout)
880
+
881
+ self.patch_embedding = PatchEmbed(
882
+ patch_size=patch_size_px,
883
+ channels=in_channels * input_size_time,
884
+ embed_dim=embed_dim,
885
+ )
886
+
887
+ if self.residual == "climate":
888
+ self.patch_embedding_static = PatchEmbed(
889
+ patch_size=patch_size_px,
890
+ channels=in_channels + in_channels_static,
891
+ embed_dim=embed_dim,
892
+ )
893
+ else:
894
+ self.patch_embedding_static = PatchEmbed(
895
+ patch_size=patch_size_px,
896
+ channels=in_channels_static,
897
+ embed_dim=embed_dim,
898
+ )
899
+
900
+ self.input_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)
901
+ self.lead_time_embedding = nn.Linear(1, embed_dim // 4, bias=True)
902
+
903
+ self.mask_token = nn.Parameter(torch.randn(1, 1, 1, self.embed_dim))
904
+ self._nglobal_mu = np.prod(self.global_shape_mu)
905
+ self._global_idx = torch.arange(self._nglobal_mu)
906
+
907
+ self._nlocal_mu = np.prod(self.local_shape_mu)
908
+ self._local_idx = torch.arange(self._nlocal_mu)
909
+
910
+ self.encoder = PrithviWxCEncoderDecoder(
911
+ embed_dim=embed_dim,
912
+ n_blocks=n_blocks_encoder,
913
+ mlp_multiplier=mlp_multiplier,
914
+ n_heads=n_heads,
915
+ dropout=dropout,
916
+ drop_path=drop_path,
917
+ transformer_cp=checkpoint_encoder,
918
+ )
919
+
920
+ if n_blocks_decoder != 0:
921
+ if self._decoder_shift:
922
+ self.decoder_shifter = d_shifter = SWINShift(
923
+ self.mask_unit_size_px,
924
+ self.global_shape_mu,
925
+ self.local_shape_mu,
926
+ self.patch_size_px,
927
+ n_context_tokens=0,
928
+ )
929
+ else:
930
+ self.decoder_shifter = d_shifter = None
931
+
932
+ self.decoder = PrithviWxCEncoderDecoder(
933
+ embed_dim=embed_dim,
934
+ n_blocks=n_blocks_decoder,
935
+ mlp_multiplier=mlp_multiplier,
936
+ n_heads=n_heads,
937
+ dropout=dropout,
938
+ drop_path=0.0,
939
+ shifter=d_shifter,
940
+ transformer_cp=checkpoint_decoder,
941
+ )
942
+
943
+ self.unembed = nn.Linear(
944
+ self.embed_dim,
945
+ self.in_channels
946
+ * self.patch_size_px[0]
947
+ * self.patch_size_px[1],
948
+ bias=True,
949
+ )
950
+
951
+ self.masking_mode = masking_mode.lower()
952
+ match self.masking_mode:
953
+ case "local":
954
+ self.generate_mask = self._gen_mask_local
955
+ case "global":
956
+ self.generate_mask = self._gen_mask_global
957
+ case "both":
958
+ self._mask_both_local: bool = True
959
+ self.generate_mask = self._gen_mask_both
960
+ case _:
961
+ raise ValueError(
962
+ f"Masking mode '{masking_mode}' not supported"
963
+ )
964
+
965
+ def swap_masking(self) -> None:
966
+ self._mask_both_local = not self._mask_both_local
967
+
968
+ @cached_property
969
+ def n_masked_global(self):
970
+ return int(self.mask_ratio_inputs * np.prod(self.global_shape_mu))
971
+
972
+ @cached_property
973
+ def n_masked_local(self):
974
+ return int(self.mask_ratio_inputs * np.prod(self.local_shape_mu))
975
+
976
+ @staticmethod
977
+ def _shuffle_along_axis(a, axis):
978
+ idx = torch.argsort(input=torch.rand(*a.shape), dim=axis)
979
+ return torch.gather(a, dim=axis, index=idx)
980
+
981
+ def _gen_mask_local(self, sizes: tuple[int]) -> tuple[Tensor]:
982
+ """
983
+ Args:
984
+ batch_size: Number of elements in batch
985
+ Returns:
986
+ Tuple of torch tensors. [indices masked, indices unmasked].
987
+ Each of these is a tensor of shape (batch, global sequene)
988
+ """
989
+ # Identify which indices (values) should be masked
990
+
991
+ maskable_indices = self._local_idx.view(1, -1).expand(*sizes[:2], -1)
992
+
993
+ maskable_indices = self._shuffle_along_axis(maskable_indices, 2)
994
+
995
+ indices_masked = maskable_indices[:, :, : self.n_masked_local]
996
+ indices_unmasked = maskable_indices[:, :, self.n_masked_local :]
997
+
998
+ return indices_masked, indices_unmasked
999
+
1000
+ def _gen_mask_global(self, sizes: tuple[int]) -> tuple[Tensor]:
1001
+ """
1002
+ Args:
1003
+ batch_size: Number of elements in batch
1004
+ Returns:
1005
+ Tuple of torch tensors. [indices masked, indices unmasked].
1006
+ Each of these is a tensor of shape (batch, global sequene)
1007
+ """
1008
+ # Identify which indices (values) should be masked
1009
+
1010
+ maskable_indices = self._global_idx.view(1, -1).expand(*sizes[:1], -1)
1011
+
1012
+ maskable_indices = self._shuffle_along_axis(maskable_indices, 1)
1013
+
1014
+ indices_masked = maskable_indices[:, : self.n_masked_global]
1015
+ indices_unmasked = maskable_indices[:, self.n_masked_global :]
1016
+
1017
+ return indices_masked, indices_unmasked
1018
+
1019
+ def _gen_mask_both(self, sizes: tuple[int]) -> tuple[Tensor]:
1020
+ if self._mask_both_local:
1021
+ return self._gen_mask_local(sizes)
1022
+ else:
1023
+ return self._gen_mask_global(sizes)
1024
+
1025
+ @staticmethod
1026
+ def reconstruct_batch(
1027
+ idx_masked: Tensor,
1028
+ idx_unmasked: Tensor,
1029
+ data_masked: Tensor,
1030
+ data_unmasked: Tensor,
1031
+ ) -> Tensor:
1032
+ """Reconstructs a tensor along the mask unit dimension. Batched
1033
+ version.
1034
+
1035
+ Args:
1036
+ idx_masked: Tensor of shape `batch, mask unit sequence`.
1037
+ idx_unmasked: Tensor of shape `batch, mask unit sequence`.
1038
+ data_masked: Tensor of shape `batch, mask unit sequence, ...`.
1039
+ Should have same size along mask unit sequence dimension as
1040
+ idx_masked. Dimensions beyond the first two, marked here as ...
1041
+ will typically be `local_sequence, channel` or
1042
+ `channel, lat, lon`. These dimensions should agree with
1043
+ data_unmasked.
1044
+ data_unmasked: Tensor of shape `batch, mask unit sequence, ...`.
1045
+ Should have same size along mask unit sequence dimension as
1046
+ idx_unmasked. Dimensions beyond the first two, marked here as
1047
+ ... will typically be `local_sequence, channel` or `channel,
1048
+ lat, lon`. These dimensions should agree with data_masked.
1049
+ Returns:
1050
+ Tensor: Tensor of same shape as inputs data_masked and
1051
+ data_unmasked. I.e. `batch, mask unit sequence, ...`. Index for
1052
+ the total data composed of the masked and the unmasked part.
1053
+ """
1054
+ dim: int = idx_masked.ndim
1055
+
1056
+ idx_total = torch.argsort(
1057
+ torch.cat([idx_masked, idx_unmasked], dim=-1), dim=-1
1058
+ )
1059
+ idx_total = idx_total.view(
1060
+ *idx_total.shape, *[1] * (data_unmasked.ndim - dim)
1061
+ )
1062
+ idx_total = idx_total.expand(
1063
+ *idx_total.shape[:dim], *data_unmasked.shape[dim:]
1064
+ )
1065
+
1066
+ data = torch.cat([data_masked, data_unmasked], dim=dim - 1)
1067
+ data = torch.gather(data, dim=dim - 1, index=idx_total)
1068
+
1069
+ return data, idx_total
1070
+
1071
+ def fourier_pos_encoding(self, x_static: Tensor) -> Tensor:
1072
+ """
1073
+ Args
1074
+ x_static: B x C x H x W. first two channels are lat, and lon
1075
+ Returns
1076
+ Tensor: Tensor of shape B x E x H x W where E is the embedding
1077
+ dimension.
1078
+ """
1079
+
1080
+ # B x C x H x W -> B x 1 x H/P x W/P
1081
+ latitudes_patch = F.avg_pool2d(
1082
+ x_static[:, [0]],
1083
+ kernel_size=self.patch_size_px,
1084
+ stride=self.patch_size_px,
1085
+ )
1086
+ longitudes_patch = F.avg_pool2d(
1087
+ x_static[:, [1]],
1088
+ kernel_size=self.patch_size_px,
1089
+ stride=self.patch_size_px,
1090
+ )
1091
+
1092
+ modes = (
1093
+ torch.arange(self.embed_dim // 4, device=x_static.device).view(
1094
+ 1, -1, 1, 1
1095
+ )
1096
+ + 1.0
1097
+ )
1098
+ pos_encoding = torch.cat(
1099
+ (
1100
+ torch.sin(latitudes_patch * modes),
1101
+ torch.sin(longitudes_patch * modes),
1102
+ torch.cos(latitudes_patch * modes),
1103
+ torch.cos(longitudes_patch * modes),
1104
+ ),
1105
+ axis=1,
1106
+ )
1107
+
1108
+ return pos_encoding # B x E x H/P x W/P
1109
+
1110
+ def time_encoding(self, input_time, lead_time):
1111
+ """
1112
+ Args:
1113
+ input_time: Tensor of shape [batch].
1114
+ lead_time: Tensor of shape [batch].
1115
+ Returns:
1116
+ Tensor: Tensor of shape [batch, embed_dim, 1, 1]
1117
+ """
1118
+ input_time = self.input_time_embedding(input_time.view(-1, 1, 1, 1))
1119
+ lead_time = self.lead_time_embedding(lead_time.view(-1, 1, 1, 1))
1120
+
1121
+ time_encoding = torch.cat(
1122
+ (
1123
+ torch.cos(input_time),
1124
+ torch.cos(lead_time),
1125
+ torch.sin(input_time),
1126
+ torch.sin(lead_time),
1127
+ ),
1128
+ axis=3,
1129
+ )
1130
+ return time_encoding
1131
+
1132
+ def to_patching(self, x: Tensor) -> Tensor:
1133
+ """Transform data from lat/lon space to two axis patching
1134
+
1135
+ Args: ->
1136
+ x: Tesnor in lat/lon space (N, C, Nlat//P_0, Nlon//P_1)
1137
+
1138
+ Returns:
1139
+ Tensor in patch space (N, G, L, C)
1140
+ """
1141
+ n_batch = x.shape[0]
1142
+
1143
+ x = x.view(
1144
+ n_batch,
1145
+ -1,
1146
+ self.global_shape_mu[0],
1147
+ self.local_shape_mu[0],
1148
+ self.global_shape_mu[1],
1149
+ self.local_shape_mu[1],
1150
+ )
1151
+ x = x.permute(0, 2, 4, 3, 5, 1).contiguous()
1152
+
1153
+ s = x.shape
1154
+ return x.view(n_batch, s[1] * s[2], s[3] * s[4], -1)
1155
+
1156
+ def from_patching(self, x: Tensor) -> Tensor:
1157
+ """Transform data from two axis patching to lat/lon space
1158
+
1159
+ Args:
1160
+ x: Tensor in patch space with shape (N, G, L, C*P_0*P_1)
1161
+
1162
+ Returns:
1163
+ Tensor: Tensor in lat/lon space
1164
+ (N, C*P_0*P_1, Nlat//P_0, Nlon // P_1)
1165
+ """
1166
+ n_batch = x.shape[0]
1167
+
1168
+ x = x.view(
1169
+ n_batch,
1170
+ self.global_shape_mu[0],
1171
+ self.global_shape_mu[1],
1172
+ self.local_shape_mu[0],
1173
+ self.local_shape_mu[1],
1174
+ -1,
1175
+ )
1176
+ x = x.permute(0, 5, 1, 3, 2, 4).contiguous()
1177
+
1178
+ s = x.shape
1179
+ return x.view(n_batch, -1, s[2] * s[3], s[4] * s[5])
1180
+
1181
+ def forward(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:
1182
+ """
1183
+ Args:
1184
+ batch: Dictionary the following keys::
1185
+
1186
+ 'x': Tensor of shape [batch, time, parameter, lat, lon]
1187
+ 'y': Tensor of shape [batch, parameter, lat, lon]
1188
+ 'static': Tensor of shape [batch, channel_static, lat, lon]
1189
+ 'climate': Optional tensor of shape [batch, parameter, lat, lon]
1190
+ 'input_time': Tensor of shape [batch]. Or none.
1191
+ 'lead_time': Tensor of shape [batch]. Or none.
1192
+
1193
+ Returns:
1194
+ Tensor: Tensor of shape [batch, parameter, lat, lon].
1195
+ """ # noqa: E501
1196
+ x_rescaled = (batch["x"] - self.input_scalers_mu) / (
1197
+ self.input_scalers_sigma + self.input_scalers_epsilon
1198
+ )
1199
+ batch_size = x_rescaled.shape[0]
1200
+
1201
+ if self.positional_encoding == "fourier":
1202
+ x_static_pos = self.fourier_pos_encoding(batch["static"])
1203
+ x_static = (
1204
+ batch["static"][:, 2:] - self.static_input_scalers_mu[:, 3:]
1205
+ ) / (
1206
+ self.static_input_scalers_sigma[:, 3:]
1207
+ + self.static_input_scalers_epsilon
1208
+ )
1209
+ else:
1210
+ x_static = (batch["static"] - self.static_input_scalers_mu) / (
1211
+ self.static_input_scalers_sigma
1212
+ + self.static_input_scalers_epsilon
1213
+ )
1214
+
1215
+ if self.residual == "temporal":
1216
+ # We create a residual of same shape as y
1217
+ index = torch.where(
1218
+ batch["lead_time"] > 0, batch["x"].shape[1] - 1, 0
1219
+ )
1220
+ index = index.view(-1, 1, 1, 1, 1)
1221
+ index = index.expand(batch_size, 1, *batch["x"].shape[2:])
1222
+ x_hat = torch.gather(batch["x"], dim=1, index=index)
1223
+ x_hat = x_hat.squeeze(1)
1224
+ elif self.residual == "climate":
1225
+ climate_scaled = (
1226
+ batch["climate"] - self.input_scalers_mu.view(1, -1, 1, 1)
1227
+ ) / (
1228
+ self.input_scalers_sigma.view(1, -1, 1, 1)
1229
+ + self.input_scalers_epsilon
1230
+ )
1231
+
1232
+ # [batch, time, parameter, lat, lon]
1233
+ # -> [batch, time x parameter, lat, lon]
1234
+ x_rescaled = x_rescaled.flatten(1, 2)
1235
+ # Parameter dropout
1236
+ x_rescaled = self.parameter_dropout(x_rescaled)
1237
+
1238
+ x_embedded = self.patch_embedding(x_rescaled)
1239
+
1240
+ if self.residual == "climate":
1241
+ static_embedded = self.patch_embedding_static(
1242
+ torch.cat((x_static, climate_scaled), dim=1)
1243
+ )
1244
+ else:
1245
+ static_embedded = self.patch_embedding_static(x_static)
1246
+
1247
+ if self.positional_encoding == "fourier":
1248
+ static_embedded += x_static_pos
1249
+
1250
+ x_embedded = self.to_patching(x_embedded)
1251
+ static_embedded = self.to_patching(static_embedded)
1252
+
1253
+ time_encoding = self.time_encoding(
1254
+ batch["input_time"], batch["lead_time"]
1255
+ )
1256
+
1257
+ tokens = x_embedded + static_embedded + time_encoding
1258
+
1259
+ # Now we generate masks based on masking_mode
1260
+ indices_masked, indices_unmasked = self.generate_mask(
1261
+ (batch_size, self._nglobal_mu)
1262
+ )
1263
+ indices_masked = indices_masked.to(device=tokens.device)
1264
+ indices_unmasked = indices_unmasked.to(device=tokens.device)
1265
+ maskdim: int = indices_masked.ndim
1266
+
1267
+ # Unmasking
1268
+ unmask_view = (*indices_unmasked.shape, *[1] * (tokens.ndim - maskdim))
1269
+ unmasked = torch.gather(
1270
+ tokens,
1271
+ dim=maskdim - 1,
1272
+ index=indices_unmasked.view(*unmask_view).expand(
1273
+ *indices_unmasked.shape, *tokens.shape[maskdim:]
1274
+ ),
1275
+ )
1276
+
1277
+ # Encoder
1278
+ x_encoded = self.encoder(unmasked)
1279
+
1280
+ # Generate and position encode the mask tokens
1281
+ # [1, 1, 1, embed_dim]
1282
+ # -> [batch, global_seq_masked, local seq, embed_dim]
1283
+ mask_view = (*indices_masked.shape, *[1] * (tokens.ndim - maskdim))
1284
+ masking = self.mask_token.repeat(*static_embedded.shape[:3], 1)
1285
+ masked = masking + static_embedded
1286
+ masked = torch.gather(
1287
+ masked,
1288
+ dim=maskdim - 1,
1289
+ index=indices_masked.view(*mask_view).expand(
1290
+ *indices_masked.shape, *tokens.shape[maskdim:]
1291
+ ),
1292
+ )
1293
+
1294
+ recon, _ = self.reconstruct_batch(
1295
+ indices_masked, indices_unmasked, masked, x_encoded
1296
+ )
1297
+
1298
+ x_decoded = self.decoder(recon)
1299
+
1300
+ # Output: [batch, global sequence, local sequence,
1301
+ # in_channels * patch_size[0] * patch_size[1]]
1302
+ x_unembed = self.unembed(x_decoded)
1303
+
1304
+ # Reshape to [batch, global_lat, global_lon, local_lat, local_lon,
1305
+ # in_channels * patch_size[0] * patch_size[1]]
1306
+ x_out = self.from_patching(x_unembed)
1307
+
1308
+ # Pixel shuffle to [batch, in_channels, lat, lon]
1309
+ x_out = F.pixel_shuffle(x_out, self.patch_size_px[0])
1310
+
1311
+ if self.residual == "temporal":
1312
+ x_out = self.output_scalers * x_out + x_hat
1313
+ elif self.residual == "climate":
1314
+ x_out = self.output_scalers * x_out + batch["climate"]
1315
+ elif self.residual == "none":
1316
+ x_out = (
1317
+ self.output_scalers * x_out
1318
+ + self.input_scalers_mu.reshape(1, -1, 1, 1)
1319
+ )
1320
+
1321
+ return x_out
PrithviWxC/rollout.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor, nn
3
+
4
+
5
+ def rollout_iter(
6
+ nsteps: int,
7
+ model: nn.Module,
8
+ batch: dict[str, Tensor | int | float],
9
+ ) -> Tensor:
10
+ """A helper function for performing autoregressive rollout.
11
+
12
+ Args:
13
+ nsteps (int): The number of rollout steps to take
14
+ model (nn.Module): A model.
15
+ batch (dict): A data dictionary common to the Prithvi models.
16
+
17
+ Raises:
18
+ ValueError: If the number of steps isn't positive.
19
+
20
+ Returns:
21
+ Tensor: the output of the model after nsteps autoregressive iterations.
22
+ """
23
+ if nsteps < 1:
24
+ raise ValueError("'nsteps' shouold be a positive int.")
25
+
26
+ xlast = batch["x"][:, 1]
27
+ batch["lead_time"] = batch["lead_time"][..., 0]
28
+
29
+ # Save the masking ratio to be restored later
30
+ mask_ratio_tmp = model.mask_ratio_inputs
31
+
32
+ for step in range(nsteps):
33
+ # After first step, turn off masking
34
+ if step > 0:
35
+ model.mask_ratio_inputs = 0.0
36
+
37
+ batch["static"] = batch["statics"][:, step]
38
+ batch["climate"] = batch["climates"][:, step]
39
+ batch["y"] = batch["ys"][:, step]
40
+
41
+ out = model(batch)
42
+
43
+ batch["x"] = torch.cat((xlast[:, None], out[:, None]), dim=1)
44
+ xlast = out
45
+
46
+ # Restore the masking ratio
47
+ model.mask_ratio_inputs = mask_ratio_tmp
48
+
49
+ return xlast