Lennard Schober commited on
Commit
d2ec756
·
1 Parent(s): 59cb565

Init commit

Browse files
Files changed (2) hide show
  1. app.py +479 -0
  2. npz/.DS_Store +0 -0
app.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import os
4
+ import time
5
+ import plotly.graph_objs as go
6
+ import matplotlib.pyplot as plt
7
+ import shutil
8
+ from colorama import Fore
9
+
10
+ # Path to the npz folder
11
+ npz_folder = "npz"
12
+
13
+ glob_a = -2
14
+ glob_b = -2
15
+ glob_c = -4
16
+ glob_d = 7
17
+
18
+
19
+ def clear_folder(folder_path=npz_folder):
20
+ for filename in os.listdir(folder_path):
21
+ file_path = os.path.join(folder_path, filename)
22
+ try:
23
+ if os.path.isfile(file_path) or os.path.islink(file_path):
24
+ os.unlink(file_path) # Remove the file or symbolic link
25
+ elif os.path.isdir(file_path):
26
+ shutil.rmtree(file_path) # Remove the directory and its contents
27
+ except Exception as e:
28
+ print(f"Failed to delete {file_path}. Reason: {e}")
29
+
30
+
31
+ def complex_heat_eq_solution(x, t, a=glob_a, b=glob_b, c=glob_c, d=glob_d, k=0.5):
32
+ global glob_a, glob_b, glob_c, glob_d
33
+ return (
34
+ np.exp(-k * t) * np.sin(np.pi * x)
35
+ + 0.5 * np.exp(glob_a * k * t) * np.sin(glob_b * np.pi * x)
36
+ + 0.25 * np.exp(glob_c * k * t) * np.sin(glob_d * np.pi * x)
37
+ )
38
+
39
+
40
+ def plot_heat_equation(m, approx_type):
41
+ # Define grid dimensions
42
+ n_x = 32 # Fixed spatial grid resolution
43
+ n_t = 50
44
+
45
+ try:
46
+ loaded_values = np.load(f"npz/{approx_type}_m{m}.npz")
47
+ except:
48
+ raise gr.Error(f"First train the coefficients for {approx_type} and m = {m}")
49
+ alpha = loaded_values["alpha"]
50
+ Phi = loaded_values["Phi"]
51
+
52
+ # Create grids for x and t
53
+ x = np.linspace(0, 1, n_x) # Spatial grid
54
+ t = np.linspace(0, 5, n_t) # Temporal grid
55
+ X, T = np.meshgrid(x, t)
56
+
57
+ # Compute the real solution over the grid
58
+ U_real = complex_heat_eq_solution(X, T)
59
+
60
+ # Compute the selected approximation
61
+ U_approx = np.zeros_like(U_real)
62
+ for i, t_val in enumerate(t):
63
+ Phi_gff_at_t = Phi[i * n_x : (i + 1) * n_x]
64
+ U_approx[i, :] = np.dot(Phi_gff_at_t, alpha)
65
+
66
+ # Create the 3D plot with Plotly
67
+ traces = []
68
+
69
+ # Real solution surface with a distinct color (e.g., 'Viridis')
70
+ traces.append(
71
+ go.Surface(
72
+ z=U_real,
73
+ x=X,
74
+ y=T,
75
+ colorscale="Blues",
76
+ showscale=False,
77
+ name="Real Solution",
78
+ showlegend=True,
79
+ )
80
+ )
81
+
82
+ # Approximation surface with a distinct color (e.g., 'Plasma')
83
+ traces.append(
84
+ go.Surface(
85
+ z=U_approx,
86
+ x=X,
87
+ y=T,
88
+ colorscale="Reds",
89
+ reversescale=True,
90
+ showscale=False,
91
+ name=f"{approx_type} Approximation",
92
+ showlegend=True,
93
+ )
94
+ )
95
+
96
+ # Layout for the Plotly plot without controls
97
+ layout = go.Layout(
98
+ title=f"Heat Equation Approximation | Kernel = {approx_type} | m = {m}",
99
+ scene=dict(
100
+ camera=dict(
101
+ eye=dict(x=0, y=-2, z=0), # Front view
102
+ ),
103
+ xaxis_title="x",
104
+ yaxis_title="t",
105
+ zaxis_title="u",
106
+ ),
107
+ )
108
+
109
+ # Config to remove modebar buttons except the save image button
110
+ config = {
111
+ "modeBarButtonsToRemove": [
112
+ "pan",
113
+ "resetCameraLastSave",
114
+ "hoverClosest3d",
115
+ "hoverCompareCartesian",
116
+ "zoomIn",
117
+ "zoomOut",
118
+ "select2d",
119
+ "lasso2d",
120
+ "zoomIn2d",
121
+ "zoomOut2d",
122
+ "sendDataToCloud",
123
+ "zoom3d",
124
+ "orbitRotation",
125
+ "tableRotation",
126
+ ],
127
+ "displayModeBar": True, # Keep the modebar visible
128
+ "displaylogo": False, # Hide the Plotly logo
129
+ }
130
+
131
+ # Create the figure
132
+ fig = go.Figure(data=traces, layout=layout)
133
+
134
+ fig.show(config=config)
135
+
136
+
137
+ def plot_errors(m, approx_type):
138
+ # Define grid dimensions
139
+ n_x = 32 # Fixed spatial grid resolution
140
+ n_t = 50
141
+
142
+ try:
143
+ loaded_values = np.load(f"npz/{approx_type}_m{m}.npz")
144
+ except:
145
+ raise gr.Error(f"First train the coefficients for {approx_type} and m = {m}")
146
+ alpha = loaded_values["alpha"]
147
+ Phi = loaded_values["Phi"]
148
+
149
+ # Create grids for x and t
150
+ x = np.linspace(0, 1, n_x) # Spatial grid
151
+ t = np.linspace(0, 5, n_t) # Temporal grid
152
+ X, T = np.meshgrid(x, t)
153
+
154
+ # Compute the real solution over the grid
155
+ U_real = complex_heat_eq_solution(X, T)
156
+
157
+ # Compute the selected approximation
158
+ U_approx = np.zeros_like(U_real)
159
+ for i, t_val in enumerate(t):
160
+ Phi_gff_at_t = Phi[i * n_x : (i + 1) * n_x]
161
+ U_approx[i, :] = np.dot(Phi_gff_at_t, alpha)
162
+
163
+ U_err = abs(U_approx - U_real)
164
+
165
+ # Create the 3D plot with Plotly
166
+ traces = []
167
+
168
+ # Real solution surface with a distinct color (e.g., 'Viridis')
169
+ traces.append(
170
+ go.Surface(
171
+ z=U_err,
172
+ x=X,
173
+ y=T,
174
+ colorscale="Viridis",
175
+ showscale=False,
176
+ name=f"Absolute Error",
177
+ showlegend=True,
178
+ )
179
+ )
180
+
181
+ # Layout for the Plotly plot without controls
182
+ layout = go.Layout(
183
+ title=f"Heat Equation Approximation Error | Kernel = {approx_type} | m = {m}",
184
+ scene=dict(
185
+ camera=dict(
186
+ eye=dict(x=0, y=-2, z=0), # Front view
187
+ ),
188
+ xaxis_title="x",
189
+ yaxis_title="t",
190
+ zaxis_title="u",
191
+ ),
192
+ )
193
+
194
+ # Config to remove modebar buttons except the save image button
195
+ config = {
196
+ "modeBarButtonsToRemove": [
197
+ "pan",
198
+ "resetCameraLastSave",
199
+ "hoverClosest3d",
200
+ "hoverCompareCartesian",
201
+ "zoomIn",
202
+ "zoomOut",
203
+ "select2d",
204
+ "lasso2d",
205
+ "zoomIn2d",
206
+ "zoomOut2d",
207
+ "sendDataToCloud",
208
+ "zoom3d",
209
+ "orbitRotation",
210
+ "tableRotation",
211
+ ],
212
+ "displayModeBar": True, # Keep the modebar visible
213
+ "displaylogo": False, # Hide the Plotly logo
214
+ }
215
+
216
+ # Create the figure
217
+ fig = go.Figure(data=traces, layout=layout)
218
+
219
+ fig.show(config=config)
220
+
221
+
222
+ # Function to get the available .npz files in the npz folder
223
+ def get_available_approx_files():
224
+ files = os.listdir(npz_folder)
225
+ npz_files = [f for f in files if f.endswith(".npz")]
226
+ return npz_files
227
+
228
+
229
+ def generate_data(n_x=32, n_t=50):
230
+ """Generate training data."""
231
+ x = np.linspace(0, 1, n_x) # spatial points
232
+ t = np.linspace(0, 5, n_t) # temporal points
233
+ X, T = np.meshgrid(x, t)
234
+ a_train = np.c_[X.ravel(), T.ravel()] # shape (n_x * n_t, 2)
235
+ u_train = complex_heat_eq_solution(
236
+ a_train[:, 0], a_train[:, 1]
237
+ ) # shape (n_x * n_t,)
238
+ return a_train, u_train, x, t
239
+
240
+
241
+ def random_features(a, theta_j, kernel="SINE", k=0.5, t=1.0):
242
+ """Compute random features with adjustable kernel width."""
243
+ if kernel == "SINE":
244
+ return np.sin(t * np.linalg.norm(a - theta_j, axis=-1))
245
+ elif kernel == "GFF":
246
+ return np.log(np.linalg.norm(a - theta_j, axis=-1)) / (2 * np.pi)
247
+ else:
248
+ raise ValueError("Unsupported kernel type!")
249
+
250
+
251
+ def design_matrix(a, theta, kernel):
252
+ """Construct design matrix."""
253
+ return np.array([random_features(a, theta_j, kernel=kernel) for theta_j in theta]).T
254
+
255
+
256
+ def learn_coefficients(Phi, u):
257
+ """Learn coefficients alpha via least squares."""
258
+ return np.linalg.lstsq(Phi, u, rcond=None)[0]
259
+
260
+
261
+ def approximate_solution(a, alpha, theta, kernel):
262
+ """Compute the approximation."""
263
+ Phi = design_matrix(a, theta, kernel)
264
+ return Phi @ alpha
265
+
266
+
267
+ def polyfit2d(x, y, z, kx=3, ky=3, order=None):
268
+ # grid coords
269
+ x, y = np.meshgrid(x, y)
270
+ # coefficient array, up to x^kx, y^ky
271
+ coeffs = np.ones((kx + 1, ky + 1))
272
+
273
+ # solve array
274
+ a = np.zeros((coeffs.size, x.size))
275
+
276
+ # for each coefficient produce array x^i, y^j
277
+ for index, (j, i) in enumerate(np.ndindex(coeffs.shape)):
278
+ # do not include powers greater than order
279
+ if order is not None and i + j > order:
280
+ arr = np.zeros_like(x)
281
+ else:
282
+ arr = coeffs[i, j] * x**i * y**j
283
+ a[index] = arr.ravel()
284
+
285
+ # do leastsq fitting and return leastsq result
286
+ return np.linalg.lstsq(a.T, np.ravel(z), rcond=None)
287
+
288
+
289
+ def train_coefficients(m, kernel):
290
+ # Start time for training
291
+ start_time = time.time()
292
+
293
+ # Generate data
294
+ n_x, n_t = 32, 50
295
+ a_train, u_train, x, t = generate_data(n_x, n_t)
296
+
297
+ # Define random features
298
+ theta = np.column_stack(
299
+ (
300
+ np.random.uniform(-1, 1, size=m), # First dimension: [-1, 1]
301
+ np.random.uniform(-5, 5, size=m), # Second dimension: [-5, 5]
302
+ )
303
+ )
304
+
305
+ # Construct design matrix and learn coefficients
306
+ Phi = design_matrix(a_train, theta, kernel)
307
+ alpha = learn_coefficients(Phi, u_train)
308
+ # Validate and animate results
309
+ u_real = np.array([complex_heat_eq_solution(x, t_i) for t_i in t])
310
+ a_test = np.c_[np.meshgrid(x, t)[0].ravel(), np.meshgrid(x, t)[1].ravel()]
311
+ u_approx = approximate_solution(a_test, alpha, theta, kernel).reshape(n_t, n_x)
312
+
313
+ # Save values to the npz folder
314
+ np.savez(
315
+ f"{npz_folder}/{kernel}_m{m}.npz",
316
+ alpha=alpha,
317
+ kernel=kernel,
318
+ Phi=Phi,
319
+ theta=theta,
320
+ )
321
+
322
+ # Compute average error
323
+ avg_err = np.mean(np.abs(u_real - u_approx))
324
+
325
+ return f"Training completed in {time.time() - start_time:.2f} seconds. The average error is {avg_err}."
326
+
327
+
328
+ def plot_function(a, b, c, d, k=0.5):
329
+ global glob_a, glob_b, glob_c, glob_d
330
+
331
+ glob_a, glob_b, glob_c, glob_d = a, b, c, d
332
+
333
+ x = np.linspace(0, 1, 100)
334
+ t = np.linspace(0, 5, 500)
335
+ X, T = np.meshgrid(x, t) # Create the mesh grid
336
+ Z = complex_heat_eq_solution(X, T, a, b, c, d)
337
+
338
+ traces = []
339
+ traces.append(
340
+ go.Surface(
341
+ z=Z,
342
+ x=X,
343
+ y=T,
344
+ colorscale="Viridis",
345
+ showscale=False,
346
+ showlegend=False,
347
+ )
348
+ )
349
+
350
+ # Layout for the Plotly plot without controls
351
+ layout = go.Layout(
352
+ scene=dict(
353
+ camera=dict(
354
+ eye=dict(x=1.25, y=-1.75, z=0.3), # Front view
355
+ ),
356
+ xaxis_title="x",
357
+ yaxis_title="t",
358
+ zaxis_title="u",
359
+ ),
360
+ margin=dict(l=0, r=0, t=0, b=0), # Reduce margins
361
+ )
362
+
363
+ # Create the figure
364
+ fig = go.Figure(data=traces, layout=layout)
365
+
366
+ # fig.show(config=config)
367
+ fig.update_layout(
368
+ modebar_remove=[
369
+ "pan",
370
+ "resetCameraLastSave",
371
+ "hoverClosest3d",
372
+ "hoverCompareCartesian",
373
+ "zoomIn",
374
+ "zoomOut",
375
+ "select2d",
376
+ "lasso2d",
377
+ "zoomIn2d",
378
+ "zoomOut2d",
379
+ "sendDataToCloud",
380
+ "zoom3d",
381
+ "orbitRotation",
382
+ "tableRotation",
383
+ "toImage",
384
+ "resetCameraDefault3d"
385
+ ]
386
+ )
387
+
388
+ return fig
389
+
390
+
391
+ # Gradio interface
392
+ def create_gradio_ui():
393
+ # Get the initial available files
394
+ with gr.Blocks() as demo:
395
+ gr.Markdown("# Learn the Coefficients for the Heat Equation using the RFM")
396
+
397
+ # Function parameter inputs
398
+ gr.Markdown(
399
+ """
400
+ ## Function: $$u_k(x, t)\\coloneqq\\exp(-kt)\\cdot\\sin(\\pi x)+0.5\\cdot\\exp(\\textcolor{red}{a}kt)\\cdot\\sin(\\textcolor{red}{b}\\pi x)+0.25\\cdot\\exp(\\textcolor{red}{c}kt)\\cdot\\sin(\\textcolor{red}{d}\\pi x)$$
401
+
402
+ Adjust the values for <span style='color: red;'>a</span>, <span style='color: red;'>b</span>, <span style='color: red;'>c</span> and <span style='color: red;'>d</span> with the sliders below.
403
+ """
404
+ )
405
+
406
+ with gr.Row():
407
+ with gr.Column():
408
+ a_slider = gr.Slider(minimum=-10, maximum=-1, step=1, value=-2, label="a")
409
+ b_slider = gr.Slider(minimum=-10, maximum=10, step=1, value=-2, label="b")
410
+ c_slider = gr.Slider(minimum=-10, maximum=-1, step=1, value=-4, label="c")
411
+ d_slider = gr.Slider(minimum=-10, maximum=10, step=1, value=7, label="d")
412
+
413
+ plot_output = gr.Plot()
414
+
415
+ a_slider.change(
416
+ fn=plot_function,
417
+ inputs=[a_slider, b_slider, c_slider, d_slider],
418
+ outputs=[plot_output],
419
+ )
420
+ b_slider.change(
421
+ fn=plot_function,
422
+ inputs=[a_slider, b_slider, c_slider, d_slider],
423
+ outputs=[plot_output],
424
+ )
425
+ c_slider.change(
426
+ fn=plot_function,
427
+ inputs=[a_slider, b_slider, c_slider, d_slider],
428
+ outputs=[plot_output],
429
+ )
430
+ d_slider.change(
431
+ fn=plot_function,
432
+ inputs=[a_slider, b_slider, c_slider, d_slider],
433
+ outputs=[plot_output],
434
+ )
435
+
436
+ with gr.Column():
437
+ with gr.Row():
438
+ # Kernel selection and slider for m
439
+ kernel_dropdown = gr.Dropdown(
440
+ label="Choose Kernel", choices=["SINE", "GFF"], value="SINE"
441
+ )
442
+ m_slider = gr.Dropdown(
443
+ label="Number of Random Features (m)",
444
+ choices=[50, 250, 1000, 5000, 10000, 25000],
445
+ value=1000,
446
+ )
447
+
448
+ # Output to show status
449
+ output = gr.Textbox(label="Status", interactive=False)
450
+
451
+ with gr.Row():
452
+ # Button to train coefficients
453
+ train_button = gr.Button("Train Coefficients")
454
+ # Function to trigger training and update dropdown
455
+ train_button.click(
456
+ fn=train_coefficients,
457
+ inputs=[m_slider, kernel_dropdown],
458
+ outputs=output,
459
+ )
460
+ with gr.Row():
461
+ approx_button = gr.Button("Plot Approximation")
462
+ approx_button.click(
463
+ fn=plot_heat_equation, inputs=[m_slider, kernel_dropdown], outputs=None
464
+ )
465
+
466
+ error_button = gr.Button("Plot Errors")
467
+ error_button.click(
468
+ fn=plot_errors, inputs=[m_slider, kernel_dropdown], outputs=None
469
+ )
470
+ demo.load(fn=clear_folder, inputs=None, outputs=None)
471
+ demo.load(fn=plot_function, inputs=[a_slider, b_slider, c_slider, d_slider], outputs=[plot_output])
472
+
473
+ return demo
474
+
475
+
476
+ # Launch Gradio app
477
+ if __name__ == "__main__":
478
+ interface = create_gradio_ui()
479
+ interface.launch(share=False)
npz/.DS_Store ADDED
Binary file (8.2 kB). View file