MilesCranmer commited on
Commit
a0c6429
1 Parent(s): 90d24f5

Fix JAX test

Browse files
Files changed (1) hide show
  1. test/test_jax.py +6 -7
test/test_jax.py CHANGED
@@ -1,6 +1,6 @@
1
  import unittest
2
  import numpy as np
3
- from pysr import sympy2jax, get_hof, PySRRegressor
4
  import pandas as pd
5
  from jax import numpy as jnp
6
  from jax import random
@@ -35,18 +35,17 @@ class TestJAX(unittest.TestCase):
35
  "equation_file.csv.bkup", sep="|"
36
  )
37
 
38
- equations = get_hof(
39
- "equation_file.csv",
40
- n_features=2,
41
- variables_names="x1 x2 x3".split(" "),
42
- extra_sympy_mappings={},
43
  output_jax_format=True,
 
44
  multioutput=False,
45
  nout=1,
46
  selection=[1, 2, 3],
47
  )
48
 
49
- model = PySRRegressor()
 
50
  jformat = model.jax()
51
 
52
  np.testing.assert_almost_equal(
 
1
  import unittest
2
  import numpy as np
3
+ from pysr import sympy2jax, PySRRegressor
4
  import pandas as pd
5
  from jax import numpy as jnp
6
  from jax import random
 
35
  "equation_file.csv.bkup", sep="|"
36
  )
37
 
38
+ model = PySRRegressor(
39
+ equation_file="equation_file.csv",
 
 
 
40
  output_jax_format=True,
41
+ variables_names="x1 x2 x3".split(" "),
42
  multioutput=False,
43
  nout=1,
44
  selection=[1, 2, 3],
45
  )
46
 
47
+ model.n_features = 2
48
+ model.refresh()
49
  jformat = model.jax()
50
 
51
  np.testing.assert_almost_equal(