MilesCranmer commited on
Commit
a206d6a
1 Parent(s): ef7aada

refactor(gui): gradio to use object oriented wrapper

Browse files
Files changed (3) hide show
  1. gui/app.py +162 -136
  2. gui/data.py +1 -1
  3. gui/processing.py +2 -0
gui/app.py CHANGED
@@ -1,56 +1,78 @@
 
 
1
  import gradio as gr
2
- from data import test_equations
 
 
3
  from plots import plot_example_data, plot_pareto_curve
4
  from processing import processing
5
 
6
- GLOBAL_SETTINGS = dict(theme="default")
7
-
8
 
9
- def _data_layout():
10
- with gr.Tab("Example Data"):
11
- # Plot of the example data:
12
  with gr.Row():
 
13
  with gr.Column():
14
- example_plot = gr.Plot()
15
  with gr.Column():
16
- test_equation = gr.Radio(
17
- test_equations, value=test_equations[0], label="Test Equation"
18
  )
19
- num_points = gr.Slider(
20
  minimum=10,
21
  maximum=1000,
22
  value=200,
23
  label="Number of Data Points",
24
  step=1,
25
  )
26
- noise_level = gr.Slider(
27
  minimum=0, maximum=1, value=0.05, label="Noise Level"
28
  )
29
- data_seed = gr.Number(value=0, label="Random Seed")
30
- with gr.Tab("Upload Data"):
31
- file_input = gr.File(label="Upload a CSV File")
32
- gr.Markdown(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  "The rightmost column of your CSV file will be used as the target variable."
34
  )
35
 
36
- return dict(
37
- file_input=file_input,
38
- test_equation=test_equation,
39
- num_points=num_points,
40
- noise_level=noise_level,
41
- data_seed=data_seed,
42
- example_plot=example_plot,
43
- )
44
 
45
 
46
- def _settings_layout():
47
- with gr.Tab("Basic Settings"):
48
- binary_operators = gr.CheckboxGroup(
49
  choices=["+", "-", "*", "/", "^", "max", "min", "mod", "cond"],
50
  label="Binary Operators",
51
  value=["+", "-", "*", "/"],
52
  )
