MilesCranmer commited on
Commit
8a2bd53
1 Parent(s): 9d6017e

Add more advanced settings

Browse files
Files changed (1) hide show
  1. gui/app.py +102 -13
gui/app.py CHANGED
@@ -52,6 +52,16 @@ def _greet_dispatch(
52
  binary_operators,
53
  unary_operators,
54
  plot_update_delay,
 
 
 
 
 
 
 
 
 
 
55
  ):
56
  """Load data, then spawn a process to run the greet function."""
57
  if file_input is not None:
@@ -96,6 +106,16 @@ def _greet_dispatch(
96
  binary_operators=binary_operators,
97
  unary_operators=unary_operators,
98
  equation_file=equation_file,
 
 
 
 
 
 
 
 
 
 
99
  ),
100
  )
101
  process.start()
@@ -140,22 +160,14 @@ def greet(
140
  *,
141
  X,
142
  y,
143
- niterations: int,
144
- maxsize: int,
145
- binary_operators: list,
146
- unary_operators: list,
147
- equation_file: Union[str, Path],
148
  ):
149
  import pysr
150
 
151
  model = pysr.PySRRegressor(
152
  progress=False,
153
- maxsize=maxsize,
154
- niterations=niterations,
155
- binary_operators=binary_operators,
156
- unary_operators=unary_operators,
157
  timeout_in_seconds=1000,
158
- equation_file=equation_file,
159
  )
160
  model.fit(X, y)
161
 
@@ -230,15 +242,68 @@ def _settings_layout():
230
  )
231
  maxsize = gr.Slider(
232
  minimum=7,
233
- maximum=35,
234
  value=20,
235
  label="Maximum Complexity",
236
  step=1,
237
  )
238
- force_run = gr.Checkbox(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  value=False,
240
- label="Ignore Warnings",
241
  )
 
 
 
 
 
 
 
 
242
  with gr.Tab("Gradio Settings"):
243
  plot_update_delay = gr.Slider(
244
  minimum=1,
@@ -246,6 +311,10 @@ def _settings_layout():
246
  value=3,
247
  label="Plot Update Delay",
248
  )
 
 
 
 
249
  return dict(
250
  binary_operators=binary_operators,
251
  unary_operators=unary_operators,
@@ -253,6 +322,16 @@ def _settings_layout():
253
  maxsize=maxsize,
254
  force_run=force_run,
255
  plot_update_delay=plot_update_delay,
 
 
 
 
 
 
 
 
 
 
256
  )
257
 
258
 
@@ -292,6 +371,16 @@ def main():
292
  "binary_operators",
293
  "unary_operators",
294
  "plot_update_delay",
 
 
 
 
 
 
 
 
 
 
295
  ]
296
  ],
297
  outputs=blocks["df"],
 
52
  binary_operators,
53
  unary_operators,
54
  plot_update_delay,
55
+ parsimony,
56
+ populations,
57
+ population_size,
58
+ ncycles_per_iteration,
59
+ elementwise_loss,
60
+ adaptive_parsimony_scaling,
61
+ optimizer_algorithm,
62
+ optimizer_iterations,
63
+ batching,
64
+ batch_size,
65
  ):
66
  """Load data, then spawn a process to run the greet function."""
67
  if file_input is not None:
 
106
  binary_operators=binary_operators,
107
  unary_operators=unary_operators,
108
  equation_file=equation_file,
109
+ parsimony=parsimony,
110
+ populations=populations,
111
+ population_size=population_size,
112
+ ncycles_per_iteration=ncycles_per_iteration,
113
+ elementwise_loss=elementwise_loss,
114
+ adaptive_parsimony_scaling=adaptive_parsimony_scaling,
115
+ optimizer_algorithm=optimizer_algorithm,
116
+ optimizer_iterations=optimizer_iterations,
117
+ batching=batching,
118
+ batch_size=batch_size,
119
  ),
120
  )
121
  process.start()
 
160
  *,
161
  X,
