Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
898f500
1
Parent(s):
3772652
Add mechanism for extracting JAX functions
Browse files- pysr/sr.py +18 -4
- setup.py +1 -1
pysr/sr.py
CHANGED
@@ -12,7 +12,7 @@ import shutil
|
|
12 |
from pathlib import Path
|
13 |
from datetime import datetime
|
14 |
import warnings
|
15 |
-
|
16 |
|
17 |
global_equation_file = 'hall_of_fame.csv'
|
18 |
global_n_features = None
|
@@ -106,6 +106,7 @@ def pysr(X=None, y=None, weights=None,
|
|
106 |
user_input=True,
|
107 |
update=True,
|
108 |
temp_equation_file=False,
|
|
|
109 |
warmupMaxsize=None, #Deprecated
|
110 |
):
|
111 |
"""Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
|
@@ -216,6 +217,8 @@ def pysr(X=None, y=None, weights=None,
|
|
216 |
:param temp_equation_file: Whether to put the hall of fame file in
|
217 |
the temp directory. Deletion is then controlled with the
|
218 |
delete_tempfiles argument.
|
|
|
|
|
219 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
220 |
(as strings).
|
221 |
|
@@ -281,7 +284,8 @@ def pysr(X=None, y=None, weights=None,
|
|
281 |
weightSimplify=weightSimplify,
|
282 |
constraints=constraints,
|
283 |
extra_sympy_mappings=extra_sympy_mappings,
|
284 |
-
julia_project=julia_project, loss=loss
|
|
|
285 |
|
286 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
287 |
|
@@ -633,7 +637,8 @@ def run_feature_selection(X, y, select_k_features):
|
|
633 |
max_features=select_k_features, prefit=True)
|
634 |
return selector.get_support(indices=True)
|
635 |
|
636 |
-
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
|
637 |
"""Get the equations from a hall of fame file. If no arguments
|
638 |
entered, the ones used previously from a call to PySR will be used."""
|
639 |
|
@@ -663,6 +668,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None, extra_symp
|
|
663 |
lastComplexity = 0
|
664 |
sympy_format = []
|
665 |
lambda_format = []
|
|
|
|
|
666 |
use_custom_variable_names = (len(variable_names) != 0)
|
667 |
local_sympy_mappings = {
|
668 |
**extra_sympy_mappings,
|
@@ -677,6 +684,9 @@ def get_hof(equation_file=None, n_features=None, variable_names=None, extra_symp
|
|
677 |
for i in range(len(output)):
|
678 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
679 |
sympy_format.append(eqn)
|
|
|
|
|
|
|
680 |
lambda_format.append(lambdify(sympy_symbols, eqn))
|
681 |
curMSE = output.loc[i, 'MSE']
|
682 |
curComplexity = output.loc[i, 'Complexity']
|
@@ -693,8 +703,12 @@ def get_hof(equation_file=None, n_features=None, variable_names=None, extra_symp
|
|
693 |
output['score'] = np.array(scores)
|
694 |
output['sympy_format'] = sympy_format
|
695 |
output['lambda_format'] = lambda_format
|
|
|
|
|
|
|
|
|
696 |
|
697 |
-
return output[
|
698 |
|
699 |
def best_row(equations=None):
|
700 |
"""Return the best row of a hall of fame file using the score column.
|
|
|
12 |
from pathlib import Path
|
13 |
from datetime import datetime
|
14 |
import warnings
|
15 |
+
from .export import sympy2jax
|
16 |
|
17 |
global_equation_file = 'hall_of_fame.csv'
|
18 |
global_n_features = None
|
|
|
106 |
user_input=True,
|
107 |
update=True,
|
108 |
temp_equation_file=False,
|
109 |
+
output_jax_format=False,
|
110 |
warmupMaxsize=None, #Deprecated
|
111 |
):
|
112 |
"""Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
|
|
|
217 |
:param temp_equation_file: Whether to put the hall of fame file in
|
218 |
the temp directory. Deletion is then controlled with the
|
219 |
delete_tempfiles argument.
|
220 |
+
:param output_jax_format: Whether to create a 'jax_format' column in the output,
|
221 |
+
containing jax-callable functions and the default parameters in a jax array.
|
222 |
:returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
|
223 |
(as strings).
|
224 |
|
|
|
284 |
weightSimplify=weightSimplify,
|
285 |
constraints=constraints,
|
286 |
extra_sympy_mappings=extra_sympy_mappings,
|
287 |
+
julia_project=julia_project, loss=loss,
|
288 |
+
output_jax_format=output_jax_format)
|
289 |
|
290 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
291 |
|
|
|
637 |
max_features=select_k_features, prefit=True)
|
638 |
return selector.get_support(indices=True)
|
639 |
|
640 |
+
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
641 |
+
extra_sympy_mappings=None, output_jax_format=False, **kwargs):
|
642 |
"""Get the equations from a hall of fame file. If no arguments
|
643 |
entered, the ones used previously from a call to PySR will be used."""
|
644 |
|
|
|
668 |
lastComplexity = 0
|
669 |
sympy_format = []
|
670 |
lambda_format = []
|
671 |
+
if output_jax_format:
|
672 |
+
jax_format = []
|
673 |
use_custom_variable_names = (len(variable_names) != 0)
|
674 |
local_sympy_mappings = {
|
675 |
**extra_sympy_mappings,
|
|
|
684 |
for i in range(len(output)):
|
685 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
686 |
sympy_format.append(eqn)
|
687 |
+
if output_jax_format:
|
688 |
+
func, params = sympy2jax(eqn, sympy_symbols)
|
689 |
+
jax_format.append({'callable': func, 'parameters': parameters})
|
690 |
lambda_format.append(lambdify(sympy_symbols, eqn))
|
691 |
curMSE = output.loc[i, 'MSE']
|
692 |
curComplexity = output.loc[i, 'Complexity']
|
|
|
703 |
output['score'] = np.array(scores)
|
704 |
output['sympy_format'] = sympy_format
|
705 |
output['lambda_format'] = lambda_format
|
706 |
+
output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
|
707 |
+
if output_jax_format:
|
708 |
+
output_cols += 'jax_format'
|
709 |
+
output['jax_format'] = jax_format
|
710 |
|
711 |
+
return output[output_cols]
|
712 |
|
713 |
def best_row(equations=None):
|
714 |
"""Return the best row of a hall of fame file using the score column.
|
setup.py
CHANGED
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
|
|
5 |
|
6 |
setuptools.setup(
|
7 |
name="pysr", # Replace with your own username
|
8 |
-
version="0.5.
|
9 |
author="Miles Cranmer",
|
10 |
author_email="[email protected]",
|
11 |
description="Simple and efficient symbolic regression",
|
|
|
5 |
|
6 |
setuptools.setup(
|
7 |
name="pysr", # Replace with your own username
|
8 |
+
version="0.5.13",
|
9 |
author="Miles Cranmer",
|
10 |
author_email="[email protected]",
|
11 |
description="Simple and efficient symbolic regression",
|