MilesCranmer commited on
Commit
9d6017e
1 Parent(s): f751163

Add setting for plot update rate

Browse files
Files changed (1) hide show
  1. gui/app.py +91 -77
gui/app.py CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
5
  import time
6
  import multiprocessing as mp
7
  from matplotlib import pyplot as plt
 
8
  plt.ioff()
9
  import tempfile
10
  from typing import Optional, Union
@@ -18,9 +19,7 @@ empty_df = pd.DataFrame(
18
  }
19
  )
20
 
21
- test_equations = [
22
- "sin(2*x)/x + 0.1*x"
23
- ]
24
 
25
 
26
  def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
@@ -52,7 +51,7 @@ def _greet_dispatch(
52
  maxsize,
53
  binary_operators,
54
  unary_operators,
55
- seed,
56
  ):
57
  """Load data, then spawn a process to run the greet function."""
58
  if file_input is not None:
@@ -96,7 +95,6 @@ def _greet_dispatch(
96
  maxsize=maxsize,
97
  binary_operators=binary_operators,
98
  unary_operators=unary_operators,
99
- seed=seed,
100
  equation_file=equation_file,
101
  ),
102
  )
@@ -123,7 +121,10 @@ def _greet_dispatch(
123
  bad_idx.append(i)
124
  equations.drop(index=bad_idx, inplace=True)
125
 
126
- while last_yield_time is not None and time.time() - last_yield_time < 1:
 
 
 
127
  time.sleep(0.1)
128
 
129
  yield equations[["Complexity", "Loss", "Equation"]]
@@ -132,7 +133,6 @@ def _greet_dispatch(
132
  except pd.errors.EmptyDataError:
133
  pass
134
 
135
-
136
  process.join()
137
 
138
 
@@ -144,7 +144,6 @@ def greet(
144
  maxsize: int,
145
  binary_operators: list,
146
  unary_operators: list,
147
- seed: int,
148
  equation_file: Union[str, Path],
149
  ):
150
  import pysr
@@ -180,7 +179,9 @@ def _data_layout():
180
  label="Number of Data Points",
181
  step=1,
182
  )
183
- noise_level = gr.Slider(minimum=0, maximum=1, value=0.05, label="Noise Level")
 
 
184
  data_seed = gr.Number(value=0, label="Random Seed")
185
  with gr.Tab("Upload Data"):
186
  file_input = gr.File(label="Upload a CSV File")
@@ -199,55 +200,59 @@ def _data_layout():
199
 
200
 
201
  def _settings_layout():
202
- binary_operators = gr.CheckboxGroup(
203
- choices=["+", "-", "*", "/", "^"],
204
- label="Binary Operators",
205
- value=["+", "-", "*", "/"],
206
- )
207
- unary_operators = gr.CheckboxGroup(
208
- choices=[
209
- "sin",
210
- "cos",
211
- "exp",
212
- "log",
213
- "square",
214
- "cube",
215
- "sqrt",
216
- "abs",
217
- "tan",
218
- ],
219
- label="Unary Operators",
220
- value=["sin"],
221
- )
222
- niterations = gr.Slider(
223
- minimum=1,
224
- maximum=1000,
225
- value=40,
226
- label="Number of Iterations",
227
- step=1,
228
- )
229
- maxsize = gr.Slider(
230
- minimum=7,
231
- maximum=35,
232
- value=20,
233
- label="Maximum Complexity",
234
- step=1,
235
- )
236
- seed = gr.Number(
237
- value=0,
238
- label="Random Seed",
239
- )
240
- force_run = gr.Checkbox(
241
- value=False,
242
- label="Ignore Warnings",
243
- )
 
 
 
 
244
  return dict(
245
  binary_operators=binary_operators,
246
  unary_operators=unary_operators,
247
  niterations=niterations,
248
  maxsize=maxsize,
249
  force_run=force_run,
250
- seed=seed,
251
  )
252
 
253
 
@@ -286,7 +291,7 @@ def main():
286
  "maxsize",
287
  "binary_operators",
288
  "unary_operators",
289
- "seed",
290
  ]
291
  ],
292
  outputs=blocks["df"],
@@ -302,7 +307,6 @@ def main():
302
  for eqn_component in eqn_components:
303
  eqn_component.change(replot, eqn_components, blocks["example_plot"])
304
 
305
-
306
  # Update plot when dataframe is updated:
