File size: 1,620 Bytes
2f38c9c
41e5fd5
a0c6429
9bfcbfa
41e5fd5
 
 
 
 
7d4300a
2f38c9c
51a6b05
 
 
2f38c9c
7d4300a
2f38c9c
 
 
 
 
 
7d4300a
9bfcbfa
b07eb2d
7d4300a
 
b444c7e
7d4300a
 
 
 
9bfcbfa
7d4300a
 
 
9bfcbfa
a0c6429
 
7d4300a
a0c6429
7d4300a
 
 
 
9bfcbfa
a0c6429
b444c7e
a0c6429
d398bf9
 
9bfcbfa
7d4300a
 
 
9bfcbfa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import unittest
import numpy as np
from pysr import sympy2jax, PySRRegressor
import pandas as pd
from jax import numpy as jnp
from jax import random
from jax import grad
import sympy


class TestJAX(unittest.TestCase):
    def setUp(self):
        np.random.seed(0)

    def test_sympy2jax(self):
        x, y, z = sympy.symbols("x y z")
        cosx = 1.0 * sympy.cos(x) + y
        key = random.PRNGKey(0)
        X = random.normal(key, (1000, 2))
        true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
        f, params = sympy2jax(cosx, [x, y, z])
        self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())

    def test_pipeline(self):
        X = np.random.randn(100, 10)
        equations = pd.DataFrame(
            {
                "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
                "MSE": [1.0, 0.1, 1e-5],
                "Complexity": [1, 2, 3],
            }
        )

        equations["Complexity MSE Equation".split(" ")].to_csv(
            "equation_file.csv.bkup", sep="|"
        )

        model = PySRRegressor(
            equation_file="equation_file.csv",
            output_jax_format=True,
            variables_names="x1 x2 x3".split(" "),
            multioutput=False,
            nout=1,
            selection=[1, 2, 3],
        )

        model.n_features = 2
        model.using_pandas = False
        model.refresh()
        jformat = model.jax()

        np.testing.assert_almost_equal(
            np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
            np.square(np.cos(X[:, 1])),  # Select feature 1
            decimal=4,
        )