MilesCranmer commited on
Commit
0c1c3db
2 Parent(s): bb99ca5 ce64294

Merge pull request #156 from MilesCranmer/latex-table

Browse files
Files changed (3) hide show
  1. pysr/export_latex.py +153 -0
  2. pysr/sr.py +65 -3
  3. test/test.py +234 -19
pysr/export_latex.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Functions to help export PySR equations to LaTeX."""
2
+ import sympy
3
+ from sympy.printing.latex import LatexPrinter
4
+ import pandas as pd
5
+ from typing import List
6
+ import warnings
7
+
8
+
9
+ class PreciseLatexPrinter(LatexPrinter):
10
+ """Modified SymPy printer with custom float precision."""
11
+
12
+ def __init__(self, settings=None, prec=3):
13
+ super().__init__(settings)
14
+ self.prec = prec
15
+
16
+ def _print_Float(self, expr):
17
+ # Reduce precision of float:
18
+ reduced_float = sympy.Float(expr, self.prec)
19
+ return super()._print_Float(reduced_float)
20
+
21
+
22
+ def to_latex(expr, prec=3, full_prec=True, **settings):
23
+ """Convert sympy expression to LaTeX with custom precision."""
24
+ settings["full_prec"] = full_prec
25
+ printer = PreciseLatexPrinter(settings=settings, prec=prec)
26
+ return printer.doprint(expr)
27
+
28
+
29
+ def generate_table_environment(columns=["equation", "complexity", "loss"]):
30
+ margins = "c" * len(columns)
31
+ column_map = {
32
+ "complexity": "Complexity",
33
+ "loss": "Loss",
34
+ "equation": "Equation",
35
+ "score": "Score",
36
+ }
37
+ columns = [column_map[col] for col in columns]
38
+ top_pieces = [
39
+ r"\begin{table}[h]",
40
+ r"\begin{center}",
41
+ r"\begin{tabular}{@{}" + margins + r"@{}}",
42
+ r"\toprule",
43
+ " & ".join(columns) + r" \\",
44
+ r"\midrule",
45
+ ]
46
+
47
+ bottom_pieces = [
48
+ r"\bottomrule",
49
+ r"\end{tabular}",
50
+ r"\end{center}",
51
+ r"\end{table}",
52
+ ]
53
+ top_latex_table = "\n".join(top_pieces)
54
+ bottom_latex_table = "\n".join(bottom_pieces)
55
+
56
+ return top_latex_table, bottom_latex_table
57
+
58
+
59
+ def generate_single_table(
60
+ equations: pd.DataFrame,
61
+ indices: List[int] = None,
62
+ precision: int = 3,
63
+ columns=["equation", "complexity", "loss", "score"],
64
+ max_equation_length: int = 50,
65
+ output_variable_name: str = "y",
66
+ ):
67
+ """Generate a booktabs-style LaTeX table for a single set of equations."""
68
+ assert isinstance(equations, pd.DataFrame)
69
+
70
+ latex_top, latex_bottom = generate_table_environment(columns)
71
+ latex_table_content = []
72
+
73
+ if indices is None:
74
+ indices = range(len(equations))
75
+
76
+ for i in indices:
77
+ latex_equation = to_latex(
78
+ equations.iloc[i]["sympy_format"],
79
+ prec=precision,
80
+ )
81
+ complexity = str(equations.iloc[i]["complexity"])
82
+ loss = to_latex(
83
+ sympy.Float(equations.iloc[i]["loss"]),
84
+ prec=precision,
85
+ )
86
+ score = to_latex(
87
+ sympy.Float(equations.iloc[i]["score"]),
88
+ prec=precision,
89
+ )
90
+
91
+ row_pieces = []
92
+ for col in columns:
93
+ if col == "equation":
94
+ if len(latex_equation) < max_equation_length:
95
+ row_pieces.append(
96
+ "$" + output_variable_name + " = " + latex_equation + "$"
97
+ )
98
+ else:
99
+
100
+ broken_latex_equation = " ".join(
101
+ [
102
+ r"\begin{minipage}{0.8\linewidth}",
103
+ r"\vspace{-1em}",
104
+ r"\begin{dmath*}",
105
+ output_variable_name + " = " + latex_equation,
106
+ r"\end{dmath*}",
107
+ r"\end{minipage}",
108
+ ]
109
+ )
110
+ row_pieces.append(broken_latex_equation)
111
+
112
+ elif col == "complexity":
113
+ row_pieces.append("$" + complexity + "$")
114
+ elif col == "loss":
115
+ row_pieces.append("$" + loss + "$")
116
+ elif col == "score":
117
+ row_pieces.append("$" + score + "$")
118
+ else:
119
+ raise ValueError(f"Unknown column: {col}")
120
+
121
+ latex_table_content.append(
122
+ " & ".join(row_pieces) + r" \\",
123
+ )
124
+
125
+ return "\n".join([latex_top, *latex_table_content, latex_bottom])
126
+
127
+
128
+ def generate_multiple_tables(
129
+ equations: List[pd.DataFrame],
130
+ indices: List[List[int]] = None,
131
+ precision: int = 3,
132
+ columns=["equation", "complexity", "loss", "score"],
133
+ output_variable_names: str = None,
134
+ ):
135
+ """Generate multiple latex tables for a list of equation sets."""
136
+ # TODO: Let user specify custom output variable
137
+
138
+ latex_tables = [
139
+ generate_single_table(
140
+ equations[i],
141
+ (None if not indices else indices[i]),
142
+ precision=precision,
143
+ columns=columns,
144
+ output_variable_name=(
145
+ "y_{" + str(i) + "}"
146
+ if output_variable_names is None
147
+ else output_variable_names[i]
148
+ ),
149
+ )
150
+ for i in range(len(equations))
151
+ ]
152
+
153
+ return "\n\n".join(latex_tables)
pysr/sr.py CHANGED
@@ -29,6 +29,7 @@ from .julia_helpers import (
29
  import_error_string,
30
  )
31
  from .export_numpy import CallableEquation
 
32
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
33
 
34
 
@@ -1875,7 +1876,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1875
  return [eq["sympy_format"] for eq in best_equation]
1876
  return best_equation["sympy_format"]
1877
 
1878
- def latex(self, index=None):
1879
  """
