Climate-ML-Foundation-Models / aurora_utils.py
qq1990's picture
init
100edb4
raw
history blame
5.81 kB
import streamlit as st
import torch
from aurora import Aurora, Batch, Metadata
import numpy as np
from datetime import datetime
def aurora_config_ui():
st.subheader("Aurora Model Data Upload")
st.markdown("### Drag and Drop Your Data Files Here")
uploaded_files = st.file_uploader(
"Upload Data Files for Aurora",
accept_multiple_files=True,
key="aurora_uploader",
type=["nc", "netcdf", "nc4"]
)
return uploaded_files
def prepare_aurora_batch(ds):
desired_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
# Ensure that the 'lev' dimension exists
if 'lev' not in ds.dims:
raise ValueError("The dataset does not contain a 'lev' (pressure level) dimension.")
# Define the _prepare function
def _prepare(x: np.ndarray, i: int) -> torch.Tensor:
# Select previous and current time steps
selected = x[[i - 6, i]]
# Add a batch dimension
selected = selected[None]
# Ensure data is contiguous
selected = selected.copy()
# Convert to PyTorch tensor
return torch.from_numpy(selected)
# Adjust latitudes and longitudes
lat = ds.lat.values * -1
lon = ds.lon.values + 180
# Subset the dataset to only include the desired pressure levels
ds_subset = ds.sel(lev=desired_levels, method="nearest")
# Verify that all desired levels are present
present_levels = ds_subset.lev.values
missing_levels = set(desired_levels) - set(present_levels)
if missing_levels:
raise ValueError(f"The following desired pressure levels are missing in the dataset: {missing_levels}")
# Extract pressure levels after subsetting
lev = ds_subset.lev.values # Pressure levels in hPa
# Prepare surface variables at 1000 hPa
try:
lev_index_1000 = np.where(lev == 1000)[0][0]
except IndexError:
raise ValueError("1000 hPa level not found in the 'lev' dimension after subsetting.")
T_surface = ds_subset.T.isel(lev=lev_index_1000).compute()
U_surface = ds_subset.U.isel(lev=lev_index_1000).compute()
V_surface = ds_subset.V.isel(lev=lev_index_1000).compute()
SLP = ds_subset.SLP.compute()
# Reorder static variables (selecting the first time index to remove the time dimension)
PHIS = ds_subset.PHIS.isel(time=0).compute()
# Prepare atmospheric variables for the desired pressure levels excluding 1000 hPa
atmos_levels = [int(level) for level in lev if level != 1000]
T_atm = (ds_subset.T.sel(lev=atmos_levels)).compute()
U_atm = (ds_subset.U.sel(lev=atmos_levels)).compute()
V_atm = (ds_subset.V.sel(lev=atmos_levels)).compute()
# Select time index
num_times = ds_subset.time.size
i = 6 # Adjust as needed (1 <= i < num_times)
if i >= num_times or i < 1:
raise IndexError("Time index i is out of bounds.")
time_values = ds_subset.time.values
current_time = np.datetime64(time_values[i]).astype('datetime64[s]').astype(datetime)
# Prepare surface variables
surf_vars = {
"2t": _prepare(T_surface.values, i), # Two-meter temperature
"10u": _prepare(U_surface.values, i), # Ten-meter eastward wind
"10v": _prepare(V_surface.values, i), # Ten-meter northward wind
"msl": _prepare(SLP.values, i), # Mean sea-level pressure
}
# Prepare static variables (now 2D tensors)
static_vars = {
"z": torch.from_numpy(PHIS.values.copy()), # Geopotential (h, w)
# Add 'lsm' and 'slt' if available and needed
}
# Prepare atmospheric variables
atmos_vars = {
"t": _prepare(T_atm.values, i), # Temperature at desired levels
"u": _prepare(U_atm.values, i), # Eastward wind at desired levels
"v": _prepare(V_atm.values, i), # Southward wind at desired levels
}
# Define metadata
metadata = Metadata(
lat=torch.from_numpy(lat.copy()),
lon=torch.from_numpy(lon.copy()),
time=(current_time,),
atmos_levels=tuple(atmos_levels), # Only the desired atmospheric levels
)
# Create the Batch object
batch = Batch(
surf_vars=surf_vars,
static_vars=static_vars,
atmos_vars=atmos_vars,
metadata=metadata
) # Display the dataset or perform further processing
return batch
def initialize_aurora_model(device):
model = Aurora(use_lora=False)
# Load pretrained checkpoint if available
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
model = model.to(device)
return model