Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
9d6017e
1
Parent(s):
f751163
Add setting for plot update rate
Browse files- gui/app.py +91 -77
gui/app.py
CHANGED
@@ -5,6 +5,7 @@ import pandas as pd
|
|
5 |
import time
|
6 |
import multiprocessing as mp
|
7 |
from matplotlib import pyplot as plt
|
|
|
8 |
plt.ioff()
|
9 |
import tempfile
|
10 |
from typing import Optional, Union
|
@@ -18,9 +19,7 @@ empty_df = pd.DataFrame(
|
|
18 |
}
|
19 |
)
|
20 |
|
21 |
-
test_equations = [
|
22 |
-
"sin(2*x)/x + 0.1*x"
|
23 |
-
]
|
24 |
|
25 |
|
26 |
def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
|
@@ -52,7 +51,7 @@ def _greet_dispatch(
|
|
52 |
maxsize,
|
53 |
binary_operators,
|
54 |
unary_operators,
|
55 |
-
|
56 |
):
|
57 |
"""Load data, then spawn a process to run the greet function."""
|
58 |
if file_input is not None:
|
@@ -96,7 +95,6 @@ def _greet_dispatch(
|
|
96 |
maxsize=maxsize,
|
97 |
binary_operators=binary_operators,
|
98 |
unary_operators=unary_operators,
|
99 |
-
seed=seed,
|
100 |
equation_file=equation_file,
|
101 |
),
|
102 |
)
|
@@ -123,7 +121,10 @@ def _greet_dispatch(
|
|
123 |
bad_idx.append(i)
|
124 |
equations.drop(index=bad_idx, inplace=True)
|
125 |
|
126 |
-
while
|
|
|
|
|
|
|
127 |
time.sleep(0.1)
|
128 |
|
129 |
yield equations[["Complexity", "Loss", "Equation"]]
|
@@ -132,7 +133,6 @@ def _greet_dispatch(
|
|
132 |
except pd.errors.EmptyDataError:
|
133 |
pass
|
134 |
|
135 |
-
|
136 |
process.join()
|
137 |
|
138 |
|
@@ -144,7 +144,6 @@ def greet(
|
|
144 |
maxsize: int,
|
145 |
binary_operators: list,
|
146 |
unary_operators: list,
|
147 |
-
seed: int,
|
148 |
equation_file: Union[str, Path],
|
149 |
):
|
150 |
import pysr
|
@@ -180,7 +179,9 @@ def _data_layout():
|
|
180 |
label="Number of Data Points",
|
181 |
step=1,
|
182 |
)
|
183 |
-
noise_level = gr.Slider(
|
|
|
|
|
184 |
data_seed = gr.Number(value=0, label="Random Seed")
|
185 |
with gr.Tab("Upload Data"):
|
186 |
file_input = gr.File(label="Upload a CSV File")
|
@@ -199,55 +200,59 @@ def _data_layout():
|
|
199 |
|
200 |
|
201 |
def _settings_layout():
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
|
|
|
|
|
|
|
|
244 |
return dict(
|
245 |
binary_operators=binary_operators,
|
246 |
unary_operators=unary_operators,
|
247 |
niterations=niterations,
|
248 |
maxsize=maxsize,
|
249 |
force_run=force_run,
|
250 |
-
|
251 |
)
|
252 |
|
253 |
|
@@ -286,7 +291,7 @@ def main():
|
|
286 |
"maxsize",
|
287 |
"binary_operators",
|
288 |
"unary_operators",
|
289 |
-
"
|
290 |
]
|
291 |
],
|
292 |
outputs=blocks["df"],
|
@@ -302,7 +307,6 @@ def main():
|
|
302 |
for eqn_component in eqn_components:
|
303 |
eqn_component.change(replot, eqn_components, blocks["example_plot"])
|
304 |
|
305 |
-
|
306 |
# Update plot when dataframe is updated:
|
307 |
blocks["df"].change(
|
308 |
replot_pareto,
|
@@ -313,60 +317,70 @@ def main():
|
|
313 |
|
314 |
demo.launch(debug=True)
|
315 |
|
|
|
316 |
def replot_pareto(df, maxsize):
|
317 |
-
plt.rcParams[
|
318 |
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
|
319 |
|
320 |
-
if len(df) == 0 or
|
321 |
return fig
|
322 |
|
323 |
# Plotting the data
|
324 |
-
ax.loglog(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
325 |
|
326 |
# Set the axis limits
|
327 |
ax.set_xlim(0.5, maxsize + 1)
|
328 |
-
ytop = 2 ** (np.ceil(np.log2(df[
|
329 |
-
ybottom = 2 ** (np.floor(np.log2(df[
|
330 |
ax.set_ylim(ybottom, ytop)
|
331 |
|
332 |
-
ax.grid(True, which="both", ls="--", linewidth=0.5, color=
|
333 |
-
ax.spines[
|
334 |
-
ax.spines[
|
335 |
|
336 |
# Range-frame the plot
|
337 |
-
for direction in [
|
338 |
-
ax.spines[direction].set_position((
|
339 |
|
340 |
# Delete far ticks
|
341 |
-
ax.tick_params(axis=
|
342 |
-
ax.tick_params(axis=
|
343 |
|
344 |
-
ax.set_xlabel(
|
345 |
-
ax.set_ylabel(
|
346 |
fig.tight_layout(pad=2)
|
347 |
|
348 |
return fig
|
349 |
|
|
|
350 |
def replot(test_equation, num_points, noise_level, data_seed):
|
351 |
X, y = generate_data(test_equation, num_points, noise_level, data_seed)
|
352 |
x = X["x"]
|
353 |
|
354 |
-
plt.rcParams[
|
355 |
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
|
356 |
|
357 |
-
ax.scatter(x, y, alpha=0.7, edgecolors=
|
358 |
|
359 |
-
ax.grid(True, which="both", ls="--", linewidth=0.5, color=
|
360 |
-
ax.spines[
|
361 |
-
ax.spines[
|
362 |
|
363 |
# Range-frame the plot
|
364 |
-
for direction in [
|
365 |
-
ax.spines[direction].set_position((
|
366 |
|
367 |
# Delete far ticks
|
368 |
-
ax.tick_params(axis=
|
369 |
-
ax.tick_params(axis=
|
370 |
|
371 |
ax.set_xlabel("x")
|
372 |
ax.set_ylabel("y")
|
|
|
5 |
import time
|
6 |
import multiprocessing as mp
|
7 |
from matplotlib import pyplot as plt
|
8 |
+
|
9 |
plt.ioff()
|
10 |
import tempfile
|
11 |
from typing import Optional, Union
|
|
|
19 |
}
|
20 |
)
|
21 |
|
22 |
+
test_equations = ["sin(2*x)/x + 0.1*x"]
|
|
|
|
|
23 |
|
24 |
|
25 |
def generate_data(s: str, num_points: int, noise_level: float, data_seed: int):
|
|
|
51 |
maxsize,
|
52 |
binary_operators,
|
53 |
unary_operators,
|
54 |
+
plot_update_delay,
|
55 |
):
|
56 |
"""Load data, then spawn a process to run the greet function."""
|
57 |
if file_input is not None:
|
|
|
95 |
maxsize=maxsize,
|
96 |
binary_operators=binary_operators,
|
97 |
unary_operators=unary_operators,
|
|
|
98 |
equation_file=equation_file,
|
99 |
),
|
100 |
)
|
|
|
121 |
bad_idx.append(i)
|
122 |
equations.drop(index=bad_idx, inplace=True)
|
123 |
|
124 |
+
while (
|
125 |
+
last_yield_time is not None
|
126 |
+
and time.time() - last_yield_time < plot_update_delay
|
127 |
+
):
|
128 |
time.sleep(0.1)
|
129 |
|
130 |
yield equations[["Complexity", "Loss", "Equation"]]
|
|
|
133 |
except pd.errors.EmptyDataError:
|
134 |
pass
|
135 |
|
|
|
136 |
process.join()
|
137 |
|
138 |
|
|
|
144 |
maxsize: int,
|
145 |
binary_operators: list,
|
146 |
unary_operators: list,
|
|
|
147 |
equation_file: Union[str, Path],
|
148 |
):
|
149 |
import pysr
|
|
|
179 |
label="Number of Data Points",
|
180 |
step=1,
|
181 |
)
|
182 |
+
noise_level = gr.Slider(
|
183 |
+
minimum=0, maximum=1, value=0.05, label="Noise Level"
|
184 |
+
)
|
185 |
data_seed = gr.Number(value=0, label="Random Seed")
|
186 |
with gr.Tab("Upload Data"):
|
187 |
file_input = gr.File(label="Upload a CSV File")
|
|
|
200 |
|
201 |
|
202 |
def _settings_layout():
|
203 |
+
with gr.Tab("Basic Settings"):
|
204 |
+
binary_operators = gr.CheckboxGroup(
|
205 |
+
choices=["+", "-", "*", "/", "^"],
|
206 |
+
label="Binary Operators",
|
207 |
+
value=["+", "-", "*", "/"],
|
208 |
+
)
|
209 |
+
unary_operators = gr.CheckboxGroup(
|
210 |
+
choices=[
|
211 |
+
"sin",
|
212 |
+
"cos",
|
213 |
+
"exp",
|
214 |
+
"log",
|
215 |
+
"square",
|
216 |
+
"cube",
|
217 |
+
"sqrt",
|
218 |
+
"abs",
|
219 |
+
"tan",
|
220 |
+
],
|
221 |
+
label="Unary Operators",
|
222 |
+
value=["sin"],
|
223 |
+
)
|
224 |
+
niterations = gr.Slider(
|
225 |
+
minimum=1,
|
226 |
+
maximum=1000,
|
227 |
+
value=40,
|
228 |
+
label="Number of Iterations",
|
229 |
+
step=1,
|
230 |
+
)
|
231 |
+
maxsize = gr.Slider(
|
232 |
+
minimum=7,
|
233 |
+
maximum=35,
|
234 |
+
value=20,
|
235 |
+
label="Maximum Complexity",
|
236 |
+
step=1,
|
237 |
+
)
|
238 |
+
force_run = gr.Checkbox(
|
239 |
+
value=False,
|
240 |
+
label="Ignore Warnings",
|
241 |
+
)
|
242 |
+
with gr.Tab("Gradio Settings"):
|
243 |
+
plot_update_delay = gr.Slider(
|
244 |
+
minimum=1,
|
245 |
+
maximum=100,
|
246 |
+
value=3,
|
247 |
+
label="Plot Update Delay",
|
248 |
+
)
|
249 |
return dict(
|
250 |
binary_operators=binary_operators,
|
251 |
unary_operators=unary_operators,
|
252 |
niterations=niterations,
|
253 |
maxsize=maxsize,
|
254 |
force_run=force_run,
|
255 |
+
plot_update_delay=plot_update_delay,
|
256 |
)
|
257 |
|
258 |
|
|
|
291 |
"maxsize",
|
292 |
"binary_operators",
|
293 |
"unary_operators",
|
294 |
+
"plot_update_delay",
|
295 |
]
|
296 |
],
|
297 |
outputs=blocks["df"],
|
|
|
307 |
for eqn_component in eqn_components:
|
308 |
eqn_component.change(replot, eqn_components, blocks["example_plot"])
|
309 |
|
|
|
310 |
# Update plot when dataframe is updated:
|
311 |
blocks["df"].change(
|
312 |
replot_pareto,
|
|
|
317 |
|
318 |
demo.launch(debug=True)
|
319 |
|
320 |
+
|
321 |
def replot_pareto(df, maxsize):
|
322 |
+
plt.rcParams["font.family"] = "IBM Plex Mono"
|
323 |
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
|
324 |
|
325 |
+
if len(df) == 0 or "Equation" not in df.columns:
|
326 |
return fig
|
327 |
|
328 |
# Plotting the data
|
329 |
+
ax.loglog(
|
330 |
+
df["Complexity"],
|
331 |
+
df["Loss"],
|
332 |
+
marker="o",
|
333 |
+
linestyle="-",
|
334 |
+
color="#333f48",
|
335 |
+
linewidth=1.5,
|
336 |
+
markersize=6,
|
337 |
+
)
|
338 |
|
339 |
# Set the axis limits
|
340 |
ax.set_xlim(0.5, maxsize + 1)
|
341 |
+
ytop = 2 ** (np.ceil(np.log2(df["Loss"].max())))
|
342 |
+
ybottom = 2 ** (np.floor(np.log2(df["Loss"].min() + 1e-20)))
|
343 |
ax.set_ylim(ybottom, ytop)
|
344 |
|
345 |
+
ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
|
346 |
+
ax.spines["top"].set_visible(False)
|
347 |
+
ax.spines["right"].set_visible(False)
|
348 |
|
349 |
# Range-frame the plot
|
350 |
+
for direction in ["bottom", "left"]:
|
351 |
+
ax.spines[direction].set_position(("outward", 10))
|
352 |
|
353 |
# Delete far ticks
|
354 |
+
ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
|
355 |
+
ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
|
356 |
|
357 |
+
ax.set_xlabel("Complexity")
|
358 |
+
ax.set_ylabel("Loss")
|
359 |
fig.tight_layout(pad=2)
|
360 |
|
361 |
return fig
|
362 |
|
363 |
+
|
364 |
def replot(test_equation, num_points, noise_level, data_seed):
|
365 |
X, y = generate_data(test_equation, num_points, noise_level, data_seed)
|
366 |
x = X["x"]
|
367 |
|
368 |
+
plt.rcParams["font.family"] = "IBM Plex Mono"
|
369 |
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)
|
370 |
|
371 |
+
ax.scatter(x, y, alpha=0.7, edgecolors="w", s=50)
|
372 |
|
373 |
+
ax.grid(True, which="both", ls="--", linewidth=0.5, color="gray", alpha=0.5)
|
374 |
+
ax.spines["top"].set_visible(False)
|
375 |
+
ax.spines["right"].set_visible(False)
|
376 |
|
377 |
# Range-frame the plot
|
378 |
+
for direction in ["bottom", "left"]:
|
379 |
+
ax.spines[direction].set_position(("outward", 10))
|
380 |
|
381 |
# Delete far ticks
|
382 |
+
ax.tick_params(axis="both", which="major", labelsize=10, direction="out", length=5)
|
383 |
+
ax.tick_params(axis="both", which="minor", labelsize=8, direction="out", length=3)
|
384 |
|
385 |
ax.set_xlabel("x")
|
386 |
ax.set_ylabel("y")
|