1880
  Return latex representation of the equation(s) chosen by `model_selection`.
1881
 
@@ -1887,6 +1888,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1887
  the `model_selection` parameter. If there are multiple output
1888
  features, then pass a list of indices with the order the same
1889
  as the output feature.
 
 
 
1890
 
1891
  Returns
1892
  -------
@@ -1896,8 +1900,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1896
  self.refresh()
1897
  sympy_representation = self.sympy(index=index)
1898
  if self.nout_ > 1:
1899
- return [sympy.latex(s) for s in sympy_representation]
1900
- return sympy.latex(sympy_representation)
 
 
 
 
1901
 
1902
  def jax(self, index=None):
1903
  """
@@ -2147,6 +2155,60 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2147
  return ret_outputs
2148
  return ret_outputs[0]
2149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2150
 
2151
  def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int:
2152
  """
 
29
  import_error_string,
30
  )
31
  from .export_numpy import CallableEquation
32
+ from .export_latex import generate_single_table, generate_multiple_tables, to_latex
33
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
34
 
35
 
 
1876
  return [eq["sympy_format"] for eq in best_equation]
1877
  return best_equation["sympy_format"]
1878
 
1879
+ def latex(self, index=None, precision=3):
1880
  """
1881
  Return latex representation of the equation(s) chosen by `model_selection`.
1882
 
 
1888
  the `model_selection` parameter. If there are multiple output
1889
  features, then pass a list of indices with the order the same
1890
  as the output feature.
1891
+ precision : int, default=3
1892
+ The number of significant figures shown in the LaTeX
1893
+ representation.
1894
 
1895
  Returns
1896
  -------
 
1900
  self.refresh()
1901
  sympy_representation = self.sympy(index=index)
1902
  if self.nout_ > 1:
1903
+ output = []
1904
+ for s in sympy_representation:
1905
+ latex = to_latex(s, prec=precision)
1906
+ output.append(latex)
1907
+ return output
1908
+ return to_latex(sympy_representation, prec=precision)
1909
 
1910
  def jax(self, index=None):
1911
  """
 
