|
import streamlit as st |
|
import torch |
|
import random |
|
import numpy as np |
|
import yaml |
|
from pathlib import Path |
|
from io import BytesIO |
|
import random |
|
from pathlib import Path |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
import tempfile |
|
import traceback |
|
import functools as ft |
|
import os |
|
import random |
|
import re |
|
from collections import defaultdict |
|
from datetime import datetime, timedelta |
|
from pathlib import Path |
|
import h5py |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
from torch import Tensor |
|
from torch.utils.data import Dataset |
|
import logging |
|
from Prithvi import * |
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
st.set_page_config( |
|
page_title="MERRA2 Data Processor", |
|
layout="wide", |
|
initial_sidebar_state="expanded", |
|
) |
|
dataset_type = st.sidebar.selectbox( |
|
"Select Dataset Type", |
|
options=["MERRA2", "GEOS5"], |
|
index=0 |
|
) |
|
st.title("MERRA2 Data Processor with PrithviWxC Model") |
|
|
|
|
|
st.sidebar.header("Upload MERRA2 Data Files") |
|
|
|
|
|
uploaded_surface_files = st.sidebar.file_uploader( |
|
"Upload Surface Data Files", |
|
type=["nc", "netcdf"], |
|
accept_multiple_files=True, |
|
key="surface_uploader", |
|
) |
|
|
|
|
|
uploaded_vertical_files = st.sidebar.file_uploader( |
|
"Upload Vertical Data Files", |
|
type=["nc", "netcdf"], |
|
accept_multiple_files=True, |
|
key="vertical_uploader", |
|
) |
|
|
|
|
|
uploaded_config = st.sidebar.file_uploader( |
|
"Upload config.yaml", |
|
type=["yaml", "yml"], |
|
key="config_uploader", |
|
) |
|
|
|
|
|
uploaded_weights = st.sidebar.file_uploader( |
|
"Upload Model Weights (.pt)", |
|
type=["pt"], |
|
key="weights_uploader", |
|
) |
|
|
|
|
|
st.sidebar.header("Task Configuration") |
|
|
|
lead_times = st.sidebar.multiselect( |
|
"Select Lead Times", |
|
options=[12, 24, 36, 48], |
|
default=[12], |
|
) |
|
|
|
input_times = st.sidebar.multiselect( |
|
"Select Input Times", |
|
options=[-6, -12, -18, -24], |
|
default=[-6], |
|
) |
|
|
|
time_range_start = st.sidebar.text_input( |
|
"Start Time (e.g., 2020-01-01T00:00:00)", |
|
value="2020-01-01T00:00:00", |
|
) |
|
|
|
time_range_end = st.sidebar.text_input( |
|
"End Time (e.g., 2020-01-01T23:59:59)", |
|
value="2020-01-01T23:59:59", |
|
) |
|
|
|
time_range = (time_range_start, time_range_end) |
|
|
|
|
|
def save_uploaded_files(uploaded_files, folder_name, max_size_mb=1024): |
|
if not uploaded_files: |
|
st.warning(f"No {folder_name} files uploaded.") |
|
return None |
|
|
|
for file in uploaded_files: |
|
if file.size > max_size_mb * 1024 * 1024: |
|
st.error(f"File {file.name} exceeds the maximum size of {max_size_mb} MB.") |
|
return None |
|
temp_dir = tempfile.mkdtemp() |
|
with st.spinner(f"Saving {folder_name} files..."): |
|
for uploaded_file in uploaded_files: |
|
file_path = Path(temp_dir) / uploaded_file.name |
|
with open(file_path, "wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
st.success(f"Saved {len(uploaded_files)} {folder_name} files.") |
|
return Path(temp_dir) |
|
|
|
|
|
surf_dir = save_uploaded_files(uploaded_surface_files, "surface") |
|
vert_dir = save_uploaded_files(uploaded_vertical_files, "vertical") |
|
|
|
|
|
if surf_dir: |
|
st.sidebar.subheader("Surface Files Uploaded:") |
|
for file in surf_dir.iterdir(): |
|
st.sidebar.write(file.name) |
|
|
|
if vert_dir: |
|
st.sidebar.subheader("Vertical Files Uploaded:") |
|
for file in vert_dir.iterdir(): |
|
st.sidebar.write(file.name) |
|
|
|
|
|
st.sidebar.header("Upload Climatology Files (If Missing)") |
|
|
|
|
|
default_clim_dir = Path("Prithvi-WxC/examples/climatology") |
|
surf_in_scal_path = default_clim_dir / "musigma_surface.nc" |
|
vert_in_scal_path = default_clim_dir / "musigma_vertical.nc" |
|
surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc" |
|
vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc" |
|
|
|
|
|
clim_files_exist = all( |
|
[ |
|
surf_in_scal_path.exists(), |
|
vert_in_scal_path.exists(), |
|
surf_out_scal_path.exists(), |
|
vert_out_scal_path.exists(), |
|
] |
|
) |
|
|
|
if not clim_files_exist: |
|
st.sidebar.warning("Climatology files are missing.") |
|
uploaded_clim_surface = st.sidebar.file_uploader( |
|
"Upload Climatology Surface File", |
|
type=["nc", "netcdf"], |
|
key="clim_surface_uploader", |
|
) |
|
uploaded_clim_vertical = st.sidebar.file_uploader( |
|
"Upload Climatology Vertical File", |
|
type=["nc", "netcdf"], |
|
key="clim_vertical_uploader", |
|
) |
|
|
|
if uploaded_clim_surface and uploaded_clim_vertical: |
|
clim_temp_dir = tempfile.mkdtemp() |
|
clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name |
|
with open(clim_surf_path, "wb") as f: |
|
f.write(uploaded_clim_surface.getbuffer()) |
|
clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name |
|
with open(clim_vert_path, "wb") as f: |
|
f.write(uploaded_clim_vertical.getbuffer()) |
|
st.success("Climatology files uploaded and saved.") |
|
else: |
|
if not (uploaded_clim_surface and uploaded_clim_vertical): |
|
st.warning("Please upload both climatology surface and vertical files.") |
|
else: |
|
clim_surf_path = surf_in_scal_path |
|
clim_vert_path = vert_in_scal_path |
|
|
|
|
|
if uploaded_config: |
|
temp_config = tempfile.mktemp(suffix=".yaml") |
|
with open(temp_config, "wb") as f: |
|
f.write(uploaded_config.getbuffer()) |
|
config_path = Path(temp_config) |
|
st.sidebar.success("Config.yaml uploaded and saved.") |
|
else: |
|
|
|
config_path = Path("Prithvi-WxC/examples/config.yaml") |
|
if not config_path.exists(): |
|
st.sidebar.error("Default config.yaml not found. Please upload a config file.") |
|
st.stop() |
|
|
|
|
|
if uploaded_weights: |
|
temp_weights = tempfile.mktemp(suffix=".pt") |
|
with open(temp_weights, "wb") as f: |
|
f.write(uploaded_weights.getbuffer()) |
|
weights_path = Path(temp_weights) |
|
st.sidebar.success("Model weights uploaded and saved.") |
|
else: |
|
|
|
weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt") |
|
if not weights_path.exists(): |
|
st.sidebar.error("Default model weights not found. Please upload model weights.") |
|
st.stop() |
|
|
|
|
|
if st.sidebar.button("Run Inference"): |
|
|
|
|
|
torch.jit.enable_onednn_fusion(True) |
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda") |
|
st.write(f"Using device: {torch.cuda.get_device_name()}") |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cudnn.deterministic = True |
|
else: |
|
device = torch.device("cpu") |
|
st.write("Using device: CPU") |
|
|
|
|
|
random.seed(42) |
|
if torch.cuda.is_available(): |
|
torch.cuda.manual_seed(42) |
|
torch.manual_seed(42) |
|
np.random.seed(42) |
|
|
|
|
|
surface_vars = [ |
|
"EFLUX", |
|
"GWETROOT", |
|
"HFLUX", |
|
"LAI", |
|
"LWGAB", |
|
"LWGEM", |
|
"LWTUP", |
|
"PS", |
|
"QV2M", |
|
"SLP", |
|
"SWGNT", |
|
"SWTNT", |
|
"T2M", |
|
"TQI", |
|
"TQL", |
|
"TQV", |
|
"TS", |
|
"U10M", |
|
"V10M", |
|
"Z0M", |
|
] |
|
static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"] |
|
vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"] |
|
levels = [ |
|
34.0, |
|
39.0, |
|
41.0, |
|
43.0, |
|
44.0, |
|
45.0, |
|
48.0, |
|
51.0, |
|
53.0, |
|
56.0, |
|
63.0, |
|
68.0, |
|
71.0, |
|
72.0, |
|
] |
|
padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]} |
|
|
|
residual = "climate" |
|
masking_mode = "local" |
|
decoder_shifting = True |
|
masking_ratio = 0.99 |
|
|
|
positional_encoding = "fourier" |
|
|
|
|
|
try: |
|
with st.spinner("Initializing dataset..."): |
|
|
|
if not clim_files_exist and not (uploaded_clim_surface and uploaded_clim_vertical): |
|
st.error("Climatology files are missing. Please upload both surface and vertical climatology files.") |
|
st.stop() |
|
|
|
dataset = Merra2Dataset( |
|
time_range=time_range, |
|
lead_times=lead_times, |
|
input_times=input_times, |
|
data_path_surface=Path("Prithvi-WxC/examples/merra-2"), |
|
data_path_vertical=Path("Prithvi-WxC/examples/merra-2"), |
|
climatology_path_surface=Path("Prithvi-WxC/examples/climatology"), |
|
climatology_path_vertical=Path("Prithvi-WxC/examples/climatology"), |
|
surface_vars=surface_vars, |
|
static_surface_vars=static_surface_vars, |
|
vertical_vars=vertical_vars, |
|
levels=levels, |
|
positional_encoding=positional_encoding, |
|
) |
|
assert len(dataset) > 0, "There doesn't seem to be any valid data." |
|
st.success("Dataset initialized successfully.") |
|
except Exception as e: |
|
st.error("Error initializing dataset:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
try: |
|
with st.spinner("Loading scalers..."): |
|
|
|
surf_in_scal_path = clim_surf_path |
|
vert_in_scal_path = clim_vert_path |
|
surf_out_scal_path = Path(clim_surf_path.parent) / "anomaly_variance_surface.nc" |
|
vert_out_scal_path = Path(clim_vert_path.parent) / "anomaly_variance_vertical.nc" |
|
|
|
|
|
if not surf_out_scal_path.exists() or not vert_out_scal_path.exists(): |
|
st.error("Anomaly variance scaler files are missing.") |
|
st.stop() |
|
|
|
in_mu, in_sig = input_scalers( |
|
surface_vars, |
|
vertical_vars, |
|
levels, |
|
surf_in_scal_path, |
|
vert_in_scal_path, |
|
) |
|
|
|
output_sig = output_scalers( |
|
surface_vars, |
|
vertical_vars, |
|
levels, |
|
surf_out_scal_path, |
|
vert_out_scal_path, |
|
) |
|
|
|
static_mu, static_sig = static_input_scalers( |
|
surf_in_scal_path, |
|
static_surface_vars, |
|
) |
|
st.success("Scalers loaded successfully.") |
|
except Exception as e: |
|
st.error("Error loading scalers:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
try: |
|
with st.spinner("Loading configuration..."): |
|
with open(config_path, "r") as f: |
|
config = yaml.safe_load(f) |
|
|
|
required_params = [ |
|
"in_channels", "input_size_time", "in_channels_static", |
|
"input_scalers_epsilon", "static_input_scalers_epsilon", |
|
"n_lats_px", "n_lons_px", "patch_size_px", |
|
"mask_unit_size_px", "embed_dim", "n_blocks_encoder", |
|
"n_blocks_decoder", "mlp_multiplier", "n_heads", |
|
"dropout", "drop_path", "parameter_dropout" |
|
] |
|
missing_params = [param for param in required_params if param not in config.get("params", {})] |
|
if missing_params: |
|
st.error(f"Missing configuration parameters: {missing_params}") |
|
st.stop() |
|
st.success("Configuration loaded successfully.") |
|
except Exception as e: |
|
st.error("Error loading configuration:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
try: |
|
with st.spinner("Initializing model..."): |
|
model = PrithviWxC( |
|
in_channels=config["params"]["in_channels"], |
|
input_size_time=config["params"]["input_size_time"], |
|
in_channels_static=config["params"]["in_channels_static"], |
|
input_scalers_mu=in_mu, |
|
input_scalers_sigma=in_sig, |
|
input_scalers_epsilon=config["params"]["input_scalers_epsilon"], |
|
static_input_scalers_mu=static_mu, |
|
static_input_scalers_sigma=static_sig, |
|
static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"], |
|
output_scalers=output_sig**0.5, |
|
n_lats_px=config["params"]["n_lats_px"], |
|
n_lons_px=config["params"]["n_lons_px"], |
|
patch_size_px=config["params"]["patch_size_px"], |
|
mask_unit_size_px=config["params"]["mask_unit_size_px"], |
|
mask_ratio_inputs=masking_ratio, |
|
embed_dim=config["params"]["embed_dim"], |
|
n_blocks_encoder=config["params"]["n_blocks_encoder"], |
|
n_blocks_decoder=config["params"]["n_blocks_decoder"], |
|
mlp_multiplier=config["params"]["mlp_multiplier"], |
|
n_heads=config["params"]["n_heads"], |
|
dropout=config["params"]["dropout"], |
|
drop_path=config["params"]["drop_path"], |
|
parameter_dropout=config["params"]["parameter_dropout"], |
|
residual=residual, |
|
masking_mode=masking_mode, |
|
decoder_shifting=decoder_shifting, |
|
positional_encoding=positional_encoding, |
|
checkpoint_encoder=[], |
|
checkpoint_decoder=[], |
|
) |
|
st.success("Model initialized successfully.") |
|
except Exception as e: |
|
st.error("Error initializing model:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
try: |
|
with st.spinner("Loading model weights..."): |
|
state_dict = torch.load(weights_path, map_location=device) |
|
if "model_state" in state_dict: |
|
state_dict = state_dict["model_state"] |
|
model.load_state_dict(state_dict, strict=True) |
|
model.to(device) |
|
st.success("Model weights loaded successfully.") |
|
except Exception as e: |
|
st.error("Error loading model weights:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
try: |
|
with st.spinner("Preparing data batch..."): |
|
data = next(iter(dataset)) |
|
batch = preproc([data], padding) |
|
|
|
for k, v in batch.items(): |
|
if isinstance(v, torch.Tensor): |
|
batch[k] = v.to(device) |
|
st.success("Data batch prepared successfully.") |
|
except Exception as e: |
|
st.error("Error preparing data batch:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
try: |
|
with st.spinner("Running model inference..."): |
|
rng_state_1 = torch.get_rng_state() |
|
with torch.no_grad(): |
|
model.eval() |
|
out = model(batch) |
|
st.success("Model inference completed successfully.") |
|
except Exception as e: |
|
st.error("Error during model inference:") |
|
st.error(traceback.format_exc()) |
|
st.stop() |
|
|
|
|
|
st.header("Inference Results") |
|
st.write(out) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
st.info("Please upload the necessary files and click 'Run Inference' to start.") |
|
|