Spaces:
Sleeping
Sleeping
import streamlit as st | |
import tempfile | |
from pathlib import Path | |
import torch | |
import traceback | |
import yaml | |
# from Prithvi import PrithviWxC, Merra2Dataset, input_scalers, output_scalers, static_input_scalers, preproc | |
def prithvi_config_ui(): | |
st.subheader("Prithvi Model Configuration") | |
param1 = st.number_input("Prithvi Parameter 1", value=10, step=1) | |
param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi") | |
config = {"param1": param1, "param2": param2} | |
st.markdown("### Upload Data Files for Prithvi Model") | |
uploaded_surface_files = st.file_uploader( | |
"Upload Surface Data Files", | |
type=["nc", "netcdf"], | |
accept_multiple_files=True, | |
key="surface_uploader", | |
) | |
uploaded_vertical_files = st.file_uploader( | |
"Upload Vertical Data Files", | |
type=["nc", "netcdf"], | |
accept_multiple_files=True, | |
key="vertical_uploader", | |
) | |
st.markdown("### Upload Climatology Files (If Missing)") | |
default_clim_dir = Path("Prithvi-WxC/examples/climatology") | |
surf_in_scal_path = default_clim_dir / "musigma_surface.nc" | |
vert_in_scal_path = default_clim_dir / "musigma_vertical.nc" | |
surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc" | |
vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc" | |
clim_files_exist = all([ | |
surf_in_scal_path.exists(), | |
vert_in_scal_path.exists(), | |
surf_out_scal_path.exists(), | |
vert_out_scal_path.exists(), | |
]) | |
if not clim_files_exist: | |
st.warning("Climatology files are missing.") | |
uploaded_clim_surface = st.file_uploader( | |
"Upload Climatology Surface File", | |
type=["nc", "netcdf"], | |
key="clim_surface_uploader", | |
) | |
uploaded_clim_vertical = st.file_uploader( | |
"Upload Climatology Vertical File", | |
type=["nc", "netcdf"], | |
key="clim_vertical_uploader", | |
) | |
if uploaded_clim_surface and uploaded_clim_vertical: | |
clim_temp_dir = tempfile.mkdtemp() | |
clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name | |
with open(clim_surf_path, "wb") as f: | |
f.write(uploaded_clim_surface.getbuffer()) | |
clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name | |
with open(clim_vert_path, "wb") as f: | |
f.write(uploaded_clim_vertical.getbuffer()) | |
st.success("Climatology files uploaded and saved.") | |
else: | |
st.warning("Please upload both climatology surface and vertical files.") | |
clim_surf_path, clim_vert_path = None, None | |
else: | |
clim_surf_path = surf_in_scal_path | |
clim_vert_path = vert_in_scal_path | |
uploaded_config = st.file_uploader( | |
"Upload config.yaml", | |
type=["yaml", "yml"], | |
key="config_uploader", | |
) | |
if uploaded_config: | |
temp_config = tempfile.mktemp(suffix=".yaml") | |
with open(temp_config, "wb") as f: | |
f.write(uploaded_config.getbuffer()) | |
config_path = Path(temp_config) | |
st.success("Config.yaml uploaded and saved.") | |
else: | |
config_path = Path("Prithvi-WxC/examples/config.yaml") | |
if not config_path.exists(): | |
st.error("Default config.yaml not found. Please upload a config file.") | |
st.stop() | |
uploaded_weights = st.file_uploader( | |
"Upload Model Weights (.pt)", | |
type=["pt"], | |
key="weights_uploader", | |
) | |
if uploaded_weights: | |
temp_weights = tempfile.mktemp(suffix=".pt") | |
with open(temp_weights, "wb") as f: | |
f.write(uploaded_weights.getbuffer()) | |
weights_path = Path(temp_weights) | |
st.success("Model weights uploaded and saved.") | |
else: | |
weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt") | |
if not weights_path.exists(): | |
st.error("Default model weights not found. Please upload model weights.") | |
st.stop() | |
return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path | |
def initialize_prithvi_model(config, config_path, weights_path, device): | |
# Load the configuration | |
with open(config_path, "r") as f: | |
cfg = yaml.safe_load(f) | |
# Validate and load scalers, etc. | |
# Insert your logic here (loading scalers, etc.) | |
# Example (pseudo-code): | |
# in_mu, in_sig = input_scalers(...) | |
# output_sig = output_scalers(...) | |
# static_mu, static_sig = static_input_scalers(...) | |
# from Prithvi import PrithviWxC | |
# model = PrithviWxC(**cfg["params"], ...) | |
# state_dict = torch.load(weights_path, map_location=device) | |
# model.load_state_dict(state_dict["model_state"] if "model_state" in state_dict else state_dict, strict=True) | |
# model.to(device) | |
# Placeholder returns until actual logic is implemented | |
model = None | |
in_mu, in_sig, output_sig, static_mu, static_sig = None, None, None, None, None | |
return model, in_mu, in_sig, output_sig, static_mu, static_sig | |
def prepare_prithvi_batch(uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, device): | |
# Prepare your dataset and batch for Prithvi inference | |
# dataset = Merra2Dataset(...) | |
# data = next(iter(dataset)) | |
# batch = preproc([data], padding={...}) | |
# for k,v in batch.items(): | |
# if isinstance(v, torch.Tensor): | |
# batch[k] = v.to(device) | |
# Placeholder until implemented | |
return None | |