Climate-ML-Foundation-Models / aurora_utils.py
qq1990's picture
roll back
4df22b4
raw
history blame
8.49 kB
import streamlit as st
import torch
from aurora import Aurora, Batch, Metadata
import numpy as np
from datetime import datetime
def aurora_config_ui():
st.subheader("Aurora Model Data Input")
# Detailed data description section
st.markdown("""
**Available Models & Usage:**
Aurora provides several pretrained and fine-tuned models at 0.25° and 0.1° resolutions.
Models and weights are available through the HuggingFace repository: [microsoft/aurora](https://huggingface.co/microsoft/aurora).
**Aurora 0.25° Pretrained**
- Trained on a variety of data.
- Suitable if no fine-tuned version exists for your dataset or to fine-tune Aurora yourself.
- Use if your dataset is ERA5 at 0.25° resolution (721x1440).
**Aurora 0.25° Pretrained Small**
- A smaller version of the pretrained model for debugging purposes.
**Aurora 0.25° Fine-Tuned**
- Fine-tuned on IFS HRES T0.
- Best performance at 0.25° but should only be used for IFS HRES T0 data.
- May not give optimal results for other datasets.
**Aurora 0.1° Fine-Tuned**
- For IFS HRES T0 at 0.1° resolution (1801x3600).
- Best performing at 0.1° resolution.
- Data must match IFS HRES T0 conditions.
**Required Variables & Pressure Levels:**
For all Aurora models at these resolutions, the following inputs are required:
- **Surface-level variables:** 2t, 10u, 10v, msl
- **Static variables:** lsm, slt, z
- **Atmospheric variables:** t, u, v, q, z
- **Pressure levels (hPa):** 50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000
Latitude range should decrease from 90°N to -90°S, and longitude range from 0° to 360° (excluding 360°). Data should be in single precision float32.
**Data Format (Batch):**
Data should be provided as a `aurora.Batch` object:
- `surf_vars` dict with shape (b, t, h, w)
- `static_vars` dict with shape (h, w)
- `atmos_vars` dict with shape (b, t, c, h, w)
- `metadata` containing lat, lon, time, and atmos_levels.
For detailed instructions and examples, refer to the official Aurora documentation and code repository.
""")
# File uploader for Aurora data
st.markdown("### Upload Your Input Data Files for Aurora")
st.markdown("Upload the NetCDF files (e.g., `.nc`, `.netcdf`, `.nc4`) containing the required variables.")
uploaded_files = st.file_uploader(
"Drag and drop or select multiple .nc files",
accept_multiple_files=True,
key="aurora_uploader",
type=["nc", "netcdf", "nc4"]
)
st.markdown("---")
st.markdown("### References & Resources")
st.markdown("""
- **HuggingFace Repository:** [microsoft/aurora](https://huggingface.co/microsoft/aurora)
- **Model Usage Examples:**
```python
from aurora import Aurora
model = Aurora()
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
```
- **API & Documentation:** Refer to the Aurora official GitHub and HuggingFace pages for detailed instructions.
""")
return uploaded_files
def prepare_aurora_batch(ds):
desired_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]
# Ensure that the 'lev' dimension exists
if 'lev' not in ds.dims:
raise ValueError("The dataset does not contain a 'lev' (pressure level) dimension.")
# Define the _prepare function
def _prepare(x: np.ndarray, i: int) -> torch.Tensor:
# Select previous and current time steps
selected = x[[i - 6, i]]
# Add a batch dimension
selected = selected[None]
# Ensure data is contiguous
selected = selected.copy()
# Convert to PyTorch tensor
return torch.from_numpy(selected)
# Adjust latitudes and longitudes
lat = ds.lat.values * -1
lon = ds.lon.values + 180
# Subset the dataset to only include the desired pressure levels
ds_subset = ds.sel(lev=desired_levels, method="nearest")
# Verify that all desired levels are present
present_levels = ds_subset.lev.values
missing_levels = set(desired_levels) - set(present_levels)
if missing_levels:
raise ValueError(f"The following desired pressure levels are missing in the dataset: {missing_levels}")
# Extract pressure levels after subsetting
lev = ds_subset.lev.values # Pressure levels in hPa
# Prepare surface variables at 1000 hPa
try:
lev_index_1000 = np.where(lev == 1000)[0][0]
except IndexError:
raise ValueError("1000 hPa level not found in the 'lev' dimension after subsetting.")
T_surface = ds_subset.T.isel(lev=lev_index_1000).compute()
U_surface = ds_subset.U.isel(lev=lev_index_1000).compute()
V_surface = ds_subset.V.isel(lev=lev_index_1000).compute()
SLP = ds_subset.SLP.compute()
# Reorder static variables (selecting the first time index to remove the time dimension)
PHIS = ds_subset.PHIS.isel(time=0).compute()
# Prepare atmospheric variables for the desired pressure levels excluding 1000 hPa
atmos_levels = [int(level) for level in lev if level != 1000]
T_atm = (ds_subset.T.sel(lev=atmos_levels)).compute()
U_atm = (ds_subset.U.sel(lev=atmos_levels)).compute()
V_atm = (ds_subset.V.sel(lev=atmos_levels)).compute()
# Select time index
num_times = ds_subset.time.size
i = 6 # Adjust as needed (1 <= i < num_times)
if i >= num_times or i < 1:
raise IndexError("Time index i is out of bounds.")
time_values = ds_subset.time.values
current_time = np.datetime64(time_values[i]).astype('datetime64[s]').astype(datetime)
# Prepare surface variables
surf_vars = {
"2t": _prepare(T_surface.values, i), # Two-meter temperature
"10u": _prepare(U_surface.values, i), # Ten-meter eastward wind
"10v": _prepare(V_surface.values, i), # Ten-meter northward wind
"msl": _prepare(SLP.values, i), # Mean sea-level pressure
}
# Prepare static variables (now 2D tensors)
static_vars = {
"z": torch.from_numpy(PHIS.values.copy()), # Geopotential (h, w)
# Add 'lsm' and 'slt' if available and needed
}
# Prepare atmospheric variables
atmos_vars = {
"t": _prepare(T_atm.values, i), # Temperature at desired levels
"u": _prepare(U_atm.values, i), # Eastward wind at desired levels
"v": _prepare(V_atm.values, i), # Southward wind at desired levels
}
# Define metadata
metadata = Metadata(
lat=torch.from_numpy(lat.copy()),
lon=torch.from_numpy(lon.copy()),
time=(current_time,),
atmos_levels=tuple(atmos_levels), # Only the desired atmospheric levels
)
# Create the Batch object
batch = Batch(
surf_vars=surf_vars,
static_vars=static_vars,
atmos_vars=atmos_vars,
metadata=metadata
) # Display the dataset or perform further processing
return batch
def initialize_aurora_model(device):
model = Aurora(use_lora=False)
# Load pretrained checkpoint if available
model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
model = model.to(device)
return model