Spaces:
Sleeping
Sleeping
File size: 44,465 Bytes
100edb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 |
import streamlit as st
import torch
import random
import numpy as np
import yaml
from pathlib import Path
import tempfile
import traceback
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from Prithvi import * # Ensure this import includes your model and dataset classes
import xarray as xr
from aurora import Batch, Metadata
from aurora import Aurora, rollout
import logging
import matplotlib.pyplot as plt
import numpy as np
import cartopy.crs as ccrs
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Function to save uploaded files to temporary files and store paths in session_state
def save_uploaded_files(uploaded_files):
if 'temp_file_paths' not in st.session_state:
st.session_state.temp_file_paths = []
for uploaded_file in uploaded_files:
suffix = os.path.splitext(uploaded_file.name)[1]
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=suffix)
temp_file.write(uploaded_file.read())
temp_file.close()
st.session_state.temp_file_paths.append(temp_file.name)
# Cached function to load dataset
@st.cache_resource
def load_dataset(file_paths):
try:
ds = xr.open_mfdataset(file_paths, combine='by_coords').load()
return ds
except Exception as e:
st.error("Error loading dataset:")
st.error(traceback.format_exc())
return None
# Set page configuration
st.set_page_config(
page_title="Weather Data Processor",
layout="wide",
initial_sidebar_state="expanded",
)
# Create a header with two columns: one for the title and one for the model selector
header_col1, header_col2 = st.columns([4, 1]) # Adjust the ratio as needed
with header_col1:
st.title("🌦️ Weather & Climate Data Processor and Forecaster")
with header_col2:
st.markdown("### Select a Model")
selected_model = st.selectbox(
"",
options=["Aurora", "Climax", "Prithvi", "LSTM"],
index=0,
key="model_selector",
help="Select the model you want to use for processing the data."
)
st.write("---") # Horizontal separator
# --- Layout: Two Columns ---
left_col, right_col = st.columns([1, 2]) # Adjust column ratios as needed
with left_col:
st.header("🔧 Configuration")
# --- Dynamic Configuration Based on Selected Model ---
def get_model_configuration(model_name):
if model_name == "Prithvi":
st.subheader("Prithvi Model Configuration")
# Prithvi-specific configuration inputs
param1 = st.number_input("Prithvi Parameter 1", value=10, step=1)
param2 = st.text_input("Prithvi Parameter 2", value="default_prithvi")
# Add other Prithvi-specific parameters here
config = {
"param1": param1,
"param2": param2,
# Include other parameters as needed
}
# --- Prithvi-Specific File Uploads ---
st.markdown("### Upload Data Files for Prithvi Model")
# File uploader for surface data
uploaded_surface_files = st.file_uploader(
"Upload Surface Data Files",
type=["nc", "netcdf"],
accept_multiple_files=True,
key="surface_uploader",
)
# File uploader for vertical data
uploaded_vertical_files = st.file_uploader(
"Upload Vertical Data Files",
type=["nc", "netcdf"],
accept_multiple_files=True,
key="vertical_uploader",
)
# Handle Climatology Files
st.markdown("### Upload Climatology Files (If Missing)")
# Climatology files paths
default_clim_dir = Path("Prithvi-WxC/examples/climatology")
surf_in_scal_path = default_clim_dir / "musigma_surface.nc"
vert_in_scal_path = default_clim_dir / "musigma_vertical.nc"
surf_out_scal_path = default_clim_dir / "anomaly_variance_surface.nc"
vert_out_scal_path = default_clim_dir / "anomaly_variance_vertical.nc"
# Check if climatology files exist
clim_files_exist = all(
[
surf_in_scal_path.exists(),
vert_in_scal_path.exists(),
surf_out_scal_path.exists(),
vert_out_scal_path.exists(),
]
)
if not clim_files_exist:
st.warning("Climatology files are missing.")
uploaded_clim_surface = st.file_uploader(
"Upload Climatology Surface File",
type=["nc", "netcdf"],
key="clim_surface_uploader",
)
uploaded_clim_vertical = st.file_uploader(
"Upload Climatology Vertical File",
type=["nc", "netcdf"],
key="clim_vertical_uploader",
)
# Process uploaded climatology files
if uploaded_clim_surface and uploaded_clim_vertical:
clim_temp_dir = tempfile.mkdtemp()
clim_surf_path = Path(clim_temp_dir) / uploaded_clim_surface.name
with open(clim_surf_path, "wb") as f:
f.write(uploaded_clim_surface.getbuffer())
clim_vert_path = Path(clim_temp_dir) / uploaded_clim_vertical.name
with open(clim_vert_path, "wb") as f:
f.write(uploaded_clim_vertical.getbuffer())
st.success("Climatology files uploaded and saved.")
else:
st.warning("Please upload both climatology surface and vertical files.")
else:
clim_surf_path = surf_in_scal_path
clim_vert_path = vert_in_scal_path
# Optional: Upload config.yaml
uploaded_config = st.file_uploader(
"Upload config.yaml",
type=["yaml", "yml"],
key="config_uploader",
)
if uploaded_config:
temp_config = tempfile.mktemp(suffix=".yaml")
with open(temp_config, "wb") as f:
f.write(uploaded_config.getbuffer())
config_path = Path(temp_config)
st.success("Config.yaml uploaded and saved.")
else:
# Use default config.yaml path
config_path = Path("Prithvi-WxC/examples/config.yaml")
if not config_path.exists():
st.error("Default config.yaml not found. Please upload a config file.")
st.stop()
# Optional: Upload model weights
uploaded_weights = st.file_uploader(
"Upload Model Weights (.pt)",
type=["pt"],
key="weights_uploader",
)
if uploaded_weights:
temp_weights = tempfile.mktemp(suffix=".pt")
with open(temp_weights, "wb") as f:
f.write(uploaded_weights.getbuffer())
weights_path = Path(temp_weights)
st.success("Model weights uploaded and saved.")
else:
# Use default weights path
weights_path = Path("Prithvi-WxC/examples/weights/prithvi.wxc.2300m.v1.pt")
if not weights_path.exists():
st.error("Default model weights not found. Please upload model weights.")
st.stop()
return config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path
else:
# For other models, provide a simple file uploader
st.subheader(f"{model_name} Model Data Upload")
st.markdown("### Drag and Drop Your Data Files Here")
uploaded_files = st.file_uploader(
f"Upload Data Files for {model_name}",
accept_multiple_files=True,
key=f"{model_name.lower()}_uploader",
type=["nc", "netcdf", "nc4"],
)
return uploaded_files
# Retrieve model-specific configuration and files
if selected_model == "Prithvi":
config, uploaded_surface_files, uploaded_vertical_files, clim_surf_path, clim_vert_path, config_path, weights_path = get_model_configuration(selected_model)
else:
uploaded_files = get_model_configuration(selected_model)
st.write("---") # Horizontal separator
# --- Run Inference Button ---
if st.button("🚀 Run Inference"):
with right_col:
st.header("📈 Inference Progress & Visualization")
# Initialize device
try:
torch.jit.enable_onednn_fusion(True)
if torch.cuda.is_available():
device = torch.device("cuda")
st.write(f"Using device: **{torch.cuda.get_device_name()}**")
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
else:
device = torch.device("cpu")
st.write("Using device: **CPU**")
except Exception as e:
st.error("Error initializing device:")
st.error(traceback.format_exc())
st.stop()
# Set random seeds
try:
random.seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(42)
torch.manual_seed(42)
np.random.seed(42)
except Exception as e:
st.error("Error setting random seeds:")
st.error(traceback.format_exc())
st.stop()
# # Define variables and parameters based on dataset type
# if dataset_type == "MERRA2":
# surface_vars = [
# "EFLUX",
# "GWETROOT",
# "HFLUX",
# "LAI",
# "LWGAB",
# "LWGEM",
# "LWTUP",
# "PS",
# "QV2M",
# "SLP",
# "SWGNT",
# "SWTNT",
# "T2M",
# "TQI",
# "TQL",
# "TQV",
# "TS",
# "U10M",
# "V10M",
# "Z0M",
# ]
# static_surface_vars = ["FRACI", "FRLAND", "FROCEAN", "PHIS"]
# vertical_vars = ["CLOUD", "H", "OMEGA", "PL", "QI", "QL", "QV", "T", "U", "V"]
# levels = [
# 34.0,
# 39.0,
# 41.0,
# 43.0,
# 44.0,
# 45.0,
# 48.0,
# 51.0,
# 53.0,
# 56.0,
# 63.0,
# 68.0,
# 71.0,
# 72.0,
# ]
# elif dataset_type == "GEOS5":
# # Define GEOS5 specific variables
# surface_vars = [
# "GEOS5_EFLUX",
# "GEOS5_GWETROOT",
# "GEOS5_HFLUX",
# "GEOS5_LAI",
# "GEOS5_LWGAB",
# "GEOS5_LWGEM",
# "GEOS5_LWTUP",
# "GEOS5_PS",
# "GEOS5_QV2M",
# "GEOS5_SLP",
# "GEOS5_SWGNT",
# "GEOS5_SWTNT",
# "GEOS5_T2M",
# "GEOS5_TQI",
# "GEOS5_TQL",
# "GEOS5_TQV",
# "GEOS5_TS",
# "GEOS5_U10M",
# "GEOS5_V10M",
# "GEOS5_Z0M",
# ]
# static_surface_vars = ["GEOS5_FRACI", "GEOS5_FRLAND", "GEOS5_FROCEAN", "GEOS5_PHIS"]
# vertical_vars = ["GEOS5_CLOUD", "GEOS5_H", "GEOS5_OMEGA", "GEOS5_PL", "GEOS5_QI", "GEOS5_QL", "GEOS5_QV", "GEOS5_T", "GEOS5_U", "GEOS5_V"]
# levels = [
# # Define levels specific to GEOS5 if different
# 10.0,
# 20.0,
# 30.0,
# 40.0,
# 50.0,
# 60.0,
# 70.0,
# 80.0,
# ]
# else:
# st.error("Unsupported dataset type selected.")
# st.stop()
padding = {"level": [0, 0], "lat": [0, -1], "lon": [0, 0]}
residual = "climate"
masking_mode = "local"
decoder_shifting = True
masking_ratio = 0.99
positional_encoding = "fourier"
# --- Initialize Dataset ---
try:
with st.spinner("Initializing dataset..."):
if selected_model == "Prithvi":
pass
# # Validate climatology files
# if not clim_files_exist and not (uploaded_clim_surface and uploaded_clim_vertical):
# st.error("Climatology files are missing. Please upload both climatology surface and vertical files.")
# st.stop()
# dataset = Merra2Dataset(
# time_range=time_range,
# lead_times=lead_times,
# input_times=input_times,
# data_path_surface=surf_dir,
# data_path_vertical=vert_dir,
# climatology_path_surface=clim_surf_path,
# climatology_path_vertical=clim_vert_path,
# surface_vars=surface_vars,
# static_surface_vars=static_surface_vars,
# vertical_vars=vertical_vars,
# levels=levels,
# positional_encoding=positional_encoding,
# )
# assert len(dataset) > 0, "There doesn't seem to be any valid data."
elif selected_model == "Aurora":
# TODO just temporary, replace this
if uploaded_files:
temp_file_paths = [] # List to store paths of temporary files
try:
# Save each uploaded file to a temporary file
save_uploaded_files(uploaded_files)
ds = load_dataset(st.session_state.temp_file_paths)
# Now, use xarray to open the multiple files
if ds:
st.success("Files successfully loaded!")
st.session_state.ds_subset = ds
# print(ds)
ds = ds.fillna(ds.mean())
desired_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
# Ensure that the 'lev' dimension exists
if 'lev' not in ds.dims:
raise ValueError("The dataset does not contain a 'lev' (pressure level) dimension.")
# Define the _prepare function
def _prepare(x: np.ndarray, i: int) -> torch.Tensor:
# Select previous and current time steps
selected = x[[i - 6, i]]
# Add a batch dimension
selected = selected[None]
# Ensure data is contiguous
selected = selected.copy()
# Convert to PyTorch tensor
return torch.from_numpy(selected)
# Adjust latitudes and longitudes
lat = ds.lat.values * -1
lon = ds.lon.values + 180
# Subset the dataset to only include the desired pressure levels
ds_subset = ds.sel(lev=desired_levels, method="nearest")
# Verify that all desired levels are present
present_levels = ds_subset.lev.values
missing_levels = set(desired_levels) - set(present_levels)
if missing_levels:
raise ValueError(f"The following desired pressure levels are missing in the dataset: {missing_levels}")
# Extract pressure levels after subsetting
lev = ds_subset.lev.values # Pressure levels in hPa
# Prepare surface variables at 1000 hPa
try:
lev_index_1000 = np.where(lev == 1000)[0][0]
except IndexError:
raise ValueError("1000 hPa level not found in the 'lev' dimension after subsetting.")
T_surface = ds_subset.T.isel(lev=lev_index_1000).compute()
U_surface = ds_subset.U.isel(lev=lev_index_1000).compute()
V_surface = ds_subset.V.isel(lev=lev_index_1000).compute()
SLP = ds_subset.SLP.compute()
# Reorder static variables (selecting the first time index to remove the time dimension)
PHIS = ds_subset.PHIS.isel(time=0).compute()
# Prepare atmospheric variables for the desired pressure levels excluding 1000 hPa
atmos_levels = [int(level) for level in lev if level != 1000]
T_atm = (ds_subset.T.sel(lev=atmos_levels)).compute()
U_atm = (ds_subset.U.sel(lev=atmos_levels)).compute()
V_atm = (ds_subset.V.sel(lev=atmos_levels)).compute()
# Select time index
num_times = ds_subset.time.size
i = 6 # Adjust as needed (1 <= i < num_times)
if i >= num_times or i < 1:
raise IndexError("Time index i is out of bounds.")
time_values = ds_subset.time.values
current_time = np.datetime64(time_values[i]).astype('datetime64[s]').astype(datetime)
# Prepare surface variables
surf_vars = {
"2t": _prepare(T_surface.values, i), # Two-meter temperature
"10u": _prepare(U_surface.values, i), # Ten-meter eastward wind
"10v": _prepare(V_surface.values, i), # Ten-meter northward wind
"msl": _prepare(SLP.values, i), # Mean sea-level pressure
}
# Prepare static variables (now 2D tensors)
static_vars = {
"z": torch.from_numpy(PHIS.values.copy()), # Geopotential (h, w)
# Add 'lsm' and 'slt' if available and needed
}
# Prepare atmospheric variables
atmos_vars = {
"t": _prepare(T_atm.values, i), # Temperature at desired levels
"u": _prepare(U_atm.values, i), # Eastward wind at desired levels
"v": _prepare(V_atm.values, i), # Southward wind at desired levels
}
# Define metadata
metadata = Metadata(
lat=torch.from_numpy(lat.copy()),
lon=torch.from_numpy(lon.copy()),
time=(current_time,),
atmos_levels=tuple(atmos_levels), # Only the desired atmospheric levels
)
# Create the Batch object
batch = Batch(
surf_vars=surf_vars,
static_vars=static_vars,
atmos_vars=atmos_vars,
metadata=metadata
) # Display the dataset or perform further processing
st.session_state['batch'] = batch
except Exception as e:
st.error(f"An error occurred: {e}")
# finally:
# # Clean up: Remove temporary files
# for path in temp_file_paths:
# try:
# os.remove(path)
# except Exception as e:
# st.warning(f"Could not delete temp file {path}: {e}")
else:
# For other models, implement their specific dataset initialization
# Placeholder: Replace with actual dataset initialization for other models
dataset = None # Replace with actual dataset
st.warning("Dataset initialization for this model is not implemented yet.")
st.stop()
st.success("Dataset initialized successfully.")
except Exception as e:
st.error("Error initializing dataset:")
st.error(traceback.format_exc())
st.stop()
# --- Load Scalers ---
try:
with st.spinner("Loading scalers..."):
if selected_model == "Prithvi":
pass
# # Assuming the scaler paths are the same as climatology paths
# surf_in_scal_path = clim_surf_path
# vert_in_scal_path = clim_vert_path
# surf_out_scal_path = Path(clim_surf_path.parent) / "anomaly_variance_surface.nc"
# vert_out_scal_path = Path(clim_vert_path.parent) / "anomaly_variance_vertical.nc"
# # Check if output scaler files exist
# if not surf_out_scal_path.exists() or not vert_out_scal_path.exists():
# st.error("Anomaly variance scaler files are missing.")
# st.stop()
# in_mu, in_sig = input_scalers(
# surface_vars,
# vertical_vars,
# levels,
# surf_in_scal_path,
# vert_in_scal_path,
# )
# output_sig = output_scalers(
# surface_vars,
# vertical_vars,
# levels,
# surf_out_scal_path,
# vert_out_scal_path,
# )
# static_mu, static_sig = static_input_scalers(
# surf_in_scal_path,
# static_surface_vars,
# )
else:
# Load scalers for other models if applicable
# Placeholder: Replace with actual scaler loading for other models
in_mu, in_sig = None, None
output_sig = None
static_mu, static_sig = None, None
st.success("Scalers loaded successfully.")
except Exception as e:
st.error("Error loading scalers:")
st.error(traceback.format_exc())
st.stop()
# --- Load Configuration ---
try:
with st.spinner("Loading configuration..."):
if selected_model == "Prithvi":
with open(config_path, "r") as f:
config = yaml.safe_load(f)
# Validate config
required_params = [
"in_channels", "input_size_time", "in_channels_static",
"input_scalers_epsilon", "static_input_scalers_epsilon",
"n_lats_px", "n_lons_px", "patch_size_px",
"mask_unit_size_px", "embed_dim", "n_blocks_encoder",
"n_blocks_decoder", "mlp_multiplier", "n_heads",
"dropout", "drop_path", "parameter_dropout"
]
missing_params = [param for param in required_params if param not in config.get("params", {})]
if missing_params:
st.error(f"Missing configuration parameters: {missing_params}")
st.stop()
else:
# Load configuration for other models if applicable
# Placeholder: Replace with actual configuration loading for other models
config = {}
st.success("Configuration loaded successfully.")
except Exception as e:
st.error("Error loading configuration:")
st.error(traceback.format_exc())
st.stop()
# --- Initialize the Model ---
try:
with st.spinner("Initializing model..."):
if selected_model == "Prithvi":
model = PrithviWxC(
in_channels=config["params"]["in_channels"],
input_size_time=config["params"]["input_size_time"],
in_channels_static=config["params"]["in_channels_static"],
input_scalers_mu=in_mu,
input_scalers_sigma=in_sig,
input_scalers_epsilon=config["params"]["input_scalers_epsilon"],
static_input_scalers_mu=static_mu,
static_input_scalers_sigma=static_sig,
static_input_scalers_epsilon=config["params"]["static_input_scalers_epsilon"],
output_scalers=output_sig**0.5,
n_lats_px=config["params"]["n_lats_px"],
n_lons_px=config["params"]["n_lons_px"],
patch_size_px=config["params"]["patch_size_px"],
mask_unit_size_px=config["params"]["mask_unit_size_px"],
mask_ratio_inputs=masking_ratio,
embed_dim=config["params"]["embed_dim"],
n_blocks_encoder=config["params"]["n_blocks_encoder"],
n_blocks_decoder=config["params"]["n_blocks_decoder"],
mlp_multiplier=config["params"]["mlp_multiplier"],
n_heads=config["params"]["n_heads"],
dropout=config["params"]["dropout"],
drop_path=config["params"]["drop_path"],
parameter_dropout=config["params"]["parameter_dropout"],
residual=residual,
masking_mode=masking_mode,
decoder_shifting=decoder_shifting,
positional_encoding=positional_encoding,
checkpoint_encoder=[],
checkpoint_decoder=[],
)
elif selected_model == "Aurora":
pass
else:
# Initialize other models here
# Placeholder: Replace with actual model initialization for other models
model = None
st.warning("Model initialization for this model is not implemented yet.")
st.stop()
# model.to(device)
st.success("Model initialized successfully.")
except Exception as e:
st.error("Error initializing model:")
st.error(traceback.format_exc())
st.stop()
# --- Load Model Weights ---
try:
with st.spinner("Loading model weights..."):
if selected_model == "Prithvi":
state_dict = torch.load(weights_path, map_location=device)
if "model_state" in state_dict:
state_dict = state_dict["model_state"]
model.load_state_dict(state_dict, strict=True)
model.to(device)
else:
# Load weights for other models if applicable
# Placeholder: Replace with actual weight loading for other models
pass
st.success("Model weights loaded successfully.")
except Exception as e:
st.error("Error loading model weights:")
st.error(traceback.format_exc())
st.stop()
# --- Prepare Data Batch ---
try:
with st.spinner("Preparing data batch..."):
if selected_model == "Prithvi":
data = next(iter(dataset))
batch = preproc([data], padding)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(device)
elif selected_model == "Aurora":
batch = batch.regrid(res=0.25)
else:
# Prepare data batch for other models
# Placeholder: Replace with actual data preparation for other models
batch = None
st.success("Data batch prepared successfully.")
except Exception as e:
st.error("Error preparing data batch:")
st.error(traceback.format_exc())
st.stop()
# --- Run Inference ---
try:
with st.spinner("Running model inference..."):
if selected_model == "Prithvi":
model.eval()
with torch.no_grad():
out = model(batch)
elif selected_model == "Aurora":
model = Aurora(use_lora=False)
# model = Aurora()
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
# model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
model.eval()
# model = model.to("cuda") # Uncomment if using a GPU
with torch.inference_mode():
out = [pred.to("cpu") for pred in rollout(model, batch, steps=2)]
model = model.to("cpu")
st.session_state.model = model
else:
# Run inference for other models
# Placeholder: Replace with actual inference code for other models
out = torch.randn(1, 10, 180, 360) # Dummy tensor
st.success("Model inference completed successfully.")
st.session_state['out'] = out
except Exception as e:
st.error("Error during model inference:")
st.error(traceback.format_exc())
st.stop()
# --- Visualization Settings ---
st.markdown("## 📊 Visualization Settings")
if 'out' in st.session_state and 'batch' in st.session_state and selected_model == "Prithvi":
# Display the shape of the output tensor
out_tensor = st.session_state['out']
st.write(f"**Output tensor shape:** {out_tensor.shape}")
# Ensure the output tensor has at least 4 dimensions (batch, variables, lat, lon)
if out_tensor.ndim < 4:
st.error("The output tensor does not have the expected number of dimensions (batch, variables, lat, lon).")
st.stop()
# Get the number of variables
num_variables = out_tensor.shape[1]
# Define variable names (update with your actual variable names)
variable_names = [f"Variable_{i}" for i in range(num_variables)]
# Visualization settings
col1, col2 = st.columns(2)
with col1:
# Select variable to plot
selected_variable_name = st.selectbox(
"Select Variable to Plot",
options=variable_names,
index=0,
help="Choose the variable you want to visualize."
)
# Select plot type
plot_type = st.selectbox(
"Select Plot Type",
options=["Contour", "Heatmap"],
index=0,
help="Choose the type of plot to display."
)
with col2:
# Select color map
cmap = st.selectbox(
"Select Color Map",
options=plt.colormaps(),
index=plt.colormaps().index("viridis"),
help="Choose the color map for the plot."
)
# Set number of levels (for contour plot)
if plot_type == "Contour":
num_levels = st.slider(
"Number of Contour Levels",
min_value=5,
max_value=100,
value=20,
step=5,
help="Set the number of contour levels."
)
else:
num_levels = None
# Find the index based on the selected name
variable_index = variable_names.index(selected_variable_name)
# Extract the selected variable
selected_variable = out_tensor[0, variable_index].cpu().numpy()
# Generate latitude and longitude arrays
lat = np.linspace(-90, 90, selected_variable.shape[0])
lon = np.linspace(-180, 180, selected_variable.shape[1])
X, Y = np.meshgrid(lon, lat)
# Plot the selected variable
st.markdown(f"### Plot of {selected_variable_name}")
# Matplotlib figure
fig, ax = plt.subplots(figsize=(10, 6))
if plot_type == "Contour":
# Generate the contour plot
contour = ax.contourf(X, Y, selected_variable, levels=num_levels, cmap=cmap)
elif plot_type == "Heatmap":
# Generate the heatmap
contour = ax.imshow(selected_variable, extent=[-180, 180, -90, 90], cmap=cmap, origin='lower', aspect='auto')
# Add a color bar
cbar = plt.colorbar(contour, ax=ax)
cbar.set_label(f'{selected_variable_name}', fontsize=12)
# Set aspect ratio and labels
ax.set_xlabel("Longitude", fontsize=12)
ax.set_ylabel("Latitude", fontsize=12)
ax.set_title(f"{selected_variable_name}", fontsize=14)
# Display the plot in Streamlit
st.pyplot(fig)
# Optional: Provide interactive Plotly plot
st.markdown("#### Interactive Plot")
if plot_type == "Contour":
fig_plotly = go.Figure(data=go.Contour(
z=selected_variable,
x=lon,
y=lat,
colorscale=cmap,
contours=dict(
coloring='fill',
showlabels=True,
labelfont=dict(size=12, color='white'),
ncontours=num_levels
)
))
elif plot_type == "Heatmap":
fig_plotly = go.Figure(data=go.Heatmap(
z=selected_variable,
x=lon,
y=lat,
colorscale=cmap
))
fig_plotly.update_layout(
xaxis_title="Longitude",
yaxis_title="Latitude",
autosize=False,
width=800,
height=600,
)
st.plotly_chart(fig_plotly)
elif 'out' in st.session_state and selected_model == "Aurora" and st.session_state['out'] is not None:
preds = st.session_state['out']
ds_subset = st.session_state.get('ds_subset', None)
batch = st.session_state.get('batch', None)
# **Determine Available Levels**
# For example, let's assume levels range from 0 to max_level_index
# You need to replace 'max_level_index' with the actual maximum level index in your data
try:
# Assuming 'lev' dimension exists and is 1D
levels = preds[0].atmos_vars["t"].shape[2] # Adjust based on your data structure
level_indices = list(range(levels))
except Exception as e:
st.error("Error determining available levels:")
st.error(traceback.format_exc())
levels = None # Set to None if levels cannot be determined
if levels is not None:
# **Add a Slider for Level Selection**
selected_level = st.slider(
'Select Level',
min_value=0,
max_value=levels - 1,
value=11, # Default level index
step=1,
help="Select the vertical level for plotting."
)
# Loop through predictions and ground truths
for idx in range(len(preds)):
pred = preds[idx]
pred_time = pred.metadata.time[0]
# Display prediction time
st.write(f"### Prediction Time: {pred_time}")
# **Extract Data at Selected Level**
try:
# Update indices with the selected_level
pred_data = pred.atmos_vars["t"][0][0][selected_level].numpy() - 273.15
truth_data = ds_subset.T.isel(lev=selected_level)[idx].values - 273.15
except Exception as e:
st.error("Error extracting data for plotting:")
st.error(traceback.format_exc())
continue
# Extract latitude and longitude
try:
lat = np.array(pred.metadata.lat) # Assuming 'lat' is 1D
lon = np.array(pred.metadata.lon) # Assuming 'lon' is 1D
except Exception as e:
st.error("Error extracting latitude and longitude:")
st.error(traceback.format_exc())
continue
# Create a meshgrid for plotting
lon_grid, lat_grid = np.meshgrid(lon, lat)
# Create a Matplotlib figure with Cartopy projection
fig, axes = plt.subplots(
1, 3, figsize=(18, 6),
subplot_kw={'projection': ccrs.PlateCarree()}
)
# **Ground Truth Plot**
im1 = axes[0].imshow(
truth_data,
extent=[lon.min(), lon.max(), lat.min(), lat.max()],
origin='lower',
cmap='coolwarm',
transform=ccrs.PlateCarree()
)
axes[0].set_title(f"Ground Truth at Level {selected_level} - {pred_time}")
axes[0].set_xlabel('Longitude')
axes[0].set_ylabel('Latitude')
plt.colorbar(im1, ax=axes[0], orientation='horizontal', pad=0.05)
# **Prediction Plot**
im2 = axes[1].imshow(
pred_data,
extent=[lon.min(), lon.max(), lat.min(), lat.max()],
origin='lower',
cmap='coolwarm',
transform=ccrs.PlateCarree()
)
axes[1].set_title(f"Prediction at Level {selected_level} - {pred_time}")
axes[1].set_xlabel('Longitude')
axes[1].set_ylabel('Latitude')
plt.colorbar(im2, ax=axes[1], orientation='horizontal', pad=0.05)
plt.tight_layout()
# Display the plot in Streamlit
st.pyplot(fig)
else:
st.error("Could not determine the available levels in the data.")
else:
st.warning("No output available to display or visualization is not implemented for this model.")
# --- End of Inference Button ---
else:
with right_col:
st.header("🖥️ Visualization & Progress")
st.info("Awaiting inference to display results.")
|