Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
c96b30c
1
Parent(s):
8c55475
Clean up global variables into single dict
Browse files- pysr/sr.py +60 -49
- test/test.py +2 -1
pysr/sr.py
CHANGED
@@ -14,12 +14,19 @@ from pathlib import Path
|
|
14 |
from datetime import datetime
|
15 |
import warnings
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
sympy_mappings = {
|
25 |
'div': lambda x, y : x/y,
|
@@ -62,16 +69,20 @@ sympy_mappings = {
|
|
62 |
|
63 |
class CallableEquation(object):
|
64 |
"""Simple wrapper for numpy lambda functions built with sympy"""
|
65 |
-
def __init__(self, sympy_symbols, eqn):
|
66 |
self._sympy = eqn
|
67 |
self._sympy_symbols = sympy_symbols
|
|
|
68 |
self._lambda = lambdify(sympy_symbols, eqn)
|
69 |
|
70 |
def __repr__(self):
|
71 |
return f"PySRFunction(X=>{self._sympy})"
|
72 |
|
73 |
def __call__(self, X):
|
74 |
-
|
|
|
|
|
|
|
75 |
|
76 |
def pysr(X, y, weights=None,
|
77 |
binary_operators=None,
|
@@ -284,7 +295,7 @@ def pysr(X, y, weights=None,
|
|
284 |
if maxsize > 40:
|
285 |
warnings.warn("Note: Using a large maxsize for the equation search will be slow and use significant memory. You should consider turning `useFrequency` to False, and perhaps use `warmupMaxsizeBy`.")
|
286 |
|
287 |
-
X, variable_names = _handle_feature_selection(
|
288 |
X, select_k_features,
|
289 |
use_custom_variable_names, variable_names, y
|
290 |
)
|
@@ -343,6 +354,7 @@ def pysr(X, y, weights=None,
|
|
343 |
julia_project=julia_project, loss=loss,
|
344 |
output_jax_format=output_jax_format,
|
345 |
output_torch_format=output_torch_format,
|
|
|
346 |
multioutput=multioutput, nout=nout)
|
347 |
|
348 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
@@ -391,21 +403,13 @@ def pysr(X, y, weights=None,
|
|
391 |
return equations
|
392 |
|
393 |
|
|
|
|
|
394 |
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
global global_variable_names
|
400 |
-
global global_extra_sympy_mappings
|
401 |
-
global global_multioutput
|
402 |
-
global global_nout
|
403 |
-
global_n_features = X.shape[1]
|
404 |
-
global_equation_file = equation_file
|
405 |
-
global_variable_names = variable_names
|
406 |
-
global_extra_sympy_mappings = extra_sympy_mappings
|
407 |
-
global_multioutput = multioutput
|
408 |
-
global_nout = nout
|
409 |
|
410 |
|
411 |
def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
|
@@ -668,7 +672,9 @@ def _handle_feature_selection(X, select_k_features, use_custom_variable_names, v
|
|
668 |
|
669 |
if use_custom_variable_names:
|
670 |
variable_names = [variable_names[selection[i]] for i in range(len(selection))]
|
671 |
-
|
|
|
|
|
672 |
|
673 |
|
674 |
def _set_paths(tempdir):
|
@@ -732,33 +738,38 @@ def run_feature_selection(X, y, select_k_features):
|
|
732 |
return selector.get_support(indices=True)
|
733 |
|
734 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
735 |
-
|
736 |
-
|
737 |
extra_jax_mappings=None, extra_torch_mappings=None,
|
738 |
multioutput=None, nout=None, **kwargs):
|
739 |
"""Get the equations from a hall of fame file. If no arguments
|
740 |
entered, the ones used previously from a call to PySR will be used."""
|
741 |
|
742 |
-
global
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
if
|
750 |
-
if
|
751 |
-
if
|
752 |
-
if
|
753 |
-
if
|
754 |
-
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
|
|
|
|
|
|
|
|
|
|
762 |
|
763 |
try:
|
764 |
if multioutput:
|
@@ -797,18 +808,18 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
797 |
sympy_format.append(eqn)
|
798 |
|
799 |
# Numpy:
|
800 |
-
lambda_format.append(CallableEquation(sympy_symbols, eqn))
|
801 |
|
802 |
# JAX:
|
803 |
if output_jax_format:
|
804 |
from .export_jax import sympy2jax
|
805 |
-
func, params = sympy2jax(eqn, sympy_symbols)
|
806 |
jax_format.append({'callable': func, 'parameters': params})
|
807 |
|
808 |
# Torch:
|
809 |
if output_torch_format:
|
810 |
from .export_torch import sympy2torch
|
811 |
-
module = sympy2torch(eqn, sympy_symbols)
|
812 |
torch_format.append(module)
|
813 |
|
814 |
curMSE = output.loc[i, 'MSE']
|
|
|
14 |
from datetime import datetime
|
15 |
import warnings
|
16 |
|
17 |
+
global_state = dict(
|
18 |
+
equation_file='hall_of_fame.csv',
|
19 |
+
n_features=None,
|
20 |
+
variable_names=[],
|
21 |
+
extra_sympy_mappings={},
|
22 |
+
extra_torch_mappings={},
|
23 |
+
extra_jax_mappings={},
|
24 |
+
output_jax_format=False,
|
25 |
+
output_torch_format=False,
|
26 |
+
multioutput=False,
|
27 |
+
nout=1,
|
28 |
+
selection=None
|
29 |
+
)
|
30 |
|
31 |
sympy_mappings = {
|
32 |
'div': lambda x, y : x/y,
|
|
|
69 |
|
70 |
class CallableEquation(object):
|
71 |
"""Simple wrapper for numpy lambda functions built with sympy"""
|
72 |
+
def __init__(self, sympy_symbols, eqn, selection=None):
|
73 |
self._sympy = eqn
|
74 |
self._sympy_symbols = sympy_symbols
|
75 |
+
self._selection = selection
|
76 |
self._lambda = lambdify(sympy_symbols, eqn)
|
77 |
|
78 |
def __repr__(self):
|
79 |
return f"PySRFunction(X=>{self._sympy})"
|
80 |
|
81 |
def __call__(self, X):
|
82 |
+
if self._selection is not None:
|
83 |
+
return self._lambda(*X[:, self._selection].T)
|
84 |
+
else:
|
85 |
+
return self._lambda(*X.T)
|
86 |
|
87 |
def pysr(X, y, weights=None,
|
88 |
binary_operators=None,
|
|
|
295 |
if maxsize > 40:
|
296 |
warnings.warn("Note: Using a large maxsize for the equation search will be slow and use significant memory. You should consider turning `useFrequency` to False, and perhaps use `warmupMaxsizeBy`.")
|
297 |
|
298 |
+
X, variable_names, selection = _handle_feature_selection(
|
299 |
X, select_k_features,
|
300 |
use_custom_variable_names, variable_names, y
|
301 |
)
|
|
|
354 |
julia_project=julia_project, loss=loss,
|
355 |
output_jax_format=output_jax_format,
|
356 |
output_torch_format=output_torch_format,
|
357 |
+
selection=selection,
|
358 |
multioutput=multioutput, nout=nout)
|
359 |
|
360 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
|
|
403 |
return equations
|
404 |
|
405 |
|
406 |
+
def _set_globals(X, **kwargs):
|
407 |
+
global global_state
|
408 |
|
409 |
+
global_state['n_features'] = X.shape[1]
|
410 |
+
for key, value in kwargs.items():
|
411 |
+
if key in global_state:
|
412 |
+
global_state[key] = value
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
|
414 |
|
415 |
def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
|
|
|
672 |
|
673 |
if use_custom_variable_names:
|
674 |
variable_names = [variable_names[selection[i]] for i in range(len(selection))]
|
675 |
+
else:
|
676 |
+
selection = None
|
677 |
+
return X, variable_names, selection
|
678 |
|
679 |
|
680 |
def _set_paths(tempdir):
|
|
|
738 |
return selector.get_support(indices=True)
|
739 |
|
740 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
741 |
+
output_jax_format=None, output_torch_format=None,
|
742 |
+
selection=None, extra_sympy_mappings=None,
|
743 |
extra_jax_mappings=None, extra_torch_mappings=None,
|
744 |
multioutput=None, nout=None, **kwargs):
|
745 |
"""Get the equations from a hall of fame file. If no arguments
|
746 |
entered, the ones used previously from a call to PySR will be used."""
|
747 |
|
748 |
+
global global_state
|
749 |
+
|
750 |
+
if equation_file is None: equation_file = global_state['equation_file']
|
751 |
+
if n_features is None: n_features = global_state['n_features']
|
752 |
+
if variable_names is None: variable_names = global_state['variable_names']
|
753 |
+
if extra_sympy_mappings is None: extra_sympy_mappings = global_state['extra_sympy_mappings']
|
754 |
+
if extra_jax_mappings is None: extra_jax_mappings = global_state['extra_jax_mappings']
|
755 |
+
if extra_torch_mappings is None: extra_torch_mappings = global_state['extra_torch_mappings']
|
756 |
+
if output_torch_format is None: output_torch_format = global_state['output_torch_format']
|
757 |
+
if output_jax_format is None: output_jax_format = global_state['output_jax_format']
|
758 |
+
if multioutput is None: multioutput = global_state['multioutput']
|
759 |
+
if nout is None: nout = global_state['nout']
|
760 |
+
|
761 |
+
global_state['selection'] = selection
|
762 |
+
global_state['equation_file'] = equation_file
|
763 |
+
global_state['n_features'] = n_features
|
764 |
+
global_state['variable_names'] = variable_names
|
765 |
+
global_state['extra_sympy_mappings'] = extra_sympy_mappings
|
766 |
+
global_state['extra_jax_mappings'] = extra_jax_mappings
|
767 |
+
global_state['extra_torch_mappings'] = extra_torch_mappings
|
768 |
+
global_state['output_torch_format'] = output_torch_format
|
769 |
+
global_state['output_jax_format'] = output_jax_format
|
770 |
+
global_state['multioutput'] = multioutput
|
771 |
+
global_state['nout'] = nout
|
772 |
+
global_state['selection'] = selection
|
773 |
|
774 |
try:
|
775 |
if multioutput:
|
|
|
808 |
sympy_format.append(eqn)
|
809 |
|
810 |
# Numpy:
|
811 |
+
lambda_format.append(CallableEquation(sympy_symbols, eqn, selection))
|
812 |
|
813 |
# JAX:
|
814 |
if output_jax_format:
|
815 |
from .export_jax import sympy2jax
|
816 |
+
func, params = sympy2jax(eqn, sympy_symbols, selection)
|
817 |
jax_format.append({'callable': func, 'parameters': params})
|
818 |
|
819 |
# Torch:
|
820 |
if output_torch_format:
|
821 |
from .export_torch import sympy2torch
|
822 |
+
module = sympy2torch(eqn, sympy_symbols, selection)
|
823 |
torch_format.append(module)
|
824 |
|
825 |
curMSE = output.loc[i, 'MSE']
|
test/test.py
CHANGED
@@ -112,11 +112,12 @@ class TestFeatureSelection(unittest.TestCase):
|
|
112 |
X = np.random.randn(20000, 5)
|
113 |
y = X[:, 2]**2 + X[:, 3]**2
|
114 |
var_names = [f'x{i}' for i in range(5)]
|
115 |
-
selected_X, selected_var_names = _handle_feature_selection(
|
116 |
X, select_k_features=2,
|
117 |
use_custom_variable_names=True,
|
118 |
variable_names=[f'x{i}' for i in range(5)],
|
119 |
y=y)
|
|
|
120 |
self.assertEqual(set(selected_var_names), set('x2 x3'.split(' ')))
|
121 |
np.testing.assert_array_equal(
|
122 |
np.sort(selected_X, axis=1),
|
|
|
112 |
X = np.random.randn(20000, 5)
|
113 |
y = X[:, 2]**2 + X[:, 3]**2
|
114 |
var_names = [f'x{i}' for i in range(5)]
|
115 |
+
selected_X, selected_var_names, selection = _handle_feature_selection(
|
116 |
X, select_k_features=2,
|
117 |
use_custom_variable_names=True,
|
118 |
variable_names=[f'x{i}' for i in range(5)],
|
119 |
y=y)
|
120 |
+
self.assertTrue((2 in selection) and (3 in selection))
|
121 |
self.assertEqual(set(selected_var_names), set('x2 x3'.split(' ')))
|
122 |
np.testing.assert_array_equal(
|
123 |
np.sort(selected_X, axis=1),
|