MilesCranmer commited on
Commit
b2d7f41
β€’
1 Parent(s): f89d890

Refactor sympy and export functionality

Browse files
pysr/export_latex.py CHANGED
@@ -19,7 +19,7 @@ class PreciseLatexPrinter(LatexPrinter):
19
  return super()._print_Float(reduced_float)
20
 
21
 
22
- def to_latex(expr, prec=3, full_prec=True, **settings):
23
  """Convert sympy expression to LaTeX with custom precision."""
24
  settings["full_prec"] = full_prec
25
  printer = PreciseLatexPrinter(settings=settings, prec=prec)
@@ -56,7 +56,7 @@ def generate_table_environment(columns=["equation", "complexity", "loss"]):
56
  return top_latex_table, bottom_latex_table
57
 
58
 
59
- def generate_single_table(
60
  equations: pd.DataFrame,
61
  indices: List[int] = None,
62
  precision: int = 3,
@@ -74,16 +74,16 @@ def generate_single_table(
74
  indices = range(len(equations))
75
 
76
  for i in indices:
77
- latex_equation = to_latex(
78
  equations.iloc[i]["sympy_format"],
79
  prec=precision,
80
  )
81
  complexity = str(equations.iloc[i]["complexity"])
82
- loss = to_latex(
83
  sympy.Float(equations.iloc[i]["loss"]),
84
  prec=precision,
85
  )
86
- score = to_latex(
87
  sympy.Float(equations.iloc[i]["score"]),
88
  prec=precision,
89
  )
@@ -124,7 +124,7 @@ def generate_single_table(
124
  return "\n".join([latex_top, *latex_table_content, latex_bottom])
125
 
126
 
127
- def generate_multiple_tables(
128
  equations: List[pd.DataFrame],
129
  indices: List[List[int]] = None,
130
  precision: int = 3,
@@ -135,7 +135,7 @@ def generate_multiple_tables(
135
  # TODO: Let user specify custom output variable
136
 
137
  latex_tables = [
138
- generate_single_table(
139
  equations[i],
140
  (None if not indices else indices[i]),
141
  precision=precision,
 
19
  return super()._print_Float(reduced_float)
20
 
21
 
22
+ def sympy2latex(expr, prec=3, full_prec=True, **settings):
23
  """Convert sympy expression to LaTeX with custom precision."""
24
  settings["full_prec"] = full_prec
25
  printer = PreciseLatexPrinter(settings=settings, prec=prec)
 
56
  return top_latex_table, bottom_latex_table
57
 
58
 
59
+ def sympy2latextable(
60
  equations: pd.DataFrame,
61
  indices: List[int] = None,
62
  precision: int = 3,
 
74
  indices = range(len(equations))
75
 
76
  for i in indices:
77
+ latex_equation = sympy2latex(
78
  equations.iloc[i]["sympy_format"],
79
  prec=precision,
80
  )
81
  complexity = str(equations.iloc[i]["complexity"])
82
+ loss = sympy2latex(
83
  sympy.Float(equations.iloc[i]["loss"]),
84
  prec=precision,
85
  )
86
+ score = sympy2latex(
87
  sympy.Float(equations.iloc[i]["score"]),
88
  prec=precision,
89
  )
 
124
  return "\n".join([latex_top, *latex_table_content, latex_bottom])
125
 
126
 
127
+ def sympy2multilatextable(
128
  equations: List[pd.DataFrame],
129
  indices: List[List[int]] = None,
130
  precision: int = 3,
 
135
  # TODO: Let user specify custom output variable
136
 
137
  latex_tables = [
138
+ sympy2latextable(
139
  equations[i],
140
  (None if not indices else indices[i]),
141
  precision=precision,
pysr/export_numpy.py CHANGED
@@ -6,14 +6,17 @@ import pandas as pd
6
  from sympy import lambdify
7
 
8
 
 
 
 
 
9
  class CallableEquation:
10
  """Simple wrapper for numpy lambda functions built with sympy"""
11
 
12
- def __init__(self, sympy_symbols, eqn, selection=None, variable_names=None):
13
  self._sympy = eqn
14
  self._sympy_symbols = sympy_symbols
15
  self._selection = selection
16
- self._variable_names = variable_names
17
 
18
  def __repr__(self):
19
  return f"PySRFunction(X=>{self._sympy})"
@@ -23,7 +26,7 @@ class CallableEquation:
23
  if isinstance(X, pd.DataFrame):
24
  # Lambda function takes as argument:
25
  return self._lambda(
26
- **{k: X[k].values for k in self._variable_names}
27
  ) * np.ones(expected_shape)
28
  if self._selection is not None:
29
  if X.shape[1] != len(self._selection):
 
6
  from sympy import lambdify
7
 
8
 
9
+ def sympy2numpy(eqn, sympy_symbols, *, selection=None):
10
+ return CallableEquation(eqn, sympy_symbols, selection=selection)
11
+
12
+
13
  class CallableEquation:
14
  """Simple wrapper for numpy lambda functions built with sympy"""
15
 
16
+ def __init__(self, eqn, sympy_symbols, selection=None):
17
  self._sympy = eqn
18
  self._sympy_symbols = sympy_symbols
19
  self._selection = selection
 
20
 
21
  def __repr__(self):
22
  return f"PySRFunction(X=>{self._sympy})"
 
26
  if isinstance(X, pd.DataFrame):
27
  # Lambda function takes as argument:
28
  return self._lambda(
29
+ **{k: X[k].values for k in map(str, self._sympy_symbols)}
30
  ) * np.ones(expected_shape)
31
  if self._selection is not None:
32
  if X.shape[1] != len(self._selection):
pysr/export_sympy.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Define utilities to export to sympy"""
2
+ from typing import Callable, Dict, List, Optional
3
+
4
+ import sympy
5
+ from sympy import sympify
6
+
7
+ sympy_mappings = {
8
+ "div": lambda x, y: x / y,
9
+ "mult": lambda x, y: x * y,
10
+ "sqrt": lambda x: sympy.sqrt(x),
11
+ "sqrt_abs": lambda x: sympy.sqrt(abs(x)),
12
+ "square": lambda x: x**2,
13
+ "cube": lambda x: x**3,
14
+ "plus": lambda x, y: x + y,
15
+ "sub": lambda x, y: x - y,
16
+ "neg": lambda x: -x,
17
+ "pow": lambda x, y: x**y,
18
+ "pow_abs": lambda x, y: abs(x) ** y,
19
+ "cos": sympy.cos,
20
+ "sin": sympy.sin,
21
+ "tan": sympy.tan,
22
+ "cosh": sympy.cosh,
23
+ "sinh": sympy.sinh,
24
+ "tanh": sympy.tanh,
25
+ "exp": sympy.exp,
26
+ "acos": sympy.acos,
27
+ "asin": sympy.asin,
28
+ "atan": sympy.atan,
29
+ "acosh": lambda x: sympy.acosh(x),
30
+ "acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
31
+ "asinh": sympy.asinh,
32
+ "atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
33
+ "atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
34
+ "abs": abs,
35
+ "mod": sympy.Mod,
36
+ "erf": sympy.erf,
37
+ "erfc": sympy.erfc,
38
+ "log": lambda x: sympy.log(x),
39
+ "log10": lambda x: sympy.log(x, 10),
40
+ "log2": lambda x: sympy.log(x, 2),
41
+ "log1p": lambda x: sympy.log(x + 1),
42
+ "log_abs": lambda x: sympy.log(abs(x)),
43
+ "log10_abs": lambda x: sympy.log(abs(x), 10),
44
+ "log2_abs": lambda x: sympy.log(abs(x), 2),
45
+ "log1p_abs": lambda x: sympy.log(abs(x) + 1),
46
+ "floor": sympy.floor,
47
+ "ceil": sympy.ceiling,
48
+ "sign": sympy.sign,
49
+ "gamma": sympy.gamma,
50
+ }
51
+
52
+
53
+ def create_sympy_symbols(
54
+ feature_names_in: Optional[List[str]] = None,
55
+ ) -> List[sympy.Symbol]:
56
+ return [sympy.Symbol(variable) for variable in feature_names_in]
57
+
58
+
59
+ def pysr2sympy(
60
+ equation: str, *, extra_sympy_mappings: Optional[Dict[str, Callable]] = None
61
+ ) -> sympy.Expr:
62
+ local_sympy_mappings = {
63
+ **(extra_sympy_mappings if extra_sympy_mappings else {}),
64
+ **sympy_mappings,
65
+ }
66
+
67
+ return sympify(equation, locals=local_sympy_mappings)
68
+
69
+
70
+ def assert_valid_sympy_symbol(var_name: str) -> None:
71
+ if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
72
+ raise ValueError(f"Variable name {var_name} is already a function name.")
pysr/sr.py CHANGED
@@ -14,15 +14,16 @@ from pathlib import Path
14
 
15
  import numpy as np
16
  import pandas as pd
17
- import sympy
18
  from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
19
  from sklearn.utils import check_array, check_consistent_length, check_random_state
20
  from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
21
- from sympy import sympify
22
 
23
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
24
- from .export_latex import generate_multiple_tables, generate_single_table, to_latex
25
- from .export_numpy import CallableEquation
 
 
 
26
  from .julia_helpers import (
27
  _escape_filename,
28
  _load_backend,
@@ -37,51 +38,6 @@ Main = None # TODO: Rename to more descriptive name like "julia_runtime"
37
 
38
  already_ran = False
39
 
40
- sympy_mappings = {
41
- "div": lambda x, y: x / y,
42
- "mult": lambda x, y: x * y,
43
- "sqrt": lambda x: sympy.sqrt(x),
44
- "sqrt_abs": lambda x: sympy.sqrt(abs(x)),
45
- "square": lambda x: x**2,
46
- "cube": lambda x: x**3,
47
- "plus": lambda x, y: x + y,
48
- "sub": lambda x, y: x - y,
49
- "neg": lambda x: -x,
50
- "pow": lambda x, y: x**y,
51
- "pow_abs": lambda x, y: abs(x) ** y,
52
- "cos": sympy.cos,
53
- "sin": sympy.sin,
54
- "tan": sympy.tan,
55
- "cosh": sympy.cosh,
56
- "sinh": sympy.sinh,
57
- "tanh": sympy.tanh,
58
- "exp": sympy.exp,
59
- "acos": sympy.acos,
60
- "asin": sympy.asin,
61
- "atan": sympy.atan,
62
- "acosh": lambda x: sympy.acosh(x),
63
- "acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
64
- "asinh": sympy.asinh,
65
- "atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
66
- "atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
67
- "abs": abs,
68
- "mod": sympy.Mod,
69
- "erf": sympy.erf,
70
- "erfc": sympy.erfc,
71
- "log": lambda x: sympy.log(x),
72
- "log10": lambda x: sympy.log(x, 10),
73
- "log2": lambda x: sympy.log(x, 2),
74
- "log1p": lambda x: sympy.log(x + 1),
75
- "log_abs": lambda x: sympy.log(abs(x)),
76
- "log10_abs": lambda x: sympy.log(abs(x), 10),
77
- "log2_abs": lambda x: sympy.log(abs(x), 2),
78
- "log1p_abs": lambda x: sympy.log(abs(x) + 1),
79
- "floor": sympy.floor,
80
- "ceil": sympy.ceiling,
81
- "sign": sympy.sign,
82
- "gamma": sympy.gamma,
83
- }
84
-
85
 
86
  def pysr(X, y, weights=None, **kwargs): # pragma: no cover
87
  warnings.warn(
@@ -188,10 +144,6 @@ def _check_assertions(
188
  assert len(variable_names) == X.shape[1]
189
  # Check none of the variable names are function names:
190
  for var_name in variable_names:
191
- if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
192
- raise ValueError(
193
- f"Variable name {var_name} is already a function name."
194
- )
195
  # Check if alphanumeric only:
196
  if not re.match(r"^[β‚€β‚β‚‚β‚ƒβ‚„β‚…β‚†β‚‡β‚ˆβ‚‰a-zA-Z0-9_]+$", var_name):
197
  raise ValueError(
@@ -199,6 +151,7 @@ def _check_assertions(
199
  "Only alphanumeric characters, numbers, "
200
  "and underscores are allowed."
201
  )
 
202
  if X_units is not None and len(X_units) != X.shape[1]:
203
  raise ValueError(
204
  "The number of units in `X_units` must equal the number of features in `X`."
@@ -2116,10 +2069,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2116
  if self.nout_ > 1:
2117
  output = []
2118
  for s in sympy_representation:
2119
- latex = to_latex(s, prec=precision)
2120
  output.append(latex)
2121
  return output
2122
- return to_latex(sympy_representation, prec=precision)
2123
 
2124
  def jax(self, index=None):
2125
  """
@@ -2282,53 +2235,41 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2282
  jax_format = []
2283
  if self.output_torch_format:
2284
  torch_format = []
2285
- local_sympy_mappings = {
2286
- **(self.extra_sympy_mappings if self.extra_sympy_mappings else {}),
2287
- **sympy_mappings,
2288
- }
2289
-
2290
- sympy_symbols = [
2291
- sympy.Symbol(variable) for variable in self.feature_names_in_
2292
- ]
2293
 
2294
  for _, eqn_row in output.iterrows():
2295
- eqn = sympify(eqn_row["equation"], locals=local_sympy_mappings)
 
 
 
2296
  sympy_format.append(eqn)
2297
 
2298
- # Numpy:
 
2299
  lambda_format.append(
2300
- CallableEquation(
2301
- sympy_symbols, eqn, self.selection_mask_, self.feature_names_in_
 
 
2302
  )
2303
  )
2304
 
2305
  # JAX:
2306
  if self.output_jax_format:
2307
- from .export_jax import sympy2jax
2308
-
2309
  func, params = sympy2jax(
2310
  eqn,
2311
  sympy_symbols,
2312
  selection=self.selection_mask_,
2313
- extra_jax_mappings=(
2314
- self.extra_jax_mappings if self.extra_jax_mappings else {}
2315
- ),
2316
  )
2317
  jax_format.append({"callable": func, "parameters": params})
2318
 
2319
  # Torch:
2320
  if self.output_torch_format:
2321
- from .export_torch import sympy2torch
2322
-
2323
  module = sympy2torch(
2324
  eqn,
2325
  sympy_symbols,
2326
  selection=self.selection_mask_,
2327
- extra_torch_mappings=(
2328
- self.extra_torch_mappings
2329
- if self.extra_torch_mappings
2330
- else {}
2331
- ),
2332
  )
2333
  torch_format.append(module)
2334
 
@@ -2410,17 +2351,18 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2410
  assert isinstance(indices[0], list)
2411
  assert len(indices) == self.nout_
2412
 
2413
- generator_fnc = generate_multiple_tables
 
 
2414
  else:
2415
  if indices is not None:
2416
  assert isinstance(indices, list)
2417
  assert isinstance(indices[0], int)
2418
 
2419
- generator_fnc = generate_single_table
 
 
2420
 
2421
- table_string = generator_fnc(
2422
- self.equations_, indices=indices, precision=precision, columns=columns
2423
- )
2424
  preamble_string = [
2425
  r"\usepackage{breqn}",
2426
  r"\usepackage{booktabs}",
 
14
 
15
  import numpy as np
16
  import pandas as pd
 
17
  from sklearn.base import BaseEstimator, MultiOutputMixin, RegressorMixin
18
  from sklearn.utils import check_array, check_consistent_length, check_random_state
19
  from sklearn.utils.validation import _check_feature_names_in, check_is_fitted
 
20
 
21
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
22
+ from .export_jax import sympy2jax
23
+ from .export_latex import sympy2latex, sympy2latextable, sympy2multilatextable
24
+ from .export_numpy import sympy2numpy
25
+ from .export_sympy import assert_valid_sympy_symbol, create_sympy_symbols, pysr2sympy
26
+ from .export_torch import sympy2torch
27
  from .julia_helpers import (
28
  _escape_filename,
29
  _load_backend,
 
38
 
39
  already_ran = False
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  def pysr(X, y, weights=None, **kwargs): # pragma: no cover
43
  warnings.warn(
 
144
  assert len(variable_names) == X.shape[1]
145
  # Check none of the variable names are function names:
146
  for var_name in variable_names:
 
 
 
 
147
  # Check if alphanumeric only:
148
  if not re.match(r"^[β‚€β‚β‚‚β‚ƒβ‚„β‚…β‚†β‚‡β‚ˆβ‚‰a-zA-Z0-9_]+$", var_name):
149
  raise ValueError(
 
151
  "Only alphanumeric characters, numbers, "
152
  "and underscores are allowed."
153
  )
154
+ assert_valid_sympy_symbol(var_name)
155
  if X_units is not None and len(X_units) != X.shape[1]:
156
  raise ValueError(
157
  "The number of units in `X_units` must equal the number of features in `X`."
 
2069
  if self.nout_ > 1:
2070
  output = []
2071
  for s in sympy_representation:
2072
+ latex = sympy2latex(s, prec=precision)
2073
  output.append(latex)
2074
  return output
2075
+ return sympy2latex(sympy_representation, prec=precision)
2076
 
2077
  def jax(self, index=None):
2078
  """
 
2235
  jax_format = []
2236
  if self.output_torch_format:
2237
  torch_format = []
 
 
 
 
 
 
 
 
2238
 
2239
  for _, eqn_row in output.iterrows():
2240
+ eqn = pysr2sympy(
2241
+ eqn_row["equation"],
2242
+ extra_sympy_mappings=self.extra_sympy_mappings,
2243
+ )
2244
  sympy_format.append(eqn)
2245
 
2246
+ # NumPy:
2247
+ sympy_symbols = create_sympy_symbols(self.feature_names_in_)
2248
  lambda_format.append(
2249
+ sympy2numpy(
2250
+ eqn,
2251
+ sympy_symbols,
2252
+ selection=self.selection_mask_,
2253
  )
2254
  )
2255
 
2256
  # JAX:
2257
  if self.output_jax_format:
 
 
2258
  func, params = sympy2jax(
2259
  eqn,
2260
  sympy_symbols,
2261
  selection=self.selection_mask_,
2262
+ extra_jax_mappings=self.extra_jax_mappings,
 
 
2263
  )
2264
  jax_format.append({"callable": func, "parameters": params})
2265
 
2266
  # Torch:
2267
  if self.output_torch_format:
 
 
2268
  module = sympy2torch(
2269
  eqn,
2270
  sympy_symbols,
2271
  selection=self.selection_mask_,
2272
+ extra_torch_mappings=self.extra_torch_mappings,
 
 
 
 
2273
  )
2274
  torch_format.append(module)
2275
 
 
2351
  assert isinstance(indices[0], list)
2352
  assert len(indices) == self.nout_
2353
 
2354
+ table_string = sympy2multilatextable(
2355
+ self.equations_, indices=indices, precision=precision, columns=columns
2356
+ )
2357
  else:
2358
  if indices is not None:
2359
  assert isinstance(indices, list)
2360
  assert isinstance(indices[0], int)
2361
 
2362
+ table_string = sympy2latextable(
2363
+ self.equations_, indices=indices, precision=precision, columns=columns
2364
+ )
2365
 
 
 
 
2366
  preamble_string = [
2367
  r"\usepackage{breqn}",
2368
  r"\usepackage{booktabs}",
pysr/test/test.py CHANGED
@@ -10,11 +10,10 @@ from pathlib import Path
10
  import numpy as np
11
  import pandas as pd
12
  import sympy
13
- from sklearn import model_selection
14
  from sklearn.utils.estimator_checks import check_estimator
15
 
16
  from .. import PySRRegressor, julia_helpers
17
- from ..export_latex import to_latex
18
  from ..sr import (
19
  _check_assertions,
20
  _csv_filename_to_pkl_filename,
@@ -884,23 +883,23 @@ class TestLaTeXTable(unittest.TestCase):
884
  def test_latex_float_precision(self):
885
  """Test that we can print latex expressions with custom precision"""
886
  expr = sympy.Float(4583.4485748, dps=50)
887
- self.assertEqual(to_latex(expr, prec=6), r"4583.45")
888
- self.assertEqual(to_latex(expr, prec=5), r"4583.4")
889
- self.assertEqual(to_latex(expr, prec=4), r"4583.")
890
- self.assertEqual(to_latex(expr, prec=3), r"4.58 \cdot 10^{3}")
891
- self.assertEqual(to_latex(expr, prec=2), r"4.6 \cdot 10^{3}")
892
 
893
  # Multiple numbers:
894
  x = sympy.Symbol("x")
895
  expr = x * 3232.324857384 - 1.4857485e-10
896
  self.assertEqual(
897
- to_latex(expr, prec=2), r"3.2 \cdot 10^{3} x - 1.5 \cdot 10^{-10}"
898
  )
899
  self.assertEqual(
900
- to_latex(expr, prec=3), r"3.23 \cdot 10^{3} x - 1.49 \cdot 10^{-10}"
901
  )
902
  self.assertEqual(
903
- to_latex(expr, prec=8), r"3232.3249 x - 1.4857485 \cdot 10^{-10}"
904
  )
905
 
906
  def test_latex_break_long_equation(self):
 
10
  import numpy as np
11
  import pandas as pd
12
  import sympy
 
13
  from sklearn.utils.estimator_checks import check_estimator
14
 
15
  from .. import PySRRegressor, julia_helpers
16
+ from ..export_latex import sympy2latex
17
  from ..sr import (
18
  _check_assertions,
19
  _csv_filename_to_pkl_filename,
 
883
  def test_latex_float_precision(self):
884
  """Test that we can print latex expressions with custom precision"""
885
  expr = sympy.Float(4583.4485748, dps=50)
886
+ self.assertEqual(sympy2latex(expr, prec=6), r"4583.45")
887
+ self.assertEqual(sympy2latex(expr, prec=5), r"4583.4")
888
+ self.assertEqual(sympy2latex(expr, prec=4), r"4583.")
889
+ self.assertEqual(sympy2latex(expr, prec=3), r"4.58 \cdot 10^{3}")
890
+ self.assertEqual(sympy2latex(expr, prec=2), r"4.6 \cdot 10^{3}")
891
 
892
  # Multiple numbers:
893
  x = sympy.Symbol("x")
894
  expr = x * 3232.324857384 - 1.4857485e-10
895
  self.assertEqual(
896
+ sympy2latex(expr, prec=2), r"3.2 \cdot 10^{3} x - 1.5 \cdot 10^{-10}"
897
  )
898
  self.assertEqual(
899
+ sympy2latex(expr, prec=3), r"3.23 \cdot 10^{3} x - 1.49 \cdot 10^{-10}"
900
  )
901
  self.assertEqual(
902
+ sympy2latex(expr, prec=8), r"3232.3249 x - 1.4857485 \cdot 10^{-10}"
903
  )
904
 
905
  def test_latex_break_long_equation(self):