2155
  return ret_outputs
2156
  return ret_outputs[0]
2157
 
2158
+ def latex_table(
2159
+ self,
2160
+ indices=None,
2161
+ precision=3,
2162
+ columns=["equation", "complexity", "loss", "score"],
2163
+ ):
2164
+ """Create a LaTeX/booktabs table for all, or some, of the equations.
2165
+
2166
+ Parameters
2167
+ ----------
2168
+ indices : list[int] | list[list[int]], default=None
2169
+ If you wish to select a particular subset of equations from
2170
+ `self.equations_`, give the row numbers here. By default,
2171
+ all equations will be used. If there are multiple output
2172
+ features, then pass a list of lists.
2173
+ precision : int, default=3
2174
+ The number of significant figures shown in the LaTeX
2175
+ representations.
2176
+ columns : list[str], default=["equation", "complexity", "loss", "score"]
2177
+ Which columns to include in the table.
2178
+
2179
+ Returns
2180
+ -------
2181
+ latex_table_str : str
2182
+ A string that will render a table in LaTeX of the equations.
2183
+ """
2184
+ self.refresh()
2185
+
2186
+ if self.nout_ > 1:
2187
+ if indices is not None:
2188
+ assert isinstance(indices, list)
2189
+ assert isinstance(indices[0], list)
2190
+ assert isinstance(len(indices), self.nout_)
2191
+
2192
+ generator_fnc = generate_multiple_tables
2193
+ else:
2194
+ if indices is not None:
2195
+ assert isinstance(indices, list)
2196
+ assert isinstance(indices[0], int)
2197
+
2198
+ generator_fnc = generate_single_table
2199
+
2200
+ table_string = generator_fnc(
2201
+ self.equations_, indices=indices, precision=precision, columns=columns
2202
+ )
2203
+ preamble_string = [
2204
+ r"\usepackage{breqn}",
2205
+ r"\usepackage{booktabs}",
2206
+ "",
2207
+ "...",
2208
+ "",
2209
+ ]
2210
+ return "\n".join(preamble_string + [table_string])
2211
+
2212
 
2213
  def idx_model_selection(equations: pd.DataFrame, model_selection: str) -> int:
