MilesCranmer commited on
Commit
3752ba6
1 Parent(s): 0dbee97

Include necessary packages in latex_table()

Browse files
Files changed (3) hide show
  1. pysr/export_latex.py +0 -10
  2. pysr/sr.py +10 -1
  3. test/test.py +33 -5
pysr/export_latex.py CHANGED
@@ -6,9 +6,6 @@ from typing import List
6
  import warnings
7
 
8
 
9
- raised_long_equation_warning = False
10
-
11
-
12
  class PreciseLatexPrinter(LatexPrinter):
13
  """Modified SymPy printer with custom float precision."""
14
 
@@ -70,8 +67,6 @@ def generate_single_table(
70
  """Generate a booktabs-style LaTeX table for a single set of equations."""
71
  assert isinstance(equations, pd.DataFrame)
72
 
73
- global raised_long_equation_warning
74
-
75
  latex_top, latex_bottom = generate_table_environment(columns)
76
  latex_table_content = []
77
 
@@ -101,11 +96,6 @@ def generate_single_table(
101
  "$" + output_variable_name + " = " + latex_equation + "$"
102
  )
103
  else:
104
- if not raised_long_equation_warning:
105
- warnings.warn(
106
- "Please add \\usepackage{breqn} to your LaTeX preamble."
107
- )
108
- raised_long_equation_warning = True
109
 
110
  broken_latex_equation = " ".join(
111
  [
 
6
  import warnings
7
 
8
 
 
 
 
9
  class PreciseLatexPrinter(LatexPrinter):
10
  """Modified SymPy printer with custom float precision."""
11
 
 
67
  """Generate a booktabs-style LaTeX table for a single set of equations."""
68
  assert isinstance(equations, pd.DataFrame)
69
 
 
 
70
  latex_top, latex_bottom = generate_table_environment(columns)
71
  latex_table_content = []
72
 
 
96
  "$" + output_variable_name + " = " + latex_equation + "$"
97
  )
98
  else:
 
 
 
 
 
99
 
100
  broken_latex_equation = " ".join(
101
  [
pysr/sr.py CHANGED
@@ -2197,9 +2197,18 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2197
 
2198
  generator_fnc = generate_single_table
2199
 
2200
- return generator_fnc(
2201
  self.equations_, indices=indices, precision=precision, columns=columns
2202
  )
 
 
 
 
 
 
 
 
 
2203
 
2204
 
2205
  def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int:
 
2197
 
2198
  generator_fnc = generate_single_table
2199
 
2200
+ table_string = generator_fnc(
2201
  self.equations_, indices=indices, precision=precision, columns=columns
2202
  )
2203
+ preamble_string = [
2204
+ r"\usepackage{breqn}",
2205
+ r"\usepackage{booktabs}",
2206
+ r"\usepackage{tabularx}",
2207
+ "",
2208
+ "...",
2209
+ "",
2210
+ ]
2211
+ return "\n".join(preamble_string + [table_string])
2212
 
2213
 
2214
  def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int:
test/test.py CHANGED
@@ -608,6 +608,18 @@ class TestMiscellaneous(unittest.TestCase):
608
  self.assertEqual(len(exception_messages), 0)
609
 
610
 
 
 
 
 
 
 
 
 
 
 
 
 
611
  class TestLaTeXTable(unittest.TestCase):
612
  def setUp(self):
613
  equations = pd.DataFrame(
@@ -618,6 +630,7 @@ class TestLaTeXTable(unittest.TestCase):
618
  )
619
  )
620
  self.model = manually_create_model(equations)
 
621
 
622
  def create_true_latex(self, middle_part, include_score=False):
623
  if include_score:
@@ -657,7 +670,9 @@ class TestLaTeXTable(unittest.TestCase):
657
  $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ \\
658
  $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
659
  """
660
- true_latex_table_str = self.create_true_latex(middle_part)
 
 
661
  self.assertEqual(latex_table_str, true_latex_table_str)
662
 
663
  def test_other_precision(self):
@@ -669,7 +684,9 @@ class TestLaTeXTable(unittest.TestCase):
669
  $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.023150$ \\
670
  $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.1235 \cdot 10^{-15}$ \\
671
  """
672
- true_latex_table_str = self.create_true_latex(middle_part)
 
 
673
  self.assertEqual(latex_table_str, true_latex_table_str)
674
 
675
  def test_include_score(self):
@@ -679,7 +696,11 @@ class TestLaTeXTable(unittest.TestCase):
679
  $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
680
  $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
681
  """
682
- true_latex_table_str = self.create_true_latex(middle_part, include_score=True)
 
 
 
 
683
  self.assertEqual(latex_table_str, true_latex_table_str)
684
 
685
  def test_last_equation(self):
@@ -689,7 +710,9 @@ class TestLaTeXTable(unittest.TestCase):
689
  middle_part = r"""
690
  $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
691
  """
692
- true_latex_table_str = self.create_true_latex(middle_part)
 
 
693
  self.assertEqual(latex_table_str, true_latex_table_str)
694
 
695
  def test_multi_output(self):
@@ -723,6 +746,7 @@ class TestLaTeXTable(unittest.TestCase):
723
  self.create_true_latex(part, include_score=True)
724
  for part in [middle_part_1, middle_part_2]
725
  )
 
726
  latex_table_str = model.latex_table()
727
 
728
  self.assertEqual(latex_table_str, true_latex_table_str)
@@ -771,5 +795,9 @@ class TestLaTeXTable(unittest.TestCase):
771
  $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
772
  \begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0}^{5} + x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} - 5.20 \sin{\left(2.60 x_{0} - 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
773
  """
774
- true_latex_table_str = self.create_true_latex(middle_part, include_score=True)
 
 
 
 
775
  self.assertEqual(latex_table_str, true_latex_table_str)
 
608
  self.assertEqual(len(exception_messages), 0)
609
 
610
 
611
+ TRUE_PREAMBLE = "\n".join(
612
+ [
613
+ r"\usepackage{breqn}",
614
+ r"\usepackage{booktabs}",
615
+ r"\usepackage{tabularx}",
616
+ "",
617
+ "...",
618
+ "",
619
+ ]
620
+ )
621
+
622
+
623
  class TestLaTeXTable(unittest.TestCase):
624
  def setUp(self):
625
  equations = pd.DataFrame(
 
630
  )
631
  )
632
  self.model = manually_create_model(equations)
633
+ self.maxDiff = None
634
 
635
  def create_true_latex(self, middle_part, include_score=False):
636
  if include_score:
 
670
  $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ \\
671
  $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
672
  """
673
+ true_latex_table_str = (
674
+ TRUE_PREAMBLE + "\n" + self.create_true_latex(middle_part)
675
+ )
676
  self.assertEqual(latex_table_str, true_latex_table_str)
677
 
678
  def test_other_precision(self):
 
684
  $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.023150$ \\
685
  $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.1235 \cdot 10^{-15}$ \\
686
  """
687
+ true_latex_table_str = (
688
+ TRUE_PREAMBLE + "\n" + self.create_true_latex(middle_part)
689
+ )
690
  self.assertEqual(latex_table_str, true_latex_table_str)
691
 
692
  def test_include_score(self):
 
696
  $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
697
  $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
698
  """
699
+ true_latex_table_str = (
700
+ TRUE_PREAMBLE
701
+ + "\n"
702
+ + self.create_true_latex(middle_part, include_score=True)
703
+ )
704
  self.assertEqual(latex_table_str, true_latex_table_str)
705
 
706
  def test_last_equation(self):
 
710
  middle_part = r"""
711
  $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
712
  """
713
+ true_latex_table_str = (
714
+ TRUE_PREAMBLE + "\n" + self.create_true_latex(middle_part)
715
+ )
716
  self.assertEqual(latex_table_str, true_latex_table_str)
717
 
718
  def test_multi_output(self):
 
746
  self.create_true_latex(part, include_score=True)
747
  for part in [middle_part_1, middle_part_2]
748
  )
749
+ true_latex_table_str = TRUE_PREAMBLE + "\n" + true_latex_table_str
750
  latex_table_str = model.latex_table()
751
 
752
  self.assertEqual(latex_table_str, true_latex_table_str)
 
795
  $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
796
  \begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0}^{5} + x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} - 5.20 \sin{\left(2.60 x_{0} - 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
797
  """
798
+ true_latex_table_str = (
799
+ TRUE_PREAMBLE
800
+ + "\n"
801
+ + self.create_true_latex(middle_part, include_score=True)
802
+ )
803
  self.assertEqual(latex_table_str, true_latex_table_str)