Spaces:
Running
Running
File size: 1,620 Bytes
2f38c9c 41e5fd5 a0c6429 9bfcbfa 41e5fd5 7d4300a 2f38c9c 51a6b05 2f38c9c 7d4300a 2f38c9c 7d4300a 9bfcbfa b07eb2d 7d4300a b444c7e 7d4300a 9bfcbfa 7d4300a 9bfcbfa a0c6429 7d4300a a0c6429 7d4300a 9bfcbfa a0c6429 b444c7e a0c6429 d398bf9 9bfcbfa 7d4300a 9bfcbfa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import unittest
import numpy as np
from pysr import sympy2jax, PySRRegressor
import pandas as pd
from jax import numpy as jnp
from jax import random
from jax import grad
import sympy
class TestJAX(unittest.TestCase):
def setUp(self):
np.random.seed(0)
def test_sympy2jax(self):
x, y, z = sympy.symbols("x y z")
cosx = 1.0 * sympy.cos(x) + y
key = random.PRNGKey(0)
X = random.normal(key, (1000, 2))
true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
f, params = sympy2jax(cosx, [x, y, z])
self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
def test_pipeline(self):
X = np.random.randn(100, 10)
equations = pd.DataFrame(
{
"Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
"MSE": [1.0, 0.1, 1e-5],
"Complexity": [1, 2, 3],
}
)
equations["Complexity MSE Equation".split(" ")].to_csv(
"equation_file.csv.bkup", sep="|"
)
model = PySRRegressor(
equation_file="equation_file.csv",
output_jax_format=True,
variables_names="x1 x2 x3".split(" "),
multioutput=False,
nout=1,
selection=[1, 2, 3],
)
model.n_features = 2
model.using_pandas = False
model.refresh()
jformat = model.jax()
np.testing.assert_almost_equal(
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
np.square(np.cos(X[:, 1])), # Select feature 1
decimal=4,
)
|