MilesCranmer commited on
Commit
46fdaa6
1 Parent(s): 73042d9

Automatically plot test data

Browse files
Files changed (2) hide show
  1. gui/app.py +111 -70
  2. gui/requirements.txt +1 -1
gui/app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  import numpy as np
3
- import os
4
  import pandas as pd
5
  import pysr
6
  import tempfile
@@ -14,18 +13,13 @@ empty_df = pd.DataFrame(
14
  }
15
  )
16
 
17
- test_equations = {
18
- "Complex Polynomial": "3*x^3 + 2*x^2 - x + sin(x)",
19
- "Exponential and Logarithmic": "exp(-x) + log(x+1)",
20
- "Trigonometric Polynomial": "sin(x) + cos(2*x) + tan(x/3)",
21
- "Mixed Functions": "sqrt(x)*exp(-x) + cos(pi*x)",
22
- "Rational Function": "(x^2 + 1) / (x - 2)",
23
- }
24
 
25
 
26
- def generate_data(equation: str, num_points: int, noise_level: float):
27
- x = np.linspace(-10, 10, num_points)
28
- s = test_equations[equation]
29
  for (k, v) in {
30
  "sin": "np.sin",
31
  "cos": "np.cos",
@@ -117,68 +111,115 @@ model.fit(X, y)"""
117
 
118
 
119
  def main():
120
- demo = gr.Interface(
121
- fn=greet,
122
- description="Symbolic Regression with PySR. Watch search progress by following the logs.",
123
- inputs=[
124
- gr.File(label="Upload a CSV File"),
125
- gr.Radio(list(test_equations.keys()), label="Test Equation"),
126
- gr.Slider(
127
- minimum=10,
128
- maximum=1000,
129
- value=100,
130
- label="Number of Data Points",
131
- step=1,
132
- ),
133
- gr.Slider(minimum=0, maximum=1, value=0.1, label="Noise Level"),
134
- gr.Slider(
135
- minimum=1,
136
- maximum=1000,
137
- value=40,
138
- label="Number of Iterations",
139
- step=1,
140
- ),
141
- gr.Slider(
142
- minimum=7,
143
- maximum=35,
144
- value=20,
145
- label="Maximum Complexity",
146
- step=1,
147
- ),
148
- gr.CheckboxGroup(
149
- choices=["+", "-", "*", "/", "^"],
150
- label="Binary Operators",
151
- value=["+", "-", "*", "/"],
152
- ),
153
- gr.CheckboxGroup(
154
- choices=[
155
- "sin",
156
- "cos",
157
- "exp",
158
- "log",
159
- "square",
160
- "cube",
161
- "sqrt",
162
- "abs",
163
- "tan",
164
- ],
165
- label="Unary Operators",
166
- value=[],
167
- ),
168
- gr.Checkbox(
169
- value=False,
170
- label="Ignore Warnings",
171
- ),
172
- ],
173
- outputs=[
174
- "dataframe",
175
- gr.Textbox(label="Error Log"),
176
- ],
177
- )
178
- # Add file to the demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  demo.launch()
181
 
 
 
 
 
 
182
 
183
  if __name__ == "__main__":
184
  main()
 
1
  import gradio as gr
2
  import numpy as np
 
3
  import pandas as pd
4
  import pysr
5
  import tempfile
 
13
  }
14
  )
15
 
16
+ test_equations = [
17
+ "sin(x) + cos(2*x) + tan(x/3)",
18
+ ]
 
 
 
 
19
 
20
 
21
+ def generate_data(s: str, num_points: int, noise_level: float):
22
+ x = np.linspace(0, 10, num_points)
 
23
  for (k, v) in {
24
  "sin": "np.sin",
25
  "cos": "np.cos",
 
111
 
112
 
113
  def main():
114
+ with gr.Blocks() as demo:
115
+ with gr.Row():
116
+ with gr.Column():
117
+ with gr.Row():
118
+ with gr.Tab("Example Data"):
119
+ # Plot of the example data:
120
+ example_plot = gr.ScatterPlot(
121
+ x="x",
122
+ y="y",
123
+ tooltip=["x", "y"],
124
+ x_lim=[0, 10],
125
+ y_lim=[-5, 5],
126
+ width=350,
127
+ height=300,
128
+ )
129
+ test_equation = gr.Radio(
130
+ test_equations,
131
+ value=test_equations[0],
132
+ label="Test Equation"
133
+ )
134
+ num_points = gr.Slider(
135
+ minimum=10,
136
+ maximum=1000,
137
+ value=100,
138
+ label="Number of Data Points",
139
+ step=1,
140
+ )
141
+ noise_level = gr.Slider(
142
+ minimum=0, maximum=1, value=0.1, label="Noise Level"
143
+ )
144
+ with gr.Tab("Upload Data"):
145
+ file_input = gr.File(label="Upload a CSV File")
146
+ with gr.Row():
147
+ binary_operators = gr.CheckboxGroup(
148
+ choices=["+", "-", "*", "/", "^"],
149
+ label="Binary Operators",
150
+ value=["+", "-", "*", "/"],
151
+ )
152
+ unary_operators = gr.CheckboxGroup(
153
+ choices=[
154
+ "sin",
155
+ "cos",
156
+ "exp",
157
+ "log",
158
+ "square",
159
+ "cube",
160
+ "sqrt",
161
+ "abs",
162
+ "tan",
163
+ ],
164
+ label="Unary Operators",
165
+ value=[],
166
+ )
167
+ niterations = gr.Slider(
168
+ minimum=1,
169
+ maximum=1000,
170
+ value=40,
171
+ label="Number of Iterations",
172
+ step=1,
173
+ )
174
+ maxsize = gr.Slider(
175
+ minimum=7,
176
+ maximum=35,
177
+ value=20,
178
+ label="Maximum Complexity",
179
+ step=1,
180
+ )
181
+ force_run = gr.Checkbox(
182
+ value=False,
183
+ label="Ignore Warnings",
184
+ )
185
+
186
+ with gr.Column():
187
+ with gr.Row():
188
+ df = gr.Dataframe(
189
+ headers=["Equation", "Loss", "Complexity"],
190
+ datatype=["str", "number", "number"],
191
+ )
192
+ error_log = gr.Textbox(label="Error Log")
193
+ with gr.Row():
194
+ run_button = gr.Button()
195
+
196
+ run_button.click(
197
+ greet,
198
+ inputs=[
199
+ file_input,
200
+ test_equation,
201
+ num_points,
202
+ noise_level,
203
+ niterations,
204
+ maxsize,
205
+ binary_operators,
206
+ unary_operators,
207
+ force_run,
208
+ ],
209
+ outputs=[df, error_log],
210
+ )
211
+
212
+ # Any update to the equation choice will trigger a replot:
213
+ for eqn_component in [test_equation, num_points, noise_level]:
214
+ eqn_component.change(replot, [test_equation, num_points, noise_level], example_plot)
215
 
216
  demo.launch()
217
 
218
+ def replot(test_equation, num_points, noise_level):
219
+ X, y = generate_data(test_equation, num_points, noise_level)
220
+ df = pd.DataFrame({"x": X["x"], "y": y})
221
+ return df
222
+
223
 
224
  if __name__ == "__main__":
225
  main()
gui/requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  pysr==0.18.1
2
  numpy
3
  pandas
4
- gradio
 
1
  pysr==0.18.1
2
  numpy
3
  pandas
4
+ gradio