|
import streamlit as st |
|
import tempfile |
|
from pathlib import Path |
|
import torch |
|
import traceback |
|
import yaml |
|
|
|
|
|
def prithvi_config_ui(): |
|
st.subheader("Prithvi Model Configuration") |
|
param1 = st.number_input("Prithvi Parameter 1", value=10, step=1) |
|
param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi") |
|
|
|
config = {"param1": param1, "param2": param2} |
|
|
|
st.markdown("### Upload Data Files for Prithvi Model") |
|
uploaded_surface_files = st.file_uploader( |
|
"Upload Surface Data Files", |
|
type=["nc", "netcdf"], |
|
accept_multiple_files=True, |
|
key="surface_uploader", |
|
) |
|
|
|
uploaded_vertical_files = st.file_uploader( |
|
"Upload Vertical Data Files", |
|
type=["nc", "netcdf"], |
|
accept_multiple_files=True, |
|
key="vertical_uploader", |
|
) |
|
|
|
st.markdown("### Upload Climatology Files (If Missing)") |
|
default_clim_dir = Path("Prithvi-WxC/examples/climatology") |
|
surf_in_scal_path = default_clim_dir / "musigma_surface.nc" |
|
vert_in_scal_path = default_clim_dir / "musigma_vertical.nc" |
|
surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc" |
|
vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc" |
|
clim_files_exist = all([ |
|
surf_in_scal_path.exists(), |
|
vert_in_scal_path.exists(), |
|
surf_out_scal_path.exists(), |
|
vert_out_scal_path.exists(), |
|
]) |
|
|
|
if not clim_files_exist: |
|
st.warning("Climatology files are missing.") |
|
uploaded_clim_surface = st.file_uploader( |
|
"Upload Climatology Surface File", |
|
type=["nc", "netcdf"], |
|
key="clim_surface_uploader", |
|
) |
|
uploaded_clim_vertical = st.file_uploader( |
|
"Upload Climatology Vertical File", |
|
type=["nc", "netcdf"], |
|
key="clim_vertical_uploader", |
|
) |
|
if uploaded_clim_surface and uploaded_clim_vertical: |
|
clim_temp_dir = tempfile.mkdtemp() |
|
clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name |
|
with open(clim_surf_path, "wb") as f: |
|
f.write(uploaded_clim_surface.getbuffer()) |
|
clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name |
|
with open(clim_vert_path, "wb") as f: |
|
f.write(uploaded_clim_vertical.getbuffer()) |
|
st.success("Climatology files uploaded and saved.") |
|
else: |
|
st.warning("Please upload both climatology surface and vertical files.") |
|
clim_surf_path, clim_vert_path = None, None |
|
else: |
|
clim_surf_path = surf_in_scal_path |
|
clim_vert_path = vert_in_scal_path |
|
|
|
uploaded_config = st.file_uploader( |
|
"Upload config.yaml", |
|
type=["yaml", "yml"], |
|
key="config_uploader", |
|
) |
|
|
|
if uploaded_config: |
|
temp_config = tempfile.mktemp(suffix=".yaml") |
|
with open(temp_config, "wb") as f: |
|
f.write(uploaded_config.getbuffer()) |
|
config_path = Path(temp_config) |
|
st.success("Config.yaml uploaded and saved.") |
|
else: |
|
config_path = Path("Prithvi-WxC/examples/config.yaml") |
|
if not config_path.exists(): |
|
st.error("Default config.yaml not found. Please upload a config file.") |
|
st.stop() |
|
|
|
uploaded_weights = st.file_uploader( |
|
"Upload Model Weights (.pt)", |
|
type=["pt"], |
|
key="weights_uploader", |
|
) |
|
|
|
if uploaded_weights: |
|
temp_weights = tempfile.mktemp(suffix=".pt") |
|
with open(temp_weights, "wb") as f: |
|
f.write(uploaded_weights.getbuffer()) |
|
weights_path = Path(temp_weights) |
|
st.success("Model weights uploaded and saved.") |
|
else: |
|
weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt") |
|
if not weights_path.exists(): |
|
st.error("Default model weights not found. Please upload model weights.") |
|
st.stop() |
|
|
|
return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path |
|
|
|
|
|
def initialize_prithvi_model(config, config_path, weights_path, device): |
|
|
|
with open(config_path, "r") as f: |
|
cfg = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = None |
|
in_mu, in_sig, output_sig, static_mu, static_sig = None, None, None, None, None |
|
return model, in_mu, in_sig, output_sig, static_mu, static_sig |
|
|
|
|
|
def prepare_prithvi_batch(uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, device): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return None |
|
|