53
- unary_operators = gr.CheckboxGroup(
54
  choices=[
55
  "sin",
56
  "cos",
@@ -69,58 +91,61 @@ def _settings_layout():
69
  label="Unary Operators",
70
  value=["sin"],
71
  )
72
- niterations = gr.Slider(
73
  minimum=1,
74
  maximum=1000,
75
  value=40,
76
  label="Number of Iterations",
77
  step=1,
78
  )
79
- maxsize = gr.Slider(
80
  minimum=7,
81
  maximum=100,
82
  value=20,
83
  label="Maximum Complexity",
84
  step=1,
85
  )
86
- parsimony = gr.Number(
87
  value=0.0032,
88
  label="Parsimony Coefficient",
89
  )
90
- with gr.Tab("Advanced Settings"):
91
- populations = gr.Slider(
 
 
 
92
  minimum=2,
93
  maximum=100,
94
  value=15,
95
  label="Number of Populations",
96
  step=1,
97
  )
98
- population_size = gr.Slider(
99
  minimum=2,
100
  maximum=1000,
101
  value=33,
102
  label="Population Size",
103
  step=1,
104
  )
105
- ncycles_per_iteration = gr.Number(
106
  value=550,
107
  label="Cycles per Iteration",
108
  )
109
- elementwise_loss = gr.Radio(
110
  ["L2DistLoss()", "L1DistLoss()", "LogitDistLoss()", "HuberLoss()"],
111
  value="L2DistLoss()",
112
  label="Loss Function",
113
  )
114
- adaptive_parsimony_scaling = gr.Number(
115
  value=20.0,
116
  label="Adaptive Parsimony Scaling",
117
  )
118
- optimizer_algorithm = gr.Radio(
119
  ["BFGS", "NelderMead"],
120
  value="BFGS",
121
  label="Optimizer Algorithm",
122
  )
123
- optimizer_iterations = gr.Slider(
124
  minimum=1,
125
  maximum=100,
126
  value=8,
@@ -128,11 +153,11 @@ def _settings_layout():
128
  step=1,
129
  )
130
  # Bool:
131
- batching = gr.Checkbox(
132
  value=False,
133
  label="Batching",
134
  )
135
- batch_size = gr.Slider(
136
  minimum=2,
137
  maximum=1000,
138
  value=50,
@@ -140,121 +165,122 @@ def _settings_layout():
140
  step=1,
141
  )
142
 
143
- with gr.Tab("Gradio Settings"):
144
- plot_update_delay = gr.Slider(
 
 
145
  minimum=1,
146
  maximum=100,
147
  value=3,
148
  label="Plot Update Delay",
149
  )
150
- force_run = gr.Checkbox(
151
  value=False,
152
  label="Ignore Warnings",
153
  )
154
- return dict(
155
- binary_operators=binary_operators,
156
- unary_operators=unary_operators,
157
- niterations=niterations,
158
- maxsize=maxsize,
159
- force_run=force_run,
160
- plot_update_delay=plot_update_delay,
161
- parsimony=parsimony,
162
- populations=populations,
163
- population_size=population_size,
164
- ncycles_per_iteration=ncycles_per_iteration,
165
- elementwise_loss=elementwise_loss,
166
- adaptive_parsimony_scaling=adaptive_parsimony_scaling,
167
- optimizer_algorithm=optimizer_algorithm,
168
- optimizer_iterations=optimizer_iterations,
169
- batching=batching,
170
- batch_size=batch_size,
171
- )
172
-
173
-
174
- def main():
175
- global GLOBAL_SETTINGS
176
- blocks = {}
177
- with gr.Blocks(**GLOBAL_SETTINGS) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  with gr.Row():
179
  with gr.Column():
180
  with gr.Row():
181
- blocks = {**blocks, **_data_layout()}
182
  with gr.Row():
183
- blocks = {**blocks, **_settings_layout()}
184
-
185
  with gr.Column():
186
- with gr.Tab("Pareto Front"):
187
- blocks["pareto"] = gr.Plot()
188
- with gr.Tab("Predictions"):
189
- blocks["predictions_plot"] = gr.Plot()
190
-
191
- blocks["df"] = gr.Dataframe(
192
- headers=["complexity", "loss", "equation"],
193
- datatype=["number", "number", "str"],
194
- wrap=True,
195
- column_widths=[75, 75, 200],
196
- interactive=False,
197
- )
198
- blocks["run"] = gr.Button()
199
-
200
- blocks["run"].click(
201
- processing,
202
- inputs=[
203
- blocks[k]
204
- for k in [
205
- "file_input",
206
- "force_run",
207
- "test_equation",
208
- "num_points",
209
- "noise_level",
210
- "data_seed",
211
- "niterations",
212
- "maxsize",
213
- "binary_operators",
214
- "unary_operators",
215
- "plot_update_delay",
216
- "parsimony",
217
- "populations",
218
- "population_size",
219
- "ncycles_per_iteration",
220
- "elementwise_loss",
221
- "adaptive_parsimony_scaling",
222
- "optimizer_algorithm",
223
- "optimizer_iterations",
224
- "batching",
225
- "batch_size",
226
- ]
227
- ],
228
- outputs=[blocks["df"], blocks["predictions_plot"]],
229
- show_progress=True,
230
- )
231
-
232
- # Any update to the equation choice will trigger a plot_example_data:
233
- eqn_components = [
234
- blocks["test_equation"],
235
- blocks["num_points"],
236
- blocks["noise_level"],
237
- blocks["data_seed"],
238
- ]
239
- for eqn_component in eqn_components:
240
- eqn_component.change(
241
- plot_example_data,
242
- eqn_components,
243
- blocks["example_plot"],
244
- show_progress=False,
245
- )
246
 
247
  # Update plot when dataframe is updated:
248
- blocks["df"].change(
249
  plot_pareto_curve,
250
- inputs=[blocks["df"], blocks["maxsize"]],
251
- outputs=[blocks["pareto"]],
252
  show_progress=False,
253
  )
254
- demo.load(plot_example_data, eqn_components, blocks["example_plot"])
255
 
256
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
 
258
 
259
  if __name__ == "__main__":
260
- main()
 
1
+ from collections import OrderedDict
2
+
3
  import gradio as gr
4
+ import numpy as np
5
+ from data import TEST_EQUATIONS
6
+ from gradio.components.base import Component
7
  from plots import plot_example_data, plot_pareto_curve
8
  from processing import processing
9
 
 
 
10
 
11
+ class ExampleData:
12
+ def __init__(self, demo: gr.Blocks) -> None:
 
13
  with gr.Row():
14
+ # Plot of the example data:
15
  with gr.Column():
16
+ self.example_plot = gr.Plot()
17
  with gr.Column():
18
+ self.test_equation = gr.Radio(
19
+ TEST_EQUATIONS, value=TEST_EQUATIONS[0], label="Test Equation"
20
  )
21
+ self.num_points = gr.Slider(
22
  minimum=10,
23
  maximum=1000,
24
  value=200,
25
  label="Number of Data Points",
26
  step=1,
27
  )
28
+ self.noise_level = gr.Slider(
29
  minimum=0, maximum=1, value=0.05, label="Noise Level"
30
  )
31
+ self.data_seed = gr.Number(value=0, label="Random Seed")
32
+
33
+ # Set up plotting:
34
+
35
+ eqn_components = [
36
+ self.test_equation,
37
+ self.num_points,
38
+ self.noise_level,
39
+ self.data_seed,
40
+ ]
41
+ for eqn_component in eqn_components:
42
+ eqn_component.change(
43
+ plot_example_data,
44
+ eqn_components,
45
+ self.example_plot,
46
+ show_progress=False,
47
+ )
48
+
49
+ demo.load(plot_example_data, eqn_components, self.example_plot)
50
+
51
+
52
+ class UploadData:
53
+ def __init__(self) -> None:
54
+ self.file_input = gr.File(label="Upload a CSV File")
55
+ self.label = gr.Markdown(
56
  "The rightmost column of your CSV file will be used as the target variable."
57
  )
58
 
59
+
60
+ class Data:
61
+ def __init__(self, demo: gr.Blocks) -> None:
62
+ with gr.Tab("Example Data"):
63
+ self.example_data = ExampleData(demo)
64
+ with gr.Tab("Upload Data"):
65
+ self.upload_data = UploadData()
 
66
 
67
 
68
+ class BasicSettings:
69
+ def __init__(self) -> None:
70
+ self.binary_operators = gr.CheckboxGroup(
71
  choices=["+", "-", "*", "/", "^", "max", "min", "mod", "cond"],
72
  label="Binary Operators",
73
  value=["+", "-", "*", "/"],
74
  )
75
+ self.unary_operators = gr.CheckboxGroup(
76
  choices=[
77
  "sin",
78
  "cos",
 
91
  label="Unary Operators",
92
  value=["sin"],
93
  )
94
+ self.niterations = gr.Slider(
95
  minimum=1,
96
  maximum=1000,
97
  value=40,
98
  label="Number of Iterations",
99
  step=1,
100
  )
101
+ self.maxsize = gr.Slider(
102
  minimum=7,
103
  maximum=100,
104
  value=20,
105
  label="Maximum Complexity",
106
  step=1,
107
  )
108
+ self.parsimony = gr.Number(
109
  value=0.0032,
110
  label="Parsimony Coefficient",
111
  )
112
+
113
+
114
+ class AdvancedSettings:
115
+ def __init__(self) -> None:
116
+ self.populations = gr.Slider(
117
  minimum=2,
118
  maximum=100,
119
  value=15,
120
  label="Number of Populations",
121
  step=1,
122
  )
123
+ self.population_size = gr.Slider(
124
  minimum=2,
125
  maximum=1000,
126
  value=33,
127
  label="Population Size",
128
  step=1,
129
  )
130
+ self.ncycles_per_iteration = gr.Number(
131
  value=550,
132
  label="Cycles per Iteration",
133
  )
134
+ self.elementwise_loss = gr.Radio(
135
  ["L2DistLoss()", "L1DistLoss()", "LogitDistLoss()", "HuberLoss()"],
136
  value="L2DistLoss()",
137
  label="Loss Function",
138
  )
139
+ self.adaptive_parsimony_scaling = gr.Number(
140
  value=20.0,
141
  label="Adaptive Parsimony Scaling",
142
  )
143
+ self.optimizer_algorithm = gr.Radio(
144
  ["BFGS", "NelderMead"],
145
  value="BFGS",
146
  label="Optimizer Algorithm",
147
  )
148
+ self.optimizer_iterations = gr.Slider(
149
  minimum=1,
150
  maximum=100,
151
  value=8,
 
153
  step=1,
154
  )
155
  # Bool:
156
+ self.batching = gr.Checkbox(
157
  value=False,
158
  label="Batching",
159
  )
160
+ self.batch_size = gr.Slider(
161
  minimum=2,
162
  maximum=1000,
163
  value=50,
 
165
  step=1,
166
  )
167
 
168
+
169
+ class GradioSettings:
170
+ def __init__(self) -> None:
171
+ self.plot_update_delay = gr.Slider(
172
  minimum=1,
173
  maximum=100,
174
  value=3,
175
  label="Plot Update Delay",
176
  )
177
+ self.force_run = gr.Checkbox(
178
  value=False,
179
  label="Ignore Warnings",
180
  )
181
+
182
+
183
+ class Settings:
184
+ def __init__(self):
185
+ with gr.Tab("Basic Settings"):
186
+ self.basic_settings = BasicSettings()
187
+ with gr.Tab("Advanced Settings"):
188
+ self.advanced_settings = AdvancedSettings()
189
+ with gr.Tab("Gradio Settings"):
190
+ self.gradio_settings = GradioSettings()
191
+
192
+
193
+ class Results:
194
+ def __init__(self):
195
+ with gr.Tab("Pareto Front"):
196
+ self.pareto = gr.Plot()
197
+ with gr.Tab("Predictions"):
198
+ self.predictions_plot = gr.Plot()
199
+
200
+ self.df = gr.Dataframe(
201
+ headers=["complexity", "loss", "equation"],
202
+ datatype=["number", "number", "str"],
203
+ wrap=True,
204
+ column_widths=[75, 75, 200],
205
+ interactive=False,
206
+ )
207
+
208
+
209
+ def flatten_attributes(component_group, absolute_name: str, d=None) -> OrderedDict:
210
+ if d is None:
211
+ d = OrderedDict()
212
+
213
+ if not hasattr(component_group, "__dict__"):
214
+ return d
215
+
216
+ for name, elem in component_group.__dict__.items():
217
+ new_absolute_name = absolute_name + "." + name
218
+ if name.startswith("_"):
219
+ # Private attribute
220
+ continue
221
+ elif elem in component_group.__dict__.values():
222
+ # Don't duplicate any tiems
223
+ continue
224
+ elif isinstance(elem, Component):
225
+ # Only add components to dict
226
+ d[new_absolute_name] = elem
227
+ else:
228
+ d = flatten_attributes(elem, new_absolute_name, d=d)
229
+
230
+ return d
231
+
232
+
233
+ class AppInterface:
234
+ def __init__(self, demo: gr.Blocks) -> None:
235
  with gr.Row():
236
  with gr.Column():
237
  with gr.Row():
238
+ self.data = Data(demo)
239
  with gr.Row():
240
+ self.settings = Settings()
 
241
  with gr.Column():
242
+ self.results = Results()
243
+ self.run = gr.Button()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
244
 
245
  # Update plot when dataframe is updated:
246
+ self.results.df.change(
247
  plot_pareto_curve,
248
+ inputs=[self.results.df, self.settings.basic_settings.maxsize],
249
+ outputs=[self.results.pareto],
250
  show_progress=False,
251
  )
 
252
 
253
+ self.run.click(
254
+ create_processing_function(self, ignore=["df", "predictions_plot"]),
255
+ inputs=list(flatten_attributes(self, "interface").values()),
256
+ outputs=[self.results.df, self.results.predictions_plot],
257
+ show_progress=True,
258
+ )
259
+
260
+
261
+ def create_processing_function(interface: AppInterface, ignore=[]):
262
+ d = flatten_attributes(interface, "interface")
263
+ keys = [k.split(".")[-1] for k in d.keys()]
264
+ keys = [k for k in keys if k not in ignore]
265
+ _, idx, counts = np.unique(keys, return_index=True, return_counts=True)
266
+ if np.any(counts > 1):
267
+ raise AssertionError("Bad keys: " + ",".join(np.array(keys)[idx[counts > 1]]))
268
+
269
+ def f(components):
270
+ n = len(components)
271
+ assert n == len(keys)
272
+ return processing(**{keys[i]: components[i] for i in range(n)})
273
+
274
+ return f
275
+
276
+
277
+ class App:
278
+ def __init__(self, theme="default") -> None:
279
+ with gr.Blocks(theme=theme) as demo:
280
+ self.interface = AppInterface(demo)
281
+
282
+ demo.launch(debug=True)
283
 
284
 
285
  if __name__ == "__main__":
286
+ app = App()
gui/data.py CHANGED
@@ -1,7 +1,7 @@
1
  import numpy as np
2
  import pandas as pd
3
 
4
- test_equations = ["sin(2*x)/x + 0.1*x"]
5
 
6
 
7
  def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
 
1
  import numpy as np
2
  import pandas as pd
3
 
4
+ TEST_EQUATIONS = ["sin(2*x)/x + 0.1*x"]
5
 
6
 
7
  def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
gui/processing.py CHANGED
@@ -92,6 +92,7 @@ PERSISTENT_READER = None
92
 
93
 
94
  def processing(
 
95
  file_input,
96
  force_run,
97
  test_equation,
@@ -113,6 +114,7 @@ def processing(
113
  optimizer_iterations,
114
  batching,
115
  batch_size,
 
116
  ):
117
  """Load data, then spawn a process to run the greet function."""
118
  global PERSISTENT_WRITER
 
92
 
93
 
94
  def processing(
95
+ *,
96
  file_input,
97
  force_run,
98
  test_equation,
 
114
  optimizer_iterations,
115
  batching,
116
  batch_size,
117
+ **kwargs,
118
  ):
119
  """Load data, then spawn a process to run the greet function."""
120
  global PERSISTENT_WRITER