MilesCranmer commited on
Commit
6a4fa2c
1 Parent(s): 97e6589

Fix issue with lambda getting redefined; add test

Browse files
Files changed (2) hide show
  1. pysr/sr.py +15 -2
  2. test/test.py +28 -4
pysr/sr.py CHANGED
@@ -61,6 +61,19 @@ sympy_mappings = {
61
  'gamma': lambda x : sympy.gamma(x),
62
  }
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def pysr(X, y, weights=None,
65
  binary_operators=None,
66
  unary_operators=None,
@@ -774,8 +787,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
774
  if output_jax_format:
775
  func, params = sympy2jax(eqn, sympy_symbols)
776
  jax_format.append({'callable': func, 'parameters': params})
777
- tmp_lambda = lambdify(sympy_symbols, eqn)
778
- lambda_format.append(lambda X: tmp_lambda(*X.T))
779
  curMSE = output.loc[i, 'MSE']
780
  curComplexity = output.loc[i, 'Complexity']
781
 
 
61
  'gamma': lambda x : sympy.gamma(x),
62
  }
63
 
64
+ class CallableEquation(object):
65
+ """Simple wrapper for numpy lambda functions built with sympy"""
66
+ def __init__(self, sympy_symbols, eqn):
67
+ self._sympy = eqn
68
+ self._sympy_symbols = sympy_symbols
69
+ self._lambda = lambdify(sympy_symbols, eqn)
70
+
71
+ def __repr__(self):
72
+ return f"PySRFunction(X=>{self._sympy})"
73
+
74
+ def __call__(self, X):
75
+ return self._lambda(*X.T)
76
+
77
  def pysr(X, y, weights=None,
78
  binary_operators=None,
79
  unary_operators=None,
 
787
  if output_jax_format:
788
  func, params = sympy2jax(eqn, sympy_symbols)
789
  jax_format.append({'callable': func, 'parameters': params})
790
+
791
+ lambda_format.append(CallableEquation(sympy_symbols, eqn))
792
  curMSE = output.loc[i, 'MSE']
793
  curComplexity = output.loc[i, 'Complexity']
794
 
test/test.py CHANGED
@@ -1,8 +1,9 @@
1
  import unittest
2
  import numpy as np
3
- from pysr import pysr, get_hof, best, best_tex, best_callable
4
  from pysr.sr import run_feature_selection, _handle_feature_selection
5
  import sympy
 
6
  import pandas as pd
7
 
8
  class TestPipeline(unittest.TestCase):
@@ -27,12 +28,36 @@ class TestPipeline(unittest.TestCase):
27
  y = self.X[:, [0, 1]]**2
28
  equations = pysr(self.X, y,
29
  unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
30
- extra_sympy_mappings={'square': lambda x: x**2},
31
- **self.default_test_kwargs)
 
32
  print(equations)
33
  self.assertLessEqual(equations[0].iloc[-1]['MSE'], 1e-4)
34
  self.assertLessEqual(equations[1].iloc[-1]['MSE'], 1e-4)
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def test_empty_operators_single_input(self):
37
  X = np.random.randn(100, 1)
38
  y = X[:, 0] + 3.0
@@ -40,7 +65,6 @@ class TestPipeline(unittest.TestCase):
40
  unary_operators=[], binary_operators=["plus"],
41
  **self.default_test_kwargs)
42
 
43
- print(equations)
44
  self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
45
 
46
  class TestBest(unittest.TestCase):
 
1
  import unittest
2
  import numpy as np
3
+ from pysr import pysr, get_hof, best, best_tex, best_callable, best_row
4
  from pysr.sr import run_feature_selection, _handle_feature_selection
5
  import sympy
6
+ from sympy import lambdify
7
  import pandas as pd
8
 
9
  class TestPipeline(unittest.TestCase):
 
28
  y = self.X[:, [0, 1]]**2
29
  equations = pysr(self.X, y,
30
  unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
31
+ extra_sympy_mappings={'sq': lambda x: x**2},
32
+ **self.default_test_kwargs,
33
+ procs=0)
34
  print(equations)
35
  self.assertLessEqual(equations[0].iloc[-1]['MSE'], 1e-4)
36
  self.assertLessEqual(equations[1].iloc[-1]['MSE'], 1e-4)
37
 
38
+ def test_multioutput_weighted_with_callable(self):
39
+ y = self.X[:, [0, 1]]**2
40
+ w = np.random.rand(*y.shape)
41
+ w[w < 0.5] = 0.0
42
+ w[w >= 0.5] = 1.0
43
+
44
+ # Double equation when weights are 0:
45
+ y += (1-w) * y
46
+ # Thus, pysr needs to use the weights to find the right equation!
47
+
48
+ equations = pysr(self.X, y, weights=w,
49
+ unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
50
+ extra_sympy_mappings={'sq': lambda x: x**2},
51
+ **self.default_test_kwargs,
52
+ procs=0)
53
+
54
+ np.testing.assert_almost_equal(
55
+ best_callable()[0](self.X),
56
+ self.X[:, 0]**2)
57
+ np.testing.assert_almost_equal(
58
+ best_callable()[1](self.X),
59
+ self.X[:, 1]**2)
60
+
61
  def test_empty_operators_single_input(self):
62
  X = np.random.randn(100, 1)
63
  y = X[:, 0] + 3.0
 
65
  unary_operators=[], binary_operators=["plus"],
66
  **self.default_test_kwargs)
67
 
 
68
  self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
69
 
70
  class TestBest(unittest.TestCase):