broadfield-dev commited on
Commit
ca15132
·
verified ·
1 Parent(s): b02989a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -74
app.py CHANGED
@@ -10,7 +10,12 @@ import warnings
10
  import logging
11
  import datetime
12
  import matplotlib.pyplot as plt
13
- import sunpy.visualization.colormaps as sunpy_cm
 
 
 
 
 
14
  import traceback
15
  from io import BytesIO
16
  import re
@@ -26,13 +31,22 @@ logger = logging.getLogger(__name__)
26
 
27
  APP_CACHE = {}
28
 
29
- CHANNEL_TO_URL_CODE = {
30
- "aia94": "0094", "aia131": "0131", "aia171": "0171", "aia193": "0193",
31
- "aia211": "0211", "aia304": "0304", "aia335": "0335", "aia1600": "1600",
32
- "hmi_m": "HMIBC", "hmi_bx": "HMIB", "hmi_by": "HMIB",
33
- "hmi_bz": "HMIB", "hmi_v": "HMID"
 
 
 
 
 
 
 
 
 
34
  }
35
- SDO_CHANNELS = list(CHANNEL_TO_URL_CODE.keys())
36
 
37
  def setup_and_load_model():
38
  if "model" in APP_CACHE:
@@ -75,41 +89,6 @@ def setup_and_load_model():
75
  APP_CACHE["model"] = model
76
  yield "✅ Model setup complete."
77
 
78
- def find_nearest_browse_image_url(channel, target_dt):
79
- url_code = CHANNEL_TO_URL_CODE[channel]
80
- base_url = "https://sdo.gsfc.nasa.gov/assets/img/browse"
81
-
82
- for i in range(2):
83
- dt_to_try = target_dt - datetime.timedelta(days=i)
84
- dir_url = dt_to_try.strftime(f"{base_url}/%Y/%m/%d/")
85
-
86
- response = requests.get(dir_url)
87
- if response.status_code != 200:
88
- continue
89
-
90
- filenames = re.findall(r'href="(\d{8}_\d{6}_4096_' + url_code + r'\.jpg)"', response.text)
91
- if not filenames:
92
- continue
93
-
94
- best_filename = ""
95
- min_diff = float('inf')
96
-
97
- for fname in filenames:
98
- try:
99
- timestamp_str = fname.split('_')[1]
100
- img_dt = datetime.datetime.strptime(f"{dt_to_try.strftime('%Y%m%d')}{timestamp_str}", "%Y%m%d%H%M%S")
101
- diff = abs((target_dt - img_dt).total_seconds())
102
- if diff < min_diff:
103
- min_diff = diff
104
- best_filename = fname
105
- except (ValueError, IndexError):
106
- continue
107
-
108
- if best_filename:
109
- return dir_url + best_filename
110
-
111
- raise FileNotFoundError(f"Could not find any browse images for {channel} in the last 48 hours.")
112
-
113
  def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
114
  config = APP_CACHE["config"]
115
  img_size = config["model"]["img_size"]
@@ -119,36 +98,62 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
119
  target_time = target_dt + datetime.timedelta(minutes=forecast_horizon_minutes)
120
  all_times = sorted(list(set(input_times + [target_time])))
121
 
122
- images = {}
 
123
  total_fetches = len(all_times) * len(SDO_CHANNELS)
124
  fetches_done = 0
125
- yield f"Starting search for {total_fetches} data files..."
126
 
127
  for t in all_times:
128
- images[t] = {}
129
  for channel in SDO_CHANNELS:
130
  fetches_done += 1
131
- yield f"Finding [{fetches_done}/{total_fetches}]: Closest image for {channel} near {t.strftime('%Y-%m-%d %H:%M')}..."
132
 
133
- image_url = find_nearest_browse_image_url(channel, t)
134
- yield f"Downloading: {os.path.basename(image_url)}..."
135
-
136
- response = requests.get(image_url)
137
- response.raise_for_status()
138
- images[t][channel] = Image.open(BytesIO(response.content))
 
 
 
 
 
 
139
 
140
- yield "✅ All images found and downloaded. Starting preprocessing..."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  scalers_dict = APP_CACHE["scalers"]
142
  processed_tensors = {}
143
- for t, channel_images in images.items():
144
  channel_tensors = []
145
  for i, channel in enumerate(SDO_CHANNELS):
146
- img = channel_images[channel]
147
- if img.mode != 'L':
148
- img = img.convert('L')
149
 
150
- img_resized = img.resize((img_size, img_size), Image.Resampling.LANCZOS)
151
- norm_data = np.array(img_resized, dtype=np.float32)
 
152
 
153
  scaler = scalers_dict[channel]
154
  scaled_data = scaler.transform(norm_data.reshape(-1, 1)).reshape(norm_data.shape)
@@ -158,10 +163,10 @@ def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
158
  yield "✅ Preprocessing complete."
159
  input_tensor_list = [processed_tensors[t] for t in input_times]
160
  input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
161
- target_image_map = images[target_time]
162
- last_input_image_map = images[input_times[-1]]
163
 
164
- yield (input_tensor, last_input_image_map, target_image_map)
165
 
