import streamlit as st import torch import random import numpy as np import yaml from pathlib import Path from io import BytesIO import random from pathlib import Path import matplotlib.pyplot as plt import numpy as np import torch from huggingface_hub import hf_hub_download, snapshot_download import tempfile import traceback import functools as ft import os import random import re from collections import defaultdict from datetime import datetime, timedelta from pathlib import Path import h5py import numpy as np import pandas as pd import torch from torch import Tensor from torch.utils.data import Dataset import logging from Prithvi import * # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Set page configuration st.set_page_config( page_title="MERRA2 Data Processor", layout="wide", initial_sidebar_state="expanded", ) dataset_type = st.sidebar.selectbox( "Select Dataset Type", options=["MERRA2", "GEOS5"], index=0 ) st.title("MERRA2 Data Processor with PrithviWxC Model") # Sidebar for file uploads st.sidebar.header("Upload MERRA2 Data Files") # File uploader for surface data uploaded_surface_files = st.sidebar.file_uploader( "Upload Surface Data Files", type=["nc", "netcdf"], accept_multiple_files=True, key="surface_uploader", ) # File uploader for vertical data uploaded_vertical_files = st.sidebar.file_uploader( "Upload Vertical Data Files", type=["nc", "netcdf"], accept_multiple_files=True, key="vertical_uploader", ) # Optional: Upload config.yaml uploaded_config = st.sidebar.file_uploader( "Upload config.yaml", type=["yaml", "yml"], key="config_uploader", ) # Optional: Upload model weights uploaded_weights = st.sidebar.file_uploader( "Upload Model Weights (.pt)", type=["pt"], key="weights_uploader", ) # Other configurations st.sidebar.header("Task Configuration") lead_times = st.sidebar.multiselect( "Select Lead Times", options=[12, 24, 36, 48], default=[12], ) input_times = st.sidebar.multiselect( "Select Input Times", options=[-6, -12, -18, -24], default=[-6], ) time_range_start = st.sidebar.text_input( "Start Time (e.g., 2020-01-01T00:00:00)", value="2020-01-01T00:00:00", ) time_range_end = st.sidebar.text_input( "End Time (e.g., 2020-01-01T23:59:59)", value="2020-01-01T23:59:59", ) time_range = (time_range_start, time_range_end) # Function to save uploaded files def save_uploaded_files(uploaded_files, folder_name, max_size_mb=1024): if not uploaded_files: st.warning(f"No {folder_name} files uploaded.") return None # Validate file sizes for file in uploaded_files: if file.size > max_size_mb * 1024 * 1024: st.error(f"File {file.name} exceeds the maximum size of {max_size_mb} MB.") return None temp_dir = tempfile.mkdtemp() with st.spinner(f"Saving {folder_name} files..."): for uploaded_file in uploaded_files: file_path = Path(temp_dir) / uploaded_file.name with open(file_path, "wb") as f: f.write(uploaded_file.getbuffer()) st.success(f"Saved {len(uploaded_files)} {folder_name} files.") return Path(temp_dir) # Save uploaded files surf_dir = save_uploaded_files(uploaded_surface_files, "surface") vert_dir = save_uploaded_files(uploaded_vertical_files, "vertical") # Display uploaded files if surf_dir: st.sidebar.subheader("Surface Files Uploaded:") for file in surf_dir.iterdir(): st.sidebar.write(file.name) if vert_dir: st.sidebar.subheader("Vertical Files Uploaded:") for file in vert_dir.iterdir(): st.sidebar.write(file.name) # Handle Climatology Files st.sidebar.header("Upload Climatology Files (If Missing)") # Climatology files paths default_clim_dir = Path("Prithvi-WxC/examples/climatology") surf_in_scal_path = default_clim_dir / "musigma_surface.nc" vert_in_scal_path = default_clim_dir / "musigma_vertical.nc" surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc" vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc" # Check if climatology files exist clim_files_exist = all( [ surf_in_scal_path.exists(), vert_in_scal_path.exists(), surf_out_scal_path.exists(), vert_out_scal_path.exists(), ] ) if not clim_files_exist: st.sidebar.warning("Climatology files are missing.") uploaded_clim_surface = st.sidebar.file_uploader( "Upload Climatology Surface File", type=["nc", "netcdf"], key="clim_surface_uploader", ) uploaded_clim_vertical = st.sidebar.file_uploader( "Upload Climatology Vertical File", type=["nc", "netcdf"], key="clim_vertical_uploader", ) if uploaded_clim_surface and uploaded_clim_vertical: clim_temp_dir = tempfile.mkdtemp() clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name with open(clim_surf_path, "wb") as f: f.write(uploaded_clim_surface.getbuffer()) clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name with open(clim_vert_path, "wb") as f: f.write(uploaded_clim_vertical.getbuffer()) st.success("Climatology files uploaded and saved.") else: if not (uploaded_clim_surface and uploaded_clim_vertical): st.warning("Please upload both climatology surface and vertical files.") else: clim_surf_path = surf_in_scal_path clim_vert_path = vert_in_scal_path # Save uploaded config.yaml if uploaded_config: temp_config = tempfile.mktemp(suffix=".yaml") with open(temp_config, "wb") as f: f.write(uploaded_config.getbuffer()) config_path = Path(temp_config) st.sidebar.success("Config.yaml uploaded and saved.") else: # Use default config.yaml path config_path = Path("Prithvi-WxC/examples/config.yaml") if not config_path.exists(): st.sidebar.error("Default config.yaml not found. Please upload a config file.") st.stop() # Save uploaded model weights if uploaded_weights: temp_weights = tempfile.mktemp(suffix=".pt") with open(temp_weights, "wb") as f: f.write(uploaded_weights.getbuffer()) weights_path = Path(temp_weights) st.sidebar.success("Model weights uploaded and saved.") else: # Use default weights path weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt") if not weights_path.exists(): st.sidebar.error("Default model weights not found. Please upload model weights.") st.stop() # Button to run inference if st.sidebar.button("Run Inference"): # Initialize device torch.jit.enable_onednn_fusion(True) if torch.cuda.is_available(): device = torch.device("cuda") st.write(f"Using device: {torch.cuda.get_device_name()}") torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = True else: device = torch.device("cpu") st.write("Using device: CPU") # Set random seeds random.seed(42) if torch.cuda.is_available(): torch.cuda.manual_seed(42) torch.manual_seed(42) np.random.seed(42) # Define variables and parameters surface_vars = [ "EFLUX", "GWETROOT", "HFLUX", "LAI", "LWGAB", "LWGEM", "LWTUP", "PS", "QV2M", "SLP", "SWGNT", "SWTNT", "T2M", "TQI", "TQL", "TQV", "TS", "U10M", "V10M", "Z0M", ] static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"] vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"] levels = [ 34.0, 39.0, 41.0, 43.0, 44.0, 45.0, 48.0, 51.0, 53.0, 56.0, 63.0, 68.0, 71.0, 72.0, ] padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]} residual = "climate" masking_mode = "local" decoder_shifting = True masking_ratio = 0.99 positional_encoding = "fourier" # Initialize Dataset try: with st.spinner("Initializing dataset..."): # Validate climatology files if not clim_files_exist and not (uploaded_clim_surface and uploaded_clim_vertical): st.error("Climatology files are missing. Please upload both surface and vertical climatology files.") st.stop() dataset = Merra2Dataset( time_range=time_range, lead_times=lead_times, input_times=input_times, data_path_surface=Path("Prithvi-WxC/examples/merra-2"), data_path_vertical=Path("Prithvi-WxC/examples/merra-2"), climatology_path_surface=Path("Prithvi-WxC/examples/climatology"), climatology_path_vertical=Path("Prithvi-WxC/examples/climatology"), surface_vars=surface_vars, static_surface_vars=static_surface_vars, vertical_vars=vertical_vars, levels=levels, positional_encoding=positional_encoding, ) assert len(dataset) > 0, "There doesn't seem to be any valid data." st.success("Dataset initialized successfully.") except Exception as e: st.error("Error initializing dataset:") st.error(traceback.format_exc()) st.stop() # Load scalers try: with st.spinner("Loading scalers..."): # Assuming the scaler paths are the same as climatology paths surf_in_scal_path = clim_surf_path vert_in_scal_path = clim_vert_path surf_out_scal_path = Path(clim_surf_path.parent) / "anomaly_variance_surface.nc" vert_out_scal_path = Path(clim_vert_path.parent) / "anomaly_variance_vertical.nc" # Check if output scaler files exist if not surf_out_scal_path.exists() or not vert_out_scal_path.exists(): st.error("Anomaly variance scaler files are missing.") st.stop() in_mu, in_sig = input_scalers( surface_vars, vertical_vars, levels, surf_in_scal_path, vert_in_scal_path, ) output_sig = output_scalers( surface_vars, vertical_vars, levels, surf_out_scal_path, vert_out_scal_path, ) static_mu, static_sig = static_input_scalers( surf_in_scal_path, static_surface_vars, ) st.success("Scalers loaded successfully.") except Exception as e: st.error("Error loading scalers:") st.error(traceback.format_exc()) st.stop() # Load configuration try: with st.spinner("Loading configuration..."): with open(config_path, "r") as f: config = yaml.safe_load(f) # Validate config required_params = [ "in_channels", "input_size_time", "in_channels_static", "input_scalers_epsilon", "static_input_scalers_epsilon", "n_lats_px", "n_lons_px", "patch_size_px", "mask_unit_size_px", "embed_dim", "n_blocks_encoder", "n_blocks_decoder", "mlp_multiplier", "n_heads", "dropout", "drop_path", "parameter_dropout" ] missing_params = [param for param in required_params if param not in config.get("params", {})] if missing_params: st.error(f"Missing configuration parameters: {missing_params}") st.stop() st.success("Configuration loaded successfully.") except Exception as e: st.error("Error loading configuration:") st.error(traceback.format_exc()) st.stop() # Initialize the model try: with st.spinner("Initializing model..."): model = PrithviWxC( in_channels=config["params"]["in_channels"], input_size_time=config["params"]["input_size_time"], in_channels_static=config["params"]["in_channels_static"], input_scalers_mu=in_mu, input_scalers_sigma=in_sig, input_scalers_epsilon=config["params"]["input_scalers_epsilon"], static_input_scalers_mu=static_mu, static_input_scalers_sigma=static_sig, static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"], output_scalers=output_sig**0.5, n_lats_px=config["params"]["n_lats_px"], n_lons_px=config["params"]["n_lons_px"], patch_size_px=config["params"]["patch_size_px"], mask_unit_size_px=config["params"]["mask_unit_size_px"], mask_ratio_inputs=masking_ratio, embed_dim=config["params"]["embed_dim"], n_blocks_encoder=config["params"]["n_blocks_encoder"], n_blocks_decoder=config["params"]["n_blocks_decoder"], mlp_multiplier=config["params"]["mlp_multiplier"], n_heads=config["params"]["n_heads"], dropout=config["params"]["dropout"], drop_path=config["params"]["drop_path"], parameter_dropout=config["params"]["parameter_dropout"], residual=residual, masking_mode=masking_mode, decoder_shifting=decoder_shifting, positional_encoding=positional_encoding, checkpoint_encoder=[], checkpoint_decoder=[], ) st.success("Model initialized successfully.") except Exception as e: st.error("Error initializing model:") st.error(traceback.format_exc()) st.stop() # Load model weights try: with st.spinner("Loading model weights..."): state_dict = torch.load(weights_path, map_location=device) if "model_state" in state_dict: state_dict = state_dict["model_state"] model.load_state_dict(state_dict, strict=True) model.to(device) st.success("Model weights loaded successfully.") except Exception as e: st.error("Error loading model weights:") st.error(traceback.format_exc()) st.stop() # Prepare data batch try: with st.spinner("Preparing data batch..."): data = next(iter(dataset)) batch = preproc([data], padding) for k, v in batch.items(): if isinstance(v, torch.Tensor): batch[k] = v.to(device) st.success("Data batch prepared successfully.") except Exception as e: st.error("Error preparing data batch:") st.error(traceback.format_exc()) st.stop() # Run inference try: with st.spinner("Running model inference..."): rng_state_1 = torch.get_rng_state() with torch.no_grad(): model.eval() out = model(batch) st.success("Model inference completed successfully.") except Exception as e: st.error("Error during model inference:") st.error(traceback.format_exc()) st.stop() # Display output st.header("Inference Results") st.write(out) # Adjust based on the structure of 'out' # Optionally, provide download links or visualizations # For example, if 'out' contains tensors or dataframes: # st.write("Output Tensor:", out["some_key"].cpu().numpy()) else: st.info("Please upload the necessary files and click 'Run Inference' to start.")