Climate-ML-Foundation-Models / prithvi_utils.py
qq1990's picture
init
100edb4
import streamlit as st
import tempfile
from pathlib import Path
import torch
import traceback
import yaml
# from Prithvi import PrithviWxC, Merra2Dataset, input_scalers, output_scalers, static_input_scalers, preproc
def prithvi_config_ui():
st.subheader("Prithvi Model Configuration")
param1 = st.number_input("Prithvi Parameter 1", value=10, step=1)
param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi")
config = {"param1": param1, "param2": param2}
st.markdown("### Upload Data Files for Prithvi Model")
uploaded_surface_files = st.file_uploader(
"Upload Surface Data Files",
type=["nc", "netcdf"],
accept_multiple_files=True,
key="surface_uploader",
)
uploaded_vertical_files = st.file_uploader(
"Upload Vertical Data Files",
type=["nc", "netcdf"],
accept_multiple_files=True,
key="vertical_uploader",
)
st.markdown("### Upload Climatology Files (If Missing)")
default_clim_dir = Path("Prithvi-WxC/examples/climatology")
surf_in_scal_path = default_clim_dir / "musigma_surface.nc"
vert_in_scal_path = default_clim_dir / "musigma_vertical.nc"
surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc"
vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc"
clim_files_exist = all([
surf_in_scal_path.exists(),
vert_in_scal_path.exists(),
surf_out_scal_path.exists(),
vert_out_scal_path.exists(),
])
if not clim_files_exist:
st.warning("Climatology files are missing.")
uploaded_clim_surface = st.file_uploader(
"Upload Climatology Surface File",
type=["nc", "netcdf"],
key="clim_surface_uploader",
)
uploaded_clim_vertical = st.file_uploader(
"Upload Climatology Vertical File",
type=["nc", "netcdf"],
key="clim_vertical_uploader",
)
if uploaded_clim_surface and uploaded_clim_vertical:
clim_temp_dir = tempfile.mkdtemp()
clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name
with open(clim_surf_path, "wb") as f:
f.write(uploaded_clim_surface.getbuffer())
clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name
with open(clim_vert_path, "wb") as f:
f.write(uploaded_clim_vertical.getbuffer())
st.success("Climatology files uploaded and saved.")
else:
st.warning("Please upload both climatology surface and vertical files.")
clim_surf_path, clim_vert_path = None, None
else:
clim_surf_path = surf_in_scal_path
clim_vert_path = vert_in_scal_path
uploaded_config = st.file_uploader(
"Upload config.yaml",
type=["yaml", "yml"],
key="config_uploader",
)
if uploaded_config:
temp_config = tempfile.mktemp(suffix=".yaml")
with open(temp_config, "wb") as f:
f.write(uploaded_config.getbuffer())
config_path = Path(temp_config)
st.success("Config.yaml uploaded and saved.")
else:
config_path = Path("Prithvi-WxC/examples/config.yaml")
if not config_path.exists():
st.error("Default config.yaml not found. Please upload a config file.")
st.stop()
uploaded_weights = st.file_uploader(
"Upload Model Weights (.pt)",
type=["pt"],
key="weights_uploader",
)
if uploaded_weights:
temp_weights = tempfile.mktemp(suffix=".pt")
with open(temp_weights, "wb") as f:
f.write(uploaded_weights.getbuffer())
weights_path = Path(temp_weights)
st.success("Model weights uploaded and saved.")
else:
weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt")
if not weights_path.exists():
st.error("Default model weights not found. Please upload model weights.")
st.stop()
return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path
def initialize_prithvi_model(config, config_path, weights_path, device):
# Load the configuration
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
# Validate and load scalers, etc.
# Insert your logic here (loading scalers, etc.)
# Example (pseudo-code):
# in_mu, in_sig = input_scalers(...)
# output_sig = output_scalers(...)
# static_mu, static_sig = static_input_scalers(...)
# from Prithvi import PrithviWxC
# model = PrithviWxC(**cfg["params"], ...)
# state_dict = torch.load(weights_path, map_location=device)
# model.load_state_dict(state_dict["model_state"] if "model_state" in state_dict else state_dict, strict=True)
# model.to(device)
# Placeholder returns until actual logic is implemented
model = None
in_mu, in_sig, output_sig, static_mu, static_sig = None, None, None, None, None
return model, in_mu, in_sig, output_sig, static_mu, static_sig
def prepare_prithvi_batch(uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, device):
# Prepare your dataset and batch for Prithvi inference
# dataset = Merra2Dataset(...)
# data = next(iter(dataset))
# batch = preproc([data], padding={...})
# for k,v in batch.items():
# if isinstance(v, torch.Tensor):
# batch[k] = v.to(device)
# Placeholder until implemented
return None