Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
d423f0c
1
Parent(s):
215a692
Refactor table env generator
Browse files- pysr/export_latex.py +7 -7
- pysr/sr.py +4 -9
pysr/export_latex.py
CHANGED
@@ -23,7 +23,7 @@ def to_latex(expr, prec=3, full_prec=True, **settings):
|
|
23 |
return printer.doprint(expr)
|
24 |
|
25 |
|
26 |
-
def
|
27 |
margins = "".join([("l" if col == "equation" else "c") for col in columns])
|
28 |
column_map = {
|
29 |
"complexity": "Complexity",
|
@@ -32,7 +32,7 @@ def generate_top_of_latex_table(columns=["equation", "complexity", "loss"]):
|
|
32 |
"score": "Score",
|
33 |
}
|
34 |
columns = [column_map[col] for col in columns]
|
35 |
-
|
36 |
r"\begin{table}[h]",
|
37 |
r"\begin{center}",
|
38 |
r"\begin{tabular}{@{}" + margins + r"@{}}",
|
@@ -40,14 +40,14 @@ def generate_top_of_latex_table(columns=["equation", "complexity", "loss"]):
|
|
40 |
" & ".join(columns) + r" \\",
|
41 |
r"\midrule",
|
42 |
]
|
43 |
-
return "\n".join(latex_table_pieces)
|
44 |
|
45 |
-
|
46 |
-
def generate_bottom_of_latex_table():
|
47 |
-
latex_table_pieces = [
|
48 |
r"\bottomrule",
|
49 |
r"\end{tabular}",
|
50 |
r"\end{center}",
|
51 |
r"\end{table}",
|
52 |
]
|
53 |
-
|
|
|
|
|
|
|
|
23 |
return printer.doprint(expr)
|
24 |
|
25 |
|
26 |
+
def generate_table_environment(columns=["equation", "complexity", "loss"]):
|
27 |
margins = "".join([("l" if col == "equation" else "c") for col in columns])
|
28 |
column_map = {
|
29 |
"complexity": "Complexity",
|
|
|
32 |
"score": "Score",
|
33 |
}
|
34 |
columns = [column_map[col] for col in columns]
|
35 |
+
top_pieces = [
|
36 |
r"\begin{table}[h]",
|
37 |
r"\begin{center}",
|
38 |
r"\begin{tabular}{@{}" + margins + r"@{}}",
|
|
|
40 |
" & ".join(columns) + r" \\",
|
41 |
r"\midrule",
|
42 |
]
|
|
|
43 |
|
44 |
+
bottom_pieces = [
|
|
|
|
|
45 |
r"\bottomrule",
|
46 |
r"\end{tabular}",
|
47 |
r"\end{center}",
|
48 |
r"\end{table}",
|
49 |
]
|
50 |
+
top_latex_table = "\n".join(top_pieces)
|
51 |
+
bottom_latex_table = "\n".join(bottom_pieces)
|
52 |
+
|
53 |
+
return top_latex_table, bottom_latex_table
|
pysr/sr.py
CHANGED
@@ -27,11 +27,7 @@ from .julia_helpers import (
|
|
27 |
import_error_string,
|
28 |
)
|
29 |
from .export_numpy import CallableEquation
|
30 |
-
from .export_latex import
|
31 |
-
to_latex,
|
32 |
-
generate_top_of_latex_table,
|
33 |
-
generate_bottom_of_latex_table,
|
34 |
-
)
|
35 |
from .deprecated import make_deprecated_kwargs_for_pysr_regressor
|
36 |
|
37 |
|
@@ -2037,8 +2033,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2037 |
else:
|
2038 |
indices = list(range(len(self.equations_)))
|
2039 |
|
2040 |
-
|
2041 |
-
latex_table_bottom = generate_bottom_of_latex_table()
|
2042 |
|
2043 |
equations = self.equations_
|
2044 |
|
@@ -2092,9 +2087,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2092 |
all_latex_table_str.append(
|
2093 |
"\n".join(
|
2094 |
[
|
2095 |
-
|
2096 |
*latex_table_content,
|
2097 |
-
|
2098 |
]
|
2099 |
)
|
2100 |
)
|
|
|
27 |
import_error_string,
|
28 |
)
|
29 |
from .export_numpy import CallableEquation
|
30 |
+
from .export_latex import to_latex, generate_table_environment
|
|
|
|
|
|
|
|
|
31 |
from .deprecated import make_deprecated_kwargs_for_pysr_regressor
|
32 |
|
33 |
|
|
|
2033 |
else:
|
2034 |
indices = list(range(len(self.equations_)))
|
2035 |
|
2036 |
+
latex_top, latex_bottom = generate_table_environment(columns)
|
|
|
2037 |
|
2038 |
equations = self.equations_
|
2039 |
|
|
|
2087 |
all_latex_table_str.append(
|
2088 |
"\n".join(
|
2089 |
[
|
2090 |
+
latex_top,
|
2091 |
*latex_table_content,
|
2092 |
+
latex_bottom,
|
2093 |
]
|
2094 |
)
|
2095 |
)
|