import streamlit as st import tempfile from pathlib import Path import torch import traceback import yaml # from Prithvi import PrithviWxC, Merra2Dataset, input_scalers, output_scalers, static_input_scalers, preproc def prithvi_config_ui(): st.subheader("Prithvi Model Configuration") param1 = st.number_input("Prithvi Parameter 1", value=10, step=1) param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi") config = {"param1": param1, "param2": param2} st.markdown("### Upload Data Files for Prithvi Model") uploaded_surface_files = st.file_uploader( "Upload Surface Data Files", type=["nc", "netcdf"], accept_multiple_files=True, key="surface_uploader", ) uploaded_vertical_files = st.file_uploader( "Upload Vertical Data Files", type=["nc", "netcdf"], accept_multiple_files=True, key="vertical_uploader", ) st.markdown("### Upload Climatology Files (If Missing)") default_clim_dir = Path("Prithvi-WxC/examples/climatology") surf_in_scal_path = default_clim_dir / "musigma_surface.nc" vert_in_scal_path = default_clim_dir / "musigma_vertical.nc" surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc" vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc" clim_files_exist = all([ surf_in_scal_path.exists(), vert_in_scal_path.exists(), surf_out_scal_path.exists(), vert_out_scal_path.exists(), ]) if not clim_files_exist: st.warning("Climatology files are missing.") uploaded_clim_surface = st.file_uploader( "Upload Climatology Surface File", type=["nc", "netcdf"], key="clim_surface_uploader", ) uploaded_clim_vertical = st.file_uploader( "Upload Climatology Vertical File", type=["nc", "netcdf"], key="clim_vertical_uploader", ) if uploaded_clim_surface and uploaded_clim_vertical: clim_temp_dir = tempfile.mkdtemp() clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name with open(clim_surf_path, "wb") as f: f.write(uploaded_clim_surface.getbuffer()) clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name with open(clim_vert_path, "wb") as f: f.write(uploaded_clim_vertical.getbuffer()) st.success("Climatology files uploaded and saved.") else: st.warning("Please upload both climatology surface and vertical files.") clim_surf_path, clim_vert_path = None, None else: clim_surf_path = surf_in_scal_path clim_vert_path = vert_in_scal_path uploaded_config = st.file_uploader( "Upload config.yaml", type=["yaml", "yml"], key="config_uploader", ) if uploaded_config: temp_config = tempfile.mktemp(suffix=".yaml") with open(temp_config, "wb") as f: f.write(uploaded_config.getbuffer()) config_path = Path(temp_config) st.success("Config.yaml uploaded and saved.") else: config_path = Path("Prithvi-WxC/examples/config.yaml") if not config_path.exists(): st.error("Default config.yaml not found. Please upload a config file.") st.stop() uploaded_weights = st.file_uploader( "Upload Model Weights (.pt)", type=["pt"], key="weights_uploader", ) if uploaded_weights: temp_weights = tempfile.mktemp(suffix=".pt") with open(temp_weights, "wb") as f: f.write(uploaded_weights.getbuffer()) weights_path = Path(temp_weights) st.success("Model weights uploaded and saved.") else: weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt") if not weights_path.exists(): st.error("Default model weights not found. Please upload model weights.") st.stop() return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path def initialize_prithvi_model(config, config_path, weights_path, device): # Load the configuration with open(config_path, "r") as f: cfg = yaml.safe_load(f) # Validate and load scalers, etc. # Insert your logic here (loading scalers, etc.) # Example (pseudo-code): # in_mu, in_sig = input_scalers(...) # output_sig = output_scalers(...) # static_mu, static_sig = static_input_scalers(...) # from Prithvi import PrithviWxC # model = PrithviWxC(**cfg["params"], ...) # state_dict = torch.load(weights_path, map_location=device) # model.load_state_dict(state_dict["model_state"] if "model_state" in state_dict else state_dict, strict=True) # model.to(device) # Placeholder returns until actual logic is implemented model = None in_mu, in_sig, output_sig, static_mu, static_sig = None, None, None, None, None return model, in_mu, in_sig, output_sig, static_mu, static_sig def prepare_prithvi_batch(uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, device): # Prepare your dataset and batch for Prithvi inference # dataset = Merra2Dataset(...) # data = next(iter(dataset)) # batch = preproc([data], padding={...}) # for k,v in batch.items(): # if isinstance(v, torch.Tensor): # batch[k] = v.to(device) # Placeholder until implemented return None