2214
  """
test/test.py CHANGED
@@ -11,6 +11,7 @@ from pysr.sr import (
11
  _csv_filename_to_pkl_filename,
12
  idx_model_selection,
13
  )
 
14
  from sklearn.utils.estimator_checks import check_estimator
15
  import sympy
16
  import pandas as pd
@@ -353,19 +354,49 @@ class TestPipeline(unittest.TestCase):
353
  np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))
354
 
355
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  class TestBest(unittest.TestCase):
357
  def setUp(self):
358
  self.rstate = np.random.RandomState(0)
359
  self.X = self.rstate.randn(10, 2)
360
  self.y = np.cos(self.X[:, 0]) ** 2
361
- self.model = PySRRegressor(
362
- progress=False,
363
- niterations=1,
364
- extra_sympy_mappings={},
365
- output_jax_format=False,
366
- model_selection="accuracy",
367
- equation_file="equation_file.csv",
368
- )
369
  equations = pd.DataFrame(
370
  {
371
  "equation": ["1.0", "cos(x0)", "square(cos(x0))"],
@@ -373,17 +404,7 @@ class TestBest(unittest.TestCase):
373
  "complexity": [1, 2, 3],
374
  }
375
  )
376
-
377
- # Set up internal parameters as if it had been fitted:
378
- self.model.equation_file_ = "equation_file.csv"
379
- self.model.nout_ = 1
380
- self.model.selection_mask_ = None
381
- self.model.feature_names_in_ = np.array(["x0", "x1"], dtype=object)
382
- equations["complexity loss equation".split(" ")].to_csv(
383
- "equation_file.csv.bkup"
384
- )
385
-
386
- self.model.refresh()
387
  self.equations_ = self.model.equations_
388
 
389
  def test_best(self):
@@ -585,3 +606,197 @@ class TestMiscellaneous(unittest.TestCase):
585
  print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
586
  # If any checks failed don't let the test pass.
587
  self.assertEqual(len(exception_messages), 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  _csv_filename_to_pkl_filename,
12
  idx_model_selection,
13
  )
14
+ from pysr.export_latex import to_latex
15
  from sklearn.utils.estimator_checks import check_estimator
16
  import sympy
17
  import pandas as pd
 
354
  np.testing.assert_allclose(model.predict(self.X), model3.predict(self.X))
355
 
356
 
357
+ def manually_create_model(equations, feature_names=None):
358
+ if feature_names is None:
359
+ feature_names = ["x0", "x1"]
360
+
361
+ model = PySRRegressor(
362
+ progress=False,
363
+ niterations=1,
364
+ extra_sympy_mappings={},
365
+ output_jax_format=False,
366
+ model_selection="accuracy",
367
+ equation_file="equation_file.csv",
368
+ )
369
+
370
+ # Set up internal parameters as if it had been fitted:
371
+ if isinstance(equations, list):
372
+ # Multi-output.
373
+ model.equation_file_ = "equation_file.csv"
374
+ model.nout_ = len(equations)
375
+ model.selection_mask_ = None
376
+ model.feature_names_in_ = np.array(feature_names, dtype=object)
377
+ for i in range(model.nout_):
378
+ equations[i]["complexity loss equation".split(" ")].to_csv(
379
+ f"equation_file.csv.out{i+1}.bkup"
380
+ )
381
+ else:
382
+ model.equation_file_ = "equation_file.csv"
383
+ model.nout_ = 1
384
+ model.selection_mask_ = None
385
+ model.feature_names_in_ = np.array(feature_names, dtype=object)
386
+ equations["complexity loss equation".split(" ")].to_csv(
387
+ "equation_file.csv.bkup"
388
+ )
389
+
390
+ model.refresh()
391
+
392
+ return model
393
+
394
+
395
  class TestBest(unittest.TestCase):
396
  def setUp(self):
397
  self.rstate = np.random.RandomState(0)
398
  self.X = self.rstate.randn(10, 2)
399
  self.y = np.cos(self.X[:, 0]) ** 2
 
 
 
 
 
 
 
 
400
  equations = pd.DataFrame(
401
  {
402
  "equation": ["1.0", "cos(x0)", "square(cos(x0))"],
 
404
  "complexity": [1, 2, 3],
405
  }
406
  )
407
+ self.model = manually_create_model(equations)
 
 
 
 
 
 
 
 
 
 
408
  self.equations_ = self.model.equations_
409
 
410
  def test_best(self):
 
606
  print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
607
  # If any checks failed don't let the test pass.
608
  self.assertEqual(len(exception_messages), 0)
609
+
610
+
611
+ TRUE_PREAMBLE = "\n".join(
612
+ [
613
+ r"\usepackage{breqn}",
614
+ r"\usepackage{booktabs}",
615
+ "",
616
+ "...",
617
+ "",
618
+ ]
619
+ )
620
+
621
+
622
+ class TestLaTeXTable(unittest.TestCase):
623
+ def setUp(self):
624
+ equations = pd.DataFrame(
625
+ dict(
626
+ equation=["x0", "cos(x0)", "x0 + x1 - cos(x1 * x0)"],
627
+ loss=[1.052, 0.02315, 1.12347e-15],
628
+ complexity=[1, 2, 8],
629
+ )
630
+ )
631
+ self.model = manually_create_model(equations)
632
+ self.maxDiff = None
633
+
634
+ def create_true_latex(self, middle_part, include_score=False):
635
+ if include_score:
636
+ true_latex_table_str = r"""
637
+ \begin{table}[h]
638
+ \begin{center}
639
+ \begin{tabular}{@{}cccc@{}}
640
+ \toprule
641
+ Equation & Complexity & Loss & Score \\
642
+ \midrule"""
643
+ else:
644
+ true_latex_table_str = r"""
645
+ \begin{table}[h]
646
+ \begin{center}
647
+ \begin{tabular}{@{}ccc@{}}
648
+ \toprule
649
+ Equation & Complexity & Loss \\
650
+ \midrule"""
651
+ true_latex_table_str += middle_part
652
+ true_latex_table_str += r"""\bottomrule
653
+ \end{tabular}
654
+ \end{center}
655
+ \end{table}
656
+ """
657
+ # First, remove empty lines:
658
+ true_latex_table_str = "\n".join(
659
+ [line.strip() for line in true_latex_table_str.split("\n") if len(line) > 0]
660
+ )
661
+ return true_latex_table_str.strip()
662
+
663
+ def test_simple_table(self):
664
+ latex_table_str = self.model.latex_table(
665
+ columns=["equation", "complexity", "loss"]
666
+ )
667
+ middle_part = r"""
668
+ $y = x_{0}$ & $1$ & $1.05$ \\
669
+ $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ \\
670
+ $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
671
+ """
672
+ true_latex_table_str = (
673
+ TRUE_PREAMBLE + "\n" + self.create_true_latex(middle_part)
674
+ )
675
+ self.assertEqual(latex_table_str, true_latex_table_str)
676
+
677
+ def test_other_precision(self):
678
+ latex_table_str = self.model.latex_table(
679
+ precision=5, columns=["equation", "complexity", "loss"]
680
+ )
681
+ middle_part = r"""
682
+ $y = x_{0}$ & $1$ & $1.0520$ \\
683
+ $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.023150$ \\
684
+ $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.1235 \cdot 10^{-15}$ \\
685
+ """
686
+ true_latex_table_str = (
687
+ TRUE_PREAMBLE + "\n" + self.create_true_latex(middle_part)
688
+ )
689
+ self.assertEqual(latex_table_str, true_latex_table_str)
690
+
691
+ def test_include_score(self):
692
+ latex_table_str = self.model.latex_table()
693
+ middle_part = r"""
694
+ $y = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
695
+ $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
696
+ $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
697
+ """
698
+ true_latex_table_str = (
699
+ TRUE_PREAMBLE
700
+ + "\n"
701
+ + self.create_true_latex(middle_part, include_score=True)
702
+ )
703
+ self.assertEqual(latex_table_str, true_latex_table_str)
704
+
705
+ def test_last_equation(self):
706
+ latex_table_str = self.model.latex_table(
707
+ indices=[2], columns=["equation", "complexity", "loss"]
708
+ )
709
+ middle_part = r"""
710
+ $y = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ \\
711
+ """
712
+ true_latex_table_str = (
713
+ TRUE_PREAMBLE + "\n" + self.create_true_latex(middle_part)
714
+ )
715
+ self.assertEqual(latex_table_str, true_latex_table_str)
716
+
717
+ def test_multi_output(self):
718
+ equations1 = pd.DataFrame(
719
+ dict(
720
+ equation=["x0", "cos(x0)", "x0 + x1 - cos(x1 * x0)"],
721
+ loss=[1.052, 0.02315, 1.12347e-15],
722
+ complexity=[1, 2, 8],
723
+ )
724
+ )
725
+ equations2 = pd.DataFrame(
726
+ dict(
727
+ equation=["x1", "cos(x1)", "x0 * x0 * x1"],
728
+ loss=[1.32, 0.052, 2e-15],
729
+ complexity=[1, 2, 5],
730
+ )
731
+ )
732
+ equations = [equations1, equations2]
733
+ model = manually_create_model(equations)
734
+ middle_part_1 = r"""
735
+ $y_{0} = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
736
+ $y_{0} = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
737
+ $y_{0} = x_{0} + x_{1} - \cos{\left(x_{0} x_{1} \right)}$ & $8$ & $1.12 \cdot 10^{-15}$ & $5.11$ \\
738
+ """
739
+ middle_part_2 = r"""
740
+ $y_{1} = x_{1}$ & $1$ & $1.32$ & $0.0$ \\
741
+ $y_{1} = \cos{\left(x_{1} \right)}$ & $2$ & $0.0520$ & $3.23$ \\
742
+ $y_{1} = x_{0}^{2} x_{1}$ & $5$ & $2.00 \cdot 10^{-15}$ & $10.3$ \\
743
+ """
744
+ true_latex_table_str = "\n\n".join(
745
+ self.create_true_latex(part, include_score=True)
746
+ for part in [middle_part_1, middle_part_2]
747
+ )
748
+ true_latex_table_str = TRUE_PREAMBLE + "\n" + true_latex_table_str
749
+ latex_table_str = model.latex_table()
750
+
751
+ self.assertEqual(latex_table_str, true_latex_table_str)
752
+
753
+ def test_latex_float_precision(self):
754
+ """Test that we can print latex expressions with custom precision"""
755
+ expr = sympy.Float(4583.4485748, dps=50)
756
+ self.assertEqual(to_latex(expr, prec=6), r"4583.45")
757
+ self.assertEqual(to_latex(expr, prec=5), r"4583.4")
758
+ self.assertEqual(to_latex(expr, prec=4), r"4583.")
759
+ self.assertEqual(to_latex(expr, prec=3), r"4.58 \cdot 10^{3}")
760
+ self.assertEqual(to_latex(expr, prec=2), r"4.6 \cdot 10^{3}")
761
+
762
+ # Multiple numbers:
763
+ x = sympy.Symbol("x")
764
+ expr = x * 3232.324857384 - 1.4857485e-10
765
+ self.assertEqual(
766
+ to_latex(expr, prec=2), "3.2 \cdot 10^{3} x - 1.5 \cdot 10^{-10}"
767
+ )
768
+ self.assertEqual(
769
+ to_latex(expr, prec=3), "3.23 \cdot 10^{3} x - 1.49 \cdot 10^{-10}"
770
+ )
771
+ self.assertEqual(
772
+ to_latex(expr, prec=8), "3232.3249 x - 1.4857485 \cdot 10^{-10}"
773
+ )
774
+
775
+ def test_latex_break_long_equation(self):
776
+ """Test that we can break a long equation inside the table"""
777
+ long_equation = """
778
+ - cos(x1 * x0) + 3.2 * x0 - 1.2 * x1 + x1 * x1 * x1 + x0 * x0 * x0
779
+ + 5.2 * sin(0.3256 * sin(x2) - 2.6 * x0) + x0 * x0 * x0 * x0 * x0
780
+ + cos(cos(x1 * x0) + 3.2 * x0 - 1.2 * x1 + x1 * x1 * x1 + x0 * x0 * x0)
781
+ """
782
+ long_equation = "".join(long_equation.split("\n")).strip()
783
+ equations = pd.DataFrame(
784
+ dict(
785
+ equation=["x0", "cos(x0)", long_equation],
786
+ loss=[1.052, 0.02315, 1.12347e-15],
787
+ complexity=[1, 2, 30],
788
+ )
789
+ )
790
+ model = manually_create_model(equations)
791
+ latex_table_str = model.latex_table()
792
+ middle_part = r"""
793
+ $y = x_{0}$ & $1$ & $1.05$ & $0.0$ \\
794
+ $y = \cos{\left(x_{0} \right)}$ & $2$ & $0.0232$ & $3.82$ \\
795
+ \begin{minipage}{0.8\linewidth} \vspace{-1em} \begin{dmath*} y = x_{0}^{5} + x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} - 5.20 \sin{\left(2.60 x_{0} - 0.326 \sin{\left(x_{2} \right)} \right)} - \cos{\left(x_{0} x_{1} \right)} + \cos{\left(x_{0}^{3} + 3.20 x_{0} + x_{1}^{3} - 1.20 x_{1} + \cos{\left(x_{0} x_{1} \right)} \right)} \end{dmath*} \end{minipage} & $30$ & $1.12 \cdot 10^{-15}$ & $1.09$ \\
796
+ """
797
+ true_latex_table_str = (
798
+ TRUE_PREAMBLE
799
+ + "\n"
800
+ + self.create_true_latex(middle_part, include_score=True)
801
+ )
802
+ self.assertEqual(latex_table_str, true_latex_table_str)