MilesCranmer commited on
Commit
758e952
1 Parent(s): 4eac491

Make nicer plot for example data

Browse files
Files changed (1) hide show
  1. gui/app.py +50 -53
gui/app.py CHANGED
@@ -19,12 +19,13 @@ empty_df = pd.DataFrame(
19
  )
20
 
21
  test_equations = [
22
- "sin(x) + cos(2*x) + tan(x/3)",
23
  ]
24
 
25
 
26
  def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
27
- x = np.linspace(0, 10, num_points)
 
28
  for (k, v) in {
29
  "sin": "np.sin",
30
  "cos": "np.cos",
@@ -35,7 +36,6 @@ def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
35
  }.items():
36
  s = s.replace(k, v)
37
  y = eval(s)
38
- rstate = np.random.RandomState(data_seed)
39
  noise = rstate.normal(0, noise_level, y.shape)
40
  y_noisy = y + noise
41
  return pd.DataFrame({"x": x}), y_noisy
@@ -101,30 +101,37 @@ def _greet_dispatch(
101
  ),
102
  )
103
  process.start()
 
104
  while process.is_alive():
105
  if equation_file_bkup.exists():
106
  try:
107
  # First, copy the file to a the copy file
108
  equation_file_copy = base / "hall_of_fame_copy.csv"
109
  os.system(f"cp {equation_file_bkup} {equation_file_copy}")
110
- df = pd.read_csv(equation_file_copy)
111
  # Ensure it is pareto dominated, with more complex expressions
112
  # having higher loss. Otherwise remove those rows.
113
  # TODO: Not sure why this occurs; could be the result of a late copy?
114
- df.sort_values("Complexity", ascending=True, inplace=True)
115
- df.reset_index(inplace=True)
116
  bad_idx = []
117
  min_loss = None
118
- for i in df.index:
119
- if min_loss is None or df.loc[i, "Loss"] < min_loss:
120
- min_loss = float(df.loc[i, "Loss"])
121
  else:
122
  bad_idx.append(i)
123
- df.drop(index=bad_idx, inplace=True)
124
- yield df[["Complexity", "Loss", "Equation"]]
 
 
 
 
 
 
125
  except pd.errors.EmptyDataError:
126
  pass
127
- time.sleep(1)
128
 
129
  process.join()
130
 
@@ -163,31 +170,23 @@ def greet(
163
  def _data_layout():
164
  with gr.Tab("Example Data"):
165
  # Plot of the example data:
166
- example_plot = gr.ScatterPlot(
167
- x="x",
168
- y="y",
169
- tooltip=["x", "y"],
170
- x_lim=[0, 10],
171
- y_lim=[-5, 5],
172
- width=350,
173
- height=300,
174
- )
175
  test_equation = gr.Radio(
176
  test_equations, value=test_equations[0], label="Test Equation"
177
  )
178
  num_points = gr.Slider(
179
  minimum=10,
180
  maximum=1000,
181
- value=100,
182
  label="Number of Data Points",
183
  step=1,
184
  )
185
- noise_level = gr.Slider(minimum=0, maximum=1, value=0.1, label="Noise Level")
186
  data_seed = gr.Number(value=0, label="Random Seed")
187
  with gr.Tab("Upload Data"):
188
  file_input = gr.File(label="Upload a CSV File")
189
  gr.Markdown(
190
- "Upload a CSV file with the data to fit. The last column will be used as the target variable."
191
  )
192
 
193
  return dict(
@@ -219,7 +218,7 @@ def _settings_layout():
219
  "tan",
220
  ],
221
  label="Unary Operators",
222
- value=[],
223
  )
224
  niterations = gr.Slider(
225
  minimum=1,
@@ -304,43 +303,17 @@ def main():
304
  for eqn_component in eqn_components:
305
  eqn_component.change(replot, eqn_components, blocks["example_plot"])
306
 
 
307
  # Update plot when dataframe is updated:
308
  blocks["df"].change(
309
  replot_pareto,
310
  inputs=[blocks["df"], blocks["maxsize"]],
311
  outputs=[blocks["pareto"]],
312
  )
 
313
 
314
  demo.launch(debug=True)
315
 
316
-
317
- def replot(test_equation, num_points, noise_level, data_seed):
318
- X, y = generate_data(test_equation, num_points, noise_level, data_seed)
319
- df = pd.DataFrame({"x": X["x"], "y": y})
320
- return df
321
-
322
- def replot_pareto(df, maxsize):
323
- # Matplotlib log-log plot of loss vs complexity:
324
- fig, ax = plt.subplots(figsize=(5, 5))
325
-
326
- ax.set_xlabel('Complexity', fontsize=14)
327
- ax.set_ylabel('Loss', fontsize=14)
328
- if len(df) == 0 or 'Equation' not in df.columns:
329
- return fig
330
-
331
- ax.loglog(df['Complexity'], df['Loss'], marker='o', linestyle='-', color='b')
332
- ax.set_xlim(1, maxsize + 1)
333
- # Set ylim to next power of 2:
334
- ytop = 2 ** (np.ceil(np.log2(df['Loss'].max())))
335
- ybottom = 2 ** (np.floor(np.log2(df['Loss'].min() + 1e-20)))
336
- ax.set_ylim(ybottom, ytop)
337
- ax.grid(True, which="both", ls="--", linewidth=0.5)
338
- fig.tight_layout()
339
- ax.tick_params(axis='both', which='major', labelsize=12)
340
- ax.tick_params(axis='both', which='minor', labelsize=10)
341
-
342
- return fig
343
-
344
  def replot_pareto(df, maxsize):
345
  plt.rcParams['font.family'] = 'IBM Plex Mono'
346
  fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
@@ -375,5 +348,29 @@ def replot_pareto(df, maxsize):
375
 
376
  return fig
377
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  if __name__ == "__main__":
379
  main()
 
19
  )
