Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -10,12 +10,7 @@ import warnings
|
|
10 |
import logging
|
11 |
import datetime
|
12 |
import matplotlib.pyplot as plt
|
13 |
-
import sunpy.
|
14 |
-
import sunpy.net.attrs as a
|
15 |
-
from sunpy.net import Fido
|
16 |
-
from astropy.wcs import WCS
|
17 |
-
import astropy.units as u
|
18 |
-
from reproject import reproject_interp
|
19 |
import traceback
|
20 |
from io import BytesIO
|
21 |
import re
|
@@ -31,22 +26,13 @@ logger = logging.getLogger(__name__)
|
|
31 |
|
32 |
APP_CACHE = {}
|
33 |
|
34 |
-
|
35 |
-
"aia94":
|
36 |
-
"
|
37 |
-
"
|
38 |
-
"
|
39 |
-
"aia211": (a.Wavelength(211 * u.angstrom), a.Sample(12 * u.s)),
|
40 |
-
"aia304": (a.Wavelength(304 * u.angstrom), a.Sample(12 * u.s)),
|
41 |
-
"aia335": (a.Wavelength(335 * u.angstrom), a.Sample(12 * u.s)),
|
42 |
-
"aia1600": (a.Wavelength(1600 * u.angstrom), a.Sample(24 * u.s)),
|
43 |
-
"hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
|
44 |
-
"hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
45 |
-
"hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
46 |
-
"hmi_bz": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
|
47 |
-
"hmi_v": (a.Physobs("los_velocity"), a.Sample(45 * u.s)),
|
48 |
}
|
49 |
-
SDO_CHANNELS = list(
|
50 |
|
51 |
def setup_and_load_model():
|
52 |
if "model" in APP_CACHE:
|
@@ -89,6 +75,41 @@ def setup_and_load_model():
|
|
89 |
APP_CACHE["model"] = model
|
90 |
yield "✅ Model setup complete."
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
93 |
config = APP_CACHE["config"]
|
94 |
img_size = config["model"]["img_size"]
|
@@ -98,62 +119,36 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
|
98 |
target_time = target_dt + datetime.timedelta(minutes=forecast_horizon_minutes)
|
99 |
all_times = sorted(list(set(input_times + [target_time])))
|
100 |
|
101 |
-
|
102 |
-
last_successful_map = {} # Store the last good map for each channel
|
103 |
total_fetches = len(all_times) * len(SDO_CHANNELS)
|
104 |
fetches_done = 0
|
105 |
-
yield f"Starting search for {total_fetches}
|
106 |
|
107 |
for t in all_times:
|
108 |
-
|
109 |
for channel in SDO_CHANNELS:
|
110 |
fetches_done += 1
|
111 |
-
yield f"
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
if data_maps[t].get("hmi_bx"):
|
116 |
-
smap = data_maps[t]["hmi_bx"]
|
117 |
-
data_maps[t][channel] = smap
|
118 |
-
last_successful_map[channel] = smap
|
119 |
-
continue
|
120 |
-
|
121 |
-
physobs, sample = SDO_CHANNELS_MAP[channel]
|
122 |
-
time_attr = a.Time(t - datetime.timedelta(minutes=5), t + datetime.timedelta(minutes=5))
|
123 |
-
instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
|
124 |
-
query = Fido.search(time_attr, instrument, physobs, sample)
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
elif channel in last_successful_map:
|
132 |
-
# If the query fails, reuse the last successful map for this channel
|
133 |
-
yield f"⚠️ WARNING: No data for {channel} near {t}. Reusing previous image."
|
134 |
-
data_maps[t][channel] = last_successful_map[channel]
|
135 |
-
else:
|
136 |
-
# If the very first image for a channel fails, we cannot proceed.
|
137 |
-
raise ValueError(f"CRITICAL: No initial data found for {channel}. Cannot proceed.")
|
138 |
-
|
139 |
-
yield "✅ All data acquired. Starting preprocessing..."
|
140 |
-
output_wcs = WCS(naxis=2)
|
141 |
-
output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
|
142 |
-
output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
|
143 |
-
output_wcs.wcs.crval = [0, 0] * u.arcsec
|
144 |
-
output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
|
145 |
-
|
146 |
scalers_dict = APP_CACHE["scalers"]
|
147 |
processed_tensors = {}
|
148 |
-
for t,
|
149 |
channel_tensors = []
|
150 |
for i, channel in enumerate(SDO_CHANNELS):
|
151 |
-
|
152 |
-
|
|
|
153 |
|
154 |
-
|
155 |
-
|
156 |
-
norm_data = reprojected_data / exp_time
|
157 |
|
158 |
scaler = scalers_dict[channel]
|
159 |
scaled_data = scaler.transform(norm_data.reshape(-1, 1)).reshape(norm_data.shape)
|
@@ -163,10 +158,10 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
|
163 |
yield "✅ Preprocessing complete."
|
164 |
input_tensor_list = [processed_tensors[t] for t in input_times]
|
165 |
input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
|
166 |
-
|
167 |
-
|
168 |
|
169 |
-
yield (input_tensor,
|
170 |
|
171 |
def run_inference(input_tensor):
|
172 |
model = APP_CACHE["model"]
|
@@ -183,14 +178,20 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
|
|
183 |
if last_input_map is None: return None, None, None
|
184 |
c_idx = SDO_CHANNELS.index(channel_name)
|
185 |
|
|
|
186 |
scaler = APP_CACHE["scalers"][channel_name]
|
187 |
-
|
|
|
|
|
|
|
|
|
188 |
|
189 |
pred_slice = inverse_transform_single_channel(
|
190 |
prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
|
191 |
)
|
192 |
|
193 |
-
|
|
|
194 |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|
195 |
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
|
196 |
|
@@ -201,7 +202,7 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
|
|
201 |
colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
|
202 |
return Image.fromarray(colored)
|
203 |
|
204 |
-
return
|
205 |
|
206 |
def forecast_controller(date_str, hour, minute, forecast_horizon):
|
207 |
yield {
|
@@ -273,9 +274,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
|
273 |
"""
|
274 |
<div align='center'>
|
275 |
# ☀️ Surya: Live Forecast Demo ☀️
|
276 |
-
### A Foundation Model for Solar Dynamics
|
277 |
This demo runs NASA's **Surya**, a foundation model trained to understand the physics of the Sun.
|
278 |
It looks at the Sun in 13 different channels (wavelengths of light) simultaneously to learn the complex relationships between phenomena like coronal loops, magnetic fields, and solar flares. By seeing these interconnected views, it can generate a holistic forecast of what the entire solar disk will look like in the near future.
|
|
|
|
|
279 |
</div>
|
280 |
"""
|
281 |
)
|
|
|
10 |
import logging
|
11 |
import datetime
|
12 |
import matplotlib.pyplot as plt
|
13 |
+
import sunpy.visualization.colormaps as sunpy_cm
|
|
|
|
|
|
|
|
|
|
|
14 |
import traceback
|
15 |
from io import BytesIO
|
16 |
import re
|
|
|
26 |
|
27 |
APP_CACHE = {}
|
28 |
|
29 |
+
CHANNEL_TO_URL_CODE = {
|
30 |
+
"aia94": "0094", "aia131": "0131", "aia171": "0171", "aia193": "0193",
|
31 |
+
"aia211": "0211", "aia304": "0304", "aia335": "0335", "aia1600": "1600",
|
32 |
+
"hmi_m": "HMIBC", "hmi_bx": "HMIB", "hmi_by": "HMIB",
|
33 |
+
"hmi_bz": "HMIB", "hmi_v": "HMID"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
}
|
35 |
+
SDO_CHANNELS = list(CHANNEL_TO_URL_CODE.keys())
|
36 |
|
37 |
def setup_and_load_model():
|
38 |
if "model" in APP_CACHE:
|
|
|
75 |
APP_CACHE["model"] = model
|
76 |
yield "✅ Model setup complete."
|
77 |
|
78 |
+
def find_nearest_browse_image_url(channel, target_dt):
|
79 |
+
url_code = CHANNEL_TO_URL_CODE[channel]
|
80 |
+
base_url = "https://sdo.gsfc.nasa.gov/assets/img/browse"
|
81 |
+
|
82 |
+
for i in range(2):
|
83 |
+
dt_to_try = target_dt - datetime.timedelta(days=i)
|
84 |
+
dir_url = dt_to_try.strftime(f"{base_url}/%Y/%m/%d/")
|
85 |
+
|
86 |
+
response = requests.get(dir_url)
|
87 |
+
if response.status_code != 200:
|
88 |
+
continue
|
89 |
+
|
90 |
+
filenames = re.findall(r'href="(\d{8}_\d{6}_4096_' + url_code + r'\.jpg)"', response.text)
|
91 |
+
if not filenames:
|
92 |
+
continue
|
93 |
+
|
94 |
+
best_filename = ""
|
95 |
+
min_diff = float('inf')
|
96 |
+
|
97 |
+
for fname in filenames:
|
98 |
+
try:
|
99 |
+
timestamp_str = fname.split('_')[1]
|
100 |
+
img_dt = datetime.datetime.strptime(f"{dt_to_try.strftime('%Y%m%d')}{timestamp_str}", "%Y%m%d%H%M%S")
|
101 |
+
diff = abs((target_dt - img_dt).total_seconds())
|
102 |
+
if diff < min_diff:
|
103 |
+
min_diff = diff
|
104 |
+
best_filename = fname
|
105 |
+
except (ValueError, IndexError):
|
106 |
+
continue
|
107 |
+
|
108 |
+
if best_filename:
|
109 |
+
return dir_url + best_filename
|
110 |
+
|
111 |
+
raise FileNotFoundError(f"Could not find any browse images for {channel} in the last 48 hours.")
|
112 |
+
|
113 |
def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
|
114 |
config = APP_CACHE["config"]
|
115 |
img_size = config["model"]["img_size"]
|
|
|
119 |
target_time = target_dt + datetime.timedelta(minutes=forecast_horizon_minutes)
|
120 |
all_times = sorted(list(set(input_times + [target_time])))
|
121 |
|
122 |
+
images = {}
|
|
|
123 |
total_fetches = len(all_times) * len(SDO_CHANNELS)
|
124 |
fetches_done = 0
|
125 |
+
yield f"Starting search for {total_fetches} data files..."
|
126 |
|
127 |
for t in all_times:
|
128 |
+
images[t] = {}
|
129 |
for channel in SDO_CHANNELS:
|
130 |
fetches_done += 1
|
131 |
+
yield f"Finding [{fetches_done}/{total_fetches}]: Closest image for {channel} near {t.strftime('%Y-%m-%d %H:%M')}..."
|
132 |
|
133 |
+
image_url = find_nearest_browse_image_url(channel, t)
|
134 |
+
yield f"Downloading: {os.path.basename(image_url)}..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
+
response = requests.get(image_url)
|
137 |
+
response.raise_for_status()
|
138 |
+
images[t][channel] = Image.open(BytesIO(response.content))
|
139 |
+
|
140 |
+
yield "✅ All images found and downloaded. Starting preprocessing..."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
scalers_dict = APP_CACHE["scalers"]
|
142 |
processed_tensors = {}
|
143 |
+
for t, channel_images in images.items():
|
144 |
channel_tensors = []
|
145 |
for i, channel in enumerate(SDO_CHANNELS):
|
146 |
+
img = channel_images[channel]
|
147 |
+
if img.mode != 'L':
|
148 |
+
img = img.convert('L')
|
149 |
|
150 |
+
img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
|
151 |
+
norm_data = np.array(img_resized, dtype=np.float32)
|
|
|
152 |
|
153 |
scaler = scalers_dict[channel]
|
154 |
scaled_data = scaler.transform(norm_data.reshape(-1, 1)).reshape(norm_data.shape)
|
|
|
158 |
yield "✅ Preprocessing complete."
|
159 |
input_tensor_list = [processed_tensors[t] for t in input_times]
|
160 |
input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
|
161 |
+
target_image_map = images[target_time]
|
162 |
+
last_input_image_map = images[input_times[-1]]
|
163 |
|
164 |
+
yield (input_tensor, last_input_image_map, target_image_map)
|
165 |
|
166 |
def run_inference(input_tensor):
|
167 |
model = APP_CACHE["model"]
|
|
|
178 |
if last_input_map is None: return None, None, None
|
179 |
c_idx = SDO_CHANNELS.index(channel_name)
|
180 |
|
181 |
+
# *** FIX: Access the specific scaler for the channel from the dictionary ***
|
182 |
scaler = APP_CACHE["scalers"][channel_name]
|
183 |
+
# *** FIX: Access the parameters as attributes, not from to_dict() ***
|
184 |
+
mean = scaler.mean
|
185 |
+
std = scaler.std
|
186 |
+
epsilon = scaler.epsilon
|
187 |
+
sl_scale_factor = scaler.sl_scale_factor
|
188 |
|
189 |
pred_slice = inverse_transform_single_channel(
|
190 |
prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
|
191 |
)
|
192 |
|
193 |
+
target_img_data = np.array(target_map[channel_name])
|
194 |
+
vmax = np.quantile(np.nan_to_num(target_img_data), 0.995)
|
195 |
cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
|
196 |
cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
|
197 |
|
|
|
202 |
colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
|
203 |
return Image.fromarray(colored)
|
204 |
|
205 |
+
return last_input_map[channel_name], to_pil(pred_slice), target_map[channel_name]
|
206 |
|
207 |
def forecast_controller(date_str, hour, minute, forecast_horizon):
|
208 |
yield {
|
|
|
274 |
"""
|
275 |
<div align='center'>
|
276 |
# ☀️ Surya: Live Forecast Demo ☀️
|
277 |
+
### A Foundation Model for Solar Dynamics
|
278 |
This demo runs NASA's **Surya**, a foundation model trained to understand the physics of the Sun.
|
279 |
It looks at the Sun in 13 different channels (wavelengths of light) simultaneously to learn the complex relationships between phenomena like coronal loops, magnetic fields, and solar flares. By seeing these interconnected views, it can generate a holistic forecast of what the entire solar disk will look like in the near future.
|
280 |
+
<br>
|
281 |
+
<p style="color:red;font-weight:bold;">NOTE: This demo uses lower-quality browse images for reliability. The model was trained on high-fidelity scientific data, so forecast accuracy may vary.</p>
|
282 |
</div>
|
283 |
"""
|
284 |
)
|