166
  def run_inference(input_tensor):
167
  model = APP_CACHE["model"]
@@ -178,20 +183,14 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
178
  if last_input_map is None: return None, None, None
179
  c_idx = SDO_CHANNELS.index(channel_name)
180
 
181
- # *** FIX: Access the specific scaler for the channel from the dictionary ***
182
  scaler = APP_CACHE["scalers"][channel_name]
183
- # *** FIX: Access the parameters as attributes, not from to_dict() ***
184
- mean = scaler.mean
185
- std = scaler.std
186
- epsilon = scaler.epsilon
187
- sl_scale_factor = scaler.sl_scale_factor
188
 
189
  pred_slice = inverse_transform_single_channel(
190
  prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
191
  )
192
 
193
- target_img_data = np.array(target_map[channel_name])
194
- vmax = np.quantile(np.nan_to_num(target_img_data), 0.995)
195
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
196
  cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
197
 
@@ -202,7 +201,7 @@ def generate_visualization(last_input_map, prediction_tensor, target_map, channe
202
  colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
203
  return Image.fromarray(colored)
204
 
205
- return last_input_map[channel_name], to_pil(pred_slice), target_map[channel_name]
206
 
207
  def forecast_controller(date_str, hour, minute, forecast_horizon):
208
  yield {
@@ -274,11 +273,9 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
274
  """
275
  <div align='center'>
276
  # ☀️ Surya: Live Forecast Demo ☀️
277
- ### A Foundation Model for Solar Dynamics
278
  This demo runs NASA's **Surya**, a foundation model trained to understand the physics of the Sun.
279
  It looks at the Sun in 13 different channels (wavelengths of light) simultaneously to learn the complex relationships between phenomena like coronal loops, magnetic fields, and solar flares. By seeing these interconnected views, it can generate a holistic forecast of what the entire solar disk will look like in the near future.
280
- <br>
281
- <p style="color:red;font-weight:bold;">NOTE: This demo uses lower-quality browse images for reliability. The model was trained on high-fidelity scientific data, so forecast accuracy may vary.</p>
282
  </div>
283
  """
284
  )
 
10
  import logging
11
  import datetime
12
  import matplotlib.pyplot as plt
13
+ import sunpy.map
14
+ import sunpy.net.attrs as a
15
+ from sunpy.net import Fido
16
+ from astropy.wcs import WCS
17
+ import astropy.units as u
18
+ from reproject import reproject_interp
19
  import traceback
20
  from io import BytesIO
21
  import re
 
31
 
32
  APP_CACHE = {}
33
 
34
+ SDO_CHANNELS_MAP = {
35
+ "aia94": (a.Wavelength(94 * u.angstrom), a.Sample(12 * u.s)),
36
+ "aia131": (a.Wavelength(131 * u.angstrom), a.Sample(12 * u.s)),
37
+ "aia171": (a.Wavelength(171 * u.angstrom), a.Sample(12 * u.s)),
38
+ "aia193": (a.Wavelength(193 * u.angstrom), a.Sample(12 * u.s)),
39
+ "aia211": (a.Wavelength(211 * u.angstrom), a.Sample(12 * u.s)),
40
+ "aia304": (a.Wavelength(304 * u.angstrom), a.Sample(12 * u.s)),
41
+ "aia335": (a.Wavelength(335 * u.angstrom), a.Sample(12 * u.s)),
42
+ "aia1600": (a.Wavelength(1600 * u.angstrom), a.Sample(24 * u.s)),
43
+ "hmi_m": (a.Physobs("intensity"), a.Sample(45 * u.s)),
44
+ "hmi_bx": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
45
+ "hmi_by": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
46
+ "hmi_bz": (a.Physobs("los_magnetic_field"), a.Sample(720 * u.s)),
47
+ "hmi_v": (a.Physobs("los_velocity"), a.Sample(45 * u.s)),
48
  }
49
+ SDO_CHANNELS = list(SDO_CHANNELS_MAP.keys())
50
 
51
  def setup_and_load_model():
52
  if "model" in APP_CACHE:
 
89
  APP_CACHE["model"] = model
90
  yield "✅ Model setup complete."
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def fetch_and_process_sdo_data(target_dt, forecast_horizon_minutes):
93
  config = APP_CACHE["config"]
94
  img_size = config["model"]["img_size"]
 
98
  target_time = target_dt + datetime.timedelta(minutes=forecast_horizon_minutes)
99
  all_times = sorted(list(set(input_times + [target_time])))
100
 
101
+ data_maps = {}
102
+ last_successful_map = {} # Store the last good map for each channel
103
  total_fetches = len(all_times) * len(SDO_CHANNELS)
104
  fetches_done = 0
105
+ yield f"Starting search for {total_fetches} scientific data files..."
106
 
107
  for t in all_times:
108
+ data_maps[t] = {}
109
  for channel in SDO_CHANNELS:
110
  fetches_done += 1
111
+ yield f"Querying [{fetches_done}/{total_fetches}]: {channel} near {t.strftime('%Y-%m-%d %H:%M')}..."
112
 
113
+ # Handle placeholder channels by reusing hmi_bx
114
+ if channel in ["hmi_by", "hmi_bz"]:
115
+ if data_maps[t].get("hmi_bx"):
116
+ smap = data_maps[t]["hmi_bx"]
117
+ data_maps[t][channel] = smap
118
+ last_successful_map[channel] = smap
119
+ continue
120
+
121
+ physobs, sample = SDO_CHANNELS_MAP[channel]
122
+ time_attr = a.Time(t - datetime.timedelta(minutes=5), t + datetime.timedelta(minutes=5))
123
+ instrument = a.Instrument.hmi if "hmi" in channel else a.Instrument.aia
124
+ query = Fido.search(time_attr, instrument, physobs, sample)
125
 
126
+ if query:
127
+ files = Fido.fetch(query[0,0], path="./data/sdo_cache")
128
+ smap = sunpy.map.Map(files[0])
129
+ data_maps[t][channel] = smap
130
+ last_successful_map[channel] = smap # Save the good map
131
+ elif channel in last_successful_map:
132
+ # If the query fails, reuse the last successful map for this channel
133
+ yield f"⚠️ WARNING: No data for {channel} near {t}. Reusing previous image."
134
+ data_maps[t][channel] = last_successful_map[channel]
135
+ else:
136
+ # If the very first image for a channel fails, we cannot proceed.
137
+ raise ValueError(f"CRITICAL: No initial data found for {channel}. Cannot proceed.")
138
+
139
+ yield "✅ All data acquired. Starting preprocessing..."
140
+ output_wcs = WCS(naxis=2)
141
+ output_wcs.wcs.crpix = [(img_size + 1) / 2, (img_size + 1) / 2]
142
+ output_wcs.wcs.cdelt = np.array([-1.2, 1.2]) * u.arcsec
143
+ output_wcs.wcs.crval = [0, 0] * u.arcsec
144
+ output_wcs.wcs.ctype = ['HPLN-TAN', 'HPLT-TAN']
145
+
146
  scalers_dict = APP_CACHE["scalers"]
147
  processed_tensors = {}
148
+ for t, channel_maps in data_maps.items():
149
  channel_tensors = []
150
  for i, channel in enumerate(SDO_CHANNELS):
151
+ smap = channel_maps[channel]
152
+ reprojected_data, _ = reproject_interp(smap, output_wcs, shape_out=(img_size, img_size))
 
153
 
154
+ exp_time = smap.meta.get('exptime', 1.0)
155
+ if exp_time is None or exp_time <= 0: exp_time = 1.0
156
+ norm_data = reprojected_data / exp_time
157
 
158
  scaler = scalers_dict[channel]
159
  scaled_data = scaler.transform(norm_data.reshape(-1, 1)).reshape(norm_data.shape)
 
163
  yield "✅ Preprocessing complete."
164
  input_tensor_list = [processed_tensors[t] for t in input_times]
165
  input_tensor = torch.stack(input_tensor_list, dim=1).unsqueeze(0)
166
+ target_map = data_maps[target_time]
167
+ last_input_map = data_maps[input_times[-1]]
168
 
169
+ yield (input_tensor, last_input_map, target_map)
170
 
171
  def run_inference(input_tensor):
172
  model = APP_CACHE["model"]
 
183
  if last_input_map is None: return None, None, None
184
  c_idx = SDO_CHANNELS.index(channel_name)
185
 
 
186
  scaler = APP_CACHE["scalers"][channel_name]
187
+ mean, std, epsilon, sl_scale_factor = scaler.mean, scaler.std, scaler.epsilon, scaler.sl_scale_factor
 
 
 
 
188
 
189
  pred_slice = inverse_transform_single_channel(
190
  prediction_tensor[0, c_idx].numpy(), mean=mean, std=std, epsilon=epsilon, sl_scale_factor=sl_scale_factor
191
  )
192
 
193
+ vmax = np.quantile(np.nan_to_num(target_map[channel_name].data), 0.995)
 
194
  cmap_name = f"sdoaia{channel_name.replace('aia', '')}" if 'aia' in channel_name else 'hmimag'
195
  cmap = plt.get_cmap(sunpy_cm.cmlist.get(cmap_name, 'gray'))
196
 
 
201
  colored = (cmap(data_norm)[:, :, :3] * 255).astype(np.uint8)
202
  return Image.fromarray(colored)
203
 
204
+ return to_pil(last_input_map[channel_name].data), to_pil(pred_slice), to_pil(target_map[channel_name].data)
205
 
206
  def forecast_controller(date_str, hour, minute, forecast_horizon):
207
  yield {
 
273
  """
274
  <div align='center'>
275
  # ☀️ Surya: Live Forecast Demo ☀️
276
+ ### A Foundation Model for Solar Dynamics using High-Fidelity Scientific Data
277
  This demo runs NASA's **Surya**, a foundation model trained to understand the physics of the Sun.
278
  It looks at the Sun in 13 different channels (wavelengths of light) simultaneously to learn the complex relationships between phenomena like coronal loops, magnetic fields, and solar flares. By seeing these interconnected views, it can generate a holistic forecast of what the entire solar disk will look like in the near future.
 
 
279
  </div>
280
  """
281
  )