Spaces:
Build error
Build error
# import streamlit as st | |
# x = st.slider('Select a value') | |
# st.write(x, 'squared is', x * x) | |
import streamlit as st | |
import random | |
from pathlib import Path | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import torch | |
from huggingface_hub import hf_hub_download, snapshot_download | |
import tarfile | |
import os | |
import sys | |
import yaml | |
st.title("PrithviWxC Model Inference") | |
st.write("Setting up environment...") | |
# Set up torch backends and seeds | |
torch.jit.enable_onednn_fusion(True) | |
if torch.cuda.is_available(): | |
st.write(f"Using device: {torch.cuda.get_device_name()}") | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.deterministic = True | |
random.seed(42) | |
if torch.cuda.is_available(): | |
torch.cuda.manual_seed(42) | |
torch.manual_seed(42) | |
np.random.seed(42) | |
# Set device | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
else: | |
device = torch.device("cpu") | |
st.write(f"Using device: {device}") | |
# Download and extract PrithviWxC module | |
st.write("Downloading and setting up PrithviWxC module...") | |
module_tar_path = hf_hub_download( | |
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
filename="PrithviWxC.tar.gz", | |
local_dir=".", | |
force_download=True | |
) | |
with tarfile.open(module_tar_path, "r:gz") as tar: | |
tar.extractall(path=".") | |
# Add the module path to sys.path | |
sys.path.append(os.path.abspath("./PrithviWxC")) | |
st.write("PrithviWxC module imported successfully.") | |
# Now import the module | |
from PrithviWxC.dataloaders.merra2 import Merra2Dataset, input_scalers, output_scalers, static_input_scalers, preproc | |
from PrithviWxC.model import PrithviWxC | |
# Variables and times | |
surface_vars = [ | |
"EFLUX", | |
"GWETROOT", | |
"HFLUX", | |
"LAI", | |
"LWGAB", | |
"LWGEM", | |
"LWTUP", | |
"PS", | |
"QV2M", | |
"SLP", | |
"SWGNT", | |
"SWTNT", | |
"T2M", | |
"TQI", | |
"TQL", | |
"TQV", | |
"TS", | |
"U10M", | |
"V10M", | |
"Z0M", | |
] | |
static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"] | |
vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"] | |
levels = [ | |
34.0, | |
39.0, | |
41.0, | |
43.0, | |
44.0, | |
45.0, | |
48.0, | |
51.0, | |
53.0, | |
56.0, | |
63.0, | |
68.0, | |
71.0, | |
72.0, | |
] | |
padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]} | |
st.write("Setting up dataset parameters...") | |
# User inputs for lead times and input times | |
lead_time = st.number_input("Lead Time (hours)", min_value=1, max_value=24, value=6) | |
input_time = st.number_input("Input Time Difference (hours)", min_value=-24, max_value=0, value=-6) | |
lead_times = [lead_time] # This variable can be changed to change the task | |
input_times = [input_time] # This variable can be changed to change the task | |
# Data file | |
time_range = ("2020-01-01T00:00:00", "2020-01-01T23:59:59") | |
st.write("Downloading data files...") | |
surf_dir = Path("./merra-2") | |
snapshot_download( | |
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
allow_patterns="merra-2/MERRA2_sfc_2020010[1].nc", | |
local_dir=".", | |
force_download=True, | |
) | |
vert_dir = Path("./merra-2") | |
snapshot_download( | |
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
allow_patterns="merra-2/MERRA_pres_2020010[1].nc", | |
local_dir=".", | |
force_download=True, | |
) | |
# Climatology | |
surf_clim_dir = Path("./climatology") | |
snapshot_download( | |
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
allow_patterns="climatology/climate_surface_doy00[1]*.nc", | |
local_dir=".", | |
force_download=True, | |
) | |
vert_clim_dir = Path("./climatology") | |
snapshot_download( | |
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
allow_patterns="climatology/climate_vertical_doy00[1]*.nc", | |
local_dir=".", | |
force_download=True, | |
) | |
st.write("Setting positional encoding...") | |
positional_encoding = "fourier" | |
st.write("Initializing dataset...") | |
dataset = Merra2Dataset( | |
time_range=time_range, | |
lead_times=lead_times, | |
input_times=input_times, | |
data_path_surface=surf_dir, | |
data_path_vertical=vert_dir, | |
climatology_path_surface=surf_clim_dir, | |
climatology_path_vertical=vert_clim_dir, | |
surface_vars=surface_vars, | |
static_surface_vars=static_surface_vars, | |
vertical_vars=vertical_vars, | |
levels=levels, | |
positional_encoding=positional_encoding, | |
) | |
assert len(dataset) > 0, "There doesn't seem to be any valid data." | |
st.write("Loading scalers...") | |
surf_in_scal_path = Path("./climatology/musigma_surface.nc") | |
hf_hub_download( | |
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
filename=f"climatology/{surf_in_scal_path.name}", | |
local_dir=".", | |
force_download=True, | |
) | |
vert_in_scal_path = Path("./climatology/musigma_vertical.nc") | |
hf_hub_download( | |
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
filename=f"climatology/{vert_in_scal_path.name}", | |
local_dir=".", | |
force_download=True, | |
) | |
surf_out_scal_path = Path("./climatology/anomaly_variance_surface.nc") | |
hf_hub_download( | |
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
filename=f"climatology/{surf_out_scal_path.name}", | |
local_dir=".", | |
force_download=True, | |
) | |
vert_out_scal_path = Path("./climatology/anomaly_variance_vertical.nc") | |
hf_hub_download( | |
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
filename=f"climatology/{vert_out_scal_path.name}", | |
local_dir=".", | |
force_download=True, | |
) | |
in_mu, in_sig = input_scalers( | |
surface_vars, | |
vertical_vars, | |
levels, | |
surf_in_scal_path, | |
vert_in_scal_path, | |
) | |
output_sig = output_scalers( | |
surface_vars, | |
vertical_vars, | |
levels, | |
surf_out_scal_path, | |
vert_out_scal_path, | |
) | |
static_mu, static_sig = static_input_scalers( | |
surf_in_scal_path, | |
static_surface_vars, | |
) | |
st.write("Setting up model...") | |
residual = "climate" | |
masking_mode = "local" | |
decoder_shifting = True | |
masking_ratio = 0.99 | |
# Load model config | |
hf_hub_download( | |
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
filename="config.yaml", | |
local_dir=".", | |
force_download=True, | |
) | |
with open("./config.yaml", "r") as f: | |
config = yaml.safe_load(f) | |
model = PrithviWxC( | |
in_channels=config["params"]["in_channels"], | |
input_size_time=config["params"]["input_size_time"], | |
in_channels_static=config["params"]["in_channels_static"], | |
input_scalers_mu=in_mu, | |
input_scalers_sigma=in_sig, | |
input_scalers_epsilon=config["params"]["input_scalers_epsilon"], | |
static_input_scalers_mu=static_mu, | |
static_input_scalers_sigma=static_sig, | |
static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"], | |
output_scalers=output_sig**0.5, | |
n_lats_px=config["params"]["n_lats_px"], | |
n_lons_px=config["params"]["n_lons_px"], | |
patch_size_px=config["params"]["patch_size_px"], | |
mask_unit_size_px=config["params"]["mask_unit_size_px"], | |
mask_ratio_inputs=masking_ratio, | |
embed_dim=config["params"]["embed_dim"], | |
n_blocks_encoder=config["params"]["n_blocks_encoder"], | |
n_blocks_decoder=config["params"]["n_blocks_decoder"], | |
mlp_multiplier=config["params"]["mlp_multiplier"], | |
n_heads=config["params"]["n_heads"], | |
dropout=config["params"]["dropout"], | |
drop_path=config["params"]["drop_path"], | |
parameter_dropout=config["params"]["parameter_dropout"], | |
residual=residual, | |
masking_mode=masking_mode, | |
decoder_shifting=decoder_shifting, | |
positional_encoding=positional_encoding, | |
checkpoint_encoder=[], | |
checkpoint_decoder=[], | |
) | |
st.write("Loading model weights...") | |
weights_path = Path("./weights/prithvi.wxc.2300m.v1.pt") | |
hf_hub_download( | |
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1", | |
filename=weights_path.name, | |
local_dir="./weights", | |
force_download=True, | |
) | |
state_dict = torch.load(weights_path, map_location=device) | |
if "model_state" in state_dict: | |
state_dict = state_dict["model_state"] | |
model.load_state_dict(state_dict, strict=True) | |
model = model.to(device) | |
st.write("Model loaded and ready.") | |
if st.button("Run Inference"): | |
st.write("Running inference...") | |
data = next(iter(dataset)) | |
batch = preproc([data], padding) | |
for k, v in batch.items(): | |
if isinstance(v, torch.Tensor): | |
batch[k] = v.to(device) | |
with torch.no_grad(): | |
model.eval() | |
out = model(batch) | |
st.write("Inference completed. Generating plot...") | |
t2m = out[0, 12].cpu().numpy() | |
lat = np.linspace(-90, 90, out.shape[-2]) | |
lon = np.linspace(-180, 180, out.shape[-1]) | |
X, Y = np.meshgrid(lon, lat) | |
fig, ax = plt.subplots() | |
cs = ax.contourf(X, Y, t2m, 100) | |
ax.set_aspect("equal") | |
plt.colorbar(cs) | |
st.pyplot(fig) | |
st.write("Plot generated.") | |