Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
b5b74c3
1
Parent(s):
a9c6697
[WIP] begin implementation of multi-output compatibility
Browse files- pysr/sr.py +94 -39
pysr/sr.py
CHANGED
@@ -19,6 +19,8 @@ global_equation_file = 'hall_of_fame.csv'
|
|
19 |
global_n_features = None
|
20 |
global_variable_names = []
|
21 |
global_extra_sympy_mappings = {}
|
|
|
|
|
22 |
|
23 |
sympy_mappings = {
|
24 |
'div': lambda x, y : x/y,
|
@@ -276,6 +278,15 @@ def pysr(X=None, y=None, weights=None,
|
|
276 |
if X is None:
|
277 |
X, y = _using_test_input(X, test, y)
|
278 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
kwargs = dict(X=X, y=y, weights=weights,
|
280 |
alpha=alpha, annealing=annealing, batchSize=batchSize,
|
281 |
batching=batching, binary_operators=binary_operators,
|
@@ -309,7 +320,8 @@ def pysr(X=None, y=None, weights=None,
|
|
309 |
constraints=constraints,
|
310 |
extra_sympy_mappings=extra_sympy_mappings,
|
311 |
julia_project=julia_project, loss=loss,
|
312 |
-
output_jax_format=output_jax_format
|
|
|
313 |
|
314 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
315 |
|
@@ -358,15 +370,20 @@ def pysr(X=None, y=None, weights=None,
|
|
358 |
|
359 |
|
360 |
|
361 |
-
def _set_globals(X, equation_file, extra_sympy_mappings, variable_names,
|
|
|
362 |
global global_n_features
|
363 |
global global_equation_file
|
364 |
global global_variable_names
|
365 |
global global_extra_sympy_mappings
|
|
|
|
|
366 |
global_n_features = X.shape[1]
|
367 |
global_equation_file = equation_file
|
368 |
global_variable_names = variable_names
|
369 |
global_extra_sympy_mappings = extra_sympy_mappings
|
|
|
|
|
370 |
|
371 |
|
372 |
def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
|
@@ -438,17 +455,35 @@ def _create_julia_files(dataset_filename, def_datasets, hyperparam_filename, de
|
|
438 |
print(f'EquationSearch(X, y, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
|
439 |
|
440 |
|
441 |
-
def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename,
|
|
|
442 |
def_datasets = """using DelimitedFiles"""
|
443 |
np.savetxt(X_filename, X.astype(np.float32), delimiter=',')
|
444 |
-
|
|
|
|
|
|
|
445 |
if weights is not None:
|
446 |
-
|
|
|
|
|
|
|
447 |
def_datasets += f"""
|
448 |
-
X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')))
|
|
|
|
|
|
|
|
|
|
|
|
|
449 |
y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')[:, 1]"""
|
|
|
450 |
if weights is not None:
|
451 |
-
|
|
|
|
|
|
|
|
|
452 |
weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')[:, 1]"""
|
453 |
return def_datasets
|
454 |
|
@@ -656,10 +691,10 @@ def _check_assertions(X, binary_operators, unary_operators, use_custom_variable_
|
|
656 |
# Check for potential errors before they happen
|
657 |
assert len(unary_operators) + len(binary_operators) > 0
|
658 |
assert len(X.shape) == 2
|
659 |
-
assert len(y.shape)
|
660 |
assert X.shape[0] == y.shape[0]
|
661 |
if weights is not None:
|
662 |
-
assert
|
663 |
assert X.shape[0] == weights.shape[0]
|
664 |
if use_custom_variable_names:
|
665 |
assert len(variable_names) == X.shape[1]
|
@@ -693,7 +728,8 @@ def run_feature_selection(X, y, select_k_features):
|
|
693 |
return selector.get_support(indices=True)
|
694 |
|
695 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
696 |
-
extra_sympy_mappings=None, output_jax_format=False,
|
|
|
697 |
"""Get the equations from a hall of fame file. If no arguments
|
698 |
entered, the ones used previously from a call to PySR will be used."""
|
699 |
|
@@ -701,19 +737,28 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
701 |
global global_equation_file
|
702 |
global global_variable_names
|
703 |
global global_extra_sympy_mappings
|
|
|
|
|
704 |
|
705 |
if equation_file is None: equation_file = global_equation_file
|
706 |
if n_features is None: n_features = global_n_features
|
707 |
if variable_names is None: variable_names = global_variable_names
|
708 |
if extra_sympy_mappings is None: extra_sympy_mappings = global_extra_sympy_mappings
|
|
|
|
|
709 |
|
710 |
global_equation_file = equation_file
|
711 |
global_n_features = n_features
|
712 |
global_variable_names = variable_names
|
713 |
global_extra_sympy_mappings = extra_sympy_mappings
|
|
|
|
|
714 |
|
715 |
try:
|
716 |
-
|
|
|
|
|
|
|
717 |
except FileNotFoundError:
|
718 |
raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
|
719 |
|
@@ -735,42 +780,52 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
735 |
else:
|
736 |
sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
|
737 |
|
738 |
-
|
739 |
-
|
740 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
741 |
if output_jax_format:
|
742 |
-
|
743 |
-
jax_format
|
744 |
-
lambda_format.append(lambdify(sympy_symbols, eqn))
|
745 |
-
curMSE = output.loc[i, 'MSE']
|
746 |
-
curComplexity = output.loc[i, 'Complexity']
|
747 |
-
|
748 |
-
if lastMSE is None:
|
749 |
-
cur_score = 0.0
|
750 |
-
else:
|
751 |
-
cur_score = - np.log(curMSE/lastMSE)/(curComplexity - lastComplexity)
|
752 |
|
753 |
-
|
754 |
-
lastMSE = curMSE
|
755 |
-
lastComplexity = curComplexity
|
756 |
-
|
757 |
-
output['score'] = np.array(scores)
|
758 |
-
output['sympy_format'] = sympy_format
|
759 |
-
output['lambda_format'] = lambda_format
|
760 |
-
output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
|
761 |
-
if output_jax_format:
|
762 |
-
output_cols += ['jax_format']
|
763 |
-
output['jax_format'] = jax_format
|
764 |
|
765 |
-
|
|
|
|
|
|
|
766 |
|
767 |
def best_row(equations=None):
|
768 |
"""Return the best row of a hall of fame file using the score column.
|
769 |
By default this uses the last equation file.
|
770 |
"""
|
771 |
-
if equations is None:
|
772 |
-
|
773 |
-
|
|
|
|
|
|
|
774 |
|
775 |
def best_tex(equations=None):
|
776 |
"""Return the equation with the best score, in latex format
|
|
|
19 |
global_n_features = None
|
20 |
global_variable_names = []
|
21 |
global_extra_sympy_mappings = {}
|
22 |
+
global_multioutput = False
|
23 |
+
global_nout = 1
|
24 |
|
25 |
sympy_mappings = {
|
26 |
'div': lambda x, y : x/y,
|
|
|
278 |
if X is None:
|
279 |
X, y = _using_test_input(X, test, y)
|
280 |
|
281 |
+
if len(y.shape) == 2:
|
282 |
+
multioutput = True
|
283 |
+
nout = y.shape[1]
|
284 |
+
elif len(y.shape) == 1:
|
285 |
+
multioutput = False
|
286 |
+
nout = 1
|
287 |
+
else:
|
288 |
+
raise NotImplementedError("y shape not supported!")
|
289 |
+
|
290 |
kwargs = dict(X=X, y=y, weights=weights,
|
291 |
alpha=alpha, annealing=annealing, batchSize=batchSize,
|
292 |
batching=batching, binary_operators=binary_operators,
|
|
|
320 |
constraints=constraints,
|
321 |
extra_sympy_mappings=extra_sympy_mappings,
|
322 |
julia_project=julia_project, loss=loss,
|
323 |
+
output_jax_format=output_jax_format,
|
324 |
+
multioutput=multioutput, nout=nout)
|
325 |
|
326 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
327 |
|
|
|
370 |
|
371 |
|
372 |
|
373 |
+
def _set_globals(X, equation_file, extra_sympy_mappings, variable_names,
|
374 |
+
multioutput, **kwargs):
|
375 |
global global_n_features
|
376 |
global global_equation_file
|
377 |
global global_variable_names
|
378 |
global global_extra_sympy_mappings
|
379 |
+
global global_multioutput
|
380 |
+
global global_nout
|
381 |
global_n_features = X.shape[1]
|
382 |
global_equation_file = equation_file
|
383 |
global_variable_names = variable_names
|
384 |
global_extra_sympy_mappings = extra_sympy_mappings
|
385 |
+
global_multioutput = multioutput
|
386 |
+
global_nout = nout
|
387 |
|
388 |
|
389 |
def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
|
|
|
455 |
print(f'EquationSearch(X, y, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
|
456 |
|
457 |
|
458 |
+
def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename,
|
459 |
+
multioutput, **kwargs):
|
460 |
def_datasets = """using DelimitedFiles"""
|
461 |
np.savetxt(X_filename, X.astype(np.float32), delimiter=',')
|
462 |
+
if multioutput:
|
463 |
+
np.savetxt(y_filename, y.astype(np.float32), delimiter=',')
|
464 |
+
else:
|
465 |
+
np.savetxt(y_filename, y.reshape(-1, 1).astype(np.float32), delimiter=',')
|
466 |
if weights is not None:
|
467 |
+
if multioutput:
|
468 |
+
np.savetxt(weights_filename, weights.astype(np.float32), delimiter=',')
|
469 |
+
else:
|
470 |
+
np.savetxt(weights_filename, weights.reshape(-1, 1).astype(np.float32), delimiter=',')
|
471 |
def_datasets += f"""
|
472 |
+
X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')))"""
|
473 |
+
|
474 |
+
if multioutput:
|
475 |
+
def_datasets+= f"""
|
476 |
+
y = copy(transpose(readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')))"""
|
477 |
+
else:
|
478 |
+
def_datasets+= f"""
|
479 |
y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')[:, 1]"""
|
480 |
+
|
481 |
if weights is not None:
|
482 |
+
if multioutput:
|
483 |
+
def_datasets += f"""
|
484 |
+
weights = copy(transpose(readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')))"""
|
485 |
+
else:
|
486 |
+
def_datasets += f"""
|
487 |
weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')[:, 1]"""
|
488 |
return def_datasets
|
489 |
|
|
|
691 |
# Check for potential errors before they happen
|
692 |
assert len(unary_operators) + len(binary_operators) > 0
|
693 |
assert len(X.shape) == 2
|
694 |
+
assert len(y.shape) in [1, 2]
|
695 |
assert X.shape[0] == y.shape[0]
|
696 |
if weights is not None:
|
697 |
+
assert weights.shape == y.shape
|
698 |
assert X.shape[0] == weights.shape[0]
|
699 |
if use_custom_variable_names:
|
700 |
assert len(variable_names) == X.shape[1]
|
|
|
728 |
return selector.get_support(indices=True)
|
729 |
|
730 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
731 |
+
extra_sympy_mappings=None, output_jax_format=False,
|
732 |
+
multioutput=False, nout=1, **kwargs):
|
733 |
"""Get the equations from a hall of fame file. If no arguments
|
734 |
entered, the ones used previously from a call to PySR will be used."""
|
735 |
|
|
|
737 |
global global_equation_file
|
738 |
global global_variable_names
|
739 |
global global_extra_sympy_mappings
|
740 |
+
global global_multioutput
|
741 |
+
global global_nout
|
742 |
|
743 |
if equation_file is None: equation_file = global_equation_file
|
744 |
if n_features is None: n_features = global_n_features
|
745 |
if variable_names is None: variable_names = global_variable_names
|
746 |
if extra_sympy_mappings is None: extra_sympy_mappings = global_extra_sympy_mappings
|
747 |
+
if multioutput is None: multioutput = global_multioutput
|
748 |
+
if nout is None: nout = global_nout
|
749 |
|
750 |
global_equation_file = equation_file
|
751 |
global_n_features = n_features
|
752 |
global_variable_names = variable_names
|
753 |
global_extra_sympy_mappings = extra_sympy_mappings
|
754 |
+
global_multioutput = multioutput
|
755 |
+
global_nout = nout
|
756 |
|
757 |
try:
|
758 |
+
if multioutput:
|
759 |
+
all_outputs = [pd.read_csv(f'out{i}_' + str(equation_file) + '.bkup', sep="|") for i in range(1, nout+1)]
|
760 |
+
else:
|
761 |
+
all_outputs = [pd.read_csv(str(equation_file) + '.bkup', sep="|")]
|
762 |
except FileNotFoundError:
|
763 |
raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
|
764 |
|
|
|
780 |
else:
|
781 |
sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
|
782 |
|
783 |
+
ret_outputs = []
|
784 |
+
for output in all_outputs:
|
785 |
+
for i in range(len(output)):
|
786 |
+
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
787 |
+
sympy_format.append(eqn)
|
788 |
+
if output_jax_format:
|
789 |
+
func, params = sympy2jax(eqn, sympy_symbols)
|
790 |
+
jax_format.append({'callable': func, 'parameters': params})
|
791 |
+
lambda_format.append(lambdify(sympy_symbols, eqn))
|
792 |
+
curMSE = output.loc[i, 'MSE']
|
793 |
+
curComplexity = output.loc[i, 'Complexity']
|
794 |
+
|
795 |
+
if lastMSE is None:
|
796 |
+
cur_score = 0.0
|
797 |
+
else:
|
798 |
+
cur_score = - np.log(curMSE/lastMSE)/(curComplexity - lastComplexity)
|
799 |
+
|
800 |
+
scores.append(cur_score)
|
801 |
+
lastMSE = curMSE
|
802 |
+
lastComplexity = curComplexity
|
803 |
+
|
804 |
+
output['score'] = np.array(scores)
|
805 |
+
output['sympy_format'] = sympy_format
|
806 |
+
output['lambda_format'] = lambda_format
|
807 |
+
output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
|
808 |
if output_jax_format:
|
809 |
+
output_cols += ['jax_format']
|
810 |
+
output['jax_format'] = jax_format
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
811 |
|
812 |
+
ret_outputs.append(output[output_cols])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
|
814 |
+
if multioutput:
|
815 |
+
return ret_outputs
|
816 |
+
else:
|
817 |
+
return ret_outputs[0]
|
818 |
|
819 |
def best_row(equations=None):
|
820 |
"""Return the best row of a hall of fame file using the score column.
|
821 |
By default this uses the last equation file.
|
822 |
"""
|
823 |
+
if equations is None: all_eqs = get_hof()
|
824 |
+
if isinstance(all_eqs, list):
|
825 |
+
return [equations[j].iloc[np.argmax(equations[j]['score'])] for j in range(len(all_eqs))]
|
826 |
+
else:
|
827 |
+
best_idx = np.argmax(equations['score'])
|
828 |
+
return equations.iloc[best_idx]
|
829 |
|
830 |
def best_tex(equations=None):
|
831 |
"""Return the equation with the best score, in latex format
|