MilesCranmer commited on
Commit
2f38c9c
1 Parent(s): 253dd65

Move testing to unittest format

Browse files
Files changed (2) hide show
  1. test/test.py +35 -31
  2. test/test_jax.py +10 -8
test/test.py CHANGED
@@ -1,38 +1,42 @@
 
1
  import numpy as np
2
  from pysr import pysr
3
  import sympy
4
- X = np.random.randn(100, 5)
5
 
6
- default_test_kwargs = dict(
7
- niterations=10,
8
- populations=4,
9
- user_input=False,
10
- annealing=True,
11
- useFrequency=False,
12
- )
 
 
 
 
 
 
 
 
 
 
13
 
14
- print("Test 1 - defaults; simple linear relation")
15
- y = X[:, 0]
16
- equations = pysr(X, y, **default_test_kwargs)
17
- print(equations)
18
- assert equations.iloc[-1]['MSE'] < 1e-4
 
 
 
 
19
 
20
- print("Test 2 - test custom operator, and multiple outputs")
21
- y = X[:, [0, 1]]**2
22
- equations = pysr(X, y,
23
- unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
24
- extra_sympy_mappings={'square': lambda x: x**2},
25
- **default_test_kwargs)
26
- print(equations)
27
- assert equations[0].iloc[-1]['MSE'] < 1e-4
28
- assert equations[1].iloc[-1]['MSE'] < 1e-4
29
 
30
- X = np.random.randn(100, 1)
31
- y = X[:, 0] + 3.0
32
- print("Test 3 - empty operator list, and single dimension input")
33
- equations = pysr(X, y,
34
- unary_operators=[], binary_operators=["plus"],
35
- **default_test_kwargs)
36
-
37
- print(equations)
38
- assert equations.iloc[-1]['MSE'] < 1e-4
 
1
+ import unittest
2
  import numpy as np
3
  from pysr import pysr
4
  import sympy
 
5
 
6
+ class TestPipeline(unittest.TestCase):
7
+ def setUp(self):
8
+ self.default_test_kwargs = dict(
9
+ niterations=10,
10
+ populations=4,
11
+ user_input=False,
12
+ annealing=True,
13
+ useFrequency=False,
14
+ )
15
+ np.random.seed(0)
16
+ self.X = np.random.randn(100, 5)
17
+
18
+ def test_linear_relation(self):
19
+ y = self.X[:, 0]
20
+ equations = pysr(self.X, y, **self.default_test_kwargs)
21
+ print(equations)
22
+ self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
23
 
24
+ def test_multioutput_custom_operator(self):
25
+ y = self.X[:, [0, 1]]**2
26
+ equations = pysr(self.X, y,
27
+ unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
28
+ extra_sympy_mappings={'square': lambda x: x**2},
29
+ **self.default_test_kwargs)
30
+ print(equations)
31
+ self.assertLessEqual(equations[0].iloc[-1]['MSE'], 1e-4)
32
+ self.assertLessEqual(equations[1].iloc[-1]['MSE'], 1e-4)
33
 
34
+ def test_empty_operators_single_input(self):
35
+ X = np.random.randn(100, 1)
36
+ y = X[:, 0] + 3.0
37
+ equations = pysr(X, y,
38
+ unary_operators=[], binary_operators=["plus"],
39
+ **self.default_test_kwargs)
 
 
 
40
 
41
+ print(equations)
42
+ self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
 
 
 
 
 
 
 
test/test_jax.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import numpy as np
2
  from pysr import pysr, sympy2jax
3
  from jax import numpy as jnp
@@ -5,11 +6,12 @@ from jax import random
5
  from jax import grad
6
  import sympy
7
 
8
- print("Test JAX 1 - test export")
9
- x, y, z = sympy.symbols('x y z')
10
- cosx = 1.0 * sympy.cos(x) + y
11
- key = random.PRNGKey(0)
12
- X = random.normal(key, (1000, 2))
13
- true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
14
- f, params = sympy2jax(cosx, [x, y, z])
15
- assert jnp.all(jnp.isclose(f(X, params), true)).item()
 
 
1
+ import unittest
2
  import numpy as np
3
  from pysr import pysr, sympy2jax
4
  from jax import numpy as jnp
 
6
  from jax import grad
7
  import sympy
8
 
9
+ class TestJAX(unittest.TestCase):
10
+ def test_sympy2jax(self):
11
+ x, y, z = sympy.symbols('x y z')
12
+ cosx = 1.0 * sympy.cos(x) + y
13
+ key = random.PRNGKey(0)
14
+ X = random.normal(key, (1000, 2))
15
+ true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
16
+ f, params = sympy2jax(cosx, [x, y, z])
17
+ self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())