MilesCranmer commited on
Commit
c6b30e7
·
1 Parent(s): 38bbf68

Change latex table back to public function

Browse files
Files changed (2) hide show
  1. pysr/sr.py +4 -1
  2. test/test.py +4 -4
pysr/sr.py CHANGED
@@ -2000,7 +2000,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2000
  return ret_outputs
2001
  return ret_outputs[0]
2002
 
2003
- def _latex_table(self, indices=None, precision=3, include_score=False):
2004
  """Create a LaTeX/booktabs table for all, or some, of the equations.
2005
 
2006
  Parameters
@@ -2042,9 +2042,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2042
  equations = self.equations_
2043
 
2044
  if isinstance(indices[0], int):
 
2045
  indices = [indices]
2046
  equations = [equations]
2047
 
 
 
2048
  latex_equations = [
2049
  [to_latex(eq, prec=precision) for eq in equation_set["sympy_format"]]
2050
  for equation_set in equations
 
2000
  return ret_outputs
2001
  return ret_outputs[0]
2002
 
2003
+ def latex_table(self, indices=None, precision=3, include_score=False):
2004
  """Create a LaTeX/booktabs table for all, or some, of the equations.
2005
 
2006
  Parameters
 
2042
  equations = self.equations_
2043
 
2044
  if isinstance(indices[0], int):
2045
+ assert self.nout_ == 1, "For multiple outputs, pass a list of lists."
2046
  indices = [indices]
2047
  equations = [equations]
2048
 
2049
+ assert len(indices) == self.nout_
2050
+
2051
  latex_equations = [
2052
  [to_latex(eq, prec=precision) for eq in equation_set["sympy_format"]]
2053
  for equation_set in equations
test/test.py CHANGED
@@ -538,7 +538,7 @@ class TestLaTeXTable(unittest.TestCase):
538
  model = manually_create_model(equations)
539
 
540
  # Regular table:
541
- latex_table_str = model._latex_table()
542
  middle_part = r"""
543
  $x_{0}$ & $1$ & $1.05$ \\
544
  $\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ \\
@@ -548,7 +548,7 @@ class TestLaTeXTable(unittest.TestCase):
548
  self.assertEqual(latex_table_str, true_latex_table_str)
549
 
550
  # Different precision:
551
- latex_table_str = model._latex_table(precision=5)
552
  middle_part = r"""
553
  $x_{0}$ & $1$ & $1.0520$ \\
554
  $\cos{\left(x_{0} \right)}$ & $2$ & $0.023150$ \\
@@ -558,7 +558,7 @@ class TestLaTeXTable(unittest.TestCase):
558
  self.assertEqual(latex_table_str, self.create_true_latex(middle_part))
559
 
560
  # Including score:
561
- latex_table_str = model._latex_table(include_score=True)
562
  middle_part = r"""
563
  $x_{0}$ & $1$ & $1.05$ & $0.0$ \\
564
  $\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
@@ -568,7 +568,7 @@ class TestLaTeXTable(unittest.TestCase):
568
  self.assertEqual(latex_table_str, true_latex_table_str)
569
 
570
  # Only last equation:
571
- latex_table_str = model._latex_table(indices=[2])
572
  middle_part = r"""
573
  $x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
574
  """
 
538
  model = manually_create_model(equations)
539
 
540
  # Regular table:
541
+ latex_table_str = model.latex_table()
542
  middle_part = r"""
543
  $x_{0}$ & $1$ & $1.05$ \\
544
  $\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ \\
 
548
  self.assertEqual(latex_table_str, true_latex_table_str)
549
 
550
  # Different precision:
551
+ latex_table_str = model.latex_table(precision=5)
552
  middle_part = r"""
553
  $x_{0}$ & $1$ & $1.0520$ \\
554
  $\cos{\left(x_{0} \right)}$ & $2$ & $0.023150$ \\
 
558
  self.assertEqual(latex_table_str, self.create_true_latex(middle_part))
559
 
560
  # Including score:
561
+ latex_table_str = model.latex_table(include_score=True)
562
  middle_part = r"""
563
  $x_{0}$ & $1$ & $1.05$ & $0.0$ \\
564
  $\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
 
568
  self.assertEqual(latex_table_str, true_latex_table_str)
569
 
570
  # Only last equation:
571
+ latex_table_str = model.latex_table(indices=[2])
572
  middle_part = r"""
573
  $x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
574
  """