Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
68b3673
1
Parent(s):
a06bfc4
Add torch format output; dont import jax/torch by default
Browse files- pysr/__init__.py +0 -2
- pysr/export_jax.py +47 -51
- pysr/export_torch.py +0 -2
- pysr/sr.py +21 -4
pysr/__init__.py
CHANGED
@@ -1,4 +1,2 @@
|
|
1 |
from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
|
2 |
from .feynman_problems import Problem, FeynmanProblem
|
3 |
-
from .export_jax import sympy2jax
|
4 |
-
from .export_torch import sympy2torch
|
|
|
1 |
from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
|
2 |
from .feynman_problems import Problem, FeynmanProblem
|
|
|
|
pysr/export_jax.py
CHANGED
@@ -2,60 +2,56 @@ import functools as ft
|
|
2 |
import sympy
|
3 |
import string
|
4 |
import random
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
from jax import numpy as jnp
|
9 |
-
from jax.scipy import special as jsp
|
10 |
|
11 |
# Special since need to reduce arguments.
|
12 |
-
|
13 |
-
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
except ImportError:
|
58 |
-
...
|
59 |
|
60 |
def sympy2jaxtext(expr, parameters, symbols_in):
|
61 |
if issubclass(expr.func, sympy.Float):
|
|
|
2 |
import sympy
|
3 |
import string
|
4 |
import random
|
5 |
+
import jax
|
6 |
+
from jax import numpy as jnp
|
7 |
+
from jax.scipy import special as jsp
|
|
|
|
|
8 |
|
9 |
# Special since need to reduce arguments.
|
10 |
+
MUL = 0
|
11 |
+
ADD = 1
|
12 |
|
13 |
+
_jnp_func_lookup = {
|
14 |
+
sympy.Mul: MUL,
|
15 |
+
sympy.Add: ADD,
|
16 |
+
sympy.div: "jnp.div",
|
17 |
+
sympy.Abs: "jnp.abs",
|
18 |
+
sympy.sign: "jnp.sign",
|
19 |
+
# Note: May raise error for ints.
|
20 |
+
sympy.ceiling: "jnp.ceil",
|
21 |
+
sympy.floor: "jnp.floor",
|
22 |
+
sympy.log: "jnp.log",
|
23 |
+
sympy.exp: "jnp.exp",
|
24 |
+
sympy.sqrt: "jnp.sqrt",
|
25 |
+
sympy.cos: "jnp.cos",
|
26 |
+
sympy.acos: "jnp.acos",
|
27 |
+
sympy.sin: "jnp.sin",
|
28 |
+
sympy.asin: "jnp.asin",
|
29 |
+
sympy.tan: "jnp.tan",
|
30 |
+
sympy.atan: "jnp.atan",
|
31 |
+
sympy.atan2: "jnp.atan2",
|
32 |
+
# Note: Also may give NaN for complex results.
|
33 |
+
sympy.cosh: "jnp.cosh",
|
34 |
+
sympy.acosh: "jnp.acosh",
|
35 |
+
sympy.sinh: "jnp.sinh",
|
36 |
+
sympy.asinh: "jnp.asinh",
|
37 |
+
sympy.tanh: "jnp.tanh",
|
38 |
+
sympy.atanh: "jnp.atanh",
|
39 |
+
sympy.Pow: "jnp.power",
|
40 |
+
sympy.re: "jnp.real",
|
41 |
+
sympy.im: "jnp.imag",
|
42 |
+
sympy.arg: "jnp.angle",
|
43 |
+
# Note: May raise error for ints and complexes
|
44 |
+
sympy.erf: "jsp.erf",
|
45 |
+
sympy.erfc: "jsp.erfc",
|
46 |
+
sympy.LessThan: "jnp.less",
|
47 |
+
sympy.GreaterThan: "jnp.greater",
|
48 |
+
sympy.And: "jnp.logical_and",
|
49 |
+
sympy.Or: "jnp.logical_or",
|
50 |
+
sympy.Not: "jnp.logical_not",
|
51 |
+
sympy.Max: "jnp.max",
|
52 |
+
sympy.Min: "jnp.min",
|
53 |
+
sympy.Mod: "jnp.mod",
|
54 |
+
}
|
|
|
|
|
55 |
|
56 |
def sympy2jaxtext(expr, parameters, symbols_in):
|
57 |
if issubclass(expr.func, sympy.Float):
|
pysr/export_torch.py
CHANGED
@@ -8,7 +8,6 @@ import functools as ft
|
|
8 |
import sympy
|
9 |
import torch
|
10 |
|
11 |
-
|
12 |
def _reduce(fn):
|
13 |
def fn_(*args):
|
14 |
return ft.reduce(fn, args)
|
@@ -67,7 +66,6 @@ _global_func_lookup = {
|
|
67 |
sympy.Determinant: torch.det,
|
68 |
}
|
69 |
|
70 |
-
|
71 |
class _Node(torch.nn.Module):
|
72 |
def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
|
73 |
super().__init__(**kwargs)
|
|
|
8 |
import sympy
|
9 |
import torch
|
10 |
|
|
|
11 |
def _reduce(fn):
|
12 |
def fn_(*args):
|
13 |
return ft.reduce(fn, args)
|
|
|
66 |
sympy.Determinant: torch.det,
|
67 |
}
|
68 |
|
|
|
69 |
class _Node(torch.nn.Module):
|
70 |
def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
|
71 |
super().__init__(**kwargs)
|
pysr/sr.py
CHANGED
@@ -13,8 +13,6 @@ import shutil
|
|
13 |
from pathlib import Path
|
14 |
from datetime import datetime
|
15 |
import warnings
|
16 |
-
from .export_jax import sympy2jax
|
17 |
-
from .export_torch import sympy2torch
|
18 |
|
19 |
global_equation_file = 'hall_of_fame.csv'
|
20 |
global_n_features = None
|
@@ -125,11 +123,12 @@ def pysr(X, y, weights=None,
|
|
125 |
update=True,
|
126 |
temp_equation_file=False,
|
127 |
output_jax_format=False,
|
|
|
128 |
optimizer_algorithm="BFGS",
|
129 |
optimizer_nrestarts=3,
|
130 |
optimize_probability=1.0,
|
131 |
-
optimizer_iterations=10
|
132 |
-
|
133 |
"""Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
|
134 |
Note: most default parameters have been tuned over several example
|
135 |
equations, but you should adjust `niterations`,
|
@@ -242,6 +241,8 @@ def pysr(X, y, weights=None,
|
|
242 |
delete_tempfiles argument.
|
243 |
:param output_jax_format: Whether to create a 'jax_format' column in the output,
|
244 |
containing jax-callable functions and the default parameters in a jax array.
|
|
|
|
|
245 |
:returns: pd.DataFrame or list, Results dataframe,
|
246 |
giving complexity, MSE, and equations (as strings), as well as functional
|
247 |
forms. If list, each element corresponds to a dataframe of equations
|
@@ -337,6 +338,7 @@ def pysr(X, y, weights=None,
|
|
337 |
extra_sympy_mappings=extra_sympy_mappings,
|
338 |
julia_project=julia_project, loss=loss,
|
339 |
output_jax_format=output_jax_format,
|
|
|
340 |
multioutput=multioutput, nout=nout)
|
341 |
|
342 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
@@ -727,6 +729,7 @@ def run_feature_selection(X, y, select_k_features):
|
|
727 |
|
728 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
729 |
extra_sympy_mappings=None, output_jax_format=False,
|
|
|
730 |
multioutput=None, nout=None, **kwargs):
|
731 |
"""Get the equations from a hall of fame file. If no arguments
|
732 |
entered, the ones used previously from a call to PySR will be used."""
|
@@ -771,6 +774,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
771 |
lambda_format = []
|
772 |
if output_jax_format:
|
773 |
jax_format = []
|
|
|
|
|
774 |
use_custom_variable_names = (len(variable_names) != 0)
|
775 |
local_sympy_mappings = {
|
776 |
**extra_sympy_mappings,
|
@@ -786,10 +791,19 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
786 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
787 |
sympy_format.append(eqn)
|
788 |
if output_jax_format:
|
|
|
789 |
func, params = sympy2jax(eqn, sympy_symbols)
|
790 |
jax_format.append({'callable': func, 'parameters': params})
|
|
|
791 |
|
792 |
lambda_format.append(CallableEquation(sympy_symbols, eqn))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
793 |
curMSE = output.loc[i, 'MSE']
|
794 |
curComplexity = output.loc[i, 'Complexity']
|
795 |
|
@@ -809,6 +823,9 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
809 |
if output_jax_format:
|
810 |
output_cols += ['jax_format']
|
811 |
output['jax_format'] = jax_format
|
|
|
|
|
|
|
812 |
|
813 |
ret_outputs.append(output[output_cols])
|
814 |
|
|
|
13 |
from pathlib import Path
|
14 |
from datetime import datetime
|
15 |
import warnings
|
|
|
|
|
16 |
|
17 |
global_equation_file = 'hall_of_fame.csv'
|
18 |
global_n_features = None
|
|
|
123 |
update=True,
|
124 |
temp_equation_file=False,
|
125 |
output_jax_format=False,
|
126 |
+
output_torch_format=False,
|
127 |
optimizer_algorithm="BFGS",
|
128 |
optimizer_nrestarts=3,
|
129 |
optimize_probability=1.0,
|
130 |
+
optimizer_iterations=10
|
131 |
+
):
|
132 |
"""Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
|
133 |
Note: most default parameters have been tuned over several example
|
134 |
equations, but you should adjust `niterations`,
|
|
|
241 |
delete_tempfiles argument.
|
242 |
:param output_jax_format: Whether to create a 'jax_format' column in the output,
|
243 |
containing jax-callable functions and the default parameters in a jax array.
|
244 |
+
:param output_torch_format: Whether to create a 'torch_format' column in the output,
|
245 |
+
containing a torch module with trainable parameters.
|
246 |
:returns: pd.DataFrame or list, Results dataframe,
|
247 |
giving complexity, MSE, and equations (as strings), as well as functional
|
248 |
forms. If list, each element corresponds to a dataframe of equations
|
|
|
338 |
extra_sympy_mappings=extra_sympy_mappings,
|
339 |
julia_project=julia_project, loss=loss,
|
340 |
output_jax_format=output_jax_format,
|
341 |
+
output_torch_format=output_torch_format,
|
342 |
multioutput=multioutput, nout=nout)
|
343 |
|
344 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
|
|
729 |
|
730 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
731 |
extra_sympy_mappings=None, output_jax_format=False,
|
732 |
+
output_torch_format=False,
|
733 |
multioutput=None, nout=None, **kwargs):
|
734 |
"""Get the equations from a hall of fame file. If no arguments
|
735 |
entered, the ones used previously from a call to PySR will be used."""
|
|
|
774 |
lambda_format = []
|
775 |
if output_jax_format:
|
776 |
jax_format = []
|
777 |
+
if output_torch_format:
|
778 |
+
torch_format = []
|
779 |
use_custom_variable_names = (len(variable_names) != 0)
|
780 |
local_sympy_mappings = {
|
781 |
**extra_sympy_mappings,
|
|
|
791 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
792 |
sympy_format.append(eqn)
|
793 |
if output_jax_format:
|
794 |
+
from .export_jax import sympy2jax
|
795 |
func, params = sympy2jax(eqn, sympy_symbols)
|
796 |
jax_format.append({'callable': func, 'parameters': params})
|
797 |
+
<<<<<<< HEAD
|
798 |
|
799 |
lambda_format.append(CallableEquation(sympy_symbols, eqn))
|
800 |
+
=======
|
801 |
+
if output_torch_format:
|
802 |
+
from .export_torch import sympy2torch
|
803 |
+
func, params = sympy2torch(eqn, sympy_symbols)
|
804 |
+
torch_format.append({'callable': func, 'parameters': params})
|
805 |
+
lambda_format.append(lambdify(sympy_symbols, eqn))
|
806 |
+
>>>>>>> 6ba697f (Add torch format output; dont import jax/torch by default)
|
807 |
curMSE = output.loc[i, 'MSE']
|
808 |
curComplexity = output.loc[i, 'Complexity']
|
809 |
|
|
|
823 |
if output_jax_format:
|
824 |
output_cols += ['jax_format']
|
825 |
output['jax_format'] = jax_format
|
826 |
+
if output_torch_format:
|
827 |
+
output_cols += ['torch_format']
|
828 |
+
output['torch_format'] = torch_format
|
829 |
|
830 |
ret_outputs.append(output[output_cols])
|
831 |
|