Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
a0c6429
1
Parent(s):
90d24f5
Fix JAX test
Browse files- test/test_jax.py +6 -7
test/test_jax.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
import unittest
|
2 |
import numpy as np
|
3 |
-
from pysr import sympy2jax,
|
4 |
import pandas as pd
|
5 |
from jax import numpy as jnp
|
6 |
from jax import random
|
@@ -35,18 +35,17 @@ class TestJAX(unittest.TestCase):
|
|
35 |
"equation_file.csv.bkup", sep="|"
|
36 |
)
|
37 |
|
38 |
-
|
39 |
-
"equation_file.csv",
|
40 |
-
n_features=2,
|
41 |
-
variables_names="x1 x2 x3".split(" "),
|
42 |
-
extra_sympy_mappings={},
|
43 |
output_jax_format=True,
|
|
|
44 |
multioutput=False,
|
45 |
nout=1,
|
46 |
selection=[1, 2, 3],
|
47 |
)
|
48 |
|
49 |
-
model =
|
|
|
50 |
jformat = model.jax()
|
51 |
|
52 |
np.testing.assert_almost_equal(
|
|
|
1 |
import unittest
|
2 |
import numpy as np
|
3 |
+
from pysr import sympy2jax, PySRRegressor
|
4 |
import pandas as pd
|
5 |
from jax import numpy as jnp
|
6 |
from jax import random
|
|
|
35 |
"equation_file.csv.bkup", sep="|"
|
36 |
)
|
37 |
|
38 |
+
model = PySRRegressor(
|
39 |
+
equation_file="equation_file.csv",
|
|
|
|
|
|
|
40 |
output_jax_format=True,
|
41 |
+
variables_names="x1 x2 x3".split(" "),
|
42 |
multioutput=False,
|
43 |
nout=1,
|
44 |
selection=[1, 2, 3],
|
45 |
)
|
46 |
|
47 |
+
model.n_features = 2
|
48 |
+
model.refresh()
|
49 |
jformat = model.jax()
|
50 |
|
51 |
np.testing.assert_almost_equal(
|