Spaces:
Running
Running
File size: 3,719 Bytes
2f38c9c 41e5fd5 a0c6429 9bfcbfa 41e5fd5 ce5b119 41e5fd5 7d4300a 2f38c9c 51a6b05 2f38c9c 7d4300a 2f38c9c 7d4300a c7187a6 fbb7cf7 4b56660 fbb7cf7 c7187a6 fbb7cf7 c7187a6 9bfcbfa b07eb2d fbb7cf7 4b56660 fbb7cf7 7d4300a b444c7e 7d4300a 9bfcbfa 7d4300a 9bfcbfa fbb7cf7 d398bf9 9bfcbfa 7d4300a 9bfcbfa ce5b119 7cda629 beaf20b 7cda629 beaf20b ce5b119 4b56660 7cda629 4b56660 7cda629 4b56660 7cda629 beaf20b ce5b119 beaf20b ce5b119 beaf20b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import unittest
import numpy as np
from pysr import sympy2jax, PySRRegressor
import pandas as pd
from jax import numpy as jnp
from jax import random
import sympy
from functools import partial
class TestJAX(unittest.TestCase):
def setUp(self):
np.random.seed(0)
def test_sympy2jax(self):
x, y, z = sympy.symbols("x y z")
cosx = 1.0 * sympy.cos(x) + y
key = random.PRNGKey(0)
X = random.normal(key, (1000, 2))
true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
f, params = sympy2jax(cosx, [x, y, z])
self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
def test_pipeline_pandas(self):
X = pd.DataFrame(np.random.randn(100, 10))
y = np.ones(X.shape[0])
model = PySRRegressor(
progress=False,
max_evals=10000,
output_jax_format=True,
)
model.fit(X, y)
equations = pd.DataFrame(
{
"Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
"MSE": [1.0, 0.1, 1e-5],
"Complexity": [1, 2, 3],
}
)
equations["Complexity MSE Equation".split(" ")].to_csv(
"equation_file.csv.bkup", sep="|"
)
model.refresh(checkpoint_file="equation_file.csv")
jformat = model.jax()
np.testing.assert_almost_equal(
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
np.square(np.cos(X.values[:, 1])), # Select feature 1
decimal=4,
)
def test_pipeline(self):
X = np.random.randn(100, 10)
y = np.ones(X.shape[0])
model = PySRRegressor(progress=False, max_evals=10000, output_jax_format=True)
model.fit(X, y)
equations = pd.DataFrame(
{
"Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
"MSE": [1.0, 0.1, 1e-5],
"Complexity": [1, 2, 3],
}
)
equations["Complexity MSE Equation".split(" ")].to_csv(
"equation_file.csv.bkup", sep="|"
)
model.refresh(checkpoint_file="equation_file.csv")
jformat = model.jax()
np.testing.assert_almost_equal(
np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
np.square(np.cos(X[:, 1])), # Select feature 1
decimal=4,
)
def test_feature_selection_custom_operators(self):
rstate = np.random.RandomState(0)
X = pd.DataFrame({f"k{i}": rstate.randn(2000) for i in range(10, 21)})
cos_approx = lambda x: 1 - (x**2) / 2 + (x**4) / 24 + (x**6) / 720
y = X["k15"] ** 2 + 2 * cos_approx(X["k20"])
model = PySRRegressor(
progress=False,
unary_operators=["cos_approx(x) = 1 - x^2 / 2 + x^4 / 24 + x^6 / 720"],
select_k_features=3,
maxsize=10,
early_stop_condition=1e-5,
extra_sympy_mappings={"cos_approx": cos_approx},
extra_jax_mappings={
"cos_approx": "(lambda x: 1 - x**2 / 2 + x**4 / 24 + x**6 / 720)"
},
random_state=0,
deterministic=True,
procs=0,
multithreading=False,
)
np.random.seed(0)
model.fit(X.values, y.values)
f, parameters = model.jax().values()
np_prediction = model.predict
jax_prediction = partial(f, parameters=parameters)
np_output = np_prediction(X.values)
jax_output = jax_prediction(X.values)
np.testing.assert_almost_equal(y.values, np_output, decimal=4)
np.testing.assert_almost_equal(y.values, jax_output, decimal=4)
|