PySR / pysr /export_latex.py
MilesCranmer's picture
Refactor LaTeX table to export_latex.py
c6f5c09
raw
history blame
3.54 kB
"""Functions to help export PySR equations to LaTeX."""
import sympy
from sympy.printing.latex import LatexPrinter
import pandas as pd
from typing import List
class PreciseLatexPrinter(LatexPrinter):
"""Modified SymPy printer with custom float precision."""
def __init__(self, settings=None, prec=3):
super().__init__(settings)
self.prec = prec
def _print_Float(self, expr):
# Reduce precision of float:
reduced_float = sympy.Float(expr, self.prec)
return super()._print_Float(reduced_float)
def to_latex(expr, prec=3, full_prec=True, **settings):
"""Convert sympy expression to LaTeX with custom precision."""
settings["full_prec"] = full_prec
printer = PreciseLatexPrinter(settings=settings, prec=prec)
return printer.doprint(expr)
def generate_table_environment(columns=["equation", "complexity", "loss"]):
margins = "".join([("l" if col == "equation" else "c") for col in columns])
column_map = {
"complexity": "Complexity",
"loss": "Loss",
"equation": "Equation",
"score": "Score",
}
columns = [column_map[col] for col in columns]
top_pieces = [
r"\begin{table}[h]",
r"\begin{center}",
r"\begin{tabular}{@{}" + margins + r"@{}}",
r"\toprule",
" & ".join(columns) + r" \\",
r"\midrule",
]
bottom_pieces = [
r"\bottomrule",
r"\end{tabular}",
r"\end{center}",
r"\end{table}",
]
top_latex_table = "\n".join(top_pieces)
bottom_latex_table = "\n".join(bottom_pieces)
return top_latex_table, bottom_latex_table
def generate_table(
equations: List[pd.DataFrame],
indices: List[List[int]],
precision=3,
columns=["equation", "complexity", "loss", "score"],
):
latex_top, latex_bottom = generate_table_environment(columns)
latex_equations = [
[to_latex(eq, prec=precision) for eq in equation_set["sympy_format"]]
for equation_set in equations
]
all_latex_table_str = []
for output_feature, index_set in enumerate(indices):
latex_table_content = []
for i in index_set:
latex_equation = latex_equations[output_feature][i]
complexity = str(equations[output_feature].iloc[i]["complexity"])
loss = to_latex(
sympy.Float(equations[output_feature].iloc[i]["loss"]),
prec=precision,
)
score = to_latex(
sympy.Float(equations[output_feature].iloc[i]["score"]),
prec=precision,
)
row_pieces = []
for col in columns:
if col == "equation":
row_pieces.append(latex_equation)
elif col == "complexity":
row_pieces.append(complexity)
elif col == "loss":
row_pieces.append(loss)
elif col == "score":
row_pieces.append(score)
else:
raise ValueError(f"Unknown column: {col}")
row_pieces = ["$" + piece + "$" for piece in row_pieces]
latex_table_content.append(
" & ".join(row_pieces) + r" \\",
)
this_latex_table = "\n".join(
[
latex_top,
*latex_table_content,
latex_bottom,
]
)
all_latex_table_str.append(this_latex_table)
return "\n\n".join(all_latex_table_str)