Spaces:
Running
Running
fix: symbolic numbers in jax
Browse files- pysr/export_jax.py +3 -1
- pysr/test/test_jax.py +36 -8
pysr/export_jax.py
CHANGED
@@ -55,7 +55,9 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
|
|
55 |
if issubclass(expr.func, sympy.Float):
|
56 |
parameters.append(float(expr))
|
57 |
return f"parameters[{len(parameters) - 1}]"
|
58 |
-
elif issubclass(expr.func, sympy.Rational)
|
|
|
|
|
59 |
return f"{float(expr)}"
|
60 |
elif issubclass(expr.func, sympy.Integer):
|
61 |
return f"{int(expr)}"
|
|
|
55 |
if issubclass(expr.func, sympy.Float):
|
56 |
parameters.append(float(expr))
|
57 |
return f"parameters[{len(parameters) - 1}]"
|
58 |
+
elif issubclass(expr.func, sympy.Rational) or issubclass(
|
59 |
+
expr.func, sympy.NumberSymbol
|
60 |
+
):
|
61 |
return f"{float(expr)}"
|
62 |
elif issubclass(expr.func, sympy.Integer):
|
63 |
return f"{int(expr)}"
|
pysr/test/test_jax.py
CHANGED
@@ -5,27 +5,29 @@ import numpy as np
|
|
5 |
import pandas as pd
|
6 |
import sympy
|
7 |
|
|
|
8 |
from pysr import PySRRegressor, sympy2jax
|
9 |
|
10 |
|
11 |
class TestJAX(unittest.TestCase):
|
12 |
def setUp(self):
|
13 |
np.random.seed(0)
|
|
|
|
|
|
|
14 |
|
15 |
def test_sympy2jax(self):
|
16 |
-
from jax import numpy as jnp
|
17 |
from jax import random
|
18 |
|
19 |
x, y, z = sympy.symbols("x y z")
|
20 |
cosx = 1.0 * sympy.cos(x) + y
|
21 |
key = random.PRNGKey(0)
|
22 |
X = random.normal(key, (1000, 2))
|
23 |
-
true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
|
24 |
f, params = sympy2jax(cosx, [x, y, z])
|
25 |
-
self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
|
26 |
|
27 |
def test_pipeline_pandas(self):
|
28 |
-
from jax import numpy as jnp
|
29 |
|
30 |
X = pd.DataFrame(np.random.randn(100, 10))
|
31 |
y = np.ones(X.shape[0])
|
@@ -52,14 +54,12 @@ class TestJAX(unittest.TestCase):
|
|
52 |
jformat = model.jax()
|
53 |
|
54 |
np.testing.assert_almost_equal(
|
55 |
-
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
|
56 |
np.square(np.cos(X.values[:, 1])), # Select feature 1
|
57 |
decimal=3,
|
58 |
)
|
59 |
|
60 |
def test_pipeline(self):
|
61 |
-
from jax import numpy as jnp
|
62 |
-
|
63 |
X = np.random.randn(100, 10)
|
64 |
y = np.ones(X.shape[0])
|
65 |
model = PySRRegressor(progress=False, max_evals=10000, output_jax_format=True)
|
@@ -81,11 +81,39 @@ class TestJAX(unittest.TestCase):
|
|
81 |
jformat = model.jax()
|
82 |
|
83 |
np.testing.assert_almost_equal(
|
84 |
-
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
|
85 |
np.square(np.cos(X[:, 1])), # Select feature 1
|
86 |
decimal=3,
|
87 |
)
|
88 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
89 |
def test_feature_selection_custom_operators(self):
|
90 |
rstate = np.random.RandomState(0)
|
91 |
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
|
|
|
5 |
import pandas as pd
|
6 |
import sympy
|
7 |
|
8 |
+
import pysr
|
9 |
from pysr import PySRRegressor, sympy2jax
|
10 |
|
11 |
|
12 |
class TestJAX(unittest.TestCase):
|
13 |
def setUp(self):
|
14 |
np.random.seed(0)
|
15 |
+
from jax import numpy as jnp
|
16 |
+
|
17 |
+
self.jnp = jnp
|
18 |
|
19 |
def test_sympy2jax(self):
|
|
|
20 |
from jax import random
|
21 |
|
22 |
x, y, z = sympy.symbols("x y z")
|
23 |
cosx = 1.0 * sympy.cos(x) + y
|
24 |
key = random.PRNGKey(0)
|
25 |
X = random.normal(key, (1000, 2))
|
26 |
+
true = 1.0 * self.jnp.cos(X[:, 0]) + X[:, 1]
|
27 |
f, params = sympy2jax(cosx, [x, y, z])
|
28 |
+
self.assertTrue(self.jnp.all(self.jnp.isclose(f(X, params), true)).item())
|
29 |
|
30 |
def test_pipeline_pandas(self):
|
|
|
31 |
|
32 |
X = pd.DataFrame(np.random.randn(100, 10))
|
33 |
y = np.ones(X.shape[0])
|
|
|
54 |
jformat = model.jax()
|
55 |
|
56 |
np.testing.assert_almost_equal(
|
57 |
+
np.array(jformat["callable"](self.jnp.array(X), jformat["parameters"])),
|
58 |
np.square(np.cos(X.values[:, 1])), # Select feature 1
|
59 |
decimal=3,
|
60 |
)
|
61 |
|
62 |
def test_pipeline(self):
|
|
|
|
|
63 |
X = np.random.randn(100, 10)
|
64 |
y = np.ones(X.shape[0])
|
65 |
model = PySRRegressor(progress=False, max_evals=10000, output_jax_format=True)
|
|
|
81 |
jformat = model.jax()
|
82 |
|
83 |
np.testing.assert_almost_equal(
|
84 |
+
np.array(jformat["callable"](self.jnp.array(X), jformat["parameters"])),
|
85 |
np.square(np.cos(X[:, 1])), # Select feature 1
|
86 |
decimal=3,
|
87 |
)
|
88 |
|
89 |
+
def test_avoid_simplification(self):
|
90 |
+
ex = pysr.export_sympy.pysr2sympy(
|
91 |
+
"square(exp(sign(0.44796443))) + 1.5 * x1",
|
92 |
+
feature_names_in=["x1"],
|
93 |
+
extra_sympy_mappings={"square": lambda x: x**2},
|
94 |
+
)
|
95 |
+
f, params = pysr.export_jax.sympy2jax(ex, [sympy.symbols("x1")])
|
96 |
+
key = np.random.RandomState(0)
|
97 |
+
X = key.randn(10, 1)
|
98 |
+
np.testing.assert_almost_equal(
|
99 |
+
np.array(f(self.jnp.array(X), params)),
|
100 |
+
np.square(np.exp(np.sign(0.44796443))) + 1.5 * X[:, 0],
|
101 |
+
decimal=3,
|
102 |
+
)
|
103 |
+
|
104 |
+
def test_issue_656(self):
|
105 |
+
import sympy
|
106 |
+
|
107 |
+
E_plus_x1 = sympy.exp(1) + sympy.symbols("x1")
|
108 |
+
f, params = pysr.export_jax.sympy2jax(E_plus_x1, [sympy.symbols("x1")])
|
109 |
+
key = np.random.RandomState(0)
|
110 |
+
X = key.randn(10, 1)
|
111 |
+
np.testing.assert_almost_equal(
|
112 |
+
np.array(f(self.jnp.array(X), params)),
|
113 |
+
np.exp(1) + X[:, 0],
|
114 |
+
decimal=3,
|
115 |
+
)
|
116 |
+
|
117 |
def test_feature_selection_custom_operators(self):
|
118 |
rstate = np.random.RandomState(0)
|
119 |
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
|