20
 
21
  test_equations = [
22
+ "sin(2*x)/x + 0.1*x"
23
  ]
24
 
25
 
26
  def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
27
+ rstate = np.random.RandomState(data_seed)
28
+ x = rstate.uniform(-10, 10, num_points)
29
  for (k, v) in {
30
  "sin": "np.sin",
31
  "cos": "np.cos",
 
36
  }.items():
37
  s = s.replace(k, v)
38
  y = eval(s)
 
39
  noise = rstate.normal(0, noise_level, y.shape)
40
  y_noisy = y + noise
41
  return pd.DataFrame({"x": x}), y_noisy
 
101
  ),
102
  )
103
  process.start()
104
+ last_yield_time = None
105
  while process.is_alive():
106
  if equation_file_bkup.exists():
107
  try:
108
  # First, copy the file to a the copy file
109
  equation_file_copy = base / "hall_of_fame_copy.csv"
110
  os.system(f"cp {equation_file_bkup} {equation_file_copy}")
111
+ equations = pd.read_csv(equation_file_copy)
112
  # Ensure it is pareto dominated, with more complex expressions
113
  # having higher loss. Otherwise remove those rows.
114
  # TODO: Not sure why this occurs; could be the result of a late copy?
115
+ equations.sort_values("Complexity", ascending=True, inplace=True)
116
+ equations.reset_index(inplace=True)
117
  bad_idx = []
118
  min_loss = None
119
+ for i in equations.index:
120
+ if min_loss is None or equations.loc[i, "Loss"] < min_loss:
121
+ min_loss = float(equations.loc[i, "Loss"])
122
  else:
123
  bad_idx.append(i)
124
+ equations.drop(index=bad_idx, inplace=True)
125
+
126
+ while last_yield_time is not None and time.time() - last_yield_time < 1:
127
+ time.sleep(0.1)
128
+
129
+ yield equations[["Complexity", "Loss", "Equation"]]
130
+
131
+ last_yield_time = time.time()
132
  except pd.errors.EmptyDataError:
133
  pass
134
+
135
 
136
  process.join()
137
 
 
170
  def _data_layout():
171
  with gr.Tab("Example Data"):
172
  # Plot of the example data:
173
+ example_plot = gr.Plot()
 
 
 
 
 
 
 
 
174
  test_equation = gr.Radio(
175
  test_equations, value=test_equations[0], label="Test Equation"
176
  )
177
  num_points = gr.Slider(
178
  minimum=10,
179
  maximum=1000,
180
+ value=200,
181
  label="Number of Data Points",
182
  step=1,
183
  )
184
+ noise_level = gr.Slider(minimum=0, maximum=1, value=0.05, label="Noise Level")
185
  data_seed = gr.Number(value=0, label="Random Seed")
186
  with gr.Tab("Upload Data"):
187
  file_input = gr.File(label="Upload a CSV File")
188
  gr.Markdown(
189
+ "The rightmost column of your CSV file be used as the target variable."
190
  )
191
 
192
  return dict(
 
218
  "tan",
219
  ],
220
  label="Unary Operators",
221
+ value=["sin"],
222
  )
223
  niterations = gr.Slider(
224
  minimum=1,
 
303
  for eqn_component in eqn_components:
304
  eqn_component.change(replot, eqn_components, blocks["example_plot"])
305
 
306
+
307
  # Update plot when dataframe is updated:
308
  blocks["df"].change(
309
  replot_pareto,
310
  inputs=[blocks["df"], blocks["maxsize"]],
311
  outputs=[blocks["pareto"]],
312
  )
313
+ demo.load(replot, eqn_components, blocks["example_plot"])
314
 
315
  demo.launch(debug=True)
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  def replot_pareto(df, maxsize):
318
  plt.rcParams['font.family'] = 'IBM Plex Mono'
319
  fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
 
348
 
349
  return fig
350
 
351
+ def replot(test_equation, num_points, noise_level, data_seed):
352
+ X, y = generate_data(test_equation, num_points, noise_level, data_seed)
353
+ x = X["x"]
354
+
355
+ plt.rcParams['font.family'] = 'IBM Plex Mono'
356
+ fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
357
+
358
+ ax.scatter(x, y, alpha=0.7, edgecolors='w', s=50)
359
+
360
+ ax.grid(True, which="major", linestyle='--', linewidth=0.5, color='gray', alpha=0.7)
361
+ ax.grid(True, which="minor", linestyle=':', linewidth=0.5, color='gray', alpha=0.5)
362
+ ax.spines['top'].set_visible(False)
363
+ ax.spines['right'].set_visible(False)
364
+ ax.spines['bottom'].set_color('gray')
365
+ ax.spines['left'].set_color('gray')
366
+ ax.tick_params(axis='both', which='major', labelsize=12, direction='out', length=6)
367
+ ax.tick_params(axis='both', which='minor', labelsize=10, direction='out', length=4)
368
+ ax.set_xlabel("x")
369
+ ax.set_ylabel("y")
370
+
371
+ fig.tight_layout()
372
+ return fig
373
+
374
+
375
  if __name__ == "__main__":
376
  main()