307
  blocks["df"].change(
308
  replot_pareto,
@@ -313,60 +317,70 @@ def main():
313
 
314
  demo.launch(debug=True)
315
 
 
316
  def replot_pareto(df, maxsize):
317
- plt.rcParams['font.family'] = 'IBM Plex Mono'
318
  fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
319
 
320
- if len(df) == 0 or 'Equation' not in df.columns:
321
  return fig
322
 
323
  # Plotting the data
324
- ax.loglog(df['Complexity'], df['Loss'], marker='o', linestyle='-', color='#333f48', linewidth=1.5, markersize=6)
 
 
 
 
 
 
 
 
325
 
326
  # Set the axis limits
327
  ax.set_xlim(0.5, maxsize + 1)
328
- ytop = 2 ** (np.ceil(np.log2(df['Loss'].max())))
329
- ybottom = 2 ** (np.floor(np.log2(df['Loss'].min() + 1e-20)))
330
  ax.set_ylim(ybottom, ytop)
331
 
332
- ax.grid(True, which="both", ls="--", linewidth=0.5, color='gray', alpha=0.5)
333
- ax.spines['top'].set_visible(False)
334
- ax.spines['right'].set_visible(False)
335
 
336
  # Range-frame the plot
337
- for direction in ['bottom', 'left']:
338
- ax.spines[direction].set_position(('outward', 10))
339
 
340
  # Delete far ticks
341
- ax.tick_params(axis='both', which='major', labelsize=10, direction='out', length=5)
342
- ax.tick_params(axis='both', which='minor', labelsize=8, direction='out', length=3)
343
 
344
- ax.set_xlabel('Complexity')
345
- ax.set_ylabel('Loss')
346
  fig.tight_layout(pad=2)
347
 
348
  return fig
349
 
 
350
  def replot(test_equation, num_points, noise_level, data_seed):
351
  X, y = generate_data(test_equation, num_points, noise_level, data_seed)
352
  x = X["x"]
353
 
354
- plt.rcParams['font.family'] = 'IBM Plex Mono'
355
  fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
356
 
357
- ax.scatter(x, y, alpha=0.7, edgecolors='w', s=50)
358
 
359
- ax.grid(True, which="both", ls="--", linewidth=0.5, color='gray', alpha=0.5)
360
- ax.spines['top'].set_visible(False)
361
- ax.spines['right'].set_visible(False)
362
 
363
  # Range-frame the plot
364
- for direction in ['bottom', 'left']:
365
- ax.spines[direction].set_position(('outward', 10))
366
 
367
  # Delete far ticks
368
- ax.tick_params(axis='both', which='major', labelsize=10, direction='out', length=5)
369
- ax.tick_params(axis='both', which='minor', labelsize=8, direction='out', length=3)
370
 
371
  ax.set_xlabel("x")
372
  ax.set_ylabel("y")
 
5
  import time
6
  import multiprocessing as mp
7
  from matplotlib import pyplot as plt
8
+
9
  plt.ioff()
10
  import tempfile
11
  from typing import Optional, Union
 
19
  }
20
  )
21
 
22
+ test_equations = ["sin(2*x)/x + 0.1*x"]
 
 
23
 
24
 
25
  def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
 
51
  maxsize,
52
  binary_operators,
53
  unary_operators,
54
+ plot_update_delay,
55
  ):
56
  """Load data, then spawn a process to run the greet function."""
57
  if file_input is not None:
 
95
  maxsize=maxsize,
96
  binary_operators=binary_operators,
97
  unary_operators=unary_operators,
 
98
  equation_file=equation_file,
99
  ),
100
  )
 
121
  bad_idx.append(i)
122
  equations.drop(index=bad_idx, inplace=True)
123
 
124
+ while (
125
+ last_yield_time is not None
126
+ and time.time() - last_yield_time < plot_update_delay
127
+ ):
128
  time.sleep(0.1)
129
 
130
  yield equations[["Complexity", "Loss", "Equation"]]
 
133
  except pd.errors.EmptyDataError:
134
  pass
135
 
 
136
  process.join()
137
 
138
 
 
144
  maxsize: int,
145
  binary_operators: list,
146
  unary_operators: list,
 
147
  equation_file: Union[str, Path],
148
  ):
149
  import pysr
 
179
  label="Number of Data Points",
180
  step=1,
181
  )
182
+ noise_level = gr.Slider(
183
+ minimum=0, maximum=1, value=0.05, label="Noise Level"
184
+ )
185
  data_seed = gr.Number(value=0, label="Random Seed")
186
  with gr.Tab("Upload Data"):
187
  file_input = gr.File(label="Upload a CSV File")
 
200
 
201
 
