MilesCranmer commited on
Commit
d398bf9
·
1 Parent(s): c9f1ebd

Add PySRRegressor versions of jax/torch tests

Browse files
Files changed (2) hide show
  1. test/test_jax.py +4 -2
  2. test/test_torch.py +6 -3
test/test_jax.py CHANGED
@@ -1,6 +1,6 @@
1
  import unittest
2
  import numpy as np
3
- from pysr import sympy2jax, get_hof
4
  import pandas as pd
5
  from jax import numpy as jnp
6
  from jax import random
@@ -46,7 +46,9 @@ class TestJAX(unittest.TestCase):
46
  selection=[1, 2, 3],
47
  )
48
 
49
- jformat = equations.iloc[-1].jax_format
 
 
50
  np.testing.assert_almost_equal(
51
  np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
52
  np.square(np.cos(X[:, 1])), # Select feature 1
 
1
  import unittest
2
  import numpy as np
3
+ from pysr import sympy2jax, get_hof, PySRRegressor
4
  import pandas as pd
5
  from jax import numpy as jnp
6
  from jax import random
 
46
  selection=[1, 2, 3],
47
  )
48
 
49
+ model = PySRRegressor()
50
+ jformat = model.jax()
51
+
52
  np.testing.assert_almost_equal(
53
  np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
54
  np.square(np.cos(X[:, 1])), # Select feature 1
test/test_torch.py CHANGED
@@ -1,7 +1,7 @@
1
  import unittest
2
  import numpy as np
3
  import pandas as pd
4
- from pysr import sympy2torch, get_hof
5
  import torch
6
  import sympy
7
 
@@ -84,7 +84,7 @@ class TestTorch(unittest.TestCase):
84
  "equation_file_custom_operator.csv.bkup", sep="|"
85
  )
86
 
87
- equations = get_hof(
88
  "equation_file_custom_operator.csv",
89
  n_features=3,
90
  variables_names="x1 x2 x3".split(" "),
@@ -96,7 +96,10 @@ class TestTorch(unittest.TestCase):
96
  selection=[0, 1, 2],
97
  )
98
 
99
- tformat = equations.iloc[-1].torch_format
 
 
 
100
  np.testing.assert_almost_equal(
101
  tformat(torch.tensor(X)).detach().numpy(),
102
  np.sin(X[:, 0]), # Selection 1st feature
 
1
  import unittest
2
  import numpy as np
3
  import pandas as pd
4
+ from pysr import sympy2torch, get_hof, PySRRegressor
5
  import torch
6
  import sympy
7
 
 
84
  "equation_file_custom_operator.csv.bkup", sep="|"
85
  )
86
 
87
+ get_hof(
88
  "equation_file_custom_operator.csv",
89
  n_features=3,
90
  variables_names="x1 x2 x3".split(" "),
 
96
  selection=[0, 1, 2],
97
  )
98
 
99
+ model = PySRRegressor()
100
+ # Will automatically use the set global state from get_hof.
101
+ tformat = model.pytorch()
102
+
103
  np.testing.assert_almost_equal(
104
  tformat(torch.tensor(X)).detach().numpy(),
105
  np.sin(X[:, 0]), # Selection 1st feature