Lennard Schober commited on
Commit
e1d92e1
·
1 Parent(s): 8226920

Store .npz files in current dir

Browse files
Files changed (1) hide show
  1. app.py +16 -25
app.py CHANGED
@@ -6,24 +6,23 @@ import numpy as np
6
  import gradio as gr
7
  import plotly.graph_objs as go
8
 
9
- # Path to the npz folder
10
- npz_folder = "npz"
11
-
12
  glob_a = -2
13
  glob_b = -2
14
  glob_c = -4
15
  glob_d = 7
16
 
17
- def clear_folder(folder_path=npz_folder):
18
- for filename in os.listdir(folder_path):
19
- file_path = os.path.join(folder_path, filename)
20
- try:
21
- if os.path.isfile(file_path) or os.path.islink(file_path):
22
- os.unlink(file_path) # Remove the file or symbolic link
23
- elif os.path.isdir(file_path):
24
- shutil.rmtree(file_path) # Remove the directory and its contents
25
- except Exception as e:
26
- print(f"Failed to delete {file_path}. Reason: {e}")
 
 
27
 
28
 
29
  def complex_heat_eq_solution(x, t, a=glob_a, b=glob_b, c=glob_c, d=glob_d, k=0.5):
@@ -41,7 +40,7 @@ def plot_heat_equation(m, approx_type):
41
  n_t = 50
42
 
43
  try:
44
- loaded_values = np.load(f"npz/{approx_type}_m{m}.npz")
45
  except:
46
  raise gr.Error(f"First train the coefficients for {approx_type} and m = {m}")
47
  alpha = loaded_values["alpha"]
@@ -138,7 +137,7 @@ def plot_errors(m, approx_type):
138
  n_t = 50
139
 
140
  try:
141
- loaded_values = np.load(f"npz/{approx_type}_m{m}.npz")
142
  except:
143
  raise gr.Error(f"First train the coefficients for {approx_type} and m = {m}")
144
  alpha = loaded_values["alpha"]
@@ -216,14 +215,6 @@ def plot_errors(m, approx_type):
216
 
217
  fig.show(config=config)
218
 
219
-
220
- # Function to get the available .npz files in the npz folder
221
- def get_available_approx_files():
222
- files = os.listdir(npz_folder)
223
- npz_files = [f for f in files if f.endswith(".npz")]
224
- return npz_files
225
-
226
-
227
  def generate_data(n_x=32, n_t=50):
228
  """Generate training data."""
229
  x = np.linspace(0, 1, n_x) # spatial points
@@ -310,7 +301,7 @@ def train_coefficients(m, kernel):
310
 
311
  # Save values to the npz folder
312
  np.savez(
313
- f"{npz_folder}/{kernel}_m{m}.npz",
314
  alpha=alpha,
315
  kernel=kernel,
316
  Phi=Phi,
@@ -465,7 +456,7 @@ def create_gradio_ui():
465
  error_button.click(
466
  fn=plot_errors, inputs=[m_slider, kernel_dropdown], outputs=None
467
  )
468
- demo.load(fn=clear_folder, inputs=None, outputs=None)
469
  demo.load(fn=plot_function, inputs=[a_slider, b_slider, c_slider, d_slider], outputs=[plot_output])
470
 
471
  return demo
 
6
  import gradio as gr
7
  import plotly.graph_objs as go
8
 
 
 
 
9
  glob_a = -2
10
  glob_b = -2
11
  glob_c = -4
12
  glob_d = 7
13
 
14
+ def clear_npz():
15
+ current_directory = os.getcwd() # Get the current working directory
16
+ for filename in os.listdir(current_directory):
17
+ if filename.endswith(".npz"): # Check if the file ends with .npz
18
+ file_path = os.path.join(current_directory, filename)
19
+ try:
20
+ if os.path.isfile(file_path) or os.path.islink(file_path):
21
+ os.unlink(file_path) # Remove the file or symbolic link
22
+ else:
23
+ print(f"Skipping {file_path}, not a file or symbolic link.")
24
+ except Exception as e:
25
+ print(f"Failed to delete {file_path}. Reason: {e}")
26
 
27
 
28
  def complex_heat_eq_solution(x, t, a=glob_a, b=glob_b, c=glob_c, d=glob_d, k=0.5):
 
40
  n_t = 50
41
 
42
  try:
43
+ loaded_values = np.load(f"{approx_type}_m{m}.npz")
44
  except:
45
  raise gr.Error(f"First train the coefficients for {approx_type} and m = {m}")
46
  alpha = loaded_values["alpha"]
 
137
  n_t = 50
138
 
139
  try:
140
+ loaded_values = np.load(f"{approx_type}_m{m}.npz")
141
  except:
142
  raise gr.Error(f"First train the coefficients for {approx_type} and m = {m}")
143
  alpha = loaded_values["alpha"]
 
215
 
216
  fig.show(config=config)
217
 
 
 
 
 
 
 
 
 
218
  def generate_data(n_x=32, n_t=50):
219
  """Generate training data."""
220
  x = np.linspace(0, 1, n_x) # spatial points
 
301
 
302
  # Save values to the npz folder
303
  np.savez(
304
+ f"{kernel}_m{m}.npz",
305
  alpha=alpha,
306
  kernel=kernel,
307
  Phi=Phi,
 
456
  error_button.click(
457
  fn=plot_errors, inputs=[m_slider, kernel_dropdown], outputs=None
458
  )
459
+ demo.load(fn=clear_npz, inputs=None, outputs=None)
460
  demo.load(fn=plot_function, inputs=[a_slider, b_slider, c_slider, d_slider], outputs=[plot_output])
461
 
462
  return demo