MilesCranmer commited on
Commit
97e6589
1 Parent(s): a626763

Add test for feature selection

Browse files
Files changed (1) hide show
  1. test/test.py +26 -0
test/test.py CHANGED
@@ -1,6 +1,7 @@
1
  import unittest
2
  import numpy as np
3
  from pysr import pysr, get_hof, best, best_tex, best_callable
 
4
  import sympy
5
  import pandas as pd
6
 
@@ -72,3 +73,28 @@ class TestBest(unittest.TestCase):
72
  y = np.cos(X[:, 0])**2
73
  for f in [best_callable(), best_callable(self.equations)]:
74
  np.testing.assert_almost_equal(f(X), y)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import unittest
2
  import numpy as np
3
  from pysr import pysr, get_hof, best, best_tex, best_callable
4
+ from pysr.sr import run_feature_selection, _handle_feature_selection
5
  import sympy
6
  import pandas as pd
7
 
 
73
  y = np.cos(X[:, 0])**2
74
  for f in [best_callable(), best_callable(self.equations)]:
75
  np.testing.assert_almost_equal(f(X), y)
76
+
77
+
78
+ class TestFeatureSelection(unittest.TestCase):
79
+ def test_feature_selection(self):
80
+ np.random.seed(0)
81
+ X = np.random.randn(20001, 5)
82
+ y = X[:, 2]**2 + X[:, 3]**2
83
+ selected = run_feature_selection(X, y, select_k_features=2)
84
+ self.assertEqual(sorted(selected), [2, 3])
85
+
86
+ def test_feature_selection_handler(self):
87
+ np.random.seed(0)
88
+ X = np.random.randn(20000, 5)
89
+ y = X[:, 2]**2 + X[:, 3]**2
90
+ var_names = [f'x{i}' for i in range(5)]
91
+ selected_X, selected_var_names = _handle_feature_selection(
92
+ X, select_k_features=2,
93
+ use_custom_variable_names=True,
94
+ variable_names=[f'x{i}' for i in range(5)],
95
+ y=y)
96
+ self.assertEqual(set(selected_var_names), set('x2 x3'.split(' ')))
97
+ np.testing.assert_array_equal(
98
+ np.sort(selected_X, axis=1),
99
+ np.sort(X[:, [2, 3]], axis=1)
100
+ )