Spaces:
Running
Running
MilesCranmer
commited on
Merge pull request #46 from MilesCranmer/multi-output
Browse files- Project.toml +1 -1
- pysr/sr.py +123 -61
- test/test.py +4 -3
Project.toml
CHANGED
@@ -2,5 +2,5 @@
|
|
2 |
SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
|
3 |
|
4 |
[compat]
|
5 |
-
SymbolicRegression = "0.
|
6 |
julia = "1.5"
|
|
|
2 |
SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
|
3 |
|
4 |
[compat]
|
5 |
+
SymbolicRegression = "0.6.0"
|
6 |
julia = "1.5"
|
pysr/sr.py
CHANGED
@@ -19,6 +19,8 @@ global_equation_file = 'hall_of_fame.csv'
|
|
19 |
global_n_features = None
|
20 |
global_variable_names = []
|
21 |
global_extra_sympy_mappings = {}
|
|
|
|
|
22 |
|
23 |
sympy_mappings = {
|
24 |
'div': lambda x, y : x/y,
|
@@ -276,6 +278,16 @@ def pysr(X=None, y=None, weights=None,
|
|
276 |
if X is None:
|
277 |
X, y = _using_test_input(X, test, y)
|
278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
kwargs = dict(X=X, y=y, weights=weights,
|
280 |
alpha=alpha, annealing=annealing, batchSize=batchSize,
|
281 |
batching=batching, binary_operators=binary_operators,
|
@@ -309,7 +321,8 @@ def pysr(X=None, y=None, weights=None,
|
|
309 |
constraints=constraints,
|
310 |
extra_sympy_mappings=extra_sympy_mappings,
|
311 |
julia_project=julia_project, loss=loss,
|
312 |
-
output_jax_format=output_jax_format
|
|
|
313 |
|
314 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
315 |
|
@@ -358,15 +371,20 @@ def pysr(X=None, y=None, weights=None,
|
|
358 |
|
359 |
|
360 |
|
361 |
-
def _set_globals(X, equation_file, extra_sympy_mappings, variable_names,
|
|
|
362 |
global global_n_features
|
363 |
global global_equation_file
|
364 |
global global_variable_names
|
365 |
global global_extra_sympy_mappings
|
|
|
|
|
366 |
global_n_features = X.shape[1]
|
367 |
global_equation_file = equation_file
|
368 |
global_variable_names = variable_names
|
369 |
global_extra_sympy_mappings = extra_sympy_mappings
|
|
|
|
|
370 |
|
371 |
|
372 |
def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
|
@@ -393,9 +411,7 @@ def _cmd_runner(command, **kwargs):
|
|
393 |
.replace('\\r', '\r')
|
394 |
.encode(sys.stdout.encoding, errors='replace'))
|
395 |
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
|
400 |
process.stdout.close()
|
401 |
process.wait()
|
@@ -438,17 +454,35 @@ def _create_julia_files(dataset_filename, def_datasets, hyperparam_filename, de
|
|
438 |
print(f'EquationSearch(X, y, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
|
439 |
|
440 |
|
441 |
-
def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename,
|
|
|
442 |
def_datasets = """using DelimitedFiles"""
|
443 |
np.savetxt(X_filename, X.astype(np.float32), delimiter=',')
|
444 |
-
|
|
|
|
|
|
|
445 |
if weights is not None:
|
446 |
-
|
|
|
|
|
|
|
447 |
def_datasets += f"""
|
448 |
-
X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')))
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')[:, 1]"""
|
|
|
450 |
if weights is not None:
|
451 |
-
|
|
|
|
|
|
|
|
|
452 |
weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')[:, 1]"""
|
453 |
return def_datasets
|
454 |
|
@@ -656,10 +690,10 @@ def _check_assertions(X, binary_operators, unary_operators, use_custom_variable_
|
|
656 |
# Check for potential errors before they happen
|
657 |
assert len(unary_operators) + len(binary_operators) > 0
|
658 |
assert len(X.shape) == 2
|
659 |
-
assert len(y.shape)
|
660 |
assert X.shape[0] == y.shape[0]
|
661 |
if weights is not None:
|
662 |
-
assert
|
663 |
assert X.shape[0] == weights.shape[0]
|
664 |
if use_custom_variable_names:
|
665 |
assert len(variable_names) == X.shape[1]
|
@@ -693,7 +727,8 @@ def run_feature_selection(X, y, select_k_features):
|
|
693 |
return selector.get_support(indices=True)
|
694 |
|
695 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
696 |
-
extra_sympy_mappings=None, output_jax_format=False,
|
|
|
697 |
"""Get the equations from a hall of fame file. If no arguments
|
698 |
entered, the ones used previously from a call to PySR will be used."""
|
699 |
|
@@ -701,99 +736,126 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
701 |
global global_equation_file
|
702 |
global global_variable_names
|
703 |
global global_extra_sympy_mappings
|
|
|
|
|
704 |
|
705 |
if equation_file is None: equation_file = global_equation_file
|
706 |
if n_features is None: n_features = global_n_features
|
707 |
if variable_names is None: variable_names = global_variable_names
|
708 |
if extra_sympy_mappings is None: extra_sympy_mappings = global_extra_sympy_mappings
|
|
|
|
|
709 |
|
710 |
global_equation_file = equation_file
|
711 |
global_n_features = n_features
|
712 |
global_variable_names = variable_names
|
713 |
global_extra_sympy_mappings = extra_sympy_mappings
|
|
|
|
|
714 |
|
715 |
try:
|
716 |
-
|
|
|
|
|
|
|
717 |
except FileNotFoundError:
|
718 |
raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
|
719 |
|
720 |
-
|
721 |
-
lastMSE = None
|
722 |
-
lastComplexity = 0
|
723 |
-
sympy_format = []
|
724 |
-
lambda_format = []
|
725 |
-
if output_jax_format:
|
726 |
-
jax_format = []
|
727 |
-
use_custom_variable_names = (len(variable_names) != 0)
|
728 |
-
local_sympy_mappings = {
|
729 |
-
**extra_sympy_mappings,
|
730 |
-
**sympy_mappings
|
731 |
-
}
|
732 |
|
733 |
-
|
734 |
-
sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(n_features)]
|
735 |
-
else:
|
736 |
-
sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
|
737 |
|
738 |
-
|
739 |
-
|
740 |
-
|
|
|
|
|
741 |
if output_jax_format:
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
if lastMSE is None:
|
749 |
-
cur_score = 0.0
|
750 |
-
else:
|
751 |
-
cur_score = - np.log(curMSE/lastMSE)/(curComplexity - lastComplexity)
|
752 |
|
753 |
-
|
754 |
-
|
755 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
756 |
|
757 |
-
|
758 |
-
output['sympy_format'] = sympy_format
|
759 |
-
output['lambda_format'] = lambda_format
|
760 |
-
output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
|
761 |
-
if output_jax_format:
|
762 |
-
output_cols += ['jax_format']
|
763 |
-
output['jax_format'] = jax_format
|
764 |
|
765 |
-
|
|
|
|
|
|
|
766 |
|
767 |
def best_row(equations=None):
|
768 |
"""Return the best row of a hall of fame file using the score column.
|
769 |
By default this uses the last equation file.
|
770 |
"""
|
771 |
if equations is None: equations = get_hof()
|
772 |
-
|
773 |
-
|
|
|
|
|
774 |
|
775 |
def best_tex(equations=None):
|
776 |
"""Return the equation with the best score, in latex format
|
777 |
By default this uses the last equation file.
|
778 |
"""
|
779 |
if equations is None: equations = get_hof()
|
780 |
-
|
781 |
-
|
|
|
|
|
782 |
|
783 |
def best(equations=None):
|
784 |
"""Return the equation with the best score, in sympy format.
|
785 |
By default this uses the last equation file.
|
786 |
"""
|
787 |
if equations is None: equations = get_hof()
|
788 |
-
|
789 |
-
|
|
|
|
|
790 |
|
791 |
def best_callable(equations=None):
|
792 |
"""Return the equation with the best score, in callable format.
|
793 |
By default this uses the last equation file.
|
794 |
"""
|
795 |
if equations is None: equations = get_hof()
|
796 |
-
|
|
|
|
|
|
|
797 |
|
798 |
def _escape_filename(filename):
|
799 |
"""Turns a file into a string representation with correctly escaped backslashes"""
|
|
|
19 |
global_n_features = None
|
20 |
global_variable_names = []
|
21 |
global_extra_sympy_mappings = {}
|
22 |
+
global_multioutput = False
|
23 |
+
global_nout = 1
|
24 |
|
25 |
sympy_mappings = {
|
26 |
'div': lambda x, y : x/y,
|
|
|
278 |
if X is None:
|
279 |
X, y = _using_test_input(X, test, y)
|
280 |
|
281 |
+
if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
|
282 |
+
multioutput = False
|
283 |
+
nout = 1
|
284 |
+
y = y.reshape(-1)
|
285 |
+
elif len(y.shape) == 2:
|
286 |
+
multioutput = True
|
287 |
+
nout = y.shape[1]
|
288 |
+
else:
|
289 |
+
raise NotImplementedError("y shape not supported!")
|
290 |
+
|
291 |
kwargs = dict(X=X, y=y, weights=weights,
|
292 |
alpha=alpha, annealing=annealing, batchSize=batchSize,
|
293 |
batching=batching, binary_operators=binary_operators,
|
|
|
321 |
constraints=constraints,
|
322 |
extra_sympy_mappings=extra_sympy_mappings,
|
323 |
julia_project=julia_project, loss=loss,
|
324 |
+
output_jax_format=output_jax_format,
|
325 |
+
multioutput=multioutput, nout=nout)
|
326 |
|
327 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
328 |
|
|
|
371 |
|
372 |
|
373 |
|
374 |
+
def _set_globals(X, equation_file, extra_sympy_mappings, variable_names,
|
375 |
+
multioutput, nout, **kwargs):
|
376 |
global global_n_features
|
377 |
global global_equation_file
|
378 |
global global_variable_names
|
379 |
global global_extra_sympy_mappings
|
380 |
+
global global_multioutput
|
381 |
+
global global_nout
|
382 |
global_n_features = X.shape[1]
|
383 |
global_equation_file = equation_file
|
384 |
global_variable_names = variable_names
|
385 |
global_extra_sympy_mappings = extra_sympy_mappings
|
386 |
+
global_multioutput = multioutput
|
387 |
+
global_nout = nout
|
388 |
|
389 |
|
390 |
def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
|
|
|
411 |
.replace('\\r', '\r')
|
412 |
.encode(sys.stdout.encoding, errors='replace'))
|
413 |
|
414 |
+
sys.stdout.buffer.write(decoded_line)
|
|
|
|
|
415 |
|
416 |
process.stdout.close()
|
417 |
process.wait()
|
|
|
454 |
print(f'EquationSearch(X, y, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
|
455 |
|
456 |
|
457 |
+
def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename,
|
458 |
+
multioutput, **kwargs):
|
459 |
def_datasets = """using DelimitedFiles"""
|
460 |
np.savetxt(X_filename, X.astype(np.float32), delimiter=',')
|
461 |
+
if multioutput:
|
462 |
+
np.savetxt(y_filename, y.astype(np.float32), delimiter=',')
|
463 |
+
else:
|
464 |
+
np.savetxt(y_filename, y.reshape(-1, 1).astype(np.float32), delimiter=',')
|
465 |
if weights is not None:
|
466 |
+
if multioutput:
|
467 |
+
np.savetxt(weights_filename, weights.astype(np.float32), delimiter=',')
|
468 |
+
else:
|
469 |
+
np.savetxt(weights_filename, weights.reshape(-1, 1).astype(np.float32), delimiter=',')
|
470 |
def_datasets += f"""
|
471 |
+
X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')))"""
|
472 |
+
|
473 |
+
if multioutput:
|
474 |
+
def_datasets+= f"""
|
475 |
+
y = copy(transpose(readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')))"""
|
476 |
+
else:
|
477 |
+
def_datasets+= f"""
|
478 |
y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')[:, 1]"""
|
479 |
+
|
480 |
if weights is not None:
|
481 |
+
if multioutput:
|
482 |
+
def_datasets += f"""
|
483 |
+
weights = copy(transpose(readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')))"""
|
484 |
+
else:
|
485 |
+
def_datasets += f"""
|
486 |
weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')[:, 1]"""
|
487 |
return def_datasets
|
488 |
|
|
|
690 |
# Check for potential errors before they happen
|
691 |
assert len(unary_operators) + len(binary_operators) > 0
|
692 |
assert len(X.shape) == 2
|
693 |
+
assert len(y.shape) in [1, 2]
|
694 |
assert X.shape[0] == y.shape[0]
|
695 |
if weights is not None:
|
696 |
+
assert weights.shape == y.shape
|
697 |
assert X.shape[0] == weights.shape[0]
|
698 |
if use_custom_variable_names:
|
699 |
assert len(variable_names) == X.shape[1]
|
|
|
727 |
return selector.get_support(indices=True)
|
728 |
|
729 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
730 |
+
extra_sympy_mappings=None, output_jax_format=False,
|
731 |
+
multioutput=None, nout=None, **kwargs):
|
732 |
"""Get the equations from a hall of fame file. If no arguments
|
733 |
entered, the ones used previously from a call to PySR will be used."""
|
734 |
|
|
|
736 |
global global_equation_file
|
737 |
global global_variable_names
|
738 |
global global_extra_sympy_mappings
|
739 |
+
global global_multioutput
|
740 |
+
global global_nout
|
741 |
|
742 |
if equation_file is None: equation_file = global_equation_file
|
743 |
if n_features is None: n_features = global_n_features
|
744 |
if variable_names is None: variable_names = global_variable_names
|
745 |
if extra_sympy_mappings is None: extra_sympy_mappings = global_extra_sympy_mappings
|
746 |
+
if multioutput is None: multioutput = global_multioutput
|
747 |
+
if nout is None: nout = global_nout
|
748 |
|
749 |
global_equation_file = equation_file
|
750 |
global_n_features = n_features
|
751 |
global_variable_names = variable_names
|
752 |
global_extra_sympy_mappings = extra_sympy_mappings
|
753 |
+
global_multioutput = multioutput
|
754 |
+
global_nout = nout
|
755 |
|
756 |
try:
|
757 |
+
if multioutput:
|
758 |
+
all_outputs = [pd.read_csv(f'out{i}_' + str(equation_file) + '.bkup', sep="|") for i in range(1, nout+1)]
|
759 |
+
else:
|
760 |
+
all_outputs = [pd.read_csv(str(equation_file) + '.bkup', sep="|")]
|
761 |
except FileNotFoundError:
|
762 |
raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
|
763 |
|
764 |
+
ret_outputs = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
765 |
|
766 |
+
for output in all_outputs:
|
|
|
|
|
|
|
767 |
|
768 |
+
scores = []
|
769 |
+
lastMSE = None
|
770 |
+
lastComplexity = 0
|
771 |
+
sympy_format = []
|
772 |
+
lambda_format = []
|
773 |
if output_jax_format:
|
774 |
+
jax_format = []
|
775 |
+
use_custom_variable_names = (len(variable_names) != 0)
|
776 |
+
local_sympy_mappings = {
|
777 |
+
**extra_sympy_mappings,
|
778 |
+
**sympy_mappings
|
779 |
+
}
|
|
|
|
|
|
|
|
|
780 |
|
781 |
+
if use_custom_variable_names:
|
782 |
+
sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(n_features)]
|
783 |
+
else:
|
784 |
+
sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
|
785 |
+
|
786 |
+
for i in range(len(output)):
|
787 |
+
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
788 |
+
sympy_format.append(eqn)
|
789 |
+
if output_jax_format:
|
790 |
+
func, params = sympy2jax(eqn, sympy_symbols)
|
791 |
+
jax_format.append({'callable': func, 'parameters': params})
|
792 |
+
lambda_format.append(lambdify(sympy_symbols, eqn))
|
793 |
+
curMSE = output.loc[i, 'MSE']
|
794 |
+
curComplexity = output.loc[i, 'Complexity']
|
795 |
+
|
796 |
+
if lastMSE is None:
|
797 |
+
cur_score = 0.0
|
798 |
+
else:
|
799 |
+
cur_score = - np.log(curMSE/lastMSE)/(curComplexity - lastComplexity)
|
800 |
+
|
801 |
+
scores.append(cur_score)
|
802 |
+
lastMSE = curMSE
|
803 |
+
lastComplexity = curComplexity
|
804 |
+
|
805 |
+
output['score'] = np.array(scores)
|
806 |
+
output['sympy_format'] = sympy_format
|
807 |
+
output['lambda_format'] = lambda_format
|
808 |
+
output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
|
809 |
+
if output_jax_format:
|
810 |
+
output_cols += ['jax_format']
|
811 |
+
output['jax_format'] = jax_format
|
812 |
|
813 |
+
ret_outputs.append(output[output_cols])
|
|
|
|
|
|
|
|
|
|
|
|
|
814 |
|
815 |
+
if multioutput:
|
816 |
+
return ret_outputs
|
817 |
+
else:
|
818 |
+
return ret_outputs[0]
|
819 |
|
820 |
def best_row(equations=None):
|
821 |
"""Return the best row of a hall of fame file using the score column.
|
822 |
By default this uses the last equation file.
|
823 |
"""
|
824 |
if equations is None: equations = get_hof()
|
825 |
+
if isinstance(equations, list):
|
826 |
+
return [eq.iloc[np.argmax(eq['score'])] for eq in equations]
|
827 |
+
else:
|
828 |
+
return equations.iloc[np.argmax(equations['score'])]
|
829 |
|
830 |
def best_tex(equations=None):
|
831 |
"""Return the equation with the best score, in latex format
|
832 |
By default this uses the last equation file.
|
833 |
"""
|
834 |
if equations is None: equations = get_hof()
|
835 |
+
if isinstance(equations, list):
|
836 |
+
return [sympy.latex(best_row(eq)['sympy_format'].simplify()) for eq in equations]
|
837 |
+
else:
|
838 |
+
return sympy.latex(best_row(equations)['sympy_format'].simplify())
|
839 |
|
840 |
def best(equations=None):
|
841 |
"""Return the equation with the best score, in sympy format.
|
842 |
By default this uses the last equation file.
|
843 |
"""
|
844 |
if equations is None: equations = get_hof()
|
845 |
+
if isinstance(equations, list):
|
846 |
+
return [best_row(eq)['sympy_format'].simplify() for eq in equations]
|
847 |
+
else:
|
848 |
+
return best_row(equations)['sympy_format'].simplify()
|
849 |
|
850 |
def best_callable(equations=None):
|
851 |
"""Return the equation with the best score, in callable format.
|
852 |
By default this uses the last equation file.
|
853 |
"""
|
854 |
if equations is None: equations = get_hof()
|
855 |
+
if isinstance(equations, list):
|
856 |
+
return [best_row(eq)['lambda_format'] for eq in equations]
|
857 |
+
else:
|
858 |
+
return best_row(equations)['lambda_format']
|
859 |
|
860 |
def _escape_filename(filename):
|
861 |
"""Turns a file into a string representation with correctly escaped backslashes"""
|
test/test.py
CHANGED
@@ -17,14 +17,15 @@ equations = pysr(X, y, **default_test_kwargs)
|
|
17 |
print(equations)
|
18 |
assert equations.iloc[-1]['MSE'] < 1e-4
|
19 |
|
20 |
-
print("Test 2 - test custom operator")
|
21 |
-
y = X[:, 0]**2
|
22 |
equations = pysr(X, y,
|
23 |
unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
|
24 |
extra_sympy_mappings={'square': lambda x: x**2},
|
25 |
**default_test_kwargs)
|
26 |
print(equations)
|
27 |
-
assert equations.iloc[-1]['MSE'] < 1e-4
|
|
|
28 |
|
29 |
X = np.random.randn(100, 1)
|
30 |
y = X[:, 0] + 3.0
|
|
|
17 |
print(equations)
|
18 |
assert equations.iloc[-1]['MSE'] < 1e-4
|
19 |
|
20 |
+
print("Test 2 - test custom operator, and multiple outputs")
|
21 |
+
y = X[:, [0, 1]]**2
|
22 |
equations = pysr(X, y,
|
23 |
unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
|
24 |
extra_sympy_mappings={'square': lambda x: x**2},
|
25 |
**default_test_kwargs)
|
26 |
print(equations)
|
27 |
+
assert equations[0].iloc[-1]['MSE'] < 1e-4
|
28 |
+
assert equations[1].iloc[-1]['MSE'] < 1e-4
|
29 |
|
30 |
X = np.random.randn(100, 1)
|
31 |
y = X[:, 0] + 3.0
|