202
  def _settings_layout():
203
+ with gr.Tab("Basic Settings"):
204
+ binary_operators = gr.CheckboxGroup(
205
+ choices=["+", "-", "*", "/", "^"],
206
+ label="Binary Operators",
207
+ value=["+", "-", "*", "/"],
208
+ )
209
+ unary_operators = gr.CheckboxGroup(
210
+ choices=[
211
+ "sin",
212
+ "cos",
213
+ "exp",
214
+ "log",
215
+ "square",
216
+ "cube",
217
+ "sqrt",
218
+ "abs",
219
+ "tan",
220
+ ],
221
+ label="Unary Operators",
222
+ value=["sin"],
223
+ )
224
+ niterations = gr.Slider(
225
+ minimum=1,
226
+ maximum=1000,
227
+ value=40,
228
+ label="Number of Iterations",
229
+ step=1,
230
+ )
231
+ maxsize = gr.Slider(
232
+ minimum=7,
233
+ maximum=35,
234
+ value=20,
235
+ label="Maximum Complexity",
236
+ step=1,
237
+ )
238
+ force_run = gr.Checkbox(
239
+ value=False,
240
+ label="Ignore Warnings",
241
+ )
242
+ with gr.Tab("Gradio Settings"):
243
+ plot_update_delay = gr.Slider(
244
+ minimum=1,
245
+ maximum=100,
246
+ value=3,
247
+ label="Plot Update Delay",
248
+ )
249
  return dict(
250
  binary_operators=binary_operators,
251
  unary_operators=unary_operators,
252
  niterations=niterations,
253
  maxsize=maxsize,
254
  force_run=force_run,
255
+ plot_update_delay=plot_update_delay,
256
  )
257
 
258
 
 
291
  "maxsize",
292
  "binary_operators",
293
  "unary_operators",
294
+ "plot_update_delay",
295
  ]
296
  ],
297
  outputs=blocks["df"],
 
307
  for eqn_component in eqn_components:
308
  eqn_component.change(replot, eqn_components, blocks["example_plot"])
309
 
 
310
  # Update plot when dataframe is updated:
311
  blocks["df"].change(
312
  replot_pareto,
 
317
 
318
  demo.launch(debug=True)
319
 
320
+
321
  def replot_pareto(df, maxsize):
322
+ plt.rcParams["font.family"] = "IBM Plex Mono"
323
  fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
324
 
325
+ if len(df) == 0 or "Equation" not in df.columns:
326
  return fig
327
 
328
  # Plotting the data
329
+ ax.loglog(
330
+ df["Complexity"],
331
+ df["Loss"],
332
+ marker="o",
333
+ linestyle="-",
334
+ color="#333f48",
335
+ linewidth=1.5,
336
+ markersize=6,
337
+ )
338
 
339
  # Set the axis limits
340
  ax.set_xlim(0.5, maxsize + 1)
341
+ ytop = 2 ** (np.ceil(np.log2(df["Loss"].max())))
342
+ ybottom = 2 ** (np.floor(np.log2(df["Loss"].min() + 1e-20)))
343
  ax.set_ylim(ybottom, ytop)
344
 
345
+ ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
346
+ ax.spines["top"].set_visible(False)
347
+ ax.spines["right"].set_visible(False)
348
 
349
  # Range-frame the plot
350
+ for direction in ["bottom", "left"]:
351
+ ax.spines[direction].set_position(("outward", 10))
352
 
353
  # Delete far ticks
354
+ ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
355
+ ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
356
 
357
+ ax.set_xlabel("Complexity")
358
+ ax.set_ylabel("Loss")
359
  fig.tight_layout(pad=2)
360
 
361
  return fig
362
 
363
+
364
  def replot(test_equation, num_points, noise_level, data_seed):
365
  X, y = generate_data(test_equation, num_points, noise_level, data_seed)
366
  x = X["x"]
367
 
368
+ plt.rcParams["font.family"] = "IBM Plex Mono"
369
  fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
370
 
371
+ ax.scatter(x, y, alpha=0.7, edgecolors="w", s=50)
372
 
373
+ ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
374
+ ax.spines["top"].set_visible(False)
375
+ ax.spines["right"].set_visible(False)
376
 
377
  # Range-frame the plot
378
+ for direction in ["bottom", "left"]:
379
+ ax.spines[direction].set_position(("outward", 10))
380
 
381
  # Delete far ticks
382
+ ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
383
+ ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
384
 
385
  ax.set_xlabel("x")
386
  ax.set_ylabel("y")