File size: 5,809 Bytes
100edb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import streamlit as st
import torch
from aurora import Aurora, Batch, Metadata
import numpy as np
from datetime import datetime
def aurora_config_ui():
st.subheader("Aurora Model Data Upload")
st.markdown("### Drag and Drop Your Data Files Here")
uploaded_files = st.file_uploader(
"Upload Data Files for Aurora",
accept_multiple_files=True,
key="aurora_uploader",
type=["nc", "netcdf", "nc4"]
)
return uploaded_files
def prepare_aurora_batch(ds):
desired_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
# Ensure that the 'lev' dimension exists
if 'lev' not in ds.dims:
raise ValueError("The dataset does not contain a 'lev' (pressure level) dimension.")
# Define the _prepare function
def _prepare(x: np.ndarray, i: int) -> torch.Tensor:
# Select previous and current time steps
selected = x[[i - 6, i]]
# Add a batch dimension
selected = selected[None]
# Ensure data is contiguous
selected = selected.copy()
# Convert to PyTorch tensor
return torch.from_numpy(selected)
# Adjust latitudes and longitudes
lat = ds.lat.values * -1
lon = ds.lon.values + 180
# Subset the dataset to only include the desired pressure levels
ds_subset = ds.sel(lev=desired_levels, method="nearest")
# Verify that all desired levels are present
present_levels = ds_subset.lev.values
missing_levels = set(desired_levels) - set(present_levels)
if missing_levels:
raise ValueError(f"The following desired pressure levels are missing in the dataset: {missing_levels}")
# Extract pressure levels after subsetting
lev = ds_subset.lev.values # Pressure levels in hPa
# Prepare surface variables at 1000 hPa
try:
lev_index_1000 = np.where(lev == 1000)[0][0]
except IndexError:
raise ValueError("1000 hPa level not found in the 'lev' dimension after subsetting.")
T_surface = ds_subset.T.isel(lev=lev_index_1000).compute()
U_surface = ds_subset.U.isel(lev=lev_index_1000).compute()
V_surface = ds_subset.V.isel(lev=lev_index_1000).compute()
SLP = ds_subset.SLP.compute()
# Reorder static variables (selecting the first time index to remove the time dimension)
PHIS = ds_subset.PHIS.isel(time=0).compute()
# Prepare atmospheric variables for the desired pressure levels excluding 1000 hPa
atmos_levels = [int(level) for level in lev if level != 1000]
T_atm = (ds_subset.T.sel(lev=atmos_levels)).compute()
U_atm = (ds_subset.U.sel(lev=atmos_levels)).compute()
V_atm = (ds_subset.V.sel(lev=atmos_levels)).compute()
# Select time index
num_times = ds_subset.time.size
i = 6 # Adjust as needed (1 <= i < num_times)
if i >= num_times or i < 1:
raise IndexError("Time index i is out of bounds.")
time_values = ds_subset.time.values
current_time = np.datetime64(time_values[i]).astype('datetime64[s]').astype(datetime)
# Prepare surface variables
surf_vars = {
"2t": _prepare(T_surface.values, i), # Two-meter temperature
"10u": _prepare(U_surface.values, i), # Ten-meter eastward wind
"10v": _prepare(V_surface.values, i), # Ten-meter northward wind
"msl": _prepare(SLP.values, i), # Mean sea-level pressure
}
# Prepare static variables (now 2D tensors)
static_vars = {
"z": torch.from_numpy(PHIS.values.copy()), # Geopotential (h, w)
# Add 'lsm' and 'slt' if available and needed
}
# Prepare atmospheric variables
atmos_vars = {
"t": _prepare(T_atm.values, i), # Temperature at desired levels
"u": _prepare(U_atm.values, i), # Eastward wind at desired levels
"v": _prepare(V_atm.values, i), # Southward wind at desired levels
}
# Define metadata
metadata = Metadata(
lat=torch.from_numpy(lat.copy()),
lon=torch.from_numpy(lon.copy()),
time=(current_time,),
atmos_levels=tuple(atmos_levels), # Only the desired atmospheric levels
)
# Create the Batch object
batch = Batch(
surf_vars=surf_vars,
static_vars=static_vars,
atmos_vars=atmos_vars,
metadata=metadata
) # Display the dataset or perform further processing
return batch
def initialize_aurora_model(device):
model = Aurora(use_lora=False)
# Load pretrained checkpoint if available
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
model = model.to(device)
return model
|