Spaces:
Running
Running
MilesCranmer
commited on
Merge pull request #48 from MilesCranmer/torch-export
Browse files- .github/workflows/CI.yml +9 -3
- pysr/__init__.py +2 -1
- pysr/{export.py → export_jax.py} +76 -55
- pysr/export_torch.py +171 -0
- pysr/sr.py +31 -6
- test/test_jax.py +23 -1
- test/test_torch.py +38 -0
.github/workflows/CI.yml
CHANGED
@@ -61,17 +61,23 @@ jobs:
|
|
61 |
python setup.py install
|
62 |
- name: "Install Coverage tool"
|
63 |
run: pip install coverage coveralls
|
|
|
|
|
|
|
64 |
- name: "Install JAX"
|
65 |
if: matrix.os != 'windows-latest'
|
66 |
run: pip install jax jaxlib # (optional import)
|
67 |
shell: bash
|
68 |
-
- name: "Run tests"
|
69 |
-
run: coverage run --source=pysr --omit='*/feynman_problems.py' -m unittest test.test
|
70 |
-
shell: bash
|
71 |
- name: "Run JAX tests"
|
72 |
if: matrix.os != 'windows-latest'
|
73 |
run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_jax
|
74 |
shell: bash
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
- name: Coveralls
|
76 |
env:
|
77 |
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
|
|
61 |
python setup.py install
|
62 |
- name: "Install Coverage tool"
|
63 |
run: pip install coverage coveralls
|
64 |
+
- name: "Run tests"
|
65 |
+
run: coverage run --source=pysr --omit='*/feynman_problems.py' -m unittest test.test
|
66 |
+
shell: bash
|
67 |
- name: "Install JAX"
|
68 |
if: matrix.os != 'windows-latest'
|
69 |
run: pip install jax jaxlib # (optional import)
|
70 |
shell: bash
|
|
|
|
|
|
|
71 |
- name: "Run JAX tests"
|
72 |
if: matrix.os != 'windows-latest'
|
73 |
run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_jax
|
74 |
shell: bash
|
75 |
+
- name: "Install Torch"
|
76 |
+
run: pip install torch # (optional import)
|
77 |
+
shell: bash
|
78 |
+
- name: "Run Torch tests"
|
79 |
+
run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_torch
|
80 |
+
shell: bash
|
81 |
- name: Coveralls
|
82 |
env:
|
83 |
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
pysr/__init__.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
|
2 |
from .feynman_problems import Problem, FeynmanProblem
|
3 |
-
from .
|
|
|
|
1 |
from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
|
2 |
from .feynman_problems import Problem, FeynmanProblem
|
3 |
+
from .export_jax import sympy2jax
|
4 |
+
from .export_torch import sympy2torch
|
pysr/{export.py → export_jax.py}
RENAMED
@@ -3,59 +3,53 @@ import sympy
|
|
3 |
import string
|
4 |
import random
|
5 |
|
6 |
-
try:
|
7 |
-
import jax
|
8 |
-
from jax import numpy as jnp
|
9 |
-
from jax.scipy import special as jsp
|
10 |
-
|
11 |
# Special since need to reduce arguments.
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
...
|
59 |
|
60 |
def sympy2jaxtext(expr, parameters, symbols_in):
|
61 |
if issubclass(expr.func, sympy.Float):
|
@@ -75,7 +69,28 @@ def sympy2jaxtext(expr, parameters, symbols_in):
|
|
75 |
else:
|
76 |
return f'{_func}({", ".join(args)})'
|
77 |
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
"""Returns a function f and its parameters;
|
80 |
the function takes an input matrix, and a list of arguments:
|
81 |
f(X, parameters)
|
@@ -146,9 +161,15 @@ def sympy2jax(equation, symbols_in):
|
|
146 |
# 3.5427954 , -2.7479894 ], dtype=float32)
|
147 |
```
|
148 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
parameters = []
|
150 |
-
functional_form_text = sympy2jaxtext(
|
151 |
-
hash_string = 'A_' + str(abs(hash(str(
|
152 |
text = f"def {hash_string}(X, parameters):\n"
|
153 |
text += " return "
|
154 |
text += functional_form_text
|
|
|
3 |
import string
|
4 |
import random
|
5 |
|
|
|
|
|
|
|
|
|
|
|
6 |
# Special since need to reduce arguments.
|
7 |
+
MUL = 0
|
8 |
+
ADD = 1
|
9 |
+
|
10 |
+
_jnp_func_lookup = {
|
11 |
+
sympy.Mul: MUL,
|
12 |
+
sympy.Add: ADD,
|
13 |
+
sympy.div: "jnp.div",
|
14 |
+
sympy.Abs: "jnp.abs",
|
15 |
+
sympy.sign: "jnp.sign",
|
16 |
+
# Note: May raise error for ints.
|
17 |
+
sympy.ceiling: "jnp.ceil",
|
18 |
+
sympy.floor: "jnp.floor",
|
19 |
+
sympy.log: "jnp.log",
|
20 |
+
sympy.exp: "jnp.exp",
|
21 |
+
sympy.sqrt: "jnp.sqrt",
|
22 |
+
sympy.cos: "jnp.cos",
|
23 |
+
sympy.acos: "jnp.acos",
|
24 |
+
sympy.sin: "jnp.sin",
|
25 |
+
sympy.asin: "jnp.asin",
|
26 |
+
sympy.tan: "jnp.tan",
|
27 |
+
sympy.atan: "jnp.atan",
|
28 |
+
sympy.atan2: "jnp.atan2",
|
29 |
+
# Note: Also may give NaN for complex results.
|
30 |
+
sympy.cosh: "jnp.cosh",
|
31 |
+
sympy.acosh: "jnp.acosh",
|
32 |
+
sympy.sinh: "jnp.sinh",
|
33 |
+
sympy.asinh: "jnp.asinh",
|
34 |
+
sympy.tanh: "jnp.tanh",
|
35 |
+
sympy.atanh: "jnp.atanh",
|
36 |
+
sympy.Pow: "jnp.power",
|
37 |
+
sympy.re: "jnp.real",
|
38 |
+
sympy.im: "jnp.imag",
|
39 |
+
sympy.arg: "jnp.angle",
|
40 |
+
# Note: May raise error for ints and complexes
|
41 |
+
sympy.erf: "jsp.erf",
|
42 |
+
sympy.erfc: "jsp.erfc",
|
43 |
+
sympy.LessThan: "jnp.less",
|
44 |
+
sympy.GreaterThan: "jnp.greater",
|
45 |
+
sympy.And: "jnp.logical_and",
|
46 |
+
sympy.Or: "jnp.logical_or",
|
47 |
+
sympy.Not: "jnp.logical_not",
|
48 |
+
sympy.Max: "jnp.max",
|
49 |
+
sympy.Min: "jnp.min",
|
50 |
+
sympy.Mod: "jnp.mod",
|
51 |
+
}
|
52 |
+
|
|
|
53 |
|
54 |
def sympy2jaxtext(expr, parameters, symbols_in):
|
55 |
if issubclass(expr.func, sympy.Float):
|
|
|
69 |
else:
|
70 |
return f'{_func}({", ".join(args)})'
|
71 |
|
72 |
+
|
73 |
+
jax_initialized = False
|
74 |
+
jax = None
|
75 |
+
jnp = None
|
76 |
+
jsp = None
|
77 |
+
|
78 |
+
def _initialize_jax():
|
79 |
+
global jax_initialized
|
80 |
+
global jax
|
81 |
+
global jnp
|
82 |
+
global jsp
|
83 |
+
|
84 |
+
if not jax_initialized:
|
85 |
+
import jax as _jax
|
86 |
+
from jax import numpy as _jnp
|
87 |
+
from jax.scipy import special as _jsp
|
88 |
+
jax = _jax
|
89 |
+
jnp = _jnp
|
90 |
+
jsp = _jsp
|
91 |
+
|
92 |
+
|
93 |
+
def sympy2jax(expression, symbols_in):
|
94 |
"""Returns a function f and its parameters;
|
95 |
the function takes an input matrix, and a list of arguments:
|
96 |
f(X, parameters)
|
|
|
161 |
# 3.5427954 , -2.7479894 ], dtype=float32)
|
162 |
```
|
163 |
"""
|
164 |
+
_initialize_jax()
|
165 |
+
global jax_initialized
|
166 |
+
global jax
|
167 |
+
global jnp
|
168 |
+
global jsp
|
169 |
+
|
170 |
parameters = []
|
171 |
+
functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
|
172 |
+
hash_string = 'A_' + str(abs(hash(str(expression) + str(symbols_in))))
|
173 |
text = f"def {hash_string}(X, parameters):\n"
|
174 |
text += " return "
|
175 |
text += functional_form_text
|
pysr/export_torch.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#####
|
2 |
+
# From https://github.com/patrick-kidger/sympytorch
|
3 |
+
# Copied here to allow PySR-specific tweaks
|
4 |
+
#####
|
5 |
+
|
6 |
+
import collections as co
|
7 |
+
import functools as ft
|
8 |
+
import sympy
|
9 |
+
|
10 |
+
def _reduce(fn):
|
11 |
+
def fn_(*args):
|
12 |
+
return ft.reduce(fn, args)
|
13 |
+
return fn_
|
14 |
+
|
15 |
+
torch_initialized = False
|
16 |
+
torch = None
|
17 |
+
_global_func_lookup = None
|
18 |
+
_Node = None
|
19 |
+
SingleSymPyModule = None
|
20 |
+
|
21 |
+
def _initialize_torch():
|
22 |
+
global torch_initialized
|
23 |
+
global torch
|
24 |
+
global _global_func_lookup
|
25 |
+
global _Node
|
26 |
+
global SingleSymPyModule
|
27 |
+
|
28 |
+
# Way to lazy load torch, only if this is called,
|
29 |
+
# but still allow this module to be loaded in __init__
|
30 |
+
if not torch_initialized:
|
31 |
+
import torch as _torch
|
32 |
+
torch = _torch
|
33 |
+
|
34 |
+
_global_func_lookup = {
|
35 |
+
sympy.Mul: _reduce(torch.mul),
|
36 |
+
sympy.Add: _reduce(torch.add),
|
37 |
+
sympy.div: torch.div,
|
38 |
+
sympy.Abs: torch.abs,
|
39 |
+
sympy.sign: torch.sign,
|
40 |
+
# Note: May raise error for ints.
|
41 |
+
sympy.ceiling: torch.ceil,
|
42 |
+
sympy.floor: torch.floor,
|
43 |
+
sympy.log: torch.log,
|
44 |
+
sympy.exp: torch.exp,
|
45 |
+
sympy.sqrt: torch.sqrt,
|
46 |
+
sympy.cos: torch.cos,
|
47 |
+
sympy.acos: torch.acos,
|
48 |
+
sympy.sin: torch.sin,
|
49 |
+
sympy.asin: torch.asin,
|
50 |
+
sympy.tan: torch.tan,
|
51 |
+
sympy.atan: torch.atan,
|
52 |
+
sympy.atan2: torch.atan2,
|
53 |
+
# Note: May give NaN for complex results.
|
54 |
+
sympy.cosh: torch.cosh,
|
55 |
+
sympy.acosh: torch.acosh,
|
56 |
+
sympy.sinh: torch.sinh,
|
57 |
+
sympy.asinh: torch.asinh,
|
58 |
+
sympy.tanh: torch.tanh,
|
59 |
+
sympy.atanh: torch.atanh,
|
60 |
+
sympy.Pow: torch.pow,
|
61 |
+
sympy.re: torch.real,
|
62 |
+
sympy.im: torch.imag,
|
63 |
+
sympy.arg: torch.angle,
|
64 |
+
# Note: May raise error for ints and complexes
|
65 |
+
sympy.erf: torch.erf,
|
66 |
+
sympy.loggamma: torch.lgamma,
|
67 |
+
sympy.Eq: torch.eq,
|
68 |
+
sympy.Ne: torch.ne,
|
69 |
+
sympy.StrictGreaterThan: torch.gt,
|
70 |
+
sympy.StrictLessThan: torch.lt,
|
71 |
+
sympy.LessThan: torch.le,
|
72 |
+
sympy.GreaterThan: torch.ge,
|
73 |
+
sympy.And: torch.logical_and,
|
74 |
+
sympy.Or: torch.logical_or,
|
75 |
+
sympy.Not: torch.logical_not,
|
76 |
+
sympy.Max: torch.max,
|
77 |
+
sympy.Min: torch.min,
|
78 |
+
# Matrices
|
79 |
+
sympy.MatAdd: torch.add,
|
80 |
+
sympy.HadamardProduct: torch.mul,
|
81 |
+
sympy.Trace: torch.trace,
|
82 |
+
# Note: May raise error for integer matrices.
|
83 |
+
sympy.Determinant: torch.det,
|
84 |
+
}
|
85 |
+
|
86 |
+
class _Node(torch.nn.Module):
|
87 |
+
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
|
88 |
+
def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
|
89 |
+
super().__init__(**kwargs)
|
90 |
+
|
91 |
+
self._sympy_func = expr.func
|
92 |
+
|
93 |
+
if issubclass(expr.func, sympy.Float):
|
94 |
+
self._value = torch.nn.Parameter(torch.tensor(float(expr)))
|
95 |
+
self._torch_func = lambda: self._value
|
96 |
+
self._args = ()
|
97 |
+
elif issubclass(expr.func, sympy.UnevaluatedExpr):
|
98 |
+
if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float):
|
99 |
+
raise ValueError("UnevaluatedExpr should only be used to wrap floats.")
|
100 |
+
self.register_buffer('_value', torch.tensor(float(expr.args[0])))
|
101 |
+
self._torch_func = lambda: self._value
|
102 |
+
self._args = ()
|
103 |
+
elif issubclass(expr.func, sympy.Integer):
|
104 |
+
# Can get here if expr is one of the Integer special cases,
|
105 |
+
# e.g. NegativeOne
|
106 |
+
self._value = int(expr)
|
107 |
+
self._torch_func = lambda: self._value
|
108 |
+
self._args = ()
|
109 |
+
elif issubclass(expr.func, sympy.Symbol):
|
110 |
+
self._name = expr.name
|
111 |
+
self._torch_func = lambda value: value
|
112 |
+
self._args = ((lambda memodict: memodict[expr.name]),)
|
113 |
+
else:
|
114 |
+
self._torch_func = _func_lookup[expr.func]
|
115 |
+
args = []
|
116 |
+
for arg in expr.args:
|
117 |
+
try:
|
118 |
+
arg_ = _memodict[arg]
|
119 |
+
except KeyError:
|
120 |
+
arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
|
121 |
+
_memodict[arg] = arg_
|
122 |
+
args.append(arg_)
|
123 |
+
self._args = torch.nn.ModuleList(args)
|
124 |
+
|
125 |
+
def forward(self, memodict):
|
126 |
+
args = []
|
127 |
+
for arg in self._args:
|
128 |
+
try:
|
129 |
+
arg_ = memodict[arg]
|
130 |
+
except KeyError:
|
131 |
+
arg_ = arg(memodict)
|
132 |
+
memodict[arg] = arg_
|
133 |
+
args.append(arg_)
|
134 |
+
return self._torch_func(*args)
|
135 |
+
|
136 |
+
|
137 |
+
class SingleSymPyModule(torch.nn.Module):
|
138 |
+
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
|
139 |
+
def __init__(self, expression, symbols_in,
|
140 |
+
extra_funcs=None, **kwargs):
|
141 |
+
super().__init__(**kwargs)
|
142 |
+
|
143 |
+
if extra_funcs is None:
|
144 |
+
extra_funcs = {}
|
145 |
+
_func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
|
146 |
+
|
147 |
+
_memodict = {}
|
148 |
+
self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
|
149 |
+
self._expression_string = str(expression)
|
150 |
+
self.symbols_in = [str(symbol) for symbol in symbols_in]
|
151 |
+
|
152 |
+
def __repr__(self):
|
153 |
+
return f"{type(self).__name__}(expression={self._expression_string})"
|
154 |
+
|
155 |
+
def forward(self, X):
|
156 |
+
symbols = {symbol: X[:, i]
|
157 |
+
for i, symbol in enumerate(self.symbols_in)}
|
158 |
+
return self._node(symbols)
|
159 |
+
|
160 |
+
|
161 |
+
def sympy2torch(expression, symbols_in, extra_torch_mappings=None):
|
162 |
+
"""Returns a module for a given sympy expression with trainable parameters;
|
163 |
+
|
164 |
+
This function will assume the input to the module is a matrix X, where
|
165 |
+
each column corresponds to each symbol you pass in `symbols_in`.
|
166 |
+
"""
|
167 |
+
global SingleSymPyModule
|
168 |
+
|
169 |
+
_initialize_torch()
|
170 |
+
|
171 |
+
return SingleSymPyModule(expression, symbols_in, extra_funcs=extra_torch_mappings)
|
pysr/sr.py
CHANGED
@@ -13,7 +13,6 @@ import shutil
|
|
13 |
from pathlib import Path
|
14 |
from datetime import datetime
|
15 |
import warnings
|
16 |
-
from .export import sympy2jax
|
17 |
|
18 |
global_equation_file = 'hall_of_fame.csv'
|
19 |
global_n_features = None
|
@@ -103,6 +102,8 @@ def pysr(X, y, weights=None,
|
|
103 |
perturbationFactor=1.0,
|
104 |
timeout=None,
|
105 |
extra_sympy_mappings=None,
|
|
|
|
|
106 |
equation_file=None,
|
107 |
verbosity=1e9,
|
108 |
progress=True,
|
@@ -124,11 +125,12 @@ def pysr(X, y, weights=None,
|
|
124 |
update=True,
|
125 |
temp_equation_file=False,
|
126 |
output_jax_format=False,
|
|
|
127 |
optimizer_algorithm="BFGS",
|
128 |
optimizer_nrestarts=3,
|
129 |
optimize_probability=1.0,
|
130 |
-
optimizer_iterations=10
|
131 |
-
|
132 |
"""Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
|
133 |
Note: most default parameters have been tuned over several example
|
134 |
equations, but you should adjust `niterations`,
|
@@ -241,6 +243,8 @@ def pysr(X, y, weights=None,
|
|
241 |
delete_tempfiles argument.
|
242 |
:param output_jax_format: Whether to create a 'jax_format' column in the output,
|
243 |
containing jax-callable functions and the default parameters in a jax array.
|
|
|
|
|
244 |
:returns: pd.DataFrame or list, Results dataframe,
|
245 |
giving complexity, MSE, and equations (as strings), as well as functional
|
246 |
forms. If list, each element corresponds to a dataframe of equations
|
@@ -334,8 +338,11 @@ def pysr(X, y, weights=None,
|
|
334 |
weightSimplify=weightSimplify,
|
335 |
constraints=constraints,
|
336 |
extra_sympy_mappings=extra_sympy_mappings,
|
|
|
|
|
337 |
julia_project=julia_project, loss=loss,
|
338 |
output_jax_format=output_jax_format,
|
|
|
339 |
multioutput=multioutput, nout=nout)
|
340 |
|
341 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
@@ -715,10 +722,10 @@ def run_feature_selection(X, y, select_k_features):
|
|
715 |
the k most important features in X, returning indices for those
|
716 |
features as output."""
|
717 |
|
718 |
-
from sklearn.ensemble import RandomForestRegressor
|
719 |
from sklearn.feature_selection import SelectFromModel, SelectKBest
|
720 |
|
721 |
-
clf =
|
722 |
clf.fit(X, y)
|
723 |
selector = SelectFromModel(clf, threshold=-np.inf,
|
724 |
max_features=select_k_features, prefit=True)
|
@@ -726,6 +733,8 @@ def run_feature_selection(X, y, select_k_features):
|
|
726 |
|
727 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
728 |
extra_sympy_mappings=None, output_jax_format=False,
|
|
|
|
|
729 |
multioutput=None, nout=None, **kwargs):
|
730 |
"""Get the equations from a hall of fame file. If no arguments
|
731 |
entered, the ones used previously from a call to PySR will be used."""
|
@@ -770,6 +779,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
770 |
lambda_format = []
|
771 |
if output_jax_format:
|
772 |
jax_format = []
|
|
|
|
|
773 |
use_custom_variable_names = (len(variable_names) != 0)
|
774 |
local_sympy_mappings = {
|
775 |
**extra_sympy_mappings,
|
@@ -784,11 +795,22 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
784 |
for i in range(len(output)):
|
785 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
786 |
sympy_format.append(eqn)
|
|
|
|
|
|
|
|
|
|
|
787 |
if output_jax_format:
|
|
|
788 |
func, params = sympy2jax(eqn, sympy_symbols)
|
789 |
jax_format.append({'callable': func, 'parameters': params})
|
790 |
|
791 |
-
|
|
|
|
|
|
|
|
|
|
|
792 |
curMSE = output.loc[i, 'MSE']
|
793 |
curComplexity = output.loc[i, 'Complexity']
|
794 |
|
@@ -808,6 +830,9 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
808 |
if output_jax_format:
|
809 |
output_cols += ['jax_format']
|
810 |
output['jax_format'] = jax_format
|
|
|
|
|
|
|
811 |
|
812 |
ret_outputs.append(output[output_cols])
|
813 |
|
|
|
13 |
from pathlib import Path
|
14 |
from datetime import datetime
|
15 |
import warnings
|
|
|
16 |
|
17 |
global_equation_file = 'hall_of_fame.csv'
|
18 |
global_n_features = None
|
|
|
102 |
perturbationFactor=1.0,
|
103 |
timeout=None,
|
104 |
extra_sympy_mappings=None,
|
105 |
+
extra_torch_mappings=None,
|
106 |
+
extra_jax_mappings=None,
|
107 |
equation_file=None,
|
108 |
verbosity=1e9,
|
109 |
progress=True,
|
|
|
125 |
update=True,
|
126 |
temp_equation_file=False,
|
127 |
output_jax_format=False,
|
128 |
+
output_torch_format=False,
|
129 |
optimizer_algorithm="BFGS",
|
130 |
optimizer_nrestarts=3,
|
131 |
optimize_probability=1.0,
|
132 |
+
optimizer_iterations=10
|
133 |
+
):
|
134 |
"""Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
|
135 |
Note: most default parameters have been tuned over several example
|
136 |
equations, but you should adjust `niterations`,
|
|
|
243 |
delete_tempfiles argument.
|
244 |
:param output_jax_format: Whether to create a 'jax_format' column in the output,
|
245 |
containing jax-callable functions and the default parameters in a jax array.
|
246 |
+
:param output_torch_format: Whether to create a 'torch_format' column in the output,
|
247 |
+
containing a torch module with trainable parameters.
|
248 |
:returns: pd.DataFrame or list, Results dataframe,
|
249 |
giving complexity, MSE, and equations (as strings), as well as functional
|
250 |
forms. If list, each element corresponds to a dataframe of equations
|
|
|
338 |
weightSimplify=weightSimplify,
|
339 |
constraints=constraints,
|
340 |
extra_sympy_mappings=extra_sympy_mappings,
|
341 |
+
extra_jax_mappings=extra_jax_mappings,
|
342 |
+
extra_torch_mappings=extra_torch_mappings,
|
343 |
julia_project=julia_project, loss=loss,
|
344 |
output_jax_format=output_jax_format,
|
345 |
+
output_torch_format=output_torch_format,
|
346 |
multioutput=multioutput, nout=nout)
|
347 |
|
348 |
kwargs = {**_set_paths(tempdir), **kwargs}
|
|
|
722 |
the k most important features in X, returning indices for those
|
723 |
features as output."""
|
724 |
|
725 |
+
from sklearn.ensemble import RandomForestRegressor
|
726 |
from sklearn.feature_selection import SelectFromModel, SelectKBest
|
727 |
|
728 |
+
clf = RandomForestRegressor(n_estimators=100, max_depth=3, random_state=0)
|
729 |
clf.fit(X, y)
|
730 |
selector = SelectFromModel(clf, threshold=-np.inf,
|
731 |
max_features=select_k_features, prefit=True)
|
|
|
733 |
|
734 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
735 |
extra_sympy_mappings=None, output_jax_format=False,
|
736 |
+
output_torch_format=False,
|
737 |
+
extra_jax_mappings=None, extra_torch_mappings=None,
|
738 |
multioutput=None, nout=None, **kwargs):
|
739 |
"""Get the equations from a hall of fame file. If no arguments
|
740 |
entered, the ones used previously from a call to PySR will be used."""
|
|
|
779 |
lambda_format = []
|
780 |
if output_jax_format:
|
781 |
jax_format = []
|
782 |
+
if output_torch_format:
|
783 |
+
torch_format = []
|
784 |
use_custom_variable_names = (len(variable_names) != 0)
|
785 |
local_sympy_mappings = {
|
786 |
**extra_sympy_mappings,
|
|
|
795 |
for i in range(len(output)):
|
796 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
797 |
sympy_format.append(eqn)
|
798 |
+
|
799 |
+
# Numpy:
|
800 |
+
lambda_format.append(CallableEquation(sympy_symbols, eqn))
|
801 |
+
|
802 |
+
# JAX:
|
803 |
if output_jax_format:
|
804 |
+
from .export_jax import sympy2jax
|
805 |
func, params = sympy2jax(eqn, sympy_symbols)
|
806 |
jax_format.append({'callable': func, 'parameters': params})
|
807 |
|
808 |
+
# Torch:
|
809 |
+
if output_torch_format:
|
810 |
+
from .export_torch import sympy2torch
|
811 |
+
module = sympy2torch(eqn, sympy_symbols)
|
812 |
+
torch_format.append(module)
|
813 |
+
|
814 |
curMSE = output.loc[i, 'MSE']
|
815 |
curComplexity = output.loc[i, 'Complexity']
|
816 |
|
|
|
830 |
if output_jax_format:
|
831 |
output_cols += ['jax_format']
|
832 |
output['jax_format'] = jax_format
|
833 |
+
if output_torch_format:
|
834 |
+
output_cols += ['torch_format']
|
835 |
+
output['torch_format'] = torch_format
|
836 |
|
837 |
ret_outputs.append(output[output_cols])
|
838 |
|
test/test_jax.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import unittest
|
2 |
import numpy as np
|
3 |
-
from pysr import
|
|
|
4 |
from jax import numpy as jnp
|
5 |
from jax import random
|
6 |
from jax import grad
|
@@ -15,3 +16,24 @@ class TestJAX(unittest.TestCase):
|
|
15 |
true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
|
16 |
f, params = sympy2jax(cosx, [x, y, z])
|
17 |
self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import unittest
|
2 |
import numpy as np
|
3 |
+
from pysr import sympy2jax, get_hof
|
4 |
+
import pandas as pd
|
5 |
from jax import numpy as jnp
|
6 |
from jax import random
|
7 |
from jax import grad
|
|
|
16 |
true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
|
17 |
f, params = sympy2jax(cosx, [x, y, z])
|
18 |
self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
|
19 |
+
def test_pipeline(self):
|
20 |
+
X = np.random.randn(100, 2)
|
21 |
+
equations = pd.DataFrame({
|
22 |
+
'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
|
23 |
+
'MSE': [1.0, 0.1, 1e-5],
|
24 |
+
'Complexity': [1, 2, 3]
|
25 |
+
})
|
26 |
+
|
27 |
+
equations['Complexity MSE Equation'.split(' ')].to_csv(
|
28 |
+
'equation_file.csv.bkup', sep='|')
|
29 |
+
|
30 |
+
equations = get_hof(
|
31 |
+
'equation_file.csv', n_features=2, variables_names='x0 x1'.split(' '),
|
32 |
+
extra_sympy_mappings={}, output_jax_format=True,
|
33 |
+
multioutput=False, nout=1)
|
34 |
+
|
35 |
+
jformat = equations.iloc[-1].jax_format
|
36 |
+
np.testing.assert_almost_equal(
|
37 |
+
np.array(jformat['callable'](jnp.array(X), jformat['parameters'])),
|
38 |
+
np.square(np.cos(X[:, 0]))
|
39 |
+
)
|
test/test_torch.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
from pysr import sympy2torch, get_hof
|
5 |
+
import torch
|
6 |
+
import sympy
|
7 |
+
|
8 |
+
class TestTorch(unittest.TestCase):
|
9 |
+
def test_sympy2torch(self):
|
10 |
+
x, y, z = sympy.symbols('x y z')
|
11 |
+
cosx = 1.0 * sympy.cos(x) + y
|
12 |
+
X = torch.randn((1000, 3))
|
13 |
+
true = 1.0 * torch.cos(X[:, 0]) + X[:, 1]
|
14 |
+
torch_module = sympy2torch(cosx, [x, y, z])
|
15 |
+
self.assertTrue(
|
16 |
+
np.all(np.isclose(torch_module(X).detach().numpy(), true.detach().numpy()))
|
17 |
+
)
|
18 |
+
def test_pipeline(self):
|
19 |
+
X = np.random.randn(100, 2)
|
20 |
+
equations = pd.DataFrame({
|
21 |
+
'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
|
22 |
+
'MSE': [1.0, 0.1, 1e-5],
|
23 |
+
'Complexity': [1, 2, 3]
|
24 |
+
})
|
25 |
+
|
26 |
+
equations['Complexity MSE Equation'.split(' ')].to_csv(
|
27 |
+
'equation_file.csv.bkup', sep='|')
|
28 |
+
|
29 |
+
equations = get_hof(
|
30 |
+
'equation_file.csv', n_features=2, variables_names='x0 x1'.split(' '),
|
31 |
+
extra_sympy_mappings={}, output_torch_format=True,
|
32 |
+
multioutput=False, nout=1)
|
33 |
+
|
34 |
+
tformat = equations.iloc[-1].torch_format
|
35 |
+
np.testing.assert_almost_equal(
|
36 |
+
tformat(torch.tensor(X)).detach().numpy(),
|
37 |
+
np.square(np.cos(X[:, 0]))
|
38 |
+
)
|