Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,6 @@
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import torch
|
5 |
-
from torch.utils.data import DataLoader
|
6 |
from huggingface_hub import snapshot_download
|
7 |
import yaml
|
8 |
import numpy as np
|
@@ -10,11 +9,9 @@ from PIL import Image
|
|
10 |
import sunpy.map
|
11 |
import sunpy.net.attrs as a
|
12 |
from sunpy.net import Fido
|
13 |
-
from sunpy.coordinates import Helioprojective
|
14 |
-
from astropy.coordinates import SkyCoord
|
15 |
from astropy.wcs import WCS
|
16 |
import astropy.units as u
|
17 |
-
from reproject import reproject_interp
|
18 |
import os
|
19 |
import warnings
|
20 |
import logging
|
@@ -27,7 +24,8 @@ from surya.models.helio_spectformer import HelioSpectFormer
|
|
27 |
from surya.utils.data import build_scalers, inverse_transform_single_channel
|
28 |
|
29 |
# --- Configuration ---
|
30 |
-
warnings.filterwarnings("ignore")
|
|
|
31 |
logging.basicConfig(level=logging.INFO)
|
32 |
logger = logging.getLogger(__name__)
|
33 |
|
@@ -94,14 +92,12 @@ def fetch_and_process_sdo_data(target_dt, progress):
|
|
94 |
config = APP_CACHE["config"]
|
95 |
img_size = config["model"]["img_size"][0]
|
96 |
|
97 |
-
# Define time windows for input and target (ground truth)
|
98 |
input_deltas = config["data"]["time_delta_input_minutes"]
|
99 |
target_delta = config["data"]["time_delta_target_minutes"][0]
|
100 |
input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
|
101 |
target_time = target_dt + datetime.timedelta(minutes=target_delta)
|
102 |
all_times = sorted(list(set(input_times + [target_time])))
|
103 |
|
104 |
-
# Download data for all required timestamps
|
105 |
data_maps = {}
|
106 |
total_downloads = len(all_times) * len(SDO_CHANNELS_MAP)
|
107 |
downloads_done = 0
|
@@ -110,53 +106,56 @@ def fetch_and_process_sdo_data(target_dt, progress):
|
|
110 |
for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
|
111 |
progress(downloads_done / total_downloads, desc=f"Downloading {channel} for {t.strftime('%H:%M')}...")
|
112 |
|
113 |
-
# HMI vector fields are not standard products, use LoS as a placeholder for demo
|
114 |
instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
|
115 |
if channel in ["hmi_by", "hmi_bz"]:
|
116 |
if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
|
117 |
continue
|
118 |
|
119 |
time_attr = a.Time(t - datetime.timedelta(minutes=10), t + datetime.timedelta(minutes=10))
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
if not query: raise ValueError(f"No data found for {channel} at {t}")
|
123 |
files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
|
124 |
data_maps[t][channel] = sunpy.map.Map(files[0])
|
125 |
downloads_done += 1
|
126 |
|
127 |
-
# Create target WCS for reprojection
|
128 |
output_wcs = WCS(naxis=2)
|
129 |
output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
|
130 |
output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
|
131 |
output_wcs.wcs.crval = [0, 0] * u.arcsec
|
132 |
output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
|
133 |
|
134 |
-
# Process data
|
135 |
processed_tensors = {}
|
|
|
|
|
136 |
for t, channel_maps in data_maps.items():
|
137 |
channel_tensors = []
|
138 |
for i, channel in enumerate(SDO_CHANNELS):
|
139 |
-
progress(
|
140 |
smap = channel_maps[channel]
|
141 |
|
142 |
-
# Reproject to common grid
|
143 |
reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size))
|
144 |
|
145 |
-
# Normalize by exposure time and apply signed-log transform
|
146 |
exp_time = smap.meta.get('exptime', 1.0)
|
147 |
-
if exp_time <= 0: exp_time = 1.0
|
148 |
norm_data = reprojected_data / exp_time
|
149 |
|
150 |
-
# Apply the same scaling as the training pipeline
|
151 |
scaler = APP_CACHE["scalers"][channel]
|
152 |
scaled_data = scaler.transform(norm_data)
|
153 |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
|
|
|
154 |
processed_tensors[t] = torch.stack(channel_tensors)
|
155 |
|
156 |
-
# Assemble final input and target tensors
|
157 |
input_tensor_list = [processed_tensors[t] for t in input_times]
|
158 |
-
input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
|
159 |
-
target_map = data_maps[target_time]
|
160 |
last_input_map = data_maps[input_times[-1]]
|
161 |
|
162 |
return input_tensor, last_input_map, target_map
|
@@ -191,8 +190,7 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
|
|
191 |
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
|
192 |
)
|
193 |
|
194 |
-
|
195 |
-
vmax = np.quantile(target_map[channel_name].data, 0.995)
|
196 |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|
197 |
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
|
198 |
|
@@ -206,7 +204,6 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
|
|
206 |
|
207 |
return to_pil(last_input_map[channel_name].data), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data)
|
208 |
|
209 |
-
|
210 |
# --- 4. Gradio UI and Controllers ---
|
211 |
def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
|
212 |
try:
|
@@ -223,13 +220,12 @@ def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
|
|
223 |
|
224 |
prediction_tensor = run_inference(input_tensor)
|
225 |
|
226 |
-
# Default visualization for aia171
|
227 |
img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
|
228 |
|
229 |
status = f"Forecast complete for {target_dt.isoformat()}. Ready to explore channels."
|
230 |
logger.info(status)
|
231 |
|
232 |
-
return (last_input_map, prediction_tensor, target_map,
|
233 |
img_in, img_pred, img_target, status, gr.update(visible=True))
|
234 |
|
235 |
except Exception as e:
|
@@ -243,7 +239,6 @@ def update_visualization_controller(last_input_map, prediction_tensor, target_ma
|
|
243 |
|
244 |
|
245 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
246 |
-
# State objects to hold the data after a forecast is run
|
247 |
state_last_input = gr.State()
|
248 |
state_prediction = gr.State()
|
249 |
state_target = gr.State()
|
@@ -263,7 +258,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
263 |
|
264 |
with gr.Row():
|
265 |
datetime_input = gr.Textbox(label="Enter Forecast Start Time (YYYY-MM-DD HH:MM:SS)",
|
266 |
-
value=(datetime.datetime.now() - datetime.timedelta(hours=
|
267 |
run_button = gr.Button("🔮 Generate Forecast", variant="primary")
|
268 |
|
269 |
with gr.Group(visible=False) as results_group:
|
@@ -288,6 +283,5 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
288 |
)
|
289 |
|
290 |
if __name__ == "__main__":
|
291 |
-
# Create cache directory if it doesn't exist
|
292 |
os.makedirs("./data/sdo_cache", exist_ok=True)
|
293 |
demo.launch(debug=True)
|
|
|
2 |
|
3 |
import gradio as gr
|
4 |
import torch
|
|
|
5 |
from huggingface_hub import snapshot_download
|
6 |
import yaml
|
7 |
import numpy as np
|
|
|
9 |
import sunpy.map
|
10 |
import sunpy.net.attrs as a
|
11 |
from sunpy.net import Fido
|
|
|
|
|
12 |
from astropy.wcs import WCS
|
13 |
import astropy.units as u
|
14 |
+
from reproject import reproject_interp # Correct import statement
|
15 |
import os
|
16 |
import warnings
|
17 |
import logging
|
|
|
24 |
from surya.utils.data import build_scalers, inverse_transform_single_channel
|
25 |
|
26 |
# --- Configuration ---
|
27 |
+
warnings.filterwarnings("ignore", category=UserWarning, module='sunpy')
|
28 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
29 |
logging.basicConfig(level=logging.INFO)
|
30 |
logger = logging.getLogger(__name__)
|
31 |
|
|
|
92 |
config = APP_CACHE["config"]
|
93 |
img_size = config["model"]["img_size"][0]
|
94 |
|
|
|
95 |
input_deltas = config["data"]["time_delta_input_minutes"]
|
96 |
target_delta = config["data"]["time_delta_target_minutes"][0]
|
97 |
input_times = [target_dt + datetime.timedelta(minutes=m) for m in input_deltas]
|
98 |
target_time = target_dt + datetime.timedelta(minutes=target_delta)
|
99 |
all_times = sorted(list(set(input_times + [target_time])))
|
100 |
|
|
|
101 |
data_maps = {}
|
102 |
total_downloads = len(all_times) * len(SDO_CHANNELS_MAP)
|
103 |
downloads_done = 0
|
|
|
106 |
for i, (channel, (physobs, sample)) in enumerate(SDO_CHANNELS_MAP.items()):
|
107 |
progress(downloads_done / total_downloads, desc=f"Downloading {channel} for {t.strftime('%H:%M')}...")
|
108 |
|
|
|
109 |
instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
|
110 |
if channel in ["hmi_by", "hmi_bz"]:
|
111 |
if data_maps[t].get("hmi_bx"): data_maps[t][channel] = data_maps[t]["hmi_bx"]
|
112 |
continue
|
113 |
|
114 |
time_attr = a.Time(t - datetime.timedelta(minutes=10), t + datetime.timedelta(minutes=10))
|
115 |
+
search_query = [time_attr, physobs, sample]
|
116 |
+
# AIA and HMI queries are slightly different
|
117 |
+
if "aia" in channel:
|
118 |
+
search_query.append(a.Instrument.aia)
|
119 |
+
else:
|
120 |
+
search_query.append(a.Instrument.hmi)
|
121 |
+
|
122 |
+
query = Fido.search(*search_query)
|
123 |
|
124 |
if not query: raise ValueError(f"No data found for {channel} at {t}")
|
125 |
files = Fido.fetch(query[0, 0], path="./data/sdo_cache")
|
126 |
data_maps[t][channel] = sunpy.map.Map(files[0])
|
127 |
downloads_done += 1
|
128 |
|
|
|
129 |
output_wcs = WCS(naxis=2)
|
130 |
output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
|
131 |
output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
|
132 |
output_wcs.wcs.crval = [0, 0] * u.arcsec
|
133 |
output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
|
134 |
|
|
|
135 |
processed_tensors = {}
|
136 |
+
total_processing = len(all_times) * len(SDO_CHANNELS)
|
137 |
+
processing_done = 0
|
138 |
for t, channel_maps in data_maps.items():
|
139 |
channel_tensors = []
|
140 |
for i, channel in enumerate(SDO_CHANNELS):
|
141 |
+
progress(processing_done / total_processing, desc=f"Processing {channel} for {t.strftime('%H:%M')}...")
|
142 |
smap = channel_maps[channel]
|
143 |
|
|
|
144 |
reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size))
|
145 |
|
|
|
146 |
exp_time = smap.meta.get('exptime', 1.0)
|
147 |
+
if exp_time is None or exp_time <= 0: exp_time = 1.0
|
148 |
norm_data = reprojected_data / exp_time
|
149 |
|
|
|
150 |
scaler = APP_CACHE["scalers"][channel]
|
151 |
scaled_data = scaler.transform(norm_data)
|
152 |
channel_tensors.append(torch.from_numpy(scaled_data.astype(np.float32)))
|
153 |
+
processing_done += 1
|
154 |
processed_tensors[t] = torch.stack(channel_tensors)
|
155 |
|
|
|
156 |
input_tensor_list = [processed_tensors[t] for t in input_times]
|
157 |
+
input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
|
158 |
+
target_map = data_maps[target_time]
|
159 |
last_input_map = data_maps[input_times[-1]]
|
160 |
|
161 |
return input_tensor, last_input_map, target_map
|
|
|
190 |
mean=means[c_idx], std=stds[c_idx], epsilon=epsilons[c_idx], sl_scale_factor=sl_scale_factors[c_idx]
|
191 |
)
|
192 |
|
193 |
+
vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
|
|
|
194 |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|
195 |
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
|
196 |
|
|
|
204 |
|
205 |
return to_pil(last_input_map[channel_name].data), to_pil(pred_slice, flip=True), to_pil(target_map[channel_name].data)
|
206 |
|
|
|
207 |
# --- 4. Gradio UI and Controllers ---
|
208 |
def forecast_controller(dt_str, progress=gr.Progress(track_tqdm=True)):
|
209 |
try:
|
|
|
220 |
|
221 |
prediction_tensor = run_inference(input_tensor)
|
222 |
|
|
|
223 |
img_in, img_pred, img_target = generate_visualization(last_input_map, prediction_tensor, target_map, "aia171")
|
224 |
|
225 |
status = f"Forecast complete for {target_dt.isoformat()}. Ready to explore channels."
|
226 |
logger.info(status)
|
227 |
|
228 |
+
return (last_input_map, prediction_tensor, target_map,
|
229 |
img_in, img_pred, img_target, status, gr.update(visible=True))
|
230 |
|
231 |
except Exception as e:
|
|
|
239 |
|
240 |
|
241 |
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
|
242 |
state_last_input = gr.State()
|
243 |
state_prediction = gr.State()
|
244 |
state_target = gr.State()
|
|
|
258 |
|
259 |
with gr.Row():
|
260 |
datetime_input = gr.Textbox(label="Enter Forecast Start Time (YYYY-MM-DD HH:MM:SS)",
|
261 |
+
value=(datetime.datetime.now() - datetime.timedelta(hours=3)).strftime("%Y-%m-%d %H:%M:%S"))
|
262 |
run_button = gr.Button("🔮 Generate Forecast", variant="primary")
|
263 |
|
264 |
with gr.Group(visible=False) as results_group:
|
|
|
283 |
)
|
284 |
|
285 |
if __name__ == "__main__":
|
|
|
286 |
os.makedirs("./data/sdo_cache", exist_ok=True)
|
287 |
demo.launch(debug=True)
|