MilesCranmer commited on
Commit
3ef2b32
1 Parent(s): fab6f87

Display output variable in table of expressions

Browse files
Files changed (2) hide show
  1. pysr/export_latex.py +12 -2
  2. test/test.py +19 -19
pysr/export_latex.py CHANGED
@@ -65,6 +65,7 @@ def generate_single_table(
65
  precision: int = 3,
66
  columns=["equation", "complexity", "loss", "score"],
67
  max_equation_length: int = 50,
 
68
  ):
69
  """Generate a booktabs-style LaTeX table for a single set of equations."""
70
  assert isinstance(equations, pd.DataFrame)
@@ -96,7 +97,9 @@ def generate_single_table(
96
  for col in columns:
97
  if col == "equation":
98
  if len(latex_equation) < max_equation_length:
99
- row_pieces.append("$" + latex_equation + "$")
 
 
100
  else:
101
  if not raised_long_equation_warning:
102
  warnings.warn(
@@ -109,7 +112,7 @@ def generate_single_table(
109
  r"\begin{minipage}{0.8\linewidth}",
110
  r"\vspace{-1em}",
111
  r"\begin{dmath*}",
112
- latex_equation,
113
  r"\end{dmath*}",
114
  r"\end{minipage}",
115
  ]
@@ -137,8 +140,10 @@ def generate_multiple_tables(
137
  indices: List[List[int]] = None,
138
  precision: int = 3,
139
  columns=["equation", "complexity", "loss", "score"],
 
140
  ):
141
  """Generate multiple latex tables for a list of equation sets."""
 
142
 
143
  latex_tables = [
144
  generate_single_table(
@@ -146,6 +151,11 @@ def generate_multiple_tables(
146
  (None if not indices else indices[i]),
147
  precision=precision,
148
  columns=columns,
 
 
 
 
 
149
  )
150
  for i in range(len(equations))
151
  ]
 
65
  precision: int = 3,
66
  columns=["equation", "complexity", "loss", "score"],
67
  max_equation_length: int = 50,
68
+ output_variable_name: str = "y",
69
  ):
70
  """Generate a booktabs-style LaTeX table for a single set of equations."""
71
  assert isinstance(equations, pd.DataFrame)
 
97
  for col in columns:
98
  if col == "equation":
99
  if len(latex_equation) < max_equation_length:
100
+ row_pieces.append(
101
+ "$" + output_variable_name + " = " + latex_equation + "$"
102
+ )
103
  else:
104
  if not raised_long_equation_warning:
105
  warnings.warn(
 
112
  r"\begin{minipage}{0.8\linewidth}",
113
  r"\vspace{-1em}",
114
  r"\begin{dmath*}",
115
+ output_variable_name + " = " + latex_equation,
116
  r"\end{dmath*}",
117
  r"\end{minipage}",
118
  ]
 
140
  indices: List[List[int]] = None,
141
  precision: int = 3,
142
  columns=["equation", "complexity", "loss", "score"],
143
+ output_variable_names: str = None,
144
  ):
145
  """Generate multiple latex tables for a list of equation sets."""
146
+ # TODO: Let user specify custom output variable
147
 
148
  latex_tables = [
149
  generate_single_table(
 
151
  (None if not indices else indices[i]),
152
  precision=precision,
153
  columns=columns,
154
+ output_variable_name=(
155
+ "y_{" + str(i) + "}"
156
+ if output_variable_names is None
157
+ else output_variable_names[i]
158
+ ),
159
  )
160
  for i in range(len(equations))
161
  ]
test/test.py CHANGED
@@ -553,9 +553,9 @@ class TestLaTeXTable(unittest.TestCase):
553
  columns=["equation", "complexity", "loss"]
554
  )
555
  middle_part = r"""
556
- $x_{0}$ & $1$ & $1.05$ \\
557
- $\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ \\
558
- $x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
559
  """
560
  true_latex_table_str = self.create_true_latex(middle_part)
561
  self.assertEqual(latex_table_str, true_latex_table_str)
@@ -565,9 +565,9 @@ class TestLaTeXTable(unittest.TestCase):
565
  precision=5, columns=["equation", "complexity", "loss"]
566
  )
567
  middle_part = r"""
568
- $x_{0}$ & $1$ & $1.0520$ \\
569
- $\cos{\left(x_{0} \right)}$ & $2$ & $0.023150$ \\
570
- $x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.1235 \cdot 10^{-15}$ \\
571
  """
572
  true_latex_table_str = self.create_true_latex(middle_part)
573
  self.assertEqual(latex_table_str, true_latex_table_str)
@@ -575,9 +575,9 @@ class TestLaTeXTable(unittest.TestCase):
575
  def test_include_score(self):
576
  latex_table_str = self.model.latex_table()
