Spaces:
Runtime error
Runtime error
Commit
·
b73936d
1
Parent(s):
3808ef8
Initial commit
Browse files- app.py +325 -0
- requirements.txt +17 -0
- surya/__init__.py +0 -0
- surya/datasets/__init__.py +0 -0
- surya/datasets/helio.py +524 -0
- surya/datasets/transformations.py +456 -0
- surya/models/__init__.py +0 -0
- surya/models/embedding.py +483 -0
- surya/models/flow.py +81 -0
- surya/models/helio_spectformer.py +318 -0
- surya/models/spectformer.py +305 -0
- surya/models/transformer_ls.py +369 -0
- surya/utils/__init__.py +0 -0
- surya/utils/config.py +311 -0
- surya/utils/data.py +176 -0
- surya/utils/distributed.py +313 -0
- surya/utils/log.py +110 -0
- surya/utils/misc.py +90 -0
- surya/utils/schemas.py +14 -0
app.py
ADDED
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import socket
|
2 |
+
import yaml
|
3 |
+
import logging
|
4 |
+
from dataclasses import dataclass
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import pandas as pd
|
9 |
+
import torch
|
10 |
+
import torch.nn.functional as F
|
11 |
+
from torch.utils.data import DataLoader
|
12 |
+
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import sunpy.visualization.colormaps as sunpy_cm
|
15 |
+
|
16 |
+
import gradio as gr
|
17 |
+
from huggingface_hub import snapshot_download
|
18 |
+
|
19 |
+
from surya.datasets.helio import HelioNetCDFDataset, inverse_transform_single_channel
|
20 |
+
from surya.models.helio_spectformer import HelioSpectFormer
|
21 |
+
from surya.utils.data import build_scalers, custom_collate_fn
|
22 |
+
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
SDO_CHANNELS = [
|
26 |
+
"aia94",
|
27 |
+
"aia131",
|
28 |
+
"aia171",
|
29 |
+
"aia193",
|
30 |
+
"aia211",
|
31 |
+
"aia304",
|
32 |
+
"aia335",
|
33 |
+
"aia1600",
|
34 |
+
"hmi_m",
|
35 |
+
"hmi_bx",
|
36 |
+
"hmi_by",
|
37 |
+
"hmi_bz",
|
38 |
+
"hmi_v",
|
39 |
+
]
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class SDOImage:
|
43 |
+
channel: str
|
44 |
+
data: np.ndarray
|
45 |
+
timestamp: str
|
46 |
+
type: str
|
47 |
+
|
48 |
+
def download_data():
|
49 |
+
snapshot_download(
|
50 |
+
repo_id="nasa-ibm-ai4science/Surya-1.0",
|
51 |
+
local_dir="data/Surya-1.0",
|
52 |
+
allow_patterns=["config.yaml", "scalers.yaml", "surya.366m.v1.pt"],
|
53 |
+
token=None,
|
54 |
+
)
|
55 |
+
snapshot_download(
|
56 |
+
repo_id="nasa-ibm-ai4science/Surya-1.0_validation_data",
|
57 |
+
repo_type="dataset",
|
58 |
+
local_dir="data/Surya-1.0_validation_data",
|
59 |
+
allow_patterns="20140107_1[5-9]??.nc",
|
60 |
+
token=None,
|
61 |
+
)
|
62 |
+
|
63 |
+
def get_dataset(config, scalers) -> HelioNetCDFDataset:
|
64 |
+
dataset = HelioNetCDFDataset(
|
65 |
+
index_path="tests/test_surya_index.csv",
|
66 |
+
time_delta_input_minutes=config["data"]["time_delta_input_minutes"],
|
67 |
+
time_delta_target_minutes=config["data"]["time_delta_target_minutes"],
|
68 |
+
n_input_timestamps=len(config["data"]["time_delta_input_minutes"]),
|
69 |
+
rollout_steps=0,
|
70 |
+
channels=config["data"]["sdo_channels"],
|
71 |
+
drop_hmi_probability=config["data"]["drop_hmi_probability"],
|
72 |
+
num_mask_aia_channels=config["data"]["num_mask_aia_channels"],
|
73 |
+
use_latitude_in_learned_flow=config["data"]["use_latitude_in_learned_flow"],
|
74 |
+
scalers=scalers,
|
75 |
+
phase="valid",
|
76 |
+
pooling=config["data"]["pooling"],
|
77 |
+
random_vert_flip=config["data"]["random_vert_flip"],
|
78 |
+
)
|
79 |
+
logger.info(f"Initialized the dataset. {len(dataset)} samples.")
|
80 |
+
|
81 |
+
return dataset
|
82 |
+
|
83 |
+
def get_scalers() -> dict:
|
84 |
+
scalers_info = yaml.safe_load(open("data/Surya-1.0/scalers.yaml", "r"))
|
85 |
+
scalers = build_scalers(info=scalers_info)
|
86 |
+
logger.info("Built the scalers.")
|
87 |
+
return scalers
|
88 |
+
|
89 |
+
def get_model_from_config(config) -> HelioSpectFormer:
|
90 |
+
model = HelioSpectFormer(
|
91 |
+
img_size=config["model"]["img_size"],
|
92 |
+
patch_size=config["model"]["patch_size"],
|
93 |
+
in_chans=len(config["data"]["sdo_channels"]),
|
94 |
+
embed_dim=config["model"]["embed_dim"],
|
95 |
+
time_embedding={
|
96 |
+
"type": "linear",
|
97 |
+
"time_dim": len(config["data"]["time_delta_input_minutes"]),
|
98 |
+
},
|
99 |
+
depth=config["model"]["depth"],
|
100 |
+
n_spectral_blocks=config["model"]["n_spectral_blocks"],
|
101 |
+
num_heads=config["model"]["num_heads"],
|
102 |
+
mlp_ratio=config["model"]["mlp_ratio"],
|
103 |
+
drop_rate=config["model"]["drop_rate"],
|
104 |
+
dtype=torch.bfloat16,
|
105 |
+
window_size=config["model"]["window_size"],
|
106 |
+
dp_rank=config["model"]["dp_rank"],
|
107 |
+
learned_flow=config["model"]["learned_flow"],
|
108 |
+
use_latitude_in_learned_flow=config["model"]["learned_flow"],
|
109 |
+
init_weights=False,
|
110 |
+
checkpoint_layers=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
|
111 |
+
rpe=config["model"]["rpe"],
|
112 |
+
ensemble=config["model"]["ensemble"],
|
113 |
+
finetune=config["model"]["finetune"],
|
114 |
+
)
|
115 |
+
logger.info("Initialized the model.")
|
116 |
+
|
117 |
+
return model
|
118 |
+
|
119 |
+
def get_config() -> dict:
|
120 |
+
with open("data/Surya-1.0/config.yaml") as fp:
|
121 |
+
config = yaml.safe_load(fp)
|
122 |
+
|
123 |
+
return config
|
124 |
+
|
125 |
+
def setup():
|
126 |
+
logger.info("Loading data ...")
|
127 |
+
download_data()
|
128 |
+
config = get_config()
|
129 |
+
scalers = get_scalers()
|
130 |
+
|
131 |
+
logger.info("Initializing dataset ...")
|
132 |
+
dataset = get_dataset(config, scalers)
|
133 |
+
|
134 |
+
logger.info("Initializing model ...")
|
135 |
+
model = get_model_from_config(config)
|
136 |
+
if torch.cuda.is_available():
|
137 |
+
device = torch.cuda.current_device()
|
138 |
+
logger.info(f"GPU detected. Running the test on device {device}.")
|
139 |
+
else:
|
140 |
+
device = "cpu"
|
141 |
+
logger.warning(f"No GPU detected. Running the test on CPU.")
|
142 |
+
model.to(device)
|
143 |
+
n_parameters = sum(p.numel() for p in model.parameters()) / 1e6
|
144 |
+
logger.info(f"Surya FM: {n_parameters:.2f} M total parameters.")
|
145 |
+
path_weights = "data/Surya-1.0/surya.366m.v1.pt"
|
146 |
+
weights = torch.load(
|
147 |
+
path_weights, map_location=torch.device(device), weights_only=True
|
148 |
+
)
|
149 |
+
model.load_state_dict(weights, strict=True)
|
150 |
+
logger.info("Loaded weights.")
|
151 |
+
|
152 |
+
return dataset, model, device
|
153 |
+
|
154 |
+
def batch_step(
|
155 |
+
model: HelioSpectFormer,
|
156 |
+
sample_data: dict,
|
157 |
+
sample_metadata: dict,
|
158 |
+
device: int | str,
|
159 |
+
hours_ahead: int = 1,
|
160 |
+
) -> np.ndarray:
|
161 |
+
"""
|
162 |
+
Perform a single batch step for the given model, batch data, metadata, and device.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
model: The PyTorch model to use for prediction.
|
166 |
+
sample_data: A dictionary containing input and target data for the batch.
|
167 |
+
sample_metadata: A dictionary containing metadata for the batch, including timestamps.
|
168 |
+
device: The device to use for computation ('cpu', 'cuda' or device number).
|
169 |
+
hours_ahead: The number of steps to forecast ahead. Defaults to 1.
|
170 |
+
|
171 |
+
Returns:
|
172 |
+
np.ndarray: Output data.
|
173 |
+
"""
|
174 |
+
|
175 |
+
data_returned = []
|
176 |
+
forecast_hat = None # Initialize forecast_hat
|
177 |
+
|
178 |
+
for step in range(1, hours_ahead + 1):
|
179 |
+
if step == 1:
|
180 |
+
curr_batch = {
|
181 |
+
key: torch.from_numpy(sample_data[key]).unsqueeze(0).to(device)
|
182 |
+
for key in ["ts", "time_delta_input"]
|
183 |
+
}
|
184 |
+
else:
|
185 |
+
# Use the previous forecast_hat from the previous iteration
|
186 |
+
if forecast_hat is not None:
|
187 |
+
curr_batch["ts"] = torch.cat(
|
188 |
+
(curr_batch["ts"][:, :, 1:, ...], forecast_hat[:, :, None, ...]),
|
189 |
+
dim=2,
|
190 |
+
)
|
191 |
+
|
192 |
+
forecast_hat = model(curr_batch)
|
193 |
+
|
194 |
+
data_returned = forecast_hat.to(dtype=torch.float32).cpu().squeeze(0).numpy()
|
195 |
+
|
196 |
+
return data_returned
|
197 |
+
|
198 |
+
|
199 |
+
def run_inference(init_time_idx, plt_channel_idx, hours_ahead):
|
200 |
+
plt_channel_str = SDO_CHANNELS[plt_channel_idx]
|
201 |
+
|
202 |
+
input_timestamp_1 = dataset.valid_indices[init_time_idx]
|
203 |
+
input_timestamp_0 = input_timestamp_1 - pd.Timedelta(1, "h")
|
204 |
+
output_timestamp = input_timestamp_1 + pd.Timedelta(int(hours_ahead), "h")
|
205 |
+
|
206 |
+
input_timestamp_0 = input_timestamp_0.strftime("%Y-%m-%d %H:%M")
|
207 |
+
input_timestamp_1 = input_timestamp_1.strftime("%Y-%m-%d %H:%M")
|
208 |
+
output_timestamp = output_timestamp.strftime("%Y-%m-%d %H:%M")
|
209 |
+
|
210 |
+
sample_data, sample_metadata = dataset[init_time_idx]
|
211 |
+
with torch.no_grad():
|
212 |
+
model_output = batch_step(
|
213 |
+
model,
|
214 |
+
sample_data,
|
215 |
+
sample_metadata,
|
216 |
+
device,
|
217 |
+
hours_ahead
|
218 |
+
)
|
219 |
+
|
220 |
+
means, stds, epsilons, sl_scale_factors = dataset.transformation_inputs()
|
221 |
+
|
222 |
+
vmin = float("-inf")
|
223 |
+
vmax = float("inf")
|
224 |
+
input_image = []
|
225 |
+
for i in range(2):
|
226 |
+
input_image.append(
|
227 |
+
inverse_transform_single_channel(
|
228 |
+
sample_data["ts"][plt_channel_idx, i],
|
229 |
+
mean=means[plt_channel_idx],
|
230 |
+
std=stds[plt_channel_idx],
|
231 |
+
epsilon=epsilons[plt_channel_idx],
|
232 |
+
sl_scale_factor=sl_scale_factors[plt_channel_idx],
|
233 |
+
)
|
234 |
+
)
|
235 |
+
vmin = max(vmin, sample_data["ts"][plt_channel_idx, i].min())
|
236 |
+
#vmax = min(vmax, np.quantile(sample_data["ts"][plt_channel_idx, i], 0.99))
|
237 |
+
vmax = min(vmax, sample_data["ts"][plt_channel_idx, i].max())
|
238 |
+
|
239 |
+
if plt_channel_str.startswith("aia"):
|
240 |
+
cm_name = "sdo" + plt_channel_str
|
241 |
+
else:
|
242 |
+
cm_name = "hmimag"
|
243 |
+
|
244 |
+
input_image = [
|
245 |
+
sunpy_cm.cmlist[cm_name](
|
246 |
+
(img[::-1]-vmin) / (vmax-vmin), bytes=True
|
247 |
+
)
|
248 |
+
for img in input_image
|
249 |
+
]
|
250 |
+
|
251 |
+
output_image = inverse_transform_single_channel(
|
252 |
+
model_output[plt_channel_idx],
|
253 |
+
mean=means[plt_channel_idx],
|
254 |
+
std=stds[plt_channel_idx],
|
255 |
+
epsilon=epsilons[plt_channel_idx],
|
256 |
+
sl_scale_factor=sl_scale_factors[plt_channel_idx],
|
257 |
+
)
|
258 |
+
output_image = sunpy_cm.cmlist[cm_name](
|
259 |
+
(output_image[::-1]-vmin) / (vmax-vmin), bytes=True
|
260 |
+
)
|
261 |
+
|
262 |
+
return input_timestamp_0, input_image[0], input_timestamp_1, input_image[1], output_timestamp, output_image
|
263 |
+
|
264 |
+
logging.basicConfig(level=logging.INFO)
|
265 |
+
dataset, model, device = setup()
|
266 |
+
hostname = socket.getfqdn()
|
267 |
+
logging.info(f"Launching app on {hostname}")
|
268 |
+
|
269 |
+
with gr.Blocks() as demo:
|
270 |
+
gr.Markdown(value="# Surya 1.0 - Visual forecasting demo")
|
271 |
+
#with gr.Row():
|
272 |
+
#with gr.Column():
|
273 |
+
with gr.Row():
|
274 |
+
with gr.Column():
|
275 |
+
init_time = gr.Dropdown(
|
276 |
+
[v.strftime("%Y-%m-%d %H:%M") for v in dataset.valid_indices],
|
277 |
+
label="Initialization time",
|
278 |
+
multiselect=False,
|
279 |
+
type="index",
|
280 |
+
)
|
281 |
+
with gr.Column():
|
282 |
+
plt_channel = gr.Dropdown(
|
283 |
+
[c.upper() for c in SDO_CHANNELS],
|
284 |
+
label="SDO Band",
|
285 |
+
value="AIA94",
|
286 |
+
multiselect=False,
|
287 |
+
type="index"
|
288 |
+
)
|
289 |
+
with gr.Row():
|
290 |
+
hours_ahead = gr.Slider(minimum=1.0, maximum=6.0, step=1.0, label="Forcast step [hours ahead]")
|
291 |
+
with gr.Row():
|
292 |
+
btn = gr.Button("Run")
|
293 |
+
|
294 |
+
with gr.Row():
|
295 |
+
with gr.Column():
|
296 |
+
input_timestamp_0 = gr.Textbox(label="Input 0")
|
297 |
+
input_image_0 = gr.Image()
|
298 |
+
with gr.Column():
|
299 |
+
input_timestamp_1 = gr.Textbox(label="Input 1")
|
300 |
+
input_image_1 = gr.Image()
|
301 |
+
with gr.Column():
|
302 |
+
output_timestamp = gr.Textbox(label="Prediction")
|
303 |
+
output_image = gr.Image()
|
304 |
+
|
305 |
+
btn.click(
|
306 |
+
fn=run_inference,
|
307 |
+
inputs=[init_time, plt_channel, hours_ahead],
|
308 |
+
outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image]
|
309 |
+
)
|
310 |
+
|
311 |
+
with gr.Row():
|
312 |
+
gr.Examples(
|
313 |
+
examples=[
|
314 |
+
["2014-01-07 17:24", "AIA94", 2],
|
315 |
+
["2014-01-07 16:12", "AIA94", 6],
|
316 |
+
["2014-01-07 16:00", "AIA131", 1],
|
317 |
+
["2014-01-07 16:00", "HMI_M", 2],
|
318 |
+
],
|
319 |
+
fn=run_inference,
|
320 |
+
inputs=[init_time, plt_channel, hours_ahead],
|
321 |
+
outputs=[input_timestamp_0, input_image_0, input_timestamp_1, input_image_1, output_timestamp, output_image],
|
322 |
+
cache_examples=False,
|
323 |
+
)
|
324 |
+
|
325 |
+
demo.launch(server_name=hostname, server_port=None)
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
einops==0.8.1
|
2 |
+
gradio==5.43.1
|
3 |
+
hdf5plugin==5.1.0
|
4 |
+
huggingface_hub==0.34.4
|
5 |
+
matplotlib==3.10.5
|
6 |
+
numba==0.61.2
|
7 |
+
numpy==2.3.2
|
8 |
+
packaging==25.0
|
9 |
+
pandas==2.3.1
|
10 |
+
PyYAML==6.0.2
|
11 |
+
PyYAML==6.0.2
|
12 |
+
skimage==0.0
|
13 |
+
sunpy==6.1.1
|
14 |
+
timm==1.0.19
|
15 |
+
torch==2.6.0
|
16 |
+
wandb==0.21.1
|
17 |
+
xarray==2025.3.1
|
surya/__init__.py
ADDED
File without changes
|
surya/datasets/__init__.py
ADDED
File without changes
|
surya/datasets/helio.py
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import random
|
4 |
+
from datetime import datetime
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import skimage.measure
|
8 |
+
import xarray as xr
|
9 |
+
import pandas as pd
|
10 |
+
from logging import Logger
|
11 |
+
from torch.utils.data import Dataset
|
12 |
+
from surya.utils.distributed import get_rank
|
13 |
+
from surya.utils.log import create_logger
|
14 |
+
from functools import cache
|
15 |
+
|
16 |
+
from numba import njit, prange
|
17 |
+
|
18 |
+
import hdf5plugin
|
19 |
+
|
20 |
+
|
21 |
+
@njit(parallel=True)
|
22 |
+
def fast_transform(data, means, stds, sl_scale_factors, epsilons):
|
23 |
+
"""
|
24 |
+
Implements signum log transform using numba for speed
|
25 |
+
Notes:
|
26 |
+
- This must reside outside the class definition from which it is called.
|
27 |
+
- We used this function during pretraining for faster data loading. On select
|
28 |
+
GPU clusters it leads to the system hanging however when data loading happens
|
29 |
+
outside the GPU thread. See below for a non-numba-enhanced version.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
data: Numpy array of shape C, H, W
|
33 |
+
means: Numpy array of shape C. Mean per channel.
|
34 |
+
stds: Numpy array of shape C. Standard deviation per channel.
|
35 |
+
sl_scale_factors: Numpy array of shape C. Signum-log scale factors.
|
36 |
+
epsilons: Numpy array of shape C. Constant to avoid zero division.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
Numpy array of shape C, H, W.
|
40 |
+
"""
|
41 |
+
C, H, W = data.shape
|
42 |
+
out = np.empty((C, H, W), dtype=np.float32)
|
43 |
+
for c in prange(C):
|
44 |
+
mean = means[c]
|
45 |
+
std = stds[c]
|
46 |
+
eps = epsilons[c]
|
47 |
+
sl_scale_factor = sl_scale_factors[c]
|
48 |
+
for i in range(H):
|
49 |
+
for j in range(W):
|
50 |
+
val = data[c, i, j]
|
51 |
+
val = val * sl_scale_factor
|
52 |
+
if val >= 0:
|
53 |
+
val = np.log1p(val)
|
54 |
+
else:
|
55 |
+
val = -np.log1p(-val)
|
56 |
+
out[c, i, j] = (val - mean) / (std + eps)
|
57 |
+
return out
|
58 |
+
|
59 |
+
def transform(
|
60 |
+
data: np.ndarray,
|
61 |
+
means: np.ndarray,
|
62 |
+
stds: np.ndarray,
|
63 |
+
sl_scale_factors: np.ndarray,
|
64 |
+
epsilons: np.ndarray
|
65 |
+
) -> np.ndarray:
|
66 |
+
"""
|
67 |
+
Implements signum log transform. Drop-in replacement for
|
68 |
+
`fast_transform` method above.
|
69 |
+
|
70 |
+
Args:
|
71 |
+
data: Numpy array of shape C, H, W
|
72 |
+
means: Numpy array of shape C. Mean per channel.
|
73 |
+
stds: Numpy array of shape C. Standard deviation per channel.
|
74 |
+
sl_scale_factors: Numpy array of shape C. Signum-log scale factors.
|
75 |
+
epsilons: Numpy array of shape C. Constant to avoid zero division.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
Numpy array of shape C, H, W.
|
79 |
+
"""
|
80 |
+
means = means.reshape(*means.shape, 1, 1)
|
81 |
+
stds = stds.reshape(*stds.shape, 1, 1)
|
82 |
+
sl_scale_factors = sl_scale_factors.reshape(*sl_scale_factors.shape, 1, 1)
|
83 |
+
epsilons = epsilons.reshape(*epsilons.shape, 1, 1)
|
84 |
+
|
85 |
+
data = data * sl_scale_factors
|
86 |
+
data = np.sign(data) * np.log1p(np.abs(data))
|
87 |
+
data = (data - means) / (stds + epsilons)
|
88 |
+
|
89 |
+
return data
|
90 |
+
|
91 |
+
@njit(parallel=True)
|
92 |
+
def inverse_fast_transform(data, means, stds, sl_scale_factors, epsilons):
|
93 |
+
"""
|
94 |
+
Implements inverse signum log transform using numba for speed
|
95 |
+
|
96 |
+
Args:
|
97 |
+
data: Numpy array of shape C, H, W
|
98 |
+
means: Numpy array of shape C. Mean per channel.
|
99 |
+
stds: Numpy array of shape C. Standard deviation per channel.
|
100 |
+
sl_scale_factors: Numpy array of shape C. Signum-log scale factors.
|
101 |
+
epsilons: Numpy array of shape C. Constant to avoid zero division.
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
Numpy array of shape C, H, W.
|
105 |
+
"""
|
106 |
+
C, H, W = data.shape
|
107 |
+
out = np.empty((C, H, W), dtype=np.float32)
|
108 |
+
|
109 |
+
for c in prange(C):
|
110 |
+
mean = means[c]
|
111 |
+
std = stds[c]
|
112 |
+
eps = epsilons[c]
|
113 |
+
sl_scale_factor = sl_scale_factors[c]
|
114 |
+
|
115 |
+
for i in range(H):
|
116 |
+
for j in range(W):
|
117 |
+
val = data[c, i, j]
|
118 |
+
val = val * (std + eps) + mean
|
119 |
+
|
120 |
+
if val >= 0:
|
121 |
+
val = np.expm1(val)
|
122 |
+
else:
|
123 |
+
val = -np.expm1(-val)
|
124 |
+
|
125 |
+
val = val / sl_scale_factor
|
126 |
+
|
127 |
+
out[c, i, j] = val
|
128 |
+
|
129 |
+
return out
|
130 |
+
|
131 |
+
|
132 |
+
def inverse_transform_single_channel(data, mean, std, sl_scale_factor, epsilon):
|
133 |
+
"""
|
134 |
+
Implements inverse signum log transform.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
data: Numpy array of shape C, H, W
|
138 |
+
means: Numpy array of shape C. Mean per channel.
|
139 |
+
stds: Numpy array of shape C. Standard deviation per channel.
|
140 |
+
sl_scale_factors: Numpy array of shape C. Signum-log scale factors.
|
141 |
+
epsilons: Numpy array of shape C. Constant to avoid zero division.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
Numpy array of shape C, H, W.
|
145 |
+
"""
|
146 |
+
data = data * (std + epsilon) + mean
|
147 |
+
|
148 |
+
data = np.sign(data) * np.expm1(np.abs(data))
|
149 |
+
|
150 |
+
data = data / sl_scale_factor
|
151 |
+
|
152 |
+
return data
|
153 |
+
|
154 |
+
|
155 |
+
class RandomChannelMaskerTransform:
|
156 |
+
def __init__(
|
157 |
+
self, num_channels, num_mask_aia_channels, phase, drop_hmi_probability
|
158 |
+
):
|
159 |
+
"""
|
160 |
+
Initialize the RandomChannelMaskerTransform class as a transform.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
- num_channels: Total number of channels in the input (3rd dimension of
|
164 |
+
the tensor).
|
165 |
+
- num_mask_aia_channels: Number of channels to randomly mask.
|
166 |
+
"""
|
167 |
+
self.num_channels = num_channels
|
168 |
+
self.num_mask_aia_channels = num_mask_aia_channels
|
169 |
+
self.drop_hmi_probability = drop_hmi_probability
|
170 |
+
|
171 |
+
def __call__(self, input_tensor):
|
172 |
+
C, T, H, W = input_tensor.shape # Unpacking the correct 5 dimensions
|
173 |
+
|
174 |
+
# Randomly select channels to mask
|
175 |
+
channels_to_mask = random.sample(range(C), self.num_mask_aia_channels)
|
176 |
+
|
177 |
+
# Create an in-place mask of shape [1, 1, num_channels, 1, 1]
|
178 |
+
mask = torch.ones((C, 1, 1, 1))
|
179 |
+
mask[channels_to_mask, ...] = 0 # Set selected channels to zero
|
180 |
+
|
181 |
+
# Apply the mask in-place for memory efficiency
|
182 |
+
masked_tensor = input_tensor * mask # Modify input_tensor directly
|
183 |
+
|
184 |
+
if self.drop_hmi_probability > random.random():
|
185 |
+
masked_tensor[-1, ...] = 0
|
186 |
+
|
187 |
+
return masked_tensor
|
188 |
+
|
189 |
+
|
190 |
+
class HelioNetCDFDataset(Dataset):
|
191 |
+
"""
|
192 |
+
PyTorch dataset to load a curated dataset from the NASA Solar Dynamics
|
193 |
+
Observatory (SDO) mission stored as NetCDF files, with handling for variable timesteps.
|
194 |
+
|
195 |
+
Internally maintains two databases. The first is `self.index`. This takes the
|
196 |
+
form
|
197 |
+
path present
|
198 |
+
timestep
|
199 |
+
2011-01-01 00:00:00 /lustre/fs0/scratch/shared/data/2011/01/Arka_2... 1
|
200 |
+
2011-01-01 00:12:00 /lustre/fs0/scratch/shared/data/2011/01/Arka_2... 1
|
201 |
+
... ... ...
|
202 |
+
2012-11-30 23:48:00 /lustre/fs0/scratch/shared/data/2012/11/Arka_2... 1
|
203 |
+
|
204 |
+
The second is `self.valid_indices`. This is simply a list of timesteps -- entries
|
205 |
+
in the index of `self.index` -- which define valid samples. A sample is valid
|
206 |
+
when all timestamps that can be reached by entris in
|
207 |
+
time_delta_input_minutes and time_delta_target_minutes can be reached from it
|
208 |
+
are present.
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(
|
212 |
+
self,
|
213 |
+
index_path: str,
|
214 |
+
time_delta_input_minutes: list[int],
|
215 |
+
time_delta_target_minutes: int,
|
216 |
+
n_input_timestamps: int,
|
217 |
+
rollout_steps: int,
|
218 |
+
scalers=None,
|
219 |
+
num_mask_aia_channels: int = 0,
|
220 |
+
drop_hmi_probability: float = 0.0,
|
221 |
+
use_latitude_in_learned_flow=False,
|
222 |
+
channels: list[str] | None = None,
|
223 |
+
phase="train",
|
224 |
+
pooling: int | None = None,
|
225 |
+
random_vert_flip: bool = False,
|
226 |
+
):
|
227 |
+
self.scalers = scalers
|
228 |
+
self.phase = phase
|
229 |
+
self.channels = channels
|
230 |
+
self.num_mask_aia_channels = num_mask_aia_channels
|
231 |
+
self.drop_hmi_probability = drop_hmi_probability
|
232 |
+
self.n_input_timestamps = n_input_timestamps
|
233 |
+
self.rollout_steps = rollout_steps
|
234 |
+
self.use_latitude_in_learned_flow = use_latitude_in_learned_flow
|
235 |
+
self.pooling = pooling if pooling is not None else 1
|
236 |
+
self.random_vert_flip = random_vert_flip
|
237 |
+
|
238 |
+
if self.channels is None:
|
239 |
+
# AIA + HMI channels
|
240 |
+
self.channels = [
|
241 |
+
"0094",
|
242 |
+
"0131",
|
243 |
+
"0171",
|
244 |
+
"0193",
|
245 |
+
"0211",
|
246 |
+
"0304",
|
247 |
+
"0335",
|
248 |
+
"hmi",
|
249 |
+
]
|
250 |
+
self.in_channels = len(self.channels)
|
251 |
+
|
252 |
+
self.masker = RandomChannelMaskerTransform(
|
253 |
+
num_channels=self.in_channels,
|
254 |
+
num_mask_aia_channels=self.num_mask_aia_channels,
|
255 |
+
phase=self.phase,
|
256 |
+
drop_hmi_probability=self.drop_hmi_probability,
|
257 |
+
)
|
258 |
+
|
259 |
+
# Convert time delta to numpy timedelta64
|
260 |
+
self.time_delta_input_minutes = sorted(
|
261 |
+
np.timedelta64(t, "m") for t in time_delta_input_minutes
|
262 |
+
)
|
263 |
+
self.time_delta_target_minutes = [
|
264 |
+
np.timedelta64(iroll * time_delta_target_minutes, "m")
|
265 |
+
for iroll in range(1, rollout_steps + 2)
|
266 |
+
]
|
267 |
+
|
268 |
+
# Create the index
|
269 |
+
self.index = pd.read_csv(index_path)
|
270 |
+
self.index = self.index[self.index["present"] == 1]
|
271 |
+
self.index["timestep"] = pd.to_datetime(self.index["timestep"]).values.astype(
|
272 |
+
"datetime64[ns]"
|
273 |
+
)
|
274 |
+
self.index.set_index("timestep", inplace=True)
|
275 |
+
self.index.sort_index(inplace=True)
|
276 |
+
|
277 |
+
# Filter out rows where the sequence is not fully present
|
278 |
+
self.valid_indices = self.filter_valid_indices()
|
279 |
+
self.adjusted_length = len(self.valid_indices)
|
280 |
+
|
281 |
+
self.rank = get_rank()
|
282 |
+
self.logger: Logger | None = None
|
283 |
+
|
284 |
+
def create_logger(self):
|
285 |
+
"""
|
286 |
+
Creates a logger attached to self.logger.
|
287 |
+
The logger is identified by SLURM job ID
|
288 |
+
as well as the data processes rank and process ID.
|
289 |
+
"""
|
290 |
+
os.makedirs("logs/data", exist_ok=True)
|
291 |
+
timestamp = datetime.now().strftime("%Y%m%dT%H%M%SZ")
|
292 |
+
pid = os.getpid()
|
293 |
+
self.logger = create_logger(
|
294 |
+
output_dir="logs/data",
|
295 |
+
dist_rank=self.rank,
|
296 |
+
name=f"{timestamp}_{self.rank:>03}_data_{self.phase}_{pid}",
|
297 |
+
)
|
298 |
+
|
299 |
+
def filter_valid_indices(self):
|
300 |
+
"""
|
301 |
+
Extracts timestamps from the index of self.index that define valid
|
302 |
+
samples.
|
303 |
+
|
304 |
+
Args:
|
305 |
+
Returns:
|
306 |
+
List of timestamps.
|
307 |
+
"""
|
308 |
+
|
309 |
+
valid_indices = []
|
310 |
+
time_deltas = np.unique(
|
311 |
+
self.time_delta_input_minutes + self.time_delta_target_minutes
|
312 |
+
)
|
313 |
+
|
314 |
+
for reference_timestep in self.index.index:
|
315 |
+
required_timesteps = reference_timestep + time_deltas
|
316 |
+
|
317 |
+
if all(t in self.index.index for t in required_timesteps):
|
318 |
+
valid_indices.append(reference_timestep)
|
319 |
+
|
320 |
+
return valid_indices
|
321 |
+
|
322 |
+
def __len__(self):
|
323 |
+
return self.adjusted_length
|
324 |
+
|
325 |
+
def __getitem__(self, idx: int) -> dict:
|
326 |
+
"""
|
327 |
+
Args:
|
328 |
+
idx: Index of sample to load. (Pytorch standard.)
|
329 |
+
Returns:
|
330 |
+
Dictionary with following keys. The values are tensors with shape as follows:
|
331 |
+
ts (torch.Tensor): C, T, H, W
|
332 |
+
time_delta_input (torch.Tensor): T
|
333 |
+
input_latitude (torch.Tensor): T
|
334 |
+
forecast (torch.Tensor): C, L, H, W
|
335 |
+
lead_time_delta (torch.Tensor): L
|
336 |
+
forecast_latitude (torch.Tensor): L
|
337 |
+
C - Channels, T - Input times, H - Image height, W - Image width, L - Lead time.
|
338 |
+
"""
|
339 |
+
if self.logger is None:
|
340 |
+
self.create_logger()
|
341 |
+
self.logger.info(f"HelioNetCDFDataset of length {self.__len__()}.")
|
342 |
+
|
343 |
+
exception_counter = 0
|
344 |
+
max_exception = 100
|
345 |
+
|
346 |
+
self.logger.info(f"Starting to retrieve index {idx}.")
|
347 |
+
|
348 |
+
while True:
|
349 |
+
try:
|
350 |
+
sample = self._get_index_data(idx)
|
351 |
+
except Exception as e:
|
352 |
+
exception_counter += 1
|
353 |
+
if exception_counter >= max_exception:
|
354 |
+
raise e
|
355 |
+
|
356 |
+
reference_timestep = self.valid_indices[idx]
|
357 |
+
self.logger.warning(
|
358 |
+
f"Failed retrieving index {idx}. Timestamp {reference_timestep}. Attempt {exception_counter}."
|
359 |
+
)
|
360 |
+
|
361 |
+
idx = (idx + 1) % self.__len__()
|
362 |
+
else:
|
363 |
+
self.logger.info(f"Returning index {idx}.")
|
364 |
+
return sample
|
365 |
+
|
366 |
+
def _get_index_data(self, idx: int) -> dict:
|
367 |
+
"""
|
368 |
+
Args:
|
369 |
+
idx: Index of sample to load. (Pytorch standard.)
|
370 |
+
Returns:
|
371 |
+
Dictionary with following keys. The values are tensors with shape as follows:
|
372 |
+
ts (torch.Tensor): C, T, H, W
|
373 |
+
time_delta_input (torch.Tensor): T
|
374 |
+
input_latitude (torch.Tensor): T
|
375 |
+
forecast (torch.Tensor): C, L, H, W
|
376 |
+
lead_time_delta (torch.Tensor): L
|
377 |
+
forecast_latitude (torch.Tensor): L
|
378 |
+
C - Channels, T - Input times, H - Image height, W - Image width, L - Lead time.
|
379 |
+
"""
|
380 |
+
# start_time = time.time()
|
381 |
+
|
382 |
+
time_deltas = np.array(
|
383 |
+
sorted(
|
384 |
+
random.sample(
|
385 |
+
self.time_delta_input_minutes[:-1], self.n_input_timestamps - 1
|
386 |
+
)
|
387 |
+
)
|
388 |
+
+ [self.time_delta_input_minutes[-1]]
|
389 |
+
+ self.time_delta_target_minutes
|
390 |
+
)
|
391 |
+
reference_timestep = self.valid_indices[idx]
|
392 |
+
required_timesteps = reference_timestep + time_deltas
|
393 |
+
|
394 |
+
sequence_data = [
|
395 |
+
self.transform_data(
|
396 |
+
self.load_nc_data(
|
397 |
+
self.index.loc[timestep, "path"], timestep, self.channels
|
398 |
+
)
|
399 |
+
)
|
400 |
+
for timestep in required_timesteps
|
401 |
+
]
|
402 |
+
|
403 |
+
# Split sequence_data into inputs and target
|
404 |
+
inputs = sequence_data[: -self.rollout_steps - 1]
|
405 |
+
targets = sequence_data[-self.rollout_steps - 1 :]
|
406 |
+
|
407 |
+
stacked_inputs = np.stack(inputs, axis=1)
|
408 |
+
stacked_targets = np.stack(targets, axis=1)
|
409 |
+
|
410 |
+
timestamps_input = required_timesteps[: -self.rollout_steps - 1]
|
411 |
+
timestamps_targets = required_timesteps[-self.rollout_steps - 1 :]
|
412 |
+
|
413 |
+
if self.num_mask_aia_channels > 0 or self.drop_hmi_probability:
|
414 |
+
# assert 0 < self.num_mask_aia_channels < self.in_channels, \
|
415 |
+
# f'num_mask_aia_channels = {self.num_mask_aia_channels} should lie between 0 and {self.in_channels}'
|
416 |
+
|
417 |
+
stacked_inputs = self.masker(stacked_inputs)
|
418 |
+
|
419 |
+
time_delta_input_float = (
|
420 |
+
time_deltas[-self.rollout_steps - 2]
|
421 |
+
- time_deltas[: -self.rollout_steps - 1]
|
422 |
+
) / np.timedelta64(1, "h")
|
423 |
+
time_delta_input_float = time_delta_input_float.astype(np.float32)
|
424 |
+
|
425 |
+
lead_time_delta_float = (
|
426 |
+
time_deltas[-self.rollout_steps - 2]
|
427 |
+
- time_deltas[-self.rollout_steps - 1 :]
|
428 |
+
) / np.timedelta64(1, "h")
|
429 |
+
lead_time_delta_float = lead_time_delta_float.astype(np.float32)
|
430 |
+
|
431 |
+
# print('LocalRank', int(os.environ["LOCAL_RANK"]),
|
432 |
+
# 'GlobalRank', int(os.environ["RANK"]),
|
433 |
+
# 'worker', torch.utils.data.get_worker_info().id,
|
434 |
+
# f': Processed Input: {idx} ',time.time()- start_time)
|
435 |
+
|
436 |
+
metadata = {
|
437 |
+
"timestamps_input": timestamps_input,
|
438 |
+
"timestamps_targets": timestamps_targets,
|
439 |
+
}
|
440 |
+
|
441 |
+
if self.random_vert_flip:
|
442 |
+
if torch.bernoulli(torch.ones(()) / 2) == 1:
|
443 |
+
stacked_inputs = torch.flip(stacked_inputs, dims=-2)
|
444 |
+
stacked_targets = torch.flip(stacked_inputs, dims=-2)
|
445 |
+
|
446 |
+
if self.use_latitude_in_learned_flow:
|
447 |
+
from sunpy.coordinates.ephemeris import get_earth
|
448 |
+
|
449 |
+
sequence_latitude = [
|
450 |
+
get_earth(timestep).lat.value for timestep in required_timesteps
|
451 |
+
]
|
452 |
+
input_latitudes = sequence_latitude[: -self.rollout_steps - 1]
|
453 |
+
target_latitude = sequence_latitude[-self.rollout_steps - 1 :]
|
454 |
+
|
455 |
+
return {
|
456 |
+
"ts": stacked_inputs,
|
457 |
+
"time_delta_input": time_delta_input_float,
|
458 |
+
"input_latitudes": input_latitudes,
|
459 |
+
"forecast": stacked_targets,
|
460 |
+
"lead_time_delta": lead_time_delta_float,
|
461 |
+
"forecast_latitude": target_latitude,
|
462 |
+
}, metadata
|
463 |
+
|
464 |
+
return {
|
465 |
+
"ts": stacked_inputs,
|
466 |
+
"time_delta_input": time_delta_input_float,
|
467 |
+
"forecast": stacked_targets,
|
468 |
+
"lead_time_delta": lead_time_delta_float,
|
469 |
+
}, metadata
|
470 |
+
|
471 |
+
def load_nc_data(
|
472 |
+
self, filepath: str, timestep: pd.Timestamp, channels: list[str]
|
473 |
+
) -> np.ndarray:
|
474 |
+
"""
|
475 |
+
Args:
|
476 |
+
filepath: String or Pathlike. Points to NetCDF file to open.
|
477 |
+
timestep: Identifies timestamp to retrieve.
|
478 |
+
Returns:
|
479 |
+
Numpy array of shape (C, H, W).
|
480 |
+
"""
|
481 |
+
self.logger.info(f"Reading file {filepath}.")
|
482 |
+
|
483 |
+
with xr.open_dataset(
|
484 |
+
filepath, engine="h5netcdf", chunks=None, cache=False,
|
485 |
+
) as ds:
|
486 |
+
data = ds[channels].to_array().load().to_numpy()
|
487 |
+
|
488 |
+
return data
|
489 |
+
|
490 |
+
@cache
|
491 |
+
def transformation_inputs(self) -> (np.ndarray, np.ndarray, np.ndarray, np.ndarray):
|
492 |
+
means = np.array([self.scalers[ch].mean for ch in self.channels])
|
493 |
+
stds = np.array([self.scalers[ch].std for ch in self.channels])
|
494 |
+
epsilons = np.array([self.scalers[ch].epsilon for ch in self.channels])
|
495 |
+
sl_scale_factors = np.array(
|
496 |
+
[self.scalers[ch].sl_scale_factor for ch in self.channels]
|
497 |
+
)
|
498 |
+
|
499 |
+
return means, stds, epsilons, sl_scale_factors
|
500 |
+
|
501 |
+
def transform_data(self, data: np.ndarray) -> np.ndarray:
|
502 |
+
"""
|
503 |
+
Applies scalers.
|
504 |
+
|
505 |
+
Args:
|
506 |
+
data: Numpy array of shape (C, H, W)
|
507 |
+
Returns:
|
508 |
+
Tensor of shape (C, H, W). Data type float32.
|
509 |
+
Uses:
|
510 |
+
numba to speed up transform
|
511 |
+
tvk-srm-heliofm environment cloned from srm-heliofm with numba added
|
512 |
+
tvk_dgx_slurm.sh shell script modified to use new environment and new jobname
|
513 |
+
train_spectformer_dgx.yaml new jobname
|
514 |
+
"""
|
515 |
+
assert data.ndim == 3
|
516 |
+
|
517 |
+
if self.pooling > 1:
|
518 |
+
data = skimage.measure.block_reduce(
|
519 |
+
data, block_size=(1, self.pooling, self.pooling), func=np.mean
|
520 |
+
)
|
521 |
+
|
522 |
+
means, stds, epsilons, sl_scale_factors = self.transformation_inputs()
|
523 |
+
result_np = transform(data, means, stds, sl_scale_factors, epsilons)
|
524 |
+
return result_np
|
surya/datasets/transformations.py
ADDED
@@ -0,0 +1,456 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from logging import info
|
3 |
+
from typing import Tuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import xarray as xr
|
8 |
+
|
9 |
+
|
10 |
+
class Transformation(object):
|
11 |
+
@abc.abstractmethod
|
12 |
+
def fit(self, data: xr.DataArray):
|
13 |
+
raise NotImplementedError()
|
14 |
+
|
15 |
+
@abc.abstractmethod
|
16 |
+
def transform(self, data: xr.DataArray):
|
17 |
+
raise NotImplementedError()
|
18 |
+
|
19 |
+
@abc.abstractmethod
|
20 |
+
def inverse_transform(self, data: xr.DataArray):
|
21 |
+
raise NotImplementedError()
|
22 |
+
|
23 |
+
@abc.abstractmethod
|
24 |
+
def fit_transform(self, data: xr.DataArray):
|
25 |
+
return self.fit(data).transform(data)
|
26 |
+
|
27 |
+
@abc.abstractmethod
|
28 |
+
def to_dict(self) -> dict:
|
29 |
+
raise NotImplementedError()
|
30 |
+
|
31 |
+
@staticmethod
|
32 |
+
@abc.abstractmethod
|
33 |
+
def from_dict(info: dict):
|
34 |
+
raise NotImplementedError()
|
35 |
+
|
36 |
+
@abc.abstractmethod
|
37 |
+
def reset(self):
|
38 |
+
raise NotImplementedError()
|
39 |
+
|
40 |
+
|
41 |
+
class MinMaxScaler(Transformation):
|
42 |
+
"""_summary_
|
43 |
+
Minmax scaling on the entire data
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, new_min=1, new_max=2):
|
47 |
+
self._is_fitted = False
|
48 |
+
self.new_min = new_min
|
49 |
+
self.new_max = new_max
|
50 |
+
self._min = None
|
51 |
+
self._max = None
|
52 |
+
|
53 |
+
@property
|
54 |
+
def min(self) -> float:
|
55 |
+
return self._min
|
56 |
+
|
57 |
+
@property
|
58 |
+
def max(self) -> float:
|
59 |
+
return self._max
|
60 |
+
|
61 |
+
@property
|
62 |
+
def is_fitted(self) -> bool:
|
63 |
+
return self._is_fitted
|
64 |
+
|
65 |
+
def fit(self, data: xr.DataArray):
|
66 |
+
if not self.is_fitted:
|
67 |
+
self._max = data.max().values
|
68 |
+
self._min = data.min().values
|
69 |
+
self._is_fitted = True
|
70 |
+
else:
|
71 |
+
info("Already fitted, skipping function.")
|
72 |
+
return self
|
73 |
+
|
74 |
+
def _transform(self, data: xr.DataArray):
|
75 |
+
return (
|
76 |
+
((data - self.min) / (self.max - self.min)) * (self.new_max - self.new_min)
|
77 |
+
) + self.new_min
|
78 |
+
|
79 |
+
def transform(self, data: xr.DataArray) -> xr.DataArray:
|
80 |
+
assert self.min is not None and self.max is not None, "You must run fit first."
|
81 |
+
|
82 |
+
data = xr.apply_ufunc(self._transform, data, dask="forbidden")
|
83 |
+
|
84 |
+
return data
|
85 |
+
|
86 |
+
def fit_transform(self, data):
|
87 |
+
self.fit(data)
|
88 |
+
return self.transform(data)
|
89 |
+
|
90 |
+
def inverse_transform(self, data):
|
91 |
+
return data * (self.max - self.min) + self.min
|
92 |
+
|
93 |
+
def to_dict(self) -> dict:
|
94 |
+
out_dict = {
|
95 |
+
"base": self.__module__,
|
96 |
+
"class": self.__class__.__name__,
|
97 |
+
"new_min": str(self.new_min),
|
98 |
+
"new_max": str(self.new_max),
|
99 |
+
"min": str(self.min),
|
100 |
+
"max": str(self.max),
|
101 |
+
"is_fitted": self.is_fitted,
|
102 |
+
}
|
103 |
+
return out_dict
|
104 |
+
|
105 |
+
@staticmethod
|
106 |
+
def from_dict(info: dict):
|
107 |
+
# with open(yaml_path, 'r') as file:
|
108 |
+
# data = yaml.load(file, Loader=yaml.SafeLoader)
|
109 |
+
out = MinMaxScaler(
|
110 |
+
new_min=np.float32(info["new_min"]), new_max=np.float32(info["new_max"])
|
111 |
+
)
|
112 |
+
out._min = np.float32(info["min"])
|
113 |
+
out._max = np.float32(info["max"])
|
114 |
+
out._is_fitted = info["is_fitted"]
|
115 |
+
return out
|
116 |
+
|
117 |
+
def reset(self):
|
118 |
+
self.__init__(self.new_min, self.new_max)
|
119 |
+
|
120 |
+
def __str__(self):
|
121 |
+
return (
|
122 |
+
f"min: {self.min}, "
|
123 |
+
f"max: {self.max}, "
|
124 |
+
f"new_max: {self.new_max}, "
|
125 |
+
f"new_min: {self.new_min}"
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
class StandardScaler(Transformation):
|
130 |
+
"""_summary_
|
131 |
+
Standard scaling on the entire data
|
132 |
+
"""
|
133 |
+
|
134 |
+
def __init__(self, epsilon=1e-8):
|
135 |
+
self.epsilon = epsilon
|
136 |
+
self._is_fitted = False
|
137 |
+
self._mean = None
|
138 |
+
self._std = None
|
139 |
+
self._min = None
|
140 |
+
self._max = None
|
141 |
+
self._sl_scale_factor = None
|
142 |
+
|
143 |
+
@property
|
144 |
+
def mean(self) -> float:
|
145 |
+
return self._mean
|
146 |
+
|
147 |
+
@property
|
148 |
+
def std(self) -> float:
|
149 |
+
return self._std
|
150 |
+
|
151 |
+
@property
|
152 |
+
def min(self) -> float:
|
153 |
+
return self._min
|
154 |
+
|
155 |
+
@property
|
156 |
+
def max(self) -> float:
|
157 |
+
return self._max
|
158 |
+
|
159 |
+
@property
|
160 |
+
def sl_scale_factor(self) -> float:
|
161 |
+
return self._sl_scale_factor
|
162 |
+
|
163 |
+
@property
|
164 |
+
def is_fitted(self) -> bool:
|
165 |
+
return self._is_fitted
|
166 |
+
|
167 |
+
def fit(self, data):
|
168 |
+
if not self.is_fitted:
|
169 |
+
self._mean = data.mean().values
|
170 |
+
self._std = data.std().values
|
171 |
+
self._min = data.min().values
|
172 |
+
self._max = data.max().values
|
173 |
+
self._is_fitted = True
|
174 |
+
else:
|
175 |
+
info("Already fitted, skipping function.")
|
176 |
+
|
177 |
+
return self
|
178 |
+
|
179 |
+
def _transform(self, data: xr.DataArray):
|
180 |
+
return (data - self.mean) / (self.std + self.epsilon)
|
181 |
+
|
182 |
+
def _signum_log_transform(self, data: xr.DataArray):
|
183 |
+
data = data * self.sl_scale_factor
|
184 |
+
return np.sign(data) * np.log1p(np.abs(data))
|
185 |
+
|
186 |
+
def signum_log_transform(self, data: xr.DataArray):
|
187 |
+
assert self.mean is not None and self.std is not None, "You must run fit first."
|
188 |
+
|
189 |
+
data = xr.apply_ufunc(self._signum_log_transform, data, dask="forbidden")
|
190 |
+
data = xr.apply_ufunc(self._transform, data, dask="forbidden")
|
191 |
+
return data
|
192 |
+
|
193 |
+
def transform(self, data: xr.DataArray):
|
194 |
+
assert self.mean is not None and self.std is not None, "You must run fit first."
|
195 |
+
|
196 |
+
data = xr.apply_ufunc(self._transform, data, dask="forbidden")
|
197 |
+
return data
|
198 |
+
|
199 |
+
def fit_transform(self, data: xr.DataArray):
|
200 |
+
self.fit(data)
|
201 |
+
return self.transform(data)
|
202 |
+
|
203 |
+
def inverse_transform(self, data):
|
204 |
+
if isinstance(data, torch.Tensor):
|
205 |
+
return data * (
|
206 |
+
torch.Tensor([self.std]).to(data.device)
|
207 |
+
+ torch.Tensor([self.epsilon]).to(data.device)
|
208 |
+
) + torch.Tensor([self.mean]).to(data.device)
|
209 |
+
else:
|
210 |
+
return data * (self.std + self.epsilon) + self.mean
|
211 |
+
|
212 |
+
def inverse_signum_log_transform(self, data):
|
213 |
+
if isinstance(data, torch.Tensor):
|
214 |
+
return (
|
215 |
+
torch.sign(data)
|
216 |
+
* torch.expm1(torch.abs(data))
|
217 |
+
/ torch.Tensor([self.sl_scale_factor]).to(data.device)
|
218 |
+
)
|
219 |
+
else:
|
220 |
+
return np.sign(data) * np.expm1(np.abs(data)) / self.sl_scale_factor
|
221 |
+
|
222 |
+
def to_dict(self) -> dict:
|
223 |
+
return {
|
224 |
+
"base": self.__module__,
|
225 |
+
"class": self.__class__.__name__,
|
226 |
+
"epsilon": str(self.epsilon),
|
227 |
+
"mean": str(self.mean),
|
228 |
+
"std": str(self.std),
|
229 |
+
"is_fitted": self.is_fitted,
|
230 |
+
"min": str(self.min),
|
231 |
+
"max": str(self.max),
|
232 |
+
"sl_scale_factor": str(self.sl_scale_factor),
|
233 |
+
}
|
234 |
+
|
235 |
+
@staticmethod
|
236 |
+
def from_dict(info: dict):
|
237 |
+
out = StandardScaler(epsilon=np.float32(info["epsilon"]))
|
238 |
+
out._mean = np.float32(info["mean"])
|
239 |
+
out._std = np.float32(info["std"])
|
240 |
+
out._is_fitted = info["is_fitted"]
|
241 |
+
out._min = np.float32(info["min"])
|
242 |
+
out._max = np.float32(info["max"])
|
243 |
+
out._sl_scale_factor = np.float32(info["sl_scale_factor"])
|
244 |
+
return out
|
245 |
+
|
246 |
+
def reset(self):
|
247 |
+
self.__init__(self.epsilon)
|
248 |
+
|
249 |
+
def __str__(self):
|
250 |
+
return f"mean: {self.mean}, " f"std: {self.std}, " f"epsilon: {self.epsilon}"
|
251 |
+
|
252 |
+
|
253 |
+
class MaskUnits2D:
|
254 |
+
"""
|
255 |
+
Transformation that takes a tuple of numpy tensors and returns a sequence of mask units. These are generally in the form `channel, dim_0, dim_1, dim_2, ...`. The returned data is largely of shape `mask unit sequence, channel, lat, lon`. Masked patches are not returned.
|
256 |
+
The return values contain sets of indices. The indices indicate which mask units where dropped (masked) or not. The 1D indexing here simply relies on flattening the 2D space of mask units. The class methods `reconstruct` and `reconstruct_batch` show how to re-assemble the entire sequence.
|
257 |
+
"""
|
258 |
+
|
259 |
+
def __init__(
|
260 |
+
self,
|
261 |
+
n_lat_mu: int,
|
262 |
+
n_lon_mu: int,
|
263 |
+
padding,
|
264 |
+
seed=None,
|
265 |
+
mask_ratio_vals: float = 0.5,
|
266 |
+
mask_ratio_tars: float = 0.0,
|
267 |
+
n_lats: int = 361,
|
268 |
+
n_lons: int = 576,
|
269 |
+
):
|
270 |
+
self.n_lat_mu = n_lat_mu
|
271 |
+
self.n_lon_mu = n_lon_mu
|
272 |
+
self.mask_ratio_vals = mask_ratio_vals
|
273 |
+
self.mask_ratio_tars = mask_ratio_tars
|
274 |
+
self.padding = padding
|
275 |
+
self.n_lats = n_lats + padding[0][0] + padding[0][1]
|
276 |
+
self.n_lons = n_lons + padding[1][0] + padding[1][1]
|
277 |
+
|
278 |
+
if self.n_lats % n_lat_mu != 0:
|
279 |
+
raise ValueError(
|
280 |
+
f"Padded latitudes {self.n_lats} are not an integer multiple of the mask unit size {n_lat_mu}."
|
281 |
+
)
|
282 |
+
if self.n_lons % n_lon_mu != 0:
|
283 |
+
raise ValueError(
|
284 |
+
f"Padded longitudes {self.n_lons} are not an integer multiple of the mask unit size {n_lon_mu}."
|
285 |
+
)
|
286 |
+
|
287 |
+
self.mask_shape = (self.n_lats // self.n_lat_mu, self.n_lons // self.n_lon_mu)
|
288 |
+
|
289 |
+
self.rng = np.random.default_rng(seed=seed)
|
290 |
+
|
291 |
+
def n_units_masked(self, mask_type="vals"):
|
292 |
+
if mask_type == "vals":
|
293 |
+
return int(self.mask_ratio_vals * np.prod(self.mask_shape))
|
294 |
+
elif mask_type == "tars":
|
295 |
+
return int(self.mask_ratio_tars * np.prod(self.mask_shape))
|
296 |
+
else:
|
297 |
+
raise ValueError(
|
298 |
+
f"`{mask_type}` not an allowed value for `mask_type`. Use `vals` or `tars`."
|
299 |
+
)
|
300 |
+
|
301 |
+
@staticmethod
|
302 |
+
def reconstruct(
|
303 |
+
idx_masked: torch.Tensor,
|
304 |
+
idx_unmasked: torch.Tensor,
|
305 |
+
data_masked: torch.Tensor,
|
306 |
+
data_unmasked: torch.Tensor,
|
307 |
+
) -> torch.Tensor:
|
308 |
+
"""
|
309 |
+
Reconstructs a tensor along the mask unit dimension. Non-batched version.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
idx_masked: Tensor of shape `mask unit sequence`.
|
313 |
+
idx_unmasked: Tensor of shape `mask unit sequence`.
|
314 |
+
data_masked: Tensor of shape `mask unit sequence, ...`. Should have same size along mask unit sequence dimension as idx_masked. Dimensions beyond the first two, marked here as ... will typically be `local_sequence, channel` or `channel, lat, lon`. These dimensions should agree with data_unmasked.
|
315 |
+
data_unmasked: Tensor of shape `mask unit sequence, ...`. Should have same size along mask unit sequence dimension as idx_unmasked. Dimensions beyond the first two, marked here as ... will typically be `local_sequence, channel` or `channel, lat, lon`. These dimensions should agree with data_masked.
|
316 |
+
Returns:
|
317 |
+
Tensor of same shape as inputs data_masked and data_unmasked. I.e. `mask unit sequence, ...`.
|
318 |
+
"""
|
319 |
+
idx_total = torch.argsort(torch.cat([idx_masked, idx_unmasked], dim=0), dim=0)
|
320 |
+
idx_total = idx_total.reshape(
|
321 |
+
*idx_total.shape,
|
322 |
+
*[1 for _ in range(len(idx_total.shape), len(data_unmasked.shape))],
|
323 |
+
)
|
324 |
+
idx_total = idx_total.expand(*idx_total.shape[:1], *data_unmasked.shape[1:])
|
325 |
+
data = torch.cat([data_masked, data_unmasked], dim=0)
|
326 |
+
data = torch.gather(data, dim=0, index=idx_total)
|
327 |
+
return data
|
328 |
+
|
329 |
+
@staticmethod
|
330 |
+
def reconstruct_batch(
|
331 |
+
idx_masked: torch.Tensor,
|
332 |
+
idx_unmasked: torch.Tensor,
|
333 |
+
data_masked: torch.Tensor,
|
334 |
+
data_unmasked: torch.Tensor,
|
335 |
+
) -> torch.Tensor:
|
336 |
+
"""
|
337 |
+
Reconstructs a tensor along the mask unit dimension. Batched version.
|
338 |
+
|
339 |
+
Args:
|
340 |
+
idx_masked: Tensor of shape `batch, mask unit sequence`.
|
341 |
+
idx_unmasked: Tensor of shape `batch, mask unit sequence`.
|
342 |
+
data_masked: Tensor of shape `batch, mask unit sequence, ...`. Should have same size along mask unit sequence dimension as idx_masked. Dimensions beyond the first two, marked here as ... will typically be `local_sequence, channel` or `channel, lat, lon`. These dimensions should agree with data_unmasked.
|
343 |
+
data_unmasked: Tensor of shape `batch, mask unit sequence, ...`. Should have same size along mask unit sequence dimension as idx_unmasked. Dimensions beyond the first two, marked here as ... will typically be `local_sequence, channel` or `channel, lat, lon`. These dimensions should agree with data_masked.
|
344 |
+
Returns:
|
345 |
+
Tensor of same shape as inputs data_masked and data_unmasked. I.e. `batch, mask unit sequence, ...`.
|
346 |
+
"""
|
347 |
+
idx_total = torch.argsort(torch.cat([idx_masked, idx_unmasked], dim=1), dim=1)
|
348 |
+
idx_total = idx_total.reshape(
|
349 |
+
*idx_total.shape,
|
350 |
+
*[1 for _ in range(len(idx_total.shape), len(data_unmasked.shape))],
|
351 |
+
)
|
352 |
+
idx_total = idx_total.expand(*idx_total.shape[:2], *data_unmasked.shape[2:])
|
353 |
+
data = torch.cat([data_masked, data_unmasked], dim=1)
|
354 |
+
data = torch.gather(data, dim=1, index=idx_total)
|
355 |
+
return data
|
356 |
+
|
357 |
+
def __call__(self, data: Tuple[np.array]) -> Tuple[torch.Tensor]:
|
358 |
+
"""
|
359 |
+
Args:
|
360 |
+
data: Tuple of numpy tensors. These are interpreted as `(sur_static, ulv_static, sur_vals, ulv_vals, sur_tars, ulv_tars)`.
|
361 |
+
Returns:
|
362 |
+
Tuple of torch tensors. If the target is unmasked (`mask_ratio_tars` is zero), the tuple contains
|
363 |
+
`(static, indices_masked_vals, indices_unmaked_vals, vals, tars)`. When targets are masked as well, we are dealing with
|
364 |
+
`(static, indices_masked_vals, indices_unmaked_vals, vals, indices_masked_tars, indices_unmasked_tars, tars)`.
|
365 |
+
Their shapes are as follows:
|
366 |
+
static: mask unit sequence, channel, lat, lon
|
367 |
+
indices_masked_vals: mask unit sequence
|
368 |
+
indices_unmaked_vals: mask unit sequence
|
369 |
+
vals: mask unit sequence, channel, lat, lon
|
370 |
+
tars: mask unit sequence, channel, lat, lon
|
371 |
+
"""
|
372 |
+
sur_static, ulv_static, sur_vals, ulv_vals, sur_tars, ulv_tars = data
|
373 |
+
|
374 |
+
sur_vals, ulv_vals = np.squeeze(sur_vals, axis=1), np.squeeze(ulv_vals, axis=1)
|
375 |
+
sur_tars, ulv_tars = np.squeeze(sur_tars, axis=1), np.squeeze(ulv_tars, axis=1)
|
376 |
+
|
377 |
+
vals = np.concatenate(
|
378 |
+
[
|
379 |
+
sur_vals,
|
380 |
+
ulv_vals.reshape(
|
381 |
+
ulv_vals.shape[0] * ulv_vals.shape[1], *ulv_vals.shape[-2:]
|
382 |
+
),
|
383 |
+
],
|
384 |
+
axis=0,
|
385 |
+
)
|
386 |
+
tars = np.concatenate(
|
387 |
+
[
|
388 |
+
sur_tars,
|
389 |
+
ulv_tars.reshape(
|
390 |
+
ulv_tars.shape[0] * ulv_tars.shape[1], *ulv_tars.shape[-2:]
|
391 |
+
),
|
392 |
+
],
|
393 |
+
axis=0,
|
394 |
+
)
|
395 |
+
|
396 |
+
padding = ((0, 0), *self.padding)
|
397 |
+
static = np.pad(sur_static, padding)
|
398 |
+
vals = np.pad(vals, padding)
|
399 |
+
tars = np.pad(tars, padding)
|
400 |
+
|
401 |
+
static = static.reshape(
|
402 |
+
static.shape[0],
|
403 |
+
static.shape[-2] // self.n_lat_mu,
|
404 |
+
self.n_lat_mu,
|
405 |
+
static.shape[-1] // self.n_lon_mu,
|
406 |
+
self.n_lon_mu,
|
407 |
+
).transpose(1, 3, 0, 2, 4)
|
408 |
+
vals = vals.reshape(
|
409 |
+
vals.shape[0],
|
410 |
+
vals.shape[-2] // self.n_lat_mu,
|
411 |
+
self.n_lat_mu,
|
412 |
+
vals.shape[-1] // self.n_lon_mu,
|
413 |
+
self.n_lon_mu,
|
414 |
+
).transpose(1, 3, 0, 2, 4)
|
415 |
+
tars = tars.reshape(
|
416 |
+
tars.shape[0],
|
417 |
+
tars.shape[-2] // self.n_lat_mu,
|
418 |
+
self.n_lat_mu,
|
419 |
+
tars.shape[-1] // self.n_lon_mu,
|
420 |
+
self.n_lon_mu,
|
421 |
+
).transpose(1, 3, 0, 2, 4)
|
422 |
+
|
423 |
+
maskable_indices = np.arange(np.prod(self.mask_shape))
|
424 |
+
maskable_indices = self.rng.permutation(maskable_indices)
|
425 |
+
indices_masked_vals = maskable_indices[: self.n_units_masked()]
|
426 |
+
indices_unmasked_vals = maskable_indices[self.n_units_masked() :]
|
427 |
+
|
428 |
+
vals = vals.reshape(-1, *vals.shape[2:])[indices_unmasked_vals, :, :, :]
|
429 |
+
|
430 |
+
if self.mask_ratio_tars > 0.0:
|
431 |
+
maskable_indices = np.arange(np.prod(self.mask_shape))
|
432 |
+
maskable_indices = self.rng.permutation(maskable_indices)
|
433 |
+
indices_masked_tars = maskable_indices[: self.n_units_masked("tars")]
|
434 |
+
indices_unmasked_tars = maskable_indices[self.n_units_masked("tars") :]
|
435 |
+
|
436 |
+
tars = tars.reshape(-1, *tars.shape[2:])[indices_unmasked_tars, :, :, :]
|
437 |
+
|
438 |
+
return_value = (
|
439 |
+
torch.from_numpy(static).flatten(0, 1),
|
440 |
+
torch.from_numpy(indices_masked_vals),
|
441 |
+
torch.from_numpy(indices_unmasked_vals),
|
442 |
+
torch.from_numpy(vals),
|
443 |
+
torch.from_numpy(indices_masked_tars),
|
444 |
+
torch.from_numpy(indices_unmasked_tars),
|
445 |
+
torch.from_numpy(tars),
|
446 |
+
)
|
447 |
+
return return_value
|
448 |
+
else:
|
449 |
+
return_value = (
|
450 |
+
torch.from_numpy(static).flatten(0, 1),
|
451 |
+
torch.from_numpy(indices_masked_vals),
|
452 |
+
torch.from_numpy(indices_unmasked_vals),
|
453 |
+
torch.from_numpy(vals),
|
454 |
+
torch.from_numpy(tars).flatten(0, 1),
|
455 |
+
)
|
456 |
+
return return_value
|
surya/models/__init__.py
ADDED
File without changes
|
surya/models/embedding.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Perceiver code is based on Aurora: https://github.com/microsoft/aurora/blob/main/aurora/model/perceiver.py
|
3 |
+
|
4 |
+
Some conventions for notation:
|
5 |
+
B - Batch
|
6 |
+
T - Time
|
7 |
+
H - Height (pixel space)
|
8 |
+
W - Width (pixel space)
|
9 |
+
HT - Height (token space)
|
10 |
+
WT - Width (token space)
|
11 |
+
ST - Sequence (token space)
|
12 |
+
C - Input channels
|
13 |
+
D - Model (embedding) dimension
|
14 |
+
"""
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
from einops import rearrange
|
20 |
+
from timm.models.layers import trunc_normal_
|
21 |
+
|
22 |
+
|
23 |
+
class PatchEmbed3D(nn.Module):
|
24 |
+
"""Timeseries Image to Patch Embedding"""
|
25 |
+
|
26 |
+
def __init__(
|
27 |
+
self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, time_dim=2
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.img_size = img_size
|
31 |
+
self.patch_size = patch_size
|
32 |
+
self.embed_dim = embed_dim
|
33 |
+
self.time_dim = time_dim
|
34 |
+
|
35 |
+
self.proj = nn.Conv2d(
|
36 |
+
in_chans * time_dim,
|
37 |
+
embed_dim,
|
38 |
+
kernel_size=(patch_size, patch_size),
|
39 |
+
stride=(patch_size, patch_size),
|
40 |
+
)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
"""
|
44 |
+
Args:
|
45 |
+
x: Tensor of shape (B, C, T, H, W)
|
46 |
+
Returns:
|
47 |
+
Tensor of shape (B, ST, D)
|
48 |
+
"""
|
49 |
+
B, C, T, H, W = x.shape
|
50 |
+
x = self.proj(x.flatten(1, 2)) # (B, C, T, H, W) -> (B, D, HT, WT)
|
51 |
+
x = rearrange(x, "B D HT WT -> B (HT WT) D") # (B, N, D)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class LinearEmbedding(nn.Module):
|
56 |
+
def __init__(
|
57 |
+
self,
|
58 |
+
img_size=224,
|
59 |
+
patch_size=16,
|
60 |
+
in_chans=3,
|
61 |
+
time_dim=2,
|
62 |
+
embed_dim=768,
|
63 |
+
drop_rate=0.0,
|
64 |
+
):
|
65 |
+
super().__init__()
|
66 |
+
|
67 |
+
self.num_patches = (img_size // patch_size) ** 2
|
68 |
+
|
69 |
+
self.patch_embed = PatchEmbed3D(
|
70 |
+
img_size=img_size,
|
71 |
+
patch_size=patch_size,
|
72 |
+
in_chans=in_chans,
|
73 |
+
embed_dim=embed_dim,
|
74 |
+
time_dim=time_dim,
|
75 |
+
)
|
76 |
+
|
77 |
+
self._generate_position_encoding(img_size, patch_size, embed_dim)
|
78 |
+
|
79 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
80 |
+
|
81 |
+
def _generate_position_encoding(self, img_size, patch_size, embed_dim):
|
82 |
+
"""
|
83 |
+
Generates a positional encoding signal for the model. The generated
|
84 |
+
positional encoding signal is stored as a buffer (`self.fourier_signal`).
|
85 |
+
|
86 |
+
Args:
|
87 |
+
img_size (int): The size of the input image.
|
88 |
+
patch_size (int): The size of each patch in the image.
|
89 |
+
embed_dim (int): The embedding dimension of the model.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
None.
|
93 |
+
"""
|
94 |
+
# Generate signal of shape (C, H, W)
|
95 |
+
x = torch.linspace(0.0, 1.0, img_size // patch_size)
|
96 |
+
y = torch.linspace(0.0, 1.0, img_size // patch_size)
|
97 |
+
x, y = torch.meshgrid(x, y, indexing="xy")
|
98 |
+
fourier_signal = []
|
99 |
+
|
100 |
+
frequencies = torch.linspace(1, (img_size // patch_size) / 2.0, embed_dim // 4)
|
101 |
+
|
102 |
+
for f in frequencies:
|
103 |
+
fourier_signal.extend(
|
104 |
+
[
|
105 |
+
torch.cos(2.0 * torch.pi * f * x),
|
106 |
+
torch.sin(2.0 * torch.pi * f * x),
|
107 |
+
torch.cos(2.0 * torch.pi * f * y),
|
108 |
+
torch.sin(2.0 * torch.pi * f * y),
|
109 |
+
]
|
110 |
+
)
|
111 |
+
fourier_signal = torch.stack(fourier_signal, dim=2)
|
112 |
+
fourier_signal = rearrange(fourier_signal, "h w c -> 1 (h w) c")
|
113 |
+
self.register_buffer("pos_embed", fourier_signal)
|
114 |
+
|
115 |
+
def forward(self, x, dt):
|
116 |
+
"""
|
117 |
+
Args:
|
118 |
+
x: Tensor of shape (B, C, T, H, W).
|
119 |
+
dt: Tensor of shape (B, T). However it is not used.
|
120 |
+
Returns:
|
121 |
+
Tensor of shape (B, ST, D)
|
122 |
+
"""
|
123 |
+
x = self.patch_embed(x)
|
124 |
+
x = x + self.pos_embed
|
125 |
+
x = self.pos_drop(x)
|
126 |
+
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
class LinearDecoder(nn.Module):
|
131 |
+
def __init__(
|
132 |
+
self,
|
133 |
+
patch_size: int,
|
134 |
+
out_chans: int,
|
135 |
+
embed_dim: int,
|
136 |
+
):
|
137 |
+
"""
|
138 |
+
Args:
|
139 |
+
patch_size: patch size
|
140 |
+
in_chans: number of iput channels
|
141 |
+
embed_dim: embedding dimension
|
142 |
+
"""
|
143 |
+
super().__init__()
|
144 |
+
|
145 |
+
self.unembed = nn.Sequential(
|
146 |
+
nn.Conv2d(
|
147 |
+
in_channels=embed_dim,
|
148 |
+
out_channels=(patch_size**2) * out_chans,
|
149 |
+
kernel_size=1,
|
150 |
+
),
|
151 |
+
nn.PixelShuffle(patch_size),
|
152 |
+
)
|
153 |
+
|
154 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
155 |
+
"""
|
156 |
+
Args:
|
157 |
+
x: Tensor of shape (B, L, D). For ensembles, we have implicitly B = (B E).
|
158 |
+
Returns:
|
159 |
+
Tensor of shape (B C H W).
|
160 |
+
Here
|
161 |
+
- C equals num_queries
|
162 |
+
- H == W == sqrt(L) x patch_size
|
163 |
+
"""
|
164 |
+
# Reshape the tokens to 2d token space: (B, C, H_token, W_token)
|
165 |
+
_, L, _ = x.shape
|
166 |
+
H_token = W_token = int(L**0.5)
|
167 |
+
x = rearrange(x, "B (H W) D -> B D H W", H=H_token, W=W_token)
|
168 |
+
|
169 |
+
# Unembed the tokens. Convolution + pixel shuffle.
|
170 |
+
x = self.unembed(x)
|
171 |
+
|
172 |
+
return x
|
173 |
+
|
174 |
+
|
175 |
+
class MLP(nn.Module):
|
176 |
+
"""A simple one-hidden-layer MLP."""
|
177 |
+
|
178 |
+
def __init__(self, dim: int, hidden_features: int, dropout: float = 0.0) -> None:
|
179 |
+
"""Initialise.
|
180 |
+
|
181 |
+
Args:
|
182 |
+
dim (int): Input dimensionality.
|
183 |
+
hidden_features (int): Width of the hidden layer.
|
184 |
+
dropout (float, optional): Drop-out rate. Defaults to no drop-out.
|
185 |
+
"""
|
186 |
+
super().__init__()
|
187 |
+
self.net = nn.Sequential(
|
188 |
+
nn.Linear(dim, hidden_features),
|
189 |
+
nn.GELU(),
|
190 |
+
nn.Linear(hidden_features, dim),
|
191 |
+
nn.Dropout(dropout),
|
192 |
+
)
|
193 |
+
|
194 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
195 |
+
"""Run the MLP."""
|
196 |
+
return self.net(x)
|
197 |
+
|
198 |
+
|
199 |
+
class PerceiverAttention(nn.Module):
|
200 |
+
"""Cross attention module from the Perceiver architecture."""
|
201 |
+
|
202 |
+
def __init__(
|
203 |
+
self,
|
204 |
+
latent_dim: int,
|
205 |
+
context_dim: int,
|
206 |
+
head_dim: int = 64,
|
207 |
+
num_heads: int = 8,
|
208 |
+
) -> None:
|
209 |
+
"""Initialise.
|
210 |
+
|
211 |
+
Args:
|
212 |
+
latent_dim (int): Dimensionality of the latent features given as input.
|
213 |
+
context_dim (int): Dimensionality of the context features also given as input.
|
214 |
+
head_dim (int): Attention head dimensionality.
|
215 |
+
num_heads (int): Number of heads.
|
216 |
+
"""
|
217 |
+
super().__init__()
|
218 |
+
self.num_heads = num_heads
|
219 |
+
self.head_dim = head_dim
|
220 |
+
self.inner_dim = head_dim * num_heads
|
221 |
+
|
222 |
+
self.to_q = nn.Linear(latent_dim, self.inner_dim, bias=False)
|
223 |
+
self.to_kv = nn.Linear(context_dim, self.inner_dim * 2, bias=False)
|
224 |
+
self.to_out = nn.Linear(self.inner_dim, latent_dim, bias=False)
|
225 |
+
|
226 |
+
def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
227 |
+
"""Run the cross-attention module.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
latents (:class:`torch.Tensor`): Latent features of shape `(B, L1, Latent_D)`
|
231 |
+
where typically `L1 < L2` and `Latent_D <= Context_D`. `Latent_D` is equal to
|
232 |
+
`self.latent_dim`.
|
233 |
+
x (:class:`torch.Tensor`): Context features of shape `(B, L2, Context_D)`.
|
234 |
+
|
235 |
+
Returns:
|
236 |
+
:class:`torch.Tensor`: Latent values of shape `(B, L1, Latent_D)`.
|
237 |
+
"""
|
238 |
+
h = self.num_heads
|
239 |
+
|
240 |
+
q = self.to_q(latents) # (B, L1, D2) to (B, L1, D)
|
241 |
+
k, v = self.to_kv(x).chunk(2, dim=-1) # (B, L2, D1) to twice (B, L2, D)
|
242 |
+
q, k, v = map(lambda t: rearrange(t, "b l (h d) -> b h l d", h=h), (q, k, v))
|
243 |
+
|
244 |
+
out = F.scaled_dot_product_attention(q, k, v)
|
245 |
+
out = rearrange(out, "B H L1 D -> B L1 (H D)") # (B, L1, D)
|
246 |
+
return self.to_out(out) # (B, L1, Latent_D)
|
247 |
+
|
248 |
+
|
249 |
+
class PerceiverResampler(nn.Module):
|
250 |
+
"""Perceiver Resampler module from the Flamingo paper."""
|
251 |
+
|
252 |
+
def __init__(
|
253 |
+
self,
|
254 |
+
latent_dim: int,
|
255 |
+
context_dim: int,
|
256 |
+
depth: int = 1,
|
257 |
+
head_dim: int = 64,
|
258 |
+
num_heads: int = 16,
|
259 |
+
mlp_ratio: float = 4.0,
|
260 |
+
drop: float = 0.0,
|
261 |
+
residual_latent: bool = True,
|
262 |
+
ln_eps: float = 1e-5,
|
263 |
+
) -> None:
|
264 |
+
"""Initialise.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
latent_dim (int): Dimensionality of the latent features given as input.
|
268 |
+
context_dim (int): Dimensionality of the context features also given as input.
|
269 |
+
depth (int, optional): Number of attention layers.
|
270 |
+
head_dim (int, optional): Attention head dimensionality. Defaults to `64`.
|
271 |
+
num_heads (int, optional): Number of heads. Defaults to `16`
|
272 |
+
mlp_ratio (float, optional): Rimensionality of the hidden layer divided by that of the
|
273 |
+
input for all MLPs. Defaults to `4.0`.
|
274 |
+
drop (float, optional): Drop-out rate. Defaults to no drop-out.
|
275 |
+
residual_latent (bool, optional): Use residual attention w.r.t. the latent features.
|
276 |
+
Defaults to `True`.
|
277 |
+
ln_eps (float, optional): Epsilon in the layer normalisation layers. Defaults to
|
278 |
+
`1e-5`.
|
279 |
+
"""
|
280 |
+
super().__init__()
|
281 |
+
|
282 |
+
self.residual_latent = residual_latent
|
283 |
+
self.layers = nn.ModuleList([])
|
284 |
+
mlp_hidden_dim = int(latent_dim * mlp_ratio)
|
285 |
+
for _ in range(depth):
|
286 |
+
self.layers.append(
|
287 |
+
nn.ModuleList(
|
288 |
+
[
|
289 |
+
PerceiverAttention(
|
290 |
+
latent_dim=latent_dim,
|
291 |
+
context_dim=context_dim,
|
292 |
+
head_dim=head_dim,
|
293 |
+
num_heads=num_heads,
|
294 |
+
),
|
295 |
+
MLP(
|
296 |
+
dim=latent_dim, hidden_features=mlp_hidden_dim, dropout=drop
|
297 |
+
),
|
298 |
+
nn.LayerNorm(latent_dim, eps=ln_eps),
|
299 |
+
nn.LayerNorm(latent_dim, eps=ln_eps),
|
300 |
+
]
|
301 |
+
)
|
302 |
+
)
|
303 |
+
|
304 |
+
def forward(self, latents: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
305 |
+
"""Run the module.
|
306 |
+
|
307 |
+
Args:
|
308 |
+
latents (:class:`torch.Tensor`): Latent features of shape `(B, L1, D1)`.
|
309 |
+
x (:class:`torch.Tensor`): Context features of shape `(B, L2, D1)`.
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
torch.Tensor: Latent features of shape `(B, L1, D1)`.
|
313 |
+
"""
|
314 |
+
for attn, ff, ln1, ln2 in self.layers:
|
315 |
+
# We use post-res-norm like in Swin v2 and most Transformer architectures these days.
|
316 |
+
# This empirically works better than the pre-norm used in the original Perceiver.
|
317 |
+
attn_out = ln1(attn(latents, x))
|
318 |
+
# HuggingFace suggests using non-residual attention in Perceiver might work better when
|
319 |
+
# the semantics of the query and the output are different:
|
320 |
+
#
|
321 |
+
# https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/perceiver/modeling_perceiver.py#L398
|
322 |
+
#
|
323 |
+
latents = attn_out + latents if self.residual_latent else attn_out
|
324 |
+
latents = ln2(ff(latents)) + latents
|
325 |
+
return latents
|
326 |
+
|
327 |
+
|
328 |
+
class PerceiverChannelEmbedding(nn.Module):
|
329 |
+
def __init__(
|
330 |
+
self,
|
331 |
+
in_chans: int,
|
332 |
+
img_size: int,
|
333 |
+
patch_size: int,
|
334 |
+
time_dim: int,
|
335 |
+
num_queries: int,
|
336 |
+
embed_dim: int,
|
337 |
+
drop_rate: float,
|
338 |
+
):
|
339 |
+
super().__init__()
|
340 |
+
|
341 |
+
if embed_dim % 2 != 0:
|
342 |
+
raise ValueError(
|
343 |
+
f"Temporal embeddings require `embed_dim` to be even. Currently we have {embed_dim}."
|
344 |
+
)
|
345 |
+
|
346 |
+
self.num_patches = (img_size // patch_size) ** 2
|
347 |
+
self.num_queries = num_queries
|
348 |
+
self.embed_dim = embed_dim
|
349 |
+
|
350 |
+
self.proj = nn.Conv2d(
|
351 |
+
in_channels=in_chans * time_dim,
|
352 |
+
out_channels=in_chans * embed_dim,
|
353 |
+
kernel_size=patch_size,
|
354 |
+
stride=patch_size,
|
355 |
+
groups=in_chans,
|
356 |
+
)
|
357 |
+
|
358 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, self.num_patches))
|
359 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
360 |
+
|
361 |
+
self.latent_queries = nn.Parameter(torch.zeros(1, num_queries, embed_dim))
|
362 |
+
trunc_normal_(self.latent_queries, std=0.02)
|
363 |
+
|
364 |
+
self.perceiver = PerceiverResampler(
|
365 |
+
latent_dim=embed_dim,
|
366 |
+
context_dim=embed_dim,
|
367 |
+
depth=1,
|
368 |
+
head_dim=embed_dim // 16,
|
369 |
+
num_heads=16,
|
370 |
+
mlp_ratio=4.0,
|
371 |
+
drop=0.0,
|
372 |
+
residual_latent=False,
|
373 |
+
ln_eps=1e-5,
|
374 |
+
)
|
375 |
+
|
376 |
+
self.latent_aggregation = nn.Linear(num_queries * embed_dim, embed_dim)
|
377 |
+
|
378 |
+
self.pos_drop = nn.Dropout(p=drop_rate)
|
379 |
+
|
380 |
+
def forward(self, x, dt):
|
381 |
+
"""
|
382 |
+
Args:
|
383 |
+
x: Tensor of shape (B, C, T, H, W)
|
384 |
+
dt: Tensor of shape (B, T) identifying time deltas.
|
385 |
+
Returns:
|
386 |
+
Tensor of shape (B, ST, D)
|
387 |
+
"""
|
388 |
+
B, C, T, H, W = x.shape
|
389 |
+
x = rearrange(x, "B C T H W -> B (C T) H W")
|
390 |
+
x = self.proj(x) # B (C T) H W -> B (C D) HT WT
|
391 |
+
x = x.flatten(2, 3) # B (C D) ST
|
392 |
+
ST = x.shape[2]
|
393 |
+
assert ST == self.num_patches
|
394 |
+
x = rearrange(x, "B (C D) ST -> (B C) D ST", B=B, ST=ST, C=C, D=self.embed_dim)
|
395 |
+
x = x + self.pos_embed
|
396 |
+
x = rearrange(x, "(B C) D ST -> (B ST) C D", B=B, ST=ST, C=C, D=self.embed_dim)
|
397 |
+
|
398 |
+
# ((B ST) NQ D), ((B ST) C D) -> ((B ST) NQ D)
|
399 |
+
x = self.perceiver(self.latent_queries.expand(B * ST, -1, -1), x)
|
400 |
+
x = rearrange(
|
401 |
+
x,
|
402 |
+
"(B ST) NQ D -> B ST (NQ D)",
|
403 |
+
B=B,
|
404 |
+
ST=self.num_patches,
|
405 |
+
NQ=self.num_queries,
|
406 |
+
D=self.embed_dim,
|
407 |
+
)
|
408 |
+
x = self.latent_aggregation(x) # B ST (NQ D) -> B ST D'
|
409 |
+
|
410 |
+
assert x.shape[1] == self.num_patches
|
411 |
+
assert x.shape[2] == self.embed_dim
|
412 |
+
|
413 |
+
x = self.pos_drop(x)
|
414 |
+
|
415 |
+
return x
|
416 |
+
|
417 |
+
|
418 |
+
class PerceiverDecoder(nn.Module):
|
419 |
+
def __init__(
|
420 |
+
self,
|
421 |
+
embed_dim: int,
|
422 |
+
patch_size: int,
|
423 |
+
out_chans: int,
|
424 |
+
):
|
425 |
+
"""
|
426 |
+
Args:
|
427 |
+
embed_dim: embedding dimension
|
428 |
+
patch_size: patch size
|
429 |
+
out_chans: number of output channels. This determines the number of latent queries.
|
430 |
+
drop_rate: dropout rate
|
431 |
+
"""
|
432 |
+
super().__init__()
|
433 |
+
|
434 |
+
self.embed_dim = embed_dim
|
435 |
+
self.patch_size = patch_size
|
436 |
+
self.out_chans = out_chans
|
437 |
+
|
438 |
+
self.latent_queries = nn.Parameter(torch.zeros(1, out_chans, embed_dim))
|
439 |
+
trunc_normal_(self.latent_queries, std=0.02)
|
440 |
+
|
441 |
+
self.perceiver = PerceiverResampler(
|
442 |
+
latent_dim=embed_dim,
|
443 |
+
context_dim=embed_dim,
|
444 |
+
depth=1,
|
445 |
+
head_dim=embed_dim // 16,
|
446 |
+
num_heads=16,
|
447 |
+
mlp_ratio=4.0,
|
448 |
+
drop=0.0,
|
449 |
+
residual_latent=False,
|
450 |
+
ln_eps=1e-5,
|
451 |
+
)
|
452 |
+
self.proj = nn.Conv2d(
|
453 |
+
in_channels=out_chans * embed_dim,
|
454 |
+
out_channels=out_chans * patch_size**2,
|
455 |
+
kernel_size=1,
|
456 |
+
padding=0,
|
457 |
+
groups=out_chans,
|
458 |
+
)
|
459 |
+
self.pixel_shuffle = nn.PixelShuffle(patch_size)
|
460 |
+
|
461 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
462 |
+
"""
|
463 |
+
Args:
|
464 |
+
x: Tensor of shape (B, L, D) For ensembles, we have implicitly B = (B E).
|
465 |
+
Returns:
|
466 |
+
Tensor of shape (B C H W).
|
467 |
+
Here
|
468 |
+
- C equals out_chans
|
469 |
+
- H == W == sqrt(L) x patch_size
|
470 |
+
"""
|
471 |
+
B, L, D = x.shape
|
472 |
+
H_token = W_token = int(L**0.5)
|
473 |
+
|
474 |
+
x = rearrange(x, "B L D -> (B L) 1 D")
|
475 |
+
# (B L) 1 D -> (B L) C D
|
476 |
+
x = self.perceiver(self.latent_queries.expand(B * L, -1, -1), x)
|
477 |
+
x = rearrange(x, "(B H W) C D -> B (C D) H W", H=H_token, W=W_token)
|
478 |
+
# B (C D) H_token W_token -> B (C patch_size patch_size) H_token W_token
|
479 |
+
x = self.proj(x)
|
480 |
+
# B (C patch_size patch_size) H_token W_token -> B C H W
|
481 |
+
x = self.pixel_shuffle(x)
|
482 |
+
|
483 |
+
return x
|
surya/models/flow.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class HelioFlowModel(nn.Module):
|
7 |
+
def __init__(self, img_size=(4096, 4096), use_latitude_in_learned_flow=False):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.use_latitude_in_learned_flow = use_latitude_in_learned_flow
|
11 |
+
|
12 |
+
u = torch.linspace(-1, 1, img_size[0])
|
13 |
+
v = torch.linspace(-1, 1, img_size[1])
|
14 |
+
u, v = torch.meshgrid(u, v, indexing="xy")
|
15 |
+
self.register_buffer(
|
16 |
+
"grid", torch.stack((u, v), dim=2).view(1, *img_size, 2)
|
17 |
+
) # B, H, W, 2
|
18 |
+
|
19 |
+
# Higher modes can be used for explicit feature engineering for flow features.
|
20 |
+
if self.use_latitude_in_learned_flow:
|
21 |
+
higher_modes = [u, v, torch.ones_like(u)]
|
22 |
+
else:
|
23 |
+
higher_modes = [
|
24 |
+
u,
|
25 |
+
v,
|
26 |
+
]
|
27 |
+
self.register_buffer(
|
28 |
+
"higher_modes", torch.stack(higher_modes, dim=2).view(1, *img_size, -1)
|
29 |
+
)
|
30 |
+
|
31 |
+
self.flow_generator = nn.Sequential(
|
32 |
+
nn.Linear(self.higher_modes.shape[3], 128),
|
33 |
+
nn.GELU(),
|
34 |
+
nn.Linear(128, 2),
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, batch):
|
38 |
+
"""
|
39 |
+
Args:
|
40 |
+
batch: Dictionary containing keys `ts` and
|
41 |
+
`forecast_latitude` (optionally).
|
42 |
+
ts (torch.Tensor): B, C, T, H, W
|
43 |
+
forecast_latitude (torch.Tensor): B, L
|
44 |
+
B - Batch size, C - Channels, T - Input times, H - Image height,
|
45 |
+
W - Image width, L - Lead time.
|
46 |
+
"""
|
47 |
+
|
48 |
+
x = batch["ts"]
|
49 |
+
B, C, T, H, W = x.shape
|
50 |
+
if T == 1:
|
51 |
+
x = x[:, :, -1, :, :]
|
52 |
+
else:
|
53 |
+
# Taking the average of the last two time stamps
|
54 |
+
x = (x[:, :, -1, :, :] + x[:, :, -2, :, :]) / 2
|
55 |
+
|
56 |
+
# Flow fields have the shape B, H_out, W_out, 2
|
57 |
+
if self.use_latitude_in_learned_flow:
|
58 |
+
broadcast_lat = batch["forecast_latitude"] / 7
|
59 |
+
broadcast_lat = torch.concatenate(
|
60 |
+
[
|
61 |
+
torch.ones_like(broadcast_lat),
|
62 |
+
torch.ones_like(broadcast_lat),
|
63 |
+
broadcast_lat,
|
64 |
+
],
|
65 |
+
1,
|
66 |
+
)[:, None, None, :]
|
67 |
+
higher_modes = self.higher_modes * broadcast_lat
|
68 |
+
flow_field = self.grid + self.flow_generator(higher_modes)
|
69 |
+
else:
|
70 |
+
flow_field = self.grid + self.flow_generator(self.higher_modes)
|
71 |
+
flow_field = flow_field.expand(B, H, W, 2)
|
72 |
+
|
73 |
+
y_hat = F.grid_sample(
|
74 |
+
x,
|
75 |
+
flow_field,
|
76 |
+
mode="bilinear",
|
77 |
+
padding_mode="border", # Possible values: zeros, border, or reflection.
|
78 |
+
align_corners=False,
|
79 |
+
)
|
80 |
+
|
81 |
+
return y_hat
|
surya/models/helio_spectformer.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from einops import rearrange
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from .spectformer import SpectFormer, BlockSpectralGating, BlockAttention
|
8 |
+
from .embedding import (
|
9 |
+
LinearEmbedding,
|
10 |
+
PatchEmbed3D,
|
11 |
+
PerceiverChannelEmbedding,
|
12 |
+
LinearDecoder,
|
13 |
+
PerceiverDecoder,
|
14 |
+
)
|
15 |
+
from .flow import HelioFlowModel
|
16 |
+
|
17 |
+
|
18 |
+
class HelioSpectFormer(nn.Module):
|
19 |
+
"""
|
20 |
+
A note on the ensemble capability:
|
21 |
+
Ensembles of size E are generated by setting `ensemble=E`. In this case, the forward
|
22 |
+
pass generates ensemble members after tokenization by increasing the batch dimension
|
23 |
+
B to B x E. Noise is injected in the `self.backbone` Specformer blocks. After the
|
24 |
+
backbone, ensemble members ride along implicitly in the batch dimension. (This is
|
25 |
+
mainly through the `self.unembed` pass.) An explicit ensemble dimension is only
|
26 |
+
generated at the end.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def __init__(
|
30 |
+
self,
|
31 |
+
img_size: int,
|
32 |
+
patch_size: int,
|
33 |
+
in_chans: int,
|
34 |
+
embed_dim: int,
|
35 |
+
time_embedding: dict,
|
36 |
+
depth: int,
|
37 |
+
n_spectral_blocks: int,
|
38 |
+
num_heads: int,
|
39 |
+
mlp_ratio: float,
|
40 |
+
drop_rate: float,
|
41 |
+
window_size: int,
|
42 |
+
dp_rank: int,
|
43 |
+
learned_flow: bool = False,
|
44 |
+
use_latitude_in_learned_flow: bool = False,
|
45 |
+
init_weights: bool = False,
|
46 |
+
checkpoint_layers: list[int] | None = None,
|
47 |
+
rpe: bool = False,
|
48 |
+
ensemble: int | None = None,
|
49 |
+
finetune: bool = True,
|
50 |
+
nglo: int = 0,
|
51 |
+
dtype: torch.dtype | None = None,
|
52 |
+
) -> None:
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
img_size: input image size
|
56 |
+
patch_size: patch size
|
57 |
+
in_chans: number of iput channels
|
58 |
+
embed_dim: embeddin dimension
|
59 |
+
time_embedding: dictionary to configure temporal embedding:
|
60 |
+
`type` (str, required): indicates embedding type. `linear`, `perceiver`.
|
61 |
+
`time_dim` (int): indicates length of time dimension. required for linear embedding.
|
62 |
+
`n_queries` (int): indicates number of perceiver queries. required for perceiver.
|
63 |
+
depth: number of transformer blocks
|
64 |
+
n_spectral_blocks: number of spectral gating blocks
|
65 |
+
num_heads: Number of transformer heads
|
66 |
+
mlp_ratio: MLP ratio for transformer blocks
|
67 |
+
drop_rate: dropout rate
|
68 |
+
window_size: window size for long/short attention
|
69 |
+
dp_rank: dp rank for long/short attention
|
70 |
+
learned_flow: if true, combine learned flow model with spectformer
|
71 |
+
use_latitude_in_learned_flow: use latitudes in learned flow
|
72 |
+
init_weights: use optimized weight initialization
|
73 |
+
checkpoint_layers: indicate which layers to use for checkpointing
|
74 |
+
rpe: Use relative position encoding in Long-Short attention blocks.
|
75 |
+
ensemble: Integer indicating ensemble size or None for deterministic model.
|
76 |
+
finetune: Indicates whether to train from scrach or fine-tune the model. If set to `True`, the final output layers are removed.
|
77 |
+
nglo: Number of (additional) global tokens.
|
78 |
+
dtype: A torch data type. Not used and added only for compatibility with the remainder of the codebase.
|
79 |
+
"""
|
80 |
+
super().__init__()
|
81 |
+
|
82 |
+
self.learned_flow = learned_flow
|
83 |
+
self.patch_size = patch_size
|
84 |
+
self.embed_dim = embed_dim
|
85 |
+
self.in_chans = in_chans
|
86 |
+
self.time_embedding = time_embedding
|
87 |
+
self.ensemble = ensemble
|
88 |
+
self.finetune = finetune
|
89 |
+
self.nglo = nglo
|
90 |
+
|
91 |
+
if learned_flow:
|
92 |
+
self.learned_flow_model = HelioFlowModel(
|
93 |
+
img_size=(img_size, img_size),
|
94 |
+
use_latitude_in_learned_flow=use_latitude_in_learned_flow,
|
95 |
+
)
|
96 |
+
|
97 |
+
match time_embedding["type"]:
|
98 |
+
case "linear":
|
99 |
+
self.time_dim = time_embedding["time_dim"]
|
100 |
+
if learned_flow:
|
101 |
+
self.time_dim += 1
|
102 |
+
self.embedding = LinearEmbedding(
|
103 |
+
img_size, patch_size, in_chans, self.time_dim, embed_dim, drop_rate
|
104 |
+
)
|
105 |
+
|
106 |
+
if not self.finetune:
|
107 |
+
self.unembed = LinearDecoder(
|
108 |
+
patch_size=patch_size, out_chans=in_chans, embed_dim=embed_dim
|
109 |
+
)
|
110 |
+
case "perceiver":
|
111 |
+
self.embedding = PerceiverChannelEmbedding(
|
112 |
+
in_chans=in_chans,
|
113 |
+
img_size=img_size,
|
114 |
+
patch_size=patch_size,
|
115 |
+
time_dim=time_embedding["time_dim"],
|
116 |
+
num_queries=time_embedding["n_queries"],
|
117 |
+
embed_dim=embed_dim,
|
118 |
+
drop_rate=drop_rate,
|
119 |
+
)
|
120 |
+
if not self.finetune:
|
121 |
+
self.unembed = PerceiverDecoder(
|
122 |
+
embed_dim=embed_dim,
|
123 |
+
patch_size=patch_size,
|
124 |
+
out_chans=in_chans,
|
125 |
+
)
|
126 |
+
case _:
|
127 |
+
raise NotImplementedError(
|
128 |
+
f'Embedding {time_embedding["type"]} has not been implemented.'
|
129 |
+
)
|
130 |
+
|
131 |
+
if isinstance(depth, list):
|
132 |
+
raise NotImplementedError(
|
133 |
+
"Multi scale models are no longer supported. Depth should be a single integer."
|
134 |
+
)
|
135 |
+
self.backbone = SpectFormer(
|
136 |
+
grid_size=img_size // patch_size,
|
137 |
+
embed_dim=embed_dim,
|
138 |
+
depth=depth,
|
139 |
+
n_spectral_blocks=n_spectral_blocks,
|
140 |
+
num_heads=num_heads,
|
141 |
+
mlp_ratio=mlp_ratio,
|
142 |
+
drop_rate=drop_rate,
|
143 |
+
window_size=window_size,
|
144 |
+
dp_rank=dp_rank,
|
145 |
+
checkpoint_layers=checkpoint_layers,
|
146 |
+
rpe=rpe,
|
147 |
+
ensemble=ensemble,
|
148 |
+
nglo=nglo,
|
149 |
+
)
|
150 |
+
|
151 |
+
if init_weights:
|
152 |
+
self.apply(self._init_weights)
|
153 |
+
|
154 |
+
# @staticmethod
|
155 |
+
# def _checkpoint_wrapper(
|
156 |
+
# model: nn.Module, data: tuple[Tensor, Tensor | None]
|
157 |
+
# ) -> Tensor:
|
158 |
+
# return checkpoint(model, data, use_reentrant=False)
|
159 |
+
|
160 |
+
def _init_weights(self, module):
|
161 |
+
|
162 |
+
if self.time_embedding["type"] == "linear":
|
163 |
+
# sampling_step * embed_dim = patch_size**2 * in_chans * time_dim
|
164 |
+
sampling_step = int(
|
165 |
+
np.sqrt(
|
166 |
+
(self.patch_size**2 * self.in_chans * self.time_dim)
|
167 |
+
/ self.embed_dim
|
168 |
+
)
|
169 |
+
)
|
170 |
+
else:
|
171 |
+
sampling_step = int(
|
172 |
+
np.sqrt((self.patch_size**2 * self.in_chans) / self.embed_dim)
|
173 |
+
)
|
174 |
+
if isinstance(module, PatchEmbed3D):
|
175 |
+
torch.nn.init.zeros_(module.proj.weight)
|
176 |
+
c_out = 0
|
177 |
+
w_pool = 1.0 / sampling_step
|
178 |
+
for k in range(self.in_chans * self.time_dim):
|
179 |
+
for i in range(0, self.patch_size, sampling_step):
|
180 |
+
for j in range(0, self.patch_size, sampling_step):
|
181 |
+
module.proj.weight.data[
|
182 |
+
c_out, k, i : i + sampling_step, j : j + sampling_step
|
183 |
+
] = w_pool
|
184 |
+
c_out += 1
|
185 |
+
if module.proj.bias is not None:
|
186 |
+
module.proj.bias.data.zero_()
|
187 |
+
if isinstance(module, BlockSpectralGating):
|
188 |
+
for m in [
|
189 |
+
module.mlp.fc1,
|
190 |
+
module.mlp.fc2,
|
191 |
+
]:
|
192 |
+
# m.weight.data.normal_(mean=0.0, std=0.01)
|
193 |
+
# torch.nn.init.eye_(m.weight)
|
194 |
+
torch.nn.init.eye_(m.weight)
|
195 |
+
if m.bias is not None:
|
196 |
+
m.bias.data.zero_()
|
197 |
+
if isinstance(module, BlockAttention):
|
198 |
+
for m in [
|
199 |
+
module.mlp.fc1,
|
200 |
+
module.mlp.fc2,
|
201 |
+
]:
|
202 |
+
# torch.nn.init.eye_(m.weight)
|
203 |
+
torch.nn.init.zeros_(m.weight)
|
204 |
+
if m.bias is not None:
|
205 |
+
m.bias.data.zero_()
|
206 |
+
for m in [
|
207 |
+
module.attn.qkv,
|
208 |
+
module.attn.proj,
|
209 |
+
module.attn.to_dynamic_projection,
|
210 |
+
]:
|
211 |
+
# m.weight.data.normal_(mean=0.0, std=0.01)
|
212 |
+
# torch.nn.init.eye_(m.weight)
|
213 |
+
torch.nn.init.zeros_(m.weight)
|
214 |
+
if m.bias is not None:
|
215 |
+
m.bias.data.zero_()
|
216 |
+
if isinstance(module, torch.nn.Sequential):
|
217 |
+
if isinstance(module[1], torch.nn.PixelShuffle):
|
218 |
+
# torch.nn.init.eye_(module[0].weight.data[:,:,0,0])
|
219 |
+
torch.nn.init.zeros_(module[0].weight)
|
220 |
+
if self.time_embedding["type"] == "linear":
|
221 |
+
c_out = 0
|
222 |
+
for k in range(1, self.in_chans + 1):
|
223 |
+
for i in range(
|
224 |
+
self.patch_size**2 // (self.patch_size * sampling_step)
|
225 |
+
):
|
226 |
+
for j in range(self.patch_size):
|
227 |
+
module[0].weight.data[
|
228 |
+
c_out : c_out + sampling_step,
|
229 |
+
j + (k * self.time_dim - 1) * self.patch_size,
|
230 |
+
] = 1.0
|
231 |
+
c_out += sampling_step
|
232 |
+
else:
|
233 |
+
c_out = 0
|
234 |
+
for k in range(2 * self.in_chans):
|
235 |
+
# l = 0
|
236 |
+
for l_feat in range(self.backbone.embed_dim):
|
237 |
+
module[0].weight.data[c_out, l_feat] = 1.0
|
238 |
+
c_out += 1
|
239 |
+
if module[0].bias is not None:
|
240 |
+
module[0].bias.data.zero_()
|
241 |
+
|
242 |
+
def forward(self, batch):
|
243 |
+
"""
|
244 |
+
Args:
|
245 |
+
batch: Dictionary containing keys `ts` and `time_delta_input`.
|
246 |
+
Their values are tensors with shapes as follows.
|
247 |
+
ts: B, C, T, H, W
|
248 |
+
time_delta_input: B, T
|
249 |
+
Returns:
|
250 |
+
Tensor fo shape (B, C, H, W) for deterministic or (B, E, C, H, W) for ensemble forecasts.
|
251 |
+
"""
|
252 |
+
x = batch["ts"]
|
253 |
+
dt = batch["time_delta_input"]
|
254 |
+
B, C, T, H, W = x.shape
|
255 |
+
|
256 |
+
if self.learned_flow:
|
257 |
+
y_hat_flow = self.learned_flow_model(batch) # B, C, H, W
|
258 |
+
if any(
|
259 |
+
[param.requires_grad for param in self.learned_flow_model.parameters()]
|
260 |
+
):
|
261 |
+
return y_hat_flow
|
262 |
+
else:
|
263 |
+
x = torch.concat((x, y_hat_flow.unsqueeze(2)), dim=2) # B, C, T+1, H, W
|
264 |
+
if self.time_embedding["type"] == "perceiver":
|
265 |
+
dt = torch.cat((dt, batch["lead_time_delta"].reshape(-1, 1)), dim=1)
|
266 |
+
|
267 |
+
# embed the data
|
268 |
+
tokens = self.embedding(x, dt)
|
269 |
+
|
270 |
+
# copy tokens in case of ensemble forecast
|
271 |
+
if self.ensemble:
|
272 |
+
# B L D -> (B E) L D == BE L D
|
273 |
+
tokens = torch.repeat_interleave(tokens, repeats=self.ensemble, dim=0)
|
274 |
+
|
275 |
+
# pass the time series through the encoder
|
276 |
+
tokens = self.backbone(tokens)
|
277 |
+
|
278 |
+
if self.finetune:
|
279 |
+
return tokens
|
280 |
+
|
281 |
+
# Unembed the tokens
|
282 |
+
# BE L D -> BE C H W
|
283 |
+
forecast_hat = self.unembed(tokens)
|
284 |
+
|
285 |
+
assert forecast_hat.shape == (
|
286 |
+
B * self.ensemble if self.ensemble else B,
|
287 |
+
C,
|
288 |
+
H,
|
289 |
+
W,
|
290 |
+
), f"forecast_hat has shape {forecast_hat.shape} yet expected {(B*self.ensemble if self.ensemble else B, C, H, W)}."
|
291 |
+
|
292 |
+
if self.learned_flow:
|
293 |
+
assert y_hat_flow.shape == (
|
294 |
+
B,
|
295 |
+
C,
|
296 |
+
H,
|
297 |
+
W,
|
298 |
+
), f"y_hat_flow has shape {y_hat_flow.shape} yet expected {(B, C, H, W)}."
|
299 |
+
if self.ensemble:
|
300 |
+
y_hat_flow = torch.repeat_interleave(
|
301 |
+
y_hat_flow, repeats=self.ensemble, dim=0
|
302 |
+
)
|
303 |
+
assert y_hat_flow.shape == forecast_hat.shape
|
304 |
+
forecast_hat = forecast_hat + y_hat_flow
|
305 |
+
|
306 |
+
assert forecast_hat.shape == (
|
307 |
+
B * self.ensemble if self.ensemble else B,
|
308 |
+
C,
|
309 |
+
H,
|
310 |
+
W,
|
311 |
+
), f"forecast_hat has shape {forecast_hat.shape} yet expected {(B*self.ensemble if self.ensemble else B, C, H, W)}."
|
312 |
+
|
313 |
+
if self.ensemble:
|
314 |
+
forecast_hat = rearrange(
|
315 |
+
forecast_hat, "(B E) C H W -> B E C H W", B=B, E=self.ensemble
|
316 |
+
)
|
317 |
+
|
318 |
+
return forecast_hat
|
surya/models/spectformer.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import logging
|
3 |
+
from itertools import chain
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.utils.checkpoint import checkpoint
|
8 |
+
|
9 |
+
from timm.models.layers import DropPath, trunc_normal_
|
10 |
+
import torch.fft
|
11 |
+
|
12 |
+
from .transformer_ls import AttentionLS
|
13 |
+
|
14 |
+
_logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class Mlp(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_features,
|
21 |
+
hidden_features=None,
|
22 |
+
out_features=None,
|
23 |
+
act_layer=nn.GELU,
|
24 |
+
drop=0.0,
|
25 |
+
):
|
26 |
+
super().__init__()
|
27 |
+
out_features = out_features or in_features
|
28 |
+
hidden_features = hidden_features or in_features
|
29 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
30 |
+
self.act = act_layer()
|
31 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
32 |
+
self.drop = nn.Dropout(drop)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
x = self.fc1(x)
|
36 |
+
x = self.act(x)
|
37 |
+
x = self.drop(x)
|
38 |
+
x = self.fc2(x)
|
39 |
+
x = self.drop(x)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
class SpectralGatingNetwork(nn.Module):
|
44 |
+
def __init__(self, dim, h=14, w=8):
|
45 |
+
super().__init__()
|
46 |
+
self.complex_weight = nn.Parameter(torch.randn(h, w, dim, 2) * 0.02)
|
47 |
+
self.w = w
|
48 |
+
self.h = h
|
49 |
+
|
50 |
+
def forward(self, x, spatial_size=None):
|
51 |
+
B, N, C = x.shape # torch.Size([1, 262144, 1024])
|
52 |
+
if spatial_size is None:
|
53 |
+
a = b = int(math.sqrt(N)) # a=b=512
|
54 |
+
else:
|
55 |
+
a, b = spatial_size
|
56 |
+
|
57 |
+
x = x.view(B, a, b, C) # torch.Size([1, 512, 512, 1024])
|
58 |
+
|
59 |
+
# FROM HERE USED TO BE AUTOCAST to float32
|
60 |
+
dtype = x.dtype
|
61 |
+
x = x.to(torch.float32)
|
62 |
+
x = torch.fft.rfft2(
|
63 |
+
x, dim=(1, 2), norm="ortho"
|
64 |
+
) # torch.Size([1, 512, 257, 1024])
|
65 |
+
weight = torch.view_as_complex(
|
66 |
+
self.complex_weight.to(torch.float32)
|
67 |
+
) # torch.Size([512, 257, 1024])
|
68 |
+
x = x * weight
|
69 |
+
x = torch.fft.irfft2(
|
70 |
+
x, s=(a, b), dim=(1, 2), norm="ortho"
|
71 |
+
) # torch.Size([1, 512, 512, 1024])
|
72 |
+
x = x.to(dtype)
|
73 |
+
|
74 |
+
x = x.reshape(B, N, C) # torch.Size([1, 262144, 1024])
|
75 |
+
# UP TO HERE USED TO BE AUTOCAST to float32
|
76 |
+
|
77 |
+
return x
|
78 |
+
|
79 |
+
|
80 |
+
class BlockSpectralGating(nn.Module):
|
81 |
+
def __init__(
|
82 |
+
self,
|
83 |
+
dim,
|
84 |
+
mlp_ratio=4.0,
|
85 |
+
drop=0.0,
|
86 |
+
drop_path=0.0,
|
87 |
+
act_layer=nn.GELU,
|
88 |
+
norm_layer=nn.LayerNorm,
|
89 |
+
h=14,
|
90 |
+
w=8,
|
91 |
+
):
|
92 |
+
super().__init__()
|
93 |
+
self.norm1 = norm_layer(dim)
|
94 |
+
self.filter = SpectralGatingNetwork(dim, h=h, w=w)
|
95 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
96 |
+
self.norm2 = norm_layer(dim)
|
97 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
98 |
+
self.mlp = Mlp(
|
99 |
+
in_features=dim,
|
100 |
+
hidden_features=mlp_hidden_dim,
|
101 |
+
act_layer=act_layer,
|
102 |
+
drop=drop,
|
103 |
+
)
|
104 |
+
|
105 |
+
def forward(self, x, *args):
|
106 |
+
x = x + self.drop_path(self.mlp(self.norm2(self.filter(self.norm1(x)))))
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
class BlockAttention(nn.Module):
|
111 |
+
def __init__(
|
112 |
+
self,
|
113 |
+
dim,
|
114 |
+
num_heads: int = 8,
|
115 |
+
mlp_ratio=4.0,
|
116 |
+
drop=0.0,
|
117 |
+
drop_path=0.0,
|
118 |
+
w=2,
|
119 |
+
dp_rank=2,
|
120 |
+
act_layer=nn.GELU,
|
121 |
+
norm_layer=nn.LayerNorm,
|
122 |
+
rpe=False,
|
123 |
+
adaLN=False,
|
124 |
+
nglo=0,
|
125 |
+
):
|
126 |
+
"""
|
127 |
+
num_heads: Attention heads. 4 for tiny, 8 for small and 12 for base
|
128 |
+
"""
|
129 |
+
super().__init__()
|
130 |
+
self.norm1 = norm_layer(dim)
|
131 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
132 |
+
self.norm2 = norm_layer(dim)
|
133 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
134 |
+
self.mlp = Mlp(
|
135 |
+
in_features=dim,
|
136 |
+
hidden_features=mlp_hidden_dim,
|
137 |
+
act_layer=act_layer,
|
138 |
+
drop=drop,
|
139 |
+
)
|
140 |
+
self.attn = AttentionLS(
|
141 |
+
dim=dim,
|
142 |
+
num_heads=num_heads,
|
143 |
+
w=w,
|
144 |
+
dp_rank=dp_rank,
|
145 |
+
nglo=nglo,
|
146 |
+
rpe=rpe,
|
147 |
+
)
|
148 |
+
|
149 |
+
if adaLN:
|
150 |
+
self.adaLN_modulation = nn.Sequential(
|
151 |
+
nn.Linear(dim, dim, bias=True),
|
152 |
+
act_layer(),
|
153 |
+
nn.Linear(dim, 6 * dim, bias=True),
|
154 |
+
)
|
155 |
+
else:
|
156 |
+
self.adaLN_modulation = None
|
157 |
+
|
158 |
+
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
|
159 |
+
if self.adaLN_modulation is not None:
|
160 |
+
(
|
161 |
+
shift_mha,
|
162 |
+
scale_mha,
|
163 |
+
gate_mha,
|
164 |
+
shift_mlp,
|
165 |
+
scale_mlp,
|
166 |
+
gate_mlp,
|
167 |
+
) = self.adaLN_modulation(c).chunk(6, dim=2)
|
168 |
+
else:
|
169 |
+
shift_mha, scale_mha, gate_mha, shift_mlp, scale_mlp, gate_mlp = 6 * (1.0,)
|
170 |
+
|
171 |
+
x = x + gate_mha * self.drop_path(
|
172 |
+
self.attn(
|
173 |
+
self.norm1(x) * scale_mha + shift_mha,
|
174 |
+
)
|
175 |
+
)
|
176 |
+
x = x + gate_mlp * self.drop_path(
|
177 |
+
self.mlp(self.norm2(x) * scale_mlp + shift_mlp)
|
178 |
+
)
|
179 |
+
|
180 |
+
return x
|
181 |
+
|
182 |
+
|
183 |
+
class SpectFormer(nn.Module):
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
grid_size: int = 224 // 16,
|
187 |
+
embed_dim=768,
|
188 |
+
depth=12,
|
189 |
+
n_spectral_blocks=4,
|
190 |
+
num_heads: int = 8,
|
191 |
+
mlp_ratio=4.0,
|
192 |
+
uniform_drop=False,
|
193 |
+
drop_rate=0.0,
|
194 |
+
drop_path_rate=0.0,
|
195 |
+
window_size=2,
|
196 |
+
dp_rank=2,
|
197 |
+
norm_layer=nn.LayerNorm,
|
198 |
+
checkpoint_layers: list[int] | None = None,
|
199 |
+
rpe=False,
|
200 |
+
ensemble: int | None = None,
|
201 |
+
nglo: int = 0,
|
202 |
+
):
|
203 |
+
"""
|
204 |
+
Args:
|
205 |
+
img_size (int, tuple): input image size
|
206 |
+
patch_size (int, tuple): patch size
|
207 |
+
embed_dim (int): embedding dimension
|
208 |
+
depth (int): depth of transformer
|
209 |
+
n_spectral_blocks (int): number of spectral gating blocks
|
210 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
211 |
+
uniform_drop (bool): true for uniform, false for linearly increasing drop path probability.
|
212 |
+
drop_rate (float): dropout rate
|
213 |
+
drop_path_rate (float): drop path (stochastic depth) rate
|
214 |
+
window_size: window size for long/short attention
|
215 |
+
dp_rank: dp rank for long/short attention
|
216 |
+
norm_layer: (nn.Module): normalization layer for attention blocks
|
217 |
+
checkpoint_layers: indicate which layers to use for checkpointing
|
218 |
+
rpe: Use relative position encoding in Long-Short attention blocks.
|
219 |
+
ensemble: Integer indicating ensemble size or None for deterministic model.
|
220 |
+
nglo: Number of (additional) global tokens.
|
221 |
+
"""
|
222 |
+
super().__init__()
|
223 |
+
self.embed_dim = embed_dim
|
224 |
+
self.n_spectral_blocks = n_spectral_blocks
|
225 |
+
self._checkpoint_layers = checkpoint_layers or []
|
226 |
+
self.ensemble = ensemble
|
227 |
+
self.nglo = nglo
|
228 |
+
|
229 |
+
h = grid_size
|
230 |
+
w = h // 2 + 1
|
231 |
+
|
232 |
+
if uniform_drop:
|
233 |
+
_logger.info(f"Using uniform droppath with expect rate {drop_path_rate}.")
|
234 |
+
dpr = [drop_path_rate for _ in range(depth)]
|
235 |
+
else:
|
236 |
+
_logger.info(
|
237 |
+
f"Using linear droppath with expect rate {drop_path_rate * 0.5}."
|
238 |
+
)
|
239 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
|
240 |
+
|
241 |
+
self.blocks_spectral_gating = nn.ModuleList()
|
242 |
+
self.blocks_attention = nn.ModuleList()
|
243 |
+
for i in range(depth):
|
244 |
+
if i < n_spectral_blocks:
|
245 |
+
layer = BlockSpectralGating(
|
246 |
+
dim=embed_dim,
|
247 |
+
mlp_ratio=mlp_ratio,
|
248 |
+
drop=drop_rate,
|
249 |
+
drop_path=dpr[i],
|
250 |
+
norm_layer=norm_layer,
|
251 |
+
h=h,
|
252 |
+
w=w,
|
253 |
+
)
|
254 |
+
self.blocks_spectral_gating.append(layer)
|
255 |
+
else:
|
256 |
+
layer = BlockAttention(
|
257 |
+
dim=embed_dim,
|
258 |
+
num_heads=num_heads,
|
259 |
+
mlp_ratio=mlp_ratio,
|
260 |
+
drop=drop_rate,
|
261 |
+
drop_path=dpr[i],
|
262 |
+
norm_layer=norm_layer,
|
263 |
+
w=window_size,
|
264 |
+
dp_rank=dp_rank,
|
265 |
+
rpe=rpe,
|
266 |
+
adaLN=True if ensemble is not None else False,
|
267 |
+
nglo=nglo,
|
268 |
+
)
|
269 |
+
self.blocks_attention.append(layer)
|
270 |
+
|
271 |
+
self.apply(self._init_weights)
|
272 |
+
|
273 |
+
def forward(self, tokens: torch.Tensor) -> torch.Tensor:
|
274 |
+
"""
|
275 |
+
Args:
|
276 |
+
tokens: Tensor of shape B, N, C for deterministic of BxE, N, C for ensemble forecast.
|
277 |
+
Returns:
|
278 |
+
Tensor of same shape as input.
|
279 |
+
"""
|
280 |
+
if self.ensemble:
|
281 |
+
BE, N, C = tokens.shape
|
282 |
+
noise = torch.randn(
|
283 |
+
size=(BE, N, C), dtype=tokens.dtype, device=tokens.device
|
284 |
+
)
|
285 |
+
else:
|
286 |
+
noise = None
|
287 |
+
|
288 |
+
for i, blk in enumerate(
|
289 |
+
chain(self.blocks_spectral_gating, self.blocks_attention)
|
290 |
+
):
|
291 |
+
if i in self._checkpoint_layers:
|
292 |
+
tokens = checkpoint(blk, tokens, noise, use_reentrant=False)
|
293 |
+
else:
|
294 |
+
tokens = blk(tokens, noise)
|
295 |
+
|
296 |
+
return tokens
|
297 |
+
|
298 |
+
def _init_weights(self, m):
|
299 |
+
if isinstance(m, nn.Linear):
|
300 |
+
trunc_normal_(m.weight, std=0.02)
|
301 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
302 |
+
nn.init.constant_(m.bias, 0)
|
303 |
+
elif isinstance(m, nn.LayerNorm):
|
304 |
+
nn.init.constant_(m.bias, 0)
|
305 |
+
nn.init.constant_(m.weight, 1.0)
|
surya/models/transformer_ls.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021 NVIDIA CORPORATION. Licensed under the MIT license.
|
2 |
+
# Written by Chen Zhu during an internship at NVIDIA, [email protected]
|
3 |
+
import math
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
import torch
|
7 |
+
from timm.models.layers import trunc_normal_
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
|
11 |
+
class AttentionLS(nn.Module):
|
12 |
+
"""Implementation for long-short term attention.
|
13 |
+
Flexible options for using window attention, global token and dynamic projection.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
dim: input and output feature dimension.
|
17 |
+
num_heads: number of attention heads.
|
18 |
+
qkv_bias: whether to use bias for the projection of query, key and values.
|
19 |
+
qk_scale: scale factor on query and key for numerical stability.
|
20 |
+
By default, set to square root of head dimensions.
|
21 |
+
attn_drop: dropout probability for attention matrix.
|
22 |
+
proj_drop: dropout probability for the final output.
|
23 |
+
rpe: whether to use relative position encoding.
|
24 |
+
nglo: number of global tokens (e.g., CLS).
|
25 |
+
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(
|
29 |
+
self,
|
30 |
+
dim,
|
31 |
+
num_heads=8,
|
32 |
+
qkv_bias=False,
|
33 |
+
qk_scale=None,
|
34 |
+
attn_drop=0.0,
|
35 |
+
proj_drop=0.0,
|
36 |
+
rpe=False,
|
37 |
+
nglo=1,
|
38 |
+
dp_rank=2,
|
39 |
+
w=2,
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
self.num_heads = num_heads
|
43 |
+
head_dim = dim // num_heads
|
44 |
+
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
45 |
+
self.scale = qk_scale or head_dim**-0.5
|
46 |
+
|
47 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
48 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
49 |
+
self.proj = nn.Linear(dim, dim)
|
50 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
51 |
+
self.nglo = nglo
|
52 |
+
|
53 |
+
# Equals to segment size (w) in the paper.
|
54 |
+
self.window_size = w
|
55 |
+
# Equals to r in the paper.
|
56 |
+
self.dp_rank = dp_rank
|
57 |
+
|
58 |
+
if self.dp_rank > 0:
|
59 |
+
self.to_dynamic_projection = nn.Linear(dim, dp_rank * num_heads)
|
60 |
+
# The LN of DualLN corresponding to dynamic projection
|
61 |
+
self.dual_ln_dp = nn.LayerNorm(dim)
|
62 |
+
# The LN of DualLN corresponding to all the tokens
|
63 |
+
self.dual_ln_full = nn.LayerNorm(dim)
|
64 |
+
|
65 |
+
# Adapted from ViL: https://github.com/microsoft/vision-longformer/blob/main/src/models/layers/longformer2d.py#L55-L100
|
66 |
+
# We only add RPE to window attention.
|
67 |
+
# Unnecessary to add bias for global tokens, since DualLN already adds biases.
|
68 |
+
self.rpe = rpe
|
69 |
+
if rpe:
|
70 |
+
# handle the boarder conditions...
|
71 |
+
w_pad = int(w * 0.5)
|
72 |
+
self.local_relative_position_bias_table = nn.Parameter(
|
73 |
+
torch.zeros(2 * (w + w_pad - 1) * (2 * w_pad + w + 1) + 1, num_heads)
|
74 |
+
)
|
75 |
+
trunc_normal_(self.local_relative_position_bias_table, std=0.02)
|
76 |
+
|
77 |
+
# get pair-wise relative position index
|
78 |
+
coords_h = torch.arange(-w_pad, w_pad + w)
|
79 |
+
coords_w = torch.arange(-w_pad, w_pad + w)
|
80 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, 2w, 2w
|
81 |
+
coords = (
|
82 |
+
coords.view(2, (w + w_pad * 2) ** 2).transpose(0, 1).unsqueeze(0)
|
83 |
+
) # 1, 4w**2, 2
|
84 |
+
q_coords_hw = torch.arange(0, w)
|
85 |
+
q_coords = torch.stack(
|
86 |
+
torch.meshgrid([q_coords_hw, q_coords_hw])
|
87 |
+
) # 2, w, w
|
88 |
+
q_coords = q_coords.view(2, w**2).transpose(0, 1).unsqueeze(1) # w**2, 1, 2
|
89 |
+
relative_coords = q_coords - coords
|
90 |
+
relative_coords += w_pad + w - 1 # shift to start from 0
|
91 |
+
relative_coords[:, :, 0] *= 2 * w_pad + w
|
92 |
+
relative_position_index = relative_coords.sum(-1) # w^2, 4w^2
|
93 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
94 |
+
|
95 |
+
def forward(self, x, nx=None, ny=None):
|
96 |
+
B, N, C = x.shape
|
97 |
+
N_feat = N - self.nglo
|
98 |
+
self.img_size = int(math.sqrt(N)) if nx is None else nx
|
99 |
+
qkv = self.qkv(x)
|
100 |
+
# query, key, value
|
101 |
+
q, k, v = qkv.chunk(3, dim=2)
|
102 |
+
q = q.mul(self.scale)
|
103 |
+
|
104 |
+
# Layer norm on the projected keys and values
|
105 |
+
k = self.dual_ln_full(k)
|
106 |
+
v = self.dual_ln_full(v)
|
107 |
+
|
108 |
+
# output size: bsz x n_heads x seqlen x d
|
109 |
+
if self.nglo > 0:
|
110 |
+
q_cls, q = q[:, : self.nglo], q[:, self.nglo :]
|
111 |
+
k_cls, k = k[:, : self.nglo], k[:, self.nglo :]
|
112 |
+
v_cls, v = v[:, : self.nglo], v[:, self.nglo :]
|
113 |
+
|
114 |
+
q_cls = q_cls.reshape(
|
115 |
+
B, self.nglo, self.num_heads, C // self.num_heads
|
116 |
+
).transpose(1, 2)
|
117 |
+
k_cls = k_cls.reshape(
|
118 |
+
B, self.nglo, self.num_heads, C // self.num_heads
|
119 |
+
).transpose(1, 2)
|
120 |
+
v_cls = v_cls.reshape(
|
121 |
+
B, self.nglo, self.num_heads, C // self.num_heads
|
122 |
+
).transpose(1, 2)
|
123 |
+
|
124 |
+
q = q.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2)
|
125 |
+
k = k.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2)
|
126 |
+
v = v.reshape(B, N_feat, self.num_heads, C // self.num_heads).transpose(1, 2)
|
127 |
+
|
128 |
+
# Long-range Attention (Dynamic Projection)
|
129 |
+
if self.dp_rank > 0:
|
130 |
+
# b x h x r x (l w)
|
131 |
+
# Compute the projection matrix (P_i in the paper)
|
132 |
+
c_scores = (
|
133 |
+
self.to_dynamic_projection(x[:, self.nglo :])
|
134 |
+
.transpose(1, 2)
|
135 |
+
.contiguous()
|
136 |
+
.view(B, self.num_heads, self.dp_rank, -1)
|
137 |
+
)
|
138 |
+
# c_scores = c_scores.softmax(dim=-1, dtype=torch.float32).to(x)
|
139 |
+
c_scores = c_scores.softmax(dim=-1).to(
|
140 |
+
x
|
141 |
+
) # Changed when experimenting with mixed precision (Johannes S.)
|
142 |
+
# b x h x r x d
|
143 |
+
k_lms = c_scores.matmul(k)
|
144 |
+
k_lms = k_lms.transpose(1, 2).contiguous().view(B, self.dp_rank, -1)
|
145 |
+
k_lms = (
|
146 |
+
self.dual_ln_dp(k_lms)
|
147 |
+
.view(B, self.dp_rank, self.num_heads, -1)
|
148 |
+
.contiguous()
|
149 |
+
.permute(0, 2, 3, 1)
|
150 |
+
)
|
151 |
+
# b x h x (lw) x r
|
152 |
+
dots_all = q.matmul(k_lms)
|
153 |
+
|
154 |
+
if self.window_size > 0:
|
155 |
+
# Switch the order of dimensions if using window attention.
|
156 |
+
dots_all = self.group_dots(dots_all)
|
157 |
+
else:
|
158 |
+
dots_all = None
|
159 |
+
|
160 |
+
# Short-term Attention (Window Attention)
|
161 |
+
# In our window attention, each token attends to at most (4w^2) tokens.
|
162 |
+
if self.window_size > 0:
|
163 |
+
dots_win = self.compute_window_scores(q, k)
|
164 |
+
w2 = int(self.window_size * self.window_size)
|
165 |
+
|
166 |
+
if self.rpe:
|
167 |
+
w_pad = int(0.5 * self.window_size)
|
168 |
+
local_relative_position_bias = self.local_relative_position_bias_table[
|
169 |
+
self.relative_position_index.view(-1)
|
170 |
+
].view(
|
171 |
+
1, w2, (w_pad * 2 + self.window_size) ** 2, -1
|
172 |
+
) # w^2, kv_nums,H
|
173 |
+
local_relative_position_bias = (
|
174 |
+
local_relative_position_bias.permute(0, 3, 1, 2)
|
175 |
+
.expand(B, -1, -1, -1)
|
176 |
+
.unsqueeze(2)
|
177 |
+
.unsqueeze(2)
|
178 |
+
)
|
179 |
+
|
180 |
+
dots_win += local_relative_position_bias
|
181 |
+
if dots_all is None:
|
182 |
+
dots_all = dots_win
|
183 |
+
else:
|
184 |
+
dots_all = torch.cat([dots_all, dots_win], dim=-1)
|
185 |
+
|
186 |
+
# Global token.
|
187 |
+
if self.nglo > 0:
|
188 |
+
# and compute the scores of queries on CLS
|
189 |
+
dots_q_cls = q.matmul(k_cls.transpose(-1, -2))
|
190 |
+
|
191 |
+
if self.window_size > 0:
|
192 |
+
dots_q_cls = self.group_dots(dots_q_cls)
|
193 |
+
dots_all = torch.cat([dots_all, dots_q_cls], dim=-1)
|
194 |
+
|
195 |
+
# attn = dots_all.softmax(dim=-1, dtype=torch.float32).to(x)
|
196 |
+
attn = dots_all.softmax(dim=-1).to(
|
197 |
+
x
|
198 |
+
) # Changed when experimenting with mixed precision (Johannes S.)
|
199 |
+
attn = self.attn_drop(attn)
|
200 |
+
out = 0
|
201 |
+
if self.window_size > 0:
|
202 |
+
offset = max(0, self.dp_rank)
|
203 |
+
kv_group_size = self.window_size
|
204 |
+
total_win_size = max(1, self.window_size // 2) * 2 + kv_group_size
|
205 |
+
attn_win = attn[:, :, :, :, :, offset : offset + total_win_size**2]
|
206 |
+
out += self.compute_window_pv(attn_win, v)
|
207 |
+
attn = self.ungroup_dots(attn)
|
208 |
+
|
209 |
+
# attn will be b x h x lw x n_k from now on
|
210 |
+
if self.dp_rank > 0:
|
211 |
+
attn_lm = attn[:, :, :, : self.dp_rank]
|
212 |
+
v_lms = (
|
213 |
+
# c_scores.matmul(v.float())
|
214 |
+
c_scores.matmul(
|
215 |
+
v
|
216 |
+
) # Changed when experimenting with mixed precision (Johannes S.)
|
217 |
+
.to(v)
|
218 |
+
.transpose(1, 2)
|
219 |
+
.contiguous()
|
220 |
+
.view(B, self.dp_rank, -1)
|
221 |
+
)
|
222 |
+
v_lms = (
|
223 |
+
self.dual_ln_dp(v_lms)
|
224 |
+
.view(B, self.dp_rank, self.num_heads, -1)
|
225 |
+
.contiguous()
|
226 |
+
.transpose(1, 2)
|
227 |
+
)
|
228 |
+
|
229 |
+
out += attn_lm.matmul(v_lms)
|
230 |
+
|
231 |
+
if self.nglo > 0:
|
232 |
+
attn_cls = attn[:, :, :, -self.nglo :]
|
233 |
+
out += attn_cls.matmul(
|
234 |
+
v_cls
|
235 |
+
) # Changed. Was `.mul` instead of `.matmul`. (JWS)
|
236 |
+
|
237 |
+
# b x h x 1 x lw
|
238 |
+
cls_inner = q_cls.matmul(k_cls.transpose(-1, -2))
|
239 |
+
cls_dots = q_cls.matmul(
|
240 |
+
k.transpose(-1, -2)
|
241 |
+
) # Changed. Was `out` instead of `k`. (JWS)
|
242 |
+
cls_dots = torch.cat([cls_inner, cls_dots], dim=-1)
|
243 |
+
|
244 |
+
# cls_dots = cls_dots.softmax(dim=-1, dtype=torch.float32).to(x)
|
245 |
+
cls_dots = cls_dots.softmax(dim=-1).to(
|
246 |
+
x
|
247 |
+
) # Changed when experimenting with mixed precision (Johannes S.)
|
248 |
+
cls_next = cls_dots[:, :, :, self.nglo :].matmul(
|
249 |
+
v
|
250 |
+
) # the post_cls variant # Changed. Was `out` instead of `v`. (JWS)
|
251 |
+
cls_next += cls_dots[:, :, :, : self.nglo].matmul(v_cls)
|
252 |
+
|
253 |
+
out = torch.cat([cls_next, out], dim=2)
|
254 |
+
out = out.transpose(1, 2).contiguous().view(B, N, -1)
|
255 |
+
|
256 |
+
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
257 |
+
out = self.proj(out)
|
258 |
+
out = self.proj_drop(out)
|
259 |
+
return out
|
260 |
+
|
261 |
+
def compute_window_scores(self, q, k):
|
262 |
+
"""Compute the inner products for the window attention.
|
263 |
+
Frist, divide the query into non-overlapping windows.
|
264 |
+
Then, use torch.as_trided (implemented in self.get_overlapping_tiles) to create a view of the keys
|
265 |
+
that corresponds to the windows with at most 2x memory overhead.
|
266 |
+
Finally, compute the inner product.
|
267 |
+
"""
|
268 |
+
# q: b h (l w) d
|
269 |
+
b, h, _, d = q.shape
|
270 |
+
side_size = max(self.window_size // 2, 1)
|
271 |
+
# q_group_size: segment size
|
272 |
+
kv_width = 2 * side_size + self.window_size # assuming q_stride=1
|
273 |
+
q_n_group = self.img_size // self.window_size
|
274 |
+
q_tiles = q.reshape(
|
275 |
+
b, h, q_n_group, self.window_size, q_n_group, self.window_size, d
|
276 |
+
).permute(0, 1, 2, 4, 3, 5, 6)
|
277 |
+
# q_tiles: b x h x n_group x n_group x w^2 x d
|
278 |
+
q_tiles = q_tiles.contiguous().view(b, h, q_n_group, q_n_group, -1, d)
|
279 |
+
|
280 |
+
# k_tiles: b x h x n_group x n_group x 9w^2 x d
|
281 |
+
k_tiles = (
|
282 |
+
self.get_overlapping_tiles(k)
|
283 |
+
.contiguous()
|
284 |
+
.view(b, h, q_n_group, q_n_group, -1, d)
|
285 |
+
)
|
286 |
+
# dot_tiles: b x h x n_group x n_group x w^2 x 9w^2
|
287 |
+
dot_tiles = q_tiles.matmul(k_tiles.transpose(-1, -2))
|
288 |
+
|
289 |
+
# fill "-inf" into the zero-padding parts
|
290 |
+
dot_tiles = dot_tiles.view(b, h, q_n_group, q_n_group, -1, kv_width, kv_width)
|
291 |
+
|
292 |
+
dot_tiles[:, :, 0, :, :, :side_size].fill_(float("-inf"))
|
293 |
+
dot_tiles[:, :, -1, :, :, -side_size:].fill_(float("-inf"))
|
294 |
+
dot_tiles[:, :, :, 0, :, :, :side_size].fill_(float("-inf"))
|
295 |
+
dot_tiles[:, :, :, -1, :, :, -side_size:].fill_(float("-inf"))
|
296 |
+
|
297 |
+
dot_tiles = dot_tiles.view(b, h, q_n_group, q_n_group, -1, kv_width**2)
|
298 |
+
return dot_tiles
|
299 |
+
|
300 |
+
def get_overlapping_tiles(self, x):
|
301 |
+
"""Get overlapping tiles in the 2D spatial domain, ensuring each query computes correlation with all neighbors"""
|
302 |
+
# x: b h (l w) d
|
303 |
+
b, h, _, d = x.shape
|
304 |
+
side_size = max(self.window_size // 2, 1)
|
305 |
+
total_size = 2 * side_size + self.window_size
|
306 |
+
kv_group_size = self.window_size
|
307 |
+
kv_width = self.img_size
|
308 |
+
|
309 |
+
x = x.view(b, h, kv_width, kv_width, d)
|
310 |
+
x = F.pad(x, [0, 0, side_size, side_size, side_size, side_size], value=0)
|
311 |
+
|
312 |
+
out_shape = [
|
313 |
+
b,
|
314 |
+
h,
|
315 |
+
kv_width // kv_group_size,
|
316 |
+
kv_width // kv_group_size,
|
317 |
+
total_size,
|
318 |
+
total_size,
|
319 |
+
d,
|
320 |
+
]
|
321 |
+
in_stride = x.stride()
|
322 |
+
out_stride = [
|
323 |
+
in_stride[0],
|
324 |
+
in_stride[1],
|
325 |
+
in_stride[2] * kv_group_size,
|
326 |
+
in_stride[3] * kv_group_size,
|
327 |
+
in_stride[2],
|
328 |
+
in_stride[3],
|
329 |
+
in_stride[4],
|
330 |
+
]
|
331 |
+
|
332 |
+
# note we ignored the boundary here
|
333 |
+
return x.as_strided(size=out_shape, stride=out_stride)
|
334 |
+
|
335 |
+
def compute_window_pv(self, attn, v):
|
336 |
+
"""Compute the inner product of attention matrix and the values for the window attention."""
|
337 |
+
b, h, n_group, _, w2, n_k = attn.shape
|
338 |
+
d = v.shape[-1]
|
339 |
+
v_tiles = (
|
340 |
+
self.get_overlapping_tiles(v)
|
341 |
+
.contiguous()
|
342 |
+
.view(b, h, n_group, n_group, -1, d)
|
343 |
+
)
|
344 |
+
|
345 |
+
# b x h x n_group x n_group x w^2 x d
|
346 |
+
pv = attn.matmul(v_tiles)
|
347 |
+
# return: b x h x (lw) x d
|
348 |
+
ret = self.ungroup_dots(pv)
|
349 |
+
|
350 |
+
return ret
|
351 |
+
|
352 |
+
def group_dots(self, dots):
|
353 |
+
b, h = dots.shape[:2]
|
354 |
+
n_group = self.img_size // self.window_size
|
355 |
+
dots = dots.reshape(
|
356 |
+
b, h, n_group, self.window_size, n_group, self.window_size, -1
|
357 |
+
).permute(0, 1, 2, 4, 3, 5, 6)
|
358 |
+
dots = dots.contiguous().view(
|
359 |
+
b, h, n_group, n_group, self.window_size * self.window_size, -1
|
360 |
+
)
|
361 |
+
return dots
|
362 |
+
|
363 |
+
def ungroup_dots(self, dots):
|
364 |
+
b, h, n_group, _, _, n_keys = dots.shape
|
365 |
+
dots = dots.reshape(
|
366 |
+
b, h, n_group, n_group, self.window_size, self.window_size, -1
|
367 |
+
).permute(0, 1, 2, 4, 3, 5, 6)
|
368 |
+
dots = dots.contiguous().view(b, h, -1, n_keys)
|
369 |
+
return dots
|
surya/utils/__init__.py
ADDED
File without changes
|
surya/utils/config.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from argparse import Namespace
|
3 |
+
|
4 |
+
import yaml
|
5 |
+
|
6 |
+
|
7 |
+
class DataConfig:
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
train_data_path: str,
|
11 |
+
valid_data_path: str,
|
12 |
+
batch_size: int,
|
13 |
+
num_data_workers: int,
|
14 |
+
prefetch_factor: int,
|
15 |
+
time_delta_input_minutes: list[int],
|
16 |
+
n_input_timestamps: int | None = None,
|
17 |
+
pooling: int | None = None,
|
18 |
+
random_vert_flip: bool = False,
|
19 |
+
**kwargs,
|
20 |
+
):
|
21 |
+
self.__dict__.update(kwargs)
|
22 |
+
|
23 |
+
self.train_data_path = train_data_path
|
24 |
+
self.valid_data_path = valid_data_path
|
25 |
+
self.batch_size = batch_size
|
26 |
+
self.num_data_workers = num_data_workers
|
27 |
+
self.prefetch_factor = prefetch_factor
|
28 |
+
self.time_delta_input_minutes = sorted(time_delta_input_minutes)
|
29 |
+
self.n_input_timestamps = n_input_timestamps
|
30 |
+
self.pooling = pooling
|
31 |
+
self.random_vert_flip = random_vert_flip
|
32 |
+
|
33 |
+
if self.n_input_timestamps is None:
|
34 |
+
self.n_input_timestamps = len(self.time_delta_input_minutes)
|
35 |
+
|
36 |
+
assert (
|
37 |
+
self.n_input_timestamps > 0
|
38 |
+
), "Number of input timestamps must be greater than 0."
|
39 |
+
assert self.n_input_timestamps <= len(self.time_delta_input_minutes), (
|
40 |
+
f"Cannot sample {self.n_input_timestamps} from list of "
|
41 |
+
f"{self.time_delta_input_minutes} input timestamps."
|
42 |
+
)
|
43 |
+
|
44 |
+
def to_dict(self):
|
45 |
+
return self.__dict__
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
def from_argparse(args: Namespace):
|
49 |
+
return DataConfig(**args.__dict__)
|
50 |
+
|
51 |
+
def __str__(self):
|
52 |
+
return (
|
53 |
+
f"Training index: {self.train_data_path}, "
|
54 |
+
f"Validation index: {self.valid_data_path}, "
|
55 |
+
)
|
56 |
+
|
57 |
+
def __repr__(self):
|
58 |
+
return (
|
59 |
+
f"Training index: {self.train_data_path}, "
|
60 |
+
f"Validation index: {self.valid_data_path}, "
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
class ModelConfig:
|
65 |
+
def __init__(
|
66 |
+
self,
|
67 |
+
# enc_num_layers: int,
|
68 |
+
# enc_num_heads: int,
|
69 |
+
# enc_embed_size: int,
|
70 |
+
# dec_num_layers: int,
|
71 |
+
# dec_num_heads: int,
|
72 |
+
# dec_embed_size: int,
|
73 |
+
# mask_ratio: float,
|
74 |
+
**kwargs,
|
75 |
+
):
|
76 |
+
self.__dict__.update(kwargs)
|
77 |
+
|
78 |
+
# self.enc_num_layers = enc_num_layers
|
79 |
+
# self.enc_num_heads = enc_num_heads
|
80 |
+
# self.enc_embed_size = enc_embed_size
|
81 |
+
# self.dec_num_layers = dec_num_layers
|
82 |
+
# self.dec_num_heads = dec_num_heads
|
83 |
+
# self.dec_embed_size = dec_embed_size
|
84 |
+
# self.mlp_ratio = 0.0
|
85 |
+
# self.mask_ratio = mask_ratio
|
86 |
+
|
87 |
+
self.__dict__.update(kwargs)
|
88 |
+
|
89 |
+
def to_dict(self):
|
90 |
+
return self.__dict__
|
91 |
+
|
92 |
+
@staticmethod
|
93 |
+
def from_argparse(args: Namespace):
|
94 |
+
return ModelConfig(**args.__dict__)
|
95 |
+
|
96 |
+
@property
|
97 |
+
def encoder_d_ff(self):
|
98 |
+
return int(self.enc_embed_size * self.mlp_ratio)
|
99 |
+
|
100 |
+
@property
|
101 |
+
def decoder_d_ff(self):
|
102 |
+
return int(self.dec_embed_size * self.mlp_ratio)
|
103 |
+
|
104 |
+
def __str__(self):
|
105 |
+
return (
|
106 |
+
f"Input channels: {self.model.in_channels}, "
|
107 |
+
f"Encoder (L, H, E): {[self.enc_num_layers, self.enc_num_heads, self.enc_embed_size]}, "
|
108 |
+
f"Decoder (L, H, E): {[self.dec_num_layers, self.dec_num_heads, self.dec_embed_size]}"
|
109 |
+
)
|
110 |
+
|
111 |
+
def __repr__(self):
|
112 |
+
return (
|
113 |
+
f"Input channels: {self.model.in_channels}, "
|
114 |
+
f"Encoder (L, H, E): {[self.enc_num_layers, self.enc_num_heads, self.enc_embed_size]}, "
|
115 |
+
f"Decoder (L, H, E): {[self.dec_num_layers, self.dec_num_heads, self.dec_embed_size]}"
|
116 |
+
)
|
117 |
+
|
118 |
+
|
119 |
+
class OptimizerConfig:
|
120 |
+
def __init__(
|
121 |
+
self,
|
122 |
+
warm_up_steps: int,
|
123 |
+
max_epochs: int,
|
124 |
+
learning_rate: float,
|
125 |
+
min_lr: float,
|
126 |
+
):
|
127 |
+
self.warm_up_steps = warm_up_steps
|
128 |
+
self.max_epochs = max_epochs
|
129 |
+
self.learning_rate = learning_rate
|
130 |
+
self.min_lr = min_lr
|
131 |
+
|
132 |
+
def to_dict(self):
|
133 |
+
return self.__dict__
|
134 |
+
|
135 |
+
@staticmethod
|
136 |
+
def from_argparse(args: Namespace):
|
137 |
+
return ModelConfig(**args.__dict__)
|
138 |
+
|
139 |
+
def __str__(self):
|
140 |
+
return (
|
141 |
+
f"Epochs: {self.max_epochs}, "
|
142 |
+
f"LR: {[self.learning_rate, self.min_lr]}, "
|
143 |
+
f"Warm up: {self.warm_up_steps},"
|
144 |
+
)
|
145 |
+
|
146 |
+
def __repr__(self):
|
147 |
+
return (
|
148 |
+
f"Epochs: {self.max_epochs}, "
|
149 |
+
f"LR: {[self.learning_rate, self.min_lr]}, "
|
150 |
+
f"Warm up: {self.warm_up_steps},"
|
151 |
+
)
|
152 |
+
|
153 |
+
|
154 |
+
class ExperimentConfig:
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
job_id: str,
|
158 |
+
data_config: DataConfig,
|
159 |
+
model_config: ModelConfig,
|
160 |
+
optimizer_config: OptimizerConfig,
|
161 |
+
path_experiment: str,
|
162 |
+
parallelism: str,
|
163 |
+
from_checkpoint: str | None = None,
|
164 |
+
**kwargs,
|
165 |
+
):
|
166 |
+
# additional experiment parameters used in downstream tasks
|
167 |
+
self.__dict__.update(kwargs)
|
168 |
+
|
169 |
+
self.job_id = job_id
|
170 |
+
self.data = data_config
|
171 |
+
self.model = model_config
|
172 |
+
self.optimizer = optimizer_config
|
173 |
+
self.path_experiment = path_experiment
|
174 |
+
self.from_checkpoint = from_checkpoint
|
175 |
+
self.parallelism = parallelism
|
176 |
+
|
177 |
+
assert self.model.in_channels == len(self.data.channels), (
|
178 |
+
f"Number of model input channels ({self.model.in_channels}) must be "
|
179 |
+
f"equal to number of input variables ({len(self.data.channels)})."
|
180 |
+
)
|
181 |
+
if self.model.time_embedding["type"] == "linear":
|
182 |
+
assert (
|
183 |
+
self.model.time_embedding["time_dim"] == self.data.n_input_timestamps
|
184 |
+
), "Time dimension of linear embedding must be equal to number of input timestamps."
|
185 |
+
if self.rollout_steps > 0:
|
186 |
+
assert self.data.n_input_timestamps == len(
|
187 |
+
self.data.time_delta_input_minutes
|
188 |
+
), "Rollout does not support randomly sampled input timestamps."
|
189 |
+
|
190 |
+
metrics_channels = []
|
191 |
+
for field1, value1 in self.metrics["train_metrics_config"].items():
|
192 |
+
for field2, value2 in self.metrics["train_metrics_config"][field1].items():
|
193 |
+
if field2 == "metrics":
|
194 |
+
for metric_definition in value2:
|
195 |
+
split_metric_definition = metric_definition.split(":")
|
196 |
+
channels = (
|
197 |
+
split_metric_definition[2]
|
198 |
+
if len(split_metric_definition) > 2
|
199 |
+
else None
|
200 |
+
)
|
201 |
+
if channels is not None:
|
202 |
+
metrics_channels = metrics_channels + channels.split("...")
|
203 |
+
|
204 |
+
for field1, value1 in self.metrics["validation_metrics_config"].items():
|
205 |
+
for field2, value2 in self.metrics["validation_metrics_config"][
|
206 |
+
field1
|
207 |
+
].items():
|
208 |
+
if field2 == "metrics":
|
209 |
+
for metric_definition in value2:
|
210 |
+
split_metric_definition = metric_definition.split(":")
|
211 |
+
channels = (
|
212 |
+
split_metric_definition[2]
|
213 |
+
if len(split_metric_definition) > 2
|
214 |
+
else None
|
215 |
+
)
|
216 |
+
if channels is not None:
|
217 |
+
metrics_channels = metrics_channels + channels.replace(
|
218 |
+
"...", "&"
|
219 |
+
).split("&")
|
220 |
+
|
221 |
+
assert set(metrics_channels).issubset(self.data.channels), (
|
222 |
+
f"{set(metrics_channels).difference(self.data.channels)} "
|
223 |
+
f"not part of data input channels."
|
224 |
+
)
|
225 |
+
|
226 |
+
assert self.parallelism in [
|
227 |
+
"ddp",
|
228 |
+
"fsdp",
|
229 |
+
], 'Valid choices for `parallelism` are "ddp" and "fsdp".'
|
230 |
+
|
231 |
+
@property
|
232 |
+
def path_checkpoint(self) -> str:
|
233 |
+
if self.path_experiment == "":
|
234 |
+
return os.path.join(self.path_weights, "train", "checkpoint.pt")
|
235 |
+
else:
|
236 |
+
return os.path.join(
|
237 |
+
os.path.dirname(self.path_experiment),
|
238 |
+
"weights",
|
239 |
+
"train",
|
240 |
+
"checkpoint.pt",
|
241 |
+
)
|
242 |
+
|
243 |
+
@property
|
244 |
+
def path_weights(self) -> str:
|
245 |
+
return os.path.join(self.path_experiment, self.make_suffix_path(), "weights")
|
246 |
+
|
247 |
+
@property
|
248 |
+
def path_states(self) -> str:
|
249 |
+
return os.path.join(self.path_experiment, self.make_suffix_path(), "states")
|
250 |
+
|
251 |
+
def to_dict(self):
|
252 |
+
d = self.__dict__.copy()
|
253 |
+
d["model"] = self.model.to_dict()
|
254 |
+
d["data"] = self.data.to_dict()
|
255 |
+
|
256 |
+
return d
|
257 |
+
|
258 |
+
@staticmethod
|
259 |
+
def from_argparse(args: Namespace):
|
260 |
+
return ExperimentConfig(
|
261 |
+
data_config=DataConfig.from_argparse(args),
|
262 |
+
model_config=ModelConfig.from_argparse(args),
|
263 |
+
optimizer_config=OptimizerConfig.from_argparse(args),
|
264 |
+
**args.__dict__,
|
265 |
+
)
|
266 |
+
|
267 |
+
@staticmethod
|
268 |
+
def from_dict(params: dict):
|
269 |
+
return ExperimentConfig(
|
270 |
+
data_config=DataConfig(**params["data"]),
|
271 |
+
model_config=ModelConfig(**params["model"]),
|
272 |
+
optimizer_config=OptimizerConfig(**params["optimizer"]),
|
273 |
+
**params,
|
274 |
+
)
|
275 |
+
|
276 |
+
def make_folder_name(self) -> str:
|
277 |
+
param_folder = "wpt-c1-s1"
|
278 |
+
return param_folder
|
279 |
+
|
280 |
+
def make_suffix_path(self) -> str:
|
281 |
+
return os.path.join(self.job_id)
|
282 |
+
|
283 |
+
def __str__(self):
|
284 |
+
return (
|
285 |
+
f"ID: {self.job_id}, "
|
286 |
+
f"Epochs: {self.optimizer.max_epochs}, "
|
287 |
+
f"Batch size: {self.data.batch_size}, "
|
288 |
+
f"LR: {[self.optimizer.learning_rate, self.optimizer.min_lr]}, "
|
289 |
+
f"Warm up: {self.optimizer.warm_up_steps},"
|
290 |
+
f"DL workers: {self.data.num_data_workers},"
|
291 |
+
f"Parallelism: {self.parallelism}"
|
292 |
+
)
|
293 |
+
|
294 |
+
def __repr__(self):
|
295 |
+
return (
|
296 |
+
f"ID: {self.job_id}, "
|
297 |
+
f"Epochs: {self.optimizer.max_epochs}, "
|
298 |
+
f"Batch size: {self.data.batch_size}, "
|
299 |
+
f"LR: {[self.optimizer.learning_rate, self.optimizer.min_lr]}, "
|
300 |
+
f"Warm up: {self.optimizer.warm_up_steps},"
|
301 |
+
f"DL workers: {self.data.num_data_workers},"
|
302 |
+
f"Parallelism: {self.parallelism}"
|
303 |
+
)
|
304 |
+
|
305 |
+
|
306 |
+
def get_config(
|
307 |
+
config_path: str,
|
308 |
+
) -> ExperimentConfig:
|
309 |
+
cfg = yaml.safe_load(open(config_path, "r"))
|
310 |
+
cfg["data"]["scalers"] = yaml.safe_load(open(cfg["data"]["scalers_path"], "r"))
|
311 |
+
return ExperimentConfig.from_dict(params=cfg)
|
surya/utils/data.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from surya.datasets.transformations import Transformation, StandardScaler
|
7 |
+
from surya.utils.config import DataConfig
|
8 |
+
from surya.utils.misc import class_from_name, view_as_windows
|
9 |
+
|
10 |
+
|
11 |
+
def custom_collate_fn(batch):
|
12 |
+
"""
|
13 |
+
Custom collate function for handling batches of data and metadata in a PyTorch DataLoader.
|
14 |
+
|
15 |
+
This function separately processes the data and metadata from the input batch.
|
16 |
+
|
17 |
+
- The `data_batch` is collated using PyTorch's `default_collate`. If collation fails due to incompatible data types,
|
18 |
+
the batch is returned as-is.
|
19 |
+
|
20 |
+
- The `metadata_batch` is assumed to be a dictionary, where each key corresponds to a list of values across the batch.
|
21 |
+
Each key is collated using `default_collate`. If collation fails for a particular key, the original list of values
|
22 |
+
is retained.
|
23 |
+
|
24 |
+
Example usage for accessing collated metadata:
|
25 |
+
- `collated_metadata['timestamps_input'][batch_idx][input_time]`
|
26 |
+
- `collated_metadata['timestamps_input'][batch_idx][rollout_step]`
|
27 |
+
|
28 |
+
Args:
|
29 |
+
batch (list of tuples): Each tuple contains (data, metadata), where:
|
30 |
+
- `data` is a tensor or other data structure used for training.
|
31 |
+
- `metadata` is a dictionary containing additional information.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
tuple: (collated_data, collated_metadata)
|
35 |
+
- `collated_data`: The processed batch of data.
|
36 |
+
- `collated_metadata`: The processed batch of metadata.
|
37 |
+
"""
|
38 |
+
|
39 |
+
# Unpack batch into separate lists of data and metadata
|
40 |
+
data_batch, metadata_batch = zip(*batch)
|
41 |
+
|
42 |
+
# Attempt to collate the data batch using PyTorch's default collate function
|
43 |
+
try:
|
44 |
+
collated_data = torch.utils.data.default_collate(data_batch)
|
45 |
+
except TypeError:
|
46 |
+
# If default_collate fails (e.g., due to incompatible types), return the data batch as-is
|
47 |
+
collated_data = data_batch
|
48 |
+
|
49 |
+
# Handle metadata collation
|
50 |
+
if isinstance(metadata_batch[0], dict):
|
51 |
+
collated_metadata = {}
|
52 |
+
for key in metadata_batch[0].keys():
|
53 |
+
values = [d[key] for d in metadata_batch]
|
54 |
+
try:
|
55 |
+
# Attempt to collate values under the current key
|
56 |
+
collated_metadata[key] = torch.utils.data.default_collate(values)
|
57 |
+
except TypeError:
|
58 |
+
# If collation fails, keep the values as a list
|
59 |
+
collated_metadata[key] = values
|
60 |
+
else:
|
61 |
+
# If metadata is not a dictionary, try to collate it as a whole
|
62 |
+
try:
|
63 |
+
collated_metadata = torch.utils.data.default_collate(metadata_batch)
|
64 |
+
except TypeError:
|
65 |
+
# If collation fails, return metadata as-is
|
66 |
+
collated_metadata = metadata_batch
|
67 |
+
|
68 |
+
return collated_data, collated_metadata
|
69 |
+
|
70 |
+
|
71 |
+
def calc_num_windows(raw_size: int, win_size: int, stride: int) -> int:
|
72 |
+
return (raw_size - win_size) // stride + 1
|
73 |
+
|
74 |
+
|
75 |
+
def get_scalers_info(dataset) -> dict:
|
76 |
+
return {
|
77 |
+
k: (type(v).__module__, type(v).__name__, v.to_dict())
|
78 |
+
for k, v in dataset.scalers.items()
|
79 |
+
}
|
80 |
+
|
81 |
+
|
82 |
+
def build_scalers_pressure(info: dict) -> Dict[str, Transformation]:
|
83 |
+
ret_dict = {k: dict() for k in info.keys()}
|
84 |
+
for var_key, var_d in info.items():
|
85 |
+
for p_key, p_val in var_d.items():
|
86 |
+
ret_dict[var_key][p_key] = class_from_name(
|
87 |
+
p_val["base"], p_val["class"]
|
88 |
+
).from_dict(p_val)
|
89 |
+
return ret_dict
|
90 |
+
|
91 |
+
|
92 |
+
def build_scalers(info: dict) -> Dict[str, Transformation]:
|
93 |
+
ret_dict = {k: None for k in info.keys()}
|
94 |
+
for p_key, p_val in info.items():
|
95 |
+
ret_dict[p_key]: StandardScaler = class_from_name(
|
96 |
+
p_val["base"], p_val["class"]
|
97 |
+
).from_dict(p_val)
|
98 |
+
return ret_dict
|
99 |
+
|
100 |
+
|
101 |
+
def break_batch_5d(
|
102 |
+
data: list, lat_size: int, lon_size: int, time_steps: int
|
103 |
+
) -> np.ndarray:
|
104 |
+
"""
|
105 |
+
data: list of samples, each sample is [C, T, L, H, W]
|
106 |
+
"""
|
107 |
+
num_levels = data[0].shape[2]
|
108 |
+
num_vars = data[0].shape[0]
|
109 |
+
big_batch = np.stack(data, axis=0)
|
110 |
+
vw = view_as_windows(
|
111 |
+
big_batch,
|
112 |
+
[1, num_vars, time_steps, num_levels, lat_size, lon_size],
|
113 |
+
step=[1, num_vars, time_steps, num_levels, lat_size, lon_size],
|
114 |
+
).squeeze()
|
115 |
+
# To check if it is correctly reshaping
|
116 |
+
# idx = 30
|
117 |
+
# (big_batch[0, :, idx:idx+2, :, 40:80, 40:80]-vw[idx//2, 1, 1]).sum()
|
118 |
+
vw = vw.reshape(-1, num_vars, time_steps, num_levels, lat_size, lon_size)
|
119 |
+
# How to test:
|
120 |
+
# (big_batch[0, :, :2, :, :40, :40] - vw[0]).sum()
|
121 |
+
# (big_batch[0, :, :2, :, :40, 40:80] - vw[1]).sum()
|
122 |
+
# (big_batch[0, :, :2, :, 40:80, :40] - vw[2]).sum()
|
123 |
+
|
124 |
+
# Need to move axis because Weather model is expecting [C, L, T, H, W] instead of [C, T, L, H, W]
|
125 |
+
vw = np.moveaxis(vw, 3, 2)
|
126 |
+
vw = torch.tensor(vw, dtype=torch.float32)
|
127 |
+
return vw
|
128 |
+
|
129 |
+
|
130 |
+
def break_batch_5d_aug(data: list, cfg: DataConfig, max_batch: int = 256) -> np.ndarray:
|
131 |
+
num_levels = data[0].shape[2]
|
132 |
+
num_vars = data[0].shape[0]
|
133 |
+
big_batch = np.stack(data, axis=0)
|
134 |
+
|
135 |
+
y_step, x_step, t_step = (
|
136 |
+
cfg.patch_size_lat // 2,
|
137 |
+
cfg.patch_size_lon // 2,
|
138 |
+
cfg.patch_size_time // 2,
|
139 |
+
)
|
140 |
+
y_max = calc_num_windows(big_batch.shape[4], cfg.input_size_lat, y_step)
|
141 |
+
x_max = calc_num_windows(big_batch.shape[5], cfg.input_size_lon, x_step)
|
142 |
+
t_max = calc_num_windows(big_batch.shape[2], cfg.input_size_time, t_step)
|
143 |
+
max_batch = min(max_batch, y_max * x_max * t_max)
|
144 |
+
|
145 |
+
batch = np.empty(
|
146 |
+
(
|
147 |
+
max_batch,
|
148 |
+
num_vars,
|
149 |
+
cfg.input_size_time,
|
150 |
+
num_levels,
|
151 |
+
cfg.input_size_lat,
|
152 |
+
cfg.input_size_lon,
|
153 |
+
),
|
154 |
+
dtype=np.float32,
|
155 |
+
)
|
156 |
+
for j, i in enumerate(np.random.permutation(np.arange(max_batch))):
|
157 |
+
t, y, x = np.unravel_index(
|
158 |
+
i,
|
159 |
+
(
|
160 |
+
t_max,
|
161 |
+
y_max,
|
162 |
+
x_max,
|
163 |
+
),
|
164 |
+
)
|
165 |
+
batch[j] = big_batch[
|
166 |
+
:, # batch_id
|
167 |
+
:, # vars
|
168 |
+
t * t_step : t * t_step + cfg.input_size_time,
|
169 |
+
:, # levels
|
170 |
+
y * y_step : y * y_step + cfg.input_size_lat,
|
171 |
+
x * x_step : x * x_step + cfg.input_size_lon,
|
172 |
+
]
|
173 |
+
|
174 |
+
batch = np.moveaxis(batch, 3, 2)
|
175 |
+
batch = torch.tensor(batch, dtype=torch.float32)
|
176 |
+
return batch
|
surya/utils/distributed.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
from datetime import timedelta
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.backends.cudnn as cudnn
|
8 |
+
import torch.distributed as dist
|
9 |
+
import torch.nn as nn
|
10 |
+
from torch.distributed import checkpoint as dist_checkpoint
|
11 |
+
from torch.distributed import fsdp
|
12 |
+
|
13 |
+
import functools
|
14 |
+
import itertools
|
15 |
+
|
16 |
+
from torch.utils.data.distributed import DistributedSampler
|
17 |
+
from torch.utils.data import Dataset
|
18 |
+
from typing import Any, Dict, Optional
|
19 |
+
|
20 |
+
from surya.utils.schemas import TrainState
|
21 |
+
|
22 |
+
|
23 |
+
def init_dist(device: str, rank: int, world_size: int):
|
24 |
+
torch.distributed.init_process_group(
|
25 |
+
device,
|
26 |
+
init_method="env://",
|
27 |
+
world_size=world_size,
|
28 |
+
rank=rank,
|
29 |
+
timeout=timedelta(minutes=60),
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
def init_ddp(use_gpu: bool):
|
34 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
35 |
+
rank = int(os.environ["RANK"])
|
36 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
37 |
+
|
38 |
+
if use_gpu:
|
39 |
+
assert (
|
40 |
+
torch.cuda.is_available()
|
41 |
+
), "GPU requested but none was found in the system."
|
42 |
+
|
43 |
+
if use_gpu:
|
44 |
+
init_dist("nccl", rank, world_size)
|
45 |
+
torch.cuda.set_device(local_rank)
|
46 |
+
os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
|
47 |
+
os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = str(1)
|
48 |
+
os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
|
49 |
+
cudnn.benchmark = True
|
50 |
+
else:
|
51 |
+
init_dist("gloo", rank, world_size)
|
52 |
+
return local_rank, rank
|
53 |
+
|
54 |
+
|
55 |
+
def set_global_seed(rank):
|
56 |
+
random.seed(42 + rank)
|
57 |
+
torch.cuda.manual_seed(42 + rank)
|
58 |
+
torch.manual_seed(42 + rank)
|
59 |
+
np.random.seed(42 + rank)
|
60 |
+
|
61 |
+
|
62 |
+
def is_dist_avail_and_initialized():
|
63 |
+
if not dist.is_available():
|
64 |
+
return False
|
65 |
+
if not dist.is_initialized():
|
66 |
+
return False
|
67 |
+
return True
|
68 |
+
|
69 |
+
|
70 |
+
def get_world_size():
|
71 |
+
if not is_dist_avail_and_initialized():
|
72 |
+
return 1
|
73 |
+
return dist.get_world_size()
|
74 |
+
|
75 |
+
|
76 |
+
def get_rank():
|
77 |
+
if not is_dist_avail_and_initialized():
|
78 |
+
return 0
|
79 |
+
return dist.get_rank()
|
80 |
+
|
81 |
+
|
82 |
+
def is_main_process():
|
83 |
+
return get_rank() == 0
|
84 |
+
|
85 |
+
|
86 |
+
# def save_model_singular(model, *args, **kwargs):
|
87 |
+
# """Stream all model parameters to rank 0 on the CPU, then pass all
|
88 |
+
# other given arguments to `torch.save` to save the model, but only on
|
89 |
+
# the root process.
|
90 |
+
# """
|
91 |
+
# save_policy = fsdp.FullStateDictConfig(
|
92 |
+
# offload_to_cpu=True, rank0_only=True)
|
93 |
+
# with fsdp.FullyShardedDataParallel.state_dict_type(
|
94 |
+
# model,
|
95 |
+
# fsdp.StateDictType.FULL_STATE_DICT,
|
96 |
+
# save_policy,
|
97 |
+
# ):
|
98 |
+
# cpu_state = model.state_dict()
|
99 |
+
# # We do *not* want to write to the same location with multiple
|
100 |
+
# # processes at the same time.
|
101 |
+
# if is_root_process():
|
102 |
+
# torch.save(cpu_state, *args, **kwargs)
|
103 |
+
|
104 |
+
|
105 |
+
def save_model(model, save_dir):
|
106 |
+
"""Obtain sharded model parameters from the GPU, then save the model
|
107 |
+
as a distributed checkpoint to the given directory. Saving a
|
108 |
+
distributed checkpoint means that the checkpoint will be split into
|
109 |
+
individual files, one for each process.
|
110 |
+
"""
|
111 |
+
state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False)
|
112 |
+
with fsdp.FullyShardedDataParallel.state_dict_type(
|
113 |
+
model,
|
114 |
+
fsdp.StateDictType.SHARDED_STATE_DICT,
|
115 |
+
state_dict_config,
|
116 |
+
):
|
117 |
+
cp_state_dict = {"model": model.state_dict()}
|
118 |
+
dist_checkpoint.save_state_dict(
|
119 |
+
cp_state_dict,
|
120 |
+
dist_checkpoint.FileSystemWriter(save_dir),
|
121 |
+
)
|
122 |
+
|
123 |
+
|
124 |
+
def load_model(model, load_dir):
|
125 |
+
"""Set the given model's state dictionary in-place from the given
|
126 |
+
distributed checkpoint directory.
|
127 |
+
"""
|
128 |
+
state_dict_config = fsdp.ShardedStateDictConfig(offload_to_cpu=False)
|
129 |
+
with fsdp.FullyShardedDataParallel.state_dict_type(
|
130 |
+
model,
|
131 |
+
fsdp.StateDictType.SHARDED_STATE_DICT,
|
132 |
+
state_dict_config,
|
133 |
+
):
|
134 |
+
cp_state_dict = {"model": model.state_dict()}
|
135 |
+
dist_checkpoint.load_state_dict(
|
136 |
+
cp_state_dict,
|
137 |
+
dist_checkpoint.FileSystemReader(load_dir),
|
138 |
+
)
|
139 |
+
model.load_state_dict(cp_state_dict["model"])
|
140 |
+
|
141 |
+
|
142 |
+
@functools.lru_cache(maxsize=None)
|
143 |
+
def is_root_process():
|
144 |
+
"""Return whether this process is the root process."""
|
145 |
+
return torch.distributed.get_rank() == 0
|
146 |
+
|
147 |
+
|
148 |
+
# The reason we define this is that `torch.distributed` does not
|
149 |
+
# implement it; for the global rank, there's
|
150 |
+
# `torch.distributed.get_rank()`.
|
151 |
+
@functools.lru_cache(maxsize=None)
|
152 |
+
def get_local_rank():
|
153 |
+
"""Return the local rank of this process."""
|
154 |
+
return int(os.getenv("LOCAL_RANK"))
|
155 |
+
|
156 |
+
|
157 |
+
def print0(*args, **kwargs):
|
158 |
+
"""Print something only on the root process."""
|
159 |
+
if (not dist.is_initialized()) or is_root_process():
|
160 |
+
print(*args, **kwargs)
|
161 |
+
|
162 |
+
|
163 |
+
def save_model_singular(model, save_path, parallelism, *args, **kwargs):
|
164 |
+
"""Stream all model parameters to rank 0 on the CPU, then pass all
|
165 |
+
other given arguments to `torch.save` to save the model, but only on
|
166 |
+
the root process.
|
167 |
+
"""
|
168 |
+
|
169 |
+
match parallelism:
|
170 |
+
case "fsdp":
|
171 |
+
save_policy = fsdp.FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
|
172 |
+
with fsdp.FullyShardedDataParallel.state_dict_type(
|
173 |
+
model,
|
174 |
+
fsdp.StateDictType.FULL_STATE_DICT,
|
175 |
+
save_policy,
|
176 |
+
):
|
177 |
+
cpu_state = model.state_dict()
|
178 |
+
# We do *not* want to write to the same location with multiple
|
179 |
+
# processes at the same time.
|
180 |
+
if is_main_process():
|
181 |
+
if not os.path.exists(os.path.dirname(save_path)):
|
182 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
183 |
+
torch.save(obj=cpu_state, f=save_path, *args, **kwargs)
|
184 |
+
|
185 |
+
case "ddp":
|
186 |
+
if is_main_process():
|
187 |
+
torch.save(obj=model.module.state_dict(), f=save_path, *args, **kwargs)
|
188 |
+
dist.barrier()
|
189 |
+
case _:
|
190 |
+
raise ValueError(
|
191 |
+
f'`parallelism` should be one of "ddp" and "fsdp". Got {parallelism}.'
|
192 |
+
)
|
193 |
+
|
194 |
+
|
195 |
+
def save_optim_singular(
|
196 |
+
model: nn.Module,
|
197 |
+
optimizer: torch.optim.Optimizer,
|
198 |
+
save_path: str,
|
199 |
+
parallelism: str = "fsdp",
|
200 |
+
):
|
201 |
+
match parallelism:
|
202 |
+
case "fsdp":
|
203 |
+
optim_state_dict_config = fsdp.FullOptimStateDictConfig(
|
204 |
+
offload_to_cpu=True, rank0_only=True
|
205 |
+
)
|
206 |
+
|
207 |
+
with fsdp.FullyShardedDataParallel.state_dict_type(
|
208 |
+
model,
|
209 |
+
fsdp.StateDictType.FULL_STATE_DICT,
|
210 |
+
optim_state_dict_config=optim_state_dict_config,
|
211 |
+
):
|
212 |
+
optim_state_dict = fsdp.FullyShardedDataParallel.optim_state_dict(
|
213 |
+
model, optimizer
|
214 |
+
)
|
215 |
+
|
216 |
+
if is_main_process():
|
217 |
+
if not os.path.exists(os.path.dirname(save_path)):
|
218 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
219 |
+
checkpoint = {
|
220 |
+
"optimizer_state_dict": optim_state_dict,
|
221 |
+
}
|
222 |
+
torch.save(checkpoint, f=save_path)
|
223 |
+
case "ddp":
|
224 |
+
if is_main_process():
|
225 |
+
optim_state_dict = optimizer.state_dict()
|
226 |
+
if not os.path.exists(os.path.dirname(save_path)):
|
227 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
228 |
+
torch.save(obj=optim_state_dict, f=save_path)
|
229 |
+
dist.barrier()
|
230 |
+
case _:
|
231 |
+
raise ValueError(
|
232 |
+
f'`parallelism` should be one of "ddp" and "fsdp". Got {parallelism}.'
|
233 |
+
)
|
234 |
+
|
235 |
+
|
236 |
+
def collect_optim_singular(
|
237 |
+
model: nn.Module, optimizer: torch.optim.Optimizer, parallelism: str = "fsdp"
|
238 |
+
) -> dict:
|
239 |
+
optim_state_dict = {}
|
240 |
+
match parallelism:
|
241 |
+
case "fsdp":
|
242 |
+
optim_state_dict_config = fsdp.FullOptimStateDictConfig(
|
243 |
+
offload_to_cpu=True, rank0_only=True
|
244 |
+
)
|
245 |
+
|
246 |
+
with fsdp.FullyShardedDataParallel.state_dict_type(
|
247 |
+
model,
|
248 |
+
fsdp.StateDictType.FULL_STATE_DICT,
|
249 |
+
optim_state_dict_config=optim_state_dict_config,
|
250 |
+
):
|
251 |
+
optim_state_dict = fsdp.FullyShardedDataParallel.optim_state_dict(
|
252 |
+
model, optimizer
|
253 |
+
)
|
254 |
+
|
255 |
+
case "ddp":
|
256 |
+
if is_main_process():
|
257 |
+
optim_state_dict = optimizer.state_dict()
|
258 |
+
dist.barrier()
|
259 |
+
case _:
|
260 |
+
raise ValueError(
|
261 |
+
f'`parallelism` should be one of "ddp" and "fsdp". Got {parallelism}.'
|
262 |
+
)
|
263 |
+
|
264 |
+
return optim_state_dict
|
265 |
+
|
266 |
+
|
267 |
+
def save_state_singular(states: TrainState, save_path, *args, **kwargs):
|
268 |
+
"""Stream all model parameters to rank 0 on the CPU, then pass all
|
269 |
+
other given arguments to `torch.save` to save paramters, but only on
|
270 |
+
the root process.
|
271 |
+
"""
|
272 |
+
if is_main_process():
|
273 |
+
if not os.path.exists(os.path.dirname(save_path)):
|
274 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
275 |
+
torch.save(obj=states, f=save_path, *args, **kwargs)
|
276 |
+
dist.barrier()
|
277 |
+
|
278 |
+
|
279 |
+
class StatefulDistributedSampler(DistributedSampler):
|
280 |
+
_YIELDED = "yielded"
|
281 |
+
|
282 |
+
def __init__(
|
283 |
+
self,
|
284 |
+
dataset: Dataset,
|
285 |
+
num_replicas: Optional[int] = None,
|
286 |
+
rank: Optional[int] = None,
|
287 |
+
shuffle: bool = True,
|
288 |
+
seed: int = 0,
|
289 |
+
drop_last: bool = False,
|
290 |
+
) -> None:
|
291 |
+
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
292 |
+
self.yielded = 0
|
293 |
+
self.next_yielded = None
|
294 |
+
|
295 |
+
def __iter__(self):
|
296 |
+
self.yielded = 0
|
297 |
+
if self.next_yielded is not None:
|
298 |
+
self.yielded = self.next_yielded
|
299 |
+
self.next_yielded = None
|
300 |
+
it = super().__iter__()
|
301 |
+
for idx in itertools.islice(it, self.yielded, None):
|
302 |
+
self.yielded += 1
|
303 |
+
yield idx
|
304 |
+
|
305 |
+
def state_dict(self) -> Dict[str, Any]:
|
306 |
+
return {self._YIELDED: self.yielded}
|
307 |
+
|
308 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
309 |
+
if self._YIELDED not in state_dict:
|
310 |
+
raise ValueError("Invalid state_dict")
|
311 |
+
if state_dict[self._YIELDED] < 0:
|
312 |
+
raise ValueError("Cannot load state_dict with negative yielded value")
|
313 |
+
self.next_yielded = state_dict[self._YIELDED]
|
surya/utils/log.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import functools
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from time import time
|
6 |
+
from packaging.version import Version
|
7 |
+
import wandb
|
8 |
+
from typing import Dict, Optional, Any
|
9 |
+
|
10 |
+
|
11 |
+
if Version(wandb.__version__) < Version("0.20.0"):
|
12 |
+
WANDB_USE_SYNC = True
|
13 |
+
else:
|
14 |
+
WANDB_USE_SYNC = False
|
15 |
+
|
16 |
+
|
17 |
+
def log(
|
18 |
+
run,
|
19 |
+
data: Dict[str, Any],
|
20 |
+
step: Optional[int] = None,
|
21 |
+
commit: Optional[bool] = None,
|
22 |
+
sync: Optional[bool] = None,
|
23 |
+
) -> None:
|
24 |
+
if run is not None:
|
25 |
+
# Note: wandb changed the .log API with version 0.20.0.
|
26 |
+
# This includes: "Removed no-op sync argument from wandb.Run::log function"
|
27 |
+
# We didn't test whether sync has any function here. But since we did
|
28 |
+
# all our development with it, let's keep it here for now.
|
29 |
+
# See https://github.com/wandb/wandb/releases/tag/v0.20.0
|
30 |
+
if WANDB_USE_SYNC:
|
31 |
+
run.log(data, step, commit, sync)
|
32 |
+
else:
|
33 |
+
run.log(data, step, commit)
|
34 |
+
else:
|
35 |
+
print(data)
|
36 |
+
|
37 |
+
|
38 |
+
# See: https://github.com/microsoft/Swin-Transformer/blob/main/logger.py
|
39 |
+
# See: https://github.com/Meituan-AutoML/Twins/blob/main/logger.py
|
40 |
+
def create_logger(output_dir: str, dist_rank: int, name: str) -> logging.Logger:
|
41 |
+
# create logger
|
42 |
+
logger = logging.getLogger(name)
|
43 |
+
logger.setLevel(logging.DEBUG)
|
44 |
+
logger.propagate = False
|
45 |
+
|
46 |
+
# create formatter
|
47 |
+
fmt = "[%(asctime)s %(name)s]: %(levelname)s %(message)s"
|
48 |
+
|
49 |
+
# create console handlers
|
50 |
+
if name.endswith("main"):
|
51 |
+
console_handler = logging.StreamHandler(sys.stdout)
|
52 |
+
console_handler.setLevel(logging.INFO)
|
53 |
+
console_handler.setFormatter(
|
54 |
+
logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S")
|
55 |
+
)
|
56 |
+
logger.addHandler(console_handler)
|
57 |
+
|
58 |
+
# create file handlers
|
59 |
+
file_handler = logging.FileHandler(
|
60 |
+
os.path.join(output_dir, f"{name}.log"), mode="a"
|
61 |
+
)
|
62 |
+
file_handler.setLevel(logging.DEBUG)
|
63 |
+
file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt="%Y-%m-%d %H:%M:%S"))
|
64 |
+
logger.addHandler(file_handler)
|
65 |
+
|
66 |
+
return logger
|
67 |
+
|
68 |
+
|
69 |
+
def log_decorator(logger, _func=None):
|
70 |
+
def log_decorator_info(func):
|
71 |
+
@functools.wraps(func)
|
72 |
+
def log_decorator_wrapper(*args, **kwargs):
|
73 |
+
"""Create a list of the positional arguments passed to function.
|
74 |
+
- Using repr() for string representation for each argument. repr() is similar to str() only
|
75 |
+
difference being it prints with a pair of quotes and if we calculate a value we get more
|
76 |
+
precise value than str().
|
77 |
+
"""
|
78 |
+
|
79 |
+
# py_file_caller = getframeinfo(stack()[1][0])
|
80 |
+
|
81 |
+
local_rank = os.environ.get("LOCAL_RANK", default=None)
|
82 |
+
rank = os.environ.get("LOCAL_RANK", default=None)
|
83 |
+
|
84 |
+
try:
|
85 |
+
"""log return value from the function"""
|
86 |
+
start_time = time()
|
87 |
+
value = func(*args, **kwargs)
|
88 |
+
if local_rank is None or rank is None:
|
89 |
+
logger.info(
|
90 |
+
f"Function '{func.__name__}' - Execution time: {(time() - start_time):.1f} seconds."
|
91 |
+
)
|
92 |
+
else:
|
93 |
+
logger.info(
|
94 |
+
f"Function '{func.__name__}' - Execution time: {(time() - start_time):.1f} "
|
95 |
+
f"seconds on rank {os.environ['RANK']} and local_rank {os.environ['LOCAL_RANK']}."
|
96 |
+
)
|
97 |
+
except Exception as err:
|
98 |
+
logger.error(f"Exception: {err}")
|
99 |
+
raise
|
100 |
+
return value
|
101 |
+
|
102 |
+
# Return the pointer to the function
|
103 |
+
return log_decorator_wrapper
|
104 |
+
|
105 |
+
# Decorator was called with arguments, so return a decorator function that can read and return a function
|
106 |
+
if _func is None:
|
107 |
+
return log_decorator_info
|
108 |
+
# Decorator was called without arguments, so apply the decorator to the function immediately
|
109 |
+
else:
|
110 |
+
return log_decorator_info(_func)
|
surya/utils/misc.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numbers
|
2 |
+
from logging import Logger
|
3 |
+
from time import time
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from numpy.lib.stride_tricks import as_strided
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
|
11 |
+
def view_as_windows(arr_in: np.ndarray, window_shape, step=1) -> np.ndarray:
|
12 |
+
"""Rolling window view of the input n-dimensional array.
|
13 |
+
Windows are overlapping views of the input array, with adjacent windows
|
14 |
+
shifted by a single row or column (or an index of a higher dimension).
|
15 |
+
|
16 |
+
Ref: https://github.com/scikit-image/scikit-image/blob/5e74a4a3a5149a8a14566b81a32bb15499aa3857/skimage/util/shape.py#L97-L247
|
17 |
+
Parameters
|
18 |
+
"""
|
19 |
+
|
20 |
+
# -- basic checks on arguments
|
21 |
+
if not isinstance(arr_in, np.ndarray):
|
22 |
+
raise TypeError("`arr_in` must be a numpy ndarray")
|
23 |
+
|
24 |
+
ndim = arr_in.ndim
|
25 |
+
|
26 |
+
if isinstance(window_shape, numbers.Number):
|
27 |
+
window_shape = (window_shape,) * ndim
|
28 |
+
if not (len(window_shape) == ndim):
|
29 |
+
raise ValueError("`window_shape` is incompatible with `arr_in.shape`")
|
30 |
+
|
31 |
+
if isinstance(step, numbers.Number):
|
32 |
+
if step < 1:
|
33 |
+
raise ValueError("`step` must be >= 1")
|
34 |
+
step = (step,) * ndim
|
35 |
+
if len(step) != ndim:
|
36 |
+
raise ValueError("`step` is incompatible with `arr_in.shape`")
|
37 |
+
|
38 |
+
arr_shape = np.array(arr_in.shape)
|
39 |
+
window_shape = np.array(window_shape, dtype=arr_shape.dtype)
|
40 |
+
|
41 |
+
if ((arr_shape - window_shape) < 0).any():
|
42 |
+
raise ValueError("`window_shape` is too large")
|
43 |
+
|
44 |
+
if ((window_shape - 1) < 0).any():
|
45 |
+
raise ValueError("`window_shape` is too small")
|
46 |
+
|
47 |
+
# -- build rolling window view
|
48 |
+
slices = tuple(slice(None, None, st) for st in step)
|
49 |
+
window_strides = np.array(arr_in.strides)
|
50 |
+
|
51 |
+
indexing_strides = arr_in[slices].strides
|
52 |
+
|
53 |
+
win_indices_shape = (
|
54 |
+
(np.array(arr_in.shape) - np.array(window_shape)) // np.array(step)
|
55 |
+
) + 1
|
56 |
+
|
57 |
+
new_shape = tuple(list(win_indices_shape) + list(window_shape))
|
58 |
+
strides = tuple(list(indexing_strides) + list(window_strides))
|
59 |
+
|
60 |
+
arr_out = as_strided(arr_in, shape=new_shape, strides=strides)
|
61 |
+
return arr_out
|
62 |
+
|
63 |
+
|
64 |
+
def class_from_name(module_name: str, class_name: str) -> object:
|
65 |
+
# load the module, will raise ImportError if module cannot be loaded
|
66 |
+
m = __import__(module_name, globals(), locals(), [class_name])
|
67 |
+
# get the class, will raise AttributeError if class cannot be found
|
68 |
+
c = getattr(m, class_name)
|
69 |
+
return c
|
70 |
+
|
71 |
+
|
72 |
+
@torch.no_grad()
|
73 |
+
def throughput(data_loader: DataLoader, model: torch.nn.Module, logger: Logger):
|
74 |
+
model.eval()
|
75 |
+
|
76 |
+
for idx, (images, _) in enumerate(data_loader):
|
77 |
+
images = images.cuda(non_blocking=True)
|
78 |
+
batch_size = images.shape[0]
|
79 |
+
for i in range(50):
|
80 |
+
model(images)
|
81 |
+
torch.cuda.synchronize()
|
82 |
+
logger.info("throughput averaged with 30 times")
|
83 |
+
tic1 = time()
|
84 |
+
for i in range(30):
|
85 |
+
model(images)
|
86 |
+
torch.cuda.synchronize()
|
87 |
+
tic2 = time()
|
88 |
+
logger.info(
|
89 |
+
f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}"
|
90 |
+
)
|
surya/utils/schemas.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict, Dict, Any
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class TrainState(TypedDict):
|
6 |
+
dataloader: torch.utils.data.DataLoader
|
7 |
+
optimizer: Dict[str, Any]
|
8 |
+
scheduler: Dict[str, Any]
|
9 |
+
sampler: Any # Changed from torch.utils.data.sampler to Any
|
10 |
+
profiler: bool
|
11 |
+
epoch: int
|
12 |
+
iteration: int
|
13 |
+
loss: float
|
14 |
+
wandb_state: int
|