File size: 3,543 Bytes
f257e58
9a5df63
 
c6f5c09
 
f257e58
 
9a5df63
 
118c5f6
6024c83
9a5df63
 
 
 
 
6024c83
9a5df63
 
 
d1f1f2c
9a5df63
6024c83
 
9a5df63
f257e58
 
d423f0c
215a692
 
 
 
 
 
 
 
d423f0c
f257e58
6d5ddcb
8f218cc
f257e58
8f218cc
f257e58
 
 
d423f0c
f257e58
 
6d5ddcb
f257e58
 
d423f0c
 
 
 
c6f5c09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Functions to help export PySR equations to LaTeX."""
import sympy
from sympy.printing.latex import LatexPrinter
import pandas as pd
from typing import List


class PreciseLatexPrinter(LatexPrinter):
    """Modified SymPy printer with custom float precision."""

    def __init__(self, settings=None, prec=3):
        super().__init__(settings)
        self.prec = prec

    def _print_Float(self, expr):
        # Reduce precision of float:
        reduced_float = sympy.Float(expr, self.prec)
        return super()._print_Float(reduced_float)


def to_latex(expr, prec=3, full_prec=True, **settings):
    """Convert sympy expression to LaTeX with custom precision."""
    settings["full_prec"] = full_prec
    printer = PreciseLatexPrinter(settings=settings, prec=prec)
    return printer.doprint(expr)


def generate_table_environment(columns=["equation", "complexity", "loss"]):
    margins = "".join([("l" if col == "equation" else "c") for col in columns])
    column_map = {
        "complexity": "Complexity",
        "loss": "Loss",
        "equation": "Equation",
        "score": "Score",
    }
    columns = [column_map[col] for col in columns]
    top_pieces = [
        r"\begin{table}[h]",
        r"\begin{center}",
        r"\begin{tabular}{@{}" + margins + r"@{}}",
        r"\toprule",
        " & ".join(columns) + r" \\",
        r"\midrule",
    ]

    bottom_pieces = [
        r"\bottomrule",
        r"\end{tabular}",
        r"\end{center}",
        r"\end{table}",
    ]
    top_latex_table = "\n".join(top_pieces)
    bottom_latex_table = "\n".join(bottom_pieces)

    return top_latex_table, bottom_latex_table


def generate_table(
    equations: List[pd.DataFrame],
    indices: List[List[int]],
    precision=3,
    columns=["equation", "complexity", "loss", "score"],
):
    latex_top, latex_bottom = generate_table_environment(columns)

    latex_equations = [
        [to_latex(eq, prec=precision) for eq in equation_set["sympy_format"]]
        for equation_set in equations
    ]

    all_latex_table_str = []

    for output_feature, index_set in enumerate(indices):
        latex_table_content = []
        for i in index_set:
            latex_equation = latex_equations[output_feature][i]
            complexity = str(equations[output_feature].iloc[i]["complexity"])
            loss = to_latex(
                sympy.Float(equations[output_feature].iloc[i]["loss"]),
                prec=precision,
            )
            score = to_latex(
                sympy.Float(equations[output_feature].iloc[i]["score"]),
                prec=precision,
            )

            row_pieces = []
            for col in columns:
                if col == "equation":
                    row_pieces.append(latex_equation)
                elif col == "complexity":
                    row_pieces.append(complexity)
                elif col == "loss":
                    row_pieces.append(loss)
                elif col == "score":
                    row_pieces.append(score)
                else:
                    raise ValueError(f"Unknown column: {col}")

            row_pieces = ["$" + piece + "$" for piece in row_pieces]

            latex_table_content.append(
                " & ".join(row_pieces) + r" \\",
            )

        this_latex_table = "\n".join(
            [
                latex_top,
                *latex_table_content,
                latex_bottom,
            ]
        )
        all_latex_table_str.append(this_latex_table)

    return "\n\n".join(all_latex_table_str)