{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "bc0b9235-53d3-49f1-a297-e404370cd5d9", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Running on local URL: http://127.0.0.1:7886\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import time\n", "import torch\n", "import warnings\n", "import numpy as np\n", "import gradio as gr\n", "import matplotlib.pyplot as plt\n", "\n", "# Import Burgers' equation components\n", "from data_burgers import exact_solution as exact_solution_burgers\n", "from model_io_burgers import load_model\n", "from model_v2 import Encoder, Decoder, Propagator_concat as Propagator, Model\n", "from LSTM_model import AE_Encoder, AE_Decoder, AE_Model, PytorchLSTM\n", "\n", "# Import Advection-Diffusion components\n", "from data_adv_dif import exact_solution as exact_solution_adv_dif\n", "from model_io_adv_dif import load_model as load_model_adv_dif\n", "from model_adv_dif import Encoder as Encoder2D, Decoder as Decoder2D, Propagator_concat as Propagator2D, Model as Model2D\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "# ========== Burgers' Equation Setup ==========\n", "def get_burgers_model(input_dim, latent_dim):\n", " encoder = Encoder(input_dim, latent_dim)\n", " decoder = Decoder(latent_dim, input_dim)\n", " propagator = Propagator(latent_dim)\n", " return Model(encoder, decoder, propagator)\n", "\n", "flexi_prop_model = get_burgers_model(128, 2)\n", "checkpoint = torch.load(\"./FlexiPropagator_2025-02-01-10-28-34_3e9656b5_best.pt\", map_location='cpu')\n", "flexi_prop_model.load_state_dict(checkpoint['model_state_dict'])\n", "flexi_prop_model.eval()\n", "\n", "# AE LSTM models\n", "ae_encoder = AE_Encoder(128)\n", "ae_decoder = AE_Decoder(2, 128)\n", "ae_model = AE_Model(ae_encoder, ae_decoder)\n", "lstm_model = PytorchLSTM()\n", "\n", "ae_encoder.load_state_dict(torch.load(\"./ae_encoder_weights.pth\", map_location='cpu'))\n", "ae_decoder.load_state_dict(torch.load(\"./ae_decoder_weights.pth\", map_location='cpu'))\n", "ae_model.load_state_dict(torch.load(\"./ae_model.pth\", map_location='cpu'))\n", "lstm_model.load_state_dict(torch.load(\"./lstm_weights.pth\", map_location='cpu'))\n", "\n", "# ========== Helper Functions Burgers ==========\n", "def exacts_equals_timewindow(t_0, Re, time_window=40):\n", " dt = 2 / 500\n", " solutions = [exact_solution_burgers(Re, t) for t in (t_0 + np.arange(0, time_window) * dt)]\n", " solns = torch.tensor(solutions, dtype=torch.float32)[None, :, :]\n", " latents = ae_encoder(solns)\n", " re_normalized = Re / 1000\n", " re_repeated = torch.ones(1, time_window, 1) * re_normalized\n", " return torch.cat((latents, re_repeated), dim=2), latents, solns\n", "\n", "# Precompute contour plots\n", "z1_vals = np.linspace(-10, 0.5, 200)\n", "z2_vals = np.linspace(5, 32, 200)\n", "Z1, Z2 = np.meshgrid(z1_vals, z2_vals)\n", "latent_grid = np.stack([Z1.ravel(), Z2.ravel()], axis=1)\n", "\n", "# Convert to tensor for decoding\n", "latent_tensors = torch.tensor(latent_grid, dtype=torch.float32)\n", "\n", "# Decode latent vectors and compute properties\n", "with torch.no_grad():\n", " decoded_signals = flexi_prop_model.decoder(latent_tensors)\n", "\n", "sharpness = []\n", "peak_positions = []\n", "x_vals = np.linspace(0, 2, decoded_signals.shape[1])\n", "dx = x_vals[1] - x_vals[0]\n", "\n", "for signal in decoded_signals.numpy():\n", " grad_u = np.gradient(signal, dx)\n", " sharpness.append(np.max(np.abs(grad_u)))\n", " peak_positions.append(x_vals[np.argmax(signal)])\n", "\n", "sharpness = np.array(sharpness).reshape(Z1.shape)\n", "peak_positions = np.array(peak_positions).reshape(Z1.shape)\n", "\n", "def plot_burgers_comparison(Re, tau, t_0):\n", " dt = 2.0 / 500.0\n", " t_final = t_0 + tau * dt\n", " x_exact = exact_solution_burgers(Re, t_final)\n", " \n", " tau_tensor, Re_tensor, xt = torch.tensor([tau]).float()[:, None], torch.tensor([Re]).float()[:, None], torch.tensor([exact_solution_burgers(Re, t_0)]).float()[:, None]\n", "\n", " with torch.no_grad():\n", " _, x_hat_tau, *_ = flexi_prop_model(xt, tau_tensor, Re_tensor)\n", "\n", " latent_for_lstm, *_ = exacts_equals_timewindow(t_0, Re)\n", " with torch.no_grad():\n", " for _ in range(40, tau):\n", " pred = lstm_model(latent_for_lstm)\n", " pred_with_re = torch.cat((pred, torch.tensor([[Re / 1000]], dtype=torch.float32)), dim=1)\n", " latent_for_lstm = torch.cat((latent_for_lstm[:, 1:, :], pred_with_re.unsqueeze(0)), dim=1)\n", " final_pred_high_dim = ae_decoder(pred.unsqueeze(0))\n", "\n", " fig, ax = plt.subplots(figsize=(9, 5))\n", " ax.plot(xt.squeeze(), '--', linewidth=3, alpha=0.5, color=\"C0\")\n", " ax.plot(x_hat_tau.squeeze(), 'D', markersize=5, color=\"C2\")\n", " ax.plot(final_pred_high_dim.squeeze().detach().numpy(), '^', markersize=5, color=\"C1\")\n", " ax.plot(x_exact.squeeze(), linewidth=2, alpha=0.5, color=\"Black\")\n", " ax.set_title(f\"Comparison ($t_0$={t_0:.2f} → $t_f$={t_final:.2f}), τ={tau}\", fontsize=14)\n", " ax.legend([\"Initial\", \"Flexi-Prop\", \"AE LSTM\", \"True\"])\n", " return fig\n", "\n", "def burgers_update(Re, tau, t0):\n", " fig1 = plot_burgers_comparison(Re, tau, t0)\n", "\n", " # Timing calculations\n", " start = time.time()\n", " _ = flexi_prop_model(torch.randn(1, 1, 128), torch.tensor([[tau]]), torch.tensor([[Re]]))\n", " flexi_time = time.time() - start\n", "\n", " start = time.time()\n", " latent_for_lstm, _, _ = exacts_equals_timewindow(t0, Re, 40)\n", " encode_time = time.time() - start\n", "\n", " start = time.time()\n", " with torch.no_grad():\n", " for _ in range(40, tau):\n", " pred = lstm_model(latent_for_lstm)\n", " pred_with_re = torch.cat((pred, torch.tensor([[Re / 1000]], dtype=torch.float32)), dim=1)\n", " latent_for_lstm = torch.cat((latent_for_lstm[:, 1:, :], pred_with_re.unsqueeze(0)), dim=1)\n", " recursion_time = time.time() - start\n", "\n", " start = time.time()\n", " final_pred_high_dim = ae_decoder(pred.unsqueeze(0))\n", " decode_time = time.time() - start\n", "\n", " ae_lstm_total_time = encode_time + recursion_time + decode_time\n", " time_ratio = ae_lstm_total_time / flexi_time\n", "\n", " # Time plot\n", " fig, ax = plt.subplots(figsize=(11, 6))\n", " ax.bar([\"Flexi-Prop\", \"AE LSTM (Encode)\", \"AE LSTM (Recursion)\", \"AE LSTM (Decode)\", \"AE LSTM (Total)\"],\n", " [flexi_time, encode_time, recursion_time, decode_time, ae_lstm_total_time], \n", " color=[\"C0\", \"C1\", \"C2\", \"C3\", \"C4\"])\n", " ax.set_ylabel(\"Time (s)\", fontsize=14)\n", " ax.set_title(\"Computation Time Comparison\", fontsize=14)\n", " ax.grid(alpha=0.3)\n", "\n", " # Latent space visualization\n", " latent_fig = plot_latent_interpretation(Re, tau, t0)\n", "\n", " return fig1, fig, time_ratio, latent_fig\n", "\n", "def plot_latent_interpretation(Re, tau, t_0):\n", " tau_tensor = torch.tensor([tau]).float()[:, None]\n", " Re_tensor = torch.tensor([Re]).float()[:, None]\n", " x_t = exact_solution_burgers(Re, t_0)\n", " xt = torch.tensor([x_t]).float()[:, None]\n", "\n", " with torch.no_grad():\n", " _, _, _, _, z_tau = flexi_prop_model(xt, tau_tensor, Re_tensor)\n", " \n", " z_tau = z_tau.squeeze().numpy()\n", "\n", " fig, axes = plt.subplots(1, 2, figsize=(9, 3))\n", "\n", " # Sharpness Plot\n", " c1 = axes[0].pcolormesh(Z1, Z2, sharpness, cmap='plasma', shading='gouraud')\n", " axes[0].scatter(z_tau[0], z_tau[1], color='red', marker='o', s=50, label=\"Current State\")\n", " axes[0].set_ylabel(\"$Z_2$\", fontsize=14)\n", " axes[0].set_title(\"Sharpness Encoding\", fontsize=14)\n", " fig.colorbar(c1, ax=axes[0])\n", " axes[0].legend()\n", "\n", " # Peak Position Plot\n", " c2 = axes[1].pcolormesh(Z1, Z2, peak_positions, cmap='viridis', shading='gouraud')\n", " axes[1].scatter(z_tau[0], z_tau[1], color='red', marker='o', s=50, label=\"Current State\")\n", " axes[1].set_title(\"Peak position Encoding\", fontsize=14)\n", " fig.colorbar(c2, ax=axes[1], label=\"Peak Position\")\n", " \n", " # Remove redundant y-axis labels on the second plot for better aesthetics\n", " axes[1].set_yticklabels([])\n", "\n", " # Set a single x-axis label centered below both plots\n", " fig.supxlabel(\"$Z_1$\", fontsize=14)\n", "\n", " return fig\n", "\n", "# ========== Advection-Diffusion Setup ==========\n", "def get_adv_dif_model(latent_dim, output_dim):\n", " encoder = Encoder2D(latent_dim)\n", " decoder = Decoder2D(latent_dim)\n", " propagator = Propagator2D(latent_dim)\n", " return Model2D(encoder, decoder, propagator)\n", "\n", "adv_dif_model = get_adv_dif_model(3, 128)\n", "adv_dif_model, _, _, _ = load_model_adv_dif(\n", " \"./FlexiPropagator_2D_2025-01-30-12-11-01_0aee8fb0_best.pt\", \n", " adv_dif_model\n", ")\n", "\n", "def generate_3d_visualization(Re, t_0, tau):\n", " dt = 2 / 500\n", " t = t_0 + tau * dt\n", "\n", " U_initial = exact_solution_adv_dif(Re, t_0)\n", " U_evolved = exact_solution_adv_dif(Re, t)\n", "\n", " if np.isnan(U_initial).any() or np.isnan(U_evolved).any():\n", " return None\n", "\n", " fig3d = plt.figure(figsize=(12, 5))\n", " ax3d = fig3d.add_subplot(111, projection='3d')\n", "\n", " x_vals = np.linspace(-2, 2, U_initial.shape[1])\n", " y_vals = np.linspace(-2, 2, U_initial.shape[0])\n", " X, Y = np.meshgrid(x_vals, y_vals)\n", "\n", " surf1 = ax3d.plot_surface(X, Y, U_initial, cmap=\"viridis\", alpha=0.6, label=\"Initial\")\n", " surf2 = ax3d.plot_surface(X, Y, U_evolved, cmap=\"plasma\", alpha=0.8, label=\"Evolved\")\n", "\n", " ax3d.set_xlim(-3, 3)\n", " ax3d.set_xlabel(\"x\")\n", " ax3d.set_ylabel(\"y\")\n", " ax3d.set_zlabel(\"u(x,y,t)\")\n", " ax3d.view_init(elev=25, azim=-45)\n", " ax3d.set_box_aspect((2,1,1))\n", "\n", " fig3d.colorbar(surf1, ax=ax3d, shrink=0.5, label=\"Initial\")\n", " fig3d.colorbar(surf2, ax=ax3d, shrink=0.5, label=\"Evolved\")\n", " ax3d.set_title(f\"Solution Evolution\\nInitial ($t_0$={t_0:.2f}) vs Evolved ($t_f$={t:.2f})\")\n", "\n", " plt.tight_layout()\n", " plt.close(fig3d)\n", " return fig3d\n", "\n", "def adv_dif_comparison(Re, t_0, tau):\n", " dt = 2 / 500\n", " exact_initial = exact_solution_adv_dif(Re, t_0)\n", " exact_final = exact_solution_adv_dif(Re, t_0 + tau * dt)\n", "\n", " if np.isnan(exact_initial).any() or np.isnan(exact_final).any():\n", " return None\n", "\n", " x_in = torch.tensor(exact_initial, dtype=torch.float32)[None, None, :, :]\n", " Re_in = torch.tensor([[Re]], dtype=torch.float32)\n", " tau_in = torch.tensor([[tau]], dtype=torch.float32)\n", "\n", " with torch.no_grad():\n", " x_hat, x_hat_tau, *_ = adv_dif_model(x_in, tau_in, Re_in)\n", "\n", " pred = x_hat_tau.squeeze().numpy()\n", " if pred.shape != exact_final.shape:\n", " return None\n", "\n", " mse = np.square(pred - exact_final)\n", "\n", " fig, axs = plt.subplots(1, 3, figsize=(15, 4))\n", "\n", " for ax, (data, title) in zip(axs, [(pred, \"Model Prediction\"),\n", " (exact_final, \"Exact Solution\"),\n", " (mse, \"MSE Error\")]):\n", " if title == \"MSE Error\":\n", " im = ax.imshow(data, cmap=\"viridis\", vmin=0, vmax=1e-2)\n", " plt.colorbar(im, ax=ax, fraction=0.075)\n", " else:\n", " im = ax.imshow(data, cmap=\"jet\")\n", "\n", " ax.set_title(title)\n", " ax.axis(\"off\")\n", "\n", " plt.tight_layout()\n", " plt.close(fig)\n", " return fig\n", "\n", "def update_initial_plot(Re, t_0):\n", " exact_initial = exact_solution_adv_dif(Re, t_0)\n", " fig, ax = plt.subplots(figsize=(5, 5))\n", " im = ax.imshow(exact_initial, cmap='jet')\n", " plt.colorbar(im, ax=ax)\n", " ax.set_title('Initial State')\n", " return fig\n", "\n", "# ========== Gradio Interface ==========\n", "with gr.Blocks(title=\"Flexi-Propagator: PDE Prediction Suite\") as app:\n", " gr.Markdown(\"# Flexi-Propagator: Unified PDE Prediction Interface\")\n", "\n", " with gr.Tabs():\n", " # 1D Burgers' Equation Tab\n", " with gr.Tab(\"1D Burgers' Equation\"):\n", " gr.Markdown(r\"\"\"\n", " ## 🚀 Flexi-Propagator: Single-Shot Prediction for Nonlinear PDEs\n", " **Governing Equation (1D Burgers' Equation):**\n", " $$\n", " \\frac{\\partial u}{\\partial t} + u \\frac{\\partial u}{\\partial x} = \\nu \\frac{\\partial^2 u}{\\partial x^2}\n", " $$\n", " **Key Advantages:** \n", " ✔️ **60-150× faster** than AE-LSTM baselines \n", " ✔️ **Parametric control**: Embeds system parameters in latent space \n", " \n", " **Physically Interpretable Latent Space - Disentanglement:** \n", "
\n", " $$\n", " Z_1 \\text{ Encodes Peak Location, } Z_2 \\text{ Predominantly Encodes Re (Sharpness)}\n", " $$\n", "
\n", "\n", " \"\"\")\n", " \n", " with gr.Row():\n", " with gr.Column():\n", " re_burgers = gr.Slider(425, 2350, 1040, label=\"Reynolds Number\")\n", " tau_burgers = gr.Slider(150, 450, 315, label=\"Time Steps (τ)\")\n", " t0_burgers = gr.Number(0.4, label=\"Initial Time\")\n", " latent_plot = gr.Plot(label=\"Latent Space Dynamics\")\n", " with gr.Column():\n", " burgers_plot = gr.Plot()\n", " time_plot = gr.Plot()\n", " ratio_out = gr.Number(label=\"Time Ratio (Flexi Prop/AE LSTM)\")\n", " \n", " # with gr.Row():\n", " # latent_plot = gr.Plot(label=\"Latent Space Dynamics\")\n", "\n", " re_burgers.change(burgers_update, [re_burgers, tau_burgers, t0_burgers], \n", " [burgers_plot, time_plot, ratio_out, latent_plot])\n", " tau_burgers.change(burgers_update, [re_burgers, tau_burgers, t0_burgers], \n", " [burgers_plot, time_plot, ratio_out, latent_plot])\n", " t0_burgers.change(burgers_update, [re_burgers, tau_burgers, t0_burgers], \n", " [burgers_plot, time_plot, ratio_out, latent_plot])\n", "\n", " # 2D Advection-Diffusion Tab\n", " with gr.Tab(\"2D Advection-Diffusion\"):\n", " gr.Markdown(r\"\"\"\n", " ## 🌪️ 2D Advection-Diffusion Visualization\n", " **Governing Equation:**\n", " $$\n", " \\frac{\\partial u}{\\partial t} + c \\frac{\\partial u}{\\partial x} = \\nu \\left( \\frac{\\partial^2 u}{\\partial x^2} + \\frac{\\partial^2 u}{\\partial y^2} \\right)\n", " $$\n", " \"\"\")\n", " \n", " with gr.Row():\n", " with gr.Column(scale=1):\n", " re_adv = gr.Slider(1, 10, 9, label=\"Reynolds Number (Re)\")\n", " t0_adv = gr.Number(0.45, label=\"Initial Time\")\n", " tau_adv = gr.Slider(150, 425, 225, label=\"Tau (τ)\")\n", " initial_plot_adv = gr.Plot(label=\"Initial State\")\n", " \n", " with gr.Column(scale=3):\n", " with gr.Row():\n", " three_d_plot_adv = gr.Plot(label=\"3D Evolution\")\n", " with gr.Row():\n", " comparison_plots_adv = gr.Plot(label=\"Model Comparison\")\n", "\n", " def adv_update(Re, t0, tau):\n", " return (\n", " generate_3d_visualization(Re, t0, tau),\n", " adv_dif_comparison(Re, t0, tau),\n", " update_initial_plot(Re, t0)\n", " )\n", "\n", " for component in [re_adv, t0_adv, tau_adv]:\n", " component.change(adv_update, [re_adv, t0_adv, tau_adv], \n", " [three_d_plot_adv, comparison_plots_adv, initial_plot_adv])\n", "\n", " app.load(lambda: adv_update(8, 0.35, 225), \n", " outputs=[three_d_plot_adv, comparison_plots_adv, initial_plot_adv])\n", "\n", "app.launch()" ] }, { "cell_type": "code", "execution_count": null, "id": "73e3f1df-972c-4966-9216-8ce7583a5e58", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }