Khalid Rafiq commited on
Commit
0710b67
·
1 Parent(s): ab72d17

Rename Gradio_Overall.ipynb to app.ipynb

Browse files
Files changed (1) hide show
  1. app.ipynb +434 -0
app.ipynb ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 2,
6
+ "id": "bc0b9235-53d3-49f1-a297-e404370cd5d9",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "* Running on local URL: http://127.0.0.1:7884\n",
14
+ "\n",
15
+ "To create a public link, set `share=True` in `launch()`.\n"
16
+ ]
17
+ },
18
+ {
19
+ "data": {
20
+ "text/html": [
21
+ "<div><iframe src=\"http://127.0.0.1:7884/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
22
+ ],
23
+ "text/plain": [
24
+ "<IPython.core.display.HTML object>"
25
+ ]
26
+ },
27
+ "metadata": {},
28
+ "output_type": "display_data"
29
+ },
30
+ {
31
+ "data": {
32
+ "text/plain": []
33
+ },
34
+ "execution_count": 2,
35
+ "metadata": {},
36
+ "output_type": "execute_result"
37
+ }
38
+ ],
39
+ "source": [
40
+ "import time\n",
41
+ "import torch\n",
42
+ "import warnings\n",
43
+ "import numpy as np\n",
44
+ "import gradio as gr\n",
45
+ "import matplotlib.pyplot as plt\n",
46
+ "\n",
47
+ "# Import Burgers' equation components\n",
48
+ "from data_burgers import exact_solution as exact_solution_burgers\n",
49
+ "from model_io_burgers import load_model\n",
50
+ "from model_v2 import Encoder, Decoder, Propagator_concat as Propagator, Model\n",
51
+ "from LSTM_model import AE_Encoder, AE_Decoder, AE_Model, PytorchLSTM\n",
52
+ "\n",
53
+ "# Import Advection-Diffusion components\n",
54
+ "from data_adv_dif import exact_solution as exact_solution_adv_dif\n",
55
+ "from model_io_adv_dif import load_model as load_model_adv_dif\n",
56
+ "from model_adv_dif import Encoder as Encoder2D, Decoder as Decoder2D, Propagator_concat as Propagator2D, Model as Model2D\n",
57
+ "\n",
58
+ "warnings.filterwarnings(\"ignore\")\n",
59
+ "\n",
60
+ "# ========== Burgers' Equation Setup ==========\n",
61
+ "def get_burgers_model(input_dim, latent_dim):\n",
62
+ " encoder = Encoder(input_dim, latent_dim)\n",
63
+ " decoder = Decoder(latent_dim, input_dim)\n",
64
+ " propagator = Propagator(latent_dim)\n",
65
+ " return Model(encoder, decoder, propagator)\n",
66
+ "\n",
67
+ "flexi_prop_model = get_burgers_model(128, 2)\n",
68
+ "checkpoint = torch.load(\"../1d_viscous_burgers/FlexiPropagator_2025-02-01-10-28-34_3e9656b5_best.pt\", map_location='cpu')\n",
69
+ "flexi_prop_model.load_state_dict(checkpoint['model_state_dict'])\n",
70
+ "flexi_prop_model.eval()\n",
71
+ "\n",
72
+ "# AE LSTM models\n",
73
+ "ae_encoder = AE_Encoder(128)\n",
74
+ "ae_decoder = AE_Decoder(2, 128)\n",
75
+ "ae_model = AE_Model(ae_encoder, ae_decoder)\n",
76
+ "lstm_model = PytorchLSTM()\n",
77
+ "\n",
78
+ "ae_encoder.load_state_dict(torch.load(\"../1d_viscous_burgers/LSTM_model/ae_encoder_weights.pth\", map_location='cpu'))\n",
79
+ "ae_decoder.load_state_dict(torch.load(\"../1d_viscous_burgers/LSTM_model/ae_decoder_weights.pth\", map_location='cpu'))\n",
80
+ "ae_model.load_state_dict(torch.load(\"../1d_viscous_burgers/LSTM_model/ae_model.pth\", map_location='cpu'))\n",
81
+ "lstm_model.load_state_dict(torch.load(\"../1d_viscous_burgers/LSTM_model/lstm_weights.pth\", map_location='cpu'))\n",
82
+ "\n",
83
+ "# ========== Helper Functions Burgers ==========\n",
84
+ "def exacts_equals_timewindow(t_0, Re, time_window=40):\n",
85
+ " dt = 2 / 500\n",
86
+ " solutions = [exact_solution_burgers(Re, t) for t in (t_0 + np.arange(0, time_window) * dt)]\n",
87
+ " solns = torch.tensor(solutions, dtype=torch.float32)[None, :, :]\n",
88
+ " latents = ae_encoder(solns)\n",
89
+ " re_normalized = Re / 1000\n",
90
+ " re_repeated = torch.ones(1, time_window, 1) * re_normalized\n",
91
+ " return torch.cat((latents, re_repeated), dim=2), latents, solns\n",
92
+ "\n",
93
+ "# Precompute contour plots\n",
94
+ "z1_vals = np.linspace(-10, 0.5, 200)\n",
95
+ "z2_vals = np.linspace(5, 32, 200)\n",
96
+ "Z1, Z2 = np.meshgrid(z1_vals, z2_vals)\n",
97
+ "latent_grid = np.stack([Z1.ravel(), Z2.ravel()], axis=1)\n",
98
+ "\n",
99
+ "# Convert to tensor for decoding\n",
100
+ "latent_tensors = torch.tensor(latent_grid, dtype=torch.float32)\n",
101
+ "\n",
102
+ "# Decode latent vectors and compute properties\n",
103
+ "with torch.no_grad():\n",
104
+ " decoded_signals = flexi_prop_model.decoder(latent_tensors)\n",
105
+ "\n",
106
+ "sharpness = []\n",
107
+ "peak_positions = []\n",
108
+ "x_vals = np.linspace(0, 2, decoded_signals.shape[1])\n",
109
+ "dx = x_vals[1] - x_vals[0]\n",
110
+ "\n",
111
+ "for signal in decoded_signals.numpy():\n",
112
+ " grad_u = np.gradient(signal, dx)\n",
113
+ " sharpness.append(np.max(np.abs(grad_u)))\n",
114
+ " peak_positions.append(x_vals[np.argmax(signal)])\n",
115
+ "\n",
116
+ "sharpness = np.array(sharpness).reshape(Z1.shape)\n",
117
+ "peak_positions = np.array(peak_positions).reshape(Z1.shape)\n",
118
+ "\n",
119
+ "def plot_burgers_comparison(Re, tau, t_0):\n",
120
+ " dt = 2.0 / 500.0\n",
121
+ " t_final = t_0 + tau * dt\n",
122
+ " x_exact = exact_solution_burgers(Re, t_final)\n",
123
+ " \n",
124
+ " tau_tensor, Re_tensor, xt = torch.tensor([tau]).float()[:, None], torch.tensor([Re]).float()[:, None], torch.tensor([exact_solution_burgers(Re, t_0)]).float()[:, None]\n",
125
+ "\n",
126
+ " with torch.no_grad():\n",
127
+ " _, x_hat_tau, *_ = flexi_prop_model(xt, tau_tensor, Re_tensor)\n",
128
+ "\n",
129
+ " latent_for_lstm, *_ = exacts_equals_timewindow(t_0, Re)\n",
130
+ " with torch.no_grad():\n",
131
+ " for _ in range(40, tau):\n",
132
+ " pred = lstm_model(latent_for_lstm)\n",
133
+ " pred_with_re = torch.cat((pred, torch.tensor([[Re / 1000]], dtype=torch.float32)), dim=1)\n",
134
+ " latent_for_lstm = torch.cat((latent_for_lstm[:, 1:, :], pred_with_re.unsqueeze(0)), dim=1)\n",
135
+ " final_pred_high_dim = ae_decoder(pred.unsqueeze(0))\n",
136
+ "\n",
137
+ " fig, ax = plt.subplots(figsize=(9, 5))\n",
138
+ " ax.plot(xt.squeeze(), '--', linewidth=3, alpha=0.5, color=\"C0\")\n",
139
+ " ax.plot(x_hat_tau.squeeze(), 'D', markersize=5, color=\"C2\")\n",
140
+ " ax.plot(final_pred_high_dim.squeeze().detach().numpy(), '^', markersize=5, color=\"C1\")\n",
141
+ " ax.plot(x_exact.squeeze(), linewidth=2, alpha=0.5, color=\"Black\")\n",
142
+ " ax.set_title(f\"Comparison ($t_0$={t_0:.2f} → $t_f$={t_final:.2f}), τ={tau}\", fontsize=14)\n",
143
+ " ax.legend([\"Initial\", \"Flexi-Prop\", \"AE LSTM\", \"True\"])\n",
144
+ " return fig\n",
145
+ "\n",
146
+ "def burgers_update(Re, tau, t0):\n",
147
+ " fig1 = plot_burgers_comparison(Re, tau, t0)\n",
148
+ "\n",
149
+ " # Timing calculations\n",
150
+ " start = time.time()\n",
151
+ " _ = flexi_prop_model(torch.randn(1, 1, 128), torch.tensor([[tau]]), torch.tensor([[Re]]))\n",
152
+ " flexi_time = time.time() - start\n",
153
+ "\n",
154
+ " start = time.time()\n",
155
+ " latent_for_lstm, _, _ = exacts_equals_timewindow(t0, Re, 40)\n",
156
+ " encode_time = time.time() - start\n",
157
+ "\n",
158
+ " start = time.time()\n",
159
+ " with torch.no_grad():\n",
160
+ " for _ in range(40, tau):\n",
161
+ " pred = lstm_model(latent_for_lstm)\n",
162
+ " pred_with_re = torch.cat((pred, torch.tensor([[Re / 1000]], dtype=torch.float32)), dim=1)\n",
163
+ " latent_for_lstm = torch.cat((latent_for_lstm[:, 1:, :], pred_with_re.unsqueeze(0)), dim=1)\n",
164
+ " recursion_time = time.time() - start\n",
165
+ "\n",
166
+ " start = time.time()\n",
167
+ " final_pred_high_dim = ae_decoder(pred.unsqueeze(0))\n",
168
+ " decode_time = time.time() - start\n",
169
+ "\n",
170
+ " ae_lstm_total_time = encode_time + recursion_time + decode_time\n",
171
+ " time_ratio = ae_lstm_total_time / flexi_time\n",
172
+ "\n",
173
+ " # Time plot\n",
174
+ " fig, ax = plt.subplots(figsize=(11, 6))\n",
175
+ " ax.bar([\"Flexi-Prop\", \"AE LSTM (Encode)\", \"AE LSTM (Recursion)\", \"AE LSTM (Decode)\", \"AE LSTM (Total)\"],\n",
176
+ " [flexi_time, encode_time, recursion_time, decode_time, ae_lstm_total_time], \n",
177
+ " color=[\"C0\", \"C1\", \"C2\", \"C3\", \"C4\"])\n",
178
+ " ax.set_ylabel(\"Time (s)\", fontsize=14)\n",
179
+ " ax.set_title(\"Computation Time Comparison\", fontsize=14)\n",
180
+ " ax.grid(alpha=0.3)\n",
181
+ "\n",
182
+ " # Latent space visualization\n",
183
+ " latent_fig = plot_latent_interpretation(Re, tau, t0)\n",
184
+ "\n",
185
+ " return fig1, fig, time_ratio, latent_fig\n",
186
+ "\n",
187
+ "def plot_latent_interpretation(Re, tau, t_0):\n",
188
+ " tau_tensor = torch.tensor([tau]).float()[:, None]\n",
189
+ " Re_tensor = torch.tensor([Re]).float()[:, None]\n",
190
+ " x_t = exact_solution_burgers(Re, t_0)\n",
191
+ " xt = torch.tensor([x_t]).float()[:, None]\n",
192
+ "\n",
193
+ " with torch.no_grad():\n",
194
+ " _, _, _, _, z_tau = flexi_prop_model(xt, tau_tensor, Re_tensor)\n",
195
+ " \n",
196
+ " z_tau = z_tau.squeeze().numpy()\n",
197
+ "\n",
198
+ " fig, axes = plt.subplots(1, 2, figsize=(9, 3))\n",
199
+ "\n",
200
+ " # Sharpness Plot\n",
201
+ " c1 = axes[0].pcolormesh(Z1, Z2, sharpness, cmap='plasma', shading='gouraud')\n",
202
+ " axes[0].scatter(z_tau[0], z_tau[1], color='red', marker='o', s=50, label=\"Current State\")\n",
203
+ " axes[0].set_ylabel(\"$Z_2$\", fontsize=14)\n",
204
+ " axes[0].set_title(\"Sharpness Encoding\", fontsize=14)\n",
205
+ " fig.colorbar(c1, ax=axes[0])\n",
206
+ " axes[0].legend()\n",
207
+ "\n",
208
+ " # Peak Position Plot\n",
209
+ " c2 = axes[1].pcolormesh(Z1, Z2, peak_positions, cmap='viridis', shading='gouraud')\n",
210
+ " axes[1].scatter(z_tau[0], z_tau[1], color='red', marker='o', s=50, label=\"Current State\")\n",
211
+ " axes[1].set_title(\"Peak position Encoding\", fontsize=14)\n",
212
+ " fig.colorbar(c2, ax=axes[1], label=\"Peak Position\")\n",
213
+ " \n",
214
+ " # Remove redundant y-axis labels on the second plot for better aesthetics\n",
215
+ " axes[1].set_yticklabels([])\n",
216
+ "\n",
217
+ " # Set a single x-axis label centered below both plots\n",
218
+ " fig.supxlabel(\"$Z_1$\", fontsize=14)\n",
219
+ "\n",
220
+ " return fig\n",
221
+ "\n",
222
+ "# ========== Advection-Diffusion Setup ==========\n",
223
+ "def get_adv_dif_model(latent_dim, output_dim):\n",
224
+ " encoder = Encoder2D(latent_dim)\n",
225
+ " decoder = Decoder2D(latent_dim)\n",
226
+ " propagator = Propagator2D(latent_dim)\n",
227
+ " return Model2D(encoder, decoder, propagator)\n",
228
+ "\n",
229
+ "adv_dif_model = get_adv_dif_model(3, 128)\n",
230
+ "adv_dif_model, _, _, _ = load_model_adv_dif(\n",
231
+ " \"../2D_adv_dif/FlexiPropagator_2D_2025-01-30-12-11-01_0aee8fb0_best.pt\", \n",
232
+ " adv_dif_model\n",
233
+ ")\n",
234
+ "\n",
235
+ "def generate_3d_visualization(Re, t_0, tau):\n",
236
+ " dt = 2 / 500\n",
237
+ " t = t_0 + tau * dt\n",
238
+ "\n",
239
+ " U_initial = exact_solution_adv_dif(Re, t_0)\n",
240
+ " U_evolved = exact_solution_adv_dif(Re, t)\n",
241
+ "\n",
242
+ " if np.isnan(U_initial).any() or np.isnan(U_evolved).any():\n",
243
+ " return None\n",
244
+ "\n",
245
+ " fig3d = plt.figure(figsize=(12, 5))\n",
246
+ " ax3d = fig3d.add_subplot(111, projection='3d')\n",
247
+ "\n",
248
+ " x_vals = np.linspace(-2, 2, U_initial.shape[1])\n",
249
+ " y_vals = np.linspace(-2, 2, U_initial.shape[0])\n",
250
+ " X, Y = np.meshgrid(x_vals, y_vals)\n",
251
+ "\n",
252
+ " surf1 = ax3d.plot_surface(X, Y, U_initial, cmap=\"viridis\", alpha=0.6, label=\"Initial\")\n",
253
+ " surf2 = ax3d.plot_surface(X, Y, U_evolved, cmap=\"plasma\", alpha=0.8, label=\"Evolved\")\n",
254
+ "\n",
255
+ " ax3d.set_xlim(-3, 3)\n",
256
+ " ax3d.set_xlabel(\"x\")\n",
257
+ " ax3d.set_ylabel(\"y\")\n",
258
+ " ax3d.set_zlabel(\"u(x,y,t)\")\n",
259
+ " ax3d.view_init(elev=25, azim=-45)\n",
260
+ " ax3d.set_box_aspect((2,1,1))\n",
261
+ "\n",
262
+ " fig3d.colorbar(surf1, ax=ax3d, shrink=0.5, label=\"Initial\")\n",
263
+ " fig3d.colorbar(surf2, ax=ax3d, shrink=0.5, label=\"Evolved\")\n",
264
+ " ax3d.set_title(f\"Solution Evolution\\nInitial ($t_0$={t_0:.2f}) vs Evolved ($t_f$={t:.2f})\")\n",
265
+ "\n",
266
+ " plt.tight_layout()\n",
267
+ " plt.close(fig3d)\n",
268
+ " return fig3d\n",
269
+ "\n",
270
+ "def adv_dif_comparison(Re, t_0, tau):\n",
271
+ " dt = 2 / 500\n",
272
+ " exact_initial = exact_solution_adv_dif(Re, t_0)\n",
273
+ " exact_final = exact_solution_adv_dif(Re, t_0 + tau * dt)\n",
274
+ "\n",
275
+ " if np.isnan(exact_initial).any() or np.isnan(exact_final).any():\n",
276
+ " return None\n",
277
+ "\n",
278
+ " x_in = torch.tensor(exact_initial, dtype=torch.float32)[None, None, :, :]\n",
279
+ " Re_in = torch.tensor([[Re]], dtype=torch.float32)\n",
280
+ " tau_in = torch.tensor([[tau]], dtype=torch.float32)\n",
281
+ "\n",
282
+ " with torch.no_grad():\n",
283
+ " x_hat, x_hat_tau, *_ = adv_dif_model(x_in, tau_in, Re_in)\n",
284
+ "\n",
285
+ " pred = x_hat_tau.squeeze().numpy()\n",
286
+ " if pred.shape != exact_final.shape:\n",
287
+ " return None\n",
288
+ "\n",
289
+ " mse = np.square(pred - exact_final)\n",
290
+ "\n",
291
+ " fig, axs = plt.subplots(1, 3, figsize=(15, 4))\n",
292
+ "\n",
293
+ " for ax, (data, title) in zip(axs, [(pred, \"Model Prediction\"),\n",
294
+ " (exact_final, \"Exact Solution\"),\n",
295
+ " (mse, \"MSE Error\")]):\n",
296
+ " if title == \"MSE Error\":\n",
297
+ " im = ax.imshow(data, cmap=\"viridis\", vmin=0, vmax=1e-2)\n",
298
+ " plt.colorbar(im, ax=ax, fraction=0.075)\n",
299
+ " else:\n",
300
+ " im = ax.imshow(data, cmap=\"jet\")\n",
301
+ "\n",
302
+ " ax.set_title(title)\n",
303
+ " ax.axis(\"off\")\n",
304
+ "\n",
305
+ " plt.tight_layout()\n",
306
+ " plt.close(fig)\n",
307
+ " return fig\n",
308
+ "\n",
309
+ "def update_initial_plot(Re, t_0):\n",
310
+ " exact_initial = exact_solution_adv_dif(Re, t_0)\n",
311
+ " fig, ax = plt.subplots(figsize=(5, 5))\n",
312
+ " im = ax.imshow(exact_initial, cmap='jet')\n",
313
+ " plt.colorbar(im, ax=ax)\n",
314
+ " ax.set_title('Initial State')\n",
315
+ " return fig\n",
316
+ "\n",
317
+ "# ========== Gradio Interface ==========\n",
318
+ "with gr.Blocks(title=\"Flexi-Propagator: PDE Prediction Suite\") as app:\n",
319
+ " gr.Markdown(\"# Flexi-Propagator: Unified PDE Prediction Interface\")\n",
320
+ "\n",
321
+ " with gr.Tabs():\n",
322
+ " # 1D Burgers' Equation Tab\n",
323
+ " with gr.Tab(\"1D Burgers' Equation\"):\n",
324
+ " gr.Markdown(r\"\"\"\n",
325
+ " ## 🚀 Flexi-Propagator: Single-Shot Prediction for Nonlinear PDEs\n",
326
+ " **Governing Equation (1D Burgers' Equation):**\n",
327
+ " $$\n",
328
+ " \\frac{\\partial u}{\\partial t} + u \\frac{\\partial u}{\\partial x} = \\nu \\frac{\\partial^2 u}{\\partial x^2}\n",
329
+ " $$\n",
330
+ " **Key Advantages:** \n",
331
+ " ✔️ **60-150× faster** than AE-LSTM baselines \n",
332
+ " ✔️ **Parametric control**: Embeds system parameters in latent space \n",
333
+ " \n",
334
+ " **Physically Interpretable Latent Space - Disentanglement:** \n",
335
+ " <div align=\"left\">\n",
336
+ " $$\n",
337
+ " Z_1 \\text{ Encodes Peak Location, } Z_2 \\text{ Predominantly Encodes Re (Sharpness)}\n",
338
+ " $$\n",
339
+ " </div>\n",
340
+ "\n",
341
+ " \"\"\")\n",
342
+ " \n",
343
+ " with gr.Row():\n",
344
+ " with gr.Column():\n",
345
+ " re_burgers = gr.Slider(425, 2350, 1040, label=\"Reynolds Number\")\n",
346
+ " tau_burgers = gr.Slider(150, 450, 315, label=\"Time Steps (τ)\")\n",
347
+ " t0_burgers = gr.Number(0.4, label=\"Initial Time\")\n",
348
+ " latent_plot = gr.Plot(label=\"Latent Space Dynamics\")\n",
349
+ " with gr.Column():\n",
350
+ " burgers_plot = gr.Plot()\n",
351
+ " time_plot = gr.Plot()\n",
352
+ " ratio_out = gr.Number(label=\"Time Ratio (Flexi Prop/AE LSTM)\")\n",
353
+ " \n",
354
+ " # with gr.Row():\n",
355
+ " # latent_plot = gr.Plot(label=\"Latent Space Dynamics\")\n",
356
+ "\n",
357
+ " re_burgers.change(burgers_update, [re_burgers, tau_burgers, t0_burgers], \n",
358
+ " [burgers_plot, time_plot, ratio_out, latent_plot])\n",
359
+ " tau_burgers.change(burgers_update, [re_burgers, tau_burgers, t0_burgers], \n",
360
+ " [burgers_plot, time_plot, ratio_out, latent_plot])\n",
361
+ " t0_burgers.change(burgers_update, [re_burgers, tau_burgers, t0_burgers], \n",
362
+ " [burgers_plot, time_plot, ratio_out, latent_plot])\n",
363
+ "\n",
364
+ " # 2D Advection-Diffusion Tab\n",
365
+ " with gr.Tab(\"2D Advection-Diffusion\"):\n",
366
+ " gr.Markdown(r\"\"\"\n",
367
+ " ## 🌪️ 2D Advection-Diffusion Visualization\n",
368
+ " **Governing Equation:**\n",
369
+ " $$\n",
370
+ " \\frac{\\partial u}{\\partial t} + c \\frac{\\partial u}{\\partial x} = \\nu \\left( \\frac{\\partial^2 u}{\\partial x^2} + \\frac{\\partial^2 u}{\\partial y^2} \\right)\n",
371
+ " $$\n",
372
+ " \"\"\")\n",
373
+ " \n",
374
+ " with gr.Row():\n",
375
+ " with gr.Column(scale=1):\n",
376
+ " re_adv = gr.Slider(1, 10, 9, label=\"Reynolds Number (Re)\")\n",
377
+ " t0_adv = gr.Number(0.45, label=\"Initial Time\")\n",
378
+ " tau_adv = gr.Slider(150, 425, 225, label=\"Tau (τ)\")\n",
379
+ " initial_plot_adv = gr.Plot(label=\"Initial State\")\n",
380
+ " \n",
381
+ " with gr.Column(scale=3):\n",
382
+ " with gr.Row():\n",
383
+ " three_d_plot_adv = gr.Plot(label=\"3D Evolution\")\n",
384
+ " with gr.Row():\n",
385
+ " comparison_plots_adv = gr.Plot(label=\"Model Comparison\")\n",
386
+ "\n",
387
+ " def adv_update(Re, t0, tau):\n",
388
+ " return (\n",
389
+ " generate_3d_visualization(Re, t0, tau),\n",
390
+ " adv_dif_comparison(Re, t0, tau),\n",
391
+ " update_initial_plot(Re, t0)\n",
392
+ " )\n",
393
+ "\n",
394
+ " for component in [re_adv, t0_adv, tau_adv]:\n",
395
+ " component.change(adv_update, [re_adv, t0_adv, tau_adv], \n",
396
+ " [three_d_plot_adv, comparison_plots_adv, initial_plot_adv])\n",
397
+ "\n",
398
+ " app.load(lambda: adv_update(8, 0.35, 225), \n",
399
+ " outputs=[three_d_plot_adv, comparison_plots_adv, initial_plot_adv])\n",
400
+ "\n",
401
+ "app.launch()"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "execution_count": null,
407
+ "id": "73e3f1df-972c-4966-9216-8ce7583a5e58",
408
+ "metadata": {},
409
+ "outputs": [],
410
+ "source": []
411
+ }
412
+ ],
413
+ "metadata": {
414
+ "kernelspec": {
415
+ "display_name": "Python 3 (ipykernel)",
416
+ "language": "python",
417
+ "name": "python3"
418
+ },
419
+ "language_info": {
420
+ "codemirror_mode": {
421
+ "name": "ipython",
422
+ "version": 3
423
+ },
424
+ "file_extension": ".py",
425
+ "mimetype": "text/x-python",
426
+ "name": "python",
427
+ "nbconvert_exporter": "python",
428
+ "pygments_lexer": "ipython3",
429
+ "version": "3.10.12"
430
+ }
431
+ },
432
+ "nbformat": 4,
433
+ "nbformat_minor": 5
434
+ }