File size: 5,809 Bytes
100edb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import streamlit as st
import torch
from aurora import Aurora, Batch, Metadata
import numpy as np
from datetime import datetime

def aurora_config_ui():
    st.subheader("Aurora Model Data Upload")
    st.markdown("### Drag and Drop Your Data Files Here")
    uploaded_files = st.file_uploader(
        "Upload Data Files for Aurora",
        accept_multiple_files=True,
        key="aurora_uploader",
        type=["nc", "netcdf", "nc4"]
    )
    return uploaded_files

def prepare_aurora_batch(ds):
    desired_levels = [50, 100, 150, 200, 250, 300, 400, 500, 600, 700, 850, 925, 1000]

    # Ensure that the 'lev' dimension exists
    if 'lev' not in ds.dims:
        raise ValueError("The dataset does not contain a 'lev' (pressure level) dimension.")

    # Define the _prepare function
    def _prepare(x: np.ndarray, i: int) -> torch.Tensor:
                                        # Select previous and current time steps
        selected = x[[i - 6, i]]
                                        
        # Add a batch dimension
        selected = selected[None]
                                        
        # Ensure data is contiguous
        selected = selected.copy()
                                        
        # Convert to PyTorch tensor
        return torch.from_numpy(selected)

                                    # Adjust latitudes and longitudes
    lat = ds.lat.values * -1
    lon = ds.lon.values + 180

                                    # Subset the dataset to only include the desired pressure levels
    ds_subset = ds.sel(lev=desired_levels, method="nearest")

                                    # Verify that all desired levels are present
    present_levels = ds_subset.lev.values
    missing_levels = set(desired_levels) - set(present_levels)
    if missing_levels:
        raise ValueError(f"The following desired pressure levels are missing in the dataset: {missing_levels}")

    # Extract pressure levels after subsetting
    lev = ds_subset.lev.values  # Pressure levels in hPa

    # Prepare surface variables at 1000 hPa
    try:
        lev_index_1000 = np.where(lev == 1000)[0][0]
    except IndexError:
        raise ValueError("1000 hPa level not found in the 'lev' dimension after subsetting.")

    T_surface = ds_subset.T.isel(lev=lev_index_1000).compute()
    U_surface = ds_subset.U.isel(lev=lev_index_1000).compute()
    V_surface = ds_subset.V.isel(lev=lev_index_1000).compute()
    SLP = ds_subset.SLP.compute()

                                    # Reorder static variables (selecting the first time index to remove the time dimension)
    PHIS = ds_subset.PHIS.isel(time=0).compute()

                                    # Prepare atmospheric variables for the desired pressure levels excluding 1000 hPa
    atmos_levels = [int(level) for level in lev if level != 1000]

    T_atm = (ds_subset.T.sel(lev=atmos_levels)).compute()
    U_atm = (ds_subset.U.sel(lev=atmos_levels)).compute()
    V_atm = (ds_subset.V.sel(lev=atmos_levels)).compute()

                                    # Select time index
    num_times = ds_subset.time.size
    i = 6  # Adjust as needed (1 <= i < num_times)

    if i >= num_times or i < 1:
        raise IndexError("Time index i is out of bounds.")

    time_values = ds_subset.time.values
    current_time = np.datetime64(time_values[i]).astype('datetime64[s]').astype(datetime)

                                    # Prepare surface variables
    surf_vars = {
                                        "2t": _prepare(T_surface.values, i),   # Two-meter temperature
                                        "10u": _prepare(U_surface.values, i),  # Ten-meter eastward wind
                                        "10v": _prepare(V_surface.values, i),  # Ten-meter northward wind
                                        "msl": _prepare(SLP.values, i),        # Mean sea-level pressure
                                    }

                                    # Prepare static variables (now 2D tensors)
    static_vars = {
                                        "z": torch.from_numpy(PHIS.values.copy()),  # Geopotential (h, w)
                                        # Add 'lsm' and 'slt' if available and needed
                                    }

                                    # Prepare atmospheric variables
    atmos_vars = {
                                        "t": _prepare(T_atm.values, i),  # Temperature at desired levels
                                        "u": _prepare(U_atm.values, i),  # Eastward wind at desired levels
                                        "v": _prepare(V_atm.values, i),  # Southward wind at desired levels
                                    }

                                    # Define metadata
    metadata = Metadata(
                                        lat=torch.from_numpy(lat.copy()),
                                        lon=torch.from_numpy(lon.copy()),
                                        time=(current_time,),
                                        atmos_levels=tuple(atmos_levels),  # Only the desired atmospheric levels
                                    )

                                    # Create the Batch object
    batch = Batch(
                                        surf_vars=surf_vars,
                                        static_vars=static_vars,
                                        atmos_vars=atmos_vars,
                                        metadata=metadata
                                    ) # Display the dataset or perform further processing
    return batch

def initialize_aurora_model(device):
    model = Aurora(use_lora=False)
    # Load pretrained checkpoint if available
    model.load_checkpoint("microsoft/aurora", "aurora-0.25-pretrained.ckpt")
    model = model.to(device)
    return model