Spaces:
Sleeping
Sleeping
File size: 3,543 Bytes
f257e58 9a5df63 c6f5c09 f257e58 9a5df63 118c5f6 6024c83 9a5df63 6024c83 9a5df63 d1f1f2c 9a5df63 6024c83 9a5df63 f257e58 d423f0c 215a692 d423f0c f257e58 6d5ddcb 8f218cc f257e58 8f218cc f257e58 d423f0c f257e58 6d5ddcb f257e58 d423f0c c6f5c09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
"""Functions to help export PySR equations to LaTeX."""
import sympy
from sympy.printing.latex import LatexPrinter
import pandas as pd
from typing import List
class PreciseLatexPrinter(LatexPrinter):
"""Modified SymPy printer with custom float precision."""
def __init__(self, settings=None, prec=3):
super().__init__(settings)
self.prec = prec
def _print_Float(self, expr):
# Reduce precision of float:
reduced_float = sympy.Float(expr, self.prec)
return super()._print_Float(reduced_float)
def to_latex(expr, prec=3, full_prec=True, **settings):
"""Convert sympy expression to LaTeX with custom precision."""
settings["full_prec"] = full_prec
printer = PreciseLatexPrinter(settings=settings, prec=prec)
return printer.doprint(expr)
def generate_table_environment(columns=["equation", "complexity", "loss"]):
margins = "".join([("l" if col == "equation" else "c") for col in columns])
column_map = {
"complexity": "Complexity",
"loss": "Loss",
"equation": "Equation",
"score": "Score",
}
columns = [column_map[col] for col in columns]
top_pieces = [
r"\begin{table}[h]",
r"\begin{center}",
r"\begin{tabular}{@{}" + margins + r"@{}}",
r"\toprule",
" & ".join(columns) + r" \\",
r"\midrule",
]
bottom_pieces = [
r"\bottomrule",
r"\end{tabular}",
r"\end{center}",
r"\end{table}",
]
top_latex_table = "\n".join(top_pieces)
bottom_latex_table = "\n".join(bottom_pieces)
return top_latex_table, bottom_latex_table
def generate_table(
equations: List[pd.DataFrame],
indices: List[List[int]],
precision=3,
columns=["equation", "complexity", "loss", "score"],
):
latex_top, latex_bottom = generate_table_environment(columns)
latex_equations = [
[to_latex(eq, prec=precision) for eq in equation_set["sympy_format"]]
for equation_set in equations
]
all_latex_table_str = []
for output_feature, index_set in enumerate(indices):
latex_table_content = []
for i in index_set:
latex_equation = latex_equations[output_feature][i]
complexity = str(equations[output_feature].iloc[i]["complexity"])
loss = to_latex(
sympy.Float(equations[output_feature].iloc[i]["loss"]),
prec=precision,
)
score = to_latex(
sympy.Float(equations[output_feature].iloc[i]["score"]),
prec=precision,
)
row_pieces = []
for col in columns:
if col == "equation":
row_pieces.append(latex_equation)
elif col == "complexity":
row_pieces.append(complexity)
elif col == "loss":
row_pieces.append(loss)
elif col == "score":
row_pieces.append(score)
else:
raise ValueError(f"Unknown column: {col}")
row_pieces = ["$" + piece + "$" for piece in row_pieces]
latex_table_content.append(
" & ".join(row_pieces) + r" \\",
)
this_latex_table = "\n".join(
[
latex_top,
*latex_table_content,
latex_bottom,
]
)
all_latex_table_str.append(this_latex_table)
return "\n\n".join(all_latex_table_str)
|