qq1990 commited on
Commit
8a18a38
1 Parent(s): 68341cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +329 -2
app.py CHANGED
@@ -1,4 +1,331 @@
 
 
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
1
+ # import streamlit as st
2
+
3
+ # x = st.slider('Select a value')
4
+ # st.write(x, 'squared is', x * x)
5
+
6
+
7
  import streamlit as st
8
+ import random
9
+ from pathlib import Path
10
+ import matplotlib.pyplot as plt
11
+ import numpy as np
12
+ import torch
13
+ from huggingface_hub import hf_hub_download, snapshot_download
14
+ import tarfile
15
+ import os
16
+ import sys
17
+ import yaml
18
+
19
+ st.title("PrithviWxC Model Inference")
20
+
21
+ st.write("Setting up environment...")
22
+
23
+ # Set up torch backends and seeds
24
+ torch.jit.enable_onednn_fusion(True)
25
+ if torch.cuda.is_available():
26
+ st.write(f"Using device: {torch.cuda.get_device_name()}")
27
+ torch.backends.cudnn.benchmark = True
28
+ torch.backends.cudnn.deterministic = True
29
+
30
+ random.seed(42)
31
+ if torch.cuda.is_available():
32
+ torch.cuda.manual_seed(42)
33
+ torch.manual_seed(42)
34
+ np.random.seed(42)
35
+
36
+ # Set device
37
+ if torch.cuda.is_available():
38
+ device = torch.device("cuda")
39
+ else:
40
+ device = torch.device("cpu")
41
+
42
+ st.write(f"Using device: {device}")
43
+
44
+ # Download and extract PrithviWxC module
45
+ st.write("Downloading and setting up PrithviWxC module...")
46
+
47
+ module_tar_path = hf_hub_download(
48
+ repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
49
+ filename="PrithviWxC.tar.gz",
50
+ local_dir=".",
51
+ force_download=True
52
+ )
53
+
54
+ with tarfile.open(module_tar_path, "r:gz") as tar:
55
+ tar.extractall(path=".")
56
+
57
+ # Add the module path to sys.path
58
+ sys.path.append(os.path.abspath("./PrithviWxC"))
59
+
60
+ st.write("PrithviWxC module imported successfully.")
61
+
62
+ # Now import the module
63
+ from PrithviWxC.dataloaders.merra2 import Merra2Dataset, input_scalers, output_scalers, static_input_scalers, preproc
64
+ from PrithviWxC.model import PrithviWxC
65
+
66
+ # Variables and times
67
+ surface_vars = [
68
+ "EFLUX",
69
+ "GWETROOT",
70
+ "HFLUX",
71
+ "LAI",
72
+ "LWGAB",
73
+ "LWGEM",
74
+ "LWTUP",
75
+ "PS",
76
+ "QV2M",
77
+ "SLP",
78
+ "SWGNT",
79
+ "SWTNT",
80
+ "T2M",
81
+ "TQI",
82
+ "TQL",
83
+ "TQV",
84
+ "TS",
85
+ "U10M",
86
+ "V10M",
87
+ "Z0M",
88
+ ]
89
+ static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
90
+ vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]
91
+ levels = [
92
+ 34.0,
93
+ 39.0,
94
+ 41.0,
95
+ 43.0,
96
+ 44.0,
97
+ 45.0,
98
+ 48.0,
99
+ 51.0,
100
+ 53.0,
101
+ 56.0,
102
+ 63.0,
103
+ 68.0,
104
+ 71.0,
105
+ 72.0,
106
+ ]
107
+ padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]}
108
+
109
+ st.write("Setting up dataset parameters...")
110
+
111
+ # User inputs for lead times and input times
112
+ lead_time = st.number_input("Lead Time (hours)", min_value=1, max_value=24, value=6)
113
+ input_time = st.number_input("Input Time Difference (hours)", min_value=-24, max_value=0, value=-6)
114
+
115
+ lead_times = [lead_time] # This variable can be changed to change the task
116
+ input_times = [input_time] # This variable can be changed to change the task
117
+
118
+ # Data file
119
+ time_range = ("2020-01-01T00:00:00", "2020-01-01T23:59:59")
120
+
121
+ st.write("Downloading data files...")
122
+
123
+ surf_dir = Path("./merra-2")
124
+ snapshot_download(
125
+ repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
126
+ allow_patterns="merra-2/MERRA2_sfc_2020010[1].nc",
127
+ local_dir=".",
128
+ force_download=True,
129
+ )
130
+
131
+ vert_dir = Path("./merra-2")
132
+ snapshot_download(
133
+ repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
134
+ allow_patterns="merra-2/MERRA_pres_2020010[1].nc",
135
+ local_dir=".",
136
+ force_download=True,
137
+ )
138
+
139
+ # Climatology
140
+ surf_clim_dir = Path("./climatology")
141
+ snapshot_download(
142
+ repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
143
+ allow_patterns="climatology/climate_surface_doy00[1]*.nc",
144
+ local_dir=".",
145
+ force_download=True,
146
+ )
147
+
148
+ vert_clim_dir = Path("./climatology")
149
+ snapshot_download(
150
+ repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
151
+ allow_patterns="climatology/climate_vertical_doy00[1]*.nc",
152
+ local_dir=".",
153
+ force_download=True,
154
+ )
155
+
156
+ st.write("Setting positional encoding...")
157
+
158
+ positional_encoding = "fourier"
159
+
160
+ st.write("Initializing dataset...")
161
+
162
+ dataset = Merra2Dataset(
163
+ time_range=time_range,
164
+ lead_times=lead_times,
165
+ input_times=input_times,
166
+ data_path_surface=surf_dir,
167
+ data_path_vertical=vert_dir,
168
+ climatology_path_surface=surf_clim_dir,
169
+ climatology_path_vertical=vert_clim_dir,
170
+ surface_vars=surface_vars,
171
+ static_surface_vars=static_surface_vars,
172
+ vertical_vars=vertical_vars,
173
+ levels=levels,
174
+ positional_encoding=positional_encoding,
175
+ )
176
+
177
+ assert len(dataset) > 0, "There doesn't seem to be any valid data."
178
+
179
+ st.write("Loading scalers...")
180
+
181
+ surf_in_scal_path = Path("./climatology/musigma_surface.nc")
182
+ hf_hub_download(
183
+ repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
184
+ filename=f"climatology/{surf_in_scal_path.name}",
185
+ local_dir=".",
186
+ force_download=True,
187
+ )
188
+
189
+ vert_in_scal_path = Path("./climatology/musigma_vertical.nc")
190
+ hf_hub_download(
191
+ repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
192
+ filename=f"climatology/{vert_in_scal_path.name}",
193
+ local_dir=".",
194
+ force_download=True,
195
+ )
196
+
197
+ surf_out_scal_path = Path("./climatology/anomaly_variance_surface.nc")
198
+ hf_hub_download(
199
+ repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
200
+ filename=f"climatology/{surf_out_scal_path.name}",
201
+ local_dir=".",
202
+ force_download=True,
203
+ )
204
+
205
+ vert_out_scal_path = Path("./climatology/anomaly_variance_vertical.nc")
206
+ hf_hub_download(
207
+ repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
208
+ filename=f"climatology/{vert_out_scal_path.name}",
209
+ local_dir=".",
210
+ force_download=True,
211
+ )
212
+
213
+ in_mu, in_sig = input_scalers(
214
+ surface_vars,
215
+ vertical_vars,
216
+ levels,
217
+ surf_in_scal_path,
218
+ vert_in_scal_path,
219
+ )
220
+
221
+ output_sig = output_scalers(
222
+ surface_vars,
223
+ vertical_vars,
224
+ levels,
225
+ surf_out_scal_path,
226
+ vert_out_scal_path,
227
+ )
228
+
229
+ static_mu, static_sig = static_input_scalers(
230
+ surf_in_scal_path,
231
+ static_surface_vars,
232
+ )
233
+
234
+ st.write("Setting up model...")
235
+
236
+ residual = "climate"
237
+ masking_mode = "local"
238
+ decoder_shifting = True
239
+ masking_ratio = 0.99
240
+
241
+ # Load model config
242
+ hf_hub_download(
243
+ repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
244
+ filename="config.yaml",
245
+ local_dir=".",
246
+ force_download=True,
247
+ )
248
+
249
+ with open("./config.yaml", "r") as f:
250
+ config = yaml.safe_load(f)
251
+
252
+ model = PrithviWxC(
253
+ in_channels=config["params"]["in_channels"],
254
+ input_size_time=config["params"]["input_size_time"],
255
+ in_channels_static=config["params"]["in_channels_static"],
256
+ input_scalers_mu=in_mu,
257
+ input_scalers_sigma=in_sig,
258
+ input_scalers_epsilon=config["params"]["input_scalers_epsilon"],
259
+ static_input_scalers_mu=static_mu,
260
+ static_input_scalers_sigma=static_sig,
261
+ static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"],
262
+ output_scalers=output_sig**0.5,
263
+ n_lats_px=config["params"]["n_lats_px"],
264
+ n_lons_px=config["params"]["n_lons_px"],
265
+ patch_size_px=config["params"]["patch_size_px"],
266
+ mask_unit_size_px=config["params"]["mask_unit_size_px"],
267
+ mask_ratio_inputs=masking_ratio,
268
+ embed_dim=config["params"]["embed_dim"],
269
+ n_blocks_encoder=config["params"]["n_blocks_encoder"],
270
+ n_blocks_decoder=config["params"]["n_blocks_decoder"],
271
+ mlp_multiplier=config["params"]["mlp_multiplier"],
272
+ n_heads=config["params"]["n_heads"],
273
+ dropout=config["params"]["dropout"],
274
+ drop_path=config["params"]["drop_path"],
275
+ parameter_dropout=config["params"]["parameter_dropout"],
276
+ residual=residual,
277
+ masking_mode=masking_mode,
278
+ decoder_shifting=decoder_shifting,
279
+ positional_encoding=positional_encoding,
280
+ checkpoint_encoder=[],
281
+ checkpoint_decoder=[],
282
+ )
283
+
284
+ st.write("Loading model weights...")
285
+
286
+ weights_path = Path("./weights/prithvi.wxc.2300m.v1.pt")
287
+ hf_hub_download(
288
+ repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
289
+ filename=weights_path.name,
290
+ local_dir="./weights",
291
+ force_download=True,
292
+ )
293
+
294
+ state_dict = torch.load(weights_path, map_location=device)
295
+ if "model_state" in state_dict:
296
+ state_dict = state_dict["model_state"]
297
+ model.load_state_dict(state_dict, strict=True)
298
+
299
+ model = model.to(device)
300
+
301
+ st.write("Model loaded and ready.")
302
+
303
+ if st.button("Run Inference"):
304
+ st.write("Running inference...")
305
+
306
+ data = next(iter(dataset))
307
+ batch = preproc([data], padding)
308
+
309
+ for k, v in batch.items():
310
+ if isinstance(v, torch.Tensor):
311
+ batch[k] = v.to(device)
312
+
313
+ with torch.no_grad():
314
+ model.eval()
315
+ out = model(batch)
316
+
317
+ st.write("Inference completed. Generating plot...")
318
+
319
+ t2m = out[0, 12].cpu().numpy()
320
+
321
+ lat = np.linspace(-90, 90, out.shape[-2])
322
+ lon = np.linspace(-180, 180, out.shape[-1])
323
+ X, Y = np.meshgrid(lon, lat)
324
+
325
+ fig, ax = plt.subplots()
326
+ cs = ax.contourf(X, Y, t2m, 100)
327
+ ax.set_aspect("equal")
328
+ plt.colorbar(cs)
329
+ st.pyplot(fig)
330
 
331
+ st.write("Plot generated.")