File size: 3,719 Bytes
2f38c9c
41e5fd5
a0c6429
9bfcbfa
41e5fd5
 
 
ce5b119
41e5fd5
7d4300a
2f38c9c
51a6b05
 
 
2f38c9c
7d4300a
2f38c9c
 
 
 
 
 
7d4300a
c7187a6
 
fbb7cf7
 
4b56660
fbb7cf7
 
 
 
 
c7187a6
 
 
 
 
 
 
 
 
 
 
 
fbb7cf7
c7187a6
 
 
 
 
 
 
 
9bfcbfa
b07eb2d
fbb7cf7
4b56660
fbb7cf7
 
7d4300a
 
b444c7e
7d4300a
 
 
 
9bfcbfa
7d4300a
 
 
9bfcbfa
fbb7cf7
d398bf9
 
9bfcbfa
7d4300a
 
 
9bfcbfa
ce5b119
7cda629
beaf20b
 
7cda629
beaf20b
ce5b119
 
4b56660
7cda629
4b56660
7cda629
4b56660
7cda629
 
 
 
beaf20b
 
 
 
ce5b119
beaf20b
ce5b119
 
 
 
 
 
 
 
 
beaf20b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import unittest
import numpy as np
from pysr import sympy2jax, PySRRegressor
import pandas as pd
from jax import numpy as jnp
from jax import random
import sympy
from functools import partial


class TestJAX(unittest.TestCase):
    def setUp(self):
        np.random.seed(0)

    def test_sympy2jax(self):
        x, y, z = sympy.symbols("x y z")
        cosx = 1.0 * sympy.cos(x) + y
        key = random.PRNGKey(0)
        X = random.normal(key, (1000, 2))
        true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
        f, params = sympy2jax(cosx, [x, y, z])
        self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())

    def test_pipeline_pandas(self):
        X = pd.DataFrame(np.random.randn(100, 10))
        y = np.ones(X.shape[0])
        model = PySRRegressor(
            progress=False,
            max_evals=10000,
            output_jax_format=True,
        )
        model.fit(X, y)

        equations = pd.DataFrame(
            {
                "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
                "MSE": [1.0, 0.1, 1e-5],
                "Complexity": [1, 2, 3],
            }
        )

        equations["Complexity MSE Equation".split(" ")].to_csv(
            "equation_file.csv.bkup", sep="|"
        )

        model.refresh(checkpoint_file="equation_file.csv")
        jformat = model.jax()

        np.testing.assert_almost_equal(
            np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
            np.square(np.cos(X.values[:, 1])),  # Select feature 1
            decimal=4,
        )

    def test_pipeline(self):
        X = np.random.randn(100, 10)
        y = np.ones(X.shape[0])
        model = PySRRegressor(progress=False, max_evals=10000, output_jax_format=True)
        model.fit(X, y)

        equations = pd.DataFrame(
            {
                "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
                "MSE": [1.0, 0.1, 1e-5],
                "Complexity": [1, 2, 3],
            }
        )

        equations["Complexity MSE Equation".split(" ")].to_csv(
            "equation_file.csv.bkup", sep="|"
        )

        model.refresh(checkpoint_file="equation_file.csv")
        jformat = model.jax()

        np.testing.assert_almost_equal(
            np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
            np.square(np.cos(X[:, 1])),  # Select feature 1
            decimal=4,
        )

    def test_feature_selection_custom_operators(self):
        rstate = np.random.RandomState(0)
        X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
        cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
        y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])

        model = PySRRegressor(
            progress=False,
            unary_operators=["cos_approx(x) = 1 - x^2 / 2 + x^4 / 24 + x^6 / 720"],
            select_k_features=3,
            maxsize=10,
            early_stop_condition=1e-5,
            extra_sympy_mappings={"cos_approx": cos_approx},
            extra_jax_mappings={
                "cos_approx": "(lambda x: 1 - x**2 / 2 + x**4 / 24 + x**6 / 720)"
            },
            random_state=0,
            deterministic=True,
            procs=0,
            multithreading=False,
        )
        np.random.seed(0)
        model.fit(X.values, y.values)
        f, parameters = model.jax().values()

        np_prediction = model.predict
        jax_prediction = partial(f, parameters=parameters)

        np_output = np_prediction(X.values)
        jax_output = jax_prediction(X.values)

        np.testing.assert_almost_equal(y.values, np_output, decimal=4)
        np.testing.assert_almost_equal(y.values, jax_output, decimal=4)