Khalid Rafiq commited on
Commit
f270024
·
1 Parent(s): 3b01b0f

Convert app.ipynb to app.py for faster execution

Browse files
Files changed (1) hide show
  1. app.py +374 -13
app.py CHANGED
@@ -1,14 +1,375 @@
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import nbformat
3
- from nbconvert import PythonExporter
4
- import subprocess
5
-
6
- def run_notebook():
7
- with open("app.ipynb") as f:
8
- notebook = nbformat.read(f, as_version=4)
9
- python_script, _ = PythonExporter().from_notebook_node(notebook)
10
- with open("app_temp.py", "w") as f:
11
- f.write(python_script)
12
- subprocess.run(["python", "app_temp.py"])
13
-
14
- run_notebook()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[2]:
5
+
6
+
7
+ import time
8
+ import torch
9
+ import warnings
10
+ import numpy as np
11
  import gradio as gr
12
+ import matplotlib.pyplot as plt
13
+
14
+ # Import Burgers' equation components
15
+ from data_burgers import exact_solution as exact_solution_burgers
16
+ from model_io_burgers import load_model
17
+ from model_v2 import Encoder, Decoder, Propagator_concat as Propagator, Model
18
+ from LSTM_model import AE_Encoder, AE_Decoder, AE_Model, PytorchLSTM
19
+
20
+ # Import Advection-Diffusion components
21
+ from data_adv_dif import exact_solution as exact_solution_adv_dif
22
+ from model_io_adv_dif import load_model as load_model_adv_dif
23
+ from model_adv_dif import Encoder as Encoder2D, Decoder as Decoder2D, Propagator_concat as Propagator2D, Model as Model2D
24
+
25
+ warnings.filterwarnings("ignore")
26
+
27
+ # ========== Burgers' Equation Setup ==========
28
+ def get_burgers_model(input_dim, latent_dim):
29
+ encoder = Encoder(input_dim, latent_dim)
30
+ decoder = Decoder(latent_dim, input_dim)
31
+ propagator = Propagator(latent_dim)
32
+ return Model(encoder, decoder, propagator)
33
+
34
+ flexi_prop_model = get_burgers_model(128, 2)
35
+ checkpoint = torch.load("../1d_viscous_burgers/FlexiPropagator_2025-02-01-10-28-34_3e9656b5_best.pt", map_location='cpu')
36
+ flexi_prop_model.load_state_dict(checkpoint['model_state_dict'])
37
+ flexi_prop_model.eval()
38
+
39
+ # AE LSTM models
40
+ ae_encoder = AE_Encoder(128)
41
+ ae_decoder = AE_Decoder(2, 128)
42
+ ae_model = AE_Model(ae_encoder, ae_decoder)
43
+ lstm_model = PytorchLSTM()
44
+
45
+ ae_encoder.load_state_dict(torch.load("../1d_viscous_burgers/LSTM_model/ae_encoder_weights.pth", map_location='cpu'))
46
+ ae_decoder.load_state_dict(torch.load("../1d_viscous_burgers/LSTM_model/ae_decoder_weights.pth", map_location='cpu'))
47
+ ae_model.load_state_dict(torch.load("../1d_viscous_burgers/LSTM_model/ae_model.pth", map_location='cpu'))
48
+ lstm_model.load_state_dict(torch.load("../1d_viscous_burgers/LSTM_model/lstm_weights.pth", map_location='cpu'))
49
+
50
+ # ========== Helper Functions Burgers ==========
51
+ def exacts_equals_timewindow(t_0, Re, time_window=40):
52
+ dt = 2 / 500
53
+ solutions = [exact_solution_burgers(Re, t) for t in (t_0 + np.arange(0, time_window) * dt)]
54
+ solns = torch.tensor(solutions, dtype=torch.float32)[None, :, :]
55
+ latents = ae_encoder(solns)
56
+ re_normalized = Re / 1000
57
+ re_repeated = torch.ones(1, time_window, 1) * re_normalized
58
+ return torch.cat((latents, re_repeated), dim=2), latents, solns
59
+
60
+ # Precompute contour plots
61
+ z1_vals = np.linspace(-10, 0.5, 200)
62
+ z2_vals = np.linspace(5, 32, 200)
63
+ Z1, Z2 = np.meshgrid(z1_vals, z2_vals)
64
+ latent_grid = np.stack([Z1.ravel(), Z2.ravel()], axis=1)
65
+
66
+ # Convert to tensor for decoding
67
+ latent_tensors = torch.tensor(latent_grid, dtype=torch.float32)
68
+
69
+ # Decode latent vectors and compute properties
70
+ with torch.no_grad():
71
+ decoded_signals = flexi_prop_model.decoder(latent_tensors)
72
+
73
+ sharpness = []
74
+ peak_positions = []
75
+ x_vals = np.linspace(0, 2, decoded_signals.shape[1])
76
+ dx = x_vals[1] - x_vals[0]
77
+
78
+ for signal in decoded_signals.numpy():
79
+ grad_u = np.gradient(signal, dx)
80
+ sharpness.append(np.max(np.abs(grad_u)))
81
+ peak_positions.append(x_vals[np.argmax(signal)])
82
+
83
+ sharpness = np.array(sharpness).reshape(Z1.shape)
84
+ peak_positions = np.array(peak_positions).reshape(Z1.shape)
85
+
86
+ def plot_burgers_comparison(Re, tau, t_0):
87
+ dt = 2.0 / 500.0
88
+ t_final = t_0 + tau * dt
89
+ x_exact = exact_solution_burgers(Re, t_final)
90
+
91
+ tau_tensor, Re_tensor, xt = torch.tensor([tau]).float()[:, None], torch.tensor([Re]).float()[:, None], torch.tensor([exact_solution_burgers(Re, t_0)]).float()[:, None]
92
+
93
+ with torch.no_grad():
94
+ _, x_hat_tau, *_ = flexi_prop_model(xt, tau_tensor, Re_tensor)
95
+
96
+ latent_for_lstm, *_ = exacts_equals_timewindow(t_0, Re)
97
+ with torch.no_grad():
98
+ for _ in range(40, tau):
99
+ pred = lstm_model(latent_for_lstm)
100
+ pred_with_re = torch.cat((pred, torch.tensor([[Re / 1000]], dtype=torch.float32)), dim=1)
101
+ latent_for_lstm = torch.cat((latent_for_lstm[:, 1:, :], pred_with_re.unsqueeze(0)), dim=1)
102
+ final_pred_high_dim = ae_decoder(pred.unsqueeze(0))
103
+
104
+ fig, ax = plt.subplots(figsize=(9, 5))
105
+ ax.plot(xt.squeeze(), '--', linewidth=3, alpha=0.5, color="C0")
106
+ ax.plot(x_hat_tau.squeeze(), 'D', markersize=5, color="C2")
107
+ ax.plot(final_pred_high_dim.squeeze().detach().numpy(), '^', markersize=5, color="C1")
108
+ ax.plot(x_exact.squeeze(), linewidth=2, alpha=0.5, color="Black")
109
+ ax.set_title(f"Comparison ($t_0$={t_0:.2f} → $t_f$={t_final:.2f}), τ={tau}", fontsize=14)
110
+ ax.legend(["Initial", "Flexi-Prop", "AE LSTM", "True"])
111
+ return fig
112
+
113
+ def burgers_update(Re, tau, t0):
114
+ fig1 = plot_burgers_comparison(Re, tau, t0)
115
+
116
+ # Timing calculations
117
+ start = time.time()
118
+ _ = flexi_prop_model(torch.randn(1, 1, 128), torch.tensor([[tau]]), torch.tensor([[Re]]))
119
+ flexi_time = time.time() - start
120
+
121
+ start = time.time()
122
+ latent_for_lstm, _, _ = exacts_equals_timewindow(t0, Re, 40)
123
+ encode_time = time.time() - start
124
+
125
+ start = time.time()
126
+ with torch.no_grad():
127
+ for _ in range(40, tau):
128
+ pred = lstm_model(latent_for_lstm)
129
+ pred_with_re = torch.cat((pred, torch.tensor([[Re / 1000]], dtype=torch.float32)), dim=1)
130
+ latent_for_lstm = torch.cat((latent_for_lstm[:, 1:, :], pred_with_re.unsqueeze(0)), dim=1)
131
+ recursion_time = time.time() - start
132
+
133
+ start = time.time()
134
+ final_pred_high_dim = ae_decoder(pred.unsqueeze(0))
135
+ decode_time = time.time() - start
136
+
137
+ ae_lstm_total_time = encode_time + recursion_time + decode_time
138
+ time_ratio = ae_lstm_total_time / flexi_time
139
+
140
+ # Time plot
141
+ fig, ax = plt.subplots(figsize=(11, 6))
142
+ ax.bar(["Flexi-Prop", "AE LSTM (Encode)", "AE LSTM (Recursion)", "AE LSTM (Decode)", "AE LSTM (Total)"],
143
+ [flexi_time, encode_time, recursion_time, decode_time, ae_lstm_total_time],
144
+ color=["C0", "C1", "C2", "C3", "C4"])
145
+ ax.set_ylabel("Time (s)", fontsize=14)
146
+ ax.set_title("Computation Time Comparison", fontsize=14)
147
+ ax.grid(alpha=0.3)
148
+
149
+ # Latent space visualization
150
+ latent_fig = plot_latent_interpretation(Re, tau, t0)
151
+
152
+ return fig1, fig, time_ratio, latent_fig
153
+
154
+ def plot_latent_interpretation(Re, tau, t_0):
155
+ tau_tensor = torch.tensor([tau]).float()[:, None]
156
+ Re_tensor = torch.tensor([Re]).float()[:, None]
157
+ x_t = exact_solution_burgers(Re, t_0)
158
+ xt = torch.tensor([x_t]).float()[:, None]
159
+
160
+ with torch.no_grad():
161
+ _, _, _, _, z_tau = flexi_prop_model(xt, tau_tensor, Re_tensor)
162
+
163
+ z_tau = z_tau.squeeze().numpy()
164
+
165
+ fig, axes = plt.subplots(1, 2, figsize=(9, 3))
166
+
167
+ # Sharpness Plot
168
+ c1 = axes[0].pcolormesh(Z1, Z2, sharpness, cmap='plasma', shading='gouraud')
169
+ axes[0].scatter(z_tau[0], z_tau[1], color='red', marker='o', s=50, label="Current State")
170
+ axes[0].set_ylabel("$Z_2$", fontsize=14)
171
+ axes[0].set_title("Sharpness Encoding", fontsize=14)
172
+ fig.colorbar(c1, ax=axes[0])
173
+ axes[0].legend()
174
+
175
+ # Peak Position Plot
176
+ c2 = axes[1].pcolormesh(Z1, Z2, peak_positions, cmap='viridis', shading='gouraud')
177
+ axes[1].scatter(z_tau[0], z_tau[1], color='red', marker='o', s=50, label="Current State")
178
+ axes[1].set_title("Peak position Encoding", fontsize=14)
179
+ fig.colorbar(c2, ax=axes[1], label="Peak Position")
180
+
181
+ # Remove redundant y-axis labels on the second plot for better aesthetics
182
+ axes[1].set_yticklabels([])
183
+
184
+ # Set a single x-axis label centered below both plots
185
+ fig.supxlabel("$Z_1$", fontsize=14)
186
+
187
+ return fig
188
+
189
+ # ========== Advection-Diffusion Setup ==========
190
+ def get_adv_dif_model(latent_dim, output_dim):
191
+ encoder = Encoder2D(latent_dim)
192
+ decoder = Decoder2D(latent_dim)
193
+ propagator = Propagator2D(latent_dim)
194
+ return Model2D(encoder, decoder, propagator)
195
+
196
+ adv_dif_model = get_adv_dif_model(3, 128)
197
+ adv_dif_model, _, _, _ = load_model_adv_dif(
198
+ "../2D_adv_dif/FlexiPropagator_2D_2025-01-30-12-11-01_0aee8fb0_best.pt",
199
+ adv_dif_model
200
+ )
201
+
202
+ def generate_3d_visualization(Re, t_0, tau):
203
+ dt = 2 / 500
204
+ t = t_0 + tau * dt
205
+
206
+ U_initial = exact_solution_adv_dif(Re, t_0)
207
+ U_evolved = exact_solution_adv_dif(Re, t)
208
+
209
+ if np.isnan(U_initial).any() or np.isnan(U_evolved).any():
210
+ return None
211
+
212
+ fig3d = plt.figure(figsize=(12, 5))
213
+ ax3d = fig3d.add_subplot(111, projection='3d')
214
+
215
+ x_vals = np.linspace(-2, 2, U_initial.shape[1])
216
+ y_vals = np.linspace(-2, 2, U_initial.shape[0])
217
+ X, Y = np.meshgrid(x_vals, y_vals)
218
+
219
+ surf1 = ax3d.plot_surface(X, Y, U_initial, cmap="viridis", alpha=0.6, label="Initial")
220
+ surf2 = ax3d.plot_surface(X, Y, U_evolved, cmap="plasma", alpha=0.8, label="Evolved")
221
+
222
+ ax3d.set_xlim(-3, 3)
223
+ ax3d.set_xlabel("x")
224
+ ax3d.set_ylabel("y")
225
+ ax3d.set_zlabel("u(x,y,t)")
226
+ ax3d.view_init(elev=25, azim=-45)
227
+ ax3d.set_box_aspect((2,1,1))
228
+
229
+ fig3d.colorbar(surf1, ax=ax3d, shrink=0.5, label="Initial")
230
+ fig3d.colorbar(surf2, ax=ax3d, shrink=0.5, label="Evolved")
231
+ ax3d.set_title(f"Solution Evolution\nInitial ($t_0$={t_0:.2f}) vs Evolved ($t_f$={t:.2f})")
232
+
233
+ plt.tight_layout()
234
+ plt.close(fig3d)
235
+ return fig3d
236
+
237
+ def adv_dif_comparison(Re, t_0, tau):
238
+ dt = 2 / 500
239
+ exact_initial = exact_solution_adv_dif(Re, t_0)
240
+ exact_final = exact_solution_adv_dif(Re, t_0 + tau * dt)
241
+
242
+ if np.isnan(exact_initial).any() or np.isnan(exact_final).any():
243
+ return None
244
+
245
+ x_in = torch.tensor(exact_initial, dtype=torch.float32)[None, None, :, :]
246
+ Re_in = torch.tensor([[Re]], dtype=torch.float32)
247
+ tau_in = torch.tensor([[tau]], dtype=torch.float32)
248
+
249
+ with torch.no_grad():
250
+ x_hat, x_hat_tau, *_ = adv_dif_model(x_in, tau_in, Re_in)
251
+
252
+ pred = x_hat_tau.squeeze().numpy()
253
+ if pred.shape != exact_final.shape:
254
+ return None
255
+
256
+ mse = np.square(pred - exact_final)
257
+
258
+ fig, axs = plt.subplots(1, 3, figsize=(15, 4))
259
+
260
+ for ax, (data, title) in zip(axs, [(pred, "Model Prediction"),
261
+ (exact_final, "Exact Solution"),
262
+ (mse, "MSE Error")]):
263
+ if title == "MSE Error":
264
+ im = ax.imshow(data, cmap="viridis", vmin=0, vmax=1e-2)
265
+ plt.colorbar(im, ax=ax, fraction=0.075)
266
+ else:
267
+ im = ax.imshow(data, cmap="jet")
268
+
269
+ ax.set_title(title)
270
+ ax.axis("off")
271
+
272
+ plt.tight_layout()
273
+ plt.close(fig)
274
+ return fig
275
+
276
+ def update_initial_plot(Re, t_0):
277
+ exact_initial = exact_solution_adv_dif(Re, t_0)
278
+ fig, ax = plt.subplots(figsize=(5, 5))
279
+ im = ax.imshow(exact_initial, cmap='jet')
280
+ plt.colorbar(im, ax=ax)
281
+ ax.set_title('Initial State')
282
+ return fig
283
+
284
+ # ========== Gradio Interface ==========
285
+ with gr.Blocks(title="Flexi-Propagator: PDE Prediction Suite") as app:
286
+ gr.Markdown("# Flexi-Propagator: Unified PDE Prediction Interface")
287
+
288
+ with gr.Tabs():
289
+ # 1D Burgers' Equation Tab
290
+ with gr.Tab("1D Burgers' Equation"):
291
+ gr.Markdown(r"""
292
+ ## 🚀 Flexi-Propagator: Single-Shot Prediction for Nonlinear PDEs
293
+ **Governing Equation (1D Burgers' Equation):**
294
+ $$
295
+ \frac{\partial u}{\partial t} + u \frac{\partial u}{\partial x} = \nu \frac{\partial^2 u}{\partial x^2}
296
+ $$
297
+ **Key Advantages:**
298
+ ✔️ **60-150× faster** than AE-LSTM baselines
299
+ ✔️ **Parametric control**: Embeds system parameters in latent space
300
+
301
+ **Physically Interpretable Latent Space - Disentanglement:**
302
+ <div align="left">
303
+ $$
304
+ Z_1 \text{ Encodes Peak Location, } Z_2 \text{ Predominantly Encodes Re (Sharpness)}
305
+ $$
306
+ </div>
307
+
308
+ """)
309
+
310
+ with gr.Row():
311
+ with gr.Column():
312
+ re_burgers = gr.Slider(425, 2350, 1040, label="Reynolds Number")
313
+ tau_burgers = gr.Slider(150, 450, 315, label="Time Steps (τ)")
314
+ t0_burgers = gr.Number(0.4, label="Initial Time")
315
+ latent_plot = gr.Plot(label="Latent Space Dynamics")
316
+ with gr.Column():
317
+ burgers_plot = gr.Plot()
318
+ time_plot = gr.Plot()
319
+ ratio_out = gr.Number(label="Time Ratio (Flexi Prop/AE LSTM)")
320
+
321
+ # with gr.Row():
322
+ # latent_plot = gr.Plot(label="Latent Space Dynamics")
323
+
324
+ re_burgers.change(burgers_update, [re_burgers, tau_burgers, t0_burgers],
325
+ [burgers_plot, time_plot, ratio_out, latent_plot])
326
+ tau_burgers.change(burgers_update, [re_burgers, tau_burgers, t0_burgers],
327
+ [burgers_plot, time_plot, ratio_out, latent_plot])
328
+ t0_burgers.change(burgers_update, [re_burgers, tau_burgers, t0_burgers],
329
+ [burgers_plot, time_plot, ratio_out, latent_plot])
330
+
331
+ # 2D Advection-Diffusion Tab
332
+ with gr.Tab("2D Advection-Diffusion"):
333
+ gr.Markdown(r"""
334
+ ## 🌪️ 2D Advection-Diffusion Visualization
335
+ **Governing Equation:**
336
+ $$
337
+ \frac{\partial u}{\partial t} + c \frac{\partial u}{\partial x} = \nu \left( \frac{\partial^2 u}{\partial x^2} + \frac{\partial^2 u}{\partial y^2} \right)
338
+ $$
339
+ """)
340
+
341
+ with gr.Row():
342
+ with gr.Column(scale=1):
343
+ re_adv = gr.Slider(1, 10, 9, label="Reynolds Number (Re)")
344
+ t0_adv = gr.Number(0.45, label="Initial Time")
345
+ tau_adv = gr.Slider(150, 425, 225, label="Tau (τ)")
346
+ initial_plot_adv = gr.Plot(label="Initial State")
347
+
348
+ with gr.Column(scale=3):
349
+ with gr.Row():
350
+ three_d_plot_adv = gr.Plot(label="3D Evolution")
351
+ with gr.Row():
352
+ comparison_plots_adv = gr.Plot(label="Model Comparison")
353
+
354
+ def adv_update(Re, t0, tau):
355
+ return (
356
+ generate_3d_visualization(Re, t0, tau),
357
+ adv_dif_comparison(Re, t0, tau),
358
+ update_initial_plot(Re, t0)
359
+ )
360
+
361
+ for component in [re_adv, t0_adv, tau_adv]:
362
+ component.change(adv_update, [re_adv, t0_adv, tau_adv],
363
+ [three_d_plot_adv, comparison_plots_adv, initial_plot_adv])
364
+
365
+ app.load(lambda: adv_update(8, 0.35, 225),
366
+ outputs=[three_d_plot_adv, comparison_plots_adv, initial_plot_adv])
367
+
368
+ app.launch()
369
+
370
+
371
+ # In[ ]:
372
+
373
+
374
+
375
+