Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
f59f827
1
Parent(s):
d398bf9
Add .latex() representation to PySRRegressor
Browse files- pysr/sklearn.py +3 -0
- test/test.py +5 -0
pysr/sklearn.py
CHANGED
@@ -100,6 +100,9 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
100 |
def sympy(self):
|
101 |
return self.get_best()["sympy_format"]
|
102 |
|
|
|
|
|
|
|
103 |
def jax(self):
|
104 |
self.equations = get_hof(output_jax_format=True)
|
105 |
return self.get_best()["jax_format"]
|
|
|
100 |
def sympy(self):
|
101 |
return self.get_best()["sympy_format"]
|
102 |
|
103 |
+
def latex(self):
|
104 |
+
return self.sympy().simplify()
|
105 |
+
|
106 |
def jax(self):
|
107 |
self.equations = get_hof(output_jax_format=True)
|
108 |
return self.get_best()["jax_format"]
|
test/test.py
CHANGED
@@ -193,13 +193,18 @@ class TestBest(unittest.TestCase):
|
|
193 |
nout=1,
|
194 |
)
|
195 |
|
|
|
|
|
|
|
196 |
def test_best(self):
|
197 |
self.assertEqual(best(self.equations), sympy.cos(sympy.Symbol("x0")) ** 2)
|
198 |
self.assertEqual(best(), sympy.cos(sympy.Symbol("x0")) ** 2)
|
|
|
199 |
|
200 |
def test_best_tex(self):
|
201 |
self.assertEqual(best_tex(self.equations), "\\cos^{2}{\\left(x_{0} \\right)}")
|
202 |
self.assertEqual(best_tex(), "\\cos^{2}{\\left(x_{0} \\right)}")
|
|
|
203 |
|
204 |
def test_best_lambda(self):
|
205 |
X = np.random.randn(10, 2)
|
|
|
193 |
nout=1,
|
194 |
)
|
195 |
|
196 |
+
self.model = PySRRegressor()
|
197 |
+
self.model.equations = self.equations
|
198 |
+
|
199 |
def test_best(self):
|
200 |
self.assertEqual(best(self.equations), sympy.cos(sympy.Symbol("x0")) ** 2)
|
201 |
self.assertEqual(best(), sympy.cos(sympy.Symbol("x0")) ** 2)
|
202 |
+
self.assertEqual(self.model.sympy(), sympy.cos(sympy.Symbol("x0")) ** 2)
|
203 |
|
204 |
def test_best_tex(self):
|
205 |
self.assertEqual(best_tex(self.equations), "\\cos^{2}{\\left(x_{0} \\right)}")
|
206 |
self.assertEqual(best_tex(), "\\cos^{2}{\\left(x_{0} \\right)}")
|
207 |
+
self.assertEqual(self.model.latex(), "\\cos^{2}{\\left(x_{0} \\right)}")
|
208 |
|
209 |
def test_best_lambda(self):
|
210 |
X = np.random.randn(10, 2)
|