Prithvi-Climate / app.py
qq1990's picture
Update app.py
8a18a38 verified
# import streamlit as st
# x = st.slider('Select a value')
# st.write(x, 'squared is', x * x)
import streamlit as st
import random
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
from huggingface_hub import hf_hub_download, snapshot_download
import tarfile
import os
import sys
import yaml
st.title("PrithviWxC Model Inference")
st.write("Setting up environment...")
# Set up torch backends and seeds
torch.jit.enable_onednn_fusion(True)
if torch.cuda.is_available():
st.write(f"Using device: {torch.cuda.get_device_name()}")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
random.seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
torch.manual_seed(42)
np.random.seed(42)
# Set device
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
st.write(f"Using device: {device}")
# Download and extract PrithviWxC module
st.write("Downloading and setting up PrithviWxC module...")
module_tar_path = hf_hub_download(
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
filename="PrithviWxC.tar.gz",
local_dir=".",
force_download=True
)
with tarfile.open(module_tar_path, "r:gz") as tar:
tar.extractall(path=".")
# Add the module path to sys.path
sys.path.append(os.path.abspath("./PrithviWxC"))
st.write("PrithviWxC module imported successfully.")
# Now import the module
from PrithviWxC.dataloaders.merra2 import Merra2Dataset, input_scalers, output_scalers, static_input_scalers, preproc
from PrithviWxC.model import PrithviWxC
# Variables and times
surface_vars = [
"EFLUX",
"GWETROOT",
"HFLUX",
"LAI",
"LWGAB",
"LWGEM",
"LWTUP",
"PS",
"QV2M",
"SLP",
"SWGNT",
"SWTNT",
"T2M",
"TQI",
"TQL",
"TQV",
"TS",
"U10M",
"V10M",
"Z0M",
]
static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]
levels = [
34.0,
39.0,
41.0,
43.0,
44.0,
45.0,
48.0,
51.0,
53.0,
56.0,
63.0,
68.0,
71.0,
72.0,
]
padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]}
st.write("Setting up dataset parameters...")
# User inputs for lead times and input times
lead_time = st.number_input("Lead Time (hours)", min_value=1, max_value=24, value=6)
input_time = st.number_input("Input Time Difference (hours)", min_value=-24, max_value=0, value=-6)
lead_times = [lead_time] # This variable can be changed to change the task
input_times = [input_time] # This variable can be changed to change the task
# Data file
time_range = ("2020-01-01T00:00:00", "2020-01-01T23:59:59")
st.write("Downloading data files...")
surf_dir = Path("./merra-2")
snapshot_download(
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
allow_patterns="merra-2/MERRA2_sfc_2020010[1].nc",
local_dir=".",
force_download=True,
)
vert_dir = Path("./merra-2")
snapshot_download(
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
allow_patterns="merra-2/MERRA_pres_2020010[1].nc",
local_dir=".",
force_download=True,
)
# Climatology
surf_clim_dir = Path("./climatology")
snapshot_download(
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
allow_patterns="climatology/climate_surface_doy00[1]*.nc",
local_dir=".",
force_download=True,
)
vert_clim_dir = Path("./climatology")
snapshot_download(
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
allow_patterns="climatology/climate_vertical_doy00[1]*.nc",
local_dir=".",
force_download=True,
)
st.write("Setting positional encoding...")
positional_encoding = "fourier"
st.write("Initializing dataset...")
dataset = Merra2Dataset(
time_range=time_range,
lead_times=lead_times,
input_times=input_times,
data_path_surface=surf_dir,
data_path_vertical=vert_dir,
climatology_path_surface=surf_clim_dir,
climatology_path_vertical=vert_clim_dir,
surface_vars=surface_vars,
static_surface_vars=static_surface_vars,
vertical_vars=vertical_vars,
levels=levels,
positional_encoding=positional_encoding,
)
assert len(dataset) > 0, "There doesn't seem to be any valid data."
st.write("Loading scalers...")
surf_in_scal_path = Path("./climatology/musigma_surface.nc")
hf_hub_download(
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
filename=f"climatology/{surf_in_scal_path.name}",
local_dir=".",
force_download=True,
)
vert_in_scal_path = Path("./climatology/musigma_vertical.nc")
hf_hub_download(
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
filename=f"climatology/{vert_in_scal_path.name}",
local_dir=".",
force_download=True,
)
surf_out_scal_path = Path("./climatology/anomaly_variance_surface.nc")
hf_hub_download(
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
filename=f"climatology/{surf_out_scal_path.name}",
local_dir=".",
force_download=True,
)
vert_out_scal_path = Path("./climatology/anomaly_variance_vertical.nc")
hf_hub_download(
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
filename=f"climatology/{vert_out_scal_path.name}",
local_dir=".",
force_download=True,
)
in_mu, in_sig = input_scalers(
surface_vars,
vertical_vars,
levels,
surf_in_scal_path,
vert_in_scal_path,
)
output_sig = output_scalers(
surface_vars,
vertical_vars,
levels,
surf_out_scal_path,
vert_out_scal_path,
)
static_mu, static_sig = static_input_scalers(
surf_in_scal_path,
static_surface_vars,
)
st.write("Setting up model...")
residual = "climate"
masking_mode = "local"
decoder_shifting = True
masking_ratio = 0.99
# Load model config
hf_hub_download(
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
filename="config.yaml",
local_dir=".",
force_download=True,
)
with open("./config.yaml", "r") as f:
config = yaml.safe_load(f)
model = PrithviWxC(
in_channels=config["params"]["in_channels"],
input_size_time=config["params"]["input_size_time"],
in_channels_static=config["params"]["in_channels_static"],
input_scalers_mu=in_mu,
input_scalers_sigma=in_sig,
input_scalers_epsilon=config["params"]["input_scalers_epsilon"],
static_input_scalers_mu=static_mu,
static_input_scalers_sigma=static_sig,
static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"],
output_scalers=output_sig**0.5,
n_lats_px=config["params"]["n_lats_px"],
n_lons_px=config["params"]["n_lons_px"],
patch_size_px=config["params"]["patch_size_px"],
mask_unit_size_px=config["params"]["mask_unit_size_px"],
mask_ratio_inputs=masking_ratio,
embed_dim=config["params"]["embed_dim"],
n_blocks_encoder=config["params"]["n_blocks_encoder"],
n_blocks_decoder=config["params"]["n_blocks_decoder"],
mlp_multiplier=config["params"]["mlp_multiplier"],
n_heads=config["params"]["n_heads"],
dropout=config["params"]["dropout"],
drop_path=config["params"]["drop_path"],
parameter_dropout=config["params"]["parameter_dropout"],
residual=residual,
masking_mode=masking_mode,
decoder_shifting=decoder_shifting,
positional_encoding=positional_encoding,
checkpoint_encoder=[],
checkpoint_decoder=[],
)
st.write("Loading model weights...")
weights_path = Path("./weights/prithvi.wxc.2300m.v1.pt")
hf_hub_download(
repo_id="Prithvi-WxC/prithvi.wxc.2300m.v1",
filename=weights_path.name,
local_dir="./weights",
force_download=True,
)
state_dict = torch.load(weights_path, map_location=device)
if "model_state" in state_dict:
state_dict = state_dict["model_state"]
model.load_state_dict(state_dict, strict=True)
model = model.to(device)
st.write("Model loaded and ready.")
if st.button("Run Inference"):
st.write("Running inference...")
data = next(iter(dataset))
batch = preproc([data], padding)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
with torch.no_grad():
model.eval()
out = model(batch)
st.write("Inference completed. Generating plot...")
t2m = out[0, 12].cpu().numpy()
lat = np.linspace(-90, 90, out.shape[-2])
lon = np.linspace(-180, 180, out.shape[-1])
X, Y = np.meshgrid(lon, lat)
fig, ax = plt.subplots()
cs = ax.contourf(X, Y, t2m, 100)
ax.set_aspect("equal")
plt.colorbar(cs)
st.pyplot(fig)
st.write("Plot generated.")