File size: 4,849 Bytes
f257e58
b896bd3
976f8d8
 
9a5df63
 
fab6f87
 
9a5df63
 
118c5f6
6024c83
9a5df63
 
 
 
 
6024c83
9a5df63
 
 
b896bd3
9a5df63
6024c83
 
9a5df63
f257e58
 
b896bd3
 
 
fab6f87
215a692
 
 
 
 
 
 
d423f0c
f257e58
6d5ddcb
8f218cc
f257e58
8f218cc
f257e58
 
 
d423f0c
f257e58
 
6d5ddcb
f257e58
 
d423f0c
 
 
 
c6f5c09
 
b2d7f41
c5cd4bb
b896bd3
c5cd4bb
b896bd3
2a802ab
3ef2b32
b896bd3
de4d559
c5cd4bb
 
c6f5c09
c5cd4bb
c6f5c09
c5cd4bb
b896bd3
c5cd4bb
 
b2d7f41
c5cd4bb
 
 
 
b2d7f41
c5cd4bb
 
 
b2d7f41
c5cd4bb
 
 
 
 
 
 
2a802ab
3ef2b32
 
 
2a802ab
 
 
fab6f87
2a802ab
fab6f87
3ef2b32
fab6f87
 
2a802ab
 
 
 
c5cd4bb
2a802ab
c5cd4bb
2a802ab
c5cd4bb
2a802ab
c5cd4bb
 
 
 
 
 
 
 
 
 
b2d7f41
c5cd4bb
b896bd3
c5cd4bb
b896bd3
 
 
de4d559
3ef2b32
c6f5c09
c5cd4bb
b2d7f41
c5cd4bb
 
 
 
3ef2b32
 
 
 
 
c6f5c09
c5cd4bb
 
c6f5c09
c5cd4bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Functions to help export PySR equations to LaTeX."""
from typing import List, Optional, Tuple

import pandas as pd
import sympy
from sympy.printing.latex import LatexPrinter


class PreciseLatexPrinter(LatexPrinter):
    """Modified SymPy printer with custom float precision."""

    def __init__(self, settings=None, prec=3):
        super().__init__(settings)
        self.prec = prec

    def _print_Float(self, expr):
        # Reduce precision of float:
        reduced_float = sympy.Float(expr, self.prec)
        return super()._print_Float(reduced_float)


def sympy2latex(expr, prec=3, full_prec=True, **settings) -> str:
    """Convert sympy expression to LaTeX with custom precision."""
    settings["full_prec"] = full_prec
    printer = PreciseLatexPrinter(settings=settings, prec=prec)
    return printer.doprint(expr)


def generate_table_environment(
    columns: List[str] = ["equation", "complexity", "loss"]
) -> Tuple[str, str]:
    margins = "c" * len(columns)
    column_map = {
        "complexity": "Complexity",
        "loss": "Loss",
        "equation": "Equation",
        "score": "Score",
    }
    columns = [column_map[col] for col in columns]
    top_pieces = [
        r"\begin{table}[h]",
        r"\begin{center}",
        r"\begin{tabular}{@{}" + margins + r"@{}}",
        r"\toprule",
        " & ".join(columns) + r" \\",
        r"\midrule",
    ]

    bottom_pieces = [
        r"\bottomrule",
        r"\end{tabular}",
        r"\end{center}",
        r"\end{table}",
    ]
    top_latex_table = "\n".join(top_pieces)
    bottom_latex_table = "\n".join(bottom_pieces)

    return top_latex_table, bottom_latex_table


def sympy2latextable(
    equations: pd.DataFrame,
    indices: Optional[List[int]] = None,
    precision: int = 3,
    columns: List[str] = ["equation", "complexity", "loss", "score"],
    max_equation_length: int = 50,
    output_variable_name: str = "y",
) -> str:
    """Generate a booktabs-style LaTeX table for a single set of equations."""
    assert isinstance(equations, pd.DataFrame)

    latex_top, latex_bottom = generate_table_environment(columns)
    latex_table_content = []

    if indices is None:
        indices = list(equations.index)

    for i in indices:
        latex_equation = sympy2latex(
            equations.iloc[i]["sympy_format"],
            prec=precision,
        )
        complexity = str(equations.iloc[i]["complexity"])
        loss = sympy2latex(
            sympy.Float(equations.iloc[i]["loss"]),
            prec=precision,
        )
        score = sympy2latex(
            sympy.Float(equations.iloc[i]["score"]),
            prec=precision,
        )

        row_pieces = []
        for col in columns:
            if col == "equation":
                if len(latex_equation) < max_equation_length:
                    row_pieces.append(
                        "$" + output_variable_name + " = " + latex_equation + "$"
                    )
                else:
                    broken_latex_equation = " ".join(
                        [
                            r"\begin{minipage}{0.8\linewidth}",
                            r"\vspace{-1em}",
                            r"\begin{dmath*}",
                            output_variable_name + " = " + latex_equation,
                            r"\end{dmath*}",
                            r"\end{minipage}",
                        ]
                    )
                    row_pieces.append(broken_latex_equation)

            elif col == "complexity":
                row_pieces.append("$" + complexity + "$")
            elif col == "loss":
                row_pieces.append("$" + loss + "$")
            elif col == "score":
                row_pieces.append("$" + score + "$")
            else:
                raise ValueError(f"Unknown column: {col}")

        latex_table_content.append(
            " & ".join(row_pieces) + r" \\",
        )

    return "\n".join([latex_top, *latex_table_content, latex_bottom])


def sympy2multilatextable(
    equations: List[pd.DataFrame],
    indices: Optional[List[List[int]]] = None,
    precision: int = 3,
    columns: List[str] = ["equation", "complexity", "loss", "score"],
    output_variable_names: Optional[List[str]] = None,
) -> str:
    """Generate multiple latex tables for a list of equation sets."""
    # TODO: Let user specify custom output variable

    latex_tables = [
        sympy2latextable(
            equations[i],
            (None if not indices else indices[i]),
            precision=precision,
            columns=columns,
            output_variable_name=(
                "y_{" + str(i) + "}"
                if output_variable_names is None
                else output_variable_names[i]
            ),
        )
        for i in range(len(equations))
    ]

    return "\n\n".join(latex_tables)