PySR / gui /processing.py
MilesCranmer's picture
wip on predictions from equations
fd28328 unverified
raw
history blame
6.1 kB
import multiprocessing as mp
import os
import tempfile
import time
from pathlib import Path
import pandas as pd
from data import generate_data, read_csv
EMPTY_DF = lambda: pd.DataFrame(
{
"Equation": [],
"Loss": [],
"Complexity": [],
}
)
def pysr_fit(queue: mp.Queue, out_queue: mp.Queue):
import pysr
while True:
# Get the arguments from the queue, if available
args = queue.get()
if args is None:
break
X = args["X"]
y = args["y"]
kwargs = args["kwargs"]
model = pysr.PySRRegressor(
progress=False,
timeout_in_seconds=1000,
**kwargs,
)
model.fit(X, y)
out_queue.put(None)
def pysr_predict(queue: mp.Queue, out_queue: mp.Queue):
import numpy as np
import pysr
while True:
args = queue.get()
if args is None:
break
X = args["X"]
equation_file = str(args["equation_file"])
complexity = args["complexity"]
equation_file_pkl = equation_file.replace(".csv", ".pkl")
equation_file_bkup = equation_file + ".bkup"
equation_file_copy = equation_file.replace(".csv", "_copy.csv")
equation_file_pkl_copy = equation_file.replace(".csv", "_copy.pkl")
# TODO: See if there is way to get lock on file
os.system(f"cp {equation_file_bkup} {equation_file_copy}")
os.system(f"cp {equation_file_pkl} {equation_file_pkl_copy}")
try:
model = pysr.PySRRegressor.from_file(equation_file_pkl_copy, verbosity=0)
except pd.errors.EmptyDataError:
continue
index = np.abs(model.equations_.complexity - complexity).argmin
ypred = model.predict(X, index)
out_queue.put(ypred)
class PySRProcess:
def __init__(self):
self.queue = mp.Queue()
self.out_queue = mp.Queue()
self.process = mp.Process(target=pysr_fit, args=(self.queue, self.out_queue))
self.process.start()
class PySRReaderProcess:
def __init__(self):
self.queue = mp.Queue()
self.out_queue = mp.Queue()
self.process = mp.Process(
target=pysr_predict, args=(self.queue, self.out_queue)
)
self.process.start()
PERSISTENT_WRITER = None
def processing(
file_input,
force_run,
test_equation,
num_points,
noise_level,
data_seed,
niterations,
maxsize,
binary_operators,
unary_operators,
plot_update_delay,
parsimony,
populations,
population_size,
ncycles_per_iteration,
elementwise_loss,
adaptive_parsimony_scaling,
optimizer_algorithm,
optimizer_iterations,
batching,
batch_size,
):
"""Load data, then spawn a process to run the greet function."""
global PERSISTENT_WRITER
if PERSISTENT_WRITER is None:
print("Starting PySR process")
PERSISTENT_WRITER = PySRProcess()
if file_input is not None:
try:
X, y = read_csv(file_input, force_run)
except ValueError as e:
return (EMPTY_DF(), str(e))
else:
X, y = generate_data(test_equation, num_points, noise_level, data_seed)
with tempfile.TemporaryDirectory() as tmpdirname:
base = Path(tmpdirname)
equation_file = base / "hall_of_fame.csv"
equation_file_bkup = base / "hall_of_fame.csv.bkup"
# Check if queue is empty, if not, kill the process
# and start a new one
if not PERSISTENT_WRITER.queue.empty():
print("Restarting PySR process")
if PERSISTENT_WRITER.process.is_alive():
PERSISTENT_WRITER.process.terminate()
PERSISTENT_WRITER.process.join()
PERSISTENT_WRITER = PySRProcess()
# Write these to queue instead:
PERSISTENT_WRITER.queue.put(
dict(
X=X,
y=y,
kwargs=dict(
niterations=niterations,
maxsize=maxsize,
binary_operators=binary_operators,
unary_operators=unary_operators,
equation_file=equation_file,
parsimony=parsimony,
populations=populations,
population_size=population_size,
ncycles_per_iteration=ncycles_per_iteration,
elementwise_loss=elementwise_loss,
adaptive_parsimony_scaling=adaptive_parsimony_scaling,
optimizer_algorithm=optimizer_algorithm,
optimizer_iterations=optimizer_iterations,
batching=batching,
batch_size=batch_size,
),
)
)
while PERSISTENT_WRITER.out_queue.empty():
if equation_file_bkup.exists():
# First, copy the file to a the copy file
equation_file_copy = base / "hall_of_fame_copy.csv"
os.system(f"cp {equation_file_bkup} {equation_file_copy}")
try:
equations = pd.read_csv(equation_file_copy)
except pd.errors.EmptyDataError:
continue
# Ensure it is pareto dominated, with more complex expressions
# having higher loss. Otherwise remove those rows.
# TODO: Not sure why this occurs; could be the result of a late copy?
equations.sort_values("Complexity", ascending=True, inplace=True)
equations.reset_index(inplace=True)
bad_idx = []
min_loss = None
for i in equations.index:
if min_loss is None or equations.loc[i, "Loss"] < min_loss:
min_loss = float(equations.loc[i, "Loss"])
else:
bad_idx.append(i)
equations.drop(index=bad_idx, inplace=True)
yield equations[["Complexity", "Loss", "Equation"]]
time.sleep(0.1)