File size: 5,564 Bytes
100edb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import streamlit as st
import tempfile
from pathlib import Path
import torch
import traceback
import yaml
# from Prithvi import PrithviWxC, Merra2Dataset, input_scalers, output_scalers, static_input_scalers, preproc
def prithvi_config_ui():
st.subheader("Prithvi Model Configuration")
param1 = st.number_input("Prithvi Parameter 1", value=10, step=1)
param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi")
config = {"param1": param1, "param2": param2}
st.markdown("### Upload Data Files for Prithvi Model")
uploaded_surface_files = st.file_uploader(
"Upload Surface Data Files",
type=["nc", "netcdf"],
accept_multiple_files=True,
key="surface_uploader",
)
uploaded_vertical_files = st.file_uploader(
"Upload Vertical Data Files",
type=["nc", "netcdf"],
accept_multiple_files=True,
key="vertical_uploader",
)
st.markdown("### Upload Climatology Files (If Missing)")
default_clim_dir = Path("Prithvi-WxC/examples/climatology")
surf_in_scal_path = default_clim_dir / "musigma_surface.nc"
vert_in_scal_path = default_clim_dir / "musigma_vertical.nc"
surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc"
vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc"
clim_files_exist = all([
surf_in_scal_path.exists(),
vert_in_scal_path.exists(),
surf_out_scal_path.exists(),
vert_out_scal_path.exists(),
])
if not clim_files_exist:
st.warning("Climatology files are missing.")
uploaded_clim_surface = st.file_uploader(
"Upload Climatology Surface File",
type=["nc", "netcdf"],
key="clim_surface_uploader",
)
uploaded_clim_vertical = st.file_uploader(
"Upload Climatology Vertical File",
type=["nc", "netcdf"],
key="clim_vertical_uploader",
)
if uploaded_clim_surface and uploaded_clim_vertical:
clim_temp_dir = tempfile.mkdtemp()
clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name
with open(clim_surf_path, "wb") as f:
f.write(uploaded_clim_surface.getbuffer())
clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name
with open(clim_vert_path, "wb") as f:
f.write(uploaded_clim_vertical.getbuffer())
st.success("Climatology files uploaded and saved.")
else:
st.warning("Please upload both climatology surface and vertical files.")
clim_surf_path, clim_vert_path = None, None
else:
clim_surf_path = surf_in_scal_path
clim_vert_path = vert_in_scal_path
uploaded_config = st.file_uploader(
"Upload config.yaml",
type=["yaml", "yml"],
key="config_uploader",
)
if uploaded_config:
temp_config = tempfile.mktemp(suffix=".yaml")
with open(temp_config, "wb") as f:
f.write(uploaded_config.getbuffer())
config_path = Path(temp_config)
st.success("Config.yaml uploaded and saved.")
else:
config_path = Path("Prithvi-WxC/examples/config.yaml")
if not config_path.exists():
st.error("Default config.yaml not found. Please upload a config file.")
st.stop()
uploaded_weights = st.file_uploader(
"Upload Model Weights (.pt)",
type=["pt"],
key="weights_uploader",
)
if uploaded_weights:
temp_weights = tempfile.mktemp(suffix=".pt")
with open(temp_weights, "wb") as f:
f.write(uploaded_weights.getbuffer())
weights_path = Path(temp_weights)
st.success("Model weights uploaded and saved.")
else:
weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt")
if not weights_path.exists():
st.error("Default model weights not found. Please upload model weights.")
st.stop()
return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path
def initialize_prithvi_model(config, config_path, weights_path, device):
# Load the configuration
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
# Validate and load scalers, etc.
# Insert your logic here (loading scalers, etc.)
# Example (pseudo-code):
# in_mu, in_sig = input_scalers(...)
# output_sig = output_scalers(...)
# static_mu, static_sig = static_input_scalers(...)
# from Prithvi import PrithviWxC
# model = PrithviWxC(**cfg["params"], ...)
# state_dict = torch.load(weights_path, map_location=device)
# model.load_state_dict(state_dict["model_state"] if "model_state" in state_dict else state_dict, strict=True)
# model.to(device)
# Placeholder returns until actual logic is implemented
model = None
in_mu, in_sig, output_sig, static_mu, static_sig = None, None, None, None, None
return model, in_mu, in_sig, output_sig, static_mu, static_sig
def prepare_prithvi_batch(uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, device):
# Prepare your dataset and batch for Prithvi inference
# dataset = Merra2Dataset(...)
# data = next(iter(dataset))
# batch = preproc([data], padding={...})
# for k,v in batch.items():
# if isinstance(v, torch.Tensor):
# batch[k] = v.to(device)
# Placeholder until implemented
return None
|