MilesCranmer commited on
Commit
c6f5c09
·
1 Parent(s): d423f0c

Refactor LaTeX table to export_latex.py

Browse files
Files changed (2) hide show
  1. pysr/export_latex.py +62 -0
  2. pysr/sr.py +4 -54
pysr/export_latex.py CHANGED
@@ -1,6 +1,8 @@
1
  """Functions to help export PySR equations to LaTeX."""
2
  import sympy
3
  from sympy.printing.latex import LatexPrinter
 
 
4
 
5
 
6
  class PreciseLatexPrinter(LatexPrinter):
@@ -51,3 +53,63 @@ def generate_table_environment(columns=["equation", "complexity", "loss"]):
51
  bottom_latex_table = "\n".join(bottom_pieces)
52
 
53
  return top_latex_table, bottom_latex_table
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """Functions to help export PySR equations to LaTeX."""
2
  import sympy
3
  from sympy.printing.latex import LatexPrinter
4
+ import pandas as pd
5
+ from typing import List
6
 
7
 
8
  class PreciseLatexPrinter(LatexPrinter):
 
53
  bottom_latex_table = "\n".join(bottom_pieces)
54
 
55
  return top_latex_table, bottom_latex_table
56
+
57
+
58
+ def generate_table(
59
+ equations: List[pd.DataFrame],
60
+ indices: List[List[int]],
61
+ precision=3,
62
+ columns=["equation", "complexity", "loss", "score"],
63
+ ):
64
+ latex_top, latex_bottom = generate_table_environment(columns)
65
+
66
+ latex_equations = [
67
+ [to_latex(eq, prec=precision) for eq in equation_set["sympy_format"]]
68
+ for equation_set in equations
69
+ ]
70
+
71
+ all_latex_table_str = []
72
+
73
+ for output_feature, index_set in enumerate(indices):
74
+ latex_table_content = []
75
+ for i in index_set:
76
+ latex_equation = latex_equations[output_feature][i]
77
+ complexity = str(equations[output_feature].iloc[i]["complexity"])
78
+ loss = to_latex(
79
+ sympy.Float(equations[output_feature].iloc[i]["loss"]),
80
+ prec=precision,
81
+ )
82
+ score = to_latex(
83
+ sympy.Float(equations[output_feature].iloc[i]["score"]),
84
+ prec=precision,
85
+ )
86
+
87
+ row_pieces = []
88
+ for col in columns:
89
+ if col == "equation":
90
+ row_pieces.append(latex_equation)
91
+ elif col == "complexity":
92
+ row_pieces.append(complexity)
93
+ elif col == "loss":
94
+ row_pieces.append(loss)
95
+ elif col == "score":
96
+ row_pieces.append(score)
97
+ else:
98
+ raise ValueError(f"Unknown column: {col}")
99
+
100
+ row_pieces = ["$" + piece + "$" for piece in row_pieces]
101
+
102
+ latex_table_content.append(
103
+ " & ".join(row_pieces) + r" \\",
104
+ )
105
+
106
+ this_latex_table = "\n".join(
107
+ [
108
+ latex_top,
109
+ *latex_table_content,
110
+ latex_bottom,
111
+ ]
112
+ )
113
+ all_latex_table_str.append(this_latex_table)
114
+
115
+ return "\n\n".join(all_latex_table_str)
pysr/sr.py CHANGED
@@ -27,7 +27,7 @@ from .julia_helpers import (
27
  import_error_string,
28
  )
29
  from .export_numpy import CallableEquation
30
- from .export_latex import to_latex, generate_table_environment
31
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
32
 
33
 
@@ -2033,8 +2033,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2033
  else:
2034
  indices = list(range(len(self.equations_)))
2035
 
2036
- latex_top, latex_bottom = generate_table_environment(columns)
2037
-
2038
  equations = self.equations_
2039
 
2040
  if isinstance(indices[0], int):
@@ -2044,57 +2042,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2044
 
2045
  assert len(indices) == self.nout_
2046
 
2047
- latex_equations = [
2048
- [to_latex(eq, prec=precision) for eq in equation_set["sympy_format"]]
2049
- for equation_set in equations
2050
- ]
2051
-
2052
- all_latex_table_str = []
2053
-
2054
- for output_feature, index_set in enumerate(indices):
2055
- latex_table_content = []
2056
- for i in index_set:
2057
- latex_equation = latex_equations[output_feature][i]
2058
- complexity = str(equations[output_feature].iloc[i]["complexity"])
2059
- loss = to_latex(
2060
- sympy.Float(equations[output_feature].iloc[i]["loss"]),
2061
- prec=precision,
2062
- )
2063
- score = to_latex(
2064
- sympy.Float(equations[output_feature].iloc[i]["score"]),
2065
- prec=precision,
2066
- )
2067
-
2068
- row_pieces = []
2069
- for col in columns:
2070
- if col == "equation":
2071
- row_pieces.append(latex_equation)
2072
- elif col == "complexity":
2073
- row_pieces.append(complexity)
2074
- elif col == "loss":
2075
- row_pieces.append(loss)
2076
- elif col == "score":
2077
- row_pieces.append(score)
2078
- else:
2079
- raise ValueError(f"Unknown column: {col}")
2080
-
2081
- row_pieces = ["$" + piece + "$" for piece in row_pieces]
2082
-
2083
- latex_table_content.append(
2084
- " & ".join(row_pieces) + r" \\",
2085
- )
2086
-
2087
- all_latex_table_str.append(
2088
- "\n".join(
2089
- [
2090
- latex_top,
2091
- *latex_table_content,
2092
- latex_bottom,
2093
- ]
2094
- )
2095
- )
2096
-
2097
- return "\n\n".join(all_latex_table_str)
2098
 
2099
 
2100
  def _denoise(X, y, Xresampled=None, random_state=None):
 
27
  import_error_string,
28
  )
29
  from .export_numpy import CallableEquation
30
+ from .export_latex import to_latex, generate_table
31
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
32
 
33
 
 
2033
  else:
2034
  indices = list(range(len(self.equations_)))
2035
 
 
 
2036
  equations = self.equations_
2037
 
2038
  if isinstance(indices[0], int):
 
2042
 
2043
  assert len(indices) == self.nout_
2044
 
2045
+ return generate_table(
2046
+ equations, indices=indices, precision=precision, columns=columns
2047
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2048
 
2049
 
2050
  def _denoise(X, y, Xresampled=None, random_state=None):