577
  middle_part = r"""
578
- $x_{0}$ & $1$ & $1.05$ & $0.0$ \\
579
- $\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
580
- $x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
581
  """
582
  true_latex_table_str = self.create_true_latex(middle_part, include_score=True)
583
  self.assertEqual(latex_table_str, true_latex_table_str)
@@ -587,7 +587,7 @@ class TestLaTeXTable(unittest.TestCase):
587
  indices=[2], columns=["equation", "complexity", "loss"]
588
  )
589
  middle_part = r"""
590
- $x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
591
  """
592
  true_latex_table_str = self.create_true_latex(middle_part)
593
  self.assertEqual(latex_table_str, true_latex_table_str)
@@ -610,14 +610,14 @@ class TestLaTeXTable(unittest.TestCase):
610
  equations = [equations1, equations2]
611
  model = manually_create_model(equations)
612
  middle_part_1 = r"""
613
- $x_{0}$ & $1$ & $1.05$ & $0.0$ \\
614
- $\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
615
- $x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
616
  """
617
  middle_part_2 = r"""
618
- $x_{1}$ & $1$ & $1.32$ & $0.0$ \\
619
- $\cos{\left(x_{1} \right)}$ & $2$ & $0.0520$ & $3.23$ \\
620
- $x_{0}^{2} x_{1}$ & $5$ & $2.00 \cdot 10^{-15}$ & $10.3$ \\
621
  """
622
  true_latex_table_str = "\n\n".join(
623
  self.create_true_latex(part, include_score=True)
@@ -667,9 +667,9 @@ class TestLaTeXTable(unittest.TestCase):
667
  model = manually_create_model(equations)
668
  latex_table_str = model.latex_table()
669
  middle_part = r"""
670
- $x_{0}$ & $1$ & $1.05$ & $0.0$ \\
671
- $\cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
672
- \begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} x_{0}^{5} + x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} - 5.20 \sin{\left(2.60 x_{0} - 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
673
  """
674
  true_latex_table_str = self.create_true_latex(middle_part, include_score=True)
675
  self.assertEqual(latex_table_str, true_latex_table_str)
 
553
  columns=["equation", "complexity", "loss"]
554
  )
555
  middle_part = r"""
556
+ $y = x_{0}$ & $1$ & $1.05$ \\
557
+ $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ \\
558
+ $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
559
  """
560
  true_latex_table_str = self.create_true_latex(middle_part)
561
  self.assertEqual(latex_table_str, true_latex_table_str)
 
565
  precision=5, columns=["equation", "complexity", "loss"]
566
  )
567
  middle_part = r"""
568
+ $y = x_{0}$ & $1$ & $1.0520$ \\
569
+ $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.023150$ \\
570
+ $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.1235 \cdot 10^{-15}$ \\
571
  """
572
  true_latex_table_str = self.create_true_latex(middle_part)
573
  self.assertEqual(latex_table_str, true_latex_table_str)
 
575
  def test_include_score(self):
576
  latex_table_str = self.model.latex_table()
577
  middle_part = r"""
578
+ $y = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
579
+ $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
580
+ $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
581
  """
582
  true_latex_table_str = self.create_true_latex(middle_part, include_score=True)
583
  self.assertEqual(latex_table_str, true_latex_table_str)
 
587
  indices=[2], columns=["equation", "complexity", "loss"]
588
  )
589
  middle_part = r"""
590
+ $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
591
  """
592
  true_latex_table_str = self.create_true_latex(middle_part)
593
  self.assertEqual(latex_table_str, true_latex_table_str)
 
610
  equations = [equations1, equations2]
611
  model = manually_create_model(equations)
612
  middle_part_1 = r"""
613
+ $y_{0} = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
614
+ $y_{0} = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
615
+ $y_{0} = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
616
  """
617
  middle_part_2 = r"""
618
+ $y_{1} = x_{1}$ & $1$ & $1.32$ & $0.0$ \\
619
+ $y_{1} = \cos{\left(x_{1} \right)}$ & $2$ & $0.0520$ & $3.23$ \\
620
+ $y_{1} = x_{0}^{2} x_{1}$ & $5$ & $2.00 \cdot 10^{-15}$ & $10.3$ \\
621
  """
622
  true_latex_table_str = "\n\n".join(
623
  self.create_true_latex(part, include_score=True)
 
667
  model = manually_create_model(equations)
668
  latex_table_str = model.latex_table()
669
  middle_part = r"""
670
+ $y = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
671
+ $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
672
+ \begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0}^{5} + x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} - 5.20 \sin{\left(2.60 x_{0} - 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
673
  """
674
  true_latex_table_str = self.create_true_latex(middle_part, include_score=True)
675
  self.assertEqual(latex_table_str, true_latex_table_str)