MilesCranmer commited on
Commit
1adfa85
1 Parent(s): 4d915b2

Test best() functions

Browse files
Files changed (1) hide show
  1. test/test.py +31 -1
test/test.py CHANGED
@@ -1,7 +1,8 @@
1
  import unittest
2
  import numpy as np
3
- from pysr import pysr
4
  import sympy
 
5
 
6
  class TestPipeline(unittest.TestCase):
7
  def setUp(self):
@@ -40,3 +41,32 @@ class TestPipeline(unittest.TestCase):
40
 
41
  print(equations)
42
  self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import unittest
2
  import numpy as np
3
+ from pysr import pysr, get_hof, best, best_tex, best_callable
4
  import sympy
5
+ import pandas as pd
6
 
7
  class TestPipeline(unittest.TestCase):
8
  def setUp(self):
 
41
 
42
  print(equations)
43
  self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
44
+
45
+ class TestBest(unittest.TestCase):
46
+ def setUp(self):
47
+ equations = pd.DataFrame({
48
+ 'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
49
+ 'MSE': [1.0, 0.1, 1e-5],
50
+ 'Complexity': [1, 2, 3]
51
+ })
52
+
53
+ equations['Complexity MSE Equation'.split(' ')].to_csv(
54
+ 'equation_file.csv.bkup', sep='|')
55
+
56
+ self.equations = get_hof(
57
+ 'equation_file.csv', n_features=2,
58
+ variables_names='x0 x1'.split(' '),
59
+ extra_sympy_mappings={}, output_jax_format=False,
60
+ multioutput=False, nout=1)
61
+
62
+ def test_best(self):
63
+ self.assertEqual(best(), sympy.cos(sympy.Symbol('x0'))**2)
64
+
65
+ def test_best_tex(self):
66
+ self.assertEqual(best_tex(), '\\cos^{2}{\\left(x_{0} \\right)}')
67
+
68
+ def test_best_lambda(self):
69
+ f = best_callable()
70
+ X = np.random.randn(10, 2)
71
+ y = np.cos(X[:, 0])**2
72
+ np.testing.assert_almost_equal(f(X), y)