162
  y,
163
+ **pysr_kwargs,
 
 
 
 
164
  ):
165
  import pysr
166
 
167
  model = pysr.PySRRegressor(
168
  progress=False,
 
 
 
 
169
  timeout_in_seconds=1000,
170
+ **pysr_kwargs,
171
  )
172
  model.fit(X, y)
173
 
 
242
  )
243
  maxsize = gr.Slider(
244
  minimum=7,
245
+ maximum=100,
246
  value=20,
247
  label="Maximum Complexity",
248
  step=1,
249
  )
250
+ parsimony = gr.Number(
251
+ value=0.0032,
252
+ label="Parsimony Coefficient",
253
+ )
254
+ with gr.Tab("Advanced Settings"):
255
+ populations = gr.Slider(
256
+ minimum=2,
257
+ maximum=100,
258
+ value=15,
259
+ label="Number of Populations",
260
+ step=1,
261
+ )
262
+ population_size = gr.Slider(
263
+ minimum=2,
264
+ maximum=1000,
265
+ value=33,
266
+ label="Population Size",
267
+ step=1,
268
+ )
269
+ ncycles_per_iteration = gr.Number(
270
+ value=550,
271
+ label="Cycles per Iteration",
272
+ )
273
+ elementwise_loss = gr.Radio(
274
+ ["L2DistLoss()", "L1DistLoss()", "LogitDistLoss()", "HuberLoss()"],
275
+ value="L2DistLoss()",
276
+ label="Loss Function",
277
+ )
278
+ adaptive_parsimony_scaling = gr.Number(
279
+ value=20.0,
280
+ label="Adaptive Parsimony Scaling",
281
+ )
282
+ optimizer_algorithm = gr.Radio(
283
+ ["BFGS", "NelderMead"],
284
+ value="BFGS",
285
+ label="Optimizer Algorithm",
286
+ )
287
+ optimizer_iterations = gr.Slider(
288
+ minimum=1,
289
+ maximum=100,
290
+ value=8,
291
+ label="Optimizer Iterations",
292
+ step=1,
293
+ )
294
+ # Bool:
295
+ batching = gr.Checkbox(
296
  value=False,
297
+ label="Batching",
298
  )
299
+ batch_size = gr.Slider(
300
+ minimum=2,
301
+ maximum=1000,
302
+ value=50,
303
+ label="Batch Size",
304
+ step=1,
305
+ )
306
+
307
  with gr.Tab("Gradio Settings"):
308
  plot_update_delay = gr.Slider(
309
  minimum=1,
 
311
  value=3,
312
  label="Plot Update Delay",
313
  )
314
+ force_run = gr.Checkbox(
315
+ value=False,
316
+ label="Ignore Warnings",
317
+ )
318
  return dict(
319
  binary_operators=binary_operators,
320
  unary_operators=unary_operators,
 
322
  maxsize=maxsize,
323
  force_run=force_run,
324
  plot_update_delay=plot_update_delay,
325
+ parsimony=parsimony,
326
+ populations=populations,
327
+ population_size=population_size,
328
+ ncycles_per_iteration=ncycles_per_iteration,
329
+ elementwise_loss=elementwise_loss,
330
+ adaptive_parsimony_scaling=adaptive_parsimony_scaling,
331
+ optimizer_algorithm=optimizer_algorithm,
332
+ optimizer_iterations=optimizer_iterations,
333
+ batching=batching,
334
+ batch_size=batch_size,
335
  )
336
 
337
 
 
371
  "binary_operators",
372
  "unary_operators",
373
  "plot_update_delay",
374
+ "parsimony",
375
+ "populations",
376
+ "population_size",
377
+ "ncycles_per_iteration",
378
+ "elementwise_loss",
379
+ "adaptive_parsimony_scaling",
380
+ "optimizer_algorithm",
381
+ "optimizer_iterations",
382
+ "batching",
383
+ "batch_size",
384
  ]
385
  ],
386
  outputs=blocks["df"],